Skip to main content

netray_common/
rate_limit.rs

1use std::num::NonZeroU32;
2
3use governor::RateLimiter;
4use governor::clock::DefaultClock;
5
6/// Keyed rate limiter type alias wrapping governor's `RateLimiter` with
7/// `DefaultKeyedStateStore` and `DefaultClock`.
8pub type KeyedLimiter<K> =
9    RateLimiter<K, governor::state::keyed::DefaultKeyedStateStore<K>, DefaultClock>;
10
11/// Result of a failed rate limit check.
12#[derive(Debug)]
13pub struct RateLimitRejection {
14    /// Seconds the client should wait before retrying.
15    pub retry_after_secs: u64,
16    /// Which limiter scope rejected the request (e.g. "per_ip", "per_target", "global").
17    pub scope: &'static str,
18}
19
20/// Check a keyed rate limiter with the given cost.
21///
22/// Returns `Ok(())` if the request is within budget, or `Err(RateLimitRejection)`
23/// with the appropriate retry-after duration and scope.
24///
25/// The `metrics_prefix` is used to increment a counter named
26/// `{metrics_prefix}_rate_limit_hits_total` with a `scope` label on rejection.
27pub fn check_keyed_cost<K: std::hash::Hash + Eq + Clone>(
28    limiter: &KeyedLimiter<K>,
29    key: &K,
30    cost: NonZeroU32,
31    scope: &'static str,
32    metrics_prefix: &'static str,
33) -> Result<(), RateLimitRejection> {
34    let retry_secs = match limiter.check_key_n(key, cost) {
35        Ok(Ok(())) => return Ok(()),
36        Ok(Err(not_until)) => {
37            let wait =
38                not_until.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()));
39            wait.as_secs()
40        }
41        // InsufficientCapacity: cost exceeds burst size entirely.
42        Err(_) => 60,
43    };
44    let counter_name = format!("{metrics_prefix}_rate_limit_hits_total");
45    metrics::counter!(counter_name, "scope" => scope).increment(1);
46    Err(RateLimitRejection {
47        retry_after_secs: retry_secs.max(1),
48        scope,
49    })
50}
51
52/// Check a direct (unkeyed/global) rate limiter with the given cost.
53///
54/// Returns `Ok(())` if the request is within budget, or `Err(RateLimitRejection)`
55/// with scope `"global"`.
56pub fn check_direct_cost(
57    limiter: &RateLimiter<governor::state::NotKeyed, governor::state::InMemoryState, DefaultClock>,
58    cost: NonZeroU32,
59    metrics_prefix: &'static str,
60) -> Result<(), RateLimitRejection> {
61    let retry_secs = match limiter.check_n(cost) {
62        Ok(Ok(())) => return Ok(()),
63        Ok(Err(not_until)) => {
64            let wait =
65                not_until.wait_time_from(governor::clock::Clock::now(&DefaultClock::default()));
66            wait.as_secs()
67        }
68        Err(_) => 60,
69    };
70    let counter_name = format!("{metrics_prefix}_rate_limit_hits_total");
71    metrics::counter!(counter_name, "scope" => "global").increment(1);
72    Err(RateLimitRejection {
73        retry_after_secs: retry_secs.max(1),
74        scope: "global",
75    })
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use governor::Quota;
82    use std::net::IpAddr;
83
84    fn make_keyed_limiter(per_minute: u32, burst: u32) -> KeyedLimiter<IpAddr> {
85        RateLimiter::keyed(
86            Quota::per_minute(NonZeroU32::new(per_minute).unwrap())
87                .allow_burst(NonZeroU32::new(burst).unwrap()),
88        )
89    }
90
91    fn make_direct_limiter(
92        per_minute: u32,
93        burst: u32,
94    ) -> RateLimiter<governor::state::NotKeyed, governor::state::InMemoryState, DefaultClock> {
95        RateLimiter::direct(
96            Quota::per_minute(NonZeroU32::new(per_minute).unwrap())
97                .allow_burst(NonZeroU32::new(burst).unwrap()),
98        )
99    }
100
101    #[test]
102    fn keyed_allows_within_budget() {
103        let limiter = make_keyed_limiter(30, 10);
104        let ip: IpAddr = "198.51.100.1".parse().unwrap();
105        let cost = NonZeroU32::new(5).unwrap();
106
107        assert!(check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").is_ok());
108    }
109
110    #[test]
111    fn keyed_rejects_when_exhausted() {
112        let limiter = make_keyed_limiter(30, 10);
113        let ip: IpAddr = "198.51.100.1".parse().unwrap();
114
115        // Exhaust the burst
116        let cost = NonZeroU32::new(10).unwrap();
117        assert!(check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").is_ok());
118
119        // Next request should be rejected
120        let cost = NonZeroU32::new(1).unwrap();
121        let err = check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").unwrap_err();
122        assert_eq!(err.scope, "per_ip");
123        assert!(err.retry_after_secs >= 1);
124    }
125
126    #[test]
127    fn keyed_independent_keys() {
128        let limiter = make_keyed_limiter(30, 10);
129        let ip1: IpAddr = "198.51.100.1".parse().unwrap();
130        let ip2: IpAddr = "198.51.100.2".parse().unwrap();
131        let cost = NonZeroU32::new(10).unwrap();
132
133        assert!(check_keyed_cost(&limiter, &ip1, cost, "per_ip", "test").is_ok());
134        assert!(check_keyed_cost(&limiter, &ip2, cost, "per_ip", "test").is_ok());
135    }
136
137    #[test]
138    fn keyed_insufficient_capacity() {
139        let limiter = make_keyed_limiter(30, 10);
140        let ip: IpAddr = "198.51.100.1".parse().unwrap();
141        let cost = NonZeroU32::new(11).unwrap();
142
143        let err = check_keyed_cost(&limiter, &ip, cost, "per_ip", "test").unwrap_err();
144        assert_eq!(err.retry_after_secs, 60);
145    }
146
147    #[test]
148    fn direct_allows_within_budget() {
149        let limiter = make_direct_limiter(500, 50);
150        let cost = NonZeroU32::new(10).unwrap();
151
152        assert!(check_direct_cost(&limiter, cost, "test").is_ok());
153    }
154
155    #[test]
156    fn direct_rejects_when_exhausted() {
157        let limiter = make_direct_limiter(500, 50);
158
159        let cost = NonZeroU32::new(50).unwrap();
160        assert!(check_direct_cost(&limiter, cost, "test").is_ok());
161
162        let cost = NonZeroU32::new(1).unwrap();
163        let err = check_direct_cost(&limiter, cost, "test").unwrap_err();
164        assert_eq!(err.scope, "global");
165        assert!(err.retry_after_secs >= 1);
166    }
167
168    #[test]
169    fn direct_insufficient_capacity() {
170        let limiter = make_direct_limiter(500, 50);
171        let cost = NonZeroU32::new(51).unwrap();
172
173        let err = check_direct_cost(&limiter, cost, "test").unwrap_err();
174        assert_eq!(err.retry_after_secs, 60);
175        assert_eq!(err.scope, "global");
176    }
177}