use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, StreamExt};
use std::future::Future;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::task::JoinHandle;
use tungstenite::Message;
pub type SharedMessage = Arc<Message>;
use crate::connection::WsStream;
use crate::error::{ExtensionError, ReceiveError, SendError};
#[derive(Debug, Clone)]
pub struct DispatcherConfig {
pub receive_timeout: Duration,
pub broadcast_capacity: usize,
pub send_buffer_capacity: usize,
pub processor_error_policy: ProcessorErrorPolicy,
}
#[derive(Debug, Clone, Copy)]
pub enum ProcessorErrorPolicy {
Ignore,
Disconnect,
}
impl Default for DispatcherConfig {
fn default() -> Self {
Self {
receive_timeout: Duration::from_secs(30),
broadcast_capacity: 1024,
send_buffer_capacity: 256,
processor_error_policy: ProcessorErrorPolicy::Ignore,
}
}
}
impl DispatcherConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_receive_timeout(mut self, timeout: Duration) -> Self {
self.receive_timeout = timeout;
self
}
#[must_use]
pub const fn with_broadcast_capacity(mut self, capacity: usize) -> Self {
self.broadcast_capacity = capacity;
self
}
#[must_use]
pub const fn with_send_buffer_capacity(mut self, capacity: usize) -> Self {
self.send_buffer_capacity = capacity;
self
}
#[must_use]
pub const fn with_processor_error_policy(mut self, policy: ProcessorErrorPolicy) -> Self {
self.processor_error_policy = policy;
self
}
}
struct SenderState<S: WsStream> {
send_task: Option<JoinHandle<()>>,
send_tx: Option<mpsc::Sender<Message>>,
_marker: PhantomData<S>,
}
pub struct MessageDispatcher<S: WsStream = crate::connection::DefaultWsStream> {
config: DispatcherConfig,
sender_state: Arc<RwLock<SenderState<S>>>,
is_connected: Arc<AtomicBool>,
message_tx: broadcast::Sender<SharedMessage>,
}
#[allow(clippy::future_not_send)]
impl<S: WsStream> MessageDispatcher<S> {
#[must_use]
pub fn new(config: DispatcherConfig) -> Self {
let (message_tx, _) = broadcast::channel(config.broadcast_capacity);
Self {
config,
sender_state: Arc::new(RwLock::new(SenderState::<S> {
send_task: None,
send_tx: None,
_marker: PhantomData,
})),
is_connected: Arc::new(AtomicBool::new(false)),
message_tx,
}
}
pub async fn attach(&self, sender: SplitSink<S, Message>) {
let (tx, mut rx) = mpsc::channel::<Message>(self.config.send_buffer_capacity);
let connected = self.is_connected.clone();
let send_task = tokio::spawn(async move {
let mut sink = sender;
while let Some(msg) = rx.recv().await {
if let Err(e) = sink.send(msg).await {
tracing::debug!(error = ?e, "Dispatcher send task encountered error");
connected.store(false, Ordering::Release);
break;
}
}
});
{
let mut state = self.sender_state.write().await;
if let Some(handle) = state.send_task.take() {
handle.abort();
}
state.send_tx = Some(tx);
state.send_task = Some(send_task);
}
self.is_connected.store(true, Ordering::Release);
tracing::debug!("Message dispatcher attached");
}
pub async fn detach(&self) {
self.is_connected.store(false, Ordering::Release);
{
let mut state = self.sender_state.write().await;
state.send_tx = None;
if let Some(handle) = state.send_task.take() {
handle.abort();
}
}
tracing::debug!("Message dispatcher detached");
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.is_connected.load(Ordering::Acquire)
}
pub async fn send(&self, msg: Message) -> Result<(), SendError> {
if !self.is_connected() {
return Err(SendError::NotConnected);
}
let tx = {
let state = self.sender_state.read().await;
state.send_tx.clone()
};
match tx {
Some(tx) => tx.send(msg).await.map_err(|_| SendError::ChannelClosed),
None => Err(SendError::NotConnected),
}
}
#[must_use]
pub fn subscribe(&self) -> broadcast::Receiver<SharedMessage> {
self.message_tx.subscribe()
}
#[must_use]
pub fn subscriber_count(&self) -> usize {
self.message_tx.receiver_count()
}
pub async fn receive_loop(&self, mut receiver: SplitStream<S>) -> Result<(), ReceiveError> {
let timeout = self.config.receive_timeout;
loop {
let result = tokio::time::timeout(timeout, receiver.next()).await;
match result {
Ok(Some(Ok(msg))) => {
let _ = self.message_tx.send(Arc::new(msg));
}
Ok(Some(Err(e))) => {
tracing::debug!(error = ?e, "WebSocket receive error");
return Err(ReceiveError::WebSocket(e.to_string()));
}
Ok(None) => {
tracing::debug!("WebSocket stream closed");
return Err(ReceiveError::StreamClosed);
}
Err(_) => {
tracing::debug!(timeout = ?timeout, "Receive timeout");
return Err(ReceiveError::Timeout(timeout));
}
}
}
}
pub async fn receive_loop_with_activity<F>(
&self,
mut receiver: SplitStream<S>,
on_activity: F,
) -> Result<(), ReceiveError>
where
F: Fn() + Send + Sync,
{
let timeout = self.config.receive_timeout;
loop {
let result = tokio::time::timeout(timeout, receiver.next()).await;
match result {
Ok(Some(Ok(msg))) => {
on_activity();
let _ = self.message_tx.send(Arc::new(msg));
}
Ok(Some(Err(e))) => {
return Err(ReceiveError::WebSocket(e.to_string()));
}
Ok(None) => {
return Err(ReceiveError::StreamClosed);
}
Err(_) => {
return Err(ReceiveError::Timeout(timeout));
}
}
}
}
pub async fn receive_loop_with_processor<FAct, FActFut, FProc, FProcFut>(
&self,
mut receiver: SplitStream<S>,
on_activity: FAct,
processor: FProc,
) -> Result<(), ReceiveError>
where
FAct: Fn() -> FActFut + Send + Sync,
FActFut: Future<Output = ()> + Send,
FProc: Fn(Message) -> FProcFut + Send + Sync,
FProcFut: Future<Output = Result<Option<Message>, ExtensionError>> + Send,
{
let timeout = self.config.receive_timeout;
loop {
let result = tokio::time::timeout(timeout, receiver.next()).await;
match result {
Ok(Some(Ok(msg))) => {
on_activity().await;
match processor(msg).await {
Ok(Some(broadcast_msg)) => {
let _ = self.message_tx.send(Arc::new(broadcast_msg));
}
Ok(None) => {
}
Err(e) => match self.config.processor_error_policy {
ProcessorErrorPolicy::Ignore => {
tracing::warn!(error = ?e, "Message processor failed");
}
ProcessorErrorPolicy::Disconnect => {
return Err(ReceiveError::WebSocket(e.to_string()));
}
},
}
}
Ok(Some(Err(e))) => {
return Err(ReceiveError::WebSocket(e.to_string()));
}
Ok(None) => {
return Err(ReceiveError::StreamClosed);
}
Err(_) => {
return Err(ReceiveError::Timeout(timeout));
}
}
}
}
}
impl<S: WsStream> Default for MessageDispatcher<S> {
fn default() -> Self {
Self::new(DispatcherConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dispatcher_config() {
let config = DispatcherConfig::new()
.with_receive_timeout(Duration::from_secs(60))
.with_broadcast_capacity(2048);
assert_eq!(config.receive_timeout, Duration::from_secs(60));
assert_eq!(config.broadcast_capacity, 2048);
}
#[tokio::test]
async fn test_dispatcher_not_connected() {
let dispatcher = MessageDispatcher::<crate::connection::DefaultWsStream>::default();
let result = dispatcher.send(Message::Text("test".into())).await;
assert!(matches!(result, Err(SendError::NotConnected)));
}
}