wsq/
handshake.rs

1use std::fmt;
2use std::io::Write;
3use std::net::SocketAddr;
4use std::str::from_utf8;
5
6use httparse;
7use rand;
8use sha1::{self, Digest};
9use url;
10
11use result::{Error, Kind, Result};
12
13static WS_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
14static BASE64: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
15const MAX_HEADERS: usize = 124;
16
17fn generate_key() -> String {
18    let key: [u8; 16] = rand::random();
19    encode_base64(&key)
20}
21
22pub fn hash_key(key: &[u8]) -> String {
23    let mut hasher = sha1::Sha1::new();
24
25    hasher.input(key);
26    hasher.input(WS_GUID.as_bytes());
27
28    encode_base64(&hasher.result())
29}
30
31// This code is based on rustc_serialize base64 STANDARD
32fn encode_base64(data: &[u8]) -> String {
33    let len = data.len();
34    let mod_len = len % 3;
35
36    let mut encoded = vec![b'='; (len + 2) / 3 * 4];
37    {
38        let mut in_iter = data[..len - mod_len].iter().map(|&c| u32::from(c));
39        let mut out_iter = encoded.iter_mut();
40
41        let enc = |val| BASE64[val as usize];
42        let mut write = |val| *out_iter.next().unwrap() = val;
43
44        while let (Some(one), Some(two), Some(three)) =
45            (in_iter.next(), in_iter.next(), in_iter.next())
46        {
47            let g24 = one << 16 | two << 8 | three;
48            write(enc((g24 >> 18) & 63));
49            write(enc((g24 >> 12) & 63));
50            write(enc((g24 >> 6) & 63));
51            write(enc(g24 & 63));
52        }
53
54        match mod_len {
55            1 => {
56                let pad = (u32::from(data[len - 1])) << 16;
57                write(enc((pad >> 18) & 63));
58                write(enc((pad >> 12) & 63));
59            }
60            2 => {
61                let pad = (u32::from(data[len - 2])) << 16 | (u32::from(data[len - 1])) << 8;
62                write(enc((pad >> 18) & 63));
63                write(enc((pad >> 12) & 63));
64                write(enc((pad >> 6) & 63));
65            }
66            _ => (),
67        }
68    }
69
70    String::from_utf8(encoded).unwrap()
71}
72
73/// A struct representing the two halves of the WebSocket handshake.
74#[derive(Debug)]
75pub struct Handshake {
76    /// The HTTP request sent to begin the handshake.
77    pub request: Request,
78    /// The HTTP response from the server confirming the handshake.
79    pub response: Response,
80    /// The socket address of the other endpoint. This address may
81    /// be an intermediary such as a proxy server.
82    pub peer_addr: Option<SocketAddr>,
83    /// The socket address of this endpoint.
84    pub local_addr: Option<SocketAddr>,
85}
86
87impl Handshake {
88    /// Get the IP address of the remote connection.
89    ///
90    /// This is the preferred method of obtaining the client's IP address.
91    /// It will attempt to retrieve the most likely IP address based on request
92    /// headers, falling back to the address of the peer.
93    ///
94    /// # Note
95    /// This assumes that the peer is a client. If you are implementing a
96    /// WebSocket client and want to obtain the address of the server, use
97    /// `Handshake::peer_addr` instead.
98    ///
99    /// This method does not ensure that the address is a valid IP address.
100    #[allow(dead_code)]
101    pub fn remote_addr(&self) -> Result<Option<String>> {
102        Ok(self.request.client_addr()?.map(String::from).or_else(|| {
103            if let Some(addr) = self.peer_addr {
104                Some(addr.ip().to_string())
105            } else {
106                None
107            }
108        }))
109    }
110}
111
112/// The handshake request.
113#[derive(Debug)]
114pub struct Request {
115    path: String,
116    method: String,
117    headers: Vec<(String, Vec<u8>)>,
118}
119
120impl Request {
121    /// Get the value of the first instance of an HTTP header.
122    pub fn header(&self, header: &str) -> Option<&Vec<u8>> {
123        self.headers
124            .iter()
125            .find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase())
126            .map(|&(_, ref val)| val)
127    }
128
129    /// Edit the value of the first instance of an HTTP header.
130    pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> {
131        self.headers
132            .iter_mut()
133            .find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase())
134            .map(|&mut (_, ref mut val)| val)
135    }
136
137    /// Access the request headers.
138    #[allow(dead_code)]
139    #[inline]
140    pub fn headers(&self) -> &Vec<(String, Vec<u8>)> {
141        &self.headers
142    }
143
144    /// Edit the request headers.
145    #[allow(dead_code)]
146    #[inline]
147    pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> {
148        &mut self.headers
149    }
150
151    /// Get the origin of the request if it comes from a browser.
152    #[allow(dead_code)]
153    pub fn origin(&self) -> Result<Option<&str>> {
154        if let Some(origin) = self.header("origin") {
155            Ok(Some(from_utf8(origin)?))
156        } else {
157            Ok(None)
158        }
159    }
160
161    /// Get the unhashed WebSocket key sent in the request.
162    pub fn key(&self) -> Result<&Vec<u8>> {
163        self.header("sec-websocket-key")
164            .ok_or_else(|| Error::new(Kind::Protocol, "Unable to parse WebSocket key."))
165    }
166
167    /// Get the hashed WebSocket key from this request.
168    pub fn hashed_key(&self) -> Result<String> {
169        Ok(hash_key(self.key()?))
170    }
171
172    /// Get the WebSocket protocol version from the request (should be 13).
173    #[allow(dead_code)]
174    pub fn version(&self) -> Result<&str> {
175        if let Some(version) = self.header("sec-websocket-version") {
176            from_utf8(version).map_err(Error::from)
177        } else {
178            Err(Error::new(
179                Kind::Protocol,
180                "The Sec-WebSocket-Version header is missing.",
181            ))
182        }
183    }
184
185    /// Get the request method.
186    #[inline]
187    pub fn method(&self) -> &str {
188        &self.method
189    }
190
191    /// Get the path of the request.
192    #[allow(dead_code)]
193    #[inline]
194    pub fn resource(&self) -> &str {
195        &self.path
196    }
197
198    /// Get the possible protocols for the WebSocket connection.
199    #[allow(dead_code)]
200    pub fn protocols(&self) -> Result<Vec<&str>> {
201        if let Some(protos) = self.header("sec-websocket-protocol") {
202            Ok(from_utf8(protos)?
203                .split(',')
204                .map(|proto| proto.trim())
205                .collect())
206        } else {
207            Ok(Vec::new())
208        }
209    }
210
211    /// Add a possible protocol to this request.
212    /// This may result in duplicate protocols listed.
213    #[allow(dead_code)]
214    pub fn add_protocol(&mut self, protocol: &str) {
215        if let Some(protos) = self.header_mut("sec-websocket-protocol") {
216            protos.push(b","[0]);
217            protos.extend(protocol.as_bytes());
218            return;
219        }
220        self.headers_mut()
221            .push(("Sec-WebSocket-Protocol".into(), protocol.into()))
222    }
223
224    /// Remove a possible protocol from this request.
225    #[allow(dead_code)]
226    pub fn remove_protocol(&mut self, protocol: &str) {
227        if let Some(protos) = self.header_mut("sec-websocket-protocol") {
228            let mut new_protos = Vec::with_capacity(protos.len());
229
230            if let Ok(protos_str) = from_utf8(protos) {
231                new_protos = protos_str
232                    .split(',')
233                    .filter(|proto| proto.trim() == protocol)
234                    .collect::<Vec<&str>>()
235                    .join(",")
236                    .into();
237            }
238            if new_protos.len() < protos.len() {
239                *protos = new_protos
240            }
241        }
242    }
243
244    /// Get the possible extensions for the WebSocket connection.
245    #[allow(dead_code)]
246    pub fn extensions(&self) -> Result<Vec<&str>> {
247        if let Some(exts) = self.header("sec-websocket-extensions") {
248            Ok(from_utf8(exts)?.split(',').map(|ext| ext.trim()).collect())
249        } else {
250            Ok(Vec::new())
251        }
252    }
253
254    /// Add a possible extension to this request.
255    /// This may result in duplicate extensions listed. Also, the order of extensions
256    /// indicates preference, so if the preference matters, consider using the
257    /// `Sec-WebSocket-Protocol` header directly.
258    #[allow(dead_code)]
259    pub fn add_extension(&mut self, ext: &str) {
260        if let Some(exts) = self.header_mut("sec-websocket-extensions") {
261            exts.push(b","[0]);
262            exts.extend(ext.as_bytes());
263            return;
264        }
265        self.headers_mut()
266            .push(("Sec-WebSocket-Extensions".into(), ext.into()))
267    }
268
269    /// Remove a possible extension from this request.
270    /// This will remove all configurations of the extension.
271    #[allow(dead_code)]
272    pub fn remove_extension(&mut self, ext: &str) {
273        if let Some(exts) = self.header_mut("sec-websocket-extensions") {
274            let mut new_exts = Vec::with_capacity(exts.len());
275
276            if let Ok(exts_str) = from_utf8(exts) {
277                new_exts = exts_str
278                    .split(',')
279                    .filter(|e| e.trim().starts_with(ext))
280                    .collect::<Vec<&str>>()
281                    .join(",")
282                    .into();
283            }
284            if new_exts.len() < exts.len() {
285                *exts = new_exts
286            }
287        }
288    }
289
290    /// Get the IP address of the client.
291    ///
292    /// This method will attempt to retrieve the most likely IP address of the requester
293    /// in the following manner:
294    ///
295    /// If the `X-Forwarded-For` header exists, this method will return the left most
296    /// address in the list.
297    ///
298    /// If the [Forwarded HTTP Header Field](https://tools.ietf.org/html/rfc7239) exits,
299    /// this method will return the left most address indicated by the `for` parameter,
300    /// if it exists.
301    ///
302    /// # Note
303    /// This method does not ensure that the address is a valid IP address.
304    #[allow(dead_code)]
305    pub fn client_addr(&self) -> Result<Option<&str>> {
306        if let Some(x_forward) = self.header("x-forwarded-for") {
307            return Ok(from_utf8(x_forward)?.split(',').next());
308        }
309
310        // We only care about the first forwarded header, so header is ok
311        if let Some(forward) = self.header("forwarded") {
312            if let Some(_for) = from_utf8(forward)?
313                .split(';')
314                .find(|f| f.trim().starts_with("for"))
315            {
316                if let Some(_for_eq) = _for.trim().split(',').next() {
317                    let mut it = _for_eq.split('=');
318                    it.next();
319                    return Ok(it.next());
320                }
321            }
322        }
323        Ok(None)
324    }
325
326    /// Attempt to parse an HTTP request from a buffer. If the buffer does not contain a complete
327    /// request, this will return `Ok(None)`.
328    pub fn parse(buf: &[u8]) -> Result<Option<Request>> {
329        let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
330        let mut req = httparse::Request::new(&mut headers);
331        let parsed = req.parse(buf)?;
332        if !parsed.is_partial() {
333            Ok(Some(Request {
334                path: req.path.unwrap().into(),
335                method: req.method.unwrap().into(),
336                headers: req.headers
337                    .iter()
338                    .map(|h| (h.name.into(), h.value.into()))
339                    .collect(),
340            }))
341        } else {
342            Ok(None)
343        }
344    }
345
346    /// Construct a new WebSocket handshake HTTP request from a url.
347    pub fn from_url(url: &url::Url) -> Result<Request> {
348        let query = if let Some(q) = url.query() {
349            format!("?{}", q)
350        } else {
351            "".into()
352        };
353
354        let mut headers = vec![
355            ("Connection".into(), "Upgrade".into()),
356            (
357                "Host".into(),
358                format!(
359                    "{}:{}",
360                    url.host_str().ok_or_else(|| Error::new(
361                        Kind::Internal,
362                        "No host passed for WebSocket connection.",
363                    ))?,
364                    url.port_or_known_default().unwrap_or(80)
365                ).into(),
366            ),
367            ("Sec-WebSocket-Version".into(), "13".into()),
368            ("Sec-WebSocket-Key".into(), generate_key().into()),
369            ("Upgrade".into(), "websocket".into()),
370        ];
371
372        if url.password().is_some() || url.username() != "" {
373            let basic = encode_base64(format!("{}:{}", url.username(), url.password().unwrap_or("")).as_bytes());
374            headers.push(("Authorization".into(), format!("Basic {}", basic).into()))
375        }
376
377        let req = Request {
378            path: format!("{}{}", url.path(), query),
379            method: "GET".to_owned(),
380            headers: headers,
381        };
382
383        debug!("Built request from URL:\n{}", req);
384
385        Ok(req)
386    }
387
388    /// Write a request out to a buffer
389    pub fn format<W>(&self, w: &mut W) -> Result<()>
390    where
391        W: Write,
392    {
393        write!(w, "{} {} HTTP/1.1\r\n", self.method, self.path)?;
394        for &(ref key, ref val) in &self.headers {
395            write!(w, "{}: ", key)?;
396            w.write_all(val)?;
397            write!(w, "\r\n")?;
398        }
399        write!(w, "\r\n")?;
400        Ok(())
401    }
402}
403
404impl fmt::Display for Request {
405    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
406        let mut s = Vec::with_capacity(2048);
407        self.format(&mut s).map_err(|err| {
408            error!("{:?}", err);
409            fmt::Error
410        })?;
411        write!(
412            f,
413            "{}",
414            from_utf8(&s).map_err(|err| {
415                error!("Unable to format request as utf8: {:?}", err);
416                fmt::Error
417            })?
418        )
419    }
420}
421
422/// The handshake response.
423#[derive(Debug)]
424pub struct Response {
425    status: u16,
426    reason: String,
427    headers: Vec<(String, Vec<u8>)>,
428    body: Vec<u8>,
429}
430
431impl Response {
432    // TODO: resolve the overlap with Request
433
434    /// Construct a generic HTTP response with a body.
435    pub fn new<R>(status: u16, reason: R, body: Vec<u8>) -> Response
436    where
437        R: Into<String>,
438    {
439        Response {
440            status,
441            reason: reason.into(),
442            headers: vec![("Content-Length".into(), body.len().to_string().into())],
443            body,
444        }
445    }
446
447    /// Get the response body.
448    #[inline]
449    pub fn body(&self) -> &[u8] {
450        &self.body
451    }
452
453    /// Get the value of the first instance of an HTTP header.
454    fn header(&self, header: &str) -> Option<&Vec<u8>> {
455        self.headers
456            .iter()
457            .find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase())
458            .map(|&(_, ref val)| val)
459    }
460    /// Edit the value of the first instance of an HTTP header.
461    pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> {
462        self.headers
463            .iter_mut()
464            .find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase())
465            .map(|&mut (_, ref mut val)| val)
466    }
467
468    /// Access the request headers.
469    #[allow(dead_code)]
470    #[inline]
471    pub fn headers(&self) -> &Vec<(String, Vec<u8>)> {
472        &self.headers
473    }
474
475    /// Edit the request headers.
476    #[allow(dead_code)]
477    #[inline]
478    pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> {
479        &mut self.headers
480    }
481
482    /// Get the HTTP status code.
483    #[allow(dead_code)]
484    #[inline]
485    pub fn status(&self) -> u16 {
486        self.status
487    }
488
489    /// Set the HTTP status code.
490    #[allow(dead_code)]
491    #[inline]
492    pub fn set_status(&mut self, status: u16) {
493        self.status = status
494    }
495
496    /// Get the HTTP status reason.
497    #[allow(dead_code)]
498    #[inline]
499    pub fn reason(&self) -> &str {
500        &self.reason
501    }
502
503    /// Set the HTTP status reason.
504    #[allow(dead_code)]
505    #[inline]
506    pub fn set_reason<R>(&mut self, reason: R)
507    where
508        R: Into<String>,
509    {
510        self.reason = reason.into()
511    }
512
513    /// Get the hashed WebSocket key.
514    pub fn key(&self) -> Result<&Vec<u8>> {
515        self.header("sec-websocket-accept")
516            .ok_or_else(|| Error::new(Kind::Protocol, "Unable to parse WebSocket key."))
517    }
518
519    /// Get the protocol that the server has decided to use.
520    #[allow(dead_code)]
521    pub fn protocol(&self) -> Result<Option<&str>> {
522        if let Some(proto) = self.header("sec-websocket-protocol") {
523            Ok(Some(from_utf8(proto)?))
524        } else {
525            Ok(None)
526        }
527    }
528
529    /// Set the protocol that the server has decided to use.
530    #[allow(dead_code)]
531    pub fn set_protocol(&mut self, protocol: &str) {
532        if let Some(proto) = self.header_mut("sec-websocket-protocol") {
533            *proto = protocol.into();
534            return;
535        }
536        self.headers_mut()
537            .push(("Sec-WebSocket-Protocol".into(), protocol.into()))
538    }
539
540    /// Get the extensions that the server has decided to use. If these are unacceptable, it is
541    /// appropriate to send an Extension close code.
542    #[allow(dead_code)]
543    pub fn extensions(&self) -> Result<Vec<&str>> {
544        if let Some(exts) = self.header("sec-websocket-extensions") {
545            Ok(from_utf8(exts)?
546                .split(',')
547                .map(|proto| proto.trim())
548                .collect())
549        } else {
550            Ok(Vec::new())
551        }
552    }
553
554    /// Add an accepted extension to this response.
555    /// This may result in duplicate extensions listed.
556    #[allow(dead_code)]
557    pub fn add_extension(&mut self, ext: &str) {
558        if let Some(exts) = self.header_mut("sec-websocket-extensions") {
559            exts.push(b","[0]);
560            exts.extend(ext.as_bytes());
561            return;
562        }
563        self.headers_mut()
564            .push(("Sec-WebSocket-Extensions".into(), ext.into()))
565    }
566
567    /// Remove an accepted extension from this response.
568    /// This will remove all configurations of the extension.
569    #[allow(dead_code)]
570    pub fn remove_extension(&mut self, ext: &str) {
571        if let Some(exts) = self.header_mut("sec-websocket-extensions") {
572            let mut new_exts = Vec::with_capacity(exts.len());
573
574            if let Ok(exts_str) = from_utf8(exts) {
575                new_exts = exts_str
576                    .split(',')
577                    .filter(|e| e.trim().starts_with(ext))
578                    .collect::<Vec<&str>>()
579                    .join(",")
580                    .into();
581            }
582            if new_exts.len() < exts.len() {
583                *exts = new_exts
584            }
585        }
586    }
587
588    /// Attempt to parse an HTTP response from a buffer. If the buffer does not contain a complete
589    /// response, thiw will return `Ok(None)`.
590    pub fn parse(buf: &[u8]) -> Result<Option<Response>> {
591        let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
592        let mut res = httparse::Response::new(&mut headers);
593
594        let parsed = res.parse(buf)?;
595        if !parsed.is_partial() {
596            Ok(Some(Response {
597                status: res.code.unwrap(),
598                reason: res.reason.unwrap().into(),
599                headers: res.headers
600                    .iter()
601                    .map(|h| (h.name.into(), h.value.into()))
602                    .collect(),
603                body: Vec::new(),
604            }))
605        } else {
606            Ok(None)
607        }
608    }
609
610    /// Construct a new WebSocket handshake HTTP response from a request.
611    /// This will create a response that ignores protocols and extensions. Edit this response to
612    /// accept a protocol and extensions as necessary.
613    pub fn from_request(req: &Request) -> Result<Response> {
614        let res = Response {
615            status: 101,
616            reason: "Switching Protocols".into(),
617            headers: vec![
618                ("Connection".into(), "Upgrade".into()),
619                ("Sec-WebSocket-Accept".into(), req.hashed_key()?.into()),
620                ("Upgrade".into(), "websocket".into()),
621            ],
622            body: Vec::new(),
623        };
624
625        debug!("Built response from request:\n{}", res);
626        Ok(res)
627    }
628
629    /// Write a response out to a buffer
630    pub fn format<W>(&self, w: &mut W) -> Result<()>
631    where
632        W: Write,
633    {
634        write!(w, "HTTP/1.1 {} {}\r\n", self.status, self.reason)?;
635        for &(ref key, ref val) in &self.headers {
636            write!(w, "{}: ", key)?;
637            w.write_all(val)?;
638            write!(w, "\r\n")?;
639        }
640        write!(w, "\r\n")?;
641        w.write_all(&self.body)?;
642        Ok(())
643    }
644}
645
646impl fmt::Display for Response {
647    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
648        let mut s = Vec::with_capacity(2048);
649        self.format(&mut s).map_err(|err| {
650            error!("{:?}", err);
651            fmt::Error
652        })?;
653        write!(
654            f,
655            "{}",
656            from_utf8(&s).map_err(|err| {
657                error!("Unable to format response as utf8: {:?}", err);
658                fmt::Error
659            })?
660        )
661    }
662}
663
664mod test {
665    #![allow(unused_imports, unused_variables, dead_code)]
666    use super::*;
667    use std::io::Write;
668    use std::net::SocketAddr;
669    use std::str::FromStr;
670
671    #[test]
672    fn remote_addr() {
673        let mut buf = Vec::with_capacity(2048);
674        write!(
675            &mut buf,
676            "GET / HTTP/1.1\r\n\
677             Connection: Upgrade\r\n\
678             Upgrade: websocket\r\n\
679             Sec-WebSocket-Version: 13\r\n\
680             Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n"
681        ).unwrap();
682
683        let req = Request::parse(&buf).unwrap().unwrap();
684        let res = Response::from_request(&req).unwrap();
685        let shake = Handshake {
686            request: req,
687            response: res,
688            peer_addr: Some(SocketAddr::from_str("127.0.0.1:8888").unwrap()),
689            local_addr: None,
690        };
691        assert_eq!(shake.remote_addr().unwrap().unwrap(), "127.0.0.1");
692    }
693
694    #[test]
695    fn remote_addr_x_forwarded_for() {
696        let mut buf = Vec::with_capacity(2048);
697        write!(
698            &mut buf,
699            "GET / HTTP/1.1\r\n\
700             Connection: Upgrade\r\n\
701             Upgrade: websocket\r\n\
702             X-Forwarded-For: 192.168.1.1, 192.168.1.2, 192.168.1.3\r\n\
703             Sec-WebSocket-Version: 13\r\n\
704             Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n"
705        ).unwrap();
706
707        let req = Request::parse(&buf).unwrap().unwrap();
708        let res = Response::from_request(&req).unwrap();
709        let shake = Handshake {
710            request: req,
711            response: res,
712            peer_addr: None,
713            local_addr: None,
714        };
715        assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.168.1.1");
716    }
717
718    #[test]
719    fn remote_addr_forwarded() {
720        let mut buf = Vec::with_capacity(2048);
721        write!(
722            &mut buf,
723            "GET / HTTP/1.1\r\n\
724            Connection: Upgrade\r\n\
725            Upgrade: websocket\r\n\
726            Forwarded: by=192.168.1.1; for=192.0.2.43, for=\"[2001:db8:cafe::17]\", for=unknown\r\n\
727            Sec-WebSocket-Version: 13\r\n\
728            Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n")
729            .unwrap();
730        let req = Request::parse(&buf).unwrap().unwrap();
731        let res = Response::from_request(&req).unwrap();
732        let shake = Handshake {
733            request: req,
734            response: res,
735            peer_addr: None,
736            local_addr: None,
737        };
738        assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.0.2.43");
739    }
740}