use std::os::{fd::AsRawFd, unix::prelude::RawFd};
use nix::sys::socket::{self, AddressFamily, MsgFlags, SockFlag, SockProtocol, SockType};
use crate::{
error::QueryError,
nlmsg::{
nft_nlmsg_maxsize, pad_netlink_object_with_variable_size, NfNetlinkAttribute,
NfNetlinkObject, NfNetlinkWriter,
},
parser::{parse_nlmsg, NlMsg},
sys::{NLM_F_DUMP, NLM_F_MULTI},
ProtocolFamily,
};
pub(crate) fn recv_and_process<'a, T>(
sock: RawFd,
max_seq: Option<u32>,
cb: Option<&dyn Fn(&[u8], &mut T) -> Result<(), QueryError>>,
working_data: &'a mut T,
) -> Result<(), QueryError> {
let mut msg_buffer = vec![0; 2 * nft_nlmsg_maxsize() as usize];
let mut buf_start = 0;
let mut end_pos = 0;
loop {
let nb_recv = socket::recv(sock, &mut msg_buffer[end_pos..], MsgFlags::empty())
.map_err(QueryError::NetlinkRecvError)?;
if nb_recv <= 0 {
return Ok(());
}
end_pos += nb_recv;
loop {
let buf = &msg_buffer.as_slice()[buf_start..end_pos];
if buf.len() == 0 {
break;
}
debug!("Calling parse_nlmsg");
let (nlmsghdr, msg) = parse_nlmsg(&buf)?;
debug!("Got a valid netlink message: {:?} {:?}", nlmsghdr, msg);
match msg {
NlMsg::Done => {
return Ok(());
}
NlMsg::Error(e) => {
if e.error != 0 {
return Err(QueryError::NetlinkError(e));
}
}
NlMsg::Noop => {}
NlMsg::NfGenMsg(_genmsg, _data) => {
if let Some(cb) = cb {
cb(&buf[0..nlmsghdr.nlmsg_len as usize], working_data)?;
}
}
}
if max_seq.is_none() && nlmsghdr.nlmsg_flags & NLM_F_MULTI as u16 == 0 {
return Err(QueryError::UndecidableMessageTermination);
}
if let Some(max_seq) = max_seq {
if nlmsghdr.nlmsg_seq >= max_seq {
return Ok(());
}
}
let aligned_length = pad_netlink_object_with_variable_size(nlmsghdr.nlmsg_len as usize);
buf_start += aligned_length;
}
if end_pos >= nft_nlmsg_maxsize() as usize {
if buf_start < end_pos {
msg_buffer.copy_within(buf_start..end_pos, 0);
}
end_pos = end_pos - buf_start;
buf_start = 0;
}
}
}
pub(crate) fn socket_close_wrapper<E>(
sock: RawFd,
cb: impl FnOnce(RawFd) -> Result<(), E>,
) -> Result<(), QueryError>
where
QueryError: From<E>,
{
let ret = cb(sock);
nix::unistd::close(sock).map_err(QueryError::CloseFailed)?;
Ok(ret?)
}
pub fn get_list_of_objects<T: NfNetlinkAttribute>(
msg_type: u16,
seq: u32,
filter: Option<&T>,
) -> Result<Vec<u8>, QueryError> {
let mut buffer = Vec::new();
let mut writer = NfNetlinkWriter::new(&mut buffer);
writer.write_header(
msg_type,
ProtocolFamily::Unspec,
NLM_F_DUMP as u16,
seq,
None,
);
if let Some(filter) = filter {
let buf = writer.add_data_zeroed(filter.get_size());
filter.write_payload(buf);
}
writer.finalize_writing_object();
Ok(buffer)
}
pub fn list_objects_with_data<'a, Object, Accumulator>(
data_type: u16,
cb: &dyn Fn(Object, &mut Accumulator) -> Result<(), QueryError>,
filter: Option<&Object>,
working_data: &'a mut Accumulator,
) -> Result<(), QueryError>
where
Object: NfNetlinkObject + NfNetlinkAttribute,
{
debug!("Listing objects of kind {}", data_type);
let sock = socket::socket(
AddressFamily::Netlink,
SockType::Raw,
SockFlag::empty(),
SockProtocol::NetlinkNetFilter,
)
.map_err(QueryError::NetlinkOpenError)?;
let seq = 0;
let chains_buf = get_list_of_objects(data_type, seq, filter)?;
socket::send(sock.as_raw_fd(), &chains_buf, MsgFlags::empty())
.map_err(QueryError::NetlinkSendError)?;
socket_close_wrapper(sock.as_raw_fd(), move |sock| {
recv_and_process(
sock,
None,
Some(&|buf: &[u8], working_data: &mut Accumulator| {
debug!("Calling Object::deserialize()");
cb(Object::deserialize(buf)?.0, working_data)
}),
working_data,
)
})
}