use std::fmt::Write;
use crate::{
Body, Error, Version,
client::{IncomingResponse, OutgoingRequest},
url::Url,
ws::{AgentRole, WebSocket},
};
pub struct ClientHandshakeBuilder {
key: String,
protocols: Vec<String>,
input_buffer_capacity: usize,
}
impl ClientHandshakeBuilder {
fn new() -> Self {
Self {
key: super::create_key(),
protocols: Vec::new(),
input_buffer_capacity: 65_536,
}
}
pub fn protocol<T>(mut self, protocol: T) -> Self
where
T: Into<String>,
{
self.protocols.push(protocol.into());
self
}
#[inline]
pub fn input_buffer_capacity(mut self, capacity: usize) -> Self {
self.input_buffer_capacity = capacity;
self
}
pub fn build(self, url: Url) -> (ClientHandshake, OutgoingRequest) {
let handshake = ClientHandshake {
accept: super::create_accept_token(self.key.as_bytes()),
protocols: self.protocols,
input_buffer_capacity: self.input_buffer_capacity,
};
let mut builder = OutgoingRequest::get(url)
.unwrap()
.set_version(Version::Version11)
.add_header_field(("Connection", "upgrade"))
.add_header_field(("Upgrade", "websocket"))
.add_header_field(("Sec-WebSocket-Version", "13"))
.add_header_field(("Sec-WebSocket-Key", self.key));
if !handshake.protocols.is_empty() {
builder = builder
.add_header_field(("Sec-WebSocket-Protocol", join_tokens(&handshake.protocols)));
}
(handshake, builder.body(Body::empty()))
}
}
pub struct ClientHandshake {
accept: String,
protocols: Vec<String>,
input_buffer_capacity: usize,
}
impl ClientHandshake {
#[inline]
pub fn builder() -> ClientHandshakeBuilder {
ClientHandshakeBuilder::new()
}
pub fn complete(self, response: IncomingResponse) -> Result<WebSocket, Error> {
assert_eq!(response.status_code(), 101);
let is_upgrade = response
.get_header_field_value("connection")
.map(|v| v.as_ref())
.unwrap_or(b"")
.split(|&b| b == b',')
.map(|kw| kw.trim_ascii())
.filter(|kw| !kw.is_empty())
.any(|kw| kw.eq_ignore_ascii_case(b"upgrade"));
if !is_upgrade {
return Err(Error::from_static_msg("not a connection upgrade"));
}
let is_websocket = response
.get_header_field_value("upgrade")
.map(|v| v.as_ref())
.unwrap_or(b"")
.trim_ascii()
.eq_ignore_ascii_case(b"websocket");
if !is_websocket {
return Err(Error::from_static_msg("not a WebSocket upgrade"));
}
let accept = response
.get_header_field_value("sec-websocket-accept")
.ok_or_else(|| Error::from_static_msg("missing WS accept header"))?
.as_ref();
if accept != self.accept.as_bytes() {
return Err(Error::from_static_msg("invalid WS accept header"));
}
let protocol = response
.get_header_field_value("sec-websocket-protocol")
.map(|p| p.trim_ascii());
if let Some(protocol) = protocol {
let is_valid_protocol = self.protocols.iter().any(|p| p.as_bytes() == protocol);
if !is_valid_protocol {
return Err(Error::from_static_msg("invalid WS sub-protocol"));
}
}
let extensions = response
.get_header_fields("sec-websocket-extensions")
.flat_map(|field| {
field
.value()
.map(|v| v.as_ref())
.unwrap_or(b"")
.split(|&b| b == b',')
.map(|p| p.trim_ascii())
.filter(|p| !p.is_empty())
})
.count();
if extensions != 0 {
return Err(Error::from_static_msg("unknown WS extensions"));
}
let upgraded = response
.upgrade()
.ok_or_else(|| Error::from_static_msg("unable to upgrade the HTTP connection"))?;
let res = WebSocket::new(upgraded, AgentRole::Client, self.input_buffer_capacity);
Ok(res)
}
}
fn join_tokens(tokens: &[String]) -> String {
let mut res = String::new();
let mut tokens = tokens.iter();
if let Some(token) = tokens.next() {
res += token.trim();
}
for token in tokens {
let _ = write!(res, ",{}", token.trim());
}
res
}