use dashmap::DashMap;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct TargetMetrics {
pub current_concurrency: usize,
pub avg_response_time_ms: f64,
pub error_count: u64,
pub success_count: u64,
pub total_requests: u64,
}
impl TargetMetrics {
fn new(initial_concurrency: usize) -> Self {
Self {
current_concurrency: initial_concurrency,
avg_response_time_ms: 0.0,
error_count: 0,
success_count: 0,
total_requests: 0,
}
}
fn error_rate(&self) -> f64 {
if self.total_requests == 0 {
return 0.0;
}
self.error_count as f64 / self.total_requests as f64
}
}
pub struct AdaptiveConcurrencyTracker {
targets: Arc<DashMap<String, TargetMetrics>>,
initial_concurrency: usize,
max_concurrency: usize,
}
impl AdaptiveConcurrencyTracker {
pub fn new(initial_concurrency: usize, max_concurrency: usize) -> Self {
Self {
targets: Arc::new(DashMap::new()),
initial_concurrency,
max_concurrency,
}
}
pub async fn get_concurrency(&self, target_domain: &str) -> usize {
self.targets
.get(target_domain)
.map(|m| m.current_concurrency)
.unwrap_or(self.initial_concurrency)
}
pub async fn record_success(&self, target_domain: &str, response_time: Duration) {
let mut entry = self
.targets
.entry(target_domain.to_string())
.or_insert_with(|| TargetMetrics::new(self.initial_concurrency));
let metrics = entry.value_mut();
metrics.success_count += 1;
metrics.total_requests += 1;
let response_ms = response_time.as_millis() as f64;
if metrics.avg_response_time_ms == 0.0 {
metrics.avg_response_time_ms = response_ms;
} else {
metrics.avg_response_time_ms = 0.7 * metrics.avg_response_time_ms + 0.3 * response_ms;
}
self.adjust_concurrency(metrics);
debug!(
"Target {}: response_time={:.2}ms, concurrency={}, error_rate={:.2}%",
target_domain,
metrics.avg_response_time_ms,
metrics.current_concurrency,
metrics.error_rate() * 100.0
);
}
pub async fn record_error(&self, target_domain: &str) {
let mut entry = self
.targets
.entry(target_domain.to_string())
.or_insert_with(|| TargetMetrics::new(self.initial_concurrency));
let metrics = entry.value_mut();
metrics.error_count += 1;
metrics.total_requests += 1;
self.adjust_concurrency(metrics);
debug!(
"Target {}: ERROR recorded, concurrency={}, error_rate={:.2}%",
target_domain,
metrics.current_concurrency,
metrics.error_rate() * 100.0
);
}
fn adjust_concurrency(&self, metrics: &mut TargetMetrics) {
let error_rate = metrics.error_rate();
if error_rate > 0.1 {
let new_concurrency = (metrics.current_concurrency / 2).max(1);
if new_concurrency != metrics.current_concurrency {
info!(
"🔻 Reducing concurrency: {} -> {} (high error rate: {:.1}%)",
metrics.current_concurrency,
new_concurrency,
error_rate * 100.0
);
metrics.current_concurrency = new_concurrency;
}
return;
}
if metrics.avg_response_time_ms > 500.0 {
let new_concurrency = (metrics.current_concurrency / 2).max(1);
if new_concurrency != metrics.current_concurrency {
info!(
"🔻 Reducing concurrency: {} -> {} (slow response: {:.2}ms)",
metrics.current_concurrency, new_concurrency, metrics.avg_response_time_ms
);
metrics.current_concurrency = new_concurrency;
}
return;
}
if metrics.avg_response_time_ms < 100.0 && error_rate < 0.01 {
let new_concurrency = (metrics.current_concurrency * 2).min(self.max_concurrency);
if new_concurrency != metrics.current_concurrency && metrics.total_requests > 10 {
info!(
"🔺 Increasing concurrency: {} -> {} (fast response: {:.2}ms)",
metrics.current_concurrency, new_concurrency, metrics.avg_response_time_ms
);
metrics.current_concurrency = new_concurrency;
}
}
}
pub async fn get_metrics(&self, target_domain: &str) -> Option<TargetMetrics> {
self.targets.get(target_domain).map(|m| m.value().clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_initial_concurrency() {
let tracker = AdaptiveConcurrencyTracker::new(10, 50);
let concurrency = tracker.get_concurrency("example.com").await;
assert_eq!(concurrency, 10);
}
#[tokio::test]
async fn test_success_recording() {
let tracker = AdaptiveConcurrencyTracker::new(10, 50);
tracker
.record_success("example.com", Duration::from_millis(50))
.await;
let metrics = tracker.get_metrics("example.com").await.unwrap();
assert_eq!(metrics.success_count, 1);
assert_eq!(metrics.total_requests, 1);
assert!(metrics.avg_response_time_ms > 0.0);
}
#[tokio::test]
async fn test_error_recording() {
let tracker = AdaptiveConcurrencyTracker::new(10, 50);
tracker.record_error("example.com").await;
let metrics = tracker.get_metrics("example.com").await.unwrap();
assert_eq!(metrics.error_count, 1);
assert_eq!(metrics.total_requests, 1);
}
#[tokio::test]
async fn test_concurrency_increases_on_fast_responses() {
let tracker = AdaptiveConcurrencyTracker::new(10, 50);
for _ in 0..20 {
tracker
.record_success("example.com", Duration::from_millis(50))
.await;
}
let concurrency = tracker.get_concurrency("example.com").await;
assert!(concurrency > 10);
}
#[tokio::test]
async fn test_concurrency_decreases_on_slow_responses() {
let tracker = AdaptiveConcurrencyTracker::new(10, 50);
for _ in 0..5 {
tracker
.record_success("example.com", Duration::from_millis(600))
.await;
}
let concurrency = tracker.get_concurrency("example.com").await;
assert!(concurrency < 10);
}
#[tokio::test]
async fn test_concurrency_decreases_on_errors() {
let tracker = AdaptiveConcurrencyTracker::new(10, 50);
for _ in 0..5 {
tracker.record_error("example.com").await;
}
let concurrency = tracker.get_concurrency("example.com").await;
assert!(concurrency < 10);
}
}