use std::collections::{HashMap, HashSet};
use iroh_metrics::inc;
use tokio::{sync::mpsc, task::JoinSet};
use tracing::{Instrument, Span};
use super::{
client_conn::{ClientConnBuilder, ClientConnManager},
metrics::Metrics,
types::Packet,
};
use crate::key::PublicKey;
const RETRIES: usize = 3;
#[derive(Debug)]
struct Client {
conn: ClientConnManager,
sent_to: HashSet<PublicKey>,
}
impl Client {
pub fn new(conn: ClientConnManager) -> Self {
Self {
conn,
sent_to: HashSet::default(),
}
}
pub fn record_send(&mut self, dst: PublicKey) {
self.sent_to.insert(dst);
}
pub fn shutdown(self) {
tokio::spawn(
async move {
self.conn.shutdown().await;
}
.instrument(Span::current()),
);
}
pub async fn shutdown_await(self) {
self.conn.shutdown().await;
}
pub fn send_packet(&self, packet: Packet) -> Result<(), SendError> {
let res = try_send(&self.conn.client_channels.send_queue, packet);
if res.is_ok() {
inc!(Metrics, send_packets_sent);
}
res
}
pub fn send_disco_packet(&self, packet: Packet) -> Result<(), SendError> {
let res = try_send(&self.conn.client_channels.disco_send_queue, packet);
if res.is_ok() {
inc!(Metrics, disco_packets_sent);
}
res
}
pub fn send_peer_gone(&self, key: PublicKey) -> Result<(), SendError> {
let res = try_send(&self.conn.client_channels.peer_gone, key);
match res {
Ok(_) => {
inc!(Metrics, other_packets_sent);
}
Err(_) => {
inc!(Metrics, other_packets_dropped);
}
}
res
}
}
fn try_send<T>(sender: &mpsc::Sender<T>, msg: T) -> Result<(), SendError> {
let mut msg = msg;
for _ in 0..RETRIES {
match sender.try_send(msg) {
Ok(_) => return Ok(()),
Err(mpsc::error::TrySendError::Full(m)) => msg = m,
Err(_) => return Err(SendError::SenderClosed),
}
}
Err(SendError::PacketDropped)
}
#[derive(Debug)]
enum SendError {
PacketDropped,
SenderClosed,
}
#[derive(Debug)]
pub(crate) struct Clients {
inner: HashMap<PublicKey, Client>,
}
impl Drop for Clients {
fn drop(&mut self) {}
}
impl Clients {
pub fn new() -> Self {
Self {
inner: HashMap::default(),
}
}
pub async fn shutdown(&mut self) {
tracing::trace!("shutting down conn");
let mut handles = JoinSet::default();
for (_, client) in self.inner.drain() {
handles.spawn(async move { client.shutdown_await().await }.instrument(Span::current()));
}
while let Some(t) = handles.join_next().await {
if let Err(err) = t {
tracing::trace!("shutdown error: {:?}", err);
}
}
}
pub fn record_send(&mut self, src: &PublicKey, dst: PublicKey) {
if let Some(client) = self.inner.get_mut(src) {
client.record_send(dst);
}
}
pub fn contains_key(&self, key: &PublicKey) -> bool {
self.inner.contains_key(key)
}
pub fn has_client(&self, key: &PublicKey, conn_num: usize) -> bool {
if let Some(client) = self.inner.get(key) {
return client.conn.conn_num == conn_num;
}
false
}
pub fn register(&mut self, client_builder: ClientConnBuilder) {
let key = client_builder.key;
tracing::trace!("registering client: {:?}", key);
let client = client_builder.build();
let client = Client::new(client);
if let Some(old_client) = self.inner.insert(key, client) {
tracing::warn!("multiple connections found for {key:?}, pruning old connection",);
old_client.shutdown();
}
}
pub fn unregister(&mut self, peer: &PublicKey) {
tracing::trace!("unregistering client: {:?}", peer);
if let Some(client) = self.inner.remove(peer) {
for key in client.sent_to.iter() {
self.send_peer_gone(key, *peer);
}
tracing::warn!("pruning connection {peer:?}");
client.shutdown();
}
}
pub fn send_packet(&mut self, key: &PublicKey, packet: Packet) -> anyhow::Result<()> {
if let Some(client) = self.inner.get(key) {
let res = client.send_packet(packet);
return self.process_result(key, res);
};
tracing::warn!("Could not find client for {key:?}, dropping packet");
anyhow::bail!("Could not find client for {key:?}, dropped packet");
}
pub fn send_disco_packet(&mut self, key: &PublicKey, packet: Packet) -> anyhow::Result<()> {
if let Some(client) = self.inner.get(key) {
let res = client.send_disco_packet(packet);
return self.process_result(key, res);
};
tracing::warn!("Could not find client for {key:?}, dropping disco packet");
anyhow::bail!("Could not find client for {key:?}, dropped packet");
}
pub fn send_peer_gone(&mut self, key: &PublicKey, peer: PublicKey) {
if let Some(client) = self.inner.get(key) {
let res = client.send_peer_gone(peer);
let _ = self.process_result(key, res);
return;
};
tracing::warn!("Could not find client for {key:?}, dropping peer gone packet");
}
fn process_result(
&mut self,
key: &PublicKey,
res: Result<(), SendError>,
) -> anyhow::Result<()> {
match res {
Ok(_) => return Ok(()),
Err(SendError::PacketDropped) => {
tracing::warn!("client {key:?} too busy to receive packet, dropping packet");
}
Err(SendError::SenderClosed) => {
tracing::warn!("Can no longer write to client {key:?}, dropping message and pruning connection");
self.unregister(key);
}
}
anyhow::bail!("unable to send msg");
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use bytes::Bytes;
use tokio::io::DuplexStream;
use tokio_util::codec::{Framed, FramedRead};
use super::*;
use crate::{
key::SecretKey,
relay::{
codec::{recv_frame, DerpCodec, Frame, FrameType},
server::streams::{MaybeTlsStream, RelayIo},
},
};
fn test_client_builder(
key: PublicKey,
conn_num: usize,
) -> (ClientConnBuilder, FramedRead<DuplexStream, DerpCodec>) {
let (test_io, io) = tokio::io::duplex(1024);
let (server_channel, _) = mpsc::channel(10);
(
ClientConnBuilder {
key,
conn_num,
io: RelayIo::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)),
write_timeout: None,
channel_capacity: 10,
server_channel,
},
FramedRead::new(test_io, DerpCodec),
)
}
#[tokio::test]
async fn test_clients() -> Result<()> {
let a_key = SecretKey::generate().public();
let b_key = SecretKey::generate().public();
let (builder_a, mut a_rw) = test_client_builder(a_key, 0);
let mut clients = Clients::new();
clients.register(builder_a);
let data = b"hello world!";
let expect_packet = Packet {
src: b_key,
bytes: Bytes::from(&data[..]),
};
clients.send_packet(&a_key.clone(), expect_packet.clone())?;
let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?;
assert_eq!(
frame,
Frame::RecvPacket {
src_key: b_key,
content: data.to_vec().into(),
}
);
clients.send_disco_packet(&a_key.clone(), expect_packet)?;
let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?;
assert_eq!(
frame,
Frame::RecvPacket {
src_key: b_key,
content: data.to_vec().into(),
}
);
clients.send_peer_gone(&a_key, b_key);
let frame = recv_frame(FrameType::PeerGone, &mut a_rw).await?;
assert_eq!(frame, Frame::PeerGone { peer: b_key });
clients.unregister(&a_key.clone());
assert!(!clients.inner.contains_key(&a_key));
clients.shutdown().await;
Ok(())
}
}