1use crate::{clock::Clock, limiter::Resource};
4use std::{
5 io::{self, IoSlice, IoSliceMut},
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10fn length_of_result_usize<B>(a: &io::Result<usize>, _: &B) -> usize {
11 *a.as_ref().unwrap_or(&0)
12}
13
14impl<R: futures_io::AsyncRead, C: Clock> futures_io::AsyncRead for Resource<R, C> {
15 fn poll_read(
16 self: Pin<&mut Self>,
17 cx: &mut Context<'_>,
18 buf: &mut [u8],
19 ) -> Poll<io::Result<usize>> {
20 self.poll_limited(cx, (), length_of_result_usize, |r, cx, ()| {
21 r.poll_read(cx, buf)
22 })
23 }
24
25 fn poll_read_vectored(
26 self: Pin<&mut Self>,
27 cx: &mut Context<'_>,
28 bufs: &mut [IoSliceMut<'_>],
29 ) -> Poll<io::Result<usize>> {
30 self.poll_limited(cx, (), length_of_result_usize, |r, cx, ()| {
31 r.poll_read_vectored(cx, bufs)
32 })
33 }
34}
35
36impl<R: futures_io::AsyncWrite, C: Clock> futures_io::AsyncWrite for Resource<R, C> {
37 fn poll_write(
38 self: Pin<&mut Self>,
39 cx: &mut Context<'_>,
40 buf: &[u8],
41 ) -> Poll<io::Result<usize>> {
42 self.poll_limited(cx, (), length_of_result_usize, |r, cx, ()| {
43 r.poll_write(cx, buf)
44 })
45 }
46
47 fn poll_write_vectored(
48 self: Pin<&mut Self>,
49 cx: &mut Context<'_>,
50 bufs: &[IoSlice<'_>],
51 ) -> Poll<io::Result<usize>> {
52 self.poll_limited(cx, (), length_of_result_usize, |r, cx, ()| {
53 r.poll_write_vectored(cx, bufs)
54 })
55 }
56
57 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
58 self.get_pin_mut().poll_flush(cx)
59 }
60
61 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
62 self.get_pin_mut().poll_close(cx)
63 }
64}
65
66#[cfg(feature = "tokio")]
67impl<R: tokio::io::AsyncRead, C: Clock> tokio::io::AsyncRead for Resource<R, C> {
68 fn poll_read(
69 self: Pin<&mut Self>,
70 cx: &mut Context<'_>,
71 buf: &mut tokio::io::ReadBuf<'_>,
72 ) -> Poll<io::Result<()>> {
73 let filled = buf.filled().len();
74
75 self.poll_limited(
76 cx,
77 buf,
78 |res: &io::Result<()>, buf| {
79 res.is_ok()
80 .then(|| buf.filled().len() - filled)
81 .unwrap_or_default()
82 },
83 |r, cx, buf| r.poll_read(cx, buf),
84 )
85 }
86}
87
88#[cfg(feature = "tokio")]
89impl<R: tokio::io::AsyncWrite, C: Clock> tokio::io::AsyncWrite for Resource<R, C> {
90 fn poll_write(
91 self: Pin<&mut Self>,
92 cx: &mut Context<'_>,
93 buf: &[u8],
94 ) -> Poll<io::Result<usize>> {
95 self.poll_limited(cx, (), length_of_result_usize, |r, cx, _| {
96 r.poll_write(cx, buf)
97 })
98 }
99
100 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101 self.get_pin_mut().poll_flush(cx)
102 }
103
104 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
105 self.get_pin_mut().poll_shutdown(cx)
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::{
113 clock::{ManualClock, Nanoseconds},
114 Limiter,
115 };
116 use futures_executor::LocalPool;
117 use futures_util::{
118 io::{copy_buf, BufReader},
119 task::SpawnExt,
120 };
121 use rand::{thread_rng, RngCore};
122
123 #[test]
124 fn limited_read() {
125 let mut pool = LocalPool::new();
126 let sp = pool.spawner();
127
128 let limiter = Limiter::<ManualClock>::new(512.0);
129 let clock = limiter.clock();
130
131 sp.spawn({
132 let limiter = limiter.clone();
133 let clock = clock.clone();
134 async move {
135 let mut src = vec![0_u8; 1024];
136 thread_rng().fill_bytes(&mut src);
137 let mut dst = Vec::new();
138
139 let read = BufReader::with_capacity(256, limiter.limit(&*src));
140 let count = copy_buf(read, &mut dst).await.unwrap();
141
142 assert_eq!(clock.now(), Nanoseconds(2_000_000_000));
143 assert_eq!(count, src.len() as u64);
144 assert!(src == dst);
145 }
146 })
147 .unwrap();
148
149 clock.set_time(Nanoseconds(0));
150 pool.run_until_stalled();
151 assert_eq!(limiter.total_bytes_consumed(), 256);
152
153 clock.set_time(Nanoseconds(500_000_000));
154 pool.run_until_stalled();
155 assert_eq!(limiter.total_bytes_consumed(), 512);
156
157 clock.set_time(Nanoseconds(1_000_000_000));
158 pool.run_until_stalled();
159 assert_eq!(limiter.total_bytes_consumed(), 768);
160
161 clock.set_time(Nanoseconds(1_500_000_000));
162 pool.run_until_stalled();
163 assert_eq!(limiter.total_bytes_consumed(), 1024);
164
165 clock.set_time(Nanoseconds(2_000_000_000));
166 pool.run_until_stalled();
167
168 assert!(!pool.try_run_one());
169 }
170
171 #[test]
172 fn unlimited_read() {
173 let mut pool = LocalPool::new();
174 let sp = pool.spawner();
175
176 let limiter = Limiter::<ManualClock>::new(std::f64::INFINITY);
177
178 sp.spawn({
179 async move {
180 let mut src = vec![0_u8; 1024];
181 thread_rng().fill_bytes(&mut src);
182 let mut dst = Vec::new();
183
184 let read = BufReader::with_capacity(256, limiter.limit(&*src));
185 let count = copy_buf(read, &mut dst).await.unwrap();
186
187 assert_eq!(count, src.len() as u64);
188 assert!(src == dst);
189 }
190 })
191 .unwrap();
192
193 pool.run_until_stalled();
194 assert!(!pool.try_run_one());
195 }
196
197 #[test]
198 fn limited_write() {
199 let mut pool = LocalPool::new();
200 let sp = pool.spawner();
201
202 let limiter = Limiter::<ManualClock>::new(512.0);
203 let clock = limiter.clock();
204
205 sp.spawn({
206 let limiter = limiter.clone();
207 let clock = clock.clone();
208 async move {
209 let mut src = vec![0_u8; 1024];
210 thread_rng().fill_bytes(&mut src);
211
212 let read = BufReader::with_capacity(256, &*src);
213 let mut write = limiter.limit(Vec::new());
214 let count = copy_buf(read, &mut write).await.unwrap();
215
216 assert_eq!(clock.now(), Nanoseconds(1_500_000_000));
217 assert_eq!(count, src.len() as u64);
218 assert!(src == write.into_inner());
219 }
220 })
221 .unwrap();
222
223 clock.set_time(Nanoseconds(0));
224 pool.run_until_stalled();
225 assert_eq!(limiter.total_bytes_consumed(), 256);
226
227 clock.set_time(Nanoseconds(500_000_000));
228 pool.run_until_stalled();
229 assert_eq!(limiter.total_bytes_consumed(), 512);
230
231 clock.set_time(Nanoseconds(1_000_000_000));
232 pool.run_until_stalled();
233 assert_eq!(limiter.total_bytes_consumed(), 768);
234
235 clock.set_time(Nanoseconds(1_500_000_000));
236 pool.run_until_stalled();
237 assert_eq!(limiter.total_bytes_consumed(), 1024);
238
239 clock.set_time(Nanoseconds(2_000_000_000));
240 pool.run_until_stalled();
241
242 assert!(!pool.try_run_one());
243 }
244}
245
246#[cfg(test)]
247#[cfg(all(feature = "tokio", feature = "standard-clock"))]
248mod tokio_tests {
249 use std::{
250 io,
251 time::{Duration, Instant},
252 };
253
254 use crate::Limiter;
255
256 use futures_util::{future::ok, TryStreamExt as _};
257 use tokio::io::{copy, repeat, sink, AsyncReadExt as _, AsyncWriteExt as _};
258 use tokio_util::codec::{BytesCodec, FramedRead};
259
260 #[tokio::test]
261 async fn limited_read() -> io::Result<()> {
262 let limiter = <Limiter>::new(32768.0);
263
264 let start_time = Instant::now();
265
266 let reader = repeat(50).take(65536);
267 let mut reader = limiter.limit(reader);
268
269 let mut sink = sink();
270 let total = copy(&mut reader, &mut sink).await?;
271 sink.shutdown().await?;
272
273 let elapsed = start_time.elapsed();
274
275 assert!(
276 Duration::from_millis(1900) <= elapsed && elapsed <= Duration::from_millis(2100),
277 "elapsed = {:?}",
278 elapsed
279 );
280 assert_eq!(total, 65536);
281
282 Ok(())
283 }
284
285 #[tokio::test]
286 async fn unlimited_read() -> io::Result<()> {
287 let limiter = <Limiter>::new(std::f64::INFINITY);
288
289 let start_time = Instant::now();
290
291 let reader = repeat(50).take(65536);
292 let mut reader = limiter.limit(reader);
293 let mut sink = sink();
294
295 let total = copy(&mut reader, &mut sink).await?;
296 sink.shutdown().await?;
297
298 let elapsed = start_time.elapsed();
299
300 assert!(
301 elapsed <= Duration::from_millis(100),
302 "elapsed = {:?}",
303 elapsed
304 );
305 assert_eq!(total, 65536);
306
307 Ok(())
308 }
309
310 #[tokio::test]
311 async fn limited_read_byte_stream() -> io::Result<()> {
312 let limiter = <Limiter>::new(30000.0);
313
314 let start_time = Instant::now();
315
316 let reader = repeat(50).take(60000);
317 let reader = limiter.limit(reader);
318
319 let total = FramedRead::new(reader, BytesCodec::new())
320 .try_fold(0, |i, j| {
321 assert!(j.iter().all(|b| *b == 50_u8), "{} / {:?}", i, j);
322 ok(i + j.len())
323 })
324 .await?;
325
326 let elapsed = start_time.elapsed();
327
328 assert!(
329 Duration::from_millis(1900) <= elapsed && elapsed <= Duration::from_millis(2100),
330 "elapsed = {:?}",
331 elapsed
332 );
333 assert_eq!(total, 60000);
334
335 Ok(())
336 }
337}