netray_common/
rate_limit.rs1use std::num::NonZeroU32;
2
3use governor::clock::DefaultClock;
4use governor::RateLimiter;
5
6pub type KeyedLimiter<K> =
9 RateLimiter<K, governor::state::keyed::DefaultKeyedStateStore<K>, DefaultClock>;
10
11#[derive(Debug)]
13pub struct RateLimitRejection {
14 pub retry_after_secs: u64,
16 pub scope: &'static str,
18}
19
20pub fn check_keyed_cost<K: std::hash::Hash + Eq + Clone>(
28 limiter: &KeyedLimiter<K>,
29 key: &K,
30 cost: NonZeroU32,
31 scope: &'static str,
32 metrics_prefix: &'static str,
33) -> Result<(), RateLimitRejection> {
34 let retry_secs = match limiter.check_key_n(key, cost) {
35 Ok(Ok(())) => return Ok(()),
36 Ok(Err(not_until)) => {
37 let wait =
38 not_until.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()));
39 wait.as_secs()
40 }
41 Err(_) => 60,
43 };
44 let counter_name = format!("{metrics_prefix}_rate_limit_hits_total");
45 metrics::counter!(counter_name, "scope" => scope).increment(1);
46 Err(RateLimitRejection {
47 retry_after_secs: retry_secs.max(1),
48 scope,
49 })
50}
51
52pub fn check_direct_cost(
57 limiter: &RateLimiter<
58 governor::state::NotKeyed,
59 governor::state::InMemoryState,
60 DefaultClock,
61 >,
62 cost: NonZeroU32,
63 metrics_prefix: &'static str,
64) -> Result<(), RateLimitRejection> {
65 let retry_secs = match limiter.check_n(cost) {
66 Ok(Ok(())) => return Ok(()),
67 Ok(Err(not_until)) => {
68 let wait =
69 not_until.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()));
70 wait.as_secs()
71 }
72 Err(_) => 60,
73 };
74 let counter_name = format!("{metrics_prefix}_rate_limit_hits_total");
75 metrics::counter!(counter_name, "scope" => "global").increment(1);
76 Err(RateLimitRejection {
77 retry_after_secs: retry_secs.max(1),
78 scope: "global",
79 })
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use governor::Quota;
86 use std::net::IpAddr;
87
88 fn make_keyed_limiter(per_minute: u32, burst: u32) -> KeyedLimiter<IpAddr> {
89 RateLimiter::keyed(
90 Quota::per_minute(NonZeroU32::new(per_minute).unwrap())
91 .allow_burst(NonZeroU32::new(burst).unwrap()),
92 )
93 }
94
95 fn make_direct_limiter(
96 per_minute: u32,
97 burst: u32,
98 ) -> RateLimiter<governor::state::NotKeyed, governor::state::InMemoryState, DefaultClock> {
99 RateLimiter::direct(
100 Quota::per_minute(NonZeroU32::new(per_minute).unwrap())
101 .allow_burst(NonZeroU32::new(burst).unwrap()),
102 )
103 }
104
105 #[test]
106 fn keyed_allows_within_budget() {
107 let limiter = make_keyed_limiter(30, 10);
108 let ip: IpAddr = "198.51.100.1".parse().unwrap();
109 let cost = NonZeroU32::new(5).unwrap();
110
111 assert!(check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").is_ok());
112 }
113
114 #[test]
115 fn keyed_rejects_when_exhausted() {
116 let limiter = make_keyed_limiter(30, 10);
117 let ip: IpAddr = "198.51.100.1".parse().unwrap();
118
119 let cost = NonZeroU32::new(10).unwrap();
121 assert!(check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").is_ok());
122
123 let cost = NonZeroU32::new(1).unwrap();
125 let err = check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").unwrap_err();
126 assert_eq!(err.scope, "per_ip");
127 assert!(err.retry_after_secs >= 1);
128 }
129
130 #[test]
131 fn keyed_independent_keys() {
132 let limiter = make_keyed_limiter(30, 10);
133 let ip1: IpAddr = "198.51.100.1".parse().unwrap();
134 let ip2: IpAddr = "198.51.100.2".parse().unwrap();
135 let cost = NonZeroU32::new(10).unwrap();
136
137 assert!(check_keyed_cost(&limiter, &ip1, cost, "per_ip", "test").is_ok());
138 assert!(check_keyed_cost(&limiter, &ip2, cost, "per_ip", "test").is_ok());
139 }
140
141 #[test]
142 fn keyed_insufficient_capacity() {
143 let limiter = make_keyed_limiter(30, 10);
144 let ip: IpAddr = "198.51.100.1".parse().unwrap();
145 let cost = NonZeroU32::new(11).unwrap();
146
147 let err = check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").unwrap_err();
148 assert_eq!(err.retry_after_secs, 60);
149 }
150
151 #[test]
152 fn direct_allows_within_budget() {
153 let limiter = make_direct_limiter(500, 50);
154 let cost = NonZeroU32::new(10).unwrap();
155
156 assert!(check_direct_cost(&limiter, cost, "test").is_ok());
157 }
158
159 #[test]
160 fn direct_rejects_when_exhausted() {
161 let limiter = make_direct_limiter(500, 50);
162
163 let cost = NonZeroU32::new(50).unwrap();
164 assert!(check_direct_cost(&limiter, cost, "test").is_ok());
165
166 let cost = NonZeroU32::new(1).unwrap();
167 let err = check_direct_cost(&limiter, cost, "test").unwrap_err();
168 assert_eq!(err.scope, "global");
169 assert!(err.retry_after_secs >= 1);
170 }
171
172 #[test]
173 fn direct_insufficient_capacity() {
174 let limiter = make_direct_limiter(500, 50);
175 let cost = NonZeroU32::new(51).unwrap();
176
177 let err = check_direct_cost(&limiter, cost, "test").unwrap_err();
178 assert_eq!(err.retry_after_secs, 60);
179 assert_eq!(err.scope, "global");
180 }
181}