use crate::error::ErrorKind::*;
use crate::error::{Result, SeccompError};
use crate::{ensure_supported_api, ScmpArch, ScmpFilterContext, ScmpVersion};
use libseccomp_sys::*;
use std::os::unix::io::RawFd;
fn get_errno() -> i32 {
std::io::Error::last_os_error().raw_os_error().unwrap_or(0)
}
fn notify_supported() -> Result<()> {
ensure_supported_api("seccomp notification", 6, ScmpVersion::from((2, 5, 0)))?;
Ok(())
}
pub type ScmpFd = RawFd;
pub const NOTIF_FLAG_CONTINUE: u32 = SECCOMP_USER_NOTIF_FLAG_CONTINUE;
impl ScmpFilterContext {
pub fn get_notify_fd(&self) -> Result<ScmpFd> {
notify_supported()?;
let ret = unsafe { seccomp_notify_fd(self.ctx.as_ptr()) };
if ret < 0 {
return Err(SeccompError::new(Errno(ret)));
}
Ok(ret)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ScmpNotifData {
pub syscall: i32,
pub arch: ScmpArch,
pub instr_pointer: u64,
pub args: [u64; 6],
}
impl ScmpNotifData {
fn from_sys(data: seccomp_data) -> Result<Self> {
Ok(Self {
syscall: data.nr,
arch: ScmpArch::from_sys(data.arch)?,
instr_pointer: data.instruction_pointer,
args: data.args,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ScmpNotifReq {
pub id: u64,
pub pid: u32,
pub flags: u32,
pub data: ScmpNotifData,
}
impl ScmpNotifReq {
fn from_sys(req: seccomp_notif) -> Result<Self> {
Ok(Self {
id: req.id,
pid: req.pid,
flags: req.flags,
data: ScmpNotifData::from_sys(req.data)?,
})
}
pub fn receive(fd: ScmpFd) -> Result<Self> {
notify_supported()?;
let mut req_ptr: *mut seccomp_notif = std::ptr::null_mut();
let ret = unsafe { seccomp_notify_alloc(&mut req_ptr, std::ptr::null_mut()) };
if ret != 0 {
return Err(SeccompError::new(Errno(ret)));
}
loop {
let ret = unsafe { seccomp_notify_receive(fd, req_ptr) };
let errno = get_errno();
if ret == 0 {
break;
} else if errno == libc::EINTR {
continue;
} else {
unsafe { seccomp_notify_free(req_ptr, std::ptr::null_mut()) };
return Err(SeccompError::new(Errno(ret)));
}
}
let req = seccomp_notif {
id: unsafe { (*req_ptr).id },
pid: unsafe { (*req_ptr).pid },
flags: unsafe { (*req_ptr).flags },
data: unsafe { (*req_ptr).data },
};
unsafe { seccomp_notify_free(req_ptr, std::ptr::null_mut()) };
Self::from_sys(req)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ScmpNotifResp {
pub id: u64,
pub val: i64,
pub error: i32,
pub flags: u32,
}
impl ScmpNotifResp {
unsafe fn to_sys(self, resp: *mut seccomp_notif_resp) {
(*resp).id = self.id;
(*resp).val = self.val;
(*resp).error = self.error;
(*resp).flags = self.flags;
}
pub fn new(id: u64, val: i64, error: i32, flags: u32) -> Self {
Self {
id,
val,
error,
flags,
}
}
pub fn respond(&self, fd: ScmpFd) -> Result<()> {
notify_supported()?;
let mut resp_ptr: *mut seccomp_notif_resp = std::ptr::null_mut();
let ret = unsafe { seccomp_notify_alloc(std::ptr::null_mut(), &mut resp_ptr) };
if ret != 0 {
return Err(SeccompError::new(Errno(ret)));
}
unsafe { self.to_sys(resp_ptr) };
loop {
let ret = unsafe { seccomp_notify_respond(fd, resp_ptr) };
let errno = get_errno();
if ret == 0 {
break;
} else if errno == libc::EINTR {
continue;
} else {
unsafe { seccomp_notify_free(std::ptr::null_mut(), resp_ptr) };
return Err(SeccompError::new(Errno(ret)));
}
}
unsafe { seccomp_notify_free(std::ptr::null_mut(), resp_ptr) };
Ok(())
}
}
pub fn notify_id_valid(fd: ScmpFd, id: u64) -> Result<()> {
notify_supported()?;
loop {
let ret = unsafe { seccomp_notify_id_valid(fd, id) };
let errno = get_errno();
if ret == 0 {
break;
} else if errno == libc::EINTR {
continue;
} else {
return Err(SeccompError::new(Errno(ret)));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{get_syscall_from_name, ScmpAction, ScmpArch, ScmpFilterContext};
use libc::{dup3, O_CLOEXEC};
use std::thread;
macro_rules! skip_if_not_supported {
() => {
if notify_supported().is_err() {
println!("Skip tests for userspace notification");
return;
}
};
}
#[derive(Debug)]
struct TestData {
syscall: i32,
args: Vec<u64>,
arch: ScmpArch,
resp_val: i64,
resp_err: i32,
resp_flags: u32,
expected_val: i64,
}
#[test]
fn test_user_notification() {
skip_if_not_supported!();
let mut ctx = ScmpFilterContext::new_filter(ScmpAction::Allow).unwrap();
let syscall = get_syscall_from_name("dup3", None).unwrap();
let arch = ScmpArch::native().unwrap();
ctx.add_arch(arch).unwrap();
ctx.add_rule(ScmpAction::Notify, syscall).unwrap();
let tests = &[
TestData {
syscall,
args: vec![0, 100, O_CLOEXEC as u64],
arch,
resp_val: 10,
resp_err: 0,
resp_flags: 0,
expected_val: 10,
},
TestData {
syscall,
args: vec![0, 100, O_CLOEXEC as u64],
arch,
resp_val: 0,
resp_err: -1,
resp_flags: 0,
expected_val: -1,
},
TestData {
syscall,
args: vec![0, 100, O_CLOEXEC as u64],
arch,
resp_val: 0,
resp_err: 0,
resp_flags: NOTIF_FLAG_CONTINUE,
expected_val: 100,
},
];
ctx.load().unwrap();
let fd = ctx.get_notify_fd().unwrap();
let mut handlers = vec![];
for test in tests.iter() {
let args: (i32, i32, i32) = (
test.args[0] as i32,
test.args[1] as i32,
test.args[2] as i32,
);
handlers.push(thread::spawn(move || unsafe {
dup3(args.0, args.1, args.2)
}));
let req = ScmpNotifReq::receive(fd).unwrap();
assert_eq!(req.data.arch, test.arch);
assert_eq!(req.data.syscall, test.syscall,);
for (i, test_val) in test.args.iter().enumerate() {
assert_eq!(&req.data.args[i], test_val);
}
assert!(notify_id_valid(fd, req.id).is_ok());
let resp = ScmpNotifResp::new(req.id, test.resp_val, test.resp_err, test.resp_flags);
resp.respond(fd).unwrap();
}
for (i, handler) in handlers.into_iter().enumerate() {
let ret_val = handler.join().unwrap();
assert_eq!(tests[i].expected_val as i32, ret_val);
}
}
#[test]
fn test_error() {
skip_if_not_supported!();
let ctx = ScmpFilterContext::new_filter(ScmpAction::Allow).unwrap();
let resp = ScmpNotifResp::new(0, 0, 0, 0);
assert!(ctx.get_notify_fd().is_err());
assert!(ScmpNotifReq::receive(0).is_err());
assert!(resp.respond(0).is_err());
assert!(notify_id_valid(0, 0).is_err());
}
}