Skip to main content

faucet_stream/
stream.rs

1//! The main REST stream executor.
2
3use crate::config::RestStreamConfig;
4use crate::error::FaucetError;
5use crate::extract;
6use crate::pagination::{PaginationState, PaginationStyle};
7use crate::retry;
8use futures_core::Stream;
9use reqwest::Client;
10use reqwest::header::HeaderMap;
11use serde::Deserialize;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::pin::Pin;
15
16/// A configured REST API stream that handles pagination, auth, and extraction.
17pub struct RestStream {
18    config: RestStreamConfig,
19    client: Client,
20}
21
22impl RestStream {
23    /// Create a new stream from the given configuration.
24    pub fn new(config: RestStreamConfig) -> Result<Self, FaucetError> {
25        let mut builder = Client::builder();
26        if let Some(t) = config.timeout {
27            builder = builder.timeout(t);
28        }
29        Ok(Self {
30            config,
31            client: builder.build()?,
32        })
33    }
34
35    /// Fetch all records across all pages as raw JSON values.
36    pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
37        let mut all_records = Vec::new();
38        let mut state = PaginationState::default();
39        let mut pages_fetched = 0usize;
40
41        loop {
42            if let Some(max) = self.config.max_pages
43                && pages_fetched >= max
44            {
45                tracing::warn!("max pages ({max}) reached");
46                break;
47            }
48
49            let mut params = self.config.query_params.clone();
50            self.config.pagination.apply_params(&mut params, &state);
51
52            // For LinkHeader pagination, subsequent requests use the full URL from the
53            // Link header rather than constructing from base_url + path.
54            let url_override = match &self.config.pagination {
55                PaginationStyle::LinkHeader => state.next_link.clone(),
56                _ => None,
57            };
58
59            let params_clone = params.clone();
60            let (body, resp_headers) = retry::execute_with_retry(
61                self.config.max_retries,
62                self.config.retry_backoff,
63                || self.execute_request(&params_clone, url_override.as_deref()),
64            )
65            .await?;
66
67            let records = extract::extract_records(&body, self.config.records_path.as_deref())?;
68            let count = records.len();
69            all_records.extend(records);
70
71            let has_next =
72                self.config
73                    .pagination
74                    .advance(&body, &resp_headers, &mut state, count)?;
75            pages_fetched += 1;
76            if !has_next {
77                break;
78            }
79
80            if let Some(delay) = self.config.request_delay {
81                tokio::time::sleep(delay).await;
82            }
83        }
84
85        tracing::info!(
86            "fetched {} total records across {} page(s)",
87            all_records.len(),
88            pages_fetched
89        );
90        Ok(all_records)
91    }
92
93    /// Fetch all records and deserialize into typed structs.
94    pub async fn fetch_all_as<T: for<'de> Deserialize<'de>>(&self) -> Result<Vec<T>, FaucetError> {
95        let values = self.fetch_all().await?;
96        values
97            .into_iter()
98            .map(|v| serde_json::from_value(v).map_err(FaucetError::Json))
99            .collect()
100    }
101
102    /// Stream records page-by-page, yielding one `Vec<Value>` per page as it arrives.
103    ///
104    /// Unlike [`fetch_all`](Self::fetch_all), this does not wait for all pages to be fetched
105    /// before returning — callers can process each page immediately.
106    ///
107    /// ```rust,no_run
108    /// use faucet_stream::{RestStream, RestStreamConfig};
109    /// use futures::StreamExt;
110    ///
111    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
112    /// let stream = RestStream::new(RestStreamConfig::new("https://api.example.com", "/items"))?;
113    /// let mut pages = stream.stream_pages();
114    /// while let Some(page) = pages.next().await {
115    ///     let records = page?;
116    ///     println!("got {} records", records.len());
117    /// }
118    /// # Ok(())
119    /// # }
120    /// ```
121    pub fn stream_pages(
122        &self,
123    ) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + '_>> {
124        Box::pin(async_stream::try_stream! {
125            let mut state = PaginationState::default();
126            let mut pages_fetched = 0usize;
127
128            loop {
129                if let Some(max) = self.config.max_pages
130                    && pages_fetched >= max
131                {
132                    tracing::warn!("max pages ({max}) reached");
133                    break;
134                }
135
136                let mut params = self.config.query_params.clone();
137                self.config.pagination.apply_params(&mut params, &state);
138
139                let url_override = match &self.config.pagination {
140                    PaginationStyle::LinkHeader => state.next_link.clone(),
141                    _ => None,
142                };
143
144                let params_clone = params.clone();
145                let (body, resp_headers) = retry::execute_with_retry(
146                    self.config.max_retries,
147                    self.config.retry_backoff,
148                    || self.execute_request(&params_clone, url_override.as_deref()),
149                )
150                .await?;
151
152                let records = extract::extract_records(&body, self.config.records_path.as_deref())?;
153                let count = records.len();
154
155                yield records;
156
157                let has_next = self
158                    .config
159                    .pagination
160                    .advance(&body, &resp_headers, &mut state, count)?;
161                pages_fetched += 1;
162                if !has_next {
163                    break;
164                }
165
166                if let Some(delay) = self.config.request_delay {
167                    tokio::time::sleep(delay).await;
168                }
169            }
170        })
171    }
172
173    /// Execute a single HTTP request and return the response body and headers.
174    ///
175    /// When `url_override` is `Some`, that URL is used directly (no query params are
176    /// appended — the override URL already encodes them, as with Link header pagination).
177    async fn execute_request(
178        &self,
179        params: &HashMap<String, String>,
180        url_override: Option<&str>,
181    ) -> Result<(Value, HeaderMap), FaucetError> {
182        let use_override = url_override.is_some();
183        let url = match url_override {
184            Some(u) => u.to_string(),
185            None => format!(
186                "{}/{}",
187                self.config.base_url,
188                self.config.path.trim_start_matches('/')
189            ),
190        };
191
192        let mut headers = self.config.headers.clone();
193        self.config.auth.apply(&mut headers)?;
194
195        let mut req = self
196            .client
197            .request(self.config.method.clone(), &url)
198            .headers(headers);
199
200        if !use_override {
201            req = req.query(params);
202        }
203
204        if let Some(body) = &self.config.body {
205            req = req.json(body);
206        }
207
208        let resp = req.send().await?.error_for_status()?;
209        let resp_headers = resp.headers().clone();
210        let body: Value = resp.json().await?;
211        Ok((body, resp_headers))
212    }
213}