Skip to main content

faucet_stream/
stream.rs

1//! The main REST stream executor.
2
3use crate::auth::Auth;
4use crate::auth::oauth2::TokenCache;
5use crate::config::RestStreamConfig;
6use crate::error::FaucetError;
7use crate::extract;
8use crate::pagination::{PaginationState, PaginationStyle};
9use crate::replication::{ReplicationMethod, filter_incremental, max_replication_value};
10use crate::retry;
11use crate::schema;
12use crate::transform::{self, CompiledTransform};
13use futures_core::Stream;
14use reqwest::Client;
15use reqwest::header::HeaderMap;
16use serde::Deserialize;
17use serde_json::Value;
18use std::collections::HashMap;
19use std::pin::Pin;
20use std::time::Duration;
21
22/// A configured REST API stream that handles pagination, auth, and extraction.
23pub struct RestStream {
24    config: RestStreamConfig,
25    client: Client,
26    /// Pre-compiled transforms (regex patterns compiled once at construction time).
27    compiled_transforms: Vec<CompiledTransform>,
28    /// Shared OAuth2 token cache (only used when `config.auth` is `Auth::OAuth2`).
29    token_cache: TokenCache,
30}
31
32impl RestStream {
33    /// Create a new stream from the given configuration.
34    ///
35    /// Returns [`FaucetError::Transform`] immediately if any `RenameKeys`
36    /// transform contains an invalid regex pattern — fail-fast before any
37    /// HTTP requests are made.
38    pub fn new(config: RestStreamConfig) -> Result<Self, FaucetError> {
39        // Validate OAuth2 expiry_ratio at construction time.
40        if let Auth::OAuth2 { expiry_ratio, .. } = &config.auth
41            && (*expiry_ratio <= 0.0 || *expiry_ratio > 1.0)
42        {
43            return Err(FaucetError::Auth(format!(
44                "expiry_ratio must be in (0.0, 1.0], got {expiry_ratio}"
45            )));
46        }
47
48        let compiled_transforms = config
49            .transforms
50            .iter()
51            .map(transform::compile)
52            .collect::<Result<Vec<_>, _>>()?;
53
54        let mut builder = Client::builder();
55        if let Some(t) = config.timeout {
56            builder = builder.timeout(t);
57        }
58        Ok(Self {
59            config,
60            client: builder.build()?,
61            compiled_transforms,
62            token_cache: TokenCache::new(),
63        })
64    }
65
66    /// Fetch all records across all pages as raw JSON values.
67    ///
68    /// When `partitions` are configured, the stream is executed once per
69    /// partition and all results are concatenated.
70    ///
71    /// When `replication_method` is `Incremental` and `replication_key` +
72    /// `start_replication_value` are both set, records at or before the
73    /// bookmark are filtered out.
74    pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
75        if self.config.partitions.is_empty() {
76            self.fetch_partition(None, None).await
77        } else {
78            let mut all_records = Vec::new();
79            for ctx in &self.config.partitions {
80                let records = self.fetch_partition(Some(ctx), None).await?;
81                all_records.extend(records);
82            }
83            Ok(all_records)
84        }
85    }
86
87    /// Fetch all records and deserialize into typed structs.
88    pub async fn fetch_all_as<T: for<'de> Deserialize<'de>>(&self) -> Result<Vec<T>, FaucetError> {
89        let values = self.fetch_all().await?;
90        values
91            .into_iter()
92            .map(|v| serde_json::from_value(v).map_err(FaucetError::Json))
93            .collect()
94    }
95
96    /// Infer a JSON Schema for this stream's records.
97    ///
98    /// If a `schema` is already set on the config, it is returned immediately
99    /// without making any HTTP requests.
100    ///
101    /// Otherwise the stream fetches up to `schema_sample_size` records
102    /// (respecting `max_pages`) and derives a JSON Schema from them.  Fields
103    /// that are absent in some records, or that carry a `null` value, are
104    /// marked as nullable (`["<type>", "null"]`).
105    ///
106    /// Set `schema_sample_size` to `0` to sample all available records.
107    pub async fn infer_schema(&self) -> Result<Value, FaucetError> {
108        if let Some(ref s) = self.config.schema {
109            return Ok(s.clone());
110        }
111        let limit = match self.config.schema_sample_size {
112            0 => None,
113            n => Some(n),
114        };
115        let records = self.fetch_partition(None, limit).await?;
116        Ok(schema::infer_schema(&records))
117    }
118
119    /// Fetch all records in incremental mode, returning the records along with
120    /// the maximum value of `replication_key` observed across those records.
121    ///
122    /// The returned bookmark should be persisted by the caller and passed back
123    /// as `start_replication_value` on the next run.
124    ///
125    /// If no `replication_key` is configured, this behaves identically to
126    /// [`fetch_all`](Self::fetch_all) and the bookmark is `None`.
127    pub async fn fetch_all_incremental(&self) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
128        let records = self.fetch_all().await?;
129        let bookmark = self
130            .config
131            .replication_key
132            .as_deref()
133            .and_then(|key| max_replication_value(&records, key))
134            .cloned();
135        Ok((records, bookmark))
136    }
137
138    /// Stream records page-by-page, yielding one `Vec<Value>` per page as it arrives.
139    ///
140    /// Unlike [`fetch_all`](Self::fetch_all), this does not wait for all pages to be fetched
141    /// before returning — callers can process each page immediately.
142    ///
143    /// Note: partitions are not supported by `stream_pages`. Use `fetch_all` for
144    /// multi-partition streams.
145    ///
146    /// ```rust,no_run
147    /// use faucet_stream::{RestStream, RestStreamConfig};
148    /// use futures::StreamExt;
149    ///
150    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
151    /// let stream = RestStream::new(RestStreamConfig::new("https://api.example.com", "/items"))?;
152    /// let mut pages = stream.stream_pages();
153    /// while let Some(page) = pages.next().await {
154    ///     let records = page?;
155    ///     println!("got {} records", records.len());
156    /// }
157    /// # Ok(())
158    /// # }
159    /// ```
160    pub fn stream_pages(
161        &self,
162    ) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + Send + '_>> {
163        self.stream_pages_inner(None)
164    }
165
166    // ── Private helpers ───────────────────────────────────────────────────────
167
168    /// Core pagination loop shared by [`stream_pages`](Self::stream_pages) and
169    /// [`fetch_partition`](Self::fetch_partition).
170    ///
171    /// Yields one `Vec<Value>` per page.  When `context` is `Some`, path
172    /// placeholders are substituted for partition support.
173    fn stream_pages_inner(
174        &self,
175        context: Option<&HashMap<String, Value>>,
176    ) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + Send + '_>> {
177        // Clone the context into an owned map so it can live inside the
178        // `async_stream` generator without borrowing from the caller.
179        let owned_context: Option<HashMap<String, Value>> = context.cloned();
180
181        Box::pin(async_stream::try_stream! {
182            let mut state = PaginationState::default();
183            let mut pages_fetched = 0usize;
184
185            loop {
186                if let Some(max) = self.config.max_pages
187                    && pages_fetched >= max
188                {
189                    tracing::warn!("max pages ({max}) reached");
190                    break;
191                }
192
193                let mut params = self.config.query_params.clone();
194                self.config.pagination.apply_params(&mut params, &state);
195
196                let url_override = match &self.config.pagination {
197                    PaginationStyle::LinkHeader | PaginationStyle::NextLinkInBody { .. } => {
198                        state.next_link.clone()
199                    }
200                    _ => None,
201                };
202
203                let params_clone = params.clone();
204                let ctx_ref = owned_context.as_ref();
205                let (body, resp_headers) = retry::execute_with_retry(
206                    self.config.max_retries,
207                    self.config.retry_backoff,
208                    || self.execute_request(&params_clone, url_override.as_deref(), ctx_ref),
209                )
210                .await?;
211
212                let raw_records =
213                    extract::extract_records(&body, self.config.records_path.as_deref())?;
214                let raw_count = raw_records.len();
215
216                let records =
217                    if self.config.replication_method == ReplicationMethod::Incremental {
218                        if let (Some(key), Some(start)) = (
219                            &self.config.replication_key,
220                            &self.config.start_replication_value,
221                        ) {
222                            filter_incremental(raw_records, key, start)
223                        } else {
224                            raw_records
225                        }
226                    } else {
227                        raw_records
228                    };
229
230                let records: Vec<Value> = records
231                    .into_iter()
232                    .map(|rec| transform::apply_all(rec, &self.compiled_transforms))
233                    .collect();
234
235                yield records;
236
237                let has_next = self
238                    .config
239                    .pagination
240                    .advance(&body, &resp_headers, &mut state, raw_count)?;
241                pages_fetched += 1;
242                if !has_next {
243                    break;
244                }
245
246                if let Some(delay) = self.config.request_delay {
247                    tokio::time::sleep(delay).await;
248                }
249            }
250        })
251    }
252
253    /// Run the full pagination loop for a single partition context.
254    ///
255    /// `max_records`: when `Some(n)`, stop collecting after `n` records
256    /// (used for schema sampling).
257    async fn fetch_partition(
258        &self,
259        context: Option<&HashMap<String, Value>>,
260        max_records: Option<usize>,
261    ) -> Result<Vec<Value>, FaucetError> {
262        let mut all_records = Vec::new();
263        let mut pages_fetched = 0usize;
264        let mut pages = self.stream_pages_inner(context);
265
266        // Poll the stream without requiring StreamExt (avoids extra dependency).
267        loop {
268            let page = std::future::poll_fn(|cx: &mut std::task::Context<'_>| {
269                pages.as_mut().poll_next(cx)
270            })
271            .await;
272
273            match page {
274                Some(Ok(records)) => {
275                    pages_fetched += 1;
276                    match max_records {
277                        Some(limit) => {
278                            let remaining = limit.saturating_sub(all_records.len());
279                            all_records.extend(records.into_iter().take(remaining));
280                            if all_records.len() >= limit {
281                                break;
282                            }
283                        }
284                        None => all_records.extend(records),
285                    }
286                }
287                Some(Err(e)) => return Err(e),
288                None => break,
289            }
290        }
291
292        tracing::info!(
293            stream = self.config.name.as_deref().unwrap_or("(unnamed)"),
294            records = all_records.len(),
295            pages = pages_fetched,
296            "fetch complete"
297        );
298        Ok(all_records)
299    }
300
301    /// Execute a single HTTP request and return the response body and headers.
302    ///
303    /// - When `url_override` is `Some`, that full URL is used and query params
304    ///   are **not** appended (Link header pagination encodes them in the URL).
305    /// - When `path_context` is `Some`, `{key}` placeholders in `config.path`
306    ///   are substituted with values from the context map (partition support).
307    async fn execute_request(
308        &self,
309        params: &HashMap<String, String>,
310        url_override: Option<&str>,
311        path_context: Option<&HashMap<String, Value>>,
312    ) -> Result<(Value, HeaderMap), FaucetError> {
313        let use_override = url_override.is_some();
314        let url = match url_override {
315            Some(u) => u.to_string(),
316            None => {
317                let path = match path_context {
318                    Some(ctx) => resolve_path(&self.config.path, ctx),
319                    None => self.config.path.clone(),
320                };
321                format!("{}/{}", self.config.base_url, path.trim_start_matches('/'))
322            }
323        };
324
325        // Resolve OAuth2 credentials to a Bearer token before applying auth headers.
326        // The token is cached and reused until it expires, avoiding a token
327        // fetch on every HTTP request.
328        let resolved_auth = match &self.config.auth {
329            Auth::OAuth2 {
330                token_url,
331                client_id,
332                client_secret,
333                scopes,
334                expiry_ratio,
335            } => {
336                let token = self
337                    .token_cache
338                    .get_or_refresh(
339                        &self.client,
340                        token_url,
341                        client_id,
342                        client_secret,
343                        scopes,
344                        *expiry_ratio,
345                    )
346                    .await?;
347                Auth::Bearer(token)
348            }
349            other => other.clone(),
350        };
351
352        let mut headers = self.config.headers.clone();
353        resolved_auth.apply(&mut headers)?;
354
355        let mut req = self
356            .client
357            .request(self.config.method.clone(), &url)
358            .headers(headers);
359
360        if !use_override {
361            req = req.query(params);
362        }
363
364        // ApiKeyQuery: inject the API key as a query parameter.
365        if let Auth::ApiKeyQuery { param, value } = &self.config.auth {
366            req = req.query(&[(param.as_str(), value.as_str())]);
367        }
368
369        if let Some(body) = &self.config.body {
370            req = req.json(body);
371        }
372
373        let resp = req.send().await?;
374        let status = resp.status();
375
376        // 429 Too Many Requests: honour Retry-After before retrying.
377        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
378            let wait = parse_retry_after(resp.headers());
379            return Err(FaucetError::RateLimited(wait));
380        }
381
382        // Tolerated errors: treat as empty page.
383        if self.config.tolerated_http_errors.contains(&status.as_u16()) {
384            tracing::debug!(
385                status = status.as_u16(),
386                "tolerated HTTP error; treating as empty page"
387            );
388            return Ok((Value::Array(vec![]), HeaderMap::new()));
389        }
390
391        // For non-success responses, capture the body for debugging before
392        // returning the error. This gives callers (and logs) the server's
393        // error message rather than just a status code.
394        if !status.is_success() {
395            let resp_url = resp.url().to_string();
396            let body_text = resp.text().await.unwrap_or_default();
397            // Truncate very long error bodies to avoid bloating logs/errors.
398            let truncated = if body_text.len() > 1024 {
399                // Find a safe UTF-8 boundary at or before 1024 bytes.
400                let end = body_text.floor_char_boundary(1024);
401                format!("{}...(truncated)", &body_text[..end])
402            } else {
403                body_text
404            };
405            return Err(FaucetError::HttpStatus {
406                status: status.as_u16(),
407                url: resp_url,
408                body: truncated,
409            });
410        }
411
412        let resp_headers = resp.headers().clone();
413        let body: Value = resp.json().await?;
414        Ok((body, resp_headers))
415    }
416}
417
418/// Substitute `{key}` placeholders in `path` with values from `context`.
419fn resolve_path(path: &str, context: &HashMap<String, Value>) -> String {
420    let mut result = path.to_string();
421    for (key, value) in context {
422        let placeholder = format!("{{{key}}}");
423        let replacement = match value {
424            Value::String(s) => s.clone(),
425            other => other.to_string(),
426        };
427        result = result.replace(&placeholder, &replacement);
428    }
429    result
430}
431
432/// Parse the `Retry-After` header as a number of seconds.
433/// Falls back to 60 s if the header is absent or unparseable.
434fn parse_retry_after(headers: &HeaderMap) -> Duration {
435    headers
436        .get(reqwest::header::RETRY_AFTER)
437        .and_then(|v| v.to_str().ok())
438        .and_then(|s| s.parse::<u64>().ok())
439        .map(Duration::from_secs)
440        .unwrap_or(Duration::from_secs(60))
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use serde_json::json;
447
448    #[test]
449    fn test_resolve_path_substitutes_placeholders() {
450        let mut ctx = HashMap::new();
451        ctx.insert("org_id".to_string(), json!("acme"));
452        ctx.insert("repo".to_string(), json!("myrepo"));
453        let result = resolve_path("/orgs/{org_id}/repos/{repo}/issues", &ctx);
454        assert_eq!(result, "/orgs/acme/repos/myrepo/issues");
455    }
456
457    #[test]
458    fn test_resolve_path_no_placeholders() {
459        let ctx = HashMap::new();
460        let result = resolve_path("/api/users", &ctx);
461        assert_eq!(result, "/api/users");
462    }
463
464    #[test]
465    fn test_resolve_path_numeric_value() {
466        let mut ctx = HashMap::new();
467        ctx.insert("id".to_string(), json!(42));
468        let result = resolve_path("/items/{id}", &ctx);
469        assert_eq!(result, "/items/42");
470    }
471
472    #[test]
473    fn test_parse_retry_after_valid() {
474        let mut headers = HeaderMap::new();
475        headers.insert(
476            reqwest::header::RETRY_AFTER,
477            reqwest::header::HeaderValue::from_static("30"),
478        );
479        assert_eq!(parse_retry_after(&headers), Duration::from_secs(30));
480    }
481
482    #[test]
483    fn test_parse_retry_after_missing_defaults_to_60() {
484        assert_eq!(
485            parse_retry_after(&HeaderMap::new()),
486            Duration::from_secs(60)
487        );
488    }
489}