Skip to main content

rns_ctl/
ws.rs

1use std::io::{self, Read, Write};
2
3use crate::encode::to_base64;
4use crate::http::HttpRequest;
5use crate::sha1::sha1;
6
7const WS_MAGIC: &str = "258EAFA5-E914-47DA-95CA-5AB5DC11D045";
8
9/// WebSocket frame opcodes.
10const OPCODE_TEXT: u8 = 0x1;
11const OPCODE_CLOSE: u8 = 0x8;
12const OPCODE_PING: u8 = 0x9;
13const OPCODE_PONG: u8 = 0xA;
14
15/// A decoded WebSocket frame.
16pub struct WsFrame {
17    pub opcode: u8,
18    pub payload: Vec<u8>,
19}
20
21/// Check if an HTTP request is a WebSocket upgrade.
22pub fn is_upgrade(req: &HttpRequest) -> bool {
23    req.headers
24        .get("upgrade")
25        .map(|v| v.eq_ignore_ascii_case("websocket"))
26        .unwrap_or(false)
27}
28
29/// Complete the WebSocket handshake (write 101 response).
30pub fn do_handshake(stream: &mut dyn Write, req: &HttpRequest) -> io::Result<()> {
31    let key = req
32        .headers
33        .get("sec-websocket-key")
34        .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Missing Sec-WebSocket-Key"))?;
35
36    let accept = compute_accept(key);
37
38    write!(
39        stream,
40        "HTTP/1.1 101 Switching Protocols\r\n\
41         Upgrade: websocket\r\n\
42         Connection: Upgrade\r\n\
43         Sec-WebSocket-Accept: {}\r\n\
44         \r\n",
45        accept
46    )?;
47    stream.flush()
48}
49
50fn compute_accept(key: &str) -> String {
51    let combined = format!("{}{}", key, WS_MAGIC);
52    let hash = sha1(combined.as_bytes());
53    to_base64(&hash)
54}
55
56/// Read a single WebSocket frame. Handles masking.
57pub fn read_frame(stream: &mut dyn Read) -> io::Result<WsFrame> {
58    let mut head = [0u8; 2];
59    stream.read_exact(&mut head)?;
60
61    let _fin = head[0] & 0x80 != 0;
62    let opcode = head[0] & 0x0F;
63    let masked = head[1] & 0x80 != 0;
64    let len_byte = head[1] & 0x7F;
65
66    let payload_len: usize = if len_byte <= 125 {
67        len_byte as usize
68    } else if len_byte == 126 {
69        let mut buf = [0u8; 2];
70        stream.read_exact(&mut buf)?;
71        u16::from_be_bytes(buf) as usize
72    } else {
73        let mut buf = [0u8; 8];
74        stream.read_exact(&mut buf)?;
75        u64::from_be_bytes(buf) as usize
76    };
77
78    let mask_key = if masked {
79        let mut key = [0u8; 4];
80        stream.read_exact(&mut key)?;
81        Some(key)
82    } else {
83        None
84    };
85
86    let mut payload = vec![0u8; payload_len];
87    if payload_len > 0 {
88        stream.read_exact(&mut payload)?;
89    }
90
91    // Unmask
92    if let Some(key) = mask_key {
93        for i in 0..payload.len() {
94            payload[i] ^= key[i % 4];
95        }
96    }
97
98    Ok(WsFrame { opcode, payload })
99}
100
101/// Write a text frame (server→client, unmasked).
102pub fn write_text_frame(stream: &mut dyn Write, text: &str) -> io::Result<()> {
103    write_frame(stream, OPCODE_TEXT, text.as_bytes())
104}
105
106/// Write a close frame.
107pub fn write_close_frame(stream: &mut dyn Write) -> io::Result<()> {
108    write_frame(stream, OPCODE_CLOSE, &[])
109}
110
111/// Write a pong frame.
112pub fn write_pong_frame(stream: &mut dyn Write, payload: &[u8]) -> io::Result<()> {
113    write_frame(stream, OPCODE_PONG, payload)
114}
115
116fn write_frame(stream: &mut dyn Write, opcode: u8, data: &[u8]) -> io::Result<()> {
117    // FIN bit set, given opcode
118    stream.write_all(&[0x80 | opcode])?;
119
120    let len = data.len();
121    if len <= 125 {
122        stream.write_all(&[len as u8])?;
123    } else if len <= 65535 {
124        stream.write_all(&[126])?;
125        stream.write_all(&(len as u16).to_be_bytes())?;
126    } else {
127        stream.write_all(&[127])?;
128        stream.write_all(&(len as u64).to_be_bytes())?;
129    }
130
131    stream.write_all(data)?;
132    stream.flush()
133}
134
135/// Handle a WebSocket connection: read frames, respond to control frames,
136/// dispatch text messages.
137///
138/// `on_text` is called for each text frame received.
139/// Returns when the connection is closed (by either side).
140pub fn run_ws_loop(
141    read_stream: &mut dyn Read,
142    write_stream: &mut dyn Write,
143    mut on_text: impl FnMut(&str),
144) -> io::Result<()> {
145    loop {
146        let frame = match read_frame(read_stream) {
147            Ok(f) => f,
148            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
149            Err(e) => return Err(e),
150        };
151
152        match frame.opcode {
153            OPCODE_TEXT => {
154                if let Ok(text) = std::str::from_utf8(&frame.payload) {
155                    on_text(text);
156                }
157            }
158            OPCODE_PING => {
159                let _ = write_pong_frame(write_stream, &frame.payload);
160            }
161            OPCODE_CLOSE => {
162                let _ = write_close_frame(write_stream);
163                break;
164            }
165            _ => {}
166        }
167    }
168    Ok(())
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn compute_accept_rfc() {
177        // Verified against Python hashlib + base64
178        let accept = compute_accept("dGhlIHNhbXBsZSBub25jZQ==");
179        assert_eq!(accept, "RyVTkfbvgIu+vAZLbuzyhbcrH/0=");
180    }
181
182    #[test]
183    fn write_read_text_frame() {
184        let mut buf = Vec::new();
185        write_text_frame(&mut buf, "hello").unwrap();
186
187        // Server frames are unmasked; simulate client reading by reading as-is
188        let frame = read_frame(&mut &buf[..]).unwrap();
189        assert_eq!(frame.opcode, OPCODE_TEXT);
190        assert_eq!(frame.payload, b"hello");
191    }
192
193    #[test]
194    fn write_read_large_frame() {
195        let text = "x".repeat(300);
196        let mut buf = Vec::new();
197        write_text_frame(&mut buf, &text).unwrap();
198
199        let frame = read_frame(&mut &buf[..]).unwrap();
200        assert_eq!(frame.opcode, OPCODE_TEXT);
201        assert_eq!(frame.payload.len(), 300);
202    }
203}