#[cfg(doc)]
use crate::{Config, protocols::Handshake};
use crate::{
Connection,
ConnectionSide,
P2P,
Tcp,
protocols::{ProtocolHandler, ReturnableConnection},
};
use async_trait::async_trait;
use bytes::BytesMut;
use futures_util::StreamExt;
use std::{
io,
net::SocketAddr,
time::{Duration, Instant},
};
use tokio::{
io::AsyncRead,
sync::{mpsc, oneshot},
time::timeout,
};
use tokio_util::codec::{Decoder, FramedRead};
use tracing::*;
#[async_trait]
pub trait Reading: P2P
where
Self: Clone + Send + Sync + 'static,
{
fn message_queue_depth(&self) -> usize {
1024
}
const INITIAL_BUFFER_SIZE: usize = 1024 * 1024;
const IDLE_TIMEOUT: Duration = Duration::from_secs(150);
type Message: Send;
type Codec: Decoder<Item = Self::Message, Error = io::Error> + Send;
async fn enable_reading(&self) {
let (conn_sender, mut conn_receiver) = mpsc::channel(self.tcp().config().max_connections as usize);
let (tx_reading, rx_reading) = oneshot::channel();
let self_clone = self.clone();
let reading_task = tokio::spawn(async move {
trace!(parent: self_clone.tcp().span(), "spawned the Reading handler task");
tx_reading.send(()).unwrap();
while let Some(returnable_conn) = conn_receiver.recv().await {
self_clone.handle_new_connection(returnable_conn).await;
}
});
let _ = rx_reading.await;
self.tcp().tasks.lock().push(reading_task);
let hdl = Box::new(ProtocolHandler(conn_sender));
assert!(self.tcp().protocols.reading.set(hdl).is_ok(), "the Reading protocol was enabled more than once!");
}
fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec;
async fn process_message(&self, source: SocketAddr, message: Self::Message) -> io::Result<()>;
}
#[async_trait]
trait ReadingInternal: Reading {
async fn handle_new_connection(&self, (conn, conn_returner): ReturnableConnection);
fn map_codec<T: AsyncRead>(
&self,
framed: FramedRead<T, Self::Codec>,
conn: &Connection,
) -> FramedRead<T, CountingCodec<Self::Codec>>;
}
#[async_trait]
impl<R: Reading> ReadingInternal for R {
async fn handle_new_connection(&self, (mut conn, conn_returner): ReturnableConnection) {
let addr = conn.addr();
let codec = self.codec(addr, !conn.side());
let reader = conn.reader.take().expect("missing connection reader!");
let framed = FramedRead::new(reader, codec);
let mut framed = self.map_codec(framed, &conn);
let (tx_conn_ready, rx_conn_ready) = oneshot::channel();
conn.readiness_notifier = Some(tx_conn_ready);
if Self::INITIAL_BUFFER_SIZE != 0 {
framed.read_buffer_mut().reserve(Self::INITIAL_BUFFER_SIZE);
}
let (inbound_message_sender, mut inbound_message_receiver) =
mpsc::channel::<(R::Message, QueuedMessageGuard)>(self.message_queue_depth());
let (tx_processing, rx_processing) = oneshot::channel::<()>();
let self_clone = self.clone();
let conn_span = conn.span().clone();
let inbound_processing_task = tokio::spawn(Box::pin(async move {
let node = self_clone.tcp();
trace!(parent: &conn_span, "spawned a task for processing messages");
tx_processing.send(()).unwrap();
while let Some((msg, _guard)) = inbound_message_receiver.recv().await {
if let Err(e) = self_clone.process_message(addr, msg).await {
error!(parent: &conn_span, "can't process a message: {e}");
node.known_peers().register_failure(addr.ip());
}
}
}));
let _ = rx_processing.await;
conn.tasks.push(inbound_processing_task);
let (tx_reader, rx_reader) = oneshot::channel::<()>();
let node = self.tcp().clone();
let conn_span = conn.span().clone();
let reader_task = tokio::spawn(Box::pin(async move {
trace!(parent: &conn_span, "spawned a task for reading messages");
tx_reader.send(()).unwrap();
let _ = rx_conn_ready.await;
let mut dropped_count: usize = 0;
let mut last_drop_log = Instant::now();
loop {
let next_frame_future = framed.next();
let read_result = match timeout(Self::IDLE_TIMEOUT, next_frame_future).await {
Ok(res) => res, Err(_) => {
debug!(parent: &conn_span, "connection timed out due to inactivity");
break;
}
};
match read_result {
Some(Ok(msg)) => {
if let Err(e) = inbound_message_sender.try_send((msg, QueuedMessageGuard::new())) {
node.stats().register_failure();
match e {
mpsc::error::TrySendError::Full(_) => {
dropped_count += 1;
if last_drop_log.elapsed() >= Duration::from_secs(1) {
warn_about_dropped_messages(&conn_span, &mut dropped_count, &mut last_drop_log);
}
}
mpsc::error::TrySendError::Closed(_) => {
error!(parent: &conn_span, "inbound channel closed");
break;
}
}
} else if dropped_count != 0 {
warn_about_dropped_messages(&conn_span, &mut dropped_count, &mut last_drop_log);
debug!(parent: &conn_span, "the inbound queue is no longer saturated");
}
#[cfg(feature = "metrics")]
metrics::increment_gauge(metrics::tcp::TCP_TASKS, 1f64);
}
Some(Err(e)) => {
error!(parent: &conn_span, "can't read: {e}");
node.known_peers().register_failure(addr.ip());
if node.config().fatal_io_errors.contains(&e.kind()) {
break;
}
}
None => break, }
}
let _ = node.disconnect(addr).await;
}));
let _ = rx_reader.await;
conn.tasks.push(reader_task);
if conn_returner.send(Ok(conn)).is_err() {
unreachable!("couldn't return a Connection to the Tcp");
}
}
fn map_codec<T: AsyncRead>(
&self,
framed: FramedRead<T, Self::Codec>,
conn: &Connection,
) -> FramedRead<T, CountingCodec<Self::Codec>> {
framed.map_decoder(|codec| CountingCodec { codec, node: self.tcp().clone(), acc: 0, span: conn.span().clone() })
}
}
struct CountingCodec<D: Decoder> {
codec: D,
node: Tcp,
acc: usize,
span: Span,
}
impl<D: Decoder> Decoder for CountingCodec<D> {
type Error = D::Error;
type Item = D::Item;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let initial_buf_len = src.len();
let ret = self.codec.decode(src)?;
let final_buf_len = src.len();
let read_len = initial_buf_len - final_buf_len + self.acc;
if read_len != 0 {
trace!(parent: &self.span, "read {read_len}B");
if ret.is_some() {
self.acc = 0;
self.node.stats().register_received_message(read_len);
} else {
self.acc = read_len;
}
}
Ok(ret)
}
}
struct QueuedMessageGuard;
impl QueuedMessageGuard {
fn new() -> Self {
#[cfg(feature = "metrics")]
metrics::increment_gauge(metrics::tcp::TCP_TASKS, 1f64);
Self
}
}
impl Drop for QueuedMessageGuard {
fn drop(&mut self) {
#[cfg(feature = "metrics")]
metrics::decrement_gauge(metrics::tcp::TCP_TASKS, 1f64);
}
}
fn warn_about_dropped_messages(span: &Span, dropped_count: &mut usize, last_drop_log: &mut Instant) {
warn!(
parent: span,
"dropped {dropped_count} messages due\
to inbound queue saturation",
);
*dropped_count = 0;
*last_drop_log = Instant::now();
}