use std::{collections::HashMap, marker::PhantomData, sync::Arc};
use futures::future::join_all;
use tokio::sync::{Mutex, broadcast, mpsc};
use crate::node::NodeId;
use super::information_packet::Content;
#[derive(Default)]
pub struct OutChannels(pub(crate) HashMap<NodeId, Arc<Mutex<OutChannel>>>);
impl OutChannels {
pub async fn send_to(&self, id: &NodeId, content: Content) -> Result<(), SendErr> {
match self.get(id) {
Some(channel) => channel.lock().await.send(content).await,
None => Err(SendErr::NoSuchChannel),
}
}
pub async fn broadcast(&self, content: Content) -> Vec<Result<(), SendErr>> {
let futures = self
.0
.values()
.map(|c| async { c.lock().await.send(content.clone()).await });
join_all(futures).await
}
pub async fn close(&mut self, id: &NodeId) {
if self.get(id).is_some() {
self.0.remove(id);
}
}
pub(crate) async fn close_all(&mut self) {
self.0.clear();
}
fn get(&self, id: &NodeId) -> Option<Arc<Mutex<OutChannel>>> {
self.0.get(id).cloned()
}
pub(crate) fn insert(&mut self, node_id: NodeId, channel: Arc<Mutex<OutChannel>>) {
self.0.insert(node_id, channel);
}
pub fn get_receiver_ids(&self) -> Vec<NodeId> {
self.0.keys().copied().collect()
}
}
pub enum OutChannel {
Mpsc(mpsc::Sender<Content>),
Bcst(broadcast::Sender<Content>),
}
impl OutChannel {
async fn send(&self, value: Content) -> Result<(), SendErr> {
match self {
OutChannel::Mpsc(sender) => match sender.send(value).await {
Ok(_) => Ok(()),
Err(e) => Err(SendErr::ClosedChannel(e.0)),
},
OutChannel::Bcst(sender) => match sender.send(value) {
Ok(_) => Ok(()),
Err(e) => Err(SendErr::ClosedChannel(e.0)),
},
}
}
}
#[derive(Debug)]
pub enum SendErr {
NoSuchChannel,
ClosedChannel(Content),
}
#[derive(Default)]
pub struct TypedOutChannels<T: Send + Sync + 'static>(
pub(crate) HashMap<NodeId, Arc<Mutex<OutChannel>>>,
pub(crate) PhantomData<T>,
);
impl<T: Send + Sync + 'static> TypedOutChannels<T> {
pub async fn send_to(&self, id: &NodeId, content: T) -> Result<(), SendErr> {
match self.get(id) {
Some(channel) => channel.lock().await.send(Content::new(content)).await,
None => Err(SendErr::NoSuchChannel),
}
}
pub async fn broadcast(&self, content: T) -> Vec<Result<(), SendErr>> {
let content = Content::new(content);
let futures = self
.0
.values()
.map(|c| async { c.lock().await.send(content.clone()).await });
join_all(futures).await
}
pub async fn close(&mut self, id: &NodeId) {
if self.get(id).is_some() {
self.0.remove(id);
}
}
fn get(&self, id: &NodeId) -> Option<Arc<Mutex<OutChannel>>> {
self.0.get(id).cloned()
}
pub fn get_receiver_ids(&self) -> Vec<NodeId> {
self.0.keys().copied().collect()
}
}