use std::collections::{BTreeSet, HashSet};
use std::hash::{BuildHasher as _, Hash as _, Hasher as _};
use std::path::{Path, PathBuf};
use std::process::Command;
pub(crate) fn generate_btf_anchor(
bpf_object_dir: &Path,
clang: &str,
cflags: &[String],
anchor_path: &Path,
) -> Option<PathBuf> {
let mut bpf_sources = discover_sources_from_objects(bpf_object_dir);
if bpf_sources.is_empty() {
tracing::debug!("btf_anchor: no .bpf.c sources found via BTF");
return None;
}
tracing::debug!(
sources = bpf_sources.len(),
"btf_anchor: discovered BPF sources via BTF"
);
bpf_sources.sort();
let input_hash = {
let mut h = ahash::RandomState::with_seeds(0x6b74, 0x7374, 0x7200, 0x616e).build_hasher();
env!("CARGO_PKG_VERSION").hash(&mut h);
for p in &bpf_sources {
p.to_string_lossy().hash(&mut h);
}
for cflag in cflags {
cflag.hash(&mut h);
}
if let Ok(entries) = std::fs::read_dir(bpf_object_dir) {
let mut sizes: Vec<(String, u64)> = entries
.flatten()
.filter_map(|e| {
let name = e.file_name().to_string_lossy().to_string();
if name.ends_with(".bpf.o") {
e.metadata().ok().map(|m| (name, m.len()))
} else {
None
}
})
.collect();
sizes.sort();
for (name, size) in &sizes {
name.hash(&mut h);
size.hash(&mut h);
}
}
h.finish()
};
if let Some(old_hash) = read_anchor_hash(anchor_path)
&& old_hash == input_hash
{
tracing::debug!("btf_anchor: cached anchor is current");
let abs = std::fs::canonicalize(anchor_path).unwrap_or_else(|_| anchor_path.to_path_buf());
return Some(abs);
}
let dep_files = collect_dep_files(&bpf_sources, clang, cflags);
if dep_files.is_empty() {
tracing::debug!("btf_anchor: clang -M produced no dep files");
return None;
}
tracing::debug!(
files = dep_files.len(),
"btf_anchor: collected dep files via clang -M"
);
let structs = extract_struct_names(&dep_files);
if structs.is_empty() {
tracing::debug!("btf_anchor: no struct definitions found");
return None;
}
tracing::debug!(
structs = structs.len(),
"btf_anchor: extracted struct definitions"
);
if let Some(parent) = anchor_path.parent() {
let _ = std::fs::create_dir_all(parent);
}
write_anchor_header(anchor_path, &structs, input_hash)?;
let abs = std::fs::canonicalize(anchor_path).unwrap_or_else(|_| anchor_path.to_path_buf());
Some(abs)
}
fn discover_sources_from_objects(dir: &Path) -> Vec<PathBuf> {
let mut sources: HashSet<PathBuf> = HashSet::new();
let Ok(entries) = std::fs::read_dir(dir) else {
return Vec::new();
};
for entry in entries.flatten() {
let path = entry.path();
let Some(name) = path.file_name().and_then(|n| n.to_str()) else {
continue;
};
if !name.ends_with(".bpf.o") || name == "bpf.bpf.o" {
continue;
}
let Ok(bytes) = std::fs::read(&path) else {
continue;
};
if let Some(btf_data) = find_btf_section_raw(&bytes) {
for s in btf_strings(btf_data) {
if s.ends_with(".bpf.c") {
let p = PathBuf::from(s);
if p.is_file()
&& let Ok(canonical) = std::fs::canonicalize(&p)
{
sources.insert(canonical);
}
}
}
}
}
sources.into_iter().collect()
}
fn find_btf_section_raw(bytes: &[u8]) -> Option<&[u8]> {
if bytes.len() < 64 {
return None;
}
let e_shoff = u64::from_le_bytes(bytes[40..48].try_into().ok()?) as usize;
let e_shentsize = u16::from_le_bytes(bytes[58..60].try_into().ok()?) as usize;
let e_shnum = u16::from_le_bytes(bytes[60..62].try_into().ok()?) as usize;
let e_shstrndx = u16::from_le_bytes(bytes[62..64].try_into().ok()?) as usize;
if e_shstrndx >= e_shnum || e_shentsize < 64 {
return None;
}
let strtab_base = e_shoff + e_shstrndx * e_shentsize;
if strtab_base + 64 > bytes.len() {
return None;
}
let strtab_off =
u64::from_le_bytes(bytes[strtab_base + 24..strtab_base + 32].try_into().ok()?) as usize;
let strtab_size =
u64::from_le_bytes(bytes[strtab_base + 32..strtab_base + 40].try_into().ok()?) as usize;
if strtab_off + strtab_size > bytes.len() {
return None;
}
let strtab = &bytes[strtab_off..strtab_off + strtab_size];
for i in 0..e_shnum {
let base = e_shoff + i * e_shentsize;
if base + 64 > bytes.len() {
break;
}
let sh_name = u32::from_le_bytes(bytes[base..base + 4].try_into().ok()?) as usize;
if sh_name + 4 >= strtab.len() {
continue;
}
if &strtab[sh_name..sh_name + 4] != b".BTF" {
continue;
}
if sh_name + 4 < strtab.len() && strtab[sh_name + 4] != 0 {
continue;
}
let sh_offset = u64::from_le_bytes(bytes[base + 24..base + 32].try_into().ok()?) as usize;
let sh_size = u64::from_le_bytes(bytes[base + 32..base + 40].try_into().ok()?) as usize;
if sh_offset + sh_size <= bytes.len() && sh_size >= 24 {
return Some(&bytes[sh_offset..sh_offset + sh_size]);
}
}
None
}
fn btf_strings(btf: &[u8]) -> Vec<&str> {
if btf.len() < 24 {
return Vec::new();
}
let hdr_len = u32::from_le_bytes([btf[4], btf[5], btf[6], btf[7]]) as usize;
let str_off = u32::from_le_bytes([btf[16], btf[17], btf[18], btf[19]]) as usize;
let str_len = u32::from_le_bytes([btf[20], btf[21], btf[22], btf[23]]) as usize;
let str_start = hdr_len + str_off;
let str_end = str_start + str_len;
if str_end > btf.len() {
return Vec::new();
}
let str_section = &btf[str_start..str_end];
let mut result = Vec::new();
for chunk in str_section.split(|&b| b == 0) {
if let Ok(s) = std::str::from_utf8(chunk)
&& !s.is_empty()
{
result.push(s);
}
}
result
}
fn collect_dep_files(sources: &[PathBuf], clang: &str, cflags: &[String]) -> Vec<PathBuf> {
let all_deps = std::sync::Mutex::new(HashSet::<PathBuf>::new());
std::thread::scope(|s| {
for source in sources {
let deps_ref = &all_deps;
s.spawn(move || {
let output = Command::new(clang)
.arg("-M")
.arg("-MG")
.arg("-target")
.arg("bpf")
.args(cflags)
.arg(source)
.output();
let Ok(output) = output else { return };
if !output.status.success() {
return;
}
let mut local = HashSet::new();
let stdout = String::from_utf8_lossy(&output.stdout);
let joined = stdout.replace("\\\n", " ");
for line in joined.lines() {
let deps_part = match line.split_once(':') {
Some((_, deps)) => deps,
None => line,
};
for token in deps_part.split_whitespace() {
let p = PathBuf::from(token);
if p.is_file()
&& let Ok(canonical) = std::fs::canonicalize(&p)
&& !is_system_header(&canonical)
{
local.insert(canonical);
}
}
}
deps_ref.lock().unwrap().extend(local);
});
}
});
all_deps.into_inner().unwrap().into_iter().collect()
}
fn is_system_header(path: &Path) -> bool {
let s = path.to_string_lossy();
if s.contains("/usr/include/") || s.contains("/usr/lib/") {
return true;
}
if let Some(name) = path.file_name().and_then(|n| n.to_str())
&& (name == "vmlinux.h" || name == "vmlinux.bpf.h")
{
return true;
}
if s.contains("scx_utils-bpf_h/") {
return true;
}
false
}
fn extract_struct_names(files: &[PathBuf]) -> BTreeSet<String> {
let mut names = BTreeSet::new();
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_c::LANGUAGE.into())
.expect("tree-sitter-c language");
for file in files {
let Ok(content) = std::fs::read_to_string(file) else {
continue;
};
let Some(tree) = parser.parse(&content, None) else {
continue;
};
collect_structs(tree.root_node(), content.as_bytes(), &mut names);
}
names
}
fn collect_structs(node: tree_sitter::Node, source: &[u8], names: &mut BTreeSet<String>) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "struct_specifier"
&& child.child_by_field_name("body").is_some()
&& let Some(name_node) = child.child_by_field_name("name")
&& let Ok(name) = std::str::from_utf8(&source[name_node.byte_range()])
&& !name.is_empty()
&& !name.starts_with("__")
{
names.insert(name.to_string());
}
collect_structs(child, source, names);
}
}
fn read_anchor_hash(path: &Path) -> Option<u64> {
let content = std::fs::read_to_string(path).ok()?;
let line = content.lines().find(|l| l.starts_with("/* ktstr_hash="))?;
let hex = line.strip_prefix("/* ktstr_hash=")?.strip_suffix(" */")?;
u64::from_str_radix(hex, 16).ok()
}
fn write_anchor_header(path: &Path, structs: &BTreeSet<String>, hash: u64) -> Option<()> {
let mut src = String::new();
src.push_str(&format!("/* ktstr_hash={hash:016x} */\n"));
src.push_str("#ifndef __KTSTR_BTF_ANCHOR_H\n");
src.push_str("#define __KTSTR_BTF_ANCHOR_H\n");
for (i, s) in structs.iter().enumerate() {
src.push_str(&format!(
"struct {s} __attribute__((weak)) *__ktstr_keep_{i};\n"
));
}
src.push_str("#endif\n");
std::fs::write(path, &src).ok()
}