use std::net::IpAddr;
use super::bpf::{BpfFilter, BuildError};
use super::ipnet::IpNet;
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum MatchFrag {
EthType(u16),
Vlan,
VlanId(u16),
IpProto(u8),
SrcHost(IpAddr),
DstHost(IpAddr),
AnyHost(IpAddr),
SrcNet(IpNet),
DstNet(IpNet),
AnyNet(IpNet),
SrcPort(u16),
DstPort(u16),
AnyPort(u16),
}
#[derive(Debug, Clone, Default)]
#[must_use]
pub struct BpfFilterBuilder {
pub(crate) fragments: Vec<MatchFrag>,
pub(crate) or_branches: Vec<BpfFilterBuilder>,
pub(crate) negated: bool,
}
impl BpfFilterBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn eth_type(mut self, ty: u16) -> Self {
self.fragments.push(MatchFrag::EthType(ty));
self
}
pub fn ipv4(self) -> Self {
self.eth_type(0x0800)
}
pub fn ipv6(self) -> Self {
self.eth_type(0x86dd)
}
pub fn arp(self) -> Self {
self.eth_type(0x0806)
}
pub fn vlan(mut self) -> Self {
self.fragments.push(MatchFrag::Vlan);
self
}
pub fn vlan_id(mut self, id: u16) -> Self {
self.fragments.push(MatchFrag::VlanId(id));
self
}
pub fn ip_proto(mut self, proto: u8) -> Self {
self.fragments.push(MatchFrag::IpProto(proto));
self
}
pub fn tcp(self) -> Self {
self.ip_proto(6)
}
pub fn udp(self) -> Self {
self.ip_proto(17)
}
pub fn icmp(self) -> Self {
self.ip_proto(1)
}
pub fn src_host(mut self, addr: IpAddr) -> Self {
self.fragments.push(MatchFrag::SrcHost(addr));
self
}
pub fn dst_host(mut self, addr: IpAddr) -> Self {
self.fragments.push(MatchFrag::DstHost(addr));
self
}
pub fn host(mut self, addr: IpAddr) -> Self {
self.fragments.push(MatchFrag::AnyHost(addr));
self
}
pub fn src_net(mut self, net: IpNet) -> Self {
self.fragments.push(MatchFrag::SrcNet(net));
self
}
pub fn dst_net(mut self, net: IpNet) -> Self {
self.fragments.push(MatchFrag::DstNet(net));
self
}
pub fn net(mut self, net: IpNet) -> Self {
self.fragments.push(MatchFrag::AnyNet(net));
self
}
pub fn src_port(mut self, port: u16) -> Self {
self.fragments.push(MatchFrag::SrcPort(port));
self
}
pub fn dst_port(mut self, port: u16) -> Self {
self.fragments.push(MatchFrag::DstPort(port));
self
}
pub fn port(mut self, port: u16) -> Self {
self.fragments.push(MatchFrag::AnyPort(port));
self
}
pub fn negate(mut self) -> Self {
self.negated = !self.negated;
self
}
pub fn or(mut self, build: impl FnOnce(BpfFilterBuilder) -> BpfFilterBuilder) -> Self {
self.or_branches.push(build(BpfFilterBuilder::new()));
self
}
pub fn build(self) -> Result<BpfFilter, BuildError> {
super::bpf_compile::compile(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_builder_compiles() {
let f = BpfFilterBuilder::new().build().unwrap();
assert!(!f.is_empty()); }
#[test]
fn eth_type_records_fragment() {
let b = BpfFilterBuilder::new().eth_type(0x0800);
assert_eq!(b.fragments, vec![MatchFrag::EthType(0x0800)]);
}
#[test]
fn tcp_records_ip_proto_6() {
let b = BpfFilterBuilder::new().tcp();
assert_eq!(b.fragments, vec![MatchFrag::IpProto(6)]);
}
#[test]
fn negate_toggles_flag() {
let b = BpfFilterBuilder::new().tcp().negate();
assert!(b.negated);
let b = b.negate();
assert!(!b.negated);
}
#[test]
fn or_collects_branch() {
let b = BpfFilterBuilder::new().tcp().or(|b| b.udp().port(53));
assert_eq!(b.or_branches.len(), 1);
assert_eq!(
b.or_branches[0].fragments,
vec![MatchFrag::IpProto(17), MatchFrag::AnyPort(53)]
);
}
#[test]
fn chained_methods_match_capture_builder_style() {
let b = BpfFilterBuilder::new()
.ipv4()
.tcp()
.dst_port(80)
.src_host("10.0.0.1".parse().unwrap());
assert_eq!(
b.fragments,
vec![
MatchFrag::EthType(0x0800),
MatchFrag::IpProto(6),
MatchFrag::DstPort(80),
MatchFrag::SrcHost("10.0.0.1".parse().unwrap()),
]
);
}
}