paging_stream/
lib.rs

1//! A utility to simplify consuming paginated data sources as a `futures::Stream`.
2//!
3//! This crate provides the `Paginated` trait, which you implement for your API client
4//! or repository, and the `PagingStream` struct, which wraps your type and yields
5//! items as a stream. This allows consumers of your API to work with a continuous
6//! stream of data, abstracting away the underlying pagination logic.
7//!
8//! ## Example
9//!
10//! ```rust
11//! use futures::{StreamExt, future::Future};
12//! use paging_stream::{Paginated, PagingStream};
13//!
14//! // 1. Define a client/repository struct.
15//! struct MyApiClient;
16//!
17//! // 2. Define your types for parameters, items, and errors.
18//! struct MyParams {
19//!     since: usize,
20//!     until: usize,
21//!     limit: usize
22//! }
23//!
24//! // 3. Implement the `Paginated` trait for your client.
25//! impl Paginated for MyApiClient {
26//!     type Params = MyParams;
27//!     type Item = usize;
28//!     type Error = ();
29//!
30//!     fn fetch_page(
31//!         &self,
32//!         params: Self::Params,
33//!     ) -> impl Future<Output = Result<(Vec<Self::Item>, Option<Self::Params>), Self::Error>>
34//!     + Send
35//!     + 'static {
36//!         async move {
37//!             // Replace with your actual asynchronous data fetching logic.
38//!             //
39//!             // - `params`: Contains the necessary information to fetch the current page.
40//!             // - Return `Ok((items, next_params))` where:
41//!             //   - `items`: A `Vec` of fetched items for the current page.
42//!             //   - `next_params`: An `Option<Self::Params>`:
43//!             //     - `Some(params)`: Contains the parameters needed to fetch the *next* page.
44//!             //     - `None`: Signifies that there are no more pages.
45//!             // - Return `Err(your_error)` if fetching fails.
46//!            Ok((Vec::new(), None)) // Placeholder for example
47//!         }
48//!     }
49//! }
50//!
51//! async fn consume_as_stream() {
52//!     let client = MyApiClient;
53//!     let initial_params = MyParams {
54//!         since: 0,
55//!         until: 100,
56//!         limit: 20
57//!     };
58//!
59//!     // 4. Create a `PagingStream`.
60//!     let mut stream = PagingStream::new(client, initial_params);
61//!
62//!     // 5. Consume the stream.
63//!     while let Some(result) = stream.next().await {
64//!         match result {
65//!             Ok(item) => { /* process `item` */ }
66//!             Err(e) => { /* handle `e` */ break; }
67//!         }
68//!     }
69//! }
70//! ```
71
72use futures::Stream;
73use std::collections::VecDeque;
74use std::pin::Pin;
75use std::task::Poll;
76
77/// Represents a data source that can be paginated.
78pub trait Paginated {
79    /// The type of parameters used to request a page (e.g., page number, cursor, offset).
80    type Params: Unpin;
81    /// The type of item that the stream will yield.
82    type Item: Unpin;
83    /// The type of error that can occur during page fetching.
84    type Error;
85
86    /// Asynchronously fetches a single page of items.
87    ///
88    /// This method takes the current `params` and should return a `Result` containing:
89    /// - `Ok((Vec<Self::Item>, Option<Self::Params>))`:
90    ///   - A `Vec` of items for the current page.
91    ///   - An `Option` for the next page's parameters. `Some(next_params)` indicates
92    ///     there might be more data, and `None` signifies the end of the data source.
93    /// - `Err(Self::Error)`: If an error occurs during fetching.
94    ///
95    /// The returned `Future` must be `Send + 'static`.
96    fn fetch_page(
97        &self,
98        params: Self::Params,
99    ) -> impl Future<Output = Result<(Vec<Self::Item>, Option<Self::Params>), Self::Error>>
100    + Send
101    + 'static;
102}
103
104type MaybeInFlight<T, U, E> =
105    Option<Pin<Box<dyn Future<Output = Result<(Vec<T>, Option<U>), E>> + Send + 'static>>>;
106
107/// A stream that wraps a `Paginated` type to provide continuous, lazy-loaded data.
108///
109/// `PagingStream` handles the logic of fetching pages, buffering items, and
110/// managing the state of requests. It polls the `Paginated::fetch_page` method
111/// as needed when the stream is consumed.
112///
113/// # Type Parameters
114/// - `T`: The type that implements the `Paginated` trait. It must also be `Unpin`.
115pub struct PagingStream<T>
116where
117    T: Paginated,
118    T: Unpin,
119{
120    client: T,
121    params: Option<T::Params>,
122    buffer: VecDeque<T::Item>,
123    request: MaybeInFlight<T::Item, T::Params, T::Error>,
124}
125
126impl<T> PagingStream<T>
127where
128    T: Paginated,
129    T: Unpin,
130{
131    /// Creates a new `PagingStream`.
132    ///
133    /// # Arguments
134    /// * `client`: An instance of your type that implements `Paginated`.
135    /// * `params`: The initial parameters to fetch the first page.
136    pub fn new(paginated: T, params: T::Params) -> Self {
137        Self {
138            client: paginated,
139            params: Some(params),
140            buffer: VecDeque::new(),
141            request: None,
142        }
143    }
144}
145
146impl<T> Stream for PagingStream<T>
147where
148    T: Paginated,
149    T: Unpin,
150    T::Item: Unpin,
151    T::Params: Unpin,
152{
153    type Item = Result<T::Item, T::Error>;
154
155    fn poll_next(
156        self: Pin<&mut Self>,
157        cx: &mut std::task::Context<'_>,
158    ) -> Poll<Option<Self::Item>> {
159        let slf = self.get_mut();
160
161        loop {
162            // #1: yield results from the buffer until exhaustion
163            if let Some(value) = slf.buffer.pop_front() {
164                return Poll::Ready(Some(Ok(value)));
165            }
166
167            if let Some(mut request) = slf.request.take() {
168                match Pin::as_mut(&mut request).poll(cx) {
169                    Poll::Ready(Ok((values, params))) => {
170                        // #2: assign the returned values if the request was successful
171                        slf.buffer.extend(values);
172                        slf.params = params;
173                        continue;
174                    }
175                    Poll::Ready(Err(err)) => {
176                        // #3: yield the error if the request failed
177                        return Poll::Ready(Some(Err(err)));
178                    }
179                    Poll::Pending => {
180                        // #4: yield pending if the request is pending
181                        slf.request = Some(request);
182                        return Poll::Pending;
183                    }
184                }
185            }
186
187            if let Some(params) = slf.params.take() {
188                // #5: send a new request if:
189                //      1. there are no items in the buffer
190                //      2. there is no pending request
191                //      3. there are params
192                slf.request = Some(Box::pin(slf.client.fetch_page(params)));
193                cx.waker().wake_by_ref();
194                return Poll::Pending;
195            } else {
196                // #6: yield None when there is nothing left to do
197                return Poll::Ready(None);
198            }
199        }
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use futures::StreamExt;
207    use std::sync::Once;
208    use std::time;
209
210    static INITIALIZE_TRACING: Once = Once::new();
211
212    fn init_tracing() {
213        INITIALIZE_TRACING.call_once(|| {
214            tracing_subscriber::fmt()
215                .with_test_writer()
216                .with_max_level(tracing::Level::DEBUG)
217                .init();
218        });
219    }
220
221    pub struct Repository;
222
223    pub struct Params {
224        since: usize,
225        until: usize,
226        limit: usize,
227    }
228
229    impl Repository {
230        fn get_next_page_params(params: &Params, results: &[usize]) -> Option<Params> {
231            results.last().map(|last| Params {
232                since: last + 1,
233                until: params.until,
234                limit: params.limit,
235            })
236        }
237    }
238
239    const END_OF_COLLECTION: usize = 1000;
240    const ERR_RANGE_START: usize = 200;
241    const ERR_RANGE_END: usize = 500;
242
243    impl Paginated for Repository {
244        type Params = Params;
245        type Item = usize;
246        type Error = ();
247
248        fn fetch_page(
249            &self,
250            params: Self::Params,
251        ) -> impl Future<Output = Result<(Vec<Self::Item>, Option<Self::Params>), Self::Error>>
252        + Send
253        + 'static {
254            async move {
255                tracing::debug!(message="Fetching page", since=?params.since, until=?params.until, limit=?params.limit);
256
257                tokio::time::sleep(time::Duration::from_millis(5)).await;
258
259                let mut values = Vec::with_capacity(params.limit);
260
261                // return empty vec if since is larger than the end of the collection
262                if params.since > END_OF_COLLECTION {
263                    return Ok((values, None));
264                }
265
266                // return err if since is in the error range
267                if params.since > ERR_RANGE_START && params.since < ERR_RANGE_END {
268                    return Err(());
269                }
270
271                let requested_until = std::cmp::min(params.since + params.limit, params.until);
272
273                let end_of_page = std::cmp::min(requested_until, END_OF_COLLECTION);
274
275                for i in params.since..end_of_page {
276                    values.push(i)
277                }
278
279                let params = Self::get_next_page_params(&params, &values);
280
281                Ok((values, params))
282            }
283        }
284    }
285
286    #[tokio::test]
287    async fn it_streams_up_until() {
288        let mut since = 500;
289        let until = 700;
290        let limit = 100;
291
292        let mut stream = PagingStream::new(
293            Repository,
294            Params {
295                since,
296                until,
297                limit,
298            },
299        );
300
301        let mut last_value = 0;
302        while let Some(value) = stream.next().await {
303            assert_eq!(value, Ok(since));
304            last_value = value.unwrap();
305            since += 1;
306        }
307
308        assert_eq!(last_value, until - 1);
309
310        // subsequent polls yield None
311        let value = stream.next().await;
312        assert_eq!(value, None);
313    }
314
315    #[tokio::test]
316    async fn it_terminates_at_the_end_of_the_collection() {
317        init_tracing();
318
319        let mut since = 900;
320        let until = 1100;
321        let limit = 100;
322
323        let mut stream = PagingStream::new(
324            Repository,
325            Params {
326                since,
327                until,
328                limit,
329            },
330        );
331
332        let mut last_value = 0;
333        while let Some(value) = stream.next().await {
334            assert_eq!(value, Ok(since));
335            last_value = value.unwrap();
336            since += 1;
337        }
338
339        assert_eq!(last_value, END_OF_COLLECTION - 1);
340
341        // subsequent polls yield None
342        let value = stream.next().await;
343        assert_eq!(value, None);
344    }
345
346    #[tokio::test]
347    async fn it_streams_mutliples_of_limit() {
348        init_tracing();
349
350        let mut since = 0;
351        let until = 20;
352        let limit = 10;
353
354        let mut stream = PagingStream::new(
355            Repository,
356            Params {
357                since,
358                until,
359                limit,
360            },
361        );
362
363        let mut last_value = 0;
364        while let Some(value) = stream.next().await {
365            assert_eq!(value, Ok(since));
366            last_value = value.unwrap();
367            since += 1;
368        }
369
370        assert_eq!(last_value, until - 1);
371
372        // subsequent polls yield None
373        let value = stream.next().await;
374        assert_eq!(value, None);
375    }
376
377    #[tokio::test]
378    async fn it_terminates_if_the_collection_is_empty() {
379        init_tracing();
380
381        let since = 1000;
382        let until = 1001;
383        let limit = 1;
384
385        let mut stream = PagingStream::new(
386            Repository,
387            Params {
388                since,
389                until,
390                limit,
391            },
392        );
393
394        let value = stream.next().await;
395        assert_eq!(value, None);
396    }
397
398    #[tokio::test]
399    async fn it_terminates_if_limit_is_zero() {
400        init_tracing();
401
402        let since = 0;
403        let until = 20;
404        let limit = 0;
405
406        let mut stream = PagingStream::new(
407            Repository,
408            Params {
409                since,
410                until,
411                limit,
412            },
413        );
414
415        let value = stream.next().await;
416        assert_eq!(value, None);
417    }
418
419    #[tokio::test]
420    async fn it_bails_out_on_error() {
421        init_tracing();
422
423        let since = 499;
424        let until = 500;
425        let limit = 1;
426
427        let mut stream = PagingStream::new(
428            Repository,
429            Params {
430                since,
431                until,
432                limit,
433            },
434        );
435
436        // it yields the encountered error
437        let value = stream.next().await;
438        assert_eq!(value, Some(Err(())));
439
440        // it then terminates
441        let value = stream.next().await;
442        assert_eq!(value, None);
443    }
444
445    #[tokio::test]
446    async fn it_bails_out_on_error_for_a_subsequent_page() {
447        init_tracing();
448
449        let since = 200;
450        let until = 201;
451        let limit = 1;
452
453        let mut stream = PagingStream::new(
454            Repository,
455            Params {
456                since,
457                until,
458                limit,
459            },
460        );
461
462        // it yield valid values from the first page
463        let value = stream.next().await;
464        assert_eq!(value, Some(Ok(200)));
465
466        // it then yields the encountered error
467        let value = stream.next().await;
468        assert_eq!(value, Some(Err(())));
469
470        // it then terminates
471        let value = stream.next().await;
472        assert_eq!(value, None);
473    }
474}