use super::connection::{ConnectionHandler, WebSocket};
use super::frame::Message;
use super::pubsub::ChannelManager;
use super::WebSocketConfig;
use bytes::Bytes;
use http_body_util::Full;
use hyper::header::{
CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE,
};
use hyper::{Request as HyperRequest, Response as HyperResponse, StatusCode};
use sha1::{Digest, Sha1};
use std::future::Future;
use std::sync::Arc;
use tokio::sync::mpsc;
const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
pub struct WebSocketUpgrade<T = ()> {
request: HyperRequest<hyper::body::Incoming>,
data: Option<T>,
headers: Vec<(String, String)>,
protocols: Vec<String>,
config: WebSocketConfig,
channel_manager: Arc<ChannelManager>,
}
impl<T> WebSocketUpgrade<T>
where
T: Send + 'static,
{
pub fn new(request: HyperRequest<hyper::body::Incoming>) -> Self {
Self {
request,
data: None,
headers: Vec::new(),
protocols: Vec::new(),
config: WebSocketConfig::default(),
channel_manager: Arc::new(ChannelManager::new()),
}
}
pub fn with_data(mut self, data: T) -> Self {
self.data = Some(data);
self
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.push((key.into(), value.into()));
self
}
pub fn with_protocols(mut self, protocols: Vec<String>) -> Self {
self.protocols = protocols;
self
}
pub fn with_config(mut self, config: WebSocketConfig) -> Self {
self.config = config;
self
}
pub fn with_channel_manager(mut self, channel_manager: Arc<ChannelManager>) -> Self {
self.channel_manager = channel_manager;
self
}
pub fn on_upgrade<F, Fut>(self, callback: F) -> HyperResponse<Full<Bytes>>
where
F: FnOnce(WebSocket<T>) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
T: Send + 'static,
{
if !is_valid_upgrade_request(&self.request) {
return HyperResponse::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Invalid WebSocket upgrade request")))
.unwrap();
}
let key = match self.request.headers().get(SEC_WEBSOCKET_KEY) {
Some(key) => key.to_str().unwrap_or(""),
None => {
return HyperResponse::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Missing Sec-WebSocket-Key header")))
.unwrap();
}
};
let accept_key = calculate_accept_key(key);
let mut response = HyperResponse::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(UPGRADE, "websocket")
.header(CONNECTION, "Upgrade")
.header(SEC_WEBSOCKET_ACCEPT, accept_key);
for (key, value) in self.headers {
response = response.header(key, value);
}
let response = response.body(Full::new(Bytes::new())).unwrap();
let data = self.data.expect("WebSocket data not set");
let channel_manager = self.channel_manager;
let config = Arc::new(self.config);
tokio::spawn(async move {
match hyper::upgrade::on(self.request).await {
Ok(upgraded) => {
let (handler, sender, mut incoming_rx, mut _drain_rx) =
ConnectionHandler::new(upgraded, channel_manager.clone(), config.clone());
let connection_id = uuid::Uuid::new_v4();
let remote_addr = None;
let ws = WebSocket::new(
data,
sender,
channel_manager,
connection_id,
remote_addr,
config.clone(),
);
let handler_task = tokio::spawn(async move {
if let Err(e) = handler.handle().await {
tracing::error!("WebSocket handler error: {}", e);
}
});
let callback_task = tokio::spawn(async move {
callback(ws).await;
while incoming_rx.recv().await.is_some() {
}
});
let _ = tokio::join!(handler_task, callback_task);
}
Err(e) => {
tracing::error!("WebSocket upgrade error: {}", e);
}
}
});
response
}
pub fn on_upgrade_with_receiver<F, Fut>(self, callback: F) -> HyperResponse<Full<Bytes>>
where
F: FnOnce(
WebSocket<T>,
mpsc::UnboundedReceiver<Message>,
mpsc::UnboundedReceiver<()>,
) -> Fut
+ Send
+ 'static,
Fut: Future<Output = ()> + Send + 'static,
T: Send + 'static,
{
if !is_valid_upgrade_request(&self.request) {
return HyperResponse::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Invalid WebSocket upgrade request")))
.unwrap();
}
let key = match self.request.headers().get(SEC_WEBSOCKET_KEY) {
Some(key) => key.to_str().unwrap_or(""),
None => {
return HyperResponse::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Missing Sec-WebSocket-Key header")))
.unwrap();
}
};
let accept_key = calculate_accept_key(key);
let mut response = HyperResponse::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(UPGRADE, "websocket")
.header(CONNECTION, "Upgrade")
.header(SEC_WEBSOCKET_ACCEPT, accept_key);
for (key, value) in self.headers {
response = response.header(key, value);
}
let response = response.body(Full::new(Bytes::new())).unwrap();
let data = self.data.expect("WebSocket data not set");
let channel_manager = self.channel_manager;
let config = Arc::new(self.config);
tokio::spawn(async move {
match hyper::upgrade::on(self.request).await {
Ok(upgraded) => {
let (handler, sender, incoming_rx, drain_rx) =
ConnectionHandler::new(upgraded, channel_manager.clone(), config.clone());
let connection_id = uuid::Uuid::new_v4();
let remote_addr = None;
let ws = WebSocket::new(
data,
sender,
channel_manager,
connection_id,
remote_addr,
config.clone(),
);
let handler_task = tokio::spawn(async move {
if let Err(e) = handler.handle().await {
tracing::error!("WebSocket handler error: {}", e);
}
});
let callback_task = tokio::spawn(async move {
callback(ws, incoming_rx, drain_rx).await;
});
let _ = tokio::join!(handler_task, callback_task);
}
Err(e) => {
tracing::error!("WebSocket upgrade error: {}", e);
}
}
});
response
}
pub fn build(self) -> HyperResponse<Full<Bytes>>
where
T: Default,
{
self.on_upgrade(|_ws| async {
})
}
}
fn is_valid_upgrade_request(req: &HyperRequest<hyper::body::Incoming>) -> bool {
if req.method() != hyper::Method::GET {
return false;
}
let upgrade = req.headers().get(UPGRADE);
if upgrade.is_none() || upgrade.unwrap().to_str().unwrap_or("").to_lowercase() != "websocket" {
return false;
}
let connection = req.headers().get(CONNECTION);
if connection.is_none() {
return false;
}
let version = req.headers().get(SEC_WEBSOCKET_VERSION);
if version.is_none() || version.unwrap() != "13" {
return false;
}
if req.headers().get(SEC_WEBSOCKET_KEY).is_none() {
return false;
}
true
}
fn calculate_accept_key(key: &str) -> String {
use base64::{engine::general_purpose, Engine as _};
let mut hasher = Sha1::new();
hasher.update(key.as_bytes());
hasher.update(WEBSOCKET_GUID.as_bytes());
let result = hasher.finalize();
general_purpose::STANDARD.encode(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_accept_key() {
let key = "dGhlIHNhbXBsZSBub25jZQ==";
let accept = calculate_accept_key(key);
assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
}
}