use std::collections::VecDeque;
use std::time::{Duration, Instant};
#[cfg(feature = "redis-storage")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PolicyError {
ZeroMaxCount,
ZeroMaxEvents,
ZeroWindowDuration,
ZeroCapacity,
ZeroRefillRate,
}
impl std::fmt::Display for PolicyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PolicyError::ZeroMaxCount => write!(f, "max_count must be greater than 0"),
PolicyError::ZeroMaxEvents => write!(f, "max_events must be greater than 0"),
PolicyError::ZeroWindowDuration => write!(f, "window duration must be greater than 0"),
PolicyError::ZeroCapacity => write!(f, "capacity must be greater than 0"),
PolicyError::ZeroRefillRate => write!(f, "refill_rate must be greater than 0"),
}
}
}
impl std::error::Error for PolicyError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PolicyDecision {
Allow,
Suppress,
}
pub trait RateLimitPolicy: Send + Sync {
fn register_event(&mut self, timestamp: Instant) -> PolicyDecision;
fn reset(&mut self);
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "redis-storage", derive(Serialize, Deserialize))]
pub struct CountBasedPolicy {
max_count: usize,
current_count: usize,
}
impl CountBasedPolicy {
pub fn new(max_count: usize) -> Result<Self, PolicyError> {
if max_count == 0 {
return Err(PolicyError::ZeroMaxCount);
}
Ok(Self {
max_count,
current_count: 0,
})
}
}
impl RateLimitPolicy for CountBasedPolicy {
fn register_event(&mut self, _timestamp: Instant) -> PolicyDecision {
self.current_count += 1;
if self.current_count <= self.max_count {
PolicyDecision::Allow
} else {
PolicyDecision::Suppress
}
}
fn reset(&mut self) {
self.current_count = 0;
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TimeWindowPolicy {
max_events: usize,
window_duration: Duration,
event_timestamps: VecDeque<Instant>,
}
#[cfg(feature = "redis-storage")]
impl Serialize for TimeWindowPolicy {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let base = self.event_timestamps.front().copied();
let timestamps_nanos: Vec<u64> = if let Some(base_instant) = base {
self.event_timestamps
.iter()
.map(|instant| {
instant
.saturating_duration_since(base_instant)
.as_nanos()
.min(u64::MAX as u128) as u64
})
.collect()
} else {
Vec::new()
};
let mut state = serializer.serialize_struct("TimeWindowPolicy", 4)?;
state.serialize_field("max_events", &self.max_events)?;
state.serialize_field("window_duration_nanos", &self.window_duration.as_nanos())?;
state.serialize_field("timestamps_nanos", ×tamps_nanos)?;
state.serialize_field("base_timestamp_nanos", &base.map(|_| 0u64))?;
state.end()
}
}
#[cfg(feature = "redis-storage")]
impl<'de> Deserialize<'de> for TimeWindowPolicy {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, MapAccess, Visitor};
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "snake_case")]
enum Field {
MaxEvents,
WindowDurationNanos,
TimestampsNanos,
BaseTimestampNanos,
}
struct TimeWindowPolicyVisitor;
impl<'de> Visitor<'de> for TimeWindowPolicyVisitor {
type Value = TimeWindowPolicy;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("struct TimeWindowPolicy")
}
fn visit_map<V>(self, mut map: V) -> Result<TimeWindowPolicy, V::Error>
where
V: MapAccess<'de>,
{
let mut max_events = None;
let mut window_duration_nanos = None;
let mut timestamps_nanos = None;
let mut _base_timestamp_nanos = None;
while let Some(key) = map.next_key()? {
match key {
Field::MaxEvents => {
if max_events.is_some() {
return Err(de::Error::duplicate_field("max_events"));
}
max_events = Some(map.next_value()?);
}
Field::WindowDurationNanos => {
if window_duration_nanos.is_some() {
return Err(de::Error::duplicate_field("window_duration_nanos"));
}
window_duration_nanos = Some(map.next_value()?);
}
Field::TimestampsNanos => {
if timestamps_nanos.is_some() {
return Err(de::Error::duplicate_field("timestamps_nanos"));
}
timestamps_nanos = Some(map.next_value()?);
}
Field::BaseTimestampNanos => {
_base_timestamp_nanos = Some(map.next_value::<Option<u64>>()?);
}
}
}
let max_events =
max_events.ok_or_else(|| de::Error::missing_field("max_events"))?;
let window_duration_nanos: u128 = window_duration_nanos
.ok_or_else(|| de::Error::missing_field("window_duration_nanos"))?;
let timestamps_nanos: Vec<u64> =
timestamps_nanos.ok_or_else(|| de::Error::missing_field("timestamps_nanos"))?;
let now = Instant::now();
let event_timestamps: VecDeque<Instant> = timestamps_nanos
.into_iter()
.map(|nanos| now.checked_add(Duration::from_nanos(nanos)).unwrap_or(now))
.collect();
Ok(TimeWindowPolicy {
max_events,
window_duration: Duration::from_nanos(window_duration_nanos as u64),
event_timestamps,
})
}
}
const FIELDS: &[&str] = &[
"max_events",
"window_duration_nanos",
"timestamps_nanos",
"base_timestamp_nanos",
];
deserializer.deserialize_struct("TimeWindowPolicy", FIELDS, TimeWindowPolicyVisitor)
}
}
impl TimeWindowPolicy {
pub fn new(max_events: usize, window_duration: Duration) -> Result<Self, PolicyError> {
if max_events == 0 {
return Err(PolicyError::ZeroMaxEvents);
}
if window_duration.is_zero() {
return Err(PolicyError::ZeroWindowDuration);
}
Ok(Self {
max_events,
window_duration,
event_timestamps: VecDeque::new(),
})
}
fn expire_old_events(&mut self, current_time: Instant) {
while let Some(&oldest) = self.event_timestamps.front() {
if current_time.saturating_duration_since(oldest) > self.window_duration {
self.event_timestamps.pop_front();
} else {
break;
}
}
}
}
impl RateLimitPolicy for TimeWindowPolicy {
fn register_event(&mut self, timestamp: Instant) -> PolicyDecision {
self.expire_old_events(timestamp);
if self.event_timestamps.len() < self.max_events {
self.event_timestamps.push_back(timestamp);
PolicyDecision::Allow
} else {
PolicyDecision::Suppress
}
}
fn reset(&mut self) {
self.event_timestamps.clear();
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "redis-storage", derive(Serialize, Deserialize))]
pub struct ExponentialBackoffPolicy {
event_count: u64,
next_allowed: u64,
}
impl ExponentialBackoffPolicy {
pub fn new() -> Self {
Self {
event_count: 0,
next_allowed: 1,
}
}
}
impl Default for ExponentialBackoffPolicy {
fn default() -> Self {
Self::new()
}
}
impl RateLimitPolicy for ExponentialBackoffPolicy {
fn register_event(&mut self, _timestamp: Instant) -> PolicyDecision {
self.event_count += 1;
if self.event_count == self.next_allowed {
self.next_allowed = self.next_allowed.saturating_mul(2);
PolicyDecision::Allow
} else {
PolicyDecision::Suppress
}
}
fn reset(&mut self) {
self.event_count = 0;
self.next_allowed = 1;
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TokenBucketPolicy {
capacity: f64,
refill_rate: f64,
tokens: f64,
last_refill: Option<Instant>,
}
#[cfg(feature = "redis-storage")]
impl Serialize for TokenBucketPolicy {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("TokenBucketPolicy", 4)?;
state.serialize_field("capacity", &self.capacity)?;
state.serialize_field("refill_rate", &self.refill_rate)?;
state.serialize_field("tokens", &self.tokens)?;
state.serialize_field("has_last_refill", &self.last_refill.is_some())?;
state.end()
}
}
#[cfg(feature = "redis-storage")]
impl<'de> Deserialize<'de> for TokenBucketPolicy {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, MapAccess, Visitor};
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "snake_case")]
enum Field {
Capacity,
RefillRate,
Tokens,
HasLastRefill,
}
struct TokenBucketPolicyVisitor;
impl<'de> Visitor<'de> for TokenBucketPolicyVisitor {
type Value = TokenBucketPolicy;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("struct TokenBucketPolicy")
}
fn visit_map<V>(self, mut map: V) -> Result<TokenBucketPolicy, V::Error>
where
V: MapAccess<'de>,
{
let mut capacity = None;
let mut refill_rate = None;
let mut tokens = None;
let mut has_last_refill = None;
while let Some(key) = map.next_key()? {
match key {
Field::Capacity => {
if capacity.is_some() {
return Err(de::Error::duplicate_field("capacity"));
}
capacity = Some(map.next_value()?);
}
Field::RefillRate => {
if refill_rate.is_some() {
return Err(de::Error::duplicate_field("refill_rate"));
}
refill_rate = Some(map.next_value()?);
}
Field::Tokens => {
if tokens.is_some() {
return Err(de::Error::duplicate_field("tokens"));
}
tokens = Some(map.next_value()?);
}
Field::HasLastRefill => {
has_last_refill = Some(map.next_value()?);
}
}
}
let capacity = capacity.ok_or_else(|| de::Error::missing_field("capacity"))?;
let refill_rate =
refill_rate.ok_or_else(|| de::Error::missing_field("refill_rate"))?;
let tokens = tokens.ok_or_else(|| de::Error::missing_field("tokens"))?;
let _has_last_refill = has_last_refill.unwrap_or(false);
Ok(TokenBucketPolicy {
capacity,
refill_rate,
tokens,
last_refill: None,
})
}
}
const FIELDS: &[&str] = &["capacity", "refill_rate", "tokens", "has_last_refill"];
deserializer.deserialize_struct("TokenBucketPolicy", FIELDS, TokenBucketPolicyVisitor)
}
}
impl TokenBucketPolicy {
pub fn new(capacity: f64, refill_rate: f64) -> Result<Self, PolicyError> {
if capacity <= 0.0 {
return Err(PolicyError::ZeroCapacity);
}
if refill_rate <= 0.0 {
return Err(PolicyError::ZeroRefillRate);
}
Ok(Self {
capacity,
refill_rate,
tokens: capacity,
last_refill: None,
})
}
fn refill(&mut self, now: Instant) {
if let Some(last) = self.last_refill {
if now < last {
self.last_refill = Some(now);
return;
}
let elapsed = now.duration_since(last).as_secs_f64();
let new_tokens = elapsed * self.refill_rate;
self.tokens = (self.tokens + new_tokens).min(self.capacity);
}
self.last_refill = Some(now);
}
}
impl RateLimitPolicy for TokenBucketPolicy {
fn register_event(&mut self, timestamp: Instant) -> PolicyDecision {
self.refill(timestamp);
if self.tokens >= 1.0 {
self.tokens -= 1.0;
PolicyDecision::Allow
} else {
PolicyDecision::Suppress
}
}
fn reset(&mut self) {
self.tokens = self.capacity;
self.last_refill = None;
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "redis-storage", derive(Serialize, Deserialize))]
pub enum Policy {
CountBased(CountBasedPolicy),
TimeWindow(TimeWindowPolicy),
ExponentialBackoff(ExponentialBackoffPolicy),
TokenBucket(TokenBucketPolicy),
}
impl Policy {
pub fn count_based(max_count: usize) -> Result<Self, PolicyError> {
Ok(Policy::CountBased(CountBasedPolicy::new(max_count)?))
}
pub fn time_window(max_events: usize, window: Duration) -> Result<Self, PolicyError> {
Ok(Policy::TimeWindow(TimeWindowPolicy::new(
max_events, window,
)?))
}
pub fn exponential_backoff() -> Self {
Policy::ExponentialBackoff(ExponentialBackoffPolicy::new())
}
pub fn token_bucket(capacity: f64, refill_rate: f64) -> Result<Self, PolicyError> {
Ok(Policy::TokenBucket(TokenBucketPolicy::new(
capacity,
refill_rate,
)?))
}
}
impl RateLimitPolicy for Policy {
fn register_event(&mut self, timestamp: Instant) -> PolicyDecision {
match self {
Policy::CountBased(p) => p.register_event(timestamp),
Policy::TimeWindow(p) => p.register_event(timestamp),
Policy::ExponentialBackoff(p) => p.register_event(timestamp),
Policy::TokenBucket(p) => p.register_event(timestamp),
}
}
fn reset(&mut self) {
match self {
Policy::CountBased(p) => p.reset(),
Policy::TimeWindow(p) => p.reset(),
Policy::ExponentialBackoff(p) => p.reset(),
Policy::TokenBucket(p) => p.reset(),
}
}
}
impl PolicyDecision {
pub fn is_allow(&self) -> bool {
matches!(self, PolicyDecision::Allow)
}
pub fn is_suppress(&self) -> bool {
matches!(self, PolicyDecision::Suppress)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_count_based_policy() {
let mut policy = CountBasedPolicy::new(3).unwrap();
let now = Instant::now();
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
policy.reset();
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
}
#[test]
fn test_time_window_policy() {
let mut policy = TimeWindowPolicy::new(2, Duration::from_secs(1)).unwrap();
let now = Instant::now();
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
let later = now + Duration::from_secs(2);
assert_eq!(policy.register_event(later), PolicyDecision::Allow);
}
#[test]
fn test_exponential_backoff_policy() {
let mut policy = ExponentialBackoffPolicy::new();
let now = Instant::now();
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
}
#[test]
fn test_policy_enum() {
let mut policy = Policy::count_based(2).unwrap();
let now = Instant::now();
assert!(policy.register_event(now).is_allow());
assert!(policy.register_event(now).is_allow());
assert!(policy.register_event(now).is_suppress());
}
#[test]
fn test_count_based_policy_zero_limit() {
let result = CountBasedPolicy::new(0);
assert_eq!(result, Err(PolicyError::ZeroMaxCount));
}
#[test]
fn test_count_based_policy_one_limit() {
let mut policy = CountBasedPolicy::new(1).unwrap();
let now = Instant::now();
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
}
#[test]
fn test_count_based_policy_reset() {
let mut policy = CountBasedPolicy::new(2).unwrap();
let now = Instant::now();
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
policy.reset();
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
}
#[test]
fn test_time_window_policy_zero_duration() {
let result = TimeWindowPolicy::new(2, Duration::from_secs(0));
assert_eq!(result, Err(PolicyError::ZeroWindowDuration));
}
#[test]
fn test_time_window_policy_rapid_events() {
let mut policy = TimeWindowPolicy::new(3, Duration::from_millis(100)).unwrap();
let now = Instant::now();
for i in 0..10 {
let decision = policy.register_event(now);
if i < 3 {
assert_eq!(
decision,
PolicyDecision::Allow,
"Event {} should be allowed",
i
);
} else {
assert_eq!(
decision,
PolicyDecision::Suppress,
"Event {} should be suppressed",
i
);
}
}
}
#[test]
fn test_time_window_policy_reset() {
let mut policy = TimeWindowPolicy::new(2, Duration::from_secs(60)).unwrap();
let now = Instant::now();
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
policy.reset();
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
}
#[test]
fn test_exponential_backoff_large_count() {
let mut policy = ExponentialBackoffPolicy::new();
let now = Instant::now();
let expected_allowed = [0, 1, 3, 7, 15, 31, 63];
for i in 0..100 {
let decision = policy.register_event(now);
if expected_allowed.contains(&i) {
assert_eq!(
decision,
PolicyDecision::Allow,
"Event {} should be allowed",
i + 1
);
} else {
assert_eq!(
decision,
PolicyDecision::Suppress,
"Event {} should be suppressed",
i + 1
);
}
}
}
#[test]
fn test_exponential_backoff_reset() {
let mut policy = ExponentialBackoffPolicy::new();
let now = Instant::now();
assert_eq!(policy.register_event(now), PolicyDecision::Allow); assert_eq!(policy.register_event(now), PolicyDecision::Allow); assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
policy.reset();
assert_eq!(policy.register_event(now), PolicyDecision::Allow); }
#[test]
fn test_token_bucket_basic_consumption() {
let mut policy = TokenBucketPolicy::new(3.0, 1.0).unwrap();
let now = Instant::now();
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
}
#[test]
fn test_token_bucket_refill_over_time() {
let mut policy = TokenBucketPolicy::new(10.0, 10.0).unwrap(); let now = Instant::now();
for _ in 0..10 {
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
}
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
let later = now + Duration::from_millis(500);
for i in 0..5 {
assert_eq!(
policy.register_event(later),
PolicyDecision::Allow,
"Event {} should be allowed after refill",
i
);
}
assert_eq!(policy.register_event(later), PolicyDecision::Suppress);
}
#[test]
fn test_token_bucket_burst_tolerance() {
let mut policy = TokenBucketPolicy::new(100.0, 1.0).unwrap();
let now = Instant::now();
for i in 0..100 {
assert_eq!(
policy.register_event(now),
PolicyDecision::Allow,
"Event {} in burst should be allowed",
i
);
}
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
}
#[test]
fn test_token_bucket_sustained_rate() {
let mut policy = TokenBucketPolicy::new(10.0, 10.0).unwrap(); let now = Instant::now();
for _ in 0..10 {
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
}
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
let later = now + Duration::from_secs(1);
for i in 0..10 {
assert_eq!(
policy.register_event(later),
PolicyDecision::Allow,
"Event {} after 1s should be allowed",
i
);
}
assert_eq!(policy.register_event(later), PolicyDecision::Suppress);
let even_later = later + Duration::from_millis(500);
for i in 0..5 {
assert_eq!(
policy.register_event(even_later),
PolicyDecision::Allow,
"Event {} after 0.5s should be allowed",
i
);
}
assert_eq!(policy.register_event(even_later), PolicyDecision::Suppress);
}
#[test]
fn test_token_bucket_recovery_after_quiet() {
let mut policy = TokenBucketPolicy::new(5.0, 2.0).unwrap();
let now = Instant::now();
for _ in 0..5 {
policy.register_event(now);
}
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
let much_later = now + Duration::from_secs(10);
for i in 0..5 {
assert_eq!(
policy.register_event(much_later),
PolicyDecision::Allow,
"Event {} after recovery should be allowed",
i
);
}
assert_eq!(policy.register_event(much_later), PolicyDecision::Suppress);
}
#[test]
fn test_token_bucket_fractional_refill() {
let mut policy = TokenBucketPolicy::new(10.0, 0.5).unwrap(); let now = Instant::now();
for _ in 0..10 {
policy.register_event(now);
}
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
let later = now + Duration::from_secs(3);
assert_eq!(policy.register_event(later), PolicyDecision::Allow);
assert_eq!(policy.register_event(later), PolicyDecision::Suppress);
let even_later = later + Duration::from_secs(1);
assert_eq!(policy.register_event(even_later), PolicyDecision::Allow);
assert_eq!(policy.register_event(even_later), PolicyDecision::Suppress);
}
#[test]
fn test_token_bucket_reset() {
let mut policy = TokenBucketPolicy::new(5.0, 1.0).unwrap();
let now = Instant::now();
for _ in 0..5 {
policy.register_event(now);
}
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
policy.reset();
for i in 0..5 {
assert_eq!(
policy.register_event(now),
PolicyDecision::Allow,
"Event {} after reset should be allowed",
i
);
}
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
}
#[test]
fn test_token_bucket_capacity_cap() {
let mut policy = TokenBucketPolicy::new(5.0, 10.0).unwrap();
let now = Instant::now();
for _ in 0..3 {
policy.register_event(now);
}
let much_later = now + Duration::from_secs(100);
for i in 0..5 {
assert_eq!(
policy.register_event(much_later),
PolicyDecision::Allow,
"Event {} should be allowed (capped at capacity)",
i
);
}
assert_eq!(policy.register_event(much_later), PolicyDecision::Suppress);
}
#[test]
fn test_token_bucket_zero_capacity() {
let result = TokenBucketPolicy::new(0.0, 1.0);
assert_eq!(result, Err(PolicyError::ZeroCapacity));
}
#[test]
fn test_token_bucket_negative_capacity() {
let result = TokenBucketPolicy::new(-5.0, 1.0);
assert_eq!(result, Err(PolicyError::ZeroCapacity));
}
#[test]
fn test_token_bucket_zero_refill_rate() {
let result = TokenBucketPolicy::new(10.0, 0.0);
assert_eq!(result, Err(PolicyError::ZeroRefillRate));
}
#[test]
fn test_token_bucket_negative_refill_rate() {
let result = TokenBucketPolicy::new(10.0, -2.0);
assert_eq!(result, Err(PolicyError::ZeroRefillRate));
}
#[test]
fn test_token_bucket_policy_enum() {
let mut policy = Policy::token_bucket(5.0, 2.0).unwrap();
let now = Instant::now();
for i in 0..5 {
assert!(
policy.register_event(now).is_allow(),
"Event {} should be allowed",
i
);
}
assert!(policy.register_event(now).is_suppress());
policy.reset();
assert!(policy.register_event(now).is_allow());
}
#[test]
fn test_token_bucket_incremental_refill() {
let mut policy = TokenBucketPolicy::new(1.0, 10.0).unwrap(); let now = Instant::now();
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
let t1 = now + Duration::from_millis(100);
assert_eq!(policy.register_event(t1), PolicyDecision::Allow);
assert_eq!(policy.register_event(t1), PolicyDecision::Suppress);
let t2 = t1 + Duration::from_millis(100);
assert_eq!(policy.register_event(t2), PolicyDecision::Allow);
assert_eq!(policy.register_event(t2), PolicyDecision::Suppress);
}
#[test]
fn test_token_bucket_same_timestamp_multiple_events() {
let mut policy = TokenBucketPolicy::new(5.0, 2.0).unwrap();
let start = Instant::now();
for i in 0..5 {
assert_eq!(
policy.register_event(start),
PolicyDecision::Allow,
"Event {} should be allowed",
i
);
}
for i in 5..8 {
assert_eq!(
policy.register_event(start),
PolicyDecision::Suppress,
"Event {} should be suppressed (no tokens)",
i
);
}
let t1 = start + Duration::from_secs(1);
assert_eq!(
policy.register_event(t1),
PolicyDecision::Allow,
"First event after 1s should be allowed"
);
assert_eq!(
policy.register_event(t1),
PolicyDecision::Allow,
"Second event after 1s should be allowed"
);
assert_eq!(
policy.register_event(t1),
PolicyDecision::Suppress,
"Third event after 1s should be suppressed (only 2 tokens refilled)"
);
assert_eq!(policy.register_event(t1), PolicyDecision::Suppress);
assert_eq!(policy.register_event(t1), PolicyDecision::Suppress);
}
#[test]
fn test_token_bucket_time_goes_backwards() {
let mut policy = TokenBucketPolicy::new(10.0, 5.0).unwrap();
let now = Instant::now();
for _ in 0..5 {
assert_eq!(policy.register_event(now), PolicyDecision::Allow);
}
let future = now + Duration::from_secs(1);
for _ in 0..10 {
assert_eq!(policy.register_event(future), PolicyDecision::Allow);
}
let past = now + Duration::from_millis(500);
assert!(past < future, "Test setup: past must be before future");
assert_eq!(
policy.register_event(past),
PolicyDecision::Suppress,
"Should suppress when no tokens available after time went backwards"
);
let future2 = past + Duration::from_secs(1);
for i in 0..5 {
assert_eq!(
policy.register_event(future2),
PolicyDecision::Allow,
"Token {} should be available after normal time progression",
i
);
}
assert_eq!(policy.register_event(future2), PolicyDecision::Suppress);
}
#[test]
fn test_time_window_with_many_events() {
let mut policy = TimeWindowPolicy::new(100, Duration::from_secs(60)).unwrap();
let now = Instant::now();
for i in 0..100 {
let timestamp = now + Duration::from_millis(i * 10);
policy.register_event(timestamp);
}
assert_eq!(
policy.register_event(now + Duration::from_millis(1000)),
PolicyDecision::Suppress
);
let later = now + Duration::from_secs(70);
assert_eq!(policy.register_event(later), PolicyDecision::Allow);
}
}