use std::{cmp, fmt::Display, time::Duration};
use super::{StateStore, clock, nanos::Nanos, quota::Quota};
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct StateSnapshot {
t: Nanos,
tau: Nanos,
pub(crate) time_of_measurement: Nanos,
pub(crate) tat: Nanos,
}
impl StateSnapshot {
#[inline]
pub(crate) const fn new(t: Nanos, tau: Nanos, time_of_measurement: Nanos, tat: Nanos) -> Self {
Self {
t,
tau,
time_of_measurement,
tat,
}
}
pub fn quota(&self) -> Quota {
Quota::from_gcra_parameters(self.t, self.tau)
}
#[allow(dead_code)]
pub fn remaining_burst_capacity(&self) -> u32 {
let t = self.t.as_u64();
if t == 0 {
return 0;
}
let t0 = self.time_of_measurement + self.t;
(cmp::min(
(t0 + self.tau).saturating_sub(self.tat).as_u64(),
self.tau.as_u64(),
) / t) as u32
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct NotUntil<P: clock::Reference> {
state: StateSnapshot,
start: P,
}
impl<P: clock::Reference> NotUntil<P> {
#[inline]
pub(crate) const fn new(state: StateSnapshot, start: P) -> Self {
Self { state, start }
}
#[inline]
pub fn earliest_possible(&self) -> P {
let tat: Nanos = self.state.tat;
self.start + tat
}
#[inline]
pub fn wait_time_from(&self, from: P) -> Duration {
let earliest = self.earliest_possible();
earliest.duration_since(earliest.min(from)).into()
}
#[inline]
pub fn quota(&self) -> Quota {
self.state.quota()
}
}
impl<P: clock::Reference> Display for NotUntil<P> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
write!(f, "rate-limited until {:?}", self.start + self.state.tat)
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct Gcra {
t: Nanos,
tau: Nanos,
}
impl Gcra {
pub(crate) fn new(quota: Quota) -> Self {
let tau: Nanos = (quota.replenish_1_per * quota.max_burst.get()).into();
let t: Nanos = quota.replenish_1_per.into();
Self { t, tau }
}
fn starting_state(&self, t0: Nanos) -> Nanos {
t0 + self.t
}
pub(crate) fn test_and_update<K, S: StateStore<Key = K>, P: clock::Reference>(
&self,
start: P,
key: &K,
state: &S,
t0: P,
) -> Result<(), NotUntil<P>> {
let t0 = t0.duration_since(start);
let tau = self.tau;
let t = self.t;
state.measure_and_replace(key, |tat| {
let tat = tat.unwrap_or_else(|| self.starting_state(t0));
let earliest_time = tat.saturating_sub(tau);
if t0 < earliest_time {
Err(NotUntil::new(
StateSnapshot::new(self.t, self.tau, earliest_time, earliest_time),
start,
))
} else {
let next = cmp::max(tat, t0) + t;
Ok(((), next))
}
})
}
}