1use std::{collections::HashSet, sync::Arc, time::Duration};
4
5use iroh_base::EndpointId;
6use n0_error::{e, stack_error};
7use n0_future::{SinkExt, StreamExt};
8use rand::RngExt;
9use time::{Date, OffsetDateTime};
10use tokio::{
11 sync::mpsc::{self, error::TrySendError},
12 time::MissedTickBehavior,
13};
14use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle};
15use tracing::{Instrument, debug, trace, warn};
16
17use crate::{
18 PingTracker,
19 http::ProtocolVersion,
20 protos::{
21 relay::{ClientToRelayMsg, Datagrams, PING_INTERVAL, RelayToClientMsg, Status},
22 streams::BytesStreamSink,
23 },
24 server::{
25 clients::Clients,
26 metrics::Metrics,
27 streams::{RecvError as RelayRecvError, RelayedStream, SendError as RelaySendError},
28 },
29};
30
31#[derive(Debug, Clone)]
33pub(super) struct Packet {
34 src: EndpointId,
36 data: Datagrams,
38}
39
40#[derive(Debug)]
44pub struct Config<S> {
45 pub endpoint_id: EndpointId,
47 pub stream: RelayedStream<S>,
49 pub write_timeout: Duration,
51 pub channel_capacity: usize,
53 pub protocol_version: ProtocolVersion,
55}
56
57#[derive(Debug)]
62pub struct Client {
63 endpoint_id: EndpointId,
65 connection_id: u64,
67 done: CancellationToken,
69 handle: AbortOnDropHandle<()>,
71 packet_queue: mpsc::Sender<Packet>,
73 message_queue: mpsc::Sender<RelayToClientMsg>,
75 protocol_version: ProtocolVersion,
77}
78
79impl Client {
80 pub(super) fn new<S>(
84 config: Config<S>,
85 connection_id: u64,
86 clients: &Clients,
87 metrics: Arc<Metrics>,
88 ) -> Client
89 where
90 S: BytesStreamSink + Send + 'static,
91 {
92 let Config {
93 endpoint_id,
94 stream,
95 write_timeout,
96 channel_capacity,
97 protocol_version,
98 } = config;
99
100 let (packet_send_queue_s, packet_send_queue_r) = mpsc::channel(channel_capacity);
101 let (message_send_queue_s, message_send_queue_r) = mpsc::channel(channel_capacity);
102 let done = CancellationToken::new();
103
104 let actor = Actor {
105 stream,
106 timeout: write_timeout,
107 packet_send_queue: packet_send_queue_r,
108 message_send_queue: message_send_queue_r,
109 endpoint_id,
110 connection_id,
111 clients: clients.clone(),
112 client_counter: ClientCounter::default(),
113 ping_tracker: PingTracker::default(),
114 metrics,
115 };
116
117 let io_done = done.clone();
119 let handle = tokio::task::spawn(actor.run(io_done).instrument(tracing::info_span!(
120 "client-connection-actor",
121 remote_endpoint = %endpoint_id.fmt_short(),
122 connection_id = connection_id
123 )));
124
125 Client {
126 endpoint_id,
127 connection_id,
128 handle: AbortOnDropHandle::new(handle),
129 done,
130 packet_queue: packet_send_queue_s,
131 message_queue: message_send_queue_s,
132 protocol_version,
133 }
134 }
135
136 pub(super) fn connection_id(&self) -> u64 {
137 self.connection_id
138 }
139
140 pub(super) async fn shutdown(self) {
144 self.start_shutdown();
145 if let Err(e) = self.handle.await {
146 warn!(
147 remote_endpoint = %self.endpoint_id.fmt_short(),
148 "error closing actor loop: {e:#?}",
149 );
150 };
151 }
152
153 pub(super) fn start_shutdown(&self) {
155 self.done.cancel();
156 }
157
158 pub(super) fn try_send_packet(
159 &self,
160 src: EndpointId,
161 data: Datagrams,
162 ) -> Result<(), TrySendError<Packet>> {
163 self.packet_queue.try_send(Packet { src, data })
164 }
165
166 pub(super) fn try_send_peer_gone(
167 &self,
168 key: EndpointId,
169 ) -> Result<(), TrySendError<RelayToClientMsg>> {
170 self.message_queue
171 .try_send(RelayToClientMsg::EndpointGone(key))
172 }
173
174 pub(super) fn try_send_health(
175 &self,
176 status: Status,
177 ) -> Result<(), TrySendError<RelayToClientMsg>> {
178 let message = match self.protocol_version {
179 ProtocolVersion::V2 => RelayToClientMsg::Status(status),
180 ProtocolVersion::V1 => RelayToClientMsg::Health {
181 problem: status.to_string(),
182 },
183 };
184 self.message_queue.try_send(message)
185 }
186}
187
188#[stack_error(derive, add_meta, from_sources)]
190#[allow(missing_docs)]
191#[non_exhaustive]
192pub enum HandleFrameError {
193 #[error(transparent)]
194 ForwardPacket { source: ForwardPacketError },
195 #[error("Stream terminated")]
196 StreamTerminated {},
197 #[error(transparent)]
198 Recv { source: RelayRecvError },
199 #[error(transparent)]
200 Send { source: WriteFrameError },
201}
202
203#[stack_error(derive, add_meta, from_sources)]
205#[allow(missing_docs)]
206#[non_exhaustive]
207pub enum WriteFrameError {
208 #[error(transparent)]
209 Stream { source: RelaySendError },
210 #[error(transparent)]
211 Timeout {
212 #[error(std_err)]
213 source: tokio::time::error::Elapsed,
214 },
215}
216
217#[stack_error(derive, add_meta)]
219#[allow(missing_docs)]
220#[non_exhaustive]
221pub enum RunError {
222 #[error(transparent)]
223 ForwardPacket {
224 #[error(from)]
225 source: ForwardPacketError,
226 },
227 #[error("Flush")]
228 Flush {},
229 #[error(transparent)]
230 HandleFrame {
231 #[error(from)]
232 source: HandleFrameError,
233 },
234 #[error("Failed to send packet")]
235 PacketSend { source: WriteFrameError },
236 #[error("Handle was dropped")]
237 HandleDropped {},
238 #[error("Writing a frame failed")]
239 WriteFrame { source: WriteFrameError },
240 #[error("Tick flush")]
241 TickFlush {},
242}
243
244#[derive(Debug)]
262struct Actor<S> {
263 stream: RelayedStream<S>,
265 timeout: Duration,
267 packet_send_queue: mpsc::Receiver<Packet>,
269 message_send_queue: mpsc::Receiver<RelayToClientMsg>,
271 endpoint_id: EndpointId,
273 connection_id: u64,
275 clients: Clients,
277 client_counter: ClientCounter,
279 ping_tracker: PingTracker,
280 metrics: Arc<Metrics>,
281}
282
283impl<S> Actor<S>
284where
285 S: BytesStreamSink,
286{
287 async fn run(mut self, done: CancellationToken) {
288 self.metrics.accepts.inc();
292 if self.client_counter.update(self.endpoint_id) {
293 self.metrics.unique_client_keys.inc();
294 }
295 match self.run_inner(done).await {
296 Err(e) => {
297 warn!("actor errored {e:#}, exiting");
298 }
299 Ok(()) => {
300 debug!("actor finished, exiting");
301 }
302 }
303
304 self.clients
305 .unregister(self.connection_id, self.endpoint_id, &self.metrics);
306 self.metrics.disconnects.inc();
307 }
308
309 async fn run_inner(&mut self, done: CancellationToken) -> Result<(), RunError> {
310 let next_interval = || {
312 let random_secs = rand::rng().random_range(1..=5);
313 Duration::from_secs(random_secs) + PING_INTERVAL
314 };
315
316 let mut ping_interval = tokio::time::interval(next_interval());
317 ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
319 ping_interval.tick().await;
320
321 loop {
322 tokio::select! {
323 biased;
324
325 _ = done.cancelled() => {
326 trace!("actor loop cancelled, exiting");
327 self.stream.flush().await.map_err(|_| e!(RunError::Flush))?;
329 break;
330 }
331 maybe_frame = self.stream.next() => {
332 self
333 .handle_frame(maybe_frame)
334 .await?;
335 ping_interval.reset();
337 }
338 packet = self.packet_send_queue.recv() => {
340 let packet = packet.ok_or_else(|| e!(RunError::HandleDropped))?;
341 self.send_packet(packet)
342 .await
343 .map_err(|err| e!(RunError::PacketSend, err))?;
344 }
345 message = self.message_send_queue.recv() => {
347 let message = message .ok_or_else(|| e!(RunError::HandleDropped))?;
348 trace!("send {message:?}");
349 self.write_frame(message)
350 .await
351 .map_err(|err| e!(RunError::WriteFrame, err))?;
352 }
353 _ = self.ping_tracker.timeout() => {
354 trace!("pong timed out");
355 break;
356 }
357 _ = ping_interval.tick() => {
358 trace!("keep alive ping");
359 ping_interval.reset_after(next_interval());
361 let data = self.ping_tracker.new_ping();
362 self.write_frame(RelayToClientMsg::Ping(data))
363 .await
364 .map_err(|err| e!(RunError::WriteFrame, err))?;
365 }
366 }
367
368 self.stream
369 .flush()
370 .await
371 .map_err(|_| e!(RunError::TickFlush))?;
372 }
373 Ok(())
374 }
375
376 async fn write_frame(&mut self, frame: RelayToClientMsg) -> Result<(), WriteFrameError> {
380 tokio::time::timeout(self.timeout, self.stream.send(frame)).await??;
381 Ok(())
382 }
383
384 async fn send_raw(&mut self, packet: Packet) -> Result<(), WriteFrameError> {
389 let remote_endpoint_id = packet.src;
390 let datagrams = packet.data;
391
392 if let Ok(len) = datagrams.contents.len().try_into() {
393 self.metrics.bytes_sent.inc_by(len);
394 }
395 self.write_frame(RelayToClientMsg::Datagrams {
396 remote_endpoint_id,
397 datagrams,
398 })
399 .await
400 }
401
402 async fn send_packet(&mut self, packet: Packet) -> Result<(), WriteFrameError> {
403 trace!("send packet");
404 match self.send_raw(packet).await {
405 Ok(()) => {
406 self.metrics.send_packets_sent.inc();
407 Ok(())
408 }
409 Err(err) => {
410 self.metrics.send_packets_dropped.inc();
411 Err(err)
412 }
413 }
414 }
415
416 async fn handle_frame(
418 &mut self,
419 maybe_frame: Option<Result<ClientToRelayMsg, RelayRecvError>>,
420 ) -> Result<(), HandleFrameError> {
421 trace!(?maybe_frame, "handle incoming frame");
422 let frame = match maybe_frame {
423 Some(frame) => frame?,
424 None => return Err(e!(HandleFrameError::StreamTerminated)),
425 };
426
427 match frame {
428 ClientToRelayMsg::Datagrams {
429 dst_endpoint_id: dst_key,
430 datagrams,
431 } => {
432 let packet_len = datagrams.contents.len();
433 if let Err(err @ ForwardPacketError { .. }) =
434 self.handle_frame_send_packet(dst_key, datagrams)
435 {
436 warn!("failed to handle send packet frame: {err:#}");
437 }
438 self.metrics.bytes_recv.inc_by(packet_len as u64);
439 }
440 ClientToRelayMsg::Ping(data) => {
441 self.metrics.got_ping.inc();
442 self.write_frame(RelayToClientMsg::Pong(data)).await?;
444 self.metrics.sent_pong.inc();
445 }
446 ClientToRelayMsg::Pong(data) => {
447 self.ping_tracker.pong_received(data);
448 }
449 }
450 Ok(())
451 }
452
453 fn handle_frame_send_packet(
454 &self,
455 dst: EndpointId,
456 data: Datagrams,
457 ) -> Result<(), ForwardPacketError> {
458 self.metrics.send_packets_recv.inc();
459 self.clients
460 .send_packet(dst, data, self.endpoint_id, &self.metrics)?;
461
462 Ok(())
463 }
464}
465
466#[derive(Debug)]
467pub(crate) enum SendError {
468 Full,
469 Closed,
470}
471
472#[stack_error(derive, add_meta)]
478#[error("failed to forward packet: {reason:?}")]
479pub struct ForwardPacketError {
480 reason: SendError,
481}
482
483#[derive(Debug)]
485struct ClientCounter {
486 clients: HashSet<EndpointId>,
487 last_clear_date: Date,
488}
489
490impl Default for ClientCounter {
491 fn default() -> Self {
492 Self {
493 clients: HashSet::new(),
494 last_clear_date: OffsetDateTime::now_utc().date(),
495 }
496 }
497}
498
499impl ClientCounter {
500 fn check_and_clear(&mut self) {
501 let today = OffsetDateTime::now_utc().date();
502 if today != self.last_clear_date {
503 self.clients.clear();
504 self.last_clear_date = today;
505 }
506 }
507
508 fn update(&mut self, client: EndpointId) -> bool {
510 self.check_and_clear();
511 self.clients.insert(client)
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use iroh_base::SecretKey;
518 use n0_error::{Result, StdResultExt, bail_any};
519 use n0_future::Stream;
520 use n0_tracing_test::traced_test;
521 use rand::SeedableRng;
522 use tracing::info;
523
524 use super::*;
525 use crate::{
526 client::conn::Conn,
527 http::ProtocolVersion,
528 protos::{common::FrameType, relay::Status, streams::WsBytesFramed},
529 server::streams::{MaybeTlsStream, RateLimited, ServerRelayedStream},
530 };
531
532 async fn recv_frame<
533 E: std::error::Error + Sync + Send + 'static,
534 S: Stream<Item = Result<RelayToClientMsg, E>> + Unpin,
535 >(
536 frame_type: FrameType,
537 mut stream: S,
538 ) -> Result<RelayToClientMsg> {
539 match stream.next().await {
540 Some(Ok(frame)) => {
541 if frame_type != frame.typ() {
542 bail_any!(
543 "Unexpected frame, got {:?}, but expected {:?}",
544 frame.typ(),
545 frame_type
546 );
547 }
548 Ok(frame)
549 }
550 Some(Err(err)) => Err(err).anyerr(),
551 None => bail_any!("Unexpected EOF, expected frame {frame_type:?}"),
552 }
553 }
554
555 #[tokio::test]
556 #[traced_test]
557 async fn test_client_actor_basic() -> Result {
558 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
559
560 let (send_queue_s, send_queue_r) = mpsc::channel(10);
561 let (message_s, message_r) = mpsc::channel(10);
562
563 let endpoint_id = SecretKey::from_bytes(&rng.random()).public();
564 let (io, io_rw) = tokio::io::duplex(1024);
565 let mut io_rw = Conn::test(io_rw, Default::default());
566 let stream = RelayedStream::test(io);
567
568 let clients = Clients::default();
569 let metrics = Arc::new(Metrics::default());
570 let actor = Actor {
571 stream,
572 timeout: Duration::from_secs(1),
573 packet_send_queue: send_queue_r,
574 message_send_queue: message_r,
575 connection_id: 0,
576 endpoint_id,
577 clients: clients.clone(),
578 client_counter: ClientCounter::default(),
579 ping_tracker: PingTracker::default(),
580 metrics,
581 };
582
583 let done = CancellationToken::new();
584 let io_done = done.clone();
585 let handle = tokio::task::spawn(async move { actor.run(io_done).await });
586
587 println!("-- write");
589 let data = b"hello world!";
590
591 println!(" send packet");
593 let packet = Packet {
594 src: endpoint_id,
595 data: Datagrams::from(&data[..]),
596 };
597 send_queue_s
598 .send(packet.clone())
599 .await
600 .std_context("send")?;
601 let frame = recv_frame(FrameType::RelayToClientDatagram, &mut io_rw)
602 .await
603 .anyerr()?;
604 assert_eq!(
605 frame,
606 RelayToClientMsg::Datagrams {
607 remote_endpoint_id: endpoint_id,
608 datagrams: data.to_vec().into()
609 }
610 );
611
612 println!("send peer gone");
614 message_s
615 .send(RelayToClientMsg::EndpointGone(endpoint_id))
616 .await
617 .std_context("send")?;
618 let frame = recv_frame(FrameType::EndpointGone, &mut io_rw)
619 .await
620 .anyerr()?;
621 assert_eq!(frame, RelayToClientMsg::EndpointGone(endpoint_id));
622
623 println!("--read");
625
626 let data = b"pingpong";
628 io_rw.send(ClientToRelayMsg::Ping(*data)).await?;
629
630 println!(" recv pong");
632 let frame = recv_frame(FrameType::Pong, &mut io_rw).await?;
633 assert_eq!(frame, RelayToClientMsg::Pong(*data));
634
635 let target = SecretKey::from_bytes(&rng.random()).public();
636
637 println!(" send packet");
639 let data = b"hello world!";
640 io_rw
641 .send(ClientToRelayMsg::Datagrams {
642 dst_endpoint_id: target,
643 datagrams: Datagrams::from(data),
644 })
645 .await
646 .std_context("send")?;
647
648 done.cancel();
649 handle.await.std_context("join")?;
650 Ok(())
651 }
652
653 fn test_client_builder(
654 key: EndpointId,
655 protocol_version: ProtocolVersion,
656 ) -> (Config<WsBytesFramed<RateLimited<MaybeTlsStream>>>, Conn) {
657 let (server, client) = tokio::io::duplex(1024);
658 (
659 Config {
660 endpoint_id: key,
661 stream: ServerRelayedStream::test(server),
662 write_timeout: Duration::from_secs(1),
663 channel_capacity: 10,
664 protocol_version,
665 },
666 Conn::test(client, protocol_version),
667 )
668 }
669
670 #[tokio::test]
671 #[traced_test]
672 async fn test_client_v1_protocol() -> Result {
673 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42u64);
674 let a_key = SecretKey::from_bytes(&rng.random()).public();
675 let b_key = SecretKey::from_bytes(&rng.random()).public();
676
677 let (builder_a, mut a_rw) = test_client_builder(a_key, ProtocolVersion::V1);
678
679 let clients = Clients::default();
680 let metrics = Arc::new(Metrics::default());
681 clients.register(builder_a, metrics.clone());
682
683 let data = b"hello world v1!";
685 clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?;
686 let frame = recv_frame(FrameType::RelayToClientDatagram, &mut a_rw).await?;
687 assert_eq!(
688 frame,
689 RelayToClientMsg::Datagrams {
690 remote_endpoint_id: b_key,
691 datagrams: data.to_vec().into(),
692 }
693 );
694
695 clients.shutdown().await;
696 Ok(())
697 }
698
699 #[tokio::test]
700 #[traced_test]
701 async fn test_client_v2_protocol() -> Result {
702 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42u64);
703 let a_key = SecretKey::from_bytes(&rng.random()).public();
704 let b_key = SecretKey::from_bytes(&rng.random()).public();
705
706 let (builder_a, mut a_rw) = test_client_builder(a_key, ProtocolVersion::V2);
707
708 let clients = Clients::default();
709 let metrics = Arc::new(Metrics::default());
710 clients.register(builder_a, metrics.clone());
711
712 let data = b"hello world v2!";
714 clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?;
715 let frame = recv_frame(FrameType::RelayToClientDatagram, &mut a_rw).await?;
716 assert_eq!(
717 frame,
718 RelayToClientMsg::Datagrams {
719 remote_endpoint_id: b_key,
720 datagrams: data.to_vec().into(),
721 }
722 );
723
724 clients.shutdown().await;
725 Ok(())
726 }
727
728 #[tokio::test]
730 #[traced_test]
731 async fn test_duplicate_endpoint_v1_receives_v1health() -> Result {
732 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42u64);
733 let key = SecretKey::from_bytes(&rng.random()).public();
734
735 let (builder_first, mut first_rw) = test_client_builder(key, ProtocolVersion::V1);
736
737 let clients = Clients::default();
738 let metrics = Arc::new(Metrics::default());
739 clients.register(builder_first, metrics.clone());
740
741 let (builder_second, _second_rw) = test_client_builder(key, ProtocolVersion::V1);
744 clients.register(builder_second, metrics.clone());
745
746 let frame = recv_frame(FrameType::Health, &mut first_rw).await?;
747 assert!(
748 matches!(frame, RelayToClientMsg::Health { .. }),
749 "expected V1Health frame for V1 client, got {frame:?}"
750 );
751
752 clients.shutdown().await;
753 Ok(())
754 }
755
756 #[tokio::test]
758 #[traced_test]
759 async fn test_duplicate_endpoint_v2_receives_health() -> Result {
760 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42u64);
761 let key = SecretKey::from_bytes(&rng.random()).public();
762
763 let (builder_first, mut first_rw) = test_client_builder(key, ProtocolVersion::V2);
764
765 let clients = Clients::default();
766 let metrics = Arc::new(Metrics::default());
767 clients.register(builder_first, metrics.clone());
768
769 let (builder_second, _second_rw) = test_client_builder(key, ProtocolVersion::V2);
772 clients.register(builder_second, metrics.clone());
773
774 let frame = recv_frame(FrameType::Status, &mut first_rw).await?;
775 assert_eq!(
776 frame,
777 RelayToClientMsg::Status(Status::SameEndpointIdConnected)
778 );
779
780 clients.shutdown().await;
781 Ok(())
782 }
783
784 #[tokio::test(start_paused = true)]
785 #[traced_test]
786 async fn test_rate_limit() -> Result {
787 const LIMIT: u32 = 50;
788 const MAX_FRAMES: u32 = 100;
789
790 let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0u64);
791
792 let (io_read, io_write) = tokio::io::duplex((LIMIT * MAX_FRAMES) as _);
794 let mut frame_writer = Conn::test(io_write, Default::default());
795 let mut stream = RelayedStream::test_limited(io_read, LIMIT / 10, LIMIT)?;
797
798 let data = Datagrams::from(b"hello world!!!!!");
800 let target = SecretKey::from_bytes(&rng.random()).public();
801 let frame = ClientToRelayMsg::Datagrams {
802 dst_endpoint_id: target,
803 datagrams: data.clone(),
804 };
805 let frame_len = frame.to_bytes().len();
806 assert_eq!(frame_len, LIMIT as usize);
807
808 info!("-- send packet");
810 frame_writer.send(frame.clone()).await.std_context("send")?;
811 frame_writer.flush().await.std_context("flush")?;
812 let recv_frame = tokio::time::timeout(Duration::from_millis(500), stream.next())
813 .await
814 .expect("timeout")
815 .expect("option")
816 .expect("ok");
817 assert_eq!(recv_frame, frame);
818
819 info!("-- send packet");
821 frame_writer.send(frame.clone()).await.std_context("send")?;
822 frame_writer.flush().await.std_context("flush")?;
823 let res = tokio::time::timeout(Duration::from_millis(100), stream.next()).await;
824 assert!(res.is_err(), "expecting a timeout");
825 info!("-- timeout happened");
826
827 info!("-- sleep");
829 tokio::time::sleep(Duration::from_secs(1)).await;
830
831 let recv_frame = tokio::time::timeout(Duration::from_millis(500), stream.next())
833 .await
834 .expect("timeout")
835 .expect("option")
836 .expect("ok");
837 assert_eq!(recv_frame, frame);
838
839 Ok(())
840 }
841}