speedbump/strategy/
mod.rs

1use crate::strategy::clock::Clock;
2pub(crate) mod clock;
3
4use std::{
5    convert::Infallible,
6    error::Error,
7    fmt::Debug,
8    time::{Duration, SystemTime},
9};
10
11use serde::{Deserialize, Serialize};
12
13use crate::strategy::clock::{DefaultClock, SystemTimeClock};
14
15pub struct LimitResult<M> {
16    allowed: bool,
17    metadata: Option<M>,
18}
19
20impl<M> LimitResult<M> {
21    #[must_use]
22    pub fn allowed() -> Self {
23        Self {
24            allowed: true,
25            metadata: None,
26        }
27    }
28
29    #[must_use]
30    pub fn disallowed() -> Self {
31        Self {
32            allowed: false,
33            metadata: None,
34        }
35    }
36
37    #[must_use]
38    pub fn with_metadata(self, metadata: M) -> Self {
39        Self {
40            allowed: self.allowed,
41            metadata: Some(metadata),
42        }
43    }
44
45    pub fn is_allowed(&self) -> bool {
46        self.allowed
47    }
48
49    pub fn metadata(&self) -> Option<&M> {
50        self.metadata.as_ref()
51    }
52}
53
54pub type DefaultLimitResult = LimitResult<()>;
55
56pub trait LimitStrategy {
57    type State: Debug;
58    type Error: Error;
59    type Metadata: Debug;
60
61    /// Check if a request should be allowed and update the provided state
62    ///
63    /// # Errors
64    /// Errors when a user-defined error occurs in the strategy.
65    fn check_limit(
66        &self,
67        state: &mut Self::State,
68    ) -> Result<LimitResult<Self::Metadata>, Self::Error>;
69
70    /// Initialize a new rate limit state for a key
71    fn initialize_state(&self) -> Self::State;
72}
73
74pub struct FixedWindow {
75    window_duration: Duration,
76    limit: u32,
77    clock: DefaultClock,
78}
79
80impl FixedWindow {
81    #[must_use]
82    pub fn new(window_duration: Duration, limit: u32) -> Self {
83        #[cfg(not(test))]
84        let clock = SystemTimeClock;
85
86        #[cfg(test)]
87        let clock = std::sync::Arc::new(SystemTimeClock) as DefaultClock;
88
89        Self {
90            window_duration,
91            limit,
92            clock,
93        }
94    }
95
96    #[cfg(test)]
97    pub fn with_clock(self, clock: DefaultClock) -> Self {
98        Self {
99            window_duration: self.window_duration,
100            limit: self.limit,
101            clock,
102        }
103    }
104}
105
106#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
107pub struct FixedWindowCounterState {
108    count: u32,
109    window_start: SystemTime,
110}
111
112#[derive(Debug, Clone, Copy)]
113pub struct FixedWindowMetadata {
114    pub till_next_window: Duration,
115}
116
117impl LimitStrategy for FixedWindow {
118    type State = FixedWindowCounterState;
119    type Error = Infallible;
120    type Metadata = FixedWindowMetadata;
121
122    fn check_limit(
123        &self,
124        state: &mut Self::State,
125    ) -> Result<LimitResult<FixedWindowMetadata>, Self::Error> {
126        let now = self.clock.now();
127        let time_since_start = now
128            .duration_since(state.window_start)
129            .unwrap_or(Duration::MAX);
130
131        let allowed = if time_since_start < self.window_duration {
132            if state.count < self.limit {
133                state.count += 1;
134                LimitResult::allowed()
135            } else {
136                LimitResult::disallowed()
137            }
138        } else {
139            // New window, reset the counter and update the start time
140            state.count = 1;
141            state.window_start = now;
142            LimitResult::allowed()
143        };
144
145        let next_window = state
146            .window_start
147            .checked_add(self.window_duration)
148            .map_or(Duration::MAX, |x| {
149                x.duration_since(now).unwrap_or(Duration::MAX)
150            });
151
152        let result = allowed.with_metadata(FixedWindowMetadata {
153            till_next_window: next_window,
154        });
155
156        Ok(result)
157    }
158
159    fn initialize_state(&self) -> Self::State {
160        FixedWindowCounterState {
161            count: 0,
162            window_start: self.clock.now(),
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use std::sync::Arc;
170
171    use super::*;
172    use crate::strategy::clock::test::MockClock;
173
174    #[test]
175    fn test_fixed_window() {
176        let clock = Arc::new(MockClock::new());
177        let fixed_window = FixedWindow::new(Duration::from_secs(10), 3).with_clock(clock.clone());
178        let mut state = fixed_window.initialize_state();
179
180        // should let 3 through
181        assert!(fixed_window.check_limit(&mut state).unwrap().is_allowed());
182        assert!(fixed_window.check_limit(&mut state).unwrap().is_allowed());
183        assert!(fixed_window.check_limit(&mut state).unwrap().is_allowed());
184
185        // and then stop the next
186        assert!(!fixed_window.check_limit(&mut state).unwrap().is_allowed());
187
188        // advancing 10 seconds should allow again
189        clock.advance(Duration::from_secs(10));
190
191        // should let a further 3 through
192        assert!(fixed_window.check_limit(&mut state).unwrap().is_allowed());
193        assert!(fixed_window.check_limit(&mut state).unwrap().is_allowed());
194        assert!(fixed_window.check_limit(&mut state).unwrap().is_allowed());
195
196        // and then stop the next
197        assert!(!fixed_window.check_limit(&mut state).unwrap().is_allowed());
198    }
199}