use super::{FetchRequest, Middleware};
use crate::{
error::KumoError,
extract::Response,
logging::{event, target},
};
use std::{
sync::{Arc, Mutex},
time::Duration,
};
struct ThrottleState {
current_delay: Duration,
ewma_latency_secs: f64,
}
pub struct AutoThrottle {
target_concurrency: f64,
min_delay: Duration,
max_delay: Duration,
backoff_statuses: Vec<u16>,
state: Arc<Mutex<ThrottleState>>,
}
impl AutoThrottle {
pub fn new() -> Self {
let start_delay = Duration::from_millis(500);
Self {
target_concurrency: 1.0,
min_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(60),
backoff_statuses: vec![429, 503],
state: Arc::new(Mutex::new(ThrottleState {
current_delay: start_delay,
ewma_latency_secs: start_delay.as_secs_f64(),
})),
}
}
pub fn target_concurrency(mut self, n: f64) -> Self {
self.target_concurrency = n.max(0.1);
self
}
pub fn start_delay(self, d: Duration) -> Self {
let mut st = self.state.lock().unwrap();
st.current_delay = d;
st.ewma_latency_secs = d.as_secs_f64();
drop(st);
self
}
pub fn min_delay(mut self, d: Duration) -> Self {
self.min_delay = d;
self
}
pub fn max_delay(mut self, d: Duration) -> Self {
self.max_delay = d;
self
}
pub fn backoff_statuses(mut self, codes: Vec<u16>) -> Self {
self.backoff_statuses = codes;
self
}
}
impl Default for AutoThrottle {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl Middleware for AutoThrottle {
async fn before_request(&self, _request: &mut FetchRequest) -> Result<(), KumoError> {
let delay = self.state.lock().unwrap().current_delay;
tokio::time::sleep(delay).await;
Ok(())
}
async fn after_response(&self, response: &mut Response) -> Result<(), KumoError> {
let latency = response.elapsed().as_secs_f64();
let mut st = self.state.lock().unwrap();
st.ewma_latency_secs = 0.3 * latency + 0.7 * st.ewma_latency_secs;
let new_delay = if self.backoff_statuses.contains(&response.status()) {
(st.current_delay * 2).min(self.max_delay)
} else {
let target_secs = st.ewma_latency_secs / self.target_concurrency;
let blended = (st.current_delay.as_secs_f64() + target_secs) / 2.0;
Duration::from_secs_f64(blended)
};
st.current_delay = new_delay.clamp(self.min_delay, self.max_delay);
tracing::debug!(
target: target::REQUEST,
event = event::REQUEST_AUTOTHROTTLE,
url = response.url(),
delay_ms = st.current_delay.as_millis(),
ewma_latency_ms = (st.ewma_latency_secs * 1000.0) as u64,
status = response.status(),
"request.autothrottle"
);
Ok(())
}
}