1use std::fmt;
2use std::io::Write;
3use std::net::SocketAddr;
4use std::str::from_utf8;
5
6use httparse;
7use rand;
8use sha1::{self, Digest};
9use url;
10
11use result::{Error, Kind, Result};
12
13static WS_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
14static BASE64: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
15const MAX_HEADERS: usize = 124;
16
17fn generate_key() -> String {
18 let key: [u8; 16] = rand::random();
19 encode_base64(&key)
20}
21
22pub fn hash_key(key: &[u8]) -> String {
23 let mut hasher = sha1::Sha1::new();
24
25 hasher.input(key);
26 hasher.input(WS_GUID.as_bytes());
27
28 encode_base64(&hasher.result())
29}
30
31fn encode_base64(data: &[u8]) -> String {
33 let len = data.len();
34 let mod_len = len % 3;
35
36 let mut encoded = vec![b'='; (len + 2) / 3 * 4];
37 {
38 let mut in_iter = data[..len - mod_len].iter().map(|&c| u32::from(c));
39 let mut out_iter = encoded.iter_mut();
40
41 let enc = |val| BASE64[val as usize];
42 let mut write = |val| *out_iter.next().unwrap() = val;
43
44 while let (Some(one), Some(two), Some(three)) =
45 (in_iter.next(), in_iter.next(), in_iter.next())
46 {
47 let g24 = one << 16 | two << 8 | three;
48 write(enc((g24 >> 18) & 63));
49 write(enc((g24 >> 12) & 63));
50 write(enc((g24 >> 6) & 63));
51 write(enc(g24 & 63));
52 }
53
54 match mod_len {
55 1 => {
56 let pad = (u32::from(data[len - 1])) << 16;
57 write(enc((pad >> 18) & 63));
58 write(enc((pad >> 12) & 63));
59 }
60 2 => {
61 let pad = (u32::from(data[len - 2])) << 16 | (u32::from(data[len - 1])) << 8;
62 write(enc((pad >> 18) & 63));
63 write(enc((pad >> 12) & 63));
64 write(enc((pad >> 6) & 63));
65 }
66 _ => (),
67 }
68 }
69
70 String::from_utf8(encoded).unwrap()
71}
72
73#[derive(Debug)]
75pub struct Handshake {
76 pub request: Request,
78 pub response: Response,
80 pub peer_addr: Option<SocketAddr>,
83 pub local_addr: Option<SocketAddr>,
85}
86
87impl Handshake {
88 #[allow(dead_code)]
101 pub fn remote_addr(&self) -> Result<Option<String>> {
102 Ok(self.request.client_addr()?.map(String::from).or_else(|| {
103 if let Some(addr) = self.peer_addr {
104 Some(addr.ip().to_string())
105 } else {
106 None
107 }
108 }))
109 }
110}
111
112#[derive(Debug)]
114pub struct Request {
115 path: String,
116 method: String,
117 headers: Vec<(String, Vec<u8>)>,
118}
119
120impl Request {
121 pub fn header(&self, header: &str) -> Option<&Vec<u8>> {
123 self.headers
124 .iter()
125 .find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase())
126 .map(|&(_, ref val)| val)
127 }
128
129 pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> {
131 self.headers
132 .iter_mut()
133 .find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase())
134 .map(|&mut (_, ref mut val)| val)
135 }
136
137 #[allow(dead_code)]
139 #[inline]
140 pub fn headers(&self) -> &Vec<(String, Vec<u8>)> {
141 &self.headers
142 }
143
144 #[allow(dead_code)]
146 #[inline]
147 pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> {
148 &mut self.headers
149 }
150
151 #[allow(dead_code)]
153 pub fn origin(&self) -> Result<Option<&str>> {
154 if let Some(origin) = self.header("origin") {
155 Ok(Some(from_utf8(origin)?))
156 } else {
157 Ok(None)
158 }
159 }
160
161 pub fn key(&self) -> Result<&Vec<u8>> {
163 self.header("sec-websocket-key")
164 .ok_or_else(|| Error::new(Kind::Protocol, "Unable to parse WebSocket key."))
165 }
166
167 pub fn hashed_key(&self) -> Result<String> {
169 Ok(hash_key(self.key()?))
170 }
171
172 #[allow(dead_code)]
174 pub fn version(&self) -> Result<&str> {
175 if let Some(version) = self.header("sec-websocket-version") {
176 from_utf8(version).map_err(Error::from)
177 } else {
178 Err(Error::new(
179 Kind::Protocol,
180 "The Sec-WebSocket-Version header is missing.",
181 ))
182 }
183 }
184
185 #[inline]
187 pub fn method(&self) -> &str {
188 &self.method
189 }
190
191 #[allow(dead_code)]
193 #[inline]
194 pub fn resource(&self) -> &str {
195 &self.path
196 }
197
198 #[allow(dead_code)]
200 pub fn protocols(&self) -> Result<Vec<&str>> {
201 if let Some(protos) = self.header("sec-websocket-protocol") {
202 Ok(from_utf8(protos)?
203 .split(',')
204 .map(|proto| proto.trim())
205 .collect())
206 } else {
207 Ok(Vec::new())
208 }
209 }
210
211 #[allow(dead_code)]
214 pub fn add_protocol(&mut self, protocol: &str) {
215 if let Some(protos) = self.header_mut("sec-websocket-protocol") {
216 protos.push(b","[0]);
217 protos.extend(protocol.as_bytes());
218 return;
219 }
220 self.headers_mut()
221 .push(("Sec-WebSocket-Protocol".into(), protocol.into()))
222 }
223
224 #[allow(dead_code)]
226 pub fn remove_protocol(&mut self, protocol: &str) {
227 if let Some(protos) = self.header_mut("sec-websocket-protocol") {
228 let mut new_protos = Vec::with_capacity(protos.len());
229
230 if let Ok(protos_str) = from_utf8(protos) {
231 new_protos = protos_str
232 .split(',')
233 .filter(|proto| proto.trim() == protocol)
234 .collect::<Vec<&str>>()
235 .join(",")
236 .into();
237 }
238 if new_protos.len() < protos.len() {
239 *protos = new_protos
240 }
241 }
242 }
243
244 #[allow(dead_code)]
246 pub fn extensions(&self) -> Result<Vec<&str>> {
247 if let Some(exts) = self.header("sec-websocket-extensions") {
248 Ok(from_utf8(exts)?.split(',').map(|ext| ext.trim()).collect())
249 } else {
250 Ok(Vec::new())
251 }
252 }
253
254 #[allow(dead_code)]
259 pub fn add_extension(&mut self, ext: &str) {
260 if let Some(exts) = self.header_mut("sec-websocket-extensions") {
261 exts.push(b","[0]);
262 exts.extend(ext.as_bytes());
263 return;
264 }
265 self.headers_mut()
266 .push(("Sec-WebSocket-Extensions".into(), ext.into()))
267 }
268
269 #[allow(dead_code)]
272 pub fn remove_extension(&mut self, ext: &str) {
273 if let Some(exts) = self.header_mut("sec-websocket-extensions") {
274 let mut new_exts = Vec::with_capacity(exts.len());
275
276 if let Ok(exts_str) = from_utf8(exts) {
277 new_exts = exts_str
278 .split(',')
279 .filter(|e| e.trim().starts_with(ext))
280 .collect::<Vec<&str>>()
281 .join(",")
282 .into();
283 }
284 if new_exts.len() < exts.len() {
285 *exts = new_exts
286 }
287 }
288 }
289
290 #[allow(dead_code)]
305 pub fn client_addr(&self) -> Result<Option<&str>> {
306 if let Some(x_forward) = self.header("x-forwarded-for") {
307 return Ok(from_utf8(x_forward)?.split(',').next());
308 }
309
310 if let Some(forward) = self.header("forwarded") {
312 if let Some(_for) = from_utf8(forward)?
313 .split(';')
314 .find(|f| f.trim().starts_with("for"))
315 {
316 if let Some(_for_eq) = _for.trim().split(',').next() {
317 let mut it = _for_eq.split('=');
318 it.next();
319 return Ok(it.next());
320 }
321 }
322 }
323 Ok(None)
324 }
325
326 pub fn parse(buf: &[u8]) -> Result<Option<Request>> {
329 let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
330 let mut req = httparse::Request::new(&mut headers);
331 let parsed = req.parse(buf)?;
332 if !parsed.is_partial() {
333 Ok(Some(Request {
334 path: req.path.unwrap().into(),
335 method: req.method.unwrap().into(),
336 headers: req.headers
337 .iter()
338 .map(|h| (h.name.into(), h.value.into()))
339 .collect(),
340 }))
341 } else {
342 Ok(None)
343 }
344 }
345
346 pub fn from_url(url: &url::Url) -> Result<Request> {
348 let query = if let Some(q) = url.query() {
349 format!("?{}", q)
350 } else {
351 "".into()
352 };
353
354 let mut headers = vec![
355 ("Connection".into(), "Upgrade".into()),
356 (
357 "Host".into(),
358 format!(
359 "{}:{}",
360 url.host_str().ok_or_else(|| Error::new(
361 Kind::Internal,
362 "No host passed for WebSocket connection.",
363 ))?,
364 url.port_or_known_default().unwrap_or(80)
365 ).into(),
366 ),
367 ("Sec-WebSocket-Version".into(), "13".into()),
368 ("Sec-WebSocket-Key".into(), generate_key().into()),
369 ("Upgrade".into(), "websocket".into()),
370 ];
371
372 if url.password().is_some() || url.username() != "" {
373 let basic = encode_base64(format!("{}:{}", url.username(), url.password().unwrap_or("")).as_bytes());
374 headers.push(("Authorization".into(), format!("Basic {}", basic).into()))
375 }
376
377 let req = Request {
378 path: format!("{}{}", url.path(), query),
379 method: "GET".to_owned(),
380 headers: headers,
381 };
382
383 debug!("Built request from URL:\n{}", req);
384
385 Ok(req)
386 }
387
388 pub fn format<W>(&self, w: &mut W) -> Result<()>
390 where
391 W: Write,
392 {
393 write!(w, "{} {} HTTP/1.1\r\n", self.method, self.path)?;
394 for &(ref key, ref val) in &self.headers {
395 write!(w, "{}: ", key)?;
396 w.write_all(val)?;
397 write!(w, "\r\n")?;
398 }
399 write!(w, "\r\n")?;
400 Ok(())
401 }
402}
403
404impl fmt::Display for Request {
405 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
406 let mut s = Vec::with_capacity(2048);
407 self.format(&mut s).map_err(|err| {
408 error!("{:?}", err);
409 fmt::Error
410 })?;
411 write!(
412 f,
413 "{}",
414 from_utf8(&s).map_err(|err| {
415 error!("Unable to format request as utf8: {:?}", err);
416 fmt::Error
417 })?
418 )
419 }
420}
421
422#[derive(Debug)]
424pub struct Response {
425 status: u16,
426 reason: String,
427 headers: Vec<(String, Vec<u8>)>,
428 body: Vec<u8>,
429}
430
431impl Response {
432 pub fn new<R>(status: u16, reason: R, body: Vec<u8>) -> Response
436 where
437 R: Into<String>,
438 {
439 Response {
440 status,
441 reason: reason.into(),
442 headers: vec![("Content-Length".into(), body.len().to_string().into())],
443 body,
444 }
445 }
446
447 #[inline]
449 pub fn body(&self) -> &[u8] {
450 &self.body
451 }
452
453 fn header(&self, header: &str) -> Option<&Vec<u8>> {
455 self.headers
456 .iter()
457 .find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase())
458 .map(|&(_, ref val)| val)
459 }
460 pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> {
462 self.headers
463 .iter_mut()
464 .find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase())
465 .map(|&mut (_, ref mut val)| val)
466 }
467
468 #[allow(dead_code)]
470 #[inline]
471 pub fn headers(&self) -> &Vec<(String, Vec<u8>)> {
472 &self.headers
473 }
474
475 #[allow(dead_code)]
477 #[inline]
478 pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> {
479 &mut self.headers
480 }
481
482 #[allow(dead_code)]
484 #[inline]
485 pub fn status(&self) -> u16 {
486 self.status
487 }
488
489 #[allow(dead_code)]
491 #[inline]
492 pub fn set_status(&mut self, status: u16) {
493 self.status = status
494 }
495
496 #[allow(dead_code)]
498 #[inline]
499 pub fn reason(&self) -> &str {
500 &self.reason
501 }
502
503 #[allow(dead_code)]
505 #[inline]
506 pub fn set_reason<R>(&mut self, reason: R)
507 where
508 R: Into<String>,
509 {
510 self.reason = reason.into()
511 }
512
513 pub fn key(&self) -> Result<&Vec<u8>> {
515 self.header("sec-websocket-accept")
516 .ok_or_else(|| Error::new(Kind::Protocol, "Unable to parse WebSocket key."))
517 }
518
519 #[allow(dead_code)]
521 pub fn protocol(&self) -> Result<Option<&str>> {
522 if let Some(proto) = self.header("sec-websocket-protocol") {
523 Ok(Some(from_utf8(proto)?))
524 } else {
525 Ok(None)
526 }
527 }
528
529 #[allow(dead_code)]
531 pub fn set_protocol(&mut self, protocol: &str) {
532 if let Some(proto) = self.header_mut("sec-websocket-protocol") {
533 *proto = protocol.into();
534 return;
535 }
536 self.headers_mut()
537 .push(("Sec-WebSocket-Protocol".into(), protocol.into()))
538 }
539
540 #[allow(dead_code)]
543 pub fn extensions(&self) -> Result<Vec<&str>> {
544 if let Some(exts) = self.header("sec-websocket-extensions") {
545 Ok(from_utf8(exts)?
546 .split(',')
547 .map(|proto| proto.trim())
548 .collect())
549 } else {
550 Ok(Vec::new())
551 }
552 }
553
554 #[allow(dead_code)]
557 pub fn add_extension(&mut self, ext: &str) {
558 if let Some(exts) = self.header_mut("sec-websocket-extensions") {
559 exts.push(b","[0]);
560 exts.extend(ext.as_bytes());
561 return;
562 }
563 self.headers_mut()
564 .push(("Sec-WebSocket-Extensions".into(), ext.into()))
565 }
566
567 #[allow(dead_code)]
570 pub fn remove_extension(&mut self, ext: &str) {
571 if let Some(exts) = self.header_mut("sec-websocket-extensions") {
572 let mut new_exts = Vec::with_capacity(exts.len());
573
574 if let Ok(exts_str) = from_utf8(exts) {
575 new_exts = exts_str
576 .split(',')
577 .filter(|e| e.trim().starts_with(ext))
578 .collect::<Vec<&str>>()
579 .join(",")
580 .into();
581 }
582 if new_exts.len() < exts.len() {
583 *exts = new_exts
584 }
585 }
586 }
587
588 pub fn parse(buf: &[u8]) -> Result<Option<Response>> {
591 let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
592 let mut res = httparse::Response::new(&mut headers);
593
594 let parsed = res.parse(buf)?;
595 if !parsed.is_partial() {
596 Ok(Some(Response {
597 status: res.code.unwrap(),
598 reason: res.reason.unwrap().into(),
599 headers: res.headers
600 .iter()
601 .map(|h| (h.name.into(), h.value.into()))
602 .collect(),
603 body: Vec::new(),
604 }))
605 } else {
606 Ok(None)
607 }
608 }
609
610 pub fn from_request(req: &Request) -> Result<Response> {
614 let res = Response {
615 status: 101,
616 reason: "Switching Protocols".into(),
617 headers: vec![
618 ("Connection".into(), "Upgrade".into()),
619 ("Sec-WebSocket-Accept".into(), req.hashed_key()?.into()),
620 ("Upgrade".into(), "websocket".into()),
621 ],
622 body: Vec::new(),
623 };
624
625 debug!("Built response from request:\n{}", res);
626 Ok(res)
627 }
628
629 pub fn format<W>(&self, w: &mut W) -> Result<()>
631 where
632 W: Write,
633 {
634 write!(w, "HTTP/1.1 {} {}\r\n", self.status, self.reason)?;
635 for &(ref key, ref val) in &self.headers {
636 write!(w, "{}: ", key)?;
637 w.write_all(val)?;
638 write!(w, "\r\n")?;
639 }
640 write!(w, "\r\n")?;
641 w.write_all(&self.body)?;
642 Ok(())
643 }
644}
645
646impl fmt::Display for Response {
647 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
648 let mut s = Vec::with_capacity(2048);
649 self.format(&mut s).map_err(|err| {
650 error!("{:?}", err);
651 fmt::Error
652 })?;
653 write!(
654 f,
655 "{}",
656 from_utf8(&s).map_err(|err| {
657 error!("Unable to format response as utf8: {:?}", err);
658 fmt::Error
659 })?
660 )
661 }
662}
663
664mod test {
665 #![allow(unused_imports, unused_variables, dead_code)]
666 use super::*;
667 use std::io::Write;
668 use std::net::SocketAddr;
669 use std::str::FromStr;
670
671 #[test]
672 fn remote_addr() {
673 let mut buf = Vec::with_capacity(2048);
674 write!(
675 &mut buf,
676 "GET / HTTP/1.1\r\n\
677 Connection: Upgrade\r\n\
678 Upgrade: websocket\r\n\
679 Sec-WebSocket-Version: 13\r\n\
680 Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n"
681 ).unwrap();
682
683 let req = Request::parse(&buf).unwrap().unwrap();
684 let res = Response::from_request(&req).unwrap();
685 let shake = Handshake {
686 request: req,
687 response: res,
688 peer_addr: Some(SocketAddr::from_str("127.0.0.1:8888").unwrap()),
689 local_addr: None,
690 };
691 assert_eq!(shake.remote_addr().unwrap().unwrap(), "127.0.0.1");
692 }
693
694 #[test]
695 fn remote_addr_x_forwarded_for() {
696 let mut buf = Vec::with_capacity(2048);
697 write!(
698 &mut buf,
699 "GET / HTTP/1.1\r\n\
700 Connection: Upgrade\r\n\
701 Upgrade: websocket\r\n\
702 X-Forwarded-For: 192.168.1.1, 192.168.1.2, 192.168.1.3\r\n\
703 Sec-WebSocket-Version: 13\r\n\
704 Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n"
705 ).unwrap();
706
707 let req = Request::parse(&buf).unwrap().unwrap();
708 let res = Response::from_request(&req).unwrap();
709 let shake = Handshake {
710 request: req,
711 response: res,
712 peer_addr: None,
713 local_addr: None,
714 };
715 assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.168.1.1");
716 }
717
718 #[test]
719 fn remote_addr_forwarded() {
720 let mut buf = Vec::with_capacity(2048);
721 write!(
722 &mut buf,
723 "GET / HTTP/1.1\r\n\
724 Connection: Upgrade\r\n\
725 Upgrade: websocket\r\n\
726 Forwarded: by=192.168.1.1; for=192.0.2.43, for=\"[2001:db8:cafe::17]\", for=unknown\r\n\
727 Sec-WebSocket-Version: 13\r\n\
728 Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n")
729 .unwrap();
730 let req = Request::parse(&buf).unwrap().unwrap();
731 let res = Response::from_request(&req).unwrap();
732 let shake = Handshake {
733 request: req,
734 response: res,
735 peer_addr: None,
736 local_addr: None,
737 };
738 assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.0.2.43");
739 }
740}