1use std::collections::{HashMap, HashSet};
2
3use localtime::LocalTime;
4use radicle::node::{address, config, HostName, NodeId};
5
6#[derive(Debug, Default)]
13pub struct RateLimiter {
14 pub buckets: HashMap<HostName, TokenBucket>,
15 pub bypass: HashSet<NodeId>,
16}
17
18impl RateLimiter {
19 pub fn new(bypass: impl IntoIterator<Item = NodeId>) -> Self {
21 Self {
22 buckets: HashMap::default(),
23 bypass: bypass.into_iter().collect(),
24 }
25 }
26
27 pub fn limit<T: AsTokens>(
33 &mut self,
34 addr: HostName,
35 nid: Option<&NodeId>,
36 tokens: &T,
37 now: LocalTime,
38 ) -> bool {
39 if let Some(nid) = nid {
40 if self.bypass.contains(nid) {
41 return false;
42 }
43 }
44 if let HostName::Ip(ip) = addr {
45 if !address::is_routable(&ip) {
47 return false;
48 }
49 }
50 !self
51 .buckets
52 .entry(addr)
53 .or_insert_with(|| TokenBucket::new(tokens.capacity(), tokens.rate(), now))
54 .take(now)
55 }
56}
57
58pub trait AsTokens {
60 fn capacity(&self) -> usize;
62 fn rate(&self) -> f64;
65}
66
67impl AsTokens for config::RateLimit {
68 fn rate(&self) -> f64 {
69 self.fill_rate
70 }
71
72 fn capacity(&self) -> usize {
73 self.capacity
74 }
75}
76
77#[derive(Debug, serde::Serialize)]
78#[serde(rename_all = "camelCase")]
79pub struct TokenBucket {
80 rate: f64,
82 capacity: f64,
84 tokens: f64,
86 refilled_at: LocalTime,
88}
89
90impl TokenBucket {
91 fn new(tokens: usize, rate: f64, now: LocalTime) -> Self {
92 Self {
93 rate,
94 capacity: tokens as f64,
95 tokens: tokens as f64,
96 refilled_at: now,
97 }
98 }
99
100 fn refill(&mut self, now: LocalTime) {
101 let elapsed = now.duration_since(self.refilled_at);
102 let tokens = elapsed.as_secs() as f64 * self.rate;
103
104 self.tokens = (self.tokens + tokens).min(self.capacity);
105 self.refilled_at = now;
106 }
107
108 fn take(&mut self, now: LocalTime) -> bool {
109 self.refill(now);
110
111 if self.tokens >= 1.0 {
112 self.tokens -= 1.0;
113 true
114 } else {
115 false
116 }
117 }
118}
119
120#[cfg(test)]
121#[allow(clippy::bool_assert_comparison, clippy::redundant_clone)]
122mod test {
123 use radicle::test::arbitrary;
124
125 use super::*;
126
127 impl AsTokens for (usize, f64) {
128 fn capacity(&self) -> usize {
129 self.0
130 }
131
132 fn rate(&self) -> f64 {
133 self.1
134 }
135 }
136
137 #[test]
138 fn test_limitter_refill() {
139 let mut r = RateLimiter::default();
140 let t = (3, 0.2); let a = HostName::Dns(String::from("seed.radicle.example.com"));
142 let n = arbitrary::gen::<NodeId>(1);
143 let n = Some(&n);
144
145 assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(0)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(1)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(2)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(3)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(4)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(5)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(6)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(7)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(8)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(9)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(10)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(11)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(12)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(13)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(14)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(15)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(16)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), true); }
167
168 #[test]
169 #[rustfmt::skip]
170 fn test_limitter_multi() {
171 let t = (1, 1.0); let n = arbitrary::gen::<NodeId>(1);
173 let n = Some(&n);
174 let mut r = RateLimiter::default();
175 let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
176 let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
177
178 assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), false);
179 assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), true);
180 assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), false);
181 assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), true);
182 assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), false);
183 assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), true);
184 assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), false);
185 assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), true);
186 }
187
188 #[test]
189 #[rustfmt::skip]
190 fn test_limitter_different_rates() {
191 let t1 = (1, 1.0); let t2 = (2, 2.0); let n = arbitrary::gen::<NodeId>(1);
194 let n = Some(&n);
195 let mut r = RateLimiter::default();
196 let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
197 let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
198
199 assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), false);
200 assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), true);
201 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
202 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
203 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), true);
204 assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), false); assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), true);
206 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false); assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false);
208 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), true);
209 }
210}