structured_proxy/shield/
rate.rs1use std::time::Duration;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct Rate {
12 pub limit: u64,
14 pub window: Duration,
16}
17
18impl Rate {
19 pub fn parse(raw: &str, default_window: Duration) -> Result<Self, String> {
28 let raw = raw.trim();
29 match raw.split_once('/') {
30 None => {
31 let limit = raw
32 .parse::<u64>()
33 .map_err(|_| format!("invalid rate count: {raw:?}"))?;
34 Ok(Rate {
35 limit,
36 window: default_window,
37 })
38 }
39 Some((count, unit)) => {
40 let limit = count
41 .trim()
42 .parse::<u64>()
43 .map_err(|_| format!("invalid rate count: {count:?}"))?;
44 let window = parse_unit(unit.trim())?;
45 Ok(Rate { limit, window })
46 }
47 }
48 }
49}
50
51fn parse_unit(unit: &str) -> Result<Duration, String> {
53 let secs = match unit.to_ascii_lowercase().as_str() {
54 "s" | "sec" | "secs" | "second" | "seconds" => 1,
55 "m" | "min" | "mins" | "minute" | "minutes" => 60,
56 "h" | "hr" | "hour" | "hours" => 3600,
57 other => return Err(format!("invalid rate unit: {other:?}")),
58 };
59 Ok(Duration::from_secs(secs))
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65
66 const DEFAULT: Duration = Duration::from_secs(60);
67
68 #[test]
69 fn parse_count_per_unit() {
70 assert_eq!(
71 Rate::parse("20/min", DEFAULT).unwrap(),
72 Rate {
73 limit: 20,
74 window: Duration::from_secs(60)
75 }
76 );
77 assert_eq!(
78 Rate::parse("5/s", DEFAULT).unwrap(),
79 Rate {
80 limit: 5,
81 window: Duration::from_secs(1)
82 }
83 );
84 assert_eq!(
85 Rate::parse("100/hour", DEFAULT).unwrap(),
86 Rate {
87 limit: 100,
88 window: Duration::from_secs(3600)
89 }
90 );
91 }
92
93 #[test]
94 fn parse_bare_count_uses_default_window() {
95 assert_eq!(
96 Rate::parse("42", DEFAULT).unwrap(),
97 Rate {
98 limit: 42,
99 window: DEFAULT
100 }
101 );
102 }
103
104 #[test]
105 fn parse_rejects_garbage() {
106 assert!(Rate::parse("abc", DEFAULT).is_err());
107 assert!(Rate::parse("10/fortnight", DEFAULT).is_err());
108 assert!(Rate::parse("/min", DEFAULT).is_err());
109 }
110}