use std::collections::{HashMap, HashSet};
use localtime::LocalTime;
use radicle::node::{HostName, NodeId, address, config};
#[derive(Debug, Default)]
pub struct RateLimiter {
pub buckets: HashMap<HostName, TokenBucket>,
pub bypass: HashSet<NodeId>,
}
impl RateLimiter {
pub fn new(bypass: impl IntoIterator<Item = NodeId>) -> Self {
Self {
buckets: HashMap::default(),
bypass: bypass.into_iter().collect(),
}
}
pub fn limit<T: AsTokens>(
&mut self,
addr: HostName,
nid: Option<&NodeId>,
tokens: &T,
now: LocalTime,
) -> bool {
if let Some(nid) = nid {
if self.bypass.contains(nid) {
return false;
}
}
if let HostName::Ip(ip) = addr {
if !address::is_routable(&ip) {
return false;
}
}
!self
.buckets
.entry(addr)
.or_insert_with(|| TokenBucket::new(tokens.capacity(), tokens.rate(), now))
.take(now)
}
}
pub trait AsTokens {
fn capacity(&self) -> usize;
fn rate(&self) -> f64;
}
impl AsTokens for config::RateLimit {
fn rate(&self) -> f64 {
self.fill_rate
}
fn capacity(&self) -> usize {
self.capacity
}
}
impl AsTokens for config::LimitRateInbound {
fn capacity(&self) -> usize {
config::RateLimit::from(*self).capacity()
}
fn rate(&self) -> f64 {
config::RateLimit::from(*self).rate()
}
}
impl AsTokens for config::LimitRateOutbound {
fn capacity(&self) -> usize {
config::RateLimit::from(*self).capacity()
}
fn rate(&self) -> f64 {
config::RateLimit::from(*self).rate()
}
}
#[derive(Debug, serde::Serialize)]
#[serde(rename_all = "camelCase")]
pub struct TokenBucket {
rate: f64,
capacity: f64,
tokens: f64,
refilled_at: LocalTime,
}
impl TokenBucket {
fn new(tokens: usize, rate: f64, now: LocalTime) -> Self {
Self {
rate,
capacity: tokens as f64,
tokens: tokens as f64,
refilled_at: now,
}
}
fn refill(&mut self, now: LocalTime) {
let elapsed = now.duration_since(self.refilled_at);
let tokens = elapsed.as_secs() as f64 * self.rate;
self.tokens = (self.tokens + tokens).min(self.capacity);
self.refilled_at = now;
}
fn take(&mut self, now: LocalTime) -> bool {
self.refill(now);
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
}
#[cfg(test)]
#[allow(clippy::bool_assert_comparison, clippy::redundant_clone)]
mod test {
use radicle::test::arbitrary;
use super::*;
impl AsTokens for (usize, f64) {
fn capacity(&self) -> usize {
self.0
}
fn rate(&self) -> f64 {
self.1
}
}
#[test]
fn test_limiter_refill() {
let mut r = RateLimiter::default();
let t = (3, 0.2); let a = HostName::Dns(String::from("seed.radicle.example.com"));
let n = arbitrary::r#gen::<NodeId>(1);
let n = Some(&n);
assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(0)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(1)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(2)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(3)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(4)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(5)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(6)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(7)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(8)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(9)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(10)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(11)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(12)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(13)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(14)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(15)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(16)), true); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), false); assert_eq!(r.limit(a.clone(), n, &t, LocalTime::from_secs(60)), true); }
#[test]
#[rustfmt::skip]
fn test_limiter_multi() {
let t = (1, 1.0); let n = arbitrary::r#gen::<NodeId>(1);
let n = Some(&n);
let mut r = RateLimiter::default();
let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), false);
assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(0)), true);
assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), false);
assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(0)), true);
assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), false);
assert_eq!(r.limit(addr1.clone(), n, &t, LocalTime::from_secs(1)), true);
assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), false);
assert_eq!(r.limit(addr2.clone(), n, &t, LocalTime::from_secs(1)), true);
}
#[test]
#[rustfmt::skip]
fn test_limiter_different_rates() {
let t1 = (1, 1.0); let t2 = (2, 2.0); let n = arbitrary::r#gen::<NodeId>(1);
let n = Some(&n);
let mut r = RateLimiter::default();
let addr1 = HostName::Dns(String::from("seed.radicle.example.com"));
let addr2 = HostName::Dns(String::from("seed.radicle.example.net"));
assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), false);
assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(0)), true);
assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), false);
assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(0)), true);
assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), false); assert_eq!(r.limit(addr1.clone(), n, &t1, LocalTime::from_secs(1)), true);
assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false); assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), false);
assert_eq!(r.limit(addr2.clone(), n, &t2, LocalTime::from_secs(1)), true);
}
}