mod constants;
pub mod commands;
pub mod rule;
use std::ffi::CString;
use std::fmt::{self, Display, Formatter};
use std::io::{self, IoSlice};
use bytes::{Buf, BufMut, Bytes};
use libc::{
AF_NETLINK, AF_UNSPEC, NETLINK_NETFILTER, NFNL_SUBSYS_NFTABLES, NLM_F_ACK, NLM_F_APPEND,
NLM_F_CREATE, NLM_F_ECHO, NLM_F_REQUEST, SOCK_RAW,
};
use socket2::{Domain, MaybeUninitSlice, MsgHdr, MsgHdrMut, Protocol, Socket, Type};
use crate::commands::{AddChain, AddRule, AddTable, DelChain, DelRule, DelTable};
#[derive(Debug)]
pub struct Error(ErrorImpl);
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match &self.0 {
ErrorImpl::Io(err) => Display::fmt(err, f),
}
}
}
impl std::error::Error for Error {}
#[derive(Debug)]
enum ErrorImpl {
Io(io::Error),
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct Handle(u64);
impl Handle {
#[inline]
pub const fn to_bits(self) -> u64 {
self.0
}
#[inline]
pub const fn from_bits(bits: u64) -> Self {
Self(bits)
}
}
#[derive(Debug)]
pub struct Connection {
socket: Socket,
page_size: usize,
}
impl Connection {
pub fn new() -> Result<Self, Error> {
let socket = Socket::new(
Domain::from(AF_NETLINK),
Type::from(SOCK_RAW),
Some(Protocol::from(NETLINK_NETFILTER)),
)
.map_err(|err| Error(ErrorImpl::Io(err)))?;
let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) } as usize;
Ok(Self { socket, page_size })
}
pub fn execute(&mut self, batch: &Batch<'_>) -> Result<Vec<CommandResult>, Error> {
let mut buf = Vec::new();
write_batch_begin(&mut buf, 1);
let mut seq: u32 = 2;
for cmd in &batch.cmds {
let offset = buf.len();
write_cmd(&mut buf, cmd.id(), cmd.proto() as u8, 0);
cmd.encode(&mut buf);
let len = (buf.len() - offset) as u32;
buf[offset..offset + 4].copy_from_slice(&len.to_ne_bytes());
buf[offset + 8..offset + 12].copy_from_slice(&seq.to_ne_bytes());
seq += 1;
}
write_batch_end(&mut buf, seq);
self.socket
.sendmsg(&MsgHdr::new().with_buffers(&[IoSlice::new(&buf)]), 0)
.map_err(|err| Error(ErrorImpl::Io(err)))?;
let mut results = Vec::new();
let mut buf_size = 8_192;
if buf_size % self.page_size != 0 {
buf_size += self.page_size % buf_size;
}
let mut resp = Vec::with_capacity(buf_size);
'outer: loop {
resp.clear();
let mut buffers = [MaybeUninitSlice::new(resp.spare_capacity_mut())];
let mut hdr = MsgHdrMut::new().with_buffers(&mut buffers);
let count = self
.socket
.recvmsg(&mut hdr, 0)
.map_err(|err| Error(ErrorImpl::Io(err)))?;
unsafe {
resp.set_len(count);
}
let mut resp_buf = &resp[..];
while resp_buf.has_remaining() {
let header = Header::decode(&mut resp_buf).unwrap();
let len = usize::min(
(header.len as usize).saturating_sub(Header::SIZE),
resp_buf.len(),
);
let mut resp = &resp_buf[..len];
match header.ty as i32 {
libc::NLMSG_NOOP => {}
libc::NLMSG_ERROR => {
let err = resp.get_i32_ne().abs();
if err != 0 {
let err = io::Error::from_raw_os_error(err);
if cfg!(debug_assertions) {
let err = self
.socket
.recvmsg(&mut MsgHdrMut::new(), libc::MSG_DONTWAIT)
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::WouldBlock);
}
return Err(Error(ErrorImpl::Io(err)));
}
}
libc::NLMSG_DONE => {}
_ => {
resp.advance(4);
let index = header.seq as usize - 2;
match batch.cmds[index] {
Command::AddTable(_) => {}
Command::DelTable(_) => {}
Command::AddChain(_) => {
let handle = AddChain::read_handle(&mut resp).unwrap();
results.push(CommandResult { index, handle });
}
Command::DelChain(_) => {}
Command::AddRule(_) => {
let handle = AddRule::read_handle(&mut resp).unwrap();
results.push(CommandResult { index, handle });
}
Command::DelRule(_) => {}
}
}
}
if header.seq == seq {
break 'outer;
}
resp_buf.advance(len);
}
}
if cfg!(debug_assertions) {
let err = self
.socket
.recvmsg(&mut MsgHdrMut::new(), libc::MSG_DONTWAIT)
.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::WouldBlock);
}
Ok(results)
}
}
#[derive(Clone, Debug, Default)]
pub struct Batch<'a> {
cmds: Vec<Command<'a>>,
}
impl<'a> Batch<'a> {
pub fn new() -> Self {
Self { cmds: Vec::new() }
}
pub fn push<T>(&mut self, cmd: T)
where
T: Into<Command<'a>>,
{
self.cmds.push(cmd.into());
}
pub fn clear(&mut self) {
self.cmds.clear();
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum Command<'a> {
AddTable(AddTable<'a>),
DelTable(DelTable<'a>),
AddChain(AddChain<'a>),
DelChain(DelChain<'a>),
AddRule(AddRule<'a>),
DelRule(DelRule<'a>),
}
impl Command<'_> {
fn id(&self) -> u16 {
match self {
Self::AddTable(_) => AddTable::ID,
Self::DelTable(_) => DelTable::ID,
Self::AddChain(_) => AddChain::ID,
Self::DelChain(_) => DelChain::ID,
Self::AddRule(_) => AddRule::ID,
Self::DelRule(_) => DelRule::ID,
}
}
fn proto(&self) -> ProtoFamily {
match self {
Self::AddTable(cmd) => cmd.proto(),
Self::DelTable(cmd) => cmd.proto(),
Self::AddChain(cmd) => cmd.proto(),
Self::DelChain(cmd) => cmd.proto(),
Self::AddRule(cmd) => cmd.proto(),
Self::DelRule(cmd) => cmd.proto(),
}
}
}
impl Encode for Command<'_> {
fn encode<B>(&self, buf: B)
where
B: BufMut,
{
match self {
Self::AddTable(cmd) => cmd.encode(buf),
Self::DelTable(cmd) => cmd.encode(buf),
Self::AddChain(cmd) => cmd.encode(buf),
Self::DelChain(cmd) => cmd.encode(buf),
Self::AddRule(cmd) => cmd.encode(buf),
Self::DelRule(cmd) => cmd.encode(buf),
}
}
}
#[derive(Clone, Debug)]
pub struct CommandResult {
pub index: usize,
pub handle: Option<Handle>,
}
fn write_attribute<B>(mut buf: B, ty: u16, data: &[u8])
where
B: BufMut,
{
AttributeHeader {
len: 4 + data.len() as u16,
ty,
}
.encode(&mut buf);
buf.put_slice(data);
if data.len() % 4 != 0 {
let pad = 4 - (data.len() % 4);
for _ in 0..pad {
buf.put_u8(0);
}
}
}
fn read_attribute<B>(mut buf: B) -> Result<(AttributeHeader, Bytes), Error>
where
B: Buf,
{
let len = buf.get_u16_le();
let ty = buf.get_u16_le();
let data = buf.copy_to_bytes(len.saturating_sub(4).into());
if len % 4 != 0 {
let pad = 4 - (len % 4);
buf.advance(pad.into());
}
Ok((AttributeHeader { len, ty }, data))
}
trait Message {
const ID: u16;
fn proto(&self) -> ProtoFamily;
fn read_handle<B>(_buf: B) -> Result<Option<Handle>, Error>
where
B: Buf,
{
Ok(None)
}
}
fn write_cmd<B>(mut buf: B, cmd: u16, family: u8, seq: u32)
where
B: BufMut,
{
Header {
len: 0,
ty: ((NFNL_SUBSYS_NFTABLES as u16) << 8) | cmd,
flags: NLM_F_REQUEST as u16
| NLM_F_ECHO as u16
| NLM_F_CREATE as u16
| NLM_F_APPEND as u16, seq,
pid: 0,
}
.encode(&mut buf);
NfHeader {
nfgen_family: family,
version: NF_VERSION,
res_id: 0,
}
.encode(&mut buf);
}
#[derive(Copy, Clone, Debug)]
struct Header {
len: u32,
ty: u16,
flags: u16,
seq: u32,
pid: u32,
}
impl Header {
const SIZE: usize = size_of::<libc::nlmsghdr>();
}
impl Encode for Header {
fn encode<B>(&self, mut buf: B)
where
B: BufMut,
{
buf.put_u32_ne(self.len);
buf.put_u16_ne(self.ty);
buf.put_u16_ne(self.flags);
buf.put_u32_ne(self.seq);
buf.put_u32_ne(self.pid);
}
}
impl Decode for Header {
fn decode<B>(mut buf: B) -> Result<Self, Error>
where
B: Buf,
{
let len = buf.get_u32_ne();
let ty = buf.get_u16_ne();
let flags = buf.get_u16_ne();
let seq = buf.get_u32_ne();
let pid = buf.get_u32_ne();
Ok(Self {
len,
ty,
flags,
seq,
pid,
})
}
}
#[derive(Copy, Clone, Debug)]
struct NfHeader {
nfgen_family: u8,
version: u8,
res_id: u16,
}
impl Encode for NfHeader {
fn encode<B>(&self, mut buf: B)
where
B: BufMut,
{
buf.put_u8(self.nfgen_family);
buf.put_u8(self.version);
buf.put_u16_ne(self.res_id);
}
}
impl Decode for NfHeader {
fn decode<B>(mut buf: B) -> Result<Self, Error>
where
B: Buf,
{
let nfgen_family = buf.get_u8();
let version = buf.get_u8();
let res_id = buf.get_u16_ne();
Ok(Self {
nfgen_family,
version,
res_id,
})
}
}
const NF_VERSION: u8 = 0;
const NFT_MSG_BATCH_BEGIN: u16 = 0x10;
const NFT_MSG_BATCH_END: u16 = 0x11;
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
#[repr(u16)]
pub enum ProtoFamily {
Unspec = libc::NFPROTO_UNSPEC as u16,
Inet = libc::NFPROTO_INET as u16,
Ipv4 = libc::NFPROTO_IPV4 as u16,
Arp = libc::NFPROTO_ARP as u16,
NetDev = libc::NFPROTO_NETDEV as u16,
Bridge = libc::NFPROTO_BRIDGE as u16,
Ipv6 = libc::NFPROTO_IPV6 as u16,
DecNet = libc::NFPROTO_DECNET as u16,
}
#[derive(Copy, Clone, Debug)]
struct AttributeHeader {
len: u16,
ty: u16,
}
impl Encode for AttributeHeader {
fn encode<B>(&self, mut buf: B)
where
B: BufMut,
{
buf.put_u16_ne(self.len);
buf.put_u16_ne(self.ty);
}
}
impl Encode for CString {
fn encode<B>(&self, mut buf: B)
where
B: BufMut,
{
buf.put_slice(self.as_bytes_with_nul());
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
#[repr(u32)]
pub enum Policy {
Accept = libc::NF_ACCEPT as u32,
Drop = libc::NF_DROP as u32,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
#[repr(u16)]
pub enum Hook {
PreRouting = libc::NF_INET_PRE_ROUTING as u16,
In = libc::NF_INET_LOCAL_IN as u16,
Forward = libc::NF_INET_FORWARD as u16,
Out = libc::NF_INET_LOCAL_OUT as u16,
PostRouting = libc::NF_INET_POST_ROUTING as u16,
Ingress = libc::NF_INET_INGRESS as u16,
}
trait Encode {
fn encode<B>(&self, buf: B)
where
B: BufMut;
}
trait Decode: Sized {
fn decode<B>(buf: B) -> Result<Self, Error>
where
B: Buf;
}
fn write_batch_begin(mut buf: &mut Vec<u8>, seq: u32) {
Header {
len: 20,
ty: NFT_MSG_BATCH_BEGIN,
flags: NLM_F_REQUEST as u16,
seq,
pid: 0,
}
.encode(&mut buf);
NfHeader {
nfgen_family: AF_UNSPEC as u8,
version: NF_VERSION,
res_id: (NFNL_SUBSYS_NFTABLES as u16) << 8,
}
.encode(&mut buf);
}
fn write_batch_end(mut buf: &mut Vec<u8>, seq: u32) {
Header {
len: 20,
ty: NFT_MSG_BATCH_END,
flags: NLM_F_REQUEST as u16 | NLM_F_ACK as u16,
seq,
pid: 0,
}
.encode(&mut buf);
NfHeader {
nfgen_family: AF_UNSPEC as u8,
version: NF_VERSION,
res_id: (NFNL_SUBSYS_NFTABLES as u16) << 8,
}
.encode(&mut buf);
}