use std::sync::Arc;
use tari_shutdown::ShutdownSignal;
use crate::{
PeerManager,
Substream,
connectivity::ConnectivityRequester,
protocol::{ProtocolId, ProtocolNotificationTx, Protocols},
};
pub type ProtocolExtensionError = anyhow::Error;
pub trait ProtocolExtension: Send {
fn install(self: Box<Self>, context: &mut ProtocolExtensionContext) -> Result<(), ProtocolExtensionError>;
}
impl<F> ProtocolExtension for F
where F: FnOnce(&mut ProtocolExtensionContext) -> Result<(), ProtocolExtensionError> + Send + Sync
{
fn install(self: Box<Self>, context: &mut ProtocolExtensionContext) -> Result<(), ProtocolExtensionError> {
(self)(context)
}
}
#[derive(Default)]
pub struct ProtocolExtensions {
inner: Vec<Box<dyn ProtocolExtension>>,
}
impl ProtocolExtensions {
pub fn new() -> Self {
Self { inner: Vec::new() }
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn add<T: ProtocolExtension + 'static>(&mut self, ext: T) -> &mut Self {
self.inner.push(Box::new(ext));
self
}
pub(crate) fn install_all(self, context: &mut ProtocolExtensionContext) -> Result<(), ProtocolExtensionError> {
for ext in self.inner {
ext.install(context)?;
}
Ok(())
}
}
impl Extend<Box<dyn ProtocolExtension>> for ProtocolExtensions {
fn extend<T: IntoIterator<Item = Box<dyn ProtocolExtension>>>(&mut self, iter: T) {
self.inner.extend(iter)
}
}
impl From<Protocols<Substream>> for ProtocolExtensions {
fn from(protocols: Protocols<Substream>) -> Self {
let mut p = Self::new();
p.add(protocols);
p
}
}
impl IntoIterator for ProtocolExtensions {
type IntoIter = <Vec<Self::Item> as IntoIterator>::IntoIter;
type Item = Box<dyn ProtocolExtension>;
fn into_iter(self) -> Self::IntoIter {
self.inner.into_iter()
}
}
pub struct ProtocolExtensionContext {
connectivity: ConnectivityRequester,
peer_manager: Arc<PeerManager>,
protocols: Option<Protocols<Substream>>,
complete_signals: Vec<ShutdownSignal>,
shutdown_signal: ShutdownSignal,
}
impl ProtocolExtensionContext {
pub(crate) fn new(
connectivity: ConnectivityRequester,
peer_manager: Arc<PeerManager>,
shutdown_signal: ShutdownSignal,
) -> Self {
Self {
connectivity,
peer_manager,
protocols: Some(Protocols::new()),
complete_signals: Vec::new(),
shutdown_signal,
}
}
pub fn add_protocol<I: AsRef<[ProtocolId]>>(
&mut self,
protocols: I,
notifier: &ProtocolNotificationTx<Substream>,
) -> &mut Self {
self.protocols
.as_mut()
.expect("CommsContext::protocols taken!")
.add(protocols, notifier);
self
}
pub fn register_complete_signal(&mut self, signal: ShutdownSignal) -> &mut Self {
self.complete_signals.push(signal);
self
}
pub fn connectivity(&self) -> ConnectivityRequester {
self.connectivity.clone()
}
pub fn peer_manager(&self) -> Arc<PeerManager> {
self.peer_manager.clone()
}
pub fn shutdown_signal(&self) -> ShutdownSignal {
self.shutdown_signal.clone()
}
pub(crate) fn drain_complete_signals(&mut self) -> Vec<ShutdownSignal> {
self.complete_signals.drain(..).collect()
}
pub(crate) fn take_protocols(&mut self) -> Option<Protocols<Substream>> {
self.protocols.take()
}
}