#[cfg(feature = "router")]
use axum::extract::ws::{Message, Utf8Bytes};
#[cfg(feature = "router")]
use futures_util::{Sink, SinkExt, Stream, StreamExt};
use serde::{Serialize, de::DeserializeOwned};
#[cfg(feature = "router")]
use std::fmt::Display;
use std::ops::Deref;
use tracing::debug;
pub(super) trait WebsocketStreamExt<T: DeserializeOwned, Text: Deref<Target = str>> {
async fn next_text(&mut self) -> Option<Text>;
async fn next_json(&mut self) -> Option<T> {
let text = self.next_text().await?;
match serde_json::from_slice(text.as_bytes()) {
Ok(value) => Some(value),
Err(err) => {
debug!("{}", err);
None
}
}
}
}
#[cfg(feature = "router")]
impl<S, T> WebsocketStreamExt<T, Utf8Bytes> for S
where
S: Stream<Item = Result<Message, axum::Error>> + Unpin,
T: DeserializeOwned,
{
async fn next_text(&mut self) -> Option<Utf8Bytes> {
let msg = if let Some(msg) = self.next().await {
match msg {
Ok(msg) => msg,
Err(err) => {
debug!("{}", err);
return None;
}
}
} else {
return None;
};
let msg_text = match msg.into_text() {
Ok(text) => text,
Err(err) => {
debug!("{}", err);
return None;
}
};
Some(msg_text)
}
}
pub(super) trait WebsocketSinkExt<T: Serialize> {
async fn send_json(&mut self, value: &T) -> Option<()>;
}
#[cfg(feature = "router")]
impl<S, T> WebsocketSinkExt<T> for S
where
S: Sink<Message> + Unpin,
S::Error: Display,
T: Serialize,
{
async fn send_json(&mut self, value: &T) -> Option<()> {
let json_str = match serde_json::to_string(value).into() {
Ok(json_str) => json_str,
Err(err) => {
debug!("{}", err);
return None;
}
};
match self.send(Message::Text(json_str.into())).await {
Ok(_) => Some(()),
Err(err) => {
debug!("{}", err);
None
}
}
}
}