rustables/
batch.rs

1use std::os::fd::AsRawFd;
2
3use libc;
4
5use thiserror::Error;
6
7use crate::error::QueryError;
8use crate::nlmsg::{NfNetlinkObject, NfNetlinkWriter};
9use crate::sys::{NFNL_SUBSYS_NFTABLES, NLM_F_ACK};
10use crate::{MsgType, ProtocolFamily};
11
12use nix::sys::socket::{
13    self, AddressFamily, MsgFlags, NetlinkAddr, SockFlag, SockProtocol, SockType,
14};
15
16/// Error while communicating with netlink.
17#[derive(Error, Debug)]
18#[error("Error while communicating with netlink")]
19pub struct NetlinkError(());
20
21/// A batch of netfilter messages to be performed in one atomic operation.
22pub struct Batch {
23    buf: Box<Vec<u8>>,
24    // the 'static lifetime here is a cheat, as the writer can only be used as long
25    // as `self.buf` exists. This is why this member must never be exposed directly to
26    // the rest of the crate (let alone publicly).
27    writer: NfNetlinkWriter<'static>,
28    seq: u32,
29}
30
31impl Batch {
32    /// Creates a new nftnl batch with the [default page size].
33    ///
34    /// [default page size]: fn.default_batch_page_size.html
35    pub fn new() -> Self {
36        // TODO: use a pinned Box ?
37        let mut buf = Box::new(Vec::with_capacity(default_batch_page_size() as usize));
38        // Safe because we hold onto the buffer for as long as `writer` exists
39        let mut writer = NfNetlinkWriter::new(unsafe {
40            std::mem::transmute(Box::as_mut(&mut buf) as *mut Vec<u8>)
41        });
42        let seq = 0;
43        writer.write_header(
44            libc::NFNL_MSG_BATCH_BEGIN as u16,
45            ProtocolFamily::Unspec,
46            NLM_F_ACK as u16,
47            seq,
48            Some(libc::NFNL_SUBSYS_NFTABLES as u16),
49        );
50        writer.finalize_writing_object();
51        Batch {
52            buf,
53            writer,
54            seq: seq + 1,
55        }
56    }
57
58    /// Adds the given message to this batch.
59    pub fn add<T: NfNetlinkObject>(&mut self, msg: &T, msg_type: MsgType) {
60        trace!("Writing NlMsg with seq {} to batch", self.seq);
61        msg.add_or_remove(&mut self.writer, msg_type, self.seq);
62        self.seq += 1;
63    }
64
65    /// Adds all the messages in the given iterator to this batch.
66    pub fn add_iter<T: NfNetlinkObject, I: Iterator<Item = T>>(
67        &mut self,
68        msg_iter: I,
69        msg_type: MsgType,
70    ) {
71        for msg in msg_iter {
72            self.add(&msg, msg_type);
73        }
74    }
75
76    /// Adds the final end message to the batch and returns a [`FinalizedBatch`] that can be used
77    /// to send the messages to netfilter.
78    ///
79    /// Return None if there is no object in the batch (this could block forever).
80    ///
81    /// [`FinalizedBatch`]: struct.FinalizedBatch.html
82    pub fn finalize(mut self) -> Vec<u8> {
83        self.writer.write_header(
84            libc::NFNL_MSG_BATCH_END as u16,
85            ProtocolFamily::Unspec,
86            0,
87            self.seq,
88            Some(NFNL_SUBSYS_NFTABLES as u16),
89        );
90        self.writer.finalize_writing_object();
91        *self.buf
92    }
93
94    pub fn send(self) -> Result<(), QueryError> {
95        use crate::query::{recv_and_process, socket_close_wrapper};
96
97        let sock = socket::socket(
98            AddressFamily::Netlink,
99            SockType::Raw,
100            SockFlag::empty(),
101            SockProtocol::NetlinkNetFilter,
102        )
103        .map_err(QueryError::NetlinkOpenError)?;
104
105        let max_seq = self.seq - 1;
106
107        let addr = NetlinkAddr::new(0, 0);
108        // while this bind() is not strictly necessary, strace have trouble decoding the messages
109        // if we don't
110        socket::bind(sock.as_raw_fd(), &addr).map_err(|_| QueryError::BindFailed)?;
111
112        let to_send = self.finalize();
113        let mut sent = 0;
114        while sent != to_send.len() {
115            sent += socket::send(sock.as_raw_fd(), &to_send[sent..], MsgFlags::empty())
116                .map_err(QueryError::NetlinkSendError)?;
117        }
118
119        Ok(socket_close_wrapper(sock.as_raw_fd(), move |sock| {
120            recv_and_process(sock, Some(max_seq), None, &mut ())
121        })?)
122    }
123}
124
125/// Selected batch page is 256 Kbytes long to load ruleset of half a million rules without hitting
126/// -EMSGSIZE due to large iovec.
127pub fn default_batch_page_size() -> u32 {
128    unsafe { libc::sysconf(libc::_SC_PAGESIZE) as u32 * 32 }
129}