use core::time::Duration;
use clock_lib::{Clock, SystemClock};
use crate::decision::Decision;
use crate::key::Key;
use crate::limiter::RateLimiter;
pub struct AsyncLimiter<C: Clock + Clone = SystemClock> {
inner: RateLimiter<C>,
}
impl<C: Clock + Clone> AsyncLimiter<C> {
#[must_use]
pub fn new(inner: RateLimiter<C>) -> Self {
Self { inner }
}
#[must_use]
pub fn inner(&self) -> &RateLimiter<C> {
&self.inner
}
#[must_use]
pub fn into_inner(self) -> RateLimiter<C> {
self.inner
}
pub fn check(&self, key: impl Into<Key>) -> Decision {
self.inner.check(key)
}
pub fn check_n(&self, key: impl Into<Key>, n: u32) -> Decision {
self.inner.check_n(key, n)
}
pub async fn until_ready(&self, key: impl Into<Key>) {
self.until_ready_n(key, 1).await;
}
pub async fn until_ready_n(&self, key: impl Into<Key>, n: u32) {
let key = key.into();
loop {
match self.inner.check_n(key.clone(), n) {
Decision::Allow => return,
Decision::Deny { retry_after } => {
if retry_after == Duration::MAX {
return; }
tokio::time::sleep(retry_after).await;
}
}
}
}
}
impl<C: Clock + Clone> From<RateLimiter<C>> for AsyncLimiter<C> {
fn from(inner: RateLimiter<C>) -> Self {
Self::new(inner)
}
}
#[cfg(test)]
mod tests {
use super::AsyncLimiter;
use crate::limiter::RateLimiter;
#[test]
fn test_check_passthrough_is_sync() {
let limiter = AsyncLimiter::new(RateLimiter::per_second(1));
assert!(limiter.check("k").is_allow());
assert!(limiter.check("k").is_deny());
}
#[tokio::test]
async fn test_until_ready_admits_after_waiting() {
let limiter = AsyncLimiter::new(RateLimiter::per_second(200));
for _ in 0..200 {
assert!(limiter.check("k").is_allow());
}
assert!(limiter.check("k").is_deny());
let completed =
tokio::time::timeout(std::time::Duration::from_secs(2), limiter.until_ready("k")).await;
assert!(completed.is_ok(), "until_ready did not complete within 2s");
}
#[tokio::test]
async fn test_until_ready_n_gives_up_when_impossible() {
let limiter = AsyncLimiter::new(RateLimiter::per_second(5));
limiter.until_ready_n("k", 6).await;
}
#[test]
fn test_from_and_into_inner_round_trip() {
let limiter: AsyncLimiter = RateLimiter::per_second(10).into();
assert_eq!(limiter.inner().quota().limit(), 10);
let back = limiter.into_inner();
assert_eq!(back.quota().limit(), 10);
}
}