use crate::{Conn, WebSocketConfig, WebSocketConn};
use std::{
borrow::Cow,
error::Error,
fmt::{self, Display},
ops::{Deref, DerefMut},
};
use trillium_http::{
KnownHeaderName::{
Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketVersion,
Upgrade as UpgradeHeader,
},
Method, Status, Upgrade, Version,
};
pub use trillium_websockets::Message;
use trillium_websockets::{Role, websocket_accept_hash, websocket_key};
impl Conn {
fn set_websocket_upgrade_headers_h1(&mut self) {
let headers = self.request_headers_mut();
headers.try_insert(UpgradeHeader, "websocket");
headers.try_insert(Connection, "upgrade");
headers.try_insert(SecWebsocketVersion, "13");
headers.try_insert(SecWebsocketKey, websocket_key());
}
pub async fn into_websocket(self) -> Result<WebSocketConn, WebSocketUpgradeError> {
self.into_websocket_with_config(WebSocketConfig::default())
.await
}
pub async fn into_websocket_with_config(
self,
config: WebSocketConfig,
) -> Result<WebSocketConn, WebSocketUpgradeError> {
if self.status().is_some() {
return Err(WebSocketUpgradeError::new(self, ErrorKind::AlreadyExecuted));
}
match self.http_version() {
Version::Http2 | Version::Http3 => self.into_websocket_extended_connect(config).await,
_ => self.into_websocket_h1(config).await,
}
}
async fn into_websocket_h1(
mut self,
config: WebSocketConfig,
) -> Result<WebSocketConn, WebSocketUpgradeError> {
self.set_websocket_upgrade_headers_h1();
if let Err(e) = (&mut self).await {
return Err(WebSocketUpgradeError::new(self, e.into()));
}
let status = self.status().expect("Response did not include status");
if status != Status::SwitchingProtocols {
return Err(WebSocketUpgradeError::new(self, ErrorKind::Status(status)));
}
let key = self
.request_headers()
.get_str(SecWebsocketKey)
.expect("Request did not include Sec-WebSocket-Key");
let accept_key = websocket_accept_hash(key);
if self.response_headers().get_str(SecWebsocketAccept) != Some(&accept_key) {
return Err(WebSocketUpgradeError::new(self, ErrorKind::InvalidAccept));
}
let peer_ip = self.peer_addr().map(|addr| addr.ip());
let mut conn = WebSocketConn::new(Upgrade::from(self), Some(config), Role::Client).await;
conn.set_peer_ip(peer_ip);
Ok(conn)
}
async fn into_websocket_extended_connect(
mut self,
config: WebSocketConfig,
) -> Result<WebSocketConn, WebSocketUpgradeError> {
self.request_headers_mut()
.try_insert(SecWebsocketVersion, "13");
self.set_method(Method::Connect);
self.protocol = Some(Cow::Borrowed("websocket"));
if let Err(e) = (&mut self).await {
let kind = match e {
trillium_http::Error::ExtendedConnectUnsupported => {
ErrorKind::ExtendedConnectUnsupported
}
other => other.into(),
};
return Err(WebSocketUpgradeError::new(self, kind));
}
let status = self.status().expect("Response did not include status");
if status != Status::Ok {
return Err(WebSocketUpgradeError::new(self, ErrorKind::Status(status)));
}
let peer_ip = self.peer_addr().map(|addr| addr.ip());
let mut conn = WebSocketConn::new(Upgrade::from(self), Some(config), Role::Client).await;
conn.set_peer_ip(peer_ip);
Ok(conn)
}
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum ErrorKind {
#[error(transparent)]
Http(#[from] trillium_http::Error),
#[error("Unexpected response status {0} for websocket upgrade")]
Status(Status),
#[error("Response Sec-WebSocket-Accept was missing or invalid")]
InvalidAccept,
#[error(
"Conn::into_websocket called after execution — build the conn and await into_websocket \
instead of awaiting the conn separately"
)]
AlreadyExecuted,
#[error("peer does not support extended CONNECT")]
ExtendedConnectUnsupported,
}
#[derive(Debug)]
pub struct WebSocketUpgradeError {
pub kind: ErrorKind,
conn: Box<Conn>,
}
impl WebSocketUpgradeError {
fn new(conn: Conn, kind: ErrorKind) -> Self {
let conn = Box::new(conn);
Self { conn, kind }
}
}
impl From<WebSocketUpgradeError> for Conn {
fn from(value: WebSocketUpgradeError) -> Self {
*value.conn
}
}
impl Deref for WebSocketUpgradeError {
type Target = Conn;
fn deref(&self) -> &Self::Target {
&self.conn
}
}
impl DerefMut for WebSocketUpgradeError {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.conn
}
}
impl Error for WebSocketUpgradeError {}
impl Display for WebSocketUpgradeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.kind.fmt(f)
}
}