relay_core_lib/proxy/
throttle.rs1use crate::interceptor::{BoxError, HttpBody};
2use hyper::body::{Body, Bytes, Frame, SizeHint};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6use tokio::time::Instant;
7
8pub struct ThrottleBody {
12 inner: HttpBody,
13 bytes_per_sec: u64,
14 last_frame_at: Option<Instant>,
15}
16
17impl ThrottleBody {
18 pub fn new(inner: HttpBody, bytes_per_sec: u64) -> Self {
19 Self {
20 inner,
21 bytes_per_sec,
22 last_frame_at: None,
23 }
24 }
25}
26
27impl Body for ThrottleBody {
28 type Data = Bytes;
29 type Error = BoxError;
30
31 fn poll_frame(
32 mut self: Pin<&mut Self>,
33 cx: &mut Context<'_>,
34 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
35 let frame = match Pin::new(&mut self.inner).poll_frame(cx) {
36 Poll::Ready(Some(Ok(frame))) => frame,
37 other => return other,
38 };
39
40 if let Some(data) = frame.data_ref() {
42 let bytes = data.len() as u64;
43 if bytes > 0 && self.bytes_per_sec > 0 {
44 let frame_dur = Duration::from_micros(bytes * 1_000_000 / self.bytes_per_sec);
45 let now = Instant::now();
46
47 if let Some(last) = self.last_frame_at {
48 let elapsed = now.duration_since(last);
49 if elapsed < frame_dur {
50 let remaining = frame_dur - elapsed;
51 let waker = cx.waker().clone();
53 tokio::spawn(async move {
54 tokio::time::sleep(remaining).await;
55 waker.wake();
56 });
57 return Poll::Pending;
58 }
59 }
60 self.last_frame_at = Some(now);
61 }
62 }
63
64 Poll::Ready(Some(Ok(frame)))
65 }
66
67 fn is_end_stream(&self) -> bool {
68 self.inner.is_end_stream()
69 }
70
71 fn size_hint(&self) -> SizeHint {
72 self.inner.size_hint()
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use bytes::Bytes;
80 use http_body_util::{BodyExt, Full};
81 use hyper::body::Frame;
82 use std::pin::Pin;
83 use std::task::{Context, Poll, Waker};
84
85 #[tokio::test]
86 async fn test_throttle_body_preserves_data() {
87 let data = Bytes::from("test-body-data");
88 let body: HttpBody = Full::new(data.clone())
89 .map_err(|e| -> BoxError { Box::new(e) })
90 .boxed();
91 let throttled = ThrottleBody::new(body, 1_000_000);
93 let collected = throttled.collect().await.unwrap().to_bytes();
94 assert_eq!(collected, data);
95 }
96
97 #[tokio::test]
98 async fn test_throttle_body_passthrough_empty() {
99 let body: HttpBody = Full::new(Bytes::new())
100 .map_err(|e| -> BoxError { Box::new(e) })
101 .boxed();
102 let throttled = ThrottleBody::new(body, 1000);
103 let collected = throttled.collect().await.unwrap().to_bytes();
104 assert_eq!(collected.len(), 0);
105 }
106
107 #[tokio::test]
109 async fn test_throttle_body_passes_trailers() {
110 struct TrailerBody {
112 phase: u8,
113 }
114
115 impl Body for TrailerBody {
116 type Data = Bytes;
117 type Error = BoxError;
118
119 fn poll_frame(
120 mut self: Pin<&mut Self>,
121 _cx: &mut Context<'_>,
122 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
123 match self.phase {
124 0 => {
125 self.phase = 1;
126 Poll::Ready(Some(Ok(Frame::data(Bytes::from("body-data")))))
127 }
128 1 => {
129 self.phase = 2;
130 let mut trailers = hyper::HeaderMap::new();
131 trailers.insert("x-trailer", "present".parse().unwrap());
132 trailers.insert("x-end-stream", "true".parse().unwrap());
133 Poll::Ready(Some(Ok(Frame::trailers(trailers))))
134 }
135 _ => Poll::Ready(None),
136 }
137 }
138 }
139
140 let body: HttpBody = TrailerBody { phase: 0 }
141 .map_err(|e| -> BoxError { e })
142 .boxed();
143 let mut throttled = ThrottleBody::new(body, 1_000_000);
144
145 let mut poll_count = 0;
146 let mut data_frames = 0;
147 let mut trailer_frames = 0;
148 let mut trailers: Option<hyper::HeaderMap> = None;
149
150 let waker = Waker::noop();
151 let mut cx = Context::from_waker(&waker);
152 loop {
153 match Pin::new(&mut throttled).poll_frame(&mut cx) {
154 Poll::Ready(Some(Ok(frame))) => {
155 poll_count += 1;
156 if frame.data_ref().is_some() {
157 data_frames += 1;
158 }
159 if let Some(t) = frame.trailers_ref() {
160 trailer_frames += 1;
161 trailers = Some(t.clone());
162 }
163 }
164 Poll::Ready(Some(Err(e))) => panic!("unexpected error: {}", e),
165 Poll::Ready(None) => break,
166 Poll::Pending => panic!("ThrottleBody should not pend at full speed"),
167 }
168 }
169
170 assert_eq!(poll_count, 2, "should yield data + trailers = 2 frames");
171 assert_eq!(data_frames, 1, "should have 1 data frame");
172 assert_eq!(trailer_frames, 1, "should have 1 trailers frame");
173 let trailers = trailers.expect("trailers should be present");
174 assert_eq!(
175 trailers.get("x-trailer").and_then(|v| v.to_str().ok()),
176 Some("present")
177 );
178 assert_eq!(
179 trailers.get("x-end-stream").and_then(|v| v.to_str().ok()),
180 Some("true")
181 );
182 }
183}