use std::time::{Duration, Instant};
use crate::error::Error;
#[derive(Debug, Clone, Default)]
pub struct SpeedLimits {
pub max_recv_speed: Option<u64>,
pub max_send_speed: Option<u64>,
pub low_speed_limit: Option<u32>,
pub low_speed_time: Option<Duration>,
}
impl SpeedLimits {
#[must_use]
pub const fn has_limits(&self) -> bool {
self.max_recv_speed.is_some()
|| self.max_send_speed.is_some()
|| (self.low_speed_limit.is_some() && self.low_speed_time.is_some())
}
}
pub(crate) struct RateLimiter {
max_speed: Option<u64>,
low_speed_limit: Option<u32>,
low_speed_time: Option<Duration>,
bytes_transferred: u64,
start_time: Instant,
low_speed_start: Option<Instant>,
}
impl RateLimiter {
pub(crate) fn for_recv(limits: &SpeedLimits) -> Self {
Self {
max_speed: limits.max_recv_speed,
low_speed_limit: limits.low_speed_limit,
low_speed_time: limits.low_speed_time,
bytes_transferred: 0,
start_time: Instant::now(),
low_speed_start: None,
}
}
pub(crate) fn for_send(limits: &SpeedLimits) -> Self {
Self {
max_speed: limits.max_send_speed,
low_speed_limit: limits.low_speed_limit,
low_speed_time: limits.low_speed_time,
bytes_transferred: 0,
start_time: Instant::now(),
low_speed_start: None,
}
}
pub(crate) const fn is_active(&self) -> bool {
self.max_speed.is_some()
|| (self.low_speed_limit.is_some() && self.low_speed_time.is_some())
}
pub(crate) async fn record(&mut self, bytes: usize) -> Result<(), Error> {
self.bytes_transferred += bytes as u64;
self.check_low_speed()?;
self.throttle().await;
Ok(())
}
fn check_low_speed(&mut self) -> Result<(), Error> {
let (Some(limit), Some(time)) = (self.low_speed_limit, self.low_speed_time) else {
return Ok(());
};
let elapsed = self.start_time.elapsed();
if elapsed < Duration::from_millis(100) {
return Ok(());
}
#[allow(clippy::cast_precision_loss)]
let current_speed = self.bytes_transferred as f64 / elapsed.as_secs_f64();
#[allow(clippy::cast_precision_loss)]
if current_speed < f64::from(limit) {
match self.low_speed_start {
None => {
self.low_speed_start = Some(Instant::now());
}
Some(start) => {
if start.elapsed() >= time {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
return Err(Error::SpeedLimit {
speed: current_speed.max(0.0) as u64,
limit: u64::from(limit),
duration: start.elapsed(),
});
}
}
}
} else {
self.low_speed_start = None;
}
Ok(())
}
async fn throttle(&self) {
let Some(max_speed) = self.max_speed else {
return;
};
if max_speed == 0 {
return;
}
let elapsed = self.start_time.elapsed();
#[allow(clippy::cast_precision_loss)]
let expected_time =
Duration::from_secs_f64(self.bytes_transferred as f64 / max_speed as f64);
if let Some(delay) = expected_time.checked_sub(elapsed) {
tokio::time::sleep(delay).await;
}
}
}
pub(crate) const THROTTLE_CHUNK_SIZE: usize = 16 * 1024;
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn speed_limits_default_has_no_limits() {
let limits = SpeedLimits::default();
assert!(!limits.has_limits());
}
#[test]
fn speed_limits_has_limits_with_recv() {
let limits = SpeedLimits { max_recv_speed: Some(1024), ..Default::default() };
assert!(limits.has_limits());
}
#[test]
fn speed_limits_has_limits_with_send() {
let limits = SpeedLimits { max_send_speed: Some(1024), ..Default::default() };
assert!(limits.has_limits());
}
#[test]
fn speed_limits_has_limits_with_low_speed() {
let limits = SpeedLimits {
low_speed_limit: Some(100),
low_speed_time: Some(Duration::from_secs(10)),
..Default::default()
};
assert!(limits.has_limits());
}
#[test]
fn speed_limits_no_limits_with_only_low_speed_limit() {
let limits = SpeedLimits { low_speed_limit: Some(100), ..Default::default() };
assert!(!limits.has_limits());
}
#[test]
fn rate_limiter_inactive_by_default() {
let limits = SpeedLimits::default();
let limiter = RateLimiter::for_recv(&limits);
assert!(!limiter.is_active());
}
#[test]
fn rate_limiter_active_with_max_speed() {
let limits = SpeedLimits { max_recv_speed: Some(1024), ..Default::default() };
let limiter = RateLimiter::for_recv(&limits);
assert!(limiter.is_active());
}
#[test]
fn rate_limiter_for_send_uses_send_speed() {
let limits = SpeedLimits {
max_recv_speed: Some(1024),
max_send_speed: Some(2048),
..Default::default()
};
let limiter = RateLimiter::for_send(&limits);
assert!(limiter.is_active());
assert_eq!(limiter.max_speed, Some(2048));
}
#[tokio::test]
async fn rate_limiter_record_no_limits() {
let limits = SpeedLimits::default();
let mut limiter = RateLimiter::for_recv(&limits);
limiter.record(1000).await.unwrap();
assert_eq!(limiter.bytes_transferred, 1000);
}
#[tokio::test]
async fn rate_limiter_throttle_slows_transfer() {
let limits = SpeedLimits {
max_recv_speed: Some(1000), ..Default::default()
};
let mut limiter = RateLimiter::for_recv(&limits);
let start = Instant::now();
limiter.record(1000).await.unwrap();
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(800), "expected >= 800ms, got {elapsed:?}");
}
#[tokio::test]
async fn rate_limiter_low_speed_triggers() {
let limits = SpeedLimits {
low_speed_limit: Some(1_000_000), low_speed_time: Some(Duration::from_millis(100)),
..Default::default()
};
let mut limiter = RateLimiter::for_recv(&limits);
tokio::time::sleep(Duration::from_millis(150)).await;
limiter.record(1).await.unwrap();
tokio::time::sleep(Duration::from_millis(150)).await;
let result = limiter.record(1).await;
assert!(result.is_err(), "expected SpeedLimit error");
let err = result.unwrap_err();
assert!(matches!(err, Error::SpeedLimit { .. }), "expected SpeedLimit, got {err:?}");
}
#[tokio::test]
async fn rate_limiter_low_speed_resets_on_fast() {
let limits = SpeedLimits {
low_speed_limit: Some(100), low_speed_time: Some(Duration::from_secs(30)),
..Default::default()
};
let mut limiter = RateLimiter::for_recv(&limits);
limiter.record(10_000).await.unwrap();
assert!(limiter.low_speed_start.is_none());
}
}