use std::os::fd::AsRawFd;
use libc;
use thiserror::Error;
use crate::error::QueryError;
use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter};
use crate::sys::{NFNL_SUBSYS_NFTABLES, NLM_F_ACK};
use crate::{MsgType, ProtocolFamily};
use nix::sys::socket::{
self, AddressFamily, MsgFlags, NetlinkAddr, SockFlag, SockProtocol, SockType,
};
#[derive(Error, Debug)]
#[error("Error while communicating with netlink")]
pub struct NetlinkError(());
pub struct Batch {
buf: Box<Vec<u8>>,
writer: NfNetlinkWriter<'static>,
seq: u32,
}
impl Batch {
pub fn new() -> Self {
let mut buf = Box::new(Vec::with_capacity(default_batch_page_size() as usize));
let mut writer = NfNetlinkWriter::new(unsafe {
std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>)
});
let seq = 0;
writer.write_header(
libc::NFNL_MSG_BATCH_BEGIN as u16,
ProtocolFamily::Unspec,
NLM_F_ACK as u16,
seq,
Some(libc::NFNL_SUBSYS_NFTABLES as u16),
);
writer.finalize_writing_object();
Batch {
buf,
writer,
seq: seq + 1,
}
}
pub fn add<T: NfNetlinkObject>(&mut self, msg: &T, msg_type: MsgType) {
trace!("Writing NlMsg with seq {} to batch", self.seq);
msg.add_or_remove(&mut self.writer, msg_type, self.seq);
self.seq += 1;
}
pub fn add_iter<T: NfNetlinkObject, I: Iterator<Item = T>>(
&mut self,
msg_iter: I,
msg_type: MsgType,
) {
for msg in msg_iter {
self.add(&msg, msg_type);
}
}
pub fn finalize(mut self) -> Vec<u8> {
self.writer.write_header(
libc::NFNL_MSG_BATCH_END as u16,
ProtocolFamily::Unspec,
0,
self.seq,
Some(NFNL_SUBSYS_NFTABLES as u16),
);
self.writer.finalize_writing_object();
*self.buf
}
pub fn send(self) -> Result<(), QueryError> {
use crate::query::{recv_and_process, socket_close_wrapper};
let sock = socket::socket(
AddressFamily::Netlink,
SockType::Raw,
SockFlag::empty(),
SockProtocol::NetlinkNetFilter,
)
.map_err(QueryError::NetlinkOpenError)?;
let max_seq = self.seq - 1;
let addr = NetlinkAddr::new(0, 0);
socket::bind(sock.as_raw_fd(), &addr).map_err(|_| QueryError::BindFailed)?;
let to_send = self.finalize();
let mut sent = 0;
while sent != to_send.len() {
sent += socket::send(sock.as_raw_fd(), &to_send[sent..], MsgFlags::empty())
.map_err(QueryError::NetlinkSendError)?;
}
Ok(socket_close_wrapper(sock.as_raw_fd(), move |sock| {
recv_and_process(sock, Some(max_seq), None, &mut ())
})?)
}
}
pub fn default_batch_page_size() -> u32 {
unsafe { libc::sysconf(libc::_SC_PAGESIZE) as u32 * 32 }
}