use bytes::Bytes;
use commonware_cryptography::{Digest, PublicKey};
use commonware_utils::{channel::mpsc, sync::Mutex};
use std::collections::{btree_map::Entry, BTreeMap};
use tracing::{error, warn};
pub struct Relay<D: Digest, P: PublicKey> {
#[allow(clippy::type_complexity)]
recipients: Mutex<BTreeMap<P, Vec<mpsc::UnboundedSender<(D, Bytes)>>>>,
}
impl<D: Digest, P: PublicKey> Relay<D, P> {
#[allow(clippy::new_without_default)]
pub const fn new() -> Self {
Self {
recipients: Mutex::new(BTreeMap::new()),
}
}
pub fn deregister_all(&self) {
let mut recipients = self.recipients.lock();
recipients.clear();
}
pub fn register(&self, public_key: P) -> mpsc::UnboundedReceiver<(D, Bytes)> {
let (sender, receiver) = mpsc::unbounded_channel();
{
let mut recipients = self.recipients.lock();
match recipients.entry(public_key.clone()) {
Entry::Vacant(vacant) => {
vacant.insert(vec![sender]);
}
Entry::Occupied(mut occupied) => {
warn!(?public_key, "duplicate registration");
occupied.get_mut().push(sender);
}
}
}
receiver
}
pub fn broadcast(&self, sender: &P, (payload, data): (D, Bytes)) {
let channels = {
let mut channels = Vec::new();
let recipients = self.recipients.lock();
for (public_key, channel) in recipients.iter() {
if public_key == sender {
continue;
}
channels.push((public_key.clone(), channel.clone()));
}
channels
};
for (recipient, listeners) in channels {
for listener in listeners {
if let Err(e) = listener.send((payload, data.clone())) {
error!(?e, ?recipient, "failed to send message to recipient");
}
}
}
}
}
impl<D: Digest, P: PublicKey> Default for Relay<D, P> {
fn default() -> Self {
Self::new()
}
}