use std::sync::Arc;
use crate::tools::time::TimeMillis;
pub struct DdosScore {
score: f64,
last_updated: Option<TimeMillis>,
}
impl Default for DdosScore {
fn default() -> Self {
Self::new()
}
}
impl DdosScore {
pub fn new() -> Self {
Self { score: 0.0, last_updated: None }
}
fn elapsed_secs(&self, now: TimeMillis) -> f64 {
match self.last_updated {
Some(prev) => (now.0.saturating_sub(prev.0).max(0) as f64) / 1000.0,
None => 0.0,
}
}
pub fn increment(&mut self, points: f64, decay_per_second: f64, now: TimeMillis) -> f64 {
self.score = (self.score - decay_per_second * self.elapsed_secs(now)).max(0.0) + points;
self.last_updated = Some(now);
self.score
}
pub fn current(&self, decay_per_second: f64, now: TimeMillis) -> f64 {
(self.score - decay_per_second * self.elapsed_secs(now)).max(0.0)
}
}
pub struct DdosConnectionGuard {
ip: String,
ddos: Arc<dyn DdosProtection>,
}
impl DdosConnectionGuard {
pub fn try_new(ddos: Arc<dyn DdosProtection>, ip: impl Into<String>) -> Option<Self> {
let ip = ip.into();
if ddos.try_acquire_connection(&ip) {
Some(Self { ip, ddos })
} else {
None
}
}
pub fn ip(&self) -> &str {
&self.ip
}
pub fn allow_request(&self) -> bool {
self.ddos.allow_request(&self.ip)
}
pub fn report_bad_request(&self) {
self.ddos.report_bad_request(&self.ip)
}
}
impl Drop for DdosConnectionGuard {
fn drop(&mut self) {
self.ddos.release_connection(&self.ip);
}
}
pub trait DdosProtection: Send + Sync {
fn allow_request(&self, ip: &str) -> bool;
fn report_bad_request(&self, ip: &str);
fn try_acquire_connection(&self, ip: &str) -> bool;
fn release_connection(&self, ip: &str);
}
#[cfg(test)]
mod tests {
use super::DdosScore;
use crate::tools::time::TimeMillis;
#[test]
fn fresh_score_has_no_decay_on_first_increment() {
let mut score = DdosScore::new();
assert_eq!(score.increment(5.0, 1000.0, TimeMillis(1_000_000)), 5.0);
}
#[test]
fn score_decays_with_elapsed_time() {
let decay_per_second = 2.0;
let mut score = DdosScore::new();
assert_eq!(score.increment(10.0, decay_per_second, TimeMillis(0)), 10.0);
assert_eq!(score.current(decay_per_second, TimeMillis(3_000)), 4.0);
assert_eq!(score.current(decay_per_second, TimeMillis(0)), 10.0);
assert_eq!(score.increment(1.0, decay_per_second, TimeMillis(10_000)), 1.0);
}
#[test]
fn zero_decay_never_drains() {
let mut score = DdosScore::new();
score.increment(3.0, 0.0, TimeMillis(0));
assert_eq!(score.current(0.0, TimeMillis(1_000_000_000)), 3.0);
}
}