use crate::Node;
use bytes::Bytes;
use once_cell::sync::OnceCell;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::tcp::{OwnedReadHalf, OwnedWriteHalf},
sync::mpsc::{channel, Sender},
task::JoinHandle,
};
use tracing::*;
use std::{
io::{self, ErrorKind},
net::SocketAddr,
ops::Not,
sync::Arc,
};
#[derive(Clone, Copy)]
pub enum ConnectionSide {
Initiator,
Responder,
}
impl Not for ConnectionSide {
type Output = Self;
fn not(self) -> Self::Output {
match self {
Self::Initiator => Self::Responder,
Self::Responder => Self::Initiator,
}
}
}
pub struct ConnectionReader {
pub node: Arc<Node>,
pub addr: SocketAddr,
pub buffer: Box<[u8]>,
pub carry: usize,
pub reader: OwnedReadHalf,
}
impl ConnectionReader {
pub(crate) fn new(addr: SocketAddr, reader: OwnedReadHalf, node: Arc<Node>) -> Self {
Self {
addr,
buffer: vec![0; node.config.conn_read_buffer_size].into(),
carry: 0,
reader,
node,
}
}
pub async fn read_queued_bytes(&mut self) -> io::Result<&[u8]> {
let len = self.reader.read(&mut self.buffer).await?;
trace!(parent: self.node.span(), "read {}B from {}", len, self.addr);
Ok(&self.buffer[..len])
}
pub async fn read_exact(&mut self, num: usize) -> io::Result<&[u8]> {
let buffer = &mut self.buffer;
if num > buffer.len() {
error!(parent: self.node.span(), "can' read {}B from the stream; the buffer is too small ({}B)", num, buffer.len());
return Err(ErrorKind::Other.into());
}
self.reader.read_exact(&mut buffer[..num]).await?;
trace!(parent: self.node.span(), "read {}B from {}", num, self.addr);
Ok(&buffer[..num])
}
}
pub struct Connection {
node: Arc<Node>,
pub(crate) addr: SocketAddr,
pub inbound_reader_task: OnceCell<JoinHandle<()>>,
pub inbound_processing_task: OnceCell<JoinHandle<()>>,
_writer_task: JoinHandle<()>,
pub(crate) message_sender: Sender<Bytes>,
pub side: ConnectionSide,
}
impl Connection {
pub(crate) fn new(
addr: SocketAddr,
mut writer: OwnedWriteHalf,
node: Arc<Node>,
side: ConnectionSide,
) -> Self {
let (message_sender, mut message_receiver) =
channel::<Bytes>(node.config.outbound_message_queue_depth);
let node_clone = Arc::clone(&node);
let _writer_task = tokio::spawn(async move {
loop {
while let Some(msg) = message_receiver.recv().await {
if let Err(e) = writer.write_all(&msg).await {
node_clone.known_peers().register_failure(addr);
error!(parent: node_clone.span(), "couldn't send {}B to {}: {}", msg.len(), addr, e);
} else {
node_clone
.known_peers()
.register_sent_message(addr, msg.len());
node_clone.stats.register_sent_message(msg.len());
trace!(parent: node_clone.span(), "sent {}B to {}", msg.len(), addr);
}
}
if let Err(e) = writer.flush().await {
node_clone.known_peers().register_failure(addr);
error!(parent: node_clone.span(), "couldn't flush the stream to {}: {}", addr, e);
}
}
});
trace!(parent: node.span(), "spawned a task for writing messages to {}", addr);
Self {
node,
addr,
inbound_reader_task: Default::default(),
inbound_processing_task: Default::default(),
_writer_task,
message_sender,
side,
}
}
pub async fn send_message(&self, message: Bytes) {
self.message_sender
.send(message)
.await
.expect("the connection writer task is closed");
}
}
impl Drop for Connection {
fn drop(&mut self) {
debug!(parent: self.node.span(), "disconnecting from {}", self.addr);
if matches!(self.side, ConnectionSide::Initiator) {
self.node.known_peers().remove(self.addr);
}
}
}