Skip to main content

relay_core_lib/proxy/
throttle.rs

1use crate::interceptor::{BoxError, HttpBody};
2use hyper::body::{Body, Bytes, Frame, SizeHint};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6use tokio::time::Instant;
7
8/// Wraps a body stream with bandwidth throttling (bytes/sec).
9/// Inserts artificial delays between data frames to ensure
10/// throughput does not exceed the configured rate.
11pub struct ThrottleBody {
12    inner: HttpBody,
13    bytes_per_sec: u64,
14    last_frame_at: Option<Instant>,
15}
16
17impl ThrottleBody {
18    pub fn new(inner: HttpBody, bytes_per_sec: u64) -> Self {
19        Self {
20            inner,
21            bytes_per_sec,
22            last_frame_at: None,
23        }
24    }
25}
26
27impl Body for ThrottleBody {
28    type Data = Bytes;
29    type Error = BoxError;
30
31    fn poll_frame(
32        mut self: Pin<&mut Self>,
33        cx: &mut Context<'_>,
34    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
35        let frame = match Pin::new(&mut self.inner).poll_frame(cx) {
36            Poll::Ready(Some(Ok(frame))) => frame,
37            other => return other,
38        };
39
40        // Calculate per-frame delay based on data size
41        if let Some(data) = frame.data_ref() {
42            let bytes = data.len() as u64;
43            if bytes > 0 && self.bytes_per_sec > 0 {
44                let frame_dur = Duration::from_micros(bytes * 1_000_000 / self.bytes_per_sec);
45                let now = Instant::now();
46
47                if let Some(last) = self.last_frame_at {
48                    let elapsed = now.duration_since(last);
49                    if elapsed < frame_dur {
50                        let remaining = frame_dur - elapsed;
51                        // Since we can't .await in poll_frame, schedule a wake
52                        let waker = cx.waker().clone();
53                        tokio::spawn(async move {
54                            tokio::time::sleep(remaining).await;
55                            waker.wake();
56                        });
57                        return Poll::Pending;
58                    }
59                }
60                self.last_frame_at = Some(now);
61            }
62        }
63
64        Poll::Ready(Some(Ok(frame)))
65    }
66
67    fn is_end_stream(&self) -> bool {
68        self.inner.is_end_stream()
69    }
70
71    fn size_hint(&self) -> SizeHint {
72        self.inner.size_hint()
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use bytes::Bytes;
80    use http_body_util::{BodyExt, Full};
81    use hyper::body::Frame;
82    use std::pin::Pin;
83    use std::task::{Context, Poll, Waker};
84
85    #[tokio::test]
86    async fn test_throttle_body_preserves_data() {
87        let data = Bytes::from("test-body-data");
88        let body: HttpBody = Full::new(data.clone())
89            .map_err(|e| -> BoxError { Box::new(e) })
90            .boxed();
91        // High rate limit — no effective throttling
92        let throttled = ThrottleBody::new(body, 1_000_000);
93        let collected = throttled.collect().await.unwrap().to_bytes();
94        assert_eq!(collected, data);
95    }
96
97    #[tokio::test]
98    async fn test_throttle_body_passthrough_empty() {
99        let body: HttpBody = Full::new(Bytes::new())
100            .map_err(|e| -> BoxError { Box::new(e) })
101            .boxed();
102        let throttled = ThrottleBody::new(body, 1000);
103        let collected = throttled.collect().await.unwrap().to_bytes();
104        assert_eq!(collected.len(), 0);
105    }
106
107    /// Verify ThrottleBody passes trailers through unchanged.
108    #[tokio::test]
109    async fn test_throttle_body_passes_trailers() {
110        /// A body that yields data then trailers then EOF.
111        struct TrailerBody {
112            phase: u8,
113        }
114
115        impl Body for TrailerBody {
116            type Data = Bytes;
117            type Error = BoxError;
118
119            fn poll_frame(
120                mut self: Pin<&mut Self>,
121                _cx: &mut Context<'_>,
122            ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
123                match self.phase {
124                    0 => {
125                        self.phase = 1;
126                        Poll::Ready(Some(Ok(Frame::data(Bytes::from("body-data")))))
127                    }
128                    1 => {
129                        self.phase = 2;
130                        let mut trailers = hyper::HeaderMap::new();
131                        trailers.insert("x-trailer", "present".parse().unwrap());
132                        trailers.insert("x-end-stream", "true".parse().unwrap());
133                        Poll::Ready(Some(Ok(Frame::trailers(trailers))))
134                    }
135                    _ => Poll::Ready(None),
136                }
137            }
138        }
139
140        let body: HttpBody = TrailerBody { phase: 0 }
141            .map_err(|e| -> BoxError { e })
142            .boxed();
143        let mut throttled = ThrottleBody::new(body, 1_000_000);
144
145        let mut poll_count = 0;
146        let mut data_frames = 0;
147        let mut trailer_frames = 0;
148        let mut trailers: Option<hyper::HeaderMap> = None;
149
150        let waker = Waker::noop();
151        let mut cx = Context::from_waker(&waker);
152        loop {
153            match Pin::new(&mut throttled).poll_frame(&mut cx) {
154                Poll::Ready(Some(Ok(frame))) => {
155                    poll_count += 1;
156                    if frame.data_ref().is_some() {
157                        data_frames += 1;
158                    }
159                    if let Some(t) = frame.trailers_ref() {
160                        trailer_frames += 1;
161                        trailers = Some(t.clone());
162                    }
163                }
164                Poll::Ready(Some(Err(e))) => panic!("unexpected error: {}", e),
165                Poll::Ready(None) => break,
166                Poll::Pending => panic!("ThrottleBody should not pend at full speed"),
167            }
168        }
169
170        assert_eq!(poll_count, 2, "should yield data + trailers = 2 frames");
171        assert_eq!(data_frames, 1, "should have 1 data frame");
172        assert_eq!(trailer_frames, 1, "should have 1 trailers frame");
173        let trailers = trailers.expect("trailers should be present");
174        assert_eq!(
175            trailers.get("x-trailer").and_then(|v| v.to_str().ok()),
176            Some("present")
177        );
178        assert_eq!(
179            trailers.get("x-end-stream").and_then(|v| v.to_str().ok()),
180            Some("true")
181        );
182    }
183}