use std::collections::HashMap;
#[cfg(feature = "rustls")]
use std::sync::Arc;
use std::thread;
use async_net::TcpListener;
#[cfg(feature = "rustls")]
use async_tls::TlsAcceptor;
use bytes::Bytes;
use futures_lite::future::block_on;
use futures_lite::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(feature = "rustls")]
use rcgen::generate_simple_self_signed;
#[cfg(feature = "rustls")]
use rustls::ServerConfig;
use sha1::{Digest, Sha1};
use ugi::{Client, CookieJar, Url, WebSocketMessage};
const WS_ACCEPT_MAGIC: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
#[test]
fn websocket_echo_roundtrip_works() {
let url = block_on(spawn_ws_server(|mut stream, request| async move {
assert_eq!(request.request_line, "GET /echo HTTP/1.1");
assert_eq!(request.header("sec-websocket-protocol"), Some("chat.v1"));
let key = request.header("sec-websocket-key").unwrap().to_owned();
send_handshake_response(&mut stream, &key, Some("chat.v1"), None)
.await
.unwrap();
let frame = read_client_frame(&mut stream).await.unwrap();
assert_eq!(frame.opcode, 0x1);
assert_eq!(String::from_utf8(frame.payload).unwrap(), "hello websocket");
send_server_frame(&mut stream, 0x1, b"hello websocket")
.await
.unwrap();
let close = read_client_frame(&mut stream).await.unwrap();
assert_eq!(close.opcode, 0x8);
}))
.unwrap();
let mut ws =
block_on(async { ugi::ws(format!("{url}/echo")).protocol("chat.v1").await }).unwrap();
assert_eq!(ws.selected_protocol(), Some("chat.v1"));
block_on(ws.send_text("hello websocket")).unwrap();
let message = block_on(ws.next()).unwrap().unwrap();
assert_eq!(
message,
WebSocketMessage::Text("hello websocket".to_owned())
);
block_on(ws.close()).unwrap();
}
#[test]
fn client_ws_inherits_base_url_headers_cookies_and_cookie_jar() {
let url = block_on(spawn_ws_server(|mut stream, request| async move {
assert_eq!(request.request_line, "GET /chat HTTP/1.1");
assert_eq!(request.header("authorization"), Some("Bearer token-123"));
let cookie = request.header("cookie").unwrap_or_default();
assert!(cookie.contains("manual=1"));
assert!(cookie.contains("jar_cookie=jar-value"));
let key = request.header("sec-websocket-key").unwrap().to_owned();
send_handshake_response(
&mut stream,
&key,
Some("chat.v1"),
Some("from_server=ok; Path=/"),
)
.await
.unwrap();
let close = read_client_frame(&mut stream).await.unwrap();
assert_eq!(close.opcode, 0x8);
}))
.unwrap();
let jar = CookieJar::new();
let cookie_url = Url::parse(format!("{url}/chat")).unwrap();
jar.store_set_cookie(&cookie_url, "jar_cookie=jar-value; Path=/")
.unwrap();
let client = Client::builder()
.base_url(&url)
.unwrap()
.header("authorization", "Bearer token-123")
.unwrap()
.cookie("manual", "1")
.cookie_jar(jar.clone())
.build()
.unwrap();
let mut ws = block_on(async { client.ws("/chat").protocol("chat.v1").await }).unwrap();
assert_eq!(ws.selected_protocol(), Some("chat.v1"));
block_on(ws.close()).unwrap();
let stored = jar
.get_cookie_header(&Url::parse(format!("{url}/anything")).unwrap())
.unwrap_or_default();
assert!(stored.contains("jar_cookie=jar-value"));
assert!(stored.contains("from_server=ok"));
}
#[test]
fn websocket_handles_ping_and_close() {
let url = block_on(spawn_ws_server(|mut stream, request| async move {
let key = request.header("sec-websocket-key").unwrap().to_owned();
send_handshake_response(&mut stream, &key, None, None)
.await
.unwrap();
send_server_frame(&mut stream, 0x9, b"abc").await.unwrap();
let pong = read_client_frame(&mut stream).await.unwrap();
assert_eq!(pong.opcode, 0xA);
assert_eq!(pong.payload, b"abc");
send_server_frame(&mut stream, 0x8, &close_payload(1000, "bye"))
.await
.unwrap();
let close_reply = read_client_frame(&mut stream).await.unwrap();
assert_eq!(close_reply.opcode, 0x8);
assert_eq!(close_reply.payload, close_payload(1000, "bye"));
}))
.unwrap();
let mut ws = block_on(async { ugi::ws(url).await }).unwrap();
let ping = block_on(ws.next()).unwrap().unwrap();
assert_eq!(ping, WebSocketMessage::Ping(Bytes::from_static(b"abc")));
let close = block_on(ws.next()).unwrap().unwrap();
assert_eq!(
close,
WebSocketMessage::Close(Some(ugi::CloseFrame {
code: 1000,
reason: "bye".to_owned(),
}))
);
assert!(block_on(ws.next()).is_none());
}
#[test]
fn websocket_reassembles_fragmented_text_message() {
let url = block_on(spawn_ws_server(|mut stream, request| async move {
let key = request.header("sec-websocket-key").unwrap().to_owned();
send_handshake_response(&mut stream, &key, None, None)
.await
.unwrap();
send_server_frame_fin(&mut stream, false, 0x1, b"hello ")
.await
.unwrap();
send_server_frame_fin(&mut stream, true, 0x0, b"world")
.await
.unwrap();
let _close = read_client_frame(&mut stream).await.unwrap();
}))
.unwrap();
let mut ws = block_on(async { ugi::ws(url).await }).unwrap();
let message = block_on(ws.next()).unwrap().unwrap();
assert_eq!(
message,
WebSocketMessage::Text("hello world".to_owned()),
"fragmented text frames must be reassembled into a single message"
);
block_on(ws.close()).unwrap();
}
#[test]
fn websocket_reassembles_fragmented_binary_message() {
let url = block_on(spawn_ws_server(|mut stream, request| async move {
let key = request.header("sec-websocket-key").unwrap().to_owned();
send_handshake_response(&mut stream, &key, None, None)
.await
.unwrap();
send_server_frame_fin(&mut stream, false, 0x2, &[0x01, 0x02])
.await
.unwrap();
send_server_frame_fin(&mut stream, false, 0x0, &[0x03, 0x04])
.await
.unwrap();
send_server_frame_fin(&mut stream, true, 0x0, &[0x05])
.await
.unwrap();
let _close = read_client_frame(&mut stream).await.unwrap();
}))
.unwrap();
let mut ws = block_on(async { ugi::ws(url).await }).unwrap();
let message = block_on(ws.next()).unwrap().unwrap();
assert_eq!(
message,
WebSocketMessage::Binary(Bytes::from_static(&[0x01, 0x02, 0x03, 0x04, 0x05])),
"fragmented binary frames must be reassembled into a single message"
);
block_on(ws.close()).unwrap();
}
#[test]
fn websocket_interleaves_control_frame_in_fragmented_message() {
let url = block_on(spawn_ws_server(|mut stream, request| async move {
let key = request.header("sec-websocket-key").unwrap().to_owned();
send_handshake_response(&mut stream, &key, None, None)
.await
.unwrap();
send_server_frame_fin(&mut stream, false, 0x1, b"part-one-")
.await
.unwrap();
send_server_frame_fin(&mut stream, true, 0x9, b"ping-data")
.await
.unwrap();
let pong = read_client_frame(&mut stream).await.unwrap();
assert_eq!(pong.opcode, 0xA, "client must respond to ping with pong");
assert_eq!(pong.payload, b"ping-data");
send_server_frame_fin(&mut stream, true, 0x0, b"part-two")
.await
.unwrap();
let _close = read_client_frame(&mut stream).await.unwrap();
}))
.unwrap();
let mut ws = block_on(async { ugi::ws(url).await }).unwrap();
let ping = block_on(ws.next()).unwrap().unwrap();
assert_eq!(
ping,
WebSocketMessage::Ping(Bytes::from_static(b"ping-data"))
);
let message = block_on(ws.next()).unwrap().unwrap();
assert_eq!(
message,
WebSocketMessage::Text("part-one-part-two".to_owned()),
"reassembled message must combine all fragments"
);
block_on(ws.close()).unwrap();
}
#[test]
#[cfg(feature = "rustls")]
fn wss_roundtrip_works_with_invalid_cert_override() {
let url = block_on(spawn_wss_server(|mut stream, request| async move {
assert_eq!(request.request_line, "GET /secure HTTP/1.1");
let key = request.header("sec-websocket-key").unwrap().to_owned();
send_handshake_response(&mut stream, &key, None, None)
.await
.unwrap();
let frame = read_client_frame(&mut stream).await.unwrap();
assert_eq!(frame.opcode, 0x1);
assert_eq!(String::from_utf8(frame.payload).unwrap(), "secure hello");
send_server_frame(&mut stream, 0x1, b"secure hello")
.await
.unwrap();
let close = read_client_frame(&mut stream).await.unwrap();
assert_eq!(close.opcode, 0x8);
}))
.unwrap();
let mut ws = block_on(async {
ugi::ws(format!("{url}/secure"))
.danger_accept_invalid_certs(true)
.await
})
.unwrap();
block_on(ws.send_text("secure hello")).unwrap();
let message = block_on(ws.next()).unwrap().unwrap();
assert_eq!(message, WebSocketMessage::Text("secure hello".to_owned()));
block_on(ws.close()).unwrap();
}
#[test]
#[cfg(all(feature = "rustls", feature = "btls-backend"))]
fn wss_roundtrip_works_with_boring_backend() {
let url = block_on(spawn_wss_server(|mut stream, request| async move {
assert_eq!(request.request_line, "GET /secure HTTP/1.1");
let key = request.header("sec-websocket-key").unwrap().to_owned();
send_handshake_response(&mut stream, &key, None, None)
.await
.unwrap();
let frame = read_client_frame(&mut stream).await.unwrap();
assert_eq!(frame.opcode, 0x1);
assert_eq!(String::from_utf8(frame.payload).unwrap(), "boring hello");
send_server_frame(&mut stream, 0x1, b"boring hello")
.await
.unwrap();
let close = read_client_frame(&mut stream).await.unwrap();
assert_eq!(close.opcode, 0x8);
}))
.unwrap();
let mut ws = block_on(async {
ugi::ws(format!("{url}/secure"))
.tls_backend(ugi::TlsBackend::Boring)
.danger_accept_invalid_certs(true)
.await
})
.unwrap();
block_on(ws.send_text("boring hello")).unwrap();
let message = block_on(ws.next()).unwrap().unwrap();
assert_eq!(message, WebSocketMessage::Text("boring hello".to_owned()));
block_on(ws.close()).unwrap();
}
#[derive(Clone)]
struct HandshakeRequest {
request_line: String,
headers: HashMap<String, String>,
}
impl HandshakeRequest {
fn header(&self, name: &str) -> Option<&str> {
self.headers
.get(&name.to_ascii_lowercase())
.map(String::as_str)
}
}
struct ClientFrame {
opcode: u8,
payload: Vec<u8>,
}
async fn spawn_ws_server<F, Fut>(handler: F) -> ugi::Result<String>
where
F: FnOnce(async_net::TcpStream, HandshakeRequest) -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to bind ws test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to inspect ws test server",
err,
)
})?;
thread::spawn(move || {
block_on(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let request = read_handshake_request(&mut stream).await.unwrap();
handler(stream, request).await;
});
});
Ok(format!("ws://{}", addr))
}
#[cfg(feature = "rustls")]
async fn spawn_wss_server<F, Fut>(handler: F) -> ugi::Result<String>
where
F: FnOnce(async_tls::server::TlsStream<async_net::TcpStream>, HandshakeRequest) -> Fut
+ Send
+ 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let cert = generate_simple_self_signed(vec!["localhost".into()]).map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to generate ws tls certificate",
err,
)
})?;
let cert_der = rustls::Certificate(cert.cert.der().to_vec());
let key_der = rustls::PrivateKey(cert.signing_key.serialize_der());
let server_config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert_der], key_der)
.map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to build ws tls server config",
err,
)
})?;
let acceptor = TlsAcceptor::from(Arc::new(server_config));
let listener = TcpListener::bind(("127.0.0.1", 0)).await.map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to bind wss test server",
err,
)
})?;
let addr = listener.local_addr().map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to inspect wss test server",
err,
)
})?;
thread::spawn(move || {
block_on(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut stream = acceptor.accept(stream).await.unwrap();
let request = read_handshake_request(&mut stream).await.unwrap();
handler(stream, request).await;
});
});
Ok(format!("wss://localhost:{}", addr.port()))
}
async fn read_handshake_request<S>(stream: &mut S) -> ugi::Result<HandshakeRequest>
where
S: AsyncRead + Unpin,
{
let mut buffer = Vec::new();
loop {
let mut scratch = [0_u8; 1024];
let read = stream.read(&mut scratch).await.map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to read ws handshake request",
err,
)
})?;
if read == 0 {
return Err(ugi::Error::new(
ugi::ErrorKind::Transport,
"incomplete ws handshake request",
));
}
buffer.extend_from_slice(&scratch[..read]);
if buffer.windows(4).any(|window| window == b"\r\n\r\n") {
break;
}
}
let end = buffer
.windows(4)
.position(|window| window == b"\r\n\r\n")
.unwrap();
let text = std::str::from_utf8(&buffer[..end]).map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Decode,
"ws handshake request is not utf-8",
err,
)
})?;
let mut lines = text.split("\r\n");
let request_line = lines.next().unwrap_or_default().to_owned();
let mut headers = HashMap::new();
for line in lines {
if line.is_empty() {
continue;
}
if let Some((name, value)) = line.split_once(':') {
headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_owned());
}
}
Ok(HandshakeRequest {
request_line,
headers,
})
}
async fn send_handshake_response<S>(
stream: &mut S,
websocket_key: &str,
selected_protocol: Option<&str>,
set_cookie: Option<&str>,
) -> ugi::Result<()>
where
S: AsyncWrite + Unpin,
{
let accept = websocket_accept_value(websocket_key);
let mut response = format!(
"HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\nSec-WebSocket-Accept: {accept}\r\n"
);
if let Some(protocol) = selected_protocol {
response.push_str(&format!("Sec-WebSocket-Protocol: {protocol}\r\n"));
}
if let Some(set_cookie) = set_cookie {
response.push_str(&format!("Set-Cookie: {set_cookie}\r\n"));
}
response.push_str("\r\n");
stream.write_all(response.as_bytes()).await.map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to write ws handshake response",
err,
)
})?;
stream.flush().await.map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to flush ws handshake response",
err,
)
})?;
Ok(())
}
async fn send_server_frame<S>(stream: &mut S, opcode: u8, payload: &[u8]) -> ugi::Result<()>
where
S: AsyncWrite + Unpin,
{
send_server_frame_fin(stream, true, opcode, payload).await
}
async fn send_server_frame_fin<S>(
stream: &mut S,
fin: bool,
opcode: u8,
payload: &[u8],
) -> ugi::Result<()>
where
S: AsyncWrite + Unpin,
{
let fin_bit: u8 = if fin { 0x80 } else { 0x00 };
let mut frame = Vec::with_capacity(payload.len() + 10);
frame.push(fin_bit | (opcode & 0x0F));
if payload.len() <= 125 {
frame.push(payload.len() as u8);
} else if payload.len() <= u16::MAX as usize {
frame.push(126);
frame.extend_from_slice(&(payload.len() as u16).to_be_bytes());
} else {
frame.push(127);
frame.extend_from_slice(&(payload.len() as u64).to_be_bytes());
}
frame.extend_from_slice(payload);
stream.write_all(&frame).await.map_err(|err| {
ugi::Error::with_source(ugi::ErrorKind::Transport, "failed to write ws frame", err)
})?;
stream.flush().await.map_err(|err| {
ugi::Error::with_source(ugi::ErrorKind::Transport, "failed to flush ws frame", err)
})?;
Ok(())
}
async fn read_client_frame<S>(stream: &mut S) -> ugi::Result<ClientFrame>
where
S: AsyncRead + Unpin,
{
let mut header = [0_u8; 2];
stream.read_exact(&mut header).await.map_err(|err| {
ugi::Error::with_source(ugi::ErrorKind::Transport, "failed to read ws frame", err)
})?;
let opcode = header[0] & 0x0F;
let masked = header[1] & 0x80 != 0;
assert!(masked, "client websocket frames must be masked");
let payload_len = match header[1] & 0x7F {
len @ 0..=125 => len as usize,
126 => {
let mut extended = [0_u8; 2];
stream.read_exact(&mut extended).await.map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to read ws extended frame length",
err,
)
})?;
u16::from_be_bytes(extended) as usize
}
127 => {
let mut extended = [0_u8; 8];
stream.read_exact(&mut extended).await.map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to read ws extended frame length",
err,
)
})?;
u64::from_be_bytes(extended) as usize
}
_ => unreachable!(),
};
let mut mask = [0_u8; 4];
stream.read_exact(&mut mask).await.map_err(|err| {
ugi::Error::with_source(ugi::ErrorKind::Transport, "failed to read ws mask", err)
})?;
let mut payload = vec![0_u8; payload_len];
stream.read_exact(&mut payload).await.map_err(|err| {
ugi::Error::with_source(
ugi::ErrorKind::Transport,
"failed to read ws frame payload",
err,
)
})?;
for (index, byte) in payload.iter_mut().enumerate() {
*byte ^= mask[index % mask.len()];
}
Ok(ClientFrame { opcode, payload })
}
fn close_payload(code: u16, reason: &str) -> Vec<u8> {
let mut payload = Vec::with_capacity(2 + reason.len());
payload.extend_from_slice(&code.to_be_bytes());
payload.extend_from_slice(reason.as_bytes());
payload
}
fn websocket_accept_value(websocket_key: &str) -> String {
let mut hasher = Sha1::new();
hasher.update(websocket_key.as_bytes());
hasher.update(WS_ACCEPT_MAGIC.as_bytes());
encode_base64(&hasher.finalize())
}
fn encode_base64(bytes: &[u8]) -> String {
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut output = String::new();
let mut chunks = bytes.chunks_exact(3);
for chunk in &mut chunks {
let n = ((chunk[0] as u32) << 16) | ((chunk[1] as u32) << 8) | chunk[2] as u32;
output.push(TABLE[((n >> 18) & 0x3F) as usize] as char);
output.push(TABLE[((n >> 12) & 0x3F) as usize] as char);
output.push(TABLE[((n >> 6) & 0x3F) as usize] as char);
output.push(TABLE[(n & 0x3F) as usize] as char);
}
let rem = chunks.remainder();
if !rem.is_empty() {
let first = rem[0] as u32;
let second = rem.get(1).copied().unwrap_or_default() as u32;
let n = (first << 16) | (second << 8);
output.push(TABLE[((n >> 18) & 0x3F) as usize] as char);
output.push(TABLE[((n >> 12) & 0x3F) as usize] as char);
if rem.len() == 2 {
output.push(TABLE[((n >> 6) & 0x3F) as usize] as char);
output.push('=');
} else {
output.push('=');
output.push('=');
}
}
output
}