use crate::util::EntropySource;
use base64::Engine;
use sha1::{Digest, Sha1};
use std::collections::BTreeMap;
use std::fmt;
const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
#[must_use]
pub fn compute_accept_key(client_key: &str) -> String {
let mut hasher = Sha1::new();
hasher.update(client_key.as_bytes());
hasher.update(WS_GUID.as_bytes());
let hash = hasher.finalize();
base64::engine::general_purpose::STANDARD.encode(hash)
}
fn generate_client_key(entropy: &dyn EntropySource) -> String {
let mut key = [0u8; 16];
entropy.fill_bytes(&mut key);
base64::engine::general_purpose::STANDARD.encode(key)
}
fn parse_extension_offers(header_value: &str) -> Vec<String> {
header_value
.split(',')
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToOwned::to_owned)
.collect()
}
fn extension_token(offer: &str) -> &str {
offer.split(';').next().unwrap_or("").trim()
}
fn header_has_token(value: &str, token: &str) -> bool {
value
.split(',')
.map(str::trim)
.any(|part| part.eq_ignore_ascii_case(token))
}
fn split_http_header_block(data: &[u8]) -> Result<(&[u8], &[u8]), HandshakeError> {
let crlf_pos = data
.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|p| p + 4);
let lf_pos = data.windows(2).position(|w| w == b"\n\n").map(|p| p + 2);
let split_at = match (crlf_pos, lf_pos) {
(Some(c), Some(l)) => Some(std::cmp::min(c, l)),
(Some(c), None) => Some(c),
(None, Some(l)) => Some(l),
(None, None) => None,
};
split_at.map_or_else(
|| {
Err(HandshakeError::InvalidRequest(
"incomplete HTTP headers".into(),
))
},
|pos| Ok((&data[..pos], &data[pos..])),
)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WsUrl {
pub host: String,
pub port: u16,
pub path: String,
pub tls: bool,
}
impl WsUrl {
pub fn parse(url: &str) -> Result<Self, HandshakeError> {
let (scheme, rest) = url
.split_once("://")
.ok_or_else(|| HandshakeError::InvalidUrl("missing scheme".into()))?;
let tls = match scheme {
"ws" => false,
"wss" => true,
_ => {
return Err(HandshakeError::InvalidUrl(format!(
"unsupported scheme: {scheme}"
)));
}
};
let default_port = if tls { 443 } else { 80 };
let (host_port, path) = rest
.find('/')
.map_or((rest, "/"), |idx| (&rest[..idx], &rest[idx..]));
let (host, port) = if host_port.starts_with('[') {
host_port.find(']').map_or_else(
|| {
Err(HandshakeError::InvalidUrl(
"missing closing bracket for IPv6 address".into(),
))
},
|bracket_end| {
let host = &host_port[1..bracket_end];
let port = if host_port.len() > bracket_end + 1
&& host_port.as_bytes()[bracket_end + 1] == b':'
{
host_port[bracket_end + 2..]
.parse()
.map_err(|_| HandshakeError::InvalidUrl("invalid port".into()))?
} else {
default_port
};
Ok((host.to_string(), port))
},
)?
} else if host_port.matches(':').count() > 1 {
(host_port.to_string(), default_port)
} else if let Some(colon_idx) = host_port.rfind(':') {
let host = &host_port[..colon_idx];
let port = host_port[colon_idx + 1..]
.parse()
.map_err(|_| HandshakeError::InvalidUrl("invalid port".into()))?;
(host.to_string(), port)
} else {
(host_port.to_string(), default_port)
};
if host.is_empty() {
return Err(HandshakeError::InvalidUrl("empty host".into()));
}
Ok(Self {
host,
port,
path: path.to_string(),
tls,
})
}
#[must_use]
pub fn host_header(&self) -> String {
let default_port = if self.tls { 443 } else { 80 };
let host_str = if self.host.contains(':') {
format!("[{}]", self.host)
} else {
self.host.clone()
};
if self.port == default_port {
host_str
} else {
format!("{}:{}", host_str, self.port)
}
}
}
#[derive(Debug)]
pub enum HandshakeError {
InvalidUrl(String),
InvalidRequest(String),
MissingHeader(&'static str),
InvalidKey,
InvalidAccept {
expected: String,
actual: String,
},
UnsupportedVersion(String),
ProtocolMismatch {
requested: Vec<String>,
offered: Option<String>,
},
ExtensionMismatch {
requested: Vec<String>,
offered: Vec<String>,
},
Rejected {
status: u16,
reason: String,
},
NotSwitchingProtocols(u16),
Io(std::io::Error),
}
impl fmt::Display for HandshakeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidUrl(msg) => write!(f, "invalid URL: {msg}"),
Self::InvalidRequest(msg) => write!(f, "invalid HTTP request: {msg}"),
Self::MissingHeader(name) => write!(f, "missing required header: {name}"),
Self::InvalidKey => write!(f, "invalid Sec-WebSocket-Key"),
Self::InvalidAccept { expected, actual } => {
write!(
f,
"invalid Sec-WebSocket-Accept: expected {expected}, got {actual}"
)
}
Self::UnsupportedVersion(v) => write!(f, "unsupported WebSocket version: {v}"),
Self::ProtocolMismatch { requested, offered } => {
write!(
f,
"protocol mismatch: requested {requested:?}, offered {offered:?}"
)
}
Self::ExtensionMismatch { requested, offered } => {
write!(
f,
"extension mismatch: requested {requested:?}, offered {offered:?}"
)
}
Self::Rejected { status, reason } => {
write!(f, "server rejected upgrade: {status} {reason}")
}
Self::NotSwitchingProtocols(status) => {
write!(f, "expected 101 Switching Protocols, got {status}")
}
Self::Io(e) => write!(f, "I/O error: {e}"),
}
}
}
impl std::error::Error for HandshakeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
_ => None,
}
}
}
impl From<std::io::Error> for HandshakeError {
fn from(err: std::io::Error) -> Self {
Self::Io(err)
}
}
#[derive(Debug, Clone)]
pub struct ClientHandshake {
url: WsUrl,
key: String,
protocols: Vec<String>,
extensions: Vec<String>,
headers: BTreeMap<String, String>,
}
impl ClientHandshake {
#[doc(hidden)]
pub fn new_for_test(
url: WsUrl,
key: String,
protocols: Vec<String>,
extensions: Vec<String>,
headers: BTreeMap<String, String>,
) -> Self {
Self {
url,
key,
protocols,
extensions,
headers,
}
}
pub fn new(url: &str, entropy: &dyn EntropySource) -> Result<Self, HandshakeError> {
let parsed_url = WsUrl::parse(url)?;
Ok(Self {
url: parsed_url,
key: generate_client_key(entropy),
protocols: Vec::new(),
extensions: Vec::new(),
headers: BTreeMap::new(),
})
}
#[must_use]
pub fn protocol(mut self, protocol: impl Into<String>) -> Self {
self.protocols.push(protocol.into());
self
}
#[must_use]
pub fn extension(mut self, extension: impl Into<String>) -> Self {
self.extensions.push(extension.into());
self
}
#[must_use]
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(name.into(), value.into());
self
}
#[must_use]
pub fn url(&self) -> &WsUrl {
&self.url
}
#[must_use]
pub fn key(&self) -> &str {
&self.key
}
#[must_use]
pub fn request_bytes(&self) -> Vec<u8> {
let mut request = format!(
"GET {} HTTP/1.1\r\n\
Host: {}\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {}\r\n\
Sec-WebSocket-Version: 13\r\n",
self.url.path,
self.url.host_header(),
self.key
);
if !self.protocols.is_empty() {
request.push_str("Sec-WebSocket-Protocol: ");
let sanitized: Vec<String> = self
.protocols
.iter()
.map(|p| p.replace(['\r', '\n'], ""))
.collect();
request.push_str(&sanitized.join(", "));
request.push_str("\r\n");
}
if !self.extensions.is_empty() {
request.push_str("Sec-WebSocket-Extensions: ");
let sanitized: Vec<String> = self
.extensions
.iter()
.map(|e| e.replace(['\r', '\n'], ""))
.collect();
request.push_str(&sanitized.join(", "));
request.push_str("\r\n");
}
for (name, value) in &self.headers {
let name = name.replace(['\r', '\n'], "");
let value = value.replace(['\r', '\n'], "");
request.push_str(&name);
request.push_str(": ");
request.push_str(&value);
request.push_str("\r\n");
}
request.push_str("\r\n");
request.into_bytes()
}
pub fn validate_response(&self, response: &HttpResponse) -> Result<(), HandshakeError> {
if response.status != 101 {
return Err(HandshakeError::NotSwitchingProtocols(response.status));
}
let upgrade = response
.header("upgrade")
.ok_or(HandshakeError::MissingHeader("Upgrade"))?;
if !header_has_token(upgrade, "websocket") {
return Err(HandshakeError::InvalidRequest(format!(
"Upgrade header must contain 'websocket', got '{upgrade}'"
)));
}
let connection = response
.header("connection")
.ok_or(HandshakeError::MissingHeader("Connection"))?;
if !header_has_token(connection, "upgrade") {
return Err(HandshakeError::InvalidRequest(format!(
"Connection header must contain 'Upgrade', got '{connection}'"
)));
}
let accept = response
.header("sec-websocket-accept")
.ok_or(HandshakeError::MissingHeader("Sec-WebSocket-Accept"))?;
let expected = compute_accept_key(&self.key);
if accept != expected {
return Err(HandshakeError::InvalidAccept {
expected,
actual: accept.to_string(),
});
}
if let Some(offered_protocol) = response.header("sec-websocket-protocol") {
let offered = offered_protocol.trim().to_string();
if !self.protocols.iter().any(|requested| requested == &offered) {
return Err(HandshakeError::ProtocolMismatch {
requested: self.protocols.clone(),
offered: Some(offered),
});
}
}
if let Some(offered_extensions) = response.header("sec-websocket-extensions") {
let offered = parse_extension_offers(offered_extensions);
let mut invalid = Vec::new();
for extension in &offered {
let token = extension_token(extension);
if token.is_empty()
|| !self
.extensions
.iter()
.any(|requested| requested.eq_ignore_ascii_case(token))
{
invalid.push(extension.clone());
}
}
if !invalid.is_empty() {
return Err(HandshakeError::ExtensionMismatch {
requested: self.extensions.clone(),
offered: invalid,
});
}
}
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct ServerHandshake {
supported_protocols: Vec<String>,
supported_extensions: Vec<String>,
}
impl ServerHandshake {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn protocol(mut self, protocol: impl Into<String>) -> Self {
self.supported_protocols.push(protocol.into());
self
}
#[must_use]
pub fn extension(mut self, extension: impl Into<String>) -> Self {
self.supported_extensions.push(extension.into());
self
}
pub fn accept(&self, request: &HttpRequest) -> Result<AcceptResponse, HandshakeError> {
if request.method != "GET" {
return Err(HandshakeError::InvalidRequest(format!(
"method must be GET, got '{}'",
request.method
)));
}
let upgrade = request
.header("upgrade")
.ok_or(HandshakeError::MissingHeader("Upgrade"))?;
if !header_has_token(upgrade, "websocket") {
return Err(HandshakeError::InvalidRequest(format!(
"Upgrade header must contain 'websocket', got '{upgrade}'"
)));
}
let connection = request
.header("connection")
.ok_or(HandshakeError::MissingHeader("Connection"))?;
if !header_has_token(connection, "upgrade") {
return Err(HandshakeError::InvalidRequest(format!(
"Connection header must contain 'Upgrade', got '{connection}'"
)));
}
let version = request
.header("sec-websocket-version")
.ok_or(HandshakeError::MissingHeader("Sec-WebSocket-Version"))?;
if version != "13" {
return Err(HandshakeError::UnsupportedVersion(version.to_string()));
}
let client_key = request
.header("sec-websocket-key")
.ok_or(HandshakeError::MissingHeader("Sec-WebSocket-Key"))?;
match base64::engine::general_purpose::STANDARD.decode(client_key) {
Ok(decoded) if decoded.len() == 16 => {}
_ => return Err(HandshakeError::InvalidKey),
}
let accept_key = compute_accept_key(client_key);
let selected_protocol = if let Some(requested) = request.header("sec-websocket-protocol") {
let offered: Vec<String> = requested
.split(',')
.map(str::trim)
.filter(|candidate| !candidate.is_empty())
.map(ToOwned::to_owned)
.collect();
let selected = offered.iter().find(|candidate| {
self.supported_protocols
.iter()
.any(|supported| supported.as_str() == candidate.as_str())
});
match selected {
Some(s) => Some(s.clone()),
None if !self.supported_protocols.is_empty() && !offered.is_empty() => {
return Err(HandshakeError::ProtocolMismatch {
requested: offered,
offered: None,
});
}
None => None,
}
} else {
None
};
let negotiated_extensions =
request
.header("sec-websocket-extensions")
.map_or_else(Vec::new, |requested| {
let mut accepted = Vec::new();
let mut accepted_tokens = std::collections::BTreeSet::new();
for offer in parse_extension_offers(requested) {
let token = extension_token(&offer);
if token.is_empty() {
continue;
}
if self
.supported_extensions
.iter()
.any(|supported| supported.eq_ignore_ascii_case(token))
{
let normalized = token.to_ascii_lowercase();
if accepted_tokens.insert(normalized) {
let safe = offer.replace(['\r', '\n'], "");
accepted.push(safe);
}
}
}
accepted
});
Ok(AcceptResponse {
accept_key,
protocol: selected_protocol,
extensions: negotiated_extensions,
})
}
#[must_use]
pub fn reject(status: u16, reason: &str) -> Vec<u8> {
let reason = reason.replace(['\r', '\n'], "");
format!(
"HTTP/1.1 {status} {reason}\r\n\
Connection: close\r\n\
\r\n"
)
.into_bytes()
}
}
#[derive(Debug, Clone)]
pub struct AcceptResponse {
pub accept_key: String,
pub protocol: Option<String>,
pub extensions: Vec<String>,
}
impl AcceptResponse {
#[must_use]
pub fn response_bytes(&self) -> Vec<u8> {
let mut response = String::from(
"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n",
);
response.push_str("Sec-WebSocket-Accept: ");
response.push_str(&self.accept_key);
response.push_str("\r\n");
if let Some(ref protocol) = self.protocol {
response.push_str("Sec-WebSocket-Protocol: ");
response.push_str(protocol);
response.push_str("\r\n");
}
if !self.extensions.is_empty() {
response.push_str("Sec-WebSocket-Extensions: ");
response.push_str(&self.extensions.join(", "));
response.push_str("\r\n");
}
response.push_str("\r\n");
response.into_bytes()
}
}
#[derive(Debug, Clone)]
pub struct HttpRequest {
pub method: String,
pub path: String,
headers: BTreeMap<String, String>,
}
impl HttpRequest {
#[allow(clippy::option_if_let_else)]
pub fn parse_with_trailing(data: &[u8]) -> Result<(Self, &[u8]), HandshakeError> {
let (header_bytes, trailing) = split_http_header_block(data)?;
let text = std::str::from_utf8(header_bytes)
.map_err(|_| HandshakeError::InvalidRequest("invalid UTF-8".into()))?;
let mut lines = text.lines();
let request_line = lines
.next()
.ok_or_else(|| HandshakeError::InvalidRequest("empty request".into()))?;
let mut parts = request_line.split_whitespace();
let method = parts
.next()
.ok_or_else(|| HandshakeError::InvalidRequest("missing method".into()))?
.to_string();
let path = parts
.next()
.ok_or_else(|| HandshakeError::InvalidRequest("missing path".into()))?
.to_string();
let mut headers = BTreeMap::new();
for line in lines {
if line.is_empty() {
break;
}
if let Some((name, value)) = line.split_once(':') {
headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
}
}
Ok((
Self {
method,
path,
headers,
},
trailing,
))
}
pub fn parse(data: &[u8]) -> Result<Self, HandshakeError> {
Self::parse_with_trailing(data).map(|(req, _)| req)
}
#[must_use]
pub fn header(&self, name: &str) -> Option<&str> {
self.headers
.get(&name.to_ascii_lowercase())
.map(String::as_str)
}
}
#[derive(Debug, Clone)]
pub struct HttpResponse {
pub status: u16,
pub reason: String,
headers: BTreeMap<String, String>,
}
impl HttpResponse {
pub fn parse(data: &[u8]) -> Result<Self, HandshakeError> {
let (header_bytes, _trailing) = split_http_header_block(data)?;
let text = std::str::from_utf8(header_bytes)
.map_err(|_| HandshakeError::InvalidRequest("invalid UTF-8".into()))?;
let mut lines = text.lines();
let status_line = lines
.next()
.ok_or_else(|| HandshakeError::InvalidRequest("empty response".into()))?;
let mut parts = status_line.splitn(3, ' ');
let _version = parts
.next()
.ok_or_else(|| HandshakeError::InvalidRequest("missing HTTP version".into()))?;
let status: u16 = parts
.next()
.ok_or_else(|| HandshakeError::InvalidRequest("missing status code".into()))?
.parse()
.map_err(|_| HandshakeError::InvalidRequest("invalid status code".into()))?;
let reason = parts.next().unwrap_or("").to_string();
let mut headers = BTreeMap::new();
for line in lines {
if line.is_empty() {
break;
}
if let Some((name, value)) = line.split_once(':') {
headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
}
}
Ok(Self {
status,
reason,
headers,
})
}
#[must_use]
pub fn header(&self, name: &str) -> Option<&str> {
self.headers
.get(&name.to_ascii_lowercase())
.map(String::as_str)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::util::DetEntropy;
#[test]
fn test_compute_accept_key() {
let client_key = "dGhlIHNhbXBsZSBub25jZQ==";
let accept = compute_accept_key(client_key);
assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
#[test]
fn test_ws_url_parse() {
let url = WsUrl::parse("ws://example.com/chat").unwrap();
assert_eq!(url.host, "example.com");
assert_eq!(url.port, 80);
assert_eq!(url.path, "/chat");
assert!(!url.tls);
let url = WsUrl::parse("wss://example.com:8443/ws").unwrap();
assert_eq!(url.host, "example.com");
assert_eq!(url.port, 8443);
assert_eq!(url.path, "/ws");
assert!(url.tls);
let url = WsUrl::parse("ws://localhost:9000").unwrap();
assert_eq!(url.host, "localhost");
assert_eq!(url.port, 9000);
assert_eq!(url.path, "/");
let url = WsUrl::parse("ws://[::1]:8080/test").unwrap();
assert_eq!(url.host, "::1");
assert_eq!(url.port, 8080);
assert_eq!(url.path, "/test");
}
#[test]
fn test_ws_url_host_header() {
let url = WsUrl::parse("ws://example.com/chat").unwrap();
assert_eq!(url.host_header(), "example.com");
let url = WsUrl::parse("ws://example.com:8080/chat").unwrap();
assert_eq!(url.host_header(), "example.com:8080");
let url = WsUrl::parse("wss://example.com/chat").unwrap();
assert_eq!(url.host_header(), "example.com");
let url = WsUrl::parse("wss://example.com:443/chat").unwrap();
assert_eq!(url.host_header(), "example.com");
}
#[test]
fn test_client_handshake_request() {
let entropy = DetEntropy::new(7);
let handshake = ClientHandshake::new("ws://example.com/chat", &entropy)
.unwrap()
.protocol("chat");
let request = handshake.request_bytes();
let text = String::from_utf8(request).unwrap();
assert!(text.starts_with("GET /chat HTTP/1.1\r\n"));
assert!(text.contains("Host: example.com\r\n"));
assert!(text.contains("Upgrade: websocket\r\n"));
assert!(text.contains("Connection: Upgrade\r\n"));
assert!(text.contains("Sec-WebSocket-Key: "));
assert!(text.contains("Sec-WebSocket-Version: 13\r\n"));
assert!(text.contains("Sec-WebSocket-Protocol: chat\r\n"));
assert!(text.ends_with("\r\n\r\n"));
}
#[test]
fn test_client_validate_response() {
let handshake = ClientHandshake {
url: WsUrl::parse("ws://example.com/chat").unwrap(),
key: "dGhlIHNhbXBsZSBub25jZQ==".to_string(),
protocols: vec![],
extensions: vec![],
headers: BTreeMap::new(),
};
let response = HttpResponse::parse(
b"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\
\r\n",
)
.unwrap();
assert!(handshake.validate_response(&response).is_ok());
}
#[test]
fn test_client_validate_response_rejects_connection_substring_false_positive() {
let handshake = ClientHandshake {
url: WsUrl::parse("ws://example.com/chat").unwrap(),
key: "dGhlIHNhbXBsZSBub25jZQ==".to_string(),
protocols: vec![],
extensions: vec![],
headers: BTreeMap::new(),
};
let response = HttpResponse::parse(
b"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: notupgrade\r\n\
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\
\r\n",
)
.unwrap();
let err = handshake.validate_response(&response).unwrap_err();
assert!(matches!(err, HandshakeError::InvalidRequest(_)));
}
#[test]
fn test_client_validate_response_allows_upgrade_header_token_list() {
let handshake = ClientHandshake {
url: WsUrl::parse("ws://example.com/chat").unwrap(),
key: "dGhlIHNhbXBsZSBub25jZQ==".to_string(),
protocols: vec![],
extensions: vec![],
headers: BTreeMap::new(),
};
let response = HttpResponse::parse(
b"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: h2c, websocket\r\n\
Connection: keep-alive, Upgrade\r\n\
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\
\r\n",
)
.unwrap();
assert!(handshake.validate_response(&response).is_ok());
}
#[test]
fn test_client_validate_response_bad_accept() {
let handshake = ClientHandshake {
url: WsUrl::parse("ws://example.com/chat").unwrap(),
key: "dGhlIHNhbXBsZSBub25jZQ==".to_string(),
protocols: vec![],
extensions: vec![],
headers: BTreeMap::new(),
};
let response = HttpResponse::parse(
b"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: wrong-accept-key\r\n\
\r\n",
)
.unwrap();
let err = handshake.validate_response(&response).unwrap_err();
assert!(matches!(err, HandshakeError::InvalidAccept { .. }));
}
#[test]
fn test_client_validate_response_unsolicited_protocol_rejected() {
let handshake = ClientHandshake {
url: WsUrl::parse("ws://example.com/chat").expect("valid url"),
key: "dGhlIHNhbXBsZSBub25jZQ==".to_string(),
protocols: vec![],
extensions: vec![],
headers: BTreeMap::new(),
};
let response = HttpResponse::parse(
b"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\
Sec-WebSocket-Protocol: chat\r\n\
\r\n",
)
.expect("response must parse");
let err = handshake
.validate_response(&response)
.expect_err("unsolicited protocol must be rejected");
assert!(matches!(err, HandshakeError::ProtocolMismatch { .. }));
}
#[test]
fn test_client_validate_response_unrequested_protocol_rejected() {
let handshake = ClientHandshake {
url: WsUrl::parse("ws://example.com/chat").expect("valid url"),
key: "dGhlIHNhbXBsZSBub25jZQ==".to_string(),
protocols: vec!["chat".to_string()],
extensions: vec![],
headers: BTreeMap::new(),
};
let response = HttpResponse::parse(
b"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\
Sec-WebSocket-Protocol: superchat\r\n\
\r\n",
)
.expect("response must parse");
let err = handshake
.validate_response(&response)
.expect_err("protocol not in request must be rejected");
assert!(matches!(err, HandshakeError::ProtocolMismatch { .. }));
}
#[test]
fn test_client_validate_response_requested_protocol_accepted() {
let handshake = ClientHandshake {
url: WsUrl::parse("ws://example.com/chat").expect("valid url"),
key: "dGhlIHNhbXBsZSBub25jZQ==".to_string(),
protocols: vec!["chat".to_string(), "superchat".to_string()],
extensions: vec![],
headers: BTreeMap::new(),
};
let response = HttpResponse::parse(
b"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\
Sec-WebSocket-Protocol: superchat\r\n\
\r\n",
)
.expect("response must parse");
assert!(handshake.validate_response(&response).is_ok());
}
#[test]
fn test_server_accept() {
let server = ServerHandshake::new().protocol("chat");
let request = HttpRequest::parse(
b"GET /chat HTTP/1.1\r\n\
Host: example.com\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Protocol: chat\r\n\
\r\n",
)
.unwrap();
let accept = server.accept(&request).unwrap();
assert_eq!(accept.accept_key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
assert_eq!(accept.protocol, Some("chat".to_string()));
}
#[test]
fn test_server_accept_allows_upgrade_header_token_list() {
let server = ServerHandshake::new();
let request = HttpRequest::parse(
b"GET /chat HTTP/1.1\r\n\
Host: example.com\r\n\
Upgrade: h2c, websocket\r\n\
Connection: keep-alive, Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n",
)
.unwrap();
let accept = server.accept(&request).unwrap();
assert_eq!(accept.accept_key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
#[test]
fn test_server_accept_rejects_connection_substring_false_positive() {
let server = ServerHandshake::new();
let request = HttpRequest::parse(
b"GET /chat HTTP/1.1\r\n\
Host: example.com\r\n\
Upgrade: websocket\r\n\
Connection: notupgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
\r\n",
)
.unwrap();
let err = server.accept(&request).unwrap_err();
assert!(matches!(err, HandshakeError::InvalidRequest(_)));
}
#[test]
fn test_server_accept_negotiates_extensions() {
let server = ServerHandshake::new()
.extension("permessage-deflate")
.extension("x-webkit-deflate-frame");
let request = HttpRequest::parse(
b"GET /chat HTTP/1.1\r\n\
Host: example.com\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits, x-ignored\r\n\
\r\n",
)
.unwrap();
let accept = server.accept(&request).unwrap();
assert_eq!(
accept.extensions,
vec!["permessage-deflate; client_max_window_bits".to_string()]
);
}
#[test]
fn test_server_reject_bad_version() {
let server = ServerHandshake::new();
let request = HttpRequest::parse(
b"GET /chat HTTP/1.1\r\n\
Host: example.com\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 8\r\n\
\r\n",
)
.unwrap();
let err = server.accept(&request).unwrap_err();
assert!(matches!(err, HandshakeError::UnsupportedVersion(_)));
}
#[test]
fn test_accept_response_bytes() {
let accept = AcceptResponse {
accept_key: "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".to_string(),
protocol: Some("chat".to_string()),
extensions: vec![],
};
let response = accept.response_bytes();
let text = String::from_utf8(response).unwrap();
assert!(text.starts_with("HTTP/1.1 101 Switching Protocols\r\n"));
assert!(text.contains("Upgrade: websocket\r\n"));
assert!(text.contains("Connection: Upgrade\r\n"));
assert!(text.contains("Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"));
assert!(text.contains("Sec-WebSocket-Protocol: chat\r\n"));
assert!(text.ends_with("\r\n\r\n"));
}
#[test]
fn test_accept_response_snapshot_negotiated_protocol_and_extension() {
let server = ServerHandshake::new()
.protocol("superchat")
.extension("permessage-deflate");
let request = HttpRequest::parse(
b"GET /chat HTTP/1.1\r\n\
Host: example.com\r\n\
Upgrade: websocket\r\n\
Connection: keep-alive, Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Protocol: chat, superchat\r\n\
Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits, x-ignored\r\n\
\r\n",
)
.unwrap();
let accept = server.accept(&request).unwrap();
let response = String::from_utf8(accept.response_bytes()).unwrap();
insta::assert_snapshot!(
"accept_response_negotiated_protocol_and_extension",
response
);
}
#[test]
fn test_client_validate_response_rejects_unsolicited_extensions() {
let handshake = ClientHandshake {
url: WsUrl::parse("ws://example.com/chat").expect("valid url"),
key: "dGhlIHNhbXBsZSBub25jZQ==".to_string(),
protocols: vec![],
extensions: vec!["permessage-deflate".to_string()],
headers: BTreeMap::new(),
};
let response = HttpResponse::parse(
b"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\
Sec-WebSocket-Extensions: x-unrequested\r\n\
\r\n",
)
.expect("response must parse");
let err = handshake
.validate_response(&response)
.expect_err("unrequested extension must be rejected");
assert!(matches!(err, HandshakeError::ExtensionMismatch { .. }));
}
#[test]
fn test_client_validate_response_accepts_requested_extensions() {
let handshake = ClientHandshake {
url: WsUrl::parse("ws://example.com/chat").expect("valid url"),
key: "dGhlIHNhbXBsZSBub25jZQ==".to_string(),
protocols: vec![],
extensions: vec!["permessage-deflate".to_string()],
headers: BTreeMap::new(),
};
let response = HttpResponse::parse(
b"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\
Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n\
\r\n",
)
.expect("response must parse");
assert!(handshake.validate_response(&response).is_ok());
}
#[test]
fn test_http_request_parse() {
let request = HttpRequest::parse(
b"GET /chat HTTP/1.1\r\n\
Host: example.com\r\n\
Upgrade: WebSocket\r\n\
Connection: Upgrade\r\n\
\r\n",
)
.unwrap();
assert_eq!(request.method, "GET");
assert_eq!(request.path, "/chat");
assert_eq!(request.header("host"), Some("example.com"));
assert_eq!(request.header("upgrade"), Some("WebSocket"));
assert_eq!(request.header("connection"), Some("Upgrade"));
}
#[test]
fn test_http_request_parse_rejects_incomplete_headers() {
let err = HttpRequest::parse(
b"GET /chat HTTP/1.1\r\n\
Host: example.com\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n",
)
.expect_err("missing blank line must be treated as an incomplete request");
assert!(matches!(err, HandshakeError::InvalidRequest(_)));
}
#[test]
fn test_http_response_parse() {
let response = HttpResponse::parse(
b"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: xyz\r\n\
\r\n",
)
.unwrap();
assert_eq!(response.status, 101);
assert_eq!(response.reason, "Switching Protocols");
assert_eq!(response.header("upgrade"), Some("websocket"));
assert_eq!(response.header("sec-websocket-accept"), Some("xyz"));
}
#[test]
fn test_http_response_parse_rejects_incomplete_headers() {
let err = HttpResponse::parse(
b"HTTP/1.1 101 Switching Protocols\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: xyz\r\n",
)
.expect_err("missing blank line must be treated as an incomplete response");
assert!(matches!(err, HandshakeError::InvalidRequest(_)));
}
#[test]
fn test_split_http_header_block_prefers_earliest_complete_terminator() {
let data = b"GET /chat HTTP/1.1\n\
Host: example.com\n\
Upgrade: websocket\n\
Connection: Upgrade\n\
\n\
body-prefix\r\n\r\nstill-body";
let (header, trailing) = split_http_header_block(data).unwrap();
assert_eq!(
header,
b"GET /chat HTTP/1.1\n\
Host: example.com\n\
Upgrade: websocket\n\
Connection: Upgrade\n\
\n"
);
assert_eq!(trailing, b"body-prefix\r\n\r\nstill-body");
}
#[test]
fn test_generate_client_key() {
let entropy = DetEntropy::new(42);
let key = generate_client_key(&entropy);
let decoded = base64::engine::general_purpose::STANDARD
.decode(&key)
.unwrap();
assert_eq!(decoded.len(), 16);
}
#[test]
fn ws_url_debug_clone_eq() {
let u = WsUrl {
host: "example.com".into(),
port: 80,
path: "/chat".into(),
tls: false,
};
let dbg = format!("{u:?}");
assert!(dbg.contains("WsUrl"));
assert!(dbg.contains("example.com"));
let u2 = u.clone();
assert_eq!(u, u2);
let u3 = WsUrl {
host: "other.com".into(),
port: 443,
path: "/".into(),
tls: true,
};
assert_ne!(u, u3);
}
#[test]
fn server_handshake_debug_clone_default() {
let s = ServerHandshake::default();
let dbg = format!("{s:?}");
assert!(dbg.contains("ServerHandshake"));
let s2 = s;
let dbg2 = format!("{s2:?}");
assert_eq!(dbg, dbg2);
}
#[test]
fn http_request_debug_clone() {
let r = HttpRequest::parse(b"GET /test HTTP/1.1\r\nHost: localhost\r\n\r\n").unwrap();
let dbg = format!("{r:?}");
assert!(dbg.contains("HttpRequest"));
let r2 = r;
assert_eq!(r2.method, "GET");
assert_eq!(r2.path, "/test");
}
#[test]
fn server_accept_strips_crlf_from_extension_offers() {
let raw_request = "GET / HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\n\
Connection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\r\n";
let mut request = HttpRequest::parse(raw_request.as_bytes()).unwrap();
request.headers.insert(
"sec-websocket-extensions".to_string(),
"permessage-deflate; x\r\nX-Injected: evil".to_string(),
);
let server = ServerHandshake::new().extension("permessage-deflate");
let accept = server.accept(&request).unwrap();
let response = accept.response_bytes();
let response_str = String::from_utf8_lossy(&response);
let line_count = response_str.lines().count();
assert!(
line_count <= 7,
"response splitting injected extra header lines: {response_str}"
);
for line in response_str.lines() {
if line.starts_with("Sec-WebSocket-Extensions:") {
assert!(
!line.contains('\r') && !line.contains('\n'),
"extension header must not contain embedded CRLF: {line}"
);
}
}
}
#[test]
fn golden_16_byte_base64_key_validation() {
let server = ServerHandshake::new();
let valid_keys = vec![
"dGhlIHNhbXBsZSBub25jZQ==", "AQIDBAUGBwgJCgsMDQ4PEA==", "/////////////////////w==", "AAAAAAAAAAAAAAAAAAAAAA==", "MTIzNDU2Nzg5YWJjZGVmZw==", ];
for (i, key) in valid_keys.iter().enumerate() {
let request_data = format!(
"GET /test HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {}\r\n\
Sec-WebSocket-Version: 13\r\n\r\n",
key
);
let request = HttpRequest::parse(request_data.as_bytes())
.expect(&format!("Failed to parse request {}", i));
let result = server.accept(&request);
assert!(
result.is_ok(),
"Valid 16-byte key #{} should be accepted: '{}', error: {:?}",
i,
key,
result.unwrap_err()
);
let decoded = base64::engine::general_purpose::STANDARD
.decode(key)
.expect("Key should decode properly");
assert_eq!(
decoded.len(),
16,
"Key #{} should decode to exactly 16 bytes: '{}'",
i,
key
);
}
let invalid_keys = vec![
("", "empty key"),
("dGhlIHNhbXBsZSBub25jZQ", "missing padding"),
("dGhlIHNhbXBsZSBub25jZQ====", "too much padding"),
("dGhlIHNhbXBsZSBub25jZ===", "15 bytes (one short)"),
("dGhlIHNhbXBsZSBub25jZGQ=", "17 bytes (one too many)"),
("dGhlIHNhbXBsZSBub25jZGRk", "18 bytes"),
("MTIzNA==", "only 4 bytes"),
("!@#$%^&*()_+{}|:<>?", "invalid base64 characters"),
("dGhlIHNhbXBsZSBub25jZQ=", "invalid padding"),
("AAAAAAAAAAAAAAAAAAAAAAAAAAAA", "32 bytes"),
];
for (key, description) in invalid_keys {
let request_data = format!(
"GET /test HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {}\r\n\
Sec-WebSocket-Version: 13\r\n\r\n",
key
);
let request = HttpRequest::parse(request_data.as_bytes())
.expect(&format!("Failed to parse request for {}", description));
let result = server.accept(&request);
assert!(
result.is_err(),
"Invalid key should be rejected: {} ({})",
key,
description
);
if let Err(error) = result {
assert!(
matches!(error, HandshakeError::InvalidKey),
"Should fail with InvalidKey error for {}: got {:?}",
description,
error
);
}
}
}
#[test]
fn golden_sha1_fixed_guid_concatenation() {
let client_key = "dGhlIHNhbXBsZSBub25jZQ==";
let expected_accept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=";
let actual_accept = compute_accept_key(client_key);
assert_eq!(
actual_accept, expected_accept,
"RFC 6455 test vector must match exactly"
);
let concatenated = format!("{}{}", client_key, WS_GUID);
assert_eq!(
concatenated,
"dGhlIHNhbXBsZSBub25jZQ==258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
);
let mut hasher = Sha1::new();
hasher.update(concatenated.as_bytes());
let hash = hasher.finalize();
let manual_accept = base64::engine::general_purpose::STANDARD.encode(hash);
assert_eq!(
manual_accept, expected_accept,
"Manual computation should match library computation"
);
let test_vectors = vec![
("AQIDBAUGBwgJCgsMDQ4PEA==", "C/0nmHhBztSRGR1CwL6Tf4ZjwpY="),
("AAAAAAAAAAAAAAAAAAAAAA==", "ICX+Yqv66kxgM0FcWaLWlFLwTAI="),
("/////////////////////w==", "XXpj4jYzLM2yUE0C7TIgMwTQh2g="),
];
for (key, expected) in test_vectors {
let computed = compute_accept_key(key);
assert_eq!(
computed, expected,
"Accept key computation failed for test vector: key={}, expected={}, got={}",
key, expected, computed
);
let computed_again = compute_accept_key(key);
assert_eq!(
computed, computed_again,
"Accept key computation should be deterministic"
);
}
assert_eq!(WS_GUID, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
let wrong_guid = "358EAFA5-E914-47DA-95CA-C5AB0DC85B11"; let concatenated_wrong = format!("{}{}", client_key, wrong_guid);
let mut hasher_wrong = Sha1::new();
hasher_wrong.update(concatenated_wrong.as_bytes());
let hash_wrong = hasher_wrong.finalize();
let wrong_accept = base64::engine::general_purpose::STANDARD.encode(hash_wrong);
assert_ne!(
wrong_accept, expected_accept,
"Wrong GUID should produce different result"
);
}
#[test]
fn golden_key_reuse_detection() {
let server = ServerHandshake::new().protocol("chat");
let reused_key = "dGhlIHNhbXBsZSBub25jZQ==";
let request1_data = format!(
"GET /test1 HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {}\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Protocol: chat\r\n\r\n",
reused_key
);
let request1 =
HttpRequest::parse(request1_data.as_bytes()).expect("First request should parse");
let accept1 = server
.accept(&request1)
.expect("First connection should be accepted");
let request2_data = format!(
"GET /test2 HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {}\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Protocol: chat\r\n\r\n",
reused_key
);
let request2 =
HttpRequest::parse(request2_data.as_bytes()).expect("Second request should parse");
let accept2 = server
.accept(&request2)
.expect("Second connection should be accepted");
assert_eq!(
accept1.accept_key, accept2.accept_key,
"Same client key should always produce same accept key"
);
assert_eq!(accept1.accept_key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
for i in 0..10 {
let request_data = format!(
"GET /test{} HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {}\r\n\
Sec-WebSocket-Version: 13\r\n\r\n",
i, reused_key
);
let request = HttpRequest::parse(request_data.as_bytes())
.expect(&format!("Request {} should parse", i));
let accept = server
.accept(&request)
.expect(&format!("Connection {} should be accepted", i));
assert_eq!(
accept.accept_key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
"Connection {} should have consistent accept key",
i
);
}
let different_keys = vec![
"AQIDBAUGBwgJCgsMDQ4PEA==",
"AAAAAAAAAAAAAAAAAAAAAA==",
"/////////////////////w==",
];
let mut accept_keys = vec![accept1.accept_key.clone()];
for (i, key) in different_keys.iter().enumerate() {
let request_data = format!(
"GET /unique{} HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: {}\r\n\
Sec-WebSocket-Version: 13\r\n\r\n",
i, key
);
let request = HttpRequest::parse(request_data.as_bytes())
.expect(&format!("Request for key {} should parse", i));
let accept = server
.accept(&request)
.expect(&format!("Connection for key {} should be accepted", i));
accept_keys.push(accept.accept_key.clone());
}
for i in 0..accept_keys.len() {
for j in (i + 1)..accept_keys.len() {
assert_ne!(
accept_keys[i], accept_keys[j],
"Accept keys {} and {} should be different: '{}' vs '{}'",
i, j, accept_keys[i], accept_keys[j]
);
}
}
}
#[test]
fn golden_multiple_sec_websocket_protocol_negotiation() {
let server = ServerHandshake::new()
.protocol("chat")
.protocol("superchat")
.protocol("echo");
let request_data = "GET /test HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Protocol: superchat, chat, echo\r\n\r\n";
let request = HttpRequest::parse(request_data.as_bytes())
.expect("Multiple protocol request should parse");
let accept = server
.accept(&request)
.expect("Multiple protocol negotiation should succeed");
assert_eq!(
accept.protocol,
Some("superchat".to_string()),
"Should select first matching protocol from client preference order"
);
let server_limited = ServerHandshake::new().protocol("private-protocol");
let request_data = "GET /test HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Protocol: chat, superchat, echo\r\n\r\n";
let request = HttpRequest::parse(request_data.as_bytes())
.expect("Unsupported protocol request should parse");
let result = server_limited.accept(&request);
assert!(result.is_err(), "Should reject when no protocols match");
if let Err(error) = result {
assert!(
matches!(error, HandshakeError::ProtocolMismatch { .. }),
"Should fail with ProtocolMismatch error: {:?}",
error
);
}
let server_single = ServerHandshake::new().protocol("websocket-chat");
let request_data = "GET /test HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Protocol: websocket-chat\r\n\r\n";
let request = HttpRequest::parse(request_data.as_bytes())
.expect("Single protocol request should parse");
let accept = server_single
.accept(&request)
.expect("Single protocol negotiation should succeed");
assert_eq!(
accept.protocol,
Some("websocket-chat".to_string()),
"Should accept exact protocol match"
);
let request_data = "GET /test HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\r\n";
let request =
HttpRequest::parse(request_data.as_bytes()).expect("No protocol request should parse");
let accept = server
.accept(&request)
.expect("Should accept connection without protocol when client doesn't request any");
assert_eq!(
accept.protocol, None,
"Should not select protocol when client doesn't request any"
);
let protocol_test_cases = vec![
("chat", "chat"),
("chat, superchat", "chat"), (" chat , superchat ", "chat"), ("superchat,chat,echo", "chat"), ("unknown, chat, unknown2", "chat"), ];
let server_chat = ServerHandshake::new().protocol("chat");
for (protocol_header, expected) in protocol_test_cases {
let request_data = format!(
"GET /test HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Protocol: {}\r\n\r\n",
protocol_header
);
let request = HttpRequest::parse(request_data.as_bytes()).expect(&format!(
"Protocol header '{}' should parse",
protocol_header
));
let accept = server_chat.accept(&request).expect(&format!(
"Protocol negotiation should succeed for '{}'",
protocol_header
));
assert_eq!(
accept.protocol,
Some(expected.to_string()),
"Protocol header '{}' should select '{}'",
protocol_header,
expected
);
}
let server_case = ServerHandshake::new().protocol("Chat");
let request_data = "GET /test HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Protocol: chat\r\n\r\n";
let request =
HttpRequest::parse(request_data.as_bytes()).expect("Case test request should parse");
let result = server_case.accept(&request);
assert!(
result.is_err(),
"Protocol matching should be case-sensitive: 'Chat' != 'chat'"
);
}
#[test]
fn golden_rfc6455_compliant_status_codes() {
let server = ServerHandshake::new();
let valid_request_data = "GET /test HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\r\n";
let request =
HttpRequest::parse(valid_request_data.as_bytes()).expect("Valid request should parse");
let accept = server
.accept(&request)
.expect("Valid request should be accepted");
let response_bytes = accept.response_bytes();
let response_str = String::from_utf8_lossy(&response_bytes);
assert!(
response_str.starts_with("HTTP/1.1 101 Switching Protocols"),
"Successful handshake should return 101 Switching Protocols"
);
let missing_header_tests = vec![
(
"GET /test HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\n\r\n",
"MissingHeader",
"Missing Connection header",
),
(
"GET /test HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\n\r\n",
"MissingHeader",
"Missing Upgrade header",
),
(
"GET /test HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n",
"MissingHeader",
"Missing Sec-WebSocket-Key header",
),
(
"GET /test HTTP/1.1\r\nHost: localhost\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\r\n",
"MissingHeader",
"Missing Sec-WebSocket-Version header",
),
];
for (request_data, expected_error, description) in missing_header_tests {
let request = HttpRequest::parse(request_data.as_bytes())
.expect(&format!("Request should parse: {}", description));
let result = server.accept(&request);
assert!(result.is_err(), "Should reject request: {}", description);
let error = result.unwrap_err();
let error_str = format!("{:?}", error);
assert!(
error_str.contains(expected_error),
"Should fail with {}: {} - got {:?}",
expected_error,
description,
error
);
}
let invalid_version_data = "GET /test HTTP/1.1\r\n\
Host: localhost\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 12\r\n\r\n";
let request = HttpRequest::parse(invalid_version_data.as_bytes())
.expect("Invalid version request should parse");
let result = server.accept(&request);
assert!(result.is_err(), "Should reject invalid WebSocket version");
if let Err(error) = result {
assert!(
matches!(error, HandshakeError::UnsupportedVersion(_)),
"Should fail with UnsupportedVersion error: {:?}",
error
);
}
let handshake = ClientHandshake {
url: crate::net::websocket::handshake::WsUrl::parse("ws://example.com/test").unwrap(),
key: "dGhlIHNhbXBsZSBub25jZQ==".to_string(),
protocols: vec![],
extensions: vec![],
headers: std::collections::BTreeMap::new(),
};
let invalid_status_tests = vec![
(200, "200 OK"),
(400, "400 Bad Request"),
(404, "404 Not Found"),
(426, "426 Upgrade Required"),
(500, "500 Internal Server Error"),
];
for (status_code, status_text) in invalid_status_tests {
let response_data = format!(
"HTTP/1.1 {} {}\r\n\
Upgrade: websocket\r\n\
Connection: Upgrade\r\n\
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n",
status_code, status_text
);
let response = HttpResponse::parse(response_data.as_bytes()).expect(&format!(
"Response with status {} should parse",
status_code
));
let result = handshake.validate_response(&response);
assert!(
result.is_err(),
"Should reject response with status code {}",
status_code
);
if let Err(error) = result {
assert!(
matches!(error, HandshakeError::NotSwitchingProtocols(_)),
"Should fail with NotSwitchingProtocols for status {}: {:?}",
status_code,
error
);
}
}
let valid_request_for_response = HttpRequest::parse(valid_request_data.as_bytes())
.expect("Valid request should parse");
let accept = server
.accept(&valid_request_for_response)
.expect("Valid request should be accepted");
let response_bytes = accept.response_bytes();
let response_str = String::from_utf8_lossy(&response_bytes);
assert!(
response_str.contains("Upgrade: websocket"),
"Response should contain Upgrade header"
);
assert!(
response_str.contains("Connection: Upgrade"),
"Response should contain Connection header"
);
assert!(
response_str.contains("Sec-WebSocket-Accept: "),
"Response should contain Sec-WebSocket-Accept header"
);
assert!(
response_str.ends_with("\r\n\r\n"),
"Response should end with CRLF CRLF"
);
let line_count = response_str.lines().count();
assert!(
line_count <= 6,
"Response should not have extra headers: {}",
response_str
);
let malformed_requests: Vec<&[u8]> = vec![
b"NOT HTTP\r\n\r\n",
b"GET /test\r\n\r\n", b"GET /test HTTP/1.0\r\n\r\n", b"",
];
for (i, malformed) in malformed_requests.iter().enumerate() {
let result = HttpRequest::parse(malformed);
if i < 3 {
if let Ok(request) = result {
let server_result = server.accept(&request);
assert!(
server_result.is_err(),
"Malformed request {} should be rejected",
i
);
}
} else {
assert!(result.is_err(), "Empty request should fail to parse");
}
}
}
#[test]
fn golden_end_to_end_handshake_validation() {
let entropy = crate::util::entropy::DetEntropy::new(12345);
let client = ClientHandshake::new("ws://localhost:8080/socket", &entropy)
.expect("Client handshake should initialize")
.protocol("chat")
.protocol("echo")
.extension("permessage-deflate");
let server = ServerHandshake::new()
.protocol("echo")
.protocol("chat") .extension("permessage-deflate");
let request_bytes = client.request_bytes();
let request_str = String::from_utf8_lossy(&request_bytes);
assert!(request_str.contains("GET /socket HTTP/1.1"));
assert!(request_str.contains("Host: localhost:8080"));
assert!(request_str.contains("Upgrade: websocket"));
assert!(request_str.contains("Connection: Upgrade"));
assert!(request_str.contains("Sec-WebSocket-Key: "));
assert!(request_str.contains("Sec-WebSocket-Version: 13"));
assert!(request_str.contains("Sec-WebSocket-Protocol: chat, echo"));
let request =
HttpRequest::parse(&request_bytes).expect("Client request should parse on server");
let accept = server
.accept(&request)
.expect("Server should accept valid client request");
assert_eq!(
accept.protocol,
Some("chat".to_string()),
"Server should select first client protocol it supports"
);
let response_bytes = accept.response_bytes();
let response_str = String::from_utf8_lossy(&response_bytes);
assert!(response_str.contains("HTTP/1.1 101 Switching Protocols"));
assert!(response_str.contains(&format!("Sec-WebSocket-Accept: {}", accept.accept_key)));
assert!(response_str.contains("Sec-WebSocket-Protocol: chat"));
let response =
HttpResponse::parse(&response_bytes).expect("Server response should parse on client");
let validation_result = client.validate_response(&response);
assert!(
validation_result.is_ok(),
"Client should validate server response: {:?}",
validation_result.unwrap_err()
);
let expected_accept = compute_accept_key(&client.key);
assert_eq!(
accept.accept_key, expected_accept,
"Server accept key should match computed value"
);
if !accept.extensions.is_empty() {
assert!(
response_str.contains("Sec-WebSocket-Extensions:"),
"Response should include extension header when extensions are negotiated"
);
}
}
}