1use crate::auth::Auth;
4use crate::auth::oauth2::TokenCache;
5use crate::auth::token_endpoint::TokenEndpointCache;
6use crate::config::RestStreamConfig;
7use crate::extract;
8use crate::pagination::{PaginationState, PaginationStyle};
9use crate::retry;
10use async_trait::async_trait;
11use faucet_core::FaucetError;
12use faucet_core::replication::{ReplicationMethod, filter_incremental, max_replication_value};
13use faucet_core::schema;
14use faucet_core::transform::{self, CompiledTransform};
15use futures_core::Stream;
16use reqwest::Client;
17use reqwest::header::HeaderMap;
18use serde::Deserialize;
19use serde_json::Value;
20use std::collections::HashMap;
21use std::pin::Pin;
22use std::time::Duration;
23
24pub struct RestStream {
26 config: RestStreamConfig,
27 client: Client,
28 compiled_transforms: Vec<CompiledTransform>,
30 token_cache: TokenCache,
32 token_endpoint_cache: TokenEndpointCache,
34}
35
36impl RestStream {
37 pub fn new(config: RestStreamConfig) -> Result<Self, FaucetError> {
43 let expiry_ratio_to_validate = match &config.auth {
45 Auth::OAuth2 { expiry_ratio, .. } | Auth::TokenEndpoint { expiry_ratio, .. } => {
46 Some(*expiry_ratio)
47 }
48 _ => None,
49 };
50 if let Some(ratio) = expiry_ratio_to_validate
51 && (ratio <= 0.0 || ratio > 1.0)
52 {
53 return Err(FaucetError::Auth(format!(
54 "expiry_ratio must be in (0.0, 1.0], got {ratio}"
55 )));
56 }
57
58 let compiled_transforms = config
59 .transforms
60 .iter()
61 .map(transform::compile)
62 .collect::<Result<Vec<_>, _>>()?;
63
64 let mut builder = Client::builder();
65 if let Some(t) = config.timeout {
66 builder = builder.timeout(t);
67 }
68 Ok(Self {
69 config,
70 client: builder.build()?,
71 compiled_transforms,
72 token_cache: TokenCache::new(),
73 token_endpoint_cache: TokenEndpointCache::new(),
74 })
75 }
76
77 pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
86 if self.config.partitions.is_empty() {
87 self.fetch_partition(None, None).await
88 } else if let Some(concurrency) = self.config.partition_concurrency {
89 let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(concurrency.max(1)));
91 let mut handles = Vec::with_capacity(self.config.partitions.len());
92
93 for ctx in &self.config.partitions {
94 let permit =
95 semaphore.clone().acquire_owned().await.map_err(|e| {
96 FaucetError::Config(format!("semaphore acquire failed: {e}"))
97 })?;
98 let fut = self.fetch_partition(Some(ctx), None);
99 handles.push(async move {
100 let result = fut.await;
101 drop(permit);
102 result
103 });
104 }
105
106 let results = futures::future::try_join_all(handles).await?;
107 Ok(results.into_iter().flatten().collect())
108 } else {
109 let mut all_records = Vec::new();
110 for ctx in &self.config.partitions {
111 let records = self.fetch_partition(Some(ctx), None).await?;
112 all_records.extend(records);
113 }
114 Ok(all_records)
115 }
116 }
117
118 pub async fn fetch_all_as<T: for<'de> Deserialize<'de>>(&self) -> Result<Vec<T>, FaucetError> {
120 let values = self.fetch_all().await?;
121 values
122 .into_iter()
123 .map(|v| serde_json::from_value(v).map_err(FaucetError::Json))
124 .collect()
125 }
126
127 pub async fn infer_schema(&self) -> Result<Value, FaucetError> {
139 if let Some(ref s) = self.config.schema {
140 return Ok(s.clone());
141 }
142 let limit = match self.config.schema_sample_size {
143 0 => None,
144 n => Some(n),
145 };
146 let records = self.fetch_partition(None, limit).await?;
147 Ok(schema::infer_schema(&records))
148 }
149
150 pub async fn fetch_all_incremental(&self) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
159 let records = self.fetch_all().await?;
160 let bookmark = self
161 .config
162 .replication_key
163 .as_deref()
164 .and_then(|key| max_replication_value(&records, key))
165 .cloned();
166 Ok((records, bookmark))
167 }
168
169 pub fn stream_pages(
192 &self,
193 ) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + Send + '_>> {
194 self.stream_pages_inner(None)
195 }
196
197 fn stream_pages_inner(
205 &self,
206 context: Option<&HashMap<String, Value>>,
207 ) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + Send + '_>> {
208 let owned_context: Option<HashMap<String, Value>> = context.cloned();
211
212 Box::pin(async_stream::try_stream! {
213 let mut state = PaginationState::default();
214 let mut pages_fetched = 0usize;
215
216 loop {
217 if let Some(max) = self.config.max_pages
218 && pages_fetched >= max
219 {
220 tracing::warn!("max pages ({max}) reached");
221 break;
222 }
223
224 let mut params = self.config.query_params.clone();
225 self.config.pagination.apply_params(&mut params, &state);
226
227 let url_override = match &self.config.pagination {
228 PaginationStyle::LinkHeader | PaginationStyle::NextLinkInBody { .. } => {
229 state.next_link.clone()
230 }
231 _ => None,
232 };
233
234 let params_clone = params.clone();
235 let ctx_ref = owned_context.as_ref();
236 let (body, resp_headers) = retry::execute_with_retry(
237 self.config.max_retries,
238 self.config.retry_backoff,
239 || self.execute_request(¶ms_clone, url_override.as_deref(), ctx_ref),
240 )
241 .await?;
242
243 let raw_records =
244 extract::extract_records(&body, self.config.records_path.as_deref())?;
245 let raw_count = raw_records.len();
246
247 let records =
248 if self.config.replication_method == ReplicationMethod::Incremental {
249 if let (Some(key), Some(start)) = (
250 &self.config.replication_key,
251 &self.config.start_replication_value,
252 ) {
253 filter_incremental(raw_records, key, start)
254 } else {
255 raw_records
256 }
257 } else {
258 raw_records
259 };
260
261 let records: Vec<Value> = records
262 .into_iter()
263 .map(|rec| transform::apply_all(rec, &self.compiled_transforms))
264 .collect();
265
266 yield records;
267
268 let has_next = self
269 .config
270 .pagination
271 .advance(&body, &resp_headers, &mut state, raw_count)?;
272 pages_fetched += 1;
273 if !has_next {
274 break;
275 }
276
277 if let Some(delay) = self.config.request_delay {
278 tokio::time::sleep(delay).await;
279 }
280 }
281 })
282 }
283
284 async fn fetch_partition(
289 &self,
290 context: Option<&HashMap<String, Value>>,
291 max_records: Option<usize>,
292 ) -> Result<Vec<Value>, FaucetError> {
293 let mut all_records = Vec::new();
294 let mut pages_fetched = 0usize;
295 let mut pages = self.stream_pages_inner(context);
296
297 loop {
299 let page = std::future::poll_fn(|cx: &mut std::task::Context<'_>| {
300 pages.as_mut().poll_next(cx)
301 })
302 .await;
303
304 match page {
305 Some(Ok(records)) => {
306 pages_fetched += 1;
307 match max_records {
308 Some(limit) => {
309 let remaining = limit.saturating_sub(all_records.len());
310 all_records.extend(records.into_iter().take(remaining));
311 if all_records.len() >= limit {
312 break;
313 }
314 }
315 None => all_records.extend(records),
316 }
317 }
318 Some(Err(e)) => return Err(e),
319 None => break,
320 }
321 }
322
323 tracing::info!(
324 stream = self.config.name.as_deref().unwrap_or("(unnamed)"),
325 records = all_records.len(),
326 pages = pages_fetched,
327 "fetch complete"
328 );
329 Ok(all_records)
330 }
331
332 async fn execute_request(
339 &self,
340 params: &HashMap<String, String>,
341 url_override: Option<&str>,
342 path_context: Option<&HashMap<String, Value>>,
343 ) -> Result<(Value, HeaderMap), FaucetError> {
344 let use_override = url_override.is_some();
345 let url = match url_override {
346 Some(u) => u.to_string(),
347 None => {
348 let path = match path_context {
349 Some(ctx) => resolve_path(&self.config.path, ctx),
350 None => self.config.path.clone(),
351 };
352 format!("{}/{}", self.config.base_url, path.trim_start_matches('/'))
353 }
354 };
355
356 let resolved_auth = match &self.config.auth {
360 Auth::OAuth2 {
361 token_url,
362 client_id,
363 client_secret,
364 scopes,
365 expiry_ratio,
366 } => {
367 let token = self
368 .token_cache
369 .get_or_refresh(
370 &self.client,
371 token_url,
372 client_id,
373 client_secret,
374 scopes,
375 *expiry_ratio,
376 )
377 .await?;
378 Auth::Bearer(token)
379 }
380 Auth::TokenEndpoint {
381 url: token_url,
382 method: token_method,
383 headers: token_headers,
384 body: token_body,
385 token_path,
386 expiry_path,
387 expiry_ratio,
388 response_validator,
389 } => {
390 let token = self
391 .token_endpoint_cache
392 .get_or_refresh(
393 &self.client,
394 token_url,
395 token_method,
396 token_headers,
397 token_body.as_ref(),
398 token_path,
399 expiry_path.as_deref(),
400 *expiry_ratio,
401 response_validator.as_ref(),
402 )
403 .await?;
404 Auth::Bearer(token)
405 }
406 other => other.clone(),
407 };
408
409 let mut headers = self.config.headers.clone();
410 resolved_auth.apply(&mut headers)?;
411
412 let mut req = self
413 .client
414 .request(self.config.method.clone(), &url)
415 .headers(headers);
416
417 if !use_override {
418 req = req.query(params);
419 }
420
421 if let Auth::ApiKeyQuery { param, value } = &self.config.auth {
423 req = req.query(&[(param.as_str(), value.as_str())]);
424 }
425
426 if let Some(body) = &self.config.body {
427 req = req.json(body);
428 }
429
430 let resp = req.send().await?;
431 let status = resp.status();
432
433 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
435 let wait = parse_retry_after(resp.headers());
436 return Err(FaucetError::RateLimited(wait));
437 }
438
439 if self.config.tolerated_http_errors.contains(&status.as_u16()) {
441 tracing::debug!(
442 status = status.as_u16(),
443 "tolerated HTTP error; treating as empty page"
444 );
445 return Ok((Value::Array(vec![]), HeaderMap::new()));
446 }
447
448 if !status.is_success() {
452 let resp_url = resp.url().to_string();
453 let body_text = resp.text().await.unwrap_or_default();
454 let truncated = if body_text.len() > 1024 {
456 let end = body_text.floor_char_boundary(1024);
458 format!("{}...(truncated)", &body_text[..end])
459 } else {
460 body_text
461 };
462 return Err(FaucetError::HttpStatus {
463 status: status.as_u16(),
464 url: resp_url,
465 body: truncated,
466 });
467 }
468
469 let resp_headers = resp.headers().clone();
470 let body: Value = resp.json().await?;
471 Ok((body, resp_headers))
472 }
473}
474
475fn resolve_path(path: &str, context: &HashMap<String, Value>) -> String {
477 let mut result = path.to_string();
478 for (key, value) in context {
479 let placeholder = format!("{{{key}}}");
480 let replacement = match value {
481 Value::String(s) => s.clone(),
482 other => other.to_string(),
483 };
484 result = result.replace(&placeholder, &replacement);
485 }
486 result
487}
488
489fn parse_retry_after(headers: &HeaderMap) -> Duration {
492 headers
493 .get(reqwest::header::RETRY_AFTER)
494 .and_then(|v| v.to_str().ok())
495 .and_then(|s| s.parse::<u64>().ok())
496 .map(Duration::from_secs)
497 .unwrap_or(Duration::from_secs(60))
498}
499
500#[async_trait]
501impl faucet_core::Source for RestStream {
502 async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
503 RestStream::fetch_all(self).await
504 }
505
506 async fn fetch_all_incremental(&self) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
507 RestStream::fetch_all_incremental(self).await
508 }
509
510 fn config_schema(&self) -> serde_json::Value {
511 serde_json::to_value(faucet_core::schema_for!(RestStreamConfig))
512 .expect("schema serialization")
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use serde_json::json;
520
521 #[test]
522 fn test_resolve_path_substitutes_placeholders() {
523 let mut ctx = HashMap::new();
524 ctx.insert("org_id".to_string(), json!("acme"));
525 ctx.insert("repo".to_string(), json!("myrepo"));
526 let result = resolve_path("/orgs/{org_id}/repos/{repo}/issues", &ctx);
527 assert_eq!(result, "/orgs/acme/repos/myrepo/issues");
528 }
529
530 #[test]
531 fn test_resolve_path_no_placeholders() {
532 let ctx = HashMap::new();
533 let result = resolve_path("/api/users", &ctx);
534 assert_eq!(result, "/api/users");
535 }
536
537 #[test]
538 fn test_resolve_path_numeric_value() {
539 let mut ctx = HashMap::new();
540 ctx.insert("id".to_string(), json!(42));
541 let result = resolve_path("/items/{id}", &ctx);
542 assert_eq!(result, "/items/42");
543 }
544
545 #[test]
546 fn test_parse_retry_after_valid() {
547 let mut headers = HeaderMap::new();
548 headers.insert(
549 reqwest::header::RETRY_AFTER,
550 reqwest::header::HeaderValue::from_static("30"),
551 );
552 assert_eq!(parse_retry_after(&headers), Duration::from_secs(30));
553 }
554
555 #[test]
556 fn test_parse_retry_after_missing_defaults_to_60() {
557 assert_eq!(
558 parse_retry_after(&HeaderMap::new()),
559 Duration::from_secs(60)
560 );
561 }
562
563 #[test]
564 fn test_parse_retry_after_non_numeric_defaults_to_60() {
565 let mut headers = HeaderMap::new();
566 headers.insert(
567 reqwest::header::RETRY_AFTER,
568 reqwest::header::HeaderValue::from_static("not-a-number"),
569 );
570 assert_eq!(parse_retry_after(&headers), Duration::from_secs(60));
571 }
572
573 #[test]
574 fn test_new_rejects_invalid_expiry_ratio_zero() {
575 let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
576 token_url: "https://auth.example.com/token".into(),
577 client_id: "id".into(),
578 client_secret: "secret".into(),
579 scopes: vec![],
580 expiry_ratio: 0.0,
581 });
582 let result = RestStream::new(config);
583 assert!(result.is_err());
584 assert!(matches!(result, Err(FaucetError::Auth(_))));
585 }
586
587 #[test]
588 fn test_new_rejects_invalid_expiry_ratio_negative() {
589 let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
590 token_url: "https://auth.example.com/token".into(),
591 client_id: "id".into(),
592 client_secret: "secret".into(),
593 scopes: vec![],
594 expiry_ratio: -0.5,
595 });
596 assert!(RestStream::new(config).is_err());
597 }
598
599 #[test]
600 fn test_new_rejects_invalid_expiry_ratio_above_one() {
601 let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
602 token_url: "https://auth.example.com/token".into(),
603 client_id: "id".into(),
604 client_secret: "secret".into(),
605 scopes: vec![],
606 expiry_ratio: 1.5,
607 });
608 assert!(RestStream::new(config).is_err());
609 }
610
611 #[test]
612 fn test_new_accepts_valid_expiry_ratio() {
613 let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
614 token_url: "https://auth.example.com/token".into(),
615 client_id: "id".into(),
616 client_secret: "secret".into(),
617 scopes: vec![],
618 expiry_ratio: 1.0,
619 });
620 assert!(RestStream::new(config).is_ok());
621 }
622
623 #[test]
624 fn test_new_rejects_invalid_transform_regex() {
625 let config = RestStreamConfig::new("https://example.com", "/data").add_transform(
626 faucet_core::RecordTransform::RenameKeys {
627 pattern: "[invalid".into(),
628 replacement: "".into(),
629 },
630 );
631 let result = RestStream::new(config);
632 assert!(result.is_err());
633 assert!(matches!(result, Err(FaucetError::Transform(_))));
634 }
635
636 #[test]
637 fn test_new_with_no_auth_succeeds() {
638 let config = RestStreamConfig::new("https://example.com", "/data");
639 assert!(RestStream::new(config).is_ok());
640 }
641
642 #[test]
643 fn test_new_with_timeout() {
644 let config =
645 RestStreamConfig::new("https://example.com", "/data").timeout(Duration::from_secs(10));
646 assert!(RestStream::new(config).is_ok());
647 }
648
649 #[test]
650 fn test_resolve_path_missing_placeholder_unchanged() {
651 let mut ctx = HashMap::new();
652 ctx.insert("org".to_string(), json!("acme"));
653 let result = resolve_path("/items/{missing}", &ctx);
654 assert_eq!(result, "/items/{missing}");
655 }
656
657 #[test]
658 fn test_resolve_path_boolean_value() {
659 let mut ctx = HashMap::new();
660 ctx.insert("flag".to_string(), json!(true));
661 let result = resolve_path("/items/{flag}", &ctx);
662 assert_eq!(result, "/items/true");
663 }
664}