netray_common/
rate_limit.rs1use std::num::NonZeroU32;
2
3use governor::RateLimiter;
4use governor::clock::DefaultClock;
5
6pub type KeyedLimiter<K> =
9 RateLimiter<K, governor::state::keyed::DefaultKeyedStateStore<K>, DefaultClock>;
10
11#[derive(Debug)]
13pub struct RateLimitRejection {
14 pub retry_after_secs: u64,
16 pub scope: &'static str,
18}
19
20pub fn check_keyed_cost<K: std::hash::Hash + Eq + Clone>(
28 limiter: &KeyedLimiter<K>,
29 key: &K,
30 cost: NonZeroU32,
31 scope: &'static str,
32 metrics_prefix: &'static str,
33) -> Result<(), RateLimitRejection> {
34 let retry_secs = match limiter.check_key_n(key, cost) {
35 Ok(Ok(())) => return Ok(()),
36 Ok(Err(not_until)) => {
37 let wait =
38 not_until.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()));
39 wait.as_secs()
40 }
41 Err(_) => 60,
43 };
44 let counter_name = format!("{metrics_prefix}_rate_limit_hits_total");
45 metrics::counter!(counter_name, "scope" => scope).increment(1);
46 Err(RateLimitRejection {
47 retry_after_secs: retry_secs.max(1),
48 scope,
49 })
50}
51
52pub fn check_direct_cost(
57 limiter: &RateLimiter<governor::state::NotKeyed, governor::state::InMemoryState, DefaultClock>,
58 cost: NonZeroU32,
59 metrics_prefix: &'static str,
60) -> Result<(), RateLimitRejection> {
61 let retry_secs = match limiter.check_n(cost) {
62 Ok(Ok(())) => return Ok(()),
63 Ok(Err(not_until)) => {
64 let wait =
65 not_until.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()));
66 wait.as_secs()
67 }
68 Err(_) => 60,
69 };
70 let counter_name = format!("{metrics_prefix}_rate_limit_hits_total");
71 metrics::counter!(counter_name, "scope" => "global").increment(1);
72 Err(RateLimitRejection {
73 retry_after_secs: retry_secs.max(1),
74 scope: "global",
75 })
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use governor::Quota;
82 use std::net::IpAddr;
83
84 fn make_keyed_limiter(per_minute: u32, burst: u32) -> KeyedLimiter<IpAddr> {
85 RateLimiter::keyed(
86 Quota::per_minute(NonZeroU32::new(per_minute).unwrap())
87 .allow_burst(NonZeroU32::new(burst).unwrap()),
88 )
89 }
90
91 fn make_direct_limiter(
92 per_minute: u32,
93 burst: u32,
94 ) -> RateLimiter<governor::state::NotKeyed, governor::state::InMemoryState, DefaultClock> {
95 RateLimiter::direct(
96 Quota::per_minute(NonZeroU32::new(per_minute).unwrap())
97 .allow_burst(NonZeroU32::new(burst).unwrap()),
98 )
99 }
100
101 #[test]
102 fn keyed_allows_within_budget() {
103 let limiter = make_keyed_limiter(30, 10);
104 let ip: IpAddr = "198.51.100.1".parse().unwrap();
105 let cost = NonZeroU32::new(5).unwrap();
106
107 assert!(check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").is_ok());
108 }
109
110 #[test]
111 fn keyed_rejects_when_exhausted() {
112 let limiter = make_keyed_limiter(30, 10);
113 let ip: IpAddr = "198.51.100.1".parse().unwrap();
114
115 let cost = NonZeroU32::new(10).unwrap();
117 assert!(check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").is_ok());
118
119 let cost = NonZeroU32::new(1).unwrap();
121 let err = check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").unwrap_err();
122 assert_eq!(err.scope, "per_ip");
123 assert!(err.retry_after_secs >= 1);
124 }
125
126 #[test]
127 fn keyed_independent_keys() {
128 let limiter = make_keyed_limiter(30, 10);
129 let ip1: IpAddr = "198.51.100.1".parse().unwrap();
130 let ip2: IpAddr = "198.51.100.2".parse().unwrap();
131 let cost = NonZeroU32::new(10).unwrap();
132
133 assert!(check_keyed_cost(&limiter, &ip1, cost, "per_ip", "test").is_ok());
134 assert!(check_keyed_cost(&limiter, &ip2, cost, "per_ip", "test").is_ok());
135 }
136
137 #[test]
138 fn keyed_insufficient_capacity() {
139 let limiter = make_keyed_limiter(30, 10);
140 let ip: IpAddr = "198.51.100.1".parse().unwrap();
141 let cost = NonZeroU32::new(11).unwrap();
142
143 let err = check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").unwrap_err();
144 assert_eq!(err.retry_after_secs, 60);
145 }
146
147 #[test]
148 fn direct_allows_within_budget() {
149 let limiter = make_direct_limiter(500, 50);
150 let cost = NonZeroU32::new(10).unwrap();
151
152 assert!(check_direct_cost(&limiter, cost, "test").is_ok());
153 }
154
155 #[test]
156 fn direct_rejects_when_exhausted() {
157 let limiter = make_direct_limiter(500, 50);
158
159 let cost = NonZeroU32::new(50).unwrap();
160 assert!(check_direct_cost(&limiter, cost, "test").is_ok());
161
162 let cost = NonZeroU32::new(1).unwrap();
163 let err = check_direct_cost(&limiter, cost, "test").unwrap_err();
164 assert_eq!(err.scope, "global");
165 assert!(err.retry_after_secs >= 1);
166 }
167
168 #[test]
169 fn direct_insufficient_capacity() {
170 let limiter = make_direct_limiter(500, 50);
171 let cost = NonZeroU32::new(51).unwrap();
172
173 let err = check_direct_cost(&limiter, cost, "test").unwrap_err();
174 assert_eq!(err.retry_after_secs, 60);
175 assert_eq!(err.scope, "global");
176 }
177}