1use std::os::fd::AsRawFd;
2
3use libc;
4
5use thiserror::Error;
6
7use crate::error::QueryError;
8use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter};
9use crate::sys::{NFNL_SUBSYS_NFTABLES, NLM_F_ACK};
10use crate::{MsgType, ProtocolFamily};
11
12use nix::sys::socket::{
13 self, AddressFamily, MsgFlags, NetlinkAddr, SockFlag, SockProtocol, SockType,
14};
15
16#[derive(Error, Debug)]
18#[error("Error while communicating with netlink")]
19pub struct NetlinkError(());
20
21pub struct Batch {
23 buf: Box<Vec<u8>>,
24 writer: NfNetlinkWriter<'static>,
28 seq: u32,
29}
30
31impl Batch {
32 pub fn new() -> Self {
36 let mut buf = Box::new(Vec::with_capacity(default_batch_page_size() as usize));
38 let mut writer = NfNetlinkWriter::new(unsafe {
40 std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>)
41 });
42 let seq = 0;
43 writer.write_header(
44 libc::NFNL_MSG_BATCH_BEGIN as u16,
45 ProtocolFamily::Unspec,
46 NLM_F_ACK as u16,
47 seq,
48 Some(libc::NFNL_SUBSYS_NFTABLES as u16),
49 );
50 writer.finalize_writing_object();
51 Batch {
52 buf,
53 writer,
54 seq: seq + 1,
55 }
56 }
57
58 pub fn add<T: NfNetlinkObject>(&mut self, msg: &T, msg_type: MsgType) {
60 trace!("Writing NlMsg with seq {} to batch", self.seq);
61 msg.add_or_remove(&mut self.writer, msg_type, self.seq);
62 self.seq += 1;
63 }
64
65 pub fn add_iter<T: NfNetlinkObject, I: Iterator<Item = T>>(
67 &mut self,
68 msg_iter: I,
69 msg_type: MsgType,
70 ) {
71 for msg in msg_iter {
72 self.add(&msg, msg_type);
73 }
74 }
75
76 pub fn finalize(mut self) -> Vec<u8> {
83 self.writer.write_header(
84 libc::NFNL_MSG_BATCH_END as u16,
85 ProtocolFamily::Unspec,
86 0,
87 self.seq,
88 Some(NFNL_SUBSYS_NFTABLES as u16),
89 );
90 self.writer.finalize_writing_object();
91 *self.buf
92 }
93
94 pub fn send(self) -> Result<(), QueryError> {
95 use crate::query::{recv_and_process, socket_close_wrapper};
96
97 let sock = socket::socket(
98 AddressFamily::Netlink,
99 SockType::Raw,
100 SockFlag::empty(),
101 SockProtocol::NetlinkNetFilter,
102 )
103 .map_err(QueryError::NetlinkOpenError)?;
104
105 let max_seq = self.seq - 1;
106
107 let addr = NetlinkAddr::new(0, 0);
108 socket::bind(sock.as_raw_fd(), &addr).map_err(|_| QueryError::BindFailed)?;
111
112 let to_send = self.finalize();
113 let mut sent = 0;
114 while sent != to_send.len() {
115 sent += socket::send(sock.as_raw_fd(), &to_send[sent..], MsgFlags::empty())
116 .map_err(QueryError::NetlinkSendError)?;
117 }
118
119 Ok(socket_close_wrapper(sock.as_raw_fd(), move |sock| {
120 recv_and_process(sock, Some(max_seq), None, &mut ())
121 })?)
122 }
123}
124
125pub fn default_batch_page_size() -> u32 {
128 unsafe { libc::sysconf(libc::_SC_PAGESIZE) as u32 * 32 }
129}