use shiguredo_http11::{HttpHead, ResponseDecoder, ResponseHead};
use crate::error::Error;
use crate::websocket_extension::Extension;
use crate::websocket_handshake::calculate_accept;
#[derive(Debug, Clone, Default)]
pub struct ServerHandshakeResponse {
pub protocol: Option<String>,
pub extensions: Vec<String>,
pub additional_headers: Vec<(String, String)>,
}
impl ServerHandshakeResponse {
pub fn new() -> Self {
Self::default()
}
pub fn protocol(mut self, protocol: &str) -> Self {
self.protocol = Some(protocol.to_string());
self
}
pub fn extension(mut self, extension: &str) -> Self {
self.extensions.push(extension.to_string());
self
}
pub fn header(mut self, name: &str, value: &str) -> Self {
self.additional_headers
.push((name.to_string(), value.to_string()));
self
}
}
#[derive(Debug, Clone)]
pub struct HandshakeResponse {
pub protocol: Option<String>,
pub extensions: Vec<String>,
}
pub struct HandshakeValidator {
decoder: ResponseDecoder,
expected_accept: String,
decode_error: Option<String>,
}
impl HandshakeValidator {
pub fn new(nonce: [u8; 16]) -> Self {
let expected_accept = calculate_accept(&nonce);
Self {
decoder: ResponseDecoder::new(),
expected_accept,
decode_error: None,
}
}
pub fn feed(&mut self, data: &[u8]) {
if self.decode_error.is_none()
&& let Err(err) = self.decoder.feed(data)
{
self.decode_error = Some(err.to_string());
}
}
pub fn remaining(&self) -> &[u8] {
self.decoder.remaining()
}
pub fn validate(&mut self) -> Result<Option<HandshakeResponse>, Error> {
if let Some(reason) = self.decode_error.as_deref() {
return Err(Error::invalid_data(reason));
}
let head = match self
.decoder
.decode_headers()
.map_err(|err| Error::invalid_data(err.to_string()))?
{
Some((head, _body_kind)) => head,
None => return Ok(None),
};
self.validate_response(&head)
}
fn validate_response(
&self,
response: &ResponseHead,
) -> Result<Option<HandshakeResponse>, Error> {
if response.status_code() != 101 {
return Err(Error::http_response(crate::error::HttpResponseInfo {
status_code: response.status_code(),
reason_phrase: response.reason_phrase().to_string(),
headers: response
.headers()
.iter()
.map(|(name, value)| (name.as_str().to_string(), value.clone()))
.collect(),
}));
}
{
let upgrade_values = response.get_headers("Upgrade");
if upgrade_values.is_empty() {
return Err(Error::handshake_rejected("missing Upgrade header"));
}
let has_websocket = upgrade_values.iter().any(|v| {
v.split(',')
.any(|token| token.trim().eq_ignore_ascii_case("websocket"))
});
if !has_websocket {
return Err(Error::handshake_rejected(format!(
"invalid Upgrade header: {}",
upgrade_values.join(", ")
)));
}
}
{
let connection_values = response.get_headers("Connection");
if connection_values.is_empty() {
return Err(Error::handshake_rejected("missing Connection header"));
}
let has_upgrade = connection_values.iter().any(|v| {
v.split(',')
.any(|token| token.trim().eq_ignore_ascii_case("upgrade"))
});
if !has_upgrade {
return Err(Error::handshake_rejected(format!(
"invalid Connection header: {}",
connection_values.join(", ")
)));
}
}
{
let accept_values = response.get_headers("Sec-WebSocket-Accept");
if accept_values.len() > 1 {
return Err(Error::handshake_rejected(
"duplicate Sec-WebSocket-Accept header",
));
}
}
match response.get_header("Sec-WebSocket-Accept") {
Some(v) if v == self.expected_accept => {}
Some(v) => {
return Err(Error::handshake_rejected(format!(
"invalid Sec-WebSocket-Accept: expected {}, got {}",
self.expected_accept, v
)));
}
None => {
return Err(Error::handshake_rejected(
"missing Sec-WebSocket-Accept header",
));
}
}
{
let protocol_values = response.get_headers("Sec-WebSocket-Protocol");
if protocol_values.len() > 1 {
return Err(Error::handshake_rejected(
"duplicate Sec-WebSocket-Protocol header",
));
}
}
let protocol = response
.get_header("Sec-WebSocket-Protocol")
.map(String::from);
let extension_values = response.get_headers("Sec-WebSocket-Extensions");
let extensions: Vec<String> = extension_values
.iter()
.flat_map(|v| v.split(','))
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
if !extension_values.is_empty() && extensions.is_empty() {
return Err(Error::handshake_rejected(
"malformed Sec-WebSocket-Extensions header: no valid extensions",
));
}
for ext in &extensions {
Extension::parse_strict(ext).map_err(|e| {
Error::handshake_rejected(format!("invalid Sec-WebSocket-Extensions value: {e}"))
})?;
}
Ok(Some(HandshakeResponse {
protocol,
extensions,
}))
}
}