iroh_relay/server/
streams.rs

1//! Streams used in the server-side implementation of iroh relays.
2
3use std::{
4    pin::Pin,
5    sync::Arc,
6    task::{Context, Poll},
7};
8
9use n0_error::{ensure, stack_error};
10use n0_future::{FutureExt, Sink, Stream, ready, time};
11use tokio::io::{AsyncRead, AsyncWrite};
12use tracing::instrument;
13
14use super::{ClientRateLimit, Metrics};
15use crate::{
16    ExportKeyingMaterial, KeyCache, MAX_PACKET_SIZE,
17    protos::{
18        relay::{ClientToRelayMsg, Error as ProtoError, RelayToClientMsg},
19        streams::{StreamError, WsBytesFramed},
20    },
21};
22
23/// The relay's connection to a client.
24///
25/// This implements
26/// - a [`Stream`] of [`ClientToRelayMsg`]s that are received from the client,
27/// - a [`Sink`] of [`RelayToClientMsg`]s that can be sent to the client.
28#[derive(Debug)]
29pub(crate) struct RelayedStream {
30    pub(crate) inner: WsBytesFramed<RateLimited<MaybeTlsStream>>,
31    pub(crate) key_cache: KeyCache,
32}
33
34#[cfg(test)]
35impl RelayedStream {
36    pub(crate) fn test(stream: tokio::io::DuplexStream) -> Self {
37        let stream = MaybeTlsStream::Test(stream);
38        let stream = RateLimited::unlimited(stream, Arc::new(Metrics::default()));
39        Self {
40            inner: WsBytesFramed {
41                io: tokio_websockets::ServerBuilder::new()
42                    .limits(Self::limits())
43                    .serve(stream),
44            },
45            key_cache: KeyCache::test(),
46        }
47    }
48
49    pub(crate) fn test_limited(
50        stream: tokio::io::DuplexStream,
51        max_burst_bytes: u32,
52        bytes_per_second: u32,
53    ) -> Result<Self, InvalidBucketConfig> {
54        let stream = MaybeTlsStream::Test(stream);
55        let stream = RateLimited::new(
56            stream,
57            max_burst_bytes,
58            bytes_per_second,
59            Arc::new(Metrics::default()),
60        )?;
61        Ok(Self {
62            inner: WsBytesFramed {
63                io: tokio_websockets::ServerBuilder::new()
64                    .limits(Self::limits())
65                    .serve(stream),
66            },
67            key_cache: KeyCache::test(),
68        })
69    }
70
71    fn limits() -> tokio_websockets::Limits {
72        tokio_websockets::Limits::default()
73            .max_payload_len(Some(crate::protos::relay::MAX_FRAME_SIZE))
74    }
75}
76
77/// Relay send errors
78#[stack_error(derive, add_meta)]
79#[non_exhaustive]
80pub enum SendError {
81    #[error(transparent)]
82    StreamError {
83        #[error(from, std_err)]
84        source: StreamError,
85    },
86    #[error("Packet exceeds max packet size")]
87    ExceedsMaxPacketSize { size: usize },
88    #[error("Attempted to send empty packet")]
89    EmptyPacket {},
90}
91
92impl Sink<RelayToClientMsg> for RelayedStream {
93    type Error = SendError;
94
95    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
96        Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into)
97    }
98
99    fn start_send(mut self: Pin<&mut Self>, item: RelayToClientMsg) -> Result<(), Self::Error> {
100        let size = item.encoded_len();
101        ensure!(
102            size <= MAX_PACKET_SIZE,
103            SendError::ExceedsMaxPacketSize { size }
104        );
105        if let RelayToClientMsg::Datagrams { datagrams, .. } = &item {
106            ensure!(!datagrams.contents.is_empty(), SendError::EmptyPacket);
107        }
108
109        Pin::new(&mut self.inner)
110            .start_send(item.to_bytes().freeze())
111            .map_err(Into::into)
112    }
113
114    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115        Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into)
116    }
117
118    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
119        Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into)
120    }
121}
122
123/// Relay receive errors
124#[stack_error(derive, add_meta, from_sources)]
125#[non_exhaustive]
126pub enum RecvError {
127    #[error(transparent)]
128    Proto { source: ProtoError },
129    #[error(transparent)]
130    StreamError {
131        #[error(std_err)]
132        source: StreamError,
133    },
134}
135
136impl Stream for RelayedStream {
137    type Item = Result<ClientToRelayMsg, RecvError>;
138
139    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140        Poll::Ready(match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
141            Some(Ok(msg)) => {
142                Some(ClientToRelayMsg::from_bytes(msg, &self.key_cache).map_err(Into::into))
143            }
144            Some(Err(e)) => Some(Err(e.into())),
145            None => None,
146        })
147    }
148}
149
150/// The main underlying IO stream type used for the relay server.
151///
152/// Allows choosing whether or not the underlying [`tokio::net::TcpStream`] is served over Tls
153#[derive(Debug)]
154#[allow(clippy::large_enum_variant)]
155pub enum MaybeTlsStream {
156    /// A plain non-Tls [`tokio::net::TcpStream`]
157    Plain(tokio::net::TcpStream),
158    /// A Tls wrapped [`tokio::net::TcpStream`]
159    Tls(tokio_rustls::server::TlsStream<tokio::net::TcpStream>),
160    /// An in-memory bidirectional pipe.
161    #[cfg(test)]
162    Test(tokio::io::DuplexStream),
163}
164
165impl ExportKeyingMaterial for MaybeTlsStream {
166    fn export_keying_material<T: AsMut<[u8]>>(
167        &self,
168        output: T,
169        label: &[u8],
170        context: Option<&[u8]>,
171    ) -> Option<T> {
172        let Self::Tls(tls) = self else {
173            return None;
174        };
175
176        tls.get_ref()
177            .1
178            .export_keying_material(output, label, context)
179            .ok()
180    }
181}
182
183impl AsyncRead for MaybeTlsStream {
184    fn poll_read(
185        mut self: Pin<&mut Self>,
186        cx: &mut Context<'_>,
187        buf: &mut tokio::io::ReadBuf<'_>,
188    ) -> Poll<std::io::Result<()>> {
189        match &mut *self {
190            MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
191            MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
192            #[cfg(test)]
193            MaybeTlsStream::Test(s) => Pin::new(s).poll_read(cx, buf),
194        }
195    }
196}
197
198impl AsyncWrite for MaybeTlsStream {
199    fn poll_flush(
200        mut self: Pin<&mut Self>,
201        cx: &mut Context<'_>,
202    ) -> Poll<std::result::Result<(), std::io::Error>> {
203        match &mut *self {
204            MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
205            MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
206            #[cfg(test)]
207            MaybeTlsStream::Test(s) => Pin::new(s).poll_flush(cx),
208        }
209    }
210
211    fn poll_shutdown(
212        mut self: Pin<&mut Self>,
213        cx: &mut Context<'_>,
214    ) -> Poll<std::result::Result<(), std::io::Error>> {
215        match &mut *self {
216            MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
217            MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
218            #[cfg(test)]
219            MaybeTlsStream::Test(s) => Pin::new(s).poll_shutdown(cx),
220        }
221    }
222
223    fn poll_write(
224        mut self: Pin<&mut Self>,
225        cx: &mut Context<'_>,
226        buf: &[u8],
227    ) -> Poll<std::result::Result<usize, std::io::Error>> {
228        match &mut *self {
229            MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
230            MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
231            #[cfg(test)]
232            MaybeTlsStream::Test(s) => Pin::new(s).poll_write(cx, buf),
233        }
234    }
235
236    fn poll_write_vectored(
237        mut self: Pin<&mut Self>,
238        cx: &mut Context<'_>,
239        bufs: &[std::io::IoSlice<'_>],
240    ) -> Poll<std::result::Result<usize, std::io::Error>> {
241        match &mut *self {
242            MaybeTlsStream::Plain(s) => Pin::new(s).poll_write_vectored(cx, bufs),
243            MaybeTlsStream::Tls(s) => Pin::new(s).poll_write_vectored(cx, bufs),
244            #[cfg(test)]
245            MaybeTlsStream::Test(s) => Pin::new(s).poll_write_vectored(cx, bufs),
246        }
247    }
248
249    fn is_write_vectored(&self) -> bool {
250        match self {
251            MaybeTlsStream::Plain(s) => s.is_write_vectored(),
252            MaybeTlsStream::Tls(s) => s.is_write_vectored(),
253            #[cfg(test)]
254            MaybeTlsStream::Test(s) => s.is_write_vectored(),
255        }
256    }
257}
258
259/// Rate limiter for reading from a [`RelayedStream`].
260///
261/// The writes to the sink are not rate limited.
262///
263/// This potentially buffers one frame if the rate limiter does not allows this frame.
264/// While the frame is buffered the undernlying stream is no longer polled.
265#[derive(Debug)]
266pub(crate) struct RateLimited<S> {
267    inner: S,
268    bucket: Option<Bucket>,
269    bucket_refilled: Option<Pin<Box<time::Sleep>>>,
270    /// Keeps track if this stream was ever rate-limited.
271    limited_once: bool,
272    metrics: Arc<Metrics>,
273}
274
275#[derive(Debug)]
276struct Bucket {
277    // The current bucket fill
278    fill: i64,
279    // The maximum bucket fill
280    max: i64,
281    // The bucket's last fill time
282    last_fill: time::Instant,
283    // Interval length of one refill
284    refill_period: time::Duration,
285    // How much we re-fill per refill period
286    refill: i64,
287}
288
289#[allow(missing_docs)]
290#[stack_error(derive, add_meta)]
291pub struct InvalidBucketConfig {
292    max: i64,
293    bytes_per_second: i64,
294    refill_period: time::Duration,
295}
296
297impl Bucket {
298    fn new(
299        max: i64,
300        bytes_per_second: i64,
301        refill_period: time::Duration,
302    ) -> Result<Self, InvalidBucketConfig> {
303        // milliseconds is the tokio timer resolution
304        let refill = bytes_per_second.saturating_mul(refill_period.as_millis() as i64) / 1000;
305        ensure!(
306            max > 0 && bytes_per_second > 0 && refill_period.as_millis() as u32 > 0 && refill > 0,
307            InvalidBucketConfig {
308                max,
309                bytes_per_second,
310                refill_period
311            }
312        );
313        Ok(Self {
314            fill: max,
315            max,
316            last_fill: time::Instant::now(),
317            refill_period,
318            refill,
319        })
320    }
321
322    fn update_state(&mut self) {
323        let now = time::Instant::now();
324        // div safety: self.refill_period.as_millis() is checked to be non-null in constructor
325        let refill_periods = now.saturating_duration_since(self.last_fill).as_millis() as u32
326            / self.refill_period.as_millis() as u32;
327        if refill_periods == 0 {
328            // Nothing to do - we won't refill yet
329            return;
330        }
331
332        self.fill = self
333            .fill
334            .saturating_add(refill_periods as i64 * self.refill);
335        self.fill = std::cmp::min(self.fill, self.max);
336        self.last_fill += self.refill_period * refill_periods;
337    }
338
339    fn consume(&mut self, bytes: usize) -> Result<(), time::Instant> {
340        let bytes = i64::try_from(bytes).unwrap_or(i64::MAX);
341        self.update_state();
342
343        self.fill = self.fill.saturating_sub(bytes);
344
345        if self.fill > 0 {
346            return Ok(());
347        }
348
349        let missing = self.fill.saturating_neg();
350
351        let periods_needed = (missing / self.refill) + 1;
352        let periods_needed = u32::try_from(periods_needed).unwrap_or(u32::MAX);
353
354        Err(self.last_fill + periods_needed * self.refill_period)
355    }
356}
357
358impl<S> RateLimited<S> {
359    pub(crate) fn from_cfg(
360        cfg: Option<ClientRateLimit>,
361        io: S,
362        metrics: Arc<Metrics>,
363    ) -> Result<Self, InvalidBucketConfig> {
364        match cfg {
365            Some(cfg) => {
366                let bytes_per_second = cfg.bytes_per_second.into();
367                let max_burst_bytes = cfg.max_burst_bytes.map_or(bytes_per_second / 10, u32::from);
368                Self::new(io, max_burst_bytes, bytes_per_second, metrics)
369            }
370            None => Ok(Self::unlimited(io, metrics)),
371        }
372    }
373
374    pub(crate) fn new(
375        inner: S,
376        max_burst_bytes: u32,
377        bytes_per_second: u32,
378        metrics: Arc<Metrics>,
379    ) -> Result<Self, InvalidBucketConfig> {
380        Ok(Self {
381            inner,
382            bucket: Some(Bucket::new(
383                max_burst_bytes as i64,
384                bytes_per_second as i64,
385                time::Duration::from_millis(100),
386            )?),
387            bucket_refilled: None,
388            limited_once: false,
389            metrics,
390        })
391    }
392
393    pub(crate) fn unlimited(inner: S, metrics: Arc<Metrics>) -> Self {
394        Self {
395            inner,
396            bucket: None,
397            bucket_refilled: None,
398            limited_once: false,
399            metrics,
400        }
401    }
402
403    /// Records metrics about being rate-limited.
404    fn record_rate_limited(&mut self, bytes: usize) {
405        // TODO: add a label for the frame type.
406        self.metrics.bytes_rx_ratelimited_total.inc_by(bytes as u64);
407        if !self.limited_once {
408            self.metrics.conns_rx_ratelimited_total.inc();
409            self.limited_once = true;
410        }
411    }
412}
413
414impl<S: ExportKeyingMaterial> ExportKeyingMaterial for RateLimited<S> {
415    fn export_keying_material<T: AsMut<[u8]>>(
416        &self,
417        output: T,
418        label: &[u8],
419        context: Option<&[u8]>,
420    ) -> Option<T> {
421        self.inner.export_keying_material(output, label, context)
422    }
423}
424
425impl<S: AsyncRead + Unpin> AsyncRead for RateLimited<S> {
426    #[instrument(name = "rate_limited_poll_read", skip_all)]
427    fn poll_read(
428        mut self: Pin<&mut Self>,
429        cx: &mut std::task::Context<'_>,
430        buf: &mut tokio::io::ReadBuf<'_>,
431    ) -> Poll<std::io::Result<()>> {
432        let this = &mut *self;
433        let Some(bucket) = &mut this.bucket else {
434            // If there is no rate-limiter, then directly poll the inner.
435            return Pin::new(&mut this.inner).poll_read(cx, buf);
436        };
437
438        // If we're currently limited, wait until we've got some bucket space again
439        if let Some(bucket_refilled) = &mut this.bucket_refilled {
440            ready!(bucket_refilled.poll(cx));
441            this.bucket_refilled = None;
442        }
443
444        // We're not currently limited, let's read
445
446        // Poll inner for a new item.
447        let bytes_before = buf.remaining();
448        ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?;
449        let bytes_read = bytes_before - buf.remaining();
450
451        // Record how much we've read, rate limit accordingly, if need be.
452        if let Err(refill_time) = bucket.consume(bytes_read) {
453            this.record_rate_limited(bytes_read);
454            this.bucket_refilled = Some(Box::pin(time::sleep_until(refill_time)));
455        }
456
457        Poll::Ready(Ok(()))
458    }
459}
460
461impl<S: AsyncWrite + Unpin> AsyncWrite for RateLimited<S> {
462    fn poll_write(
463        mut self: Pin<&mut Self>,
464        cx: &mut std::task::Context<'_>,
465        buf: &[u8],
466    ) -> Poll<Result<usize, std::io::Error>> {
467        Pin::new(&mut self.inner).poll_write(cx, buf)
468    }
469
470    fn poll_flush(
471        mut self: Pin<&mut Self>,
472        cx: &mut std::task::Context<'_>,
473    ) -> Poll<Result<(), std::io::Error>> {
474        Pin::new(&mut self.inner).poll_flush(cx)
475    }
476
477    fn poll_shutdown(
478        mut self: Pin<&mut Self>,
479        cx: &mut std::task::Context<'_>,
480    ) -> Poll<Result<(), std::io::Error>> {
481        Pin::new(&mut self.inner).poll_shutdown(cx)
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use std::sync::Arc;
488
489    use n0_error::{Result, StdResultExt};
490    use n0_future::time;
491    use tokio::io::{AsyncReadExt, AsyncWriteExt};
492    use tracing_test::traced_test;
493
494    use super::Bucket;
495    use crate::server::{Metrics, streams::RateLimited};
496
497    #[tokio::test(start_paused = true)]
498    #[traced_test]
499    async fn test_ratelimiter() -> Result {
500        let (read, mut write) = tokio::io::duplex(4096);
501
502        let send_total = 10 * 1024 * 1024; // 10MiB
503        let send_data = vec![42u8; send_total];
504
505        let bytes_per_second = 12_345;
506
507        let mut rate_limited = RateLimited::new(
508            read,
509            bytes_per_second / 10,
510            bytes_per_second,
511            Arc::new(Metrics::default()),
512        )?;
513
514        let before = time::Instant::now();
515        n0_future::future::try_zip(
516            async {
517                let mut remaining = send_total;
518                let mut buf = [0u8; 4096];
519                while remaining > 0 {
520                    remaining -= rate_limited.read(&mut buf).await?;
521                }
522                Ok(())
523            },
524            async {
525                write.write_all(&send_data).await?;
526                write.flush().await
527            },
528        )
529        .await
530        .anyerr()?;
531
532        let duration = time::Instant::now().duration_since(before);
533        assert_ne!(duration.as_millis(), 0);
534
535        let actual_bytes_per_second = send_total as f64 / duration.as_secs_f64();
536        println!("{actual_bytes_per_second}");
537        assert_eq!(actual_bytes_per_second.round() as u32, bytes_per_second);
538
539        Ok(())
540    }
541
542    #[tokio::test(start_paused = true)]
543    async fn test_bucket_high_refill() -> Result {
544        let bytes_per_second = i64::MAX;
545        let mut bucket = Bucket::new(i64::MAX, bytes_per_second, time::Duration::from_millis(100))?;
546        for _ in 0..100 {
547            time::sleep(time::Duration::from_millis(100)).await;
548            assert!(bucket.consume(1_000_000).is_ok());
549        }
550
551        Ok(())
552    }
553
554    #[tokio::test(start_paused = true)]
555    async fn smoke_test_bucket_high_consume() -> Result {
556        let bytes_per_second = 123_456;
557        let mut bucket = Bucket::new(
558            bytes_per_second / 10,
559            bytes_per_second,
560            time::Duration::from_millis(100),
561        )?;
562        for _ in 0..100 {
563            let Err(until) = bucket.consume(usize::MAX) else {
564                panic!("i64::MAX shouldn't be within limits");
565            };
566            time::sleep_until(until).await;
567        }
568
569        Ok(())
570    }
571}