use std::io;
use futures::{SinkExt, StreamExt, future, future::Either};
use log::*;
use tari_shutdown::ShutdownSignal;
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::{broadcast, mpsc},
};
#[cfg(feature = "metrics")]
use super::metrics;
use super::{MessagingEvent, MessagingProtocol};
use crate::{PeerConnection, message::InboundMessage};
const LOG_TARGET: &str = "comms::protocol::messaging::inbound";
pub struct InboundMessaging {
connection: PeerConnection,
inbound_message_tx: mpsc::Sender<InboundMessage>,
messaging_events_tx: broadcast::Sender<MessagingEvent>,
enable_message_received_event: bool,
shutdown_signal: ShutdownSignal,
}
impl InboundMessaging {
pub fn new(
connection: PeerConnection,
inbound_message_tx: mpsc::Sender<InboundMessage>,
messaging_events_tx: broadcast::Sender<MessagingEvent>,
enable_message_received_event: bool,
shutdown_signal: ShutdownSignal,
) -> Self {
Self {
connection,
inbound_message_tx,
messaging_events_tx,
enable_message_received_event,
shutdown_signal,
}
}
pub async fn run<S>(mut self, socket: S)
where S: AsyncRead + AsyncWrite + Unpin {
let peer = self.connection.peer_node_id();
#[cfg(feature = "metrics")]
metrics::num_sessions().inc();
debug!(
target: LOG_TARGET,
"Starting inbound messaging protocol for peer '{}'",
peer.short_str()
);
let stream = MessagingProtocol::framed(socket);
let stream = stream.take_until(self.connection.on_disconnect());
tokio::pin!(stream);
while let Either::Right((Some(result), _)) = future::select(self.shutdown_signal.wait(), stream.next()).await {
match result {
Ok(raw_msg) => {
#[cfg(feature = "metrics")]
metrics::inbound_message_count().inc();
let msg_len = raw_msg.len();
let inbound_msg = InboundMessage::new(peer.clone(), raw_msg.freeze());
debug!(
target: LOG_TARGET,
"Received message {} from peer '{}' ({} bytes)",
inbound_msg.tag,
peer.short_str(),
msg_len
);
let message_tag = inbound_msg.tag;
if self.inbound_message_tx.send(inbound_msg).await.is_err() {
warn!(
target: LOG_TARGET,
"Failed to send InboundMessage {} for peer '{}' because inbound message channel closed",
message_tag,
peer.short_str(),
);
break;
}
if self.enable_message_received_event {
let _result = self
.messaging_events_tx
.send(MessagingEvent::MessageReceived(peer.clone(), message_tag));
}
},
Err(err) if err.kind() == io::ErrorKind::InvalidData => {
#[cfg(feature = "metrics")]
metrics::error_count().inc();
debug!(
target: LOG_TARGET,
"Failed to receive from peer '{}' because '{}'",
peer.short_str(),
err
);
let _result = self.messaging_events_tx.send(MessagingEvent::ProtocolViolation {
peer_node_id: peer.clone(),
details: err.to_string(),
});
break;
},
Err(err) => {
#[cfg(feature = "metrics")]
metrics::error_count().inc();
error!(
target: LOG_TARGET,
"Failed to receive from peer '{}' because '{}'",
peer.short_str(),
err
);
break;
},
}
}
let _ignore = stream.close().await;
let _ignore = self
.messaging_events_tx
.send(MessagingEvent::InboundProtocolExited(peer.clone()));
#[cfg(feature = "metrics")]
metrics::num_sessions().dec();
debug!(
target: LOG_TARGET,
"Inbound messaging handler exited for peer `{}`",
peer.short_str()
);
}
}