Skip to main content

radicle_protocol/service/
limiter.rs

1use std::collections::{HashMap, HashSet};
2
3use localtime::LocalTime;
4use radicle::node::{address, config, HostName, NodeId};
5use serde::Serialize;
6
7/// Peer rate limiter.
8///
9/// Uses a token bucket algorithm, where each address starts with a certain amount of tokens,
10/// and every request from that address consumes one token. Tokens refill at a predefined
11/// rate. This mechanism allows for consistent request rates with potential bursts up to the
12/// bucket's capacity.
13#[derive(Debug, Default, Serialize)]
14pub struct RateLimiter {
15    pub buckets: HashMap<HostName, TokenBucket>,
16    pub bypass: HashSet<NodeId>,
17}
18
19impl RateLimiter {
20    /// Create a new rate limiter with a bypass list. Nodes in the bypass list are not limited.
21    pub fn new(bypass: impl IntoIterator<Item = NodeId>) -> Self {
22        Self {
23            buckets: HashMap::default(),
24            bypass: bypass.into_iter().collect(),
25        }
26    }
27
28    /// Call this when the address has performed some rate-limited action.
29    /// Returns whether the action is rate-limited or not.
30    ///
31    /// Supplying a different amount of tokens per address is useful if for eg. a peer
32    /// is outbound vs. inbound.
33    pub fn limit<T: AsTokens>(
34        &mut self,
35        addr: HostName,
36        nid: Option<&NodeId>,
37        tokens: &T,
38        now: LocalTime,
39    ) -> bool {
40        if let Some(nid) = nid {
41            if self.bypass.contains(nid) {
42                return false;
43            }
44        }
45        if let HostName::Ip(ip) = addr {
46            // Don't limit LAN addresses.
47            if !address::is_routable(&ip) {
48                return false;
49            }
50        }
51        !self
52            .buckets
53            .entry(addr)
54            .or_insert_with(|| TokenBucket::new(tokens.capacity(), tokens.rate(), now))
55            .take(now)
56    }
57}
58
59/// Any type that can be assigned a number of rate-limit tokens.
60pub trait AsTokens {
61    /// Get the token capacity for this object.
62    fn capacity(&self) -> usize;
63    /// Get the refill rate for this object.
64    /// A rate of `1.0` means one token per second.
65    fn rate(&self) -> f64;
66}
67
68impl AsTokens for config::RateLimit {
69    fn rate(&self) -> f64 {
70        self.fill_rate
71    }
72
73    fn capacity(&self) -> usize {
74        self.capacity
75    }
76}
77
78impl AsTokens for config::LimitRateInbound {
79    fn capacity(&self) -> usize {
80        config::RateLimit::from(*self).capacity()
81    }
82
83    fn rate(&self) -> f64 {
84        config::RateLimit::from(*self).rate()
85    }
86}
87
88impl AsTokens for config::LimitRateOutbound {
89    fn capacity(&self) -> usize {
90        config::RateLimit::from(*self).capacity()
91    }
92
93    fn rate(&self) -> f64 {
94        config::RateLimit::from(*self).rate()
95    }
96}
97
98#[derive(Debug, serde::Serialize)]
99#[serde(rename_all = "camelCase")]
100pub struct TokenBucket {
101    /// Token refill rate per second.
102    rate: f64,
103    /// Token capacity.
104    capacity: f64,
105    /// Tokens remaining.
106    tokens: f64,
107    /// Time of last token refill.
108    refilled_at: LocalTime,
109}
110
111impl TokenBucket {
112    fn new(tokens: usize, rate: f64, now: LocalTime) -> Self {
113        Self {
114            rate,
115            capacity: tokens as f64,
116            tokens: tokens as f64,
117            refilled_at: now,
118        }
119    }
120
121    fn refill(&mut self, now: LocalTime) {
122        let elapsed = now.duration_since(self.refilled_at);
123        let tokens = elapsed.as_secs() as f64 * self.rate;
124
125        self.tokens = (self.tokens + tokens).min(self.capacity);
126        self.refilled_at = now;
127    }
128
129    fn take(&mut self, now: LocalTime) -> bool {
130        self.refill(now);
131
132        if self.tokens >= 1.0 {
133            self.tokens -= 1.0;
134            true
135        } else {
136            false
137        }
138    }
139}
140
141#[cfg(test)]
142#[allow(clippy::bool_assert_comparison, clippy::redundant_clone)]
143mod test {
144    use radicle::test::arbitrary;
145
146    use super::*;
147
148    impl AsTokens for (usize, f64) {
149        fn capacity(&self) -> usize {
150            self.0
151        }
152
153        fn rate(&self) -> f64 {
154            self.1
155        }
156    }
157
158    #[test]
159    fn test_limiter_refill() {
160        let mut r = RateLimiter::default();
161        let t = (3, 0.2); // Three tokens burst. One token every 5 seconds.
162        let a = HostName::Dns(String::from("seed.radicle.example.com"));
163        let n = arbitrary::gen::<NodeId>(1);
164        let n = Some(&n);
165
166        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(0)), false); // Burst capacity
167        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(1)), false); // Burst capacity
168        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(2)), false); // Burst capacity
169        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(3)), true); // Limited
170        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(4)), true); // Limited
171        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(5)), false); // Refilled (1)
172        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(6)), true); // Limited
173        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(7)), true); // Limited
174        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(8)), true); // Limited
175        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(9)), true); // Limited
176        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(10)), false); // Refilled (1)
177        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(11)), true); // Limited
178        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(12)), true); // Limited
179        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(13)), true); // Limited
180        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(14)), true); // Limited
181        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(15)), false); // Refilled (1)
182        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(16)), true); // Limited
183        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); // Refilled (3)
184        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); // Burst capacity
185        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); // Burst capacity
186        assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), true); // Limited
187    }
188
189    #[test]
190    #[rustfmt::skip]
191    fn test_limiter_multi() {
192        let t = (1, 1.0); // One token per second. One token burst.
193        let n = arbitrary::gen::<NodeId>(1);
194        let n = Some(&n);
195        let mut r = RateLimiter::default();
196        let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
197        let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
198
199        assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), false);
200        assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), true);
201        assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), false);
202        assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), true);
203        assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), false);
204        assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), true);
205        assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), false);
206        assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), true);
207    }
208
209    #[test]
210    #[rustfmt::skip]
211    fn test_limiter_different_rates() {
212        let t1 = (1, 1.0); // One token per second. One token burst.
213        let t2 = (2, 2.0); // Two tokens per second. Two token burst.
214        let n = arbitrary::gen::<NodeId>(1);
215        let n = Some(&n);
216        let mut r = RateLimiter::default();
217        let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
218        let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
219
220        assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), false);
221        assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), true);
222        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
223        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
224        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), true);
225        assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), false); // Refilled (1)
226        assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), true);
227        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false); // Refilled (2)
228        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false);
229        assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), true);
230    }
231}