1use std::io::{self, Read, Write};
2
3use crate::encode::to_base64;
4use crate::http::HttpRequest;
5use crate::sha1::sha1;
6
7const WS_MAGIC: &str = "258EAFA5-E914-47DA-95CA-5AB5DC11D045";
8
9const OPCODE_TEXT: u8 = 0x1;
11const OPCODE_CLOSE: u8 = 0x8;
12const OPCODE_PING: u8 = 0x9;
13const OPCODE_PONG: u8 = 0xA;
14
15pub struct WsFrame {
17 pub opcode: u8,
18 pub payload: Vec<u8>,
19}
20
21pub fn is_upgrade(req: &HttpRequest) -> bool {
23 req.headers
24 .get("upgrade")
25 .map(|v| v.eq_ignore_ascii_case("websocket"))
26 .unwrap_or(false)
27}
28
29pub fn do_handshake(stream: &mut dyn Write, req: &HttpRequest) -> io::Result<()> {
31 let key = req
32 .headers
33 .get("sec-websocket-key")
34 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Missing Sec-WebSocket-Key"))?;
35
36 let accept = compute_accept(key);
37
38 write!(
39 stream,
40 "HTTP/1.1 101 Switching Protocols\r\n\
41 Upgrade: websocket\r\n\
42 Connection: Upgrade\r\n\
43 Sec-WebSocket-Accept: {}\r\n\
44 \r\n",
45 accept
46 )?;
47 stream.flush()
48}
49
50fn compute_accept(key: &str) -> String {
51 let combined = format!("{}{}", key, WS_MAGIC);
52 let hash = sha1(combined.as_bytes());
53 to_base64(&hash)
54}
55
56pub fn read_frame(stream: &mut dyn Read) -> io::Result<WsFrame> {
58 let mut head = [0u8; 2];
59 stream.read_exact(&mut head)?;
60
61 let _fin = head[0] & 0x80 != 0;
62 let opcode = head[0] & 0x0F;
63 let masked = head[1] & 0x80 != 0;
64 let len_byte = head[1] & 0x7F;
65
66 let payload_len: usize = if len_byte <= 125 {
67 len_byte as usize
68 } else if len_byte == 126 {
69 let mut buf = [0u8; 2];
70 stream.read_exact(&mut buf)?;
71 u16::from_be_bytes(buf) as usize
72 } else {
73 let mut buf = [0u8; 8];
74 stream.read_exact(&mut buf)?;
75 u64::from_be_bytes(buf) as usize
76 };
77
78 let mask_key = if masked {
79 let mut key = [0u8; 4];
80 stream.read_exact(&mut key)?;
81 Some(key)
82 } else {
83 None
84 };
85
86 let mut payload = vec![0u8; payload_len];
87 if payload_len > 0 {
88 stream.read_exact(&mut payload)?;
89 }
90
91 if let Some(key) = mask_key {
93 for i in 0..payload.len() {
94 payload[i] ^= key[i % 4];
95 }
96 }
97
98 Ok(WsFrame { opcode, payload })
99}
100
101pub fn write_text_frame(stream: &mut dyn Write, text: &str) -> io::Result<()> {
103 write_frame(stream, OPCODE_TEXT, text.as_bytes())
104}
105
106pub fn write_close_frame(stream: &mut dyn Write) -> io::Result<()> {
108 write_frame(stream, OPCODE_CLOSE, &[])
109}
110
111pub fn write_pong_frame(stream: &mut dyn Write, payload: &[u8]) -> io::Result<()> {
113 write_frame(stream, OPCODE_PONG, payload)
114}
115
116fn write_frame(stream: &mut dyn Write, opcode: u8, data: &[u8]) -> io::Result<()> {
117 stream.write_all(&[0x80 | opcode])?;
119
120 let len = data.len();
121 if len <= 125 {
122 stream.write_all(&[len as u8])?;
123 } else if len <= 65535 {
124 stream.write_all(&[126])?;
125 stream.write_all(&(len as u16).to_be_bytes())?;
126 } else {
127 stream.write_all(&[127])?;
128 stream.write_all(&(len as u64).to_be_bytes())?;
129 }
130
131 stream.write_all(data)?;
132 stream.flush()
133}
134
135pub fn run_ws_loop(
141 read_stream: &mut dyn Read,
142 write_stream: &mut dyn Write,
143 mut on_text: impl FnMut(&str),
144) -> io::Result<()> {
145 loop {
146 let frame = match read_frame(read_stream) {
147 Ok(f) => f,
148 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
149 Err(e) => return Err(e),
150 };
151
152 match frame.opcode {
153 OPCODE_TEXT => {
154 if let Ok(text) = std::str::from_utf8(&frame.payload) {
155 on_text(text);
156 }
157 }
158 OPCODE_PING => {
159 let _ = write_pong_frame(write_stream, &frame.payload);
160 }
161 OPCODE_CLOSE => {
162 let _ = write_close_frame(write_stream);
163 break;
164 }
165 _ => {}
166 }
167 }
168 Ok(())
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn compute_accept_rfc() {
177 let accept = compute_accept("dGhlIHNhbXBsZSBub25jZQ==");
179 assert_eq!(accept, "RyVTkfbvgIu+vAZLbuzyhbcrH/0=");
180 }
181
182 #[test]
183 fn write_read_text_frame() {
184 let mut buf = Vec::new();
185 write_text_frame(&mut buf, "hello").unwrap();
186
187 let frame = read_frame(&mut &buf[..]).unwrap();
189 assert_eq!(frame.opcode, OPCODE_TEXT);
190 assert_eq!(frame.payload, b"hello");
191 }
192
193 #[test]
194 fn write_read_large_frame() {
195 let text = "x".repeat(300);
196 let mut buf = Vec::new();
197 write_text_frame(&mut buf, &text).unwrap();
198
199 let frame = read_frame(&mut &buf[..]).unwrap();
200 assert_eq!(frame.opcode, OPCODE_TEXT);
201 assert_eq!(frame.payload.len(), 300);
202 }
203}