use {
cfg_if::cfg_if,
dashmap::{mapref::entry::Entry, DashMap},
solana_svm_type_overrides::sync::atomic::{AtomicU64, AtomicUsize, Ordering},
std::{borrow::Borrow, cmp::Reverse, hash::Hash, time::Instant},
};
pub struct TokenBucket {
new_tokens_per_us: f64,
max_tokens: u64,
base_time: Instant,
tokens: AtomicU64,
last_update: AtomicU64,
credit_time_us: AtomicU64,
}
#[cfg(feature = "shuttle-test")]
static TIME_US: AtomicU64 = AtomicU64::new(0);
impl TokenBucket {
pub fn new(initial_tokens: u64, max_tokens: u64, new_tokens_per_second: f64) -> Self {
assert!(
new_tokens_per_second > 0.0,
"Token bucket can not have zero influx rate"
);
assert!(
initial_tokens <= max_tokens,
"Can not have more initial tokens than max tokens"
);
let base_time = Instant::now();
TokenBucket {
new_tokens_per_us: new_tokens_per_second / 1e6,
max_tokens,
tokens: AtomicU64::new(initial_tokens),
last_update: AtomicU64::new(0),
base_time,
credit_time_us: AtomicU64::new(0),
}
}
#[inline]
pub fn current_tokens(&self) -> u64 {
let now = self.time_us();
self.update_state(now);
self.tokens.load(Ordering::Relaxed)
}
#[inline]
pub fn consume_tokens(&self, request_size: u64) -> Result<u64, u64> {
let now = self.time_us();
self.update_state(now);
match self.tokens.fetch_update(
Ordering::AcqRel, Ordering::Acquire, |tokens| {
if tokens >= request_size {
Some(tokens.saturating_sub(request_size))
} else {
None
}
},
) {
Ok(prev) => Ok(prev.saturating_sub(request_size)),
Err(prev) => Err(request_size.saturating_sub(prev)),
}
}
fn time_us(&self) -> u64 {
cfg_if! {
if #[cfg(feature="shuttle-test")] {
TIME_US.load(Ordering::Relaxed)
} else {
let now = Instant::now();
let elapsed = now.saturating_duration_since(self.base_time);
elapsed.as_micros() as u64
}
}
}
fn update_state(&self, now: u64) {
let last = self.last_update.load(Ordering::SeqCst);
if now <= last {
return;
}
match self.last_update.compare_exchange(
last,
now,
Ordering::AcqRel, Ordering::Acquire, ) {
Ok(_) => {
let elapsed = now.saturating_sub(last);
let elapsed =
elapsed.saturating_add(self.credit_time_us.swap(0, Ordering::Relaxed));
let new_tokens_f64 = elapsed as f64 * self.new_tokens_per_us;
let new_tokens = new_tokens_f64.floor() as u64;
let time_to_return = if new_tokens >= 1 {
let _ = self.tokens.fetch_update(
Ordering::AcqRel, Ordering::Acquire, |tokens| Some(tokens.saturating_add(new_tokens).min(self.max_tokens)),
);
(new_tokens_f64.fract() / self.new_tokens_per_us) as u64
} else {
elapsed
};
self.credit_time_us
.fetch_add(time_to_return, Ordering::Relaxed);
}
Err(_) => {
}
}
}
}
impl Clone for TokenBucket {
fn clone(&self) -> Self {
Self {
new_tokens_per_us: self.new_tokens_per_us,
max_tokens: self.max_tokens,
base_time: self.base_time,
tokens: AtomicU64::new(self.tokens.load(Ordering::Relaxed)),
last_update: AtomicU64::new(self.last_update.load(Ordering::Relaxed)),
credit_time_us: AtomicU64::new(self.credit_time_us.load(Ordering::Relaxed)),
}
}
}
pub struct KeyedRateLimiter<K>
where
K: Hash + Eq,
{
data: DashMap<K, TokenBucket>,
target_capacity: usize,
prototype_bucket: TokenBucket,
countdown_to_shrink: AtomicUsize,
approx_len: AtomicUsize,
shrink_interval: usize,
}
impl<K> KeyedRateLimiter<K>
where
K: Hash + Eq,
{
#[allow(clippy::arithmetic_side_effects)]
pub fn new(target_capacity: usize, prototype_bucket: TokenBucket, shard_amount: usize) -> Self {
let shrink_interval = target_capacity / 4;
Self {
data: DashMap::with_capacity_and_shard_amount(target_capacity * 2, shard_amount),
target_capacity,
prototype_bucket,
countdown_to_shrink: AtomicUsize::new(shrink_interval),
approx_len: AtomicUsize::new(0),
shrink_interval,
}
}
#[inline]
pub fn current_tokens(&self, key: impl Borrow<K>) -> Option<u64> {
let bucket = self.data.get(key.borrow())?;
Some(bucket.current_tokens())
}
pub fn consume_tokens(&self, key: K, request_size: u64) -> Result<u64, u64> {
let (entry_added, res) = {
let bucket = self.data.entry(key);
match bucket {
Entry::Occupied(entry) => (false, entry.get().consume_tokens(request_size)),
Entry::Vacant(entry) => {
let bucket = self.prototype_bucket.clone();
let res = bucket.consume_tokens(request_size);
entry.insert(bucket);
(true, res)
}
}
};
if entry_added {
if let Ok(count) =
self.countdown_to_shrink
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
if v == 0 {
None
} else {
Some(v.saturating_sub(1))
}
})
{
if count == 1 {
self.maybe_shrink();
self.countdown_to_shrink
.store(self.shrink_interval, Ordering::Relaxed);
}
} else {
self.approx_len.fetch_add(1, Ordering::Relaxed);
}
}
res
}
#[inline]
pub fn len_approx(&self) -> usize {
self.approx_len.load(Ordering::Relaxed)
}
#[allow(clippy::arithmetic_side_effects)]
fn maybe_shrink(&self) {
let mut actual_len = 0;
let target_shard_size = self.target_capacity / self.data.shards().len();
let mut entries = Vec::with_capacity(target_shard_size * 2);
for shardlock in self.data.shards() {
let mut shard = shardlock.write();
if shard.len() <= target_shard_size * 3 / 2 {
actual_len += shard.len();
continue;
}
entries.clear();
entries.extend(
shard.drain().map(|(key, value)| {
(key, value.get().last_update.load(Ordering::SeqCst), value)
}),
);
entries.select_nth_unstable_by_key(target_shard_size, |(_, last_update, _)| {
Reverse(*last_update)
});
shard.extend(
entries
.drain(..)
.take(target_shard_size)
.map(|(key, _last_update, value)| (key, value)),
);
debug_assert!(shard.len() <= target_shard_size);
actual_len += shard.len();
}
self.approx_len.store(actual_len, Ordering::Relaxed);
}
pub fn set_shrink_interval(&mut self, interval: usize) {
self.shrink_interval = interval;
}
pub fn shrink_interval(&self) -> usize {
self.shrink_interval
}
}
#[cfg(test)]
pub mod test {
use {
super::*,
solana_svm_type_overrides::thread,
std::{
net::{IpAddr, Ipv4Addr},
time::Duration,
},
};
#[test]
fn test_token_bucket() {
let tb = TokenBucket::new(100, 100, 1000.0);
assert_eq!(tb.current_tokens(), 100);
tb.consume_tokens(50).expect("Bucket is initially full");
tb.consume_tokens(50)
.expect("We should still have >50 tokens left");
tb.consume_tokens(50)
.expect_err("There should not be enough tokens now");
thread::sleep(Duration::from_millis(50));
assert!(
tb.current_tokens() > 40,
"We should be refilling at ~1 token per millisecond"
);
assert!(
tb.current_tokens() < 70,
"We should be refilling at ~1 token per millisecond"
);
tb.consume_tokens(40)
.expect("Bucket should have enough for another request now");
thread::sleep(Duration::from_millis(120));
assert_eq!(tb.current_tokens(), 100, "Bucket should not overfill");
}
#[test]
fn test_keyed_rate_limiter() {
let prototype_bucket = TokenBucket::new(100, 100, 1000.0);
let rl = KeyedRateLimiter::new(8, prototype_bucket, 2);
let ip1 = IpAddr::V4(Ipv4Addr::from_bits(1234));
let ip2 = IpAddr::V4(Ipv4Addr::from_bits(4321));
assert_eq!(rl.current_tokens(ip1), None, "Initially no buckets exist");
rl.consume_tokens(ip1, 50)
.expect("Bucket is initially full");
rl.consume_tokens(ip1, 50)
.expect("We should still have >50 tokens left");
rl.consume_tokens(ip1, 50)
.expect_err("There should not be enough tokens now");
rl.consume_tokens(ip2, 50)
.expect("Bucket is initially full");
rl.consume_tokens(ip2, 50)
.expect("We should still have >50 tokens left");
rl.consume_tokens(ip2, 50)
.expect_err("There should not be enough tokens now");
std::thread::sleep(Duration::from_millis(50));
assert!(
rl.current_tokens(ip1).unwrap() > 40,
"We should be refilling at ~1 token per millisecond"
);
assert!(
rl.current_tokens(ip1).unwrap() < 70,
"We should be refilling at ~1 token per millisecond"
);
rl.consume_tokens(ip1, 40)
.expect("Bucket should have enough for another request now");
thread::sleep(Duration::from_millis(120));
assert_eq!(
rl.current_tokens(ip1),
Some(100),
"Bucket should not overfill"
);
assert_eq!(
rl.current_tokens(ip2),
Some(100),
"Bucket should not overfill"
);
rl.consume_tokens(ip2, 100).expect("Bucket should be full");
for ip in 0..64 {
let ip = IpAddr::V4(Ipv4Addr::from_bits(ip));
rl.consume_tokens(ip, 50).unwrap();
}
assert_eq!(
rl.current_tokens(ip1),
None,
"Very old record should have been erased"
);
rl.consume_tokens(ip2, 100)
.expect("New bucket should have been made for ip2");
}
#[cfg(feature = "shuttle-test")]
#[test]
fn shuttle_test_token_bucket_race() {
use shuttle::sync::atomic::AtomicBool;
shuttle::check_random(
|| {
TIME_US.store(0, Ordering::SeqCst);
let test_duration_us = 2500;
let run: &AtomicBool = Box::leak(Box::new(AtomicBool::new(true)));
let tb: &TokenBucket = Box::leak(Box::new(TokenBucket::new(10, 20, 5000.0)));
let time_advancer = thread::spawn(move || {
let mut current_time = 0;
while current_time < test_duration_us && run.load(Ordering::SeqCst) {
let increment = 100; current_time += increment;
TIME_US.store(current_time, Ordering::SeqCst);
shuttle::thread::yield_now();
}
run.store(false, Ordering::SeqCst);
});
let threads: Vec<_> = (0..2)
.map(|_| {
thread::spawn(move || {
let mut total = 0;
while run.load(Ordering::SeqCst) {
if tb.consume_tokens(5).is_ok() {
total += 1;
}
shuttle::thread::yield_now();
}
total
})
})
.collect();
time_advancer.join().unwrap();
let received = threads.into_iter().map(|t| t.join().unwrap()).sum();
assert_eq!(4, received);
},
100,
);
}
}