nf_tables 0.1.0

Pure Rust crate to interact with the Linux nf_tables subsystem
Documentation
use std::net::IpAddr;

use crate::rule::{
    Bitwise, Cmp, CmpOp, Expr, Meta, MetaProperty, Payload, PayloadBase, PayloadOp, Register,
    Verdict,
};

/// High-level builder for rule expressions.
#[derive(Clone, Debug, Default)]
pub struct RuleBuilder {
    ipv4: bool,
    ipv6: bool,
    ip_src_addr: Option<(IpAddr, u8)>,
    ip_dst_addr: Option<(IpAddr, u8)>,
    ether_saddr: Option<[u8; 6]>,
    ether_daddr: Option<[u8; 6]>,
    verdict: Option<Verdict>,
}

impl RuleBuilder {
    pub fn new() -> Self {
        Self::default()
    }

    /// Adds a filter matching if ethernet frame contains the given source address.
    ///
    /// This is the equivalent of `ether saddr <addr>`.
    pub fn with_ether_saddr(mut self, addr: [u8; 6]) -> Self {
        self.ether_saddr = Some(addr);
        self
    }

    /// Adds a filter matching if the ethernet frame contains the given destination address.
    ///
    /// This is the equivalent of `ether daddr <addr>`.
    pub fn with_ether_daddr(mut self, addr: [u8; 6]) -> Self {
        self.ether_daddr = Some(addr);
        self
    }

    /// Adds a filter matching if the packet contains an IPv4 packet.
    pub fn with_ipv4(mut self) -> Self {
        self.ipv4 = true;
        self
    }

    /// Adds a filter matching if the packet contains an IPv6 packet.
    pub fn with_ipv6(mut self) -> Self {
        self.ipv6 = true;
        self
    }

    /// Add a filter matching if the packet contains the given source address.
    ///
    /// This is the equivalent of `ip saddr <addr>` or `ip6 saddr <addr>`.
    pub fn with_ip_saddr(self, addr: IpAddr) -> Self {
        match addr {
            IpAddr::V4(_) => self.with_ip_saddr_prefix(addr, 32),
            IpAddr::V6(_) => self.with_ip_saddr_prefix(addr, 128),
        }
    }

    /// Adds a filter matching if the given `prefix` bits match the source address of the packet.
    ///
    /// This is the equivalent of `ip saddr <addr>` or `ip6 saddr <addr>` where `<addr>` is in CIDR
    /// notation (e.g. `127.0.0.0/8`).
    ///
    /// # Panics
    ///
    /// - If the given [`IpAddr`] is an IPv4 address: panics if `prefix > 32`.
    /// - If the given [`IpAddr`] is an IPv6 address: panics if `prefix > 128`.
    pub fn with_ip_saddr_prefix(mut self, addr: IpAddr, prefix: u8) -> Self {
        self.ip_src_addr = Some((addr, prefix));
        match addr {
            IpAddr::V4(_) => {
                debug_assert!(prefix <= 32);
                self.ipv4 = true;
            }
            IpAddr::V6(_) => {
                debug_assert!(prefix <= 128);
                self.ipv6 = true;
            }
        }

        self
    }

    /// Adds a filter matching if the packet contains the given destination address.
    ///
    /// This is the equivalent of `ip daddr <addr>` or `ip6a addr <addr>`.
    pub fn with_ip_daddr(self, addr: IpAddr) -> Self {
        match addr {
            IpAddr::V4(_) => self.with_ip_daddr_prefix(addr, 32),
            IpAddr::V6(_) => self.with_ip_daddr_prefix(addr, 128),
        }
    }

    /// Adds a filter matching if the given `prefix` bits match the destination address of the
    /// packet.
    ///
    /// This is the equivalent of `ip saddr <addr>` or `ip6 saddr <addr>` where `<addr>` is in CIDR
    /// notation (e.g. `127.0.0.0/8`).
    ///
    /// # Panics
    ///
    /// - If the given [`IpAddr`] is an IPv4 address: panics if `prefix > 32`.
    /// - If the given [`IpAddr`] is an IPv6 address: panics if `prefix > 128`.
    pub fn with_ip_daddr_prefix(mut self, addr: IpAddr, prefix: u8) -> Self {
        self.ip_dst_addr = Some((addr, prefix));
        match addr {
            IpAddr::V4(_) => {
                debug_assert!(prefix <= 32);
                self.ipv4 = true;
            }
            IpAddr::V6(_) => {
                debug_assert!(prefix <= 128);
                self.ipv6 = true;
            }
        }

        self
    }

