radicle_protocol/service/
limiter.rs

1use std::collections::{HashMap, HashSet};
2
3use localtime::LocalTime;
4use radicle::node::{address, config, HostName, NodeId};
5
6/// Peer rate limiter.
7///
8/// Uses a token bucket algorithm, where each address starts with a certain amount of tokens,
9/// and every request from that address consumes one token. Tokens refill at a predefined
10/// rate. This mechanism allows for consistent request rates with potential bursts up to the
11/// bucket's capacity.
12#[derive(Debug, Default)]
13pub struct RateLimiter {
14    pub buckets: HashMap<HostName, TokenBucket>,
15    pub bypass: HashSet<NodeId>,
16}
17
18impl RateLimiter {
19    /// Create a new rate limiter with a bypass list. Nodes in the bypass list are not limited.
20    pub fn new(bypass: impl IntoIterator<Item = NodeId>) -> Self {
21        Self {
22            buckets: HashMap::default(),
23            bypass: bypass.into_iter().collect(),
24        }
25    }
26
27    /// Call this when the address has performed some rate-limited action.
28    /// Returns whether the action is rate-limited or not.
29    ///
30    /// Supplying a different amount of tokens per address is useful if for eg. a peer
31    /// is outbound vs. inbound.
32    pub fn limit<T: AsTokens>(
33        &mut self,
34        addr: HostName,
35        nid: Option<&NodeId>,
36        tokens: &T,
37        now: LocalTime,
38    ) -> bool {
39        if let Some(nid) = nid {
40            if self.bypass.contains(nid) {
41                return false;
42            }
43        }
44        if let HostName::Ip(ip) = addr {
45            // Don't limit LAN addresses.
46            if !address::is_routable(&ip) {
47                return false;
48            }
49        }
50        !self
51            .buckets
52            .entry(addr)
53            .or_insert_with(|| TokenBucket::new(tokens.capacity(), tokens.rate(), now))
54            .take(now)
55    }
56}
57
58/// Any type that can be assigned a number of rate-limit tokens.
59pub trait AsTokens {
60    /// Get the token capacity for this object.
61    fn capacity(&self) -> usize;
62    /// Get the refill rate for this object.
63    /// A rate of `1.0` means one token per second.
64    fn rate(&self) -> f64;
65}
66
67impl AsTokens for config::RateLimit {
68    fn rate(&self) -> f64 {
69        self.fill_rate
70    }
71
72    fn capacity(&self) -> usize {
73        self.capacity
74    }
75}
76
77#[derive(Debug, serde::Serialize)]
78#[serde(rename_all = "camelCase")]
79pub struct TokenBucket {
80    /// Token refill rate per second.
81    rate: f64,
82    /// Token capacity.
83    capacity: f64,
84    /// Tokens remaining.
85    tokens: f64,
86    /// Time of last token refill.
87    refilled_at: LocalTime,
88}
89
90impl TokenBucket {
91    fn new(tokens: usize, rate: f64, now: LocalTime) -> Self {
92        Self {
93            rate,
94            capacity: tokens as f64,
95            tokens: tokens as f64,
96            refilled_at: now,
97        }
98    }
99
100    fn refill(&mut self, now: LocalTime) {
101        let elapsed = now.duration_since(self.refilled_at);
102        let tokens = elapsed.as_secs() as f64 * self.rate;
103
104        self.tokens = (self.tokens + tokens).min(self.capacity);
105        self.refilled_at = now;
106    }
107
108    fn take(&mut self, now: LocalTime) -> bool {
109        self.refill(now);
110
111        if self.tokens >= 1.0 {
112            self.tokens -= 1.0;
113            true
114        } else {
115            false
116        }
117    }
118}
119
120#[cfg(test)]
121#[allow(clippy::bool_assert_comparison, clippy::redundant_clone)]
122mod test {
123    use radicle::test::arbitrary;
124
125    use super::*;
126
127    impl AsTokens for (usize, f64) {
128        fn capacity(&self) -> usize {
129            self.0
130        }
131
132        fn rate(&self) -> f64 {
133            self.1
134        }
135    }
136
137    #[test]
138    fn test_limitter_refill() {
139        let mut r = RateLimiter::default();
140        let t = (3, 0.2); // Three tokens burst. One token every 5 seconds.
141        let a = HostName::Dns(String::from("seed.radicle.example.com"));
142        let n = arbitrary::gen::<NodeId>(1);
143        let n = Some(&n);
144
145        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(0)), false); // Burst capacity
146        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(1)), false); // Burst capacity
147        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(2)), false); // Burst capacity
148        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(3)), true); // Limited
149        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(4)), true); // Limited
150        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(5)), false); // Refilled (1)
151        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(6)), true); // Limited
152        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(7)), true); // Limited
153        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(8)), true); // Limited
154        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(9)), true); // Limited
155        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(10)), false); // Refilled (1)
156        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(11)), true); // Limited
157        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(12)), true); // Limited
158        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(13)), true); // Limited
159        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(14)), true); // Limited
160        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(15)), false); // Refilled (1)
161        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(16)), true); // Limited
162        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); // Refilled (3)
163        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); // Burst capacity
164        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); // Burst capacity
165        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), true); // Limited
166    }
167
168    #[test]
169    #[rustfmt::skip]
170    fn test_limitter_multi() {
171        let t = (1, 1.0); // One token per second. One token burst.
172        let n = arbitrary::gen::<NodeId>(1);
173        let n = Some(&n);
174        let mut r = RateLimiter::default();
175        let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
176        let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
177
178        assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), false);
179        assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), true);
180        assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), false);
181        assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), true);
182        assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), false);
183        assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), true);
184        assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), false);
185        assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), true);
186    }
187
188    #[test]
189    #[rustfmt::skip]
190    fn test_limitter_different_rates() {
191        let t1 = (1, 1.0); // One token per second. One token burst.
192        let t2 = (2, 2.0); // Two tokens per second. Two token burst.
193        let n = arbitrary::gen::<NodeId>(1);
194        let n = Some(&n);
195        let mut r = RateLimiter::default();
196        let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
197        let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
198
199        assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), false);
200        assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), true);
201        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
202        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
203        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), true);
204        assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), false); // Refilled (1)
205        assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), true);
206        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false); // Refilled (2)
207        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false);
208        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), true);
209    }
210}