1use crate::config::HttpConfig;
7use crate::error::{EngineError, NetworkErrorKind, Result};
8use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
9use reqwest::Client;
10use std::num::NonZeroU32;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::{Duration, Instant};
13use tokio::sync::RwLock;
14
15pub struct ConnectionPool {
17 client: Client,
19 download_limiter: Option<DefaultDirectRateLimiter>,
21 upload_limiter: Option<DefaultDirectRateLimiter>,
23 total_downloaded: AtomicU64,
25 total_uploaded: AtomicU64,
27 active_connections: AtomicU64,
29 stats: RwLock<ConnectionStats>,
31}
32
33#[derive(Debug, Clone, Default)]
35pub struct ConnectionStats {
36 pub connections_created: u64,
38 pub successful_requests: u64,
40 pub failed_requests: u64,
42 pub retried_requests: u64,
44 pub avg_response_time_ms: f64,
46 pub last_error: Option<String>,
48}
49
50impl ConnectionPool {
51 pub fn new(config: &HttpConfig) -> Result<Self> {
53 let mut builder = Client::builder()
54 .connect_timeout(Duration::from_secs(config.connect_timeout))
55 .read_timeout(Duration::from_secs(config.read_timeout))
56 .redirect(reqwest::redirect::Policy::limited(config.max_redirects))
57 .danger_accept_invalid_certs(config.accept_invalid_certs)
58 .pool_max_idle_per_host(32)
59 .pool_idle_timeout(Duration::from_secs(90))
60 .gzip(false)
64 .brotli(false);
65
66 if let Some(ref proxy_url) = config.proxy_url {
68 let proxy = reqwest::Proxy::all(proxy_url)
69 .map_err(|e| EngineError::Internal(format!("Invalid proxy URL: {}", e)))?;
70 builder = builder.proxy(proxy);
71 }
72
73 let client = builder
74 .build()
75 .map_err(|e| EngineError::Internal(format!("Failed to create HTTP client: {}", e)))?;
76
77 Ok(Self {
78 client,
79 download_limiter: None,
80 upload_limiter: None,
81 total_downloaded: AtomicU64::new(0),
82 total_uploaded: AtomicU64::new(0),
83 active_connections: AtomicU64::new(0),
84 stats: RwLock::new(ConnectionStats::default()),
85 })
86 }
87
88 pub fn with_limits(
90 config: &HttpConfig,
91 download_limit: Option<u64>,
92 upload_limit: Option<u64>,
93 ) -> Result<Self> {
94 let mut pool = Self::new(config)?;
95
96 pool.download_limiter = download_limit.and_then(|limit| {
97 let clamped = limit.min(u32::MAX as u64) as u32;
98 NonZeroU32::new(clamped).map(|n| RateLimiter::direct(Quota::per_second(n)))
99 });
100
101 pool.upload_limiter = upload_limit.and_then(|limit| {
102 let clamped = limit.min(u32::MAX as u64) as u32;
103 NonZeroU32::new(clamped).map(|n| RateLimiter::direct(Quota::per_second(n)))
104 });
105
106 Ok(pool)
107 }
108
109 pub fn client(&self) -> &Client {
111 &self.client
112 }
113
114 pub fn set_download_limit(&mut self, limit: Option<u64>) {
116 self.download_limiter = limit.and_then(|l| {
117 let clamped = l.min(u32::MAX as u64) as u32;
118 NonZeroU32::new(clamped).map(|n| RateLimiter::direct(Quota::per_second(n)))
119 });
120 }
121
122 pub fn set_upload_limit(&mut self, limit: Option<u64>) {
124 self.upload_limiter = limit.and_then(|l| {
125 let clamped = l.min(u32::MAX as u64) as u32;
126 NonZeroU32::new(clamped).map(|n| RateLimiter::direct(Quota::per_second(n)))
127 });
128 }
129
130 pub async fn acquire_download(&self, bytes: u64) {
132 if let Some(ref limiter) = self.download_limiter {
133 let chunk_size = 16384; let chunks = (bytes / chunk_size).max(1) as u32;
136 for _ in 0..chunks {
137 if let Some(n) = NonZeroU32::new(chunk_size as u32) {
138 let _ = limiter.until_n_ready(n).await;
139 }
140 }
141 }
142 }
143
144 pub async fn acquire_upload(&self, bytes: u64) {
146 if let Some(ref limiter) = self.upload_limiter {
147 let chunk_size = 16384;
148 let chunks = (bytes / chunk_size).max(1) as u32;
149 for _ in 0..chunks {
150 if let Some(n) = NonZeroU32::new(chunk_size as u32) {
151 let _ = limiter.until_n_ready(n).await;
152 }
153 }
154 }
155 }
156
157 pub fn record_download(&self, bytes: u64) {
159 self.total_downloaded.fetch_add(bytes, Ordering::Relaxed);
160 }
161
162 pub fn record_upload(&self, bytes: u64) {
164 self.total_uploaded.fetch_add(bytes, Ordering::Relaxed);
165 }
166
167 pub fn total_downloaded(&self) -> u64 {
169 self.total_downloaded.load(Ordering::Relaxed)
170 }
171
172 pub fn total_uploaded(&self) -> u64 {
174 self.total_uploaded.load(Ordering::Relaxed)
175 }
176
177 pub fn connection_started(&self) {
179 self.active_connections.fetch_add(1, Ordering::Relaxed);
180 }
181
182 pub fn connection_finished(&self) {
184 self.active_connections.fetch_sub(1, Ordering::Relaxed);
185 }
186
187 pub fn active_connections(&self) -> u64 {
189 self.active_connections.load(Ordering::Relaxed)
190 }
191
192 pub async fn record_success(&self, response_time_ms: f64) {
194 let mut stats = self.stats.write().await;
195 stats.successful_requests += 1;
196
197 let alpha = 0.2;
199 stats.avg_response_time_ms =
200 alpha * response_time_ms + (1.0 - alpha) * stats.avg_response_time_ms;
201 }
202
203 pub async fn record_failure(&self, error: &str) {
205 let mut stats = self.stats.write().await;
206 stats.failed_requests += 1;
207 stats.last_error = Some(error.to_string());
208 }
209
210 pub async fn record_retry(&self) {
212 let mut stats = self.stats.write().await;
213 stats.retried_requests += 1;
214 }
215
216 pub async fn stats(&self) -> ConnectionStats {
218 self.stats.read().await.clone()
219 }
220}
221
222#[derive(Debug, Clone)]
224pub struct RetryPolicy {
225 pub max_attempts: u32,
227 pub initial_delay_ms: u64,
229 pub max_delay_ms: u64,
231 pub jitter_factor: f64,
233}
234
235impl Default for RetryPolicy {
236 fn default() -> Self {
237 Self {
238 max_attempts: 3,
239 initial_delay_ms: 1000,
240 max_delay_ms: 30000,
241 jitter_factor: 0.25,
242 }
243 }
244}
245
246impl RetryPolicy {
247 pub fn new(max_attempts: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
249 Self {
250 max_attempts,
251 initial_delay_ms,
252 max_delay_ms,
253 jitter_factor: 0.25,
254 }
255 }
256
257 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
259 let base = self.initial_delay_ms * 2u64.pow(attempt.min(10));
261 let capped = base.min(self.max_delay_ms);
262
263 let jitter = (rand::random::<f64>() - 0.5) * 2.0 * self.jitter_factor;
265 let with_jitter = (capped as f64 * (1.0 + jitter)) as u64;
266
267 Duration::from_millis(with_jitter)
268 }
269
270 pub fn should_retry(&self, attempt: u32, error: &EngineError) -> bool {
272 if attempt >= self.max_attempts {
273 return false;
274 }
275
276 error.is_retryable()
277 }
278}
279
280pub async fn with_retry<F, T, Fut>(
282 pool: &ConnectionPool,
283 policy: &RetryPolicy,
284 operation: F,
285) -> Result<T>
286where
287 F: Fn() -> Fut,
288 Fut: std::future::Future<Output = Result<T>>,
289{
290 let mut last_error = None;
291
292 for attempt in 0..policy.max_attempts {
293 let start = Instant::now();
294
295 match operation().await {
296 Ok(result) => {
297 let elapsed = start.elapsed().as_millis() as f64;
298 pool.record_success(elapsed).await;
299 return Ok(result);
300 }
301 Err(e) => {
302 let _elapsed = start.elapsed().as_millis() as f64;
303 pool.record_failure(&e.to_string()).await;
304
305 if policy.should_retry(attempt, &e) {
306 pool.record_retry().await;
307 let delay = policy.delay_for_attempt(attempt);
308 tracing::debug!(
309 "Request failed (attempt {}), retrying in {:?}: {}",
310 attempt + 1,
311 delay,
312 e
313 );
314 tokio::time::sleep(delay).await;
315 last_error = Some(e);
316 } else {
317 return Err(e);
318 }
319 }
320 }
321 }
322
323 Err(last_error
324 .unwrap_or_else(|| EngineError::network(NetworkErrorKind::Other, "Max retries exceeded")))
325}
326
327#[derive(Debug)]
329pub struct SpeedCalculator {
330 window_size: usize,
332 measurements: Vec<(u64, Instant)>,
334 total_bytes: u64,
336}
337
338impl SpeedCalculator {
339 pub fn new(window_size: usize) -> Self {
341 Self {
342 window_size,
343 measurements: Vec::with_capacity(window_size),
344 total_bytes: 0,
345 }
346 }
347
348 pub fn add_bytes(&mut self, bytes: u64) {
350 let now = Instant::now();
351 self.total_bytes += bytes;
352
353 if self.measurements.len() >= self.window_size {
354 self.measurements.remove(0);
355 }
356 self.measurements.push((bytes, now));
357 }
358
359 pub fn speed(&self) -> u64 {
361 if self.measurements.len() < 2 {
362 return 0;
363 }
364
365 let first = &self.measurements[0];
366 let last = &self.measurements[self.measurements.len() - 1];
367
368 let elapsed = last.1.duration_since(first.1).as_secs_f64();
369 if elapsed <= 0.0 {
370 return 0;
371 }
372
373 let bytes: u64 = self.measurements.iter().map(|(b, _)| *b).sum();
374 (bytes as f64 / elapsed) as u64
375 }
376
377 pub fn total(&self) -> u64 {
379 self.total_bytes
380 }
381
382 pub fn reset(&mut self) {
384 self.measurements.clear();
385 self.total_bytes = 0;
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_retry_delay() {
395 let policy = RetryPolicy::new(3, 1000, 30000);
396
397 let delay0 = policy.delay_for_attempt(0);
399 assert!(delay0.as_millis() >= 750 && delay0.as_millis() <= 1250);
400
401 let delay1 = policy.delay_for_attempt(1);
403 assert!(delay1.as_millis() >= 1500 && delay1.as_millis() <= 2500);
404
405 let delay2 = policy.delay_for_attempt(2);
407 assert!(delay2.as_millis() >= 3000 && delay2.as_millis() <= 5000);
408 }
409
410 #[test]
411 fn test_speed_calculator() {
412 let mut calc = SpeedCalculator::new(10);
413
414 calc.add_bytes(1000);
416 std::thread::sleep(Duration::from_millis(100));
417 calc.add_bytes(1000);
418 std::thread::sleep(Duration::from_millis(100));
419 calc.add_bytes(1000);
420
421 let speed = calc.speed();
424 assert!(speed > 0);
425
426 assert_eq!(calc.total(), 3000);
427 }
428
429 #[test]
430 fn test_retry_policy_defaults() {
431 let policy = RetryPolicy::default();
432 assert_eq!(policy.max_attempts, 3);
433 assert_eq!(policy.initial_delay_ms, 1000);
434 assert_eq!(policy.max_delay_ms, 30000);
435 }
436}