use actus_reply::ReplyData;
use futures_util::future::BoxFuture;
use http::{HeaderMap, HeaderValue, Method, header};
use hyper::upgrade::{OnUpgrade, Upgraded};
use hyper_util::rt::TokioIo;
use std::future::Future;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
use tokio_tungstenite::tungstenite::protocol::Role;
pub use tokio_tungstenite::tungstenite::Message;
pub type WebSocket = WebSocketStream<TokioIo<Upgraded>>;
pub(crate) type UpgradeTask = Box<dyn FnOnce(WebSocket) -> BoxFuture<'static, ()> + Send>;
pub fn upgrade<F, Fut>(handler: F) -> ReplyData
where
F: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let task: UpgradeTask = Box::new(move |ws| -> BoxFuture<'static, ()> { Box::pin(handler(ws)) });
ReplyData::Upgrade(Box::new(task))
}
pub(crate) fn is_upgrade_request(method: &Method, headers: &HeaderMap) -> bool {
fn list_contains(headers: &HeaderMap, name: header::HeaderName, needle: &str) -> bool {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.split(',').any(|t| t.trim().eq_ignore_ascii_case(needle)))
}
*method == Method::GET
&& list_contains(headers, header::CONNECTION, "upgrade")
&& list_contains(headers, header::UPGRADE, "websocket")
&& headers.contains_key(header::SEC_WEBSOCKET_KEY)
&& headers
.get(header::SEC_WEBSOCKET_VERSION)
.and_then(|v| v.to_str().ok())
== Some("13")
}
pub(crate) fn accept_key(request_headers: &HeaderMap) -> Option<HeaderValue> {
let key = request_headers.get(header::SEC_WEBSOCKET_KEY)?;
HeaderValue::from_str(&derive_accept_key(key.as_bytes())).ok()
}
pub(crate) async fn run_upgrade(on_upgrade: OnUpgrade, task: UpgradeTask) {
match on_upgrade.await {
Ok(upgraded) => {
let socket =
WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, None).await;
task(socket).await;
}
Err(e) => tracing::warn!("websocket upgrade failed: {}", e),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn headers(pairs: &[(header::HeaderName, &str)]) -> HeaderMap {
let mut h = HeaderMap::new();
for (n, v) in pairs {
h.insert(n.clone(), HeaderValue::from_str(v).unwrap());
}
h
}
#[test]
fn recognizes_a_valid_handshake() {
let h = headers(&[
(header::CONNECTION, "keep-alive, Upgrade"),
(header::UPGRADE, "websocket"),
(header::SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ=="),
(header::SEC_WEBSOCKET_VERSION, "13"),
]);
assert!(is_upgrade_request(&Method::GET, &h));
assert!(!is_upgrade_request(&Method::POST, &h));
assert!(!is_upgrade_request(
&Method::GET,
&headers(&[(header::UPGRADE, "websocket")])
));
}
#[test]
fn derives_the_rfc6455_accept_key() {
let h = headers(&[(header::SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==")]);
assert_eq!(accept_key(&h).unwrap(), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
assert!(accept_key(&HeaderMap::new()).is_none());
}
}