use crate::upgrade::{validate_upgrade_request, WebSocketUpgrade};
use hyper::upgrade::OnUpgrade;
use rustapi_core::{ApiError, FromRequest, Request, Result};
use rustapi_openapi::{Operation, OperationModifier};
pub struct WebSocket {
sec_key: String,
protocols: Vec<String>,
extensions: Option<String>,
on_upgrade: Option<OnUpgrade>,
}
impl WebSocket {
pub fn on_upgrade<F, Fut>(mut self, callback: F) -> WebSocketUpgrade
where
F: FnOnce(crate::WebSocketStream) -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let upgrade = WebSocketUpgrade::new(self.sec_key, self.extensions, self.on_upgrade.take());
let upgrade = if let Some(protocol) = self.protocols.first() {
upgrade.protocol(protocol)
} else {
upgrade
};
upgrade.on_upgrade(callback)
}
pub fn protocols(&self) -> &[String] {
&self.protocols
}
pub fn has_protocol(&self, protocol: &str) -> bool {
self.protocols.iter().any(|p| p == protocol)
}
}
impl FromRequest for WebSocket {
async fn from_request(req: &mut Request) -> Result<Self> {
let headers = req.headers();
let method = req.method();
let sec_key = validate_upgrade_request(method, headers)
.map_err(ApiError::from)?
.to_string();
let protocols = headers
.get("Sec-WebSocket-Protocol")
.and_then(|v| v.to_str().ok())
.map(|s| s.split(',').map(|p| p.trim().to_string()).collect())
.unwrap_or_default();
let extensions = headers
.get("Sec-WebSocket-Extensions")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let on_upgrade = req.extensions_mut().remove::<OnUpgrade>();
if let Some(stream) = req.take_stream() {
use http_body_util::BodyExt;
let _ = stream.collect().await;
}
Ok(Self {
sec_key,
protocols,
extensions,
on_upgrade,
})
}
}
impl OperationModifier for WebSocket {
fn update_operation(_op: &mut Operation) {
}
}