use std::io::{self, Read, Write};
use crate::encode::to_base64;
use crate::http::HttpRequest;
use crate::sha1::sha1;
const WS_MAGIC: &str = "258EAFA5-E914-47DA-95CA-5AB5DC11D045";
pub(crate) const OPCODE_TEXT: u8 = 0x1;
pub(crate) const OPCODE_CLOSE: u8 = 0x8;
pub(crate) const OPCODE_PING: u8 = 0x9;
#[allow(dead_code)]
pub(crate) const OPCODE_PONG: u8 = 0xA;
pub struct WsFrame {
pub opcode: u8,
pub payload: Vec<u8>,
}
pub fn is_upgrade(req: &HttpRequest) -> bool {
req.headers
.get("upgrade")
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
}
pub fn do_handshake(stream: &mut dyn Write, req: &HttpRequest) -> io::Result<()> {
let key = req
.headers
.get("sec-websocket-key")
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Missing Sec-WebSocket-Key"))?;
let accept = compute_accept(key);
write!(
stream,
"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: {}\r\n\
\r\n",
accept
)?;
stream.flush()
}
fn compute_accept(key: &str) -> String {
let combined = format!("{}{}", key, WS_MAGIC);
let hash = sha1(combined.as_bytes());
to_base64(&hash)
}
pub fn read_frame(stream: &mut dyn Read) -> io::Result<WsFrame> {
let mut head = [0u8; 2];
stream.read_exact(&mut head)?;
let _fin = head[0] & 0x80 != 0;
let opcode = head[0] & 0x0F;
let masked = head[1] & 0x80 != 0;
let len_byte = head[1] & 0x7F;
let payload_len: usize = if len_byte <= 125 {
len_byte as usize
} else if len_byte == 126 {
let mut buf = [0u8; 2];
stream.read_exact(&mut buf)?;
u16::from_be_bytes(buf) as usize
} else {
let mut buf = [0u8; 8];
stream.read_exact(&mut buf)?;
u64::from_be_bytes(buf) as usize
};
let mask_key = if masked {
let mut key = [0u8; 4];
stream.read_exact(&mut key)?;
Some(key)
} else {
None
};
let mut payload = vec![0u8; payload_len];
if payload_len > 0 {
stream.read_exact(&mut payload)?;
}
if let Some(key) = mask_key {
for i in 0..payload.len() {
payload[i] ^= key[i % 4];
}
}
Ok(WsFrame { opcode, payload })
}
pub fn write_text_frame(stream: &mut dyn Write, text: &str) -> io::Result<()> {
write_frame(stream, OPCODE_TEXT, text.as_bytes())
}
pub fn write_close_frame(stream: &mut dyn Write) -> io::Result<()> {
write_frame(stream, OPCODE_CLOSE, &[])
}
pub fn write_pong_frame(stream: &mut dyn Write, payload: &[u8]) -> io::Result<()> {
write_frame(stream, OPCODE_PONG, payload)
}
fn write_frame(stream: &mut dyn Write, opcode: u8, data: &[u8]) -> io::Result<()> {
stream.write_all(&[0x80 | opcode])?;
let len = data.len();
if len <= 125 {
stream.write_all(&[len as u8])?;
} else if len <= 65535 {
stream.write_all(&[126])?;
stream.write_all(&(len as u16).to_be_bytes())?;
} else {
stream.write_all(&[127])?;
stream.write_all(&(len as u64).to_be_bytes())?;
}
stream.write_all(data)?;
stream.flush()
}
pub fn run_ws_loop(
read_stream: &mut dyn Read,
write_stream: &mut dyn Write,
mut on_text: impl FnMut(&str),
) -> io::Result<()> {
loop {
let frame = match read_frame(read_stream) {
Ok(f) => f,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e),
};
match frame.opcode {
OPCODE_TEXT => {
if let Ok(text) = std::str::from_utf8(&frame.payload) {
on_text(text);
}
}
OPCODE_PING => {
let _ = write_pong_frame(write_stream, &frame.payload);
}
OPCODE_CLOSE => {
let _ = write_close_frame(write_stream);
break;
}
_ => {}
}
}
Ok(())
}
fn parse_frame_from_buf(buf: &[u8]) -> Option<(WsFrame, usize)> {
if buf.len() < 2 {
return None;
}
let opcode = buf[0] & 0x0F;
let masked = buf[1] & 0x80 != 0;
let len_byte = buf[1] & 0x7F;
let mut pos = 2;
let payload_len: usize = if len_byte <= 125 {
len_byte as usize
} else if len_byte == 126 {
if buf.len() < pos + 2 {
return None;
}
let len = u16::from_be_bytes([buf[pos], buf[pos + 1]]) as usize;
pos += 2;
len
} else {
if buf.len() < pos + 8 {
return None;
}
let mut arr = [0u8; 8];
arr.copy_from_slice(&buf[pos..pos + 8]);
let len = u64::from_be_bytes(arr) as usize;
pos += 8;
len
};
let mask_key = if masked {
if buf.len() < pos + 4 {
return None;
}
let key = [buf[pos], buf[pos + 1], buf[pos + 2], buf[pos + 3]];
pos += 4;
Some(key)
} else {
None
};
if buf.len() < pos + payload_len {
return None;
}
let mut payload = buf[pos..pos + payload_len].to_vec();
pos += payload_len;
if let Some(key) = mask_key {
for i in 0..payload.len() {
payload[i] ^= key[i % 4];
}
}
Some((WsFrame { opcode, payload }, pos))
}
pub struct WsBuf {
buf: Vec<u8>,
}
impl WsBuf {
pub fn new() -> Self {
WsBuf {
buf: Vec::with_capacity(4096),
}
}
pub fn try_read_frame(&mut self, stream: &mut dyn Read) -> io::Result<Option<WsFrame>> {
if let Some((frame, consumed)) = parse_frame_from_buf(&self.buf) {
self.buf.drain(..consumed);
return Ok(Some(frame));
}
let mut tmp = [0u8; 4096];
match stream.read(&mut tmp) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed",
));
}
Ok(n) => {
self.buf.extend_from_slice(&tmp[..n]);
}
Err(e)
if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut =>
{
return Ok(None);
}
Err(e) => return Err(e),
}
if let Some((frame, consumed)) = parse_frame_from_buf(&self.buf) {
self.buf.drain(..consumed);
return Ok(Some(frame));
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_accept_rfc() {
let accept = compute_accept("dGhlIHNhbXBsZSBub25jZQ==");
assert_eq!(accept, "RyVTkfbvgIu+vAZLbuzyhbcrH/0=");
}
#[test]
fn write_read_text_frame() {
let mut buf = Vec::new();
write_text_frame(&mut buf, "hello").unwrap();
let frame = read_frame(&mut &buf[..]).unwrap();
assert_eq!(frame.opcode, OPCODE_TEXT);
assert_eq!(frame.payload, b"hello");
}
#[test]
fn write_read_large_frame() {
let text = "x".repeat(300);
let mut buf = Vec::new();
write_text_frame(&mut buf, &text).unwrap();
let frame = read_frame(&mut &buf[..]).unwrap();
assert_eq!(frame.opcode, OPCODE_TEXT);
assert_eq!(frame.payload.len(), 300);
}
#[test]
fn parse_frame_from_buf_complete() {
let mut data = Vec::new();
write_text_frame(&mut data, "hello").unwrap();
let result = parse_frame_from_buf(&data);
assert!(result.is_some());
let (frame, consumed) = result.unwrap();
assert_eq!(frame.opcode, OPCODE_TEXT);
assert_eq!(frame.payload, b"hello");
assert_eq!(consumed, data.len());
}
#[test]
fn parse_frame_from_buf_incomplete() {
let mut data = Vec::new();
write_text_frame(&mut data, "hello").unwrap();
assert!(parse_frame_from_buf(&data[..1]).is_none());
assert!(parse_frame_from_buf(&data[..3]).is_none());
}
#[test]
fn wsbuf_try_read_frame_wouldblock() {
use std::io;
struct WouldBlockReader;
impl Read for WouldBlockReader {
fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::WouldBlock, "would block"))
}
}
let mut ws_buf = WsBuf::new();
let result = ws_buf.try_read_frame(&mut WouldBlockReader);
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn wsbuf_try_read_frame_complete() {
let mut data = Vec::new();
write_text_frame(&mut data, "test").unwrap();
let mut ws_buf = WsBuf::new();
let mut cursor = io::Cursor::new(data);
let result = ws_buf.try_read_frame(&mut cursor).unwrap();
assert!(result.is_some());
let frame = result.unwrap();
assert_eq!(frame.opcode, OPCODE_TEXT);
assert_eq!(frame.payload, b"test");
}
}