use std::collections::HashMap;
use std::net::IpAddr;
use std::path::{Path, PathBuf};
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use crate::config::{Backend, JailConfig};
use crate::error::{Error, Result};
use crate::executor_iptables::IptablesBackend;
use crate::executor_nftables::NftablesBackend;
use crate::executor_script::ScriptBackend;
use crate::state::{self, BanRecord, StateSnapshot};
#[derive(Debug)]
pub enum FirewallCmd {
Ban {
ip: IpAddr,
jail_id: String,
banned_at: i64,
expires_at: Option<i64>,
},
Unban { ip: IpAddr, jail_id: String },
InitJail {
jail_id: String,
ports: Vec<String>,
protocol: String,
done: oneshot::Sender<Result<()>>,
},
TeardownJail {
jail_id: String,
done: oneshot::Sender<Result<()>>,
},
SaveState { snapshot: StateSnapshot },
}
#[async_trait::async_trait]
pub trait FirewallBackend: Send + Sync {
async fn init(&self, jail: &str, ports: &[String], protocol: &str) -> Result<()>;
async fn teardown(&self, jail: &str) -> Result<()>;
async fn ban(&self, ip: &IpAddr, jail: &str) -> Result<()>;
async fn unban(&self, ip: &IpAddr, jail: &str) -> Result<()>;
async fn is_banned(&self, ip: &IpAddr, jail: &str) -> Result<bool>;
fn name(&self) -> &str;
}
const SYSTEM_DIRS: &[&str] = &["/usr/sbin", "/sbin", "/usr/bin", "/bin"];
pub fn resolve_binary(name: &str) -> Result<PathBuf> {
for dir in SYSTEM_DIRS {
let path = Path::new(dir).join(name);
if path.exists() {
return Ok(path);
}
}
Err(Error::firewall(format!(
"binary '{name}' not found in {}",
SYSTEM_DIRS.join(", ")
)))
}
pub fn create_backend(backend: &Backend) -> Result<Box<dyn FirewallBackend>> {
match backend {
Backend::Nftables => {
let nft_path = resolve_binary("nft")?;
Ok(Box::new(NftablesBackend::new(nft_path)))
}
Backend::Iptables => {
let iptables_path = resolve_binary("iptables")?;
let ip6tables_path = resolve_binary("ip6tables")?;
Ok(Box::new(IptablesBackend::new(
iptables_path,
ip6tables_path,
)))
}
Backend::Script { ban_cmd, unban_cmd } => Ok(Box::new(ScriptBackend::new(
ban_cmd.clone(),
unban_cmd.clone(),
))),
}
}
pub fn create_backends(
jails: &HashMap<String, JailConfig>,
) -> Result<HashMap<String, Box<dyn FirewallBackend>>> {
jails
.iter()
.filter(|(_, cfg)| cfg.enabled)
.map(|(name, cfg)| Ok((name.clone(), create_backend(&cfg.backend)?)))
.collect()
}
pub async fn run(
mut rx: mpsc::Receiver<FirewallCmd>,
backends: HashMap<String, Box<dyn FirewallBackend>>,
state_path: PathBuf,
cancel: CancellationToken,
) {
let names: Vec<_> = backends
.iter()
.map(|(k, v)| format!("{k}={}", v.name()))
.collect();
info!(backends = ?names, "executor started");
loop {
tokio::select! {
_ = cancel.cancelled() => {
info!("executor shutting down");
break;
}
cmd = rx.recv() => {
match cmd {
Some(FirewallCmd::Ban { ip, jail_id, banned_at, expires_at }) => {
info!(%ip, jail = %jail_id, "banning");
if let Some(backend) = backends.get(&jail_id) {
if let Err(e) = backend.ban(&ip, &jail_id).await {
error!(%ip, jail = %jail_id, error = %e, "ban failed");
}
} else {
warn!(%ip, jail = %jail_id, "no backend for jail");
}
let _ = (banned_at, expires_at);
}
Some(FirewallCmd::Unban { ip, jail_id }) => {
info!(%ip, jail = %jail_id, "unbanning");
if let Some(backend) = backends.get(&jail_id) {
if let Err(e) = backend.unban(&ip, &jail_id).await {
warn!(%ip, jail = %jail_id, error = %e, "unban failed");
}
} else {
warn!(%ip, jail = %jail_id, "no backend for jail");
}
}
Some(FirewallCmd::InitJail { jail_id, ports, protocol, done }) => {
info!(jail = %jail_id, "initializing firewall");
let result = if let Some(backend) = backends.get(&jail_id) {
backend.init(&jail_id, &ports, &protocol).await
} else {
warn!(jail = %jail_id, "no backend for jail init");
Ok(())
};
if let Err(ref e) = result {
error!(jail = %jail_id, error = %e, "firewall init failed");
}
let _ = done.send(result);
}
Some(FirewallCmd::TeardownJail { jail_id, done }) => {
info!(jail = %jail_id, "tearing down firewall");
let result = if let Some(backend) = backends.get(&jail_id) {
backend.teardown(&jail_id).await
} else {
Ok(())
};
if let Err(ref e) = result {
warn!(jail = %jail_id, error = %e, "firewall teardown failed");
}
let _ = done.send(result);
}
Some(FirewallCmd::SaveState { snapshot }) => {
let ban_count = snapshot.bans.len();
if let Err(e) = state::save(&state_path, &snapshot) {
error!(error = %e, "state save failed");
} else {
info!(bans = ban_count, "state saved");
}
}
None => {
info!("executor channel closed");
break;
}
}
}
}
}
}
pub async fn restore_bans(
bans: &[BanRecord],
backends: &HashMap<String, Box<dyn FirewallBackend>>,
now: i64,
) -> Vec<BanRecord> {
let mut restored = Vec::new();
for ban in bans {
if let Some(expires) = ban.expires_at
&& expires <= now
{
continue;
}
let backend = match backends.get(&ban.jail_id) {
Some(b) => b,
None => {
warn!(ip = %ban.ip, jail = %ban.jail_id, "no backend for jail, skipping restore");
continue;
}
};
if let Err(e) = backend.ban(&ban.ip, &ban.jail_id).await {
warn!(ip = %ban.ip, jail = %ban.jail_id, error = %e, "failed to restore ban");
continue;
}
restored.push(ban.clone());
}
restored
}