rouille_ng/websocket/
websocket.rs

1// Copyright (c) 2016 The Rouille developers
2// Licensed under the Apache License, Version 2.0
3// <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
6// at your option. All files in the project carrying such
7// notice may not be copied, modified, or distributed except
8// according to those terms.
9
10use std::io;
11use std::io::Write;
12use std::mem;
13use std::sync::mpsc::Sender;
14use ReadWrite;
15use Upgrade;
16
17use websocket::low_level;
18
19/// A successful websocket. An open channel of communication. Implements `Read` and `Write`.
20pub struct Websocket {
21    // The socket. `None` if closed.
22    socket: Option<Box<dyn ReadWrite + Send>>,
23    // The websocket state machine.
24    state_machine: low_level::StateMachine,
25    // True if the fragmented message currently being processed is binary. False if string. Pings
26    // are excluded.
27    current_message_binary: bool,
28    // Buffer for the fragmented message currently being processed. Pings are excluded.
29    current_message_payload: Vec<u8>,
30    // Opcode of the fragment currently being processed.
31    current_frame_opcode: u8,
32    // Fin flag of the fragment currently being processed.
33    current_frame_fin: bool,
34    // Data of the fragment currently being processed.
35    current_frame_payload: Vec<u8>,
36    // Queue of the messages that are going to be returned by `next()`.
37    messages_in_queue: Vec<Message>,
38}
39
40/// A message produced by a websocket connection.
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum Message {
43    /// Text data. If the client is in Javascript, this happens when the client called `send()`
44    /// with a string.
45    Text(String),
46
47    /// Binary data. If the client is in Javascript, this happens when the client called `send()`
48    /// with a blob or an arraybuffer.
49    Binary(Vec<u8>),
50}
51
52/// Error that can happen when sending a message to the client.
53#[derive(Debug)]
54pub enum SendError {
55    /// Failed to transfer the message on the socket.
56    IoError(io::Error),
57
58    /// The websocket connection is closed.
59    Closed,
60}
61
62impl From<io::Error> for SendError {
63    #[inline]
64    fn from(err: io::Error) -> SendError {
65        SendError::IoError(err)
66    }
67}
68
69impl Websocket {
70    /// Sends text data over the websocket.
71    ///
72    /// Returns an error if the message didn't send correctly or if the connection is closed.
73    ///
74    /// If the client is in javascript, the message will contain a string.
75    #[inline]
76    pub fn send_text(&mut self, data: &str) -> Result<(), SendError> {
77        let socket = match self.socket {
78            Some(ref mut s) => s,
79            None => return Err(SendError::Closed),
80        };
81
82        try!(send(data.as_bytes(), Write::by_ref(socket), 0x1));
83        Ok(())
84    }
85
86    /// Sends binary data over the websocket.
87    ///
88    /// Returns an error if the message didn't send correctly or if the connection is closed.
89    ///
90    /// If the client is in javascript, the message will contain a blob or an arraybuffer.
91    #[inline]
92    pub fn send_binary(&mut self, data: &[u8]) -> Result<(), SendError> {
93        let socket = match self.socket {
94            Some(ref mut s) => s,
95            None => return Err(SendError::Closed),
96        };
97
98        try!(send(data, Write::by_ref(socket), 0x2));
99        Ok(())
100    }
101
102    /// Returns `true` if the websocket has been closed by either the client (voluntarily or not)
103    /// or by the server (if the websocket protocol was violated).
104    #[inline]
105    pub fn is_closed(&self) -> bool {
106        self.socket.is_none()
107    }
108
109    // TODO: give access to close reason
110}
111
112impl Upgrade for Sender<Websocket> {
113    fn build(&mut self, socket: Box<dyn ReadWrite + Send>) {
114        let websocket = Websocket {
115            socket: Some(socket),
116            state_machine: low_level::StateMachine::new(),
117            current_message_binary: false,
118            current_message_payload: Vec::new(),
119            current_frame_opcode: 0,
120            current_frame_fin: false,
121            current_frame_payload: Vec::new(),
122            messages_in_queue: Vec::new(),
123        };
124
125        let _ = self.send(websocket);
126    }
127}
128
129impl Iterator for Websocket {
130    type Item = Message;
131
132    fn next(&mut self) -> Option<Message> {
133        loop {
134            // If the socket is `None`, the connection has been closed.
135            if self.socket.is_none() {
136                return None;
137            }
138
139            // There may be some messages waiting to be processed.
140            if !self.messages_in_queue.is_empty() {
141                return Some(self.messages_in_queue.remove(0));
142            }
143
144            // Read `n` bytes in `buf`.
145            let mut buf = [0; 256];
146            let n = match self.socket.as_mut().unwrap().read(&mut buf) {
147                Ok(n) if n == 0 => {
148                    // Read returning zero means EOF
149                    self.socket = None;
150                    return None;
151                }
152                Ok(n) => n,
153                Err(ref err) if err.kind() == io::ErrorKind::Interrupted => 0,
154                Err(_) => {
155                    self.socket = None;
156                    return None;
157                }
158            };
159
160            // Fill `messages_in_queue` by analyzing the packets.
161            for element in self.state_machine.feed(&buf[0..n]) {
162                match element {
163                    low_level::Element::FrameStart { fin, opcode, .. } => {
164                        debug_assert!(self.current_frame_payload.is_empty());
165                        self.current_frame_fin = fin;
166                        self.current_frame_opcode = opcode;
167                    }
168
169                    low_level::Element::Data {
170                        data,
171                        last_in_frame,
172                    } => {
173                        // Under normal circumstances we just handle data by pushing it to
174                        // `current_frame_payload`.
175                        self.current_frame_payload.extend(data);
176
177                        // But if the frame is finished we additionnally need to dispatch it.
178                        if last_in_frame {
179                            match self.current_frame_opcode {
180                                // Frame is a continuation of the current message.
181                                0x0 => {
182                                    self.current_message_payload
183                                        .append(&mut self.current_frame_payload);
184
185                                    // If the message is finished, dispatch it.
186                                    if self.current_frame_fin {
187                                        let binary = mem::replace(
188                                            &mut self.current_message_payload,
189                                            Vec::new(),
190                                        );
191
192                                        if self.current_message_binary {
193                                            self.messages_in_queue.push(Message::Binary(binary));
194                                        } else {
195                                            let string = match String::from_utf8(binary) {
196                                                Ok(s) => s,
197                                                Err(_) => {
198                                                    // Closing connection because text wasn't UTF-8
199                                                    let _ = send(
200                                                        b"1007 Invalid UTF-8 encoding",
201                                                        Write::by_ref(
202                                                            self.socket.as_mut().unwrap(),
203                                                        ),
204                                                        0x8,
205                                                    );
206                                                    self.socket = None;
207                                                    return None;
208                                                }
209                                            };
210
211                                            self.messages_in_queue.push(Message::Text(string));
212                                        }
213                                    }
214                                }
215
216                                // Frame is an individual text frame.
217                                0x1 => {
218                                    // If we're in the middle of a message, this frame is invalid
219                                    // and we need to close.
220                                    if !self.current_message_payload.is_empty() {
221                                        let _ = send(
222                                            b"1002 Expected continuation frame",
223                                            Write::by_ref(self.socket.as_mut().unwrap()),
224                                            0x8,
225                                        );
226                                        self.socket = None;
227                                        return None;
228                                    }
229
230                                    if self.current_frame_fin {
231                                        // There's only one frame in this message.
232                                        let binary = mem::replace(
233                                            &mut self.current_frame_payload,
234                                            Vec::new(),
235                                        );
236                                        let string = match String::from_utf8(binary) {
237                                            Ok(s) => s,
238                                            Err(_err) => {
239                                                // Closing connection because text wasn't UTF-8
240                                                let _ = send(
241                                                    b"1007 Invalid UTF-8 encoding",
242                                                    Write::by_ref(self.socket.as_mut().unwrap()),
243                                                    0x8,
244                                                );
245                                                self.socket = None;
246                                                return None;
247                                            }
248                                        };
249
250                                        self.messages_in_queue.push(Message::Text(string));
251                                    } else {
252                                        // Start of a fragmented message.
253                                        self.current_message_binary = false;
254                                        self.current_message_payload
255                                            .append(&mut self.current_frame_payload);
256                                    }
257                                }
258
259                                // Frame is an individual binary frame.
260                                0x2 => {
261                                    // If we're in the middle of a message, this frame is invalid
262                                    // and we need to close.
263                                    if !self.current_message_payload.is_empty() {
264                                        let _ = send(
265                                            b"1002 Expected continuation frame",
266                                            Write::by_ref(self.socket.as_mut().unwrap()),
267                                            0x8,
268                                        );
269                                        self.socket = None;
270                                        return None;
271                                    }
272
273                                    if self.current_frame_fin {
274                                        let binary = mem::replace(
275                                            &mut self.current_frame_payload,
276                                            Vec::new(),
277                                        );
278                                        self.messages_in_queue.push(Message::Binary(binary));
279                                    } else {
280                                        // Start of a fragmented message.
281                                        self.current_message_binary = true;
282                                        self.current_message_payload
283                                            .append(&mut self.current_frame_payload);
284                                    }
285                                }
286
287                                // Close request.
288                                0x8 => {
289                                    // We need to send a confirmation.
290                                    let _ = send(
291                                        &self.current_frame_payload,
292                                        Write::by_ref(self.socket.as_mut().unwrap()),
293                                        0x8,
294                                    );
295                                    // Since the packets are always received in order, and since
296                                    // the server is considered dead as soon as it sends the
297                                    // confirmation, we have no risk of losing packets.
298                                    self.socket = None;
299                                    return None;
300                                }
301
302                                // Ping.
303                                0x9 => {
304                                    // Send the pong.
305                                    let _ = send(
306                                        &self.current_frame_payload,
307                                        Write::by_ref(self.socket.as_mut().unwrap()),
308                                        0xA,
309                                    );
310                                }
311
312                                // Pong. We ignore this as there's nothing to do.
313                                0xA => {}
314
315                                // Unknown opcode means error and close.
316                                _ => {
317                                    let _ = send(
318                                        b"Unknown opcode",
319                                        Write::by_ref(self.socket.as_mut().unwrap()),
320                                        0x8,
321                                    );
322                                    self.socket = None;
323                                    return None;
324                                }
325                            }
326
327                            self.current_frame_payload.clear();
328                        }
329                    }
330
331                    low_level::Element::Error { desc } => {
332                        // The low level layer signaled an error. Sending it to client and closing.
333                        let _ = send(
334                            desc.as_bytes(),
335                            Write::by_ref(self.socket.as_mut().unwrap()),
336                            0x8,
337                        );
338                        self.socket = None;
339                        return None;
340                    }
341                }
342            }
343        }
344    }
345}
346
347// Sends a mesage to a websocket.
348// TODO: message fragmentation?
349fn send<W: Write>(data: &[u8], mut dest: W, opcode: u8) -> io::Result<()> {
350    // Write the opcode
351    assert!(opcode <= 0xf);
352    let first_byte = 0x80 | opcode;
353    try!(dest.write_all(&[first_byte]));
354
355    // Write the length
356    if data.len() >= 65536 {
357        try!(dest.write_all(&[127u8]));
358        let len = data.len() as u64;
359        assert!(len < 0x8000_0000_0000_0000);
360        let len1 = (len >> 56) as u8;
361        let len2 = (len >> 48) as u8;
362        let len3 = (len >> 40) as u8;
363        let len4 = (len >> 32) as u8;
364        let len5 = (len >> 24) as u8;
365        let len6 = (len >> 16) as u8;
366        let len7 = (len >> 8) as u8;
367        let len8 = (len >> 0) as u8;
368        try!(dest.write_all(&[len1, len2, len3, len4, len5, len6, len7, len8]));
369    } else if data.len() >= 126 {
370        try!(dest.write_all(&[126u8]));
371        let len = data.len() as u16;
372        let len1 = (len >> 8) as u8;
373        let len2 = len as u8;
374        try!(dest.write_all(&[len1, len2]));
375    } else {
376        try!(dest.write_all(&[data.len() as u8]));
377    }
378
379    // Write the data
380    try!(dest.write_all(data));
381    try!(dest.flush());
382    Ok(())
383}