1use crate::auth::Auth;
4use crate::auth::oauth2::TokenCache;
5use crate::config::RestStreamConfig;
6use crate::error::FaucetError;
7use crate::extract;
8use crate::pagination::{PaginationState, PaginationStyle};
9use crate::replication::{ReplicationMethod, filter_incremental, max_replication_value};
10use crate::retry;
11use crate::schema;
12use crate::transform::{self, CompiledTransform};
13use futures_core::Stream;
14use reqwest::Client;
15use reqwest::header::HeaderMap;
16use serde::Deserialize;
17use serde_json::Value;
18use std::collections::HashMap;
19use std::pin::Pin;
20use std::time::Duration;
21
22pub struct RestStream {
24 config: RestStreamConfig,
25 client: Client,
26 compiled_transforms: Vec<CompiledTransform>,
28 token_cache: TokenCache,
30}
31
32impl RestStream {
33 pub fn new(config: RestStreamConfig) -> Result<Self, FaucetError> {
39 if let Auth::OAuth2 { expiry_ratio, .. } = &config.auth
41 && (*expiry_ratio <= 0.0 || *expiry_ratio > 1.0)
42 {
43 return Err(FaucetError::Auth(format!(
44 "expiry_ratio must be in (0.0, 1.0], got {expiry_ratio}"
45 )));
46 }
47
48 let compiled_transforms = config
49 .transforms
50 .iter()
51 .map(transform::compile)
52 .collect::<Result<Vec<_>, _>>()?;
53
54 let mut builder = Client::builder();
55 if let Some(t) = config.timeout {
56 builder = builder.timeout(t);
57 }
58 Ok(Self {
59 config,
60 client: builder.build()?,
61 compiled_transforms,
62 token_cache: TokenCache::new(),
63 })
64 }
65
66 pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
75 if self.config.partitions.is_empty() {
76 self.fetch_partition(None, None).await
77 } else {
78 let mut all_records = Vec::new();
79 for ctx in &self.config.partitions {
80 let records = self.fetch_partition(Some(ctx), None).await?;
81 all_records.extend(records);
82 }
83 Ok(all_records)
84 }
85 }
86
87 pub async fn fetch_all_as<T: for<'de> Deserialize<'de>>(&self) -> Result<Vec<T>, FaucetError> {
89 let values = self.fetch_all().await?;
90 values
91 .into_iter()
92 .map(|v| serde_json::from_value(v).map_err(FaucetError::Json))
93 .collect()
94 }
95
96 pub async fn infer_schema(&self) -> Result<Value, FaucetError> {
108 if let Some(ref s) = self.config.schema {
109 return Ok(s.clone());
110 }
111 let limit = match self.config.schema_sample_size {
112 0 => None,
113 n => Some(n),
114 };
115 let records = self.fetch_partition(None, limit).await?;
116 Ok(schema::infer_schema(&records))
117 }
118
119 pub async fn fetch_all_incremental(&self) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
128 let records = self.fetch_all().await?;
129 let bookmark = self
130 .config
131 .replication_key
132 .as_deref()
133 .and_then(|key| max_replication_value(&records, key))
134 .cloned();
135 Ok((records, bookmark))
136 }
137
138 pub fn stream_pages(
161 &self,
162 ) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + Send + '_>> {
163 self.stream_pages_inner(None)
164 }
165
166 fn stream_pages_inner(
174 &self,
175 context: Option<&HashMap<String, Value>>,
176 ) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + Send + '_>> {
177 let owned_context: Option<HashMap<String, Value>> = context.cloned();
180
181 Box::pin(async_stream::try_stream! {
182 let mut state = PaginationState::default();
183 let mut pages_fetched = 0usize;
184
185 loop {
186 if let Some(max) = self.config.max_pages
187 && pages_fetched >= max
188 {
189 tracing::warn!("max pages ({max}) reached");
190 break;
191 }
192
193 let mut params = self.config.query_params.clone();
194 self.config.pagination.apply_params(&mut params, &state);
195
196 let url_override = match &self.config.pagination {
197 PaginationStyle::LinkHeader | PaginationStyle::NextLinkInBody { .. } => {
198 state.next_link.clone()
199 }
200 _ => None,
201 };
202
203 let params_clone = params.clone();
204 let ctx_ref = owned_context.as_ref();
205 let (body, resp_headers) = retry::execute_with_retry(
206 self.config.max_retries,
207 self.config.retry_backoff,
208 || self.execute_request(¶ms_clone, url_override.as_deref(), ctx_ref),
209 )
210 .await?;
211
212 let raw_records =
213 extract::extract_records(&body, self.config.records_path.as_deref())?;
214 let raw_count = raw_records.len();
215
216 let records =
217 if self.config.replication_method == ReplicationMethod::Incremental {
218 if let (Some(key), Some(start)) = (
219 &self.config.replication_key,
220 &self.config.start_replication_value,
221 ) {
222 filter_incremental(raw_records, key, start)
223 } else {
224 raw_records
225 }
226 } else {
227 raw_records
228 };
229
230 let records: Vec<Value> = records
231 .into_iter()
232 .map(|rec| transform::apply_all(rec, &self.compiled_transforms))
233 .collect();
234
235 yield records;
236
237 let has_next = self
238 .config
239 .pagination
240 .advance(&body, &resp_headers, &mut state, raw_count)?;
241 pages_fetched += 1;
242 if !has_next {
243 break;
244 }
245
246 if let Some(delay) = self.config.request_delay {
247 tokio::time::sleep(delay).await;
248 }
249 }
250 })
251 }
252
253 async fn fetch_partition(
258 &self,
259 context: Option<&HashMap<String, Value>>,
260 max_records: Option<usize>,
261 ) -> Result<Vec<Value>, FaucetError> {
262 let mut all_records = Vec::new();
263 let mut pages_fetched = 0usize;
264 let mut pages = self.stream_pages_inner(context);
265
266 loop {
268 let page = std::future::poll_fn(|cx: &mut std::task::Context<'_>| {
269 pages.as_mut().poll_next(cx)
270 })
271 .await;
272
273 match page {
274 Some(Ok(records)) => {
275 pages_fetched += 1;
276 match max_records {
277 Some(limit) => {
278 let remaining = limit.saturating_sub(all_records.len());
279 all_records.extend(records.into_iter().take(remaining));
280 if all_records.len() >= limit {
281 break;
282 }
283 }
284 None => all_records.extend(records),
285 }
286 }
287 Some(Err(e)) => return Err(e),
288 None => break,
289 }
290 }
291
292 tracing::info!(
293 stream = self.config.name.as_deref().unwrap_or("(unnamed)"),
294 records = all_records.len(),
295 pages = pages_fetched,
296 "fetch complete"
297 );
298 Ok(all_records)
299 }
300
301 async fn execute_request(
308 &self,
309 params: &HashMap<String, String>,
310 url_override: Option<&str>,
311 path_context: Option<&HashMap<String, Value>>,
312 ) -> Result<(Value, HeaderMap), FaucetError> {
313 let use_override = url_override.is_some();
314 let url = match url_override {
315 Some(u) => u.to_string(),
316 None => {
317 let path = match path_context {
318 Some(ctx) => resolve_path(&self.config.path, ctx),
319 None => self.config.path.clone(),
320 };
321 format!("{}/{}", self.config.base_url, path.trim_start_matches('/'))
322 }
323 };
324
325 let resolved_auth = match &self.config.auth {
329 Auth::OAuth2 {
330 token_url,
331 client_id,
332 client_secret,
333 scopes,
334 expiry_ratio,
335 } => {
336 let token = self
337 .token_cache
338 .get_or_refresh(
339 &self.client,
340 token_url,
341 client_id,
342 client_secret,
343 scopes,
344 *expiry_ratio,
345 )
346 .await?;
347 Auth::Bearer(token)
348 }
349 other => other.clone(),
350 };
351
352 let mut headers = self.config.headers.clone();
353 resolved_auth.apply(&mut headers)?;
354
355 let mut req = self
356 .client
357 .request(self.config.method.clone(), &url)
358 .headers(headers);
359
360 if !use_override {
361 req = req.query(params);
362 }
363
364 if let Auth::ApiKeyQuery { param, value } = &self.config.auth {
366 req = req.query(&[(param.as_str(), value.as_str())]);
367 }
368
369 if let Some(body) = &self.config.body {
370 req = req.json(body);
371 }
372
373 let resp = req.send().await?;
374 let status = resp.status();
375
376 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
378 let wait = parse_retry_after(resp.headers());
379 return Err(FaucetError::RateLimited(wait));
380 }
381
382 if self.config.tolerated_http_errors.contains(&status.as_u16()) {
384 tracing::debug!(
385 status = status.as_u16(),
386 "tolerated HTTP error; treating as empty page"
387 );
388 return Ok((Value::Array(vec![]), HeaderMap::new()));
389 }
390
391 if !status.is_success() {
395 let resp_url = resp.url().to_string();
396 let body_text = resp.text().await.unwrap_or_default();
397 let truncated = if body_text.len() > 1024 {
399 let end = body_text.floor_char_boundary(1024);
401 format!("{}...(truncated)", &body_text[..end])
402 } else {
403 body_text
404 };
405 return Err(FaucetError::HttpStatus {
406 status: status.as_u16(),
407 url: resp_url,
408 body: truncated,
409 });
410 }
411
412 let resp_headers = resp.headers().clone();
413 let body: Value = resp.json().await?;
414 Ok((body, resp_headers))
415 }
416}
417
418fn resolve_path(path: &str, context: &HashMap<String, Value>) -> String {
420 let mut result = path.to_string();
421 for (key, value) in context {
422 let placeholder = format!("{{{key}}}");
423 let replacement = match value {
424 Value::String(s) => s.clone(),
425 other => other.to_string(),
426 };
427 result = result.replace(&placeholder, &replacement);
428 }
429 result
430}
431
432fn parse_retry_after(headers: &HeaderMap) -> Duration {
435 headers
436 .get(reqwest::header::RETRY_AFTER)
437 .and_then(|v| v.to_str().ok())
438 .and_then(|s| s.parse::<u64>().ok())
439 .map(Duration::from_secs)
440 .unwrap_or(Duration::from_secs(60))
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use serde_json::json;
447
448 #[test]
449 fn test_resolve_path_substitutes_placeholders() {
450 let mut ctx = HashMap::new();
451 ctx.insert("org_id".to_string(), json!("acme"));
452 ctx.insert("repo".to_string(), json!("myrepo"));
453 let result = resolve_path("/orgs/{org_id}/repos/{repo}/issues", &ctx);
454 assert_eq!(result, "/orgs/acme/repos/myrepo/issues");
455 }
456
457 #[test]
458 fn test_resolve_path_no_placeholders() {
459 let ctx = HashMap::new();
460 let result = resolve_path("/api/users", &ctx);
461 assert_eq!(result, "/api/users");
462 }
463
464 #[test]
465 fn test_resolve_path_numeric_value() {
466 let mut ctx = HashMap::new();
467 ctx.insert("id".to_string(), json!(42));
468 let result = resolve_path("/items/{id}", &ctx);
469 assert_eq!(result, "/items/42");
470 }
471
472 #[test]
473 fn test_parse_retry_after_valid() {
474 let mut headers = HeaderMap::new();
475 headers.insert(
476 reqwest::header::RETRY_AFTER,
477 reqwest::header::HeaderValue::from_static("30"),
478 );
479 assert_eq!(parse_retry_after(&headers), Duration::from_secs(30));
480 }
481
482 #[test]
483 fn test_parse_retry_after_missing_defaults_to_60() {
484 assert_eq!(
485 parse_retry_after(&HeaderMap::new()),
486 Duration::from_secs(60)
487 );
488 }
489}