use std::collections::HashMap;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use crate::date::DateFormat;
use crate::duration;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub global: GlobalConfig,
#[serde(default)]
pub logging: LoggingConfig,
#[serde(default)]
pub jail: HashMap<String, JailConfig>,
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct LoggingConfig {
pub destination: Option<String>,
pub endpoint: Option<String>,
pub api_key: Option<String>,
pub level: Option<String>,
pub service: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GlobalConfig {
#[serde(default = "default_state_file")]
pub state_file: PathBuf,
#[serde(default = "default_socket_path")]
pub socket_path: PathBuf,
#[serde(default = "default_log_level")]
pub log_level: String,
#[serde(default = "default_channel_size")]
pub channel_size: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JailConfig {
#[serde(default = "default_true")]
pub enabled: bool,
pub log_path: PathBuf,
#[serde(default = "default_date_format")]
pub date_format: DateFormat,
pub filter: Vec<String>,
#[serde(default = "default_max_retry")]
pub max_retry: u32,
#[serde(
default = "default_find_time",
deserialize_with = "duration::deserialize_duration"
)]
pub find_time: i64,
#[serde(
default = "default_ban_time",
deserialize_with = "duration::deserialize_duration"
)]
pub ban_time: i64,
#[serde(default)]
pub port: Vec<String>,
#[serde(default = "default_protocol")]
pub protocol: String,
#[serde(default)]
pub bantime_increment: bool,
#[serde(default = "default_bantime_factor")]
pub bantime_factor: f64,
#[serde(default)]
pub bantime_multipliers: Vec<u32>,
#[serde(
default = "default_bantime_maxtime",
deserialize_with = "duration::deserialize_duration"
)]
pub bantime_maxtime: i64,
#[serde(default)]
pub log_backend: LogBackend,
#[serde(default)]
pub journalmatch: Vec<String>,
#[serde(default)]
pub backend: Backend,
#[serde(default)]
pub ignoreregex: Vec<String>,
#[serde(default)]
pub ignoreip: Vec<String>,
#[serde(default = "default_true")]
pub ignoreself: bool,
#[serde(default)]
pub webhook: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Backend {
#[default]
Nftables,
Iptables,
Script {
ban_cmd: String,
unban_cmd: String,
},
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LogBackend {
#[default]
File,
#[cfg(feature = "systemd")]
Systemd,
}
fn default_state_file() -> PathBuf {
PathBuf::from("/var/lib/fail2ban-rs/state.bin")
}
fn default_socket_path() -> PathBuf {
PathBuf::from("/var/run/fail2ban-rs/fail2ban-rs.sock")
}
fn default_log_level() -> String {
"info".to_string()
}
fn default_channel_size() -> usize {
1024
}
fn default_true() -> bool {
true
}
fn default_date_format() -> DateFormat {
DateFormat::Syslog
}
fn default_max_retry() -> u32 {
5
}
fn default_find_time() -> i64 {
600
}
fn default_ban_time() -> i64 {
3600
}
fn default_protocol() -> String {
"tcp".to_string()
}
fn default_bantime_factor() -> f64 {
1.0
}
fn default_bantime_maxtime() -> i64 {
604800 }
impl Config {
pub fn from_file(path: &Path) -> Result<Self> {
let content = std::fs::read_to_string(path).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
Error::ConfigNotFound {
path: path.to_path_buf(),
}
} else {
Error::io(format!("reading config: {}", path.display()), e)
}
})?;
let mut base: toml::Value = content
.parse()
.map_err(|e| Error::config(format!("TOML parse error: {e}")))?;
if let Some(dir) = path.parent() {
let config_d = dir.join("config.d");
if config_d.is_dir() {
let mut entries: Vec<PathBuf> = std::fs::read_dir(&config_d)
.map_err(|e| Error::io(format!("reading {}", config_d.display()), e))?
.filter_map(|entry| entry.ok().map(|e| e.path()))
.filter(|p| p.extension().is_some_and(|ext| ext == "toml"))
.collect();
entries.sort();
for overlay_path in entries {
let overlay_content = std::fs::read_to_string(&overlay_path).map_err(|e| {
Error::io(format!("reading overlay: {}", overlay_path.display()), e)
})?;
let overlay: toml::Value = overlay_content.parse().map_err(|e| {
Error::config(format!(
"TOML parse error in {}: {e}",
overlay_path.display()
))
})?;
deep_merge(&mut base, overlay);
}
}
}
let config: Config = base
.try_into()
.map_err(|e| Error::config(format!("config deserialization error: {e}")))?;
config.validate()?;
Ok(config)
}
pub fn parse(content: &str) -> Result<Self> {
let config: Config =
toml::from_str(content).map_err(|e| Error::config(format!("TOML parse error: {e}")))?;
config.validate()?;
Ok(config)
}
fn validate(&self) -> Result<()> {
if self.jail.is_empty() {
return Err(Error::config("no jails defined"));
}
let enabled_count = self.jail.values().filter(|j| j.enabled).count();
if enabled_count == 0 {
return Err(Error::config("no enabled jails"));
}
for (name, jail) in &self.jail {
Self::validate_jail(name, jail)?;
}
Ok(())
}
fn validate_jail(name: &str, jail: &JailConfig) -> Result<()> {
if name.is_empty() || name.len() > 64 {
return Err(Error::config(format!(
"jail '{name}': name must be 1-64 characters"
)));
}
if !name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return Err(Error::config(format!(
"jail '{name}': name must contain only alphanumeric, hyphen, underscore"
)));
}
if !jail.enabled {
return Ok(());
}
if jail.filter.is_empty() {
return Err(Error::config(format!("jail '{name}': no filter patterns")));
}
for pattern in &jail.filter {
if !pattern.contains("<HOST>") {
return Err(Error::config(format!(
"jail '{name}': pattern missing <HOST>: {pattern}"
)));
}
}
if jail.max_retry == 0 {
return Err(Error::config(format!(
"jail '{name}': max_retry must be > 0"
)));
}
if jail.find_time <= 0 {
return Err(Error::config(format!(
"jail '{name}': find_time must be > 0"
)));
}
if jail.ban_time == 0 {
return Err(Error::config(format!(
"jail '{name}': ban_time must be > 0 or -1 for permanent"
)));
}
if let Backend::Script {
ref ban_cmd,
ref unban_cmd,
} = jail.backend
{
if ban_cmd.trim().is_empty() {
return Err(Error::config(format!(
"jail '{name}': script backend requires non-empty ban_cmd"
)));
}
if unban_cmd.trim().is_empty() {
return Err(Error::config(format!(
"jail '{name}': script backend requires non-empty unban_cmd"
)));
}
}
for port in &jail.port {
if port.parse::<u16>().is_err() {
return Err(Error::config(format!(
"jail '{name}': invalid port: {port}"
)));
}
}
if !["tcp", "udp", "sctp", "dccp"].contains(&jail.protocol.as_str()) {
return Err(Error::config(format!(
"jail '{name}': protocol must be tcp, udp, sctp, or dccp"
)));
}
if !jail.bantime_factor.is_finite() || jail.bantime_factor <= 0.0 {
return Err(Error::config(format!(
"jail '{name}': bantime_factor must be finite and positive"
)));
}
Ok(())
}
pub fn enabled_jails(&self) -> impl Iterator<Item = (&str, &JailConfig)> {
self.jail
.iter()
.filter(|(_, j)| j.enabled)
.map(|(name, jail)| (name.as_str(), jail))
}
}
fn deep_merge(base: &mut toml::Value, overlay: toml::Value) {
match overlay {
toml::Value::Table(overlay_table) => {
if let toml::Value::Table(base_table) = base {
for (key, overlay_val) in overlay_table {
let entry = base_table
.entry(key)
.or_insert(toml::Value::Table(toml::map::Map::new()));
deep_merge(entry, overlay_val);
}
} else {
*base = toml::Value::Table(overlay_table);
}
}
other => {
*base = other;
}
}
}