use libc::{setsockopt, SOL_SOCKET};
use std::{
io, mem,
os::unix::prelude::{AsRawFd, RawFd},
};
use thiserror::Error;
use crate::{
generated::{bpf_prog_type::BPF_PROG_TYPE_SOCKET_FILTER, SO_ATTACH_BPF, SO_DETACH_BPF},
programs::{load_program, Link, OwnedLink, ProgramData, ProgramError},
};
#[derive(Debug, Error)]
pub enum SocketFilterError {
#[error("setsockopt SO_ATTACH_BPF failed")]
SoAttachBpfError {
#[source]
io_error: io::Error,
},
}
#[derive(Debug)]
#[doc(alias = "BPF_PROG_TYPE_SOCKET_FILTER")]
pub struct SocketFilter {
pub(crate) data: ProgramData<SocketFilterLink>,
}
impl SocketFilter {
pub fn load(&mut self) -> Result<(), ProgramError> {
load_program(BPF_PROG_TYPE_SOCKET_FILTER, &mut self.data)
}
pub fn attach<T: AsRawFd>(&mut self, socket: T) -> Result<SocketFilterLinkId, ProgramError> {
let prog_fd = self.data.fd_or_err()?;
let socket = socket.as_raw_fd();
let ret = unsafe {
setsockopt(
socket,
SOL_SOCKET,
SO_ATTACH_BPF as i32,
&prog_fd as *const _ as *const _,
mem::size_of::<RawFd>() as u32,
)
};
if ret < 0 {
return Err(SocketFilterError::SoAttachBpfError {
io_error: io::Error::last_os_error(),
}
.into());
}
self.data.links.insert(SocketFilterLink { socket, prog_fd })
}
pub fn detach(&mut self, link_id: SocketFilterLinkId) -> Result<(), ProgramError> {
self.data.links.remove(link_id)
}
pub fn take_link(
&mut self,
link_id: SocketFilterLinkId,
) -> Result<OwnedLink<SocketFilterLink>, ProgramError> {
Ok(OwnedLink::new(self.data.take_link(link_id)?))
}
}
#[derive(Debug, Hash, Eq, PartialEq)]
pub struct SocketFilterLinkId(RawFd, RawFd);
#[derive(Debug)]
pub struct SocketFilterLink {
socket: RawFd,
prog_fd: RawFd,
}
impl Link for SocketFilterLink {
type Id = SocketFilterLinkId;
fn id(&self) -> Self::Id {
SocketFilterLinkId(self.socket, self.prog_fd)
}
fn detach(self) -> Result<(), ProgramError> {
unsafe {
setsockopt(
self.socket,
SOL_SOCKET,
SO_DETACH_BPF as i32,
&self.prog_fd as *const _ as *const _,
mem::size_of::<RawFd>() as u32,
);
}
Ok(())
}
}