use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, Semaphore};
use tokio::task::JoinSet;
use tracing::{debug, warn};
use crate::classifier::{BodyClassifier, BodyVerdict};
use crate::config::ScatterProxyConfig;
use crate::health::HealthTracker;
use crate::metrics::ThroughputTracker;
use crate::proxy::{ProxyManager, ProxyState};
use crate::rate_limit::RateLimiter;
use crate::score::{adaptive_k, compute_score};
use crate::task::{ScatterResponse, TaskEntry, TaskPool};
enum RaceOutcome {
Response(reqwest::Response),
RequestError(String),
Timeout,
}
pub(crate) struct Scheduler {
config: Arc<ScatterProxyConfig>,
task_pool: Arc<TaskPool>,
health: Arc<HealthTracker>,
rate_limiter: Arc<RateLimiter>,
proxy_manager: Arc<ProxyManager>,
classifier: Arc<dyn BodyClassifier>,
semaphore: Arc<Semaphore>,
throughput: Arc<ThroughputTracker>,
}
impl Scheduler {
#[allow(clippy::too_many_arguments)]
pub fn new(
config: Arc<ScatterProxyConfig>,
task_pool: Arc<TaskPool>,
health: Arc<HealthTracker>,
rate_limiter: Arc<RateLimiter>,
proxy_manager: Arc<ProxyManager>,
classifier: Arc<dyn BodyClassifier>,
semaphore: Arc<Semaphore>,
throughput: Arc<ThroughputTracker>,
) -> Self {
Self {
config,
task_pool,
health,
rate_limiter,
proxy_manager,
classifier,
semaphore,
throughput,
}
}
pub async fn run(self, mut shutdown: broadcast::Receiver<()>) {
debug!("scheduler started");
let scheduler = Arc::new(self);
let worker_count = scheduler.config.max_inflight.clamp(1, 32);
let mut workers = JoinSet::new();
for worker_id in 0..worker_count {
let scheduler = Arc::clone(&scheduler);
let mut worker_shutdown = shutdown.resubscribe();
workers.spawn(async move {
scheduler.worker_loop(worker_id, &mut worker_shutdown).await;
});
}
let _ = shutdown.recv().await;
workers.abort_all();
while workers.join_next().await.is_some() {}
debug!(workers = worker_count, "scheduler stopped");
}
async fn worker_loop(
self: Arc<Self>,
worker_id: usize,
shutdown: &mut broadcast::Receiver<()>,
) {
debug!(worker_id, "scheduler worker started");
loop {
let delay = self
.task_pool
.next_delayed_ready_in()
.unwrap_or(Duration::from_millis(10));
tokio::select! {
_ = shutdown.recv() => break,
_ = self.task_pool.notified() => {},
_ = tokio::time::sleep(delay.min(Duration::from_millis(50))) => {},
}
if let Some(task) = self.pick_task() {
self.task_pool.mark_dispatch();
self.process_task(task).await;
}
}
debug!(worker_id, "scheduler worker stopped");
}
fn is_in_cooldown(&self, proxy: &str, host: &str) -> bool {
let consecutive_fails = self.health.get_consecutive_fails(proxy, host);
let threshold = self.config.cooldown_consecutive_fails as u32;
if consecutive_fails < threshold {
return false;
}
let exponent = (consecutive_fails - threshold).min(32);
let factor = 2f64.powi(exponent as i32);
let cooldown_secs = self.config.cooldown_base.as_secs_f64() * factor;
let cooldown_secs = cooldown_secs.min(self.config.cooldown_max.as_secs_f64());
self.health.seconds_since_last_access(proxy, host) < cooldown_secs
}
fn pick_task(&self) -> Option<TaskEntry> {
self.task_pool.pick_next(&std::collections::HashSet::new())
}
async fn process_task(&self, mut task: TaskEntry) {
let host = task.host.clone();
let active_proxies = self.proxy_manager.get_active_proxies();
let mut available = Vec::new();
let mut skipped_rate_limit = 0u64;
let mut skipped_cooldown = 0u64;
for proxy in active_proxies {
if !self.rate_limiter.is_available(&proxy, &host) {
skipped_rate_limit += 1;
continue;
}
if self.is_in_cooldown(&proxy, &host) {
skipped_cooldown += 1;
continue;
}
available.push(proxy);
}
for _ in 0..skipped_rate_limit {
self.task_pool.mark_skipped_rate_limit();
}
for _ in 0..skipped_cooldown {
self.task_pool.mark_skipped_cooldown();
}
if available.is_empty() {
self.task_pool.mark_zero_available();
warn!(
host = %host,
attempt = task.attempts + 1,
skipped_rate_limit,
skipped_cooldown,
"no proxy currently available; delaying task"
);
task.attempts += 1;
task.last_error = format!(
"zero available proxies (rate_limit_skips={skipped_rate_limit}, cooldown_skips={skipped_cooldown})"
);
self.task_pool
.push_delayed(task, Duration::from_millis(100));
return;
}
let has_untested = available
.iter()
.any(|p| self.health.total_samples_for_proxy(p) == 0);
let avg_success_rate = self.health.avg_success_rate_for_host(&host);
let k = if has_untested {
available
.len()
.min(self.config.max_concurrent_per_request.max(3))
} else {
adaptive_k(
available.len(),
avg_success_rate,
self.config.max_concurrent_per_request,
)
}
.max(1);
let mut scored: Vec<(String, f64)> = available
.iter()
.map(|proxy| {
let s = compute_score(&self.health, proxy, &host);
(proxy.clone(), s)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let candidates: Vec<String> = scored.into_iter().take(k).map(|(proxy, _)| proxy).collect();
let mut join_set: JoinSet<(String, RaceOutcome, f64)> = JoinSet::new();
let mut actual_candidates: Vec<String> = Vec::new();
let mut skipped_no_permit = 0u64;
for proxy_url in &candidates {
let req = match task.request.try_clone() {
Some(r) => r,
None => {
self.task_pool.mark_failed();
return;
}
};
let client = match self.proxy_manager.get_client(proxy_url) {
Ok(c) => c,
Err(_) => continue,
};
let permit = match self.semaphore.clone().acquire_owned().await {
Ok(p) => p,
Err(_) => {
skipped_no_permit += 1;
continue;
}
};
self.rate_limiter.mark(proxy_url, &host);
let proxy_timeout = self.config.proxy_timeout;
let proxy_url_owned = proxy_url.clone();
actual_candidates.push(proxy_url.clone());
join_set.spawn(async move {
let start = Instant::now();
let outcome = match tokio::time::timeout(proxy_timeout, client.execute(req)).await {
Ok(Ok(response)) => RaceOutcome::Response(response),
Ok(Err(e)) => RaceOutcome::RequestError(e.to_string()),
Err(_) => RaceOutcome::Timeout,
};
let latency_ms = start.elapsed().as_secs_f64() * 1000.0;
drop(permit);
(proxy_url_owned, outcome, latency_ms)
});
}
for _ in 0..skipped_no_permit {
self.task_pool.mark_skipped_no_permit();
}
if join_set.is_empty() {
warn!(
host = %host,
attempt = task.attempts + 1,
candidate_count = candidates.len(),
skipped_no_permit,
"no candidate request launched; requeueing task"
);
task.attempts += 1;
task.last_error =
format!("no candidate launched (skipped_no_permit={skipped_no_permit})");
self.task_pool.push_delayed(task, Duration::from_millis(25));
return;
}
let mut success = false;
let mut last_error = String::new();
while let Some(result) = join_set.join_next().await {
let (proxy_url, outcome, latency_ms) = match result {
Ok(v) => v,
Err(e) => {
last_error = format!("task join error: {e}");
continue;
}
};
match outcome {
RaceOutcome::Response(response) => {
let status = response.status();
let headers = response.headers().clone();
let body = response.bytes().await.unwrap_or_default();
let verdict = self.classifier.classify(status, &headers, &body);
match verdict {
BodyVerdict::Success => {
self.health.record_success(&proxy_url, &host, latency_ms);
join_set.abort_all();
debug!(
host = %host,
proxy = %proxy_url,
attempt = task.attempts + 1,
latency_ms = latency_ms as u64,
"task completed"
);
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(ScatterResponse {
status,
headers,
body,
});
}
self.task_pool.mark_completed();
self.throughput.record();
success = true;
break;
}
BodyVerdict::ProxyBlocked => {
self.health.record_failure(&proxy_url, &host);
last_error = format!("proxy blocked (status={status})");
}
BodyVerdict::TargetError => {
last_error = format!("target error (status={status})");
}
}
}
RaceOutcome::RequestError(err) => {
self.health.record_failure(&proxy_url, &host);
last_error = err;
}
RaceOutcome::Timeout => {
self.health.record_failure(&proxy_url, &host);
last_error = "proxy timeout".into();
}
}
}
if success {
while join_set.join_next().await.is_some() {}
} else {
task.attempts += 1;
task.last_error = last_error;
debug!(host = %host, attempt = task.attempts, reason = %task.last_error, "task requeued");
self.task_pool.push_back(task);
}
for proxy_url in &actual_candidates {
let samples = self.health.total_samples_for_proxy(proxy_url);
if samples >= self.config.eviction_min_samples as u32 {
let global_rate = self.health.global_success_rate_for_proxy(proxy_url);
if global_rate == 0.0 {
self.proxy_manager.set_state(proxy_url, ProxyState::Dead);
debug!(host = %host, proxy = %proxy_url, samples = samples, "proxy dead | global success_rate=0%");
}
}
}
}
}