1use std::pin::Pin;
11
12use futures::future;
13use futures::future::BoxFuture;
14use futures::ready;
15use futures::stream;
16use futures::task::Context;
17use futures::task::Poll;
18use futures::Future;
19use futures::FutureExt;
20use futures::Stream;
21use futures::StreamExt;
22use futures::TryStream;
23use pin_project::pin_project;
24
25#[derive(Clone, Copy, Debug)]
27pub struct BufferedParams {
28 pub weight_limit: u64,
30 pub buffer_size: usize,
32}
33
34#[pin_project]
36pub struct WeightLimitedBufferedStream<'a, S, I> {
37 #[pin]
38 queue: stream::FuturesOrdered<BoxFuture<'a, (I, u64)>>,
39 current_weight: u64,
40 weight_limit: u64,
41 max_buffer_size: usize,
42 #[pin]
43 stream: stream::Fuse<S>,
44}
45
46impl<S, I> WeightLimitedBufferedStream<'_, S, I>
47where
48 S: Stream,
49{
50 pub fn new(params: BufferedParams, stream: S) -> Self {
52 Self {
53 queue: stream::FuturesOrdered::new(),
54 current_weight: 0,
55 weight_limit: params.weight_limit,
56 max_buffer_size: params.buffer_size,
57 stream: stream.fuse(),
58 }
59 }
60}
61
62impl<'a, S, Fut, I: 'a> Stream for WeightLimitedBufferedStream<'a, S, I>
63where
64 S: Stream<Item = (Fut, u64)>,
65 Fut: Future<Output = I> + Send + 'a,
66{
67 type Item = I;
68
69 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
70 let mut this = self.project();
71
72 while this.queue.len() < *this.max_buffer_size && this.current_weight < this.weight_limit {
75 let future = match this.stream.as_mut().poll_next(cx) {
76 Poll::Ready(Some((f, weight))) => {
77 *this.current_weight += weight;
78 f.map(move |val| (val, weight)).boxed()
79 }
80 Poll::Ready(None) | Poll::Pending => break,
81 };
82
83 this.queue.push_back(future);
84 }
85
86 if let Some((val, weight)) = ready!(this.queue.poll_next(cx)) {
88 *this.current_weight -= weight;
89 return Poll::Ready(Some(val));
90 }
91
92 if this.stream.is_done() {
96 Poll::Ready(None)
97 } else {
98 Poll::Pending
99 }
100 }
101}
102
103#[pin_project]
106pub struct WeightLimitedBufferedTryStream<'a, S, I, E> {
107 #[pin]
108 queue: stream::FuturesOrdered<BoxFuture<'a, (Result<I, E>, u64)>>,
109 current_weight: u64,
110 weight_limit: u64,
111 max_buffer_size: usize,
112 #[pin]
113 stream: stream::Fuse<S>,
114}
115
116impl<S, I, E> WeightLimitedBufferedTryStream<'_, S, I, E>
117where
118 S: TryStream,
119{
120 pub fn new(params: BufferedParams, stream: S) -> Self {
122 Self {
123 queue: stream::FuturesOrdered::new(),
124 current_weight: 0,
125 weight_limit: params.weight_limit,
126 max_buffer_size: params.buffer_size,
127 stream: stream.fuse(),
128 }
129 }
130}
131
132impl<'a, S, Fut, I: 'a, E> Stream for WeightLimitedBufferedTryStream<'a, S, I, E>
133where
134 S: Stream<Item = Result<(Fut, u64), E>>,
135 Fut: Future<Output = Result<I, E>> + Send + 'a,
136 E: Send + 'a,
137 I: Send,
138{
139 type Item = Result<I, E>;
140
141 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
142 let mut this = self.project();
143
144 while this.queue.len() < *this.max_buffer_size && this.current_weight < this.weight_limit {
147 let future = match this.stream.as_mut().poll_next(cx) {
148 Poll::Ready(Some(Ok((f, weight)))) => {
149 *this.current_weight += weight;
150 f.map(move |val| (val, weight)).boxed()
151 }
152 Poll::Ready(Some(Err(e))) => {
153 future::ready((Err(e), 0u64)).boxed()
163 }
164 Poll::Ready(None) | Poll::Pending => break,
165 };
166
167 this.queue.push_back(future);
168 }
169
170 if let Some((val, weight)) = ready!(this.queue.poll_next(cx)) {
172 *this.current_weight -= weight;
173 return Poll::Ready(Some(val));
174 }
175
176 if this.stream.is_done() {
180 Poll::Ready(None)
181 } else {
182 Poll::Pending
183 }
184 }
185}
186
187#[cfg(test)]
188mod test {
189 use std::sync::atomic::AtomicUsize;
190 use std::sync::atomic::Ordering;
191 use std::sync::Arc;
192
193 use futures::future;
194 use futures::future::BoxFuture;
195 use futures::stream;
196 use futures::stream::BoxStream;
197 use futures::FutureExt;
198 use futures::StreamExt;
199
200 use super::*;
201
202 type TestStream = BoxStream<'static, (BoxFuture<'static, ()>, u64)>;
203
204 fn create_stream() -> (Arc<AtomicUsize>, TestStream) {
205 let s: TestStream = stream::iter(vec![
206 (future::ready(()).boxed(), 100),
207 (future::ready(()).boxed(), 2),
208 (future::ready(()).boxed(), 7),
209 ])
210 .boxed();
211
212 let counter = Arc::new(AtomicUsize::new(0));
213
214 (
215 counter.clone(),
216 s.inspect({
217 move |_val| {
218 counter.fetch_add(1, Ordering::SeqCst);
219 }
220 })
221 .boxed(),
222 )
223 }
224
225 #[tokio::test]
226 async fn test_too_much_weight_to_do_in_one_go() {
227 let (counter, s) = create_stream();
228 let params = BufferedParams {
229 weight_limit: 10,
230 buffer_size: 10,
231 };
232 let s = WeightLimitedBufferedStream::new(params, s);
233
234 if let (Some(()), s) = s.into_future().await {
235 assert_eq!(counter.load(Ordering::SeqCst), 1);
236 assert_eq!(s.collect::<Vec<()>>().await.len(), 2);
237 assert_eq!(counter.load(Ordering::SeqCst), 3);
238 } else {
239 panic!("Stream did not produce even a single value");
240 }
241 }
242
243 #[tokio::test]
244 async fn test_all_in_one_go() {
245 let (counter, s) = create_stream();
246 let params = BufferedParams {
247 weight_limit: 200,
248 buffer_size: 10,
249 };
250 let s = WeightLimitedBufferedStream::new(params, s);
251
252 if let (Some(()), s) = s.into_future().await {
253 assert_eq!(counter.load(Ordering::SeqCst), 3);
254 assert_eq!(s.collect::<Vec<()>>().await.len(), 2);
255 assert_eq!(counter.load(Ordering::SeqCst), 3);
256 } else {
257 panic!("Stream did not produce even a single value");
258 }
259 }
260
261 #[tokio::test]
262 async fn test_too_much_items_to_do_in_one_go() {
263 let (counter, s) = create_stream();
264 let params = BufferedParams {
265 weight_limit: 1000,
266 buffer_size: 2,
267 };
268 let s = WeightLimitedBufferedStream::new(params, s);
269
270 if let (Some(()), s) = s.into_future().await {
271 assert_eq!(counter.load(Ordering::SeqCst), 2);
272 assert_eq!(s.collect::<Vec<()>>().await.len(), 2);
273 assert_eq!(counter.load(Ordering::SeqCst), 3);
274 } else {
275 panic!("Stream did not produce even a single value");
276 }
277 }
278
279 type Error = String;
280 type TestTryStream =
281 BoxStream<'static, Result<(BoxFuture<'static, Result<(), Error>>, u64), Error>>;
282
283 fn counted_try_stream(s: TestTryStream) -> (Arc<AtomicUsize>, TestTryStream) {
284 let counter = Arc::new(AtomicUsize::new(0));
285
286 (
287 counter.clone(),
288 s.inspect({
289 move |_val| {
290 counter.fetch_add(1, Ordering::SeqCst);
291 }
292 })
293 .boxed(),
294 )
295 }
296
297 fn create_try_stream_all_good() -> (Arc<AtomicUsize>, TestTryStream) {
298 let s: TestTryStream = stream::iter(vec![
299 Ok((future::ready(Ok(())).boxed(), 100)),
300 Ok((future::ready(Ok(())).boxed(), 2)),
301 Ok((future::ready(Ok(())).boxed(), 7)),
302 ])
303 .boxed();
304
305 counted_try_stream(s)
306 }
307
308 #[tokio::test]
309 async fn test_try_all_in_one_go() {
310 let (counter, s) = create_try_stream_all_good();
311 let params = BufferedParams {
312 weight_limit: 200,
313 buffer_size: 10,
314 };
315 let s = WeightLimitedBufferedTryStream::new(params, s);
316
317 if let (Some(Ok(())), s) = s.into_future().await {
318 assert_eq!(counter.load(Ordering::SeqCst), 3);
319 assert_eq!(s.collect::<Vec<_>>().await.len(), 2);
320 assert_eq!(counter.load(Ordering::SeqCst), 3);
321 } else {
322 panic!("Stream did not produce even a single value");
323 }
324 }
325
326 #[tokio::test]
327 async fn test_try_too_much_weight_to_do_in_one_go() {
328 let (counter, s) = create_try_stream_all_good();
329 let params = BufferedParams {
330 weight_limit: 10,
331 buffer_size: 10,
332 };
333 let s = WeightLimitedBufferedTryStream::new(params, s);
334
335 if let (Some(Ok(())), s) = s.into_future().await {
336 assert_eq!(counter.load(Ordering::SeqCst), 1);
337 assert_eq!(s.collect::<Vec<_>>().await.len(), 2);
338 assert_eq!(counter.load(Ordering::SeqCst), 3);
339 } else {
340 panic!("Stream did not produce even a single value");
341 }
342 }
343
344 #[tokio::test]
345 async fn test_try_too_much_items_to_do_in_one_go() {
346 let (counter, s) = create_try_stream_all_good();
347 let params = BufferedParams {
348 weight_limit: 1000,
349 buffer_size: 2,
350 };
351 let s = WeightLimitedBufferedTryStream::new(params, s);
352
353 if let (Some(Ok(())), s) = s.into_future().await {
354 assert_eq!(counter.load(Ordering::SeqCst), 2);
355 assert_eq!(s.collect::<Vec<_>>().await.len(), 2);
356 assert_eq!(counter.load(Ordering::SeqCst), 3);
357 } else {
358 panic!("Stream did not produce even a single value");
359 }
360 }
361
362 fn create_try_stream_fail_external() -> (Arc<AtomicUsize>, TestTryStream) {
363 let s: TestTryStream = stream::iter(vec![
364 Ok((future::ready(Ok(())).boxed(), 100)),
365 Err("failed to calculate weight".to_string()),
366 Ok((future::ready(Ok(())).boxed(), 7)),
367 ])
368 .boxed();
369
370 counted_try_stream(s)
371 }
372
373 #[tokio::test]
374 async fn test_try_fail_to_calculate_weight() {
375 let (counter, s) = create_try_stream_fail_external();
376 let params = BufferedParams {
377 weight_limit: 1000,
378 buffer_size: 2,
379 };
380 let s = WeightLimitedBufferedTryStream::new(params, s);
381
382 if let (Some(Ok(())), s) = s.into_future().await {
383 assert_eq!(counter.load(Ordering::SeqCst), 2);
386 let v = s.collect::<Vec<Result<_, _>>>().await;
387 assert!(v[0].is_err());
391 assert!(
392 v[0].clone()
393 .unwrap_err()
394 .contains("failed to calculate weight")
395 );
396 assert_eq!(v[1], Ok(()));
399 assert_eq!(v.len(), 2);
400 assert_eq!(counter.load(Ordering::SeqCst), 3);
403 } else {
404 panic!("Stream did not produce even a single value");
405 }
406 }
407
408 fn create_try_stream_fail_internal() -> (Arc<AtomicUsize>, TestTryStream) {
409 let s: TestTryStream = stream::iter(vec![
410 Ok((future::ready(Ok(())).boxed(), 100)),
411 Ok((
412 future::ready(Err("failed to produce interesting value".to_string())).boxed(),
413 2,
414 )),
415 Ok((future::ready(Ok(())).boxed(), 7)),
416 ])
417 .boxed();
418
419 counted_try_stream(s)
420 }
421
422 #[tokio::test]
423 async fn test_try_fail_to_calculate_inner_value() {
424 let (counter, s) = create_try_stream_fail_internal();
425 let params = BufferedParams {
426 weight_limit: 1000,
427 buffer_size: 2,
428 };
429 let s = WeightLimitedBufferedTryStream::new(params, s);
430
431 if let (Some(Ok(())), s) = s.into_future().await {
432 assert_eq!(counter.load(Ordering::SeqCst), 2);
435 let v = s.collect::<Vec<Result<_, _>>>().await;
436 assert!(v[0].is_err());
439 assert!(
440 v[0].clone()
441 .unwrap_err()
442 .contains("failed to produce interesting value")
443 );
444 assert_eq!(v[1], Ok(()));
447 assert_eq!(v.len(), 2);
448 assert_eq!(counter.load(Ordering::SeqCst), 3);
451 } else {
452 panic!("Stream did not produce even a single value");
453 }
454 }
455}