volli-manager 0.1.12

Manager for volli
Documentation
use lru::LruCache;
use std::{collections::VecDeque, net::IpAddr, num::NonZeroUsize};
use tokio::time::{Duration, Instant};

pub struct RateLimiter {
    entries: LruCache<IpAddr, VecDeque<Instant>>,
    max: usize,
    interval: Duration,
}

impl RateLimiter {
    pub fn new(max: usize, interval: Duration) -> Self {
        Self {
            entries: LruCache::new(NonZeroUsize::new(1024).unwrap()),
            max,
            interval,
        }
    }

    pub fn check(&mut self, ip: IpAddr) -> bool {
        if self.max == 0 {
            return true;
        }
        let now = Instant::now();
        let deque = self.entries.get_or_insert_mut(ip, VecDeque::new);
        while let Some(&ts) = deque.front() {
            if now.duration_since(ts) > self.interval {
                deque.pop_front();
            } else {
                break;
            }
        }
        if deque.len() >= self.max {
            return false;
        }
        deque.push_back(now);
        true
    }
}

#[derive(Clone, Copy)]
struct BackoffEntry {
    failures: u32,
    next_allowed: Instant,
}

pub struct AuthBackoff {
    entries: LruCache<IpAddr, BackoffEntry>,
    base_delay: Duration,
    max_delay: Duration,
}

impl AuthBackoff {
    pub fn new(base_delay: Duration, max_delay: Duration) -> Self {
        Self {
            entries: LruCache::new(NonZeroUsize::new(1024).unwrap()),
            base_delay,
            max_delay,
        }
    }

    pub fn allow(&mut self, ip: IpAddr) -> bool {
        match self.entries.get(&ip) {
            Some(entry) => Instant::now() >= entry.next_allowed,
            None => true,
        }
    }

    pub fn record_failure(&mut self, ip: IpAddr) -> Duration {
        let now = Instant::now();
        let failures = self
            .entries
            .get(&ip)
            .map(|e| e.failures.saturating_add(1))
            .unwrap_or(1);
        let exp = failures.saturating_sub(1).min(6);
        let delay = self
            .base_delay
            .checked_mul(1u32 << exp)
            .unwrap_or(self.max_delay)
            .min(self.max_delay);
        self.entries.put(
            ip,
            BackoffEntry {
                failures,
                next_allowed: now + delay,
            },
        );
        delay
    }

    pub fn record_success(&mut self, ip: IpAddr) {
        self.entries.pop(&ip);
    }
}

#[cfg(test)]
mod tests {
    use super::{AuthBackoff, RateLimiter};
    use std::net::IpAddr;
    use std::str::FromStr;
    use tokio::time::Duration;

    #[tokio::test(flavor = "current_thread")]
    async fn rate_limiter_blocks_over_limit() {
        let mut limiter = RateLimiter::new(2, Duration::from_millis(50));
        let ip = IpAddr::from_str("127.0.0.1").unwrap();
        assert!(limiter.check(ip));
        assert!(limiter.check(ip));
        assert!(!limiter.check(ip));
    }

    #[tokio::test(flavor = "current_thread")]
    async fn auth_backoff_recovers_after_delay() {
        let mut backoff = AuthBackoff::new(Duration::from_millis(5), Duration::from_millis(20));
        let ip = IpAddr::from_str("127.0.0.1").unwrap();
        assert!(backoff.allow(ip));
        backoff.record_failure(ip);
        assert!(!backoff.allow(ip));
        tokio::time::sleep(Duration::from_millis(6)).await;
        assert!(backoff.allow(ip));
    }

    #[tokio::test(flavor = "current_thread")]
    async fn auth_backoff_clears_on_success() {
        let mut backoff = AuthBackoff::new(Duration::from_millis(5), Duration::from_millis(20));
        let ip = IpAddr::from_str("127.0.0.1").unwrap();
        backoff.record_failure(ip);
        assert!(!backoff.allow(ip));
        backoff.record_success(ip);
        assert!(backoff.allow(ip));
    }
}