1use futures::future::BoxFuture;
2use futures::stream::Stream;
3use futures::FutureExt;
4use log::debug;
5use reqwest::Client;
6use serde::{de::DeserializeOwned, Deserialize};
7use std::collections::VecDeque;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use thiserror::Error;
11use url::Url;
12
13#[derive(Debug, Error)]
14#[non_exhaustive]
15pub enum PaginationError {
16 #[error("http request failed: {0}")]
17 ReqWest(#[from] reqwest::Error),
18 #[error("HTTP redirect without location")]
19 RedirectMissing,
20 #[error("HTTP redirect not valid utf-8")]
21 RedirectInvalidUTF8,
22 #[error("Too many redirects")]
23 TooManyRedirects,
24 #[error("Failed to parse next uri: {0}")]
25 ParseNextError(#[from] url::ParseError),
26}
27
28#[derive(Deserialize, Debug)]
29struct PaginatedReply<T> {
30 count: u32,
31 next: Option<String>,
32 results: VecDeque<T>,
33}
34
35enum State<T> {
36 Data(PaginatedReply<T>),
37 Next(BoxFuture<'static, Result<PaginatedReply<T>, PaginationError>>),
38 Failed,
39}
40
41pub struct Paginator<T> {
42 client: Client,
43 current: Url,
44 next: State<T>,
45 count: Option<u32>,
46}
47
48impl<T> Paginator<T>
49where
50 T: DeserializeOwned + 'static,
51{
52 pub fn new(client: Client, url: Url) -> Self {
53 let next = State::Next(Self::get(client.clone(), url.clone()).boxed());
54
55 Paginator {
56 client,
57 current: url,
58 next,
59 count: None,
60 }
61 }
62
63 async fn get(client: Client, uri: Url) -> Result<PaginatedReply<T>, PaginationError>
64 where
65 T: DeserializeOwned,
66 {
67 let mut redirects: u8 = 0;
68 let mut u = uri.clone();
69 let response = loop {
70 let response = client.get(u.clone()).send().await?;
71
72 if !response.status().is_redirection() {
73 break response;
74 }
75
76 if redirects > 9 {
77 return Err(PaginationError::TooManyRedirects);
78 }
79
80 redirects += 1;
81 if let Some(location) = response.headers().get("location") {
82 let redirect = std::str::from_utf8(location.as_bytes())
83 .or(Err(PaginationError::RedirectInvalidUTF8))?;
84
85 debug!("Redirecting from {:?} to {:?}", u, location);
86 u = u.join(redirect)?;
87 if uri.scheme() == "https" && u.scheme() == "http" {
90 u.set_scheme("https").unwrap();
91 }
92 } else {
93 return Err(PaginationError::RedirectMissing);
94 }
95 };
96
97 response
98 .error_for_status()?
99 .json()
100 .await
101 .map_err(|e| e.into())
102 }
103
104 fn next_data(&mut self) -> Result<Option<T>, PaginationError> {
105 if let State::Data(d) = &mut self.next {
106 self.count = Some(d.count);
107 if let Some(data) = d.results.pop_front() {
108 return Ok(Some(data));
109 }
110
111 if let Some(n) = &d.next {
112 let u: Result<Url, _> = n.parse();
113 match u {
114 Ok(u) => {
115 self.next = State::Next(Self::get(self.client.clone(), u.clone()).boxed());
116 self.current = u;
117 }
118 Err(e) => {
119 self.next = State::Failed;
120 return Err(e.into());
121 }
122 }
123 }
124 }
125 Ok(None)
126 }
127
128 pub fn reported_items(&self) -> Option<u32> {
129 self.count
130 }
131}
132
133impl<T> Stream for Paginator<T>
134where
135 T: DeserializeOwned + Unpin + 'static,
136{
137 type Item = Result<T, PaginationError>;
138
139 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
140 let me = self.get_mut();
141 if let Some(data) = me.next_data()? {
142 return Poll::Ready(Some(Ok(data)));
143 }
144
145 if let State::Next(n) = &mut me.next {
146 match n.as_mut().poll(cx) {
147 Poll::Ready(r) => {
148 match r {
149 Ok(r) => me.next = State::Data(r),
150 Err(e) => {
151 me.next = State::Next(
152 Self::get(me.client.clone(), me.current.clone()).boxed(),
153 );
154 return Poll::Ready(Some(Err(e)));
155 }
156 }
157 Poll::Ready(me.next_data().transpose())
158 }
159 _ => Poll::Pending,
160 }
161 } else {
162 Poll::Ready(None)
163 }
164 }
165}