use std::collections::HashSet;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, Semaphore};
use tokio::task::JoinSet;
use tracing::{debug, warn};
use crate::circuit_breaker::CircuitBreakerManager;
use crate::classifier::{BodyClassifier, BodyVerdict};
use crate::config::ScatterProxyConfig;
use crate::error::ScatterProxyError;
use crate::health::HealthTracker;
use crate::metrics::ThroughputTracker;
use crate::proxy::{ProxyManager, ProxyState};
use crate::rate_limit::RateLimiter;
use crate::score::{adaptive_k, compute_score};
use crate::task::{ScatterResponse, TaskEntry, TaskPool};
enum RaceOutcome {
Response(reqwest::Response),
RequestError(String),
Timeout,
}
pub(crate) struct Scheduler {
config: Arc<ScatterProxyConfig>,
task_pool: Arc<TaskPool>,
health: Arc<HealthTracker>,
rate_limiter: Arc<RateLimiter>,
circuit_breakers: Arc<CircuitBreakerManager>,
proxy_manager: Arc<ProxyManager>,
classifier: Arc<dyn BodyClassifier>,
semaphore: Arc<Semaphore>,
throughput: Arc<ThroughputTracker>,
}
impl Scheduler {
#[allow(clippy::too_many_arguments)]
pub fn new(
config: Arc<ScatterProxyConfig>,
task_pool: Arc<TaskPool>,
health: Arc<HealthTracker>,
rate_limiter: Arc<RateLimiter>,
circuit_breakers: Arc<CircuitBreakerManager>,
proxy_manager: Arc<ProxyManager>,
classifier: Arc<dyn BodyClassifier>,
semaphore: Arc<Semaphore>,
throughput: Arc<ThroughputTracker>,
) -> Self {
Self {
config,
task_pool,
health,
rate_limiter,
circuit_breakers,
proxy_manager,
classifier,
semaphore,
throughput,
}
}
pub async fn run(self, mut shutdown: broadcast::Receiver<()>) {
debug!("scheduler started");
loop {
tokio::select! {
_ = shutdown.recv() => break,
_ = self.schedule_one() => {}
}
}
debug!("scheduler stopped");
}
fn is_in_cooldown(&self, proxy: &str, host: &str) -> bool {
let consecutive_fails = self.health.get_consecutive_fails(proxy, host);
let threshold = self.config.cooldown_consecutive_fails as u32;
if consecutive_fails < threshold {
return false;
}
let exponent = (consecutive_fails - threshold).min(32);
let factor = 2f64.powi(exponent as i32);
let cooldown_secs = self.config.cooldown_base.as_secs_f64() * factor;
let cooldown_secs = cooldown_secs.min(self.config.cooldown_max.as_secs_f64());
self.health.seconds_since_last_access(proxy, host) < cooldown_secs
}
async fn schedule_one(&self) {
let cb_state = self.circuit_breakers.get_all();
let skip_hosts: HashSet<String> = cb_state
.iter()
.filter(|(_, &is_open)| is_open)
.map(|(host, _)| host.clone())
.collect();
for probe_host in &skip_hosts {
if self.circuit_breakers.should_probe(probe_host) {
let mut probe_skip = skip_hosts.clone();
probe_skip.remove(probe_host);
if let Some(task) = self.task_pool.pick_next(&probe_skip) {
if task.host == *probe_host {
self.run_probe(task).await;
} else {
self.task_pool.push_back(task);
}
}
}
}
let mut task = match self.task_pool.pick_next(&skip_hosts) {
Some(t) => t,
None => {
tokio::time::sleep(Duration::from_millis(50)).await;
return;
}
};
let host = task.host.clone();
let elapsed = task.submitted_at.elapsed();
if elapsed >= task.task_timeout {
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(Err(ScatterProxyError::Timeout { host, elapsed }));
}
self.task_pool.mark_failed();
return;
}
let active_proxies = self.proxy_manager.get_active_proxies();
let available: Vec<String> = active_proxies
.into_iter()
.filter(|proxy| {
self.rate_limiter.is_available(proxy, &host) && !self.is_in_cooldown(proxy, &host)
})
.collect();
if available.is_empty() {
self.circuit_breakers.trip(&host, "zero available proxies");
warn!(host = %host, "circuit OPEN | reason=zero available proxies");
self.task_pool.push_back(task);
return;
}
let avg_success_rate = self.health.avg_success_rate_for_host(&host);
let k = adaptive_k(
available.len(),
avg_success_rate,
self.config.max_concurrent_per_request,
)
.max(1);
let mut scored: Vec<(String, f64)> = available
.iter()
.map(|proxy| {
let s = compute_score(&self.health, proxy, &host);
(proxy.clone(), s)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let candidates: Vec<String> = scored.into_iter().take(k).map(|(proxy, _)| proxy).collect();
let mut join_set: JoinSet<(String, RaceOutcome, f64)> = JoinSet::new();
let mut actual_candidates: Vec<String> = Vec::new();
for proxy_url in &candidates {
let req = match task.request.try_clone() {
Some(r) => r,
None => {
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(Err(ScatterProxyError::Init(
"request body is not cloneable".into(),
)));
}
self.task_pool.mark_failed();
return;
}
};
let client = match self.proxy_manager.get_client(proxy_url) {
Ok(c) => c,
Err(_) => continue,
};
let permit = match self.semaphore.clone().try_acquire_owned() {
Ok(p) => p,
Err(_) => continue,
};
self.rate_limiter.mark(proxy_url, &host);
let proxy_timeout = self.config.proxy_timeout;
let proxy_url_owned = proxy_url.clone();
actual_candidates.push(proxy_url.clone());
join_set.spawn(async move {
let start = Instant::now();
let outcome = match tokio::time::timeout(proxy_timeout, client.execute(req)).await {
Ok(Ok(response)) => RaceOutcome::Response(response),
Ok(Err(e)) => RaceOutcome::RequestError(e.to_string()),
Err(_) => RaceOutcome::Timeout,
};
let latency_ms = start.elapsed().as_secs_f64() * 1000.0;
drop(permit);
(proxy_url_owned, outcome, latency_ms)
});
}
if join_set.is_empty() {
self.task_pool.push_back(task);
return;
}
let mut success = false;
let mut last_error = String::new();
while let Some(result) = join_set.join_next().await {
let (proxy_url, outcome, latency_ms) = match result {
Ok(v) => v,
Err(e) => {
last_error = format!("task join error: {e}");
continue;
}
};
match outcome {
RaceOutcome::Response(response) => {
let status = response.status();
let headers = response.headers().clone();
let body = response.bytes().await.unwrap_or_default();
let verdict = self.classifier.classify(status, &headers, &body);
match verdict {
BodyVerdict::Success => {
self.health.record_success(&proxy_url, &host, latency_ms);
self.circuit_breakers.record_success(&host);
join_set.abort_all();
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(Ok(ScatterResponse {
status,
headers,
body,
}));
}
self.task_pool.mark_completed();
self.throughput.record();
debug!(
host = %host,
winner = %proxy_url,
latency_ms = latency_ms,
attempt = task.attempts + 1,
max_attempts = task.max_attempts,
"task done"
);
success = true;
break;
}
BodyVerdict::ProxyBlocked => {
self.health.record_failure(&proxy_url, &host);
last_error = format!("proxy blocked (status={status})");
}
BodyVerdict::TargetError => {
self.circuit_breakers.record_target_error(&host);
last_error = format!("target error (status={status})");
}
}
}
RaceOutcome::RequestError(err) => {
self.health.record_failure(&proxy_url, &host);
last_error = err;
}
RaceOutcome::Timeout => {
self.health.record_failure(&proxy_url, &host);
last_error = "proxy timeout".into();
}
}
}
if success {
while join_set.join_next().await.is_some() {}
}
if !success {
task.attempts += 1;
task.last_error = last_error;
let elapsed = task.submitted_at.elapsed();
if task.attempts >= task.max_attempts {
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(Err(ScatterProxyError::MaxAttemptsExhausted {
host: host.clone(),
attempts: task.attempts,
last_error: task.last_error.clone(),
}));
}
self.task_pool.mark_failed();
} else if elapsed >= task.task_timeout {
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(Err(ScatterProxyError::Timeout {
host: host.clone(),
elapsed,
}));
}
self.task_pool.mark_failed();
} else {
debug!(
host = %host,
attempt = task.attempts,
max_attempts = task.max_attempts,
reason = %task.last_error,
"task requeued"
);
self.task_pool.push_back(task);
}
}
for proxy_url in &actual_candidates {
let samples = self.health.total_samples_for_proxy(proxy_url);
if samples >= self.config.eviction_min_samples as u32 {
let global_rate = self.health.global_success_rate_for_proxy(proxy_url);
if global_rate == 0.0 {
self.proxy_manager.set_state(proxy_url, ProxyState::Dead);
debug!(
proxy = %proxy_url,
samples = samples,
"proxy dead | global success_rate=0%"
);
}
}
}
}
async fn run_probe(&self, mut task: TaskEntry) {
let host = task.host.clone();
let elapsed = task.submitted_at.elapsed();
if elapsed >= task.task_timeout {
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(Err(ScatterProxyError::Timeout { host, elapsed }));
}
self.task_pool.mark_failed();
return;
}
let active_proxies = self.proxy_manager.get_active_proxies();
let proxy_url = match active_proxies
.iter()
.find(|p| self.rate_limiter.is_available(p, &host) && !self.is_in_cooldown(p, &host))
{
Some(p) => p.clone(),
None => {
self.task_pool.push_back(task);
return;
}
};
let req = match task.request.try_clone() {
Some(r) => r,
None => {
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(Err(ScatterProxyError::Init(
"request body is not cloneable".into(),
)));
}
self.task_pool.mark_failed();
return;
}
};
let client = match self.proxy_manager.get_client(&proxy_url) {
Ok(c) => c,
Err(_) => {
self.task_pool.push_back(task);
return;
}
};
let permit = match self.semaphore.clone().try_acquire_owned() {
Ok(p) => p,
Err(_) => {
self.task_pool.push_back(task);
return;
}
};
self.rate_limiter.mark(&proxy_url, &host);
let start = Instant::now();
let outcome =
match tokio::time::timeout(self.config.proxy_timeout, client.execute(req)).await {
Ok(Ok(response)) => RaceOutcome::Response(response),
Ok(Err(e)) => RaceOutcome::RequestError(e.to_string()),
Err(_) => RaceOutcome::Timeout,
};
let latency_ms = start.elapsed().as_secs_f64() * 1000.0;
drop(permit);
match outcome {
RaceOutcome::Response(response) => {
let status = response.status();
let headers = response.headers().clone();
let body = response.bytes().await.unwrap_or_default();
let verdict = self.classifier.classify(status, &headers, &body);
match verdict {
BodyVerdict::Success => {
self.health.record_success(&proxy_url, &host, latency_ms);
self.circuit_breakers.record_success(&host);
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(Ok(ScatterResponse {
status,
headers,
body,
}));
}
self.task_pool.mark_completed();
self.throughput.record();
debug!(
host = %host,
winner = %proxy_url,
latency_ms = latency_ms,
attempt = task.attempts + 1,
max_attempts = task.max_attempts,
"task done"
);
}
BodyVerdict::ProxyBlocked => {
self.health.record_failure(&proxy_url, &host);
task.attempts += 1;
task.last_error = format!("proxy blocked (status={status})");
self.requeue_or_fail(task, &host);
}
BodyVerdict::TargetError => {
self.circuit_breakers.record_target_error(&host);
task.attempts += 1;
task.last_error = format!("target error (status={status})");
self.requeue_or_fail(task, &host);
}
}
}
RaceOutcome::RequestError(err) => {
self.health.record_failure(&proxy_url, &host);
task.attempts += 1;
task.last_error = err;
self.requeue_or_fail(task, &host);
}
RaceOutcome::Timeout => {
self.health.record_failure(&proxy_url, &host);
task.attempts += 1;
task.last_error = "proxy timeout".into();
self.requeue_or_fail(task, &host);
}
}
let samples = self.health.total_samples_for_proxy(&proxy_url);
if samples >= self.config.eviction_min_samples as u32 {
let global_rate = self.health.global_success_rate_for_proxy(&proxy_url);
if global_rate == 0.0 {
self.proxy_manager.set_state(&proxy_url, ProxyState::Dead);
debug!(
proxy = %proxy_url,
samples = samples,
"proxy dead | global success_rate=0%"
);
}
}
}
fn requeue_or_fail(&self, mut task: TaskEntry, host: &str) {
let elapsed = task.submitted_at.elapsed();
if task.attempts >= task.max_attempts {
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(Err(ScatterProxyError::MaxAttemptsExhausted {
host: host.to_string(),
attempts: task.attempts,
last_error: task.last_error.clone(),
}));
}
self.task_pool.mark_failed();
} else if elapsed >= task.task_timeout {
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(Err(ScatterProxyError::Timeout {
host: host.to_string(),
elapsed,
}));
}
self.task_pool.mark_failed();
} else {
debug!(
host = %host,
attempt = task.attempts,
max_attempts = task.max_attempts,
reason = %task.last_error,
"task requeued"
);
self.task_pool.push_back(task);
}
}
}