use axum::extract::ws::{Message, WebSocket};
use futures_util::stream::SplitSink;
use std::pin::Pin;
use std::task::{Context, Poll};
#[repr(transparent)]
#[derive(Debug)]
pub struct AxumSink(SplitSink<WebSocket, Message>);
impl From<SplitSink<WebSocket, Message>> for AxumSink {
fn from(sink: SplitSink<WebSocket, Message>) -> Self {
AxumSink(sink)
}
}
impl Into<SplitSink<WebSocket, Message>> for AxumSink {
fn into(self) -> SplitSink<WebSocket, Message> {
self.0 }
}
impl futures_util::Sink<Vec<u8>> for AxumSink {
type Error = y_sync::sync::Error;
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
match Pin::new(&mut self.0).poll_ready(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
Poll::Ready(Err(y_sync::sync::Error::Other(e.into())))
}
Poll::Ready(_) => Poll::Ready(Ok(())),
}
}
fn start_send(
mut self: Pin<&mut Self>,
item: Vec<u8>,
) -> Result<(), Self::Error> {
if let Err(e) = Pin::new(&mut self.0)
.start_send(axum::extract::ws::Message::Binary(item))
{
Err(y_sync::sync::Error::Other(e.into()))
} else {
Ok(())
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
match Pin::new(&mut self.0).poll_flush(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
Poll::Ready(Err(y_sync::sync::Error::Other(e.into())))
}
Poll::Ready(_) => Poll::Ready(Ok(())),
}
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
match Pin::new(&mut self.0).poll_close(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
Poll::Ready(Err(y_sync::sync::Error::Other(e.into())))
}
Poll::Ready(_) => Poll::Ready(Ok(())),
}
}
}