use std::io::{Read, Write};
use std::net::TcpStream;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use purecrypto::hash::{Digest, Sha1};
use crate::error::{Error, Result};
use crate::tls::TlsStream;
use crate::url::Url;
const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
const OPCODE_CONT: u8 = 0x0;
const OPCODE_TEXT: u8 = 0x1;
const OPCODE_BINARY: u8 = 0x2;
const OPCODE_CLOSE: u8 = 0x8;
const OPCODE_PING: u8 = 0x9;
const OPCODE_PONG: u8 = 0xA;
const MAX_PAYLOAD_BYTES: u64 = 64 * 1024 * 1024;
pub fn fetch(url: &Url) -> Result<Vec<u8>> {
match url.scheme.as_str() {
"ws" => {
let mut sock = tcp_connect(url)?;
handshake(&mut sock, url)?;
read_data_and_close(&mut sock)
}
"wss" => {
let tcp = tcp_connect(url)?;
let mut tls = crate::tls::connect_over(tcp, &url.host)?;
handshake(&mut tls, url)?;
read_data_and_close(&mut tls)
}
other => Err(Error::UnsupportedScheme(other.to_string())),
}
}
fn tcp_connect(url: &Url) -> Result<TcpStream> {
let addr = format!("{}:{}", url.host, url.port);
let addrs: Vec<_> = std::net::ToSocketAddrs::to_socket_addrs(&addr)?.collect();
let first = addrs
.into_iter()
.next()
.ok_or_else(|| Error::InvalidUrl(url.host.clone()))?;
let stream = TcpStream::connect_timeout(&first, Duration::from_secs(30))?;
stream.set_read_timeout(Some(Duration::from_secs(60)))?;
stream.set_write_timeout(Some(Duration::from_secs(60)))?;
Ok(stream)
}
fn handshake<S: Read + Write>(stream: &mut S, url: &Url) -> Result<()> {
let key_bytes: [u8; 16] = random_16();
let key_b64 = base64_encode(&key_bytes);
let host_header =
if (url.scheme == "ws" && url.port == 80) || (url.scheme == "wss" && url.port == 443) {
url.host.clone()
} else {
format!("{}:{}", url.host, url.port)
};
let path = if url.path.is_empty() {
"/"
} else {
url.path.as_str()
};
let req = format!(
"GET {path} HTTP/1.1\r\n\
Host: {host_header}\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {key_b64}\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n"
);
stream.write_all(req.as_bytes())?;
stream.flush()?;
let mut buf: Vec<u8> = Vec::with_capacity(512);
loop {
let mut b = [0u8; 1];
let n = stream.read(&mut b)?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
buf.push(b[0]);
if buf.len() >= 4 && &buf[buf.len() - 4..] == b"\r\n\r\n" {
break;
}
if buf.len() > 64 * 1024 {
return Err(Error::BadResponse("handshake response too large".into()));
}
}
let head = std::str::from_utf8(&buf)
.map_err(|_| Error::BadResponse("non-utf8 handshake response".into()))?;
let mut lines = head.split("\r\n");
let status_line = lines
.next()
.ok_or_else(|| Error::BadResponse("empty handshake response".into()))?;
if !(status_line.starts_with("HTTP/1.1 101") || status_line.starts_with("HTTP/1.0 101")) {
return Err(Error::BadResponse(format!(
"expected 101 Switching Protocols, got: {status_line:?}"
)));
}
let mut upgrade_ok = false;
let mut connection_ok = false;
let mut accept_value: Option<String> = None;
for line in lines {
if line.is_empty() {
break;
}
let (k, v) = match line.split_once(':') {
Some((k, v)) => (k.trim(), v.trim()),
None => continue,
};
if k.eq_ignore_ascii_case("upgrade") {
if v.eq_ignore_ascii_case("websocket") {
upgrade_ok = true;
}
} else if k.eq_ignore_ascii_case("connection") {
if v.split(',')
.any(|t| t.trim().eq_ignore_ascii_case("upgrade"))
{
connection_ok = true;
}
} else if k.eq_ignore_ascii_case("sec-websocket-accept") {
accept_value = Some(v.to_string());
}
}
if !upgrade_ok {
return Err(Error::BadResponse(
"missing or wrong Upgrade header in handshake response".into(),
));
}
if !connection_ok {
return Err(Error::BadResponse(
"missing or wrong Connection header in handshake response".into(),
));
}
let accept = accept_value
.ok_or_else(|| Error::BadResponse("missing Sec-WebSocket-Accept header".into()))?;
let expected = derive_accept(&key_b64);
if accept != expected {
return Err(Error::BadResponse(format!(
"Sec-WebSocket-Accept mismatch: got {accept:?}, expected {expected:?}"
)));
}
Ok(())
}
fn read_data_and_close<S: Read + Write>(stream: &mut S) -> Result<Vec<u8>> {
let payload = loop {
let frame = read_frame(stream)?;
match frame.opcode {
OPCODE_TEXT | OPCODE_BINARY => break frame.payload,
OPCODE_PING => {
let pong = build_client_frame(OPCODE_PONG, &frame.payload);
stream.write_all(&pong)?;
stream.flush()?;
}
OPCODE_PONG => continue,
OPCODE_CLOSE => {
let _ = stream.write_all(&[0x88, 0x00]);
let _ = stream.flush();
return Ok(Vec::new());
}
OPCODE_CONT => {
return Err(Error::BadResponse(
"unexpected continuation frame before any data frame".into(),
));
}
other => {
return Err(Error::BadResponse(format!("unknown WS opcode 0x{other:x}")));
}
}
};
let close = build_client_frame(OPCODE_CLOSE, &[]);
let _ = stream.write_all(&close);
let _ = stream.flush();
Ok(payload)
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct Frame {
fin: bool,
opcode: u8,
payload: Vec<u8>,
}
fn read_frame<S: Read>(stream: &mut S) -> Result<Frame> {
let mut header = [0u8; 2];
read_exact(stream, &mut header)?;
let fin = (header[0] & 0x80) != 0;
if (header[0] & 0x70) != 0 {
return Err(Error::BadResponse(
"non-zero RSV bits on incoming WS frame".into(),
));
}
let opcode = header[0] & 0x0F;
let masked = (header[1] & 0x80) != 0;
if masked {
return Err(Error::BadResponse(
"server-to-client frame is masked".into(),
));
}
let len7 = header[1] & 0x7F;
let payload_len: u64 = match len7 {
0..=125 => len7 as u64,
126 => {
let mut ext = [0u8; 2];
read_exact(stream, &mut ext)?;
u16::from_be_bytes(ext) as u64
}
127 => {
let mut ext = [0u8; 8];
read_exact(stream, &mut ext)?;
u64::from_be_bytes(ext)
}
_ => unreachable!(),
};
if payload_len > MAX_PAYLOAD_BYTES {
return Err(Error::BadResponse(format!(
"WS payload too large: {payload_len} bytes"
)));
}
let mut payload = vec![0u8; payload_len as usize];
if payload_len > 0 {
read_exact(stream, &mut payload)?;
}
Ok(Frame {
fin,
opcode,
payload,
})
}
fn build_client_frame(opcode: u8, payload: &[u8]) -> Vec<u8> {
let mask: [u8; 4] = {
let r = random_16();
[r[0], r[1], r[2], r[3]]
};
let mut out = Vec::with_capacity(2 + 8 + 4 + payload.len());
out.push(0x80 | (opcode & 0x0F)); let n = payload.len();
if n < 126 {
out.push(0x80 | (n as u8));
} else if n <= u16::MAX as usize {
out.push(0x80 | 126);
out.extend_from_slice(&(n as u16).to_be_bytes());
} else {
out.push(0x80 | 127);
out.extend_from_slice(&(n as u64).to_be_bytes());
}
out.extend_from_slice(&mask);
let start = out.len();
out.extend_from_slice(payload);
for (i, b) in out[start..].iter_mut().enumerate() {
*b ^= mask[i & 3];
}
out
}
fn read_exact<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<()> {
let mut got = 0;
while got < buf.len() {
let n = r.read(&mut buf[got..])?;
if n == 0 {
return Err(Error::UnexpectedEof);
}
got += n;
}
Ok(())
}
fn derive_accept(key_b64: &str) -> String {
let mut h = Sha1::new();
h.update(key_b64.as_bytes());
h.update(WS_GUID.as_bytes());
let digest = h.finalize();
base64_encode(digest.as_ref())
}
fn random_16() -> [u8; 16] {
use purecrypto::rng::{OsRng, RngCore};
let mut out = [0u8; 16];
if std::fs::File::open("/dev/urandom")
.and_then(|mut f| std::io::Read::read_exact(&mut f, &mut out))
.is_ok()
{
return out;
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(0));
let nanos = now.as_nanos() as u64;
let pid = std::process::id() as u64;
let mut state = nanos ^ (pid.wrapping_mul(0x9E37_79B9_7F4A_7C15));
for chunk in out.chunks_mut(8) {
state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^= z >> 31;
let bytes = z.to_le_bytes();
chunk.copy_from_slice(&bytes[..chunk.len()]);
}
let _ = OsRng;
let _ = <OsRng as RngCore>::next_u32;
out
}
pub(crate) fn base64_encode(input: &[u8]) -> String {
const ALPHA: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
let mut i = 0;
while i + 3 <= input.len() {
let b0 = input[i];
let b1 = input[i + 1];
let b2 = input[i + 2];
out.push(ALPHA[(b0 >> 2) as usize] as char);
out.push(ALPHA[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
out.push(ALPHA[(((b1 & 0x0F) << 2) | (b2 >> 6)) as usize] as char);
out.push(ALPHA[(b2 & 0x3F) as usize] as char);
i += 3;
}
let rem = input.len() - i;
if rem == 1 {
let b0 = input[i];
out.push(ALPHA[(b0 >> 2) as usize] as char);
out.push(ALPHA[((b0 & 0x03) << 4) as usize] as char);
out.push('=');
out.push('=');
} else if rem == 2 {
let b0 = input[i];
let b1 = input[i + 1];
out.push(ALPHA[(b0 >> 2) as usize] as char);
out.push(ALPHA[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
out.push(ALPHA[((b1 & 0x0F) << 2) as usize] as char);
out.push('=');
}
out
}
#[allow(dead_code)]
fn _tlsstream_in_scope_for_docs<S: Read + Write>(_: TlsStream<S>) {}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn base64_encode_hello() {
assert_eq!(base64_encode(b"hello"), "aGVsbG8=");
}
#[test]
fn base64_encode_rfc4648_vectors() {
assert_eq!(base64_encode(b""), "");
assert_eq!(base64_encode(b"f"), "Zg==");
assert_eq!(base64_encode(b"fo"), "Zm8=");
assert_eq!(base64_encode(b"foo"), "Zm9v");
assert_eq!(base64_encode(b"foob"), "Zm9vYg==");
assert_eq!(base64_encode(b"fooba"), "Zm9vYmE=");
assert_eq!(base64_encode(b"foobar"), "Zm9vYmFy");
}
#[test]
fn rfc6455_accept_derivation() {
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let expected = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=";
assert_eq!(derive_accept(key), expected);
}
#[test]
fn parse_short_text_frame() {
let bytes = [0x81, 0x05, b'h', b'e', b'l', b'l', b'o'];
let mut cur = Cursor::new(&bytes[..]);
let f = read_frame(&mut cur).expect("frame parses");
assert!(f.fin);
assert_eq!(f.opcode, OPCODE_TEXT);
assert_eq!(f.payload, b"hello");
}
#[test]
fn parse_16bit_length_frame() {
let mut bytes: Vec<u8> = vec![0x82, 126, 0x00, 200];
bytes.extend(std::iter::repeat(b'A').take(200));
let mut cur = Cursor::new(bytes);
let f = read_frame(&mut cur).expect("frame parses");
assert_eq!(f.opcode, OPCODE_BINARY);
assert_eq!(f.payload.len(), 200);
assert!(f.payload.iter().all(|&b| b == b'A'));
}
#[test]
fn reject_masked_server_frame() {
let bytes = [0x81, 0x80, 0, 0, 0, 0];
let mut cur = Cursor::new(&bytes[..]);
let err = read_frame(&mut cur).expect_err("masked server frame must be rejected");
match err {
Error::BadResponse(_) => {}
other => panic!("wrong error: {other:?}"),
}
}
#[test]
fn build_close_frame_short_payload() {
let frame = build_client_frame(OPCODE_CLOSE, &[]);
assert_eq!(frame.len(), 6);
assert_eq!(frame[0], 0x88);
assert_eq!(frame[1], 0x80); }
#[test]
fn build_text_frame_masks_payload() {
let payload = b"hi";
let frame = build_client_frame(OPCODE_TEXT, payload);
assert_eq!(frame.len(), 8);
assert_eq!(frame[0], 0x81);
assert_eq!(frame[1], 0x82);
let mask = [frame[2], frame[3], frame[4], frame[5]];
let unmasked: Vec<u8> = frame[6..]
.iter()
.enumerate()
.map(|(i, &b)| b ^ mask[i & 3])
.collect();
assert_eq!(unmasked, payload);
}
#[test]
fn build_frame_uses_16bit_length_for_medium_payload() {
let payload = vec![0u8; 200];
let frame = build_client_frame(OPCODE_BINARY, &payload);
assert_eq!(frame[0], 0x82);
assert_eq!(frame[1], 0x80 | 126);
let len = u16::from_be_bytes([frame[2], frame[3]]);
assert_eq!(len, 200);
assert_eq!(frame.len(), 208);
}
#[test]
fn build_frame_uses_64bit_length_for_large_payload() {
let payload = vec![0u8; 70_000];
let frame = build_client_frame(OPCODE_BINARY, &payload);
assert_eq!(frame[1], 0x80 | 127);
let len = u64::from_be_bytes([
frame[2], frame[3], frame[4], frame[5], frame[6], frame[7], frame[8], frame[9],
]);
assert_eq!(len, 70_000);
}
#[test]
fn random_16_is_nonzero() {
let r = random_16();
assert_ne!(r, [0u8; 16]);
}
}