use crate::config::HttpConfig;
use crate::error::{EngineError, NetworkErrorKind, Result};
use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
use parking_lot::RwLock as ParkingRwLock;
use reqwest::Client;
use std::num::NonZeroU32;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
pub struct ConnectionPool {
client: Client,
download_limiter: ParkingRwLock<Option<Arc<DefaultDirectRateLimiter>>>,
upload_limiter: ParkingRwLock<Option<Arc<DefaultDirectRateLimiter>>>,
total_downloaded: AtomicU64,
total_uploaded: AtomicU64,
active_connections: AtomicU64,
stats: RwLock<ConnectionStats>,
}
#[derive(Debug, Clone, Default)]
pub struct ConnectionStats {
pub connections_created: u64,
pub successful_requests: u64,
pub failed_requests: u64,
pub retried_requests: u64,
pub avg_response_time_ms: f64,
pub last_error: Option<String>,
}
impl ConnectionPool {
pub fn new(config: &HttpConfig) -> Result<Self> {
let mut builder = Client::builder()
.connect_timeout(Duration::from_secs(config.connect_timeout))
.read_timeout(Duration::from_secs(config.read_timeout))
.redirect(reqwest::redirect::Policy::limited(config.max_redirects))
.danger_accept_invalid_certs(config.accept_invalid_certs)
.pool_max_idle_per_host(32)
.pool_idle_timeout(Duration::from_secs(90))
.gzip(false)
.brotli(false);
if let Some(ref proxy_url) = config.proxy_url {
let proxy = reqwest::Proxy::all(proxy_url)
.map_err(|e| EngineError::Internal(format!("Invalid proxy URL: {}", e)))?;
builder = builder.proxy(proxy);
}
let client = builder
.build()
.map_err(|e| EngineError::Internal(format!("Failed to create HTTP client: {}", e)))?;
Ok(Self {
client,
download_limiter: ParkingRwLock::new(None),
upload_limiter: ParkingRwLock::new(None),
total_downloaded: AtomicU64::new(0),
total_uploaded: AtomicU64::new(0),
active_connections: AtomicU64::new(0),
stats: RwLock::new(ConnectionStats::default()),
})
}
pub fn with_limits(
config: &HttpConfig,
download_limit: Option<u64>,
upload_limit: Option<u64>,
) -> Result<Self> {
let pool = Self::new(config)?;
pool.set_download_limit(download_limit);
pool.set_upload_limit(upload_limit);
Ok(pool)
}
pub fn client(&self) -> &Client {
&self.client
}
pub fn set_download_limit(&self, limit: Option<u64>) {
*self.download_limiter.write() = limit.and_then(build_rate_limiter);
}
pub fn set_upload_limit(&self, limit: Option<u64>) {
*self.upload_limiter.write() = limit.and_then(build_rate_limiter);
}
pub async fn acquire_download(&self, bytes: u64) {
let limiter = self.download_limiter.read().clone();
if let Some(limiter) = limiter {
for chunk in limiter_chunks(bytes) {
let _ = limiter.until_n_ready(chunk).await;
}
}
}
pub async fn acquire_upload(&self, bytes: u64) {
let limiter = self.upload_limiter.read().clone();
if let Some(limiter) = limiter {
for chunk in limiter_chunks(bytes) {
let _ = limiter.until_n_ready(chunk).await;
}
}
}
pub fn record_download(&self, bytes: u64) {
self.total_downloaded.fetch_add(bytes, Ordering::Relaxed);
}
pub fn record_upload(&self, bytes: u64) {
self.total_uploaded.fetch_add(bytes, Ordering::Relaxed);
}
pub fn total_downloaded(&self) -> u64 {
self.total_downloaded.load(Ordering::Relaxed)
}
pub fn total_uploaded(&self) -> u64 {
self.total_uploaded.load(Ordering::Relaxed)
}
pub fn connection_started(&self) {
self.active_connections.fetch_add(1, Ordering::Relaxed);
}
pub fn connection_finished(&self) {
self.active_connections.fetch_sub(1, Ordering::Relaxed);
}
pub fn active_connections(&self) -> u64 {
self.active_connections.load(Ordering::Relaxed)
}
pub async fn record_success(&self, response_time_ms: f64) {
let mut stats = self.stats.write().await;
stats.successful_requests += 1;
let alpha = 0.2;
stats.avg_response_time_ms =
alpha * response_time_ms + (1.0 - alpha) * stats.avg_response_time_ms;
}
pub async fn record_failure(&self, error: &str) {
let mut stats = self.stats.write().await;
stats.failed_requests += 1;
stats.last_error = Some(error.to_string());
}
pub async fn record_retry(&self) {
let mut stats = self.stats.write().await;
stats.retried_requests += 1;
}
pub async fn stats(&self) -> ConnectionStats {
self.stats.read().await.clone()
}
}
fn build_rate_limiter(limit: u64) -> Option<Arc<DefaultDirectRateLimiter>> {
let clamped = limit.min(u32::MAX as u64) as u32;
NonZeroU32::new(clamped).map(|n| Arc::new(RateLimiter::direct(Quota::per_second(n))))
}
fn limiter_chunks(bytes: u64) -> Vec<NonZeroU32> {
const CHUNK_SIZE: u64 = 16 * 1024;
if bytes == 0 {
return Vec::new();
}
let full_chunks = bytes / CHUNK_SIZE;
let remainder = bytes % CHUNK_SIZE;
let mut chunks = Vec::with_capacity(full_chunks as usize + usize::from(remainder > 0));
for _ in 0..full_chunks {
chunks.push(NonZeroU32::new(CHUNK_SIZE as u32).expect("chunk size is non-zero"));
}
if remainder > 0 {
chunks.push(NonZeroU32::new(remainder as u32).expect("remainder is non-zero"));
}
chunks
}
#[cfg(test)]
mod limiter_tests {
use super::limiter_chunks;
#[test]
fn limiter_chunks_is_empty_for_zero_bytes() {
assert!(limiter_chunks(0).is_empty());
}
#[test]
fn limiter_chunks_preserves_exact_byte_count() {
let chunks = limiter_chunks(16 * 1024 + 17);
let total: u64 = chunks.into_iter().map(|chunk| chunk.get() as u64).sum();
assert_eq!(total, 16 * 1024 + 17);
}
#[test]
fn limiter_chunks_does_not_over_throttle_small_reads() {
let chunks = limiter_chunks(1);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].get(), 1);
}
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub initial_delay_ms: u64,
pub max_delay_ms: u64,
pub jitter_factor: f64,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay_ms: 1000,
max_delay_ms: 30000,
jitter_factor: 0.25,
}
}
}
impl RetryPolicy {
pub fn new(max_attempts: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
Self {
max_attempts,
initial_delay_ms,
max_delay_ms,
jitter_factor: 0.25,
}
}
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
let base = self.initial_delay_ms * 2u64.pow(attempt.min(10));
let capped = base.min(self.max_delay_ms);
let jitter = (rand::random::<f64>() - 0.5) * 2.0 * self.jitter_factor;
let with_jitter = (capped as f64 * (1.0 + jitter)) as u64;
Duration::from_millis(with_jitter)
}
pub fn should_retry(&self, attempt: u32, error: &EngineError) -> bool {
if attempt >= self.max_attempts {
return false;
}
error.is_retryable()
}
}
pub async fn with_retry<F, T, Fut>(
pool: &ConnectionPool,
policy: &RetryPolicy,
operation: F,
) -> Result<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let mut last_error = None;
for attempt in 0..policy.max_attempts {
let start = Instant::now();
match operation().await {
Ok(result) => {
let elapsed = start.elapsed().as_millis() as f64;
pool.record_success(elapsed).await;
return Ok(result);
}
Err(e) => {
let _elapsed = start.elapsed().as_millis() as f64;
pool.record_failure(&e.to_string()).await;
if policy.should_retry(attempt, &e) {
pool.record_retry().await;
let delay = policy.delay_for_attempt(attempt);
tracing::debug!(
"Request failed (attempt {}), retrying in {:?}: {}",
attempt + 1,
delay,
e
);
tokio::time::sleep(delay).await;
last_error = Some(e);
} else {
return Err(e);
}
}
}
}
Err(last_error
.unwrap_or_else(|| EngineError::network(NetworkErrorKind::Other, "Max retries exceeded")))
}
#[derive(Debug)]
pub struct SpeedCalculator {
window_size: usize,
measurements: Vec<(u64, Instant)>,
total_bytes: u64,
}
impl SpeedCalculator {
pub fn new(window_size: usize) -> Self {
Self {
window_size,
measurements: Vec::with_capacity(window_size),
total_bytes: 0,
}
}
pub fn add_bytes(&mut self, bytes: u64) {
let now = Instant::now();
self.total_bytes += bytes;
if self.measurements.len() >= self.window_size {
self.measurements.remove(0);
}
self.measurements.push((bytes, now));
}
pub fn speed(&self) -> u64 {
if self.measurements.len() < 2 {
return 0;
}
let first = &self.measurements[0];
let last = &self.measurements[self.measurements.len() - 1];
let elapsed = last.1.duration_since(first.1).as_secs_f64();
if elapsed <= 0.0 {
return 0;
}
let bytes: u64 = self.measurements.iter().map(|(b, _)| *b).sum();
(bytes as f64 / elapsed) as u64
}
pub fn total(&self) -> u64 {
self.total_bytes
}
pub fn reset(&mut self) {
self.measurements.clear();
self.total_bytes = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_delay() {
let policy = RetryPolicy::new(3, 1000, 30000);
let delay0 = policy.delay_for_attempt(0);
assert!(delay0.as_millis() >= 750 && delay0.as_millis() <= 1250);
let delay1 = policy.delay_for_attempt(1);
assert!(delay1.as_millis() >= 1500 && delay1.as_millis() <= 2500);
let delay2 = policy.delay_for_attempt(2);
assert!(delay2.as_millis() >= 3000 && delay2.as_millis() <= 5000);
}
#[test]
fn test_speed_calculator() {
let mut calc = SpeedCalculator::new(10);
calc.add_bytes(1000);
std::thread::sleep(Duration::from_millis(100));
calc.add_bytes(1000);
std::thread::sleep(Duration::from_millis(100));
calc.add_bytes(1000);
let speed = calc.speed();
assert!(speed > 0);
assert_eq!(calc.total(), 3000);
}
#[test]
fn test_retry_policy_defaults() {
let policy = RetryPolicy::default();
assert_eq!(policy.max_attempts, 3);
assert_eq!(policy.initial_delay_ms, 1000);
assert_eq!(policy.max_delay_ms, 30000);
}
}