use pdk_core::policy_context::api::Tier;
use serde::{Deserialize, Serialize};
use super::{LimitStats, RequestAllowed};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct LocalBucket {
limits: Vec<LocalLimit>,
}
impl LocalBucket {
pub fn new(start_time: u128, tiers: &[Tier]) -> Self {
Self {
limits: tiers
.iter()
.map(|tier| LocalLimit::from(start_time, tier))
.collect(),
}
}
pub fn request_allowed(&mut self, now: u128, amount: usize) -> RequestAllowed {
self.set_time(now);
if !self.limits.is_empty() && self.limits.iter().all(|l| l.remaining(amount)) {
self.limits.iter_mut().for_each(|l| l.allow_request(amount));
RequestAllowed::Allowed
} else {
RequestAllowed::OutOfQuota
}
}
pub fn status(&self) -> Option<LimitStats> {
self.limits
.iter()
.map(LocalLimit::stats)
.reduce(LimitStats::most_restrictive)
}
fn set_time(&mut self, now: u128) {
self.limits.iter_mut().for_each(|limit| limit.set_time(now));
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct LocalLimit {
available: u64,
max_requests: u64,
period: u128,
reset: u128,
}
impl LocalLimit {
fn from(start_time: u128, tier: &Tier) -> Self {
Self {
available: tier.requests,
max_requests: tier.requests,
period: tier.period_in_millis as u128,
reset: start_time + tier.period_in_millis as u128,
}
}
fn stats(&self) -> LimitStats {
LimitStats::new(self.available, self.max_requests, self.reset)
}
fn set_time(&mut self, now: u128) {
while self.reset <= now {
self.reset += self.period;
self.available = self.max_requests;
}
}
fn remaining(&self, amount: usize) -> bool {
self.available >= amount as u64
}
fn allow_request(&mut self, amount: usize) {
self.available -= amount as u64;
}
}