use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use anyhow::{Context, Result};
use bytes::Bytes;
use futures_lite::StreamExt;
use futures_util::SinkExt;
use iroh_metrics::{inc, inc_by};
use tokio::sync::mpsc;
use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
use tracing::{trace, Instrument};
use crate::{
disco::looks_like_disco_wrapper,
key::PublicKey,
relay::{
codec::{write_frame, Frame, KEEP_ALIVE},
server::{
metrics::Metrics,
streams::RelayIo,
types::{Packet, ServerMessage},
},
},
};
#[derive(Debug)]
pub(crate) struct ClientConnManager {
pub(crate) conn_num: usize,
pub(crate) key: PublicKey,
done: CancellationToken,
io_handle: AbortOnDropHandle<Result<()>>,
pub(crate) client_channels: ClientChannels,
}
#[derive(Debug)]
pub(crate) struct ClientChannels {
pub(crate) send_queue: mpsc::Sender<Packet>,
pub(crate) disco_send_queue: mpsc::Sender<Packet>,
pub(crate) peer_gone: mpsc::Sender<PublicKey>,
}
#[derive(Debug)]
pub struct ClientConnBuilder {
pub(crate) key: PublicKey,
pub(crate) conn_num: usize,
pub(crate) io: RelayIo,
pub(crate) write_timeout: Option<Duration>,
pub(crate) channel_capacity: usize,
pub(crate) server_channel: mpsc::Sender<ServerMessage>,
}
impl ClientConnBuilder {
pub(crate) fn build(self) -> ClientConnManager {
ClientConnManager::new(
self.key,
self.conn_num,
self.io,
self.write_timeout,
self.channel_capacity,
self.server_channel,
)
}
}
impl ClientConnManager {
#[allow(clippy::too_many_arguments)]
pub fn new(
key: PublicKey,
conn_num: usize,
io: RelayIo,
write_timeout: Option<Duration>,
channel_capacity: usize,
server_channel: mpsc::Sender<ServerMessage>,
) -> ClientConnManager {
let done = CancellationToken::new();
let client_id = (key, conn_num);
let (send_queue_s, send_queue_r) = mpsc::channel(channel_capacity);
let (disco_send_queue_s, disco_send_queue_r) = mpsc::channel(channel_capacity);
let (peer_gone_s, peer_gone_r) = mpsc::channel(channel_capacity);
let preferred = Arc::from(AtomicBool::from(false));
let conn_io = ClientConnIo {
io,
timeout: write_timeout,
send_queue: send_queue_r,
disco_send_queue: disco_send_queue_r,
peer_gone: peer_gone_r,
key,
preferred: Arc::clone(&preferred),
server_channel: server_channel.clone(),
};
let io_done = done.clone();
let io_client_id = client_id;
let io_handle = tokio::task::spawn(
async move {
let key = io_client_id.0;
let conn_num = io_client_id.1;
let res = conn_io.run(io_done).await;
let _ = server_channel
.send(ServerMessage::RemoveClient((key, conn_num)))
.await;
match res {
Err(e) => {
tracing::warn!(
"connection manager for {key:?} {conn_num}: writer closed in error {e}"
);
Err(e)
}
Ok(_) => {
tracing::warn!("connection manager for {key:?} {conn_num}: writer closed");
Ok(())
}
}
}
.instrument(tracing::debug_span!("conn_io")),
);
ClientConnManager {
conn_num,
key,
io_handle: AbortOnDropHandle::new(io_handle),
done,
client_channels: ClientChannels {
send_queue: send_queue_s,
disco_send_queue: disco_send_queue_s,
peer_gone: peer_gone_s,
},
}
}
pub async fn shutdown(self) {
self.done.cancel();
if let Err(e) = self.io_handle.await {
tracing::warn!(
"error closing IO loop for client connection {:?} {}: {e:?}",
self.key,
self.conn_num
);
}
}
}
#[derive(Debug)]
pub(crate) struct ClientConnIo {
io: RelayIo,
timeout: Option<Duration>,
send_queue: mpsc::Receiver<Packet>,
disco_send_queue: mpsc::Receiver<Packet>,
peer_gone: mpsc::Receiver<PublicKey>,
key: PublicKey,
server_channel: mpsc::Sender<ServerMessage>,
preferred: Arc<AtomicBool>,
}
impl ClientConnIo {
async fn run(mut self, done: CancellationToken) -> Result<()> {
let jitter = Duration::from_secs(5);
let mut keep_alive = tokio::time::interval(KEEP_ALIVE + jitter);
keep_alive.tick().await;
loop {
trace!("tick");
tokio::select! {
biased;
_ = done.cancelled() => {
trace!("cancelled");
self.io.flush().await.context("flush")?;
return Ok(());
}
read_res = self.io.next() => {
trace!("handle read");
match read_res {
Some(Ok(frame)) => {
self.handle_read(frame).await.context("handle_read")?;
}
Some(Err(err)) => {
return Err(err);
}
None => {
return Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "read stream ended").into());
}
}
}
peer = self.peer_gone.recv() => {
let peer = peer.context("Server.peer_gone dropped")?;
trace!("peer gone: {:?}", peer);
self.send_peer_gone(peer).await?;
}
packet = self.send_queue.recv() => {
let packet = packet.context("Server.send_queue dropped")?;
trace!("send packet");
self.send_packet(packet).await.context("send packet")?;
}
packet = self.disco_send_queue.recv() => {
let packet = packet.context("Server.disco_send_queue dropped")?;
trace!("send disco packet");
self.send_packet(packet).await.context("send packet")?;
}
_ = keep_alive.tick() => {
trace!("keep alive");
self.send_keep_alive().await.context("send keep alive")?;
}
}
self.io.flush().await.context("final flush")?;
}
}
async fn send_keep_alive(&mut self) -> Result<()> {
write_frame(&mut self.io, Frame::KeepAlive, self.timeout).await
}
async fn send_pong(&mut self, data: [u8; 8]) -> Result<()> {
write_frame(&mut self.io, Frame::Pong { data }, self.timeout).await
}
async fn send_peer_gone(&mut self, peer: PublicKey) -> Result<()> {
write_frame(&mut self.io, Frame::PeerGone { peer }, self.timeout).await
}
async fn send_packet(&mut self, packet: Packet) -> Result<()> {
let src_key = packet.src;
let content = packet.bytes;
if let Ok(len) = content.len().try_into() {
inc_by!(Metrics, bytes_sent, len);
}
write_frame(
&mut self.io,
Frame::RecvPacket { src_key, content },
self.timeout,
)
.await
}
async fn handle_read(&mut self, frame: Frame) -> Result<()> {
match frame {
Frame::NotePreferred { preferred } => {
self.handle_frame_note_preferred(preferred)?;
inc!(Metrics, other_packets_recv);
}
Frame::SendPacket { dst_key, packet } => {
let packet_len = packet.len();
self.handle_frame_send_packet(dst_key, packet).await?;
inc_by!(Metrics, bytes_recv, packet_len as u64);
}
Frame::Ping { data } => {
self.handle_frame_ping(data).await?;
inc!(Metrics, got_ping);
}
Frame::Health { .. } => {
inc!(Metrics, other_packets_recv);
}
_ => {
inc!(Metrics, unknown_frames);
}
}
Ok(())
}
fn set_preferred(&mut self, v: bool) -> Result<()> {
if self.preferred.swap(v, Ordering::Relaxed) == v {
return Ok(());
}
Ok(())
}
fn handle_frame_note_preferred(&mut self, preferred: bool) -> Result<()> {
self.set_preferred(preferred)
}
async fn send_server(&self, msg: ServerMessage) -> Result<()> {
self.server_channel
.send(msg)
.await
.map_err(|_| anyhow::anyhow!("server gone"))?;
Ok(())
}
async fn handle_frame_ping(&mut self, data: [u8; 8]) -> Result<()> {
self.send_pong(data).await?;
inc!(Metrics, sent_pong);
Ok(())
}
async fn handle_frame_send_packet(&self, dst_key: PublicKey, data: Bytes) -> Result<()> {
let packet = Packet {
src: self.key,
bytes: data,
};
self.transfer_packet(dst_key, packet).await
}
async fn transfer_packet(&self, dstkey: PublicKey, packet: Packet) -> Result<()> {
if looks_like_disco_wrapper(&packet.bytes) {
inc!(Metrics, disco_packets_recv);
self.send_server(ServerMessage::SendDiscoPacket((dstkey, packet)))
.await?;
} else {
inc!(Metrics, send_packets_recv);
self.send_server(ServerMessage::SendPacket((dstkey, packet)))
.await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use anyhow::bail;
use tokio_util::codec::Framed;
use super::*;
use crate::{
key::SecretKey,
relay::{
client::conn,
codec::{recv_frame, DerpCodec, FrameType},
server::streams::MaybeTlsStream,
},
};
#[tokio::test]
async fn test_client_conn_io_basic() -> Result<()> {
let (send_queue_s, send_queue_r) = mpsc::channel(10);
let (disco_send_queue_s, disco_send_queue_r) = mpsc::channel(10);
let (peer_gone_s, peer_gone_r) = mpsc::channel(10);
let preferred = Arc::from(AtomicBool::from(true));
let key = SecretKey::generate().public();
let (io, io_rw) = tokio::io::duplex(1024);
let mut io_rw = Framed::new(io_rw, DerpCodec);
let (server_channel_s, mut server_channel_r) = mpsc::channel(10);
let conn_io = ClientConnIo {
io: RelayIo::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)),
timeout: None,
send_queue: send_queue_r,
disco_send_queue: disco_send_queue_r,
peer_gone: peer_gone_r,
key,
server_channel: server_channel_s,
preferred: Arc::clone(&preferred),
};
let done = CancellationToken::new();
let io_done = done.clone();
let io_handle = tokio::task::spawn(async move { conn_io.run(io_done).await });
println!("-- write");
let data = b"hello world!";
println!(" send packet");
let packet = Packet {
src: key,
bytes: Bytes::from(&data[..]),
};
send_queue_s.send(packet.clone()).await?;
let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await?;
assert_eq!(
frame,
Frame::RecvPacket {
src_key: key,
content: data.to_vec().into()
}
);
println!(" send disco packet");
disco_send_queue_s.send(packet.clone()).await?;
let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await?;
assert_eq!(
frame,
Frame::RecvPacket {
src_key: key,
content: data.to_vec().into()
}
);
println!("send peer gone");
peer_gone_s.send(key).await?;
let frame = recv_frame(FrameType::PeerGone, &mut io_rw).await?;
assert_eq!(frame, Frame::PeerGone { peer: key });
println!("--read");
let data = b"pingpong";
write_frame(&mut io_rw, Frame::Ping { data: *data }, None).await?;
println!(" recv pong");
let frame = recv_frame(FrameType::Pong, &mut io_rw).await?;
assert_eq!(frame, Frame::Pong { data: *data });
println!(" preferred: false");
write_frame(&mut io_rw, Frame::NotePreferred { preferred: false }, None).await?;
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(!preferred.load(Ordering::Relaxed));
println!(" preferred: true");
write_frame(&mut io_rw, Frame::NotePreferred { preferred: true }, None).await?;
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(preferred.fetch_and(true, Ordering::Relaxed));
let target = SecretKey::generate().public();
println!(" send packet");
let data = b"hello world!";
conn::send_packet(&mut io_rw, &None, target, Bytes::from_static(data)).await?;
let msg = server_channel_r.recv().await.unwrap();
match msg {
ServerMessage::SendPacket((got_target, packet)) => {
assert_eq!(target, got_target);
assert_eq!(key, packet.src);
assert_eq!(&data[..], &packet.bytes);
}
m => {
bail!("expected ServerMessage::SendPacket, got {m:?}");
}
}
println!(" send disco packet");
let mut disco_data = crate::disco::MAGIC.as_bytes().to_vec();
disco_data.extend_from_slice(target.as_bytes());
disco_data.extend_from_slice(data);
conn::send_packet(&mut io_rw, &None, target, disco_data.clone().into()).await?;
let msg = server_channel_r.recv().await.unwrap();
match msg {
ServerMessage::SendDiscoPacket((got_target, packet)) => {
assert_eq!(target, got_target);
assert_eq!(key, packet.src);
assert_eq!(&disco_data[..], &packet.bytes);
}
m => {
bail!("expected ServerMessage::SendDiscoPacket, got {m:?}");
}
}
done.cancel();
io_handle.await??;
Ok(())
}
#[tokio::test]
async fn test_client_conn_read_err() -> Result<()> {
let (_send_queue_s, send_queue_r) = mpsc::channel(10);
let (_disco_send_queue_s, disco_send_queue_r) = mpsc::channel(10);
let (_peer_gone_s, peer_gone_r) = mpsc::channel(10);
let preferred = Arc::from(AtomicBool::from(true));
let key = SecretKey::generate().public();
let (io, io_rw) = tokio::io::duplex(1024);
let mut io_rw = Framed::new(io_rw, DerpCodec);
let (server_channel_s, mut server_channel_r) = mpsc::channel(10);
println!("-- create client conn");
let conn_io = ClientConnIo {
io: RelayIo::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)),
timeout: None,
send_queue: send_queue_r,
disco_send_queue: disco_send_queue_r,
peer_gone: peer_gone_r,
key,
server_channel: server_channel_s,
preferred: Arc::clone(&preferred),
};
let done = CancellationToken::new();
let io_done = done.clone();
println!("-- run client conn");
let io_handle = tokio::task::spawn(async move { conn_io.run(io_done).await });
println!(" send packet");
let data = b"hello world!";
let target = SecretKey::generate().public();
conn::send_packet(&mut io_rw, &None, target, Bytes::from_static(data)).await?;
let msg = server_channel_r.recv().await.unwrap();
match msg {
ServerMessage::SendPacket((got_target, packet)) => {
assert_eq!(target, got_target);
assert_eq!(key, packet.src);
assert_eq!(&data[..], &packet.bytes);
println!(" send packet success");
}
m => {
bail!("expected ServerMessage::SendPacket, got {m:?}");
}
}
println!("-- drop io");
drop(io_rw);
if let Err(err) = tokio::time::timeout(Duration::from_secs(1), io_handle).await?? {
if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
if io_err.kind() == std::io::ErrorKind::UnexpectedEof {
println!(" task closed successfully with `UnexpectedEof` error");
} else {
bail!("expected `UnexpectedEof` error, got unknown error: {io_err:?}");
}
} else {
bail!("expected `std::io::Error`, got `None`");
}
} else {
bail!("expected task to finish in `UnexpectedEof` error, got `Ok(())`");
}
Ok(())
}
}