#![warn(
clippy::all,
clippy::dbg_macro,
clippy::todo,
clippy::empty_enum,
clippy::enum_glob_use,
clippy::mem_forget,
clippy::unused_self,
clippy::filter_map_next,
clippy::needless_continue,
clippy::needless_borrow,
clippy::match_wildcard_for_single_variants,
clippy::if_let_mutex,
clippy::mismatched_target_os,
clippy::await_holding_lock,
clippy::match_on_vec_items,
clippy::imprecise_flops,
clippy::suboptimal_flops,
clippy::lossy_float_literal,
clippy::rest_pat_in_fully_bound_structs,
clippy::fn_params_excessive_bools,
clippy::exit,
clippy::inefficient_to_string,
clippy::linkedlist,
clippy::macro_use_imports,
clippy::option_option,
clippy::verbose_file_reads,
clippy::unnested_or_patterns,
rust_2018_idioms,
future_incompatible,
nonstandard_style,
missing_debug_implementations,
missing_docs
)]
#![deny(unreachable_pub, private_in_public)]
#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
#![forbid(unsafe_code)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![cfg_attr(test, allow(clippy::float_cmp))]
use axum::{
async_trait,
extract::{ws, FromRequest, RequestParts},
response::IntoResponse,
};
use futures_util::{Sink, SinkExt, Stream, StreamExt};
use serde::{de::DeserializeOwned, Serialize};
use std::{
error::Error as StdError,
fmt,
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
#[allow(unused_macros)]
macro_rules! with_and_without_json {
(
$(#[$m:meta])*
pub struct $name:ident<S, R, C = JsonCodec> {
$(
$ident:ident : $ty:ty,
)*
}
) => {
$(#[$m])*
#[cfg(feature = "json")]
pub struct $name<S, R, C = JsonCodec> {
$(
$ident : $ty,
)*
}
$(#[$m])*
#[cfg(not(feature = "json"))]
pub struct $name<S, R, C> {
$(
$ident : $ty,
)*
}
}
}
with_and_without_json! {
pub struct WebSocketUpgrade<S, R, C = JsonCodec> {
upgrade: ws::WebSocketUpgrade,
_marker: PhantomData<fn() -> (S, R, C)>,
}
}
#[async_trait]
impl<S, R, C, B> FromRequest<B> for WebSocketUpgrade<S, R, C>
where
B: Send,
{
type Rejection = <ws::WebSocketUpgrade as FromRequest<B>>::Rejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let upgrade = FromRequest::from_request(req).await?;
Ok(Self {
upgrade,
_marker: PhantomData,
})
}
}
impl<S, R, C> WebSocketUpgrade<S, R, C> {
pub fn on_upgrade<F, Fut>(self, callback: F) -> impl IntoResponse
where
F: FnOnce(WebSocket<S, R, C>) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
S: Send,
R: Send,
{
self.upgrade
.on_upgrade(|socket| async move {
let socket = WebSocket {
socket,
_marker: PhantomData,
};
callback(socket).await
})
.into_response()
}
pub fn map<F>(mut self, f: F) -> Self
where
F: FnOnce(ws::WebSocketUpgrade) -> ws::WebSocketUpgrade,
{
self.upgrade = f(self.upgrade);
self
}
pub fn into_inner(self) -> ws::WebSocketUpgrade {
self.upgrade
}
}
impl<S, R, C> fmt::Debug for WebSocketUpgrade<S, R, C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WebSocketUpgrade")
.field("upgrade", &self.upgrade)
.finish()
}
}
with_and_without_json! {
pub struct WebSocket<S, R, C = JsonCodec> {
socket: ws::WebSocket,
_marker: PhantomData<fn() -> (S, R, C)>,
}
}
impl<S, R, C> WebSocket<S, R, C> {
pub async fn recv(&mut self) -> Option<Result<Message<R>, Error<C::Error>>>
where
R: DeserializeOwned,
C: Codec,
{
self.next().await
}
pub async fn send(&mut self, msg: Message<S>) -> Result<(), Error<C::Error>>
where
S: Serialize,
C: Codec,
{
SinkExt::send(self, msg).await
}
pub async fn close(self) -> Result<(), Error<C::Error>>
where
C: Codec,
{
self.socket.close().await.map_err(Error::Ws)
}
pub fn into_inner(self) -> ws::WebSocket {
self.socket
}
}
impl<S, R, C> fmt::Debug for WebSocket<S, R, C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WebSocket")
.field("socket", &self.socket)
.finish()
}
}
impl<S, R, C> Stream for WebSocket<S, R, C>
where
R: DeserializeOwned,
C: Codec,
{
type Item = Result<Message<R>, Error<C::Error>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let msg = futures_util::ready!(Pin::new(&mut self.socket)
.poll_next(cx)
.map_err(Error::Ws)?);
if let Some(msg) = msg {
let msg = match msg {
ws::Message::Text(msg) => msg.into_bytes(),
ws::Message::Binary(bytes) => bytes,
ws::Message::Close(frame) => {
return Poll::Ready(Some(Ok(Message::Close(frame))));
}
ws::Message::Ping(buf) => {
return Poll::Ready(Some(Ok(Message::Ping(buf))));
}
ws::Message::Pong(buf) => {
return Poll::Ready(Some(Ok(Message::Pong(buf))));
}
};
let msg = C::decode(msg).map(Message::Item).map_err(Error::Codec);
Poll::Ready(Some(msg))
} else {
Poll::Ready(None)
}
}
}
impl<S, R, C> Sink<Message<S>> for WebSocket<S, R, C>
where
S: Serialize,
C: Codec,
{
type Error = Error<C::Error>;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.socket).poll_ready(cx).map_err(Error::Ws)
}
fn start_send(mut self: Pin<&mut Self>, msg: Message<S>) -> Result<(), Self::Error> {
let msg = match msg {
Message::Item(buf) => ws::Message::Binary(C::encode(buf).map_err(Error::Codec)?),
Message::Ping(buf) => ws::Message::Ping(buf),
Message::Pong(buf) => ws::Message::Pong(buf),
Message::Close(frame) => ws::Message::Close(frame),
};
Pin::new(&mut self.socket)
.start_send(msg)
.map_err(Error::Ws)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.socket).poll_flush(cx).map_err(Error::Ws)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.socket).poll_close(cx).map_err(Error::Ws)
}
}
pub trait Codec {
type Error;
fn encode<S>(msg: S) -> Result<Vec<u8>, Self::Error>
where
S: Serialize;
fn decode<R>(buf: Vec<u8>) -> Result<R, Self::Error>
where
R: DeserializeOwned;
}
#[cfg(feature = "json")]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
#[derive(Debug)]
#[non_exhaustive]
pub struct JsonCodec;
#[cfg(feature = "json")]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
impl Codec for JsonCodec {
type Error = serde_json::Error;
fn encode<S>(msg: S) -> Result<Vec<u8>, Self::Error>
where
S: Serialize,
{
serde_json::to_vec(&msg)
}
fn decode<R>(buf: Vec<u8>) -> Result<R, Self::Error>
where
R: DeserializeOwned,
{
serde_json::from_slice(&buf)
}
}
#[derive(Debug)]
pub enum Error<E> {
Ws(axum::Error),
Codec(E),
}
impl<E> fmt::Display for Error<E>
where
E: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::Ws(inner) => inner.fmt(f),
Error::Codec(inner) => inner.fmt(f),
}
}
}
impl<E> StdError for Error<E>
where
E: StdError + 'static,
{
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
Error::Ws(inner) => Some(inner),
Error::Codec(inner) => Some(inner),
}
}
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum Message<T> {
Item(T),
Ping(Vec<u8>),
Pong(Vec<u8>),
Close(Option<ws::CloseFrame<'static>>),
}