use serde::{Deserialize, Serialize};
use std::fmt::Write;
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
pub enum LowerError {
#[error("expected `(bpf-fn name (params) body…)`, got `{0}`")]
BadShape(String),
#[error("unknown form `{0}` — bpf-fn supports only the verifier-friendly subset")]
UnknownForm(String),
#[error("unknown helper `{0}` — add it to BPF_HELPERS or call directly via aya")]
UnknownHelper(String),
#[error("unknown return action `:{0}` — see RETURN_ACTIONS for the supported set")]
UnknownReturnAction(String),
#[error("`(let)` requires a binding-list with one or more (name expr) pairs")]
BadLet,
#[error("`(if cond then else)` requires exactly three sub-forms")]
BadIf,
#[error("comparison `{0}` requires exactly two operands")]
BadCompare(String),
#[error("map operation `{0}` requires {1} args, got {2}")]
BadMapOp(&'static str, usize, usize),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct BpfFn {
pub name: String,
pub ctx: String,
pub body: Vec<BpfExpr>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "form", rename_all = "kebab-case")]
pub enum BpfExpr {
Return { action: String },
Call { helper: String, args: Vec<BpfExpr> },
Let {
name: String,
value: Box<BpfExpr>,
body: Vec<BpfExpr>,
},
If {
cond: Box<BpfExpr>,
then: Box<BpfExpr>,
otherwise: Box<BpfExpr>,
},
Compare {
op: CompareOp,
left: Box<BpfExpr>,
right: Box<BpfExpr>,
},
MapGet { map: String, key: Box<BpfExpr> },
MapSet {
map: String,
key: Box<BpfExpr>,
value: Box<BpfExpr>,
},
Int(i64),
Var(String),
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum CompareOp {
Eq,
Ne,
Lt,
Gt,
Le,
Ge,
}
impl CompareOp {
fn rust_op(self) -> &'static str {
match self {
Self::Eq => "==",
Self::Ne => "!=",
Self::Lt => "<",
Self::Gt => ">",
Self::Le => "<=",
Self::Ge => ">=",
}
}
}
pub const BPF_HELPERS: &[(&str, &str)] = &[
("get-current-cpu", "aya_ebpf::helpers::bpf_get_smp_processor_id"),
("get-current-pid-tgid", "aya_ebpf::helpers::bpf_get_current_pid_tgid"),
("get-current-uid-gid", "aya_ebpf::helpers::bpf_get_current_uid_gid"),
("get-prandom", "aya_ebpf::helpers::bpf_get_prandom_u32"),
("ktime-ns", "aya_ebpf::helpers::bpf_ktime_get_ns"),
];
pub const RETURN_ACTIONS: &[(&str, &str)] = &[
("xdp-pass", "aya_ebpf::bindings::xdp_action::XDP_PASS"),
("xdp-drop", "aya_ebpf::bindings::xdp_action::XDP_DROP"),
("xdp-tx", "aya_ebpf::bindings::xdp_action::XDP_TX"),
("xdp-redirect", "aya_ebpf::bindings::xdp_action::XDP_REDIRECT"),
("xdp-aborted", "aya_ebpf::bindings::xdp_action::XDP_ABORTED"),
("tc-act-ok", "aya_ebpf::bindings::TC_ACT_OK as i32"),
("tc-act-shot", "aya_ebpf::bindings::TC_ACT_SHOT as i32"),
("tc-act-redirect", "aya_ebpf::bindings::TC_ACT_REDIRECT as i32"),
("ok", "0"),
("err", "1"),
];
pub fn lower(f: &BpfFn) -> Result<String, LowerError> {
let mut out = String::new();
let _ = writeln!(out, "pub fn {}(ctx: aya_ebpf::programs::XdpContext) -> u32 {{", f.name);
let mut indent = 1;
let last = f.body.len().saturating_sub(1);
for (i, expr) in f.body.iter().enumerate() {
let line = lower_expr(expr, indent)?;
let suffix = if i == last { "" } else { ";" };
let _ = writeln!(out, "{}{line}{suffix}", " ".repeat(indent));
}
indent -= 1;
let _ = writeln!(out, "{}}}", " ".repeat(indent));
Ok(out)
}
fn lower_expr(e: &BpfExpr, indent: usize) -> Result<String, LowerError> {
let pad = " ".repeat(indent);
match e {
BpfExpr::Int(n) => Ok(format!("{n}_i64")),
BpfExpr::Var(name) => Ok(rust_name(name)),
BpfExpr::Return { action } => {
let mapped = RETURN_ACTIONS
.iter()
.find_map(|(k, v)| (*k == action).then_some(*v))
.ok_or_else(|| LowerError::UnknownReturnAction(action.clone()))?;
Ok(format!("return {mapped}"))
}
BpfExpr::Call { helper, args } => {
let mapped = BPF_HELPERS
.iter()
.find_map(|(k, v)| (*k == helper).then_some(*v))
.ok_or_else(|| LowerError::UnknownHelper(helper.clone()))?;
let arg_str = args
.iter()
.map(|a| lower_expr(a, indent))
.collect::<Result<Vec<_>, _>>()?
.join(", ");
Ok(format!("unsafe {{ {mapped}({arg_str}) }}"))
}
BpfExpr::Let { name, value, body } => {
let v = lower_expr(value, indent)?;
let mut buf = String::new();
let _ = writeln!(buf, "let {} = {};", rust_name(name), v);
for (i, inner) in body.iter().enumerate() {
let inner_str = lower_expr(inner, indent)?;
let suffix = if i == body.len() - 1 { "" } else { ";" };
let _ = writeln!(buf, "{pad}{inner_str}{suffix}");
}
Ok(format!("{{ {} }}", buf.trim_end()))
}
BpfExpr::If { cond, then, otherwise } => {
let c = lower_expr(cond, indent)?;
let t = lower_expr(then, indent + 1)?;
let o = lower_expr(otherwise, indent + 1)?;
Ok(format!("if {c} {{ {t} }} else {{ {o} }}"))
}
BpfExpr::Compare { op, left, right } => {
let l = lower_expr(left, indent)?;
let r = lower_expr(right, indent)?;
Ok(format!("({l} {} {r})", op.rust_op()))
}
BpfExpr::MapGet { map, key } => {
let k = lower_expr(key, indent)?;
Ok(format!(
"unsafe {{ {}.get(&{k}) }}",
rust_static_name(map)
))
}
BpfExpr::MapSet { map, key, value } => {
let k = lower_expr(key, indent)?;
let v = lower_expr(value, indent)?;
Ok(format!(
"unsafe {{ {}.insert(&{k}, &{v}, 0) }}",
rust_static_name(map)
))
}
}
}
fn rust_name(s: &str) -> String {
s.replace('-', "_")
}
fn rust_static_name(s: &str) -> String {
s.replace('-', "_").to_uppercase()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lowers_literal_return() {
let f = BpfFn {
name: "drop_all".into(),
ctx: "ctx".into(),
body: vec![BpfExpr::Return {
action: "xdp-drop".into(),
}],
};
let src = lower(&f).unwrap();
assert!(src.contains("pub fn drop_all"));
assert!(src.contains("XDP_DROP"));
}
#[test]
fn lowers_helper_call_in_let() {
let f = BpfFn {
name: "tag_cpu".into(),
ctx: "ctx".into(),
body: vec![BpfExpr::Let {
name: "cpu-id".into(),
value: Box::new(BpfExpr::Call {
helper: "get-current-cpu".into(),
args: vec![],
}),
body: vec![BpfExpr::Return {
action: "xdp-pass".into(),
}],
}],
};
let src = lower(&f).unwrap();
assert!(src.contains("let cpu_id = unsafe { aya_ebpf::helpers::bpf_get_smp_processor_id"));
assert!(src.contains("XDP_PASS"));
}
#[test]
fn lowers_if_with_compare() {
let f = BpfFn {
name: "branch".into(),
ctx: "ctx".into(),
body: vec![BpfExpr::If {
cond: Box::new(BpfExpr::Compare {
op: CompareOp::Eq,
left: Box::new(BpfExpr::Int(42)),
right: Box::new(BpfExpr::Int(42)),
}),
then: Box::new(BpfExpr::Return {
action: "xdp-pass".into(),
}),
otherwise: Box::new(BpfExpr::Return {
action: "xdp-drop".into(),
}),
}],
};
let src = lower(&f).unwrap();
assert!(src.contains("if (42_i64 == 42_i64)"));
assert!(src.contains("XDP_PASS"));
assert!(src.contains("XDP_DROP"));
}
#[test]
fn lowers_map_get_and_set() {
let body = vec![
BpfExpr::MapSet {
map: "syn-counter".into(),
key: Box::new(BpfExpr::Int(0)),
value: Box::new(BpfExpr::Int(1)),
},
BpfExpr::Return {
action: "xdp-pass".into(),
},
];
let f = BpfFn {
name: "counter_inc".into(),
ctx: "ctx".into(),
body,
};
let src = lower(&f).unwrap();
assert!(src.contains("SYN_COUNTER.insert(&0_i64, &1_i64, 0)"));
}
#[test]
fn rejects_unknown_helper() {
let f = BpfFn {
name: "bad".into(),
ctx: "ctx".into(),
body: vec![BpfExpr::Call {
helper: "wat".into(),
args: vec![],
}],
};
let err = lower(&f).unwrap_err();
assert!(matches!(err, LowerError::UnknownHelper(_)));
}
#[test]
fn rejects_unknown_return_action() {
let f = BpfFn {
name: "bad".into(),
ctx: "ctx".into(),
body: vec![BpfExpr::Return {
action: "make-up-kernel".into(),
}],
};
let err = lower(&f).unwrap_err();
assert!(matches!(err, LowerError::UnknownReturnAction(_)));
}
#[test]
fn lowering_round_trips_via_serde_json() {
let f = BpfFn {
name: "round_trip".into(),
ctx: "ctx".into(),
body: vec![BpfExpr::Return {
action: "xdp-pass".into(),
}],
};
let json = serde_json::to_value(&f).unwrap();
let back: BpfFn = serde_json::from_value(json).unwrap();
assert_eq!(f, back);
}
}