codealong_github/
cursor.rs1use regex::Regex;
2use reqwest::header::HeaderMap;
3use slog::Logger;
4
5use crate::client::Client;
6use crate::error::Result;
7
8pub struct Cursor<'client, T>
10where
11 for<'de> T: serde::Deserialize<'de>,
12{
13 client: &'client Client,
14 next_url: Option<String>,
15 num_pages: Option<usize>,
16 per_page: Option<usize>,
17 current_page: Option<std::vec::IntoIter<T>>,
18 has_loaded_page: bool,
19 logger: Logger,
20}
21
22impl<'client, T> Cursor<'client, T>
23where
24 for<'de> T: serde::Deserialize<'de>,
25{
26 pub fn new(client: &'client Client, url: &str, logger: &Logger) -> Cursor<'client, T> {
27 Cursor {
28 client,
29 next_url: Some(url.to_owned()),
30 current_page: None,
31 num_pages: None,
32 per_page: None,
33 has_loaded_page: false,
34 logger: logger.clone(),
35 }
36 }
37
38 pub fn guess_len(&mut self) -> Option<usize> {
39 self.ensure_page_loaded();
40 self.num_pages
41 .and_then(|num_page| self.per_page.map(|per_page| num_page * per_page))
42 }
43
44 fn get_next_url(&self, headers: &HeaderMap) -> Option<String> {
45 let link = headers.get("link");
46 link.and_then(|link| {
47 lazy_static! {
48 static ref LINK_NEXT_REGEX: Regex = Regex::new(r#"<([^ ]*)>; rel="next""#).unwrap();
49 }
50 LINK_NEXT_REGEX
51 .captures(link.to_str().unwrap())
52 .map(|captures| captures[1].to_owned())
53 })
54 }
55
56 fn read_from_current_page(&mut self) -> Option<T> {
57 self.current_page.as_mut().and_then(|iter| iter.next())
58 }
59
60 fn get_num_pages(&self, headers: &HeaderMap) -> Option<usize> {
61 let link = headers.get("link");
62 link.and_then(|link| {
63 lazy_static! {
64 static ref LINK_LAST_PAGE_REGEX: Regex =
65 Regex::new(r#"<[^ ]*page=(\d+)[^ ]*>; rel="last""#).unwrap();
66 }
67 LINK_LAST_PAGE_REGEX
68 .captures(link.to_str().unwrap())
69 .map(|captures| captures[1].to_owned().parse::<usize>().unwrap())
70 })
71 }
72
73 fn ensure_page_loaded(&mut self) {
74 if !self.has_loaded_page {
75 self.load_next_page()
76 }
77 }
78
79 fn load_next_page(&mut self) {
80 match self.load_next_page_helper() {
81 Ok(_) => (),
82 Err(e) => error!(self.logger, "Error loading page: {}", e),
83 }
84 }
85
86 fn load_next_page_helper(&mut self) -> Result<()> {
87 if let Some(next_url) = self.next_url.take() {
88 let mut res = self.client.get(&next_url)?;
89 self.has_loaded_page = true;
90 let new_page = res.json::<Vec<T>>().unwrap().into_iter();
91 let headers = res.headers();
92 self.next_url = self.get_next_url(&headers);
93 if let None = self.num_pages {
94 self.num_pages = self.get_num_pages(&headers);
95 }
96 if let None = self.per_page {
97 self.per_page = Some(new_page.len());
98 }
99 self.current_page = Some(new_page);
100 Ok(())
101 } else {
102 Ok(())
103 }
104 }
105}
106
107impl<'client, T> Iterator for Cursor<'client, T>
108where
109 for<'de> T: serde::Deserialize<'de>,
110{
111 type Item = T;
112
113 fn next(&mut self) -> Option<T> {
114 self.read_from_current_page().or_else(|| {
115 self.load_next_page();
116 self.read_from_current_page()
117 })
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::pull_request::PullRequest;
125 use codealong::test::build_test_logger;
126
127 #[test]
128 fn test_cursor() {
129 let client = Client::from_env();
130 let logger = build_test_logger();
131 let mut cursor: Cursor<PullRequest> = Cursor::new(
132 &client,
133 "https://api.github.com/repos/facebook/react/pulls?state=all",
134 &logger,
135 );
136 assert!(cursor.guess_len().unwrap() > 100);
137 assert_eq!(cursor.take(100).collect::<Vec<PullRequest>>().len(), 100);
138 }
139}