Skip to main content

netray_common/
rate_limit.rs

1use std::num::NonZeroU32;
2
3use governor::clock::DefaultClock;
4use governor::RateLimiter;
5
6/// Keyed rate limiter type alias wrapping governor's `RateLimiter` with
7/// `DefaultKeyedStateStore` and `DefaultClock`.
8pub type KeyedLimiter<K> =
9    RateLimiter<K, governor::state::keyed::DefaultKeyedStateStore<K>, DefaultClock>;
10
11/// Result of a failed rate limit check.
12#[derive(Debug)]
13pub struct RateLimitRejection {
14    /// Seconds the client should wait before retrying.
15    pub retry_after_secs: u64,
16    /// Which limiter scope rejected the request (e.g. "per_ip", "per_target", "global").
17    pub scope: &'static str,
18}
19
20/// Check a keyed rate limiter with the given cost.
21///
22/// Returns `Ok(())` if the request is within budget, or `Err(RateLimitRejection)`
23/// with the appropriate retry-after duration and scope.
24///
25/// The `metrics_prefix` is used to increment a counter named
26/// `{metrics_prefix}_rate_limit_hits_total` with a `scope` label on rejection.
27pub fn check_keyed_cost<K: std::hash::Hash + Eq + Clone>(
28    limiter: &KeyedLimiter<K>,
29    key: &K,
30    cost: NonZeroU32,
31    scope: &'static str,
32    metrics_prefix: &'static str,
33) -> Result<(), RateLimitRejection> {
34    let retry_secs = match limiter.check_key_n(key, cost) {
35        Ok(Ok(())) => return Ok(()),
36        Ok(Err(not_until)) => {
37            let wait =
38                not_until.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()));
39            wait.as_secs()
40        }
41        // InsufficientCapacity: cost exceeds burst size entirely.
42        Err(_) => 60,
43    };
44    let counter_name = format!("{metrics_prefix}_rate_limit_hits_total");
45    metrics::counter!(counter_name, "scope" => scope).increment(1);
46    Err(RateLimitRejection {
47        retry_after_secs: retry_secs.max(1),
48        scope,
49    })
50}
51
52/// Check a direct (unkeyed/global) rate limiter with the given cost.
53///
54/// Returns `Ok(())` if the request is within budget, or `Err(RateLimitRejection)`
55/// with scope `"global"`.
56pub fn check_direct_cost(
57    limiter: &RateLimiter<
58        governor::state::NotKeyed,
59        governor::state::InMemoryState,
60        DefaultClock,
61    >,
62    cost: NonZeroU32,
63    metrics_prefix: &'static str,
64) -> Result<(), RateLimitRejection> {
65    let retry_secs = match limiter.check_n(cost) {
66        Ok(Ok(())) => return Ok(()),
67        Ok(Err(not_until)) => {
68            let wait =
69                not_until.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()));
70            wait.as_secs()
71        }
72        Err(_) => 60,
73    };
74    let counter_name = format!("{metrics_prefix}_rate_limit_hits_total");
75    metrics::counter!(counter_name, "scope" => "global").increment(1);
76    Err(RateLimitRejection {
77        retry_after_secs: retry_secs.max(1),
78        scope: "global",
79    })
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use governor::Quota;
86    use std::net::IpAddr;
87
88    fn make_keyed_limiter(per_minute: u32, burst: u32) -> KeyedLimiter<IpAddr> {
89        RateLimiter::keyed(
90            Quota::per_minute(NonZeroU32::new(per_minute).unwrap())
91                .allow_burst(NonZeroU32::new(burst).unwrap()),
92        )
93    }
94
95    fn make_direct_limiter(
96        per_minute: u32,
97        burst: u32,
98    ) -> RateLimiter<governor::state::NotKeyed, governor::state::InMemoryState, DefaultClock> {
99        RateLimiter::direct(
100            Quota::per_minute(NonZeroU32::new(per_minute).unwrap())
101                .allow_burst(NonZeroU32::new(burst).unwrap()),
102        )
103    }
104
105    #[test]
106    fn keyed_allows_within_budget() {
107        let limiter = make_keyed_limiter(30, 10);
108        let ip: IpAddr = "198.51.100.1".parse().unwrap();
109        let cost = NonZeroU32::new(5).unwrap();
110
111        assert!(check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").is_ok());
112    }
113
114    #[test]
115    fn keyed_rejects_when_exhausted() {
116        let limiter = make_keyed_limiter(30, 10);
117        let ip: IpAddr = "198.51.100.1".parse().unwrap();
118
119        // Exhaust the burst
120        let cost = NonZeroU32::new(10).unwrap();
121        assert!(check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").is_ok());
122
123        // Next request should be rejected
124        let cost = NonZeroU32::new(1).unwrap();
125        let err = check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").unwrap_err();
126        assert_eq!(err.scope, "per_ip");
127        assert!(err.retry_after_secs >= 1);
128    }
129
130    #[test]
131    fn keyed_independent_keys() {
132        let limiter = make_keyed_limiter(30, 10);
133        let ip1: IpAddr = "198.51.100.1".parse().unwrap();
134        let ip2: IpAddr = "198.51.100.2".parse().unwrap();
135        let cost = NonZeroU32::new(10).unwrap();
136
137        assert!(check_keyed_cost(&limiter, &ip1, cost, "per_ip", "test").is_ok());
138        assert!(check_keyed_cost(&limiter, &ip2, cost, "per_ip", "test").is_ok());
139    }
140
141    #[test]
142    fn keyed_insufficient_capacity() {
143        let limiter = make_keyed_limiter(30, 10);
144        let ip: IpAddr = "198.51.100.1".parse().unwrap();
145        let cost = NonZeroU32::new(11).unwrap();
146
147        let err = check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").unwrap_err();
148        assert_eq!(err.retry_after_secs, 60);
149    }
150
151    #[test]
152    fn direct_allows_within_budget() {
153        let limiter = make_direct_limiter(500, 50);
154        let cost = NonZeroU32::new(10).unwrap();
155
156        assert!(check_direct_cost(&limiter, cost, "test").is_ok());
157    }
158
159    #[test]
160    fn direct_rejects_when_exhausted() {
161        let limiter = make_direct_limiter(500, 50);
162
163        let cost = NonZeroU32::new(50).unwrap();
164        assert!(check_direct_cost(&limiter, cost, "test").is_ok());
165
166        let cost = NonZeroU32::new(1).unwrap();
167        let err = check_direct_cost(&limiter, cost, "test").unwrap_err();
168        assert_eq!(err.scope, "global");
169        assert!(err.retry_after_secs >= 1);
170    }
171
172    #[test]
173    fn direct_insufficient_capacity() {
174        let limiter = make_direct_limiter(500, 50);
175        let cost = NonZeroU32::new(51).unwrap();
176
177        let err = check_direct_cost(&limiter, cost, "test").unwrap_err();
178        assert_eq!(err.retry_after_secs, 60);
179        assert_eq!(err.scope, "global");
180    }
181}