use std::{fmt, time::Duration};
use tokio::sync::mpsc;
use tower::Service;
use super::MessagingProtocol;
use crate::{
bounded_executor::BoundedExecutor,
message::InboundMessage,
pipeline,
protocol::{
ProtocolExtension,
ProtocolExtensionContext,
ProtocolExtensionError,
ProtocolId,
messaging::MessagingEventSender,
},
};
pub const INBOUND_MESSAGE_BUFFER_SIZE: usize = 10;
pub const MESSAGING_PROTOCOL_EVENTS_BUFFER_SIZE: usize = 30;
pub struct MessagingProtocolExtension<TInPipe, TOutPipe, TOutReq> {
event_tx: MessagingEventSender,
pipeline: pipeline::Config<TInPipe, TOutPipe, TOutReq>,
enable_message_received_event: bool,
ban_duration: Duration,
protocol_id: ProtocolId,
}
impl<TInPipe, TOutPipe, TOutReq> MessagingProtocolExtension<TInPipe, TOutPipe, TOutReq> {
pub fn new(
protocol_id: ProtocolId,
event_tx: MessagingEventSender,
pipeline: pipeline::Config<TInPipe, TOutPipe, TOutReq>,
) -> Self {
Self {
protocol_id,
event_tx,
pipeline,
enable_message_received_event: false,
ban_duration: Duration::from_secs(10 * 60),
}
}
pub fn enable_message_received_event(mut self) -> Self {
self.enable_message_received_event = true;
self
}
pub fn with_ban_duration(mut self, ban_duration: Duration) -> Self {
self.ban_duration = ban_duration;
self
}
}
impl<TInPipe, TOutPipe, TOutReq> ProtocolExtension for MessagingProtocolExtension<TInPipe, TOutPipe, TOutReq>
where
TOutPipe: Service<TOutReq, Response = ()> + Clone + Send + 'static,
TOutPipe::Error: fmt::Display + Send,
TOutPipe::Future: Send + 'static,
TInPipe: Service<InboundMessage> + Clone + Send + 'static,
TInPipe::Error: fmt::Display + Send,
TInPipe::Future: Send + 'static,
TOutReq: Send + 'static,
{
fn install(mut self: Box<Self>, context: &mut ProtocolExtensionContext) -> Result<(), ProtocolExtensionError> {
let (proto_tx, proto_rx) = mpsc::channel(MESSAGING_PROTOCOL_EVENTS_BUFFER_SIZE);
context.add_protocol(&[self.protocol_id.clone()], &proto_tx);
let (inbound_message_tx, inbound_message_rx) = mpsc::channel(INBOUND_MESSAGE_BUFFER_SIZE);
let message_receiver = self.pipeline.outbound.out_receiver.take().unwrap();
let messaging = MessagingProtocol::new(
self.protocol_id.clone(),
context.connectivity(),
proto_rx,
message_receiver,
self.event_tx,
inbound_message_tx,
context.shutdown_signal(),
)
.set_message_received_event_enabled(self.enable_message_received_event)
.with_ban_duration(self.ban_duration);
context.register_complete_signal(messaging.complete_signal());
tokio::spawn(messaging.run());
let bounded_executor = BoundedExecutor::new(self.pipeline.max_concurrent_inbound_tasks);
let inbound = pipeline::Inbound::new(
bounded_executor,
inbound_message_rx,
self.pipeline.inbound,
context.shutdown_signal(),
);
tokio::spawn(inbound.run());
let executor = BoundedExecutor::new(
self.pipeline
.max_concurrent_outbound_tasks
.unwrap_or_else(BoundedExecutor::max_theoretical_tasks),
);
let outbound = pipeline::Outbound::new(executor, self.pipeline.outbound);
tokio::spawn(outbound.run());
Ok(())
}
}