1use std::collections::{HashMap, HashSet};
2
3use localtime::LocalTime;
4use radicle::node::{address, config, HostName, NodeId};
5use serde::Serialize;
6
7#[derive(Debug, Default, Serialize)]
14pub struct RateLimiter {
15 pub buckets: HashMap<HostName, TokenBucket>,
16 pub bypass: HashSet<NodeId>,
17}
18
19impl RateLimiter {
20 pub fn new(bypass: impl IntoIterator<Item = NodeId>) -> Self {
22 Self {
23 buckets: HashMap::default(),
24 bypass: bypass.into_iter().collect(),
25 }
26 }
27
28 pub fn limit<T: AsTokens>(
34 &mut self,
35 addr: HostName,
36 nid: Option<&NodeId>,
37 tokens: &T,
38 now: LocalTime,
39 ) -> bool {
40 if let Some(nid) = nid {
41 if self.bypass.contains(nid) {
42 return false;
43 }
44 }
45 if let HostName::Ip(ip) = addr {
46 if !address::is_routable(&ip) {
48 return false;
49 }
50 }
51 !self
52 .buckets
53 .entry(addr)
54 .or_insert_with(|| TokenBucket::new(tokens.capacity(), tokens.rate(), now))
55 .take(now)
56 }
57}
58
59pub trait AsTokens {
61 fn capacity(&self) -> usize;
63 fn rate(&self) -> f64;
66}
67
68impl AsTokens for config::RateLimit {
69 fn rate(&self) -> f64 {
70 self.fill_rate
71 }
72
73 fn capacity(&self) -> usize {
74 self.capacity
75 }
76}
77
78impl AsTokens for config::LimitRateInbound {
79 fn capacity(&self) -> usize {
80 config::RateLimit::from(*self).capacity()
81 }
82
83 fn rate(&self) -> f64 {
84 config::RateLimit::from(*self).rate()
85 }
86}
87
88impl AsTokens for config::LimitRateOutbound {
89 fn capacity(&self) -> usize {
90 config::RateLimit::from(*self).capacity()
91 }
92
93 fn rate(&self) -> f64 {
94 config::RateLimit::from(*self).rate()
95 }
96}
97
98#[derive(Debug, serde::Serialize)]
99#[serde(rename_all = "camelCase")]
100pub struct TokenBucket {
101 rate: f64,
103 capacity: f64,
105 tokens: f64,
107 refilled_at: LocalTime,
109}
110
111impl TokenBucket {
112 fn new(tokens: usize, rate: f64, now: LocalTime) -> Self {
113 Self {
114 rate,
115 capacity: tokens as f64,
116 tokens: tokens as f64,
117 refilled_at: now,
118 }
119 }
120
121 fn refill(&mut self, now: LocalTime) {
122 let elapsed = now.duration_since(self.refilled_at);
123 let tokens = elapsed.as_secs() as f64 * self.rate;
124
125 self.tokens = (self.tokens + tokens).min(self.capacity);
126 self.refilled_at = now;
127 }
128
129 fn take(&mut self, now: LocalTime) -> bool {
130 self.refill(now);
131
132 if self.tokens >= 1.0 {
133 self.tokens -= 1.0;
134 true
135 } else {
136 false
137 }
138 }
139}
140
141#[cfg(test)]
142#[allow(clippy::bool_assert_comparison, clippy::redundant_clone)]
143mod test {
144 use radicle::test::arbitrary;
145
146 use super::*;
147
148 impl AsTokens for (usize, f64) {
149 fn capacity(&self) -> usize {
150 self.0
151 }
152
153 fn rate(&self) -> f64 {
154 self.1
155 }
156 }
157
158 #[test]
159 fn test_limiter_refill() {
160 let mut r = RateLimiter::default();
161 let t = (3, 0.2); let a = HostName::Dns(String::from("seed.radicle.example.com"));
163 let n = arbitrary::gen::<NodeId>(1);
164 let n = Some(&n);
165
166 assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(0)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(1)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(2)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(3)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(4)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(5)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(6)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(7)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(8)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(9)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(10)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(11)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(12)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(13)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(14)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(15)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(16)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), true); }
188
189 #[test]
190 #[rustfmt::skip]
191 fn test_limiter_multi() {
192 let t = (1, 1.0); let n = arbitrary::gen::<NodeId>(1);
194 let n = Some(&n);
195 let mut r = RateLimiter::default();
196 let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
197 let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
198
199 assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), false);
200 assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), true);
201 assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), false);
202 assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), true);
203 assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), false);
204 assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), true);
205 assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), false);
206 assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), true);
207 }
208
209 #[test]
210 #[rustfmt::skip]
211 fn test_limiter_different_rates() {
212 let t1 = (1, 1.0); let t2 = (2, 2.0); let n = arbitrary::gen::<NodeId>(1);
215 let n = Some(&n);
216 let mut r = RateLimiter::default();
217 let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
218 let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
219
220 assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), false);
221 assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), true);
222 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
223 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
224 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), true);
225 assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), false); assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), true);
227 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false); assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false);
229 assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), true);
230 }
231}