use std::net::IpAddr;
use super::bpf::{BpfFilter, BpfInsn, BuildError};
use super::bpf_builder::{BpfFilterBuilder, MatchFrag};
use super::ipnet::IpNet;
pub(crate) const BPF_LD_H_ABS: u16 = 0x28;
pub(crate) const BPF_LD_B_ABS: u16 = 0x30;
pub(crate) const BPF_LD_W_ABS: u16 = 0x20;
pub(crate) const BPF_LD_H_IND: u16 = 0x48;
pub(crate) const BPF_LDX_B_MSH: u16 = 0xb1;
pub(crate) const BPF_ALU_AND_K: u16 = 0x54;
pub(crate) const BPF_JMP_JEQ_K: u16 = 0x15;
pub(crate) const BPF_JMP_JSET_K: u16 = 0x45;
#[allow(dead_code)]
pub(crate) const BPF_JMP_JA: u16 = 0x05;
pub(crate) const BPF_RET_K: u16 = 0x06;
pub(crate) const ACCEPT_RETVAL: u32 = 0xFFFF;
#[allow(dead_code)]
pub(crate) const DROP_RETVAL: u32 = 0;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Label {
Fallthrough,
Accept,
Drop,
SkipNextN(u8),
Branch(u32),
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct SymInsn {
pub code: u16,
pub jt: Label,
pub jf: Label,
pub k: u32,
}
impl SymInsn {
fn straight(code: u16, k: u32) -> Self {
Self {
code,
jt: Label::Fallthrough,
jf: Label::Fallthrough,
k,
}
}
fn jump(code: u16, jt: Label, jf: Label, k: u32) -> Self {
Self { code, jt, jf, k }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum L3Family {
Ipv4,
Ipv6,
}
#[derive(Debug, Clone, Copy)]
struct CompileCtx {
vlan_offset: u8,
l3: L3Family,
}
pub(crate) fn compile(builder: BpfFilterBuilder) -> Result<BpfFilter, BuildError> {
let BpfFilterBuilder {
fragments,
or_branches,
negated,
} = builder;
if fragments.is_empty() && or_branches.is_empty() {
return BpfFilter::new(vec![BpfInsn {
code: BPF_RET_K,
jt: 0,
jf: 0,
k: if negated { DROP_RETVAL } else { ACCEPT_RETVAL },
}]);
}
let mut blocks: Vec<Vec<SymInsn>> = Vec::with_capacity(1 + or_branches.len());
blocks.push(compile_block(&fragments)?);
for branch in &or_branches {
if branch.fragments.is_empty() {
return Err(BuildError::EmptyOr);
}
if !branch.or_branches.is_empty() || branch.negated {
return Err(BuildError::ConflictingProtocols {
a: "or() branch with its own or()/negate()",
b: "<flatten the chain>",
});
}
blocks.push(compile_block(&branch.fragments)?);
}
let mut sym: Vec<SymInsn> = Vec::new();
let mut branch_starts: Vec<usize> = Vec::with_capacity(blocks.len());
let n_blocks = blocks.len();
for (i, block) in blocks.into_iter().enumerate() {
branch_starts.push(sym.len());
let drop_target = if i + 1 < n_blocks {
Label::Branch((i + 1) as u32)
} else {
Label::Drop
};
let mut block = block;
for insn in block.iter_mut() {
if insn.jt == Label::Drop {
insn.jt = drop_target;
}
if insn.jf == Label::Drop {
insn.jf = drop_target;
}
}
sym.extend(block);
if i + 1 < n_blocks {
sym.push(SymInsn::jump(
BPF_JMP_JA,
Label::Accept,
Label::Fallthrough,
0,
));
}
}
let accept_pc = sym.len();
sym.push(SymInsn::straight(
BPF_RET_K,
if negated { DROP_RETVAL } else { ACCEPT_RETVAL },
));
let drop_pc = sym.len();
sym.push(SymInsn::straight(
BPF_RET_K,
if negated { ACCEPT_RETVAL } else { DROP_RETVAL },
));
let resolved = resolve(&sym, accept_pc, drop_pc, &branch_starts)?;
BpfFilter::new(resolved)
}
fn compile_block(fragments: &[MatchFrag]) -> Result<Vec<SymInsn>, BuildError> {
let normalized = normalize(fragments.to_vec())?;
let mut ctx = infer_ctx(&normalized);
let mut sym = Vec::with_capacity(normalized.len() * 6);
for frag in &normalized {
compile_fragment(frag, &mut ctx, &mut sym)?;
}
Ok(sym)
}
fn normalize(fragments: Vec<MatchFrag>) -> Result<Vec<MatchFrag>, BuildError> {
let mut chosen_eth: Option<u16> = None;
let mut chosen_proto: Option<u8> = None;
for f in &fragments {
match f {
MatchFrag::EthType(t) => match chosen_eth {
Some(prev) if prev != *t => {
return Err(BuildError::ConflictingProtocols {
a: ethtype_label(prev),
b: ethtype_label(*t),
});
}
_ => chosen_eth = Some(*t),
},
MatchFrag::IpProto(p) => match chosen_proto {
Some(prev) if prev != *p => {
return Err(BuildError::ConflictingProtocols {
a: ipproto_label(prev),
b: ipproto_label(*p),
});
}
_ => chosen_proto = Some(*p),
},
_ => {}
}
}
let needs_ip = chosen_eth.is_none()
&& fragments.iter().any(|f| {
matches!(
f,
MatchFrag::IpProto(_)
| MatchFrag::SrcHost(_)
| MatchFrag::DstHost(_)
| MatchFrag::AnyHost(_)
| MatchFrag::SrcNet(_)
| MatchFrag::DstNet(_)
| MatchFrag::AnyNet(_)
| MatchFrag::SrcPort(_)
| MatchFrag::DstPort(_)
| MatchFrag::AnyPort(_)
)
});
let needs_ip_proto = chosen_proto.is_none()
&& fragments.iter().any(|f| {
matches!(
f,
MatchFrag::SrcPort(_) | MatchFrag::DstPort(_) | MatchFrag::AnyPort(_)
)
});
let mut out: Vec<MatchFrag> = Vec::with_capacity(fragments.len() + 2);
let has_vlan = fragments
.iter()
.filter(|f| matches!(f, MatchFrag::Vlan))
.count();
if has_vlan > 1 {
return Err(BuildError::ConflictingProtocols {
a: "vlan",
b: "vlan (Q-in-Q not supported)",
});
}
if has_vlan == 1 {
out.push(MatchFrag::Vlan);
}
for f in &fragments {
if matches!(f, MatchFrag::VlanId(_)) {
push_dedup(&mut out, f.clone());
}
}
if let Some(t) = chosen_eth {
out.push(MatchFrag::EthType(t));
} else if needs_ip {
out.push(MatchFrag::EthType(0x0800));
}
if let Some(p) = chosen_proto {
out.push(MatchFrag::IpProto(p));
} else if needs_ip_proto {
return Err(BuildError::ConflictingProtocols {
a: "port",
b: "<no IP protocol — call .tcp() or .udp()>",
});
}
for f in &fragments {
if matches!(
f,
MatchFrag::SrcHost(_)
| MatchFrag::DstHost(_)
| MatchFrag::AnyHost(_)
| MatchFrag::SrcNet(_)
| MatchFrag::DstNet(_)
| MatchFrag::AnyNet(_)
) {
push_dedup(&mut out, f.clone());
}
}
for f in &fragments {
if matches!(
f,
MatchFrag::SrcPort(_) | MatchFrag::DstPort(_) | MatchFrag::AnyPort(_)
) {
push_dedup(&mut out, f.clone());
}
}
Ok(out)
}
fn push_dedup(buf: &mut Vec<MatchFrag>, f: MatchFrag) {
if !buf.iter().any(|existing| existing == &f) {
buf.push(f);
}
}
fn ethtype_label(t: u16) -> &'static str {
match t {
0x0800 => "ipv4",
0x86dd => "ipv6",
0x0806 => "arp",
0x8100 => "vlan",
_ => "eth_type",
}
}
fn ipproto_label(p: u8) -> &'static str {
match p {
1 => "icmp",
6 => "tcp",
17 => "udp",
47 => "gre",
58 => "icmpv6",
_ => "ip_proto",
}
}
fn infer_ctx(fragments: &[MatchFrag]) -> CompileCtx {
let l3 = if fragments
.iter()
.any(|f| matches!(f, MatchFrag::EthType(0x86dd)))
{
L3Family::Ipv6
} else {
L3Family::Ipv4
};
CompileCtx { vlan_offset: 0, l3 }
}
fn compile_fragment(
frag: &MatchFrag,
ctx: &mut CompileCtx,
out: &mut Vec<SymInsn>,
) -> Result<(), BuildError> {
match frag {
MatchFrag::EthType(t) => emit_eth_type(*t, ctx, out),
MatchFrag::Vlan => {
if ctx.vlan_offset != 0 {
return Err(BuildError::ConflictingProtocols {
a: "vlan",
b: "vlan (Q-in-Q not supported)",
});
}
out.push(SymInsn::straight(BPF_LD_H_ABS, 12));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
0x8100,
));
ctx.vlan_offset = 4;
Ok(())
}
MatchFrag::VlanId(id) => emit_vlan_id(*id, ctx, out),
MatchFrag::IpProto(p) => emit_ip_proto(*p, ctx, out),
MatchFrag::SrcHost(IpAddr::V4(a)) => {
emit_ipv4_host(SrcDst::Src, u32::from_be_bytes(a.octets()), ctx, out)
}
MatchFrag::DstHost(IpAddr::V4(a)) => {
emit_ipv4_host(SrcDst::Dst, u32::from_be_bytes(a.octets()), ctx, out)
}
MatchFrag::AnyHost(IpAddr::V4(a)) => {
emit_ipv4_host(SrcDst::Any, u32::from_be_bytes(a.octets()), ctx, out)
}
MatchFrag::SrcHost(IpAddr::V6(a)) => emit_ipv6_host(SrcDst::Src, a.octets(), ctx, out),
MatchFrag::DstHost(IpAddr::V6(a)) => emit_ipv6_host(SrcDst::Dst, a.octets(), ctx, out),
MatchFrag::AnyHost(IpAddr::V6(a)) => emit_ipv6_host(SrcDst::Any, a.octets(), ctx, out),
MatchFrag::SrcNet(net) | MatchFrag::DstNet(net) | MatchFrag::AnyNet(net) => {
let sd = match frag {
MatchFrag::SrcNet(_) => SrcDst::Src,
MatchFrag::DstNet(_) => SrcDst::Dst,
_ => SrcDst::Any,
};
if net.is_ipv6() {
return Err(BuildError::Ipv6ExtHeader);
}
emit_ipv4_net(sd, net, ctx, out)
}
MatchFrag::SrcPort(p) => emit_l4_port(SrcDst::Src, *p, ctx, out),
MatchFrag::DstPort(p) => emit_l4_port(SrcDst::Dst, *p, ctx, out),
MatchFrag::AnyPort(p) => emit_l4_port(SrcDst::Any, *p, ctx, out),
}
}
fn emit_vlan_id(id: u16, ctx: &CompileCtx, out: &mut Vec<SymInsn>) -> Result<(), BuildError> {
if ctx.vlan_offset == 0 {
return Err(BuildError::ConflictingProtocols {
a: "vlan_id",
b: "<requires .vlan() earlier in the chain>",
});
}
out.push(SymInsn::straight(BPF_LD_H_ABS, 14));
out.push(SymInsn::straight(BPF_ALU_AND_K, 0x0FFF));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
u32::from(id),
));
Ok(())
}
#[derive(Debug, Clone, Copy)]
enum SrcDst {
Src,
Dst,
Any,
}
fn emit_eth_type(t: u16, ctx: &CompileCtx, out: &mut Vec<SymInsn>) -> Result<(), BuildError> {
let off = 12u32 + u32::from(ctx.vlan_offset);
out.push(SymInsn::straight(BPF_LD_H_ABS, off));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
u32::from(t),
));
Ok(())
}
fn emit_ip_proto(p: u8, ctx: &CompileCtx, out: &mut Vec<SymInsn>) -> Result<(), BuildError> {
let off = match ctx.l3 {
L3Family::Ipv4 => 23u32 + u32::from(ctx.vlan_offset), L3Family::Ipv6 => 20u32 + u32::from(ctx.vlan_offset), };
out.push(SymInsn::straight(BPF_LD_B_ABS, off));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
u32::from(p),
));
Ok(())
}
fn emit_ipv4_host(
sd: SrcDst,
addr: u32,
ctx: &CompileCtx,
out: &mut Vec<SymInsn>,
) -> Result<(), BuildError> {
let src_off = 26u32 + u32::from(ctx.vlan_offset); let dst_off = 30u32 + u32::from(ctx.vlan_offset); match sd {
SrcDst::Src => {
out.push(SymInsn::straight(BPF_LD_W_ABS, src_off));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
addr,
));
}
SrcDst::Dst => {
out.push(SymInsn::straight(BPF_LD_W_ABS, dst_off));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
addr,
));
}
SrcDst::Any => {
out.push(SymInsn::straight(BPF_LD_W_ABS, src_off));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::SkipNextN(2),
Label::Fallthrough,
addr,
));
out.push(SymInsn::straight(BPF_LD_W_ABS, dst_off));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
addr,
));
}
}
Ok(())
}
fn emit_ipv4_net(
sd: SrcDst,
net: &IpNet,
ctx: &CompileCtx,
out: &mut Vec<SymInsn>,
) -> Result<(), BuildError> {
if !net.is_ipv4() {
return Err(BuildError::Ipv6ExtHeader);
}
if net.prefix > 32 {
return Err(BuildError::InvalidPrefix(net.prefix));
}
let mask = net.ipv4_mask().expect("ipv4 net has mask");
let target = net.as_ipv4_u32().expect("ipv4 net has u32 addr") & mask;
let src_off = 26u32 + u32::from(ctx.vlan_offset);
let dst_off = 30u32 + u32::from(ctx.vlan_offset);
match sd {
SrcDst::Src => {
out.push(SymInsn::straight(BPF_LD_W_ABS, src_off));
out.push(SymInsn::straight(BPF_ALU_AND_K, mask));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
target,
));
}
SrcDst::Dst => {
out.push(SymInsn::straight(BPF_LD_W_ABS, dst_off));
out.push(SymInsn::straight(BPF_ALU_AND_K, mask));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
target,
));
}
SrcDst::Any => {
out.push(SymInsn::straight(BPF_LD_W_ABS, src_off));
out.push(SymInsn::straight(BPF_ALU_AND_K, mask));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::SkipNextN(3),
Label::Fallthrough,
target,
));
out.push(SymInsn::straight(BPF_LD_W_ABS, dst_off));
out.push(SymInsn::straight(BPF_ALU_AND_K, mask));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
target,
));
}
}
Ok(())
}
fn emit_l4_port(
sd: SrcDst,
port: u16,
ctx: &CompileCtx,
out: &mut Vec<SymInsn>,
) -> Result<(), BuildError> {
match ctx.l3 {
L3Family::Ipv4 => emit_ipv4_port(sd, port, ctx, out),
L3Family::Ipv6 => emit_ipv6_port(sd, port, ctx, out),
}
}
fn emit_ipv4_port(
sd: SrcDst,
port: u16,
ctx: &CompileCtx,
out: &mut Vec<SymInsn>,
) -> Result<(), BuildError> {
let frag_off = 20u32 + u32::from(ctx.vlan_offset); let ihl_off = 14u32 + u32::from(ctx.vlan_offset); let l4_base = 14u32 + u32::from(ctx.vlan_offset);
out.push(SymInsn::straight(BPF_LD_H_ABS, frag_off));
out.push(SymInsn::jump(
BPF_JMP_JSET_K,
Label::Drop,
Label::Fallthrough,
0x1FFF,
));
out.push(SymInsn::straight(BPF_LDX_B_MSH, ihl_off));
match sd {
SrcDst::Src => {
out.push(SymInsn::straight(BPF_LD_H_IND, l4_base ));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
u32::from(port),
));
}
SrcDst::Dst => {
out.push(SymInsn::straight(BPF_LD_H_IND, l4_base + 2));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
u32::from(port),
));
}
SrcDst::Any => {
out.push(SymInsn::straight(BPF_LD_H_IND, l4_base));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::SkipNextN(2),
Label::Fallthrough,
u32::from(port),
));
out.push(SymInsn::straight(BPF_LD_H_IND, l4_base + 2));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
u32::from(port),
));
}
}
Ok(())
}
fn emit_ipv6_port(
sd: SrcDst,
port: u16,
ctx: &CompileCtx,
out: &mut Vec<SymInsn>,
) -> Result<(), BuildError> {
let src_off = 54u32 + u32::from(ctx.vlan_offset);
let dst_off = 56u32 + u32::from(ctx.vlan_offset);
match sd {
SrcDst::Src => {
out.push(SymInsn::straight(BPF_LD_H_ABS, src_off));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
u32::from(port),
));
}
SrcDst::Dst => {
out.push(SymInsn::straight(BPF_LD_H_ABS, dst_off));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
u32::from(port),
));
}
SrcDst::Any => {
out.push(SymInsn::straight(BPF_LD_H_ABS, src_off));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::SkipNextN(2),
Label::Fallthrough,
u32::from(port),
));
out.push(SymInsn::straight(BPF_LD_H_ABS, dst_off));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
u32::from(port),
));
}
}
Ok(())
}
fn emit_ipv6_host(
sd: SrcDst,
addr_octets: [u8; 16],
ctx: &CompileCtx,
out: &mut Vec<SymInsn>,
) -> Result<(), BuildError> {
let src_off = 22u32 + u32::from(ctx.vlan_offset);
let dst_off = 38u32 + u32::from(ctx.vlan_offset);
let words: [u32; 4] = [
u32::from_be_bytes([
addr_octets[0],
addr_octets[1],
addr_octets[2],
addr_octets[3],
]),
u32::from_be_bytes([
addr_octets[4],
addr_octets[5],
addr_octets[6],
addr_octets[7],
]),
u32::from_be_bytes([
addr_octets[8],
addr_octets[9],
addr_octets[10],
addr_octets[11],
]),
u32::from_be_bytes([
addr_octets[12],
addr_octets[13],
addr_octets[14],
addr_octets[15],
]),
];
fn emit_one_dir(out: &mut Vec<SymInsn>, base: u32, words: &[u32; 4]) {
for (i, &w) in words.iter().enumerate() {
out.push(SymInsn::straight(BPF_LD_W_ABS, base + (i as u32) * 4));
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
Label::Fallthrough,
Label::Drop,
w,
));
}
}
match sd {
SrcDst::Src => emit_one_dir(out, src_off, &words),
SrcDst::Dst => emit_one_dir(out, dst_off, &words),
SrcDst::Any => {
for i in 0..4u8 {
let remaining_src = (3 - i) * 2 + 1; let _ = remaining_src; }
for (i, &word) in words.iter().enumerate() {
out.push(SymInsn::straight(BPF_LD_W_ABS, src_off + (i as u32) * 4));
let jf_skip = (3 - i as u8) * 2; let jt_skip = if i == 3 {
Label::SkipNextN(8)
} else {
Label::Fallthrough
};
out.push(SymInsn::jump(
BPF_JMP_JEQ_K,
jt_skip,
Label::SkipNextN(jf_skip),
word,
));
}
emit_one_dir(out, dst_off, &words);
}
}
Ok(())
}
fn resolve(
sym: &[SymInsn],
accept_pc: usize,
drop_pc: usize,
branch_starts: &[usize],
) -> Result<Vec<BpfInsn>, BuildError> {
let mut out = Vec::with_capacity(sym.len());
for (pc, insn) in sym.iter().enumerate() {
if insn.code == BPF_JMP_JA {
let target = label_target(insn.jt, pc, accept_pc, drop_pc, branch_starts)?;
let rel = target.checked_sub(pc + 1).ok_or(BuildError::JumpTooFar)?;
out.push(BpfInsn {
code: insn.code,
jt: 0,
jf: 0,
k: rel as u32,
});
} else {
let jt = resolve_label(insn.jt, pc, accept_pc, drop_pc, branch_starts)?;
let jf = resolve_label(insn.jf, pc, accept_pc, drop_pc, branch_starts)?;
out.push(BpfInsn {
code: insn.code,
jt,
jf,
k: insn.k,
});
}
}
Ok(out)
}
fn resolve_label(
label: Label,
pc: usize,
accept_pc: usize,
drop_pc: usize,
branch_starts: &[usize],
) -> Result<u8, BuildError> {
let target = label_target(label, pc, accept_pc, drop_pc, branch_starts)?;
let dist = target.checked_sub(pc + 1).ok_or(BuildError::JumpTooFar)?;
u8::try_from(dist).map_err(|_| BuildError::JumpTooFar)
}
fn label_target(
label: Label,
pc: usize,
accept_pc: usize,
drop_pc: usize,
branch_starts: &[usize],
) -> Result<usize, BuildError> {
match label {
Label::Fallthrough => Ok(pc + 1),
Label::Accept => Ok(accept_pc),
Label::Drop => Ok(drop_pc),
Label::SkipNextN(n) => Ok(pc.saturating_add(n as usize + 1)),
Label::Branch(id) => branch_starts
.get(id as usize)
.copied()
.ok_or(BuildError::JumpTooFar),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_builder_compiles_to_accept_only() {
let b = BpfFilterBuilder::new();
let f = compile(b).unwrap();
assert_eq!(f.len(), 1);
let insn = f.instructions()[0];
assert_eq!(insn.code, BPF_RET_K);
assert_eq!(insn.k, ACCEPT_RETVAL);
}
#[test]
fn ipv4_compiles() {
let b = BpfFilterBuilder::new().ipv4();
let f = compile(b).unwrap();
assert_eq!(f.len(), 4);
assert_eq!(f.instructions()[0].code, BPF_LD_H_ABS);
assert_eq!(f.instructions()[0].k, 12);
assert_eq!(f.instructions()[1].code, BPF_JMP_JEQ_K);
assert_eq!(f.instructions()[1].k, 0x0800);
assert_eq!(f.instructions()[2].code, BPF_RET_K);
assert_eq!(f.instructions()[2].k, ACCEPT_RETVAL);
assert_eq!(f.instructions()[3].code, BPF_RET_K);
assert_eq!(f.instructions()[3].k, DROP_RETVAL);
}
#[test]
fn tcp_auto_inserts_ipv4_check() {
let b = BpfFilterBuilder::new().tcp();
let f = compile(b).unwrap();
assert_eq!(f.len(), 6);
assert_eq!(f.instructions()[0].k, 12);
assert_eq!(f.instructions()[1].k, 0x0800);
assert_eq!(f.instructions()[2].code, BPF_LD_B_ABS);
assert_eq!(f.instructions()[2].k, 23);
assert_eq!(f.instructions()[3].k, 6);
}
#[test]
fn conflicting_protocols_rejected() {
let b = BpfFilterBuilder::new().tcp().udp();
let err = compile(b).unwrap_err();
assert!(matches!(err, BuildError::ConflictingProtocols { .. }));
}
#[test]
fn duplicate_tcp_dedups() {
let b = BpfFilterBuilder::new().tcp().tcp();
let f = compile(b).unwrap();
assert_eq!(f.len(), 6);
}
#[test]
fn port_without_proto_errors() {
let b = BpfFilterBuilder::new().port(80);
let err = compile(b).unwrap_err();
assert!(matches!(err, BuildError::ConflictingProtocols { .. }));
}
#[test]
fn jump_offsets_resolve() {
let b = BpfFilterBuilder::new().ipv4();
let f = compile(b).unwrap();
let jeq = f.instructions()[1];
assert_eq!(jeq.code, BPF_JMP_JEQ_K);
assert_eq!(jeq.jt, 0);
assert_eq!(jeq.jf, 1);
}
#[test]
fn or_two_branches_compiles() {
let b = BpfFilterBuilder::new()
.tcp()
.port(80)
.or(|b| b.udp().port(53));
let f = compile(b).unwrap();
assert!(f.len() > 10);
assert_eq!(f.instructions()[0].code, BPF_LD_H_ABS);
assert_eq!(f.instructions()[0].k, 12);
}
#[test]
fn negate_produces_same_length_program() {
let b = BpfFilterBuilder::new().tcp();
let plain = compile(b.clone()).unwrap();
let negated = compile(b.negate()).unwrap();
assert_eq!(plain.len(), negated.len());
}
#[test]
fn empty_or_branch_rejected() {
let b = BpfFilterBuilder::new().tcp().or(|b| b);
let err = compile(b).unwrap_err();
assert!(matches!(err, BuildError::EmptyOr));
}
#[test]
fn nested_or_in_branch_rejected() {
let b = BpfFilterBuilder::new()
.tcp()
.or(|b| b.udp().or(|b| b.icmp()));
let err = compile(b).unwrap_err();
assert!(matches!(err, BuildError::ConflictingProtocols { .. }));
}
}