Skip to main content

structured_proxy/shield/
rate.rs

1//! Rate-limit rate parsing.
2//!
3//! A rate is a request count over a time window. It is written either as
4//! `"<count>/<unit>"` (e.g. `"20/min"`) or as a bare count (e.g. `"20"`), in
5//! which case the configured default window applies.
6
7use std::time::Duration;
8
9/// A parsed rate limit: at most `limit` requests per `window`.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct Rate {
12    /// Maximum number of requests allowed within the window.
13    pub limit: u64,
14    /// Length of the window.
15    pub window: Duration,
16}
17
18impl Rate {
19    /// Parse a rate string.
20    ///
21    /// Accepts `"<count>/<unit>"` where unit is one of `s`/`sec`/`second`,
22    /// `m`/`min`/`minute`, `h`/`hour` (and their plurals), or a bare
23    /// `"<count>"` which uses `default_window`.
24    ///
25    /// # Errors
26    /// Returns an error string when the count or unit cannot be parsed.
27    pub fn parse(raw: &str, default_window: Duration) -> Result<Self, String> {
28        let raw = raw.trim();
29        match raw.split_once('/') {
30            None => {
31                let limit = raw
32                    .parse::<u64>()
33                    .map_err(|_| format!("invalid rate count: {raw:?}"))?;
34                Ok(Rate {
35                    limit,
36                    window: default_window,
37                })
38            }
39            Some((count, unit)) => {
40                let limit = count
41                    .trim()
42                    .parse::<u64>()
43                    .map_err(|_| format!("invalid rate count: {count:?}"))?;
44                let window = parse_unit(unit.trim())?;
45                Ok(Rate { limit, window })
46            }
47        }
48    }
49}
50
51/// Map a time unit token to its [`Duration`].
52fn parse_unit(unit: &str) -> Result<Duration, String> {
53    let secs = match unit.to_ascii_lowercase().as_str() {
54        "s" | "sec" | "secs" | "second" | "seconds" => 1,
55        "m" | "min" | "mins" | "minute" | "minutes" => 60,
56        "h" | "hr" | "hour" | "hours" => 3600,
57        other => return Err(format!("invalid rate unit: {other:?}")),
58    };
59    Ok(Duration::from_secs(secs))
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65
66    const DEFAULT: Duration = Duration::from_secs(60);
67
68    #[test]
69    fn parse_count_per_unit() {
70        assert_eq!(
71            Rate::parse("20/min", DEFAULT).unwrap(),
72            Rate {
73                limit: 20,
74                window: Duration::from_secs(60)
75            }
76        );
77        assert_eq!(
78            Rate::parse("5/s", DEFAULT).unwrap(),
79            Rate {
80                limit: 5,
81                window: Duration::from_secs(1)
82            }
83        );
84        assert_eq!(
85            Rate::parse("100/hour", DEFAULT).unwrap(),
86            Rate {
87                limit: 100,
88                window: Duration::from_secs(3600)
89            }
90        );
91    }
92
93    #[test]
94    fn parse_bare_count_uses_default_window() {
95        assert_eq!(
96            Rate::parse("42", DEFAULT).unwrap(),
97            Rate {
98                limit: 42,
99                window: DEFAULT
100            }
101        );
102    }
103
104    #[test]
105    fn parse_rejects_garbage() {
106        assert!(Rate::parse("abc", DEFAULT).is_err());
107        assert!(Rate::parse("10/fortnight", DEFAULT).is_err());
108        assert!(Rate::parse("/min", DEFAULT).is_err());
109    }
110}