use dashmap::DashMap;
use futures::StreamExt;
use nym_noise::config::NoiseConfig;
use nym_noise::upgrade_noise_initiator;
use nym_sphinx::forwarding::packet::MixPacket;
use nym_sphinx::framing::codec::NymCodec;
use nym_sphinx::framing::packet::FramedNymPacket;
use std::io;
use std::net::SocketAddr;
use std::ops::Deref;
use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tokio::time::sleep;
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::codec::Framed;
use tracing::*;
#[derive(Clone, Copy)]
pub struct Config {
pub initial_reconnection_backoff: Duration,
pub maximum_reconnection_backoff: Duration,
pub initial_connection_timeout: Duration,
pub maximum_connection_buffer_size: usize,
pub use_legacy_packet_encoding: bool,
}
impl Config {
pub fn new(
initial_reconnection_backoff: Duration,
maximum_reconnection_backoff: Duration,
initial_connection_timeout: Duration,
maximum_connection_buffer_size: usize,
use_legacy_packet_encoding: bool,
) -> Self {
Config {
initial_reconnection_backoff,
maximum_reconnection_backoff,
initial_connection_timeout,
maximum_connection_buffer_size,
use_legacy_packet_encoding,
}
}
}
pub trait SendWithoutResponse {
fn send_without_response(&self, packet: MixPacket) -> io::Result<()>;
}
pub struct Client {
active_connections: ActiveConnections,
noise_config: NoiseConfig,
connections_count: Arc<AtomicUsize>,
config: Config,
}
#[derive(Default, Clone)]
pub struct ActiveConnections {
inner: Arc<DashMap<SocketAddr, ConnectionSender>>,
}
impl ActiveConnections {
pub fn pending_packets(&self) -> usize {
self.inner
.iter()
.map(|sender| {
let max_capacity = sender.channel.max_capacity();
let capacity = sender.channel.capacity();
max_capacity - capacity
})
.sum()
}
}
impl Deref for ActiveConnections {
type Target = DashMap<SocketAddr, ConnectionSender>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
pub struct ConnectionSender {
channel: mpsc::Sender<FramedNymPacket>,
current_reconnection_attempt: Arc<AtomicU32>,
}
impl ConnectionSender {
fn new(channel: mpsc::Sender<FramedNymPacket>) -> Self {
ConnectionSender {
channel,
current_reconnection_attempt: Arc::new(AtomicU32::new(0)),
}
}
}
struct ManagedConnection {
address: SocketAddr,
noise_config: NoiseConfig,
message_receiver: ReceiverStream<FramedNymPacket>,
connection_timeout: Duration,
current_reconnection: Arc<AtomicU32>,
}
impl ManagedConnection {
fn new(
address: SocketAddr,
noise_config: NoiseConfig,
message_receiver: mpsc::Receiver<FramedNymPacket>,
connection_timeout: Duration,
current_reconnection: Arc<AtomicU32>,
) -> Self {
ManagedConnection {
address,
noise_config,
message_receiver: ReceiverStream::new(message_receiver),
connection_timeout,
current_reconnection,
}
}
async fn run(self) {
let address = self.address;
let connection_fut = TcpStream::connect(address);
let conn = match tokio::time::timeout(self.connection_timeout, connection_fut).await {
Ok(stream_res) => match stream_res {
Ok(stream) => {
debug!("Managed to establish connection to {}", self.address);
let noise_stream =
match upgrade_noise_initiator(stream, &self.noise_config).await {
Ok(noise_stream) => noise_stream,
Err(err) => {
error!("Failed to perform Noise handshake with {address} - {err}");
self.current_reconnection.fetch_add(1, Ordering::SeqCst);
return;
}
};
self.current_reconnection.store(0, Ordering::Release);
debug!("Noise initiator handshake completed for {:?}", address);
Framed::new(noise_stream, NymCodec)
}
Err(err) => {
debug!("failed to establish connection to {address} (err: {err})",);
return;
}
},
Err(_) => {
debug!(
"failed to connect to {address} within {:?}",
self.connection_timeout
);
self.current_reconnection.fetch_add(1, Ordering::SeqCst);
return;
}
};
if let Err(err) = self.message_receiver.map(Ok).forward(conn).await {
warn!("Failed to forward packets to {address}: {err}");
}
debug!(
"connection manager to {address} is finished. Either the connection failed or mixnet client got dropped",
);
}
}
impl Client {
pub fn new(
config: Config,
noise_config: NoiseConfig,
connections_count: Arc<AtomicUsize>,
) -> Client {
Client {
active_connections: Default::default(),
noise_config,
connections_count,
config,
}
}
pub fn active_connections(&self) -> ActiveConnections {
self.active_connections.clone()
}
fn determine_backoff(&self, current_attempt: u32) -> Option<Duration> {
if current_attempt == 0 {
None
} else {
let exp = 2_u32.checked_pow(current_attempt);
let backoff = exp
.and_then(|exp| self.config.initial_reconnection_backoff.checked_mul(exp))
.unwrap_or(self.config.maximum_reconnection_backoff);
Some(std::cmp::min(
backoff,
self.config.maximum_reconnection_backoff,
))
}
}
fn make_connection(&self, address: SocketAddr, pending_packet: FramedNymPacket) {
let (sender, receiver) = mpsc::channel(self.config.maximum_connection_buffer_size);
if self.config.maximum_connection_buffer_size > 0 {
sender.try_send(pending_packet).unwrap();
}
let current_reconnection_attempt =
if let Some(mut existing) = self.active_connections.get_mut(&address) {
existing.channel = sender;
Arc::clone(&existing.current_reconnection_attempt)
} else {
let new_entry = ConnectionSender::new(sender);
let current_attempt = Arc::clone(&new_entry.current_reconnection_attempt);
self.active_connections.insert(address, new_entry);
current_attempt
};
let reconnection_attempt = current_reconnection_attempt.load(Ordering::Acquire);
let backoff = self.determine_backoff(reconnection_attempt);
let initial_connection_timeout = self.config.initial_connection_timeout;
let connections_count = self.connections_count.clone();
let noise_config = self.noise_config.clone();
tokio::spawn(async move {
if let Some(backoff) = backoff {
trace!("waiting for {:?} before attempting connection", backoff);
sleep(backoff).await;
}
connections_count.fetch_add(1, Ordering::SeqCst);
ManagedConnection::new(
address,
noise_config,
receiver,
initial_connection_timeout,
current_reconnection_attempt,
)
.run()
.await;
connections_count.fetch_sub(1, Ordering::SeqCst);
});
}
}
impl SendWithoutResponse for Client {
fn send_without_response(&self, packet: MixPacket) -> io::Result<()> {
let address = packet.next_hop_address();
trace!("Sending packet to {address}");
let framed_packet =
FramedNymPacket::from_mix_packet(packet, self.config.use_legacy_packet_encoding);
let Some(sender) = self.active_connections.get_mut(&address) else {
debug!("establishing initial connection to {address}");
self.make_connection(address, framed_packet);
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"connection is in progress",
));
};
let sending_res = sender.channel.try_send(framed_packet);
drop(sender);
sending_res.map_err(|err| {
match err {
TrySendError::Full(_) => {
debug!("Connection to {address} seems to not be able to handle all the traffic - dropping the current packet");
io::Error::new(
io::ErrorKind::WouldBlock,
"connection queue is full",
)
}
TrySendError::Closed(dropped) => {
debug!(
"Connection to {address} seems to be dead. attempting to re-establish it...",
);
self.make_connection(address, dropped);
io::Error::new(
io::ErrorKind::ConnectionAborted,
"reconnection attempt is in progress",
)
}
}
} )
}
}
#[cfg(test)]
mod tests {
use super::*;
use nym_crypto::asymmetric::x25519;
use nym_noise::config::NoiseNetworkView;
use rand::rngs::OsRng;
fn dummy_client() -> Client {
let mut rng = OsRng; Client::new(
Config {
initial_reconnection_backoff: Duration::from_millis(10_000),
maximum_reconnection_backoff: Duration::from_millis(300_000),
initial_connection_timeout: Duration::from_millis(1_500),
maximum_connection_buffer_size: 128,
use_legacy_packet_encoding: false,
},
NoiseConfig::new(
Arc::new(x25519::KeyPair::new(&mut rng)),
NoiseNetworkView::new_empty(),
Duration::from_millis(1_500),
),
Default::default(),
)
}
#[test]
fn determining_backoff_works_regardless_of_attempt() {
let client = dummy_client();
assert!(client.determine_backoff(0).is_none());
assert!(client.determine_backoff(1).is_some());
assert!(client.determine_backoff(2).is_some());
assert_eq!(
client.determine_backoff(16).unwrap(),
client.config.maximum_reconnection_backoff
);
assert_eq!(
client.determine_backoff(32).unwrap(),
client.config.maximum_reconnection_backoff
);
assert_eq!(
client.determine_backoff(1024).unwrap(),
client.config.maximum_reconnection_backoff
);
assert_eq!(
client.determine_backoff(65536).unwrap(),
client.config.maximum_reconnection_backoff
);
assert_eq!(
client.determine_backoff(u32::MAX).unwrap(),
client.config.maximum_reconnection_backoff
);
}
}