Skip to main content

gosh_dl/http/
connection.rs

1//! Connection Pool Management
2//!
3//! This module provides HTTP connection pooling with health checks,
4//! retry logic, and speed limiting capabilities.
5
6use crate::config::HttpConfig;
7use crate::error::{EngineError, NetworkErrorKind, Result};
8use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
9use reqwest::Client;
10use std::num::NonZeroU32;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::{Duration, Instant};
13use tokio::sync::RwLock;
14
15/// Connection pool with rate limiting and health monitoring
16pub struct ConnectionPool {
17    /// HTTP client (reqwest handles its own connection pool)
18    client: Client,
19    /// Global rate limiter for download speed
20    download_limiter: Option<DefaultDirectRateLimiter>,
21    /// Global rate limiter for upload speed
22    upload_limiter: Option<DefaultDirectRateLimiter>,
23    /// Total bytes downloaded
24    total_downloaded: AtomicU64,
25    /// Total bytes uploaded
26    total_uploaded: AtomicU64,
27    /// Active connection count
28    active_connections: AtomicU64,
29    /// Connection statistics
30    stats: RwLock<ConnectionStats>,
31}
32
33/// Connection statistics
34#[derive(Debug, Clone, Default)]
35pub struct ConnectionStats {
36    /// Total connections created
37    pub connections_created: u64,
38    /// Total successful requests
39    pub successful_requests: u64,
40    /// Total failed requests
41    pub failed_requests: u64,
42    /// Total retried requests
43    pub retried_requests: u64,
44    /// Average response time in milliseconds
45    pub avg_response_time_ms: f64,
46    /// Last error message
47    pub last_error: Option<String>,
48}
49
50impl ConnectionPool {
51    /// Create a new connection pool
52    pub fn new(config: &HttpConfig) -> Result<Self> {
53        let mut builder = Client::builder()
54            .connect_timeout(Duration::from_secs(config.connect_timeout))
55            .read_timeout(Duration::from_secs(config.read_timeout))
56            .redirect(reqwest::redirect::Policy::limited(config.max_redirects))
57            .danger_accept_invalid_certs(config.accept_invalid_certs)
58            .pool_max_idle_per_host(32)
59            .pool_idle_timeout(Duration::from_secs(90))
60            // This is a download engine: preserve the exact bytes on the wire.
61            // Transparent decompression breaks progress accounting, checksums,
62            // range semantics, and on-disk fidelity.
63            .gzip(false)
64            .brotli(false);
65
66        // Add proxy if configured
67        if let Some(ref proxy_url) = config.proxy_url {
68            let proxy = reqwest::Proxy::all(proxy_url)
69                .map_err(|e| EngineError::Internal(format!("Invalid proxy URL: {}", e)))?;
70            builder = builder.proxy(proxy);
71        }
72
73        let client = builder
74            .build()
75            .map_err(|e| EngineError::Internal(format!("Failed to create HTTP client: {}", e)))?;
76
77        Ok(Self {
78            client,
79            download_limiter: None,
80            upload_limiter: None,
81            total_downloaded: AtomicU64::new(0),
82            total_uploaded: AtomicU64::new(0),
83            active_connections: AtomicU64::new(0),
84            stats: RwLock::new(ConnectionStats::default()),
85        })
86    }
87
88    /// Create a connection pool with rate limiting
89    pub fn with_limits(
90        config: &HttpConfig,
91        download_limit: Option<u64>,
92        upload_limit: Option<u64>,
93    ) -> Result<Self> {
94        let mut pool = Self::new(config)?;
95
96        pool.download_limiter = download_limit.and_then(|limit| {
97            let clamped = limit.min(u32::MAX as u64) as u32;
98            NonZeroU32::new(clamped).map(|n| RateLimiter::direct(Quota::per_second(n)))
99        });
100
101        pool.upload_limiter = upload_limit.and_then(|limit| {
102            let clamped = limit.min(u32::MAX as u64) as u32;
103            NonZeroU32::new(clamped).map(|n| RateLimiter::direct(Quota::per_second(n)))
104        });
105
106        Ok(pool)
107    }
108
109    /// Get the underlying HTTP client
110    pub fn client(&self) -> &Client {
111        &self.client
112    }
113
114    /// Update download speed limit
115    pub fn set_download_limit(&mut self, limit: Option<u64>) {
116        self.download_limiter = limit.and_then(|l| {
117            let clamped = l.min(u32::MAX as u64) as u32;
118            NonZeroU32::new(clamped).map(|n| RateLimiter::direct(Quota::per_second(n)))
119        });
120    }
121
122    /// Update upload speed limit
123    pub fn set_upload_limit(&mut self, limit: Option<u64>) {
124        self.upload_limiter = limit.and_then(|l| {
125            let clamped = l.min(u32::MAX as u64) as u32;
126            NonZeroU32::new(clamped).map(|n| RateLimiter::direct(Quota::per_second(n)))
127        });
128    }
129
130    /// Wait for rate limiter permission to download bytes
131    pub async fn acquire_download(&self, bytes: u64) {
132        if let Some(ref limiter) = self.download_limiter {
133            // Acquire permission in chunks to avoid blocking too long
134            let chunk_size = 16384; // 16KB chunks
135            let chunks = (bytes / chunk_size).max(1) as u32;
136            for _ in 0..chunks {
137                if let Some(n) = NonZeroU32::new(chunk_size as u32) {
138                    let _ = limiter.until_n_ready(n).await;
139                }
140            }
141        }
142    }
143
144    /// Wait for rate limiter permission to upload bytes
145    pub async fn acquire_upload(&self, bytes: u64) {
146        if let Some(ref limiter) = self.upload_limiter {
147            let chunk_size = 16384;
148            let chunks = (bytes / chunk_size).max(1) as u32;
149            for _ in 0..chunks {
150                if let Some(n) = NonZeroU32::new(chunk_size as u32) {
151                    let _ = limiter.until_n_ready(n).await;
152                }
153            }
154        }
155    }
156
157    /// Record downloaded bytes
158    pub fn record_download(&self, bytes: u64) {
159        self.total_downloaded.fetch_add(bytes, Ordering::Relaxed);
160    }
161
162    /// Record uploaded bytes
163    pub fn record_upload(&self, bytes: u64) {
164        self.total_uploaded.fetch_add(bytes, Ordering::Relaxed);
165    }
166
167    /// Get total downloaded bytes
168    pub fn total_downloaded(&self) -> u64 {
169        self.total_downloaded.load(Ordering::Relaxed)
170    }
171
172    /// Get total uploaded bytes
173    pub fn total_uploaded(&self) -> u64 {
174        self.total_uploaded.load(Ordering::Relaxed)
175    }
176
177    /// Increment active connection count
178    pub fn connection_started(&self) {
179        self.active_connections.fetch_add(1, Ordering::Relaxed);
180    }
181
182    /// Decrement active connection count
183    pub fn connection_finished(&self) {
184        self.active_connections.fetch_sub(1, Ordering::Relaxed);
185    }
186
187    /// Get active connection count
188    pub fn active_connections(&self) -> u64 {
189        self.active_connections.load(Ordering::Relaxed)
190    }
191
192    /// Record a successful request
193    pub async fn record_success(&self, response_time_ms: f64) {
194        let mut stats = self.stats.write().await;
195        stats.successful_requests += 1;
196
197        // Update average response time (exponential moving average)
198        let alpha = 0.2;
199        stats.avg_response_time_ms =
200            alpha * response_time_ms + (1.0 - alpha) * stats.avg_response_time_ms;
201    }
202
203    /// Record a failed request
204    pub async fn record_failure(&self, error: &str) {
205        let mut stats = self.stats.write().await;
206        stats.failed_requests += 1;
207        stats.last_error = Some(error.to_string());
208    }
209
210    /// Record a retried request
211    pub async fn record_retry(&self) {
212        let mut stats = self.stats.write().await;
213        stats.retried_requests += 1;
214    }
215
216    /// Get connection statistics
217    pub async fn stats(&self) -> ConnectionStats {
218        self.stats.read().await.clone()
219    }
220}
221
222/// Retry policy with exponential backoff and jitter
223#[derive(Debug, Clone)]
224pub struct RetryPolicy {
225    /// Maximum number of retry attempts
226    pub max_attempts: u32,
227    /// Initial delay in milliseconds
228    pub initial_delay_ms: u64,
229    /// Maximum delay in milliseconds
230    pub max_delay_ms: u64,
231    /// Jitter factor (0.0 to 1.0)
232    pub jitter_factor: f64,
233}
234
235impl Default for RetryPolicy {
236    fn default() -> Self {
237        Self {
238            max_attempts: 3,
239            initial_delay_ms: 1000,
240            max_delay_ms: 30000,
241            jitter_factor: 0.25,
242        }
243    }
244}
245
246impl RetryPolicy {
247    /// Create a new retry policy
248    pub fn new(max_attempts: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
249        Self {
250            max_attempts,
251            initial_delay_ms,
252            max_delay_ms,
253            jitter_factor: 0.25,
254        }
255    }
256
257    /// Calculate delay for a given attempt (0-indexed)
258    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
259        // Exponential backoff
260        let base = self.initial_delay_ms * 2u64.pow(attempt.min(10));
261        let capped = base.min(self.max_delay_ms);
262
263        // Add jitter: ±jitter_factor randomness
264        let jitter = (rand::random::<f64>() - 0.5) * 2.0 * self.jitter_factor;
265        let with_jitter = (capped as f64 * (1.0 + jitter)) as u64;
266
267        Duration::from_millis(with_jitter)
268    }
269
270    /// Check if we should retry based on error type
271    pub fn should_retry(&self, attempt: u32, error: &EngineError) -> bool {
272        if attempt >= self.max_attempts {
273            return false;
274        }
275
276        error.is_retryable()
277    }
278}
279
280/// Execute a request with retry logic
281pub async fn with_retry<F, T, Fut>(
282    pool: &ConnectionPool,
283    policy: &RetryPolicy,
284    operation: F,
285) -> Result<T>
286where
287    F: Fn() -> Fut,
288    Fut: std::future::Future<Output = Result<T>>,
289{
290    let mut last_error = None;
291
292    for attempt in 0..policy.max_attempts {
293        let start = Instant::now();
294
295        match operation().await {
296            Ok(result) => {
297                let elapsed = start.elapsed().as_millis() as f64;
298                pool.record_success(elapsed).await;
299                return Ok(result);
300            }
301            Err(e) => {
302                let _elapsed = start.elapsed().as_millis() as f64;
303                pool.record_failure(&e.to_string()).await;
304
305                if policy.should_retry(attempt, &e) {
306                    pool.record_retry().await;
307                    let delay = policy.delay_for_attempt(attempt);
308                    tracing::debug!(
309                        "Request failed (attempt {}), retrying in {:?}: {}",
310                        attempt + 1,
311                        delay,
312                        e
313                    );
314                    tokio::time::sleep(delay).await;
315                    last_error = Some(e);
316                } else {
317                    return Err(e);
318                }
319            }
320        }
321    }
322
323    Err(last_error
324        .unwrap_or_else(|| EngineError::network(NetworkErrorKind::Other, "Max retries exceeded")))
325}
326
327/// Speed calculator for tracking download/upload rates
328#[derive(Debug)]
329pub struct SpeedCalculator {
330    /// Window size for averaging
331    window_size: usize,
332    /// Recent measurements (bytes, timestamp)
333    measurements: Vec<(u64, Instant)>,
334    /// Total bytes tracked
335    total_bytes: u64,
336}
337
338impl SpeedCalculator {
339    /// Create a new speed calculator
340    pub fn new(window_size: usize) -> Self {
341        Self {
342            window_size,
343            measurements: Vec::with_capacity(window_size),
344            total_bytes: 0,
345        }
346    }
347
348    /// Add a measurement
349    pub fn add_bytes(&mut self, bytes: u64) {
350        let now = Instant::now();
351        self.total_bytes += bytes;
352
353        if self.measurements.len() >= self.window_size {
354            self.measurements.remove(0);
355        }
356        self.measurements.push((bytes, now));
357    }
358
359    /// Calculate current speed in bytes/second
360    pub fn speed(&self) -> u64 {
361        if self.measurements.len() < 2 {
362            return 0;
363        }
364
365        let first = &self.measurements[0];
366        let last = &self.measurements[self.measurements.len() - 1];
367
368        let elapsed = last.1.duration_since(first.1).as_secs_f64();
369        if elapsed <= 0.0 {
370            return 0;
371        }
372
373        let bytes: u64 = self.measurements.iter().map(|(b, _)| *b).sum();
374        (bytes as f64 / elapsed) as u64
375    }
376
377    /// Get total bytes tracked
378    pub fn total(&self) -> u64 {
379        self.total_bytes
380    }
381
382    /// Reset the calculator
383    pub fn reset(&mut self) {
384        self.measurements.clear();
385        self.total_bytes = 0;
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn test_retry_delay() {
395        let policy = RetryPolicy::new(3, 1000, 30000);
396
397        // First attempt: ~1000ms
398        let delay0 = policy.delay_for_attempt(0);
399        assert!(delay0.as_millis() >= 750 && delay0.as_millis() <= 1250);
400
401        // Second attempt: ~2000ms
402        let delay1 = policy.delay_for_attempt(1);
403        assert!(delay1.as_millis() >= 1500 && delay1.as_millis() <= 2500);
404
405        // Third attempt: ~4000ms
406        let delay2 = policy.delay_for_attempt(2);
407        assert!(delay2.as_millis() >= 3000 && delay2.as_millis() <= 5000);
408    }
409
410    #[test]
411    fn test_speed_calculator() {
412        let mut calc = SpeedCalculator::new(10);
413
414        // Add measurements
415        calc.add_bytes(1000);
416        std::thread::sleep(Duration::from_millis(100));
417        calc.add_bytes(1000);
418        std::thread::sleep(Duration::from_millis(100));
419        calc.add_bytes(1000);
420
421        // Speed should be roughly 10000 bytes/sec (3000 bytes in 0.2 sec)
422        // But due to timing variations, we just check it's non-zero
423        let speed = calc.speed();
424        assert!(speed > 0);
425
426        assert_eq!(calc.total(), 3000);
427    }
428
429    #[test]
430    fn test_retry_policy_defaults() {
431        let policy = RetryPolicy::default();
432        assert_eq!(policy.max_attempts, 3);
433        assert_eq!(policy.initial_delay_ms, 1000);
434        assert_eq!(policy.max_delay_ms, 30000);
435    }
436}