1use std::fmt;
2use std::io::Write;
3use std::mem::transmute;
4use std::str::from_utf8;
5use std::net::SocketAddr;
6
7use sha1;
8use rand;
9use url;
10use httparse;
11
12use result::{Result, Error, Kind};
13
14static WS_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
15static BASE64: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
16const MAX_HEADERS: usize = 124;
17
18fn generate_key() -> String {
19 let key: [u8; 16] = unsafe {
20 transmute(rand::random::<(u64, u64)>())
21 };
22 encode_base64(&key)
23}
24
25pub fn hash_key(key: &[u8]) -> String {
26 let mut hasher = sha1::Sha1::new();
27
28 hasher.update(key);
29 hasher.update(WS_GUID.as_bytes());
30
31 encode_base64(&hasher.digest().bytes())
32}
33
34fn encode_base64(data: &[u8]) -> String {
36 let len = data.len();
37 let mod_len = len % 3;
38
39 let mut encoded = vec![b'='; (len + 2) / 3 * 4];
40 {
41 let mut in_iter = data[..len - mod_len].iter().map(|&c| c as u32);
42 let mut out_iter = encoded.iter_mut();
43
44 let enc = |val| BASE64[val as usize];
45 let mut write = |val| *out_iter.next().unwrap() = val;
46
47 while let (Some(one), Some(two), Some(three)) = (in_iter.next(), in_iter.next(), in_iter.next()) {
48 let g24 = one << 16 | two << 8 | three;
49 write(enc((g24 >> 18) & 63));
50 write(enc((g24 >> 12) & 63));
51 write(enc((g24 >> 6 ) & 63));
52 write(enc(g24 & 63));
53 }
54
55 match mod_len {
56 1 => {
57 let pad = (data[len-1] as u32) << 16;
58 write(enc((pad >> 18) & 63));
59 write(enc((pad >> 12) & 63));
60 }
61 2 => {
62 let pad = (data[len-2] as u32) << 16 | (data[len-1] as u32) << 8;
63 write(enc((pad >> 18) & 63));
64 write(enc((pad >> 12) & 63));
65 write(enc((pad >> 6) & 63));
66 }
67 _ => (),
68 }
69 }
70
71 String::from_utf8(encoded).unwrap()
72}
73
74#[derive(Debug)]
76pub struct Handshake {
77 pub request: Request,
79 pub response: Response,
81 pub peer_addr: Option<SocketAddr>,
84 pub local_addr: Option<SocketAddr>,
86}
87
88impl Handshake {
89
90 #[allow(dead_code)]
103 pub fn remote_addr(&self) -> Result<Option<String>> {
104 Ok(try!(self.request.client_addr()).map(String::from).or_else(|| {
105 if let Some(addr) = self.peer_addr {
106 Some(addr.ip().to_string())
107 } else {
108 None
109 }
110 }))
111 }
112
113}
114
115
116#[derive(Debug)]
118pub struct Request {
119 path: String,
120 method: String,
121 headers: Vec<(String, Vec<u8>)>,
122}
123
124impl Request {
125
126 pub fn header(&self, header: &str) -> Option<&Vec<u8>> {
128 self.headers
129 .iter()
130 .find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase())
131 .map(|&(_, ref val)| val)
132 }
133
134 pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> {
136 self.headers
137 .iter_mut()
138 .find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase())
139 .map(|&mut (_, ref mut val)| val)
140 }
141
142 #[allow(dead_code)]
144 #[inline]
145 pub fn headers(&self) -> &Vec<(String, Vec<u8>)> {
146 &self.headers
147 }
148
149 #[allow(dead_code)]
151 #[inline]
152 pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> {
153 &mut self.headers
154 }
155
156 #[allow(dead_code)]
158 pub fn origin(&self) -> Result<Option<&str>> {
159 if let Some(origin) = self.header("origin") {
160 Ok(Some(try!(from_utf8(origin))))
161 } else {
162 Ok(None)
163 }
164 }
165
166 pub fn key(&self) -> Result<&Vec<u8>> {
168 self.header("sec-websocket-key")
169 .ok_or(Error::new(Kind::Protocol, "Unable to parse WebSocket key."))
170 }
171
172 pub fn hashed_key(&self) -> Result<String> {
174 Ok(hash_key(try!(self.key())))
175 }
176
177 #[allow(dead_code)]
179 pub fn version(&self) -> Result<&str> {
180 if let Some(version) = self.header("sec-websocket-version") {
181 from_utf8(version).map_err(Error::from)
182 } else {
183 Err(Error::new(Kind::Protocol, "The Sec-WebSocket-Version header is missing."))
184 }
185 }
186
187 #[allow(dead_code)]
189 #[inline]
190 pub fn resource(&self) -> &str {
191 &self.path
192 }
193
194 #[allow(dead_code)]
196 pub fn protocols(&self) -> Result<Vec<&str>> {
197 if let Some(protos) = self.header("sec-websocket-protocol") {
198 Ok(try!(from_utf8(protos)).split(',').map(|proto| proto.trim()).collect())
199 } else {
200 Ok(Vec::new())
201 }
202 }
203
204 #[allow(dead_code)]
207 pub fn add_protocol(&mut self, protocol: &str) {
208 if let Some(protos) = self.header_mut("sec-websocket-protocol") {
209 protos.push(b","[0]);
210 protos.extend(protocol.as_bytes());
211 return
212 }
213 self.headers_mut().push(("Sec-WebSocket-Protocol".into(), protocol.into()))
214 }
215
216 #[allow(dead_code)]
218 pub fn remove_protocol(&mut self, protocol: &str) {
219 if let Some(protos) = self.header_mut("sec-websocket-protocol") {
220 let mut new_protos = Vec::with_capacity(protos.len());
221
222 if let Ok(protos_str) = from_utf8(protos) {
223 new_protos = protos_str
224 .split(',')
225 .filter(|proto| proto.trim() == protocol)
226 .collect::<Vec<&str>>()
227 .join(",").into();
228 }
229 if new_protos.len() < protos.len() {
230 *protos = new_protos
231 }
232 }
233 }
234
235 #[allow(dead_code)]
237 pub fn extensions(&self) -> Result<Vec<&str>> {
238 if let Some(exts) = self.header("sec-websocket-extensions") {
239 Ok(try!(from_utf8(exts)).split(',').map(|ext| ext.trim()).collect())
240 } else {
241 Ok(Vec::new())
242 }
243 }
244
245 #[allow(dead_code)]
250 pub fn add_extension(&mut self, ext: &str) {
251 if let Some(exts) = self.header_mut("sec-websocket-extensions") {
252 exts.push(b","[0]);
253 exts.extend(ext.as_bytes());
254 return
255 }
256 self.headers_mut().push(("Sec-WebSocket-Extensions".into(), ext.into()))
257 }
258
259 #[allow(dead_code)]
262 pub fn remove_extension(&mut self, ext: &str) {
263 if let Some(exts) = self.header_mut("sec-websocket-extensions") {
264 let mut new_exts = Vec::with_capacity(exts.len());
265
266 if let Ok(exts_str) = from_utf8(exts) {
267 new_exts = exts_str
268 .split(',')
269 .filter(|e| e.trim().starts_with(ext))
270 .collect::<Vec<&str>>()
271 .join(",").into();
272 }
273 if new_exts.len() < exts.len() {
274 *exts = new_exts
275 }
276 }
277 }
278
279 #[allow(dead_code)]
294 pub fn client_addr(&self) -> Result<Option<&str>> {
295 if let Some(x_forward) = self.header("x-forwarded-for") {
296 return Ok(try!(from_utf8(x_forward)).split(',').next())
297 }
298
299 if let Some(forward) = self.header("forwarded") {
301 if let Some(_for) = try!(from_utf8(forward))
302 .split(';')
303 .find(|f| f.trim().starts_with("for"))
304 {
305 if let Some(_for_eq) = _for.trim().split(',').next() {
306 let mut it = _for_eq.split('=');
307 it.next();
308 return Ok(it.next())
309 }
310 }
311 }
312 Ok(None)
313 }
314
315 pub fn parse(buf: &[u8]) -> Result<Option<Request>> {
318 let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
319 let mut req = httparse::Request::new(&mut headers);
320 let parsed = try!(req.parse(buf));
321 if !parsed.is_partial() {
322 Ok(Some(Request {
323 path: req.path.unwrap().into(),
324 method: req.method.unwrap().into(),
325 headers: req.headers.iter().map(|h| (h.name.into(), h.value.into())).collect(),
326 }))
327 } else {
328 Ok(None)
329 }
330 }
331
332 pub fn from_url(url: &url::Url) -> Result<Request> {
334
335 let query = if let Some(q) = url.query() {
336 format!("?{}", q)
337 } else {
338 "".into()
339 };
340
341 let req = Request {
342 path: format!(
343 "{}{}",
344 url.path(),
345 query),
346 method: "GET".to_owned(),
347 headers: vec![
348 ("Connection".into(), "Upgrade".into()),
349 (
350 "Host".into(),
351 format!(
352 "{}:{}",
353 try!(url.host_str().ok_or(
354 Error::new(Kind::Internal, "No host passed for WebSocket connection."))),
355 url.port_or_known_default().unwrap_or(80)).into(),
356 ),
357 ("Sec-WebSocket-Version".into(), "13".into()),
358 ("Sec-WebSocket-Key".into(), generate_key().into()),
359 ("Upgrade".into(), "websocket".into()),
360 ],
361 };
362
363 debug!("Built request from URL:\n{}", req);
364
365 Ok(req)
366 }
367
368 pub fn format<W>(&self, w: &mut W) -> Result<()>
370 where W: Write,
371 {
372 try!(write!(w, "{} {} HTTP/1.1\r\n", self.method, self.path));
373 for &(ref key, ref val) in self.headers.iter() {
374 try!(write!(w, "{}: ", key));
375 try!(w.write(val));
376 try!(write!(w, "\r\n"));
377 }
378 try!(write!(w, "\r\n"));
379 Ok(())
380 }
381
382}
383
384impl fmt::Display for Request {
385
386 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
387 let mut s = Vec::with_capacity(2048);
388 try!(self.format(&mut s).map_err(|err| {
389 error!("{:?}", err);
390 fmt::Error
391 }));
392 write!(f, "{}", try!(from_utf8(&s).map_err(|err| {
393 error!("Unable to format request as utf8: {:?}", err);
394 fmt::Error
395 })))
396 }
397
398}
399
400#[derive(Debug)]
402pub struct Response {
403 status: u16,
404 reason: String,
405 headers: Vec<(String, Vec<u8>)>,
406}
407
408impl Response {
409 fn header(&self, header: &str) -> Option<&Vec<u8>> {
413 self.headers
414 .iter()
415 .find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase())
416 .map(|&(_, ref val)| val)
417 }
418 pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> {
420 self.headers
421 .iter_mut()
422 .find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase())
423 .map(|&mut (_, ref mut val)| val)
424 }
425
426 #[allow(dead_code)]
428 #[inline]
429 pub fn headers(&self) -> &Vec<(String, Vec<u8>)> {
430 &self.headers
431 }
432
433 #[allow(dead_code)]
435 #[inline]
436 pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> {
437 &mut self.headers
438 }
439
440 #[allow(dead_code)]
442 #[inline]
443 pub fn status(&self) -> u16 {
444 self.status
445 }
446
447 #[allow(dead_code)]
449 #[inline]
450 pub fn set_status(&mut self, status: u16) {
451 self.status = status
452 }
453
454 #[allow(dead_code)]
456 #[inline]
457 pub fn reason(&self) -> &str {
458 &self.reason
459 }
460
461
462 #[allow(dead_code)]
464 #[inline]
465 pub fn set_reason<R>(&mut self, reason: R)
466 where R: Into<String>
467 {
468 self.reason = reason.into()
469 }
470
471 pub fn key(&self) -> Result<&Vec<u8>> {
473 self.header("sec-websocket-accept").ok_or(Error::new(Kind::Protocol, "Unable to parse WebSocket key."))
474 }
475
476 #[allow(dead_code)]
478 pub fn protocol(&self) -> Result<Option<&str>> {
479 if let Some(proto) = self.header("sec-websocket-protocol") {
480 Ok(Some(try!(from_utf8(proto))))
481 } else {
482 Ok(None)
483 }
484 }
485
486 #[allow(dead_code)]
488 pub fn set_protocol(&mut self, protocol: &str) {
489 if let Some(proto) = self.header_mut("sec-websocket-protocol") {
490 *proto = protocol.into();
491 return
492 }
493 self.headers_mut().push(("Sec-WebSocket-Protocol".into(), protocol.into()))
494 }
495
496 #[allow(dead_code)]
499 pub fn extensions(&self) -> Result<Vec<&str>> {
500 if let Some(exts) = self.header("sec-websocket-extensions") {
501 Ok(try!(from_utf8(exts)).split(',').map(|proto| proto.trim()).collect())
502 } else {
503 Ok(Vec::new())
504 }
505 }
506
507 #[allow(dead_code)]
510 pub fn add_extension(&mut self, ext: &str) {
511 if let Some(exts) = self.header_mut("sec-websocket-extensions") {
512 exts.push(b","[0]);
513 exts.extend(ext.as_bytes());
514 return
515 }
516 self.headers_mut().push(("Sec-WebSocket-Extensions".into(), ext.into()))
517 }
518
519 #[allow(dead_code)]
522 pub fn remove_extension(&mut self, ext: &str) {
523 if let Some(exts) = self.header_mut("sec-websocket-extensions") {
524 let mut new_exts = Vec::with_capacity(exts.len());
525
526 if let Ok(exts_str) = from_utf8(exts) {
527 new_exts = exts_str
528 .split(',')
529 .filter(|e| e.trim().starts_with(ext))
530 .collect::<Vec<&str>>()
531 .join(",").into();
532 }
533 if new_exts.len() < exts.len() {
534 *exts = new_exts
535 }
536 }
537 }
538
539 pub fn parse(buf: &[u8]) -> Result<Option<Response>> {
542 let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
543 let mut res = httparse::Response::new(&mut headers);
544
545 let parsed = try!(res.parse(buf));
546 if !parsed.is_partial() {
547 Ok(Some(Response {
548 status: res.code.unwrap(),
549 reason: res.reason.unwrap().into(),
550 headers: res.headers.iter().map(|h| (h.name.into(), h.value.into())).collect(),
551 }))
552 } else {
553 Ok(None)
554 }
555 }
556
557 pub fn from_request(req: &Request) -> Result<Response> {
561 let res = Response {
562 status: 101,
563 reason: "Switching Protocols".into(),
564 headers: vec![
565 ("Connection".into(), "Upgrade".into()),
566 ("Sec-WebSocket-Accept".into(), try!(req.hashed_key()).into()),
567 ("Upgrade".into(), "websocket".into()),
568 ],
569 };
570
571 debug!("Built response from request:\n{}", res);
572 Ok(res)
573 }
574
575 pub fn format<W>(&self, w: &mut W) -> Result<()>
577 where W: Write
578 {
579 try!(write!(w, "HTTP/1.1 {} {}\r\n", self.status, self.reason));
580 for &(ref key, ref val) in self.headers.iter() {
581 try!(write!(w, "{}: ", key));
582 try!(w.write(val));
583 try!(write!(w, "\r\n"));
584 }
585 try!(write!(w, "\r\n"));
586 Ok(())
587 }
588}
589
590impl fmt::Display for Response {
591
592 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
593 let mut s = Vec::with_capacity(2048);
594 try!(self.format(&mut s).map_err(|err| {
595 error!("{:?}", err);
596 fmt::Error
597 }));
598 write!(f, "{}", try!(from_utf8(&s).map_err(|err| {
599 error!("Unable to format response as utf8: {:?}", err);
600 fmt::Error
601 })))
602 }
603
604}
605
606mod test {
607 #![allow(unused_imports, unused_variables, dead_code)]
608 use std::io::Write;
609 use std::net::SocketAddr;
610 use std::str::FromStr;
611 use super::*;
612
613 #[test]
614 fn remote_addr() {
615 let mut buf = Vec::with_capacity(2048);
616 write!(
617 &mut buf,
618 "GET / HTTP/1.1\r\n\
619 Connection: Upgrade\r\n\
620 Upgrade: websocket\r\n\
621 Sec-WebSocket-Version: 13\r\n\
622 Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n").unwrap();
623
624 let req = Request::parse(&buf).unwrap().unwrap();
625 let res = Response::from_request(&req).unwrap();
626 let shake = Handshake {
627 request: req,
628 response: res,
629 peer_addr: Some(SocketAddr::from_str("127.0.0.1:8888").unwrap()),
630 local_addr: None,
631 };
632 assert_eq!(shake.remote_addr().unwrap().unwrap(), "127.0.0.1");
633 }
634
635 #[test]
636 fn remote_addr_x_forwarded_for() {
637 let mut buf = Vec::with_capacity(2048);
638 write!(
639 &mut buf,
640 "GET / HTTP/1.1\r\n\
641 Connection: Upgrade\r\n\
642 Upgrade: websocket\r\n\
643 X-Forwarded-For: 192.168.1.1, 192.168.1.2, 192.168.1.3\r\n\
644 Sec-WebSocket-Version: 13\r\n\
645 Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n").unwrap();
646
647 let req = Request::parse(&buf).unwrap().unwrap();
648 let res = Response::from_request(&req).unwrap();
649 let shake = Handshake {
650 request: req,
651 response: res,
652 peer_addr: None,
653 local_addr: None,
654 };
655 assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.168.1.1");
656 }
657
658 #[test]
659 fn remote_addr_forwarded() {
660 let mut buf = Vec::with_capacity(2048);
661 write!(
662 &mut buf,
663 "GET / HTTP/1.1\r\n\
664 Connection: Upgrade\r\n\
665 Upgrade: websocket\r\n\
666 Forwarded: by=192.168.1.1; for=192.0.2.43, for=\"[2001:db8:cafe::17]\", for=unknown\r\n\
667 Sec-WebSocket-Version: 13\r\n\
668 Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n").unwrap();
669 let req = Request::parse(&buf).unwrap().unwrap();
670 let res = Response::from_request(&req).unwrap();
671 let shake = Handshake {
672 request: req,
673 response: res,
674 peer_addr: None,
675 local_addr: None,
676 };
677 assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.0.2.43");
678 }
679}