#![deny(missing_docs)]
use std::os::unix::io::{AsRawFd, RawFd};
use std::time::{Duration, Instant};
use std::{fmt, io};
use log::error;
use timerfd::{ClockId, SetTimeFlags, TimerFd, TimerState};
#[derive(Debug)]
pub enum Error {
SpuriousRateLimiterEvent(&'static str),
}
const REFILL_TIMER_INTERVAL_MS: u64 = 10;
const TIMER_REFILL_STATE: TimerState =
TimerState::Oneshot(Duration::from_millis(REFILL_TIMER_INTERVAL_MS));
const NANOSEC_IN_ONE_MILLISEC: u64 = 1_000_000;
fn gcd(x: u64, y: u64) -> u64 {
let mut x = x;
let mut y = y;
while y != 0 {
let t = y;
y = x % y;
x = t;
}
x
}
#[derive(Clone, Debug, PartialEq)]
pub enum BucketReduction {
Failure,
Success,
OverConsumption(f64),
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TokenBucket {
size: u64,
initial_one_time_burst: u64,
refill_time: u64,
one_time_burst: u64,
budget: u64,
last_update: Instant,
processed_capacity: u64,
processed_refill_time: u64,
}
impl TokenBucket {
pub fn new(size: u64, one_time_burst: u64, complete_refill_time_ms: u64) -> Self {
debug_assert!(size != 0 && complete_refill_time_ms != 0);
let complete_refill_time_ns = complete_refill_time_ms * NANOSEC_IN_ONE_MILLISEC;
let common_factor = gcd(size, complete_refill_time_ns);
let processed_capacity: u64 = size / common_factor;
let processed_refill_time: u64 = complete_refill_time_ns / common_factor;
TokenBucket {
size,
one_time_burst,
initial_one_time_burst: one_time_burst,
refill_time: complete_refill_time_ms,
budget: size,
last_update: Instant::now(),
processed_capacity,
processed_refill_time,
}
}
fn auto_replenish(&mut self) {
let time_delta = self.last_update.elapsed().as_nanos() as u64;
self.last_update = Instant::now();
let tokens = (time_delta * self.processed_capacity) / self.processed_refill_time;
self.budget = std::cmp::min(self.budget + tokens, self.size);
}
pub fn reduce(&mut self, mut tokens: u64) -> BucketReduction {
if self.one_time_burst > 0 {
if self.one_time_burst >= tokens {
self.one_time_burst -= tokens;
self.last_update = Instant::now();
return BucketReduction::Success;
} else {
tokens -= self.one_time_burst;
self.one_time_burst = 0;
}
}
if tokens > self.budget {
self.auto_replenish();
if tokens > self.size {
error!(
"Consumed {} tokens from bucket of size {}",
tokens, self.size
);
tokens -= self.budget;
self.budget = 0;
return BucketReduction::OverConsumption(tokens as f64 / self.size as f64);
}
if tokens > self.budget {
return BucketReduction::Failure;
}
}
self.budget -= tokens;
BucketReduction::Success
}
pub fn force_replenish(&mut self, tokens: u64) {
if self.one_time_burst > 0 {
self.one_time_burst += tokens;
return;
}
self.budget = std::cmp::min(self.budget + tokens, self.size);
}
pub fn capacity(&self) -> u64 {
self.size
}
pub fn one_time_burst(&self) -> u64 {
self.one_time_burst
}
pub fn refill_time_ms(&self) -> u64 {
self.refill_time
}
pub fn budget(&self) -> u64 {
self.budget
}
pub fn initial_one_time_burst(&self) -> u64 {
self.initial_one_time_burst
}
}
pub enum TokenType {
Bytes,
Ops,
}
#[derive(Clone, Debug)]
pub enum BucketUpdate {
None,
Disabled,
Update(TokenBucket),
}
pub struct RateLimiter {
bandwidth: Option<TokenBucket>,
ops: Option<TokenBucket>,
timer_fd: TimerFd,
timer_active: bool,
}
impl PartialEq for RateLimiter {
fn eq(&self, other: &RateLimiter) -> bool {
self.bandwidth == other.bandwidth && self.ops == other.ops
}
}
impl fmt::Debug for RateLimiter {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"RateLimiter {{ bandwidth: {:?}, ops: {:?} }}",
self.bandwidth, self.ops
)
}
}
impl RateLimiter {
pub fn make_bucket(
total_capacity: u64,
one_time_burst: u64,
complete_refill_time_ms: u64,
) -> Option<TokenBucket> {
if total_capacity != 0 && complete_refill_time_ms != 0 {
Some(TokenBucket::new(
total_capacity,
one_time_burst,
complete_refill_time_ms,
))
} else {
None
}
}
pub fn new(
bytes_total_capacity: u64,
bytes_one_time_burst: u64,
bytes_complete_refill_time_ms: u64,
ops_total_capacity: u64,
ops_one_time_burst: u64,
ops_complete_refill_time_ms: u64,
) -> io::Result<Self> {
let bytes_token_bucket = Self::make_bucket(
bytes_total_capacity,
bytes_one_time_burst,
bytes_complete_refill_time_ms,
);
let ops_token_bucket = Self::make_bucket(
ops_total_capacity,
ops_one_time_burst,
ops_complete_refill_time_ms,
);
let timer_fd = TimerFd::new_custom(ClockId::Monotonic, true, true)?;
Ok(RateLimiter {
bandwidth: bytes_token_bucket,
ops: ops_token_bucket,
timer_fd,
timer_active: false,
})
}
fn activate_timer(&mut self, timer_state: TimerState) {
self.timer_fd.set_state(timer_state, SetTimeFlags::Default);
self.timer_active = true;
}
pub fn consume(&mut self, tokens: u64, token_type: TokenType) -> bool {
if self.timer_active {
return false;
}
let token_bucket = match token_type {
TokenType::Bytes => self.bandwidth.as_mut(),
TokenType::Ops => self.ops.as_mut(),
};
if let Some(bucket) = token_bucket {
let refill_time = bucket.refill_time_ms();
match bucket.reduce(tokens) {
BucketReduction::Failure => {
if !self.timer_active {
self.activate_timer(TIMER_REFILL_STATE);
}
false
}
BucketReduction::Success => true,
BucketReduction::OverConsumption(ratio) => {
self.activate_timer(TimerState::Oneshot(Duration::from_millis(
(ratio * refill_time as f64) as u64,
)));
true
}
}
} else {
true
}
}
pub fn manual_replenish(&mut self, tokens: u64, token_type: TokenType) {
let token_bucket = match token_type {
TokenType::Bytes => self.bandwidth.as_mut(),
TokenType::Ops => self.ops.as_mut(),
};
if let Some(bucket) = token_bucket {
bucket.force_replenish(tokens);
}
}
pub fn is_blocked(&self) -> bool {
self.timer_active
}
pub fn event_handler(&mut self) -> Result<(), Error> {
match self.timer_fd.read() {
0 => Err(Error::SpuriousRateLimiterEvent(
"Rate limiter event handler called without a present timer",
)),
_ => {
self.timer_active = false;
Ok(())
}
}
}
pub fn update_buckets(&mut self, bytes: BucketUpdate, ops: BucketUpdate) {
match bytes {
BucketUpdate::Disabled => self.bandwidth = None,
BucketUpdate::Update(tb) => self.bandwidth = Some(tb),
BucketUpdate::None => (),
};
match ops {
BucketUpdate::Disabled => self.ops = None,
BucketUpdate::Update(tb) => self.ops = Some(tb),
BucketUpdate::None => (),
};
}
pub fn bandwidth(&self) -> Option<&TokenBucket> {
self.bandwidth.as_ref()
}
pub fn ops(&self) -> Option<&TokenBucket> {
self.ops.as_ref()
}
}
impl AsRawFd for RateLimiter {
fn as_raw_fd(&self) -> RawFd {
self.timer_fd.as_raw_fd()
}
}
impl Default for RateLimiter {
fn default() -> Self {
RateLimiter::new(0, 0, 0, 0, 0, 0).expect("Failed to build default RateLimiter")
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
const TEST_REFILL_TIMER_INTERVAL_MS: u64 = 100;
impl TokenBucket {
fn reset(&mut self) {
self.budget = self.size;
self.last_update = Instant::now();
}
fn get_last_update(&self) -> &Instant {
&self.last_update
}
fn get_processed_capacity(&self) -> u64 {
self.processed_capacity
}
fn get_processed_refill_time(&self) -> u64 {
self.processed_refill_time
}
pub fn partial_eq(&self, other: &TokenBucket) -> bool {
(other.capacity() == self.capacity())
&& (other.one_time_burst() == self.one_time_burst())
&& (other.refill_time_ms() == self.refill_time_ms())
&& (other.budget() == self.budget())
}
}
impl RateLimiter {
fn get_token_bucket(&self, token_type: TokenType) -> Option<&TokenBucket> {
match token_type {
TokenType::Bytes => self.bandwidth.as_ref(),
TokenType::Ops => self.ops.as_ref(),
}
}
}
#[test]
fn test_token_bucket_create() {
let before = Instant::now();
let tb = TokenBucket::new(1000, 0, 1000);
assert_eq!(tb.capacity(), 1000);
assert_eq!(tb.budget(), 1000);
assert_eq!(tb.initial_one_time_burst(), 0);
assert!(*tb.get_last_update() >= before);
let after = Instant::now();
assert!(*tb.get_last_update() <= after);
assert_eq!(tb.get_processed_capacity(), 1);
assert_eq!(tb.get_processed_refill_time(), 1_000_000);
}
#[test]
fn test_token_bucket_preprocess() {
let tb = TokenBucket::new(1000, 0, 1000);
assert_eq!(tb.get_processed_capacity(), 1);
assert_eq!(tb.get_processed_refill_time(), NANOSEC_IN_ONE_MILLISEC);
let thousand = 1000;
let tb = TokenBucket::new(3 * 7 * 11 * 19 * thousand, 0, 7 * 11 * 13 * 17);
assert_eq!(tb.get_processed_capacity(), 3 * 19);
assert_eq!(
tb.get_processed_refill_time(),
13 * 17 * (NANOSEC_IN_ONE_MILLISEC / thousand)
);
}
#[test]
fn test_token_bucket_reduce() {
let capacity = 1000;
let refill_ms = 1000;
let mut tb = TokenBucket::new(capacity, 0, refill_ms as u64);
assert_eq!(tb.reduce(123), BucketReduction::Success);
assert_eq!(tb.budget(), capacity - 123);
assert_eq!(tb.reduce(capacity), BucketReduction::Failure);
thread::sleep(Duration::from_millis(80));
assert_eq!(tb.reduce(1), BucketReduction::Success);
assert_eq!(tb.reduce(100), BucketReduction::Success);
assert_eq!(tb.reduce(capacity), BucketReduction::Failure);
let mut tb = TokenBucket::new(1000, 1100, 1000);
assert_eq!(tb.reduce(1000), BucketReduction::Success);
assert_eq!(tb.one_time_burst(), 100);
assert_eq!(tb.reduce(500), BucketReduction::Success);
assert_eq!(tb.one_time_burst(), 0);
assert_eq!(tb.reduce(500), BucketReduction::Success);
assert_eq!(tb.reduce(500), BucketReduction::Failure);
thread::sleep(Duration::from_millis(500));
assert_eq!(tb.reduce(500), BucketReduction::Success);
thread::sleep(Duration::from_millis(1000));
assert_eq!(tb.reduce(2500), BucketReduction::OverConsumption(1.5));
let before = Instant::now();
tb.reset();
assert_eq!(tb.capacity(), 1000);
assert_eq!(tb.budget(), 1000);
assert!(*tb.get_last_update() >= before);
let after = Instant::now();
assert!(*tb.get_last_update() <= after);
}
#[test]
fn test_rate_limiter_default() {
let mut l = RateLimiter::default();
assert!(!l.is_blocked());
assert!(l.consume(u64::max_value(), TokenType::Ops));
assert!(l.consume(u64::max_value(), TokenType::Bytes));
assert!(l.event_handler().is_err());
assert_eq!(
format!("{:?}", l.event_handler().err().unwrap()),
"SpuriousRateLimiterEvent(\
\"Rate limiter event handler called without a present timer\")"
);
}
#[test]
fn test_rate_limiter_new() {
let l = RateLimiter::new(1000, 1001, 1002, 1003, 1004, 1005).unwrap();
let bw = l.bandwidth.unwrap();
assert_eq!(bw.capacity(), 1000);
assert_eq!(bw.one_time_burst(), 1001);
assert_eq!(bw.initial_one_time_burst(), 1001);
assert_eq!(bw.refill_time_ms(), 1002);
assert_eq!(bw.budget(), 1000);
let ops = l.ops.unwrap();
assert_eq!(ops.capacity(), 1003);
assert_eq!(ops.one_time_burst(), 1004);
assert_eq!(ops.initial_one_time_burst(), 1004);
assert_eq!(ops.refill_time_ms(), 1005);
assert_eq!(ops.budget(), 1003);
}
#[test]
fn test_rate_limiter_manual_replenish() {
let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
assert!(l.consume(123, TokenType::Bytes));
l.manual_replenish(23, TokenType::Bytes);
{
let bytes_tb = l.get_token_bucket(TokenType::Bytes).unwrap();
assert_eq!(bytes_tb.budget(), 900);
}
assert!(l.consume(123, TokenType::Ops));
l.manual_replenish(23, TokenType::Ops);
{
let bytes_tb = l.get_token_bucket(TokenType::Ops).unwrap();
assert_eq!(bytes_tb.budget(), 900);
}
}
#[test]
fn test_rate_limiter_bandwidth() {
let mut l = RateLimiter::new(1000, 0, 1000, 0, 0, 0).unwrap();
assert!(!l.is_blocked());
assert!(l.as_raw_fd() > 0);
assert!(l.consume(u64::max_value(), TokenType::Ops));
assert!(l.consume(1000, TokenType::Bytes));
assert!(!l.consume(100, TokenType::Bytes));
assert!(l.is_blocked());
thread::sleep(Duration::from_millis(TEST_REFILL_TIMER_INTERVAL_MS / 2));
assert!(l.is_blocked());
thread::sleep(Duration::from_millis(TEST_REFILL_TIMER_INTERVAL_MS / 2));
assert!(l.event_handler().is_ok());
assert!(!l.is_blocked());
assert!(l.consume(100, TokenType::Bytes));
}
#[test]
fn test_rate_limiter_ops() {
let mut l = RateLimiter::new(0, 0, 0, 1000, 0, 1000).unwrap();
assert!(!l.is_blocked());
assert!(l.as_raw_fd() > 0);
assert!(l.consume(u64::max_value(), TokenType::Bytes));
assert!(l.consume(1000, TokenType::Ops));
assert!(!l.consume(100, TokenType::Ops));
assert!(l.is_blocked());
thread::sleep(Duration::from_millis(TEST_REFILL_TIMER_INTERVAL_MS / 2));
assert!(l.is_blocked());
thread::sleep(Duration::from_millis(TEST_REFILL_TIMER_INTERVAL_MS / 2));
assert!(l.event_handler().is_ok());
assert!(!l.is_blocked());
assert!(l.consume(100, TokenType::Ops));
}
#[test]
fn test_rate_limiter_full() {
let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
assert!(!l.is_blocked());
assert!(l.as_raw_fd() > 0);
assert!(l.consume(1000, TokenType::Ops));
assert!(l.consume(1000, TokenType::Bytes));
assert!(!l.consume(100, TokenType::Ops));
assert!(!l.consume(100, TokenType::Bytes));
assert!(l.is_blocked());
thread::sleep(Duration::from_millis(TEST_REFILL_TIMER_INTERVAL_MS / 2));
assert!(l.is_blocked());
thread::sleep(Duration::from_millis(TEST_REFILL_TIMER_INTERVAL_MS / 2));
assert!(l.event_handler().is_ok());
assert!(!l.is_blocked());
assert!(l.consume(100, TokenType::Ops));
assert!(l.consume(100, TokenType::Bytes));
}
#[test]
fn test_rate_limiter_overconsumption() {
let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
assert!(l.consume(2500, TokenType::Bytes));
thread::sleep(Duration::from_millis(1000));
assert!(l.event_handler().is_err());
assert!(l.is_blocked());
thread::sleep(Duration::from_millis(500));
assert!(l.event_handler().is_ok());
assert!(!l.is_blocked());
let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap();
assert!(l.consume(1500, TokenType::Bytes));
thread::sleep(Duration::from_millis(200));
assert!(l.event_handler().is_err());
assert!(l.is_blocked());
assert!(!l.consume(100, TokenType::Bytes));
assert!(l.event_handler().is_err());
assert!(l.is_blocked());
thread::sleep(Duration::from_millis(90));
assert!(l.event_handler().is_err());
assert!(l.is_blocked());
assert!(!l.consume(100, TokenType::Bytes));
thread::sleep(Duration::from_millis(210));
assert!(l.event_handler().is_ok());
assert!(!l.is_blocked());
assert!(l.consume(100, TokenType::Bytes));
}
#[test]
fn test_update_buckets() {
let mut x = RateLimiter::new(1000, 2000, 1000, 10, 20, 1000).unwrap();
let initial_bw = x.bandwidth.clone();
let initial_ops = x.ops.clone();
x.update_buckets(BucketUpdate::None, BucketUpdate::None);
assert_eq!(x.bandwidth, initial_bw);
assert_eq!(x.ops, initial_ops);
let new_bw = RateLimiter::make_bucket(123, 0, 57).unwrap();
let new_ops = RateLimiter::make_bucket(321, 12346, 89).unwrap();
x.update_buckets(
BucketUpdate::Update(new_bw.clone()),
BucketUpdate::Update(new_ops.clone()),
);
x.bandwidth.as_mut().unwrap().last_update = new_bw.last_update;
x.ops.as_mut().unwrap().last_update = new_ops.last_update;
assert_eq!(x.bandwidth, Some(new_bw));
assert_eq!(x.ops, Some(new_ops));
x.update_buckets(BucketUpdate::Disabled, BucketUpdate::Disabled);
assert_eq!(x.bandwidth, None);
assert_eq!(x.ops, None);
}
#[test]
fn test_rate_limiter_debug() {
let l = RateLimiter::new(1, 2, 3, 4, 5, 6).unwrap();
assert_eq!(
format!("{:?}", l),
format!(
"RateLimiter {{ bandwidth: {:?}, ops: {:?} }}",
l.bandwidth(),
l.ops()
),
);
}
}