Skip to main content

radicle_protocol/service/
limiter.rs

1use std::collections::{HashMap, HashSet};
2
3use localtime::LocalTime;
4use radicle::node::{address, config, HostName, NodeId};
5
6/// Peer rate limiter.
7///
8/// Uses a token bucket algorithm, where each address starts with a certain amount of tokens,
9/// and every request from that address consumes one token. Tokens refill at a predefined
10/// rate. This mechanism allows for consistent request rates with potential bursts up to the
11/// bucket's capacity.
12#[derive(Debug, Default)]
13pub struct RateLimiter {
14    pub buckets: HashMap<HostName, TokenBucket>,
15    pub bypass: HashSet<NodeId>,
16}
17
18impl RateLimiter {
19    /// Create a new rate limiter with a bypass list. Nodes in the bypass list are not limited.
20    pub fn new(bypass: impl IntoIterator<Item = NodeId>) -> Self {
21        Self {
22            buckets: HashMap::default(),
23            bypass: bypass.into_iter().collect(),
24        }
25    }
26
27    /// Call this when the address has performed some rate-limited action.
28    /// Returns whether the action is rate-limited or not.
29    ///
30    /// Supplying a different amount of tokens per address is useful if for eg. a peer
31    /// is outbound vs. inbound.
32    pub fn limit<T: AsTokens>(
33        &mut self,
34        addr: HostName,
35        nid: Option<&NodeId>,
36        tokens: &T,
37        now: LocalTime,
38    ) -> bool {
39        if let Some(nid) = nid {
40            if self.bypass.contains(nid) {
41                return false;
42            }
43        }
44        if let HostName::Ip(ip) = addr {
45            // Don't limit LAN addresses.
46            if !address::is_routable(&ip) {
47                return false;
48            }
49        }
50        !self
51            .buckets
52            .entry(addr)
53            .or_insert_with(|| TokenBucket::new(tokens.capacity(), tokens.rate(), now))
54            .take(now)
55    }
56}
57
58/// Any type that can be assigned a number of rate-limit tokens.
59pub trait AsTokens {
60    /// Get the token capacity for this object.
61    fn capacity(&self) -> usize;
62    /// Get the refill rate for this object.
63    /// A rate of `1.0` means one token per second.
64    fn rate(&self) -> f64;
65}
66
67impl AsTokens for config::RateLimit {
68    fn rate(&self) -> f64 {
69        self.fill_rate
70    }
71
72    fn capacity(&self) -> usize {
73        self.capacity
74    }
75}
76
77impl AsTokens for config::LimitRateInbound {
78    fn capacity(&self) -> usize {
79        config::RateLimit::from(*self).capacity()
80    }
81
82    fn rate(&self) -> f64 {
83        config::RateLimit::from(*self).rate()
84    }
85}
86
87impl AsTokens for config::LimitRateOutbound {
88    fn capacity(&self) -> usize {
89        config::RateLimit::from(*self).capacity()
90    }
91
92    fn rate(&self) -> f64 {
93        config::RateLimit::from(*self).rate()
94    }
95}
96
97#[derive(Debug, serde::Serialize)]
98#[serde(rename_all = "camelCase")]
99pub struct TokenBucket {
100    /// Token refill rate per second.
101    rate: f64,
102    /// Token capacity.
103    capacity: f64,
104    /// Tokens remaining.
105    tokens: f64,
106    /// Time of last token refill.
107    refilled_at: LocalTime,
108}
109
110impl TokenBucket {
111    fn new(tokens: usize, rate: f64, now: LocalTime) -> Self {
112        Self {
113            rate,
114            capacity: tokens as f64,
115            tokens: tokens as f64,
116            refilled_at: now,
117        }
118    }
119
120    fn refill(&mut self, now: LocalTime) {
121        let elapsed = now.duration_since(self.refilled_at);
122        let tokens = elapsed.as_secs() as f64 * self.rate;
123
124        self.tokens = (self.tokens + tokens).min(self.capacity);
125        self.refilled_at = now;
126    }
127
128    fn take(&mut self, now: LocalTime) -> bool {
129        self.refill(now);
130
131        if self.tokens >= 1.0 {
132            self.tokens -= 1.0;
133            true
134        } else {
135            false
136        }
137    }
138}
139
140#[cfg(test)]
141#[allow(clippy::bool_assert_comparison, clippy::redundant_clone)]
142mod test {
143    use radicle::test::arbitrary;
144
145    use super::*;
146
147    impl AsTokens for (usize, f64) {
148        fn capacity(&self) -> usize {
149            self.0
150        }
151
152        fn rate(&self) -> f64 {
153            self.1
154        }
155    }
156
157    #[test]
158    fn test_limiter_refill() {
159        let mut r = RateLimiter::default();
160        let t = (3, 0.2); // Three tokens burst. One token every 5 seconds.
161        let a = HostName::Dns(String::from("seed.radicle.example.com"));
162        let n = arbitrary::gen::<NodeId>(1);
163        let n = Some(&n);
164
165        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(0)), false); // Burst capacity
166        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(1)), false); // Burst capacity
167        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(2)), false); // Burst capacity
168        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(3)), true); // Limited
169        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(4)), true); // Limited
170        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(5)), false); // Refilled (1)
171        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(6)), true); // Limited
172        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(7)), true); // Limited
173        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(8)), true); // Limited
174        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(9)), true); // Limited
175        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(10)), false); // Refilled (1)
176        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(11)), true); // Limited
177        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(12)), true); // Limited
178        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(13)), true); // Limited
179        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(14)), true); // Limited
180        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(15)), false); // Refilled (1)
181        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(16)), true); // Limited
182        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); // Refilled (3)
183        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); // Burst capacity
184        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); // Burst capacity
185        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), true); // Limited
186    }
187
188    #[test]
189    #[rustfmt::skip]
190    fn test_limiter_multi() {
191        let t = (1, 1.0); // One token per second. One token burst.
192        let n = arbitrary::gen::<NodeId>(1);
193        let n = Some(&n);
194        let mut r = RateLimiter::default();
195        let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
196        let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
197
198        assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), false);
199        assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), true);
200        assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), false);
201        assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), true);
202        assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), false);
203        assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), true);
204        assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), false);
205        assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), true);
206    }
207
208    #[test]
209    #[rustfmt::skip]
210    fn test_limiter_different_rates() {
211        let t1 = (1, 1.0); // One token per second. One token burst.
212        let t2 = (2, 2.0); // Two tokens per second. Two token burst.
213        let n = arbitrary::gen::<NodeId>(1);
214        let n = Some(&n);
215        let mut r = RateLimiter::default();
216        let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
217        let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
218
219        assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), false);
220        assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), true);
221        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
222        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
223        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), true);
224        assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), false); // Refilled (1)
225        assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), true);
226        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false); // Refilled (2)
227        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false);
228        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), true);
229    }
230}