    /// Sets the [`Verdict`] applied to all packets matching the filters.
    ///
    /// Note that the verdict is always applied at the very end. If you add additional
    /// filters after calling `with_verdict` the verdict will be placed after those filters.
    /// This is in contrast to `nftables`, which allows to place filters after a verdict.
    pub fn with_verdict(mut self, verdict: Verdict) -> Self {
        self.verdict = Some(verdict);
        self
    }

    pub fn build(&self) -> Vec<Expr> {
        let mut exprs = Vec::new();

        if self.ether_saddr.is_some() || self.ether_daddr.is_some() {
            exprs.extend([
                Expr::Meta(Meta {
                    key: MetaProperty::IIfType,
                    register: Register::Reg1,
                    src_register: false,
                }),
                Expr::Cmp(Cmp {
                    op: CmpOp::Equal,
                    register: Register::Reg1,
                    data: libc::ARPHRD_ETHER.to_ne_bytes().to_vec(),
                }),
            ]);

            if let Some(addr) = self.ether_saddr {
                exprs.extend([
                    Expr::Payload(Payload {
                        op: PayloadOp::Load(Register::Reg1),
                        base: PayloadBase::LinkLayer,
                        offset: 6,
                        len: 6,
                    }),
                    Expr::Cmp(Cmp {
                        op: CmpOp::Equal,
                        register: Register::Reg1,
                        data: addr.to_vec(),
                    }),
                ]);
            }

            if let Some(addr) = self.ether_daddr {
                exprs.extend([
                    Expr::Payload(Payload {
                        op: PayloadOp::Load(Register::Reg1),
                        base: PayloadBase::LinkLayer,
                        offset: 0,
                        len: 6,
                    }),
                    Expr::Cmp(Cmp {
                        op: CmpOp::Equal,
                        register: Register::Reg1,
                        data: addr.to_vec(),
                    }),
                ]);
            }
        }

        if self.ipv4 {
            exprs.extend([
                Expr::Meta(Meta {
                    key: MetaProperty::NfProto,
                    register: Register::Reg1,
                    src_register: false,
                }),
                Expr::Cmp(Cmp {
                    op: CmpOp::Equal,
                    register: Register::Reg1,
                    data: vec![libc::NFPROTO_IPV4 as u8],
                }),
            ]);
        }

        if self.ipv6 {
            exprs.extend([
                Expr::Meta(Meta {
                    key: MetaProperty::NfProto,
                    register: Register::Reg1,
                    src_register: false,
                }),
                Expr::Cmp(Cmp {
                    op: CmpOp::Equal,
                    register: Register::Reg1,
                    data: vec![libc::NFPROTO_IPV6 as u8],
                }),
            ]);
        }

        match self.ip_src_addr {
            Some((IpAddr::V4(addr), prefix)) => {
                exprs.push(Expr::Payload(Payload {
                    op: PayloadOp::Load(Register::Reg1),
                    base: PayloadBase::Network,
                    offset: 12,
                    len: 4,
                }));

                let mut addr = addr.to_bits();

                if prefix != 32 {
                    let mask = ((1_u32 << prefix) - 1).reverse_bits();
                    addr &= mask;

                    exprs.push(Expr::Bitwise(Bitwise {
                        src_register: Register::Reg1,
                        dst_register: Register::Reg1,
                        mask: mask.to_be_bytes().to_vec(),
                        xor: Vec::new(),
                        len: 4,
                    }));
                }

                exprs.extend([Expr::Cmp(Cmp {
                    op: CmpOp::Equal,
                    register: Register::Reg1,
                    data: addr.to_be_bytes().to_vec(),
                })]);
            }
            Some((IpAddr::V6(addr), prefix)) => {
                exprs.push(Expr::Payload(Payload {
                    op: PayloadOp::Load(Register::Reg1),
                    base: PayloadBase::Network,
                    offset: 8,
                    len: 16,
                }));

                let mut addr = addr.to_bits();

                if prefix != 128 {
                    let mask = ((1_u128 << prefix) - 1).reverse_bits();
                    addr &= mask;

                    exprs.push(Expr::Bitwise(Bitwise {
                        src_register: Register::Reg1,
                        dst_register: Register::Reg1,
                        mask: mask.to_be_bytes().to_vec(),
                        xor: Vec::new(),
                        len: 16,
                    }));
                }

                exprs.push(Expr::Cmp(Cmp {
                    op: CmpOp::Equal,
                    register: Register::Reg1,
                    data: addr.to_be_bytes().to_vec(),
                }));
            }
            None => {}
        }

        match self.ip_dst_addr {
            Some((IpAddr::V4(addr), prefix)) => {
                exprs.push(Expr::Payload(Payload {
                    op: PayloadOp::Load(Register::Reg1),
                    base: PayloadBase::Network,
                    offset: 16,
                    len: 4,
                }));

                let mut addr = addr.to_bits();

                if prefix != 32 {
                    let mask = ((1_u32 << prefix) - 1).reverse_bits();
                    addr &= mask;

                    exprs.push(Expr::Bitwise(Bitwise {
                        src_register: Register::Reg1,
                        dst_register: Register::Reg1,
                        mask: mask.to_be_bytes().to_vec(),
                        xor: Vec::new(),
                        len: 4,
                    }));
                }

                exprs.extend([Expr::Cmp(Cmp {
                    op: CmpOp::Equal,
                    register: Register::Reg1,
                    data: addr.to_be_bytes().to_vec(),
                })]);
            }
            Some((IpAddr::V6(addr), prefix)) => {
                exprs.push(Expr::Payload(Payload {
                    op: PayloadOp::Load(Register::Reg1),
                    base: PayloadBase::Network,
                    offset: 24,
                    len: 16,
                }));

                let mut addr = addr.to_bits();

                if prefix != 128 {
                    let mask = ((1_u128 << prefix) - 1).reverse_bits();
                    addr &= mask;

                    exprs.push(Expr::Bitwise(Bitwise {
                        src_register: Register::Reg1,
                        dst_register: Register::Reg1,
                        mask: mask.to_be_bytes().to_vec(),
                        xor: Vec::new(),
                        len: 16,
                    }));
                }

                exprs.push(Expr::Cmp(Cmp {
                    op: CmpOp::Equal,
                    register: Register::Reg1,
                    data: addr.to_be_bytes().to_vec(),
                }));
            }
            None => {}
        }

        if let Some(verdict) = self.verdict.clone() {
            exprs.push(Expr::Verdict(verdict));
        }

        exprs
    }
}

