use std::{
collections::HashSet,
io,
net::{IpAddr, SocketAddr},
ops::Deref,
sync::{
atomic::{AtomicUsize, Ordering::*},
Arc,
},
};
use once_cell::sync::OnceCell;
use parking_lot::Mutex;
use tokio::{
io::split,
net::{TcpListener, TcpSocket, TcpStream},
sync::oneshot,
task::JoinHandle,
};
use tracing::*;
use crate::{
connections::{Connection, ConnectionSide, Connections},
protocols::{Protocol, Protocols},
Config, KnownPeers, Stats,
};
macro_rules! enable_protocol {
($handler_type: ident, $node:expr, $conn: expr) => {
if let Some(handler) = $node.protocols.$handler_type.get() {
let (conn_returner, conn_retriever) = oneshot::channel();
handler.trigger(($conn, conn_returner));
match conn_retriever.await {
Ok(Ok(conn)) => conn,
Err(_) => return Err(io::ErrorKind::BrokenPipe.into()),
Ok(e) => return e,
}
} else {
$conn
}
};
}
static SEQUENTIAL_NODE_ID: AtomicUsize = AtomicUsize::new(0);
#[derive(Clone)]
pub struct Node(Arc<InnerNode>);
impl Deref for Node {
type Target = Arc<InnerNode>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[doc(hidden)]
pub struct InnerNode {
span: Span,
config: Config,
listening_addr: OnceCell<SocketAddr>,
pub(crate) protocols: Protocols,
connecting: Mutex<HashSet<SocketAddr>>,
connections: Connections,
known_peers: KnownPeers,
stats: Stats,
pub(crate) tasks: Mutex<Vec<JoinHandle<()>>>,
}
impl Node {
pub fn new(mut config: Config) -> Self {
if config.name.is_none() {
config.name = Some(SEQUENTIAL_NODE_ID.fetch_add(1, SeqCst).to_string());
}
let span = create_span(config.name.as_deref().unwrap());
let node = Node(Arc::new(InnerNode {
span,
config,
listening_addr: Default::default(),
protocols: Default::default(),
connecting: Default::default(),
connections: Default::default(),
known_peers: Default::default(),
stats: Default::default(),
tasks: Default::default(),
}));
debug!(parent: node.span(), "the node is ready");
node
}
async fn create_listener(&self, listener_ip: IpAddr) -> io::Result<TcpListener> {
let listener = if let Some(port) = self.config().desired_listening_port {
let desired_listening_addr = SocketAddr::new(listener_ip, port);
match TcpListener::bind(desired_listening_addr).await {
Ok(listener) => listener,
Err(e) => {
if self.config().allow_random_port {
warn!(
parent: self.span(),
"trying any port, the desired one is unavailable: {}", e
);
let random_available_addr = SocketAddr::new(listener_ip, 0);
TcpListener::bind(random_available_addr).await?
} else {
error!(parent: self.span(), "the desired port is unavailable: {}", e);
return Err(e);
}
}
}
} else if self.config().allow_random_port {
let random_available_addr = SocketAddr::new(listener_ip, 0);
TcpListener::bind(random_available_addr).await?
} else {
panic!("you must either provide a desired port or allow a random one");
};
Ok(listener)
}
pub async fn start_listening(&self) -> io::Result<SocketAddr> {
if let Some(listening_addr) = self.listening_addr.get() {
panic!(
"the node already has a listening address associated with it: {}",
listening_addr
);
} else {
let listener_ip = self
.config()
.listener_ip
.expect("Node::start_listening was called, but Config::listener_ip is not set");
let listener = self.create_listener(listener_ip).await?;
let port = listener.local_addr()?.port(); let listening_addr = (listener_ip, port).into();
self.listening_addr
.set(listening_addr)
.expect("the node's listener was started more than once");
let (tx, rx) = oneshot::channel();
let node = self.clone();
let listening_task = tokio::spawn(async move {
trace!(parent: node.span(), "spawned the listening task");
if tx.send(()).is_err() {
error!(parent: node.span(), "node creation interrupted; shutting down the listening task");
return;
}
loop {
match listener.accept().await {
Ok((stream, addr)) => {
debug!(parent: node.span(), "tentatively accepted a connection from {}", addr);
if !node.can_add_connection() {
debug!(parent: node.span(), "rejecting the connection from {}", addr);
continue;
}
node.connecting.lock().insert(addr);
let node2 = node.clone();
tokio::spawn(async move {
if let Err(e) = node2
.adapt_stream(stream, addr, ConnectionSide::Responder)
.await
{
node2.connecting.lock().remove(&addr);
node2.known_peers().register_failure(addr);
error!(parent: node2.span(), "couldn't accept a connection: {}", e);
}
});
}
Err(e) => {
error!(parent: node.span(), "couldn't accept a connection: {}", e);
}
}
}
});
self.tasks.lock().push(listening_task);
let _ = rx.await;
debug!(parent: self.span(), "listening on {}", listening_addr);
Ok(listening_addr)
}
}
#[inline]
pub fn name(&self) -> &str {
self.config.name.as_deref().unwrap()
}
#[inline]
pub fn config(&self) -> &Config {
&self.config
}
#[inline]
pub fn stats(&self) -> &Stats {
&self.stats
}
#[inline]
pub fn span(&self) -> &Span {
&self.span
}
pub fn listening_addr(&self) -> io::Result<SocketAddr> {
self.listening_addr
.get()
.copied()
.ok_or_else(|| io::ErrorKind::AddrNotAvailable.into())
}
async fn enable_protocols(&self, conn: Connection) -> io::Result<Connection> {
let mut conn = enable_protocol!(handshake, self, conn);
if let Some(stream) = conn.stream.take() {
let (reader, writer) = split(stream);
conn.reader = Some(Box::new(reader));
conn.writer = Some(Box::new(writer));
}
let conn = enable_protocol!(reading, self, conn);
let conn = enable_protocol!(writing, self, conn);
Ok(conn)
}
async fn adapt_stream(
&self,
stream: TcpStream,
peer_addr: SocketAddr,
own_side: ConnectionSide,
) -> io::Result<()> {
self.known_peers.add(peer_addr);
if own_side == ConnectionSide::Initiator {
if let Ok(addr) = stream.local_addr() {
debug!(
parent: self.span(), "establishing connection with {}; the peer is connected on port {}",
peer_addr, addr.port()
);
} else {
warn!(parent: self.span(), "couldn't determine the peer's port");
}
}
let connection = Connection::new(peer_addr, stream, !own_side);
let mut connection = self.enable_protocols(connection).await?;
let conn_ready_tx = connection.readiness_notifier.take();
self.connections.add(connection);
self.connecting.lock().remove(&peer_addr);
if let Some(tx) = conn_ready_tx {
let _ = tx.send(());
}
Ok(())
}
async fn create_stream(
&self,
addr: SocketAddr,
socket: Option<TcpSocket>,
) -> io::Result<TcpStream> {
if let Some(socket) = socket {
socket.connect(addr).await
} else {
TcpStream::connect(addr).await
}
}
pub async fn connect(&self, addr: SocketAddr) -> io::Result<()> {
self.connect_inner(addr, None).await
}
pub async fn connect_using_socket(
&self,
addr: SocketAddr,
socket: TcpSocket,
) -> io::Result<()> {
self.connect_inner(addr, Some(socket)).await
}
async fn connect_inner(&self, addr: SocketAddr, socket: Option<TcpSocket>) -> io::Result<()> {
if let Ok(listening_addr) = self.listening_addr() {
if addr == listening_addr
|| addr.ip().is_loopback() && addr.port() == listening_addr.port()
{
error!(parent: self.span(), "can't connect to node's own listening address ({})", addr);
return Err(io::ErrorKind::AddrInUse.into());
}
}
if !self.can_add_connection() {
error!(parent: self.span(), "too many connections; refusing to connect to {}", addr);
return Err(io::ErrorKind::PermissionDenied.into());
}
if self.connections.is_connected(addr) {
warn!(parent: self.span(), "already connected to {}", addr);
return Err(io::ErrorKind::AlreadyExists.into());
}
if !self.connecting.lock().insert(addr) {
warn!(parent: self.span(), "already connecting to {}", addr);
return Err(io::ErrorKind::AlreadyExists.into());
}
let stream = self.create_stream(addr, socket).await.map_err(|e| {
self.connecting.lock().remove(&addr);
e
})?;
let ret = self
.adapt_stream(stream, addr, ConnectionSide::Initiator)
.await;
if let Err(ref e) = ret {
self.connecting.lock().remove(&addr);
self.known_peers().register_failure(addr);
error!(parent: self.span(), "couldn't initiate a connection with {}: {}", addr, e);
}
ret
}
pub async fn disconnect(&self, addr: SocketAddr) -> bool {
if let Some(handler) = self.protocols.disconnect.get() {
if self.is_connected(addr) {
let (sender, receiver) = oneshot::channel();
handler.trigger((addr, sender));
let _ = receiver.await; }
}
let conn = self.connections.remove(addr);
if let Some(ref conn) = conn {
debug!(parent: self.span(), "disconnecting from {}", conn.addr());
for task in conn.tasks.iter().rev() {
task.abort();
}
if conn.side() == ConnectionSide::Initiator {
self.known_peers().remove(conn.addr());
}
debug!(parent: self.span(), "disconnected from {}", addr);
} else {
debug!(parent: self.span(), "couldn't disconnect from {}, as it wasn't connected", addr);
}
conn.is_some()
}
pub fn connected_addrs(&self) -> Vec<SocketAddr> {
self.connections.addrs()
}
#[inline]
pub fn known_peers(&self) -> &KnownPeers {
&self.known_peers
}
pub fn is_connected(&self, addr: SocketAddr) -> bool {
self.connections.is_connected(addr)
}
pub fn is_connecting(&self, addr: SocketAddr) -> bool {
self.connecting.lock().contains(&addr)
}
pub fn num_connected(&self) -> usize {
self.connections.num_connected()
}
pub fn num_connecting(&self) -> usize {
self.connecting.lock().len()
}
fn can_add_connection(&self) -> bool {
let num_connected = self.num_connected();
let limit = self.config.max_connections as usize;
if num_connected >= limit || num_connected + self.num_connecting() >= limit {
warn!(parent: self.span(), "maximum number of connections ({}) reached", limit);
false
} else {
true
}
}
pub async fn shut_down(&self) {
debug!(parent: self.span(), "shutting down");
let mut tasks = std::mem::take(&mut *self.tasks.lock()).into_iter();
if let Some(listening_task) = tasks.next() {
listening_task.abort(); }
for addr in self.connected_addrs() {
self.disconnect(addr).await;
}
for handle in tasks {
handle.abort();
}
}
}
fn create_span(node_name: &str) -> Span {
let mut span = trace_span!("node", name = node_name);
if !span.is_disabled() {
return span;
} else {
span = debug_span!("node", name = node_name);
}
if !span.is_disabled() {
return span;
} else {
span = info_span!("node", name = node_name);
}
if !span.is_disabled() {
return span;
} else {
span = warn_span!("node", name = node_name);
}
if !span.is_disabled() {
span
} else {
error_span!("node", name = node_name)
}
}