codealong_github/
cursor.rs

1use regex::Regex;
2use reqwest::header::HeaderMap;
3use slog::Logger;
4
5use crate::client::Client;
6use crate::error::Result;
7
8/// Provides an iterator on top of the Github pagination API
9pub struct Cursor<'client, T>
10where
11    for<'de> T: serde::Deserialize<'de>,
12{
13    client: &'client Client,
14    next_url: Option<String>,
15    num_pages: Option<usize>,
16    per_page: Option<usize>,
17    current_page: Option<std::vec::IntoIter<T>>,
18    has_loaded_page: bool,
19    logger: Logger,
20}
21
22impl<'client, T> Cursor<'client, T>
23where
24    for<'de> T: serde::Deserialize<'de>,
25{
26    pub fn new(client: &'client Client, url: &str, logger: &Logger) -> Cursor<'client, T> {
27        Cursor {
28            client,
29            next_url: Some(url.to_owned()),
30            current_page: None,
31            num_pages: None,
32            per_page: None,
33            has_loaded_page: false,
34            logger: logger.clone(),
35        }
36    }
37
38    pub fn guess_len(&mut self) -> Option<usize> {
39        self.ensure_page_loaded();
40        self.num_pages
41            .and_then(|num_page| self.per_page.map(|per_page| num_page * per_page))
42    }
43
44    fn get_next_url(&self, headers: &HeaderMap) -> Option<String> {
45        let link = headers.get("link");
46        link.and_then(|link| {
47            lazy_static! {
48                static ref LINK_NEXT_REGEX: Regex = Regex::new(r#"<([^ ]*)>; rel="next""#).unwrap();
49            }
50            LINK_NEXT_REGEX
51                .captures(link.to_str().unwrap())
52                .map(|captures| captures[1].to_owned())
53        })
54    }
55
56    fn read_from_current_page(&mut self) -> Option<T> {
57        self.current_page.as_mut().and_then(|iter| iter.next())
58    }
59
60    fn get_num_pages(&self, headers: &HeaderMap) -> Option<usize> {
61        let link = headers.get("link");
62        link.and_then(|link| {
63            lazy_static! {
64                static ref LINK_LAST_PAGE_REGEX: Regex =
65                    Regex::new(r#"<[^ ]*page=(\d+)[^ ]*>; rel="last""#).unwrap();
66            }
67            LINK_LAST_PAGE_REGEX
68                .captures(link.to_str().unwrap())
69                .map(|captures| captures[1].to_owned().parse::<usize>().unwrap())
70        })
71    }
72
73    fn ensure_page_loaded(&mut self) {
74        if !self.has_loaded_page {
75            self.load_next_page()
76        }
77    }
78
79    fn load_next_page(&mut self) {
80        match self.load_next_page_helper() {
81            Ok(_) => (),
82            Err(e) => error!(self.logger, "Error loading page: {}", e),
83        }
84    }
85
86    fn load_next_page_helper(&mut self) -> Result<()> {
87        if let Some(next_url) = self.next_url.take() {
88            let mut res = self.client.get(&next_url)?;
89            self.has_loaded_page = true;
90            let new_page = res.json::<Vec<T>>().unwrap().into_iter();
91            let headers = res.headers();
92            self.next_url = self.get_next_url(&headers);
93            if let None = self.num_pages {
94                self.num_pages = self.get_num_pages(&headers);
95            }
96            if let None = self.per_page {
97                self.per_page = Some(new_page.len());
98            }
99            self.current_page = Some(new_page);
100            Ok(())
101        } else {
102            Ok(())
103        }
104    }
105}
106
107impl<'client, T> Iterator for Cursor<'client, T>
108where
109    for<'de> T: serde::Deserialize<'de>,
110{
111    type Item = T;
112
113    fn next(&mut self) -> Option<T> {
114        self.read_from_current_page().or_else(|| {
115            self.load_next_page();
116            self.read_from_current_page()
117        })
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use crate::pull_request::PullRequest;
125    use codealong::test::build_test_logger;
126
127    #[test]
128    fn test_cursor() {
129        let client = Client::from_env();
130        let logger = build_test_logger();
131        let mut cursor: Cursor<PullRequest> = Cursor::new(
132            &client,
133            "https://api.github.com/repos/facebook/react/pulls?state=all",
134            &logger,
135        );
136        assert!(cursor.guess_len().unwrap() > 100);
137        assert_eq!(cursor.take(100).collect::<Vec<PullRequest>>().len(), 100);
138    }
139}