use std::future::Future;
use http::StatusCode;
use hyper::upgrade::OnUpgrade;
use tokio_tungstenite::tungstenite::protocol::Role;
use tokio_tungstenite::WebSocketStream;
use crate::body::{empty_body, BoxBody};
use crate::extract::FromRequestParts;
pub struct WebSocketUpgrade {
on_upgrade: OnUpgrade,
sec_websocket_key: String,
}
impl FromRequestParts for WebSocketUpgrade {
type Error = (StatusCode, String);
fn from_request_parts(parts: &http::request::Parts) -> Result<Self, Self::Error> {
let is_upgrade = parts
.headers
.get(http::header::CONNECTION)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.to_lowercase().contains("upgrade"));
let is_websocket = parts
.headers
.get(http::header::UPGRADE)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.to_lowercase() == "websocket");
if !is_upgrade || !is_websocket {
return Err((
StatusCode::BAD_REQUEST,
"not a valid WebSocket upgrade request".to_string(),
));
}
let key = parts
.headers
.get("sec-websocket-key")
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
(
StatusCode::BAD_REQUEST,
"missing Sec-WebSocket-Key header".to_string(),
)
})?
.to_string();
let on_upgrade = parts
.extensions
.get::<OnUpgrade>()
.cloned()
.ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"upgrade not available — is this a hyper connection?".to_string(),
)
})?;
Ok(WebSocketUpgrade {
on_upgrade,
sec_websocket_key: key,
})
}
}
impl WebSocketUpgrade {
pub fn on_upgrade_typed<S, F, Fut>(self, callback: F) -> http::Response<BoxBody>
where
S: typeway_core::session::SessionType + std::marker::Send + 'static,
F: FnOnce(crate::typed_ws::TypedWebSocket<S>) -> Fut + std::marker::Send + 'static,
Fut: Future<Output = ()> + std::marker::Send + 'static,
{
self.on_upgrade(move |raw_ws| async move {
let typed = crate::typed_ws::TypedWebSocket::new(raw_ws);
callback(typed).await;
})
}
pub fn on_upgrade<F, Fut>(self, callback: F) -> http::Response<BoxBody>
where
F: FnOnce(WebSocketStream<hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>>) -> Fut
+ Send
+ 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let accept_key = tungstenite_accept_key(&self.sec_websocket_key);
tokio::spawn(async move {
match self.on_upgrade.await {
Ok(upgraded) => {
let io = hyper_util::rt::TokioIo::new(upgraded);
let ws = WebSocketStream::from_raw_socket(io, Role::Server, None).await;
callback(ws).await;
}
Err(e) => {
eprintln!("WebSocket upgrade failed: {e}");
}
}
});
let mut res = http::Response::new(empty_body());
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
res.headers_mut().insert(
http::header::CONNECTION,
http::HeaderValue::from_static("upgrade"),
);
res.headers_mut().insert(
http::header::UPGRADE,
http::HeaderValue::from_static("websocket"),
);
if let Ok(val) = http::HeaderValue::from_str(&accept_key) {
res.headers_mut().insert("sec-websocket-accept", val);
}
res
}
}
fn tungstenite_accept_key(key: &str) -> String {
let mut hasher = sha1_smol::Sha1::new();
hasher.update(key.as_bytes());
hasher.update(b"258EAFA5-E914-47DA-95CA-5AB5DC11CE56");
base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
hasher.digest().bytes(),
)
}