use std::{future::Future, io, net::SocketAddr, sync::Arc};
use bytes::BytesMut;
use futures_util::StreamExt;
use tokio::{
io::AsyncRead,
sync::{mpsc, oneshot},
};
use tokio_util::codec::{Decoder, FramedRead};
use tracing::*;
#[cfg(doc)]
use crate::{Config, protocols::Handshake};
use crate::{
ConnectionInfo, ConnectionSide, Node, Pea2Pea, Stats,
node::NodeTask,
protocols::{ProtocolHandler, ReturnableConnection},
};
pub trait Reading: Pea2Pea
where
Self: Clone + Send + Sync + 'static,
{
const MESSAGE_QUEUE_DEPTH: usize = 64;
const BACKPRESSURE: bool = true;
const INITIAL_BUFFER_SIZE: usize = 64 * 1024;
type Message: Send;
type Codec: Decoder<Item = Self::Message, Error = io::Error> + Send;
fn enable_reading(&self) -> impl Future<Output = ()> {
async {
let (conn_sender, mut conn_receiver) = mpsc::unbounded_channel();
let (tx_reading, rx_reading) = oneshot::channel();
let self_clone = self.clone();
let reading_task = tokio::spawn(async move {
trace!(parent: self_clone.node().span(), "spawned the Reading handler task");
if tx_reading.send(()).is_err() {
error!(parent: self_clone.node().span(), "Reading handler creation interrupted! shutting down the node");
self_clone.node().shut_down().await;
return;
}
while let Some(returnable_conn) = conn_receiver.recv().await {
self_clone.handle_new_connection(returnable_conn).await;
}
});
let _ = rx_reading.await;
self.node()
.tasks
.lock()
.insert(NodeTask::Reading, reading_task);
let hdl = ProtocolHandler(conn_sender);
assert!(
self.node().protocols.reading.set(hdl).is_ok(),
"the Reading protocol was enabled more than once!"
);
}
}
fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec;
fn process_message(
&self,
source: SocketAddr,
message: Self::Message,
) -> impl Future<Output = ()> + Send;
}
trait ReadingInternal: Reading {
fn handle_new_connection(
&self,
conn_with_returner: ReturnableConnection,
) -> impl Future<Output = ()> + Send;
fn map_codec<T: AsyncRead>(
&self,
framed: FramedRead<T, Self::Codec>,
info: &ConnectionInfo,
) -> FramedRead<T, CountingCodec<Self::Codec>>;
}
impl<R: Reading> ReadingInternal for R {
async fn handle_new_connection(&self, (mut conn, conn_returner): ReturnableConnection) {
let addr = conn.addr();
let codec = self.codec(addr, !conn.side());
let Some(reader) = conn.reader.take() else {
error!("The stream was not returned during the handshake with {addr}!");
return;
};
let framed = FramedRead::with_capacity(reader, codec, Self::INITIAL_BUFFER_SIZE);
let mut framed = self.map_codec(framed, conn.info());
let (tx_conn_ready, rx_conn_ready) = oneshot::channel();
conn.readiness_notifier = Some(tx_conn_ready);
let (inbound_message_sender, mut inbound_message_receiver) =
mpsc::channel(Self::MESSAGE_QUEUE_DEPTH);
let (tx_processing, rx_processing) = oneshot::channel::<()>();
let self_clone = self.clone();
let inbound_processing_task = tokio::spawn(Box::pin(async move {
let node = self_clone.node();
trace!(parent: node.span(), "spawned a task for processing messages from {addr}");
if tx_processing.send(()).is_err() {
error!(parent: node.span(), "Reading (processing) for {addr} was interrupted; shutting down its task");
return;
}
while let Some(msg) = inbound_message_receiver.recv().await {
self_clone.process_message(addr, msg).await;
}
}));
let _ = rx_processing.await;
conn.tasks.push(inbound_processing_task);
let (tx_reader, rx_reader) = oneshot::channel::<()>();
let node = self.node().clone();
let reader_task = tokio::spawn(Box::pin(async move {
trace!(parent: node.span(), "spawned a task for reading messages from {addr}");
if tx_reader.send(()).is_err() {
error!(parent: node.span(), "Reading (IO) for {addr} was interrupted; shutting down its task");
return;
}
let _ = rx_conn_ready.await;
while let Some(bytes) = framed.next().await {
match bytes {
Ok(msg) => {
match Self::BACKPRESSURE {
true => {
if let Err(e) = inbound_message_sender.send(msg).await {
error!(parent: node.span(), "can't process a message from {addr}: {e}");
break;
}
}
false => {
if let Err(e) = inbound_message_sender.try_send(msg) {
error!(parent: node.span(), "can't process a message from {addr}: {e}");
if matches!(e, mpsc::error::TrySendError::Closed(_)) {
break;
}
}
}
}
}
Err(e) => {
error!(parent: node.span(), "can't read from {addr}: {e}");
}
}
}
let _ = node.disconnect(addr).await;
}));
let _ = rx_reader.await;
conn.tasks.push(reader_task);
if conn_returner.send(Ok(conn)).is_err() {
error!(parent: self.node().span(), "couldn't return a Connection with {addr} from the Reading handler");
}
}
fn map_codec<T: AsyncRead>(
&self,
framed: FramedRead<T, Self::Codec>,
info: &ConnectionInfo,
) -> FramedRead<T, CountingCodec<Self::Codec>> {
framed.map_decoder(|codec| CountingCodec {
codec,
node: self.node().clone(),
addr: info.addr(),
stats: info.stats().clone(),
acc: 0,
})
}
}
struct CountingCodec<D: Decoder> {
codec: D,
node: Node,
addr: SocketAddr,
stats: Arc<Stats>,
acc: usize,
}
impl<D: Decoder> Decoder for CountingCodec<D> {
type Item = D::Item;
type Error = D::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let initial_buf_len = src.len();
let ret = self.codec.decode(src)?;
let final_buf_len = src.len();
let read_len = initial_buf_len - final_buf_len + self.acc;
if read_len != 0 {
trace!(parent: self.node.span(), "read {read_len}B from {}", self.addr);
if ret.is_some() {
self.acc = 0;
self.stats.register_received_message(read_len);
self.node.stats().register_received_message(read_len);
} else {
self.acc = read_len;
}
}
Ok(ret)
}
}