Skip to main content

rns_ctl/
ws.rs

1use std::io::{self, Read, Write};
2
3use crate::encode::to_base64;
4use crate::http::HttpRequest;
5use crate::sha1::sha1;
6
7const WS_MAGIC: &str = "258EAFA5-E914-47DA-95CA-5AB5DC11D045";
8
9/// WebSocket frame opcodes.
10pub(crate) const OPCODE_TEXT: u8 = 0x1;
11pub(crate) const OPCODE_CLOSE: u8 = 0x8;
12pub(crate) const OPCODE_PING: u8 = 0x9;
13#[allow(dead_code)]
14pub(crate) const OPCODE_PONG: u8 = 0xA;
15
16/// A decoded WebSocket frame.
17pub struct WsFrame {
18    pub opcode: u8,
19    pub payload: Vec<u8>,
20}
21
22/// Check if an HTTP request is a WebSocket upgrade.
23pub fn is_upgrade(req: &HttpRequest) -> bool {
24    req.headers
25        .get("upgrade")
26        .map(|v| v.eq_ignore_ascii_case("websocket"))
27        .unwrap_or(false)
28}
29
30/// Complete the WebSocket handshake (write 101 response).
31pub fn do_handshake(stream: &mut dyn Write, req: &HttpRequest) -> io::Result<()> {
32    let key = req
33        .headers
34        .get("sec-websocket-key")
35        .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Missing Sec-WebSocket-Key"))?;
36
37    let accept = compute_accept(key);
38
39    write!(
40        stream,
41        "HTTP/1.1 101 Switching Protocols\r\n\
42         Upgrade: websocket\r\n\
43         Connection: Upgrade\r\n\
44         Sec-WebSocket-Accept: {}\r\n\
45         \r\n",
46        accept
47    )?;
48    stream.flush()
49}
50
51fn compute_accept(key: &str) -> String {
52    let combined = format!("{}{}", key, WS_MAGIC);
53    let hash = sha1(combined.as_bytes());
54    to_base64(&hash)
55}
56
57/// Read a single WebSocket frame. Handles masking.
58pub fn read_frame(stream: &mut dyn Read) -> io::Result<WsFrame> {
59    let mut head = [0u8; 2];
60    stream.read_exact(&mut head)?;
61
62    let _fin = head[0] & 0x80 != 0;
63    let opcode = head[0] & 0x0F;
64    let masked = head[1] & 0x80 != 0;
65    let len_byte = head[1] & 0x7F;
66
67    let payload_len: usize = if len_byte <= 125 {
68        len_byte as usize
69    } else if len_byte == 126 {
70        let mut buf = [0u8; 2];
71        stream.read_exact(&mut buf)?;
72        u16::from_be_bytes(buf) as usize
73    } else {
74        let mut buf = [0u8; 8];
75        stream.read_exact(&mut buf)?;
76        u64::from_be_bytes(buf) as usize
77    };
78
79    let mask_key = if masked {
80        let mut key = [0u8; 4];
81        stream.read_exact(&mut key)?;
82        Some(key)
83    } else {
84        None
85    };
86
87    let mut payload = vec![0u8; payload_len];
88    if payload_len > 0 {
89        stream.read_exact(&mut payload)?;
90    }
91
92    // Unmask
93    if let Some(key) = mask_key {
94        for i in 0..payload.len() {
95            payload[i] ^= key[i % 4];
96        }
97    }
98
99    Ok(WsFrame { opcode, payload })
100}
101
102/// Write a text frame (server→client, unmasked).
103pub fn write_text_frame(stream: &mut dyn Write, text: &str) -> io::Result<()> {
104    write_frame(stream, OPCODE_TEXT, text.as_bytes())
105}
106
107/// Write a close frame.
108pub fn write_close_frame(stream: &mut dyn Write) -> io::Result<()> {
109    write_frame(stream, OPCODE_CLOSE, &[])
110}
111
112/// Write a pong frame.
113pub fn write_pong_frame(stream: &mut dyn Write, payload: &[u8]) -> io::Result<()> {
114    write_frame(stream, OPCODE_PONG, payload)
115}
116
117fn write_frame(stream: &mut dyn Write, opcode: u8, data: &[u8]) -> io::Result<()> {
118    // FIN bit set, given opcode
119    stream.write_all(&[0x80 | opcode])?;
120
121    let len = data.len();
122    if len <= 125 {
123        stream.write_all(&[len as u8])?;
124    } else if len <= 65535 {
125        stream.write_all(&[126])?;
126        stream.write_all(&(len as u16).to_be_bytes())?;
127    } else {
128        stream.write_all(&[127])?;
129        stream.write_all(&(len as u64).to_be_bytes())?;
130    }
131
132    stream.write_all(data)?;
133    stream.flush()
134}
135
136/// Handle a WebSocket connection: read frames, respond to control frames,
137/// dispatch text messages.
138///
139/// `on_text` is called for each text frame received.
140/// Returns when the connection is closed (by either side).
141pub fn run_ws_loop(
142    read_stream: &mut dyn Read,
143    write_stream: &mut dyn Write,
144    mut on_text: impl FnMut(&str),
145) -> io::Result<()> {
146    loop {
147        let frame = match read_frame(read_stream) {
148            Ok(f) => f,
149            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
150            Err(e) => return Err(e),
151        };
152
153        match frame.opcode {
154            OPCODE_TEXT => {
155                if let Ok(text) = std::str::from_utf8(&frame.payload) {
156                    on_text(text);
157                }
158            }
159            OPCODE_PING => {
160                let _ = write_pong_frame(write_stream, &frame.payload);
161            }
162            OPCODE_CLOSE => {
163                let _ = write_close_frame(write_stream);
164                break;
165            }
166            _ => {}
167        }
168    }
169    Ok(())
170}
171
172/// Try to parse a complete WebSocket frame from a byte buffer.
173/// Returns `Some((frame, bytes_consumed))` if a complete frame is available.
174fn parse_frame_from_buf(buf: &[u8]) -> Option<(WsFrame, usize)> {
175    if buf.len() < 2 {
176        return None;
177    }
178
179    let opcode = buf[0] & 0x0F;
180    let masked = buf[1] & 0x80 != 0;
181    let len_byte = buf[1] & 0x7F;
182
183    let mut pos = 2;
184
185    let payload_len: usize = if len_byte <= 125 {
186        len_byte as usize
187    } else if len_byte == 126 {
188        if buf.len() < pos + 2 {
189            return None;
190        }
191        let len = u16::from_be_bytes([buf[pos], buf[pos + 1]]) as usize;
192        pos += 2;
193        len
194    } else {
195        if buf.len() < pos + 8 {
196            return None;
197        }
198        let mut arr = [0u8; 8];
199        arr.copy_from_slice(&buf[pos..pos + 8]);
200        let len = u64::from_be_bytes(arr) as usize;
201        pos += 8;
202        len
203    };
204
205    let mask_key = if masked {
206        if buf.len() < pos + 4 {
207            return None;
208        }
209        let key = [buf[pos], buf[pos + 1], buf[pos + 2], buf[pos + 3]];
210        pos += 4;
211        Some(key)
212    } else {
213        None
214    };
215
216    if buf.len() < pos + payload_len {
217        return None;
218    }
219
220    let mut payload = buf[pos..pos + payload_len].to_vec();
221    pos += payload_len;
222
223    if let Some(key) = mask_key {
224        for i in 0..payload.len() {
225            payload[i] ^= key[i % 4];
226        }
227    }
228
229    Some((WsFrame { opcode, payload }, pos))
230}
231
232/// Buffered non-blocking WebSocket frame reader.
233///
234/// Accumulates bytes from a non-blocking or timeout-based stream and
235/// yields complete frames without requiring `read_exact`.
236pub struct WsBuf {
237    buf: Vec<u8>,
238}
239
240impl WsBuf {
241    pub fn new() -> Self {
242        WsBuf {
243            buf: Vec::with_capacity(4096),
244        }
245    }
246
247    /// Try to read a complete frame. Returns:
248    /// - `Ok(Some(frame))` if a complete frame was parsed
249    /// - `Ok(None)` if not enough data yet (WouldBlock/TimedOut)
250    /// - `Err(e)` on connection error (EOF, etc.)
251    pub fn try_read_frame(&mut self, stream: &mut dyn Read) -> io::Result<Option<WsFrame>> {
252        // First, try to parse from existing buffer data
253        if let Some((frame, consumed)) = parse_frame_from_buf(&self.buf) {
254            self.buf.drain(..consumed);
255            return Ok(Some(frame));
256        }
257
258        // Read more data from the stream
259        let mut tmp = [0u8; 4096];
260        match stream.read(&mut tmp) {
261            Ok(0) => {
262                return Err(io::Error::new(
263                    io::ErrorKind::UnexpectedEof,
264                    "connection closed",
265                ));
266            }
267            Ok(n) => {
268                self.buf.extend_from_slice(&tmp[..n]);
269            }
270            Err(e)
271                if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut =>
272            {
273                return Ok(None);
274            }
275            Err(e) => return Err(e),
276        }
277
278        // Try to parse again with new data
279        if let Some((frame, consumed)) = parse_frame_from_buf(&self.buf) {
280            self.buf.drain(..consumed);
281            return Ok(Some(frame));
282        }
283
284        Ok(None)
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn compute_accept_rfc() {
294        // Verified against Python hashlib + base64
295        let accept = compute_accept("dGhlIHNhbXBsZSBub25jZQ==");
296        assert_eq!(accept, "RyVTkfbvgIu+vAZLbuzyhbcrH/0=");
297    }
298
299    #[test]
300    fn write_read_text_frame() {
301        let mut buf = Vec::new();
302        write_text_frame(&mut buf, "hello").unwrap();
303
304        // Server frames are unmasked; simulate client reading by reading as-is
305        let frame = read_frame(&mut &buf[..]).unwrap();
306        assert_eq!(frame.opcode, OPCODE_TEXT);
307        assert_eq!(frame.payload, b"hello");
308    }
309
310    #[test]
311    fn write_read_large_frame() {
312        let text = "x".repeat(300);
313        let mut buf = Vec::new();
314        write_text_frame(&mut buf, &text).unwrap();
315
316        let frame = read_frame(&mut &buf[..]).unwrap();
317        assert_eq!(frame.opcode, OPCODE_TEXT);
318        assert_eq!(frame.payload.len(), 300);
319    }
320
321    #[test]
322    fn parse_frame_from_buf_complete() {
323        let mut data = Vec::new();
324        write_text_frame(&mut data, "hello").unwrap();
325
326        let result = parse_frame_from_buf(&data);
327        assert!(result.is_some());
328        let (frame, consumed) = result.unwrap();
329        assert_eq!(frame.opcode, OPCODE_TEXT);
330        assert_eq!(frame.payload, b"hello");
331        assert_eq!(consumed, data.len());
332    }
333
334    #[test]
335    fn parse_frame_from_buf_incomplete() {
336        let mut data = Vec::new();
337        write_text_frame(&mut data, "hello").unwrap();
338
339        // Only provide first byte — not enough
340        assert!(parse_frame_from_buf(&data[..1]).is_none());
341        // Provide header but truncate payload
342        assert!(parse_frame_from_buf(&data[..3]).is_none());
343    }
344
345    #[test]
346    fn wsbuf_try_read_frame_wouldblock() {
347        use std::io;
348
349        struct WouldBlockReader;
350        impl Read for WouldBlockReader {
351            fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
352                Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
353            }
354        }
355
356        let mut ws_buf = WsBuf::new();
357        let result = ws_buf.try_read_frame(&mut WouldBlockReader);
358        assert!(result.is_ok());
359        assert!(result.unwrap().is_none());
360    }
361
362    #[test]
363    fn wsbuf_try_read_frame_complete() {
364        let mut data = Vec::new();
365        write_text_frame(&mut data, "test").unwrap();
366
367        let mut ws_buf = WsBuf::new();
368        let mut cursor = io::Cursor::new(data);
369        let result = ws_buf.try_read_frame(&mut cursor).unwrap();
370        assert!(result.is_some());
371        let frame = result.unwrap();
372        assert_eq!(frame.opcode, OPCODE_TEXT);
373        assert_eq!(frame.payload, b"test");
374    }
375}