#![allow(dead_code)]
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct RateLimit {
pub bytes_per_sec: u64,
pub burst_bytes: u64,
}
impl RateLimit {
#[must_use]
pub fn new(bytes_per_sec: u64) -> Self {
Self {
bytes_per_sec,
burst_bytes: bytes_per_sec,
}
}
#[must_use]
pub fn with_burst(bytes_per_sec: u64, burst_bytes: u64) -> Self {
Self {
bytes_per_sec,
burst_bytes,
}
}
}
#[derive(Debug, Clone)]
pub struct TokenBucket {
pub capacity: f64,
pub tokens: f64,
pub refill_rate: f64,
pub last_refill_ms: u64,
}
impl TokenBucket {
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn new(limit: RateLimit, now_ms: u64) -> Self {
let refill_rate = limit.bytes_per_sec as f64 / 1000.0; Self {
capacity: limit.burst_bytes as f64,
tokens: limit.burst_bytes as f64, refill_rate,
last_refill_ms: now_ms,
}
}
#[allow(clippy::cast_precision_loss)]
pub fn try_consume(&mut self, bytes: u64, now_ms: u64) -> bool {
if now_ms > self.last_refill_ms {
let elapsed_ms = (now_ms - self.last_refill_ms) as f64;
self.tokens = (self.tokens + elapsed_ms * self.refill_rate).min(self.capacity);
self.last_refill_ms = now_ms;
}
if self.tokens >= bytes as f64 {
self.tokens -= bytes as f64;
true
} else {
false
}
}
#[must_use]
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
pub fn wait_ms_for(&self, bytes: u64) -> u64 {
let needed = bytes as f64 - self.tokens;
if needed <= 0.0 || self.refill_rate <= 0.0 {
return 0;
}
(needed / self.refill_rate).ceil() as u64
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RateLimitResult {
Allowed,
Throttled(u64),
}
impl RateLimitResult {
#[must_use]
pub fn is_allowed(&self) -> bool {
matches!(self, Self::Allowed)
}
}
pub struct RateLimiter {
buckets: HashMap<String, TokenBucket>,
}
impl RateLimiter {
#[must_use]
pub fn new() -> Self {
Self {
buckets: HashMap::new(),
}
}
pub fn add_stream(&mut self, id: &str, limit: RateLimit, now_ms: u64) {
let bucket = TokenBucket::new(limit, now_ms);
self.buckets.insert(id.to_string(), bucket);
}
pub fn check_and_consume(&mut self, id: &str, bytes: u64, now_ms: u64) -> RateLimitResult {
let Some(bucket) = self.buckets.get_mut(id) else {
return RateLimitResult::Allowed;
};
if bucket.try_consume(bytes, now_ms) {
RateLimitResult::Allowed
} else {
let wait_ms = bucket.wait_ms_for(bytes);
RateLimitResult::Throttled(wait_ms)
}
}
pub fn remove_stream(&mut self, id: &str) {
self.buckets.remove(id);
}
#[must_use]
pub fn stream_count(&self) -> usize {
self.buckets.len()
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
pub struct BandwidthTracker {
window_ms: u64,
observations: Vec<(u64, u64)>,
}
impl BandwidthTracker {
#[must_use]
pub fn new(window_ms: u64) -> Self {
Self {
window_ms,
observations: Vec::new(),
}
}
pub fn record(&mut self, bytes: u64, now_ms: u64) {
self.observations.push((now_ms, bytes));
self.prune(now_ms);
}
#[must_use]
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
pub fn current_bps(&self, now_ms: u64) -> u64 {
let cutoff = now_ms.saturating_sub(self.window_ms);
let total_bytes: u64 = self
.observations
.iter()
.filter(|(ts, _)| *ts >= cutoff)
.map(|(_, bytes)| bytes)
.sum();
if self.window_ms == 0 {
return 0;
}
let window_sec = self.window_ms as f64 / 1000.0;
(total_bytes as f64 / window_sec) as u64
}
#[must_use]
pub fn total_bytes_in_window(&self, now_ms: u64) -> u64 {
let cutoff = now_ms.saturating_sub(self.window_ms);
self.observations
.iter()
.filter(|(ts, _)| *ts >= cutoff)
.map(|(_, bytes)| bytes)
.sum()
}
fn prune(&mut self, now_ms: u64) {
let cutoff = now_ms.saturating_sub(self.window_ms);
self.observations.retain(|(ts, _)| *ts >= cutoff);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IoDirection {
Read,
Write,
}
#[derive(Debug, Clone, Copy)]
pub struct DirectionalLimits {
pub read_bytes_per_sec: u64,
pub read_burst_bytes: u64,
pub write_bytes_per_sec: u64,
pub write_burst_bytes: u64,
}
impl DirectionalLimits {
#[must_use]
pub fn symmetric(bytes_per_sec: u64) -> Self {
Self {
read_bytes_per_sec: bytes_per_sec,
read_burst_bytes: bytes_per_sec,
write_bytes_per_sec: bytes_per_sec,
write_burst_bytes: bytes_per_sec,
}
}
#[must_use]
pub fn asymmetric(read_bps: u64, write_bps: u64) -> Self {
Self {
read_bytes_per_sec: read_bps,
read_burst_bytes: read_bps,
write_bytes_per_sec: write_bps,
write_burst_bytes: write_bps,
}
}
#[must_use]
pub fn with_burst(mut self, read_burst: u64, write_burst: u64) -> Self {
self.read_burst_bytes = read_burst;
self.write_burst_bytes = write_burst;
self
}
#[must_use]
pub fn for_direction(&self, direction: IoDirection) -> RateLimit {
match direction {
IoDirection::Read => {
RateLimit::with_burst(self.read_bytes_per_sec, self.read_burst_bytes)
}
IoDirection::Write => {
RateLimit::with_burst(self.write_bytes_per_sec, self.write_burst_bytes)
}
}
}
}
pub struct DirectionalRateLimiter {
read_bucket: TokenBucket,
write_bucket: TokenBucket,
config: DirectionalLimits,
}
impl DirectionalRateLimiter {
#[must_use]
pub fn new(limits: DirectionalLimits, now_ms: u64) -> Self {
let read_limit = limits.for_direction(IoDirection::Read);
let write_limit = limits.for_direction(IoDirection::Write);
Self {
read_bucket: TokenBucket::new(read_limit, now_ms),
write_bucket: TokenBucket::new(write_limit, now_ms),
config: limits,
}
}
pub fn check_and_consume(
&mut self,
direction: IoDirection,
bytes: u64,
now_ms: u64,
) -> RateLimitResult {
let bucket = match direction {
IoDirection::Read => &mut self.read_bucket,
IoDirection::Write => &mut self.write_bucket,
};
if bucket.try_consume(bytes, now_ms) {
RateLimitResult::Allowed
} else {
let wait_ms = bucket.wait_ms_for(bytes);
RateLimitResult::Throttled(wait_ms)
}
}
#[must_use]
pub fn wait_ms_for(&self, direction: IoDirection, bytes: u64) -> u64 {
match direction {
IoDirection::Read => self.read_bucket.wait_ms_for(bytes),
IoDirection::Write => self.write_bucket.wait_ms_for(bytes),
}
}
#[must_use]
pub fn config(&self) -> &DirectionalLimits {
&self.config
}
#[must_use]
pub fn available_tokens(&self, direction: IoDirection) -> f64 {
match direction {
IoDirection::Read => self.read_bucket.tokens,
IoDirection::Write => self.write_bucket.tokens,
}
}
}
pub struct DirectionalBandwidthTracker {
read_tracker: BandwidthTracker,
write_tracker: BandwidthTracker,
}
impl DirectionalBandwidthTracker {
#[must_use]
pub fn new(window_ms: u64) -> Self {
Self {
read_tracker: BandwidthTracker::new(window_ms),
write_tracker: BandwidthTracker::new(window_ms),
}
}
pub fn record(&mut self, direction: IoDirection, bytes: u64, now_ms: u64) {
match direction {
IoDirection::Read => self.read_tracker.record(bytes, now_ms),
IoDirection::Write => self.write_tracker.record(bytes, now_ms),
}
}
#[must_use]
pub fn current_bps(&self, direction: IoDirection, now_ms: u64) -> u64 {
match direction {
IoDirection::Read => self.read_tracker.current_bps(now_ms),
IoDirection::Write => self.write_tracker.current_bps(now_ms),
}
}
#[must_use]
pub fn total_bytes_in_window(&self, direction: IoDirection, now_ms: u64) -> u64 {
match direction {
IoDirection::Read => self.read_tracker.total_bytes_in_window(now_ms),
IoDirection::Write => self.write_tracker.total_bytes_in_window(now_ms),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_bucket_starts_full() {
let limit = RateLimit::with_burst(1000, 5000);
let bucket = TokenBucket::new(limit, 0);
assert!((bucket.tokens - 5000.0).abs() < f64::EPSILON);
assert!((bucket.capacity - 5000.0).abs() < f64::EPSILON);
}
#[test]
fn test_token_bucket_consume_success() {
let limit = RateLimit::new(1_000_000);
let mut bucket = TokenBucket::new(limit, 0);
assert!(bucket.try_consume(500_000, 0));
assert!((bucket.tokens - 500_000.0).abs() < f64::EPSILON);
}
#[test]
fn test_token_bucket_consume_fail_insufficient() {
let limit = RateLimit::new(1000);
let mut bucket = TokenBucket::new(limit, 0);
bucket.try_consume(1000, 0);
assert!(!bucket.try_consume(1, 0));
}
#[test]
fn test_token_bucket_refill_over_time() {
let limit = RateLimit::new(1000); let mut bucket = TokenBucket::new(limit, 0);
bucket.try_consume(1000, 0); assert!(bucket.try_consume(500, 500));
}
#[test]
fn test_token_bucket_capped_at_capacity() {
let limit = RateLimit::new(100);
let mut bucket = TokenBucket::new(limit, 0);
bucket.try_consume(0, 100_000);
assert!(bucket.tokens <= bucket.capacity + f64::EPSILON);
}
#[test]
fn test_token_bucket_wait_ms() {
let limit = RateLimit::new(1000); let mut bucket = TokenBucket::new(limit, 0);
bucket.try_consume(1000, 0); let wait = bucket.wait_ms_for(500);
assert!(wait >= 500);
}
#[test]
fn test_rate_limit_result_is_allowed() {
assert!(RateLimitResult::Allowed.is_allowed());
assert!(!RateLimitResult::Throttled(100).is_allowed());
}
#[test]
fn test_rate_limiter_no_stream_always_allowed() {
let mut rl = RateLimiter::new();
assert_eq!(
rl.check_and_consume("unknown", 1_000_000, 0),
RateLimitResult::Allowed
);
}
#[test]
fn test_rate_limiter_stream_allowed() {
let mut rl = RateLimiter::new();
rl.add_stream("s1", RateLimit::new(10_000), 0);
assert_eq!(
rl.check_and_consume("s1", 5_000, 0),
RateLimitResult::Allowed
);
}
#[test]
fn test_rate_limiter_stream_throttled() {
let mut rl = RateLimiter::new();
rl.add_stream("s2", RateLimit::new(100), 0);
rl.check_and_consume("s2", 100, 0); let result = rl.check_and_consume("s2", 50, 0);
assert!(matches!(result, RateLimitResult::Throttled(_)));
}
#[test]
fn test_rate_limiter_remove_stream() {
let mut rl = RateLimiter::new();
rl.add_stream("s3", RateLimit::new(100), 0);
rl.remove_stream("s3");
assert_eq!(rl.stream_count(), 0);
assert_eq!(
rl.check_and_consume("s3", 99999, 0),
RateLimitResult::Allowed
);
}
#[test]
fn test_bandwidth_tracker_empty() {
let tracker = BandwidthTracker::new(1000);
assert_eq!(tracker.current_bps(0), 0);
}
#[test]
fn test_bandwidth_tracker_single_observation() {
let mut tracker = BandwidthTracker::new(1000); tracker.record(500_000, 500); assert_eq!(tracker.current_bps(1000), 500_000);
}
#[test]
fn test_bandwidth_tracker_old_observations_pruned() {
let mut tracker = BandwidthTracker::new(1000);
tracker.record(1_000_000, 0); assert_eq!(tracker.current_bps(2000), 0);
}
#[test]
fn test_bandwidth_tracker_total_bytes_in_window() {
let mut tracker = BandwidthTracker::new(2000);
tracker.record(100, 0);
tracker.record(200, 1000);
tracker.record(400, 2000);
assert_eq!(tracker.total_bytes_in_window(2000), 700);
}
#[test]
fn test_directional_limits_symmetric() {
let limits = DirectionalLimits::symmetric(1_000_000);
assert_eq!(limits.read_bytes_per_sec, 1_000_000);
assert_eq!(limits.write_bytes_per_sec, 1_000_000);
assert_eq!(limits.read_burst_bytes, 1_000_000);
assert_eq!(limits.write_burst_bytes, 1_000_000);
}
#[test]
fn test_directional_limits_asymmetric() {
let limits = DirectionalLimits::asymmetric(10_000_000, 1_000_000);
assert_eq!(limits.read_bytes_per_sec, 10_000_000);
assert_eq!(limits.write_bytes_per_sec, 1_000_000);
}
#[test]
fn test_directional_limits_for_direction() {
let limits = DirectionalLimits::asymmetric(2000, 1000);
let read_limit = limits.for_direction(IoDirection::Read);
assert_eq!(read_limit.bytes_per_sec, 2000);
let write_limit = limits.for_direction(IoDirection::Write);
assert_eq!(write_limit.bytes_per_sec, 1000);
}
#[test]
fn test_directional_rate_limiter_read_allowed() {
let limits = DirectionalLimits::asymmetric(10_000, 1_000);
let mut rl = DirectionalRateLimiter::new(limits, 0);
assert_eq!(
rl.check_and_consume(IoDirection::Read, 5_000, 0),
RateLimitResult::Allowed
);
}
#[test]
fn test_directional_rate_limiter_write_throttled() {
let limits = DirectionalLimits::asymmetric(10_000, 500);
let mut rl = DirectionalRateLimiter::new(limits, 0);
rl.check_and_consume(IoDirection::Write, 500, 0);
let result = rl.check_and_consume(IoDirection::Write, 100, 0);
assert!(matches!(result, RateLimitResult::Throttled(_)));
}
#[test]
fn test_directional_rate_limiter_read_and_write_independent() {
let limits = DirectionalLimits::asymmetric(10_000, 500);
let mut rl = DirectionalRateLimiter::new(limits, 0);
rl.check_and_consume(IoDirection::Write, 500, 0);
assert!(matches!(
rl.check_and_consume(IoDirection::Write, 1, 0),
RateLimitResult::Throttled(_)
));
assert_eq!(
rl.check_and_consume(IoDirection::Read, 5_000, 0),
RateLimitResult::Allowed
);
}
#[test]
fn test_directional_rate_limiter_wait_ms() {
let limits = DirectionalLimits::symmetric(1000); let mut rl = DirectionalRateLimiter::new(limits, 0);
rl.check_and_consume(IoDirection::Write, 1000, 0); let wait = rl.wait_ms_for(IoDirection::Write, 500);
assert!(wait >= 500);
assert_eq!(rl.wait_ms_for(IoDirection::Read, 500), 0);
}
#[test]
fn test_directional_rate_limiter_available_tokens() {
let limits = DirectionalLimits::symmetric(1000);
let rl = DirectionalRateLimiter::new(limits, 0);
assert!((rl.available_tokens(IoDirection::Read) - 1000.0).abs() < f64::EPSILON);
assert!((rl.available_tokens(IoDirection::Write) - 1000.0).abs() < f64::EPSILON);
}
#[test]
fn test_directional_bandwidth_tracker_separate() {
let mut tracker = DirectionalBandwidthTracker::new(1000);
tracker.record(IoDirection::Read, 500_000, 500);
tracker.record(IoDirection::Write, 100_000, 500);
assert_eq!(tracker.current_bps(IoDirection::Read, 1000), 500_000);
assert_eq!(tracker.current_bps(IoDirection::Write, 1000), 100_000);
}
#[test]
fn test_directional_bandwidth_tracker_total_bytes() {
let mut tracker = DirectionalBandwidthTracker::new(2000);
tracker.record(IoDirection::Read, 100, 0);
tracker.record(IoDirection::Write, 200, 500);
assert_eq!(tracker.total_bytes_in_window(IoDirection::Read, 2000), 100);
assert_eq!(tracker.total_bytes_in_window(IoDirection::Write, 2000), 200);
}
}