use std::{any::Any, collections::HashMap, io, net::SocketAddr, sync::Arc, time::Duration};
use async_trait::async_trait;
use futures_util::sink::SinkExt;
#[cfg(feature = "locktick")]
use locktick::parking_lot::RwLock;
#[cfg(not(feature = "locktick"))]
use parking_lot::RwLock;
use tokio::{
io::AsyncWrite,
sync::{mpsc, oneshot},
time::timeout,
};
use tokio_util::codec::{Encoder, FramedWrite};
use tracing::*;
#[cfg(doc)]
use crate::{Config, Tcp, protocols::Handshake};
use crate::{
Connection,
ConnectionSide,
P2P,
connections::create_connection_span,
protocols::{Protocol, ProtocolHandler, ReturnableConnection},
};
type WritingSenders = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<WrappedMessage>>>>;
#[async_trait]
pub trait Writing: P2P
where
Self: Clone + Send + Sync + 'static,
{
fn message_queue_depth(&self) -> usize {
1024
}
const TIMEOUT: Duration = Duration::from_secs(5);
type Message: Send;
type Codec: Encoder<Self::Message, Error = io::Error> + Send;
async fn enable_writing(&self) {
let (conn_sender, mut conn_receiver) = mpsc::channel(self.tcp().config().max_connections as usize);
let conn_senders: WritingSenders = Default::default();
let senders = conn_senders.clone();
let (tx_writing, rx_writing) = oneshot::channel();
let self_clone = self.clone();
let writing_task = tokio::spawn(async move {
trace!(parent: self_clone.tcp().span(), "spawned the Writing handler task");
tx_writing.send(()).unwrap();
while let Some(returnable_conn) = conn_receiver.recv().await {
self_clone.handle_new_connection(returnable_conn, &conn_senders).await;
}
});
let _ = rx_writing.await;
self.tcp().tasks.lock().push(writing_task);
let hdl = Box::new(WritingHandler { handler: ProtocolHandler(conn_sender), senders });
assert!(self.tcp().protocols.writing.set(hdl).is_ok(), "the Writing protocol was enabled more than once!");
}
fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec;
fn unicast(&self, addr: SocketAddr, message: Self::Message) -> io::Result<oneshot::Receiver<io::Result<()>>> {
if let Some(handler) = self.tcp().protocols.writing.get() {
if let Some(sender) = handler.senders.read().get(&addr).cloned() {
let (msg, delivery) = WrappedMessage::new(Box::new(message));
sender
.try_send(msg)
.map_err(|e| {
let conn_span = create_connection_span(addr, self.tcp().span());
error!(parent: conn_span, "can't send a message: {e}");
self.tcp().stats().register_failure();
io::ErrorKind::Other.into()
})
.map(|_| delivery)
} else {
Err(io::ErrorKind::NotConnected.into())
}
} else {
Err(io::ErrorKind::Unsupported.into())
}
}
fn broadcast(&self, message: Self::Message) -> io::Result<()>
where
Self::Message: Clone,
{
if let Some(handler) = self.tcp().protocols.writing.get() {
let senders = handler.senders.read().clone();
for (addr, message_sender) in senders {
let (msg, _delivery) = WrappedMessage::new(Box::new(message.clone()));
let _ = message_sender.try_send(msg).map_err(|e| {
let conn_span = create_connection_span(addr, self.tcp().span());
error!(parent: conn_span, "can't send a message: {e}");
self.tcp().stats().register_failure();
});
}
Ok(())
} else {
Err(io::ErrorKind::Unsupported.into())
}
}
}
#[async_trait]
trait WritingInternal: Writing {
async fn write_to_stream<W: AsyncWrite + Unpin + Send>(
&self,
message: Self::Message,
writer: &mut FramedWrite<W, Self::Codec>,
) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error>;
async fn handle_new_connection(&self, (conn, conn_returner): ReturnableConnection, conn_senders: &WritingSenders);
}
#[async_trait]
impl<W: Writing> WritingInternal for W {
async fn write_to_stream<A: AsyncWrite + Unpin + Send>(
&self,
message: Self::Message,
writer: &mut FramedWrite<A, Self::Codec>,
) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error> {
writer.feed(message).await?;
let len = writer.write_buffer().len();
match timeout(W::TIMEOUT, writer.flush()).await {
Ok(Ok(())) => Ok(len),
Ok(Err(e)) => Err(e),
Err(_) => Err(io::Error::new(io::ErrorKind::TimedOut, "write timed out")),
}
}
async fn handle_new_connection(
&self,
(mut conn, conn_returner): ReturnableConnection,
conn_senders: &WritingSenders,
) {
let addr = conn.addr();
let codec = self.codec(addr, !conn.side());
let writer = conn.writer.take().expect("missing connection writer!");
let mut framed = FramedWrite::new(writer, codec);
let (outbound_message_sender, mut outbound_message_receiver) = mpsc::channel(self.message_queue_depth());
conn_senders.write().insert(addr, outbound_message_sender);
let auto_cleanup = SenderCleanup { addr, senders: Arc::clone(conn_senders) };
let (tx_writer, rx_writer) = oneshot::channel();
let self_clone = self.clone();
let conn_span = conn.span().clone();
let writer_task = tokio::spawn(Box::pin(async move {
let node = self_clone.tcp();
trace!(parent: &conn_span, "spawned a task for writing messages");
tx_writer.send(()).unwrap();
let _auto_cleanup = auto_cleanup;
while let Some(wrapped_msg) = outbound_message_receiver.recv().await {
let msg = wrapped_msg.msg.downcast().unwrap();
match self_clone.write_to_stream(*msg, &mut framed).await {
Ok(len) => {
let _ = wrapped_msg.delivery_notification.send(Ok(()));
node.stats().register_sent_message(len);
trace!(parent: &conn_span, "sent {len}B");
}
Err(e) => {
node.known_peers().register_failure(addr.ip());
error!(parent: &conn_span, "couldn't send a message: {e}");
let is_fatal = node.config().fatal_io_errors.contains(&e.kind());
let _ = wrapped_msg.delivery_notification.send(Err(e));
if is_fatal {
break;
}
}
}
}
node.disconnect(addr).await;
}));
let _ = rx_writer.await;
conn.tasks.push(writer_task);
if conn_returner.send(Ok(conn)).is_err() {
unreachable!("couldn't return a Connection to the Tcp");
}
}
}
struct WrappedMessage {
msg: Box<dyn Any + Send>,
delivery_notification: oneshot::Sender<io::Result<()>>,
}
impl WrappedMessage {
fn new(msg: Box<dyn Any + Send>) -> (Self, oneshot::Receiver<io::Result<()>>) {
let (tx, rx) = oneshot::channel();
let wrapped_msg = Self { msg, delivery_notification: tx };
(wrapped_msg, rx)
}
}
pub(crate) struct WritingHandler {
handler: ProtocolHandler<Connection, io::Result<Connection>>,
senders: WritingSenders,
}
impl Protocol<Connection, io::Result<Connection>> for WritingHandler {
async fn trigger(&self, item: ReturnableConnection) {
self.handler.trigger(item).await;
}
}
struct SenderCleanup {
addr: SocketAddr,
senders: WritingSenders,
}
impl Drop for SenderCleanup {
fn drop(&mut self) {
self.senders.write().remove(&self.addr);
}
}