futures_ext/stream/
weight_limited_buffered_stream.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under both the MIT license found in the
5 * LICENSE-MIT file in the root directory of this source tree and the Apache
6 * License, Version 2.0 found in the LICENSE-APACHE file in the root directory
7 * of this source tree.
8 */
9
10use std::pin::Pin;
11
12use futures::future;
13use futures::future::BoxFuture;
14use futures::ready;
15use futures::stream;
16use futures::task::Context;
17use futures::task::Poll;
18use futures::Future;
19use futures::FutureExt;
20use futures::Stream;
21use futures::StreamExt;
22use futures::TryStream;
23use pin_project::pin_project;
24
25/// Params for [crate::FbStreamExt::buffered_weight_limited] and [WeightLimitedBufferedStream]
26#[derive(Clone, Copy, Debug)]
27pub struct BufferedParams {
28    /// Limit for the sum of weights in the [WeightLimitedBufferedStream] stream
29    pub weight_limit: u64,
30    /// Limit for size of buffer in the [WeightLimitedBufferedStream] stream
31    pub buffer_size: usize,
32}
33
34/// Like [stream::Buffered], but can also limit number of futures in a buffer by "weight".
35#[pin_project]
36pub struct WeightLimitedBufferedStream<'a, S, I> {
37    #[pin]
38    queue: stream::FuturesOrdered<BoxFuture<'a, (I, u64)>>,
39    current_weight: u64,
40    weight_limit: u64,
41    max_buffer_size: usize,
42    #[pin]
43    stream: stream::Fuse<S>,
44}
45
46impl<S, I> WeightLimitedBufferedStream<'_, S, I>
47where
48    S: Stream,
49{
50    /// Create a new instance that will be configured using the `params` provided
51    pub fn new(params: BufferedParams, stream: S) -> Self {
52        Self {
53            queue: stream::FuturesOrdered::new(),
54            current_weight: 0,
55            weight_limit: params.weight_limit,
56            max_buffer_size: params.buffer_size,
57            stream: stream.fuse(),
58        }
59    }
60}
61
62impl<'a, S, Fut, I: 'a> Stream for WeightLimitedBufferedStream<'a, S, I>
63where
64    S: Stream<Item = (Fut, u64)>,
65    Fut: Future<Output = I> + Send + 'a,
66{
67    type Item = I;
68
69    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
70        let mut this = self.project();
71
72        // First up, try to spawn off as many futures as possible by filling up
73        // our slab of futures.
74        while this.queue.len() < *this.max_buffer_size && this.current_weight < this.weight_limit {
75            let future = match this.stream.as_mut().poll_next(cx) {
76                Poll::Ready(Some((f, weight))) => {
77                    *this.current_weight += weight;
78                    f.map(move |val| (val, weight)).boxed()
79                }
80                Poll::Ready(None) | Poll::Pending => break,
81            };
82
83            this.queue.push_back(future);
84        }
85
86        // Try polling a new future
87        if let Some((val, weight)) = ready!(this.queue.poll_next(cx)) {
88            *this.current_weight -= weight;
89            return Poll::Ready(Some(val));
90        }
91
92        // If we've gotten this far, then there are no events for us to process
93        // and nothing was ready, so figure out if we're not done yet or if
94        // we've reached the end.
95        if this.stream.is_done() {
96            Poll::Ready(None)
97        } else {
98            Poll::Pending
99        }
100    }
101}
102
103/// Like [stream::Buffered], but is for TryStream and can also
104/// limit number of futures in a buffer by "weight"
105#[pin_project]
106pub struct WeightLimitedBufferedTryStream<'a, S, I, E> {
107    #[pin]
108    queue: stream::FuturesOrdered<BoxFuture<'a, (Result<I, E>, u64)>>,
109    current_weight: u64,
110    weight_limit: u64,
111    max_buffer_size: usize,
112    #[pin]
113    stream: stream::Fuse<S>,
114}
115
116impl<S, I, E> WeightLimitedBufferedTryStream<'_, S, I, E>
117where
118    S: TryStream,
119{
120    /// Create a new instance that will be configured using the `params` provided
121    pub fn new(params: BufferedParams, stream: S) -> Self {
122        Self {
123            queue: stream::FuturesOrdered::new(),
124            current_weight: 0,
125            weight_limit: params.weight_limit,
126            max_buffer_size: params.buffer_size,
127            stream: stream.fuse(),
128        }
129    }
130}
131
132impl<'a, S, Fut, I: 'a, E> Stream for WeightLimitedBufferedTryStream<'a, S, I, E>
133where
134    S: Stream<Item = Result<(Fut, u64), E>>,
135    Fut: Future<Output = Result<I, E>> + Send + 'a,
136    E: Send + 'a,
137    I: Send,
138{
139    type Item = Result<I, E>;
140
141    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
142        let mut this = self.project();
143
144        // First up, try to spawn off as many futures as possible by filling up
145        // our slab of futures.
146        while this.queue.len() < *this.max_buffer_size && this.current_weight < this.weight_limit {
147            let future = match this.stream.as_mut().poll_next(cx) {
148                Poll::Ready(Some(Ok((f, weight)))) => {
149                    *this.current_weight += weight;
150                    f.map(move |val| (val, weight)).boxed()
151                }
152                Poll::Ready(Some(Err(e))) => {
153                    // We failed to even get the weight of the future
154                    // Let's record the failure in the queue instead
155                    // of returning error from the stream now. Otherwise
156                    // the error returned now may actually correspond
157                    // to a future for which we succeeded querying weight.
158                    // Note: this behavior is different from what we had
159                    //       in `WeightLimitedBufferedStream` for Stream 0.1
160                    //       but IMO it's more correct, as the stream can
161                    //       keep returning successes after an error
162                    future::ready((Err(e), 0u64)).boxed()
163                }
164                Poll::Ready(None) | Poll::Pending => break,
165            };
166
167            this.queue.push_back(future);
168        }
169
170        // Try polling a new future
171        if let Some((val, weight)) = ready!(this.queue.poll_next(cx)) {
172            *this.current_weight -= weight;
173            return Poll::Ready(Some(val));
174        }
175
176        // If we've gotten this far, then there are no events for us to process
177        // and nothing was ready, so figure out if we're not done yet or if
178        // we've reached the end.
179        if this.stream.is_done() {
180            Poll::Ready(None)
181        } else {
182            Poll::Pending
183        }
184    }
185}
186
187#[cfg(test)]
188mod test {
189    use std::sync::atomic::AtomicUsize;
190    use std::sync::atomic::Ordering;
191    use std::sync::Arc;
192
193    use futures::future;
194    use futures::future::BoxFuture;
195    use futures::stream;
196    use futures::stream::BoxStream;
197    use futures::FutureExt;
198    use futures::StreamExt;
199
200    use super::*;
201
202    type TestStream = BoxStream<'static, (BoxFuture<'static, ()>, u64)>;
203
204    fn create_stream() -> (Arc<AtomicUsize>, TestStream) {
205        let s: TestStream = stream::iter(vec![
206            (future::ready(()).boxed(), 100),
207            (future::ready(()).boxed(), 2),
208            (future::ready(()).boxed(), 7),
209        ])
210        .boxed();
211
212        let counter = Arc::new(AtomicUsize::new(0));
213
214        (
215            counter.clone(),
216            s.inspect({
217                move |_val| {
218                    counter.fetch_add(1, Ordering::SeqCst);
219                }
220            })
221            .boxed(),
222        )
223    }
224
225    #[tokio::test]
226    async fn test_too_much_weight_to_do_in_one_go() {
227        let (counter, s) = create_stream();
228        let params = BufferedParams {
229            weight_limit: 10,
230            buffer_size: 10,
231        };
232        let s = WeightLimitedBufferedStream::new(params, s);
233
234        if let (Some(()), s) = s.into_future().await {
235            assert_eq!(counter.load(Ordering::SeqCst), 1);
236            assert_eq!(s.collect::<Vec<()>>().await.len(), 2);
237            assert_eq!(counter.load(Ordering::SeqCst), 3);
238        } else {
239            panic!("Stream did not produce even a single value");
240        }
241    }
242
243    #[tokio::test]
244    async fn test_all_in_one_go() {
245        let (counter, s) = create_stream();
246        let params = BufferedParams {
247            weight_limit: 200,
248            buffer_size: 10,
249        };
250        let s = WeightLimitedBufferedStream::new(params, s);
251
252        if let (Some(()), s) = s.into_future().await {
253            assert_eq!(counter.load(Ordering::SeqCst), 3);
254            assert_eq!(s.collect::<Vec<()>>().await.len(), 2);
255            assert_eq!(counter.load(Ordering::SeqCst), 3);
256        } else {
257            panic!("Stream did not produce even a single value");
258        }
259    }
260
261    #[tokio::test]
262    async fn test_too_much_items_to_do_in_one_go() {
263        let (counter, s) = create_stream();
264        let params = BufferedParams {
265            weight_limit: 1000,
266            buffer_size: 2,
267        };
268        let s = WeightLimitedBufferedStream::new(params, s);
269
270        if let (Some(()), s) = s.into_future().await {
271            assert_eq!(counter.load(Ordering::SeqCst), 2);
272            assert_eq!(s.collect::<Vec<()>>().await.len(), 2);
273            assert_eq!(counter.load(Ordering::SeqCst), 3);
274        } else {
275            panic!("Stream did not produce even a single value");
276        }
277    }
278
279    type Error = String;
280    type TestTryStream =
281        BoxStream<'static, Result<(BoxFuture<'static, Result<(), Error>>, u64), Error>>;
282
283    fn counted_try_stream(s: TestTryStream) -> (Arc<AtomicUsize>, TestTryStream) {
284        let counter = Arc::new(AtomicUsize::new(0));
285
286        (
287            counter.clone(),
288            s.inspect({
289                move |_val| {
290                    counter.fetch_add(1, Ordering::SeqCst);
291                }
292            })
293            .boxed(),
294        )
295    }
296
297    fn create_try_stream_all_good() -> (Arc<AtomicUsize>, TestTryStream) {
298        let s: TestTryStream = stream::iter(vec![
299            Ok((future::ready(Ok(())).boxed(), 100)),
300            Ok((future::ready(Ok(())).boxed(), 2)),
301            Ok((future::ready(Ok(())).boxed(), 7)),
302        ])
303        .boxed();
304
305        counted_try_stream(s)
306    }
307
308    #[tokio::test]
309    async fn test_try_all_in_one_go() {
310        let (counter, s) = create_try_stream_all_good();
311        let params = BufferedParams {
312            weight_limit: 200,
313            buffer_size: 10,
314        };
315        let s = WeightLimitedBufferedTryStream::new(params, s);
316
317        if let (Some(Ok(())), s) = s.into_future().await {
318            assert_eq!(counter.load(Ordering::SeqCst), 3);
319            assert_eq!(s.collect::<Vec<_>>().await.len(), 2);
320            assert_eq!(counter.load(Ordering::SeqCst), 3);
321        } else {
322            panic!("Stream did not produce even a single value");
323        }
324    }
325
326    #[tokio::test]
327    async fn test_try_too_much_weight_to_do_in_one_go() {
328        let (counter, s) = create_try_stream_all_good();
329        let params = BufferedParams {
330            weight_limit: 10,
331            buffer_size: 10,
332        };
333        let s = WeightLimitedBufferedTryStream::new(params, s);
334
335        if let (Some(Ok(())), s) = s.into_future().await {
336            assert_eq!(counter.load(Ordering::SeqCst), 1);
337            assert_eq!(s.collect::<Vec<_>>().await.len(), 2);
338            assert_eq!(counter.load(Ordering::SeqCst), 3);
339        } else {
340            panic!("Stream did not produce even a single value");
341        }
342    }
343
344    #[tokio::test]
345    async fn test_try_too_much_items_to_do_in_one_go() {
346        let (counter, s) = create_try_stream_all_good();
347        let params = BufferedParams {
348            weight_limit: 1000,
349            buffer_size: 2,
350        };
351        let s = WeightLimitedBufferedTryStream::new(params, s);
352
353        if let (Some(Ok(())), s) = s.into_future().await {
354            assert_eq!(counter.load(Ordering::SeqCst), 2);
355            assert_eq!(s.collect::<Vec<_>>().await.len(), 2);
356            assert_eq!(counter.load(Ordering::SeqCst), 3);
357        } else {
358            panic!("Stream did not produce even a single value");
359        }
360    }
361
362    fn create_try_stream_fail_external() -> (Arc<AtomicUsize>, TestTryStream) {
363        let s: TestTryStream = stream::iter(vec![
364            Ok((future::ready(Ok(())).boxed(), 100)),
365            Err("failed to calculate weight".to_string()),
366            Ok((future::ready(Ok(())).boxed(), 7)),
367        ])
368        .boxed();
369
370        counted_try_stream(s)
371    }
372
373    #[tokio::test]
374    async fn test_try_fail_to_calculate_weight() {
375        let (counter, s) = create_try_stream_fail_external();
376        let params = BufferedParams {
377            weight_limit: 1000,
378            buffer_size: 2,
379        };
380        let s = WeightLimitedBufferedTryStream::new(params, s);
381
382        if let (Some(Ok(())), s) = s.into_future().await {
383            // Producting the very first value caused a buffer
384            // to be filled with 2 futures
385            assert_eq!(counter.load(Ordering::SeqCst), 2);
386            let v = s.collect::<Vec<Result<_, _>>>().await;
387            // Second element of the resulting stream is an
388            // error, since we could not even calculate its
389            // weithg and get its future
390            assert!(v[0].is_err());
391            assert!(
392                v[0].clone()
393                    .unwrap_err()
394                    .contains("failed to calculate weight")
395            );
396            // Third element of the resulting stream was
397            // successfully produced
398            assert_eq!(v[1], Ok(()));
399            assert_eq!(v.len(), 2);
400            // Collecting the while resulting stream caused
401            // 3 elements of the inner stream to be polled
402            assert_eq!(counter.load(Ordering::SeqCst), 3);
403        } else {
404            panic!("Stream did not produce even a single value");
405        }
406    }
407
408    fn create_try_stream_fail_internal() -> (Arc<AtomicUsize>, TestTryStream) {
409        let s: TestTryStream = stream::iter(vec![
410            Ok((future::ready(Ok(())).boxed(), 100)),
411            Ok((
412                future::ready(Err("failed to produce interesting value".to_string())).boxed(),
413                2,
414            )),
415            Ok((future::ready(Ok(())).boxed(), 7)),
416        ])
417        .boxed();
418
419        counted_try_stream(s)
420    }
421
422    #[tokio::test]
423    async fn test_try_fail_to_calculate_inner_value() {
424        let (counter, s) = create_try_stream_fail_internal();
425        let params = BufferedParams {
426            weight_limit: 1000,
427            buffer_size: 2,
428        };
429        let s = WeightLimitedBufferedTryStream::new(params, s);
430
431        if let (Some(Ok(())), s) = s.into_future().await {
432            // Producting the very first value caused a buffer
433            // to be filled with 2 futures
434            assert_eq!(counter.load(Ordering::SeqCst), 2);
435            let v = s.collect::<Vec<Result<_, _>>>().await;
436            // Second element of the resulting stream is an
437            // error
438            assert!(v[0].is_err());
439            assert!(
440                v[0].clone()
441                    .unwrap_err()
442                    .contains("failed to produce interesting value")
443            );
444            // Third element of the resulting stream was
445            // successfully produced
446            assert_eq!(v[1], Ok(()));
447            assert_eq!(v.len(), 2);
448            // Collecting the while resulting stream caused
449            // 3 elements of the inner stream to be polled
450            assert_eq!(counter.load(Ordering::SeqCst), 3);
451        } else {
452            panic!("Stream did not produce even a single value");
453        }
454    }
455}