1use std::{
4 pin::Pin,
5 sync::Arc,
6 task::{Context, Poll},
7};
8
9use n0_error::{ensure, stack_error};
10use n0_future::{FutureExt, Sink, Stream, ready, time};
11use tokio::io::{AsyncRead, AsyncWrite};
12use tracing::instrument;
13
14use super::{ClientRateLimit, Metrics};
15use crate::{
16 ExportKeyingMaterial, KeyCache, MAX_PACKET_SIZE,
17 protos::{
18 relay::{ClientToRelayMsg, Error as ProtoError, RelayToClientMsg},
19 streams::{StreamError, WsBytesFramed},
20 },
21};
22
23#[derive(Debug)]
29pub(crate) struct RelayedStream {
30 pub(crate) inner: WsBytesFramed<RateLimited<MaybeTlsStream>>,
31 pub(crate) key_cache: KeyCache,
32}
33
34#[cfg(test)]
35impl RelayedStream {
36 pub(crate) fn test(stream: tokio::io::DuplexStream) -> Self {
37 let stream = MaybeTlsStream::Test(stream);
38 let stream = RateLimited::unlimited(stream, Arc::new(Metrics::default()));
39 Self {
40 inner: WsBytesFramed {
41 io: tokio_websockets::ServerBuilder::new()
42 .limits(Self::limits())
43 .serve(stream),
44 },
45 key_cache: KeyCache::test(),
46 }
47 }
48
49 pub(crate) fn test_limited(
50 stream: tokio::io::DuplexStream,
51 max_burst_bytes: u32,
52 bytes_per_second: u32,
53 ) -> Result<Self, InvalidBucketConfig> {
54 let stream = MaybeTlsStream::Test(stream);
55 let stream = RateLimited::new(
56 stream,
57 max_burst_bytes,
58 bytes_per_second,
59 Arc::new(Metrics::default()),
60 )?;
61 Ok(Self {
62 inner: WsBytesFramed {
63 io: tokio_websockets::ServerBuilder::new()
64 .limits(Self::limits())
65 .serve(stream),
66 },
67 key_cache: KeyCache::test(),
68 })
69 }
70
71 fn limits() -> tokio_websockets::Limits {
72 tokio_websockets::Limits::default()
73 .max_payload_len(Some(crate::protos::relay::MAX_FRAME_SIZE))
74 }
75}
76
77#[stack_error(derive, add_meta)]
79#[non_exhaustive]
80pub enum SendError {
81 #[error(transparent)]
82 StreamError {
83 #[error(from, std_err)]
84 source: StreamError,
85 },
86 #[error("Packet exceeds max packet size")]
87 ExceedsMaxPacketSize { size: usize },
88 #[error("Attempted to send empty packet")]
89 EmptyPacket {},
90}
91
92impl Sink<RelayToClientMsg> for RelayedStream {
93 type Error = SendError;
94
95 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
96 Pin::new(&mut self.inner).poll_ready(cx).map_err(Into::into)
97 }
98
99 fn start_send(mut self: Pin<&mut Self>, item: RelayToClientMsg) -> Result<(), Self::Error> {
100 let size = item.encoded_len();
101 ensure!(
102 size <= MAX_PACKET_SIZE,
103 SendError::ExceedsMaxPacketSize { size }
104 );
105 if let RelayToClientMsg::Datagrams { datagrams, .. } = &item {
106 ensure!(!datagrams.contents.is_empty(), SendError::EmptyPacket);
107 }
108
109 Pin::new(&mut self.inner)
110 .start_send(item.to_bytes().freeze())
111 .map_err(Into::into)
112 }
113
114 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
115 Pin::new(&mut self.inner).poll_flush(cx).map_err(Into::into)
116 }
117
118 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
119 Pin::new(&mut self.inner).poll_close(cx).map_err(Into::into)
120 }
121}
122
123#[stack_error(derive, add_meta, from_sources)]
125#[non_exhaustive]
126pub enum RecvError {
127 #[error(transparent)]
128 Proto { source: ProtoError },
129 #[error(transparent)]
130 StreamError {
131 #[error(std_err)]
132 source: StreamError,
133 },
134}
135
136impl Stream for RelayedStream {
137 type Item = Result<ClientToRelayMsg, RecvError>;
138
139 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140 Poll::Ready(match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
141 Some(Ok(msg)) => {
142 Some(ClientToRelayMsg::from_bytes(msg, &self.key_cache).map_err(Into::into))
143 }
144 Some(Err(e)) => Some(Err(e.into())),
145 None => None,
146 })
147 }
148}
149
150#[derive(Debug)]
154#[allow(clippy::large_enum_variant)]
155pub enum MaybeTlsStream {
156 Plain(tokio::net::TcpStream),
158 Tls(tokio_rustls::server::TlsStream<tokio::net::TcpStream>),
160 #[cfg(test)]
162 Test(tokio::io::DuplexStream),
163}
164
165impl ExportKeyingMaterial for MaybeTlsStream {
166 fn export_keying_material<T: AsMut<[u8]>>(
167 &self,
168 output: T,
169 label: &[u8],
170 context: Option<&[u8]>,
171 ) -> Option<T> {
172 let Self::Tls(tls) = self else {
173 return None;
174 };
175
176 tls.get_ref()
177 .1
178 .export_keying_material(output, label, context)
179 .ok()
180 }
181}
182
183impl AsyncRead for MaybeTlsStream {
184 fn poll_read(
185 mut self: Pin<&mut Self>,
186 cx: &mut Context<'_>,
187 buf: &mut tokio::io::ReadBuf<'_>,
188 ) -> Poll<std::io::Result<()>> {
189 match &mut *self {
190 MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
191 MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
192 #[cfg(test)]
193 MaybeTlsStream::Test(s) => Pin::new(s).poll_read(cx, buf),
194 }
195 }
196}
197
198impl AsyncWrite for MaybeTlsStream {
199 fn poll_flush(
200 mut self: Pin<&mut Self>,
201 cx: &mut Context<'_>,
202 ) -> Poll<std::result::Result<(), std::io::Error>> {
203 match &mut *self {
204 MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
205 MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
206 #[cfg(test)]
207 MaybeTlsStream::Test(s) => Pin::new(s).poll_flush(cx),
208 }
209 }
210
211 fn poll_shutdown(
212 mut self: Pin<&mut Self>,
213 cx: &mut Context<'_>,
214 ) -> Poll<std::result::Result<(), std::io::Error>> {
215 match &mut *self {
216 MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
217 MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
218 #[cfg(test)]
219 MaybeTlsStream::Test(s) => Pin::new(s).poll_shutdown(cx),
220 }
221 }
222
223 fn poll_write(
224 mut self: Pin<&mut Self>,
225 cx: &mut Context<'_>,
226 buf: &[u8],
227 ) -> Poll<std::result::Result<usize, std::io::Error>> {
228 match &mut *self {
229 MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
230 MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
231 #[cfg(test)]
232 MaybeTlsStream::Test(s) => Pin::new(s).poll_write(cx, buf),
233 }
234 }
235
236 fn poll_write_vectored(
237 mut self: Pin<&mut Self>,
238 cx: &mut Context<'_>,
239 bufs: &[std::io::IoSlice<'_>],
240 ) -> Poll<std::result::Result<usize, std::io::Error>> {
241 match &mut *self {
242 MaybeTlsStream::Plain(s) => Pin::new(s).poll_write_vectored(cx, bufs),
243 MaybeTlsStream::Tls(s) => Pin::new(s).poll_write_vectored(cx, bufs),
244 #[cfg(test)]
245 MaybeTlsStream::Test(s) => Pin::new(s).poll_write_vectored(cx, bufs),
246 }
247 }
248
249 fn is_write_vectored(&self) -> bool {
250 match self {
251 MaybeTlsStream::Plain(s) => s.is_write_vectored(),
252 MaybeTlsStream::Tls(s) => s.is_write_vectored(),
253 #[cfg(test)]
254 MaybeTlsStream::Test(s) => s.is_write_vectored(),
255 }
256 }
257}
258
259#[derive(Debug)]
266pub(crate) struct RateLimited<S> {
267 inner: S,
268 bucket: Option<Bucket>,
269 bucket_refilled: Option<Pin<Box<time::Sleep>>>,
270 limited_once: bool,
272 metrics: Arc<Metrics>,
273}
274
275#[derive(Debug)]
276struct Bucket {
277 fill: i64,
279 max: i64,
281 last_fill: time::Instant,
283 refill_period: time::Duration,
285 refill: i64,
287}
288
289#[allow(missing_docs)]
290#[stack_error(derive, add_meta)]
291pub struct InvalidBucketConfig {
292 max: i64,
293 bytes_per_second: i64,
294 refill_period: time::Duration,
295}
296
297impl Bucket {
298 fn new(
299 max: i64,
300 bytes_per_second: i64,
301 refill_period: time::Duration,
302 ) -> Result<Self, InvalidBucketConfig> {
303 let refill = bytes_per_second.saturating_mul(refill_period.as_millis() as i64) / 1000;
305 ensure!(
306 max > 0 && bytes_per_second > 0 && refill_period.as_millis() as u32 > 0 && refill > 0,
307 InvalidBucketConfig {
308 max,
309 bytes_per_second,
310 refill_period
311 }
312 );
313 Ok(Self {
314 fill: max,
315 max,
316 last_fill: time::Instant::now(),
317 refill_period,
318 refill,
319 })
320 }
321
322 fn update_state(&mut self) {
323 let now = time::Instant::now();
324 let refill_periods = now.saturating_duration_since(self.last_fill).as_millis() as u32
326 / self.refill_period.as_millis() as u32;
327 if refill_periods == 0 {
328 return;
330 }
331
332 self.fill = self
333 .fill
334 .saturating_add(refill_periods as i64 * self.refill);
335 self.fill = std::cmp::min(self.fill, self.max);
336 self.last_fill += self.refill_period * refill_periods;
337 }
338
339 fn consume(&mut self, bytes: usize) -> Result<(), time::Instant> {
340 let bytes = i64::try_from(bytes).unwrap_or(i64::MAX);
341 self.update_state();
342
343 self.fill = self.fill.saturating_sub(bytes);
344
345 if self.fill > 0 {
346 return Ok(());
347 }
348
349 let missing = self.fill.saturating_neg();
350
351 let periods_needed = (missing / self.refill) + 1;
352 let periods_needed = u32::try_from(periods_needed).unwrap_or(u32::MAX);
353
354 Err(self.last_fill + periods_needed * self.refill_period)
355 }
356}
357
358impl<S> RateLimited<S> {
359 pub(crate) fn from_cfg(
360 cfg: Option<ClientRateLimit>,
361 io: S,
362 metrics: Arc<Metrics>,
363 ) -> Result<Self, InvalidBucketConfig> {
364 match cfg {
365 Some(cfg) => {
366 let bytes_per_second = cfg.bytes_per_second.into();
367 let max_burst_bytes = cfg.max_burst_bytes.map_or(bytes_per_second / 10, u32::from);
368 Self::new(io, max_burst_bytes, bytes_per_second, metrics)
369 }
370 None => Ok(Self::unlimited(io, metrics)),
371 }
372 }
373
374 pub(crate) fn new(
375 inner: S,
376 max_burst_bytes: u32,
377 bytes_per_second: u32,
378 metrics: Arc<Metrics>,
379 ) -> Result<Self, InvalidBucketConfig> {
380 Ok(Self {
381 inner,
382 bucket: Some(Bucket::new(
383 max_burst_bytes as i64,
384 bytes_per_second as i64,
385 time::Duration::from_millis(100),
386 )?),
387 bucket_refilled: None,
388 limited_once: false,
389 metrics,
390 })
391 }
392
393 pub(crate) fn unlimited(inner: S, metrics: Arc<Metrics>) -> Self {
394 Self {
395 inner,
396 bucket: None,
397 bucket_refilled: None,
398 limited_once: false,
399 metrics,
400 }
401 }
402
403 fn record_rate_limited(&mut self, bytes: usize) {
405 self.metrics.bytes_rx_ratelimited_total.inc_by(bytes as u64);
407 if !self.limited_once {
408 self.metrics.conns_rx_ratelimited_total.inc();
409 self.limited_once = true;
410 }
411 }
412}
413
414impl<S: ExportKeyingMaterial> ExportKeyingMaterial for RateLimited<S> {
415 fn export_keying_material<T: AsMut<[u8]>>(
416 &self,
417 output: T,
418 label: &[u8],
419 context: Option<&[u8]>,
420 ) -> Option<T> {
421 self.inner.export_keying_material(output, label, context)
422 }
423}
424
425impl<S: AsyncRead + Unpin> AsyncRead for RateLimited<S> {
426 #[instrument(name = "rate_limited_poll_read", skip_all)]
427 fn poll_read(
428 mut self: Pin<&mut Self>,
429 cx: &mut std::task::Context<'_>,
430 buf: &mut tokio::io::ReadBuf<'_>,
431 ) -> Poll<std::io::Result<()>> {
432 let this = &mut *self;
433 let Some(bucket) = &mut this.bucket else {
434 return Pin::new(&mut this.inner).poll_read(cx, buf);
436 };
437
438 if let Some(bucket_refilled) = &mut this.bucket_refilled {
440 ready!(bucket_refilled.poll(cx));
441 this.bucket_refilled = None;
442 }
443
444 let bytes_before = buf.remaining();
448 ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?;
449 let bytes_read = bytes_before - buf.remaining();
450
451 if let Err(refill_time) = bucket.consume(bytes_read) {
453 this.record_rate_limited(bytes_read);
454 this.bucket_refilled = Some(Box::pin(time::sleep_until(refill_time)));
455 }
456
457 Poll::Ready(Ok(()))
458 }
459}
460
461impl<S: AsyncWrite + Unpin> AsyncWrite for RateLimited<S> {
462 fn poll_write(
463 mut self: Pin<&mut Self>,
464 cx: &mut std::task::Context<'_>,
465 buf: &[u8],
466 ) -> Poll<Result<usize, std::io::Error>> {
467 Pin::new(&mut self.inner).poll_write(cx, buf)
468 }
469
470 fn poll_flush(
471 mut self: Pin<&mut Self>,
472 cx: &mut std::task::Context<'_>,
473 ) -> Poll<Result<(), std::io::Error>> {
474 Pin::new(&mut self.inner).poll_flush(cx)
475 }
476
477 fn poll_shutdown(
478 mut self: Pin<&mut Self>,
479 cx: &mut std::task::Context<'_>,
480 ) -> Poll<Result<(), std::io::Error>> {
481 Pin::new(&mut self.inner).poll_shutdown(cx)
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use std::sync::Arc;
488
489 use n0_error::{Result, StdResultExt};
490 use n0_future::time;
491 use tokio::io::{AsyncReadExt, AsyncWriteExt};
492 use tracing_test::traced_test;
493
494 use super::Bucket;
495 use crate::server::{Metrics, streams::RateLimited};
496
497 #[tokio::test(start_paused = true)]
498 #[traced_test]
499 async fn test_ratelimiter() -> Result {
500 let (read, mut write) = tokio::io::duplex(4096);
501
502 let send_total = 10 * 1024 * 1024; let send_data = vec![42u8; send_total];
504
505 let bytes_per_second = 12_345;
506
507 let mut rate_limited = RateLimited::new(
508 read,
509 bytes_per_second / 10,
510 bytes_per_second,
511 Arc::new(Metrics::default()),
512 )?;
513
514 let before = time::Instant::now();
515 n0_future::future::try_zip(
516 async {
517 let mut remaining = send_total;
518 let mut buf = [0u8; 4096];
519 while remaining > 0 {
520 remaining -= rate_limited.read(&mut buf).await?;
521 }
522 Ok(())
523 },
524 async {
525 write.write_all(&send_data).await?;
526 write.flush().await
527 },
528 )
529 .await
530 .anyerr()?;
531
532 let duration = time::Instant::now().duration_since(before);
533 assert_ne!(duration.as_millis(), 0);
534
535 let actual_bytes_per_second = send_total as f64 / duration.as_secs_f64();
536 println!("{actual_bytes_per_second}");
537 assert_eq!(actual_bytes_per_second.round() as u32, bytes_per_second);
538
539 Ok(())
540 }
541
542 #[tokio::test(start_paused = true)]
543 async fn test_bucket_high_refill() -> Result {
544 let bytes_per_second = i64::MAX;
545 let mut bucket = Bucket::new(i64::MAX, bytes_per_second, time::Duration::from_millis(100))?;
546 for _ in 0..100 {
547 time::sleep(time::Duration::from_millis(100)).await;
548 assert!(bucket.consume(1_000_000).is_ok());
549 }
550
551 Ok(())
552 }
553
554 #[tokio::test(start_paused = true)]
555 async fn smoke_test_bucket_high_consume() -> Result {
556 let bytes_per_second = 123_456;
557 let mut bucket = Bucket::new(
558 bytes_per_second / 10,
559 bytes_per_second,
560 time::Duration::from_millis(100),
561 )?;
562 for _ in 0..100 {
563 let Err(until) = bucket.consume(usize::MAX) else {
564 panic!("i64::MAX shouldn't be within limits");
565 };
566 time::sleep_until(until).await;
567 }
568
569 Ok(())
570 }
571}