use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::browser::ContextOverride;
use crate::launcher::Proxy;
use super::health::{DEFAULT_CHECK_URL, ProxyHealth, probe_proxy};
use super::rotate::{RotateStrategy, Rotator};
#[derive(Clone)]
pub struct ProxyPool {
proxies: Arc<Vec<Proxy>>,
rotator: Arc<Rotator>,
health: Arc<Mutex<Vec<ProxyHealth>>>,
}
impl ProxyPool {
pub fn new(proxies: Vec<Proxy>) -> Self {
Self::with_strategy(proxies, RotateStrategy::RoundRobin)
}
pub fn with_strategy(proxies: Vec<Proxy>, strategy: RotateStrategy) -> Self {
let n = proxies.len();
Self {
proxies: Arc::new(proxies),
rotator: Arc::new(Rotator::new(strategy)),
health: Arc::new(Mutex::new(vec![ProxyHealth::default(); n])),
}
}
pub fn strategy(self, strategy: RotateStrategy) -> Self {
Self {
proxies: self.proxies,
rotator: Arc::new(Rotator::new(strategy)),
health: self.health,
}
}
pub fn len(&self) -> usize {
self.proxies.len()
}
pub fn is_empty(&self) -> bool {
self.proxies.is_empty()
}
#[allow(clippy::should_implement_trait)] pub fn next(&self) -> Option<Proxy> {
self.rotator
.pick(self.proxies.len(), None)
.map(|i| self.proxies[i].clone())
}
pub fn for_key(&self, key: &str) -> Option<Proxy> {
self.rotator
.pick(self.proxies.len(), Some(key))
.map(|i| self.proxies[i].clone())
}
pub async fn check_health(&self) -> usize {
self.check_health_with(DEFAULT_CHECK_URL, Duration::from_secs(10))
.await
}
pub async fn check_health_with(&self, check_url: &str, timeout: Duration) -> usize {
let proxies = self.proxies.clone();
let results = futures_util::future::join_all(
proxies.iter().map(|p| probe_proxy(p, check_url, timeout)),
)
.await;
let healthy = results.iter().filter(|h| h.healthy == Some(true)).count();
if let Ok(mut guard) = self.health.lock() {
*guard = results;
}
healthy
}
pub fn next_healthy(&self) -> Option<Proxy> {
let len = self.proxies.len();
let start = self.rotator.pick(len, None)?;
if let Ok(guard) = self.health.lock() {
for off in 0..len {
let i = (start + off) % len;
if guard.get(i).map(ProxyHealth::usable).unwrap_or(true) {
return Some(self.proxies[i].clone());
}
}
}
Some(self.proxies[start].clone())
}
pub fn next_coherent(&self) -> Option<ContextOverride> {
let p = self.next_healthy()?;
Some(self.coherent_override_for(&p))
}
pub fn coherent_override_for(&self, proxy: &Proxy) -> ContextOverride {
let geo = self
.index_of(&proxy.server)
.and_then(|i| self.health.lock().ok().map(|g| g[i].geo.clone()))
.unwrap_or_default();
geo.coherent_override().proxy(proxy.clone())
}
pub fn mark_bad(&self, server: &str) {
if let Some(i) = self.index_of(server)
&& let Ok(mut guard) = self.health.lock()
&& let Some(h) = guard.get_mut(i)
{
h.healthy = Some(false);
}
}
pub fn healthy_count(&self) -> usize {
self.health
.lock()
.map(|g| g.iter().filter(|h| h.healthy == Some(true)).count())
.unwrap_or(0)
}
pub fn health_of(&self, server: &str) -> Option<ProxyHealth> {
let i = self.index_of(server)?;
self.health.lock().ok().and_then(|g| g.get(i).cloned())
}
pub fn report(&self) -> Vec<(Proxy, ProxyHealth)> {
let guard = match self.health.lock() {
Ok(g) => g,
Err(_) => return Vec::new(),
};
self.proxies
.iter()
.cloned()
.zip(guard.iter().cloned())
.collect()
}
fn index_of(&self, server: &str) -> Option<usize> {
self.proxies.iter().position(|p| p.server == server)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn pool() -> ProxyPool {
ProxyPool::new(vec![
Proxy::new("http://a:1"),
Proxy::new("http://b:2"),
Proxy::new("http://c:3"),
])
}
#[test]
fn round_robin_cycles() {
let p = pool();
let got: Vec<String> = (0..4).map(|_| p.next().unwrap().server).collect();
assert_eq!(
got,
vec!["http://a:1", "http://b:2", "http://c:3", "http://a:1"]
);
}
#[test]
fn empty_pool_returns_none() {
let p = ProxyPool::new(vec![]);
assert!(p.is_empty());
assert!(p.next().is_none());
assert!(p.for_key("x").is_none());
}
#[test]
fn sticky_same_key_same_proxy() {
let p = pool().strategy(RotateStrategy::Sticky);
let a = p.for_key("acct-7").unwrap().server;
let b = p.for_key("acct-7").unwrap().server;
assert_eq!(a, b);
}
#[test]
fn clones_share_cursor() {
let p = pool();
let q = p.clone();
assert_eq!(p.next().unwrap().server, "http://a:1");
assert_eq!(q.next().unwrap().server, "http://b:2");
assert_eq!(p.next().unwrap().server, "http://c:3");
}
#[test]
fn mark_bad_then_next_healthy_skips_it() {
let p = pool();
p.mark_bad("http://a:1");
let got: Vec<String> = (0..6).map(|_| p.next_healthy().unwrap().server).collect();
assert!(
!got.iter().any(|s| s == "http://a:1"),
"应跳过坏代理,实得 {got:?}"
);
assert!(got.iter().any(|s| s == "http://b:2"));
assert!(got.iter().any(|s| s == "http://c:3"));
}
#[test]
fn next_healthy_before_check_behaves_like_next() {
let p = pool();
assert!(p.next_healthy().is_some());
assert_eq!(p.healthy_count(), 0); }
#[test]
fn all_bad_falls_back_to_cursor() {
let p = pool();
for s in ["http://a:1", "http://b:2", "http://c:3"] {
p.mark_bad(s);
}
assert!(p.next_healthy().is_some());
}
#[test]
fn coherent_override_includes_proxy() {
let p = pool();
let proxy = Proxy::new("http://b:2");
let ov = p.coherent_override_for(&proxy);
assert_eq!(
ov.proxy.as_ref().map(|p| p.server.as_str()),
Some("http://b:2")
);
assert!(ov.timezone_id.is_none());
}
#[test]
fn report_has_one_entry_per_proxy() {
let p = pool();
let rep = p.report();
assert_eq!(rep.len(), 3);
assert!(rep.iter().all(|(_, h)| h.healthy.is_none())); }
}