oc_http/
websocket.rs

1use std::{
2    io,
3    convert::TryFrom,
4    fmt,
5};
6
7use sha1::{Sha1, Digest};
8use crate::{respond, Request, Response};
9use nom::{
10    IResult,
11    bits::{
12        bits,
13        complete::take,
14    },
15};
16use futures::{
17    AsyncRead,
18    AsyncWrite,
19    AsyncWriteExt,
20    AsyncReadExt,
21};
22
23const MAX_PAYLOAD_SIZE: u64 = 16_000;
24
25#[derive(Debug, Clone)]
26pub enum WebSocketError {
27    ConnectionNotUpgrade,
28    NoConnectionHeader,
29    NoUpgradeHeader,
30    UpgradeNotToWebSocket,
31    WrongVersion,
32    NoKey,
33    TooBig,
34    ProtocolError,
35    IOError(String),
36    BadOpcode,
37    ConnectionClosed,
38}
39
40impl From<io::Error> for WebSocketError {
41    fn from(err: io::Error) -> Self {
42        if err.kind() == io::ErrorKind::UnexpectedEof {
43            WebSocketError::ConnectionClosed
44        } else {
45            WebSocketError::IOError(format!("{:?}", err))
46        }
47    }
48}
49
50impl<E> From<nom::Err<E>> for WebSocketError {
51    fn from(_err: nom::Err<E>) -> Self {
52        WebSocketError::ProtocolError
53    }
54}
55
56impl fmt::Display for WebSocketError {
57    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
58        write!(f, "problem establishing websocket connection")
59    }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq)]
63pub enum MessageType {
64    Continuation,
65    Text,
66    Binary,
67    Close,
68    Ping,
69    Pong,
70}
71
72impl MessageType {
73    pub fn is_control(&self) -> bool {
74        match self {
75            MessageType::Ping | MessageType::Pong | MessageType::Close => true,
76            _ => false,
77        }
78    }
79}
80
81impl TryFrom<u8> for MessageType {
82    type Error = WebSocketError;
83
84    fn try_from(b: u8) -> Result<Self, Self::Error> {
85        Ok(match b {
86            0x0 => MessageType::Continuation,
87            0x1 => MessageType::Text,
88            0x2 => MessageType::Binary,
89            0x8 => MessageType::Close,
90            0x9 => MessageType::Ping,
91            0xA => MessageType::Pong,
92            _ => return Err(WebSocketError::BadOpcode),
93        })
94    }
95}
96
97impl Into<u8> for MessageType {
98    fn into(self) -> u8 {
99        match self {
100            MessageType::Continuation => 0x0,
101            MessageType::Text => 0x1,
102            MessageType::Binary => 0x2,
103            MessageType::Close => 0x8,
104            MessageType::Ping => 0x9,
105            MessageType::Pong => 0xA,
106        }
107    }
108}
109
110#[derive(Debug, Clone)]
111pub struct Message{
112    pub typ: MessageType,
113    pub contents: Vec<u8>,
114}
115
116pub async fn upgrade<S>(req: &Request, mut stream: S) -> Result<(WebSocketReader<S>, WebSocketWriter<S>), WebSocketError>
117where S: AsyncRead + AsyncWrite + Clone + Unpin
118{
119    // sanity check that required headers are in place
120    match req.headers.get("Connection") {
121        Some(header) => {
122            let mut ok = false;
123            if let Ok(txt) = std::str::from_utf8(header) {
124                if let Some(_) = txt.find("Upgrade") {
125                    ok = true;
126                }
127            }
128            if !ok {
129                Err(WebSocketError::ConnectionNotUpgrade)?
130            }
131        },
132        None => Err(WebSocketError::NoConnectionHeader)?,
133    };
134    match req.headers.get("Upgrade") {
135        Some(header) => if header != b"websocket" { Err(WebSocketError::UpgradeNotToWebSocket)? },
136        None => Err(WebSocketError::NoUpgradeHeader)?,
137    };
138    match req.headers.get("Sec-WebSocket-Version") {
139        Some(header) => if header != b"13" { Err(WebSocketError::WrongVersion)? },
140        None => Err(WebSocketError::WrongVersion)?,
141    };
142    // get the key we need to hash in the response
143    let key = match req.headers.get("Sec-WebSocket-Key") {
144        Some(k) => k,
145        None => Err(WebSocketError::NoKey)?,
146    };
147    let mut hasher = Sha1::new();
148    hasher.update(&key);
149    // magic string from the interwebs
150    hasher.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
151    let result = hasher.finalize();
152    let mut headers = vec!();
153    headers.push(("Upgrade".into(), Vec::from("websocket")));
154    headers.push(("Connection".into(), Vec::from("Upgrade")));
155    headers.push(("Sec-WebSocket-Accept".into(), base64::encode(&result[..]).into()));
156    // complete the handshake
157    respond(&mut stream, Response{
158        code: 101,
159        reason: "Switching Protocols",
160        headers,
161    }).await?;
162    stream.flush().await?;
163    Ok((WebSocketReader{
164        stream: stream.clone(),
165        buffered_message: None,
166    }, WebSocketWriter{
167        stream,
168    }))
169}
170
171pub struct WebSocketReader<S>
172where S: AsyncRead + Unpin
173{
174    stream: S,
175    buffered_message: Option<(MessageType, Vec<u8>)>,
176}
177
178impl<S> WebSocketReader<S>
179where S: AsyncRead + Unpin
180{
181    pub async fn recv(&mut self) -> Result<Message, WebSocketError> {
182        loop {
183            let header = read_header(&mut self.stream).await?;
184            if header.payload_len > MAX_PAYLOAD_SIZE {
185                Err(WebSocketError::TooBig)?;
186            }
187            // read the body
188            let mut contents = vec![0u8; header.payload_len as usize];
189            self.stream.read_exact(&mut contents).await?;
190            // unmask the value in-place
191            let len = contents.len();
192            for i in 0..len {
193                contents[i] = contents[i] ^ header.masking_key[i % header.masking_key.len()];
194            }
195            let typ = MessageType::try_from(header.opcode)?;
196            if typ.is_control() {
197                return Ok(Message{contents, typ});
198            }
199            // if this is a new fragment chain, start it
200            if header.fin == 0 && typ != MessageType::Continuation {
201                self.buffered_message = Some((typ, contents));
202            } else if header.fin == 0 {
203                match &mut self.buffered_message {
204                    Some((_, old)) => {
205                        old.append(&mut contents);
206                    },
207                    None => return Err(WebSocketError::BadOpcode),
208                }
209            } else {
210                let (typ, contents) = self.buffered_message.take().unwrap_or((typ, contents));
211                return Ok(Message{typ, contents});
212            }
213        }
214    }
215}
216
217pub struct WebSocketWriter<S>
218where S: AsyncWrite + Unpin
219{
220    stream: S,
221}
222
223impl<S> WebSocketWriter<S>
224where S: AsyncWrite + Unpin
225{
226    pub async fn write(&mut self, msg: &Message) -> Result<(), WebSocketError> {
227        let res = WebSocketHeader{
228            fin: 1,
229            opcode: msg.typ.into(),
230            mask: 0,
231            payload_len: msg.contents.len() as u64,
232            masking_key: vec!(),
233        };
234        self.stream.write_all(&mut res.to_vec()).await?;
235        self.stream.write_all(&msg.contents).await?;
236        self.stream.flush().await?;
237        Ok(())
238    }
239}
240
241#[derive(Debug, Clone)]
242struct WebSocketHeader{
243    fin: u8,
244    opcode: u8,
245    mask: u8,
246    payload_len: u64,
247    masking_key: Vec<u8>,
248}
249
250impl WebSocketHeader {
251    fn to_vec(&self) -> Vec<u8> {
252        let mut ret = Vec::with_capacity(70);
253        ret.push((self.fin << 7) | self.opcode);
254        ret.extend(if self.payload_len < 126 {
255            vec!(self.payload_len as u8)
256        } else if self.payload_len < u16::MAX as u64 {
257            let mut ret = vec!(126u8);
258            ret.extend(&(self.payload_len as u16).to_be_bytes());
259            ret
260        } else {
261            let mut ret = vec!(127u8);
262            ret.extend(&(self.payload_len as u16).to_be_bytes());
263            ret
264        });
265        ret
266    }
267}
268
269/// handles control message (ping, pong) to make sure the socket stays open
270pub async fn handle_control<S>(msg: &Message, wrt: &mut WebSocketWriter<S>) -> Result<bool, WebSocketError>
271where S: AsyncWrite + Unpin
272{
273    match msg.typ {
274        MessageType::Pong => {
275            let msg = Message{
276                typ: MessageType::Pong,
277                contents: msg.contents.clone(),
278            };
279            wrt.write(&msg).await?;
280            Ok(true)
281        },
282        MessageType::Close => {
283            Err(WebSocketError::ConnectionClosed)
284        }
285        _ => {
286            Ok(false)
287        },
288    }
289}
290
291async fn read_header<S>(stream: &mut S) -> Result<WebSocketHeader, WebSocketError>
292where S: AsyncRead + Unpin
293{
294    // fixed-length header size is 2 bytes, followed by optional extended length
295    // and finally mask
296    let mut header_fixed = vec![0u8; 2];
297    stream.read_exact(&mut header_fixed).await?;
298    let (_, mut res) = read_header_internal(&header_fixed)?;
299    header_fixed[1] &= 0b01111111;
300    if res.payload_len == 126 {
301        // read 16 bites, 2 bytes
302        let mut len = [0u8; 2];
303        stream.read_exact(&mut len).await?;
304        res.payload_len = u16::from_be_bytes(len) as u64;
305    } else if res.payload_len == 127 {
306        // read 64 bits, 8 bytes
307        let mut len = [0u8; 8];
308        stream.read_exact(&mut len).await?;
309        res.payload_len = u64::from_be_bytes(len) as u64;
310    }
311    if res.mask != 0 {
312        let mut mask_key = vec![0u8; 4];
313        stream.read_exact(&mut mask_key).await?;
314        res.masking_key =  mask_key;
315    }
316    Ok(res)
317}
318
319fn read_header_internal(input: &[u8]) -> IResult<&[u8], WebSocketHeader> {
320    bits(read_header_internal_bits)(input)
321}
322
323fn read_header_internal_bits(input: (&[u8], usize)) -> IResult<(&[u8], usize), WebSocketHeader>
324{
325    let (input, fin) = take(1usize)(input)?;
326    let (input, _rsv1): ((&[u8], usize), u8) = take(1usize)(input)?;
327    let (input, _rsv2): ((&[u8], usize), u8) = take(1usize)(input)?;
328    let (input, _rsv3): ((&[u8], usize), u8) = take(1usize)(input)?;
329    let (input, opcode) = take(4usize)(input)?;
330    let (input, mask) = take(1usize)(input)?;
331    let (input, payload_len) = take(7usize)(input)?;
332    Ok((input, WebSocketHeader{fin, opcode, mask, payload_len, masking_key: vec!()}))
333}