use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::sync::{Arc, Mutex, OnceLock};
use std::time::SystemTime;
pub const DEFAULT_RING_CAPACITY: usize = 100;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ViolationKind {
FileRead,
FileWrite,
NetworkOutbound,
ProcessExec,
Other,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Violation {
pub kind: ViolationKind,
pub target: Option<String>,
pub command: Option<String>,
#[serde(with = "system_time_serde")]
pub at: SystemTime,
}
impl Violation {
#[must_use]
pub fn new(kind: ViolationKind, target: Option<String>, command: Option<String>) -> Self {
Self {
kind,
target,
command,
at: SystemTime::now(),
}
}
#[must_use]
pub fn render_line(&self) -> String {
let kind = match self.kind {
ViolationKind::FileRead => "deny file-read*",
ViolationKind::FileWrite => "deny file-write*",
ViolationKind::NetworkOutbound => "deny network-outbound",
ViolationKind::ProcessExec => "deny process-exec*",
ViolationKind::Other => "deny",
};
match &self.target {
Some(t) => format!("{kind} {t}"),
None => kind.to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct SandboxViolationStore {
inner: Arc<Mutex<VecDeque<Violation>>>,
capacity: usize,
}
impl SandboxViolationStore {
#[must_use]
pub fn new() -> Self {
Self::with_capacity(DEFAULT_RING_CAPACITY)
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
inner: Arc::new(Mutex::new(VecDeque::with_capacity(capacity))),
capacity,
}
}
pub fn record(&self, v: Violation) {
if self.capacity == 0 {
return;
}
let Ok(mut buf) = self.inner.lock() else {
return;
};
if buf.len() == self.capacity {
buf.pop_front();
}
buf.push_back(v);
}
#[must_use]
pub fn snapshot(&self) -> Vec<Violation> {
self.inner
.lock()
.map(|buf| buf.iter().cloned().collect())
.unwrap_or_default()
}
#[must_use]
pub fn drain(&self) -> Vec<Violation> {
let Ok(mut buf) = self.inner.lock() else {
return Vec::new();
};
buf.drain(..).collect()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.lock().map(|b| b.len()).unwrap_or(0)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn count_by_kind(&self) -> std::collections::HashMap<ViolationKind, usize> {
let buf = match self.inner.lock() {
Ok(b) => b,
Err(_) => return std::collections::HashMap::new(),
};
let mut counts = std::collections::HashMap::new();
for v in buf.iter() {
*counts.entry(v.kind.clone()).or_insert(0) += 1;
}
counts
}
}
impl Default for SandboxViolationStore {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn render_block(violations: &[Violation]) -> Option<String> {
if violations.is_empty() {
return None;
}
let mut out = String::from("<sandbox_violations>\n");
for v in violations {
out.push_str(&v.render_line());
out.push('\n');
}
out.push_str("</sandbox_violations>");
Some(out)
}
#[must_use]
pub fn global_store() -> SandboxViolationStore {
static GLOBAL: OnceLock<SandboxViolationStore> = OnceLock::new();
GLOBAL.get_or_init(SandboxViolationStore::new).clone()
}
mod system_time_serde {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::{SystemTime, UNIX_EPOCH};
pub fn serialize<S: Serializer>(t: &SystemTime, s: S) -> Result<S::Ok, S::Error> {
let secs = t
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs_f64())
.unwrap_or(0.0);
secs.serialize(s)
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<SystemTime, D::Error> {
let secs = f64::deserialize(d)?;
Ok(UNIX_EPOCH + std::time::Duration::from_secs_f64(secs))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ring_caps_at_capacity() {
let store = SandboxViolationStore::with_capacity(3);
for i in 0..5 {
store.record(Violation::new(
ViolationKind::FileWrite,
Some(format!("/tmp/v{i}")),
None,
));
}
let snap = store.snapshot();
assert_eq!(snap.len(), 3);
assert_eq!(snap[0].target.as_deref(), Some("/tmp/v2"));
assert_eq!(snap[2].target.as_deref(), Some("/tmp/v4"));
}
#[test]
fn capacity_zero_records_nothing() {
let store = SandboxViolationStore::with_capacity(0);
store.record(Violation::new(ViolationKind::Other, None, None));
assert!(store.is_empty());
}
#[test]
fn drain_empties_the_buffer() {
let store = SandboxViolationStore::new();
store.record(Violation::new(ViolationKind::FileRead, None, None));
store.record(Violation::new(ViolationKind::FileRead, None, None));
let drained = store.drain();
assert_eq!(drained.len(), 2);
assert!(store.is_empty());
}
#[test]
fn render_block_returns_none_when_empty() {
assert!(render_block(&[]).is_none());
}
#[test]
fn render_block_matches_cc_shape() {
let v = Violation::new(
ViolationKind::FileWrite,
Some("/Users/me/.ssh/id_rsa".into()),
None,
);
let block = render_block(&[v]).unwrap();
assert!(block.starts_with("<sandbox_violations>\n"));
assert!(block.contains("deny file-write* /Users/me/.ssh/id_rsa"));
assert!(block.ends_with("</sandbox_violations>"));
}
#[test]
fn store_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<SandboxViolationStore>();
}
#[test]
fn violation_serde_roundtrip() {
let v = Violation::new(
ViolationKind::NetworkOutbound,
Some("evil.com:443".into()),
Some("curl evil.com".into()),
);
let json = serde_json::to_string(&v).unwrap();
let back: Violation = serde_json::from_str(&json).unwrap();
assert_eq!(back.kind, v.kind);
assert_eq!(back.target, v.target);
assert_eq!(back.command, v.command);
}
#[test]
fn store_clones_share_inner_buffer() {
let store = SandboxViolationStore::new();
let twin = store.clone();
store.record(Violation::new(ViolationKind::Other, None, None));
assert_eq!(twin.len(), 1);
}
#[test]
fn global_store_returns_same_handle_across_calls() {
let a = global_store();
let b = global_store();
let before = a.len();
a.record(Violation::new(ViolationKind::Other, None, None));
assert_eq!(b.len(), before + 1);
let _ = a.drain();
}
#[test]
fn count_by_kind_tallies_per_variant() {
let s = SandboxViolationStore::new();
assert!(
s.count_by_kind().is_empty(),
"empty store should yield empty map"
);
s.record(Violation::new(ViolationKind::FileRead, None, None));
s.record(Violation::new(ViolationKind::FileRead, None, None));
s.record(Violation::new(ViolationKind::FileWrite, None, None));
s.record(Violation::new(ViolationKind::NetworkOutbound, None, None));
let counts = s.count_by_kind();
assert_eq!(counts.get(&ViolationKind::FileRead), Some(&2));
assert_eq!(counts.get(&ViolationKind::FileWrite), Some(&1));
assert_eq!(counts.get(&ViolationKind::NetworkOutbound), Some(&1));
assert!(!counts.contains_key(&ViolationKind::ProcessExec));
assert_eq!(counts.values().sum::<usize>(), s.len());
}
}