hyper_throttle/
lib.rs

1//! ```rust
2//! use bytes::Bytes;
3//! use http_body_util::{BodyExt, Empty};
4//! use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
5//! use hyper_util::client::legacy::connect::HttpConnector;
6//! use hyper_util::client::legacy::Client;
7//! use hyper_util::rt::{TokioExecutor, TokioTimer};
8//!
9//! let mut connector = HttpConnector::new();
10//! connector.enforce_http(false);
11//! let connector = hyper_throttle::Connector::builder(TokioTimer::new())
12//!     .read_rate(65536) // 64 KiB/s
13//!     .build(connector);
14//! let connector = HttpsConnectorBuilder::new()
15//!     .with_native_roots()?
16//!     .https_or_http()
17//!     .enable_all_versions()
18//!     .wrap_connector(connector);
19//! let client = Client::builder(TokioExecutor::new()).build::<_, Empty<Bytes>>(connector);
20//! # std::io::Result::Ok(())
21//! ```
22
23use http::Uri;
24use hyper::rt::{Read, ReadBuf, ReadBufCursor, Sleep, Timer, Write};
25use std::future;
26use std::io;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::{ready, Context, Poll};
30use std::time::Duration;
31use tower_service::Service;
32
33pub struct ConnectorBuilder {
34    timer: Arc<dyn Timer + Send + Sync>,
35    read_rate: Option<u64>,
36    write_rate: Option<u64>,
37}
38
39impl ConnectorBuilder {
40    pub fn build<C>(self, inner: C) -> Connector<C> {
41        Connector {
42            inner,
43            timer: self.timer,
44            read_rate: self.read_rate,
45            write_rate: self.write_rate,
46        }
47    }
48
49    #[must_use]
50    pub fn read_rate(mut self, rate: u64) -> Self {
51        self.read_rate = Some(rate);
52        self
53    }
54
55    #[must_use]
56    pub fn write_rate(mut self, rate: u64) -> Self {
57        self.write_rate = Some(rate);
58        self
59    }
60}
61
62#[derive(Clone)]
63pub struct Connector<C> {
64    inner: C,
65    timer: Arc<dyn Timer + Send + Sync>,
66    read_rate: Option<u64>,
67    write_rate: Option<u64>,
68}
69
70impl Connector<()> {
71    pub fn builder<T>(timer: T) -> ConnectorBuilder
72    where
73        T: Timer + Send + Sync + 'static,
74    {
75        ConnectorBuilder {
76            timer: Arc::new(timer),
77            read_rate: None,
78            write_rate: None,
79        }
80    }
81}
82
83impl<C> Service<Uri> for Connector<C>
84where
85    C: Service<Uri>,
86{
87    type Response = Stream<C::Response>;
88    type Error = C::Error;
89    type Future = Future<C::Future>;
90
91    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92        self.inner.poll_ready(cx)
93    }
94
95    fn call(&mut self, request: Uri) -> Self::Future {
96        Future {
97            inner: self.inner.call(request),
98            timer: self.timer.clone(),
99            read_rate: self.read_rate,
100            write_rate: self.write_rate,
101        }
102    }
103}
104
105#[pin_project::pin_project]
106pub struct Stream<S> {
107    #[pin]
108    inner: S,
109    timer: Arc<dyn Timer + Send + Sync>,
110    read_rate: Option<u64>,
111    write_rate: Option<u64>,
112    read_sleep: Option<(usize, Pin<Box<dyn Sleep>>)>,
113    write_sleep: Option<(usize, Pin<Box<dyn Sleep>>)>,
114}
115
116fn call<T, F>(
117    cx: &mut Context<'_>,
118    timer: &T,
119    rate: Option<u64>,
120    sleep: &mut Option<(usize, Pin<Box<dyn Sleep>>)>,
121    mut f: F,
122) -> Poll<Result<usize, io::Error>>
123where
124    T: Timer + ?Sized,
125    F: FnMut(&mut Context<'_>) -> Poll<Result<usize, io::Error>>,
126{
127    loop {
128        if let Some((len, ref mut f)) = *sleep {
129            ready!(f.as_mut().poll(cx));
130            *sleep = None;
131            break Poll::Ready(Ok(len));
132        }
133        let len = ready!(f(cx)?);
134        if let Some(rate) = rate {
135            *sleep = Some((
136                len,
137                timer.sleep(Duration::from_nanos(len as u64 * 1_000_000_000 / rate)),
138            ));
139        } else {
140            break Poll::Ready(Ok(len));
141        }
142    }
143}
144
145impl<S> Read for Stream<S>
146where
147    S: Read,
148{
149    fn poll_read(
150        self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        mut buf: ReadBufCursor<'_>,
153    ) -> Poll<Result<(), io::Error>> {
154        let mut this = self.project();
155        let len = ready!(call(
156            cx,
157            this.timer.as_ref(),
158            *this.read_rate,
159            this.read_sleep,
160            |cx| unsafe {
161                let mut buf = ReadBuf::uninit(buf.as_mut());
162                ready!(this.inner.as_mut().poll_read(cx, buf.unfilled())?);
163                Poll::Ready(Ok(buf.filled().len()))
164            }
165        )?);
166        unsafe { buf.advance(len) };
167        Poll::Ready(Ok(()))
168    }
169}
170
171impl<S> Write for Stream<S>
172where
173    S: Write,
174{
175    fn poll_write(
176        self: Pin<&mut Self>,
177        cx: &mut Context<'_>,
178        buf: &[u8],
179    ) -> Poll<Result<usize, io::Error>> {
180        let mut this = self.project();
181        call(
182            cx,
183            this.timer.as_ref(),
184            *this.write_rate,
185            this.write_sleep,
186            |cx| this.inner.as_mut().poll_write(cx, buf),
187        )
188    }
189
190    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
191        let this = self.project();
192        this.inner.poll_flush(cx)
193    }
194
195    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
196        let this = self.project();
197        this.inner.poll_shutdown(cx)
198    }
199}
200
201#[cfg(feature = "hyper-util")]
202impl<S> hyper_util::client::legacy::connect::Connection for Stream<S>
203where
204    S: hyper_util::client::legacy::connect::Connection,
205{
206    fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
207        self.inner.connected()
208    }
209}
210
211#[pin_project::pin_project]
212pub struct Future<F> {
213    #[pin]
214    inner: F,
215    timer: Arc<dyn Timer + Send + Sync>,
216    read_rate: Option<u64>,
217    write_rate: Option<u64>,
218}
219
220impl<F, T, E> future::Future for Future<F>
221where
222    F: future::Future<Output = Result<T, E>>,
223{
224    type Output = Result<Stream<T>, E>;
225
226    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
227        let this = self.project();
228        this.inner.poll(cx).map_ok(|inner| Stream {
229            inner,
230            timer: this.timer.clone(),
231            read_rate: *this.read_rate,
232            write_rate: *this.write_rate,
233            read_sleep: None,
234            write_sleep: None,
235        })
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use bytes::Bytes;
242    use http_body_util::{BodyExt, Empty};
243    use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
244    use hyper_util::client::legacy::connect::HttpConnector;
245    use hyper_util::client::legacy::Client;
246    use hyper_util::rt::{TokioExecutor, TokioTimer};
247    use std::time::{Duration, Instant};
248
249    fn client(
250        read_rate: Option<u64>,
251        write_rate: Option<u64>,
252    ) -> Client<HttpsConnector<super::Connector<HttpConnector>>, Empty<Bytes>> {
253        let mut connector = HttpConnector::new();
254        connector.enforce_http(false);
255        let mut builder = super::Connector::builder(TokioTimer::new());
256        if let Some(rate) = read_rate {
257            builder = builder.read_rate(rate);
258        }
259        if let Some(rate) = write_rate {
260            builder = builder.write_rate(rate);
261        }
262        let connector = builder.build(connector);
263        let connector = HttpsConnectorBuilder::new()
264            .with_native_roots()
265            .unwrap()
266            .https_only()
267            .enable_http1()
268            .enable_http2()
269            .wrap_connector(connector);
270        Client::builder(TokioExecutor::new()).build(connector)
271    }
272
273    #[tokio::test]
274    async fn test_thorottle() {
275        let client = client(Some(4096), Some(4096));
276        let now = Instant::now();
277        let response = client
278            .get("https://www.rust-lang.org".parse().unwrap())
279            .await
280            .unwrap();
281        assert!(response.status().is_success());
282        response.into_body().collect().await.unwrap();
283        assert!(now.elapsed() > Duration::from_secs(4));
284    }
285
286    #[tokio::test]
287    async fn test_passthrough() {
288        let client = client(None, None);
289        let now = Instant::now();
290        let response = client
291            .get("https://www.rust-lang.org".parse().unwrap())
292            .await
293            .unwrap();
294        assert!(response.status().is_success());
295        response.into_body().collect().await.unwrap();
296        assert!(now.elapsed() < Duration::from_secs(2));
297    }
298}