use std::net::Ipv4Addr;
use std::process::Command;
use ipnetwork::Ipv4Network;
use crate::error::{NetError, Result};
const CHAIN_PREFIX: &str = "ARCBOX";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FirewallBackend {
#[default]
Iptables,
Nftables,
}
impl FirewallBackend {
pub fn detect() -> Self {
if Command::new("nft").arg("--version").output().is_ok() {
return Self::Nftables;
}
Self::Iptables
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Protocol {
Tcp,
Udp,
Both,
}
impl Protocol {
fn as_str(&self) -> &'static str {
match self {
Self::Tcp => "tcp",
Self::Udp => "udp",
Self::Both => "all",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NatType {
Masquerade,
Snat,
}
#[derive(Debug, Clone)]
pub struct NatRule {
pub source: Ipv4Network,
pub out_interface: String,
pub nat_type: NatType,
pub snat_addr: Option<Ipv4Addr>,
}
impl NatRule {
#[must_use]
pub fn masquerade(source: Ipv4Network, out_interface: impl Into<String>) -> Self {
Self {
source,
out_interface: out_interface.into(),
nat_type: NatType::Masquerade,
snat_addr: None,
}
}
#[must_use]
pub fn snat(
source: Ipv4Network,
out_interface: impl Into<String>,
snat_addr: Ipv4Addr,
) -> Self {
Self {
source,
out_interface: out_interface.into(),
nat_type: NatType::Snat,
snat_addr: Some(snat_addr),
}
}
}
#[derive(Debug, Clone)]
pub struct DnatRule {
pub protocol: Protocol,
pub host_port: u16,
pub host_interface: Option<String>,
pub guest_ip: Ipv4Addr,
pub guest_port: u16,
}
impl DnatRule {
#[must_use]
pub fn new(protocol: Protocol, host_port: u16, guest_ip: Ipv4Addr, guest_port: u16) -> Self {
Self {
protocol,
host_port,
host_interface: None,
guest_ip,
guest_port,
}
}
#[must_use]
pub fn with_interface(mut self, interface: impl Into<String>) -> Self {
self.host_interface = Some(interface.into());
self
}
}
#[derive(Debug, Clone)]
pub struct ForwardRule {
pub in_interface: String,
pub out_interface: String,
}
pub struct LinuxFirewall {
backend: FirewallBackend,
setup_complete: bool,
}
impl LinuxFirewall {
pub fn new() -> Result<Self> {
Self::with_backend(FirewallBackend::detect())
}
pub fn with_backend(backend: FirewallBackend) -> Result<Self> {
let cmd = match backend {
FirewallBackend::Iptables => "iptables",
FirewallBackend::Nftables => "nft",
};
let output = Command::new(cmd)
.arg("--version")
.output()
.map_err(|e| NetError::Firewall(format!("{} not available: {}", cmd, e)))?;
if !output.status.success() {
return Err(NetError::Firewall(format!("{} check failed", cmd)));
}
tracing::info!("Using firewall backend: {:?}", backend);
Ok(Self {
backend,
setup_complete: false,
})
}
#[must_use]
pub fn backend(&self) -> FirewallBackend {
self.backend
}
pub fn setup(&mut self) -> Result<()> {
match self.backend {
FirewallBackend::Iptables => self.setup_iptables(),
FirewallBackend::Nftables => self.setup_nftables(),
}?;
self.setup_complete = true;
Ok(())
}
fn setup_iptables(&self) -> Result<()> {
let chains = [
("nat", format!("{}_PREROUTING", CHAIN_PREFIX)),
("nat", format!("{}_POSTROUTING", CHAIN_PREFIX)),
("filter", format!("{}_FORWARD", CHAIN_PREFIX)),
];
for (table, chain) in &chains {
let _ = self.run_iptables(&["-t", table, "-N", chain]);
self.run_iptables(&["-t", table, "-F", chain])?;
}
let jump_rules = [
("nat", "PREROUTING", format!("{}_PREROUTING", CHAIN_PREFIX)),
(
"nat",
"POSTROUTING",
format!("{}_POSTROUTING", CHAIN_PREFIX),
),
("filter", "FORWARD", format!("{}_FORWARD", CHAIN_PREFIX)),
];
for (table, builtin, custom) in &jump_rules {
let check = self.run_iptables(&["-t", table, "-C", builtin, "-j", custom]);
if check.is_err() {
self.run_iptables(&["-t", table, "-I", builtin, "1", "-j", custom])?;
}
}
tracing::debug!("iptables chains set up");
Ok(())
}
fn setup_nftables(&self) -> Result<()> {
let _ = self.run_nft(&["add", "table", "ip", "arcbox"]);
let chains = [
"add chain ip arcbox prerouting { type nat hook prerouting priority -100; }",
"add chain ip arcbox postrouting { type nat hook postrouting priority 100; }",
"add chain ip arcbox forward { type filter hook forward priority 0; }",
];
for chain_cmd in &chains {
let _ = self.run_nft(&chain_cmd.split_whitespace().collect::<Vec<_>>());
}
tracing::debug!("nftables table and chains set up");
Ok(())
}
pub fn teardown(&mut self) -> Result<()> {
match self.backend {
FirewallBackend::Iptables => self.teardown_iptables(),
FirewallBackend::Nftables => self.teardown_nftables(),
}?;
self.setup_complete = false;
Ok(())
}
fn teardown_iptables(&self) -> Result<()> {
let jump_rules = [
("nat", "PREROUTING", format!("{}_PREROUTING", CHAIN_PREFIX)),
(
"nat",
"POSTROUTING",
format!("{}_POSTROUTING", CHAIN_PREFIX),
),
("filter", "FORWARD", format!("{}_FORWARD", CHAIN_PREFIX)),
];
for (table, builtin, custom) in &jump_rules {
let _ = self.run_iptables(&["-t", table, "-D", builtin, "-j", custom]);
}
let chains = [
("nat", format!("{}_PREROUTING", CHAIN_PREFIX)),
("nat", format!("{}_POSTROUTING", CHAIN_PREFIX)),
("filter", format!("{}_FORWARD", CHAIN_PREFIX)),
];
for (table, chain) in &chains {
let _ = self.run_iptables(&["-t", table, "-F", chain]);
let _ = self.run_iptables(&["-t", table, "-X", chain]);
}
tracing::debug!("iptables chains torn down");
Ok(())
}
fn teardown_nftables(&self) -> Result<()> {
let _ = self.run_nft(&["delete", "table", "ip", "arcbox"]);
tracing::debug!("nftables table deleted");
Ok(())
}
pub fn enable_ip_forward(&self) -> Result<()> {
std::fs::write("/proc/sys/net/ipv4/ip_forward", "1")
.map_err(|e| NetError::Firewall(format!("failed to enable IP forwarding: {}", e)))?;
tracing::debug!("IP forwarding enabled");
Ok(())
}
pub fn disable_ip_forward(&self) -> Result<()> {
std::fs::write("/proc/sys/net/ipv4/ip_forward", "0")
.map_err(|e| NetError::Firewall(format!("failed to disable IP forwarding: {}", e)))?;
tracing::debug!("IP forwarding disabled");
Ok(())
}
pub fn add_nat_rule(&mut self, rule: &NatRule) -> Result<()> {
match self.backend {
FirewallBackend::Iptables => self.add_nat_rule_iptables(rule),
FirewallBackend::Nftables => self.add_nat_rule_nftables(rule),
}
}
fn add_nat_rule_iptables(&self, rule: &NatRule) -> Result<()> {
let chain = format!("{}_POSTROUTING", CHAIN_PREFIX);
let source = rule.source.to_string();
let mut args = vec![
"-t",
"nat",
"-A",
&chain,
"-s",
&source,
"-o",
&rule.out_interface,
];
let target;
match rule.nat_type {
NatType::Masquerade => {
args.extend(&["-j", "MASQUERADE"]);
}
NatType::Snat => {
target = format!(
"--to-source {}",
rule.snat_addr.expect("SNAT requires snat_addr")
);
args.extend(&["-j", "SNAT", &target]);
}
}
self.run_iptables(&args)?;
tracing::debug!("Added NAT rule: {:?}", rule);
Ok(())
}
fn add_nat_rule_nftables(&self, rule: &NatRule) -> Result<()> {
let source = rule.source.to_string();
let rule_str = match rule.nat_type {
NatType::Masquerade => {
format!(
"add rule ip arcbox postrouting ip saddr {} oifname \"{}\" masquerade",
source, rule.out_interface
)
}
NatType::Snat => {
let snat_addr = rule.snat_addr.expect("SNAT requires snat_addr");
format!(
"add rule ip arcbox postrouting ip saddr {} oifname \"{}\" snat to {}",
source, rule.out_interface, snat_addr
)
}
};
self.run_nft(&rule_str.split_whitespace().collect::<Vec<_>>())?;
tracing::debug!("Added NAT rule: {:?}", rule);
Ok(())
}
pub fn remove_nat_rule(&mut self, rule: &NatRule) -> Result<()> {
match self.backend {
FirewallBackend::Iptables => self.remove_nat_rule_iptables(rule),
FirewallBackend::Nftables => self.remove_nat_rule_nftables(rule),
}
}
fn remove_nat_rule_nftables(&self, rule: &NatRule) -> Result<()> {
let source_pattern = rule.source.to_string();
let nat_pattern = match rule.nat_type {
NatType::Masquerade => "masquerade".to_string(),
NatType::Snat => format!(
"snat to {}",
rule.snat_addr.expect("SNAT requires snat_addr")
),
};
self.delete_nft_rule_by_pattern("postrouting", &[&source_pattern, &nat_pattern])?;
tracing::debug!("Removed NAT rule: {:?}", rule);
Ok(())
}
fn delete_nft_rule_by_pattern(&self, chain: &str, patterns: &[&str]) -> Result<()> {
let output = Command::new("nft")
.args(["-a", "list", "chain", "ip", "arcbox", chain])
.output()
.map_err(|e| NetError::Firewall(format!("Failed to list nft rules: {}", e)))?;
if !output.status.success() {
return Ok(());
}
let stdout = String::from_utf8_lossy(&output.stdout);
for line in stdout.lines() {
let matches = patterns.iter().all(|p| line.contains(p));
if matches {
if let Some(handle) = Self::extract_nft_handle(line) {
tracing::debug!(
"Found matching rule with handle {}: {}",
handle,
line.trim()
);
let result = Command::new("nft")
.args([
"delete",
"rule",
"ip",
"arcbox",
chain,
"handle",
&handle.to_string(),
])
.output();
match result {
Ok(out) if out.status.success() => {
tracing::debug!("Deleted nft rule with handle {}", handle);
}
Ok(out) => {
let stderr = String::from_utf8_lossy(&out.stderr);
tracing::warn!("Failed to delete nft rule {}: {}", handle, stderr);
}
Err(e) => {
tracing::warn!("Failed to run nft delete: {}", e);
}
}
}
}
}
Ok(())
}
fn extract_nft_handle(line: &str) -> Option<u64> {
let parts: Vec<&str> = line.split_whitespace().collect();
for i in 0..parts.len().saturating_sub(1) {
if parts[i] == "handle" {
return parts[i + 1].parse().ok();
}
}
None
}
fn remove_nat_rule_iptables(&self, rule: &NatRule) -> Result<()> {
let chain = format!("{}_POSTROUTING", CHAIN_PREFIX);
let source = rule.source.to_string();
let mut args = vec![
"-t",
"nat",
"-D",
&chain,
"-s",
&source,
"-o",
&rule.out_interface,
];
match rule.nat_type {
NatType::Masquerade => {
args.extend(&["-j", "MASQUERADE"]);
}
NatType::Snat => {
let target = format!(
"--to-source {}",
rule.snat_addr.expect("SNAT requires snat_addr")
);
args.extend(&["-j", "SNAT"]);
args.push(&target);
}
}
self.run_iptables(&args)?;
tracing::debug!("Removed NAT rule: {:?}", rule);
Ok(())
}
pub fn add_dnat_rule(&mut self, rule: &DnatRule) -> Result<()> {
match self.backend {
FirewallBackend::Iptables => self.add_dnat_rule_iptables(rule),
FirewallBackend::Nftables => self.add_dnat_rule_nftables(rule),
}
}
fn add_dnat_rule_iptables(&self, rule: &DnatRule) -> Result<()> {
let chain = format!("{}_PREROUTING", CHAIN_PREFIX);
let protocols = match rule.protocol {
Protocol::Tcp => vec!["tcp"],
Protocol::Udp => vec!["udp"],
Protocol::Both => vec!["tcp", "udp"],
};
for proto in protocols {
let mut args = vec!["-t", "nat", "-A", &chain];
if let Some(ref iface) = rule.host_interface {
args.extend(&["-i", iface]);
}
let dport = rule.host_port.to_string();
let to_dest = format!("{}:{}", rule.guest_ip, rule.guest_port);
args.extend(&[
"-p",
proto,
"--dport",
&dport,
"-j",
"DNAT",
"--to-destination",
&to_dest,
]);
self.run_iptables(&args)?;
}
tracing::debug!("Added DNAT rule: {:?}", rule);
Ok(())
}
fn add_dnat_rule_nftables(&self, rule: &DnatRule) -> Result<()> {
let protocols = match rule.protocol {
Protocol::Tcp => vec!["tcp"],
Protocol::Udp => vec!["udp"],
Protocol::Both => vec!["tcp", "udp"],
};
for proto in protocols {
let mut rule_str = String::from("add rule ip arcbox prerouting");
if let Some(ref iface) = rule.host_interface {
rule_str.push_str(&format!(" iifname \"{}\"", iface));
}
rule_str.push_str(&format!(
" {} dport {} dnat to {}:{}",
proto, rule.host_port, rule.guest_ip, rule.guest_port
));
self.run_nft(&rule_str.split_whitespace().collect::<Vec<_>>())?;
}
tracing::debug!("Added DNAT rule: {:?}", rule);
Ok(())
}
pub fn remove_dnat_rule(&mut self, rule: &DnatRule) -> Result<()> {
match self.backend {
FirewallBackend::Iptables => self.remove_dnat_rule_iptables(rule),
FirewallBackend::Nftables => self.remove_dnat_rule_nftables(rule),
}
}
fn remove_dnat_rule_nftables(&self, rule: &DnatRule) -> Result<()> {
let protocols = match rule.protocol {
Protocol::Tcp => vec!["tcp"],
Protocol::Udp => vec!["udp"],
Protocol::Both => vec!["tcp", "udp"],
};
for proto in protocols {
let dport_pattern = format!("dport {}", rule.host_port);
let dnat_pattern = format!("dnat to {}:{}", rule.guest_ip, rule.guest_port);
self.delete_nft_rule_by_pattern("prerouting", &[proto, &dport_pattern, &dnat_pattern])?;
}
tracing::debug!("Removed DNAT rule: {:?}", rule);
Ok(())
}
fn remove_dnat_rule_iptables(&self, rule: &DnatRule) -> Result<()> {
let chain = format!("{}_PREROUTING", CHAIN_PREFIX);
let protocols = match rule.protocol {
Protocol::Tcp => vec!["tcp"],
Protocol::Udp => vec!["udp"],
Protocol::Both => vec!["tcp", "udp"],
};
for proto in protocols {
let mut args = vec!["-t", "nat", "-D", &chain];
if let Some(ref iface) = rule.host_interface {
args.extend(&["-i", iface]);
}
let dport = rule.host_port.to_string();
let to_dest = format!("{}:{}", rule.guest_ip, rule.guest_port);
args.extend(&[
"-p",
proto,
"--dport",
&dport,
"-j",
"DNAT",
"--to-destination",
&to_dest,
]);
let _ = self.run_iptables(&args);
}
tracing::debug!("Removed DNAT rule: {:?}", rule);
Ok(())
}
pub fn add_forward_rule(&mut self, in_if: &str, out_if: &str) -> Result<()> {
match self.backend {
FirewallBackend::Iptables => {
let chain = format!("{}_FORWARD", CHAIN_PREFIX);
self.run_iptables(&[
"-t",
"filter",
"-A",
&chain,
"-i",
out_if,
"-o",
in_if,
"-m",
"state",
"--state",
"RELATED,ESTABLISHED",
"-j",
"ACCEPT",
])?;
self.run_iptables(&[
"-t", "filter", "-A", &chain, "-i", in_if, "-o", out_if, "-j", "ACCEPT",
])?;
}
FirewallBackend::Nftables => {
let rules = [
format!(
"add rule ip arcbox forward iifname \"{}\" oifname \"{}\" ct state related,established accept",
out_if, in_if
),
format!(
"add rule ip arcbox forward iifname \"{}\" oifname \"{}\" accept",
in_if, out_if
),
];
for rule in &rules {
self.run_nft(&rule.split_whitespace().collect::<Vec<_>>())?;
}
}
}
tracing::debug!("Added forward rule: {} <-> {}", in_if, out_if);
Ok(())
}
fn run_iptables(&self, args: &[&str]) -> Result<()> {
let output = Command::new("iptables")
.args(args)
.output()
.map_err(|e| NetError::Firewall(format!("failed to run iptables: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(NetError::Firewall(format!(
"iptables {} failed: {}",
args.join(" "),
stderr
)));
}
Ok(())
}
fn run_nft(&self, args: &[&str]) -> Result<()> {
let output = Command::new("nft")
.args(args)
.output()
.map_err(|e| NetError::Firewall(format!("failed to run nft: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(NetError::Firewall(format!(
"nft {} failed: {}",
args.join(" "),
stderr
)));
}
Ok(())
}
}
impl Drop for LinuxFirewall {
fn drop(&mut self) {
if self.setup_complete {
tracing::debug!("LinuxFirewall dropped, rules still active");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nat_rule_masquerade() {
let rule = NatRule::masquerade("192.168.64.0/24".parse().unwrap(), "eth0");
assert_eq!(rule.nat_type, NatType::Masquerade);
assert_eq!(rule.out_interface, "eth0");
assert!(rule.snat_addr.is_none());
}
#[test]
fn test_nat_rule_snat() {
let rule = NatRule::snat(
"192.168.64.0/24".parse().unwrap(),
"eth0",
"10.0.0.1".parse().unwrap(),
);
assert_eq!(rule.nat_type, NatType::Snat);
assert_eq!(rule.snat_addr, Some("10.0.0.1".parse().unwrap()));
}
#[test]
fn test_dnat_rule() {
let rule = DnatRule::new(Protocol::Tcp, 8080, "192.168.64.2".parse().unwrap(), 80);
assert_eq!(rule.protocol, Protocol::Tcp);
assert_eq!(rule.host_port, 8080);
assert_eq!(rule.guest_port, 80);
}
#[test]
fn test_protocol_as_str() {
assert_eq!(Protocol::Tcp.as_str(), "tcp");
assert_eq!(Protocol::Udp.as_str(), "udp");
assert_eq!(Protocol::Both.as_str(), "all");
}
}