use axum::extract::ws::WebSocket;
use futures_util::stream::SplitStream;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_stream::Stream;
#[derive(Debug)]
pub struct AxumStream(SplitStream<WebSocket>);
impl From<SplitStream<WebSocket>> for AxumStream {
fn from(stream: SplitStream<WebSocket>) -> Self {
AxumStream(stream)
}
}
impl Into<SplitStream<WebSocket>> for AxumStream {
fn into(self) -> SplitStream<WebSocket> {
self.0 }
}
impl Stream for AxumStream {
type Item = Result<Vec<u8>, axum::Error>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.0).poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(res)) => match res {
Ok(axum::extract::ws::Message::Binary(bin)) => {
Poll::Ready(Some(Ok(bin)))
}
Ok(axum::extract::ws::Message::Close(Some(
axum::extract::ws::CloseFrame { code, reason },
))) => Poll::Ready(Some(Err(axum::Error::new(
std::io::Error::new(
std::io::ErrorKind::Other,
format!("Client Disconnected {:?} {:?}", code, {
if reason.is_empty() {
None
} else {
Some(reason)
}
}),
),
)))),
Ok(_) => Poll::Ready(Some(Err(axum::Error::new(
std::io::Error::new(
std::io::ErrorKind::Other,
format!("non-binary message received {:?}", res),
),
)))),
Err(e) => Poll::Ready(Some(Err(e))),
},
}
}
}