use std::{
io,
os::fd::{AsFd, AsRawFd as _, RawFd},
};
use aya_obj::generated::{
SO_ATTACH_BPF, SO_DETACH_BPF, bpf_prog_type::BPF_PROG_TYPE_SOCKET_FILTER,
};
use libc::{SOL_SOCKET, setsockopt};
use thiserror::Error;
use crate::programs::{Link, ProgramData, ProgramError, ProgramType, id_as_key, load_program};
#[derive(Debug, Error)]
pub enum SocketFilterError {
#[error("setsockopt SO_ATTACH_BPF failed")]
SoAttachEbpfError {
#[source]
io_error: io::Error,
},
}
#[derive(Debug)]
#[doc(alias = "BPF_PROG_TYPE_SOCKET_FILTER")]
pub struct SocketFilter {
pub(crate) data: ProgramData<SocketFilterLink>,
}
impl SocketFilter {
pub const PROGRAM_TYPE: ProgramType = ProgramType::SocketFilter;
pub fn load(&mut self) -> Result<(), ProgramError> {
load_program(BPF_PROG_TYPE_SOCKET_FILTER, &mut self.data)
}
pub fn attach<T: AsFd>(&mut self, socket: T) -> Result<SocketFilterLinkId, ProgramError> {
let prog_fd = self.fd()?;
let prog_fd = prog_fd.as_fd();
let prog_fd = prog_fd.as_raw_fd();
let socket = socket.as_fd();
let socket = socket.as_raw_fd();
let ret = unsafe {
setsockopt(
socket,
SOL_SOCKET,
SO_ATTACH_BPF as i32,
std::ptr::from_ref(&prog_fd).cast(),
size_of_val(&prog_fd) as u32,
)
};
if ret < 0 {
return Err(SocketFilterError::SoAttachEbpfError {
io_error: io::Error::last_os_error(),
}
.into());
}
self.data.links.insert(SocketFilterLink { socket, prog_fd })
}
pub fn detach(&mut self, link_id: SocketFilterLinkId) -> Result<(), ProgramError> {
self.data.links.remove(link_id)
}
pub fn take_link(
&mut self,
link_id: SocketFilterLinkId,
) -> Result<SocketFilterLink, ProgramError> {
self.data.links.forget(link_id)
}
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct SocketFilterLinkId(RawFd, RawFd);
#[derive(Debug)]
pub struct SocketFilterLink {
socket: RawFd,
prog_fd: RawFd,
}
impl Link for SocketFilterLink {
type Id = SocketFilterLinkId;
fn id(&self) -> Self::Id {
SocketFilterLinkId(self.socket, self.prog_fd)
}
fn detach(self) -> Result<(), ProgramError> {
unsafe {
setsockopt(
self.socket,
SOL_SOCKET,
SO_DETACH_BPF as i32,
std::ptr::from_ref(&self.prog_fd).cast(),
size_of_val(&self.prog_fd) as u32,
);
}
Ok(())
}
}
id_as_key!(SocketFilterLink, SocketFilterLinkId);