use anyhow::{bail, Result};
use regex::Regex;
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct SyscallFilter {
include: Option<HashSet<String>>,
exclude: HashSet<String>,
include_regex: Vec<Regex>,
exclude_regex: Vec<Regex>,
}
impl SyscallFilter {
pub fn all() -> Self {
Self {
include: None,
exclude: HashSet::new(),
include_regex: Vec::new(),
exclude_regex: Vec::new(),
}
}
pub fn from_expr(expr: &str) -> Result<Self> {
if let Some(trace_spec) = expr.strip_prefix("trace=") {
Self::from_trace_spec(trace_spec)
} else {
bail!("Invalid filter expression: {expr}. Expected format: trace=SPEC");
}
}
fn from_trace_spec(spec: &str) -> Result<Self> {
validate_trace_spec(spec)?;
let (include_syscalls, exclude_syscalls, include_regex, exclude_regex, has_includes) =
parse_syscall_sets(spec)?;
let include = if has_includes {
Some(include_syscalls)
} else {
None };
Ok(Self { include, exclude: exclude_syscalls, include_regex, exclude_regex })
}
pub fn should_trace(&self, syscall_name: &str) -> bool {
contract_pre_error_handling!(syscall_name);
contract_post_error_handling!(&"ok");
if self.exclude.contains(syscall_name) {
return false;
}
for pattern in &self.exclude_regex {
if pattern.is_match(syscall_name) {
return false;
}
}
match &self.include {
None => {
if self.include_regex.is_empty() {
true
} else {
self.include_regex.iter().any(|pattern| pattern.is_match(syscall_name))
}
}
Some(set) => {
set.contains(syscall_name)
|| self.include_regex.iter().any(|pattern| pattern.is_match(syscall_name))
}
}
}
}
fn validate_trace_spec(spec: &str) -> Result<()> {
if spec.is_empty() {
return Ok(());
}
if spec.trim() == "!" {
bail!("Invalid negation syntax: '!' must be followed by syscall name or class");
}
Ok(())
}
type ParseResult = (
HashSet<String>, HashSet<String>, Vec<Regex>, Vec<Regex>, bool, );
fn parse_syscall_sets(spec: &str) -> Result<ParseResult> {
let mut include_syscalls = HashSet::new();
let mut exclude_syscalls = HashSet::new();
let mut include_regex = Vec::new();
let mut exclude_regex = Vec::new();
let mut has_includes = false;
if spec.is_empty() {
return Ok((include_syscalls, exclude_syscalls, include_regex, exclude_regex, true));
}
for part in spec.split(',') {
let part = part.trim();
let (is_negation, syscall_part) = if let Some(s) = part.strip_prefix('!') {
(true, s)
} else {
has_includes = true;
(false, part)
};
if let Some(pattern) = parse_regex_pattern(syscall_part)? {
if is_negation {
exclude_regex.push(pattern);
} else {
include_regex.push(pattern);
}
} else {
let syscalls_to_add = expand_syscall_class(syscall_part);
if is_negation {
exclude_syscalls.extend(syscalls_to_add);
} else {
include_syscalls.extend(syscalls_to_add);
}
}
}
Ok((include_syscalls, exclude_syscalls, include_regex, exclude_regex, has_includes))
}
fn parse_regex_pattern(input: &str) -> Result<Option<Regex>> {
if input.starts_with('/') && input.ends_with('/') && input.len() >= 2 {
let pattern = &input[1..input.len() - 1];
match Regex::new(pattern) {
Ok(regex) => Ok(Some(regex)),
Err(e) => bail!("Invalid regex pattern '{pattern}': {e}"),
}
} else {
Ok(None)
}
}
fn expand_syscall_class(name: &str) -> Vec<String> {
match name {
"file" => vec![
"open",
"openat",
"close",
"read",
"write",
"lseek",
"stat",
"fstat",
"newfstatat",
"access",
"mkdir",
"rmdir",
"unlink",
]
.iter()
.map(std::string::ToString::to_string)
.collect(),
"network" => [
"socket",
"connect",
"accept",
"bind",
"listen",
"send",
"recv",
"sendto",
"recvfrom",
"setsockopt",
"getsockopt",
]
.iter()
.map(std::string::ToString::to_string)
.collect(),
"process" => [
"fork",
"vfork",
"clone",
"execve",
"exit",
"exit_group",
"wait4",
"waitid",
"kill",
"tkill",
"tgkill",
]
.iter()
.map(std::string::ToString::to_string)
.collect(),
"memory" => ["mmap", "munmap", "mprotect", "mremap", "brk", "sbrk"]
.iter()
.map(std::string::ToString::to_string)
.collect(),
_ => vec![name.to_string()],
}
}
static_assertions::assert_impl_all!(SyscallFilter: Send, Sync);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_filter_all_traces_everything() {
let filter = SyscallFilter::all();
assert!(filter.should_trace("open"));
assert!(filter.should_trace("write"));
assert!(filter.should_trace("anything"));
}
#[test]
fn test_filter_individual_syscalls() {
let filter = SyscallFilter::from_expr("trace=open,read,write").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("read"));
assert!(filter.should_trace("write"));
assert!(!filter.should_trace("close"));
}
#[test]
fn test_filter_file_class() {
let filter = SyscallFilter::from_expr("trace=file").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("openat"));
assert!(filter.should_trace("read"));
assert!(filter.should_trace("write"));
assert!(!filter.should_trace("socket"));
}
#[test]
fn test_filter_network_class() {
let filter = SyscallFilter::from_expr("trace=network").expect("test");
assert!(filter.should_trace("socket"));
assert!(filter.should_trace("connect"));
assert!(!filter.should_trace("open"));
}
#[test]
fn test_filter_mixed() {
let filter = SyscallFilter::from_expr("trace=file,socket").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("socket"));
assert!(!filter.should_trace("clone"));
}
#[test]
fn test_invalid_expression() {
let result = SyscallFilter::from_expr("invalid");
assert!(result.is_err());
}
#[test]
fn test_filter_process_class() {
let filter = SyscallFilter::from_expr("trace=process").expect("test");
assert!(filter.should_trace("fork"));
assert!(filter.should_trace("clone"));
assert!(filter.should_trace("execve"));
assert!(filter.should_trace("exit"));
assert!(!filter.should_trace("open"));
assert!(!filter.should_trace("socket"));
}
#[test]
fn test_filter_memory_class() {
let filter = SyscallFilter::from_expr("trace=memory").expect("test");
assert!(filter.should_trace("mmap"));
assert!(filter.should_trace("munmap"));
assert!(filter.should_trace("mprotect"));
assert!(filter.should_trace("brk"));
assert!(!filter.should_trace("open"));
assert!(!filter.should_trace("fork"));
}
#[test]
fn test_filter_multiple_classes() {
let filter = SyscallFilter::from_expr("trace=file,network,process").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("read"));
assert!(filter.should_trace("socket"));
assert!(filter.should_trace("connect"));
assert!(filter.should_trace("fork"));
assert!(filter.should_trace("execve"));
assert!(!filter.should_trace("mmap"));
}
#[test]
fn test_filter_clone() {
let filter1 = SyscallFilter::from_expr("trace=open,read").expect("test");
let filter2 = filter1.clone();
assert!(filter2.should_trace("open"));
assert!(filter2.should_trace("read"));
assert!(!filter2.should_trace("write"));
}
#[test]
fn test_filter_debug() {
let filter = SyscallFilter::all();
let debug_str = format!("{:?}", filter);
assert!(debug_str.contains("SyscallFilter"));
}
#[test]
fn test_filter_empty_trace_spec() {
let filter = SyscallFilter::from_expr("trace=").expect("test");
assert!(!filter.should_trace("open"));
assert!(!filter.should_trace("read"));
}
#[test]
fn test_filter_whitespace_handling() {
let filter = SyscallFilter::from_expr("trace=open, read , write").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("read"));
assert!(filter.should_trace("write"));
assert!(!filter.should_trace("close"));
}
#[test]
fn test_negation_single_syscall() {
let filter = SyscallFilter::from_expr("trace=!close").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("read"));
assert!(!filter.should_trace("close"));
}
#[test]
fn test_negation_multiple_syscalls() {
let filter = SyscallFilter::from_expr("trace=!open,!close").expect("test");
assert!(!filter.should_trace("open"));
assert!(!filter.should_trace("close"));
assert!(filter.should_trace("read"));
assert!(filter.should_trace("write"));
}
#[test]
fn test_negation_syscall_class() {
let filter = SyscallFilter::from_expr("trace=!file").expect("test");
assert!(!filter.should_trace("open"));
assert!(!filter.should_trace("read"));
assert!(!filter.should_trace("write"));
assert!(filter.should_trace("socket"));
assert!(filter.should_trace("fork"));
}
#[test]
fn test_mixed_positive_negative() {
let filter = SyscallFilter::from_expr("trace=file,!close").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("read"));
assert!(!filter.should_trace("close")); assert!(!filter.should_trace("socket")); }
#[test]
fn test_negation_invalid_syntax() {
let result = SyscallFilter::from_expr("trace=!");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Invalid negation syntax"));
}
#[test]
fn test_negation_preserves_original_behavior() {
let filter = SyscallFilter::from_expr("trace=open,read").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("read"));
assert!(!filter.should_trace("write"));
}
#[test]
fn test_expand_syscall_class_file() {
let syscalls = expand_syscall_class("file");
assert!(syscalls.contains(&"open".to_string()));
assert!(syscalls.contains(&"close".to_string()));
assert!(syscalls.contains(&"read".to_string()));
}
#[test]
fn test_expand_syscall_class_network() {
let syscalls = expand_syscall_class("network");
assert!(syscalls.contains(&"socket".to_string()));
assert!(syscalls.contains(&"connect".to_string()));
}
#[test]
fn test_expand_syscall_class_individual() {
let syscalls = expand_syscall_class("custom_syscall");
assert_eq!(syscalls, vec!["custom_syscall".to_string()]);
}
#[test]
fn test_regex_pattern_basic() {
let filter = SyscallFilter::from_expr("trace=/^open.*/").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("openat"));
assert!(!filter.should_trace("close"));
assert!(!filter.should_trace("read"));
}
#[test]
fn test_regex_pattern_suffix() {
let filter = SyscallFilter::from_expr("trace=/.*at$/").expect("test");
assert!(filter.should_trace("openat"));
assert!(filter.should_trace("newfstatat"));
assert!(!filter.should_trace("open"));
assert!(!filter.should_trace("close"));
}
#[test]
fn test_regex_pattern_or() {
let filter = SyscallFilter::from_expr("trace=/read|write/").expect("test");
assert!(filter.should_trace("read"));
assert!(filter.should_trace("write"));
assert!(!filter.should_trace("open"));
assert!(!filter.should_trace("close"));
}
#[test]
fn test_regex_pattern_case_insensitive() {
let filter = SyscallFilter::from_expr("trace=/(?i)OPEN/").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("OPEN"));
assert!(!filter.should_trace("close"));
}
#[test]
fn test_regex_mixed_with_literal() {
let filter = SyscallFilter::from_expr("trace=/^open.*/,close").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("openat"));
assert!(filter.should_trace("close"));
assert!(!filter.should_trace("read"));
}
#[test]
fn test_regex_mixed_with_negation() {
let filter = SyscallFilter::from_expr("trace=/^open.*/,!/openat/").expect("test");
assert!(filter.should_trace("open"));
assert!(!filter.should_trace("openat")); assert!(!filter.should_trace("close"));
}
#[test]
fn test_regex_negation_pattern() {
let filter = SyscallFilter::from_expr("trace=!/close/").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("read"));
assert!(!filter.should_trace("close")); }
#[test]
fn test_regex_invalid_pattern() {
let result = SyscallFilter::from_expr("trace=/[invalid/");
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("regex") || err_msg.contains("invalid"));
}
#[test]
fn test_parse_regex_pattern_valid() {
let result = parse_regex_pattern("/^test.*/");
assert!(result.is_ok());
let pattern = result.expect("test");
assert!(pattern.is_some());
let regex = pattern.expect("test");
assert!(regex.is_match("test123"));
assert!(!regex.is_match("other"));
}
#[test]
fn test_parse_regex_pattern_not_regex() {
let result = parse_regex_pattern("open");
assert!(result.is_ok());
assert!(result.expect("test").is_none());
}
#[test]
fn test_parse_regex_pattern_empty() {
let result = parse_regex_pattern("//");
assert!(result.is_ok());
let pattern = result.expect("test");
assert!(pattern.is_some());
}
#[test]
fn test_regex_with_syscall_class() {
let filter = SyscallFilter::from_expr("trace=file,/socket|connect/").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("read"));
assert!(filter.should_trace("socket"));
assert!(filter.should_trace("connect"));
assert!(!filter.should_trace("fork"));
}
#[test]
fn test_regex_exclude_with_include_class() {
let filter = SyscallFilter::from_expr("trace=file,!/.*at$/").expect("test");
assert!(filter.should_trace("open"));
assert!(filter.should_trace("read"));
assert!(!filter.should_trace("openat")); assert!(!filter.should_trace("newfstatat")); }
}