use std::sync::Arc;
pub use tokio::time::Duration;
use tokio::{sync::mpsc::error::SendError, task::JoinHandle, time::timeout};
use crate::{
error::{NotifierError, UnexpectedErrorKind},
notifier::{Sender, SmartChannelId},
};
type Handler<M> = JoinHandle<Result<(), SendError<M>>>;
#[derive(Default)]
pub struct WritingHandler<M: Send + 'static> {
handlers: Vec<Handler<M>>,
}
fn get_handler<M: Send + 'static>(sender: Sender<M, SmartChannelId>, msg: M) -> Handler<M> {
tokio::spawn(async move { sender.send(msg).await })
}
impl<M: Send + 'static + Sync> WritingHandler<Arc<M>> {
pub(crate) fn new_arc_broadcast(msg: M, senders: &[Sender<Arc<M>, SmartChannelId>]) -> Self {
let msg = Arc::new(msg);
WritingHandler {
handlers: senders
.iter()
.map(|sender| {
let msg = Arc::clone(&msg);
let sender = sender.clone();
get_handler(sender, msg)
})
.collect(),
}
}
}
impl<M: Send + 'static + Clone> WritingHandler<M> {
pub(crate) fn new_cloning_broadcast(msg: M, senders: &[Sender<M, SmartChannelId>]) -> Self {
if senders.is_empty() {
return Self::empty();
}
let mut handlers = senders
.iter()
.skip(1)
.map(|sender| {
let msg = msg.clone();
let sender = sender.clone();
get_handler(sender, msg)
})
.collect::<Vec<_>>();
handlers.push(get_handler(senders[0].clone(), msg)); WritingHandler { handlers }
}
}
impl<M: Send + 'static> WritingHandler<M> {
pub fn empty() -> Self {
Self {
handlers: Vec::new(),
}
}
pub fn len(&self) -> usize {
self.handlers.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub async fn wait(self, duration: Option<Duration>) -> Result<usize, NotifierError<M, ()>> {
let n = self.handlers.len();
let mut errors = Vec::new();
for handler in self.handlers {
let result = match duration {
Some(duration) => timeout(duration, handler).await,
None => Ok(handler.await),
};
match result {
Ok(Ok(Ok(()))) => (),
Ok(Ok(Err(e))) => errors.push(NotifierError::SendingError(e)),
Ok(Err(e)) => errors.push(NotifierError::JoiningError(e)),
Err(_) => errors.push(NotifierError::WritingTimeout(match duration {
Some(d) => d,
None => {
return Err(NotifierError::UnexpectedError(
UnexpectedErrorKind::DurationIsMissing, ));
}
})),
}
}
if errors.is_empty() {
Ok(n)
} else {
Err(NotifierError::WritingSendError(errors))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use smart_channel::channel;
const TEST_ID: SmartChannelId = SmartChannelId {
channel_counter: 1,
notifier_address: 1,
};
#[tokio::test]
async fn test_empty_handler_wait() {
let handler: WritingHandler<String> = WritingHandler::empty();
let result = handler.wait(None).await;
assert_eq!(result.unwrap(), 0);
}
#[tokio::test]
async fn test_handler_len() {
let handler: WritingHandler<String> = WritingHandler::empty();
assert_eq!(handler.len(), 0);
assert!(handler.is_empty());
let (tx1, _) = channel(10, TEST_ID);
let (tx2, _) = channel(10, TEST_ID);
let message = "Hello from Arc!";
let handler = WritingHandler::new_arc_broadcast(message, &[tx1, tx2]);
assert!(handler.len() == 2)
}
#[tokio::test]
async fn test_arc_broadcast_success() {
let (tx1, mut rx1) = channel(10, TEST_ID);
let (tx2, mut rx2) = channel(10, TEST_ID);
let message = "Hello from Arc!";
let handler = WritingHandler::new_arc_broadcast(message, &[tx1, tx2]);
handler.wait(None).await.unwrap();
assert_eq!(rx1.recv().await.unwrap(), Arc::new("Hello from Arc!"));
assert_eq!(rx2.recv().await.unwrap(), Arc::new("Hello from Arc!"));
}
#[tokio::test]
async fn test_cloning_broadcast_success() {
let (tx1, mut rx1) = channel(10, TEST_ID);
let (tx2, mut rx2) = channel(10, TEST_ID);
let message = "Hello from Arc!".to_string();
let handler = WritingHandler::new_cloning_broadcast(message, &[tx1, tx2]);
handler.wait(None).await.unwrap();
assert_eq!(*rx1.recv().await.unwrap(), String::from("Hello from Arc!"));
assert_eq!(*rx2.recv().await.unwrap(), String::from("Hello from Arc!"));
}
#[tokio::test]
async fn test_timeout_error() {
let (tx1, _rx1) = channel(1, TEST_ID);
let valid_handler = WritingHandler::new_cloning_broadcast(
"Message should pass".to_string(),
&[tx1.clone()],
);
valid_handler.wait(None).await.unwrap();
let err_handler = WritingHandler::new_cloning_broadcast(
"Message should not pass".to_string(),
&[tx1.clone()],
);
let result = err_handler.wait(Some(Duration::from_millis(500))).await;
assert!(result.is_err());
if let Err(NotifierError::WritingSendError(errors)) = result {
assert!(errors.len() == 1);
assert!(matches!(errors[0], NotifierError::WritingTimeout(_)));
} else {
panic!("Expected timeout error.");
}
}
#[tokio::test]
async fn test_send_error() {
let (tx, _) = channel(10, TEST_ID);
let handler = WritingHandler::new_cloning_broadcast("Join test".to_string(), &[tx]);
let result = handler.wait(None).await;
assert!(result.is_err());
if let Err(NotifierError::WritingSendError(errors)) = result {
assert!(matches!(errors[0], NotifierError::SendingError(_)));
} else {
panic!("Expected join error.");
}
}
#[tokio::test]
async fn test_multiple_errors() {
let (tx1, _) = channel(10, TEST_ID); let (tx2, _) = channel(10, TEST_ID);
let handler =
WritingHandler::new_cloning_broadcast("Multi-error test".to_string(), &[tx1, tx2]);
let result = handler.wait(None).await;
assert!(result.is_err());
if let Err(NotifierError::WritingSendError(errors)) = result {
assert_eq!(errors.len(), 2); } else {
panic!("Expected multiple send errors.");
}
}
#[tokio::test]
async fn test_no_error_with_successful_senders() {
let (tx, mut rx) = channel(10, TEST_ID);
let handler = WritingHandler::new_cloning_broadcast("Success message".to_string(), &[tx]);
tokio::spawn(async move {
let _ = rx.recv().await;
});
let result = handler.wait(None).await;
assert!(result.is_ok());
}
}