Skip to main content

faucet_source_rest/
stream.rs

1//! The main REST stream executor.
2
3use crate::auth::Auth;
4use crate::auth::oauth2::TokenCache;
5use crate::auth::token_endpoint::TokenEndpointCache;
6use crate::config::RestStreamConfig;
7use crate::extract;
8use crate::pagination::{PaginationState, PaginationStyle};
9use crate::retry;
10use async_trait::async_trait;
11use faucet_core::replication::{
12    ReplicationMethod, filter_incremental, max_replication_value, max_value,
13};
14use faucet_core::schema;
15use faucet_core::{AuthSpec, Credential, FaucetError, SharedAuthProvider};
16use futures_core::Stream;
17use reqwest::Client;
18use reqwest::header::HeaderMap;
19use serde::Deserialize;
20use serde_json::Value;
21use std::collections::HashMap;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::time::Duration;
25use tokio::sync::Mutex as AsyncMutex;
26
27/// A configured REST API stream that handles pagination, auth, and extraction.
28pub struct RestStream {
29    config: RestStreamConfig,
30    client: Client,
31    /// Shared OAuth2 token cache (only used when `config.auth` is `Auth::OAuth2`).
32    token_cache: TokenCache,
33    /// Shared token endpoint cache (only used when `config.auth` is `Auth::TokenEndpoint`).
34    token_endpoint_cache: TokenEndpointCache,
35    /// Optional shared auth provider. Set when `config.auth` is an
36    /// `AuthSpec::Reference` resolved by the caller (e.g. the CLI `auth:`
37    /// catalog), or injected directly by a library caller to share one token
38    /// across multiple sources. When present it takes precedence over inline
39    /// auth.
40    auth_provider: Option<SharedAuthProvider>,
41    /// Bookmark applied at runtime via
42    /// [`Source::apply_start_bookmark`](faucet_core::Source::apply_start_bookmark).
43    /// Takes precedence over `config.start_replication_value` when set.
44    runtime_start: Arc<AsyncMutex<Option<Value>>>,
45}
46
47/// Map a [`Credential`] from a shared provider onto the REST [`Auth`]
48/// representation so the existing header-application path can be reused.
49fn credential_to_auth(cred: Credential) -> Auth {
50    match cred {
51        Credential::Bearer(token) => Auth::Bearer { token },
52        Credential::Token(token) => Auth::Custom {
53            headers: std::iter::once(("Authorization".to_string(), token)).collect(),
54        },
55        Credential::Basic { username, password } => Auth::Basic { username, password },
56        Credential::Header { name, value } => Auth::Custom {
57            headers: std::iter::once((name, value)).collect(),
58        },
59    }
60}
61
62impl RestStream {
63    /// Create a new stream from the given configuration.
64    pub fn new(config: RestStreamConfig) -> Result<Self, FaucetError> {
65        // Validate expiry_ratio at construction time.
66        let expiry_ratio_to_validate = match &config.auth {
67            AuthSpec::Inline(Auth::OAuth2 { expiry_ratio, .. })
68            | AuthSpec::Inline(Auth::TokenEndpoint { expiry_ratio, .. }) => Some(*expiry_ratio),
69            _ => None,
70        };
71        if let Some(ratio) = expiry_ratio_to_validate
72            && (ratio <= 0.0 || ratio > 1.0)
73        {
74            return Err(FaucetError::Auth(format!(
75                "expiry_ratio must be in (0.0, 1.0], got {ratio}"
76            )));
77        }
78
79        let mut builder = Client::builder();
80        if let Some(t) = config.timeout {
81            builder = builder.timeout(t);
82        }
83        Ok(Self {
84            config,
85            client: builder.build()?,
86            token_cache: TokenCache::new(),
87            token_endpoint_cache: TokenEndpointCache::new(),
88            auth_provider: None,
89            runtime_start: Arc::new(AsyncMutex::new(None)),
90        })
91    }
92
93    /// Attach a shared [`AuthProvider`](faucet_core::AuthProvider). When set, the
94    /// provider supplies the credential for every request (taking precedence
95    /// over inline auth), so several sources can share one token with
96    /// single-flight refresh. Used by the CLI to resolve `auth: { ref }`, and by
97    /// library callers who construct one provider and inject it into many
98    /// sources.
99    pub fn with_auth_provider(mut self, provider: SharedAuthProvider) -> Self {
100        self.auth_provider = Some(provider);
101        self
102    }
103
104    /// Fetch all records across all pages as raw JSON values.
105    ///
106    /// When `partitions` are configured, the stream is executed once per
107    /// partition and all results are concatenated.
108    ///
109    /// When `replication_method` is `Incremental` and `replication_key` +
110    /// `start_replication_value` are both set, records at or before the
111    /// bookmark are filtered out.
112    pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
113        if self.config.partitions.is_empty() {
114            self.fetch_partition(None, None).await
115        } else if let Some(concurrency) = self.config.partition_concurrency {
116            // Process partitions concurrently using a semaphore to limit parallelism.
117            let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(concurrency.max(1)));
118            let mut handles = Vec::with_capacity(self.config.partitions.len());
119
120            for ctx in &self.config.partitions {
121                let permit =
122                    semaphore.clone().acquire_owned().await.map_err(|e| {
123                        FaucetError::Config(format!("semaphore acquire failed: {e}"))
124                    })?;
125                let fut = self.fetch_partition(Some(ctx), None);
126                handles.push(async move {
127                    let result = fut.await;
128                    drop(permit);
129                    result
130                });
131            }
132
133            let results = futures::future::try_join_all(handles).await?;
134            Ok(results.into_iter().flatten().collect())
135        } else {
136            let mut all_records = Vec::new();
137            for ctx in &self.config.partitions {
138                let records = self.fetch_partition(Some(ctx), None).await?;
139                all_records.extend(records);
140            }
141            Ok(all_records)
142        }
143    }
144
145    /// Fetch all records and deserialize into typed structs.
146    pub async fn fetch_all_as<T: for<'de> Deserialize<'de>>(&self) -> Result<Vec<T>, FaucetError> {
147        let values = self.fetch_all().await?;
148        values
149            .into_iter()
150            .map(|v| serde_json::from_value(v).map_err(FaucetError::Json))
151            .collect()
152    }
153
154    /// Infer a JSON Schema for this stream's records.
155    ///
156    /// If a `schema` is already set on the config, it is returned immediately
157    /// without making any HTTP requests.
158    ///
159    /// Otherwise the stream fetches up to `schema_sample_size` records
160    /// (respecting `max_pages`) and derives a JSON Schema from them.  Fields
161    /// that are absent in some records, or that carry a `null` value, are
162    /// marked as nullable (`["<type>", "null"]`).
163    ///
164    /// Set `schema_sample_size` to `0` to sample all available records.
165    pub async fn infer_schema(&self) -> Result<Value, FaucetError> {
166        if let Some(ref s) = self.config.schema {
167            return Ok(s.clone());
168        }
169        let limit = match self.config.schema_sample_size {
170            0 => None,
171            n => Some(n),
172        };
173        let records = self.fetch_partition(None, limit).await?;
174        Ok(schema::infer_schema(&records))
175    }
176
177    /// Fetch all records in incremental mode, returning the records along with
178    /// the maximum value of `replication_key` observed across those records.
179    ///
180    /// The returned bookmark should be persisted by the caller and passed back
181    /// as `start_replication_value` on the next run.
182    ///
183    /// If no `replication_key` is configured, this behaves identically to
184    /// [`fetch_all`](Self::fetch_all) and the bookmark is `None`.
185    pub async fn fetch_all_incremental(&self) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
186        let records = self.fetch_all().await?;
187        let bookmark = self
188            .config
189            .replication_key
190            .as_deref()
191            .and_then(|key| max_replication_value(&records, key))
192            .cloned();
193        Ok((records, bookmark))
194    }
195
196    /// Stream API pages without buffering the full result set.
197    ///
198    /// This is a thin convenience wrapper around the
199    /// [`Source::stream_pages`](faucet_core::Source::stream_pages) trait
200    /// method — it discards bookmarks and yields one `Vec<Value>` per
201    /// upstream API page. Use the trait method directly if you need
202    /// per-page bookmarks for incremental replication.
203    ///
204    /// Note: partitions are not supported by `stream_pages`. Use `fetch_all`
205    /// for multi-partition streams.
206    ///
207    /// ```rust,no_run
208    /// use faucet_source_rest::{RestStream, RestStreamConfig};
209    /// use futures::StreamExt;
210    ///
211    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
212    /// let stream = RestStream::new(RestStreamConfig::new("https://api.example.com", "/items"))?;
213    /// let mut pages = stream.stream_pages();
214    /// while let Some(page) = pages.next().await {
215    ///     let records = page?;
216    ///     println!("got {} records", records.len());
217    /// }
218    /// # Ok(())
219    /// # }
220    /// ```
221    pub fn stream_pages(
222        &self,
223    ) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + Send + '_>> {
224        let mut inner = self.stream_pages_inner(None);
225        Box::pin(async_stream::try_stream! {
226            loop {
227                let page = std::future::poll_fn(|cx| inner.as_mut().poll_next(cx)).await;
228                match page {
229                    Some(Ok(p)) => yield p.records,
230                    Some(Err(e)) => Err(e)?,
231                    None => break,
232                }
233            }
234        })
235    }
236
237    // ── Private helpers ───────────────────────────────────────────────────────
238
239    /// Core pagination loop shared by [`Source::stream_pages`] and
240    /// [`fetch_partition`](Self::fetch_partition).
241    ///
242    /// Yields one [`faucet_core::StreamPage`] per page. The final page carries
243    /// the consolidated replication bookmark (`Some(value)`); all intermediate
244    /// pages carry `None`. When `context` is `Some`, path placeholders are
245    /// substituted for partition support.
246    fn stream_pages_inner(
247        &self,
248        context: Option<&HashMap<String, Value>>,
249    ) -> Pin<Box<dyn Stream<Item = Result<faucet_core::StreamPage, FaucetError>> + Send + '_>> {
250        // Clone the context into an owned map so it can live inside the
251        // `async_stream` generator without borrowing from the caller.
252        let owned_context: Option<HashMap<String, Value>> = context.cloned();
253
254        Box::pin(async_stream::try_stream! {
255            // Resolve the effective start-bookmark once at the top of the stream.
256            // A runtime override (applied via `Source::apply_start_bookmark` —
257            // typically by the pipeline reading from a `StateStore`) takes
258            // precedence over the static config value.
259            let effective_start: Option<Value> = {
260                let guard = self.runtime_start.lock().await;
261                guard
262                    .clone()
263                    .or_else(|| self.config.start_replication_value.clone())
264            };
265
266            let mut state = PaginationState::default();
267            let mut pages_fetched = 0usize;
268            let mut running_max: Option<Value> = effective_start.clone();
269            let mut bookmark_emitted = false;
270
271            // H13 (audit #146): combining `max_pages` with incremental
272            // replication only makes safe forward progress when the API returns
273            // rows ordered ascending by the replication key. On truncation we
274            // advance the bookmark to the max key seen so far (so the next run
275            // resumes past it — without this the stream would re-read the same
276            // first `max_pages` window forever and never progress); but if the
277            // feed is unordered, unfetched later pages may hold lower keys that
278            // resuming past `running_max` would then drop. Warn loudly so the
279            // requirement is explicit rather than a silent data-loss edge.
280            if self.config.max_pages.is_some()
281                && self.config.replication_method == ReplicationMethod::Incremental
282                && self.config.replication_key.is_some()
283            {
284                tracing::warn!(
285                    "max_pages combined with incremental replication assumes the API returns rows \
286                     ordered ascending by the replication key; an unordered feed can drop unfetched \
287                     lower-key records on resume. Ensure ordering, or remove max_pages for a full \
288                     incremental sweep."
289                );
290            }
291
292            loop {
293                if let Some(max) = self.config.max_pages
294                    && pages_fetched >= max
295                {
296                    tracing::warn!("max pages ({max}) reached");
297                    break;
298                }
299
300                let mut params = self.config.query_params.clone();
301                self.config.pagination.apply_params(&mut params, &state);
302
303                let url_override = match &self.config.pagination {
304                    PaginationStyle::LinkHeader | PaginationStyle::NextLinkInBody { .. } => {
305                        state.next_link.clone()
306                    }
307                    _ => None,
308                };
309
310                let params_clone = params.clone();
311                let ctx_ref = owned_context.as_ref();
312                let is_first_page = pages_fetched == 0;
313                let (body, resp_headers) = retry::execute_with_retry(
314                    self.config.max_retries,
315                    self.config.retry_backoff,
316                    || {
317                        self.execute_request(
318                            &params_clone,
319                            url_override.as_deref(),
320                            ctx_ref,
321                            is_first_page,
322                        )
323                    },
324                )
325                .await?;
326
327                let raw_records =
328                    extract::extract_records(&body, self.config.records_path.as_deref())?;
329                let raw_count = raw_records.len();
330
331                let records =
332                    if self.config.replication_method == ReplicationMethod::Incremental {
333                        if let (Some(key), Some(start)) =
334                            (&self.config.replication_key, effective_start.as_ref())
335                        {
336                            filter_incremental(raw_records, key, start)
337                        } else {
338                            raw_records
339                        }
340                    } else {
341                        raw_records
342                    };
343
344                // Track the running max replication value across pages so the
345                // final page can carry the consolidated bookmark.
346                if self.config.replication_method == ReplicationMethod::Incremental
347                    && let Some(key) = self.config.replication_key.as_deref()
348                        && let Some(page_max) = max_replication_value(&records, key) {
349                            let page_max = page_max.clone();
350                            running_max = Some(match running_max.take() {
351                                Some(prev) => max_value(prev, page_max),
352                                None => page_max,
353                            });
354                        }
355
356                // Advance pagination state to learn whether there is a next
357                // page BEFORE yielding the current one. This way the bookmark
358                // is only attached to pages where `has_next == false`, and we
359                // never pre-fetch the next page just to classify the current
360                // one as "final" (which would prevent early exit in callers
361                // such as `fetch_partition` with `max_records`).
362                let has_next = self
363                    .config
364                    .pagination
365                    .advance(&body, &resp_headers, &mut state, raw_count)?;
366                pages_fetched += 1;
367
368                if has_next {
369                    // Intermediate page — yield without bookmark so the
370                    // pipeline does not persist a partial checkpoint.
371                    yield faucet_core::StreamPage { records, bookmark: None };
372                } else {
373                    // Final page — attach the consolidated bookmark.
374                    bookmark_emitted = running_max.is_some();
375                    yield faucet_core::StreamPage {
376                        records,
377                        bookmark: running_max.clone(),
378                    };
379                    break;
380                }
381
382                if let Some(delay) = self.config.request_delay {
383                    tokio::time::sleep(delay).await;
384                }
385            }
386
387            // Trailing checkpoint: if the loop exited without carrying the
388            // bookmark on a real page (e.g. via max_pages truncation, or with
389            // zero pages fetched and a seeded start bookmark), emit one empty
390            // page carrying the consolidated bookmark so the pipeline still
391            // persists incremental progress and the next run resumes from here.
392            // (Safe forward progress under max_pages assumes ascending order by
393            // the replication key — see the warning emitted above, audit #146 H13.)
394            if !bookmark_emitted && running_max.is_some() {
395                yield faucet_core::StreamPage {
396                    records: Vec::new(),
397                    bookmark: running_max,
398                };
399            }
400        })
401    }
402
403    /// Run the full pagination loop for a single partition context.
404    ///
405    /// `max_records`: when `Some(n)`, stop collecting after `n` records
406    /// (used for schema sampling).
407    async fn fetch_partition(
408        &self,
409        context: Option<&HashMap<String, Value>>,
410        max_records: Option<usize>,
411    ) -> Result<Vec<Value>, FaucetError> {
412        let mut all_records = Vec::new();
413        let mut pages_fetched = 0usize;
414        let mut pages = self.stream_pages_inner(context);
415
416        // Poll the stream without requiring StreamExt (avoids extra dependency).
417        loop {
418            let page = std::future::poll_fn(|cx: &mut std::task::Context<'_>| {
419                pages.as_mut().poll_next(cx)
420            })
421            .await;
422
423            match page {
424                Some(Ok(page)) => {
425                    pages_fetched += 1;
426                    let records = page.records;
427                    match max_records {
428                        Some(limit) => {
429                            let remaining = limit.saturating_sub(all_records.len());
430                            all_records.extend(records.into_iter().take(remaining));
431                            if all_records.len() >= limit {
432                                break;
433                            }
434                        }
435                        None => all_records.extend(records),
436                    }
437                }
438                Some(Err(e)) => return Err(e),
439                None => break,
440            }
441        }
442
443        tracing::info!(
444            stream = self.config.name.as_deref().unwrap_or("(unnamed)"),
445            records = all_records.len(),
446            pages = pages_fetched,
447            "fetch complete"
448        );
449        Ok(all_records)
450    }
451
452    /// Execute a single HTTP request and return the response body and headers.
453    ///
454    /// - When `url_override` is `Some`, that full URL is used and query params
455    ///   are **not** appended (Link header pagination encodes them in the URL).
456    /// - When `path_context` is `Some`, `{key}` placeholders in `config.path`
457    ///   are substituted with values from the context map (partition support).
458    async fn execute_request(
459        &self,
460        params: &HashMap<String, String>,
461        url_override: Option<&str>,
462        path_context: Option<&HashMap<String, Value>>,
463        is_first_page: bool,
464    ) -> Result<(Value, HeaderMap), FaucetError> {
465        let use_override = url_override.is_some();
466        let url = match url_override {
467            Some(u) => u.to_string(),
468            None => {
469                let path = match path_context {
470                    Some(ctx) => faucet_core::util::substitute_context(&self.config.path, ctx),
471                    None => self.config.path.clone(),
472                };
473                format!("{}/{}", self.config.base_url, path.trim_start_matches('/'))
474            }
475        };
476
477        // Resolve credentials to concrete auth headers. A shared auth provider
478        // (from `auth: { ref }` or injected by a library caller) takes
479        // precedence; otherwise inline OAuth2 / TokenEndpoint are resolved to a
480        // Bearer token via the per-source cache (cached until expiry, avoiding a
481        // token fetch on every request).
482        let resolved_auth = if let Some(provider) = &self.auth_provider {
483            credential_to_auth(provider.credential().await?)
484        } else {
485            match &self.config.auth {
486                AuthSpec::Inline(Auth::OAuth2 {
487                    token_url,
488                    client_id,
489                    client_secret,
490                    scopes,
491                    expiry_ratio,
492                }) => {
493                    let token = self
494                        .token_cache
495                        .get_or_refresh(
496                            &self.client,
497                            token_url,
498                            client_id,
499                            client_secret,
500                            scopes,
501                            *expiry_ratio,
502                        )
503                        .await?;
504                    Auth::Bearer { token }
505                }
506                AuthSpec::Inline(Auth::TokenEndpoint {
507                    url: token_url,
508                    method: token_method,
509                    headers: token_headers,
510                    body: token_body,
511                    token_path,
512                    expiry_path,
513                    expiry_ratio,
514                    response_validator,
515                }) => {
516                    let token = self
517                        .token_endpoint_cache
518                        .get_or_refresh(
519                            &self.client,
520                            token_url,
521                            token_method,
522                            token_headers,
523                            token_body.as_ref(),
524                            token_path,
525                            expiry_path.as_deref(),
526                            *expiry_ratio,
527                            response_validator.as_ref(),
528                        )
529                        .await?;
530                    Auth::Bearer { token }
531                }
532                AuthSpec::Inline(other) => other.clone(),
533                AuthSpec::Reference(r) => {
534                    return Err(FaucetError::Auth(format!(
535                        "auth references provider '{}' but no provider was supplied; \
536                         set one via the CLI `auth:` catalog or `with_auth_provider`",
537                        r.name
538                    )));
539                }
540            }
541        };
542
543        let mut headers = self.config.headers.clone();
544        resolved_auth.apply(&mut headers)?;
545
546        let mut req = self
547            .client
548            .request(self.config.method.clone(), &url)
549            .headers(headers);
550
551        if !use_override {
552            // When parent context is available, substitute {placeholders} in
553            // query param values so child sources can be parameterised.
554            if let Some(ctx) = path_context {
555                let substituted: HashMap<String, String> = params
556                    .iter()
557                    .map(|(k, v)| (k.clone(), faucet_core::util::substitute_context(v, ctx)))
558                    .collect();
559                req = req.query(&substituted.iter().collect::<Vec<_>>());
560            } else {
561                req = req.query(params);
562            }
563        }
564
565        // ApiKeyQuery: inject the API key as a query parameter.
566        if let AuthSpec::Inline(Auth::ApiKeyQuery { param, value }) = &self.config.auth {
567            req = req.query(&[(param.as_str(), value.as_str())]);
568        }
569
570        if let Some(body) = &self.config.body {
571            // Substitute context into body string values when available.
572            if let Some(ctx) = path_context {
573                let body_str = body.to_string();
574                let substituted = faucet_core::util::substitute_context(&body_str, ctx);
575                let substituted_value: Value =
576                    serde_json::from_str(&substituted).unwrap_or(Value::String(substituted));
577                req = req.json(&substituted_value);
578            } else {
579                req = req.json(body);
580            }
581        }
582
583        let resp = req.send().await?;
584        let status = resp.status();
585
586        // 429 Too Many Requests: honour Retry-After before retrying.
587        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
588            let wait = parse_retry_after(resp.headers());
589            return Err(FaucetError::RateLimited(wait));
590        }
591
592        // Tolerated errors: treat as an empty page ONLY on the first request,
593        // where they legitimately mean "this resource is absent/empty". Mid-
594        // pagination, an empty page makes every pagination style read "last
595        // page" and stop, silently dropping every remaining page as a
596        // "successful" run (#78/#7). There we fall through to the real error
597        // path: the retry executor retries 5xx, and a persistent error fails
598        // loudly instead of truncating the stream.
599        if is_first_page && self.config.tolerated_http_errors.contains(&status.as_u16()) {
600            tracing::debug!(
601                status = status.as_u16(),
602                "tolerated HTTP error on first request; treating as empty page"
603            );
604            return Ok((Value::Array(vec![]), HeaderMap::new()));
605        }
606        if !is_first_page && self.config.tolerated_http_errors.contains(&status.as_u16()) {
607            tracing::warn!(
608                status = status.as_u16(),
609                "tolerated HTTP error mid-pagination; surfacing as an error to avoid \
610                 silently truncating the stream"
611            );
612        }
613
614        // For non-success responses, capture the body for debugging before
615        // returning the error. This gives callers (and logs) the server's
616        // error message rather than just a status code.
617        if !status.is_success() {
618            let resp_url = resp.url().to_string();
619            let body_text = resp.text().await.unwrap_or_default();
620            // Truncate very long error bodies to avoid bloating logs/errors.
621            let truncated = if body_text.len() > 1024 {
622                // Find a safe UTF-8 boundary at or before 1024 bytes.
623                let end = body_text.floor_char_boundary(1024);
624                format!("{}...(truncated)", &body_text[..end])
625            } else {
626                body_text
627            };
628            return Err(FaucetError::HttpStatus {
629                status: status.as_u16(),
630                url: resp_url,
631                body: truncated,
632            });
633        }
634
635        let resp_headers = resp.headers().clone();
636
637        // A 204 No Content — or any 2xx with an empty / whitespace-only body —
638        // carries no JSON to parse. `resp.json()` on such a response yields a
639        // non-retriable decode error ("EOF while parsing a value") that aborts
640        // the run; treat it as an empty page ("no data") instead (#146 M10). A
641        // non-empty body that isn't valid JSON still surfaces as a parse error.
642        if status == reqwest::StatusCode::NO_CONTENT {
643            return Ok((Value::Array(vec![]), resp_headers));
644        }
645        let bytes = resp.bytes().await?;
646        if bytes.iter().all(u8::is_ascii_whitespace) {
647            return Ok((Value::Array(vec![]), resp_headers));
648        }
649        let body: Value = serde_json::from_slice(&bytes)?;
650        Ok((body, resp_headers))
651    }
652}
653
654/// Parse the `Retry-After` header. RFC 7231 permits **either** delta-seconds
655/// **or** an HTTP-date; we honour both. An HTTP-date in the past yields a zero
656/// wait (retry now). Falls back to 60 s only when the header is absent or in
657/// neither form.
658fn parse_retry_after(headers: &HeaderMap) -> Duration {
659    const DEFAULT: Duration = Duration::from_secs(60);
660    let Some(raw) = headers
661        .get(reqwest::header::RETRY_AFTER)
662        .and_then(|v| v.to_str().ok())
663        .map(str::trim)
664    else {
665        return DEFAULT;
666    };
667    // delta-seconds form.
668    if let Ok(secs) = raw.parse::<u64>() {
669        return Duration::from_secs(secs);
670    }
671    // HTTP-date form (IMF-fixdate / RFC 850 / asctime).
672    if let Ok(when) = httpdate::parse_http_date(raw) {
673        return when
674            .duration_since(std::time::SystemTime::now())
675            .unwrap_or(Duration::ZERO);
676    }
677    DEFAULT
678}
679
680#[async_trait]
681impl faucet_core::Source for RestStream {
682    async fn fetch_with_context(
683        &self,
684        context: &std::collections::HashMap<String, serde_json::Value>,
685    ) -> Result<Vec<Value>, FaucetError> {
686        if context.is_empty() {
687            // No parent context — use normal fetch_all with partitions
688            RestStream::fetch_all(self).await
689        } else if self.config.partitions.is_empty() {
690            // Parent context, no partitions — use context directly as partition context
691            self.fetch_partition(Some(context), None).await
692        } else {
693            // Both parent context and partitions — merge context into each partition
694            let mut all_records = Vec::new();
695            for partition in &self.config.partitions {
696                let mut merged = context.clone();
697                merged.extend(partition.iter().map(|(k, v)| (k.clone(), v.clone())));
698                all_records.extend(self.fetch_partition(Some(&merged), None).await?);
699            }
700            Ok(all_records)
701        }
702    }
703
704    async fn fetch_with_context_incremental(
705        &self,
706        context: &std::collections::HashMap<String, serde_json::Value>,
707    ) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
708        let records = self.fetch_with_context(context).await?;
709        let bookmark = self
710            .config
711            .replication_key
712            .as_deref()
713            .and_then(|key| faucet_core::replication::max_replication_value(&records, key))
714            .cloned();
715        Ok((records, bookmark))
716    }
717
718    fn connector_name(&self) -> &'static str {
719        "rest"
720    }
721
722    fn config_schema(&self) -> serde_json::Value {
723        serde_json::to_value(faucet_core::schema_for!(RestStreamConfig))
724            .expect("schema serialization")
725    }
726
727    fn state_key(&self) -> Option<String> {
728        self.config.state_key.clone()
729    }
730
731    fn stream_pages<'a>(
732        &'a self,
733        context: &'a HashMap<String, Value>,
734        _batch_size: usize,
735    ) -> Pin<Box<dyn Stream<Item = Result<faucet_core::StreamPage, FaucetError>> + Send + 'a>> {
736        // RestStream chunks by upstream-API page boundaries, not by an
737        // in-memory `batch_size` knob. The arg is accepted for trait
738        // conformance and reserved for a future `page_size` mapping.
739        self.stream_pages_inner(Some(context))
740    }
741
742    async fn apply_start_bookmark(&self, bookmark: Value) -> Result<(), FaucetError> {
743        *self.runtime_start.lock().await = Some(bookmark);
744        Ok(())
745    }
746}
747
748#[cfg(test)]
749mod tests {
750    use super::*;
751    use serde_json::json;
752
753    #[test]
754    fn test_substitute_context_substitutes_placeholders() {
755        let mut ctx = HashMap::new();
756        ctx.insert("org_id".to_string(), json!("acme"));
757        ctx.insert("repo".to_string(), json!("myrepo"));
758        let result =
759            faucet_core::util::substitute_context("/orgs/{org_id}/repos/{repo}/issues", &ctx);
760        assert_eq!(result, "/orgs/acme/repos/myrepo/issues");
761    }
762
763    #[test]
764    fn test_substitute_context_no_placeholders() {
765        let ctx = HashMap::new();
766        let result = faucet_core::util::substitute_context("/api/users", &ctx);
767        assert_eq!(result, "/api/users");
768    }
769
770    #[test]
771    fn test_substitute_context_numeric_value() {
772        let mut ctx = HashMap::new();
773        ctx.insert("id".to_string(), json!(42));
774        let result = faucet_core::util::substitute_context("/items/{id}", &ctx);
775        assert_eq!(result, "/items/42");
776    }
777
778    #[test]
779    fn test_parse_retry_after_valid() {
780        let mut headers = HeaderMap::new();
781        headers.insert(
782            reqwest::header::RETRY_AFTER,
783            reqwest::header::HeaderValue::from_static("30"),
784        );
785        assert_eq!(parse_retry_after(&headers), Duration::from_secs(30));
786    }
787
788    #[test]
789    fn test_parse_retry_after_missing_defaults_to_60() {
790        assert_eq!(
791            parse_retry_after(&HeaderMap::new()),
792            Duration::from_secs(60)
793        );
794    }
795
796    #[test]
797    fn test_parse_retry_after_non_numeric_defaults_to_60() {
798        let mut headers = HeaderMap::new();
799        headers.insert(
800            reqwest::header::RETRY_AFTER,
801            reqwest::header::HeaderValue::from_static("not-a-number"),
802        );
803        assert_eq!(parse_retry_after(&headers), Duration::from_secs(60));
804    }
805
806    #[test]
807    fn test_parse_retry_after_http_date() {
808        // RFC 7231 permits an HTTP-date form instead of delta-seconds.
809        let future = std::time::SystemTime::now() + Duration::from_secs(7200);
810        let date = httpdate::fmt_http_date(future);
811        let mut headers = HeaderMap::new();
812        headers.insert(
813            reqwest::header::RETRY_AFTER,
814            reqwest::header::HeaderValue::from_str(&date).unwrap(),
815        );
816        let d = parse_retry_after(&headers);
817        // ~2 hours out — must not collapse to the 60s fallback.
818        assert!(
819            d > Duration::from_secs(3600),
820            "expected ~2h from HTTP-date, got {d:?}"
821        );
822        assert!(
823            d <= Duration::from_secs(7200),
824            "should not exceed the target instant, got {d:?}"
825        );
826    }
827
828    #[test]
829    fn test_parse_retry_after_past_http_date_is_zero() {
830        // A date already in the past → retry now (zero wait), not the fallback.
831        let past = std::time::SystemTime::now() - Duration::from_secs(3600);
832        let date = httpdate::fmt_http_date(past);
833        let mut headers = HeaderMap::new();
834        headers.insert(
835            reqwest::header::RETRY_AFTER,
836            reqwest::header::HeaderValue::from_str(&date).unwrap(),
837        );
838        assert_eq!(parse_retry_after(&headers), Duration::ZERO);
839    }
840
841    #[test]
842    fn test_new_rejects_invalid_expiry_ratio_zero() {
843        let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
844            token_url: "https://auth.example.com/token".into(),
845            client_id: "id".into(),
846            client_secret: "secret".into(),
847            scopes: vec![],
848            expiry_ratio: 0.0,
849        });
850        let result = RestStream::new(config);
851        assert!(result.is_err());
852        assert!(matches!(result, Err(FaucetError::Auth(_))));
853    }
854
855    #[test]
856    fn test_new_rejects_invalid_expiry_ratio_negative() {
857        let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
858            token_url: "https://auth.example.com/token".into(),
859            client_id: "id".into(),
860            client_secret: "secret".into(),
861            scopes: vec![],
862            expiry_ratio: -0.5,
863        });
864        assert!(RestStream::new(config).is_err());
865    }
866
867    #[test]
868    fn test_new_rejects_invalid_expiry_ratio_above_one() {
869        let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
870            token_url: "https://auth.example.com/token".into(),
871            client_id: "id".into(),
872            client_secret: "secret".into(),
873            scopes: vec![],
874            expiry_ratio: 1.5,
875        });
876        assert!(RestStream::new(config).is_err());
877    }
878
879    #[test]
880    fn test_new_accepts_valid_expiry_ratio() {
881        let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
882            token_url: "https://auth.example.com/token".into(),
883            client_id: "id".into(),
884            client_secret: "secret".into(),
885            scopes: vec![],
886            expiry_ratio: 1.0,
887        });
888        assert!(RestStream::new(config).is_ok());
889    }
890
891    #[test]
892    fn test_new_with_no_auth_succeeds() {
893        let config = RestStreamConfig::new("https://example.com", "/data");
894        assert!(RestStream::new(config).is_ok());
895    }
896
897    #[test]
898    fn test_new_with_timeout() {
899        let config =
900            RestStreamConfig::new("https://example.com", "/data").timeout(Duration::from_secs(10));
901        assert!(RestStream::new(config).is_ok());
902    }
903
904    #[test]
905    fn test_substitute_context_missing_placeholder_unchanged() {
906        let mut ctx = HashMap::new();
907        ctx.insert("org".to_string(), json!("acme"));
908        let result = faucet_core::util::substitute_context("/items/{missing}", &ctx);
909        assert_eq!(result, "/items/{missing}");
910    }
911
912    #[test]
913    fn test_substitute_context_boolean_value() {
914        let mut ctx = HashMap::new();
915        ctx.insert("flag".to_string(), json!(true));
916        let result = faucet_core::util::substitute_context("/items/{flag}", &ctx);
917        assert_eq!(result, "/items/true");
918    }
919
920    #[test]
921    fn rest_source_connector_name_is_rest() {
922        use faucet_core::Source;
923        let source = RestStream::new(RestStreamConfig::new("https://example.com", "/data"))
924            .expect("minimal RestStream construction");
925        assert_eq!(source.connector_name(), "rest");
926    }
927}