use std::{cmp, fmt::Display, time::Duration};
use super::{StateStore, clock, nanos::Nanos, quota::Quota};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct StateSnapshot {
t: Nanos,
tau: Nanos,
pub(crate) time_of_measurement: Nanos,
pub(crate) tat: Nanos,
}
impl StateSnapshot {
#[inline]
pub(crate) const fn new(t: Nanos, tau: Nanos, time_of_measurement: Nanos, tat: Nanos) -> Self {
Self { t, tau, time_of_measurement, tat }
}
pub fn quota(&self) -> Quota {
Quota::from_gcra_parameters(self.t, self.tau)
}
#[allow(dead_code)]
pub fn remaining_burst_capacity(&self) -> u32 {
let t0 = self.time_of_measurement + self.t;
(cmp::min((t0 + self.tau).saturating_sub(self.tat).as_u64(), self.tau.as_u64()) / self.t.as_u64()) as u32
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct NotUntil<P: clock::Reference> {
state: StateSnapshot,
start: P,
}
impl<P: clock::Reference> NotUntil<P> {
#[inline]
pub(crate) const fn new(state: StateSnapshot, start: P) -> Self {
Self { state, start }
}
#[inline]
pub fn earliest_possible(&self) -> P {
let tat: Nanos = self.state.tat;
self.start + tat
}
#[inline]
pub fn wait_time_from(&self, from: P) -> Duration {
let earliest = self.earliest_possible();
earliest.duration_since(earliest.min(from)).into()
}
#[inline]
pub fn quota(&self) -> Quota {
self.state.quota()
}
}
impl<P: clock::Reference> Display for NotUntil<P> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
write!(f, "rate-limited until {:?}", self.start + self.state.tat)
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct Gcra {
t: Nanos,
tau: Nanos,
}
impl Gcra {
pub(crate) fn new(quota: Quota) -> Self {
let tau: Nanos = (quota.replenish_1_per * quota.max_burst.get()).into();
let t: Nanos = quota.replenish_1_per.into();
Self { t, tau }
}
fn starting_state(&self, t0: Nanos) -> Nanos {
t0 + self.t
}
pub(crate) fn test_and_update<K, S: StateStore<Key = K>, P: clock::Reference>(&self, start: P, key: &K, state: &S, t0: P) -> Result<(), NotUntil<P>> {
self.test_and_update_n(start, key, state, t0, 1)
}
pub(crate) fn test_and_update_n<K, S: StateStore<Key = K>, P: clock::Reference>(&self, start: P, key: &K, state: &S, t0: P, n: u32) -> Result<(), NotUntil<P>> {
let tau = self.tau;
let t = self.t;
assert!(
t.as_u64().saturating_mul(n as u64) <= tau.as_u64(),
"weight {n} exceeds burst capacity ({}); this request can never be satisfied",
tau.as_u64() / t.as_u64()
);
let weight = t * n as u64;
let t0 = t0.duration_since(start);
state.measure_and_replace(key, |tat| {
let tat = tat.unwrap_or_else(|| self.starting_state(t0));
let earliest_time = tat.saturating_sub(tau);
if t0 < earliest_time {
Err(NotUntil::new(StateSnapshot::new(self.t, self.tau, earliest_time, earliest_time), start))
} else {
let next = cmp::max(tat, t0) + weight;
Ok(((), next))
}
})
}
}