1use http::Uri;
24use hyper::rt::{Read, ReadBuf, ReadBufCursor, Sleep, Timer, Write};
25use std::future;
26use std::io;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::{ready, Context, Poll};
30use std::time::Duration;
31use tower_service::Service;
32
33pub struct ConnectorBuilder {
34 timer: Arc<dyn Timer + Send + Sync>,
35 read_rate: Option<u64>,
36 write_rate: Option<u64>,
37}
38
39impl ConnectorBuilder {
40 pub fn build<C>(self, inner: C) -> Connector<C> {
41 Connector {
42 inner,
43 timer: self.timer,
44 read_rate: self.read_rate,
45 write_rate: self.write_rate,
46 }
47 }
48
49 #[must_use]
50 pub fn read_rate(mut self, rate: u64) -> Self {
51 self.read_rate = Some(rate);
52 self
53 }
54
55 #[must_use]
56 pub fn write_rate(mut self, rate: u64) -> Self {
57 self.write_rate = Some(rate);
58 self
59 }
60}
61
62#[derive(Clone)]
63pub struct Connector<C> {
64 inner: C,
65 timer: Arc<dyn Timer + Send + Sync>,
66 read_rate: Option<u64>,
67 write_rate: Option<u64>,
68}
69
70impl Connector<()> {
71 pub fn builder<T>(timer: T) -> ConnectorBuilder
72 where
73 T: Timer + Send + Sync + 'static,
74 {
75 ConnectorBuilder {
76 timer: Arc::new(timer),
77 read_rate: None,
78 write_rate: None,
79 }
80 }
81}
82
83impl<C> Service<Uri> for Connector<C>
84where
85 C: Service<Uri>,
86{
87 type Response = Stream<C::Response>;
88 type Error = C::Error;
89 type Future = Future<C::Future>;
90
91 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92 self.inner.poll_ready(cx)
93 }
94
95 fn call(&mut self, request: Uri) -> Self::Future {
96 Future {
97 inner: self.inner.call(request),
98 timer: self.timer.clone(),
99 read_rate: self.read_rate,
100 write_rate: self.write_rate,
101 }
102 }
103}
104
105#[pin_project::pin_project]
106pub struct Stream<S> {
107 #[pin]
108 inner: S,
109 timer: Arc<dyn Timer + Send + Sync>,
110 read_rate: Option<u64>,
111 write_rate: Option<u64>,
112 read_sleep: Option<(usize, Pin<Box<dyn Sleep>>)>,
113 write_sleep: Option<(usize, Pin<Box<dyn Sleep>>)>,
114}
115
116fn call<T, F>(
117 cx: &mut Context<'_>,
118 timer: &T,
119 rate: Option<u64>,
120 sleep: &mut Option<(usize, Pin<Box<dyn Sleep>>)>,
121 mut f: F,
122) -> Poll<Result<usize, io::Error>>
123where
124 T: Timer + ?Sized,
125 F: FnMut(&mut Context<'_>) -> Poll<Result<usize, io::Error>>,
126{
127 loop {
128 if let Some((len, ref mut f)) = *sleep {
129 ready!(f.as_mut().poll(cx));
130 *sleep = None;
131 break Poll::Ready(Ok(len));
132 }
133 let len = ready!(f(cx)?);
134 if let Some(rate) = rate {
135 *sleep = Some((
136 len,
137 timer.sleep(Duration::from_nanos(len as u64 * 1_000_000_000 / rate)),
138 ));
139 } else {
140 break Poll::Ready(Ok(len));
141 }
142 }
143}
144
145impl<S> Read for Stream<S>
146where
147 S: Read,
148{
149 fn poll_read(
150 self: Pin<&mut Self>,
151 cx: &mut Context<'_>,
152 mut buf: ReadBufCursor<'_>,
153 ) -> Poll<Result<(), io::Error>> {
154 let mut this = self.project();
155 let len = ready!(call(
156 cx,
157 this.timer.as_ref(),
158 *this.read_rate,
159 this.read_sleep,
160 |cx| unsafe {
161 let mut buf = ReadBuf::uninit(buf.as_mut());
162 ready!(this.inner.as_mut().poll_read(cx, buf.unfilled())?);
163 Poll::Ready(Ok(buf.filled().len()))
164 }
165 )?);
166 unsafe { buf.advance(len) };
167 Poll::Ready(Ok(()))
168 }
169}
170
171impl<S> Write for Stream<S>
172where
173 S: Write,
174{
175 fn poll_write(
176 self: Pin<&mut Self>,
177 cx: &mut Context<'_>,
178 buf: &[u8],
179 ) -> Poll<Result<usize, io::Error>> {
180 let mut this = self.project();
181 call(
182 cx,
183 this.timer.as_ref(),
184 *this.write_rate,
185 this.write_sleep,
186 |cx| this.inner.as_mut().poll_write(cx, buf),
187 )
188 }
189
190 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
191 let this = self.project();
192 this.inner.poll_flush(cx)
193 }
194
195 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
196 let this = self.project();
197 this.inner.poll_shutdown(cx)
198 }
199}
200
201#[cfg(feature = "hyper-util")]
202impl<S> hyper_util::client::legacy::connect::Connection for Stream<S>
203where
204 S: hyper_util::client::legacy::connect::Connection,
205{
206 fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
207 self.inner.connected()
208 }
209}
210
211#[pin_project::pin_project]
212pub struct Future<F> {
213 #[pin]
214 inner: F,
215 timer: Arc<dyn Timer + Send + Sync>,
216 read_rate: Option<u64>,
217 write_rate: Option<u64>,
218}
219
220impl<F, T, E> future::Future for Future<F>
221where
222 F: future::Future<Output = Result<T, E>>,
223{
224 type Output = Result<Stream<T>, E>;
225
226 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
227 let this = self.project();
228 this.inner.poll(cx).map_ok(|inner| Stream {
229 inner,
230 timer: this.timer.clone(),
231 read_rate: *this.read_rate,
232 write_rate: *this.write_rate,
233 read_sleep: None,
234 write_sleep: None,
235 })
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use bytes::Bytes;
242 use http_body_util::{BodyExt, Empty};
243 use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
244 use hyper_util::client::legacy::connect::HttpConnector;
245 use hyper_util::client::legacy::Client;
246 use hyper_util::rt::{TokioExecutor, TokioTimer};
247 use std::time::{Duration, Instant};
248
249 fn client(
250 read_rate: Option<u64>,
251 write_rate: Option<u64>,
252 ) -> Client<HttpsConnector<super::Connector<HttpConnector>>, Empty<Bytes>> {
253 let mut connector = HttpConnector::new();
254 connector.enforce_http(false);
255 let mut builder = super::Connector::builder(TokioTimer::new());
256 if let Some(rate) = read_rate {
257 builder = builder.read_rate(rate);
258 }
259 if let Some(rate) = write_rate {
260 builder = builder.write_rate(rate);
261 }
262 let connector = builder.build(connector);
263 let connector = HttpsConnectorBuilder::new()
264 .with_native_roots()
265 .unwrap()
266 .https_only()
267 .enable_http1()
268 .enable_http2()
269 .wrap_connector(connector);
270 Client::builder(TokioExecutor::new()).build(connector)
271 }
272
273 #[tokio::test]
274 async fn test_thorottle() {
275 let client = client(Some(4096), Some(4096));
276 let now = Instant::now();
277 let response = client
278 .get("https://www.rust-lang.org".parse().unwrap())
279 .await
280 .unwrap();
281 assert!(response.status().is_success());
282 response.into_body().collect().await.unwrap();
283 assert!(now.elapsed() > Duration::from_secs(4));
284 }
285
286 #[tokio::test]
287 async fn test_passthrough() {
288 let client = client(None, None);
289 let now = Instant::now();
290 let response = client
291 .get("https://www.rust-lang.org".parse().unwrap())
292 .await
293 .unwrap();
294 assert!(response.status().is_success());
295 response.into_body().collect().await.unwrap();
296 assert!(now.elapsed() < Duration::from_secs(2));
297 }
298}