use crate::auth::Auth;
use crate::auth::oauth2::TokenCache;
use crate::auth::token_endpoint::TokenEndpointCache;
use crate::config::RestStreamConfig;
use crate::extract;
use crate::pagination::{PaginationState, PaginationStyle};
use crate::retry;
use async_trait::async_trait;
use faucet_core::FaucetError;
use faucet_core::replication::{ReplicationMethod, filter_incremental, max_replication_value};
use faucet_core::schema;
use faucet_core::transform::{self, CompiledTransform};
use futures_core::Stream;
use reqwest::Client;
use reqwest::header::HeaderMap;
use serde::Deserialize;
use serde_json::Value;
use std::collections::HashMap;
use std::pin::Pin;
use std::time::Duration;
pub struct RestStream {
config: RestStreamConfig,
client: Client,
compiled_transforms: Vec<CompiledTransform>,
token_cache: TokenCache,
token_endpoint_cache: TokenEndpointCache,
}
impl RestStream {
pub fn new(config: RestStreamConfig) -> Result<Self, FaucetError> {
let expiry_ratio_to_validate = match &config.auth {
Auth::OAuth2 { expiry_ratio, .. } | Auth::TokenEndpoint { expiry_ratio, .. } => {
Some(*expiry_ratio)
}
_ => None,
};
if let Some(ratio) = expiry_ratio_to_validate
&& (ratio <= 0.0 || ratio > 1.0)
{
return Err(FaucetError::Auth(format!(
"expiry_ratio must be in (0.0, 1.0], got {ratio}"
)));
}
let compiled_transforms = config
.transforms
.iter()
.map(transform::compile)
.collect::<Result<Vec<_>, _>>()?;
let mut builder = Client::builder();
if let Some(t) = config.timeout {
builder = builder.timeout(t);
}
Ok(Self {
config,
client: builder.build()?,
compiled_transforms,
token_cache: TokenCache::new(),
token_endpoint_cache: TokenEndpointCache::new(),
})
}
pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
if self.config.partitions.is_empty() {
self.fetch_partition(None, None).await
} else if let Some(concurrency) = self.config.partition_concurrency {
let semaphore = std::sync::Arc::new(tokio::sync::Semaphore::new(concurrency.max(1)));
let mut handles = Vec::with_capacity(self.config.partitions.len());
for ctx in &self.config.partitions {
let permit =
semaphore.clone().acquire_owned().await.map_err(|e| {
FaucetError::Config(format!("semaphore acquire failed: {e}"))
})?;
let fut = self.fetch_partition(Some(ctx), None);
handles.push(async move {
let result = fut.await;
drop(permit);
result
});
}
let results = futures::future::try_join_all(handles).await?;
Ok(results.into_iter().flatten().collect())
} else {
let mut all_records = Vec::new();
for ctx in &self.config.partitions {
let records = self.fetch_partition(Some(ctx), None).await?;
all_records.extend(records);
}
Ok(all_records)
}
}
pub async fn fetch_all_as<T: for<'de> Deserialize<'de>>(&self) -> Result<Vec<T>, FaucetError> {
let values = self.fetch_all().await?;
values
.into_iter()
.map(|v| serde_json::from_value(v).map_err(FaucetError::Json))
.collect()
}
pub async fn infer_schema(&self) -> Result<Value, FaucetError> {
if let Some(ref s) = self.config.schema {
return Ok(s.clone());
}
let limit = match self.config.schema_sample_size {
0 => None,
n => Some(n),
};
let records = self.fetch_partition(None, limit).await?;
Ok(schema::infer_schema(&records))
}
pub async fn fetch_all_incremental(&self) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
let records = self.fetch_all().await?;
let bookmark = self
.config
.replication_key
.as_deref()
.and_then(|key| max_replication_value(&records, key))
.cloned();
Ok((records, bookmark))
}
pub fn stream_pages(
&self,
) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + Send + '_>> {
self.stream_pages_inner(None)
}
fn stream_pages_inner(
&self,
context: Option<&HashMap<String, Value>>,
) -> Pin<Box<dyn Stream<Item = Result<Vec<Value>, FaucetError>> + Send + '_>> {
let owned_context: Option<HashMap<String, Value>> = context.cloned();
Box::pin(async_stream::try_stream! {
let mut state = PaginationState::default();
let mut pages_fetched = 0usize;
loop {
if let Some(max) = self.config.max_pages
&& pages_fetched >= max
{
tracing::warn!("max pages ({max}) reached");
break;
}
let mut params = self.config.query_params.clone();
self.config.pagination.apply_params(&mut params, &state);
let url_override = match &self.config.pagination {
PaginationStyle::LinkHeader | PaginationStyle::NextLinkInBody { .. } => {
state.next_link.clone()
}
_ => None,
};
let params_clone = params.clone();
let ctx_ref = owned_context.as_ref();
let (body, resp_headers) = retry::execute_with_retry(
self.config.max_retries,
self.config.retry_backoff,
|| self.execute_request(¶ms_clone, url_override.as_deref(), ctx_ref),
)
.await?;
let raw_records =
extract::extract_records(&body, self.config.records_path.as_deref())?;
let raw_count = raw_records.len();
let records =
if self.config.replication_method == ReplicationMethod::Incremental {
if let (Some(key), Some(start)) = (
&self.config.replication_key,
&self.config.start_replication_value,
) {
filter_incremental(raw_records, key, start)
} else {
raw_records
}
} else {
raw_records
};
let records: Vec<Value> = records
.into_iter()
.map(|rec| transform::apply_all(rec, &self.compiled_transforms))
.collect();
yield records;
let has_next = self
.config
.pagination
.advance(&body, &resp_headers, &mut state, raw_count)?;
pages_fetched += 1;
if !has_next {
break;
}
if let Some(delay) = self.config.request_delay {
tokio::time::sleep(delay).await;
}
}
})
}
async fn fetch_partition(
&self,
context: Option<&HashMap<String, Value>>,
max_records: Option<usize>,
) -> Result<Vec<Value>, FaucetError> {
let mut all_records = Vec::new();
let mut pages_fetched = 0usize;
let mut pages = self.stream_pages_inner(context);
loop {
let page = std::future::poll_fn(|cx: &mut std::task::Context<'_>| {
pages.as_mut().poll_next(cx)
})
.await;
match page {
Some(Ok(records)) => {
pages_fetched += 1;
match max_records {
Some(limit) => {
let remaining = limit.saturating_sub(all_records.len());
all_records.extend(records.into_iter().take(remaining));
if all_records.len() >= limit {
break;
}
}
None => all_records.extend(records),
}
}
Some(Err(e)) => return Err(e),
None => break,
}
}
tracing::info!(
stream = self.config.name.as_deref().unwrap_or("(unnamed)"),
records = all_records.len(),
pages = pages_fetched,
"fetch complete"
);
Ok(all_records)
}
async fn execute_request(
&self,
params: &HashMap<String, String>,
url_override: Option<&str>,
path_context: Option<&HashMap<String, Value>>,
) -> Result<(Value, HeaderMap), FaucetError> {
let use_override = url_override.is_some();
let url = match url_override {
Some(u) => u.to_string(),
None => {
let path = match path_context {
Some(ctx) => resolve_path(&self.config.path, ctx),
None => self.config.path.clone(),
};
format!("{}/{}", self.config.base_url, path.trim_start_matches('/'))
}
};
let resolved_auth = match &self.config.auth {
Auth::OAuth2 {
token_url,
client_id,
client_secret,
scopes,
expiry_ratio,
} => {
let token = self
.token_cache
.get_or_refresh(
&self.client,
token_url,
client_id,
client_secret,
scopes,
*expiry_ratio,
)
.await?;
Auth::Bearer(token)
}
Auth::TokenEndpoint {
url: token_url,
method: token_method,
headers: token_headers,
body: token_body,
token_path,
expiry_path,
expiry_ratio,
response_validator,
} => {
let token = self
.token_endpoint_cache
.get_or_refresh(
&self.client,
token_url,
token_method,
token_headers,
token_body.as_ref(),
token_path,
expiry_path.as_deref(),
*expiry_ratio,
response_validator.as_ref(),
)
.await?;
Auth::Bearer(token)
}
other => other.clone(),
};
let mut headers = self.config.headers.clone();
resolved_auth.apply(&mut headers)?;
let mut req = self
.client
.request(self.config.method.clone(), &url)
.headers(headers);
if !use_override {
req = req.query(params);
}
if let Auth::ApiKeyQuery { param, value } = &self.config.auth {
req = req.query(&[(param.as_str(), value.as_str())]);
}
if let Some(body) = &self.config.body {
req = req.json(body);
}
let resp = req.send().await?;
let status = resp.status();
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
let wait = parse_retry_after(resp.headers());
return Err(FaucetError::RateLimited(wait));
}
if self.config.tolerated_http_errors.contains(&status.as_u16()) {
tracing::debug!(
status = status.as_u16(),
"tolerated HTTP error; treating as empty page"
);
return Ok((Value::Array(vec![]), HeaderMap::new()));
}
if !status.is_success() {
let resp_url = resp.url().to_string();
let body_text = resp.text().await.unwrap_or_default();
let truncated = if body_text.len() > 1024 {
let end = body_text.floor_char_boundary(1024);
format!("{}...(truncated)", &body_text[..end])
} else {
body_text
};
return Err(FaucetError::HttpStatus {
status: status.as_u16(),
url: resp_url,
body: truncated,
});
}
let resp_headers = resp.headers().clone();
let body: Value = resp.json().await?;
Ok((body, resp_headers))
}
}
fn resolve_path(path: &str, context: &HashMap<String, Value>) -> String {
let mut result = path.to_string();
for (key, value) in context {
let placeholder = format!("{{{key}}}");
let replacement = match value {
Value::String(s) => s.clone(),
other => other.to_string(),
};
result = result.replace(&placeholder, &replacement);
}
result
}
fn parse_retry_after(headers: &HeaderMap) -> Duration {
headers
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.map(Duration::from_secs)
.unwrap_or(Duration::from_secs(60))
}
#[async_trait]
impl faucet_core::Source for RestStream {
async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
RestStream::fetch_all(self).await
}
async fn fetch_all_incremental(&self) -> Result<(Vec<Value>, Option<Value>), FaucetError> {
RestStream::fetch_all_incremental(self).await
}
fn config_schema(&self) -> serde_json::Value {
serde_json::to_value(faucet_core::schema_for!(RestStreamConfig))
.expect("schema serialization")
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_resolve_path_substitutes_placeholders() {
let mut ctx = HashMap::new();
ctx.insert("org_id".to_string(), json!("acme"));
ctx.insert("repo".to_string(), json!("myrepo"));
let result = resolve_path("/orgs/{org_id}/repos/{repo}/issues", &ctx);
assert_eq!(result, "/orgs/acme/repos/myrepo/issues");
}
#[test]
fn test_resolve_path_no_placeholders() {
let ctx = HashMap::new();
let result = resolve_path("/api/users", &ctx);
assert_eq!(result, "/api/users");
}
#[test]
fn test_resolve_path_numeric_value() {
let mut ctx = HashMap::new();
ctx.insert("id".to_string(), json!(42));
let result = resolve_path("/items/{id}", &ctx);
assert_eq!(result, "/items/42");
}
#[test]
fn test_parse_retry_after_valid() {
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::RETRY_AFTER,
reqwest::header::HeaderValue::from_static("30"),
);
assert_eq!(parse_retry_after(&headers), Duration::from_secs(30));
}
#[test]
fn test_parse_retry_after_missing_defaults_to_60() {
assert_eq!(
parse_retry_after(&HeaderMap::new()),
Duration::from_secs(60)
);
}
#[test]
fn test_parse_retry_after_non_numeric_defaults_to_60() {
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::RETRY_AFTER,
reqwest::header::HeaderValue::from_static("not-a-number"),
);
assert_eq!(parse_retry_after(&headers), Duration::from_secs(60));
}
#[test]
fn test_new_rejects_invalid_expiry_ratio_zero() {
let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
token_url: "https://auth.example.com/token".into(),
client_id: "id".into(),
client_secret: "secret".into(),
scopes: vec![],
expiry_ratio: 0.0,
});
let result = RestStream::new(config);
assert!(result.is_err());
assert!(matches!(result, Err(FaucetError::Auth(_))));
}
#[test]
fn test_new_rejects_invalid_expiry_ratio_negative() {
let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
token_url: "https://auth.example.com/token".into(),
client_id: "id".into(),
client_secret: "secret".into(),
scopes: vec![],
expiry_ratio: -0.5,
});
assert!(RestStream::new(config).is_err());
}
#[test]
fn test_new_rejects_invalid_expiry_ratio_above_one() {
let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
token_url: "https://auth.example.com/token".into(),
client_id: "id".into(),
client_secret: "secret".into(),
scopes: vec![],
expiry_ratio: 1.5,
});
assert!(RestStream::new(config).is_err());
}
#[test]
fn test_new_accepts_valid_expiry_ratio() {
let config = RestStreamConfig::new("https://example.com", "/data").auth(Auth::OAuth2 {
token_url: "https://auth.example.com/token".into(),
client_id: "id".into(),
client_secret: "secret".into(),
scopes: vec![],
expiry_ratio: 1.0,
});
assert!(RestStream::new(config).is_ok());
}
#[test]
fn test_new_rejects_invalid_transform_regex() {
let config = RestStreamConfig::new("https://example.com", "/data").add_transform(
faucet_core::RecordTransform::RenameKeys {
pattern: "[invalid".into(),
replacement: "".into(),
},
);
let result = RestStream::new(config);
assert!(result.is_err());
assert!(matches!(result, Err(FaucetError::Transform(_))));
}
#[test]
fn test_new_with_no_auth_succeeds() {
let config = RestStreamConfig::new("https://example.com", "/data");
assert!(RestStream::new(config).is_ok());
}
#[test]
fn test_new_with_timeout() {
let config =
RestStreamConfig::new("https://example.com", "/data").timeout(Duration::from_secs(10));
assert!(RestStream::new(config).is_ok());
}
#[test]
fn test_resolve_path_missing_placeholder_unchanged() {
let mut ctx = HashMap::new();
ctx.insert("org".to_string(), json!("acme"));
let result = resolve_path("/items/{missing}", &ctx);
assert_eq!(result, "/items/{missing}");
}
#[test]
fn test_resolve_path_boolean_value() {
let mut ctx = HashMap::new();
ctx.insert("flag".to_string(), json!(true));
let result = resolve_path("/items/{flag}", &ctx);
assert_eq!(result, "/items/true");
}
}