iroh_net/relay/client/
conn.rs

1//! Manages client-side connections to the relay server.
2//!
3//! based on tailscale/derp/derp_client.go
4
5use std::{
6    net::SocketAddr,
7    num::NonZeroU32,
8    pin::Pin,
9    sync::Arc,
10    task::{Context, Poll},
11    time::Duration,
12};
13
14use anyhow::{anyhow, bail, ensure, Context as _, Result};
15use bytes::Bytes;
16use futures_lite::Stream;
17use futures_sink::Sink;
18use futures_util::{
19    stream::{SplitSink, SplitStream, StreamExt},
20    SinkExt,
21};
22use tokio::sync::mpsc;
23use tokio_tungstenite_wasm::WebSocketStream;
24use tokio_util::{
25    codec::{FramedRead, FramedWrite},
26    task::AbortOnDropHandle,
27};
28use tracing::{debug, info_span, trace, Instrument};
29
30use crate::{
31    defaults::timeouts::relay::CLIENT_RECV_TIMEOUT,
32    key::{PublicKey, SecretKey},
33    relay::{
34        client::streams::{MaybeTlsStreamReader, MaybeTlsStreamWriter},
35        codec::{
36            write_frame, ClientInfo, DerpCodec, Frame, MAX_PACKET_SIZE,
37            PER_CLIENT_READ_QUEUE_DEPTH, PER_CLIENT_SEND_QUEUE_DEPTH, PROTOCOL_VERSION,
38        },
39    },
40};
41
42impl PartialEq for Conn {
43    fn eq(&self, other: &Self) -> bool {
44        Arc::ptr_eq(&self.inner, &other.inner)
45    }
46}
47
48impl Eq for Conn {}
49
50/// A connection to a relay server.
51///
52/// Cheaply clonable.
53/// Call `close` to shut down the write loop and read functionality.
54#[derive(Debug, Clone)]
55pub struct Conn {
56    inner: Arc<ConnTasks>,
57}
58
59/// The channel on which a relay connection sends received messages.
60///
61/// The [`Conn`] to a relay is easily clonable but can only send DISCO messages to a relay
62/// server.  This is the counterpart which receives DISCO messages from the relay server for
63/// a connection.  It is not clonable.
64#[derive(Debug)]
65pub struct ConnReceiver {
66    /// The reader channel, receiving incoming messages.
67    reader_channel: mpsc::Receiver<Result<ReceivedMessage>>,
68}
69
70impl ConnReceiver {
71    /// Reads a messages from a relay server.
72    ///
73    /// Once it returns an error, the [`Conn`] is dead forever.
74    pub async fn recv(&mut self) -> Result<ReceivedMessage> {
75        let msg = self
76            .reader_channel
77            .recv()
78            .await
79            .ok_or(anyhow!("shut down"))??;
80        Ok(msg)
81    }
82}
83
84#[derive(derive_more::Debug)]
85pub struct ConnTasks {
86    /// Our local address, if known.
87    ///
88    /// Is `None` in tests or when using websockets (because we don't control connection establishment in browsers).
89    local_addr: Option<SocketAddr>,
90    /// Channel on which to communicate to the server. The associated [`mpsc::Receiver`] will close
91    /// if there is ever an error writing to the server.
92    writer_channel: mpsc::Sender<ConnWriterMessage>,
93    /// JoinHandle for the [`ConnWriter`] task
94    writer_task: AbortOnDropHandle<Result<()>>,
95    reader_task: AbortOnDropHandle<()>,
96}
97
98impl Conn {
99    /// Sends a packet to the node identified by `dstkey`
100    ///
101    /// Errors if the packet is larger than [`MAX_PACKET_SIZE`]
102    pub async fn send(&self, dstkey: PublicKey, packet: Bytes) -> Result<()> {
103        trace!(%dstkey, len = packet.len(), "[RELAY] send");
104
105        self.inner
106            .writer_channel
107            .send(ConnWriterMessage::Packet((dstkey, packet)))
108            .await?;
109        Ok(())
110    }
111
112    /// Send a ping with 8 bytes of random data.
113    pub async fn send_ping(&self, data: [u8; 8]) -> Result<()> {
114        self.inner
115            .writer_channel
116            .send(ConnWriterMessage::Ping(data))
117            .await?;
118        Ok(())
119    }
120
121    /// Respond to a ping request. The `data` field should be filled
122    /// by the 8 bytes of random data send by the ping.
123    pub async fn send_pong(&self, data: [u8; 8]) -> Result<()> {
124        self.inner
125            .writer_channel
126            .send(ConnWriterMessage::Pong(data))
127            .await?;
128        Ok(())
129    }
130
131    /// Sends a packet that tells the server whether this
132    /// connection is to the user's preferred server. This is only
133    /// used in the server for stats.
134    pub async fn note_preferred(&self, preferred: bool) -> Result<()> {
135        self.inner
136            .writer_channel
137            .send(ConnWriterMessage::NotePreferred(preferred))
138            .await?;
139        Ok(())
140    }
141
142    /// The local address that the [`Conn`] is listening on.
143    ///
144    /// `None`, when run in a testing environment or when using websockets.
145    pub fn local_addr(&self) -> Option<SocketAddr> {
146        self.inner.local_addr
147    }
148
149    /// Whether or not this [`Conn`] is closed.
150    ///
151    /// The [`Conn`] is considered closed if the write side of the connection is no longer running.
152    pub fn is_closed(&self) -> bool {
153        self.inner.writer_task.is_finished()
154    }
155
156    /// Close the connection
157    ///
158    /// Shuts down the write loop directly and marks the connection as closed. The [`Conn`] will
159    /// check if the it is closed before attempting to read from it.
160    pub async fn close(&self) {
161        if self.inner.writer_task.is_finished() && self.inner.reader_task.is_finished() {
162            return;
163        }
164
165        self.inner
166            .writer_channel
167            .send(ConnWriterMessage::Shutdown)
168            .await
169            .ok();
170        self.inner.reader_task.abort();
171    }
172}
173
174fn process_incoming_frame(frame: Frame) -> Result<ReceivedMessage> {
175    match frame {
176        Frame::KeepAlive => {
177            // A one-way keep-alive message that doesn't require an ack.
178            // This predated FrameType::Ping/FrameType::Pong.
179            Ok(ReceivedMessage::KeepAlive)
180        }
181        Frame::PeerGone { peer } => Ok(ReceivedMessage::PeerGone(peer)),
182        Frame::RecvPacket { src_key, content } => {
183            let packet = ReceivedMessage::ReceivedPacket {
184                source: src_key,
185                data: content,
186            };
187            Ok(packet)
188        }
189        Frame::Ping { data } => Ok(ReceivedMessage::Ping(data)),
190        Frame::Pong { data } => Ok(ReceivedMessage::Pong(data)),
191        Frame::Health { problem } => {
192            let problem = std::str::from_utf8(&problem)?.to_owned();
193            let problem = Some(problem);
194            Ok(ReceivedMessage::Health { problem })
195        }
196        Frame::Restarting {
197            reconnect_in,
198            try_for,
199        } => {
200            let reconnect_in = Duration::from_millis(reconnect_in as u64);
201            let try_for = Duration::from_millis(try_for as u64);
202            Ok(ReceivedMessage::ServerRestarting {
203                reconnect_in,
204                try_for,
205            })
206        }
207        _ => bail!("unexpected packet: {:?}", frame.typ()),
208    }
209}
210
211/// The kinds of messages we can send to the [`Server`](crate::relay::server::Server)
212#[derive(Debug)]
213enum ConnWriterMessage {
214    /// Send a packet (addressed to the [`PublicKey`]) to the server
215    Packet((PublicKey, Bytes)),
216    /// Send a pong to the server
217    Pong([u8; 8]),
218    /// Send a ping to the server
219    Ping([u8; 8]),
220    /// Tell the server whether or not this client is the user's preferred client
221    NotePreferred(bool),
222    /// Shutdown the writer
223    Shutdown,
224}
225
226/// Call [`ConnWriterTasks::run`] to listen for messages to send to the connection.
227/// Should be used by the [`Conn`]
228///
229/// Shutsdown when you send a [`ConnWriterMessage::Shutdown`], or if there is an error writing to
230/// the server.
231struct ConnWriterTasks {
232    recv_msgs: mpsc::Receiver<ConnWriterMessage>,
233    writer: ConnWriter,
234    rate_limiter: Option<RateLimiter>,
235}
236
237impl ConnWriterTasks {
238    async fn run(mut self) -> Result<()> {
239        while let Some(msg) = self.recv_msgs.recv().await {
240            match msg {
241                ConnWriterMessage::Packet((key, bytes)) => {
242                    send_packet(&mut self.writer, &self.rate_limiter, key, bytes).await?;
243                }
244                ConnWriterMessage::Pong(data) => {
245                    write_frame(&mut self.writer, Frame::Pong { data }, None).await?;
246                    self.writer.flush().await?;
247                }
248                ConnWriterMessage::Ping(data) => {
249                    write_frame(&mut self.writer, Frame::Ping { data }, None).await?;
250                    self.writer.flush().await?;
251                }
252                ConnWriterMessage::NotePreferred(preferred) => {
253                    write_frame(&mut self.writer, Frame::NotePreferred { preferred }, None).await?;
254                    self.writer.flush().await?;
255                }
256                ConnWriterMessage::Shutdown => {
257                    return Ok(());
258                }
259            }
260        }
261
262        bail!("channel unexpectedly closed");
263    }
264}
265
266/// The Builder returns a [`Conn`] and a [`ConnReceiver`] and
267/// runs a [`ConnWriterTasks`] in the background.
268pub struct ConnBuilder {
269    secret_key: SecretKey,
270    reader: ConnReader,
271    writer: ConnWriter,
272    local_addr: Option<SocketAddr>,
273}
274
275pub(crate) enum ConnReader {
276    Derp(FramedRead<MaybeTlsStreamReader, DerpCodec>),
277    Ws(SplitStream<WebSocketStream>),
278}
279
280pub(crate) enum ConnWriter {
281    Derp(FramedWrite<MaybeTlsStreamWriter, DerpCodec>),
282    Ws(SplitSink<WebSocketStream, tokio_tungstenite_wasm::Message>),
283}
284
285fn tung_wasm_to_io_err(e: tokio_tungstenite_wasm::Error) -> std::io::Error {
286    match e {
287        tokio_tungstenite_wasm::Error::Io(io_err) => io_err,
288        _ => std::io::Error::new(std::io::ErrorKind::Other, e.to_string()),
289    }
290}
291
292impl Stream for ConnReader {
293    type Item = Result<Frame>;
294
295    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
296        match *self {
297            Self::Derp(ref mut ws) => Pin::new(ws).poll_next(cx),
298            Self::Ws(ref mut ws) => match Pin::new(ws).poll_next(cx) {
299                Poll::Ready(Some(Ok(tokio_tungstenite_wasm::Message::Binary(vec)))) => {
300                    Poll::Ready(Some(Frame::decode_from_ws_msg(vec)))
301                }
302                Poll::Ready(Some(Ok(msg))) => {
303                    tracing::warn!(?msg, "Got websocket message of unsupported type, skipping.");
304                    Poll::Pending
305                }
306                Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
307                Poll::Ready(None) => Poll::Ready(None),
308                Poll::Pending => Poll::Pending,
309            },
310        }
311    }
312}
313
314impl Sink<Frame> for ConnWriter {
315    type Error = std::io::Error;
316
317    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
318        match *self {
319            Self::Derp(ref mut ws) => Pin::new(ws).poll_ready(cx),
320            Self::Ws(ref mut ws) => Pin::new(ws).poll_ready(cx).map_err(tung_wasm_to_io_err),
321        }
322    }
323
324    fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
325        match *self {
326            Self::Derp(ref mut ws) => Pin::new(ws).start_send(item),
327            Self::Ws(ref mut ws) => Pin::new(ws)
328                .start_send(tokio_tungstenite_wasm::Message::binary(
329                    item.encode_for_ws_msg(),
330                ))
331                .map_err(tung_wasm_to_io_err),
332        }
333    }
334
335    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
336        match *self {
337            Self::Derp(ref mut ws) => Pin::new(ws).poll_flush(cx),
338            Self::Ws(ref mut ws) => Pin::new(ws).poll_flush(cx).map_err(tung_wasm_to_io_err),
339        }
340    }
341
342    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
343        match *self {
344            Self::Derp(ref mut ws) => Pin::new(ws).poll_close(cx),
345            Self::Ws(ref mut ws) => Pin::new(ws).poll_close(cx).map_err(tung_wasm_to_io_err),
346        }
347    }
348}
349
350impl ConnBuilder {
351    pub fn new(
352        secret_key: SecretKey,
353        local_addr: Option<SocketAddr>,
354        reader: ConnReader,
355        writer: ConnWriter,
356    ) -> Self {
357        Self {
358            secret_key,
359            reader,
360            writer,
361            local_addr,
362        }
363    }
364
365    async fn server_handshake(&mut self) -> Result<Option<RateLimiter>> {
366        debug!("server_handshake: started");
367        let client_info = ClientInfo {
368            version: PROTOCOL_VERSION,
369        };
370        debug!("server_handshake: sending client_key: {:?}", &client_info);
371        crate::relay::codec::send_client_key(&mut self.writer, &self.secret_key, &client_info)
372            .await?;
373
374        // TODO: add some actual configuration
375        let rate_limiter = RateLimiter::new(0, 0)?;
376
377        debug!("server_handshake: done");
378        Ok(rate_limiter)
379    }
380
381    pub async fn build(mut self) -> Result<(Conn, ConnReceiver)> {
382        // exchange information with the server
383        let rate_limiter = self.server_handshake().await?;
384
385        // create task to handle writing to the server
386        let (writer_sender, writer_recv) = mpsc::channel(PER_CLIENT_SEND_QUEUE_DEPTH);
387        let writer_task = tokio::task::spawn(
388            ConnWriterTasks {
389                rate_limiter,
390                writer: self.writer,
391                recv_msgs: writer_recv,
392            }
393            .run()
394            .instrument(info_span!("conn.writer")),
395        );
396
397        let (reader_sender, reader_recv) = mpsc::channel(PER_CLIENT_READ_QUEUE_DEPTH);
398        let reader_task = tokio::task::spawn({
399            let writer_sender = writer_sender.clone();
400            async move {
401                loop {
402                    let frame = tokio::time::timeout(CLIENT_RECV_TIMEOUT, self.reader.next()).await;
403                    let res = match frame {
404                        Ok(Some(Ok(frame))) => process_incoming_frame(frame),
405                        Ok(Some(Err(err))) => {
406                            // Error processing incoming messages
407                            Err(err)
408                        }
409                        Ok(None) => {
410                            // EOF
411                            Err(anyhow::anyhow!("EOF: reader stream ended"))
412                        }
413                        Err(err) => {
414                            // Timeout
415                            Err(err.into())
416                        }
417                    };
418                    if res.is_err() {
419                        // shutdown
420                        writer_sender.send(ConnWriterMessage::Shutdown).await.ok();
421                        break;
422                    }
423                    if reader_sender.send(res).await.is_err() {
424                        // shutdown, as the reader is gone
425                        writer_sender.send(ConnWriterMessage::Shutdown).await.ok();
426                        break;
427                    }
428                }
429            }
430            .instrument(info_span!("conn.reader"))
431        });
432
433        let conn = Conn {
434            inner: Arc::new(ConnTasks {
435                local_addr: self.local_addr,
436                writer_channel: writer_sender,
437                writer_task: AbortOnDropHandle::new(writer_task),
438                reader_task: AbortOnDropHandle::new(reader_task),
439            }),
440        };
441
442        let conn_receiver = ConnReceiver {
443            reader_channel: reader_recv,
444        };
445
446        Ok((conn, conn_receiver))
447    }
448}
449
450#[derive(derive_more::Debug, Clone)]
451/// The type of message received by the [`Conn`] from a relay server.
452pub enum ReceivedMessage {
453    /// Represents an incoming packet.
454    ReceivedPacket {
455        /// The [`PublicKey`] of the packet sender.
456        source: PublicKey,
457        /// The received packet bytes.
458        #[debug(skip)]
459        data: Bytes, // TODO: ref
460    },
461    /// Indicates that the client identified by the underlying public key had previously sent you a
462    /// packet but has now disconnected from the server.
463    PeerGone(PublicKey),
464    /// Request from a client or server to reply to the
465    /// other side with a [`ReceivedMessage::Pong`] with the given payload.
466    Ping([u8; 8]),
467    /// Reply to a [`ReceivedMessage::Ping`] from a client or server
468    /// with the payload sent previously in the ping.
469    Pong([u8; 8]),
470    /// A one-way empty message from server to client, just to
471    /// keep the connection alive. It's like a [`ReceivedMessage::Ping`], but doesn't solicit
472    /// a reply from the client.
473    KeepAlive,
474    /// A one-way message from server to client, declaring the connection health state.
475    Health {
476        /// If set, is a description of why the connection is unhealthy.
477        ///
478        /// If `None` means the connection is healthy again.
479        ///
480        /// The default condition is healthy, so the server doesn't broadcast a [`ReceivedMessage::Health`]
481        /// until a problem exists.
482        problem: Option<String>,
483    },
484    /// A one-way message from server to client, advertising that the server is restarting.
485    ServerRestarting {
486        /// An advisory duration that the client should wait before attempting to reconnect.
487        /// It might be zero. It exists for the server to smear out the reconnects.
488        reconnect_in: Duration,
489        /// An advisory duration for how long the client should attempt to reconnect
490        /// before giving up and proceeding with its normal connection failure logic. The interval
491        /// between retries is undefined for now. A server should not send a TryFor duration more
492        /// than a few seconds.
493        try_for: Duration,
494    },
495}
496
497pub(crate) async fn send_packet<S: Sink<Frame, Error = std::io::Error> + Unpin>(
498    mut writer: S,
499    rate_limiter: &Option<RateLimiter>,
500    dst_key: PublicKey,
501    packet: Bytes,
502) -> Result<()> {
503    ensure!(
504        packet.len() <= MAX_PACKET_SIZE,
505        "packet too big: {}",
506        packet.len()
507    );
508
509    let frame = Frame::SendPacket { dst_key, packet };
510    if let Some(rate_limiter) = rate_limiter {
511        if rate_limiter.check_n(frame.len()).is_err() {
512            tracing::warn!("dropping send: rate limit reached");
513            return Ok(());
514        }
515    }
516    writer.send(frame).await?;
517    writer.flush().await?;
518
519    Ok(())
520}
521
522pub(crate) struct RateLimiter {
523    inner: governor::RateLimiter<
524        governor::state::direct::NotKeyed,
525        governor::state::InMemoryState,
526        governor::clock::DefaultClock,
527        governor::middleware::NoOpMiddleware,
528    >,
529}
530
531impl RateLimiter {
532    pub(crate) fn new(bytes_per_second: usize, bytes_burst: usize) -> Result<Option<Self>> {
533        if bytes_per_second == 0 || bytes_burst == 0 {
534            return Ok(None);
535        }
536        let bytes_per_second = NonZeroU32::new(u32::try_from(bytes_per_second)?)
537            .context("bytes_per_second not non-zero")?;
538        let bytes_burst =
539            NonZeroU32::new(u32::try_from(bytes_burst)?).context("bytes_burst not non-zero")?;
540        Ok(Some(Self {
541            inner: governor::RateLimiter::direct(
542                governor::Quota::per_second(bytes_per_second).allow_burst(bytes_burst),
543            ),
544        }))
545    }
546
547    pub(crate) fn check_n(&self, n: usize) -> Result<()> {
548        let n = NonZeroU32::new(u32::try_from(n)?).context("n not non-zero")?;
549        match self.inner.check_n(n) {
550            Ok(_) => Ok(()),
551            Err(_) => bail!("batch cannot go through"),
552        }
553    }
554}