1use std::io::{Read, Write, BufRead, BufReader};
7use std::net::TcpStream;
8
9#[cfg(feature = "websocket")]
10use native_tls::TlsConnector;
11
12#[derive(Debug)]
14pub struct WebSocketError {
15 pub message: String,
16}
17
18impl WebSocketError {
19 pub fn new(msg: impl Into<String>) -> Self {
20 Self { message: msg.into() }
21 }
22}
23
24impl std::fmt::Display for WebSocketError {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 write!(f, "WebSocket error: {}", self.message)
27 }
28}
29
30impl std::error::Error for WebSocketError {}
31
32#[derive(Debug, Clone, Copy, PartialEq)]
34#[repr(u8)]
35pub enum Opcode {
36 Continuation = 0x0,
37 Text = 0x1,
38 Binary = 0x2,
39 Close = 0x8,
40 Ping = 0x9,
41 Pong = 0xA,
42}
43
44impl Opcode {
45 fn from_u8(val: u8) -> Option<Self> {
46 match val {
47 0x0 => Some(Opcode::Continuation),
48 0x1 => Some(Opcode::Text),
49 0x2 => Some(Opcode::Binary),
50 0x8 => Some(Opcode::Close),
51 0x9 => Some(Opcode::Ping),
52 0xA => Some(Opcode::Pong),
53 _ => None,
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub enum Message {
61 Text(String),
62 Binary(Vec<u8>),
63 Close,
64 Ping(Vec<u8>),
65 Pong(Vec<u8>),
66}
67
68enum Stream {
70 Plain(TcpStream),
71 #[cfg(feature = "websocket")]
72 Tls(native_tls::TlsStream<TcpStream>),
73}
74
75impl Read for Stream {
76 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
77 match self {
78 Stream::Plain(s) => s.read(buf),
79 #[cfg(feature = "websocket")]
80 Stream::Tls(s) => s.read(buf),
81 }
82 }
83}
84
85impl Write for Stream {
86 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
87 match self {
88 Stream::Plain(s) => s.write(buf),
89 #[cfg(feature = "websocket")]
90 Stream::Tls(s) => s.write(buf),
91 }
92 }
93
94 fn flush(&mut self) -> std::io::Result<()> {
95 match self {
96 Stream::Plain(s) => s.flush(),
97 #[cfg(feature = "websocket")]
98 Stream::Tls(s) => s.flush(),
99 }
100 }
101}
102
103pub struct WebSocket {
105 stream: Stream,
106}
107
108impl WebSocket {
109 pub fn connect(url: &str) -> Result<Self, WebSocketError> {
113 let (secure, host, port, path) = Self::parse_url(url)?;
115
116 let addr = format!("{}:{}", host, port);
118 let tcp_stream = TcpStream::connect(&addr)
119 .map_err(|e| WebSocketError::new(format!("TCP connection failed: {}", e)))?;
120
121 tcp_stream.set_read_timeout(Some(std::time::Duration::from_secs(30))).ok();
123 tcp_stream.set_write_timeout(Some(std::time::Duration::from_secs(30))).ok();
124
125 let stream = if secure {
127 #[cfg(feature = "websocket")]
128 {
129 let connector = TlsConnector::new()
130 .map_err(|e| WebSocketError::new(format!("TLS setup failed: {}", e)))?;
131 let tls_stream = connector.connect(&host, tcp_stream)
132 .map_err(|e| WebSocketError::new(format!("TLS handshake failed: {}", e)))?;
133 Stream::Tls(tls_stream)
134 }
135 #[cfg(not(feature = "websocket"))]
136 {
137 return Err(WebSocketError::new("TLS support not compiled in"));
138 }
139 } else {
140 Stream::Plain(tcp_stream)
141 };
142
143 let mut ws = WebSocket { stream };
144
145 ws.handshake(&host, port, &path)?;
147
148 Ok(ws)
149 }
150
151 fn parse_url(url: &str) -> Result<(bool, String, u16, String), WebSocketError> {
153 let (secure, rest) = if url.starts_with("wss://") {
154 (true, &url[6..])
155 } else if url.starts_with("ws://") {
156 (false, &url[5..])
157 } else {
158 return Err(WebSocketError::new("URL must start with ws:// or wss://"));
159 };
160
161 let (host_port, path) = match rest.find('/') {
163 Some(idx) => (&rest[..idx], &rest[idx..]),
164 None => (rest, "/"),
165 };
166
167 let (host, port) = match host_port.find(':') {
169 Some(idx) => {
170 let port_str = &host_port[idx + 1..];
171 let port = port_str.parse::<u16>()
172 .map_err(|_| WebSocketError::new("Invalid port number"))?;
173 (host_port[..idx].to_string(), port)
174 }
175 None => (host_port.to_string(), if secure { 443 } else { 80 }),
176 };
177
178 Ok((secure, host, port, path.to_string()))
179 }
180
181 fn handshake(&mut self, host: &str, port: u16, path: &str) -> Result<(), WebSocketError> {
183 let key_bytes: [u8; 16] = rand::random();
185 let key = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, key_bytes);
186
187 let host_header = if port == 80 || port == 443 {
189 host.to_string()
190 } else {
191 format!("{}:{}", host, port)
192 };
193
194 let request = format!(
195 "GET {} HTTP/1.1\r\n\
196 Host: {}\r\n\
197 Upgrade: websocket\r\n\
198 Connection: Upgrade\r\n\
199 Sec-WebSocket-Key: {}\r\n\
200 Sec-WebSocket-Version: 13\r\n\
201 \r\n",
202 path, host_header, key
203 );
204
205 self.stream.write_all(request.as_bytes())
207 .map_err(|e| WebSocketError::new(format!("Failed to send handshake: {}", e)))?;
208 self.stream.flush()
209 .map_err(|e| WebSocketError::new(format!("Failed to flush handshake: {}", e)))?;
210
211 let mut reader = BufReader::new(&mut self.stream);
213 let mut response_line = String::new();
214 reader.read_line(&mut response_line)
215 .map_err(|e| WebSocketError::new(format!("Failed to read response: {}", e)))?;
216
217 if !response_line.starts_with("HTTP/1.1 101") {
219 return Err(WebSocketError::new(format!("Handshake failed: {}", response_line.trim())));
220 }
221
222 let expected_accept = Self::compute_accept_key(&key);
224 let mut found_accept = false;
225
226 loop {
227 let mut line = String::new();
228 reader.read_line(&mut line)
229 .map_err(|e| WebSocketError::new(format!("Failed to read headers: {}", e)))?;
230
231 let line = line.trim();
232 if line.is_empty() {
233 break; }
235
236 if let Some((name, value)) = line.split_once(':') {
237 let name = name.trim().to_lowercase();
238 let value = value.trim();
239
240 if name == "sec-websocket-accept" {
241 if value != expected_accept {
242 return Err(WebSocketError::new("Invalid Sec-WebSocket-Accept"));
243 }
244 found_accept = true;
245 }
246 }
247 }
248
249 if !found_accept {
250 return Err(WebSocketError::new("Missing Sec-WebSocket-Accept header"));
251 }
252
253 Ok(())
254 }
255
256 fn compute_accept_key(key: &str) -> String {
258 use sha1::{Sha1, Digest};
259
260 let magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
262 let combined = format!("{}{}", key, magic);
263
264 let mut hasher = Sha1::new();
266 hasher.update(combined.as_bytes());
267 let hash = hasher.finalize();
268
269 base64::Engine::encode(&base64::engine::general_purpose::STANDARD, hash)
271 }
272
273 pub fn send_text(&mut self, text: &str) -> Result<(), WebSocketError> {
275 self.send_frame(Opcode::Text, text.as_bytes())
276 }
277
278 pub fn send_binary(&mut self, data: &[u8]) -> Result<(), WebSocketError> {
280 self.send_frame(Opcode::Binary, data)
281 }
282
283 pub fn send_close(&mut self) -> Result<(), WebSocketError> {
285 self.send_frame(Opcode::Close, &[])
286 }
287
288 fn send_frame(&mut self, opcode: Opcode, payload: &[u8]) -> Result<(), WebSocketError> {
292 let mut frame = Vec::with_capacity(14 + payload.len());
293
294 frame.push(0x80 | (opcode as u8)); let len = payload.len();
299 if len < 126 {
300 frame.push(0x80 | len as u8); } else if len < 65536 {
302 frame.push(0x80 | 126);
303 frame.push((len >> 8) as u8);
304 frame.push(len as u8);
305 } else {
306 frame.push(0x80 | 127);
307 for i in (0..8).rev() {
308 frame.push((len >> (i * 8)) as u8);
309 }
310 }
311
312 let mask: [u8; 4] = rand::random();
314 frame.extend_from_slice(&mask);
315
316 for (i, &byte) in payload.iter().enumerate() {
318 frame.push(byte ^ mask[i % 4]);
319 }
320
321 self.stream.write_all(&frame)
322 .map_err(|e| WebSocketError::new(format!("Failed to send frame: {}", e)))?;
323 self.stream.flush()
324 .map_err(|e| WebSocketError::new(format!("Failed to flush frame: {}", e)))?;
325
326 Ok(())
327 }
328
329 pub fn receive(&mut self) -> Result<Message, WebSocketError> {
331 loop {
332 let (opcode, payload) = self.receive_frame()?;
333
334 match opcode {
335 Opcode::Text => {
336 let text = String::from_utf8(payload)
337 .map_err(|e| WebSocketError::new(format!("Invalid UTF-8: {}", e)))?;
338 return Ok(Message::Text(text));
339 }
340 Opcode::Binary => {
341 return Ok(Message::Binary(payload));
342 }
343 Opcode::Close => {
344 return Ok(Message::Close);
345 }
346 Opcode::Ping => {
347 self.send_frame(Opcode::Pong, &payload)?;
349 }
351 Opcode::Pong => {
352 }
354 Opcode::Continuation => {
355 let text = String::from_utf8_lossy(&payload).to_string();
357 return Ok(Message::Text(text));
358 }
359 }
360 }
361 }
362
363 fn receive_frame(&mut self) -> Result<(Opcode, Vec<u8>), WebSocketError> {
365 let mut header = [0u8; 2];
367 self.read_exact(&mut header)?;
368
369 let _fin = (header[0] & 0x80) != 0;
370 let opcode = Opcode::from_u8(header[0] & 0x0F)
371 .ok_or_else(|| WebSocketError::new("Invalid opcode"))?;
372
373 let masked = (header[1] & 0x80) != 0;
374 let mut len = (header[1] & 0x7F) as usize;
375
376 if len == 126 {
378 let mut ext = [0u8; 2];
379 self.read_exact(&mut ext)?;
380 len = ((ext[0] as usize) << 8) | (ext[1] as usize);
381 } else if len == 127 {
382 let mut ext = [0u8; 8];
383 self.read_exact(&mut ext)?;
384 len = 0;
385 for &b in &ext {
386 len = (len << 8) | (b as usize);
387 }
388 }
389
390 let mask = if masked {
392 let mut m = [0u8; 4];
393 self.read_exact(&mut m)?;
394 Some(m)
395 } else {
396 None
397 };
398
399 let mut payload = vec![0u8; len];
401 if len > 0 {
402 self.read_exact(&mut payload)?;
403 }
404
405 if let Some(mask) = mask {
407 for (i, byte) in payload.iter_mut().enumerate() {
408 *byte ^= mask[i % 4];
409 }
410 }
411
412 Ok((opcode, payload))
413 }
414
415 fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), WebSocketError> {
417 let mut total = 0;
418 while total < buf.len() {
419 match self.stream.read(&mut buf[total..]) {
420 Ok(0) => return Err(WebSocketError::new("Connection closed")),
421 Ok(n) => total += n,
422 Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
423 Err(e) => return Err(WebSocketError::new(format!("Read error: {}", e))),
424 }
425 }
426 Ok(())
427 }
428
429 pub fn close(&mut self) -> Result<(), WebSocketError> {
431 let _ = self.send_close();
433 Ok(())
434 }
435}
436
437pub fn send_and_receive(url: &str, message: &str) -> Result<String, WebSocketError> {
441 let mut ws = WebSocket::connect(url)?;
442 ws.send_text(message)?;
443
444 let response = match ws.receive()? {
445 Message::Text(t) => t,
446 Message::Binary(b) => String::from_utf8_lossy(&b).to_string(),
447 Message::Close => String::new(),
448 Message::Ping(_) | Message::Pong(_) => {
449 match ws.receive()? {
451 Message::Text(t) => t,
452 Message::Binary(b) => String::from_utf8_lossy(&b).to_string(),
453 _ => String::new(),
454 }
455 }
456 };
457
458 ws.close()?;
459 Ok(response)
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn test_parse_url() {
468 let (secure, host, port, path) = WebSocket::parse_url("ws://example.com/path").unwrap();
469 assert!(!secure);
470 assert_eq!(host, "example.com");
471 assert_eq!(port, 80);
472 assert_eq!(path, "/path");
473
474 let (secure, host, port, path) = WebSocket::parse_url("wss://example.com:8443/api").unwrap();
475 assert!(secure);
476 assert_eq!(host, "example.com");
477 assert_eq!(port, 8443);
478 assert_eq!(path, "/api");
479 }
480
481 #[test]
482 fn test_compute_accept_key() {
483 let key = "dGhlIHNhbXBsZSBub25jZQ==";
485 let expected = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=";
486 assert_eq!(WebSocket::compute_accept_key(key), expected);
487 }
488}