speedbump/strategy/
mod.rs1use crate::strategy::clock::Clock;
2pub(crate) mod clock;
3
4use std::{
5 convert::Infallible,
6 error::Error,
7 fmt::Debug,
8 time::{Duration, SystemTime},
9};
10
11use serde::{Deserialize, Serialize};
12
13use crate::strategy::clock::{DefaultClock, SystemTimeClock};
14
15pub struct LimitResult<M> {
16 allowed: bool,
17 metadata: Option<M>,
18}
19
20impl<M> LimitResult<M> {
21 #[must_use]
22 pub fn allowed() -> Self {
23 Self {
24 allowed: true,
25 metadata: None,
26 }
27 }
28
29 #[must_use]
30 pub fn disallowed() -> Self {
31 Self {
32 allowed: false,
33 metadata: None,
34 }
35 }
36
37 #[must_use]
38 pub fn with_metadata(self, metadata: M) -> Self {
39 Self {
40 allowed: self.allowed,
41 metadata: Some(metadata),
42 }
43 }
44
45 pub fn is_allowed(&self) -> bool {
46 self.allowed
47 }
48
49 pub fn metadata(&self) -> Option<&M> {
50 self.metadata.as_ref()
51 }
52}
53
54pub type DefaultLimitResult = LimitResult<()>;
55
56pub trait LimitStrategy {
57 type State: Debug;
58 type Error: Error;
59 type Metadata: Debug;
60
61 fn check_limit(
66 &self,
67 state: &mut Self::State,
68 ) -> Result<LimitResult<Self::Metadata>, Self::Error>;
69
70 fn initialize_state(&self) -> Self::State;
72}
73
74pub struct FixedWindow {
75 window_duration: Duration,
76 limit: u32,
77 clock: DefaultClock,
78}
79
80impl FixedWindow {
81 #[must_use]
82 pub fn new(window_duration: Duration, limit: u32) -> Self {
83 #[cfg(not(test))]
84 let clock = SystemTimeClock;
85
86 #[cfg(test)]
87 let clock = std::sync::Arc::new(SystemTimeClock) as DefaultClock;
88
89 Self {
90 window_duration,
91 limit,
92 clock,
93 }
94 }
95
96 #[cfg(test)]
97 pub fn with_clock(self, clock: DefaultClock) -> Self {
98 Self {
99 window_duration: self.window_duration,
100 limit: self.limit,
101 clock,
102 }
103 }
104}
105
106#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
107pub struct FixedWindowCounterState {
108 count: u32,
109 window_start: SystemTime,
110}
111
112#[derive(Debug, Clone, Copy)]
113pub struct FixedWindowMetadata {
114 pub till_next_window: Duration,
115}
116
117impl LimitStrategy for FixedWindow {
118 type State = FixedWindowCounterState;
119 type Error = Infallible;
120 type Metadata = FixedWindowMetadata;
121
122 fn check_limit(
123 &self,
124 state: &mut Self::State,
125 ) -> Result<LimitResult<FixedWindowMetadata>, Self::Error> {
126 let now = self.clock.now();
127 let time_since_start = now
128 .duration_since(state.window_start)
129 .unwrap_or(Duration::MAX);
130
131 let allowed = if time_since_start < self.window_duration {
132 if state.count < self.limit {
133 state.count += 1;
134 LimitResult::allowed()
135 } else {
136 LimitResult::disallowed()
137 }
138 } else {
139 state.count = 1;
141 state.window_start = now;
142 LimitResult::allowed()
143 };
144
145 let next_window = state
146 .window_start
147 .checked_add(self.window_duration)
148 .map_or(Duration::MAX, |x| {
149 x.duration_since(now).unwrap_or(Duration::MAX)
150 });
151
152 let result = allowed.with_metadata(FixedWindowMetadata {
153 till_next_window: next_window,
154 });
155
156 Ok(result)
157 }
158
159 fn initialize_state(&self) -> Self::State {
160 FixedWindowCounterState {
161 count: 0,
162 window_start: self.clock.now(),
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use std::sync::Arc;
170
171 use super::*;
172 use crate::strategy::clock::test::MockClock;
173
174 #[test]
175 fn test_fixed_window() {
176 let clock = Arc::new(MockClock::new());
177 let fixed_window = FixedWindow::new(Duration::from_secs(10), 3).with_clock(clock.clone());
178 let mut state = fixed_window.initialize_state();
179
180 assert!(fixed_window.check_limit(&mut state).unwrap().is_allowed());
182 assert!(fixed_window.check_limit(&mut state).unwrap().is_allowed());
183 assert!(fixed_window.check_limit(&mut state).unwrap().is_allowed());
184
185 assert!(!fixed_window.check_limit(&mut state).unwrap().is_allowed());
187
188 clock.advance(Duration::from_secs(10));
190
191 assert!(fixed_window.check_limit(&mut state).unwrap().is_allowed());
193 assert!(fixed_window.check_limit(&mut state).unwrap().is_allowed());
194 assert!(fixed_window.check_limit(&mut state).unwrap().is_allowed());
195
196 assert!(!fixed_window.check_limit(&mut state).unwrap().is_allowed());
198 }
199}