1use std::{
6 collections::HashMap,
7 sync::{
8 atomic::{AtomicUsize, Ordering},
9 Arc,
10 },
11 time::Duration,
12};
13
14use anyhow::{bail, Context as _, Result};
15use hyper::HeaderMap;
16use iroh_metrics::{core::UsageStatsReport, inc, inc_by, report_usage_stats};
17use time::{Date, OffsetDateTime};
18use tokio::sync::mpsc;
19use tokio_tungstenite::WebSocketStream;
20use tokio_util::{codec::Framed, sync::CancellationToken, task::AbortOnDropHandle};
21use tracing::{info_span, trace, Instrument};
22use tungstenite::protocol::Role;
23
24use crate::{
25 defaults::timeouts::relay::SERVER_WRITE_TIMEOUT as WRITE_TIMEOUT,
26 key::PublicKey,
27 relay::{
28 codec::{
29 recv_client_key, DerpCodec, PER_CLIENT_SEND_QUEUE_DEPTH, PROTOCOL_VERSION,
30 SERVER_CHANNEL_SIZE,
31 },
32 http::Protocol,
33 server::{
34 client_conn::ClientConnBuilder,
35 clients::Clients,
36 metrics::Metrics,
37 streams::{MaybeTlsStream, RelayIo},
38 types::ServerMessage,
39 },
40 },
41};
42
43static CONN_NUM: AtomicUsize = AtomicUsize::new(1);
46fn new_conn_num() -> usize {
47 CONN_NUM.fetch_add(1, Ordering::Relaxed)
48}
49
50#[derive(Debug)]
57pub struct ServerActorTask {
58 write_timeout: Option<Duration>,
61 server_channel: mpsc::Sender<ServerMessage>,
63 closed: bool,
65 loop_handler: AbortOnDropHandle<Result<()>>,
67 cancel: CancellationToken,
69 }
71
72impl Default for ServerActorTask {
73 fn default() -> Self {
74 let (server_channel_s, server_channel_r) = mpsc::channel(SERVER_CHANNEL_SIZE);
75 let server_actor = ServerActor::new(server_channel_r);
76 let cancel_token = CancellationToken::new();
77 let done = cancel_token.clone();
78 let server_task = AbortOnDropHandle::new(tokio::spawn(
79 async move { server_actor.run(done).await }.instrument(info_span!("relay.server")),
80 ));
81
82 Self {
83 write_timeout: Some(WRITE_TIMEOUT),
84 server_channel: server_channel_s,
85 closed: false,
86 loop_handler: server_task,
87 cancel: cancel_token,
88 }
89 }
90}
91
92impl ServerActorTask {
93 pub fn new() -> Self {
95 Self::default()
96 }
97
98 pub async fn close(mut self) {
100 if !self.closed {
101 if let Err(err) = self.server_channel.send(ServerMessage::Shutdown).await {
102 tracing::warn!(
103 "could not shutdown the server gracefully, doing a forced shutdown: {:?}",
104 err
105 );
106 self.cancel.cancel();
107 }
108 match self.loop_handler.await {
109 Ok(Ok(())) => {}
110 Ok(Err(e)) => tracing::warn!("error shutting down server: {e:?}"),
111 Err(e) => tracing::warn!("error waiting for the server process to close: {e:?}"),
112 }
113 self.closed = true;
114 }
115 }
116
117 pub fn abort(&self) {
121 self.cancel.cancel();
122 }
123
124 pub fn is_closed(&self) -> bool {
126 self.closed
127 }
128
129 pub fn client_conn_handler(&self, default_headers: HeaderMap) -> ClientConnHandler {
132 ClientConnHandler {
133 server_channel: self.server_channel.clone(),
134 write_timeout: self.write_timeout,
135 default_headers: Arc::new(default_headers),
136 }
137 }
138}
139
140#[derive(Debug)]
146pub struct ClientConnHandler {
147 server_channel: mpsc::Sender<ServerMessage>,
148 write_timeout: Option<Duration>,
149 pub(crate) default_headers: Arc<HeaderMap>,
150}
151
152impl Clone for ClientConnHandler {
153 fn clone(&self) -> Self {
154 Self {
155 server_channel: self.server_channel.clone(),
156 write_timeout: self.write_timeout,
157 default_headers: Arc::clone(&self.default_headers),
158 }
159 }
160}
161
162impl ClientConnHandler {
163 pub async fn accept(&self, protocol: Protocol, io: MaybeTlsStream) -> Result<()> {
174 trace!(?protocol, "accept: start");
175 let mut io = match protocol {
176 Protocol::Relay => {
177 inc!(Metrics, derp_accepts);
178 RelayIo::Derp(Framed::new(io, DerpCodec))
179 }
180 Protocol::Websocket => {
181 inc!(Metrics, websocket_accepts);
182 RelayIo::Ws(WebSocketStream::from_raw_socket(io, Role::Server, None).await)
183 }
184 };
185 trace!("accept: recv client key");
186 let (client_key, info) = recv_client_key(&mut io)
187 .await
188 .context("unable to receive client information")?;
189
190 if info.version != PROTOCOL_VERSION {
191 bail!(
192 "unexpected client version {}, expected {}",
193 info.version,
194 PROTOCOL_VERSION
195 );
196 }
197
198 trace!("accept: build client conn");
199 let client_conn_builder = ClientConnBuilder {
200 key: client_key,
201 conn_num: new_conn_num(),
202 io,
203 write_timeout: self.write_timeout,
204 channel_capacity: PER_CLIENT_SEND_QUEUE_DEPTH,
205 server_channel: self.server_channel.clone(),
206 };
207 trace!("accept: create client");
208 self.server_channel
209 .send(ServerMessage::CreateClient(client_conn_builder))
210 .await
211 .map_err(|_| {
212 anyhow::anyhow!("server channel closed, the server is probably shutdown")
213 })?;
214 Ok(())
215 }
216}
217
218struct ServerActor {
219 receiver: mpsc::Receiver<ServerMessage>,
220 clients: Clients,
222 client_counter: ClientCounter,
223}
224
225impl ServerActor {
226 fn new(receiver: mpsc::Receiver<ServerMessage>) -> Self {
227 Self {
228 receiver,
229 clients: Clients::new(),
230 client_counter: ClientCounter::default(),
231 }
232 }
233
234 async fn run(mut self, done: CancellationToken) -> Result<()> {
235 loop {
236 tokio::select! {
237 biased;
238 _ = done.cancelled() => {
239 tracing::warn!("server actor loop cancelled, closing loop");
240 self.clients.shutdown().await;
243 return Ok(());
244 }
245 msg = self.receiver.recv() => {
246 let msg = match msg {
247 Some(m) => m,
248 None => {
249 tracing::warn!("server channel sender closed unexpectedly, shutting down server loop");
250 self.clients.shutdown().await;
251 anyhow::bail!("server channel sender closed unexpectedly, closed client connections, and shutting down server loop");
252 }
253 };
254 match msg {
255 ServerMessage::SendPacket((key, packet)) => {
256 tracing::trace!("send packet from: {:?} to: {:?} ({}b)", packet.src, key, packet.bytes.len());
257 let src = packet.src;
258 if self.clients.contains_key(&key) {
259 if self.clients.send_packet(&key, packet).is_ok() {
262 self.clients.record_send(&src, key);
263 }
264 } else {
265 tracing::warn!("send packet: no way to reach client {key:?}, dropped packet");
266 inc!(Metrics, send_packets_dropped);
267 }
268 }
269 ServerMessage::SendDiscoPacket((key, packet)) => {
270 tracing::trace!("send disco packet from: {:?} to: {:?} ({}b)", packet.src, key, packet.bytes.len());
271 let src = packet.src;
272 if self.clients.contains_key(&key) {
273 if self.clients.send_disco_packet(&key, packet).is_ok() {
276 self.clients.record_send(&src, key);
277 }
278 } else {
279 tracing::warn!("send disco packet: no way to reach client {key:?}, dropped packet");
280 inc!(Metrics, disco_packets_dropped);
281 }
282 }
283 ServerMessage::CreateClient(client_builder) => {
284 inc!(Metrics, accepts);
285
286 tracing::trace!("create client: {:?}", client_builder.key);
287 let key = client_builder.key;
288
289 report_usage_stats(&UsageStatsReport::new(
290 "relay_accepts".to_string(),
291 "relay_server".to_string(), 1,
293 None, Some(key.to_string()),
295 )).await;
296 let nc = self.client_counter.update(key);
297 inc_by!(Metrics, unique_client_keys, nc);
298
299 self.clients.register(client_builder);
302
303 }
304 ServerMessage::RemoveClient((key, conn_num)) => {
305 inc!(Metrics, disconnects);
306 tracing::trace!("remove client: {:?}", key);
307 if self.clients.has_client(&key, conn_num) {
309 self.clients.unregister(&key);
312 }
313 }
314 ServerMessage::Shutdown => {
315 tracing::info!("server gracefully shutting down...");
316 self.clients.shutdown().await;
318 return Ok(());
319 }
320 }
321 }
322 }
323 }
324 }
325}
326
327struct ClientCounter {
328 clients: HashMap<PublicKey, usize>,
329 last_clear_date: Date,
330}
331
332impl Default for ClientCounter {
333 fn default() -> Self {
334 Self {
335 clients: HashMap::new(),
336 last_clear_date: OffsetDateTime::now_utc().date(),
337 }
338 }
339}
340
341impl ClientCounter {
342 fn check_and_clear(&mut self) {
343 let today = OffsetDateTime::now_utc().date();
344 if today != self.last_clear_date {
345 self.clients.clear();
346 self.last_clear_date = today;
347 }
348 }
349
350 pub fn update(&mut self, client: PublicKey) -> u64 {
352 self.check_and_clear();
353 let new_conn = !self.clients.contains_key(&client);
354 let counter = self.clients.entry(client).or_insert(0);
355 *counter += 1;
356 new_conn as u64
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use bytes::Bytes;
363 use iroh_base::key::SecretKey;
364 use tokio::io::DuplexStream;
365 use tokio_util::codec::{FramedRead, FramedWrite};
366 use tracing_subscriber::{prelude::*, EnvFilter};
367
368 use super::*;
369 use crate::relay::{
370 client::{
371 conn::{ConnBuilder, ConnReader, ConnWriter, ReceivedMessage},
372 streams::{MaybeTlsStreamReader, MaybeTlsStreamWriter},
373 },
374 codec::{recv_frame, ClientInfo, Frame, FrameType},
375 };
376
377 fn test_client_builder(
378 key: PublicKey,
379 conn_num: usize,
380 server_channel: mpsc::Sender<ServerMessage>,
381 ) -> (ClientConnBuilder, Framed<DuplexStream, DerpCodec>) {
382 let (test_io, io) = tokio::io::duplex(1024);
383 (
384 ClientConnBuilder {
385 key,
386 conn_num,
387 io: RelayIo::Derp(Framed::new(MaybeTlsStream::Test(io), DerpCodec)),
388 write_timeout: None,
389 channel_capacity: 10,
390 server_channel,
391 },
392 Framed::new(test_io, DerpCodec),
393 )
394 }
395
396 #[tokio::test]
397 async fn test_server_actor() -> Result<()> {
398 let (server_channel, server_channel_r) = mpsc::channel(20);
400 let server_actor: ServerActor = ServerActor::new(server_channel_r);
401 let done = CancellationToken::new();
402 let server_done = done.clone();
403
404 let server_task = tokio::spawn(
406 async move { server_actor.run(server_done).await }
407 .instrument(info_span!("relay.server")),
408 );
409
410 let key_a = SecretKey::generate().public();
411 let (client_a, mut a_io) = test_client_builder(key_a, 1, server_channel.clone());
412
413 server_channel
415 .send(ServerMessage::CreateClient(client_a))
416 .await
417 .map_err(|_| anyhow::anyhow!("server gone"))?;
418
419 let key_b = SecretKey::generate().public();
421 let (client_b, mut b_io) = test_client_builder(key_b, 2, server_channel.clone());
422 server_channel
423 .send(ServerMessage::CreateClient(client_b))
424 .await
425 .map_err(|_| anyhow::anyhow!("server gone"))?;
426
427 let msg = b"hello world!";
429 crate::relay::client::conn::send_packet(&mut b_io, &None, key_a, Bytes::from_static(msg))
430 .await?;
431
432 let frame = recv_frame(FrameType::RecvPacket, &mut a_io).await?;
434 assert_eq!(
435 frame,
436 Frame::RecvPacket {
437 src_key: key_b,
438 content: msg.to_vec().into()
439 }
440 );
441
442 server_channel
444 .send(ServerMessage::RemoveClient((key_b, 2)))
445 .await
446 .map_err(|_| anyhow::anyhow!("server gone"))?;
447
448 let frame = recv_frame(FrameType::PeerGone, &mut a_io).await?;
451 assert_eq!(Frame::PeerGone { peer: key_b }, frame);
452
453 server_channel
455 .send(ServerMessage::Shutdown)
456 .await
457 .map_err(|_| anyhow::anyhow!("server gone"))?;
458 server_task.await??;
459 Ok(())
460 }
461
462 #[tokio::test]
463 async fn test_client_conn_handler() -> Result<()> {
464 let (server_channel_s, mut server_channel_r) = mpsc::channel(10);
466 let client_key = SecretKey::generate();
467 let handler = ClientConnHandler {
468 write_timeout: None,
469 server_channel: server_channel_s,
470 default_headers: Default::default(),
471 };
472
473 let (client, server_io) = tokio::io::duplex(10);
475 let (client_reader, client_writer) = tokio::io::split(client);
476 let _client_reader = FramedRead::new(client_reader, DerpCodec);
477 let mut client_writer = FramedWrite::new(client_writer, DerpCodec);
478
479 let pub_client_key = client_key.public();
481 let client_task = AbortOnDropHandle::<Result<()>>::new(tokio::spawn(async move {
482 let client_info = ClientInfo {
484 version: PROTOCOL_VERSION,
485 };
486 crate::relay::codec::send_client_key(&mut client_writer, &client_key, &client_info)
487 .await?;
488
489 Ok(())
490 }));
491
492 handler
494 .accept(Protocol::Relay, MaybeTlsStream::Test(server_io))
495 .await?;
496 client_task.await??;
497
498 match server_channel_r.recv().await.unwrap() {
500 ServerMessage::CreateClient(builder) => {
501 assert_eq!(pub_client_key, builder.key);
502 }
503 _ => anyhow::bail!("unexpected server message"),
504 }
505 Ok(())
506 }
507
508 fn make_test_client(secret_key: SecretKey) -> (tokio::io::DuplexStream, ConnBuilder) {
509 let (client, server) = tokio::io::duplex(10);
510 let (client_reader, client_writer) = tokio::io::split(client);
511
512 let client_reader = MaybeTlsStreamReader::Mem(client_reader);
513 let client_writer = MaybeTlsStreamWriter::Mem(client_writer);
514
515 let client_reader = ConnReader::Derp(FramedRead::new(client_reader, DerpCodec));
516 let client_writer = ConnWriter::Derp(FramedWrite::new(client_writer, DerpCodec));
517
518 (
519 server,
520 ConnBuilder::new(secret_key, None, client_reader, client_writer),
521 )
522 }
523
524 #[tokio::test]
525 async fn test_server_basic() -> Result<()> {
526 let _guard = iroh_test::logging::setup();
527
528 let server: ServerActorTask = ServerActorTask::new();
530
531 let key_a = SecretKey::generate();
533 let public_key_a = key_a.public();
534 let (rw_a, client_a_builder) = make_test_client(key_a);
535 let handler = server.client_conn_handler(Default::default());
536 let handler_task = tokio::spawn(async move {
537 handler
538 .accept(Protocol::Relay, MaybeTlsStream::Test(rw_a))
539 .await
540 });
541 let (client_a, mut client_receiver_a) = client_a_builder.build().await?;
542 handler_task.await??;
543
544 let key_b = SecretKey::generate();
546 let public_key_b = key_b.public();
547 let (rw_b, client_b_builder) = make_test_client(key_b);
548 let handler = server.client_conn_handler(Default::default());
549 let handler_task = tokio::spawn(async move {
550 handler
551 .accept(Protocol::Relay, MaybeTlsStream::Test(rw_b))
552 .await
553 });
554 let (client_b, mut client_receiver_b) = client_b_builder.build().await?;
555 handler_task.await??;
556
557 let msg = Bytes::from_static(b"hello client b!!");
559 client_a.send(public_key_b, msg.clone()).await?;
560 match client_receiver_b.recv().await? {
561 ReceivedMessage::ReceivedPacket { source, data } => {
562 assert_eq!(public_key_a, source);
563 assert_eq!(&msg[..], data);
564 }
565 msg => {
566 anyhow::bail!("expected ReceivedPacket msg, got {msg:?}");
567 }
568 }
569
570 let msg = Bytes::from_static(b"nice to meet you client a!!");
572 client_b.send(public_key_a, msg.clone()).await?;
573 match client_receiver_a.recv().await? {
574 ReceivedMessage::ReceivedPacket { source, data } => {
575 assert_eq!(public_key_b, source);
576 assert_eq!(&msg[..], data);
577 }
578 msg => {
579 anyhow::bail!("expected ReceivedPacket msg, got {msg:?}");
580 }
581 }
582
583 server.close().await;
585
586 let res = client_a
588 .send(public_key_b, Bytes::from_static(b"try to send"))
589 .await;
590 assert!(res.is_err());
591 assert!(client_receiver_b.recv().await.is_err());
592 Ok(())
593 }
594
595 #[tokio::test]
596 async fn test_server_replace_client() -> Result<()> {
597 tracing_subscriber::registry()
598 .with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
599 .with(EnvFilter::from_default_env())
600 .try_init()
601 .ok();
602
603 let server: ServerActorTask = ServerActorTask::new();
605
606 let key_a = SecretKey::generate();
608 let public_key_a = key_a.public();
609 let (rw_a, client_a_builder) = make_test_client(key_a);
610 let handler = server.client_conn_handler(Default::default());
611 let handler_task = tokio::spawn(async move {
612 handler
613 .accept(Protocol::Relay, MaybeTlsStream::Test(rw_a))
614 .await
615 });
616 let (client_a, mut client_receiver_a) = client_a_builder.build().await?;
617 handler_task.await??;
618
619 let key_b = SecretKey::generate();
621 let public_key_b = key_b.public();
622 let (rw_b, client_b_builder) = make_test_client(key_b.clone());
623 let handler = server.client_conn_handler(Default::default());
624 let handler_task = tokio::spawn(async move {
625 handler
626 .accept(Protocol::Relay, MaybeTlsStream::Test(rw_b))
627 .await
628 });
629 let (client_b, mut client_receiver_b) = client_b_builder.build().await?;
630 handler_task.await??;
631
632 let msg = Bytes::from_static(b"hello client b!!");
634 client_a.send(public_key_b, msg.clone()).await?;
635 match client_receiver_b.recv().await? {
636 ReceivedMessage::ReceivedPacket { source, data } => {
637 assert_eq!(public_key_a, source);
638 assert_eq!(&msg[..], data);
639 }
640 msg => {
641 anyhow::bail!("expected ReceivedPacket msg, got {msg:?}");
642 }
643 }
644
645 let msg = Bytes::from_static(b"nice to meet you client a!!");
647 client_b.send(public_key_a, msg.clone()).await?;
648 match client_receiver_a.recv().await? {
649 ReceivedMessage::ReceivedPacket { source, data } => {
650 assert_eq!(public_key_b, source);
651 assert_eq!(&msg[..], data);
652 }
653 msg => {
654 anyhow::bail!("expected ReceivedPacket msg, got {msg:?}");
655 }
656 }
657
658 let (new_rw_b, new_client_b_builder) = make_test_client(key_b);
660 let handler = server.client_conn_handler(Default::default());
661 let handler_task = tokio::spawn(async move {
662 handler
663 .accept(Protocol::Relay, MaybeTlsStream::Test(new_rw_b))
664 .await
665 });
666 let (new_client_b, mut new_client_receiver_b) = new_client_b_builder.build().await?;
667 handler_task.await??;
668
669 let msg = Bytes::from_static(b"are you still there, b?!");
673 client_a.send(public_key_b, msg.clone()).await?;
674 match new_client_receiver_b.recv().await? {
675 ReceivedMessage::ReceivedPacket { source, data } => {
676 assert_eq!(public_key_a, source);
677 assert_eq!(&msg[..], data);
678 }
679 msg => {
680 anyhow::bail!("expected ReceivedPacket msg, got {msg:?}");
681 }
682 }
683
684 let msg = Bytes::from_static(b"just had a spot of trouble but I'm back now,a!!");
686 new_client_b.send(public_key_a, msg.clone()).await?;
687 match client_receiver_a.recv().await? {
688 ReceivedMessage::ReceivedPacket { source, data } => {
689 assert_eq!(public_key_b, source);
690 assert_eq!(&msg[..], data);
691 }
692 msg => {
693 anyhow::bail!("expected ReceivedPacket msg, got {msg:?}");
694 }
695 }
696
697 server.close().await;
699
700 let res = client_a
702 .send(public_key_b, Bytes::from_static(b"try to send"))
703 .await;
704 assert!(res.is_err());
705 assert!(new_client_receiver_b.recv().await.is_err());
706 Ok(())
707 }
708}