use crate::Node;
use bytes::Bytes;
use fxhash::FxHashMap;
use parking_lot::RwLock;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
},
sync::mpsc::Sender,
task::JoinHandle,
};
use tracing::*;
use std::{
io::{self, ErrorKind},
net::SocketAddr,
ops::Not,
};
#[derive(Default)]
pub(crate) struct Connections(RwLock<FxHashMap<SocketAddr, Connection>>);
impl Connections {
pub(crate) fn sender(&self, addr: SocketAddr) -> io::Result<Sender<Bytes>> {
if let Some(conn) = self.0.read().get(&addr) {
conn.sender()
} else {
Err(ErrorKind::NotConnected.into())
}
}
pub(crate) fn add(&self, conn: Connection) {
self.0.write().insert(conn.addr, conn);
}
pub(crate) fn senders(&self) -> io::Result<Vec<Sender<Bytes>>> {
self.0.read().values().map(|conn| conn.sender()).collect()
}
pub(crate) fn is_connected(&self, addr: SocketAddr) -> bool {
self.0.read().contains_key(&addr)
}
pub(crate) fn remove(&self, addr: SocketAddr) -> bool {
self.0.write().remove(&addr).is_some()
}
pub(crate) fn num_connected(&self) -> usize {
self.0.read().len()
}
pub(crate) fn addrs(&self) -> Vec<SocketAddr> {
self.0.read().keys().copied().collect()
}
}
#[derive(Clone, Copy)]
pub enum ConnectionSide {
Initiator,
Responder,
}
impl Not for ConnectionSide {
type Output = Self;
fn not(self) -> Self::Output {
match self {
Self::Initiator => Self::Responder,
Self::Responder => Self::Initiator,
}
}
}
pub struct ConnectionReader {
pub node: Node,
pub addr: SocketAddr,
pub buffer: Box<[u8]>,
pub carry: usize,
pub reader: OwnedReadHalf,
}
impl ConnectionReader {
pub async fn read_queued_bytes(&mut self) -> io::Result<&[u8]> {
let len = self.reader.read(&mut self.buffer).await?;
trace!(parent: self.node.span(), "read {}B from {}", len, self.addr);
Ok(&self.buffer[..len])
}
pub async fn read_exact(&mut self, num: usize) -> io::Result<&[u8]> {
let buffer = &mut self.buffer;
if num > buffer.len() {
error!(parent: self.node.span(), "can' read {}B from the stream; the buffer is too small ({}B)", num, buffer.len());
return Err(ErrorKind::Other.into());
}
self.reader.read_exact(&mut buffer[..num]).await?;
trace!(parent: self.node.span(), "read {}B from {}", num, self.addr);
Ok(&buffer[..num])
}
}
pub struct ConnectionWriter {
pub node: Node,
pub addr: SocketAddr,
pub buffer: Box<[u8]>,
pub carry: usize,
pub writer: OwnedWriteHalf,
}
impl ConnectionWriter {
pub async fn write_all(&mut self, buffer: &[u8]) -> io::Result<()> {
self.writer.write_all(buffer).await?;
trace!(parent: self.node.span(), "wrote {}B to {}", buffer.len(), self.addr);
Ok(())
}
}
pub struct Connection {
pub node: Node,
pub addr: SocketAddr,
pub reader: Option<ConnectionReader>,
pub writer: Option<ConnectionWriter>,
pub tasks: Vec<JoinHandle<()>>,
pub outbound_message_sender: Option<Sender<Bytes>>,
pub side: ConnectionSide,
}
impl Connection {
pub(crate) fn new(
addr: SocketAddr,
stream: TcpStream,
side: ConnectionSide,
node: &Node,
) -> Self {
let (reader, writer) = stream.into_split();
let reader = ConnectionReader {
node: node.clone(),
addr,
buffer: vec![0; node.config.conn_read_buffer_size].into(),
carry: 0,
reader,
};
let writer = ConnectionWriter {
node: node.clone(),
addr,
buffer: vec![0; node.config.conn_write_buffer_size].into(),
carry: 0,
writer,
};
Self {
node: node.clone(),
addr,
reader: Some(reader),
writer: Some(writer),
side,
tasks: Default::default(),
outbound_message_sender: Default::default(),
}
}
pub fn reader(&mut self) -> &mut ConnectionReader {
self.reader
.as_mut()
.expect("ConnectionReader is not available!")
}
pub fn writer(&mut self) -> &mut ConnectionWriter {
self.writer
.as_mut()
.expect("ConnectionWriter is not available!")
}
fn sender(&self) -> io::Result<Sender<Bytes>> {
if let Some(ref sender) = self.outbound_message_sender {
Ok(sender.clone())
} else {
error!(parent: self.node.span(), "can't send messages: the Writing protocol is disabled");
Err(ErrorKind::Other.into())
}
}
}
impl Drop for Connection {
fn drop(&mut self) {
debug!(parent: self.node.span(), "disconnecting from {}", self.addr);
for task in &self.tasks {
task.abort();
}
if matches!(self.side, ConnectionSide::Initiator) {
self.node.known_peers().remove(self.addr);
}
}
}