use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
#[derive(Debug)]
pub(crate) struct GcraState {
tat_nanos: AtomicU64,
}
impl GcraState {
pub fn new() -> Self {
Self {
tat_nanos: AtomicU64::new(0),
}
}
pub fn tat(&self, ordering: Ordering) -> u64 {
self.tat_nanos.load(ordering)
}
pub fn try_acquire(
&self,
now_nanos: u64,
emission_interval_nanos: u64,
limit_nanos: u64,
) -> Result<(), Duration> {
loop {
let tat = self.tat_nanos.load(Ordering::Acquire);
let new_tat = if tat <= now_nanos {
now_nanos.saturating_add(emission_interval_nanos)
} else {
tat.saturating_add(emission_interval_nanos)
};
let limit_at = now_nanos.saturating_add(limit_nanos);
if new_tat > limit_at {
let wait_nanos = new_tat.saturating_sub(limit_at);
return Err(Duration::from_nanos(wait_nanos));
}
match self.tat_nanos.compare_exchange_weak(
tat,
new_tat,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => return Ok(()),
Err(_) => continue, }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gcra_allows_burst() {
let state = GcraState::new();
let emission_interval = Duration::from_millis(100); let window = Duration::from_secs(1);
let now = 0u64;
let emission_nanos = emission_interval.as_nanos() as u64;
let limit_nanos = window.as_nanos() as u64;
for _ in 0..10 {
assert!(state.try_acquire(now, emission_nanos, limit_nanos).is_ok());
}
assert!(state.try_acquire(now, emission_nanos, limit_nanos).is_err());
}
#[test]
fn test_gcra_recovers_after_time() {
let state = GcraState::new();
let emission_interval = Duration::from_millis(100);
let window = Duration::from_secs(1);
let emission_nanos = emission_interval.as_nanos() as u64;
let limit_nanos = window.as_nanos() as u64;
let now = 0u64;
for _ in 0..10 {
let _ = state.try_acquire(now, emission_nanos, limit_nanos);
}
let now = Duration::from_millis(100).as_nanos() as u64;
assert!(state.try_acquire(now, emission_nanos, limit_nanos).is_ok());
}
}