use std::collections::HashMap;
use crate::config::JailConfig;
pub(crate) struct JailParams {
pub(crate) max_retry: u32,
pub(crate) find_time: i64,
pub(crate) ban_time: i64,
pub(crate) webhook: Option<String>,
pub(crate) bantime_increment: bool,
pub(crate) bantime_factor: f64,
pub(crate) bantime_multipliers: Vec<u32>,
pub(crate) bantime_maxtime: i64,
}
pub(crate) fn calc_ban_time(base: i64, count: u32, params: &JailParams) -> i64 {
if !params.bantime_increment || base < 0 {
return base;
}
let multiplier = if params.bantime_multipliers.is_empty() {
let exp = count.min(20);
2_f64.powi(exp as i32)
} else {
let idx = (count as usize).min(params.bantime_multipliers.len() - 1);
params
.bantime_multipliers
.get(idx)
.copied()
.map_or(1.0, f64::from)
};
let effective = (base as f64 * multiplier * params.bantime_factor) as i64;
if params.bantime_maxtime > 0 {
effective.min(params.bantime_maxtime)
} else {
effective
}
}
pub(crate) fn build_jail_params(
configs: &HashMap<String, JailConfig>,
) -> HashMap<String, JailParams> {
configs
.iter()
.map(|(name, cfg)| {
(
name.clone(),
JailParams {
max_retry: cfg.max_retry,
find_time: cfg.find_time,
ban_time: cfg.ban_time,
webhook: cfg.webhook.clone(),
bantime_increment: cfg.bantime_increment,
bantime_factor: cfg.bantime_factor,
bantime_multipliers: cfg.bantime_multipliers.clone(),
bantime_maxtime: cfg.bantime_maxtime,
},
)
})
.collect()
}
#[cfg(test)]
#[allow(
clippy::panic,
clippy::indexing_slicing,
clippy::unwrap_used,
clippy::needless_pass_by_value
)]
mod tests {
use std::collections::HashMap;
use crate::config::JailConfig;
use crate::track::ban_calc::{JailParams, build_jail_params, calc_ban_time};
fn base_params() -> JailParams {
JailParams {
max_retry: 3,
find_time: 600,
ban_time: 60,
webhook: None,
bantime_increment: false,
bantime_factor: 1.0,
bantime_multipliers: vec![],
bantime_maxtime: 604_800,
}
}
#[test]
fn no_increment_returns_base() {
let params = base_params();
assert_eq!(calc_ban_time(60, 0, ¶ms), 60);
assert_eq!(calc_ban_time(60, 5, ¶ms), 60);
assert_eq!(calc_ban_time(60, 100, ¶ms), 60);
}
#[test]
fn exponential_escalation() {
let mut params = base_params();
params.bantime_increment = true;
assert_eq!(calc_ban_time(60, 0, ¶ms), 60); assert_eq!(calc_ban_time(60, 1, ¶ms), 120); assert_eq!(calc_ban_time(60, 2, ¶ms), 240); assert_eq!(calc_ban_time(60, 3, ¶ms), 480); assert_eq!(calc_ban_time(60, 4, ¶ms), 960); }
#[test]
fn explicit_multipliers() {
let mut params = base_params();
params.bantime_increment = true;
params.bantime_multipliers = vec![1, 2, 4, 8, 16];
assert_eq!(calc_ban_time(60, 0, ¶ms), 60); assert_eq!(calc_ban_time(60, 1, ¶ms), 120); assert_eq!(calc_ban_time(60, 2, ¶ms), 240); assert_eq!(calc_ban_time(60, 3, ¶ms), 480); assert_eq!(calc_ban_time(60, 4, ¶ms), 960); assert_eq!(calc_ban_time(60, 5, ¶ms), 960); assert_eq!(calc_ban_time(60, 99, ¶ms), 960); }
#[test]
fn maxtime_cap() {
let mut params = base_params();
params.bantime_increment = true;
params.bantime_maxtime = 300;
assert_eq!(calc_ban_time(60, 3, ¶ms), 300); assert_eq!(calc_ban_time(60, 10, ¶ms), 300); }
#[test]
fn factor_applied() {
let mut params = base_params();
params.bantime_increment = true;
params.bantime_factor = 1.5;
assert_eq!(calc_ban_time(60, 0, ¶ms), 90); assert_eq!(calc_ban_time(60, 1, ¶ms), 180); }
#[test]
fn permanent_ban_bypasses_increment() {
let mut params = base_params();
params.bantime_increment = true;
assert_eq!(calc_ban_time(-1, 0, ¶ms), -1);
assert_eq!(calc_ban_time(-1, 10, ¶ms), -1);
}
#[test]
fn zero_maxtime_means_no_cap() {
let mut params = base_params();
params.bantime_increment = true;
params.bantime_maxtime = 0;
assert_eq!(calc_ban_time(60, 10, ¶ms), 61440);
}
#[test]
fn permanent_ban_never_downgrades() {
let mut params = base_params();
params.bantime_increment = true;
params.bantime_factor = 0.5; params.bantime_maxtime = 300;
for count in 0..20 {
assert_eq!(
calc_ban_time(-1, count, ¶ms),
-1,
"permanent ban should stay permanent at count={count}"
);
}
}
#[test]
fn escalation_sequence_monotonically_increases() {
let mut params = base_params();
params.bantime_increment = true;
params.bantime_maxtime = 0; let mut prev = 0i64;
for count in 0..15 {
let t = calc_ban_time(60, count, ¶ms);
assert!(
t >= prev,
"ban time should not decrease: count={count}, prev={prev}, got={t}"
);
prev = t;
}
}
#[test]
fn multiplier_sequence_monotonically_increases() {
let mut params = base_params();
params.bantime_increment = true;
params.bantime_multipliers = vec![1, 2, 4, 8, 16, 32];
params.bantime_maxtime = 0;
let mut prev = 0i64;
for count in 0..10 {
let t = calc_ban_time(60, count, ¶ms);
assert!(
t >= prev,
"ban time should not decrease with multipliers: count={count}"
);
prev = t;
}
}
#[test]
fn high_count_does_not_panic() {
let mut params = base_params();
params.bantime_increment = true;
params.bantime_maxtime = 0;
let result = calc_ban_time(60, 100, ¶ms);
assert_eq!(result, 60 * (1 << 20));
}
fn test_jail_config() -> JailConfig {
JailConfig {
log_path: "/tmp/test.log".into(),
filter: vec!["from <HOST>".to_string()],
..JailConfig::default()
}
}
#[test]
fn build_params_maps_all_jails() {
let mut configs = HashMap::new();
configs.insert("sshd".to_string(), test_jail_config());
let mut nginx = test_jail_config();
nginx.max_retry = 10;
nginx.ban_time = 7200;
configs.insert("nginx".to_string(), nginx);
let params = build_jail_params(&configs);
assert_eq!(params.len(), 2);
assert!(params.contains_key("sshd"));
assert!(params.contains_key("nginx"));
}
#[test]
fn build_params_copies_values_correctly() {
let mut configs = HashMap::new();
let mut jail = test_jail_config();
jail.max_retry = 7;
jail.find_time = 300;
jail.ban_time = 1800;
jail.bantime_increment = true;
jail.bantime_factor = 2.0;
jail.bantime_multipliers = vec![1, 5, 10];
jail.bantime_maxtime = 86400;
jail.webhook = Some("https://example.com/hook".to_string());
configs.insert("test".to_string(), jail);
let params = build_jail_params(&configs);
let p = ¶ms["test"];
assert_eq!(p.max_retry, 7);
assert_eq!(p.find_time, 300);
assert_eq!(p.ban_time, 1800);
assert!(p.bantime_increment);
assert!((p.bantime_factor - 2.0).abs() < f64::EPSILON);
assert_eq!(p.bantime_multipliers, vec![1, 5, 10]);
assert_eq!(p.bantime_maxtime, 86400);
assert_eq!(p.webhook, Some("https://example.com/hook".to_string()));
}
#[test]
fn build_params_empty_configs() {
let configs = HashMap::new();
let params = build_jail_params(&configs);
assert!(params.is_empty());
}
#[test]
fn build_params_preserves_defaults() {
let mut configs = HashMap::new();
configs.insert("default".to_string(), test_jail_config());
let params = build_jail_params(&configs);
let p = ¶ms["default"];
assert_eq!(p.max_retry, 5);
assert_eq!(p.find_time, 600);
assert_eq!(p.ban_time, 3600);
assert!(!p.bantime_increment);
assert!((p.bantime_factor - 1.0).abs() < f64::EPSILON);
assert!(p.bantime_multipliers.is_empty());
assert_eq!(p.bantime_maxtime, 604_800);
assert!(p.webhook.is_none());
}
}