use std::{fmt::Debug, mem::size_of};
use crate::{
error::DecodeError,
sys::{
nfgenmsg, nlmsghdr, NFNETLINK_V0, NFNL_MSG_BATCH_BEGIN, NFNL_MSG_BATCH_END,
NFNL_SUBSYS_NFTABLES, NLMSG_ALIGNTO, NLM_F_ACK, NLM_F_CREATE,
},
MsgType, ProtocolFamily,
};
pub fn nft_nlmsg_maxsize() -> u32 {
u32::from(::std::u16::MAX) + unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as u32
}
#[inline]
pub const fn pad_netlink_object_with_variable_size(size: usize) -> usize {
(size + (NLMSG_ALIGNTO as usize - 1)) & !(NLMSG_ALIGNTO as usize - 1)
}
#[inline]
pub const fn pad_netlink_object<T>() -> usize {
let size = size_of::<T>();
pad_netlink_object_with_variable_size(size)
}
pub fn get_subsystem_from_nlmsghdr_type(x: u16) -> u8 {
((x & 0xff00) >> 8) as u8
}
pub fn get_operation_from_nlmsghdr_type(x: u16) -> u8 {
(x & 0x00ff) as u8
}
pub struct NfNetlinkWriter<'a> {
buf: &'a mut Vec<u8>,
headers: Option<(usize, usize)>,
}
impl<'a> NfNetlinkWriter<'a> {
pub fn new(buf: &'a mut Vec<u8>) -> NfNetlinkWriter<'a> {
NfNetlinkWriter { buf, headers: None }
}
pub fn add_data_zeroed<'b>(&'b mut self, size: usize) -> &'b mut [u8] {
let padded_size = pad_netlink_object_with_variable_size(size);
let start = self.buf.len();
self.buf.resize(start + padded_size, 0);
if let Some((msghdr_idx, _nfgenmsg_idx)) = self.headers {
let hdr: &mut nlmsghdr = unsafe {
std::mem::transmute(self.buf[msghdr_idx..].as_mut_ptr() as *mut nlmsghdr)
};
hdr.nlmsg_len += padded_size as u32;
}
&mut self.buf[start..start + size]
}
pub fn write_header(
&mut self,
msg_type: u16,
family: ProtocolFamily,
flags: u16,
seq: u32,
ressource_id: Option<u16>,
) {
if self.headers.is_some() {
error!("Calling write_header while still holding headers open!?");
}
let nlmsghdr_len = pad_netlink_object::<nlmsghdr>();
let nfgenmsg_len = pad_netlink_object::<nfgenmsg>();
let nlmsghdr_buf = self.add_data_zeroed(nlmsghdr_len);
let hdr: &mut nlmsghdr =
unsafe { std::mem::transmute(nlmsghdr_buf.as_mut_ptr() as *mut nlmsghdr) };
hdr.nlmsg_len = (nlmsghdr_len + nfgenmsg_len) as u32;
hdr.nlmsg_type = msg_type;
if msg_type != NFNL_MSG_BATCH_BEGIN as u16 && msg_type != NFNL_MSG_BATCH_END as u16 {
hdr.nlmsg_type |= (NFNL_SUBSYS_NFTABLES as u16) << 8;
}
hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16 | flags;
hdr.nlmsg_seq = seq;
let nfgenmsg_buf = self.add_data_zeroed(nfgenmsg_len);
let nfgenmsg: &mut nfgenmsg =
unsafe { std::mem::transmute(nfgenmsg_buf.as_mut_ptr() as *mut nfgenmsg) };
nfgenmsg.nfgen_family = family as u8;
nfgenmsg.version = NFNETLINK_V0 as u8;
nfgenmsg.res_id = ressource_id.unwrap_or(0);
self.headers = Some((
self.buf.len() - (nlmsghdr_len + nfgenmsg_len),
self.buf.len() - nfgenmsg_len,
));
}
pub fn finalize_writing_object(&mut self) {
self.headers = None;
}
}
pub type NetlinkType = u16;
pub trait AttributeDecoder {
fn decode_attribute(&mut self, attr_type: NetlinkType, buf: &[u8]) -> Result<(), DecodeError>;
}
pub trait NfNetlinkDeserializable: Sized {
fn deserialize(buf: &[u8]) -> Result<(Self, &[u8]), DecodeError>;
}
pub trait NfNetlinkObject:
Sized + AttributeDecoder + NfNetlinkDeserializable + NfNetlinkAttribute
{
const MSG_TYPE_ADD: u32;
const MSG_TYPE_DEL: u32;
fn add_or_remove<'a>(&self, writer: &mut NfNetlinkWriter<'a>, msg_type: MsgType, seq: u32) {
let raw_msg_type = match msg_type {
MsgType::Add => Self::MSG_TYPE_ADD,
MsgType::Del => Self::MSG_TYPE_DEL,
} as u16;
writer.write_header(
raw_msg_type,
self.get_family(),
(if let MsgType::Add = msg_type {
self.get_add_flags()
} else {
self.get_del_flags()
} | NLM_F_ACK) as u16,
seq,
None,
);
let buf = writer.add_data_zeroed(self.get_size());
self.write_payload(buf);
writer.finalize_writing_object();
}
fn get_family(&self) -> ProtocolFamily;
fn set_family(&mut self, _family: ProtocolFamily) {
}
fn with_family(mut self, family: ProtocolFamily) -> Self {
self.set_family(family);
self
}
fn get_add_flags(&self) -> u32 {
NLM_F_CREATE
}
fn get_del_flags(&self) -> u32 {
0
}
}
pub trait NfNetlinkAttribute: Debug + Sized {
fn is_nested(&self) -> bool {
false
}
fn get_size(&self) -> usize {
size_of::<Self>()
}
fn write_payload(&self, addr: &mut [u8]);
}