use std::{
future::Future,
io,
pin::Pin,
task::{Context, Poll},
};
use futures::FutureExt;
use crate::{
Error, Method, Status, Version,
server::{IncomingRequest, OutgoingResponse},
ws::{AgentRole, WebSocket},
};
pub struct ServerHandshake {
accept: String,
protocols: Vec<String>,
}
impl ServerHandshake {
pub fn new(request: IncomingRequest) -> Result<Self, Error> {
if request.method() != Method::Get {
return Err(Error::from_static_msg(
"invalid HTTP method for WS handshake",
));
} else if request.version() == Version::Version10 {
return Err(Error::from_static_msg(
"this HTTP version is not supported for WS",
));
}
let is_upgrade = request
.get_header_fields("connection")
.flat_map(|field| {
field
.value()
.map(|v| v.as_ref())
.unwrap_or(b"")
.split(|&b| b == b',')
.map(|kw| kw.trim_ascii())
.filter(|kw| !kw.is_empty())
})
.any(|kw| kw.eq_ignore_ascii_case(b"upgrade"));
if !is_upgrade {
return Err(Error::from_static_msg("not a connection upgrade"));
}
let is_websocket = request
.get_header_fields("upgrade")
.flat_map(|field| {
field
.value()
.map(|v| v.as_ref())
.unwrap_or(b"")
.split(|&b| b == b',')
.map(|kw| kw.trim_ascii())
.filter(|kw| !kw.is_empty())
})
.any(|kw| kw.eq_ignore_ascii_case(b"websocket"));
if !is_websocket {
return Err(Error::from_static_msg("not a WebSocket upgrade"));
}
let version = request
.get_header_field_value("sec-websocket-version")
.ok_or_else(|| Error::from_static_msg("missing WS version"))?
.trim_ascii();
if version != b"13" {
return Err(Error::from_static_msg("unsupported WS version"));
}
let key = request
.get_header_field_value("sec-websocket-key")
.ok_or_else(|| Error::from_static_msg("missing WS key"))?
.trim_ascii();
let protocols = request
.get_header_fields("sec-websocket-protocol")
.flat_map(|field| {
field
.value()
.map(|v| v.as_ref())
.unwrap_or(b"")
.split(|&b| b == b',')
.map(|p| p.trim_ascii())
.filter(|p| !p.is_empty())
})
.map(str::from_utf8)
.filter_map(|res| res.ok())
.map(|s| s.to_string())
.collect::<Vec<_>>();
let res = Self {
accept: super::create_accept_token(key),
protocols,
};
Ok(res)
}
#[inline]
pub fn protocols(&self) -> &[String] {
&self.protocols
}
pub fn complete(
self,
protocol: Option<&str>,
input_buffer_capacity: usize,
) -> (FutureServer, OutgoingResponse) {
let is_valid_protocol = if let Some(protocol) = protocol {
self.protocols.iter().any(|p| p == protocol)
} else {
true
};
assert!(is_valid_protocol);
let mut builder = OutgoingResponse::builder()
.set_status(Status::SWITCHING_PROTOCOLS)
.add_header_field(("Connection", "upgrade"))
.add_header_field(("Upgrade", "websocket"))
.add_header_field(("Sec-WebSocket-Accept", self.accept));
if let Some(protocol) = protocol {
builder = builder.add_header_field(("Sec-WebSocket-Protocol", protocol.to_string()));
}
let (response, upgrade) = builder.upgrade();
let server = async move {
upgrade
.await
.map(|upgraded| WebSocket::new(upgraded, AgentRole::Server, input_buffer_capacity))
};
let future = FutureServer {
inner: Box::pin(server),
};
(future, response)
}
}
pub struct FutureServer {
inner: Pin<Box<dyn Future<Output = io::Result<WebSocket>> + Send>>,
}
impl Future for FutureServer {
type Output = io::Result<WebSocket>;
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.inner.poll_unpin(cx)
}
}