1use crate::config::RestStreamConfig;
4use crate::error::FaucetError;
5use crate::extract;
6use crate::pagination::{PaginationState, PaginationStyle};
7use crate::retry;
8use futures_core::Stream;
9use reqwest::Client;
10use reqwest::header::HeaderMap;
11use serde::Deserialize;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::pin::Pin;
15
16pub struct RestStream {
18 config: RestStreamConfig,
19 client: Client,
20}
21
22impl RestStream {
23 pub fn new(config: RestStreamConfig) -> Result<Self, FaucetError> {
25 let mut builder = Client::builder();
26 if let Some(t) = config.timeout {
27 builder = builder.timeout(t);
28 }
29 Ok(Self {
30 config,
31 client: builder.build()?,
32 })
33 }
34
35 pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
37 let mut all_records = Vec::new();
38 let mut state = PaginationState::default();
39 let mut pages_fetched = 0usize;
40
41 loop {
42 if let Some(max) = self.config.max_pages
43 && pages_fetched >= max
44 {
45 tracing::warn!("max pages ({max}) reached");
46 break;
47 }
48
49 let mut params = self.config.query_params.clone();
50 self.config.pagination.apply_params(&mut params, &state);
51
52 let url_override = match &self.config.pagination {
55 PaginationStyle::LinkHeader | PaginationStyle::NextLinkInBody { .. } => {
56 state.next_link.clone()
57 }
58 _ => None,
59 };
60
61 let params_clone = params.clone();
62 let (body, resp_headers) = retry::execute_with_retry(
63 self.config.max_retries,
64 self.config.retry_backoff,
65 || self.execute_request(¶ms_clone, url_override.as_deref()),
66 )
67 .await?;
68
69 let records = extract::extract_records(&body, self.config.records_path.as_deref())?;
70 let count = records.len();
71 all_records.extend(records);
72
73 let has_next =
74 self.config
75 .pagination
76 .advance(&body, &resp_headers, &mut state, count)?;
77 pages_fetched += 1;
78 if !has_next {
79 break;
80 }
81
82 if let Some(delay) = self.config.request_delay {
83 tokio::time::sleep(delay).await;
84 }
85 }
86
87 tracing::info!(
88 "fetched {} total records across {} page(s)",
89 all_records.len(),
90 pages_fetched
91 );
92 Ok(all_records)
93 }
94
95 pub async fn fetch_all_as<T: for<'de> Deserialize<'de>>(&self) -> Result<Vec<T>, FaucetError> {
97 let values = self.fetch_all().await?;
98 values
99 .into_iter()
100 .map(|v| serde_json::from_value(v).map_err(FaucetError::Json))
101 .collect()
102 }
103
104 pub fn stream_pages(
124 &self,
125 ) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + '_>> {
126 Box::pin(async_stream::try_stream! {
127 let mut state = PaginationState::default();
128 let mut pages_fetched = 0usize;
129
130 loop {
131 if let Some(max) = self.config.max_pages
132 && pages_fetched >= max
133 {
134 tracing::warn!("max pages ({max}) reached");
135 break;
136 }
137
138 let mut params = self.config.query_params.clone();
139 self.config.pagination.apply_params(&mut params, &state);
140
141 let url_override = match &self.config.pagination {
142 PaginationStyle::LinkHeader | PaginationStyle::NextLinkInBody { .. } => {
143 state.next_link.clone()
144 }
145 _ => None,
146 };
147
148 let params_clone = params.clone();
149 let (body, resp_headers) = retry::execute_with_retry(
150 self.config.max_retries,
151 self.config.retry_backoff,
152 || self.execute_request(¶ms_clone, url_override.as_deref()),
153 )
154 .await?;
155
156 let records = extract::extract_records(&body, self.config.records_path.as_deref())?;
157 let count = records.len();
158
159 yield records;
160
161 let has_next = self
162 .config
163 .pagination
164 .advance(&body, &resp_headers, &mut state, count)?;
165 pages_fetched += 1;
166 if !has_next {
167 break;
168 }
169
170 if let Some(delay) = self.config.request_delay {
171 tokio::time::sleep(delay).await;
172 }
173 }
174 })
175 }
176
177 async fn execute_request(
182 &self,
183 params: &HashMap<String, String>,
184 url_override: Option<&str>,
185 ) -> Result<(Value, HeaderMap), FaucetError> {
186 let use_override = url_override.is_some();
187 let url = match url_override {
188 Some(u) => u.to_string(),
189 None => format!(
190 "{}/{}",
191 self.config.base_url,
192 self.config.path.trim_start_matches('/')
193 ),
194 };
195
196 let mut headers = self.config.headers.clone();
197 self.config.auth.apply(&mut headers)?;
198
199 let mut req = self
200 .client
201 .request(self.config.method.clone(), &url)
202 .headers(headers);
203
204 if !use_override {
205 req = req.query(params);
206 }
207
208 if let Some(body) = &self.config.body {
209 req = req.json(body);
210 }
211
212 let resp = req.send().await?.error_for_status()?;
213 let resp_headers = resp.headers().clone();
214 let body: Value = resp.json().await?;
215 Ok((body, resp_headers))
216 }
217}