use crate::error::{io_error_with_path, FoundationError, Result};
use std::future::Future;
use std::path::Path;
use std::time::Duration;
pub async fn with_retry<T, F, Fut>(mut operation: F, max_retries: u32, delay_ms: u64) -> Result<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T>>,
{
let mut retries = 0;
loop {
match operation().await {
Ok(value) => return Ok(value),
Err(e) => {
retries += 1;
if retries >= max_retries {
return Err(e);
}
tracing::warn!(
"Operation attempt {} failed, retrying... Error: {}",
retries,
e
);
tokio::time::sleep(Duration::from_millis(delay_ms * retries as u64)).await;
}
}
}
}
pub struct HttpClient {
client: reqwest::Client,
timeout: Duration,
}
pub struct HttpClientBuilder {
timeout: Duration,
headers: reqwest::header::HeaderMap,
}
impl HttpClientBuilder {
pub fn new() -> Self {
Self {
timeout: Duration::from_secs(30),
headers: reqwest::header::HeaderMap::new(),
}
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn bearer_auth(mut self, token: &str) -> Result<Self> {
let value = reqwest::header::HeaderValue::from_str(&format!("Bearer {token}"))
.map_err(|e| FoundationError::Authentication(format!("Invalid auth token: {e}")))?;
self.headers.insert(reqwest::header::AUTHORIZATION, value);
Ok(self)
}
pub fn user_agent(mut self, user_agent: &str) -> Result<Self> {
let value = reqwest::header::HeaderValue::from_str(user_agent).map_err(|e| {
FoundationError::InvalidInput(format!("Invalid user-agent header: {e}"))
})?;
self.headers.insert(reqwest::header::USER_AGENT, value);
Ok(self)
}
pub fn header(mut self, key: reqwest::header::HeaderName, value: &str) -> Result<Self> {
let value = reqwest::header::HeaderValue::from_str(value).map_err(|e| {
FoundationError::InvalidInput(format!("Invalid header value for {key}: {e}"))
})?;
self.headers.insert(key, value);
Ok(self)
}
pub fn build(self) -> Result<HttpClient> {
let client = reqwest::Client::builder()
.timeout(self.timeout)
.default_headers(self.headers)
.build()
.map_err(|e| {
FoundationError::Other(anyhow::anyhow!("Failed to build HTTP client: {e}"))
})?;
Ok(HttpClient {
client,
timeout: self.timeout,
})
}
}
impl Default for HttpClientBuilder {
fn default() -> Self {
Self::new()
}
}
impl HttpClient {
pub fn new() -> Result<Self> {
HttpClientBuilder::new().build()
}
pub fn builder() -> HttpClientBuilder {
HttpClientBuilder::new()
}
pub fn with_timeout(timeout: Duration) -> Result<Self> {
HttpClientBuilder::new().timeout(timeout).build()
}
pub fn timeout(&self) -> Duration {
self.timeout
}
pub async fn download(&self, url: &str) -> Result<Vec<u8>> {
let response = self.client.get(url).send().await.map_err(|e| {
FoundationError::Other(anyhow::anyhow!("HTTP request failed for {url}: {e}"))
})?;
if !response.status().is_success() {
return Err(FoundationError::Other(anyhow::anyhow!(
"HTTP request to {url} failed with status: {}",
response.status()
)));
}
response.bytes().await.map(|b| b.to_vec()).map_err(|e| {
FoundationError::Other(anyhow::anyhow!("Failed to read response body: {e}"))
})
}
pub async fn download_text(&self, url: &str) -> Result<String> {
let bytes = self.download(url).await?;
String::from_utf8(bytes)
.map_err(|e| FoundationError::Parse(format!("Response body is not valid UTF-8: {e}")))
}
pub async fn download_json<T>(&self, url: &str) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
let bytes = self.download(url).await?;
serde_json::from_slice(&bytes).map_err(Into::into)
}
pub async fn download_to_file(&self, url: &str, path: impl AsRef<Path>) -> Result<()> {
let bytes = self.download(url).await?;
let path = path.as_ref();
std::fs::write(path, bytes)
.map_err(|e| io_error_with_path(e, path, "write downloaded content to"))
}
}
impl Default for HttpClient {
fn default() -> Self {
Self::new().expect("Failed to create default HTTP client")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_http_client_creation() {
let client = HttpClient::new().unwrap();
assert_eq!(client.timeout(), Duration::from_secs(30));
}
#[tokio::test]
async fn test_custom_timeout() {
let client = HttpClient::with_timeout(Duration::from_secs(10)).unwrap();
assert_eq!(client.timeout(), Duration::from_secs(10));
}
#[tokio::test]
async fn test_with_retry_success() {
let result = with_retry(|| async { Ok::<_, FoundationError>("success") }, 3, 10).await;
assert_eq!(result.unwrap(), "success");
}
#[tokio::test]
async fn test_with_retry_fail_then_succeed() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_clone = attempts.clone();
let result = with_retry(
move || {
let attempts = attempts_clone.clone();
async move {
let count = attempts.fetch_add(1, Ordering::SeqCst) + 1;
if count < 3 {
Err(FoundationError::Other(anyhow::anyhow!("fail")))
} else {
Ok("success")
}
}
},
5,
10,
)
.await;
assert_eq!(result.unwrap(), "success");
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
}