use std::os::fd::{AsFd as _, BorrowedFd};
use object::Endianness;
use thiserror::Error;
use crate::{
generated::{bpf_attach_type::BPF_CGROUP_INET_INGRESS, bpf_prog_type::BPF_PROG_TYPE_EXT},
obj::btf::BtfKind,
programs::{
define_link_wrapper, load_program, FdLink, FdLinkId, ProgramData, ProgramError, ProgramFd,
},
sys::{self, bpf_link_create, LinkTarget, SyscallError},
Btf,
};
#[derive(Debug, Error)]
pub enum ExtensionError {
#[error("target BPF program does not have BTF loaded to the kernel")]
NoBTF,
}
#[derive(Debug)]
#[doc(alias = "BPF_PROG_TYPE_EXT")]
pub struct Extension {
pub(crate) data: ProgramData<ExtensionLink>,
}
impl Extension {
pub fn load(&mut self, program: ProgramFd, func_name: &str) -> Result<(), ProgramError> {
let (btf_fd, btf_id) = get_btf_info(program.as_fd(), func_name)?;
self.data.attach_btf_obj_fd = Some(btf_fd);
self.data.attach_prog_fd = Some(program);
self.data.attach_btf_id = Some(btf_id);
load_program(BPF_PROG_TYPE_EXT, &mut self.data)
}
pub fn attach(&mut self) -> Result<ExtensionLinkId, ProgramError> {
let prog_fd = self.fd()?;
let prog_fd = prog_fd.as_fd();
let target_fd = self
.data
.attach_prog_fd
.as_ref()
.ok_or(ProgramError::NotLoaded)?;
let target_fd = target_fd.as_fd();
let btf_id = self.data.attach_btf_id.ok_or(ProgramError::NotLoaded)?;
let link_fd = bpf_link_create(
prog_fd,
LinkTarget::Fd(target_fd),
BPF_CGROUP_INET_INGRESS,
Some(btf_id),
0,
None,
)
.map_err(|(_, io_error)| SyscallError {
call: "bpf_link_create",
io_error,
})?;
self.data
.links
.insert(ExtensionLink::new(FdLink::new(link_fd)))
}
pub fn attach_to_program(
&mut self,
program: &ProgramFd,
func_name: &str,
) -> Result<ExtensionLinkId, ProgramError> {
let target_fd = program.as_fd();
let (_, btf_id) = get_btf_info(target_fd, func_name)?;
let prog_fd = self.fd()?;
let prog_fd = prog_fd.as_fd();
let link_fd = bpf_link_create(
prog_fd,
LinkTarget::Fd(target_fd),
BPF_CGROUP_INET_INGRESS,
Some(btf_id),
0,
None,
)
.map_err(|(_, io_error)| SyscallError {
call: "bpf_link_create",
io_error,
})?;
self.data
.links
.insert(ExtensionLink::new(FdLink::new(link_fd)))
}
pub fn detach(&mut self, link_id: ExtensionLinkId) -> Result<(), ProgramError> {
self.data.links.remove(link_id)
}
pub fn take_link(&mut self, link_id: ExtensionLinkId) -> Result<ExtensionLink, ProgramError> {
self.data.take_link(link_id)
}
}
fn get_btf_info(
prog_fd: BorrowedFd<'_>,
func_name: &str,
) -> Result<(crate::MockableFd, u32), ProgramError> {
let info = sys::bpf_prog_get_info_by_fd(prog_fd, &mut [])?;
if info.btf_id == 0 {
return Err(ProgramError::ExtensionError(ExtensionError::NoBTF));
}
let btf_fd = sys::bpf_btf_get_fd_by_id(info.btf_id)?;
let mut buf = vec![0u8; 4096];
loop {
let info = sys::btf_obj_get_info_by_fd(btf_fd.as_fd(), &mut buf)?;
let btf_size = info.btf_size as usize;
if btf_size > buf.len() {
buf.resize(btf_size, 0u8);
continue;
}
buf.truncate(btf_size);
break;
}
let btf = Btf::parse(&buf, Endianness::default()).map_err(ProgramError::Btf)?;
let btf_id = btf
.id_by_type_name_kind(func_name, BtfKind::Func)
.map_err(ProgramError::Btf)?;
Ok((btf_fd, btf_id))
}
define_link_wrapper!(
ExtensionLink,
ExtensionLinkId,
FdLink,
FdLinkId
);