1use crate::config::RestStreamConfig;
4use crate::error::FaucetError;
5use crate::extract;
6use crate::pagination::{PaginationState, PaginationStyle};
7use crate::retry;
8use futures_core::Stream;
9use reqwest::Client;
10use reqwest::header::HeaderMap;
11use serde::Deserialize;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::pin::Pin;
15
16pub struct RestStream {
18 config: RestStreamConfig,
19 client: Client,
20}
21
22impl RestStream {
23 pub fn new(config: RestStreamConfig) -> Result<Self, FaucetError> {
25 let mut builder = Client::builder();
26 if let Some(t) = config.timeout {
27 builder = builder.timeout(t);
28 }
29 Ok(Self {
30 config,
31 client: builder.build()?,
32 })
33 }
34
35 pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
37 let mut all_records = Vec::new();
38 let mut state = PaginationState::default();
39 let mut pages_fetched = 0usize;
40
41 loop {
42 if let Some(max) = self.config.max_pages
43 && pages_fetched >= max
44 {
45 tracing::warn!("max pages ({max}) reached");
46 break;
47 }
48
49 let mut params = self.config.query_params.clone();
50 self.config.pagination.apply_params(&mut params, &state);
51
52 let url_override = match &self.config.pagination {
55 PaginationStyle::LinkHeader => state.next_link.clone(),
56 _ => None,
57 };
58
59 let params_clone = params.clone();
60 let (body, resp_headers) = retry::execute_with_retry(
61 self.config.max_retries,
62 self.config.retry_backoff,
63 || self.execute_request(¶ms_clone, url_override.as_deref()),
64 )
65 .await?;
66
67 let records = extract::extract_records(&body, self.config.records_path.as_deref())?;
68 let count = records.len();
69 all_records.extend(records);
70
71 let has_next =
72 self.config
73 .pagination
74 .advance(&body, &resp_headers, &mut state, count)?;
75 pages_fetched += 1;
76 if !has_next {
77 break;
78 }
79
80 if let Some(delay) = self.config.request_delay {
81 tokio::time::sleep(delay).await;
82 }
83 }
84
85 tracing::info!(
86 "fetched {} total records across {} page(s)",
87 all_records.len(),
88 pages_fetched
89 );
90 Ok(all_records)
91 }
92
93 pub async fn fetch_all_as<T: for<'de> Deserialize<'de>>(&self) -> Result<Vec<T>, FaucetError> {
95 let values = self.fetch_all().await?;
96 values
97 .into_iter()
98 .map(|v| serde_json::from_value(v).map_err(FaucetError::Json))
99 .collect()
100 }
101
102 pub fn stream_pages(
122 &self,
123 ) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + '_>> {
124 Box::pin(async_stream::try_stream! {
125 let mut state = PaginationState::default();
126 let mut pages_fetched = 0usize;
127
128 loop {
129 if let Some(max) = self.config.max_pages
130 && pages_fetched >= max
131 {
132 tracing::warn!("max pages ({max}) reached");
133 break;
134 }
135
136 let mut params = self.config.query_params.clone();
137 self.config.pagination.apply_params(&mut params, &state);
138
139 let url_override = match &self.config.pagination {
140 PaginationStyle::LinkHeader => state.next_link.clone(),
141 _ => None,
142 };
143
144 let params_clone = params.clone();
145 let (body, resp_headers) = retry::execute_with_retry(
146 self.config.max_retries,
147 self.config.retry_backoff,
148 || self.execute_request(¶ms_clone, url_override.as_deref()),
149 )
150 .await?;
151
152 let records = extract::extract_records(&body, self.config.records_path.as_deref())?;
153 let count = records.len();
154
155 yield records;
156
157 let has_next = self
158 .config
159 .pagination
160 .advance(&body, &resp_headers, &mut state, count)?;
161 pages_fetched += 1;
162 if !has_next {
163 break;
164 }
165
166 if let Some(delay) = self.config.request_delay {
167 tokio::time::sleep(delay).await;
168 }
169 }
170 })
171 }
172
173 async fn execute_request(
178 &self,
179 params: &HashMap<String, String>,
180 url_override: Option<&str>,
181 ) -> Result<(Value, HeaderMap), FaucetError> {
182 let use_override = url_override.is_some();
183 let url = match url_override {
184 Some(u) => u.to_string(),
185 None => format!(
186 "{}/{}",
187 self.config.base_url,
188 self.config.path.trim_start_matches('/')
189 ),
190 };
191
192 let mut headers = self.config.headers.clone();
193 self.config.auth.apply(&mut headers)?;
194
195 let mut req = self
196 .client
197 .request(self.config.method.clone(), &url)
198 .headers(headers);
199
200 if !use_override {
201 req = req.query(params);
202 }
203
204 if let Some(body) = &self.config.body {
205 req = req.json(body);
206 }
207
208 let resp = req.send().await?.error_for_status()?;
209 let resp_headers = resp.headers().clone();
210 let body: Value = resp.json().await?;
211 Ok((body, resp_headers))
212 }
213}