#![cfg(feature = "ws")]
mod common;
#[path = "support/ws_frame_io.rs"]
mod ws_frame_io;
#[path = "support/ws_text_helpers.rs"]
mod ws_text_helpers;
use camber::http::{self, Request, Response, Router, WsConn};
use camber::runtime;
use std::io::Write;
use std::net::TcpStream;
use std::time::Duration;
use ws_frame_io::read_until_double_crlf;
use ws_text_helpers::{read_ws_text_frame, write_ws_close_frame, write_ws_text_frame};
#[test]
fn websocket_proxy_forwards_text_messages() {
common::test_runtime()
.keepalive_timeout(Duration::from_millis(200))
.shutdown_timeout(Duration::from_secs(2))
.run(|| {
let mut backend = Router::new();
backend.ws("/echo", |_req: &Request, mut conn: WsConn| {
while let Some(msg) = conn.recv() {
if conn.send(&msg).is_err() {
break;
}
}
Ok(())
});
let backend_addr = common::spawn_server(backend);
let mut proxy = Router::new();
proxy.proxy("/ws", &format!("http://{backend_addr}"));
let proxy_addr = common::spawn_server(proxy);
let mut stream = TcpStream::connect(proxy_addr).unwrap();
stream
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let upgrade_req = format!(
"GET /ws/echo HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {key}\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n"
);
stream.write_all(upgrade_req.as_bytes()).unwrap();
let resp = read_until_double_crlf(&mut stream);
assert!(
resp.contains("101"),
"expected 101 switching protocols: {resp}"
);
write_ws_text_frame(&mut stream, "hello");
let msg = read_ws_text_frame(&mut stream);
assert_eq!(msg, "hello");
write_ws_close_frame(&mut stream);
runtime::request_shutdown();
})
.unwrap();
}
#[test]
fn websocket_proxy_handles_client_close() {
common::test_runtime()
.keepalive_timeout(Duration::from_millis(200))
.shutdown_timeout(Duration::from_secs(2))
.run(|| {
let mut backend = Router::new();
backend.ws("/chat", |_req: &Request, mut conn: WsConn| {
conn.send("one")?;
conn.send("two")?;
conn.send("three")?;
let _ = conn.recv();
Ok(())
});
let backend_addr = common::spawn_server(backend);
let mut proxy = Router::new();
proxy.proxy("/ws", &format!("http://{backend_addr}"));
let proxy_addr = common::spawn_server(proxy);
let mut stream = TcpStream::connect(proxy_addr).unwrap();
stream
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let upgrade_req = format!(
"GET /ws/chat HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {key}\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n"
);
stream.write_all(upgrade_req.as_bytes()).unwrap();
let resp = read_until_double_crlf(&mut stream);
assert!(resp.contains("101"), "expected 101: {resp}");
let m1 = read_ws_text_frame(&mut stream);
let m2 = read_ws_text_frame(&mut stream);
let m3 = read_ws_text_frame(&mut stream);
assert_eq!(
[m1.as_str(), m2.as_str(), m3.as_str()],
["one", "two", "three"]
);
write_ws_close_frame(&mut stream);
runtime::request_shutdown();
})
.unwrap();
}
#[test]
fn websocket_proxy_coexists_with_http_proxy() {
common::test_runtime()
.keepalive_timeout(Duration::from_millis(200))
.shutdown_timeout(Duration::from_secs(2))
.run(|| {
let mut backend = Router::new();
backend.get("/hello", |_req: &Request| async {
Response::text(200, "http-ok")
});
backend.ws("/echo", |_req: &Request, mut conn: WsConn| {
while let Some(msg) = conn.recv() {
if conn.send(&msg).is_err() {
break;
}
}
Ok(())
});
let backend_addr = common::spawn_server(backend);
let mut proxy = Router::new();
proxy.proxy("/api", &format!("http://{backend_addr}"));
let proxy_addr = common::spawn_server(proxy);
let resp =
common::block_on(http::get(&format!("http://{proxy_addr}/api/hello"))).unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body(), "http-ok");
let mut stream = TcpStream::connect(proxy_addr).unwrap();
stream
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let upgrade_req = format!(
"GET /api/echo HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {key}\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n"
);
stream.write_all(upgrade_req.as_bytes()).unwrap();
let resp = read_until_double_crlf(&mut stream);
assert!(
resp.contains("101"),
"expected 101 for WS through proxy: {resp}"
);
write_ws_text_frame(&mut stream, "ping");
let msg = read_ws_text_frame(&mut stream);
assert_eq!(msg, "ping");
write_ws_close_frame(&mut stream);
runtime::request_shutdown();
})
.unwrap();
}
#[test]
fn websocket_proxy_rejects_cross_host_origin_before_upstream_upgrade() {
common::test_runtime()
.keepalive_timeout(Duration::from_millis(200))
.shutdown_timeout(Duration::from_secs(2))
.run(|| {
let mut backend = Router::new();
backend.ws("/echo", |_req: &Request, mut conn: WsConn| {
conn.send("should not reach")?;
Ok(())
});
let backend_addr = common::spawn_server(backend);
let mut proxy = Router::new();
proxy.proxy("/ws", &format!("http://{backend_addr}"));
let proxy_addr = common::spawn_server(proxy);
let mut stream = TcpStream::connect(proxy_addr).unwrap();
stream
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let upgrade_req = format!(
"GET /ws/echo HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Origin: http://evil.example.com\r\n\
Sec-WebSocket-Key: {key}\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n"
);
stream.write_all(upgrade_req.as_bytes()).unwrap();
let resp = read_until_double_crlf(&mut stream);
assert!(
resp.contains("403"),
"expected 403 for cross-host origin on proxied WS, got: {resp}"
);
runtime::request_shutdown();
})
.unwrap();
}
#[test]
fn ws_proxy_forwards_sec_websocket_protocol() {
common::test_runtime()
.keepalive_timeout(Duration::from_millis(200))
.shutdown_timeout(Duration::from_secs(2))
.run(|| {
let mut backend = Router::new();
backend.ws("/echo", |req: &Request, mut conn: WsConn| {
let proto = req
.headers()
.find(|(k, _)| k.eq_ignore_ascii_case("sec-websocket-protocol"))
.map(|(_, v)| v.to_owned())
.unwrap_or_else(|| "none".to_owned());
conn.send(&proto)?;
Ok(())
});
let backend_addr = common::spawn_server(backend);
let mut proxy = Router::new();
proxy.proxy("/ws", &format!("http://{backend_addr}"));
let proxy_addr = common::spawn_server(proxy);
let mut stream = TcpStream::connect(proxy_addr).unwrap();
stream
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let upgrade_req = format!(
"GET /ws/echo HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {key}\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Protocol: graphql-ws\r\n\
\r\n"
);
stream.write_all(upgrade_req.as_bytes()).unwrap();
let resp = read_until_double_crlf(&mut stream);
assert!(
resp.contains("101"),
"expected 101 switching protocols: {resp}"
);
let lower = resp.to_lowercase();
assert!(
lower.contains("sec-websocket-protocol: graphql-ws"),
"expected Sec-WebSocket-Protocol in 101 response: {resp}"
);
let msg = read_ws_text_frame(&mut stream);
assert_eq!(
msg, "graphql-ws",
"backend should receive Sec-WebSocket-Protocol header"
);
write_ws_close_frame(&mut stream);
runtime::request_shutdown();
})
.unwrap();
}
#[test]
fn ws_proxy_strips_spoofed_forwarded_headers() {
common::test_runtime()
.keepalive_timeout(Duration::from_millis(200))
.shutdown_timeout(Duration::from_secs(2))
.run(|| {
let mut backend = Router::new();
backend.ws("/echo", |req: &Request, mut conn: WsConn| {
let forwarded_for = req
.headers()
.find(|(k, _)| k.eq_ignore_ascii_case("x-forwarded-for"))
.map(|(_, v)| v.to_owned())
.unwrap_or_else(|| "none".to_owned());
conn.send(&forwarded_for)?;
Ok(())
});
let backend_addr = common::spawn_server(backend);
let mut proxy = Router::new();
proxy.proxy("/ws", &format!("http://{backend_addr}"));
let proxy_addr = common::spawn_server(proxy);
let mut stream = TcpStream::connect(proxy_addr).unwrap();
stream
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let upgrade_req = format!(
"GET /ws/echo HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {key}\r\n\
Sec-WebSocket-Version: 13\r\n\
X-Forwarded-For: 6.6.6.6\r\n\
\r\n"
);
stream.write_all(upgrade_req.as_bytes()).unwrap();
let resp = read_until_double_crlf(&mut stream);
assert!(
resp.contains("101"),
"expected 101 switching protocols: {resp}"
);
let msg = read_ws_text_frame(&mut stream);
assert_eq!(msg, "none", "spoofed forwarding header reached backend");
write_ws_close_frame(&mut stream);
runtime::request_shutdown();
})
.unwrap();
}
#[test]
fn websocket_proxy_rejects_invalid_backend_scheme() {
common::test_runtime()
.keepalive_timeout(Duration::from_millis(200))
.shutdown_timeout(Duration::from_secs(2))
.run(|| {
let mut proxy = Router::new();
proxy.proxy("/ws", "ftp://127.0.0.1:1");
let proxy_addr = common::spawn_server(proxy);
let mut stream = TcpStream::connect(proxy_addr).unwrap();
stream
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let upgrade_req = format!(
"GET /ws/echo HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {key}\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n"
);
stream.write_all(upgrade_req.as_bytes()).unwrap();
let resp = read_until_double_crlf(&mut stream);
assert!(
resp.contains("502"),
"unsupported scheme should produce 502, got: {resp}"
);
runtime::request_shutdown();
})
.unwrap();
}
#[test]
fn websocket_proxy_stream_upgrade_ignores_request_body_limit() {
common::test_runtime()
.keepalive_timeout(Duration::from_millis(200))
.shutdown_timeout(Duration::from_secs(2))
.run(|| {
let mut backend = Router::new();
backend.ws("/echo", |_req: &Request, mut conn: WsConn| {
while let Some(msg) = conn.recv() {
if conn.send(&msg).is_err() {
break;
}
}
Ok(())
});
let backend_addr = common::spawn_server(backend);
let mut proxy = Router::new().max_request_body(10);
proxy.proxy_stream("/ws", &format!("http://{backend_addr}"));
let proxy_addr = common::spawn_server(proxy);
let mut stream = TcpStream::connect(proxy_addr).unwrap();
stream
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let upgrade_req = format!(
"GET /ws/echo HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Content-Length: 99999\r\n\
Sec-WebSocket-Key: {key}\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n"
);
stream.write_all(upgrade_req.as_bytes()).unwrap();
let resp = read_until_double_crlf(&mut stream);
assert!(
resp.contains("101"),
"expected 101 for proxied WS through proxy_stream, got: {resp}"
);
write_ws_text_frame(&mut stream, "hello");
let msg = read_ws_text_frame(&mut stream);
assert_eq!(msg, "hello");
write_ws_close_frame(&mut stream);
runtime::request_shutdown();
})
.unwrap();
}