use std::future::Future;
use std::sync::Arc;
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, StreamExt};
use http::Response;
use hyper::Request;
use hyper::body::Incoming;
use hyper::upgrade::Upgraded;
use hyper_tungstenite::HyperWebsocket;
use hyper_util::rt::TokioIo;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite;
use crate::error::Error;
use crate::extract::FromRequest;
use crate::extract::PathParams;
use crate::response::BoxBody;
use crate::state::AppState;
type WsStream = WebSocketStream<TokioIo<Upgraded>>;
pub struct WebSocketUpgrade {
response: Response<BoxBody>,
websocket: HyperWebsocket,
}
impl FromRequest for WebSocketUpgrade {
async fn from_request(
req: Request<Incoming>,
_params: &PathParams,
_state: &Arc<AppState>,
) -> Result<Self, Error> {
let mut req = req;
if !hyper_tungstenite::is_upgrade_request(&req) {
return Err(Error::bad_request("Not a WebSocket upgrade request"));
}
let (response, websocket) = hyper_tungstenite::upgrade(&mut req, None)
.map_err(|e| Error::internal(format!("WebSocket upgrade failed: {e}")))?;
Ok(Self {
response,
websocket,
})
}
}
impl WebSocketUpgrade {
pub fn on_upgrade<F, Fut>(self, callback: F) -> Response<BoxBody>
where
F: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let websocket = self.websocket;
tokio::spawn(async move {
match websocket.await {
Ok(stream) => callback(WebSocket { inner: stream }).await,
Err(e) => tracing::error!("WebSocket upgrade failed: {e}"),
}
});
self.response
}
}
pub struct WebSocket {
inner: WsStream,
}
impl WebSocket {
pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
self.inner.next().await.map(|r| {
r.map(Message::from_tungstenite)
.map_err(|e| Error::internal(format!("WebSocket recv error: {e}")))
})
}
pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
self.inner
.send(msg.into_tungstenite())
.await
.map_err(|e| Error::internal(format!("WebSocket send error: {e}")))
}
pub async fn close(mut self) -> Result<(), Error> {
self.inner
.send(tungstenite::Message::Close(None))
.await
.map_err(|e| Error::internal(format!("WebSocket close error: {e}")))
}
pub fn split(self) -> (WsSender, WsReceiver) {
let (sink, stream) = self.inner.split();
(WsSender { inner: sink }, WsReceiver { inner: stream })
}
}
pub struct WsSender {
inner: SplitSink<WsStream, tungstenite::Message>,
}
impl WsSender {
pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
self.inner
.send(msg.into_tungstenite())
.await
.map_err(|e| Error::internal(format!("WebSocket send error: {e}")))
}
}
pub struct WsReceiver {
inner: SplitStream<WsStream>,
}
impl WsReceiver {
pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
self.inner.next().await.map(|r| {
r.map(Message::from_tungstenite)
.map_err(|e| Error::internal(format!("WebSocket recv error: {e}")))
})
}
}
#[derive(Debug, Clone)]
pub enum Message {
Text(String),
Binary(Vec<u8>),
Ping(Vec<u8>),
Pong(Vec<u8>),
Close(Option<CloseFrame>),
}
#[derive(Debug, Clone)]
pub struct CloseFrame {
pub code: u16,
pub reason: String,
}
impl Message {
pub fn is_text(&self) -> bool {
matches!(self, Self::Text(_))
}
pub fn is_binary(&self) -> bool {
matches!(self, Self::Binary(_))
}
pub fn is_close(&self) -> bool {
matches!(self, Self::Close(_))
}
pub fn is_ping(&self) -> bool {
matches!(self, Self::Ping(_))
}
pub fn is_pong(&self) -> bool {
matches!(self, Self::Pong(_))
}
pub fn as_text(&self) -> Option<&str> {
match self {
Self::Text(s) => Some(s),
_ => None,
}
}
pub fn as_bytes(&self) -> Option<&[u8]> {
match self {
Self::Binary(b) => Some(b),
_ => None,
}
}
fn from_tungstenite(msg: tungstenite::Message) -> Self {
match msg {
tungstenite::Message::Text(s) => Self::Text(s.to_string()),
tungstenite::Message::Binary(b) => Self::Binary(b.to_vec()),
tungstenite::Message::Ping(b) => Self::Ping(b.to_vec()),
tungstenite::Message::Pong(b) => Self::Pong(b.to_vec()),
tungstenite::Message::Close(f) => Self::Close(f.map(|f| CloseFrame {
code: f.code.into(),
reason: f.reason.to_string(),
})),
tungstenite::Message::Frame(_) => Self::Ping(Vec::new()),
}
}
pub(crate) fn into_tungstenite(self) -> tungstenite::Message {
match self {
Self::Text(s) => tungstenite::Message::Text(s.into()),
Self::Binary(b) => tungstenite::Message::Binary(b.into()),
Self::Ping(b) => tungstenite::Message::Ping(b.into()),
Self::Pong(b) => tungstenite::Message::Pong(b.into()),
Self::Close(f) => {
tungstenite::Message::Close(f.map(|f| tungstenite::protocol::CloseFrame {
code: f.code.into(),
reason: f.reason.into(),
}))
}
}
}
}
impl From<String> for Message {
fn from(s: String) -> Self {
Self::Text(s)
}
}
impl From<&str> for Message {
fn from(s: &str) -> Self {
Self::Text(s.to_owned())
}
}
impl From<Vec<u8>> for Message {
fn from(b: Vec<u8>) -> Self {
Self::Binary(b)
}
}