use std::time::Duration;
use tokio::time::Instant;
#[non_exhaustive]
#[derive(Copy, Clone, Debug)]
pub enum RateClass {
Regular,
Moderator,
Known,
Verified,
}
impl Default for RateClass {
fn default() -> Self {
Self::Regular
}
}
impl RateClass {
pub fn tickets(self) -> u64 {
match self {
Self::Regular => 20,
Self::Moderator => 100,
Self::Known => 50,
Self::Verified => 7500,
}
}
pub const fn period() -> Duration {
Duration::from_secs(30)
}
}
#[derive(Debug, Clone)]
pub struct RateLimit {
cap: u64,
bucket: Bucket,
}
impl Default for RateLimit {
fn default() -> Self {
Self::from_class(<_>::default())
}
}
impl RateLimit {
pub fn from_class(rate_class: RateClass) -> Self {
Self::full(rate_class.tickets(), RateClass::period())
}
pub fn new(cap: u64, initial: u64, period: Duration) -> Self {
Self {
cap,
bucket: Bucket::new(cap, initial, period),
}
}
pub fn full(cap: u64, period: Duration) -> Self {
Self {
cap,
bucket: Bucket::new(cap, cap, period),
}
}
pub fn empty(cap: u64, period: Duration) -> Self {
Self {
cap,
bucket: Bucket::new(cap, 0, period),
}
}
pub fn consume(&mut self, tokens: u64) -> Result<u64, Duration> {
let Self { bucket, .. } = self;
let now = Instant::now();
if let Some(n) = bucket.refill(now) {
bucket.tokens = std::cmp::min(bucket.tokens + n, self.cap);
}
if tokens <= bucket.tokens {
bucket.tokens -= tokens;
bucket.backoff = 0;
return Ok(bucket.tokens);
}
let prev = bucket.tokens;
Err(bucket.estimate(tokens - prev, now))
}
pub async fn throttle(&mut self, tokens: u64) -> u64 {
loop {
match self.consume(tokens) {
Ok(rem) => return rem,
Err(time) => {
log::debug!("blocking for: {:.3?}", time);
tokio::time::delay_for(time).await
}
}
}
}
#[inline]
pub async fn take(&mut self) -> u64 {
self.throttle(1).await
}
}
#[derive(Debug, Clone)]
struct Bucket {
tokens: u64,
backoff: u32,
next: Instant,
last: Instant,
quantum: u64,
period: Duration,
}
impl Bucket {
fn new(tokens: u64, initial: u64, period: Duration) -> Self {
let now = Instant::now();
Self {
tokens: initial,
backoff: 0,
next: now + period,
last: now,
quantum: tokens,
period,
}
}
fn refill(&mut self, now: Instant) -> Option<u64> {
if now < self.next {
return None;
}
let last = now.duration_since(self.last);
let periods = last.as_nanos().checked_div(self.period.as_nanos())? as u64;
self.last += self.period * (periods as u32);
self.next = self.last + self.period;
(periods * self.quantum).into()
}
fn estimate(&mut self, tokens: u64, now: Instant) -> Duration {
let until = self.next.duration_since(now);
let periods = (tokens.checked_add(self.quantum).unwrap() - 1) / self.quantum;
until + self.period * (periods as u32 - 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::prelude::*;
#[test]
fn consume() {
let mut rate = RateLimit::full(10, Duration::from_secs(30));
assert_eq!(rate.consume(1).unwrap(), 9);
assert_eq!(rate.consume(3).unwrap(), 6);
assert_eq!(rate.consume(6).unwrap(), 0);
assert!(rate.consume(1).unwrap_err() <= Duration::from_secs(30));
}
#[tokio::test]
async fn throttle() {
tokio::time::pause();
let mut rate = RateLimit::full(10, Duration::from_secs(30));
let range = [(3, 7), (3, 4), (3, 1)];
for (take, amount) in range.iter().copied() {
assert_eq!(rate.throttle(take).now_or_never().unwrap(), amount)
}
assert!(rate.throttle(3).now_or_never().is_none());
tokio::time::advance(Duration::from_secs(31)).await;
assert_eq!(rate.throttle(3).now_or_never().unwrap(), 7);
}
#[tokio::test]
async fn take() {
tokio::time::pause();
let mut rate = RateLimit::full(10, Duration::from_secs(30));
let range = 0..=9;
for tokens in range.clone().zip(range.rev()).map(|(_, r)| r) {
assert_eq!(rate.take().now_or_never().unwrap(), tokens)
}
assert!(rate.take().now_or_never().is_none());
tokio::time::advance(Duration::from_secs(31)).await;
assert_eq!(rate.take().now_or_never().unwrap(), 9);
}
}