use super::{expr::write_expressions, types::*, *};
use crate::netlink::{
attr::AttrIter,
builder::MessageBuilder,
connection::Connection,
error::{Error, Result},
message::{
MessageIter, NLM_F_ACK, NLM_F_CREATE, NLM_F_DUMP, NLM_F_EXCL, NLM_F_REPLACE, NLM_F_REQUEST,
NlMsgError,
},
protocol::Nftables,
};
impl Connection<Nftables> {
#[tracing::instrument(level = "debug", skip_all, fields(method = "add_table"))]
pub async fn add_table(&self, name: &str, family: Family) -> Result<()> {
self.add_table_with_flags(name, family, 0).await
}
#[tracing::instrument(
level = "debug",
skip_all,
fields(method = "add_table_with_flags", flags)
)]
pub async fn add_table_with_flags(
&self,
name: &str,
family: Family,
flags: u32,
) -> Result<()> {
if name.is_empty() || name.len() > 256 {
return Err(Error::InvalidMessage(
"table name must be 1-256 characters".into(),
));
}
let mut builder = MessageBuilder::new(
nft_msg_type(NFT_MSG_NEWTABLE),
NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL,
);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_TABLE_NAME, name);
if flags != 0 {
builder.append_attr_u32_be(NFTA_TABLE_FLAGS, flags);
}
self.nft_request_ack(builder).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "list_tables"))]
pub async fn list_tables(&self) -> Result<Vec<Table>> {
let mut builder =
MessageBuilder::new(nft_msg_type(NFT_MSG_GETTABLE), NLM_F_REQUEST | NLM_F_DUMP);
let nfgenmsg = NfGenMsg {
nfgen_family: 0, version: 0,
res_id: 0,
};
builder.append(&nfgenmsg);
let responses = self.nft_dump(builder).await?;
let mut tables = Vec::new();
for (family_byte, payload) in &responses {
let family = Family::from_u8(*family_byte).unwrap_or(Family::Inet);
if let Some(table) = parse_table(payload, family) {
tables.push(table);
}
}
Ok(tables)
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "add_flowtable"))]
pub async fn add_flowtable(&self, ft: &super::types::Flowtable) -> Result<()> {
let mut builder = MessageBuilder::new(
nft_msg_type(NFT_MSG_NEWFLOWTABLE),
NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL,
);
let nfgenmsg = NfGenMsg::new(ft.family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_FLOWTABLE_TABLE, &ft.table);
builder.append_attr_str(NFTA_FLOWTABLE_NAME, &ft.name);
let hook = builder.nest_start(NFTA_FLOWTABLE_HOOK | 0x8000);
builder.append_attr_u32_be(NFTA_FLOWTABLE_HOOK_NUM, NF_NETDEV_INGRESS);
builder.append_attr_u32_be(NFTA_FLOWTABLE_HOOK_PRIORITY, ft.priority as u32);
if !ft.devs.is_empty() {
let devs = builder.nest_start(NFTA_FLOWTABLE_HOOK_DEVS | 0x8000);
for dev in &ft.devs {
let dev_nest = builder.nest_start(1u16 | 0x8000); builder.append_attr_str(NFTA_DEVICE_NAME, dev);
builder.nest_end(dev_nest);
}
builder.nest_end(devs);
}
builder.nest_end(hook);
if ft.flags != 0 {
builder.append_attr_u32_be(NFTA_FLOWTABLE_FLAGS, ft.flags);
}
self.nft_request_ack(builder).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_flowtable"))]
pub async fn del_flowtable(
&self,
family: Family,
table: &str,
name: &str,
) -> Result<()> {
let mut builder = MessageBuilder::new(
nft_msg_type(NFT_MSG_DELFLOWTABLE),
NLM_F_REQUEST | NLM_F_ACK,
);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_FLOWTABLE_TABLE, table);
builder.append_attr_str(NFTA_FLOWTABLE_NAME, name);
self.nft_request_ack(builder).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "list_flowtables"))]
pub async fn list_flowtables(&self) -> Result<Vec<super::types::Flowtable>> {
let mut builder = MessageBuilder::new(
nft_msg_type(NFT_MSG_GETFLOWTABLE),
NLM_F_REQUEST | NLM_F_DUMP,
);
let nfgenmsg = NfGenMsg {
nfgen_family: 0,
version: 0,
res_id: 0,
};
builder.append(&nfgenmsg);
let responses = self.nft_dump(builder).await?;
let mut out = Vec::new();
for (family_byte, payload) in &responses {
let family = Family::from_u8(*family_byte).unwrap_or(Family::Inet);
if let Some(ft) = parse_flowtable(payload, family) {
out.push(ft);
}
}
Ok(out)
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_table"))]
pub async fn del_table(&self, name: &str, family: Family) -> Result<()> {
let mut builder =
MessageBuilder::new(nft_msg_type(NFT_MSG_DELTABLE), NLM_F_REQUEST | NLM_F_ACK);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_TABLE_NAME, name);
self.nft_request_ack(builder).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "flush_table"))]
pub async fn flush_table(&self, name: &str, family: Family) -> Result<()> {
let mut builder =
MessageBuilder::new(nft_msg_type(NFT_MSG_DELRULE), NLM_F_REQUEST | NLM_F_ACK);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_RULE_TABLE, name);
self.nft_request_ack(builder).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "add_chain"))]
pub async fn add_chain(&self, chain: Chain) -> Result<()> {
if chain.hook.is_some() && chain.chain_type.is_none() {
return Err(Error::InvalidMessage(
"base chains with a hook require chain_type".into(),
));
}
let mut builder = MessageBuilder::new(
nft_msg_type(NFT_MSG_NEWCHAIN),
NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL,
);
let nfgenmsg = NfGenMsg::new(chain.family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_CHAIN_TABLE, &chain.table);
builder.append_attr_str(NFTA_CHAIN_NAME, &chain.name);
if let Some(chain_type) = chain.chain_type {
builder.append_attr_str(NFTA_CHAIN_TYPE, chain_type.as_str());
}
if let Some(hook) = chain.hook {
let hook_nest = builder.nest_start(NFTA_CHAIN_HOOK | 0x8000);
builder.append_attr_u32_be(NFTA_HOOK_HOOKNUM, hook.to_u32());
let priority = chain.priority.unwrap_or(Priority::Filter).to_i32();
builder.append_attr_u32_be(NFTA_HOOK_PRIORITY, priority as u32);
builder.nest_end(hook_nest);
}
if let Some(policy) = chain.policy {
builder.append_attr_u32_be(NFTA_CHAIN_POLICY, policy.to_u32());
}
self.nft_request_ack(builder).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "list_chains"))]
pub async fn list_chains(&self) -> Result<Vec<ChainInfo>> {
let mut builder =
MessageBuilder::new(nft_msg_type(NFT_MSG_GETCHAIN), NLM_F_REQUEST | NLM_F_DUMP);
let nfgenmsg = NfGenMsg {
nfgen_family: 0,
version: 0,
res_id: 0,
};
builder.append(&nfgenmsg);
let responses = self.nft_dump(builder).await?;
let mut chains = Vec::new();
for (family_byte, payload) in &responses {
let family = Family::from_u8(*family_byte).unwrap_or(Family::Inet);
if let Some(chain) = parse_chain(payload, family) {
chains.push(chain);
}
}
Ok(chains)
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_chain"))]
pub async fn del_chain(&self, table: &str, name: &str, family: Family) -> Result<()> {
let mut builder =
MessageBuilder::new(nft_msg_type(NFT_MSG_DELCHAIN), NLM_F_REQUEST | NLM_F_ACK);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_CHAIN_TABLE, table);
builder.append_attr_str(NFTA_CHAIN_NAME, name);
self.nft_request_ack(builder).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "add_rule"))]
pub async fn add_rule(&self, rule: Rule) -> Result<()> {
let mut builder = MessageBuilder::new(
nft_msg_type(NFT_MSG_NEWRULE),
NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE,
);
let nfgenmsg = NfGenMsg::new(rule.family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_RULE_TABLE, &rule.table);
builder.append_attr_str(NFTA_RULE_CHAIN, &rule.chain);
if let Some(pos) = rule.position {
builder.append_attr_u64_be(NFTA_RULE_POSITION, pos);
}
if !rule.exprs.is_empty() {
write_expressions(&mut builder, &rule.exprs);
}
if let Some(comment) = &rule.comment
&& let Some(udata) = super::userdata::encode_nlink_comment(comment)
{
builder.append_attr(NFTA_RULE_USERDATA, &udata);
}
self.nft_request_ack(builder).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "list_rules"))]
pub async fn list_rules(&self, table: &str, family: Family) -> Result<Vec<RuleInfo>> {
let mut builder =
MessageBuilder::new(nft_msg_type(NFT_MSG_GETRULE), NLM_F_REQUEST | NLM_F_DUMP);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_RULE_TABLE, table);
let responses = self.nft_dump(builder).await?;
let mut rules = Vec::new();
for (family_byte, payload) in &responses {
let family = Family::from_u8(*family_byte).unwrap_or(Family::Inet);
if let Some(rule) = parse_rule(payload, family) {
rules.push(rule);
}
}
Ok(rules)
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_rule"))]
pub async fn del_rule(
&self,
table: &str,
chain: &str,
family: Family,
handle: u64,
) -> Result<()> {
let mut builder =
MessageBuilder::new(nft_msg_type(NFT_MSG_DELRULE), NLM_F_REQUEST | NLM_F_ACK);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_RULE_TABLE, table);
builder.append_attr_str(NFTA_RULE_CHAIN, chain);
builder.append_attr_u64_be(NFTA_RULE_HANDLE, handle);
self.nft_request_ack(builder).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "add_set"))]
pub async fn add_set(&self, set: Set) -> Result<()> {
if set.name.is_empty() || set.name.len() > 256 {
return Err(Error::InvalidMessage(
"set name must be 1-256 characters".into(),
));
}
let mut builder = MessageBuilder::new(
nft_msg_type(NFT_MSG_NEWSET),
NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL,
);
let nfgenmsg = NfGenMsg::new(set.family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_SET_TABLE, &set.table);
builder.append_attr_str(NFTA_SET_NAME, &set.name);
builder.append_attr_u32_be(NFTA_SET_KEY_TYPE, set.key_type.type_id());
builder.append_attr_u32_be(NFTA_SET_KEY_LEN, set.key_type.len());
builder.append_attr_u32_be(NFTA_SET_FLAGS, set.flags);
builder.append_attr_u32_be(NFTA_SET_ID, 1);
self.nft_request_ack(builder).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "list_sets"))]
pub async fn list_sets(&self, family: Family) -> Result<Vec<SetInfo>> {
let mut builder =
MessageBuilder::new(nft_msg_type(NFT_MSG_GETSET), NLM_F_REQUEST | NLM_F_DUMP);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
let responses = self.nft_dump(builder).await?;
let mut sets = Vec::new();
for (family_byte, payload) in &responses {
let family = Family::from_u8(*family_byte).unwrap_or(Family::Inet);
if let Some(set) = parse_set(payload, family) {
sets.push(set);
}
}
Ok(sets)
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_set"))]
pub async fn del_set(&self, table: &str, name: &str, family: Family) -> Result<()> {
let mut builder =
MessageBuilder::new(nft_msg_type(NFT_MSG_DELSET), NLM_F_REQUEST | NLM_F_ACK);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_SET_TABLE, table);
builder.append_attr_str(NFTA_SET_NAME, name);
self.nft_request_ack(builder).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "add_set_elements"))]
pub async fn add_set_elements(
&self,
table: &str,
set: &str,
family: Family,
elements: &[SetElement],
) -> Result<()> {
let mut builder = MessageBuilder::new(
nft_msg_type(NFT_MSG_NEWSETELEM),
NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE,
);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_SET_ELEM_LIST_TABLE, table);
builder.append_attr_str(NFTA_SET_ELEM_LIST_SET, set);
let elems_nest = builder.nest_start(NFTA_SET_ELEM_LIST_ELEMENTS | 0x8000);
for elem in elements {
let elem_nest = builder.nest_start(NFTA_LIST_ELEM | 0x8000);
let key_nest = builder.nest_start(NFTA_SET_ELEM_KEY | 0x8000);
builder.append_attr(NFTA_DATA_VALUE, &elem.key);
builder.nest_end(key_nest);
builder.nest_end(elem_nest);
}
builder.nest_end(elems_nest);
self.nft_request_ack(builder).await
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "del_set_elements"))]
pub async fn del_set_elements(
&self,
table: &str,
set: &str,
family: Family,
elements: &[SetElement],
) -> Result<()> {
let mut builder =
MessageBuilder::new(nft_msg_type(NFT_MSG_DELSETELEM), NLM_F_REQUEST | NLM_F_ACK);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_SET_ELEM_LIST_TABLE, table);
builder.append_attr_str(NFTA_SET_ELEM_LIST_SET, set);
let elems_nest = builder.nest_start(NFTA_SET_ELEM_LIST_ELEMENTS | 0x8000);
for elem in elements {
let elem_nest = builder.nest_start(NFTA_LIST_ELEM | 0x8000);
let key_nest = builder.nest_start(NFTA_SET_ELEM_KEY | 0x8000);
builder.append_attr(NFTA_DATA_VALUE, &elem.key);
builder.nest_end(key_nest);
builder.nest_end(elem_nest);
}
builder.nest_end(elems_nest);
self.nft_request_ack(builder).await
}
pub fn transaction(&self) -> Transaction {
Transaction::new()
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "flush_ruleset"))]
pub async fn flush_ruleset(&self) -> Result<()> {
let tables = self.list_tables().await?;
for table in tables {
self.del_table(&table.name, table.family).await?;
}
Ok(())
}
async fn send_batch(&self, messages: Vec<Vec<u8>>) -> Result<()> {
if messages.is_empty() {
return Ok(());
}
let mut batch = Vec::new();
let mut begin = MessageBuilder::new(NFNL_MSG_BATCH_BEGIN, NLM_F_REQUEST);
let nfgenmsg = NfGenMsg {
nfgen_family: 0,
version: 0,
res_id: 10u16.to_be(), };
begin.append(&nfgenmsg);
let seq = self.socket().next_seq();
begin.set_seq(seq);
begin.set_pid(self.socket().pid());
batch.extend_from_slice(&begin.finish());
for msg_data in &messages {
batch.extend_from_slice(msg_data);
}
let mut end = MessageBuilder::new(NFNL_MSG_BATCH_END, NLM_F_REQUEST);
let nfgenmsg = NfGenMsg {
nfgen_family: 0,
version: 0,
res_id: 10u16.to_be(),
};
end.append(&nfgenmsg);
end.set_seq(self.socket().next_seq());
end.set_pid(self.socket().pid());
batch.extend_from_slice(&end.finish());
self.socket().send(&batch).await?;
loop {
let data: Vec<u8> = self.socket().recv_msg().await?;
for msg_result in MessageIter::new(&data) {
let (header, payload) = msg_result?;
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if err.is_ack() {
return Ok(());
}
return Err(err.into_error(payload));
}
if header.is_done() {
return Ok(());
}
}
}
}
async fn nft_request_ack(&self, mut builder: MessageBuilder) -> Result<()> {
let seq = self.socket().next_seq();
builder.set_seq(seq);
builder.set_pid(self.socket().pid());
self.send_batch(vec![builder.finish()]).await
}
#[tracing::instrument(level = "info", skip(self), fields(groups = ?groups))]
pub fn subscribe(&mut self, groups: &[super::events::NftablesGroup]) -> Result<()> {
for g in groups {
self.socket_mut().add_membership(g.to_kernel_group())?;
}
Ok(())
}
pub fn subscribe_all(&mut self) -> Result<()> {
self.subscribe(&[super::events::NftablesGroup::All])
}
async fn nft_dump(&self, mut builder: MessageBuilder) -> Result<Vec<(u8, Vec<u8>)>> {
let seq = self.socket().next_seq();
builder.set_seq(seq);
builder.set_pid(self.socket().pid());
let msg = builder.finish();
self.socket().send(&msg).await?;
let mut results = Vec::new();
loop {
let data: Vec<u8> = self.socket().recv_msg().await?;
let mut done = false;
for msg_result in MessageIter::new(&data) {
let (header, payload) = msg_result?;
if header.nlmsg_seq != seq {
continue;
}
if header.is_error() {
let err = NlMsgError::from_bytes(payload)?;
if !err.is_ack() {
return Err(err.into_error(payload));
}
continue;
}
if header.is_done() {
done = true;
break;
}
if payload.len() >= NFGENMSG_HDRLEN {
let family = payload[0];
results.push((family, payload[NFGENMSG_HDRLEN..].to_vec()));
}
}
if done {
break;
}
}
Ok(results)
}
}
pub(crate) fn parse_table(data: &[u8], family: Family) -> Option<Table> {
let mut table = Table {
name: String::new(),
family,
flags: 0,
use_count: 0,
handle: 0,
};
for (attr_type, payload) in AttrIter::new(data) {
match attr_type & 0x7FFF {
NFTA_TABLE_NAME => {
table.name = attr_str(payload)?;
}
NFTA_TABLE_FLAGS if payload.len() >= 4 => {
table.flags = u32::from_be_bytes(payload[..4].try_into().unwrap());
}
NFTA_TABLE_USE if payload.len() >= 4 => {
table.use_count = u32::from_be_bytes(payload[..4].try_into().unwrap());
}
NFTA_TABLE_HANDLE if payload.len() >= 8 => {
table.handle = u64::from_be_bytes(payload[..8].try_into().unwrap());
}
_ => {}
}
}
if table.name.is_empty() {
None
} else {
Some(table)
}
}
pub(crate) fn parse_chain(data: &[u8], family: Family) -> Option<ChainInfo> {
let mut chain = ChainInfo {
table: String::new(),
name: String::new(),
family,
hook: None,
priority: None,
chain_type: None,
policy: None,
handle: 0,
};
for (attr_type, payload) in AttrIter::new(data) {
match attr_type & 0x7FFF {
NFTA_CHAIN_TABLE => {
chain.table = attr_str(payload).unwrap_or_default();
}
NFTA_CHAIN_NAME => {
chain.name = attr_str(payload).unwrap_or_default();
}
NFTA_CHAIN_HANDLE if payload.len() >= 8 => {
chain.handle = u64::from_be_bytes(payload[..8].try_into().unwrap());
}
NFTA_CHAIN_HOOK => {
for (hook_attr, hook_payload) in AttrIter::new(payload) {
match hook_attr & 0x7FFF {
NFTA_HOOK_HOOKNUM if hook_payload.len() >= 4 => {
chain.hook =
Some(u32::from_be_bytes(hook_payload[..4].try_into().unwrap()));
}
NFTA_HOOK_PRIORITY if hook_payload.len() >= 4 => {
chain.priority =
Some(i32::from_be_bytes(hook_payload[..4].try_into().unwrap()));
}
_ => {}
}
}
}
NFTA_CHAIN_POLICY if payload.len() >= 4 => {
chain.policy = Some(u32::from_be_bytes(payload[..4].try_into().unwrap()));
}
NFTA_CHAIN_TYPE => {
chain.chain_type = attr_str(payload);
}
_ => {}
}
}
if chain.name.is_empty() {
None
} else {
Some(chain)
}
}
pub(crate) fn parse_rule(data: &[u8], family: Family) -> Option<RuleInfo> {
let mut rule = RuleInfo {
table: String::new(),
chain: String::new(),
family,
handle: 0,
position: None,
comment: None,
userdata_raw: None,
expression_bytes: Vec::new(),
};
for (attr_type, payload) in AttrIter::new(data) {
match attr_type & 0x7FFF {
NFTA_RULE_TABLE => {
rule.table = attr_str(payload).unwrap_or_default();
}
NFTA_RULE_CHAIN => {
rule.chain = attr_str(payload).unwrap_or_default();
}
NFTA_RULE_HANDLE if payload.len() >= 8 => {
rule.handle = u64::from_be_bytes(payload[..8].try_into().unwrap());
}
NFTA_RULE_POSITION if payload.len() >= 8 => {
rule.position = Some(u64::from_be_bytes(payload[..8].try_into().unwrap()));
}
NFTA_RULE_EXPRESSIONS => {
rule.expression_bytes = payload.to_vec();
}
NFTA_RULE_USERDATA => {
rule.userdata_raw = Some(payload.to_vec());
rule.comment = super::userdata::parse_nlink_comment(payload);
}
_ => {}
}
}
if rule.table.is_empty() {
None
} else {
Some(rule)
}
}
fn parse_set(data: &[u8], family: Family) -> Option<SetInfo> {
let mut set = SetInfo {
table: String::new(),
name: String::new(),
family,
flags: 0,
key_type: 0,
key_len: 0,
handle: 0,
};
for (attr_type, payload) in AttrIter::new(data) {
match attr_type & 0x7FFF {
NFTA_SET_TABLE => {
set.table = attr_str(payload).unwrap_or_default();
}
NFTA_SET_NAME => {
set.name = attr_str(payload).unwrap_or_default();
}
NFTA_SET_FLAGS if payload.len() >= 4 => {
set.flags = u32::from_be_bytes(payload[..4].try_into().unwrap());
}
NFTA_SET_KEY_TYPE if payload.len() >= 4 => {
set.key_type = u32::from_be_bytes(payload[..4].try_into().unwrap());
}
NFTA_SET_KEY_LEN if payload.len() >= 4 => {
set.key_len = u32::from_be_bytes(payload[..4].try_into().unwrap());
}
NFTA_SET_HANDLE if payload.len() >= 8 => {
set.handle = u64::from_be_bytes(payload[..8].try_into().unwrap());
}
_ => {}
}
}
if set.name.is_empty() { None } else { Some(set) }
}
fn attr_str(payload: &[u8]) -> Option<String> {
if payload.is_empty() {
return None;
}
let s = std::str::from_utf8(payload)
.unwrap_or("")
.trim_end_matches('\0');
if s.is_empty() {
None
} else {
Some(s.to_string())
}
}
#[must_use = "builders do nothing unless used"]
pub struct Transaction {
messages: Vec<Vec<u8>>,
seq_counter: u32,
}
impl Transaction {
fn new() -> Self {
Self {
messages: Vec::new(),
seq_counter: 1,
}
}
fn next_seq(&mut self) -> u32 {
let seq = self.seq_counter;
self.seq_counter += 1;
seq
}
pub fn add_table(mut self, name: &str, family: Family) -> Self {
let mut builder = MessageBuilder::new(
nft_msg_type(NFT_MSG_NEWTABLE),
NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL,
);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_TABLE_NAME, name);
builder.set_seq(self.next_seq());
self.messages.push(builder.finish());
self
}
pub fn add_chain(mut self, chain: Chain) -> Self {
let mut builder = MessageBuilder::new(
nft_msg_type(NFT_MSG_NEWCHAIN),
NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL,
);
let nfgenmsg = NfGenMsg::new(chain.family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_CHAIN_TABLE, &chain.table);
builder.append_attr_str(NFTA_CHAIN_NAME, &chain.name);
if let Some(chain_type) = chain.chain_type {
builder.append_attr_str(NFTA_CHAIN_TYPE, chain_type.as_str());
}
if let Some(hook) = chain.hook {
let hook_nest = builder.nest_start(NFTA_CHAIN_HOOK | 0x8000);
builder.append_attr_u32_be(NFTA_HOOK_HOOKNUM, hook.to_u32());
let priority = chain.priority.unwrap_or(Priority::Filter).to_i32();
builder.append_attr_u32_be(NFTA_HOOK_PRIORITY, priority as u32);
builder.nest_end(hook_nest);
}
if let Some(policy) = chain.policy {
builder.append_attr_u32_be(NFTA_CHAIN_POLICY, policy.to_u32());
}
builder.set_seq(self.next_seq());
self.messages.push(builder.finish());
self
}
pub fn add_rule(mut self, rule: Rule) -> Self {
let mut builder =
MessageBuilder::new(nft_msg_type(NFT_MSG_NEWRULE), NLM_F_REQUEST | NLM_F_CREATE);
let nfgenmsg = NfGenMsg::new(rule.family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_RULE_TABLE, &rule.table);
builder.append_attr_str(NFTA_RULE_CHAIN, &rule.chain);
if let Some(pos) = rule.position {
builder.append_attr_u64_be(NFTA_RULE_POSITION, pos);
}
if !rule.exprs.is_empty() {
write_expressions(&mut builder, &rule.exprs);
}
if let Some(comment) = &rule.comment
&& let Some(udata) = super::userdata::encode_nlink_comment(comment)
{
builder.append_attr(NFTA_RULE_USERDATA, &udata);
}
builder.set_seq(self.next_seq());
self.messages.push(builder.finish());
self
}
pub fn replace_rule(mut self, rule: Rule, handle: u64) -> Self {
let mut builder = MessageBuilder::new(
nft_msg_type(NFT_MSG_NEWRULE),
NLM_F_REQUEST | NLM_F_REPLACE,
);
let nfgenmsg = NfGenMsg::new(rule.family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_RULE_TABLE, &rule.table);
builder.append_attr_str(NFTA_RULE_CHAIN, &rule.chain);
builder.append_attr_u64_be(NFTA_RULE_HANDLE, handle);
if !rule.exprs.is_empty() {
write_expressions(&mut builder, &rule.exprs);
}
if let Some(comment) = &rule.comment
&& let Some(udata) = super::userdata::encode_nlink_comment(comment)
{
builder.append_attr(NFTA_RULE_USERDATA, &udata);
}
builder.set_seq(self.next_seq());
self.messages.push(builder.finish());
self
}
pub fn del_table(mut self, name: &str, family: Family) -> Self {
let mut builder = MessageBuilder::new(nft_msg_type(NFT_MSG_DELTABLE), NLM_F_REQUEST);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_TABLE_NAME, name);
builder.set_seq(self.next_seq());
self.messages.push(builder.finish());
self
}
pub fn add_table_with_flags(mut self, name: &str, family: Family, flags: u32) -> Self {
let mut builder = MessageBuilder::new(
nft_msg_type(NFT_MSG_NEWTABLE),
NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL,
);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_TABLE_NAME, name);
if flags != 0 {
builder.append_attr_u32_be(NFTA_TABLE_FLAGS, flags);
}
builder.set_seq(self.next_seq());
self.messages.push(builder.finish());
self
}
pub fn del_chain(mut self, table: &str, name: &str, family: Family) -> Self {
let mut builder = MessageBuilder::new(nft_msg_type(NFT_MSG_DELCHAIN), NLM_F_REQUEST);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_CHAIN_TABLE, table);
builder.append_attr_str(NFTA_CHAIN_NAME, name);
builder.set_seq(self.next_seq());
self.messages.push(builder.finish());
self
}
pub fn del_rule(mut self, table: &str, chain: &str, family: Family, handle: u64) -> Self {
let mut builder = MessageBuilder::new(nft_msg_type(NFT_MSG_DELRULE), NLM_F_REQUEST);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_RULE_TABLE, table);
builder.append_attr_str(NFTA_RULE_CHAIN, chain);
builder.append_attr_u64_be(NFTA_RULE_HANDLE, handle);
builder.set_seq(self.next_seq());
self.messages.push(builder.finish());
self
}
pub fn add_flowtable(mut self, ft: &super::types::Flowtable) -> Self {
let mut builder = MessageBuilder::new(
nft_msg_type(NFT_MSG_NEWFLOWTABLE),
NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL,
);
let nfgenmsg = NfGenMsg::new(ft.family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_FLOWTABLE_TABLE, &ft.table);
builder.append_attr_str(NFTA_FLOWTABLE_NAME, &ft.name);
let hook = builder.nest_start(NFTA_FLOWTABLE_HOOK | 0x8000);
builder.append_attr_u32_be(NFTA_FLOWTABLE_HOOK_NUM, NF_NETDEV_INGRESS);
builder.append_attr_u32_be(NFTA_FLOWTABLE_HOOK_PRIORITY, ft.priority as u32);
if !ft.devs.is_empty() {
let devs = builder.nest_start(NFTA_FLOWTABLE_HOOK_DEVS | 0x8000);
for dev in &ft.devs {
let dev_nest = builder.nest_start(1u16 | 0x8000); builder.append_attr_str(NFTA_DEVICE_NAME, dev);
builder.nest_end(dev_nest);
}
builder.nest_end(devs);
}
builder.nest_end(hook);
if ft.flags != 0 {
builder.append_attr_u32_be(NFTA_FLOWTABLE_FLAGS, ft.flags);
}
builder.set_seq(self.next_seq());
self.messages.push(builder.finish());
self
}
pub fn del_flowtable(mut self, family: Family, table: &str, name: &str) -> Self {
let mut builder = MessageBuilder::new(nft_msg_type(NFT_MSG_DELFLOWTABLE), NLM_F_REQUEST);
let nfgenmsg = NfGenMsg::new(family);
builder.append(&nfgenmsg);
builder.append_attr_str(NFTA_FLOWTABLE_TABLE, table);
builder.append_attr_str(NFTA_FLOWTABLE_NAME, name);
builder.set_seq(self.next_seq());
self.messages.push(builder.finish());
self
}
#[tracing::instrument(level = "debug", skip_all, fields(method = "commit"))]
pub async fn commit(self, conn: &Connection<Nftables>) -> Result<()> {
conn.send_batch(self.messages).await
}
}
pub(crate) fn parse_flowtable(data: &[u8], family: Family) -> Option<super::types::Flowtable> {
let mut ft = super::types::Flowtable {
family,
table: String::new(),
name: String::new(),
devs: Vec::new(),
priority: 0,
flags: 0,
use_count: 0,
handle: 0,
};
for (attr_type, payload) in AttrIter::new(data) {
match attr_type & 0x7FFF {
NFTA_FLOWTABLE_TABLE => {
ft.table = attr_str(payload)?;
}
NFTA_FLOWTABLE_NAME => {
ft.name = attr_str(payload)?;
}
NFTA_FLOWTABLE_USE if payload.len() >= 4 => {
ft.use_count = u32::from_be_bytes(payload[..4].try_into().ok()?);
}
NFTA_FLOWTABLE_HANDLE if payload.len() >= 8 => {
ft.handle = u64::from_be_bytes(payload[..8].try_into().ok()?);
}
NFTA_FLOWTABLE_FLAGS if payload.len() >= 4 => {
ft.flags = u32::from_be_bytes(payload[..4].try_into().ok()?);
}
NFTA_FLOWTABLE_HOOK => {
for (h_attr, h_payload) in AttrIter::new(payload) {
match h_attr & 0x7FFF {
NFTA_FLOWTABLE_HOOK_PRIORITY if h_payload.len() >= 4 => {
ft.priority = i32::from_be_bytes(
h_payload[..4].try_into().ok()?,
);
}
NFTA_FLOWTABLE_HOOK_DEVS => {
for (_le_attr, le_payload) in AttrIter::new(h_payload) {
for (d_attr, d_payload) in AttrIter::new(le_payload) {
if d_attr & 0x7FFF == NFTA_DEVICE_NAME
&& let Some(s) = attr_str(d_payload)
{
ft.devs.push(s);
}
}
}
}
_ => {}
}
}
}
_ => {}
}
}
if ft.name.is_empty() {
return None;
}
Some(ft)
}
#[cfg(test)]
mod transaction_tests {
use super::super::*;
use super::*;
fn new_tx() -> Transaction {
Transaction::new()
}
fn assert_header(msg: &[u8], expected_type: u16, expected_flags: u16) {
assert!(msg.len() >= 16, "msg too short for nlmsghdr: {}", msg.len());
let ty = u16::from_ne_bytes([msg[4], msg[5]]);
let flags = u16::from_ne_bytes([msg[6], msg[7]]);
assert_eq!(ty, expected_type, "nlmsg_type mismatch");
assert_eq!(flags, expected_flags, "nlmsg_flags mismatch");
}
fn find_attr(payload: &[u8], wanted_type: u16) -> Option<Vec<u8>> {
let mut offset = 0;
while offset + 4 <= payload.len() {
let len = u16::from_ne_bytes([payload[offset], payload[offset + 1]]) as usize;
let ty = u16::from_ne_bytes([payload[offset + 2], payload[offset + 3]]) & 0x7FFF;
if len < 4 || offset + len > payload.len() {
return None;
}
if ty == wanted_type {
return Some(payload[offset + 4..offset + len].to_vec());
}
offset += (len + 3) & !3;
}
None
}
fn body_after_nfgenmsg(msg: &[u8]) -> &[u8] {
&msg[16 + 4..]
}
#[test]
fn del_chain_emits_correct_wire_message() {
let tx = new_tx().del_chain("filter", "input", Family::Inet);
assert_eq!(tx.messages.len(), 1);
let msg = &tx.messages[0];
assert_header(msg, nft_msg_type(NFT_MSG_DELCHAIN), NLM_F_REQUEST);
let body = body_after_nfgenmsg(msg);
let table = find_attr(body, NFTA_CHAIN_TABLE).expect("NFTA_CHAIN_TABLE missing");
let name = find_attr(body, NFTA_CHAIN_NAME).expect("NFTA_CHAIN_NAME missing");
assert_eq!(&table[..table.len().saturating_sub(1)], b"filter");
assert_eq!(&name[..name.len().saturating_sub(1)], b"input");
}
#[test]
fn del_rule_emits_correct_wire_message_with_handle() {
let tx = new_tx().del_rule("filter", "input", Family::Inet, 0xDEAD_BEEF);
assert_eq!(tx.messages.len(), 1);
let msg = &tx.messages[0];
assert_header(msg, nft_msg_type(NFT_MSG_DELRULE), NLM_F_REQUEST);
let body = body_after_nfgenmsg(msg);
let handle = find_attr(body, NFTA_RULE_HANDLE).expect("NFTA_RULE_HANDLE missing");
assert_eq!(handle.len(), 8, "handle must be u64 big-endian");
assert_eq!(u64::from_be_bytes(handle.try_into().unwrap()), 0xDEAD_BEEF);
}
#[test]
fn add_flowtable_emits_nested_hook_block() {
let ft = super::super::types::Flowtable {
family: Family::Inet,
table: "filter".into(),
name: "ft".into(),
devs: vec!["eth0".into()],
priority: -300,
flags: NFT_FLOWTABLE_HW_OFFLOAD,
use_count: 0,
handle: 0,
};
let tx = new_tx().add_flowtable(&ft);
assert_eq!(tx.messages.len(), 1);
let msg = &tx.messages[0];
assert_header(
msg,
nft_msg_type(NFT_MSG_NEWFLOWTABLE),
NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL,
);
let body = body_after_nfgenmsg(msg);
assert!(find_attr(body, NFTA_FLOWTABLE_TABLE).is_some());
assert!(find_attr(body, NFTA_FLOWTABLE_NAME).is_some());
let mut hook_found_with_nested_flag = false;
let mut offset = 0;
while offset + 4 <= body.len() {
let len = u16::from_ne_bytes([body[offset], body[offset + 1]]) as usize;
let raw_ty = u16::from_ne_bytes([body[offset + 2], body[offset + 3]]);
if len < 4 || offset + len > body.len() {
break;
}
if (raw_ty & 0x7FFF) == NFTA_FLOWTABLE_HOOK && (raw_ty & 0x8000) != 0 {
hook_found_with_nested_flag = true;
}
offset += (len + 3) & !3;
}
assert!(hook_found_with_nested_flag, "hook block missing NLA_F_NESTED flag");
let flags = find_attr(body, NFTA_FLOWTABLE_FLAGS).expect("flags missing");
assert_eq!(u32::from_be_bytes(flags.try_into().unwrap()), NFT_FLOWTABLE_HW_OFFLOAD);
}
#[test]
fn del_flowtable_emits_table_plus_name() {
let tx = new_tx().del_flowtable(Family::Inet, "filter", "ft");
assert_eq!(tx.messages.len(), 1);
let msg = &tx.messages[0];
assert_header(msg, nft_msg_type(NFT_MSG_DELFLOWTABLE), NLM_F_REQUEST);
let body = body_after_nfgenmsg(msg);
assert!(find_attr(body, NFTA_FLOWTABLE_TABLE).is_some());
assert!(find_attr(body, NFTA_FLOWTABLE_NAME).is_some());
}
#[test]
fn add_table_with_flags_emits_flags_attr() {
let tx = new_tx().add_table_with_flags(
"filter",
Family::Inet,
NFT_TABLE_F_DORMANT,
);
assert_eq!(tx.messages.len(), 1);
let msg = &tx.messages[0];
assert_header(
msg,
nft_msg_type(NFT_MSG_NEWTABLE),
NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL,
);
let body = body_after_nfgenmsg(msg);
let flags = find_attr(body, NFTA_TABLE_FLAGS).expect("NFTA_TABLE_FLAGS missing");
assert_eq!(
u32::from_be_bytes(flags.try_into().unwrap()),
NFT_TABLE_F_DORMANT
);
}
#[test]
fn add_table_with_flags_omits_flags_attr_when_zero() {
let tx = new_tx().add_table_with_flags("filter", Family::Inet, 0);
let body = body_after_nfgenmsg(&tx.messages[0]);
assert!(find_attr(body, NFTA_TABLE_FLAGS).is_none());
}
#[test]
fn chained_batch_preserves_message_order_and_seq_numbers() {
let tx = new_tx()
.del_rule("filter", "input", Family::Inet, 1)
.del_chain("filter", "input", Family::Inet)
.del_table("filter", Family::Inet);
assert_eq!(tx.messages.len(), 3);
let seqs: Vec<u32> = tx
.messages
.iter()
.map(|m| u32::from_ne_bytes([m[8], m[9], m[10], m[11]]))
.collect();
assert_eq!(seqs, vec![1, 2, 3]);
let types: Vec<u16> = tx
.messages
.iter()
.map(|m| u16::from_ne_bytes([m[4], m[5]]))
.collect();
assert_eq!(
types,
vec![
nft_msg_type(NFT_MSG_DELRULE),
nft_msg_type(NFT_MSG_DELCHAIN),
nft_msg_type(NFT_MSG_DELTABLE),
]
);
}
}
use crate::netlink::dump_stream::DumpStream;
use crate::netlink::parse::{FromNetlink, PResult};
impl FromNetlink for RuleInfo {
fn write_dump_header(buf: &mut Vec<u8>) {
let nfgenmsg = NfGenMsg {
nfgen_family: 0, version: 0,
res_id: 0,
};
buf.extend_from_slice(nfgenmsg.as_bytes());
}
fn parse(input: &mut &[u8]) -> PResult<Self> {
let consumed = *input;
*input = &input[input.len()..];
Self::from_bytes(consumed).map_err(|_| {
winnow::error::ErrMode::Cut(winnow::error::ContextError::new())
})
}
fn from_bytes(payload: &[u8]) -> crate::Result<Self> {
if payload.len() < NFGENMSG_HDRLEN {
return Err(crate::Error::InvalidMessage(
"nft rule body shorter than nfgenmsg".into(),
));
}
let family = Family::from_u8(payload[0]).unwrap_or(Family::Inet);
let attrs = &payload[NFGENMSG_HDRLEN..];
parse_rule(attrs, family).ok_or_else(|| {
crate::Error::InvalidMessage("nft rule parse failed".into())
})
}
}
impl Connection<Nftables> {
pub async fn stream_rules(
&self,
table: &str,
family: Family,
) -> Result<DumpStream<'_, Nftables, RuleInfo>> {
let mut body = Vec::with_capacity(4 + 4 + table.len() + 1);
let nfgenmsg = NfGenMsg::new(family);
body.extend_from_slice(nfgenmsg.as_bytes());
let str_len = table.len() + 1;
let attr_len = 4 + str_len;
body.extend_from_slice(&(attr_len as u16).to_le_bytes());
body.extend_from_slice(&NFTA_RULE_TABLE.to_le_bytes());
body.extend_from_slice(table.as_bytes());
body.push(0); let padding = (4 - (attr_len % 4)) % 4;
body.resize(body.len() + padding, 0);
self.dump_stream_with_body::<RuleInfo>(
nft_msg_type(NFT_MSG_GETRULE),
&body,
)
.await
}
}
#[cfg(test)]
mod stream_tests {
use super::*;
#[test]
fn rule_write_dump_header_emits_4byte_nfgenmsg() {
let mut buf = Vec::new();
<RuleInfo as FromNetlink>::write_dump_header(&mut buf);
assert_eq!(buf.len(), NFGENMSG_HDRLEN);
assert_eq!(buf[0], 0); }
#[test]
fn rule_from_bytes_rejects_truncated_payload() {
let payload = vec![0u8; 2];
assert!(<RuleInfo as FromNetlink>::from_bytes(&payload).is_err());
}
#[test]
fn rule_from_bytes_parses_family_from_nfgenmsg() {
let mut body = Vec::new();
body.push(2); body.push(0); body.extend_from_slice(&0u16.to_be_bytes()); let table = b"filter\0";
let attr_len = 4 + table.len();
body.extend_from_slice(&(attr_len as u16).to_le_bytes());
body.extend_from_slice(&NFTA_RULE_TABLE.to_le_bytes());
body.extend_from_slice(table);
let pad = (4 - body.len() % 4) % 4;
body.resize(body.len() + pad, 0);
let chain = b"input\0";
let attr_len2 = 4 + chain.len();
body.extend_from_slice(&(attr_len2 as u16).to_le_bytes());
body.extend_from_slice(&NFTA_RULE_CHAIN.to_le_bytes());
body.extend_from_slice(chain);
let pad = (4 - body.len() % 4) % 4;
body.resize(body.len() + pad, 0);
body.extend_from_slice(&12u16.to_le_bytes()); body.extend_from_slice(&NFTA_RULE_HANDLE.to_le_bytes());
body.extend_from_slice(&7u64.to_be_bytes());
let rule = <RuleInfo as FromNetlink>::from_bytes(&body).expect("parse");
assert_eq!(rule.family, Family::Ip);
assert_eq!(rule.table, "filter");
assert_eq!(rule.chain, "input");
assert_eq!(rule.handle, 7);
}
}
#[cfg(test)]
mod userdata_roundtrip_tests {
use super::*;
use crate::netlink::nftables::types::Rule;
fn body_after_nfgenmsg(msg: &[u8]) -> &[u8] {
&msg[20..]
}
#[test]
fn comment_round_trips_through_transaction_add_rule() {
let rule = Rule::new("filter", "input")
.family(Family::Inet)
.comment("ssh-accept");
let tx = Transaction::new().add_rule(rule);
let messages = &tx.messages;
assert_eq!(messages.len(), 1, "expected exactly one rule msg");
let body = body_after_nfgenmsg(&messages[0]);
let parsed = super::parse_rule(body, Family::Inet)
.expect("parse_rule should succeed on a well-formed body");
assert_eq!(parsed.table, "filter");
assert_eq!(parsed.chain, "input");
assert_eq!(
parsed.comment.as_deref(),
Some("ssh-accept"),
"comment should round-trip from emit through parse",
);
assert!(
parsed.userdata_raw.is_some(),
"raw userdata should also be preserved",
);
}
#[test]
fn rule_without_comment_has_none_after_parse() {
let rule = Rule::new("filter", "input").family(Family::Inet);
let tx = Transaction::new().add_rule(rule);
let body = body_after_nfgenmsg(&tx.messages[0]);
let parsed = super::parse_rule(body, Family::Inet).expect("parse");
assert!(parsed.comment.is_none());
assert!(parsed.userdata_raw.is_none());
}
#[test]
fn replace_rule_carries_comment_and_handle() {
let rule = Rule::new("filter", "input")
.family(Family::Inet)
.comment("ssh-accept");
let tx = Transaction::new().replace_rule(rule, 42);
let body = body_after_nfgenmsg(&tx.messages[0]);
let parsed = super::parse_rule(body, Family::Inet).expect("parse");
assert_eq!(parsed.handle, 42);
assert_eq!(parsed.comment.as_deref(), Some("ssh-accept"));
}
}