use crate::types::{AppError, InterfaceConfig, NetworkState};
use log::{debug, info};
use std::collections::{HashMap, HashSet};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
use tokio::sync::Mutex as AsyncMutex;
use std::borrow::Cow;
use nftables::{
batch::Batch,
helper, expr::Expression, schema::{NfCmd, NfListObject, NfObject, Table, Set, Element, FlushObject, SetTypeValue, SetFlag, Nftables}, types::NfFamily, };
pub struct NftablesManager {
#[allow(dead_code)] config: Arc<AsyncMutex<Vec<InterfaceConfig>>>,
table_name: String,
}
impl NftablesManager {
pub async fn new(config: Arc<AsyncMutex<Vec<InterfaceConfig>>>) -> Result<Self, AppError> {
let manager = Self {
config,
table_name: "filter".to_string(),
};
Ok(manager)
}
pub async fn load_rules(&self) -> Result<(), AppError> {
info!("[NFTABLES-RS] Ensuring base nftables structure");
let mut batch = Batch::new();
batch.add(NfListObject::Table(Table {
family: NfFamily::INet,
name: Cow::Borrowed(&self.table_name),
handle: None, }));
let config_lock = self.config.lock().await;
let unique_zones: HashSet<String> = config_lock.iter()
.filter_map(|iface| iface.nftables_zone.clone())
.collect();
drop(config_lock);
for zone_name in unique_zones {
let ipv4_set_name = format!("{}_ips", zone_name);
batch.add(NfListObject::Set(Box::new(Set {
family: NfFamily::INet,
table: Cow::Borrowed(&self.table_name),
name: Cow::Owned(ipv4_set_name),
handle: None,
set_type: nftables::schema::SetTypeValue::Single(nftables::schema::SetType::Ipv4Addr),
policy: None,
flags: Some(HashSet::from([nftables::schema::SetFlag::Dynamic])),
comment: None,
elem: None,
gc_interval: None,
size: None,
timeout: None,
})));
let ipv6_set_name = format!("{}_ipv6", zone_name);
batch.add(NfListObject::Set(Box::new(Set {
family: NfFamily::INet,
table: Cow::Borrowed(&self.table_name),
name: Cow::Owned(ipv6_set_name),
handle: None,
set_type: nftables::schema::SetTypeValue::Single(nftables::schema::SetType::Ipv6Addr),
policy: None,
flags: Some(HashSet::from([nftables::schema::SetFlag::Dynamic])),
comment: None,
elem: None,
gc_interval: None,
size: None,
timeout: None,
})));
}
let ruleset = batch.to_nftables();
debug!("[NFTABLES-RS] Load ruleset generated: {:?}", ruleset);
helper::apply_ruleset(&ruleset).map_err(AppError::NftablesError)?;
info!("[NFTABLES-RS] Base table '{}' and required sets ensured.", self.table_name);
Ok(())
}
pub async fn apply_rules(&self, network_state: &NetworkState) -> Result<(), AppError> {
info!("[NFTABLES-RS] Applying nftables rules (flush and add elements)");
let config_lock = self.config.lock().await;
let mut zone_to_ips: HashMap<String, HashSet<IpAddr>> = HashMap::new();
for interface_config in config_lock.iter() {
if let Some(zone) = &interface_config.nftables_zone {
if let Some(ips) = network_state.interface_ips.get(&interface_config.name) {
let zone_ips = zone_to_ips.entry(zone.clone()).or_default();
for ip in ips {
zone_ips.insert(*ip); }
}
}
}
drop(config_lock);
let mut flush_commands: Vec<NfObject> = Vec::new();
for zone_name in zone_to_ips.keys() {
let ipv4_set_name = format!("{}_ips", zone_name);
let set_to_flush_v4 = Box::new(Set {
family: NfFamily::INet,
table: Cow::Borrowed(&self.table_name),
name: Cow::Owned(ipv4_set_name),
handle: None,
set_type: nftables::schema::SetTypeValue::Single(nftables::schema::SetType::Ipv4Addr),
policy: None,
flags: None,
comment: None,
elem: None,
gc_interval: None,
size: None,
timeout: None,
});
flush_commands.push(NfObject::CmdObject(NfCmd::Flush(FlushObject::Set(set_to_flush_v4))));
let ipv6_set_name = format!("{}_ipv6", zone_name);
let set_to_flush_v6 = Box::new(Set {
family: NfFamily::INet,
table: Cow::Borrowed(&self.table_name),
name: Cow::Owned(ipv6_set_name),
handle: None,
set_type: nftables::schema::SetTypeValue::Single(nftables::schema::SetType::Ipv6Addr),
policy: None,
flags: None,
comment: None,
elem: None,
gc_interval: None,
size: None,
timeout: None,
});
flush_commands.push(NfObject::CmdObject(NfCmd::Flush(FlushObject::Set(set_to_flush_v6))));
}
if !flush_commands.is_empty() {
let flush_nftables = Nftables { objects: Cow::Owned(flush_commands) };
debug!("[NFTABLES-RS] Flush commands generated: {:?}", flush_nftables);
helper::apply_ruleset(&flush_nftables).map_err(AppError::NftablesError)?;
info!("[NFTABLES-RS] Successfully flushed nftables sets");
} else {
info!("[NFTABLES-RS] No sets to flush.");
}
let mut add_batch = Batch::new();
for (zone_name, ips) in zone_to_ips {
let ipv4_set_name = format!("{}_ips", zone_name);
let ipv4_elements: Vec<Element> = ips.iter()
.filter_map(|ip| match ip {
IpAddr::V4(v4) => Some(Element {
family: NfFamily::INet,
table: Cow::Borrowed(&self.table_name),
name: Cow::Owned(ipv4_set_name.clone()),
elem: Cow::Owned(vec![Expression::String(v4.to_string().into())]),
}),
_ => None,
})
.collect();
if !ipv4_elements.is_empty() {
add_batch.add(NfListObject::Element(Element {
family: NfFamily::INet,
table: Cow::Borrowed(&self.table_name),
name: Cow::Owned(ipv4_set_name),
elem: Cow::Owned(ipv4_elements.into_iter().flat_map(|e| e.elem.into_owned()).collect()),
}));
}
let ipv6_set_name = format!("{}_ipv6", zone_name);
let ipv6_elements: Vec<Element> = ips.iter()
.filter_map(|ip| match ip {
IpAddr::V6(v6) => Some(Element {
family: NfFamily::INet,
table: Cow::Borrowed(&self.table_name),
name: Cow::Owned(ipv6_set_name.clone()),
elem: Cow::Owned(vec![Expression::String(v6.to_string().into())]),
}),
_ => None,
})
.collect();
if !ipv6_elements.is_empty() {
add_batch.add(NfListObject::Element(Element {
family: NfFamily::INet,
table: Cow::Borrowed(&self.table_name),
name: Cow::Owned(ipv6_set_name),
elem: Cow::Owned(ipv6_elements.into_iter().flat_map(|e| e.elem.into_owned()).collect()),
}));
}
}
let add_ruleset = add_batch.to_nftables();
if !add_ruleset.objects.is_empty() {
debug!("[NFTABLES-RS] Add elements ruleset generated: {:?}", add_ruleset);
helper::apply_ruleset(&add_ruleset).map_err(AppError::NftablesError)?;
info!("[NFTABLES-RS] Successfully added elements to nftables sets");
} else {
info!("[NFTABLES-RS] No elements to add.");
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{InterfaceConfig, AppStateShared}; use std::net::{Ipv4Addr, IpAddr};
use std::sync::Arc;
use tokio::runtime::Runtime;
use tokio::sync::Mutex as AsyncMutex;
use std::collections::HashMap;
fn create_mock_config() -> Arc<AsyncMutex<Vec<InterfaceConfig>>> {
Arc::new(AsyncMutex::new(vec![
InterfaceConfig {
name: "eth0".to_string(),
dhcp: Some(true),
address: None,
nftables_zone: Some("wan".to_string()),
},
InterfaceConfig {
name: "eth1".to_string(),
dhcp: None,
address: Some("192.168.1.1/24".to_string()),
nftables_zone: Some("lan".to_string()),
},
]))
}
fn create_test_network_state() -> NetworkState {
let mut state = NetworkState::default(); let wan_ips = vec![
IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)),
IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)),
];
state.interface_ips.insert("eth0".to_string(), wan_ips);
let lan_ips = vec![IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))]; state.interface_ips.insert("eth1".to_string(), lan_ips);
state
}
#[test]
#[ignore] fn test_nftables_manager_init() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let config = create_mock_config();
let manager_result = NftablesManager::new(config).await;
assert!(manager_result.is_ok(), "NftablesManager::new should succeed: {:?}", manager_result.err());
});
}
#[test]
#[ignore] fn test_nftables_table_creation() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let config = create_mock_config();
let manager = NftablesManager::new(config).await.expect("Failed to create NftablesManager");
let result = manager.load_rules().await;
assert!(result.is_ok(), "load_rules should succeed: {:?}", result.err());
});
}
#[test]
#[ignore] fn test_nftables_set_management() {
let rt = Runtime::new().unwrap();
rt.block_on(async {
let config = create_mock_config();
let manager = NftablesManager::new(config).await.expect("Failed to create NftablesManager");
let result = manager.load_rules().await;
assert!(result.is_ok(), "load_rules should succeed: {:?}", result.err());
let network_state = create_test_network_state();
let apply_result = manager.apply_rules(&network_state).await;
assert!(apply_result.is_ok(), "apply_rules should succeed: {:?}", apply_result.err());
});
}
}