1use std::collections::VecDeque;
7use std::time::{Duration, Instant};
8
9#[cfg(feature = "redis-storage")]
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum PolicyError {
18 ZeroMaxCount,
20 ZeroMaxEvents,
22 ZeroWindowDuration,
24 ZeroCapacity,
26 ZeroRefillRate,
28}
29
30impl std::fmt::Display for PolicyError {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 PolicyError::ZeroMaxCount => write!(f, "max_count must be greater than 0"),
34 PolicyError::ZeroMaxEvents => write!(f, "max_events must be greater than 0"),
35 PolicyError::ZeroWindowDuration => write!(f, "window duration must be greater than 0"),
36 PolicyError::ZeroCapacity => write!(f, "capacity must be greater than 0"),
37 PolicyError::ZeroRefillRate => write!(f, "refill_rate must be greater than 0"),
38 }
39 }
40}
41
42impl std::error::Error for PolicyError {}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum PolicyDecision {
47 Allow,
49 Suppress,
51}
52
53pub trait RateLimitPolicy: Send + Sync {
58 fn register_event(&mut self, timestamp: Instant) -> PolicyDecision;
66
67 fn reset(&mut self);
71}
72
73#[derive(Debug, Clone, PartialEq)]
95#[cfg_attr(feature = "redis-storage", derive(Serialize, Deserialize))]
96pub struct CountBasedPolicy {
97 max_count: usize,
98 current_count: usize,
99}
100
101impl CountBasedPolicy {
102 pub fn new(max_count: usize) -> Result<Self, PolicyError> {
110 if max_count == 0 {
111 return Err(PolicyError::ZeroMaxCount);
112 }
113 Ok(Self {
114 max_count,
115 current_count: 0,
116 })
117 }
118}
119
120impl RateLimitPolicy for CountBasedPolicy {
121 fn register_event(&mut self, _timestamp: Instant) -> PolicyDecision {
122 self.current_count += 1;
123 if self.current_count <= self.max_count {
124 PolicyDecision::Allow
125 } else {
126 PolicyDecision::Suppress
127 }
128 }
129
130 fn reset(&mut self) {
131 self.current_count = 0;
132 }
133}
134
135#[derive(Debug, Clone, PartialEq)]
161pub struct TimeWindowPolicy {
162 max_events: usize,
163 window_duration: Duration,
164 event_timestamps: VecDeque<Instant>,
165}
166
167#[cfg(feature = "redis-storage")]
168impl Serialize for TimeWindowPolicy {
169 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
194 where
195 S: serde::Serializer,
196 {
197 use serde::ser::SerializeStruct;
198
199 let base = self.event_timestamps.front().copied();
201 let timestamps_nanos: Vec<u64> = if let Some(base_instant) = base {
202 self.event_timestamps
203 .iter()
204 .map(|instant| {
205 instant
206 .saturating_duration_since(base_instant)
207 .as_nanos()
208 .min(u64::MAX as u128) as u64
209 })
210 .collect()
211 } else {
212 Vec::new()
213 };
214
215 let mut state = serializer.serialize_struct("TimeWindowPolicy", 4)?;
216 state.serialize_field("max_events", &self.max_events)?;
217 state.serialize_field("window_duration_nanos", &self.window_duration.as_nanos())?;
218 state.serialize_field("timestamps_nanos", ×tamps_nanos)?;
219 state.serialize_field("base_timestamp_nanos", &base.map(|_| 0u64))?;
220 state.end()
221 }
222}
223
224#[cfg(feature = "redis-storage")]
225impl<'de> Deserialize<'de> for TimeWindowPolicy {
226 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
232 where
233 D: serde::Deserializer<'de>,
234 {
235 use serde::de::{self, MapAccess, Visitor};
236
237 #[derive(Deserialize)]
238 #[serde(field_identifier, rename_all = "snake_case")]
239 enum Field {
240 MaxEvents,
241 WindowDurationNanos,
242 TimestampsNanos,
243 BaseTimestampNanos,
244 }
245
246 struct TimeWindowPolicyVisitor;
247
248 impl<'de> Visitor<'de> for TimeWindowPolicyVisitor {
249 type Value = TimeWindowPolicy;
250
251 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
252 formatter.write_str("struct TimeWindowPolicy")
253 }
254
255 fn visit_map<V>(self, mut map: V) -> Result<TimeWindowPolicy, V::Error>
256 where
257 V: MapAccess<'de>,
258 {
259 let mut max_events = None;
260 let mut window_duration_nanos = None;
261 let mut timestamps_nanos = None;
262 let mut _base_timestamp_nanos = None;
263
264 while let Some(key) = map.next_key()? {
265 match key {
266 Field::MaxEvents => {
267 if max_events.is_some() {
268 return Err(de::Error::duplicate_field("max_events"));
269 }
270 max_events = Some(map.next_value()?);
271 }
272 Field::WindowDurationNanos => {
273 if window_duration_nanos.is_some() {
274 return Err(de::Error::duplicate_field("window_duration_nanos"));
275 }
276 window_duration_nanos = Some(map.next_value()?);
277 }
278 Field::TimestampsNanos => {
279 if timestamps_nanos.is_some() {
280 return Err(de::Error::duplicate_field("timestamps_nanos"));
281 }
282 timestamps_nanos = Some(map.next_value()?);
283 }
284 Field::BaseTimestampNanos => {
285 _base_timestamp_nanos = Some(map.next_value::<Option<u64>>()?);
286 }
287 }
288 }
289
290 let max_events =
291 max_events.ok_or_else(|| de::Error::missing_field("max_events"))?;
292 let window_duration_nanos: u128 = window_duration_nanos
293 .ok_or_else(|| de::Error::missing_field("window_duration_nanos"))?;
294 let timestamps_nanos: Vec<u64> =
295 timestamps_nanos.ok_or_else(|| de::Error::missing_field("timestamps_nanos"))?;
296
297 let now = Instant::now();
299 let event_timestamps: VecDeque<Instant> = timestamps_nanos
300 .into_iter()
301 .map(|nanos| now.checked_add(Duration::from_nanos(nanos)).unwrap_or(now))
302 .collect();
303
304 Ok(TimeWindowPolicy {
305 max_events,
306 window_duration: Duration::from_nanos(window_duration_nanos as u64),
307 event_timestamps,
308 })
309 }
310 }
311
312 const FIELDS: &[&str] = &[
313 "max_events",
314 "window_duration_nanos",
315 "timestamps_nanos",
316 "base_timestamp_nanos",
317 ];
318 deserializer.deserialize_struct("TimeWindowPolicy", FIELDS, TimeWindowPolicyVisitor)
319 }
320}
321
322impl TimeWindowPolicy {
323 pub fn new(max_events: usize, window_duration: Duration) -> Result<Self, PolicyError> {
333 if max_events == 0 {
334 return Err(PolicyError::ZeroMaxEvents);
335 }
336 if window_duration.is_zero() {
337 return Err(PolicyError::ZeroWindowDuration);
338 }
339 Ok(Self {
340 max_events,
341 window_duration,
342 event_timestamps: VecDeque::new(),
343 })
344 }
345
346 fn expire_old_events(&mut self, current_time: Instant) {
348 while let Some(&oldest) = self.event_timestamps.front() {
349 if current_time.saturating_duration_since(oldest) > self.window_duration {
350 self.event_timestamps.pop_front();
351 } else {
352 break;
353 }
354 }
355 }
356}
357
358impl RateLimitPolicy for TimeWindowPolicy {
359 fn register_event(&mut self, timestamp: Instant) -> PolicyDecision {
360 self.expire_old_events(timestamp);
361
362 if self.event_timestamps.len() < self.max_events {
363 self.event_timestamps.push_back(timestamp);
364 PolicyDecision::Allow
365 } else {
366 PolicyDecision::Suppress
367 }
368 }
369
370 fn reset(&mut self) {
371 self.event_timestamps.clear();
372 }
373}
374
375#[derive(Debug, Clone, PartialEq)]
394#[cfg_attr(feature = "redis-storage", derive(Serialize, Deserialize))]
395pub struct ExponentialBackoffPolicy {
396 event_count: u64,
397 next_allowed: u64,
398}
399
400impl ExponentialBackoffPolicy {
401 pub fn new() -> Self {
403 Self {
404 event_count: 0,
405 next_allowed: 1,
406 }
407 }
408}
409
410impl Default for ExponentialBackoffPolicy {
411 fn default() -> Self {
412 Self::new()
413 }
414}
415
416impl RateLimitPolicy for ExponentialBackoffPolicy {
417 fn register_event(&mut self, _timestamp: Instant) -> PolicyDecision {
418 self.event_count += 1;
419
420 if self.event_count == self.next_allowed {
421 self.next_allowed = self.next_allowed.saturating_mul(2);
422 PolicyDecision::Allow
423 } else {
424 PolicyDecision::Suppress
425 }
426 }
427
428 fn reset(&mut self) {
429 self.event_count = 0;
430 self.next_allowed = 1;
431 }
432}
433
434#[derive(Debug, Clone, PartialEq)]
472pub struct TokenBucketPolicy {
473 capacity: f64,
475 refill_rate: f64,
477 tokens: f64,
479 last_refill: Option<Instant>,
481}
482
483#[cfg(feature = "redis-storage")]
484impl Serialize for TokenBucketPolicy {
485 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
506 where
507 S: serde::Serializer,
508 {
509 use serde::ser::SerializeStruct;
510
511 let mut state = serializer.serialize_struct("TokenBucketPolicy", 4)?;
512 state.serialize_field("capacity", &self.capacity)?;
513 state.serialize_field("refill_rate", &self.refill_rate)?;
514 state.serialize_field("tokens", &self.tokens)?;
515 state.serialize_field("has_last_refill", &self.last_refill.is_some())?;
518 state.end()
519 }
520}
521
522#[cfg(feature = "redis-storage")]
523impl<'de> Deserialize<'de> for TokenBucketPolicy {
524 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
525 where
526 D: serde::Deserializer<'de>,
527 {
528 use serde::de::{self, MapAccess, Visitor};
529
530 #[derive(Deserialize)]
531 #[serde(field_identifier, rename_all = "snake_case")]
532 enum Field {
533 Capacity,
534 RefillRate,
535 Tokens,
536 HasLastRefill,
537 }
538
539 struct TokenBucketPolicyVisitor;
540
541 impl<'de> Visitor<'de> for TokenBucketPolicyVisitor {
542 type Value = TokenBucketPolicy;
543
544 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
545 formatter.write_str("struct TokenBucketPolicy")
546 }
547
548 fn visit_map<V>(self, mut map: V) -> Result<TokenBucketPolicy, V::Error>
549 where
550 V: MapAccess<'de>,
551 {
552 let mut capacity = None;
553 let mut refill_rate = None;
554 let mut tokens = None;
555 let mut has_last_refill = None;
556
557 while let Some(key) = map.next_key()? {
558 match key {
559 Field::Capacity => {
560 if capacity.is_some() {
561 return Err(de::Error::duplicate_field("capacity"));
562 }
563 capacity = Some(map.next_value()?);
564 }
565 Field::RefillRate => {
566 if refill_rate.is_some() {
567 return Err(de::Error::duplicate_field("refill_rate"));
568 }
569 refill_rate = Some(map.next_value()?);
570 }
571 Field::Tokens => {
572 if tokens.is_some() {
573 return Err(de::Error::duplicate_field("tokens"));
574 }
575 tokens = Some(map.next_value()?);
576 }
577 Field::HasLastRefill => {
578 has_last_refill = Some(map.next_value()?);
579 }
580 }
581 }
582
583 let capacity = capacity.ok_or_else(|| de::Error::missing_field("capacity"))?;
584 let refill_rate =
585 refill_rate.ok_or_else(|| de::Error::missing_field("refill_rate"))?;
586 let tokens = tokens.ok_or_else(|| de::Error::missing_field("tokens"))?;
587 let _has_last_refill = has_last_refill.unwrap_or(false);
588
589 Ok(TokenBucketPolicy {
592 capacity,
593 refill_rate,
594 tokens,
595 last_refill: None,
596 })
597 }
598 }
599
600 const FIELDS: &[&str] = &["capacity", "refill_rate", "tokens", "has_last_refill"];
601 deserializer.deserialize_struct("TokenBucketPolicy", FIELDS, TokenBucketPolicyVisitor)
602 }
603}
604
605impl TokenBucketPolicy {
606 pub fn new(capacity: f64, refill_rate: f64) -> Result<Self, PolicyError> {
624 if capacity <= 0.0 {
625 return Err(PolicyError::ZeroCapacity);
626 }
627 if refill_rate <= 0.0 {
628 return Err(PolicyError::ZeroRefillRate);
629 }
630
631 Ok(Self {
632 capacity,
633 refill_rate,
634 tokens: capacity,
635 last_refill: None,
636 })
637 }
638
639 fn refill(&mut self, now: Instant) {
644 if let Some(last) = self.last_refill {
645 if now < last {
647 self.last_refill = Some(now);
648 return;
649 }
650
651 let elapsed = now.duration_since(last).as_secs_f64();
652 let new_tokens = elapsed * self.refill_rate;
653 self.tokens = (self.tokens + new_tokens).min(self.capacity);
654 }
655 self.last_refill = Some(now);
656 }
657}
658
659impl RateLimitPolicy for TokenBucketPolicy {
660 fn register_event(&mut self, timestamp: Instant) -> PolicyDecision {
661 self.refill(timestamp);
663
664 if self.tokens >= 1.0 {
666 self.tokens -= 1.0;
667 PolicyDecision::Allow
668 } else {
669 PolicyDecision::Suppress
670 }
671 }
672
673 fn reset(&mut self) {
674 self.tokens = self.capacity;
675 self.last_refill = None;
676 }
677}
678
679#[derive(Debug, Clone)]
681#[cfg_attr(feature = "redis-storage", derive(Serialize, Deserialize))]
682pub enum Policy {
683 CountBased(CountBasedPolicy),
685 TimeWindow(TimeWindowPolicy),
687 ExponentialBackoff(ExponentialBackoffPolicy),
689 TokenBucket(TokenBucketPolicy),
691}
692
693impl Policy {
694 pub fn count_based(max_count: usize) -> Result<Self, PolicyError> {
699 Ok(Policy::CountBased(CountBasedPolicy::new(max_count)?))
700 }
701
702 pub fn time_window(max_events: usize, window: Duration) -> Result<Self, PolicyError> {
708 Ok(Policy::TimeWindow(TimeWindowPolicy::new(
709 max_events, window,
710 )?))
711 }
712
713 pub fn exponential_backoff() -> Self {
717 Policy::ExponentialBackoff(ExponentialBackoffPolicy::new())
718 }
719
720 pub fn token_bucket(capacity: f64, refill_rate: f64) -> Result<Self, PolicyError> {
738 Ok(Policy::TokenBucket(TokenBucketPolicy::new(
739 capacity,
740 refill_rate,
741 )?))
742 }
743}
744
745impl RateLimitPolicy for Policy {
746 fn register_event(&mut self, timestamp: Instant) -> PolicyDecision {
747 match self {
748 Policy::CountBased(p) => p.register_event(timestamp),
749 Policy::TimeWindow(p) => p.register_event(timestamp),
750 Policy::ExponentialBackoff(p) => p.register_event(timestamp),
751 Policy::TokenBucket(p) => p.register_event(timestamp),
752 }
753 }
754
755 fn reset(&mut self) {
756 match self {
757 Policy::CountBased(p) => p.reset(),
758 Policy::TimeWindow(p) => p.reset(),
759 Policy::ExponentialBackoff(p) => p.reset(),
760 Policy::TokenBucket(p) => p.reset(),
761 }
762 }
763}
764
765impl PolicyDecision {
766 pub fn is_allow(&self) -> bool {
768 matches!(self, PolicyDecision::Allow)
769 }
770
771 pub fn is_suppress(&self) -> bool {
773 matches!(self, PolicyDecision::Suppress)
774 }
775}
776
777#[cfg(test)]
778mod tests {
779 use super::*;
780
781 #[test]
782 fn test_count_based_policy() {
783 let mut policy = CountBasedPolicy::new(3).unwrap();
784 let now = Instant::now();
785
786 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
787 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
788 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
789 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
790 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
791
792 policy.reset();
793 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
794 }
795
796 #[test]
797 fn test_time_window_policy() {
798 let mut policy = TimeWindowPolicy::new(2, Duration::from_secs(1)).unwrap();
799 let now = Instant::now();
800
801 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
802 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
803 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
804
805 let later = now + Duration::from_secs(2);
807 assert_eq!(policy.register_event(later), PolicyDecision::Allow);
808 }
809
810 #[test]
811 fn test_exponential_backoff_policy() {
812 let mut policy = ExponentialBackoffPolicy::new();
813 let now = Instant::now();
814
815 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
817 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
819 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
821 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
823 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
825 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
826 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
827 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
829 }
830
831 #[test]
832 fn test_policy_enum() {
833 let mut policy = Policy::count_based(2).unwrap();
834 let now = Instant::now();
835
836 assert!(policy.register_event(now).is_allow());
837 assert!(policy.register_event(now).is_allow());
838 assert!(policy.register_event(now).is_suppress());
839 }
840
841 #[test]
843 fn test_count_based_policy_zero_limit() {
844 let result = CountBasedPolicy::new(0);
846 assert_eq!(result, Err(PolicyError::ZeroMaxCount));
847 }
848
849 #[test]
850 fn test_count_based_policy_one_limit() {
851 let mut policy = CountBasedPolicy::new(1).unwrap();
852 let now = Instant::now();
853
854 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
856 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
857 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
858 }
859
860 #[test]
861 fn test_count_based_policy_reset() {
862 let mut policy = CountBasedPolicy::new(2).unwrap();
863 let now = Instant::now();
864
865 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
867 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
868 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
869
870 policy.reset();
872 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
873 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
874 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
875 }
876
877 #[test]
878 fn test_time_window_policy_zero_duration() {
879 let result = TimeWindowPolicy::new(2, Duration::from_secs(0));
881 assert_eq!(result, Err(PolicyError::ZeroWindowDuration));
882 }
883
884 #[test]
885 fn test_time_window_policy_rapid_events() {
886 let mut policy = TimeWindowPolicy::new(3, Duration::from_millis(100)).unwrap();
887 let now = Instant::now();
888
889 for i in 0..10 {
891 let decision = policy.register_event(now);
892 if i < 3 {
893 assert_eq!(
894 decision,
895 PolicyDecision::Allow,
896 "Event {} should be allowed",
897 i
898 );
899 } else {
900 assert_eq!(
901 decision,
902 PolicyDecision::Suppress,
903 "Event {} should be suppressed",
904 i
905 );
906 }
907 }
908 }
909
910 #[test]
911 fn test_time_window_policy_reset() {
912 let mut policy = TimeWindowPolicy::new(2, Duration::from_secs(60)).unwrap();
913 let now = Instant::now();
914
915 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
917 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
918 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
919
920 policy.reset();
922 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
923 }
924
925 #[test]
926 fn test_exponential_backoff_large_count() {
927 let mut policy = ExponentialBackoffPolicy::new();
928 let now = Instant::now();
929
930 let expected_allowed = [0, 1, 3, 7, 15, 31, 63]; for i in 0..100 {
933 let decision = policy.register_event(now);
934 if expected_allowed.contains(&i) {
935 assert_eq!(
936 decision,
937 PolicyDecision::Allow,
938 "Event {} should be allowed",
939 i + 1
940 );
941 } else {
942 assert_eq!(
943 decision,
944 PolicyDecision::Suppress,
945 "Event {} should be suppressed",
946 i + 1
947 );
948 }
949 }
950 }
951
952 #[test]
953 fn test_exponential_backoff_reset() {
954 let mut policy = ExponentialBackoffPolicy::new();
955 let now = Instant::now();
956
957 assert_eq!(policy.register_event(now), PolicyDecision::Allow); assert_eq!(policy.register_event(now), PolicyDecision::Allow); assert_eq!(policy.register_event(now), PolicyDecision::Suppress); policy.reset();
964 assert_eq!(policy.register_event(now), PolicyDecision::Allow); }
966
967 #[test]
969 fn test_token_bucket_basic_consumption() {
970 let mut policy = TokenBucketPolicy::new(3.0, 1.0).unwrap();
971 let now = Instant::now();
972
973 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
975 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
976 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
977 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
979 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
980 }
981
982 #[test]
983 fn test_token_bucket_refill_over_time() {
984 let mut policy = TokenBucketPolicy::new(10.0, 10.0).unwrap(); let now = Instant::now();
986
987 for _ in 0..10 {
989 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
990 }
991 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
992
993 let later = now + Duration::from_millis(500);
995 for i in 0..5 {
996 assert_eq!(
997 policy.register_event(later),
998 PolicyDecision::Allow,
999 "Event {} should be allowed after refill",
1000 i
1001 );
1002 }
1003 assert_eq!(policy.register_event(later), PolicyDecision::Suppress);
1004 }
1005
1006 #[test]
1007 fn test_token_bucket_burst_tolerance() {
1008 let mut policy = TokenBucketPolicy::new(100.0, 1.0).unwrap();
1009 let now = Instant::now();
1010
1011 for i in 0..100 {
1013 assert_eq!(
1014 policy.register_event(now),
1015 PolicyDecision::Allow,
1016 "Event {} in burst should be allowed",
1017 i
1018 );
1019 }
1020 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1022 }
1023
1024 #[test]
1025 fn test_token_bucket_sustained_rate() {
1026 let mut policy = TokenBucketPolicy::new(10.0, 10.0).unwrap(); let now = Instant::now();
1028
1029 for _ in 0..10 {
1031 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
1032 }
1033 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1034
1035 let later = now + Duration::from_secs(1);
1037 for i in 0..10 {
1038 assert_eq!(
1039 policy.register_event(later),
1040 PolicyDecision::Allow,
1041 "Event {} after 1s should be allowed",
1042 i
1043 );
1044 }
1045 assert_eq!(policy.register_event(later), PolicyDecision::Suppress);
1046
1047 let even_later = later + Duration::from_millis(500);
1049 for i in 0..5 {
1050 assert_eq!(
1051 policy.register_event(even_later),
1052 PolicyDecision::Allow,
1053 "Event {} after 0.5s should be allowed",
1054 i
1055 );
1056 }
1057 assert_eq!(policy.register_event(even_later), PolicyDecision::Suppress);
1058 }
1059
1060 #[test]
1061 fn test_token_bucket_recovery_after_quiet() {
1062 let mut policy = TokenBucketPolicy::new(5.0, 2.0).unwrap();
1063 let now = Instant::now();
1064
1065 for _ in 0..5 {
1067 policy.register_event(now);
1068 }
1069 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1070
1071 let much_later = now + Duration::from_secs(10);
1073 for i in 0..5 {
1075 assert_eq!(
1076 policy.register_event(much_later),
1077 PolicyDecision::Allow,
1078 "Event {} after recovery should be allowed",
1079 i
1080 );
1081 }
1082 assert_eq!(policy.register_event(much_later), PolicyDecision::Suppress);
1083 }
1084
1085 #[test]
1086 fn test_token_bucket_fractional_refill() {
1087 let mut policy = TokenBucketPolicy::new(10.0, 0.5).unwrap(); let now = Instant::now();
1089
1090 for _ in 0..10 {
1092 policy.register_event(now);
1093 }
1094 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1095
1096 let later = now + Duration::from_secs(3);
1098 assert_eq!(policy.register_event(later), PolicyDecision::Allow);
1099 assert_eq!(policy.register_event(later), PolicyDecision::Suppress); let even_later = later + Duration::from_secs(1);
1103 assert_eq!(policy.register_event(even_later), PolicyDecision::Allow);
1104 assert_eq!(policy.register_event(even_later), PolicyDecision::Suppress);
1105 }
1106
1107 #[test]
1108 fn test_token_bucket_reset() {
1109 let mut policy = TokenBucketPolicy::new(5.0, 1.0).unwrap();
1110 let now = Instant::now();
1111
1112 for _ in 0..5 {
1114 policy.register_event(now);
1115 }
1116 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1117
1118 policy.reset();
1120 for i in 0..5 {
1121 assert_eq!(
1122 policy.register_event(now),
1123 PolicyDecision::Allow,
1124 "Event {} after reset should be allowed",
1125 i
1126 );
1127 }
1128 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1129 }
1130
1131 #[test]
1132 fn test_token_bucket_capacity_cap() {
1133 let mut policy = TokenBucketPolicy::new(5.0, 10.0).unwrap();
1134 let now = Instant::now();
1135
1136 for _ in 0..3 {
1138 policy.register_event(now);
1139 }
1140
1141 let much_later = now + Duration::from_secs(100);
1143 for i in 0..5 {
1144 assert_eq!(
1145 policy.register_event(much_later),
1146 PolicyDecision::Allow,
1147 "Event {} should be allowed (capped at capacity)",
1148 i
1149 );
1150 }
1151 assert_eq!(policy.register_event(much_later), PolicyDecision::Suppress);
1152 }
1153
1154 #[test]
1155 fn test_token_bucket_zero_capacity() {
1156 let result = TokenBucketPolicy::new(0.0, 1.0);
1157 assert_eq!(result, Err(PolicyError::ZeroCapacity));
1158 }
1159
1160 #[test]
1161 fn test_token_bucket_negative_capacity() {
1162 let result = TokenBucketPolicy::new(-5.0, 1.0);
1163 assert_eq!(result, Err(PolicyError::ZeroCapacity));
1164 }
1165
1166 #[test]
1167 fn test_token_bucket_zero_refill_rate() {
1168 let result = TokenBucketPolicy::new(10.0, 0.0);
1169 assert_eq!(result, Err(PolicyError::ZeroRefillRate));
1170 }
1171
1172 #[test]
1173 fn test_token_bucket_negative_refill_rate() {
1174 let result = TokenBucketPolicy::new(10.0, -2.0);
1175 assert_eq!(result, Err(PolicyError::ZeroRefillRate));
1176 }
1177
1178 #[test]
1179 fn test_token_bucket_policy_enum() {
1180 let mut policy = Policy::token_bucket(5.0, 2.0).unwrap();
1181 let now = Instant::now();
1182
1183 for i in 0..5 {
1185 assert!(
1186 policy.register_event(now).is_allow(),
1187 "Event {} should be allowed",
1188 i
1189 );
1190 }
1191 assert!(policy.register_event(now).is_suppress());
1192
1193 policy.reset();
1195 assert!(policy.register_event(now).is_allow());
1196 }
1197
1198 #[test]
1199 fn test_token_bucket_incremental_refill() {
1200 let mut policy = TokenBucketPolicy::new(1.0, 10.0).unwrap(); let now = Instant::now();
1202
1203 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
1205 assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1206
1207 let t1 = now + Duration::from_millis(100);
1209 assert_eq!(policy.register_event(t1), PolicyDecision::Allow);
1210 assert_eq!(policy.register_event(t1), PolicyDecision::Suppress);
1211
1212 let t2 = t1 + Duration::from_millis(100);
1213 assert_eq!(policy.register_event(t2), PolicyDecision::Allow);
1214 assert_eq!(policy.register_event(t2), PolicyDecision::Suppress);
1215 }
1216
1217 #[test]
1218 fn test_token_bucket_same_timestamp_multiple_events() {
1219 let mut policy = TokenBucketPolicy::new(5.0, 2.0).unwrap();
1221 let start = Instant::now();
1222
1223 for i in 0..5 {
1225 assert_eq!(
1226 policy.register_event(start),
1227 PolicyDecision::Allow,
1228 "Event {} should be allowed",
1229 i
1230 );
1231 }
1232
1233 for i in 5..8 {
1235 assert_eq!(
1236 policy.register_event(start),
1237 PolicyDecision::Suppress,
1238 "Event {} should be suppressed (no tokens)",
1239 i
1240 );
1241 }
1242
1243 let t1 = start + Duration::from_secs(1);
1245
1246 assert_eq!(
1248 policy.register_event(t1),
1249 PolicyDecision::Allow,
1250 "First event after 1s should be allowed"
1251 );
1252 assert_eq!(
1253 policy.register_event(t1),
1254 PolicyDecision::Allow,
1255 "Second event after 1s should be allowed"
1256 );
1257
1258 assert_eq!(
1260 policy.register_event(t1),
1261 PolicyDecision::Suppress,
1262 "Third event after 1s should be suppressed (only 2 tokens refilled)"
1263 );
1264
1265 assert_eq!(policy.register_event(t1), PolicyDecision::Suppress);
1267 assert_eq!(policy.register_event(t1), PolicyDecision::Suppress);
1268 }
1269
1270 #[test]
1271 fn test_token_bucket_time_goes_backwards() {
1272 let mut policy = TokenBucketPolicy::new(10.0, 5.0).unwrap();
1273 let now = Instant::now();
1274
1275 for _ in 0..5 {
1277 assert_eq!(policy.register_event(now), PolicyDecision::Allow);
1278 }
1279 let future = now + Duration::from_secs(1);
1283 for _ in 0..10 {
1284 assert_eq!(policy.register_event(future), PolicyDecision::Allow);
1285 }
1286 let past = now + Duration::from_millis(500);
1291 assert!(past < future, "Test setup: past must be before future");
1292
1293 assert_eq!(
1295 policy.register_event(past),
1296 PolicyDecision::Suppress,
1297 "Should suppress when no tokens available after time went backwards"
1298 );
1299
1300 let future2 = past + Duration::from_secs(1);
1302 for i in 0..5 {
1304 assert_eq!(
1305 policy.register_event(future2),
1306 PolicyDecision::Allow,
1307 "Token {} should be available after normal time progression",
1308 i
1309 );
1310 }
1311
1312 assert_eq!(policy.register_event(future2), PolicyDecision::Suppress);
1314 }
1315
1316 #[test]
1317 fn test_time_window_with_many_events() {
1318 let mut policy = TimeWindowPolicy::new(100, Duration::from_secs(60)).unwrap();
1320 let now = Instant::now();
1321
1322 for i in 0..100 {
1324 let timestamp = now + Duration::from_millis(i * 10);
1325 policy.register_event(timestamp);
1326 }
1327
1328 assert_eq!(
1330 policy.register_event(now + Duration::from_millis(1000)),
1331 PolicyDecision::Suppress
1332 );
1333
1334 let later = now + Duration::from_secs(70);
1336 assert_eq!(policy.register_event(later), PolicyDecision::Allow);
1337 }
1338}