use std::collections::HashMap;
use std::time::Instant;
use parking_lot::Mutex;
use crate::method::RateBucket;
pub type PeerKey = Vec<u8>;
#[derive(Debug, Clone, Copy)]
pub struct BucketSpec {
pub fill_per_sec: f64,
pub capacity: f64,
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub buckets: HashMap<RateBucket, BucketSpec>,
}
impl RateLimitConfig {
pub fn defaults() -> Self {
let mut buckets = HashMap::new();
buckets.insert(
RateBucket::ReadLight,
BucketSpec {
fill_per_sec: 50.0,
capacity: 100.0,
},
);
buckets.insert(
RateBucket::ReadHeavy,
BucketSpec {
fill_per_sec: 5.0,
capacity: 10.0,
},
);
buckets.insert(
RateBucket::WriteLight,
BucketSpec {
fill_per_sec: 10.0,
capacity: 20.0,
},
);
buckets.insert(
RateBucket::WriteHeavy,
BucketSpec {
fill_per_sec: 1.0,
capacity: 5.0,
},
);
buckets.insert(
RateBucket::AdminOnly,
BucketSpec {
fill_per_sec: 1.0,
capacity: 3.0,
},
);
Self { buckets }
}
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self::defaults()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimitOutcome {
Allow,
Deny {
retry_after_secs: u64,
},
}
#[derive(Debug, Clone)]
pub struct RateLimitState {
inner: std::sync::Arc<Mutex<InnerState>>,
config: std::sync::Arc<RateLimitConfig>,
}
#[derive(Debug)]
struct InnerState {
buckets: HashMap<(PeerKey, RateBucket), Bucket>,
}
#[derive(Debug)]
struct Bucket {
tokens: f64,
last_refill: Instant,
}
impl RateLimitState {
pub fn new(config: RateLimitConfig) -> Self {
Self {
inner: std::sync::Arc::new(Mutex::new(InnerState {
buckets: HashMap::new(),
})),
config: std::sync::Arc::new(config),
}
}
pub fn check(&self, peer: &PeerKey, bucket: RateBucket) -> RateLimitOutcome {
let spec = match self.config.buckets.get(&bucket) {
Some(s) => *s,
None => {
tracing::warn!(?bucket, "rate bucket not configured; allowing");
return RateLimitOutcome::Allow;
}
};
let mut g = self.inner.lock();
let now = Instant::now();
let key = (peer.clone(), bucket);
let b = g.buckets.entry(key).or_insert(Bucket {
tokens: spec.capacity,
last_refill: now,
});
let elapsed = now.duration_since(b.last_refill).as_secs_f64();
b.tokens = (b.tokens + spec.fill_per_sec * elapsed).min(spec.capacity);
b.last_refill = now;
if b.tokens >= 1.0 {
b.tokens -= 1.0;
RateLimitOutcome::Allow
} else {
let deficit = 1.0 - b.tokens;
let wait_s = (deficit / spec.fill_per_sec).ceil() as u64;
RateLimitOutcome::Deny {
retry_after_secs: wait_s.max(1),
}
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitLayer {
pub state: RateLimitState,
}
impl RateLimitLayer {
pub fn new(state: RateLimitState) -> Self {
Self { state }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn first_request_allowed() {
let s = RateLimitState::new(RateLimitConfig::defaults());
let outcome = s.check(&vec![0; 32], RateBucket::ReadLight);
assert_eq!(outcome, RateLimitOutcome::Allow);
}
#[test]
fn exhaust_bucket_denies() {
let mut buckets = HashMap::new();
buckets.insert(
RateBucket::ReadLight,
BucketSpec {
fill_per_sec: 1.0,
capacity: 3.0,
},
);
let cfg = RateLimitConfig { buckets };
let s = RateLimitState::new(cfg);
for _ in 0..3 {
assert_eq!(
s.check(&vec![0; 32], RateBucket::ReadLight),
RateLimitOutcome::Allow
);
}
let outcome = s.check(&vec![0; 32], RateBucket::ReadLight);
match outcome {
RateLimitOutcome::Deny { retry_after_secs } => {
assert!(retry_after_secs >= 1);
}
_ => panic!("expected Deny"),
}
}
#[test]
fn buckets_are_per_peer() {
let mut buckets = HashMap::new();
buckets.insert(
RateBucket::ReadLight,
BucketSpec {
fill_per_sec: 1.0,
capacity: 2.0,
},
);
let s = RateLimitState::new(RateLimitConfig { buckets });
let peer_a = vec![0xAA; 32];
let peer_b = vec![0xBB; 32];
for _ in 0..2 {
assert_eq!(
s.check(&peer_a, RateBucket::ReadLight),
RateLimitOutcome::Allow
);
}
assert!(matches!(
s.check(&peer_a, RateBucket::ReadLight),
RateLimitOutcome::Deny { .. }
));
assert_eq!(
s.check(&peer_b, RateBucket::ReadLight),
RateLimitOutcome::Allow
);
}
#[test]
fn unconfigured_bucket_allows() {
let s = RateLimitState::new(RateLimitConfig {
buckets: HashMap::new(),
});
let outcome = s.check(&vec![0; 32], RateBucket::ReadLight);
assert_eq!(outcome, RateLimitOutcome::Allow);
}
}