1use crate::config::RestStreamConfig;
4use crate::error::FaucetError;
5use crate::extract;
6use crate::pagination::{PaginationState, PaginationStyle};
7use crate::retry;
8use reqwest::Client;
9use reqwest::header::HeaderMap;
10use serde::Deserialize;
11use serde_json::Value;
12use std::collections::HashMap;
13
14pub struct RestStream {
16 config: RestStreamConfig,
17 client: Client,
18}
19
20impl RestStream {
21 pub fn new(config: RestStreamConfig) -> Result<Self, FaucetError> {
23 let mut builder = Client::builder();
24 if let Some(t) = config.timeout {
25 builder = builder.timeout(t);
26 }
27 Ok(Self {
28 config,
29 client: builder.build()?,
30 })
31 }
32
33 pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
35 let mut all_records = Vec::new();
36 let mut state = PaginationState::default();
37 let mut pages_fetched = 0usize;
38
39 loop {
40 if let Some(max) = self.config.max_pages
41 && pages_fetched >= max
42 {
43 tracing::warn!("max pages ({max}) reached");
44 break;
45 }
46
47 let mut params = self.config.query_params.clone();
48 self.config.pagination.apply_params(&mut params, &state);
49
50 let url_override = match &self.config.pagination {
53 PaginationStyle::LinkHeader => state.next_link.clone(),
54 _ => None,
55 };
56
57 let params_clone = params.clone();
58 let (body, resp_headers) = retry::execute_with_retry(
59 self.config.max_retries,
60 self.config.retry_backoff,
61 || self.execute_request(¶ms_clone, url_override.as_deref()),
62 )
63 .await?;
64
65 let records = extract::extract_records(&body, self.config.records_path.as_deref())?;
66 let count = records.len();
67 all_records.extend(records);
68
69 let has_next =
70 self.config
71 .pagination
72 .advance(&body, &resp_headers, &mut state, count)?;
73 pages_fetched += 1;
74 if !has_next {
75 break;
76 }
77
78 if let Some(delay) = self.config.request_delay {
79 tokio::time::sleep(delay).await;
80 }
81 }
82
83 tracing::info!(
84 "fetched {} total records across {} page(s)",
85 all_records.len(),
86 pages_fetched
87 );
88 Ok(all_records)
89 }
90
91 pub async fn fetch_all_as<T: for<'de> Deserialize<'de>>(&self) -> Result<Vec<T>, FaucetError> {
93 let values = self.fetch_all().await?;
94 values
95 .into_iter()
96 .map(|v| serde_json::from_value(v).map_err(FaucetError::Json))
97 .collect()
98 }
99
100 async fn execute_request(
105 &self,
106 params: &HashMap<String, String>,
107 url_override: Option<&str>,
108 ) -> Result<(Value, HeaderMap), FaucetError> {
109 let use_override = url_override.is_some();
110 let url = match url_override {
111 Some(u) => u.to_string(),
112 None => format!(
113 "{}/{}",
114 self.config.base_url,
115 self.config.path.trim_start_matches('/')
116 ),
117 };
118
119 let mut headers = self.config.headers.clone();
120 self.config.auth.apply(&mut headers)?;
121
122 let mut req = self
123 .client
124 .request(self.config.method.clone(), &url)
125 .headers(headers);
126
127 if !use_override {
128 req = req.query(params);
129 }
130
131 if let Some(body) = &self.config.body {
132 req = req.json(body);
133 }
134
135 let resp = req.send().await?.error_for_status()?;
136 let resp_headers = resp.headers().clone();
137 let body: Value = resp.json().await?;
138 Ok((body, resp_headers))
139 }
140}