1use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, Ordering};
9
10use serde::Serialize;
11use tokio::sync::RwLock;
12use tokio::task::JoinHandle;
13use tokio_util::sync::CancellationToken;
14use uuid::Uuid;
15
16use crate::circuit_breaker::CircuitBreaker;
17use crate::error::{ProxyError, ProxyResult};
18use crate::health::{HealthChecker, HealthMap};
19use crate::storage::ProxyStoragePort;
20use crate::strategy::{
21 BoxedRotationStrategy, LeastUsedStrategy, ProxyCandidate, RandomStrategy, RoundRobinStrategy,
22 WeightedStrategy,
23};
24use crate::types::{Proxy, ProxyConfig};
25
26#[derive(Debug, Serialize)]
32pub struct PoolStats {
33 pub total: usize,
35 pub healthy: usize,
37 pub open: usize,
39}
40
41pub struct ProxyHandle {
51 pub proxy_url: String,
53 circuit_breaker: Arc<CircuitBreaker>,
54 succeeded: AtomicBool,
55}
56
57impl ProxyHandle {
58 fn new(proxy_url: String, circuit_breaker: Arc<CircuitBreaker>) -> Self {
59 Self {
60 proxy_url,
61 circuit_breaker,
62 succeeded: AtomicBool::new(false),
63 }
64 }
65
66 pub fn direct() -> Self {
71 let noop_cb = Arc::new(CircuitBreaker::new(u32::MAX, u64::MAX));
72 Self {
73 proxy_url: String::new(),
74 circuit_breaker: noop_cb,
75 succeeded: AtomicBool::new(true),
76 }
77 }
78
79 pub fn mark_success(&self) {
81 self.succeeded.store(true, Ordering::Release);
82 }
83}
84
85impl std::fmt::Debug for ProxyHandle {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("ProxyHandle")
88 .field("proxy_url", &self.proxy_url)
89 .finish_non_exhaustive()
90 }
91}
92
93impl Drop for ProxyHandle {
94 fn drop(&mut self) {
95 if self.succeeded.load(Ordering::Acquire) {
96 self.circuit_breaker.record_success();
97 } else {
98 self.circuit_breaker.record_failure();
99 }
100 }
101}
102
103pub struct ProxyManager {
140 storage: Arc<dyn ProxyStoragePort>,
141 strategy: BoxedRotationStrategy,
142 health_checker: HealthChecker,
143 circuit_breakers: Arc<RwLock<HashMap<Uuid, Arc<CircuitBreaker>>>>,
144 config: ProxyConfig,
145}
146
147impl ProxyManager {
148 pub fn builder() -> ProxyManagerBuilder {
150 ProxyManagerBuilder::default()
151 }
152
153 pub fn with_round_robin(
155 storage: Arc<dyn ProxyStoragePort>,
156 config: ProxyConfig,
157 ) -> ProxyResult<Self> {
158 Self::builder()
159 .storage(storage)
160 .strategy(Arc::new(RoundRobinStrategy::default()))
161 .config(config)
162 .build()
163 }
164
165 pub fn with_random(
167 storage: Arc<dyn ProxyStoragePort>,
168 config: ProxyConfig,
169 ) -> ProxyResult<Self> {
170 Self::builder()
171 .storage(storage)
172 .strategy(Arc::new(RandomStrategy))
173 .config(config)
174 .build()
175 }
176
177 pub fn with_weighted(
179 storage: Arc<dyn ProxyStoragePort>,
180 config: ProxyConfig,
181 ) -> ProxyResult<Self> {
182 Self::builder()
183 .storage(storage)
184 .strategy(Arc::new(WeightedStrategy))
185 .config(config)
186 .build()
187 }
188
189 pub fn with_least_used(
191 storage: Arc<dyn ProxyStoragePort>,
192 config: ProxyConfig,
193 ) -> ProxyResult<Self> {
194 Self::builder()
195 .storage(storage)
196 .strategy(Arc::new(LeastUsedStrategy))
197 .config(config)
198 .build()
199 }
200
201 pub async fn add_proxy(&self, proxy: Proxy) -> ProxyResult<Uuid> {
212 let mut cb_map = self.circuit_breakers.write().await;
213 let record = self.storage.add(proxy).await?;
214 cb_map.insert(
215 record.id,
216 Arc::new(CircuitBreaker::new(
217 self.config.circuit_open_threshold,
218 self.config.circuit_half_open_after.as_millis() as u64,
219 )),
220 );
221 Ok(record.id)
222 }
223
224 pub async fn remove_proxy(&self, id: Uuid) -> ProxyResult<()> {
226 self.storage.remove(id).await?;
227 self.circuit_breakers.write().await.remove(&id);
228 Ok(())
229 }
230
231 pub fn start(&self) -> (CancellationToken, JoinHandle<()>) {
238 let token = CancellationToken::new();
239 let handle = self.health_checker.clone().spawn(token.clone());
240 (token, handle)
241 }
242
243 pub async fn acquire_proxy(&self) -> ProxyResult<ProxyHandle> {
251 let with_metrics = self.storage.list_with_metrics().await?;
252 if with_metrics.is_empty() {
253 return Err(ProxyError::PoolExhausted);
254 }
255
256 let health_map: tokio::sync::RwLockReadGuard<'_, _> =
257 self.health_checker.health_map().read().await;
258 let cb_map = self.circuit_breakers.read().await;
259
260 let candidates: Vec<ProxyCandidate> = with_metrics
261 .iter()
262 .map(|(record, metrics)| {
263 let healthy = health_map.get(&record.id).copied().unwrap_or(true);
265 let available = cb_map
266 .get(&record.id)
267 .map(|cb| cb.is_available())
268 .unwrap_or(true);
269 ProxyCandidate {
270 id: record.id,
271 weight: record.proxy.weight,
272 metrics: Arc::clone(metrics),
273 healthy: healthy && available,
274 }
275 })
276 .collect();
277
278 drop(health_map);
279 let selected = self.strategy.select(&candidates).await?;
280 let id = selected.id;
281
282 let cb = cb_map
286 .get(&id)
287 .cloned()
288 .ok_or(ProxyError::PoolExhausted)?;
289
290 let url = with_metrics
291 .iter()
292 .find(|(r, _)| r.id == id)
293 .map(|(r, _)| r.proxy.url.clone())
294 .unwrap_or_default();
295
296 Ok(ProxyHandle::new(url, cb))
297 }
298
299 pub async fn pool_stats(&self) -> ProxyResult<PoolStats> {
303 let records = self.storage.list().await?;
304 let total = records.len();
305 let health_map = self.health_checker.health_map().read().await;
306 let cb_map = self.circuit_breakers.read().await;
307
308 let mut healthy = 0usize;
309 let mut open = 0usize;
310 for r in &records {
311 if health_map.get(&r.id).copied().unwrap_or(true) {
312 healthy += 1;
313 }
314 if cb_map
315 .get(&r.id)
316 .map(|cb| !cb.is_available())
317 .unwrap_or(false)
318 {
319 open += 1;
320 }
321 }
322 Ok(PoolStats {
323 total,
324 healthy,
325 open,
326 })
327 }
328}
329
330#[derive(Default)]
336pub struct ProxyManagerBuilder {
337 storage: Option<Arc<dyn ProxyStoragePort>>,
338 strategy: Option<BoxedRotationStrategy>,
339 config: Option<ProxyConfig>,
340}
341
342impl ProxyManagerBuilder {
343 pub fn storage(mut self, s: Arc<dyn ProxyStoragePort>) -> Self {
344 self.storage = Some(s);
345 self
346 }
347
348 pub fn strategy(mut self, s: BoxedRotationStrategy) -> Self {
349 self.strategy = Some(s);
350 self
351 }
352
353 pub fn config(mut self, c: ProxyConfig) -> Self {
354 self.config = Some(c);
355 self
356 }
357
358 pub fn build(self) -> ProxyResult<ProxyManager> {
364 let storage = self.storage.ok_or_else(|| {
365 ProxyError::ConfigError("ProxyManagerBuilder: storage is required".into())
366 })?;
367 let strategy = self
368 .strategy
369 .unwrap_or_else(|| Arc::new(RoundRobinStrategy::default()));
370 let config = self.config.unwrap_or_default();
371 let health_map: HealthMap = Arc::new(RwLock::new(HashMap::new()));
372 let health_checker = HealthChecker::new(
373 config.clone(),
374 Arc::clone(&storage),
375 Arc::clone(&health_map),
376 );
377 Ok(ProxyManager {
378 storage,
379 strategy,
380 health_checker,
381 circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
382 config,
383 })
384 }
385}
386
387#[cfg(test)]
392mod tests {
393 use std::collections::HashSet;
394 use std::time::Duration;
395
396 use super::*;
397 use crate::circuit_breaker::{STATE_CLOSED, STATE_OPEN};
398 use crate::storage::MemoryProxyStore;
399 use crate::types::ProxyType;
400
401 fn make_proxy(url: &str) -> Proxy {
402 Proxy {
403 url: url.into(),
404 proxy_type: ProxyType::Http,
405 username: None,
406 password: None,
407 weight: 1,
408 tags: vec![],
409 }
410 }
411
412 fn storage() -> Arc<MemoryProxyStore> {
413 Arc::new(MemoryProxyStore::default())
414 }
415
416 #[tokio::test]
418 async fn round_robin_distribution() {
419 let store = storage();
420 let mgr = ProxyManager::with_round_robin(store.clone(), ProxyConfig::default()).unwrap();
421 mgr.add_proxy(make_proxy("http://a.test:8080"))
422 .await
423 .unwrap();
424 mgr.add_proxy(make_proxy("http://b.test:8080"))
425 .await
426 .unwrap();
427 mgr.add_proxy(make_proxy("http://c.test:8080"))
428 .await
429 .unwrap();
430
431 let mut seen = HashSet::new();
432 for _ in 0..10 {
433 let h = mgr.acquire_proxy().await.unwrap();
434 h.mark_success();
435 seen.insert(h.proxy_url.clone());
436 }
437 assert_eq!(seen.len(), 3, "all three proxies should have been selected");
438 }
439
440 #[tokio::test]
442 async fn all_open_returns_error() {
443 let store = storage();
444 let mgr = ProxyManager::with_round_robin(
445 store.clone(),
446 ProxyConfig {
447 circuit_open_threshold: 1,
448 ..ProxyConfig::default()
449 },
450 )
451 .unwrap();
452 let id = mgr
453 .add_proxy(make_proxy("http://x.test:8080"))
454 .await
455 .unwrap();
456
457 {
459 let map = mgr.circuit_breakers.read().await;
460 let cb = map.get(&id).unwrap();
461 cb.record_failure();
462 }
463
464 let err = mgr.acquire_proxy().await.unwrap_err();
465 assert!(
466 matches!(err, ProxyError::AllProxiesUnhealthy),
467 "expected AllProxiesUnhealthy, got {err:?}"
468 );
469 }
470
471 #[tokio::test]
473 async fn handle_drop_records_failure() {
474 let store = storage();
475 let mgr = ProxyManager::with_round_robin(
476 store.clone(),
477 ProxyConfig {
478 circuit_open_threshold: 1,
479 ..ProxyConfig::default()
480 },
481 )
482 .unwrap();
483 let id = mgr
484 .add_proxy(make_proxy("http://y.test:8080"))
485 .await
486 .unwrap();
487
488 {
489 let _h = mgr.acquire_proxy().await.unwrap();
490 }
492
493 let cb_map = mgr.circuit_breakers.read().await;
494 let cb = cb_map.get(&id).unwrap();
495 assert_eq!(cb.state(), STATE_OPEN);
496 }
497
498 #[tokio::test]
500 async fn handle_success_keeps_closed() {
501 let store = storage();
502 let mgr = ProxyManager::with_round_robin(store.clone(), ProxyConfig::default()).unwrap();
503 let id = mgr
504 .add_proxy(make_proxy("http://z.test:8080"))
505 .await
506 .unwrap();
507
508 let h = mgr.acquire_proxy().await.unwrap();
509 h.mark_success();
510 drop(h);
511
512 let cb_map = mgr.circuit_breakers.read().await;
513 let cb = cb_map.get(&id).unwrap();
514 assert_eq!(cb.state(), STATE_CLOSED);
515 }
516
517 #[tokio::test]
519 async fn start_and_graceful_shutdown() {
520 let store = storage();
521 let mgr = ProxyManager::with_round_robin(
522 store,
523 ProxyConfig {
524 health_check_interval: Duration::from_secs(3600),
525 ..ProxyConfig::default()
526 },
527 )
528 .unwrap();
529 let (token, handle) = mgr.start();
530 token.cancel();
531 let result = tokio::time::timeout(Duration::from_secs(1), handle).await;
532 assert!(result.is_ok(), "health checker task should exit within 1s");
533 }
534}