use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SyscallCategory {
File,
Network,
Process,
Memory,
}
#[derive(Debug, Clone)]
pub struct SyscallEvent {
pub syscall: String,
pub category: SyscallCategory,
pub pid: u32,
pub parent_pid: Option<u32>,
pub path: Option<String>,
pub host: Option<IpAddr>,
pub port: Option<u16>,
pub size: Option<u64>,
pub argv: Option<Vec<String>>,
pub denied: bool,
}
impl SyscallEvent {
pub fn path_contains(&self, s: &str) -> bool {
self.path.as_ref().map_or(false, |p| p.contains(s))
}
pub fn argv_contains(&self, s: &str) -> bool {
self.argv.as_ref().map_or(false, |args| args.iter().any(|a| a.contains(s)))
}
}
#[derive(Debug, Clone)]
pub struct LivePolicy {
pub allowed_ips: HashSet<IpAddr>,
pub max_memory_bytes: u64,
pub max_processes: u32,
}
pub struct PolicyContext {
live: Arc<RwLock<LivePolicy>>,
ceiling: LivePolicy,
restricted: HashSet<&'static str>,
pid_overrides: Arc<RwLock<HashMap<u32, HashSet<IpAddr>>>>,
denied_paths: Arc<RwLock<HashSet<String>>>,
}
impl PolicyContext {
pub(crate) fn new(
live: Arc<RwLock<LivePolicy>>,
ceiling: LivePolicy,
pid_overrides: Arc<RwLock<HashMap<u32, HashSet<IpAddr>>>>,
denied_paths: Arc<RwLock<HashSet<String>>>,
) -> Self {
Self {
live,
ceiling,
restricted: HashSet::new(),
pid_overrides,
denied_paths,
}
}
pub fn current(&self) -> LivePolicy {
self.live.read().unwrap().clone()
}
pub fn ceiling(&self) -> &LivePolicy {
&self.ceiling
}
pub fn grant_network(&mut self, ips: &[IpAddr]) -> Result<(), PolicyFnError> {
self.check_not_restricted("allowed_ips")?;
let mut live = self.live.write().unwrap();
for ip in ips {
if self.ceiling.allowed_ips.contains(ip) {
live.allowed_ips.insert(*ip);
}
}
Ok(())
}
pub fn grant_max_memory(&mut self, bytes: u64) -> Result<(), PolicyFnError> {
self.check_not_restricted("max_memory_bytes")?;
let mut live = self.live.write().unwrap();
live.max_memory_bytes = bytes.min(self.ceiling.max_memory_bytes);
Ok(())
}
pub fn grant_max_processes(&mut self, n: u32) -> Result<(), PolicyFnError> {
self.check_not_restricted("max_processes")?;
let mut live = self.live.write().unwrap();
live.max_processes = n.min(self.ceiling.max_processes);
Ok(())
}
pub fn restrict_network(&mut self, ips: &[IpAddr]) {
self.restricted.insert("allowed_ips");
let mut live = self.live.write().unwrap();
live.allowed_ips = ips.iter().copied().collect();
}
pub fn restrict_max_memory(&mut self, bytes: u64) {
self.restricted.insert("max_memory_bytes");
let mut live = self.live.write().unwrap();
live.max_memory_bytes = bytes;
}
pub fn restrict_max_processes(&mut self, n: u32) {
self.restricted.insert("max_processes");
let mut live = self.live.write().unwrap();
live.max_processes = n;
}
pub fn restrict_pid_network(&self, pid: u32, ips: &[IpAddr]) {
let mut overrides = self.pid_overrides.write().unwrap();
overrides.insert(pid, ips.iter().copied().collect());
}
pub fn clear_pid_override(&self, pid: u32) {
let mut overrides = self.pid_overrides.write().unwrap();
overrides.remove(&pid);
}
pub fn deny_path(&self, path: &str) {
let mut denied = self.denied_paths.write().unwrap();
denied.insert(path.to_string());
}
pub fn allow_path(&self, path: &str) {
let mut denied = self.denied_paths.write().unwrap();
denied.remove(path);
}
fn check_not_restricted(&self, field: &str) -> Result<(), PolicyFnError> {
if self.restricted.contains(field) {
Err(PolicyFnError::FieldRestricted(field.to_string()))
} else {
Ok(())
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum PolicyFnError {
#[error("cannot grant restricted field: {0}")]
FieldRestricted(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Verdict {
Allow,
Audit,
Deny,
DenyWith(i32),
}
impl Default for Verdict {
fn default() -> Self { Verdict::Allow }
}
pub type PolicyCallback = Arc<dyn Fn(SyscallEvent, &mut PolicyContext) -> Verdict + Send + Sync + 'static>;
pub struct PolicyEvent {
pub event: SyscallEvent,
pub gate: Option<tokio::sync::oneshot::Sender<Verdict>>,
}
pub(crate) fn spawn_policy_fn(
callback: PolicyCallback,
live: Arc<RwLock<LivePolicy>>,
ceiling: LivePolicy,
pid_overrides: Arc<RwLock<HashMap<u32, HashSet<IpAddr>>>>,
denied_paths: Arc<RwLock<HashSet<String>>>,
) -> tokio::sync::mpsc::UnboundedSender<PolicyEvent> {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<PolicyEvent>();
std::thread::Builder::new()
.name("sandlock-policy-fn".to_string())
.spawn(move || {
let mut ctx = PolicyContext::new(live, ceiling, pid_overrides, denied_paths);
while let Some(pe) = rx.blocking_recv() {
let verdict = callback(pe.event, &mut ctx);
if let Some(gate) = pe.gate {
let _ = gate.send(verdict);
}
}
})
.expect("failed to spawn policy-fn thread");
tx
}
#[cfg(test)]
mod tests {
use super::*;
fn test_live() -> LivePolicy {
LivePolicy {
allowed_ips: ["127.0.0.1", "10.0.0.1"]
.iter()
.map(|s| s.parse().unwrap())
.collect(),
max_memory_bytes: 1024 * 1024 * 1024,
max_processes: 64,
}
}
#[test]
fn test_grant_within_ceiling() {
let live = Arc::new(RwLock::new(LivePolicy {
allowed_ips: HashSet::new(),
max_memory_bytes: 0,
max_processes: 0,
}));
let ceiling = test_live();
let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
let denied_paths = Arc::new(RwLock::new(HashSet::new()));
let mut ctx = PolicyContext::new(live.clone(), ceiling, pid_overrides, denied_paths);
let ip: IpAddr = "127.0.0.1".parse().unwrap();
ctx.grant_network(&[ip]).unwrap();
assert!(live.read().unwrap().allowed_ips.contains(&ip));
}
#[test]
fn test_grant_capped_to_ceiling() {
let live = Arc::new(RwLock::new(LivePolicy {
allowed_ips: HashSet::new(),
max_memory_bytes: 0,
max_processes: 0,
}));
let ceiling = test_live();
let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
let denied_paths = Arc::new(RwLock::new(HashSet::new()));
let mut ctx = PolicyContext::new(live.clone(), ceiling, pid_overrides, denied_paths);
let foreign: IpAddr = "8.8.8.8".parse().unwrap();
ctx.grant_network(&[foreign]).unwrap();
assert!(!live.read().unwrap().allowed_ips.contains(&foreign));
}
#[test]
fn test_restrict_then_grant_fails() {
let live = Arc::new(RwLock::new(test_live()));
let ceiling = test_live();
let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
let denied_paths = Arc::new(RwLock::new(HashSet::new()));
let mut ctx = PolicyContext::new(live, ceiling, pid_overrides, denied_paths);
ctx.restrict_network(&[]);
let ip: IpAddr = "127.0.0.1".parse().unwrap();
assert!(ctx.grant_network(&[ip]).is_err());
}
#[test]
fn test_restrict_max_memory() {
let live = Arc::new(RwLock::new(test_live()));
let ceiling = test_live();
let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
let denied_paths = Arc::new(RwLock::new(HashSet::new()));
let mut ctx = PolicyContext::new(live.clone(), ceiling, pid_overrides, denied_paths);
ctx.restrict_max_memory(256 * 1024 * 1024);
assert_eq!(live.read().unwrap().max_memory_bytes, 256 * 1024 * 1024);
}
#[test]
fn test_pid_override() {
let live = Arc::new(RwLock::new(test_live()));
let ceiling = test_live();
let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
let denied_paths = Arc::new(RwLock::new(HashSet::new()));
let ctx = PolicyContext::new(live, ceiling, pid_overrides.clone(), denied_paths);
let localhost: IpAddr = "127.0.0.1".parse().unwrap();
ctx.restrict_pid_network(1234, &[localhost]);
let overrides = pid_overrides.read().unwrap();
let pid_ips = overrides.get(&1234).unwrap();
assert!(pid_ips.contains(&localhost));
assert_eq!(pid_ips.len(), 1);
}
#[test]
fn test_clear_pid_override() {
let live = Arc::new(RwLock::new(test_live()));
let ceiling = test_live();
let pid_overrides = Arc::new(RwLock::new(HashMap::new()));
let denied_paths = Arc::new(RwLock::new(HashSet::new()));
let ctx = PolicyContext::new(live, ceiling, pid_overrides.clone(), denied_paths);
let localhost: IpAddr = "127.0.0.1".parse().unwrap();
ctx.restrict_pid_network(1234, &[localhost]);
ctx.clear_pid_override(1234);
assert!(!pid_overrides.read().unwrap().contains_key(&1234));
}
#[test]
fn test_event_path_contains() {
let event = SyscallEvent {
syscall: "execve".to_string(),
category: SyscallCategory::Process,
pid: 1,
parent_pid: Some(0),
path: Some("/usr/bin/python3".to_string()),
host: None,
port: None,
size: None,
argv: Some(vec!["python3".into(), "-c".into(), "print(1)".into()]),
denied: false,
};
assert!(event.argv_contains("python3"));
assert!(event.argv_contains("-c"));
assert!(!event.argv_contains("ruby"));
assert_eq!(event.category, SyscallCategory::Process);
assert!(event.path_contains("python"));
assert!(!event.path_contains("ruby"));
}
}