use std::num::NonZero;
use std::sync::Arc;
use std::time::Duration;
use axum::body::Body;
use axum::http::Request;
use axum::response::Response;
use futures_util::StreamExt;
use governor::Jitter;
use crate::quota_config::LimitKey;
use super::limiter_pool::KeyedRateLimiter;
pub(super) fn throttle_request(
req: Request<Body>,
key: &LimitKey,
limiter: &Arc<KeyedRateLimiter>,
) -> Request<Body> {
let (parts, body) = req.into_parts();
Request::from_parts(parts, throttle_body(body, key, limiter))
}
pub(super) fn throttle_response(
res: Response<Body>,
key: &LimitKey,
limiter: &Arc<KeyedRateLimiter>,
) -> Response<Body> {
let (parts, body) = res.into_parts();
Response::from_parts(parts, throttle_body(body, key, limiter))
}
fn throttle_body(body: Body, key: &LimitKey, limiter: &Arc<KeyedRateLimiter>) -> Body {
let body_stream = body.into_data_stream();
let limiter = limiter.clone();
let key = key.clone();
let throttled = body_stream
.map(move |chunk| {
let limiter = limiter.clone();
let key = key.clone();
let jitter = Jitter::new(Duration::from_millis(25), Duration::from_millis(500));
async move {
let bytes = match chunk {
Ok(actual_chunk) => actual_chunk,
Err(e) => return Err(e),
};
let chunk_kilobytes = bytes.len().div_ceil(1024);
for _ in 0..chunk_kilobytes {
if limiter
.until_key_n_ready_with_jitter(
&key,
NonZero::new(1).expect("1 is always non zero"),
jitter,
)
.await
.is_err()
{
tracing::error!(
"Rate limiter rejected a 1 KB cell — limit may be misconfigured"
);
return Err(axum::Error::new("Rate limit exceeded"));
};
}
Ok(bytes)
}
})
.buffered(1);
Body::from_stream(throttled)
}