use std::{fmt::Debug, sync::Arc};
use futures_util::stream::{SplitSink, SplitStream};
use crate::transport::{BoxedWsTransport, Message};
pub(crate) type MessageWriter = SplitSink<BoxedWsTransport, Message>;
pub type MessageReader = SplitStream<BoxedWsTransport>;
pub type MessageHandler = Arc<dyn Fn(Message) + Send + Sync>;
pub type PingHandler = Arc<dyn Fn(Vec<u8>) + Send + Sync>;
#[must_use]
pub fn channel_message_handler() -> (
MessageHandler,
tokio::sync::mpsc::UnboundedReceiver<tokio_tungstenite::tungstenite::Message>,
) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let handler: MessageHandler = Arc::new(move |msg: Message| {
match tokio_tungstenite::tungstenite::Message::try_from(msg) {
Ok(legacy) => {
if let Err(e) = tx.send(legacy) {
log::debug!("Failed to send message to channel: {e}");
}
}
Err(e) => log::debug!("Dropping message that failed legacy conversion: {e}"),
}
});
(handler, rx)
}
pub(crate) enum WriterCommand {
Update(MessageWriter, tokio::sync::oneshot::Sender<bool>),
Send(Message),
}
impl Debug for WriterCommand {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Update(_, _) => f.debug_tuple("Update").field(&"<writer>").finish(),
Self::Send(msg) => f.debug_tuple("Send").field(msg).finish(),
}
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use rstest::rstest;
use tokio_tungstenite::tungstenite::Message as TgMessage;
use super::*;
#[rstest]
fn channel_handler_drops_invalid_utf8_text_without_panic() {
let (handler, mut rx) = channel_message_handler();
handler(Message::Text(Bytes::from_static(&[0xFF, 0xFE])));
handler(Message::Binary(Bytes::from_static(b"ok")));
let received = rx.try_recv().expect("binary should arrive");
assert!(matches!(received, TgMessage::Binary(ref b) if b.as_ref() == b"ok"));
assert!(rx.try_recv().is_err(), "no further messages expected");
}
#[rstest]
fn channel_handler_forwards_valid_text() {
let (handler, mut rx) = channel_message_handler();
handler(Message::text("hello"));
let received = rx.try_recv().expect("text should arrive");
match received {
TgMessage::Text(t) => assert_eq!(t.as_str(), "hello"),
other => panic!("expected text, was {other:?}"),
}
}
#[rstest]
fn writer_command_send_debug_includes_message() {
let cmd = WriterCommand::Send(Message::text("hi"));
let formatted = format!("{cmd:?}");
assert!(
formatted.starts_with("Send("),
"unexpected debug output: {formatted}"
);
assert!(
formatted.contains("Text"),
"debug output should retain the message variant: {formatted}"
);
assert!(
formatted.contains("hi"),
"debug output should retain the payload bytes: {formatted}"
);
}
}