use std::sync::Mutex;
use std::time::{Duration, Instant};
use crate::cu_tracker::CuCostTable;
#[derive(Debug, Clone)]
pub struct RateLimiterConfig {
pub capacity: f64,
pub refill_rate: f64,
}
impl Default for RateLimiterConfig {
fn default() -> Self {
Self {
capacity: 300.0, refill_rate: 300.0, }
}
}
struct BucketState {
tokens: f64,
last_refill: Instant,
}
pub struct TokenBucket {
config: RateLimiterConfig,
state: Mutex<BucketState>,
}
impl TokenBucket {
pub fn new(config: RateLimiterConfig) -> Self {
Self {
state: Mutex::new(BucketState {
tokens: config.capacity,
last_refill: Instant::now(),
}),
config,
}
}
pub fn try_acquire(&self, cost: f64) -> bool {
let mut state = self.state.lock().unwrap();
self.refill(&mut state);
if state.tokens >= cost {
state.tokens -= cost;
true
} else {
false
}
}
pub fn wait_time(&self, cost: f64) -> Duration {
let state = self.state.lock().unwrap();
let deficit = cost - state.tokens;
if deficit <= 0.0 {
Duration::ZERO
} else {
Duration::from_secs_f64(deficit / self.config.refill_rate)
}
}
pub fn available(&self) -> f64 {
let mut state = self.state.lock().unwrap();
self.refill(&mut state);
state.tokens
}
fn refill(&self, state: &mut BucketState) {
let now = Instant::now();
let elapsed = now.duration_since(state.last_refill).as_secs_f64();
let new_tokens = elapsed * self.config.refill_rate;
state.tokens = (state.tokens + new_tokens).min(self.config.capacity);
state.last_refill = now;
}
}
pub struct RateLimiter {
bucket: TokenBucket,
pub default_cost: f64,
}
impl RateLimiter {
pub fn new(config: RateLimiterConfig) -> Self {
Self {
bucket: TokenBucket::new(config),
default_cost: 1.0,
}
}
pub fn try_acquire(&self) -> bool {
self.bucket.try_acquire(self.default_cost)
}
pub fn try_acquire_cost(&self, cost: f64) -> bool {
self.bucket.try_acquire(cost)
}
pub fn wait_time(&self) -> Duration {
self.bucket.wait_time(self.default_cost)
}
}
pub struct MethodAwareRateLimiter {
bucket: TokenBucket,
cost_table: CuCostTable,
}
impl MethodAwareRateLimiter {
pub fn new(config: RateLimiterConfig, cost_table: CuCostTable) -> Self {
Self {
bucket: TokenBucket::new(config),
cost_table,
}
}
pub fn try_acquire_method(&self, method: &str) -> bool {
let cost = self.cost_table.cost_for(method) as f64;
self.bucket.try_acquire(cost)
}
pub fn wait_time_for_method(&self, method: &str) -> Duration {
let cost = self.cost_table.cost_for(method) as f64;
self.bucket.wait_time(cost)
}
pub fn bucket(&self) -> &TokenBucket {
&self.bucket
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn acquire_within_capacity() {
let rl = RateLimiter::new(RateLimiterConfig {
capacity: 10.0,
refill_rate: 1.0,
});
for _ in 0..10 {
assert!(rl.try_acquire(), "should succeed within capacity");
}
}
#[test]
fn reject_when_empty() {
let rl = RateLimiter::new(RateLimiterConfig {
capacity: 3.0,
refill_rate: 0.0001, });
rl.try_acquire();
rl.try_acquire();
rl.try_acquire();
assert!(!rl.try_acquire(), "should be rate limited");
}
#[test]
fn wait_time_when_empty() {
let rl = RateLimiter::new(RateLimiterConfig {
capacity: 1.0,
refill_rate: 10.0, });
rl.try_acquire(); let wait = rl.wait_time();
assert!(
wait.as_millis() >= 50 && wait.as_millis() <= 200,
"unexpected wait time: {wait:?}"
);
}
#[test]
fn method_aware_uses_cu_costs() {
let table = CuCostTable::alchemy_defaults();
let rl_expensive = MethodAwareRateLimiter::new(
RateLimiterConfig {
capacity: 150.0,
refill_rate: 0.0001, },
table.clone(),
);
assert!(rl_expensive.try_acquire_method("eth_getLogs")); assert!(rl_expensive.try_acquire_method("eth_getLogs")); assert!(
!rl_expensive.try_acquire_method("eth_getLogs"),
"should be rate limited after 2 expensive calls"
);
let rl_cheap = MethodAwareRateLimiter::new(
RateLimiterConfig {
capacity: 150.0,
refill_rate: 0.0001,
},
CuCostTable::alchemy_defaults(),
);
let mut count = 0;
while rl_cheap.try_acquire_method("eth_blockNumber") {
count += 1;
if count > 20 {
break; }
}
assert_eq!(
count, 15,
"cheap method (10 CU) should fit 15 times in 150 capacity"
);
}
#[test]
fn method_aware_wait_time() {
let table = CuCostTable::alchemy_defaults();
let rl = MethodAwareRateLimiter::new(
RateLimiterConfig {
capacity: 300.0,
refill_rate: 100.0, },
table,
);
while rl.bucket().try_acquire(100.0) {}
let wait_cheap = rl.wait_time_for_method("eth_blockNumber");
let wait_expensive = rl.wait_time_for_method("eth_getLogs");
assert!(
wait_expensive > wait_cheap,
"expensive method should have longer wait: expensive={wait_expensive:?}, cheap={wait_cheap:?}"
);
let ratio = wait_expensive.as_secs_f64() / wait_cheap.as_secs_f64();
assert!(
ratio > 5.0 && ratio < 10.0,
"wait time ratio should be ~7.5, got {ratio:.2}"
);
}
#[test]
fn method_aware_unknown_method_uses_default() {
let table = CuCostTable::alchemy_defaults();
let rl = MethodAwareRateLimiter::new(
RateLimiterConfig {
capacity: 100.0,
refill_rate: 0.0001,
},
table,
);
assert!(rl.try_acquire_method("some_unknown_rpc_method")); assert!(rl.try_acquire_method("another_unknown_method")); assert!(
!rl.try_acquire_method("yet_another_unknown"),
"unknown method should use default cost (50) and be rate limited"
);
}
}