Skip to main content

faucet_stream/
stream.rs

1//! The main REST stream executor.
2
3use crate::config::RestStreamConfig;
4use crate::error::FaucetError;
5use crate::extract;
6use crate::pagination::{PaginationState, PaginationStyle};
7use crate::retry;
8use futures_core::Stream;
9use reqwest::Client;
10use reqwest::header::HeaderMap;
11use serde::Deserialize;
12use serde_json::Value;
13use std::collections::HashMap;
14use std::pin::Pin;
15
16/// A configured REST API stream that handles pagination, auth, and extraction.
17pub struct RestStream {
18    config: RestStreamConfig,
19    client: Client,
20}
21
22impl RestStream {
23    /// Create a new stream from the given configuration.
24    pub fn new(config: RestStreamConfig) -> Result<Self, FaucetError> {
25        let mut builder = Client::builder();
26        if let Some(t) = config.timeout {
27            builder = builder.timeout(t);
28        }
29        Ok(Self {
30            config,
31            client: builder.build()?,
32        })
33    }
34
35    /// Fetch all records across all pages as raw JSON values.
36    pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
37        let mut all_records = Vec::new();
38        let mut state = PaginationState::default();
39        let mut pages_fetched = 0usize;
40
41        loop {
42            if let Some(max) = self.config.max_pages
43                && pages_fetched >= max
44            {
45                tracing::warn!("max pages ({max}) reached");
46                break;
47            }
48
49            let mut params = self.config.query_params.clone();
50            self.config.pagination.apply_params(&mut params, &state);
51
52            // For LinkHeader pagination, subsequent requests use the full URL from the
53            // Link header rather than constructing from base_url + path.
54            let url_override = match &self.config.pagination {
55                PaginationStyle::LinkHeader | PaginationStyle::NextLinkInBody { .. } => {
56                    state.next_link.clone()
57                }
58                _ => None,
59            };
60
61            let params_clone = params.clone();
62            let (body, resp_headers) = retry::execute_with_retry(
63                self.config.max_retries,
64                self.config.retry_backoff,
65                || self.execute_request(&params_clone, url_override.as_deref()),
66            )
67            .await?;
68
69            let records = extract::extract_records(&body, self.config.records_path.as_deref())?;
70            let count = records.len();
71            all_records.extend(records);
72
73            let has_next =
74                self.config
75                    .pagination
76                    .advance(&body, &resp_headers, &mut state, count)?;
77            pages_fetched += 1;
78            if !has_next {
79                break;
80            }
81
82            if let Some(delay) = self.config.request_delay {
83                tokio::time::sleep(delay).await;
84            }
85        }
86
87        tracing::info!(
88            "fetched {} total records across {} page(s)",
89            all_records.len(),
90            pages_fetched
91        );
92        Ok(all_records)
93    }
94
95    /// Fetch all records and deserialize into typed structs.
96    pub async fn fetch_all_as<T: for<'de> Deserialize<'de>>(&self) -> Result<Vec<T>, FaucetError> {
97        let values = self.fetch_all().await?;
98        values
99            .into_iter()
100            .map(|v| serde_json::from_value(v).map_err(FaucetError::Json))
101            .collect()
102    }
103
104    /// Stream records page-by-page, yielding one `Vec<Value>` per page as it arrives.
105    ///
106    /// Unlike [`fetch_all`](Self::fetch_all), this does not wait for all pages to be fetched
107    /// before returning — callers can process each page immediately.
108    ///
109    /// ```rust,no_run
110    /// use faucet_stream::{RestStream, RestStreamConfig};
111    /// use futures::StreamExt;
112    ///
113    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
114    /// let stream = RestStream::new(RestStreamConfig::new("https://api.example.com", "/items"))?;
115    /// let mut pages = stream.stream_pages();
116    /// while let Some(page) = pages.next().await {
117    ///     let records = page?;
118    ///     println!("got {} records", records.len());
119    /// }
120    /// # Ok(())
121    /// # }
122    /// ```
123    pub fn stream_pages(
124        &self,
125    ) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + '_>> {
126        Box::pin(async_stream::try_stream! {
127            let mut state = PaginationState::default();
128            let mut pages_fetched = 0usize;
129
130            loop {
131                if let Some(max) = self.config.max_pages
132                    && pages_fetched >= max
133                {
134                    tracing::warn!("max pages ({max}) reached");
135                    break;
136                }
137
138                let mut params = self.config.query_params.clone();
139                self.config.pagination.apply_params(&mut params, &state);
140
141                let url_override = match &self.config.pagination {
142                    PaginationStyle::LinkHeader | PaginationStyle::NextLinkInBody { .. } => {
143                        state.next_link.clone()
144                    }
145                    _ => None,
146                };
147
148                let params_clone = params.clone();
149                let (body, resp_headers) = retry::execute_with_retry(
150                    self.config.max_retries,
151                    self.config.retry_backoff,
152                    || self.execute_request(&params_clone, url_override.as_deref()),
153                )
154                .await?;
155
156                let records = extract::extract_records(&body, self.config.records_path.as_deref())?;
157                let count = records.len();
158
159                yield records;
160
161                let has_next = self
162                    .config
163                    .pagination
164                    .advance(&body, &resp_headers, &mut state, count)?;
165                pages_fetched += 1;
166                if !has_next {
167                    break;
168                }
169
170                if let Some(delay) = self.config.request_delay {
171                    tokio::time::sleep(delay).await;
172                }
173            }
174        })
175    }
176
177    /// Execute a single HTTP request and return the response body and headers.
178    ///
179    /// When `url_override` is `Some`, that URL is used directly (no query params are
180    /// appended — the override URL already encodes them, as with Link header pagination).
181    async fn execute_request(
182        &self,
183        params: &HashMap<String, String>,
184        url_override: Option<&str>,
185    ) -> Result<(Value, HeaderMap), FaucetError> {
186        let use_override = url_override.is_some();
187        let url = match url_override {
188            Some(u) => u.to_string(),
189            None => format!(
190                "{}/{}",
191                self.config.base_url,
192                self.config.path.trim_start_matches('/')
193            ),
194        };
195
196        let mut headers = self.config.headers.clone();
197        self.config.auth.apply(&mut headers)?;
198
199        let mut req = self
200            .client
201            .request(self.config.method.clone(), &url)
202            .headers(headers);
203
204        if !use_override {
205            req = req.query(params);
206        }
207
208        if let Some(body) = &self.config.body {
209            req = req.json(body);
210        }
211
212        let resp = req.send().await?.error_for_status()?;
213        let resp_headers = resp.headers().clone();
214        let body: Value = resp.json().await?;
215        Ok((body, resp_headers))
216    }
217}