use std::time::Duration;
use derive_more::{Debug, Display, Error};
#[derive(Debug, Display, Error)]
#[display("rate limited: retry after {retry_after:?}")]
pub struct RateLimitError {
pub retry_after: Duration,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateLimitResult {
Allowed,
Limited {
retry_after: Duration,
},
}
impl RateLimitResult {
pub const fn is_allowed(&self) -> bool {
matches!(self, Self::Allowed)
}
pub const fn is_limited(&self) -> bool {
matches!(self, Self::Limited { .. })
}
pub const fn retry_after(&self) -> Option<Duration> {
match self {
Self::Allowed => None,
Self::Limited { retry_after } => Some(*retry_after),
}
}
}
pub trait RateLimiter: Send + Sync + 'static {
fn check(&self, key: &str) -> RateLimitResult;
fn record(&self, key: &str);
fn check_and_record(&self, key: &str) -> RateLimitResult {
let result = self.check(key);
if result.is_allowed() {
self.record(key);
}
result
}
fn reset(&self, key: &str);
fn count(&self, key: &str) -> u64;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limit_result_allowed() {
let result = RateLimitResult::Allowed;
assert!(result.is_allowed());
assert!(!result.is_limited());
assert_eq!(result.retry_after(), None);
}
#[test]
fn test_rate_limit_result_limited() {
let result = RateLimitResult::Limited { retry_after: Duration::from_secs(60) };
assert!(!result.is_allowed());
assert!(result.is_limited());
assert_eq!(result.retry_after(), Some(Duration::from_secs(60)));
}
#[test]
fn test_rate_limit_error_display() {
let error = RateLimitError { retry_after: Duration::from_secs(30) };
assert!(error.to_string().contains("30"));
}
}