1use std::io::{self, Read, Write};
2
3use crate::encode::to_base64;
4use crate::http::HttpRequest;
5use crate::sha1::sha1;
6
7const WS_MAGIC: &str = "258EAFA5-E914-47DA-95CA-5AB5DC11D045";
8
9pub(crate) const OPCODE_TEXT: u8 = 0x1;
11pub(crate) const OPCODE_CLOSE: u8 = 0x8;
12pub(crate) const OPCODE_PING: u8 = 0x9;
13#[allow(dead_code)]
14pub(crate) const OPCODE_PONG: u8 = 0xA;
15
16pub struct WsFrame {
18 pub opcode: u8,
19 pub payload: Vec<u8>,
20}
21
22pub fn is_upgrade(req: &HttpRequest) -> bool {
24 req.headers
25 .get("upgrade")
26 .map(|v| v.eq_ignore_ascii_case("websocket"))
27 .unwrap_or(false)
28}
29
30pub fn do_handshake(stream: &mut dyn Write, req: &HttpRequest) -> io::Result<()> {
32 let key = req
33 .headers
34 .get("sec-websocket-key")
35 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Missing Sec-WebSocket-Key"))?;
36
37 let accept = compute_accept(key);
38
39 write!(
40 stream,
41 "HTTP/1.1 101 Switching Protocols\r\n\
42 Upgrade: websocket\r\n\
43 Connection: Upgrade\r\n\
44 Sec-WebSocket-Accept: {}\r\n\
45 \r\n",
46 accept
47 )?;
48 stream.flush()
49}
50
51fn compute_accept(key: &str) -> String {
52 let combined = format!("{}{}", key, WS_MAGIC);
53 let hash = sha1(combined.as_bytes());
54 to_base64(&hash)
55}
56
57pub fn read_frame(stream: &mut dyn Read) -> io::Result<WsFrame> {
59 let mut head = [0u8; 2];
60 stream.read_exact(&mut head)?;
61
62 let _fin = head[0] & 0x80 != 0;
63 let opcode = head[0] & 0x0F;
64 let masked = head[1] & 0x80 != 0;
65 let len_byte = head[1] & 0x7F;
66
67 let payload_len: usize = if len_byte <= 125 {
68 len_byte as usize
69 } else if len_byte == 126 {
70 let mut buf = [0u8; 2];
71 stream.read_exact(&mut buf)?;
72 u16::from_be_bytes(buf) as usize
73 } else {
74 let mut buf = [0u8; 8];
75 stream.read_exact(&mut buf)?;
76 u64::from_be_bytes(buf) as usize
77 };
78
79 let mask_key = if masked {
80 let mut key = [0u8; 4];
81 stream.read_exact(&mut key)?;
82 Some(key)
83 } else {
84 None
85 };
86
87 let mut payload = vec![0u8; payload_len];
88 if payload_len > 0 {
89 stream.read_exact(&mut payload)?;
90 }
91
92 if let Some(key) = mask_key {
94 for i in 0..payload.len() {
95 payload[i] ^= key[i % 4];
96 }
97 }
98
99 Ok(WsFrame { opcode, payload })
100}
101
102pub fn write_text_frame(stream: &mut dyn Write, text: &str) -> io::Result<()> {
104 write_frame(stream, OPCODE_TEXT, text.as_bytes())
105}
106
107pub fn write_close_frame(stream: &mut dyn Write) -> io::Result<()> {
109 write_frame(stream, OPCODE_CLOSE, &[])
110}
111
112pub fn write_pong_frame(stream: &mut dyn Write, payload: &[u8]) -> io::Result<()> {
114 write_frame(stream, OPCODE_PONG, payload)
115}
116
117fn write_frame(stream: &mut dyn Write, opcode: u8, data: &[u8]) -> io::Result<()> {
118 stream.write_all(&[0x80 | opcode])?;
120
121 let len = data.len();
122 if len <= 125 {
123 stream.write_all(&[len as u8])?;
124 } else if len <= 65535 {
125 stream.write_all(&[126])?;
126 stream.write_all(&(len as u16).to_be_bytes())?;
127 } else {
128 stream.write_all(&[127])?;
129 stream.write_all(&(len as u64).to_be_bytes())?;
130 }
131
132 stream.write_all(data)?;
133 stream.flush()
134}
135
136pub fn run_ws_loop(
142 read_stream: &mut dyn Read,
143 write_stream: &mut dyn Write,
144 mut on_text: impl FnMut(&str),
145) -> io::Result<()> {
146 loop {
147 let frame = match read_frame(read_stream) {
148 Ok(f) => f,
149 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
150 Err(e) => return Err(e),
151 };
152
153 match frame.opcode {
154 OPCODE_TEXT => {
155 if let Ok(text) = std::str::from_utf8(&frame.payload) {
156 on_text(text);
157 }
158 }
159 OPCODE_PING => {
160 let _ = write_pong_frame(write_stream, &frame.payload);
161 }
162 OPCODE_CLOSE => {
163 let _ = write_close_frame(write_stream);
164 break;
165 }
166 _ => {}
167 }
168 }
169 Ok(())
170}
171
172fn parse_frame_from_buf(buf: &[u8]) -> Option<(WsFrame, usize)> {
175 if buf.len() < 2 {
176 return None;
177 }
178
179 let opcode = buf[0] & 0x0F;
180 let masked = buf[1] & 0x80 != 0;
181 let len_byte = buf[1] & 0x7F;
182
183 let mut pos = 2;
184
185 let payload_len: usize = if len_byte <= 125 {
186 len_byte as usize
187 } else if len_byte == 126 {
188 if buf.len() < pos + 2 {
189 return None;
190 }
191 let len = u16::from_be_bytes([buf[pos], buf[pos + 1]]) as usize;
192 pos += 2;
193 len
194 } else {
195 if buf.len() < pos + 8 {
196 return None;
197 }
198 let mut arr = [0u8; 8];
199 arr.copy_from_slice(&buf[pos..pos + 8]);
200 let len = u64::from_be_bytes(arr) as usize;
201 pos += 8;
202 len
203 };
204
205 let mask_key = if masked {
206 if buf.len() < pos + 4 {
207 return None;
208 }
209 let key = [buf[pos], buf[pos + 1], buf[pos + 2], buf[pos + 3]];
210 pos += 4;
211 Some(key)
212 } else {
213 None
214 };
215
216 if buf.len() < pos + payload_len {
217 return None;
218 }
219
220 let mut payload = buf[pos..pos + payload_len].to_vec();
221 pos += payload_len;
222
223 if let Some(key) = mask_key {
224 for i in 0..payload.len() {
225 payload[i] ^= key[i % 4];
226 }
227 }
228
229 Some((WsFrame { opcode, payload }, pos))
230}
231
232pub struct WsBuf {
237 buf: Vec<u8>,
238}
239
240impl WsBuf {
241 pub fn new() -> Self {
242 WsBuf {
243 buf: Vec::with_capacity(4096),
244 }
245 }
246
247 pub fn try_read_frame(&mut self, stream: &mut dyn Read) -> io::Result<Option<WsFrame>> {
252 if let Some((frame, consumed)) = parse_frame_from_buf(&self.buf) {
254 self.buf.drain(..consumed);
255 return Ok(Some(frame));
256 }
257
258 let mut tmp = [0u8; 4096];
260 match stream.read(&mut tmp) {
261 Ok(0) => {
262 return Err(io::Error::new(
263 io::ErrorKind::UnexpectedEof,
264 "connection closed",
265 ));
266 }
267 Ok(n) => {
268 self.buf.extend_from_slice(&tmp[..n]);
269 }
270 Err(e)
271 if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut =>
272 {
273 return Ok(None);
274 }
275 Err(e) => return Err(e),
276 }
277
278 if let Some((frame, consumed)) = parse_frame_from_buf(&self.buf) {
280 self.buf.drain(..consumed);
281 return Ok(Some(frame));
282 }
283
284 Ok(None)
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn compute_accept_rfc() {
294 let accept = compute_accept("dGhlIHNhbXBsZSBub25jZQ==");
296 assert_eq!(accept, "RyVTkfbvgIu+vAZLbuzyhbcrH/0=");
297 }
298
299 #[test]
300 fn write_read_text_frame() {
301 let mut buf = Vec::new();
302 write_text_frame(&mut buf, "hello").unwrap();
303
304 let frame = read_frame(&mut &buf[..]).unwrap();
306 assert_eq!(frame.opcode, OPCODE_TEXT);
307 assert_eq!(frame.payload, b"hello");
308 }
309
310 #[test]
311 fn write_read_large_frame() {
312 let text = "x".repeat(300);
313 let mut buf = Vec::new();
314 write_text_frame(&mut buf, &text).unwrap();
315
316 let frame = read_frame(&mut &buf[..]).unwrap();
317 assert_eq!(frame.opcode, OPCODE_TEXT);
318 assert_eq!(frame.payload.len(), 300);
319 }
320
321 #[test]
322 fn parse_frame_from_buf_complete() {
323 let mut data = Vec::new();
324 write_text_frame(&mut data, "hello").unwrap();
325
326 let result = parse_frame_from_buf(&data);
327 assert!(result.is_some());
328 let (frame, consumed) = result.unwrap();
329 assert_eq!(frame.opcode, OPCODE_TEXT);
330 assert_eq!(frame.payload, b"hello");
331 assert_eq!(consumed, data.len());
332 }
333
334 #[test]
335 fn parse_frame_from_buf_incomplete() {
336 let mut data = Vec::new();
337 write_text_frame(&mut data, "hello").unwrap();
338
339 assert!(parse_frame_from_buf(&data[..1]).is_none());
341 assert!(parse_frame_from_buf(&data[..3]).is_none());
343 }
344
345 #[test]
346 fn wsbuf_try_read_frame_wouldblock() {
347 use std::io;
348
349 struct WouldBlockReader;
350 impl Read for WouldBlockReader {
351 fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
352 Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
353 }
354 }
355
356 let mut ws_buf = WsBuf::new();
357 let result = ws_buf.try_read_frame(&mut WouldBlockReader);
358 assert!(result.is_ok());
359 assert!(result.unwrap().is_none());
360 }
361
362 #[test]
363 fn wsbuf_try_read_frame_complete() {
364 let mut data = Vec::new();
365 write_text_frame(&mut data, "test").unwrap();
366
367 let mut ws_buf = WsBuf::new();
368 let mut cursor = io::Cursor::new(data);
369 let result = ws_buf.try_read_frame(&mut cursor).unwrap();
370 assert!(result.is_some());
371 let frame = result.unwrap();
372 assert_eq!(frame.opcode, OPCODE_TEXT);
373 assert_eq!(frame.payload, b"test");
374 }
375}