lava_api/
paginator.rs

1use futures::future::BoxFuture;
2use futures::stream::Stream;
3use futures::FutureExt;
4use log::debug;
5use reqwest::Client;
6use serde::{de::DeserializeOwned, Deserialize};
7use std::collections::VecDeque;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use thiserror::Error;
11use url::Url;
12
13#[derive(Debug, Error)]
14#[non_exhaustive]
15pub enum PaginationError {
16    #[error("http request failed: {0}")]
17    ReqWest(#[from] reqwest::Error),
18    #[error("HTTP redirect without location")]
19    RedirectMissing,
20    #[error("HTTP redirect not valid utf-8")]
21    RedirectInvalidUTF8,
22    #[error("Too many redirects")]
23    TooManyRedirects,
24    #[error("Failed to parse next uri: {0}")]
25    ParseNextError(#[from] url::ParseError),
26}
27
28#[derive(Deserialize, Debug)]
29struct PaginatedReply<T> {
30    count: u32,
31    next: Option<String>,
32    results: VecDeque<T>,
33}
34
35enum State<T> {
36    Data(PaginatedReply<T>),
37    Next(BoxFuture<'static, Result<PaginatedReply<T>, PaginationError>>),
38    Failed,
39}
40
41pub struct Paginator<T> {
42    client: Client,
43    current: Url,
44    next: State<T>,
45    count: Option<u32>,
46}
47
48impl<T> Paginator<T>
49where
50    T: DeserializeOwned + 'static,
51{
52    pub fn new(client: Client, url: Url) -> Self {
53        let next = State::Next(Self::get(client.clone(), url.clone()).boxed());
54
55        Paginator {
56            client,
57            current: url,
58            next,
59            count: None,
60        }
61    }
62
63    async fn get(client: Client, uri: Url) -> Result<PaginatedReply<T>, PaginationError>
64    where
65        T: DeserializeOwned,
66    {
67        let mut redirects: u8 = 0;
68        let mut u = uri.clone();
69        let response = loop {
70            let response = client.get(u.clone()).send().await?;
71
72            if !response.status().is_redirection() {
73                break response;
74            }
75
76            if redirects > 9 {
77                return Err(PaginationError::TooManyRedirects);
78            }
79
80            redirects += 1;
81            if let Some(location) = response.headers().get("location") {
82                let redirect = std::str::from_utf8(location.as_bytes())
83                    .or(Err(PaginationError::RedirectInvalidUTF8))?;
84
85                debug!("Redirecting from {:?} to {:?}", u, location);
86                u = u.join(redirect)?;
87                // Prevent https to http downgrade as we might have a token in
88                // the request
89                if uri.scheme() == "https" && u.scheme() == "http" {
90                    u.set_scheme("https").unwrap();
91                }
92            } else {
93                return Err(PaginationError::RedirectMissing);
94            }
95        };
96
97        response
98            .error_for_status()?
99            .json()
100            .await
101            .map_err(|e| e.into())
102    }
103
104    fn next_data(&mut self) -> Result<Option<T>, PaginationError> {
105        if let State::Data(d) = &mut self.next {
106            self.count = Some(d.count);
107            if let Some(data) = d.results.pop_front() {
108                return Ok(Some(data));
109            }
110
111            if let Some(n) = &d.next {
112                let u: Result<Url, _> = n.parse();
113                match u {
114                    Ok(u) => {
115                        self.next = State::Next(Self::get(self.client.clone(), u.clone()).boxed());
116                        self.current = u;
117                    }
118                    Err(e) => {
119                        self.next = State::Failed;
120                        return Err(e.into());
121                    }
122                }
123            }
124        }
125        Ok(None)
126    }
127
128    pub fn reported_items(&self) -> Option<u32> {
129        self.count
130    }
131}
132
133impl<T> Stream for Paginator<T>
134where
135    T: DeserializeOwned + Unpin + 'static,
136{
137    type Item = Result<T, PaginationError>;
138
139    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
140        let me = self.get_mut();
141        if let Some(data) = me.next_data()? {
142            return Poll::Ready(Some(Ok(data)));
143        }
144
145        if let State::Next(n) = &mut me.next {
146            match n.as_mut().poll(cx) {
147                Poll::Ready(r) => {
148                    match r {
149                        Ok(r) => me.next = State::Data(r),
150                        Err(e) => {
151                            me.next = State::Next(
152                                Self::get(me.client.clone(), me.current.clone()).boxed(),
153                            );
154                            return Poll::Ready(Some(Err(e)));
155                        }
156                    }
157                    Poll::Ready(me.next_data().transpose())
158                }
159                _ => Poll::Pending,
160            }
161        } else {
162            Poll::Ready(None)
163        }
164    }
165}