1use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicBool, Ordering};
9
10use serde::Serialize;
11use tokio::sync::RwLock;
12use tokio::task::JoinHandle;
13use tokio_util::sync::CancellationToken;
14use uuid::Uuid;
15
16use crate::circuit_breaker::CircuitBreaker;
17use crate::error::{ProxyError, ProxyResult};
18use crate::health::{HealthChecker, HealthMap};
19use crate::session::{SessionMap, StickyPolicy};
20use crate::storage::ProxyStoragePort;
21use crate::strategy::{
22 BoxedRotationStrategy, LeastUsedStrategy, ProxyCandidate, RandomStrategy, RoundRobinStrategy,
23 WeightedStrategy,
24};
25use crate::types::{Proxy, ProxyConfig};
26
27#[derive(Debug, Serialize)]
33pub struct PoolStats {
34 pub total: usize,
36 pub healthy: usize,
38 pub open: usize,
40 pub active_sessions: usize,
42}
43
44pub struct ProxyHandle {
54 pub proxy_url: String,
56 circuit_breaker: Arc<CircuitBreaker>,
57 succeeded: AtomicBool,
58 session_key: Option<String>,
60 sessions: Option<SessionMap>,
61}
62
63impl ProxyHandle {
64 fn new(proxy_url: String, circuit_breaker: Arc<CircuitBreaker>) -> Self {
65 Self {
66 proxy_url,
67 circuit_breaker,
68 succeeded: AtomicBool::new(false),
69 session_key: None,
70 sessions: None,
71 }
72 }
73
74 fn new_sticky(
75 proxy_url: String,
76 circuit_breaker: Arc<CircuitBreaker>,
77 session_key: String,
78 sessions: SessionMap,
79 ) -> Self {
80 Self {
81 proxy_url,
82 circuit_breaker,
83 succeeded: AtomicBool::new(false),
84 session_key: Some(session_key),
85 sessions: Some(sessions),
86 }
87 }
88
89 pub fn direct() -> Self {
94 let noop_cb = Arc::new(CircuitBreaker::new(u32::MAX, u64::MAX));
95 Self {
96 proxy_url: String::new(),
97 circuit_breaker: noop_cb,
98 succeeded: AtomicBool::new(true),
99 session_key: None,
100 sessions: None,
101 }
102 }
103
104 pub fn mark_success(&self) {
106 self.succeeded.store(true, Ordering::Release);
107 }
108}
109
110impl std::fmt::Debug for ProxyHandle {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 f.debug_struct("ProxyHandle")
113 .field("proxy_url", &self.proxy_url)
114 .finish_non_exhaustive()
115 }
116}
117
118impl Drop for ProxyHandle {
119 fn drop(&mut self) {
120 if self.succeeded.load(Ordering::Acquire) {
121 self.circuit_breaker.record_success();
122 } else {
123 self.circuit_breaker.record_failure();
124 if let (Some(key), Some(sessions)) = (&self.session_key, &self.sessions) {
126 sessions.unbind(key);
127 }
128 }
129 }
130}
131
132pub struct ProxyManager {
169 storage: Arc<dyn ProxyStoragePort>,
170 strategy: BoxedRotationStrategy,
171 health_checker: HealthChecker,
172 circuit_breakers: Arc<RwLock<HashMap<Uuid, Arc<CircuitBreaker>>>>,
173 config: ProxyConfig,
174 sessions: SessionMap,
176}
177
178impl ProxyManager {
179 pub fn builder() -> ProxyManagerBuilder {
181 ProxyManagerBuilder::default()
182 }
183
184 pub fn with_round_robin(
186 storage: Arc<dyn ProxyStoragePort>,
187 config: ProxyConfig,
188 ) -> ProxyResult<Self> {
189 Self::builder()
190 .storage(storage)
191 .strategy(Arc::new(RoundRobinStrategy::default()))
192 .config(config)
193 .build()
194 }
195
196 pub fn with_random(
198 storage: Arc<dyn ProxyStoragePort>,
199 config: ProxyConfig,
200 ) -> ProxyResult<Self> {
201 Self::builder()
202 .storage(storage)
203 .strategy(Arc::new(RandomStrategy))
204 .config(config)
205 .build()
206 }
207
208 pub fn with_weighted(
210 storage: Arc<dyn ProxyStoragePort>,
211 config: ProxyConfig,
212 ) -> ProxyResult<Self> {
213 Self::builder()
214 .storage(storage)
215 .strategy(Arc::new(WeightedStrategy))
216 .config(config)
217 .build()
218 }
219
220 pub fn with_least_used(
222 storage: Arc<dyn ProxyStoragePort>,
223 config: ProxyConfig,
224 ) -> ProxyResult<Self> {
225 Self::builder()
226 .storage(storage)
227 .strategy(Arc::new(LeastUsedStrategy))
228 .config(config)
229 .build()
230 }
231
232 pub async fn add_proxy(&self, proxy: Proxy) -> ProxyResult<Uuid> {
243 let mut cb_map = self.circuit_breakers.write().await;
244 let record = self.storage.add(proxy).await?;
245 cb_map.insert(
246 record.id,
247 Arc::new(CircuitBreaker::new(
248 self.config.circuit_open_threshold,
249 self.config.circuit_half_open_after.as_millis() as u64,
250 )),
251 );
252 Ok(record.id)
253 }
254
255 pub async fn remove_proxy(&self, id: Uuid) -> ProxyResult<()> {
257 self.storage.remove(id).await?;
258 self.circuit_breakers.write().await.remove(&id);
259 Ok(())
260 }
261
262 pub fn start(&self) -> (CancellationToken, JoinHandle<()>) {
269 let token = CancellationToken::new();
270 let health_handle = self.health_checker.clone().spawn(token.clone());
271
272 let sessions = self.sessions.clone();
273 let purge_token = token.clone();
274 let purge_handle = tokio::spawn(async move {
275 let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
276 loop {
277 tokio::select! {
278 _ = interval.tick() => { sessions.purge_expired(); }
279 _ = purge_token.cancelled() => break,
280 }
281 }
282 });
283
284 let combined = tokio::spawn(async move {
285 let _ = tokio::join!(health_handle, purge_handle);
286 });
287
288 (token, combined)
289 }
290
291 async fn select_proxy_inner(&self) -> ProxyResult<(String, Arc<CircuitBreaker>, Uuid)> {
297 let with_metrics = self.storage.list_with_metrics().await?;
298 if with_metrics.is_empty() {
299 return Err(ProxyError::PoolExhausted);
300 }
301
302 let health_map = self.health_checker.health_map().read().await;
303 let cb_map = self.circuit_breakers.read().await;
304
305 let candidates: Vec<ProxyCandidate> = with_metrics
306 .iter()
307 .map(|(record, metrics)| {
308 let healthy = health_map.get(&record.id).copied().unwrap_or(true);
309 let available = cb_map
310 .get(&record.id)
311 .map(|cb| cb.is_available())
312 .unwrap_or(true);
313 ProxyCandidate {
314 id: record.id,
315 weight: record.proxy.weight,
316 metrics: Arc::clone(metrics),
317 healthy: healthy && available,
318 }
319 })
320 .collect();
321
322 drop(health_map);
323 let selected = self.strategy.select(&candidates).await?;
324 let id = selected.id;
325
326 let cb = cb_map.get(&id).cloned().ok_or(ProxyError::PoolExhausted)?;
327 let url = with_metrics
328 .iter()
329 .find(|(r, _)| r.id == id)
330 .map(|(r, _)| r.proxy.url.clone())
331 .unwrap_or_default();
332
333 Ok((url, cb, id))
334 }
335
336 pub async fn acquire_proxy(&self) -> ProxyResult<ProxyHandle> {
342 let (url, cb, _id) = self.select_proxy_inner().await?;
343 Ok(ProxyHandle::new(url, cb))
344 }
345
346 pub async fn acquire_for_domain(&self, domain: &str) -> ProxyResult<ProxyHandle> {
360 let ttl = match &self.config.sticky_policy {
361 StickyPolicy::Disabled => return self.acquire_proxy().await,
362 StickyPolicy::Domain { ttl } => *ttl,
363 };
364
365 if let Some(proxy_id) = self.sessions.lookup(domain) {
367 let cb_map = self.circuit_breakers.read().await;
368 if let Some(cb) = cb_map.get(&proxy_id).cloned() {
369 if cb.is_available() {
370 let with_metrics = self.storage.list_with_metrics().await?;
372 if let Some((record, _)) = with_metrics.iter().find(|(r, _)| r.id == proxy_id) {
373 let url = record.proxy.url.clone();
374 drop(cb_map);
375 return Ok(ProxyHandle::new_sticky(
376 url,
377 cb,
378 domain.to_string(),
379 self.sessions.clone(),
380 ));
381 }
382 }
383 }
384 drop(cb_map);
386 self.sessions.unbind(domain);
387 }
388
389 let (url, cb, proxy_id) = self.select_proxy_inner().await?;
391 self.sessions.bind(domain, proxy_id, ttl);
392 Ok(ProxyHandle::new_sticky(
393 url,
394 cb,
395 domain.to_string(),
396 self.sessions.clone(),
397 ))
398 }
399
400 pub async fn pool_stats(&self) -> ProxyResult<PoolStats> {
404 let records = self.storage.list().await?;
405 let total = records.len();
406 let health_map = self.health_checker.health_map().read().await;
407 let cb_map = self.circuit_breakers.read().await;
408
409 let mut healthy = 0usize;
410 let mut open = 0usize;
411 for r in &records {
412 if health_map.get(&r.id).copied().unwrap_or(true) {
413 healthy += 1;
414 }
415 if cb_map
416 .get(&r.id)
417 .map(|cb| !cb.is_available())
418 .unwrap_or(false)
419 {
420 open += 1;
421 }
422 }
423 Ok(PoolStats {
424 total,
425 healthy,
426 open,
427 active_sessions: self.sessions.active_count(),
428 })
429 }
430}
431
432#[derive(Default)]
438pub struct ProxyManagerBuilder {
439 storage: Option<Arc<dyn ProxyStoragePort>>,
440 strategy: Option<BoxedRotationStrategy>,
441 config: Option<ProxyConfig>,
442}
443
444impl ProxyManagerBuilder {
445 pub fn storage(mut self, s: Arc<dyn ProxyStoragePort>) -> Self {
446 self.storage = Some(s);
447 self
448 }
449
450 pub fn strategy(mut self, s: BoxedRotationStrategy) -> Self {
451 self.strategy = Some(s);
452 self
453 }
454
455 pub fn config(mut self, c: ProxyConfig) -> Self {
456 self.config = Some(c);
457 self
458 }
459
460 pub fn build(self) -> ProxyResult<ProxyManager> {
466 let storage = self.storage.ok_or_else(|| {
467 ProxyError::ConfigError("ProxyManagerBuilder: storage is required".into())
468 })?;
469 let strategy = self
470 .strategy
471 .unwrap_or_else(|| Arc::new(RoundRobinStrategy::default()));
472 let config = self.config.unwrap_or_default();
473 let health_map: HealthMap = Arc::new(RwLock::new(HashMap::new()));
474 let health_checker = HealthChecker::new(
475 config.clone(),
476 Arc::clone(&storage),
477 Arc::clone(&health_map),
478 );
479 Ok(ProxyManager {
480 storage,
481 strategy,
482 health_checker,
483 circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
484 config,
485 sessions: SessionMap::new(),
486 })
487 }
488}
489
490#[cfg(test)]
495#[allow(clippy::unwrap_used)]
496mod tests {
497 use std::collections::HashSet;
498 use std::time::Duration;
499
500 use super::*;
501 use crate::circuit_breaker::{STATE_CLOSED, STATE_OPEN};
502 use crate::storage::MemoryProxyStore;
503 use crate::types::ProxyType;
504
505 fn make_proxy(url: &str) -> Proxy {
506 Proxy {
507 url: url.into(),
508 proxy_type: ProxyType::Http,
509 username: None,
510 password: None,
511 weight: 1,
512 tags: vec![],
513 }
514 }
515
516 fn storage() -> Arc<MemoryProxyStore> {
517 Arc::new(MemoryProxyStore::default())
518 }
519
520 #[tokio::test]
522 async fn round_robin_distribution() {
523 let store = storage();
524 let mgr = ProxyManager::with_round_robin(store.clone(), ProxyConfig::default()).unwrap();
525 mgr.add_proxy(make_proxy("http://a.test:8080"))
526 .await
527 .unwrap();
528 mgr.add_proxy(make_proxy("http://b.test:8080"))
529 .await
530 .unwrap();
531 mgr.add_proxy(make_proxy("http://c.test:8080"))
532 .await
533 .unwrap();
534
535 let mut seen = HashSet::new();
536 for _ in 0..10 {
537 let h = mgr.acquire_proxy().await.unwrap();
538 h.mark_success();
539 seen.insert(h.proxy_url.clone());
540 }
541 assert_eq!(seen.len(), 3, "all three proxies should have been selected");
542 }
543
544 #[tokio::test]
546 async fn all_open_returns_error() {
547 let store = storage();
548 let mgr = ProxyManager::with_round_robin(
549 store.clone(),
550 ProxyConfig {
551 circuit_open_threshold: 1,
552 ..ProxyConfig::default()
553 },
554 )
555 .unwrap();
556 let id = mgr
557 .add_proxy(make_proxy("http://x.test:8080"))
558 .await
559 .unwrap();
560
561 {
563 let map = mgr.circuit_breakers.read().await;
564 let cb = map.get(&id).unwrap();
565 cb.record_failure();
566 }
567
568 let err = mgr.acquire_proxy().await.unwrap_err();
569 assert!(
570 matches!(err, ProxyError::AllProxiesUnhealthy),
571 "expected AllProxiesUnhealthy, got {err:?}"
572 );
573 }
574
575 #[tokio::test]
577 async fn handle_drop_records_failure() {
578 let store = storage();
579 let mgr = ProxyManager::with_round_robin(
580 store.clone(),
581 ProxyConfig {
582 circuit_open_threshold: 1,
583 ..ProxyConfig::default()
584 },
585 )
586 .unwrap();
587 let id = mgr
588 .add_proxy(make_proxy("http://y.test:8080"))
589 .await
590 .unwrap();
591
592 {
593 let _h = mgr.acquire_proxy().await.unwrap();
594 }
596
597 let cb_map = mgr.circuit_breakers.read().await;
598 let cb = cb_map.get(&id).unwrap();
599 assert_eq!(cb.state(), STATE_OPEN);
600 }
601
602 #[tokio::test]
604 async fn handle_success_keeps_closed() {
605 let store = storage();
606 let mgr = ProxyManager::with_round_robin(store.clone(), ProxyConfig::default()).unwrap();
607 let id = mgr
608 .add_proxy(make_proxy("http://z.test:8080"))
609 .await
610 .unwrap();
611
612 let h = mgr.acquire_proxy().await.unwrap();
613 h.mark_success();
614 drop(h);
615
616 let cb_map = mgr.circuit_breakers.read().await;
617 let cb = cb_map.get(&id).unwrap();
618 assert_eq!(cb.state(), STATE_CLOSED);
619 }
620
621 #[tokio::test]
623 async fn start_and_graceful_shutdown() {
624 let store = storage();
625 let mgr = ProxyManager::with_round_robin(
626 store,
627 ProxyConfig {
628 health_check_interval: Duration::from_secs(3600),
629 ..ProxyConfig::default()
630 },
631 )
632 .unwrap();
633 let (token, handle) = mgr.start();
634 token.cancel();
635 let result = tokio::time::timeout(Duration::from_secs(1), handle).await;
636 assert!(result.is_ok(), "health checker task should exit within 1s");
637 }
638
639 fn sticky_config() -> ProxyConfig {
642 use crate::session::StickyPolicy;
643 ProxyConfig {
644 sticky_policy: StickyPolicy::domain_default(),
645 ..ProxyConfig::default()
646 }
647 }
648
649 #[tokio::test]
651 async fn sticky_same_domain_returns_same_proxy() {
652 let store = storage();
653 let mgr = ProxyManager::with_round_robin(store, sticky_config()).unwrap();
654 mgr.add_proxy(make_proxy("http://p1.test:8080"))
655 .await
656 .unwrap();
657 mgr.add_proxy(make_proxy("http://p2.test:8080"))
658 .await
659 .unwrap();
660
661 let h1 = mgr.acquire_for_domain("example.com").await.unwrap();
662 let url1 = h1.proxy_url.clone();
663 h1.mark_success();
664
665 let h2 = mgr.acquire_for_domain("example.com").await.unwrap();
666 let url2 = h2.proxy_url.clone();
667 h2.mark_success();
668
669 assert_eq!(url1, url2, "same domain should return the same proxy");
670 }
671
672 #[tokio::test]
674 async fn sticky_different_domains_may_differ() {
675 let store = storage();
676 let mgr = ProxyManager::with_round_robin(store, sticky_config()).unwrap();
677 mgr.add_proxy(make_proxy("http://pa.test:8080"))
678 .await
679 .unwrap();
680 mgr.add_proxy(make_proxy("http://pb.test:8080"))
681 .await
682 .unwrap();
683
684 let ha = mgr.acquire_for_domain("a.com").await.unwrap();
685 let url_a = ha.proxy_url.clone();
686 ha.mark_success();
687
688 let hb = mgr.acquire_for_domain("b.com").await.unwrap();
689 let url_b = hb.proxy_url.clone();
690 hb.mark_success();
691
692 assert_ne!(
694 url_a, url_b,
695 "different domains should differ in this scenario"
696 );
697 }
698
699 #[tokio::test]
702 async fn sticky_expired_session_re_acquires() {
703 use crate::session::StickyPolicy;
704 let store = storage();
705 let mgr = ProxyManager::with_round_robin(
706 store,
707 ProxyConfig {
708 sticky_policy: StickyPolicy::domain(Duration::from_millis(1)),
709 ..ProxyConfig::default()
710 },
711 )
712 .unwrap();
713 mgr.add_proxy(make_proxy("http://x.test:8080"))
714 .await
715 .unwrap();
716
717 let h1 = mgr.acquire_for_domain("expired.com").await.unwrap();
718 h1.mark_success();
719
720 tokio::time::sleep(Duration::from_millis(5)).await;
722
723 let h2 = mgr.acquire_for_domain("expired.com").await.unwrap();
725 h2.mark_success();
726 }
727
728 #[tokio::test]
731 async fn sticky_cb_trip_invalidates_session() {
732 let store = storage();
733 let mgr = ProxyManager::with_round_robin(
734 store,
735 ProxyConfig {
736 circuit_open_threshold: 1,
737 sticky_policy: sticky_config().sticky_policy,
738 ..ProxyConfig::default()
739 },
740 )
741 .unwrap();
742 mgr.add_proxy(make_proxy("http://q1.test:8080"))
743 .await
744 .unwrap();
745 mgr.add_proxy(make_proxy("http://q2.test:8080"))
746 .await
747 .unwrap();
748
749 let h1 = mgr.acquire_for_domain("cb.com").await.unwrap();
751 let url1 = h1.proxy_url.clone();
752 drop(h1);
754
755 tokio::task::yield_now().await;
757
758 let _h2 = mgr.acquire_for_domain("cb.com").await;
762 let _ = url1;
764 }
765
766 #[tokio::test]
768 async fn sticky_purge_expired() {
769 use crate::session::StickyPolicy;
770 let store = storage();
771 let mgr = ProxyManager::with_round_robin(
772 store,
773 ProxyConfig {
774 sticky_policy: StickyPolicy::domain(Duration::from_millis(1)),
775 ..ProxyConfig::default()
776 },
777 )
778 .unwrap();
779 mgr.add_proxy(make_proxy("http://r.test:8080"))
780 .await
781 .unwrap();
782
783 let h = mgr.acquire_for_domain("purge.com").await.unwrap();
784 h.mark_success();
785
786 assert_eq!(mgr.sessions.active_count(), 1);
787
788 tokio::time::sleep(Duration::from_millis(5)).await;
790 mgr.sessions.purge_expired();
791
792 assert_eq!(mgr.sessions.active_count(), 0);
793 }
794
795 #[tokio::test]
797 async fn pool_stats_includes_sessions() {
798 let store = storage();
799 let mgr = ProxyManager::with_round_robin(store, sticky_config()).unwrap();
800 mgr.add_proxy(make_proxy("http://s.test:8080"))
801 .await
802 .unwrap();
803
804 let stats = mgr.pool_stats().await.unwrap();
805 assert_eq!(stats.active_sessions, 0);
806
807 let h = mgr.acquire_for_domain("stats.com").await.unwrap();
808 h.mark_success();
809
810 let stats = mgr.pool_stats().await.unwrap();
811 assert_eq!(stats.active_sessions, 1);
812 }
813}