use std::collections::HashMap;
use tokio::sync::mpsc;
use crate::{
Substream,
peer_manager::NodeId,
protocol::{ProtocolError, ProtocolExtension, ProtocolExtensionContext, ProtocolExtensionError, ProtocolId},
};
pub type ProtocolNotificationTx<TSubstream> = mpsc::Sender<ProtocolNotification<TSubstream>>;
pub type ProtocolNotificationRx<TSubstream> = mpsc::Receiver<ProtocolNotification<TSubstream>>;
#[derive(Debug, Clone)]
pub enum ProtocolEvent<TSubstream> {
NewInboundSubstream(NodeId, TSubstream),
}
#[derive(Debug, Clone)]
pub struct ProtocolNotification<TSubstream> {
pub event: ProtocolEvent<TSubstream>,
pub protocol: ProtocolId,
}
impl<TSubstream> ProtocolNotification<TSubstream> {
pub fn new(protocol: ProtocolId, event: ProtocolEvent<TSubstream>) -> Self {
Self { event, protocol }
}
}
pub struct Protocols<TSubstream> {
protocols: HashMap<ProtocolId, ProtocolNotificationTx<TSubstream>>,
}
impl<TSubstream> Clone for Protocols<TSubstream> {
fn clone(&self) -> Self {
Self {
protocols: self.protocols.clone(),
}
}
}
impl<TSubstream> Default for Protocols<TSubstream> {
fn default() -> Self {
Self {
protocols: HashMap::default(),
}
}
}
impl<TSubstream> Protocols<TSubstream> {
pub fn new() -> Self {
Default::default()
}
pub fn empty() -> Self {
Default::default()
}
pub fn add<I: AsRef<[ProtocolId]>>(
&mut self,
protocols: I,
notifier: &ProtocolNotificationTx<TSubstream>,
) -> &mut Self {
self.protocols
.extend(protocols.as_ref().iter().map(|p| (p.clone(), notifier.clone())));
self
}
pub fn extend(&mut self, protocols: Self) -> &mut Self {
self.protocols.extend(protocols.protocols);
self
}
pub fn get_supported_protocols(&self) -> Vec<ProtocolId> {
self.protocols.keys().cloned().collect()
}
pub async fn notify(
&mut self,
protocol: &ProtocolId,
event: ProtocolEvent<TSubstream>,
) -> Result<(), ProtocolError> {
match self.protocols.get_mut(protocol) {
Some(sender) => {
sender
.send(ProtocolNotification::new(protocol.clone(), event))
.await
.map_err(|_| ProtocolError::NotificationSenderDisconnected)?;
Ok(())
},
None => Err(ProtocolError::ProtocolNotRegistered),
}
}
pub fn iter(&self) -> impl Iterator<Item = &ProtocolId> {
self.protocols.keys()
}
}
impl ProtocolExtension for Protocols<Substream> {
fn install(mut self: Box<Self>, context: &mut ProtocolExtensionContext) -> Result<(), ProtocolExtensionError> {
for (protocol, notifier) in self.protocols.drain() {
context.add_protocol(&[protocol], ¬ifier);
}
Ok(())
}
}
#[cfg(test)]
mod test {
use tari_test_utils::unpack_enum;
use super::*;
#[test]
fn add() {
let (tx, _) = mpsc::channel(1);
let protos = [
ProtocolId::from_static(b"/tari/test/1"),
ProtocolId::from_static(b"/tari/test/2"),
];
let mut protocols = Protocols::<()>::new();
protocols.add(&protos, &tx);
assert!(protocols.get_supported_protocols().iter().all(|p| protos.contains(p)));
}
#[tokio::test]
async fn notify() {
let (tx, mut rx) = mpsc::channel(1);
let protos = [ProtocolId::from_static(b"/tari/test/1")];
let mut protocols = Protocols::<()>::new();
protocols.add(&protos, &tx);
protocols
.notify(&protos[0], ProtocolEvent::NewInboundSubstream(NodeId::new(), ()))
.await
.unwrap();
let notification = rx.recv().await.unwrap();
unpack_enum!(ProtocolEvent::NewInboundSubstream(peer_id, _s) = notification.event);
assert_eq!(peer_id, NodeId::new());
}
#[tokio::test]
async fn notify_fail_not_registered() {
let mut protocols = Protocols::<()>::new();
let err = protocols
.notify(
&ProtocolId::from_static(b"/tari/test/0"),
ProtocolEvent::NewInboundSubstream(NodeId::new(), ()),
)
.await
.unwrap_err();
unpack_enum!(ProtocolError::ProtocolNotRegistered = err);
}
}