use crate::{Error, Result};
use reqwest::{Client, Response};
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, trace, warn};
const DEFAULT_MAX_RETRIES: u32 = 3;
const DEFAULT_INITIAL_BACKOFF_MS: u64 = 100;
const DEFAULT_MAX_BACKOFF_MS: u64 = 10_000;
const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
const DEFAULT_JITTER_FACTOR: f64 = 0.1;
const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 30;
const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 300;
#[derive(Debug, Clone)]
pub struct CdnClient {
client: Client,
max_retries: u32,
initial_backoff_ms: u64,
max_backoff_ms: u64,
backoff_multiplier: f64,
jitter_factor: f64,
user_agent: Option<String>,
}
impl CdnClient {
pub fn new() -> Result<Self> {
let client = Client::builder()
.connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
.timeout(Duration::from_secs(DEFAULT_REQUEST_TIMEOUT_SECS))
.pool_max_idle_per_host(20) .gzip(true) .deflate(true) .build()?;
Ok(Self {
client,
max_retries: DEFAULT_MAX_RETRIES,
initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
jitter_factor: DEFAULT_JITTER_FACTOR,
user_agent: None,
})
}
pub fn with_client(client: Client) -> Self {
Self {
client,
max_retries: DEFAULT_MAX_RETRIES,
initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
jitter_factor: DEFAULT_JITTER_FACTOR,
user_agent: None,
}
}
pub fn builder() -> CdnClientBuilder {
CdnClientBuilder::new()
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_initial_backoff_ms(mut self, initial_backoff_ms: u64) -> Self {
self.initial_backoff_ms = initial_backoff_ms;
self
}
pub fn with_max_backoff_ms(mut self, max_backoff_ms: u64) -> Self {
self.max_backoff_ms = max_backoff_ms;
self
}
pub fn with_backoff_multiplier(mut self, backoff_multiplier: f64) -> Self {
self.backoff_multiplier = backoff_multiplier;
self
}
pub fn with_jitter_factor(mut self, jitter_factor: f64) -> Self {
self.jitter_factor = jitter_factor.clamp(0.0, 1.0);
self
}
pub fn with_user_agent(mut self, user_agent: impl Into<String>) -> Self {
self.user_agent = Some(user_agent.into());
self
}
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_wrap,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
pub fn calculate_backoff(&self, attempt: u32) -> Duration {
let base_backoff =
self.initial_backoff_ms as f64 * self.backoff_multiplier.powi(attempt as i32);
let capped_backoff = base_backoff.min(self.max_backoff_ms as f64);
let jitter_range = capped_backoff * self.jitter_factor;
let jitter = rand::random::<f64>() * 2.0 * jitter_range - jitter_range;
let final_backoff = (capped_backoff + jitter).max(0.0) as u64;
Duration::from_millis(final_backoff)
}
async fn execute_with_retry(&self, url: &str) -> Result<Response> {
let mut last_error = None;
for attempt in 0..=self.max_retries {
if attempt > 0 {
let backoff = self.calculate_backoff(attempt - 1);
debug!("CDN retry attempt {} after {:?} backoff", attempt, backoff);
sleep(backoff).await;
}
debug!("CDN request to {} (attempt {})", url, attempt + 1);
let mut request = self.client.get(url);
if let Some(ref user_agent) = self.user_agent {
request = request.header("User-Agent", user_agent);
}
match request.send().await {
Ok(response) => {
trace!("Response status: {}", response.status());
let status = response.status();
if status.is_success() {
return Ok(response);
}
if status == reqwest::StatusCode::TOO_MANY_REQUESTS
&& attempt < self.max_retries
{
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(60);
warn!(
"Rate limited (attempt {}): retry after {} seconds",
attempt + 1,
retry_after
);
last_error = Some(Error::rate_limited(retry_after));
continue;
}
if status.is_server_error() && attempt < self.max_retries {
warn!(
"Server error {} (attempt {}): will retry",
status,
attempt + 1
);
last_error = Some(Error::Http(response.error_for_status().unwrap_err()));
continue;
}
if status.is_client_error() {
if status == reqwest::StatusCode::NOT_FOUND {
let parts: Vec<&str> = url.rsplitn(2, '/').collect();
let hash = parts.first().copied().unwrap_or("unknown");
return Err(Error::content_not_found(hash));
}
return Err(Error::Http(response.error_for_status().unwrap_err()));
}
return Err(Error::Http(response.error_for_status().unwrap_err()));
}
Err(e) => {
let is_retryable = e.is_connect() || e.is_timeout() || e.is_request();
if is_retryable && attempt < self.max_retries {
warn!(
"Request failed (attempt {}): {}, will retry",
attempt + 1,
e
);
last_error = Some(Error::Http(e));
} else {
debug!(
"Request failed (attempt {}): {}, not retrying",
attempt + 1,
e
);
return Err(Error::Http(e));
}
}
}
}
Err(last_error.unwrap_or_else(|| Error::invalid_response("All retry attempts failed")))
}
pub async fn request(&self, url: &str) -> Result<Response> {
self.execute_with_retry(url).await
}
pub fn build_url(cdn_host: &str, path: &str, hash: &str) -> Result<String> {
if hash.len() < 4 || !hash.chars().all(|c| c.is_ascii_hexdigit()) {
return Err(Error::invalid_hash(hash));
}
let url = format!(
"http://{}/{}/{}/{}/{}",
cdn_host,
path.trim_matches('/'),
&hash[0..2],
&hash[2..4],
hash
);
Ok(url)
}
pub async fn download(&self, cdn_host: &str, path: &str, hash: &str) -> Result<Response> {
let url = Self::build_url(cdn_host, path, hash)?;
self.request(&url).await
}
pub async fn download_build_config(
&self,
cdn_host: &str,
path: &str,
hash: &str,
) -> Result<Response> {
let config_path = format!("{}/config", path.trim_end_matches('/'));
self.download(cdn_host, &config_path, hash).await
}
pub async fn download_cdn_config(
&self,
cdn_host: &str,
path: &str,
hash: &str,
) -> Result<Response> {
let config_path = format!("{}/config", path.trim_end_matches('/'));
self.download(cdn_host, &config_path, hash).await
}
pub async fn download_product_config(
&self,
cdn_host: &str,
config_path: &str,
hash: &str,
) -> Result<Response> {
self.download(cdn_host, config_path, hash).await
}
pub async fn download_key_ring(
&self,
cdn_host: &str,
path: &str,
hash: &str,
) -> Result<Response> {
let config_path = format!("{}/config", path.trim_end_matches('/'));
self.download(cdn_host, &config_path, hash).await
}
pub async fn download_data(&self, cdn_host: &str, path: &str, hash: &str) -> Result<Response> {
let data_path = format!("{}/data", path.trim_end_matches('/'));
self.download(cdn_host, &data_path, hash).await
}
pub async fn download_patch(&self, cdn_host: &str, path: &str, hash: &str) -> Result<Response> {
let patch_path = format!("{}/patch", path.trim_end_matches('/'));
self.download(cdn_host, &patch_path, hash).await
}
pub async fn download_parallel(
&self,
cdn_host: &str,
path: &str,
hashes: &[String],
max_concurrent: Option<usize>,
) -> Vec<Result<Vec<u8>>> {
use futures_util::stream::{self, StreamExt};
let max_concurrent = max_concurrent.unwrap_or(10);
let futures = hashes.iter().map(|hash| {
let cdn_host = cdn_host.to_string();
let path = path.to_string();
let hash = hash.clone();
async move {
match self.download(&cdn_host, &path, &hash).await {
Ok(response) => response.bytes().await
.map(|b| b.to_vec())
.map_err(Into::into),
Err(e) => Err(e),
}
}
});
stream::iter(futures)
.buffer_unordered(max_concurrent)
.collect()
.await
}
pub async fn download_parallel_with_progress<F>(
&self,
cdn_host: &str,
path: &str,
hashes: &[String],
max_concurrent: Option<usize>,
mut progress: F,
) -> Vec<Result<Vec<u8>>>
where
F: FnMut(usize, usize),
{
use futures_util::stream::{self, StreamExt};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let max_concurrent = max_concurrent.unwrap_or(10);
let total = hashes.len();
let completed = Arc::new(AtomicUsize::new(0));
let futures = hashes.iter().enumerate().map(|(idx, hash)| {
let cdn_host = cdn_host.to_string();
let path = path.to_string();
let hash = hash.clone();
let completed = Arc::clone(&completed);
async move {
let result = match self.download(&cdn_host, &path, &hash).await {
Ok(response) => response.bytes().await
.map(|b| b.to_vec())
.map_err(Into::into),
Err(e) => Err(e),
};
let count = completed.fetch_add(1, Ordering::SeqCst) + 1;
(idx, result, count)
}
});
let mut results: Vec<Result<Vec<u8>>> = Vec::with_capacity(total);
for _ in 0..total {
results.push(Err(Error::invalid_response("Not downloaded")));
}
let mut download_stream = stream::iter(futures).buffer_unordered(max_concurrent);
while let Some((idx, result, count)) = download_stream.next().await {
results[idx] = result;
progress(count, total);
}
results
}
pub async fn download_data_parallel(
&self,
cdn_host: &str,
path: &str,
hashes: &[String],
max_concurrent: Option<usize>,
) -> Vec<Result<Vec<u8>>> {
let data_path = format!("{}/data", path.trim_end_matches('/'));
self.download_parallel(cdn_host, &data_path, hashes, max_concurrent).await
}
pub async fn download_config_parallel(
&self,
cdn_host: &str,
path: &str,
hashes: &[String],
max_concurrent: Option<usize>,
) -> Vec<Result<Vec<u8>>> {
let config_path = format!("{}/config", path.trim_end_matches('/'));
self.download_parallel(cdn_host, &config_path, hashes, max_concurrent).await
}
pub async fn download_patch_parallel(
&self,
cdn_host: &str,
path: &str,
hashes: &[String],
max_concurrent: Option<usize>,
) -> Vec<Result<Vec<u8>>> {
let patch_path = format!("{}/patch", path.trim_end_matches('/'));
self.download_parallel(cdn_host, &patch_path, hashes, max_concurrent).await
}
pub async fn download_streaming<W>(
&self,
cdn_host: &str,
path: &str,
hash: &str,
mut writer: W,
) -> Result<u64>
where
W: tokio::io::AsyncWrite + Unpin,
{
use futures_util::StreamExt;
use tokio::io::AsyncWriteExt;
let response = self.download(cdn_host, path, hash).await?;
let mut stream = response.bytes_stream();
let mut total_bytes = 0u64;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
writer.write_all(&chunk).await
.map_err(|e| Error::invalid_response(format!("Write error: {e}")))?;
total_bytes += chunk.len() as u64;
}
writer.flush().await.map_err(|e| Error::invalid_response(format!("Write error: {e}")))?;
Ok(total_bytes)
}
pub async fn download_chunked<F>(
&self,
cdn_host: &str,
path: &str,
hash: &str,
mut callback: F,
) -> Result<u64>
where
F: FnMut(&[u8]) -> Result<()>,
{
use futures_util::StreamExt;
let response = self.download(cdn_host, path, hash).await?;
let mut stream = response.bytes_stream();
let mut total_bytes = 0u64;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
callback(&chunk)?;
total_bytes += chunk.len() as u64;
}
Ok(total_bytes)
}
}
impl Default for CdnClient {
fn default() -> Self {
Self::new().expect("Failed to create default CDN client")
}
}
#[derive(Debug, Clone)]
pub struct CdnClientBuilder {
connect_timeout_secs: u64,
request_timeout_secs: u64,
pool_max_idle_per_host: usize,
max_retries: u32,
initial_backoff_ms: u64,
max_backoff_ms: u64,
backoff_multiplier: f64,
jitter_factor: f64,
user_agent: Option<String>,
}
impl CdnClientBuilder {
pub fn new() -> Self {
Self {
connect_timeout_secs: DEFAULT_CONNECT_TIMEOUT_SECS,
request_timeout_secs: DEFAULT_REQUEST_TIMEOUT_SECS,
pool_max_idle_per_host: 20,
max_retries: DEFAULT_MAX_RETRIES,
initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
jitter_factor: DEFAULT_JITTER_FACTOR,
user_agent: None,
}
}
pub fn connect_timeout(mut self, secs: u64) -> Self {
self.connect_timeout_secs = secs;
self
}
pub fn request_timeout(mut self, secs: u64) -> Self {
self.request_timeout_secs = secs;
self
}
pub fn pool_max_idle_per_host(mut self, max: usize) -> Self {
self.pool_max_idle_per_host = max;
self
}
pub fn max_retries(mut self, retries: u32) -> Self {
self.max_retries = retries;
self
}
pub fn initial_backoff_ms(mut self, ms: u64) -> Self {
self.initial_backoff_ms = ms;
self
}
pub fn max_backoff_ms(mut self, ms: u64) -> Self {
self.max_backoff_ms = ms;
self
}
pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
self.backoff_multiplier = multiplier;
self
}
pub fn jitter_factor(mut self, factor: f64) -> Self {
self.jitter_factor = factor.clamp(0.0, 1.0);
self
}
pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
self.user_agent = Some(user_agent.into());
self
}
pub fn build(self) -> Result<CdnClient> {
let client = Client::builder()
.connect_timeout(Duration::from_secs(self.connect_timeout_secs))
.timeout(Duration::from_secs(self.request_timeout_secs))
.pool_max_idle_per_host(self.pool_max_idle_per_host)
.gzip(true)
.deflate(true)
.build()?;
Ok(CdnClient {
client,
max_retries: self.max_retries,
initial_backoff_ms: self.initial_backoff_ms,
max_backoff_ms: self.max_backoff_ms,
backoff_multiplier: self.backoff_multiplier,
jitter_factor: self.jitter_factor,
user_agent: self.user_agent,
})
}
}
impl Default for CdnClientBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = CdnClient::new().unwrap();
assert_eq!(client.max_retries, DEFAULT_MAX_RETRIES);
assert_eq!(client.initial_backoff_ms, DEFAULT_INITIAL_BACKOFF_MS);
assert_eq!(client.max_backoff_ms, DEFAULT_MAX_BACKOFF_MS);
}
#[test]
fn test_builder_configuration() {
let client = CdnClient::builder()
.max_retries(5)
.initial_backoff_ms(200)
.max_backoff_ms(5000)
.backoff_multiplier(1.5)
.jitter_factor(0.2)
.connect_timeout(60)
.request_timeout(600)
.pool_max_idle_per_host(100)
.build()
.unwrap();
assert_eq!(client.max_retries, 5);
assert_eq!(client.initial_backoff_ms, 200);
assert_eq!(client.max_backoff_ms, 5000);
assert!((client.backoff_multiplier - 1.5).abs() < f64::EPSILON);
assert!((client.jitter_factor - 0.2).abs() < f64::EPSILON);
}
#[test]
fn test_jitter_factor_clamping() {
let client1 = CdnClient::new().unwrap().with_jitter_factor(1.5);
assert!((client1.jitter_factor - 1.0).abs() < f64::EPSILON);
let client2 = CdnClient::new().unwrap().with_jitter_factor(-0.5);
assert!((client2.jitter_factor - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_backoff_calculation() {
let client = CdnClient::new()
.unwrap()
.with_initial_backoff_ms(100)
.with_max_backoff_ms(1000)
.with_backoff_multiplier(2.0)
.with_jitter_factor(0.0);
let backoff0 = client.calculate_backoff(0);
assert_eq!(backoff0.as_millis(), 100);
let backoff1 = client.calculate_backoff(1);
assert_eq!(backoff1.as_millis(), 200);
let backoff2 = client.calculate_backoff(2);
assert_eq!(backoff2.as_millis(), 400);
let backoff5 = client.calculate_backoff(5);
assert_eq!(backoff5.as_millis(), 1000); }
#[test]
fn test_user_agent_configuration() {
let client = CdnClient::new()
.unwrap()
.with_user_agent("MyNGDPClient/1.0");
assert_eq!(client.user_agent, Some("MyNGDPClient/1.0".to_string()));
}
#[test]
fn test_user_agent_via_builder() {
let client = CdnClient::builder()
.user_agent("MyNGDPClient/2.0")
.build()
.unwrap();
assert_eq!(client.user_agent, Some("MyNGDPClient/2.0".to_string()));
}
#[test]
fn test_user_agent_default_none() {
let client = CdnClient::new().unwrap();
assert!(client.user_agent.is_none());
}
#[tokio::test]
async fn test_parallel_download_ordering() {
let client = CdnClient::new().unwrap();
let cdn_host = "example.com";
let path = "test";
let hashes = vec![
"hash1".to_string(),
"hash2".to_string(),
"hash3".to_string(),
];
let results = client.download_parallel(cdn_host, path, &hashes, Some(2)).await;
assert_eq!(results.len(), 3);
}
}