Skip to main content

faucet_source_rest/
stream.rs

1//! The main REST stream executor.
2
3use crate::auth::Auth;
4use crate::auth::oauth2::TokenCache;
5use crate::auth::token_endpoint::TokenEndpointCache;
6use crate::config::RestStreamConfig;
7use crate::extract;
8use crate::pagination::{PaginationState, PaginationStyle};
9use crate::retry;
10use async_trait::async_trait;
11use faucet_core::FaucetError;
12use faucet_core::replication::{ReplicationMethod, filter_incremental, max_replication_value};
13use faucet_core::schema;
14use faucet_core::transform::{self, CompiledTransform};
15use futures_core::Stream;
16use reqwest::Client;
17use reqwest::header::HeaderMap;
18use serde::Deserialize;
19use serde_json::Value;
20use std::collections::HashMap;
21use std::pin::Pin;
22use std::time::Duration;
23
24/// A configured REST API stream that handles pagination, auth, and extraction.
25pub struct RestStream {
26    config: RestStreamConfig,
27    client: Client,
28    /// Pre-compiled transforms (regex patterns compiled once at construction time).
29    compiled_transforms: Vec<CompiledTransform>,
30    /// Shared OAuth2 token cache (only used when `config.auth` is `Auth::OAuth2`).
31    token_cache: TokenCache,
32    /// Shared token endpoint cache (only used when `config.auth` is `Auth::TokenEndpoint`).
33    token_endpoint_cache: TokenEndpointCache,
34}
35
36impl RestStream {
37    /// Create a new stream from the given configuration.
38    ///
39    /// Returns [`FaucetError::Transform`] immediately if any `RenameKeys`
40    /// transform contains an invalid regex pattern — fail-fast before any
41    /// HTTP requests are made.
42    pub fn new(config: RestStreamConfig) -> Result<Self, FaucetError> {
43        // Validate expiry_ratio at construction time.
44        let expiry_ratio_to_validate = match &config.auth {
45            Auth::OAuth2 { expiry_ratio, .. } | Auth::TokenEndpoint { expiry_ratio, .. } => {
46                Some(*expiry_ratio)
47            }
48            _ => None,
49        };
50        if let Some(ratio) = expiry_ratio_to_validate
51            && (ratio <= 0.0 || ratio > 1.0)
52        {
53            return Err(FaucetError::Auth(format!(
54                "expiry_ratio must be in (0.0, 1.0], got {ratio}"
55            )));
56        }
57
58        let compiled_transforms = config
59            .transforms
60            .iter()
61            .map(transform::compile)
62            .collect::<Result<Vec<_>, _>>()?;
63
64        let mut builder = Client::builder();
65        if let Some(t) = config.timeout {
66            builder = builder.timeout(t);
67        }
68        Ok(Self {
69            config,
70            client: builder.build()?,
71            compiled_transforms,
72            token_cache: TokenCache::new(),
73            token_endpoint_cache: TokenEndpointCache::new(),
74        })
75    }
76
77    /// Fetch all records across all pages as raw JSON values.
78    ///
79    /// When `partitions` are configured, the stream is executed once per
80    /// partition and all results are concatenated.
81    ///
82    /// When `replication_method` is `Incremental` and `replication_key` +
83    /// `start_replication_value` are both set, records at or before the
84    /// bookmark are filtered out.
85    pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
86        if self.config.partitions.is_empty() {
87            self.fetch_partition(None, None).await
88        } else if let Some(concurrency) = self.config.partition_concurrency {
89            // Process partitions concurrently using a semaphore to limit parallelism.
90            let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(concurrency.max(1)));
91            let mut handles = Vec::with_capacity(self.config.partitions.len());
92
93            for ctx in &self.config.partitions {
94                let permit =
95                    semaphore.clone().acquire_owned().await.map_err(|e| {
96                        FaucetError::Config(format!("semaphore acquire failed: {e}"))
97                    })?;
98                let fut = self.fetch_partition(Some(ctx), None);
99                handles.push(async move {
100                    let result = fut.await;
101                    drop(permit);
102                    result
103                });
104            }
105
106            let results = futures::future::try_join_all(handles).await?;
107            Ok(results.into_iter().flatten().collect())
108        } else {
109            let mut all_records = Vec::new();
110            for ctx in &self.config.partitions {
111                let records = self.fetch_partition(Some(ctx), None).await?;
112                all_records.extend(records);
113            }
114            Ok(all_records)
115        }
116    }
117
118    /// Fetch all records and deserialize into typed structs.
119    pub async fn fetch_all_as<T: for<'de> Deserialize<'de>>(&self) -> Result<Vec<T>, FaucetError> {
120        let values = self.fetch_all().await?;
121        values
122            .into_iter()
123            .map(|v| serde_json::from_value(v).map_err(FaucetError::Json))
124            .collect()
125    }
126
127    /// Infer a JSON Schema for this stream's records.
128    ///
129    /// If a `schema` is already set on the config, it is returned immediately
130    /// without making any HTTP requests.
131    ///
132    /// Otherwise the stream fetches up to `schema_sample_size` records
133    /// (respecting `max_pages`) and derives a JSON Schema from them.  Fields
134    /// that are absent in some records, or that carry a `null` value, are
135    /// marked as nullable (`["<type>", "null"]`).
136    ///
137    /// Set `schema_sample_size` to `0` to sample all available records.
138    pub async fn infer_schema(&self) -> Result<Value, FaucetError> {
139        if let Some(ref s) = self.config.schema {
140            return Ok(s.clone());
141        }
142        let limit = match self.config.schema_sample_size {
143            0 => None,
144            n => Some(n),
145        };
146        let records = self.fetch_partition(None, limit).await?;
147        Ok(schema::infer_schema(&records))
148    }
149
150    /// Fetch all records in incremental mode, returning the records along with
151    /// the maximum value of `replication_key` observed across those records.
152    ///
153    /// The returned bookmark should be persisted by the caller and passed back
154    /// as `start_replication_value` on the next run.
155    ///
156    /// If no `replication_key` is configured, this behaves identically to
157    /// [`fetch_all`](Self::fetch_all) and the bookmark is `None`.
158    pub async fn fetch_all_incremental(&self) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
159        let records = self.fetch_all().await?;
160        let bookmark = self
161            .config
162            .replication_key
163            .as_deref()
164            .and_then(|key| max_replication_value(&records, key))
165            .cloned();
166        Ok((records, bookmark))
167    }
168
169    /// Stream records page-by-page, yielding one `Vec<Value>` per page as it arrives.
170    ///
171    /// Unlike [`fetch_all`](Self::fetch_all), this does not wait for all pages to be fetched
172    /// before returning — callers can process each page immediately.
173    ///
174    /// Note: partitions are not supported by `stream_pages`. Use `fetch_all` for
175    /// multi-partition streams.
176    ///
177    /// ```rust,no_run
178    /// use faucet_source_rest::{RestStream, RestStreamConfig};
179    /// use futures::StreamExt;
180    ///
181    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
182    /// let stream = RestStream::new(RestStreamConfig::new("https://api.example.com", "/items"))?;
183    /// let mut pages = stream.stream_pages();
184    /// while let Some(page) = pages.next().await {
185    ///     let records = page?;
186    ///     println!("got {} records", records.len());
187    /// }
188    /// # Ok(())
189    /// # }
190    /// ```
191    pub fn stream_pages(
192        &self,
193    ) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + Send + '_>> {
194        self.stream_pages_inner(None)
195    }
196
197    // ── Private helpers ───────────────────────────────────────────────────────
198
199    /// Core pagination loop shared by [`stream_pages`](Self::stream_pages) and
200    /// [`fetch_partition`](Self::fetch_partition).
201    ///
202    /// Yields one `Vec<Value>` per page.  When `context` is `Some`, path
203    /// placeholders are substituted for partition support.
204    fn stream_pages_inner(
205        &self,
206        context: Option<&HashMap<String, Value>>,
207    ) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + Send + '_>> {
208        // Clone the context into an owned map so it can live inside the
209        // `async_stream` generator without borrowing from the caller.
210        let owned_context: Option<HashMap<String, Value>> = context.cloned();
211
212        Box::pin(async_stream::try_stream! {
213            let mut state = PaginationState::default();
214            let mut pages_fetched = 0usize;
215
216            loop {
217                if let Some(max) = self.config.max_pages
218                    && pages_fetched >= max
219                {
220                    tracing::warn!("max pages ({max}) reached");
221                    break;
222                }
223
224                let mut params = self.config.query_params.clone();
225                self.config.pagination.apply_params(&mut params, &state);
226
227                let url_override = match &self.config.pagination {
228                    PaginationStyle::LinkHeader | PaginationStyle::NextLinkInBody { .. } => {
229                        state.next_link.clone()
230                    }
231                    _ => None,
232                };
233
234                let params_clone = params.clone();
235                let ctx_ref = owned_context.as_ref();
236                let (body, resp_headers) = retry::execute_with_retry(
237                    self.config.max_retries,
238                    self.config.retry_backoff,
239                    || self.execute_request(&params_clone, url_override.as_deref(), ctx_ref),
240                )
241                .await?;
242
243                let raw_records =
244                    extract::extract_records(&body, self.config.records_path.as_deref())?;
245                let raw_count = raw_records.len();
246
247                let records =
248                    if self.config.replication_method == ReplicationMethod::Incremental {
249                        if let (Some(key), Some(start)) = (
250                            &self.config.replication_key,
251                            &self.config.start_replication_value,
252                        ) {
253                            filter_incremental(raw_records, key, start)
254                        } else {
255                            raw_records
256                        }
257                    } else {
258                        raw_records
259                    };
260
261                let records: Vec<Value> = records
262                    .into_iter()
263                    .map(|rec| transform::apply_all(rec, &self.compiled_transforms))
264                    .collect();
265
266                yield records;
267
268                let has_next = self
269                    .config
270                    .pagination
271                    .advance(&body, &resp_headers, &mut state, raw_count)?;
272                pages_fetched += 1;
273                if !has_next {
274                    break;
275                }
276
277                if let Some(delay) = self.config.request_delay {
278                    tokio::time::sleep(delay).await;
279                }
280            }
281        })
282    }
283
284    /// Run the full pagination loop for a single partition context.
285    ///
286    /// `max_records`: when `Some(n)`, stop collecting after `n` records
287    /// (used for schema sampling).
288    async fn fetch_partition(
289        &self,
290        context: Option<&HashMap<String, Value>>,
291        max_records: Option<usize>,
292    ) -> Result<Vec<Value>, FaucetError> {
293        let mut all_records = Vec::new();
294        let mut pages_fetched = 0usize;
295        let mut pages = self.stream_pages_inner(context);
296
297        // Poll the stream without requiring StreamExt (avoids extra dependency).
298        loop {
299            let page = std::future::poll_fn(|cx: &mut std::task::Context<'_>| {
300                pages.as_mut().poll_next(cx)
301            })
302            .await;
303
304            match page {
305                Some(Ok(records)) => {
306                    pages_fetched += 1;
307                    match max_records {
308                        Some(limit) => {
309                            let remaining = limit.saturating_sub(all_records.len());
310                            all_records.extend(records.into_iter().take(remaining));
311                            if all_records.len() >= limit {
312                                break;
313                            }
314                        }
315                        None => all_records.extend(records),
316                    }
317                }
318                Some(Err(e)) => return Err(e),
319                None => break,
320            }
321        }
322
323        tracing::info!(
324            stream = self.config.name.as_deref().unwrap_or("(unnamed)"),
325            records = all_records.len(),
326            pages = pages_fetched,
327            "fetch complete"
328        );
329        Ok(all_records)
330    }
331
332    /// Execute a single HTTP request and return the response body and headers.
333    ///
334    /// - When `url_override` is `Some`, that full URL is used and query params
335    ///   are **not** appended (Link header pagination encodes them in the URL).
336    /// - When `path_context` is `Some`, `{key}` placeholders in `config.path`
337    ///   are substituted with values from the context map (partition support).
338    async fn execute_request(
339        &self,
340        params: &HashMap<String, String>,
341        url_override: Option<&str>,
342        path_context: Option<&HashMap<String, Value>>,
343    ) -> Result<(Value, HeaderMap), FaucetError> {
344        let use_override = url_override.is_some();
345        let url = match url_override {
346            Some(u) => u.to_string(),
347            None => {
348                let path = match path_context {
349                    Some(ctx) => resolve_path(&self.config.path, ctx),
350                    None => self.config.path.clone(),
351                };
352                format!("{}/{}", self.config.base_url, path.trim_start_matches('/'))
353            }
354        };
355
356        // Resolve OAuth2 / TokenEndpoint credentials to a Bearer token before
357        // applying auth headers. Tokens are cached and reused until they expire,
358        // avoiding a token fetch on every HTTP request.
359        let resolved_auth = match &self.config.auth {
360            Auth::OAuth2 {
361                token_url,
362                client_id,
363                client_secret,
364                scopes,
365                expiry_ratio,
366            } => {
367                let token = self
368                    .token_cache
369                    .get_or_refresh(
370                        &self.client,
371                        token_url,
372                        client_id,
373                        client_secret,
374                        scopes,
375                        *expiry_ratio,
376                    )
377                    .await?;
378                Auth::Bearer(token)
379            }
380            Auth::TokenEndpoint {
381                url: token_url,
382                method: token_method,
383                headers: token_headers,
384                body: token_body,
385                token_path,
386                expiry_path,
387                expiry_ratio,
388                response_validator,
389            } => {
390                let token = self
391                    .token_endpoint_cache
392                    .get_or_refresh(
393                        &self.client,
394                        token_url,
395                        token_method,
396                        token_headers,
397                        token_body.as_ref(),
398                        token_path,
399                        expiry_path.as_deref(),
400                        *expiry_ratio,
401                        response_validator.as_ref(),
402                    )
403                    .await?;
404                Auth::Bearer(token)
405            }
406            other => other.clone(),
407        };
408
409        let mut headers = self.config.headers.clone();
410        resolved_auth.apply(&mut headers)?;
411
412        let mut req = self
413            .client
414            .request(self.config.method.clone(), &url)
415            .headers(headers);
416
417        if !use_override {
418            req = req.query(params);
419        }
420
421        // ApiKeyQuery: inject the API key as a query parameter.
422        if let Auth::ApiKeyQuery { param, value } = &self.config.auth {
423            req = req.query(&[(param.as_str(), value.as_str())]);
424        }
425
426        if let Some(body) = &self.config.body {
427            req = req.json(body);
428        }
429
430        let resp = req.send().await?;
431        let status = resp.status();
432
433        // 429 Too Many Requests: honour Retry-After before retrying.
434        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
435            let wait = parse_retry_after(resp.headers());
436            return Err(FaucetError::RateLimited(wait));
437        }
438
439        // Tolerated errors: treat as empty page.
440        if self.config.tolerated_http_errors.contains(&status.as_u16()) {
441            tracing::debug!(
442                status = status.as_u16(),
443                "tolerated HTTP error; treating as empty page"
444            );
445            return Ok((Value::Array(vec![]), HeaderMap::new()));
446        }
447
448        // For non-success responses, capture the body for debugging before
449        // returning the error. This gives callers (and logs) the server's
450        // error message rather than just a status code.
451        if !status.is_success() {
452            let resp_url = resp.url().to_string();
453            let body_text = resp.text().await.unwrap_or_default();
454            // Truncate very long error bodies to avoid bloating logs/errors.
455            let truncated = if body_text.len() > 1024 {
456                // Find a safe UTF-8 boundary at or before 1024 bytes.
457                let end = body_text.floor_char_boundary(1024);
458                format!("{}...(truncated)", &body_text[..end])
459            } else {
460                body_text
461            };
462            return Err(FaucetError::HttpStatus {
463                status: status.as_u16(),
464                url: resp_url,
465                body: truncated,
466            });
467        }
468
469        let resp_headers = resp.headers().clone();
470        let body: Value = resp.json().await?;
471        Ok((body, resp_headers))
472    }
473}
474
475/// Substitute `{key}` placeholders in `path` with values from `context`.
476fn resolve_path(path: &str, context: &HashMap<String, Value>) -> String {
477    let mut result = path.to_string();
478    for (key, value) in context {
479        let placeholder = format!("{{{key}}}");
480        let replacement = match value {
481            Value::String(s) => s.clone(),
482            other => other.to_string(),
483        };
484        result = result.replace(&placeholder, &replacement);
485    }
486    result
487}
488
489/// Parse the `Retry-After` header as a number of seconds.
490/// Falls back to 60 s if the header is absent or unparseable.
491fn parse_retry_after(headers: &HeaderMap) -> Duration {
492    headers
493        .get(reqwest::header::RETRY_AFTER)
494        .and_then(|v| v.to_str().ok())
495        .and_then(|s| s.parse::<u64>().ok())
496        .map(Duration::from_secs)
497        .unwrap_or(Duration::from_secs(60))
498}
499
500#[async_trait]
501impl faucet_core::Source for RestStream {
502    async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
503        RestStream::fetch_all(self).await
504    }
505
506    async fn fetch_all_incremental(&self) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
507        RestStream::fetch_all_incremental(self).await
508    }
509
510    fn config_schema(&self) -> serde_json::Value {
511        serde_json::to_value(faucet_core::schema_for!(RestStreamConfig))
512            .expect("schema serialization")
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519    use serde_json::json;
520
521    #[test]
522    fn test_resolve_path_substitutes_placeholders() {
523        let mut ctx = HashMap::new();
524        ctx.insert("org_id".to_string(), json!("acme"));
525        ctx.insert("repo".to_string(), json!("myrepo"));
526        let result = resolve_path("/orgs/{org_id}/repos/{repo}/issues", &ctx);
527        assert_eq!(result, "/orgs/acme/repos/myrepo/issues");
528    }
529
530    #[test]
531    fn test_resolve_path_no_placeholders() {
532        let ctx = HashMap::new();
533        let result = resolve_path("/api/users", &ctx);
534        assert_eq!(result, "/api/users");
535    }
536
537    #[test]
538    fn test_resolve_path_numeric_value() {
539        let mut ctx = HashMap::new();
540        ctx.insert("id".to_string(), json!(42));
541        let result = resolve_path("/items/{id}", &ctx);
542        assert_eq!(result, "/items/42");
543    }
544
545    #[test]
546    fn test_parse_retry_after_valid() {
547        let mut headers = HeaderMap::new();
548        headers.insert(
549            reqwest::header::RETRY_AFTER,
550            reqwest::header::HeaderValue::from_static("30"),
551        );
552        assert_eq!(parse_retry_after(&headers), Duration::from_secs(30));
553    }
554
555    #[test]
556    fn test_parse_retry_after_missing_defaults_to_60() {
557        assert_eq!(
558            parse_retry_after(&HeaderMap::new()),
559            Duration::from_secs(60)
560        );
561    }
562
563    #[test]
564    fn test_parse_retry_after_non_numeric_defaults_to_60() {
565        let mut headers = HeaderMap::new();
566        headers.insert(
567            reqwest::header::RETRY_AFTER,
568            reqwest::header::HeaderValue::from_static("not-a-number"),
569        );
570        assert_eq!(parse_retry_after(&headers), Duration::from_secs(60));
571    }
572
573    #[test]
574    fn test_new_rejects_invalid_expiry_ratio_zero() {
575        let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
576            token_url: "https://auth.example.com/token".into(),
577            client_id: "id".into(),
578            client_secret: "secret".into(),
579            scopes: vec![],
580            expiry_ratio: 0.0,
581        });
582        let result = RestStream::new(config);
583        assert!(result.is_err());
584        assert!(matches!(result, Err(FaucetError::Auth(_))));
585    }
586
587    #[test]
588    fn test_new_rejects_invalid_expiry_ratio_negative() {
589        let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
590            token_url: "https://auth.example.com/token".into(),
591            client_id: "id".into(),
592            client_secret: "secret".into(),
593            scopes: vec![],
594            expiry_ratio: -0.5,
595        });
596        assert!(RestStream::new(config).is_err());
597    }
598
599    #[test]
600    fn test_new_rejects_invalid_expiry_ratio_above_one() {
601        let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
602            token_url: "https://auth.example.com/token".into(),
603            client_id: "id".into(),
604            client_secret: "secret".into(),
605            scopes: vec![],
606            expiry_ratio: 1.5,
607        });
608        assert!(RestStream::new(config).is_err());
609    }
610
611    #[test]
612    fn test_new_accepts_valid_expiry_ratio() {
613        let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
614            token_url: "https://auth.example.com/token".into(),
615            client_id: "id".into(),
616            client_secret: "secret".into(),
617            scopes: vec![],
618            expiry_ratio: 1.0,
619        });
620        assert!(RestStream::new(config).is_ok());
621    }
622
623    #[test]
624    fn test_new_rejects_invalid_transform_regex() {
625        let config = RestStreamConfig::new("https://example.com", "/data").add_transform(
626            faucet_core::RecordTransform::RenameKeys {
627                pattern: "[invalid".into(),
628                replacement: "".into(),
629            },
630        );
631        let result = RestStream::new(config);
632        assert!(result.is_err());
633        assert!(matches!(result, Err(FaucetError::Transform(_))));
634    }
635
636    #[test]
637    fn test_new_with_no_auth_succeeds() {
638        let config = RestStreamConfig::new("https://example.com", "/data");
639        assert!(RestStream::new(config).is_ok());
640    }
641
642    #[test]
643    fn test_new_with_timeout() {
644        let config =
645            RestStreamConfig::new("https://example.com", "/data").timeout(Duration::from_secs(10));
646        assert!(RestStream::new(config).is_ok());
647    }
648
649    #[test]
650    fn test_resolve_path_missing_placeholder_unchanged() {
651        let mut ctx = HashMap::new();
652        ctx.insert("org".to_string(), json!("acme"));
653        let result = resolve_path("/items/{missing}", &ctx);
654        assert_eq!(result, "/items/{missing}");
655    }
656
657    #[test]
658    fn test_resolve_path_boolean_value() {
659        let mut ctx = HashMap::new();
660        ctx.insert("flag".to_string(), json!(true));
661        let result = resolve_path("/items/{flag}", &ctx);
662        assert_eq!(result, "/items/true");
663    }
664}