use crate::{RateLimitRequest, Sender};
use std::collections::HashMap;
use std::time::{Duration, Instant};
const BUCKET_ERROR: f64 = 0.100;
pub(crate) struct Leaser {
pub(crate) table: HashMap<Sender, (u64, Instant)>,
_node_id: u64,
max_limit: u64,
max_limit_payer: u64,
epoch: Instant,
interval: Duration,
refill_amount: u64,
}
impl Leaser {
pub fn new(
_node_id: u64,
max_limit: u64,
max_limit_payer: u64,
interval: Duration,
refill_amount: u64,
) -> Self {
Self {
table: HashMap::new(),
max_limit,
max_limit_payer,
_node_id,
epoch: Instant::now(),
interval,
refill_amount,
}
}
pub fn touch(
&mut self,
req: RateLimitRequest,
default: (u64, Instant),
cost: u64,
) -> Option<(u64, Instant)> {
let e = self.table.entry(req.sender).or_insert(default);
let (credits, deadline) = e;
if cost > *credits {
return None;
}
*credits = (*credits).saturating_sub(cost);
*deadline = (*deadline).max(Instant::now() + self.interval);
Some((*credits, *deadline))
}
pub(crate) fn get_max_limit(&self, sender: &Sender) -> u64 {
match sender {
Sender::Node(_) => u64::MAX,
Sender::Payer(_) => self.max_limit_payer,
_ => self.max_limit,
}
}
pub fn sync(&mut self, other: &[HashMap<Sender, (u64, Instant)>]) {
let other: Vec<_> = other.to_vec();
let mut merged = HashMap::with_capacity(self.table.len() + other.len());
let new_table = other.iter().flatten();
let new_table = new_table.chain(self.table.iter());
for (sender, &(credits, deadline)) in new_table {
let max_limit = self.get_max_limit(sender);
merged
.entry(sender.clone())
.and_modify(|e: &mut (u64, Instant)| {
let buckets = vec![*e, (credits, deadline)];
let (total, latest) =
join_buckets(&buckets, max_limit, self.interval, self.refill_amount);
*e = (total, latest);
})
.or_insert((credits, deadline));
}
self.table = merged;
}
pub fn gc(&mut self, cutoff: Instant) {
self.table
.retain(|_, &mut (_, deadline)| deadline >= cutoff);
self.epoch = cutoff;
}
}
pub(crate) fn sync_necessary(
max_bucket_credits: u64,
credits_to_be_used: u64,
_now: Instant,
participants: u64,
) -> bool {
assert!(participants > 0);
if participants == 1 {
return false;
}
let fair_share: f64 = max_bucket_credits as f64 / participants as f64;
let error_margin = BUCKET_ERROR;
credits_to_be_used as f64 > (fair_share * (1.0 + error_margin))
}
pub(crate) fn join_buckets(
buckets: &[(u64, Instant)],
max_bucket_balance: u64,
bucket_interval: Duration,
refill_amount: u64,
) -> (u64, Instant) {
assert!(!buckets.is_empty());
let mut total_credits = 0u64;
let latest_request = buckets
.iter()
.map(|(_, request_time)| *request_time)
.max()
.unwrap();
for &(credits_bucket, bucket_request) in buckets {
if let Some((drained_credits, _drained_request)) = drain_bucket(
credits_bucket,
bucket_request,
latest_request,
bucket_interval,
max_bucket_balance,
refill_amount,
) {
total_credits = total_credits.saturating_add(drained_credits);
}
}
(total_credits.min(max_bucket_balance), latest_request)
}
pub(crate) fn drain_bucket(
current_credits: u64,
deadline: Instant,
now: Instant,
interval: Duration,
max_balance: u64,
refill_amount: u64,
) -> Option<(u64, Instant)> {
if now <= deadline {
return Some((current_credits, deadline));
}
let window_millis = interval.as_millis();
let since = now.saturating_duration_since(deadline).as_millis();
let periods = u64::try_from(since / window_millis).unwrap_or(u64::MAX);
let mut credits = current_credits;
credits += periods.checked_mul(refill_amount).unwrap_or(max_balance);
credits = credits.min(max_balance);
let remaining_time = u64::try_from(since % window_millis).unwrap_or(u64::MAX);
let deadline = now + interval.saturating_sub(Duration::from_millis(remaining_time));
Some((credits, deadline))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{RequestData, Sender};
use rand::Rng;
#[test]
fn test_drain() {
let now = Instant::now();
let test_cases = vec![
(0, Duration::from_secs(1), 100, 10),
(0, Duration::from_millis(500), 50, 5),
(0, Duration::from_secs(5), 1000, 100),
];
for (current_credits, interval, max_balance, credits_bucket) in test_cases {
let deadline = now + Duration::from_millis(100);
let result = drain_bucket(
current_credits,
deadline,
now,
interval,
max_balance,
credits_bucket,
);
assert_eq!(result, Some((current_credits, deadline)));
let deadline = now - Duration::from_millis(100);
let (credits, new_deadline) = drain_bucket(
current_credits,
deadline,
now,
interval,
max_balance,
credits_bucket,
)
.unwrap();
assert!(credits <= max_balance);
assert!(new_deadline > now);
assert!(new_deadline - now <= interval);
let deadline = now - interval.mul_f32(2.5);
let (credits, _) = drain_bucket(
current_credits,
deadline,
now,
interval,
max_balance,
credits_bucket,
)
.unwrap();
let expected = (credits_bucket * 2).min(max_balance);
assert_eq!(credits, expected);
}
}
#[test]
fn test_join_buckets() {
let now = Instant::now();
let buckets = vec![(50u64, now)];
let (total_credits, latest_deadline) =
join_buckets(&buckets, 1000, Duration::from_secs(1), 10);
assert_eq!(total_credits, 50);
assert_eq!(latest_deadline, now);
let buckets = vec![
(50u64, now + Duration::from_millis(500)),
(30u64, now + Duration::from_millis(500)),
];
let (total_credits, _latest_deadline) =
join_buckets(&buckets, 1000, Duration::from_secs(1), 10);
assert_eq!(total_credits, 80);
let buckets = vec![
(50u64, now + Duration::from_millis(500)),
(30u64, now + Duration::from_millis(300)),
];
let (total_credits, latest_deadline) =
join_buckets(&buckets, 1000, Duration::from_secs(1), 10);
assert_eq!(total_credits, 80);
assert_eq!(latest_deadline, now + Duration::from_millis(500));
let buckets = vec![
(20u64, now + Duration::from_millis(100)),
(40u64, now + Duration::from_millis(200)),
(60u64, now + Duration::from_millis(300)),
];
let (total_credits, latest_deadline) =
join_buckets(&buckets, 1000, Duration::from_secs(1), 10);
assert_eq!(total_credits, 120);
assert_eq!(latest_deadline, now + Duration::from_millis(300));
let buckets = vec![
(10u64, now + Duration::from_millis(100)),
(20u64, now + Duration::from_millis(100)),
(30u64, now + Duration::from_millis(100)),
];
let (total_credits, latest_deadline) =
join_buckets(&buckets, 1000, Duration::from_secs(1), 10);
assert_eq!(total_credits, 60);
assert_eq!(latest_deadline, now + Duration::from_millis(100));
let buckets = vec![
(u64::MAX, now + Duration::from_millis(100)),
(1u64, now + Duration::from_millis(200)),
];
let (total_credits, latest_deadline) =
join_buckets(&buckets, u64::MAX, Duration::from_secs(1), 10);
assert_eq!(total_credits, u64::MAX);
assert_eq!(latest_deadline, now + Duration::from_millis(200));
let buckets = vec![
(10u64, now + Duration::from_millis(100)),
(20u64, now + Duration::from_millis(200)),
(30u64, now + Duration::from_millis(300)),
(40u64, now + Duration::from_millis(400)),
(50u64, now + Duration::from_millis(500)),
(60u64, now + Duration::from_millis(600)),
(70u64, now + Duration::from_millis(700)),
(80u64, now + Duration::from_millis(800)),
(90u64, now + Duration::from_millis(900)),
(100u64, now + Duration::from_millis(1000)),
];
let (total_credits, latest_deadline) =
join_buckets(&buckets, 1000, Duration::from_secs(1), 10);
assert_eq!(total_credits, 550);
assert_eq!(latest_deadline, now + Duration::from_millis(1000));
}
#[test]
fn test_join_buckets_permutations() {
let now = Instant::now();
let original_buckets = vec![
(10u64, now + Duration::from_millis(100)),
(20u64, now + Duration::from_millis(200)),
(30u64, now + Duration::from_millis(300)),
];
let seeds = [0, 42, 100, 255];
let expected_total = 60u64;
let expected_deadline = now + Duration::from_millis(300);
for seed in seeds {
let permuted = permute_states_with_seed(original_buckets.clone(), seed);
let (total_credits, latest_deadline) =
join_buckets(&permuted, 100, Duration::from_secs(1), 10);
assert_eq!(
total_credits, expected_total,
"Total credits mismatch for seed {seed}"
);
assert_eq!(
latest_deadline, expected_deadline,
"Latest deadline mismatch for seed {seed}"
);
let mut sorted_original = original_buckets.clone();
let mut sorted_permuted = permuted.clone();
sorted_original.sort_by_key(|&(c, _)| c);
sorted_permuted.sort_by_key(|&(c, _)| c);
assert_eq!(
sorted_original, sorted_permuted,
"Permuted buckets missing elements for seed {seed}"
);
}
}
#[test]
fn test_sync_necessary() {
let now = Instant::now();
let max_bucket_credits = 100u64;
assert!(!sync_necessary(max_bucket_credits, 50, now, 1));
assert!(!sync_necessary(max_bucket_credits, 40, now, 2));
assert!(sync_necessary(max_bucket_credits, 60, now, 2));
let max_bucket_credits = 1000u64;
assert!(sync_necessary(max_bucket_credits, 276, now, 4));
assert!(!sync_necessary(max_bucket_credits, 275, now, 4));
let max_bucket_credits = u64::MAX;
let max_allowed = u64::MAX >> 1;
let max_allowed = (max_allowed as f64 * (1.0 + BUCKET_ERROR)).ceil() as u64;
assert!(!sync_necessary(max_bucket_credits, max_allowed, now, 2));
}
#[test]
fn test_integrated_bucket_operations() {
let now = Instant::now();
let interval = Duration::from_secs(1);
let max_balance = 1000;
let credits_bucket = 100;
let refill = 100;
let bucket1 = (500u64, now - Duration::from_millis(500));
let bucket2 = (400u64, now - Duration::from_millis(300));
let (credits1, deadline1) = drain_bucket(
credits_bucket,
bucket1.1,
now,
interval,
max_balance,
refill,
)
.unwrap();
let (credits2, deadline2) = drain_bucket(
credits_bucket,
bucket2.1,
now,
interval,
max_balance,
refill,
)
.unwrap();
let buckets = vec![(credits1, deadline1), (credits2, deadline2)];
let (total_credits, latest_deadline) =
join_buckets(&buckets, max_balance, interval, refill);
assert!(
total_credits > 0,
"Total credits should be positive after draining and joining"
);
assert_eq!(
latest_deadline,
deadline1.max(deadline2),
"Latest deadline should be max of both deadlines"
);
}
#[test]
fn test_touch() {
let mut leaser = Leaser::new(1, 100, 1000, Duration::from_secs(1), 10);
let now = Instant::now();
let _later = now + Duration::from_secs(1);
let request = RateLimitRequest {
request_data: RequestData::Query { limit: 10 },
sender: Sender::Payer(1),
};
let default = (100, now);
let cost = 10;
let result = leaser.touch(request.clone(), default, cost);
assert!(result.is_some());
let (credits, deadline) = result.unwrap();
assert_eq!(credits, 90); assert!(deadline >= now);
assert_eq!(leaser.table.len(), 1);
assert!(leaser.table.contains_key(&request.sender));
let (stored_credits, stored_deadline) = leaser.table[&request.sender];
assert_eq!(stored_credits, 90);
assert!(stored_deadline >= now);
let default = (0, now);
let result = leaser.touch(request.clone(), default, cost);
assert!(result.is_some());
let (credits, _) = result.unwrap();
assert_eq!(credits, 80);
let (stored_credits, _) = leaser.table[&request.sender];
assert_eq!(stored_credits, 80);
let default = (5, now);
let result = leaser.touch(request.clone(), default, cost);
assert!(result.is_some());
let (credits, _) = result.unwrap();
assert_eq!(credits, 70);
let (stored_credits, _) = leaser.table[&request.sender];
assert_eq!(stored_credits, 70);
let request2 = RateLimitRequest {
request_data: RequestData::Query { limit: 10 },
sender: Sender::Node(2),
};
let default = (100, now);
let result = leaser.touch(request2.clone(), default, cost);
assert!(result.is_some());
let (credits, _) = result.unwrap();
assert_eq!(credits, 90);
assert!(leaser.table.contains_key(&request2.sender));
let (stored_credits, _) = leaser.table[&request2.sender];
assert_eq!(stored_credits, 90);
let default = (100, now);
let result = leaser.touch(request.clone(), default, cost);
assert!(result.is_some());
let (credits, _) = result.unwrap();
assert_eq!(credits, 60);
let result = leaser.touch(request.clone(), default, cost);
assert!(result.is_some());
let (credits, _) = result.unwrap();
assert_eq!(credits, 50);
let (stored_credits, _) = leaser.table[&request.sender];
assert_eq!(stored_credits, 50);
}
#[test]
fn test_bucket_draining_over_time() {
let now = Instant::now();
let interval = Duration::from_secs(1);
let max_balance = 100;
let initial_credits = 80;
let test_points = vec![
(Duration::from_millis(0), Some((80, now))), (Duration::from_millis(500), Some((80, now + interval))), (
Duration::from_secs(1),
Some((100, now + Duration::from_secs(1) + interval)),
), (
Duration::from_millis(1500),
Some((100, now + Duration::from_secs(1) + interval)),
), (
Duration::from_secs(2),
Some((100, now + Duration::from_secs(2) + interval)),
), (
Duration::from_secs(4),
Some((100, now + Duration::from_secs(4) + interval)),
),
(
Duration::from_secs(400),
Some((100, now + Duration::from_secs(400) + interval)),
),
];
for (elapsed, expected) in test_points {
let check_time = now + elapsed;
let result = drain_bucket(
initial_credits,
now,
check_time,
interval,
max_balance,
initial_credits,
);
assert_eq!(
result, expected,
"Unexpected drain result after {elapsed:?}"
);
}
let far_future = now + Duration::from_secs(1000);
let result = drain_bucket(
initial_credits,
now,
far_future,
interval,
max_balance,
initial_credits,
);
assert_eq!(result, Some((max_balance, far_future + interval)));
}
fn permute_states_with_seed(mut states: Vec<(u64, Instant)>, seed: u64) -> Vec<(u64, Instant)> {
use rand::rngs::StdRng;
use rand::SeedableRng;
let mut rng = StdRng::seed_from_u64(seed);
for i in (1..states.len()).rev() {
#[allow(deprecated)]
let j = rng.gen_range(0..=i);
states.swap(i, j);
}
states
}
}