websocket_simple/
handshake.rs

1use std::fmt;
2use std::io::Write;
3use std::mem::transmute;
4use std::str::from_utf8;
5use std::net::SocketAddr;
6
7use sha1;
8use rand;
9use url;
10use httparse;
11
12use result::{Result, Error, Kind};
13
14static WS_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
15static BASE64: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
16const MAX_HEADERS: usize = 124;
17
18fn generate_key() -> String {
19    let key: [u8; 16] = unsafe {
20        transmute(rand::random::<(u64, u64)>())
21    };
22    encode_base64(&key)
23}
24
25pub fn hash_key(key: &[u8]) -> String {
26    let mut hasher = sha1::Sha1::new();
27
28    hasher.update(key);
29    hasher.update(WS_GUID.as_bytes());
30
31    encode_base64(&hasher.digest().bytes())
32}
33
34// This code is based on rustc_serialize base64 STANDARD
35fn encode_base64(data: &[u8]) -> String {
36    let len = data.len();
37    let mod_len = len % 3;
38
39    let mut encoded = vec![b'='; (len + 2) / 3 * 4];
40    {
41        let mut in_iter = data[..len - mod_len].iter().map(|&c| c as u32);
42        let mut out_iter = encoded.iter_mut();
43
44        let enc = |val| BASE64[val as usize];
45        let mut write = |val| *out_iter.next().unwrap() = val;
46
47        while let (Some(one), Some(two), Some(three)) = (in_iter.next(), in_iter.next(), in_iter.next()) {
48            let g24 = one << 16 | two << 8 | three;
49            write(enc((g24 >> 18) & 63));
50            write(enc((g24 >> 12) & 63));
51            write(enc((g24 >> 6 ) & 63));
52            write(enc(g24 & 63));
53        }
54
55        match mod_len {
56            1 => {
57                let pad = (data[len-1] as u32) << 16;
58                write(enc((pad >> 18) & 63));
59                write(enc((pad >> 12) & 63));
60            }
61            2 => {
62                let pad = (data[len-2] as u32) << 16 | (data[len-1] as u32) << 8;
63                write(enc((pad >> 18) & 63));
64                write(enc((pad >> 12) & 63));
65                write(enc((pad >> 6) & 63));
66            }
67            _ => (),
68        }
69    }
70
71    String::from_utf8(encoded).unwrap()
72}
73
74/// A struct representing the two halves of the WebSocket handshake.
75#[derive(Debug)]
76pub struct Handshake {
77    /// The HTTP request sent to begin the handshake.
78    pub request: Request,
79    /// The HTTP response from the server confirming the handshake.
80    pub response: Response,
81    /// The socket address of the other endpoint. This address may
82    /// be an intermediary such as a proxy server.
83    pub peer_addr: Option<SocketAddr>,
84    /// The socket address of this enpoint.
85    pub local_addr: Option<SocketAddr>,
86}
87
88impl Handshake {
89
90    /// Get the IP address of the remote connection.
91    ///
92    /// This is the preferred method of obtaining the client's IP address.
93    /// It will attempt to retrieve the most likely IP address based on request
94    /// headers, falling back to the address of the peer.
95    ///
96    /// # Note
97    /// This assumes that the peer is a client. If you are implementing a
98    /// WebSocket client and want to obtain the address of the server, use
99    /// `Handshake::peer_addr` instead.
100    ///
101    /// This method does not ensure that the address is a valid IP address.
102    #[allow(dead_code)]
103    pub fn remote_addr(&self) -> Result<Option<String>> {
104        Ok(try!(self.request.client_addr()).map(String::from).or_else(|| {
105            if let Some(addr) = self.peer_addr {
106                Some(addr.ip().to_string())
107            } else {
108                None
109            }
110        }))
111    }
112
113}
114
115
116/// The handshake request.
117#[derive(Debug)]
118pub struct Request {
119    path: String,
120    method: String,
121    headers: Vec<(String, Vec<u8>)>,
122}
123
124impl Request {
125
126    /// Get the value of the first instance of an HTTP header.
127    pub fn header(&self, header: &str) -> Option<&Vec<u8>> {
128        self.headers
129            .iter()
130            .find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase())
131            .map(|&(_, ref val)| val)
132    }
133
134    /// Edit the value of the first instance of an HTTP header.
135    pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> {
136        self.headers
137            .iter_mut()
138            .find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase())
139            .map(|&mut (_, ref mut val)| val)
140    }
141
142    /// Access the request headers.
143    #[allow(dead_code)]
144    #[inline]
145    pub fn headers(&self) -> &Vec<(String, Vec<u8>)> {
146        &self.headers
147    }
148
149    /// Edit the request headers.
150    #[allow(dead_code)]
151    #[inline]
152    pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> {
153        &mut self.headers
154    }
155
156    /// Get the origin of the request if it comes from a browser.
157    #[allow(dead_code)]
158    pub fn origin(&self) -> Result<Option<&str>> {
159        if let Some(origin) = self.header("origin") {
160            Ok(Some(try!(from_utf8(origin))))
161        } else {
162            Ok(None)
163        }
164    }
165
166    /// Get the unhashed WebSocket key sent in the request.
167    pub fn key(&self) -> Result<&Vec<u8>> {
168        self.header("sec-websocket-key")
169            .ok_or(Error::new(Kind::Protocol, "Unable to parse WebSocket key."))
170    }
171
172    /// Get the hashed WebSocket key from this request.
173    pub fn hashed_key(&self) -> Result<String> {
174        Ok(hash_key(try!(self.key())))
175    }
176
177    /// Get the WebSocket protocol version from the request (should be 13).
178    #[allow(dead_code)]
179    pub fn version(&self) -> Result<&str> {
180        if let Some(version) = self.header("sec-websocket-version") {
181            from_utf8(version).map_err(Error::from)
182        } else {
183            Err(Error::new(Kind::Protocol, "The Sec-WebSocket-Version header is missing."))
184        }
185    }
186
187    /// Get the path of the request.
188    #[allow(dead_code)]
189    #[inline]
190    pub fn resource(&self) -> &str {
191        &self.path
192    }
193
194    /// Get the possible protocols for the WebSocket connection.
195    #[allow(dead_code)]
196    pub fn protocols(&self) -> Result<Vec<&str>> {
197        if let Some(protos) = self.header("sec-websocket-protocol") {
198            Ok(try!(from_utf8(protos)).split(',').map(|proto| proto.trim()).collect())
199        } else {
200            Ok(Vec::new())
201        }
202    }
203
204    /// Add a possible protocol to this request.
205    /// This may result in duplicate protocols listed.
206    #[allow(dead_code)]
207    pub fn add_protocol(&mut self, protocol: &str) {
208        if let Some(protos) = self.header_mut("sec-websocket-protocol") {
209            protos.push(b","[0]);
210            protos.extend(protocol.as_bytes());
211            return
212        }
213        self.headers_mut().push(("Sec-WebSocket-Protocol".into(), protocol.into()))
214    }
215
216    /// Remove a possible protocol from this request.
217    #[allow(dead_code)]
218    pub fn remove_protocol(&mut self, protocol: &str) {
219        if let Some(protos) = self.header_mut("sec-websocket-protocol") {
220            let mut new_protos = Vec::with_capacity(protos.len());
221
222            if let Ok(protos_str) = from_utf8(protos) {
223                new_protos = protos_str
224                    .split(',')
225                    .filter(|proto| proto.trim() == protocol)
226                    .collect::<Vec<&str>>()
227                    .join(",").into();
228            }
229            if new_protos.len() < protos.len() {
230                *protos = new_protos
231            }
232        }
233    }
234
235    /// Get the possible extensions for the WebSocket connection.
236    #[allow(dead_code)]
237    pub fn extensions(&self) -> Result<Vec<&str>> {
238        if let Some(exts) = self.header("sec-websocket-extensions") {
239            Ok(try!(from_utf8(exts)).split(',').map(|ext| ext.trim()).collect())
240        } else {
241            Ok(Vec::new())
242        }
243    }
244
245    /// Add a possible extension to this request.
246    /// This may result in duplicate extensions listed. Also, the order of extensions
247    /// indicates preference, so if the preference matters, consider using the
248    /// `Sec-WebSocket-Protocol` header directly.
249    #[allow(dead_code)]
250    pub fn add_extension(&mut self, ext: &str) {
251        if let Some(exts) = self.header_mut("sec-websocket-extensions") {
252            exts.push(b","[0]);
253            exts.extend(ext.as_bytes());
254            return
255        }
256        self.headers_mut().push(("Sec-WebSocket-Extensions".into(), ext.into()))
257    }
258
259    /// Remove a possible extension from this request.
260    /// This will remove all configurations of the extension.
261    #[allow(dead_code)]
262    pub fn remove_extension(&mut self, ext: &str) {
263        if let Some(exts) = self.header_mut("sec-websocket-extensions") {
264            let mut new_exts = Vec::with_capacity(exts.len());
265
266            if let Ok(exts_str) = from_utf8(exts) {
267                new_exts = exts_str
268                    .split(',')
269                    .filter(|e| e.trim().starts_with(ext))
270                    .collect::<Vec<&str>>()
271                    .join(",").into();
272            }
273            if new_exts.len() < exts.len() {
274                *exts = new_exts
275            }
276        }
277    }
278
279    /// Get the IP address of the client.
280    ///
281    /// This method will attempt to retrieve the most likely IP address of the requester
282    /// in the following manner:
283    ///
284    /// If the `X-Forwarded-For` header exists, this method will return the left most
285    /// address in the list.
286    ///
287    /// If the [Forwarded HTTP Header Field](https://tools.ietf.org/html/rfc7239) exits,
288    /// this method will return the left most address indicated by the `for` parameter,
289    /// if it exists.
290    ///
291    /// # Note
292    /// This method does not ensure that the address is a valid IP address.
293    #[allow(dead_code)]
294    pub fn client_addr(&self) -> Result<Option<&str>> {
295        if let Some(x_forward) = self.header("x-forwarded-for") {
296            return Ok(try!(from_utf8(x_forward)).split(',').next())
297        }
298
299        // We only care about the first forwarded header, so header is ok
300        if let Some(forward) = self.header("forwarded") {
301            if let Some(_for) = try!(from_utf8(forward))
302                .split(';')
303                .find(|f| f.trim().starts_with("for"))
304            {
305                if let Some(_for_eq) = _for.trim().split(',').next() {
306                    let mut it = _for_eq.split('=');
307                    it.next();
308                    return Ok(it.next())
309                }
310            }
311        }
312        Ok(None)
313    }
314
315    /// Attempt to parse an HTTP request from a buffer. If the buffer does not contain a complete
316    /// request, this will return `Ok(None)`.
317    pub fn parse(buf: &[u8]) -> Result<Option<Request>> {
318        let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
319        let mut req = httparse::Request::new(&mut headers);
320        let parsed = try!(req.parse(buf));
321        if !parsed.is_partial() {
322            Ok(Some(Request {
323                path: req.path.unwrap().into(),
324                method: req.method.unwrap().into(),
325                headers: req.headers.iter().map(|h| (h.name.into(), h.value.into())).collect(),
326            }))
327        } else {
328            Ok(None)
329        }
330    }
331
332    /// Construct a new WebSocket handshake HTTP request from a url.
333    pub fn from_url(url: &url::Url) -> Result<Request> {
334
335        let query = if let Some(q) = url.query() {
336            format!("?{}", q)
337        } else {
338            "".into()
339        };
340
341        let req = Request {
342            path: format!(
343                "{}{}",
344                url.path(),
345                query),
346            method: "GET".to_owned(),
347            headers: vec![
348                ("Connection".into(), "Upgrade".into()),
349                (
350                    "Host".into(),
351                    format!(
352                        "{}:{}",
353                        try!(url.host_str().ok_or(
354                            Error::new(Kind::Internal, "No host passed for WebSocket connection."))),
355                        url.port_or_known_default().unwrap_or(80)).into(),
356                ),
357                ("Sec-WebSocket-Version".into(), "13".into()),
358                ("Sec-WebSocket-Key".into(), generate_key().into()),
359                ("Upgrade".into(), "websocket".into()),
360            ],
361        };
362
363        debug!("Built request from URL:\n{}", req);
364
365        Ok(req)
366    }
367
368    /// Write a request out to a buffer
369    pub fn format<W>(&self, w: &mut W) -> Result<()>
370        where W: Write,
371    {
372        try!(write!(w, "{} {} HTTP/1.1\r\n", self.method, self.path));
373        for &(ref key, ref val) in self.headers.iter() {
374            try!(write!(w, "{}: ", key));
375            try!(w.write(val));
376            try!(write!(w, "\r\n"));
377        }
378        try!(write!(w, "\r\n"));
379        Ok(())
380    }
381
382}
383
384impl fmt::Display for Request {
385
386    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
387        let mut s = Vec::with_capacity(2048);
388        try!(self.format(&mut s).map_err(|err| {
389            error!("{:?}", err);
390            fmt::Error
391        }));
392        write!(f, "{}", try!(from_utf8(&s).map_err(|err| {
393            error!("Unable to format request as utf8: {:?}", err);
394            fmt::Error
395        })))
396    }
397
398}
399
400/// The handshake response.
401#[derive(Debug)]
402pub struct Response {
403    status: u16,
404    reason: String,
405    headers: Vec<(String, Vec<u8>)>,
406}
407
408impl Response {
409    // TODO: resolve the overlap with Request
410
411    /// Get the value of the first instance of an HTTP header.
412    fn header(&self, header: &str) -> Option<&Vec<u8>> {
413        self.headers
414            .iter()
415            .find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase())
416            .map(|&(_, ref val)| val)
417    }
418    /// Edit the value of the first instance of an HTTP header.
419    pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> {
420        self.headers
421            .iter_mut()
422            .find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase())
423            .map(|&mut (_, ref mut val)| val)
424    }
425
426    /// Access the request headers.
427    #[allow(dead_code)]
428    #[inline]
429    pub fn headers(&self) -> &Vec<(String, Vec<u8>)> {
430        &self.headers
431    }
432
433    /// Edit the request headers.
434    #[allow(dead_code)]
435    #[inline]
436    pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> {
437        &mut self.headers
438    }
439
440    /// Get the HTTP status code.
441    #[allow(dead_code)]
442    #[inline]
443    pub fn status(&self) -> u16 {
444        self.status
445    }
446
447    /// Set the HTTP status code.
448    #[allow(dead_code)]
449    #[inline]
450    pub fn set_status(&mut self, status: u16)  {
451        self.status = status
452    }
453
454    /// Get the HTTP status reason.
455    #[allow(dead_code)]
456    #[inline]
457    pub fn reason(&self) -> &str {
458        &self.reason
459    }
460
461
462    /// Set the HTTP status reason.
463    #[allow(dead_code)]
464    #[inline]
465    pub fn set_reason<R>(&mut self, reason: R)
466        where R: Into<String>
467    {
468        self.reason = reason.into()
469    }
470
471    /// Get the hashed WebSocket key.
472    pub fn key(&self) -> Result<&Vec<u8>> {
473        self.header("sec-websocket-accept").ok_or(Error::new(Kind::Protocol, "Unable to parse WebSocket key."))
474    }
475
476    /// Get the protocol that the server has decided to use.
477    #[allow(dead_code)]
478    pub fn protocol(&self) -> Result<Option<&str>> {
479        if let Some(proto) = self.header("sec-websocket-protocol") {
480            Ok(Some(try!(from_utf8(proto))))
481        } else {
482            Ok(None)
483        }
484    }
485
486    /// Set the protocol that the server has decided to use.
487    #[allow(dead_code)]
488    pub fn set_protocol(&mut self, protocol: &str) {
489        if let Some(proto) = self.header_mut("sec-websocket-protocol") {
490            *proto = protocol.into();
491            return
492        }
493        self.headers_mut().push(("Sec-WebSocket-Protocol".into(), protocol.into()))
494    }
495
496    /// Get the extensions that the server has decided to use. If these are unacceptable, it is
497    /// appropriate to send an Extension close code.
498    #[allow(dead_code)]
499    pub fn extensions(&self) -> Result<Vec<&str>> {
500        if let Some(exts) = self.header("sec-websocket-extensions") {
501            Ok(try!(from_utf8(exts)).split(',').map(|proto| proto.trim()).collect())
502        } else {
503            Ok(Vec::new())
504        }
505    }
506
507    /// Add an accepted extension to this response.
508    /// This may result in duplicate extensions listed.
509    #[allow(dead_code)]
510    pub fn add_extension(&mut self, ext: &str) {
511        if let Some(exts) = self.header_mut("sec-websocket-extensions") {
512            exts.push(b","[0]);
513            exts.extend(ext.as_bytes());
514            return
515        }
516        self.headers_mut().push(("Sec-WebSocket-Extensions".into(), ext.into()))
517    }
518
519    /// Remove an accepted extension from this response.
520    /// This will remove all configurations of the extension.
521    #[allow(dead_code)]
522    pub fn remove_extension(&mut self, ext: &str) {
523        if let Some(exts) = self.header_mut("sec-websocket-extensions") {
524            let mut new_exts = Vec::with_capacity(exts.len());
525
526            if let Ok(exts_str) = from_utf8(exts) {
527                new_exts = exts_str
528                    .split(',')
529                    .filter(|e| e.trim().starts_with(ext))
530                    .collect::<Vec<&str>>()
531                    .join(",").into();
532            }
533            if new_exts.len() < exts.len() {
534                *exts = new_exts
535            }
536        }
537    }
538
539    /// Attempt to parse an HTTP response from a buffer. If the buffer does not contain a complete
540    /// response, thiw will return `Ok(None)`.
541    pub fn parse(buf: &[u8]) -> Result<Option<Response>> {
542        let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
543        let mut res = httparse::Response::new(&mut headers);
544
545        let parsed = try!(res.parse(buf));
546        if !parsed.is_partial() {
547            Ok(Some(Response {
548                status: res.code.unwrap(),
549                reason: res.reason.unwrap().into(),
550                headers: res.headers.iter().map(|h| (h.name.into(), h.value.into())).collect(),
551            }))
552        } else {
553            Ok(None)
554        }
555    }
556
557    /// Construct a new WebSocket handshake HTTP response from a request.
558    /// This will create a response that ignores protocols and extensions. Edit this response to
559    /// accept a protocol and extensions as necessary.
560    pub fn from_request(req: &Request) -> Result<Response> {
561        let res = Response {
562            status: 101,
563            reason: "Switching Protocols".into(),
564            headers: vec![
565                ("Connection".into(), "Upgrade".into()),
566                ("Sec-WebSocket-Accept".into(), try!(req.hashed_key()).into()),
567                ("Upgrade".into(), "websocket".into()),
568            ],
569        };
570
571        debug!("Built response from request:\n{}", res);
572        Ok(res)
573    }
574
575    /// Write a response out to a buffer
576    pub fn format<W>(&self, w: &mut W) -> Result<()>
577        where W: Write
578    {
579        try!(write!(w, "HTTP/1.1 {} {}\r\n", self.status, self.reason));
580        for &(ref key, ref val) in self.headers.iter() {
581            try!(write!(w, "{}: ", key));
582            try!(w.write(val));
583            try!(write!(w, "\r\n"));
584        }
585        try!(write!(w, "\r\n"));
586        Ok(())
587    }
588}
589
590impl fmt::Display for Response {
591
592    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
593        let mut s = Vec::with_capacity(2048);
594        try!(self.format(&mut s).map_err(|err| {
595            error!("{:?}", err);
596            fmt::Error
597        }));
598        write!(f, "{}", try!(from_utf8(&s).map_err(|err| {
599            error!("Unable to format response as utf8: {:?}", err);
600            fmt::Error
601        })))
602    }
603
604}
605
606mod test {
607    #![allow(unused_imports, unused_variables, dead_code)]
608    use std::io::Write;
609    use std::net::SocketAddr;
610    use std::str::FromStr;
611    use super::*;
612
613    #[test]
614    fn remote_addr() {
615        let mut buf = Vec::with_capacity(2048);
616        write!(
617            &mut buf,
618            "GET / HTTP/1.1\r\n\
619            Connection: Upgrade\r\n\
620            Upgrade: websocket\r\n\
621            Sec-WebSocket-Version: 13\r\n\
622            Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n").unwrap();
623
624        let req = Request::parse(&buf).unwrap().unwrap();
625        let res = Response::from_request(&req).unwrap();
626        let shake = Handshake {
627            request: req,
628            response: res,
629            peer_addr: Some(SocketAddr::from_str("127.0.0.1:8888").unwrap()),
630            local_addr: None,
631        };
632        assert_eq!(shake.remote_addr().unwrap().unwrap(), "127.0.0.1");
633    }
634
635    #[test]
636    fn remote_addr_x_forwarded_for() {
637        let mut buf = Vec::with_capacity(2048);
638        write!(
639            &mut buf,
640            "GET / HTTP/1.1\r\n\
641            Connection: Upgrade\r\n\
642            Upgrade: websocket\r\n\
643            X-Forwarded-For: 192.168.1.1, 192.168.1.2, 192.168.1.3\r\n\
644            Sec-WebSocket-Version: 13\r\n\
645            Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n").unwrap();
646
647        let req = Request::parse(&buf).unwrap().unwrap();
648        let res = Response::from_request(&req).unwrap();
649        let shake = Handshake {
650            request: req,
651            response: res,
652            peer_addr: None,
653            local_addr: None,
654        };
655        assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.168.1.1");
656    }
657
658    #[test]
659    fn remote_addr_forwarded() {
660        let mut buf = Vec::with_capacity(2048);
661        write!(
662            &mut buf,
663            "GET / HTTP/1.1\r\n\
664            Connection: Upgrade\r\n\
665            Upgrade: websocket\r\n\
666            Forwarded: by=192.168.1.1; for=192.0.2.43, for=\"[2001:db8:cafe::17]\", for=unknown\r\n\
667            Sec-WebSocket-Version: 13\r\n\
668            Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n").unwrap();
669        let req = Request::parse(&buf).unwrap().unwrap();
670        let res = Response::from_request(&req).unwrap();
671        let shake = Handshake {
672            request: req,
673            response: res,
674            peer_addr: None,
675            local_addr: None,
676        };
677        assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.0.2.43");
678    }
679}