1use std::collections::{HashMap, HashSet};
2
3use localtime::LocalTime;
4use radicle::node::{address, config, HostName, NodeId};
5
6#[derive(Debug, Default)]
13pub struct RateLimiter {
14 pub buckets: HashMap<HostName, TokenBucket>,
15 pub bypass: HashSet<NodeId>,
16}
17
18impl RateLimiter {
19 pub fn new(bypass: impl IntoIterator<Item = NodeId>) -> Self {
21 Self {
22 buckets: HashMap::default(),
23 bypass: bypass.into_iter().collect(),
24 }
25 }
26
27 pub fn limit<T: AsTokens>(
33 &mut self,
34 addr: HostName,
35 nid: Option<&NodeId>,
36 tokens: &T,
37 now: LocalTime,
38 ) -> bool {
39 if let Some(nid) = nid {
40 if self.bypass.contains(nid) {
41 return false;
42 }
43 }
44 if let HostName::Ip(ip) = addr {
45 if !address::is_routable(&ip) {
47 return false;
48 }
49 }
50 !self
51 .buckets
52 .entry(addr)
53 .or_insert_with(|| TokenBucket::new(tokens.capacity(), tokens.rate(), now))
54 .take(now)
55 }
56}
57
58pub trait AsTokens {
60 fn capacity(&self) -> usize;
62 fn rate(&self) -> f64;
65}
66
67impl AsTokens for config::RateLimit {
68 fn rate(&self) -> f64 {
69 self.fill_rate
70 }
71
72 fn capacity(&self) -> usize {
73 self.capacity
74 }
75}
76
77impl AsTokens for config::LimitRateInbound {
78 fn capacity(&self) -> usize {
79 config::RateLimit::from(*self).capacity()
80 }
81
82 fn rate(&self) -> f64 {
83 config::RateLimit::from(*self).rate()
84 }
85}
86
87impl AsTokens for config::LimitRateOutbound {
88 fn capacity(&self) -> usize {
89 config::RateLimit::from(*self).capacity()
90 }
91
92 fn rate(&self) -> f64 {
93 config::RateLimit::from(*self).rate()
94 }
95}
96
97#[derive(Debug, serde::Serialize)]
98#[serde(rename_all = "camelCase")]
99pub struct TokenBucket {
100 rate: f64,
102 capacity: f64,
104 tokens: f64,
106 refilled_at: LocalTime,
108}
109
110impl TokenBucket {
111 fn new(tokens: usize, rate: f64, now: LocalTime) -> Self {
112 Self {
113 rate,
114 capacity: tokens as f64,
115 tokens: tokens as f64,
116 refilled_at: now,
117 }
118 }
119
120 fn refill(&mut self, now: LocalTime) {
121 let elapsed = now.duration_since(self.refilled_at);
122 let tokens = elapsed.as_secs() as f64 * self.rate;
123
124 self.tokens = (self.tokens + tokens).min(self.capacity);
125 self.refilled_at = now;
126 }
127
128 fn take(&mut self, now: LocalTime) -> bool {
129 self.refill(now);
130
131 if self.tokens >= 1.0 {
132 self.tokens -= 1.0;
133 true
134 } else {
135 false
136 }
137 }
138}
139
140#[cfg(test)]
141#[allow(clippy::bool_assert_comparison, clippy::redundant_clone)]
142mod test {
143 use radicle::test::arbitrary;
144
145 use super::*;
146
147 impl AsTokens for (usize, f64) {
148 fn capacity(&self) -> usize {
149 self.0
150 }
151
152 fn rate(&self) -> f64 {
153 self.1
154 }
155 }
156
157 #[test]
158 fn test_limiter_refill() {
159 let mut r = RateLimiter::default();
160 let t = (3, 0.2); let a = HostName::Dns(String::from("seed.radicle.example.com"));
162 let n = arbitrary::gen::<NodeId>(1);
163 let n = Some(&n);
164
165 assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(0)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(1)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(2)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(3)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(4)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(5)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(6)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(7)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(8)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(9)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(10)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(11)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(12)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(13)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(14)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(15)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(16)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), true); }
187
188 #[test]
189 #[rustfmt::skip]
190 fn test_limiter_multi() {
191 let t = (1, 1.0); let n = arbitrary::gen::<NodeId>(1);
193 let n = Some(&n);
194 let mut r = RateLimiter::default();
195 let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
196 let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
197
198 assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), false);
199 assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), true);
200 assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), false);
201 assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), true);
202 assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), false);
203 assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), true);
204 assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), false);
205 assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), true);
206 }
207
208 #[test]
209 #[rustfmt::skip]
210 fn test_limiter_different_rates() {
211 let t1 = (1, 1.0); let t2 = (2, 2.0); let n = arbitrary::gen::<NodeId>(1);
214 let n = Some(&n);
215 let mut r = RateLimiter::default();
216 let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
217 let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
218
219 assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), false);
220 assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), true);
221 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
222 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
223 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), true);
224 assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), false); assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), true);
226 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false); assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false);
228 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), true);
229 }
230}