1use futures_core::Stream;
40use modkit_odata::{ODataQuery, Page};
41use pin_project_lite::pin_project;
42use std::collections::VecDeque;
43use std::fmt;
44use std::future::Future;
45use std::pin::Pin;
46use std::task::{Context, Poll};
47
48#[derive(Debug)]
53pub enum PagerError<E> {
54 Fetch(E),
56 InvalidCursor(String),
58}
59
60impl<E: fmt::Display> fmt::Display for PagerError<E> {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 match self {
63 Self::Fetch(e) => write!(f, "Fetch error: {e}"),
64 Self::InvalidCursor(cursor) => write!(f, "Invalid cursor: {cursor}"),
65 }
66 }
67}
68
69impl<E: std::error::Error + 'static> std::error::Error for PagerError<E> {
70 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
71 match self {
72 Self::Fetch(e) => Some(e),
73 Self::InvalidCursor(_) => None,
74 }
75 }
76}
77
78pin_project! {
79 pub struct CursorPager<T, E, F, Fut>
91 where
92 F: FnMut(ODataQuery) -> Fut,
93 Fut: Future<Output = Result<Page<T>, E>>,
94 {
95 base_query: ODataQuery,
96 next_cursor: Option<String>,
97 buffer: VecDeque<T>,
98 done: bool,
99 fetcher: F,
100 #[pin]
101 current_fetch: Option<Fut>,
102 }
103}
104
105impl<T, E, F, Fut> CursorPager<T, E, F, Fut>
106where
107 F: FnMut(ODataQuery) -> Fut,
108 Fut: Future<Output = Result<Page<T>, E>>,
109{
110 pub fn new(base_query: ODataQuery, fetcher: F) -> Self {
125 Self {
126 base_query,
127 next_cursor: None,
128 buffer: VecDeque::new(),
129 done: false,
130 fetcher,
131 current_fetch: None,
132 }
133 }
134}
135
136impl<T, E, F, Fut> Stream for CursorPager<T, E, F, Fut>
137where
138 F: FnMut(ODataQuery) -> Fut,
139 Fut: Future<Output = Result<Page<T>, E>>,
140{
141 type Item = Result<T, PagerError<E>>;
142
143 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
144 let mut this = self.project();
145
146 loop {
147 if let Some(item) = this.buffer.pop_front() {
148 return Poll::Ready(Some(Ok(item)));
149 }
150
151 if *this.done {
152 return Poll::Ready(None);
153 }
154
155 if let Some(fut) = this.current_fetch.as_mut().as_pin_mut() {
156 match fut.poll(cx) {
157 Poll::Ready(Ok(page)) => {
158 this.current_fetch.set(None);
159
160 this.next_cursor.clone_from(&page.page_info.next_cursor);
161
162 if this.next_cursor.is_none() {
163 *this.done = true;
164 }
165
166 this.buffer.extend(page.items);
167
168 continue;
169 }
170 Poll::Ready(Err(e)) => {
171 this.current_fetch.set(None);
172 *this.done = true;
173 return Poll::Ready(Some(Err(PagerError::Fetch(e))));
174 }
175 Poll::Pending => return Poll::Pending,
176 }
177 }
178
179 let mut query = this.base_query.clone();
182 if let Some(cursor_str) = this.next_cursor.as_ref() {
183 if let Ok(cursor) = modkit_odata::CursorV1::decode(cursor_str) {
184 query = query.with_cursor(cursor);
185 } else {
186 *this.done = true;
187 return Poll::Ready(Some(Err(PagerError::InvalidCursor(cursor_str.clone()))));
188 }
189 }
190
191 let fut = (this.fetcher)(query);
192 this.current_fetch.set(Some(fut));
193 }
194 }
195}
196
197pin_project! {
198 pub struct PagesPager<T, E, F, Fut>
209 where
210 F: FnMut(ODataQuery) -> Fut,
211 Fut: Future<Output = Result<Page<T>, E>>,
212 {
213 base_query: ODataQuery,
214 next_cursor: Option<String>,
215 done: bool,
216 fetcher: F,
217 #[pin]
218 current_fetch: Option<Fut>,
219 }
220}
221
222impl<T, E, F, Fut> PagesPager<T, E, F, Fut>
223where
224 F: FnMut(ODataQuery) -> Fut,
225 Fut: Future<Output = Result<Page<T>, E>>,
226{
227 pub fn new(base_query: ODataQuery, fetcher: F) -> Self {
242 Self {
243 base_query,
244 next_cursor: None,
245 done: false,
246 fetcher,
247 current_fetch: None,
248 }
249 }
250}
251
252impl<T, E, F, Fut> Stream for PagesPager<T, E, F, Fut>
253where
254 F: FnMut(ODataQuery) -> Fut,
255 Fut: Future<Output = Result<Page<T>, E>>,
256{
257 type Item = Result<Page<T>, PagerError<E>>;
258
259 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
260 let mut this = self.project();
261
262 loop {
263 if *this.done {
264 return Poll::Ready(None);
265 }
266
267 if let Some(fut) = this.current_fetch.as_mut().as_pin_mut() {
268 match fut.poll(cx) {
269 Poll::Ready(Ok(page)) => {
270 this.current_fetch.set(None);
271
272 this.next_cursor.clone_from(&page.page_info.next_cursor);
273
274 if this.next_cursor.is_none() {
275 *this.done = true;
276 }
277
278 return Poll::Ready(Some(Ok(page)));
279 }
280 Poll::Ready(Err(e)) => {
281 this.current_fetch.set(None);
282 *this.done = true;
283 return Poll::Ready(Some(Err(PagerError::Fetch(e))));
284 }
285 Poll::Pending => return Poll::Pending,
286 }
287 }
288
289 let mut query = this.base_query.clone();
292 if let Some(cursor_str) = this.next_cursor.as_ref() {
293 if let Ok(cursor) = modkit_odata::CursorV1::decode(cursor_str) {
294 query = query.with_cursor(cursor);
295 } else {
296 *this.done = true;
297 return Poll::Ready(Some(Err(PagerError::InvalidCursor(cursor_str.clone()))));
298 }
299 }
300
301 let fut = (this.fetcher)(query);
302 this.current_fetch.set(Some(fut));
303
304 }
307 }
308}
309
310#[cfg(test)]
311#[allow(clippy::similar_names)]
312mod tests {
313 use super::*;
314 use futures_util::StreamExt;
315 use modkit_odata::PageInfo;
316 use std::sync::atomic::{AtomicUsize, Ordering};
317 use std::sync::{Arc, Mutex};
318
319 #[derive(Debug, Clone, PartialEq)]
320 struct User {
321 id: i32,
322 name: String,
323 }
324
325 #[derive(Debug, Clone, PartialEq)]
326 struct FakeError(String);
327
328 #[derive(Clone)]
329 struct FakeFetcher {
330 pages: Arc<[Page<User>]>,
331 call_count: Arc<Mutex<usize>>,
332 }
333
334 impl FakeFetcher {
335 fn new(pages: Vec<Page<User>>) -> Self {
336 Self {
337 pages: Arc::from(pages),
338 call_count: Arc::new(Mutex::new(0)),
339 }
340 }
341
342 fn fetch(&self, _query: ODataQuery) -> Result<Page<User>, FakeError> {
343 let mut count = self.call_count.lock().unwrap();
344 if *count >= self.pages.len() {
345 return Err(FakeError("No more pages".to_owned()));
346 }
347 let page = self.pages[*count].clone();
348 *count += 1;
349 Ok(page)
350 }
351 }
352
353 #[tokio::test]
354 async fn test_cursor_pager_two_pages() {
355 use modkit_odata::{CursorV1, SortDir};
356
357 let cursor = CursorV1 {
358 k: vec!["2".to_owned()],
359 o: SortDir::Asc,
360 s: "filter_hash".to_owned(),
361 f: Some("filter_hash".to_owned()),
362 d: "fwd".to_owned(),
363 };
364 let encoded_cursor = cursor.encode().unwrap();
365
366 let page1 = Page::new(
367 vec![
368 User {
369 id: 1,
370 name: "Alice".to_owned(),
371 },
372 User {
373 id: 2,
374 name: "Bob".to_owned(),
375 },
376 ],
377 PageInfo {
378 next_cursor: Some(encoded_cursor.clone()),
379 prev_cursor: None,
380 limit: 2,
381 },
382 );
383
384 let page2 = Page::new(
385 vec![
386 User {
387 id: 3,
388 name: "Charlie".to_owned(),
389 },
390 User {
391 id: 4,
392 name: "Diana".to_owned(),
393 },
394 ],
395 PageInfo {
396 next_cursor: None,
397 prev_cursor: Some(encoded_cursor),
398 limit: 2,
399 },
400 );
401
402 let fetcher = FakeFetcher::new(vec![page1, page2]);
403 let query = ODataQuery::new().with_limit(2);
404
405 let pager = CursorPager::new(query, move |q| {
406 let fetcher = fetcher.clone();
407 async move { fetcher.fetch(q) }
408 });
409
410 let items: Vec<Result<User, PagerError<FakeError>>> = pager.collect().await;
411
412 assert_eq!(items.len(), 4);
413 assert!(items.iter().all(Result::is_ok));
414
415 let users: Vec<User> = items.into_iter().map(|r| r.unwrap()).collect();
416 assert_eq!(users[0].name, "Alice");
417 assert_eq!(users[1].name, "Bob");
418 assert_eq!(users[2].name, "Charlie");
419 assert_eq!(users[3].name, "Diana");
420 }
421
422 #[tokio::test]
423 async fn test_cursor_pager_empty_page() {
424 let page = Page::new(
425 vec![],
426 PageInfo {
427 next_cursor: None,
428 prev_cursor: None,
429 limit: 10,
430 },
431 );
432
433 let fetcher = FakeFetcher::new(vec![page]);
434 let query = ODataQuery::new().with_limit(10);
435
436 let pager = CursorPager::new(query, move |q| {
437 let fetcher = fetcher.clone();
438 async move { fetcher.fetch(q) }
439 });
440
441 let items: Vec<Result<User, PagerError<FakeError>>> = pager.collect().await;
442
443 assert_eq!(items.len(), 0);
444 }
445
446 #[tokio::test]
447 async fn test_cursor_pager_error_propagation() {
448 use modkit_odata::{CursorV1, SortDir};
449
450 let cursor = CursorV1 {
451 k: vec!["1".to_owned()],
452 o: SortDir::Asc,
453 s: "filter_hash".to_owned(),
454 f: Some("filter_hash".to_owned()),
455 d: "fwd".to_owned(),
456 };
457 let encoded_cursor = cursor.encode().unwrap();
458
459 let page1 = Page::new(
460 vec![User {
461 id: 1,
462 name: "Alice".to_owned(),
463 }],
464 PageInfo {
465 next_cursor: Some(encoded_cursor),
466 prev_cursor: None,
467 limit: 1,
468 },
469 );
470
471 let fetcher = FakeFetcher::new(vec![page1]);
472 let query = ODataQuery::new().with_limit(1);
473
474 let pager = CursorPager::new(query, move |q| {
475 let fetcher = fetcher.clone();
476 async move { fetcher.fetch(q) }
477 });
478
479 let items: Vec<Result<User, PagerError<FakeError>>> = pager.collect().await;
480
481 assert_eq!(items.len(), 2);
482 assert!(items[0].is_ok());
483 assert!(items[1].is_err());
484
485 if let Err(PagerError::Fetch(_)) = &items[1] {
487 } else {
489 panic!("Expected PagerError::Fetch");
490 }
491 }
492
493 #[tokio::test]
494 async fn test_pages_pager_two_pages() {
495 use modkit_odata::{CursorV1, SortDir};
496
497 let cursor = CursorV1 {
498 k: vec!["2".to_owned()],
499 o: SortDir::Asc,
500 s: "filter_hash".to_owned(),
501 f: Some("filter_hash".to_owned()),
502 d: "fwd".to_owned(),
503 };
504 let encoded_cursor = cursor.encode().unwrap();
505
506 let page1 = Page::new(
507 vec![
508 User {
509 id: 1,
510 name: "Alice".to_owned(),
511 },
512 User {
513 id: 2,
514 name: "Bob".to_owned(),
515 },
516 ],
517 PageInfo {
518 next_cursor: Some(encoded_cursor.clone()),
519 prev_cursor: None,
520 limit: 2,
521 },
522 );
523
524 let page2 = Page::new(
525 vec![User {
526 id: 3,
527 name: "Charlie".to_owned(),
528 }],
529 PageInfo {
530 next_cursor: None,
531 prev_cursor: Some(encoded_cursor),
532 limit: 2,
533 },
534 );
535
536 let fetcher = FakeFetcher::new(vec![page1.clone(), page2.clone()]);
537 let query = ODataQuery::new().with_limit(2);
538
539 let pager = PagesPager::new(query, move |q| {
540 let fetcher = fetcher.clone();
541 async move { fetcher.fetch(q) }
542 });
543
544 let pages: Vec<Result<Page<User>, PagerError<FakeError>>> = pager.collect().await;
545
546 assert_eq!(pages.len(), 2);
547 assert!(pages.iter().all(Result::is_ok));
548
549 let page_results: Vec<Page<User>> = pages.into_iter().map(|r| r.unwrap()).collect();
550 assert_eq!(page_results[0].items.len(), 2);
551 assert_eq!(page_results[1].items.len(), 1);
552 assert_eq!(page_results[0].items[0].name, "Alice");
553 assert_eq!(page_results[1].items[0].name, "Charlie");
554 }
555
556 #[tokio::test]
557 async fn test_pages_pager_single_page() {
558 let page = Page::new(
559 vec![User {
560 id: 1,
561 name: "Alice".to_owned(),
562 }],
563 PageInfo {
564 next_cursor: None,
565 prev_cursor: None,
566 limit: 10,
567 },
568 );
569
570 let fetcher = FakeFetcher::new(vec![page.clone()]);
571 let query = ODataQuery::new().with_limit(10);
572
573 let pager = PagesPager::new(query, move |q| {
574 let fetcher = fetcher.clone();
575 async move { fetcher.fetch(q) }
576 });
577
578 let pages: Vec<Result<Page<User>, PagerError<FakeError>>> = pager.collect().await;
579
580 assert_eq!(pages.len(), 1);
581 assert!(pages[0].is_ok());
582 }
583
584 #[tokio::test]
585 async fn test_cursor_pager_invalid_cursor() {
586 let page1 = Page::new(
587 vec![User {
588 id: 1,
589 name: "Alice".to_owned(),
590 }],
591 PageInfo {
592 next_cursor: Some("invalid_cursor_string".to_owned()),
593 prev_cursor: None,
594 limit: 1,
595 },
596 );
597
598 let fetcher = FakeFetcher::new(vec![page1]);
599 let query = ODataQuery::new().with_limit(1);
600
601 let pager = CursorPager::new(query, move |q| {
602 let fetcher = fetcher.clone();
603 async move { fetcher.fetch(q) }
604 });
605
606 let items: Vec<Result<User, PagerError<FakeError>>> = pager.collect().await;
607
608 assert_eq!(items.len(), 2);
609 assert!(items[0].is_ok());
610 assert!(items[1].is_err());
611
612 if let Err(PagerError::InvalidCursor(cursor)) = &items[1] {
614 assert_eq!(cursor, "invalid_cursor_string");
615 } else {
616 panic!("Expected PagerError::InvalidCursor");
617 }
618 }
619
620 #[tokio::test]
621 async fn test_pages_pager_invalid_cursor() {
622 let page1 = Page::new(
623 vec![User {
624 id: 1,
625 name: "Alice".to_owned(),
626 }],
627 PageInfo {
628 next_cursor: Some("invalid_cursor_string".to_owned()),
629 prev_cursor: None,
630 limit: 1,
631 },
632 );
633
634 let fetcher = FakeFetcher::new(vec![page1]);
635 let query = ODataQuery::new().with_limit(1);
636
637 let pager = PagesPager::new(query, move |q| {
638 let fetcher = fetcher.clone();
639 async move { fetcher.fetch(q) }
640 });
641
642 let pages: Vec<Result<Page<User>, PagerError<FakeError>>> = pager.collect().await;
643
644 assert_eq!(pages.len(), 2);
645 assert!(pages[0].is_ok());
646 assert!(pages[1].is_err());
647
648 if let Err(PagerError::InvalidCursor(cursor)) = &pages[1] {
650 assert_eq!(cursor, "invalid_cursor_string");
651 } else {
652 panic!("Expected PagerError::InvalidCursor");
653 }
654 }
655
656 #[tokio::test]
657 async fn test_pages_pager_error_propagation() {
658 use modkit_odata::{CursorV1, SortDir};
659
660 let cursor = CursorV1 {
661 k: vec!["1".to_owned()],
662 o: SortDir::Asc,
663 s: "filter_hash".to_owned(),
664 f: Some("filter_hash".to_owned()),
665 d: "fwd".to_owned(),
666 };
667 let encoded_cursor = cursor.encode().unwrap();
668
669 let page1 = Page::new(
670 vec![User {
671 id: 1,
672 name: "Alice".to_owned(),
673 }],
674 PageInfo {
675 next_cursor: Some(encoded_cursor),
676 prev_cursor: None,
677 limit: 1,
678 },
679 );
680
681 let fetcher = FakeFetcher::new(vec![page1]);
682 let query = ODataQuery::new().with_limit(1);
683
684 let pager = PagesPager::new(query, move |q| {
685 let fetcher = fetcher.clone();
686 async move { fetcher.fetch(q) }
687 });
688
689 let pages: Vec<Result<Page<User>, PagerError<FakeError>>> = pager.collect().await;
690
691 assert_eq!(pages.len(), 2);
692 assert!(pages[0].is_ok());
693 assert!(pages[1].is_err());
694
695 if let Err(PagerError::Fetch(_)) = &pages[1] {
697 } else {
699 panic!("Expected PagerError::Fetch");
700 }
701 }
702
703 #[test]
704 fn test_pages_pager_polls_new_future_immediately() {
705 struct PollCountingFuture {
706 polls: Arc<AtomicUsize>,
707 }
708
709 impl Future for PollCountingFuture {
710 type Output = Result<Page<User>, FakeError>;
711
712 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
713 self.polls.fetch_add(1, Ordering::SeqCst);
714 Poll::Pending
715 }
716 }
717
718 let polls = Arc::new(AtomicUsize::new(0));
719 let polls_for_fetcher = polls.clone();
720
721 let mut pager = PagesPager::new(ODataQuery::new().with_limit(1), move |_q| {
722 PollCountingFuture {
723 polls: polls_for_fetcher.clone(),
724 }
725 });
726
727 let waker = futures_util::task::noop_waker_ref();
728 let mut cx = Context::from_waker(waker);
729
730 let poll = Pin::new(&mut pager).poll_next(&mut cx);
731 assert!(matches!(poll, Poll::Pending));
732
733 assert_eq!(polls.load(Ordering::SeqCst), 1);
735 }
736}