#[cfg(test)]
mod tests {
    use std::net::{IpAddr, Ipv4Addr};

    use crate::rule::builder::RuleBuilder;
    use crate::rule::{
        Bitwise, Cmp, CmpOp, Expr, Meta, MetaProperty, Payload, PayloadBase, PayloadOp, Register,
    };

    #[test]
    fn build_ipv4_with_prefix() {
        let exprs = RuleBuilder::new()
            .with_ip_saddr_prefix(IpAddr::V4(Ipv4Addr::new(10, 20, 255, 40)), 19)
            .build();

        assert_eq!(
            exprs,
            [
                Expr::Meta(Meta {
                    key: MetaProperty::NfProto,
                    register: Register::Reg1,
                    src_register: false,
                }),
                Expr::Cmp(Cmp {
                    op: CmpOp::Equal,
                    register: Register::Reg1,
                    data: vec![2],
                }),
                Expr::Payload(Payload {
                    op: PayloadOp::Load(Register::Reg1),
                    base: PayloadBase::Network,
                    offset: 12,
                    len: 4,
                }),
                Expr::Bitwise(Bitwise {
                    src_register: Register::Reg1,
                    dst_register: Register::Reg1,
                    mask: vec![0b1111_1111, 0b1111_1111, 0b1110_0000, 0b0000_0000],
                    xor: vec![],
                    len: 4,
                }),
                Expr::Cmp(Cmp {
                    op: CmpOp::Equal,
                    register: Register::Reg1,
                    data: vec![10, 20, 224, 0],
                }),
            ]
        );
    }
}