use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::task::JoinHandle;
type Map = HashMap<Arc<std::net::Ipv6Addr>, u64>;
pub struct IpRate {
origin: tokio::time::Instant,
map: Arc<Mutex<Map>>,
disabled: bool,
limit: u64,
burst: u64,
ip_deny: crate::ip_deny::IpDeny,
}
impl IpRate {
pub fn new(config: Arc<crate::Config>) -> Self {
Self {
origin: tokio::time::Instant::now(),
map: Arc::new(Mutex::new(HashMap::new())),
disabled: config.disable_rate_limiting,
limit: config.limit_ip_byte_nanos() as u64,
burst: config.limit_ip_byte_burst as u64
* config.limit_ip_byte_nanos() as u64,
ip_deny: crate::ip_deny::IpDeny::new(config),
}
}
pub fn prune(&self) {
let now = self.origin.elapsed().as_nanos() as u64;
self.map.lock().unwrap().retain(|_, cur| {
if now <= *cur {
true
} else {
now - *cur < 10_000_000_000
}
});
}
pub async fn is_blocked(&self, ip: &Arc<std::net::Ipv6Addr>) -> bool {
self.ip_deny.is_blocked(ip).await
}
pub async fn is_ok(
&self,
ip: &Arc<std::net::Ipv6Addr>,
bytes: usize,
) -> bool {
if self.disabled {
return true;
}
let rate_add = bytes as u64 * self.limit;
let now = self.origin.elapsed().as_nanos() as u64;
let is_ok = {
let mut lock = self.map.lock().unwrap();
let e = lock.entry(ip.clone()).or_insert(now);
let cur = std::cmp::max(*e, now) + rate_add;
*e = cur;
cur - now <= self.burst
};
if !is_ok {
tracing::info!("IP rate limit exceeded for {ip}, blocking");
self.ip_deny.block(ip).await;
}
is_ok
}
}
pub fn spawn_prune_task(ip_rate: Arc<IpRate>) -> JoinHandle<()> {
let ip_rate = Arc::downgrade(&ip_rate);
tokio::task::spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
if let Some(ip_rate) = ip_rate.upgrade() {
ip_rate.prune();
} else {
break;
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
fn test_new(limit: u64, burst: u64) -> IpRate {
IpRate {
origin: tokio::time::Instant::now(),
map: Arc::new(Mutex::new(HashMap::new())),
disabled: false,
limit,
burst,
ip_deny: crate::ip_deny::IpDeny::new(Arc::new(
crate::Config::default(),
)),
}
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn check_one_to_one() {
let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
let rate = test_new(1, 1);
for _ in 0..10 {
tokio::time::advance(std::time::Duration::from_nanos(1)).await;
assert!(rate.is_ok(&addr1, 1).await);
}
assert!(!rate.is_ok(&addr1, 1).await);
tokio::time::advance(std::time::Duration::from_nanos(1)).await;
rate.prune();
assert_eq!(1, rate.map.lock().unwrap().len());
tokio::time::advance(std::time::Duration::from_secs(10)).await;
rate.prune();
assert_eq!(1, rate.map.lock().unwrap().len());
tokio::time::advance(std::time::Duration::from_nanos(1)).await;
rate.prune();
assert_eq!(0, rate.map.lock().unwrap().len());
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn check_burst() {
let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
let rate = test_new(1, 5);
for _ in 0..5 {
assert!(rate.is_ok(&addr1, 1).await);
}
assert!(!rate.is_ok(&addr1, 1).await);
tokio::time::advance(std::time::Duration::from_nanos(2)).await;
assert!(rate.is_ok(&addr1, 1).await);
tokio::time::advance(std::time::Duration::from_secs(10)).await;
tokio::time::advance(std::time::Duration::from_nanos(4)).await;
rate.prune();
assert_eq!(1, rate.map.lock().unwrap().len());
tokio::time::advance(std::time::Duration::from_nanos(1)).await;
rate.prune();
assert_eq!(0, rate.map.lock().unwrap().len());
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn check_limit_mult() {
let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
let rate = test_new(3, 13);
assert!(rate.is_ok(&addr1, 2).await);
assert!(rate.is_ok(&addr1, 2).await);
assert!(!rate.is_ok(&addr1, 2).await);
tokio::time::advance(std::time::Duration::from_secs(10)).await;
assert!(rate.is_ok(&addr1, 2).await);
assert!(rate.is_ok(&addr1, 2).await);
assert!(!rate.is_ok(&addr1, 2).await);
}
}