async_speed_limit/
io.rs

1// Copyright 2019 TiKV Project Authors. Licensed under MIT or Apache-2.0.
2
3use crate::{clock::Clock, limiter::Resource};
4use std::{
5    io::{self, IoSlice, IoSliceMut},
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10fn length_of_result_usize<B>(a: &io::Result<usize>, _: &B) -> usize {
11    *a.as_ref().unwrap_or(&0)
12}
13
14impl<R: futures_io::AsyncRead, C: Clock> futures_io::AsyncRead for Resource<R, C> {
15    fn poll_read(
16        self: Pin<&mut Self>,
17        cx: &mut Context<'_>,
18        buf: &mut [u8],
19    ) -> Poll<io::Result<usize>> {
20        self.poll_limited(cx, (), length_of_result_usize, |r, cx, ()| {
21            r.poll_read(cx, buf)
22        })
23    }
24
25    fn poll_read_vectored(
26        self: Pin<&mut Self>,
27        cx: &mut Context<'_>,
28        bufs: &mut [IoSliceMut<'_>],
29    ) -> Poll<io::Result<usize>> {
30        self.poll_limited(cx, (), length_of_result_usize, |r, cx, ()| {
31            r.poll_read_vectored(cx, bufs)
32        })
33    }
34}
35
36impl<R: futures_io::AsyncWrite, C: Clock> futures_io::AsyncWrite for Resource<R, C> {
37    fn poll_write(
38        self: Pin<&mut Self>,
39        cx: &mut Context<'_>,
40        buf: &[u8],
41    ) -> Poll<io::Result<usize>> {
42        self.poll_limited(cx, (), length_of_result_usize, |r, cx, ()| {
43            r.poll_write(cx, buf)
44        })
45    }
46
47    fn poll_write_vectored(
48        self: Pin<&mut Self>,
49        cx: &mut Context<'_>,
50        bufs: &[IoSlice<'_>],
51    ) -> Poll<io::Result<usize>> {
52        self.poll_limited(cx, (), length_of_result_usize, |r, cx, ()| {
53            r.poll_write_vectored(cx, bufs)
54        })
55    }
56
57    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
58        self.get_pin_mut().poll_flush(cx)
59    }
60
61    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
62        self.get_pin_mut().poll_close(cx)
63    }
64}
65
66#[cfg(feature = "tokio")]
67impl<R: tokio::io::AsyncRead, C: Clock> tokio::io::AsyncRead for Resource<R, C> {
68    fn poll_read(
69        self: Pin<&mut Self>,
70        cx: &mut Context<'_>,
71        buf: &mut tokio::io::ReadBuf<'_>,
72    ) -> Poll<io::Result<()>> {
73        let filled = buf.filled().len();
74
75        self.poll_limited(
76            cx,
77            buf,
78            |res: &io::Result<()>, buf| {
79                res.is_ok()
80                    .then(|| buf.filled().len() - filled)
81                    .unwrap_or_default()
82            },
83            |r, cx, buf| r.poll_read(cx, buf),
84        )
85    }
86}
87
88#[cfg(feature = "tokio")]
89impl<R: tokio::io::AsyncWrite, C: Clock> tokio::io::AsyncWrite for Resource<R, C> {
90    fn poll_write(
91        self: Pin<&mut Self>,
92        cx: &mut Context<'_>,
93        buf: &[u8],
94    ) -> Poll<io::Result<usize>> {
95        self.poll_limited(cx, (), length_of_result_usize, |r, cx, _| {
96            r.poll_write(cx, buf)
97        })
98    }
99
100    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101        self.get_pin_mut().poll_flush(cx)
102    }
103
104    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
105        self.get_pin_mut().poll_shutdown(cx)
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use crate::{
113        clock::{ManualClock, Nanoseconds},
114        Limiter,
115    };
116    use futures_executor::LocalPool;
117    use futures_util::{
118        io::{copy_buf, BufReader},
119        task::SpawnExt,
120    };
121    use rand::{thread_rng, RngCore};
122
123    #[test]
124    fn limited_read() {
125        let mut pool = LocalPool::new();
126        let sp = pool.spawner();
127
128        let limiter = Limiter::<ManualClock>::new(512.0);
129        let clock = limiter.clock();
130
131        sp.spawn({
132            let limiter = limiter.clone();
133            let clock = clock.clone();
134            async move {
135                let mut src = vec![0_u8; 1024];
136                thread_rng().fill_bytes(&mut src);
137                let mut dst = Vec::new();
138
139                let read = BufReader::with_capacity(256, limiter.limit(&*src));
140                let count = copy_buf(read, &mut dst).await.unwrap();
141
142                assert_eq!(clock.now(), Nanoseconds(2_000_000_000));
143                assert_eq!(count, src.len() as u64);
144                assert!(src == dst);
145            }
146        })
147        .unwrap();
148
149        clock.set_time(Nanoseconds(0));
150        pool.run_until_stalled();
151        assert_eq!(limiter.total_bytes_consumed(), 256);
152
153        clock.set_time(Nanoseconds(500_000_000));
154        pool.run_until_stalled();
155        assert_eq!(limiter.total_bytes_consumed(), 512);
156
157        clock.set_time(Nanoseconds(1_000_000_000));
158        pool.run_until_stalled();
159        assert_eq!(limiter.total_bytes_consumed(), 768);
160
161        clock.set_time(Nanoseconds(1_500_000_000));
162        pool.run_until_stalled();
163        assert_eq!(limiter.total_bytes_consumed(), 1024);
164
165        clock.set_time(Nanoseconds(2_000_000_000));
166        pool.run_until_stalled();
167
168        assert!(!pool.try_run_one());
169    }
170
171    #[test]
172    fn unlimited_read() {
173        let mut pool = LocalPool::new();
174        let sp = pool.spawner();
175
176        let limiter = Limiter::<ManualClock>::new(std::f64::INFINITY);
177
178        sp.spawn({
179            async move {
180                let mut src = vec![0_u8; 1024];
181                thread_rng().fill_bytes(&mut src);
182                let mut dst = Vec::new();
183
184                let read = BufReader::with_capacity(256, limiter.limit(&*src));
185                let count = copy_buf(read, &mut dst).await.unwrap();
186
187                assert_eq!(count, src.len() as u64);
188                assert!(src == dst);
189            }
190        })
191        .unwrap();
192
193        pool.run_until_stalled();
194        assert!(!pool.try_run_one());
195    }
196
197    #[test]
198    fn limited_write() {
199        let mut pool = LocalPool::new();
200        let sp = pool.spawner();
201
202        let limiter = Limiter::<ManualClock>::new(512.0);
203        let clock = limiter.clock();
204
205        sp.spawn({
206            let limiter = limiter.clone();
207            let clock = clock.clone();
208            async move {
209                let mut src = vec![0_u8; 1024];
210                thread_rng().fill_bytes(&mut src);
211
212                let read = BufReader::with_capacity(256, &*src);
213                let mut write = limiter.limit(Vec::new());
214                let count = copy_buf(read, &mut write).await.unwrap();
215
216                assert_eq!(clock.now(), Nanoseconds(1_500_000_000));
217                assert_eq!(count, src.len() as u64);
218                assert!(src == write.into_inner());
219            }
220        })
221        .unwrap();
222
223        clock.set_time(Nanoseconds(0));
224        pool.run_until_stalled();
225        assert_eq!(limiter.total_bytes_consumed(), 256);
226
227        clock.set_time(Nanoseconds(500_000_000));
228        pool.run_until_stalled();
229        assert_eq!(limiter.total_bytes_consumed(), 512);
230
231        clock.set_time(Nanoseconds(1_000_000_000));
232        pool.run_until_stalled();
233        assert_eq!(limiter.total_bytes_consumed(), 768);
234
235        clock.set_time(Nanoseconds(1_500_000_000));
236        pool.run_until_stalled();
237        assert_eq!(limiter.total_bytes_consumed(), 1024);
238
239        clock.set_time(Nanoseconds(2_000_000_000));
240        pool.run_until_stalled();
241
242        assert!(!pool.try_run_one());
243    }
244}
245
246#[cfg(test)]
247#[cfg(all(feature = "tokio", feature = "standard-clock"))]
248mod tokio_tests {
249    use std::{
250        io,
251        time::{Duration, Instant},
252    };
253
254    use crate::Limiter;
255
256    use futures_util::{future::ok, TryStreamExt as _};
257    use tokio::io::{copy, repeat, sink, AsyncReadExt as _, AsyncWriteExt as _};
258    use tokio_util::codec::{BytesCodec, FramedRead};
259
260    #[tokio::test]
261    async fn limited_read() -> io::Result<()> {
262        let limiter = <Limiter>::new(32768.0);
263
264        let start_time = Instant::now();
265
266        let reader = repeat(50).take(65536);
267        let mut reader = limiter.limit(reader);
268
269        let mut sink = sink();
270        let total = copy(&mut reader, &mut sink).await?;
271        sink.shutdown().await?;
272
273        let elapsed = start_time.elapsed();
274
275        assert!(
276            Duration::from_millis(1900) <= elapsed && elapsed <= Duration::from_millis(2100),
277            "elapsed = {:?}",
278            elapsed
279        );
280        assert_eq!(total, 65536);
281
282        Ok(())
283    }
284
285    #[tokio::test]
286    async fn unlimited_read() -> io::Result<()> {
287        let limiter = <Limiter>::new(std::f64::INFINITY);
288
289        let start_time = Instant::now();
290
291        let reader = repeat(50).take(65536);
292        let mut reader = limiter.limit(reader);
293        let mut sink = sink();
294
295        let total = copy(&mut reader, &mut sink).await?;
296        sink.shutdown().await?;
297
298        let elapsed = start_time.elapsed();
299
300        assert!(
301            elapsed <= Duration::from_millis(100),
302            "elapsed = {:?}",
303            elapsed
304        );
305        assert_eq!(total, 65536);
306
307        Ok(())
308    }
309
310    #[tokio::test]
311    async fn limited_read_byte_stream() -> io::Result<()> {
312        let limiter = <Limiter>::new(30000.0);
313
314        let start_time = Instant::now();
315
316        let reader = repeat(50).take(60000);
317        let reader = limiter.limit(reader);
318
319        let total = FramedRead::new(reader, BytesCodec::new())
320            .try_fold(0, |i, j| {
321                assert!(j.iter().all(|b| *b == 50_u8), "{} / {:?}", i, j);
322                ok(i + j.len())
323            })
324            .await?;
325
326        let elapsed = start_time.elapsed();
327
328        assert!(
329            Duration::from_millis(1900) <= elapsed && elapsed <= Duration::from_millis(2100),
330            "elapsed = {:?}",
331            elapsed
332        );
333        assert_eq!(total, 60000);
334
335        Ok(())
336    }
337}