Skip to main content

arbiter_audit/
middleware.rs

1//! Audit capture middleware: wraps a proxied request with timing and context.
2//!
3//! The proxy creates an [`AuditCapture`] at the start of each request, fills in
4//! context as it becomes available, then finalizes the capture after the upstream
5//! response. The resulting [`AuditEntry`] is written to the configured sink.
6
7use std::time::Instant;
8
9use uuid::Uuid;
10
11use std::sync::Arc;
12
13use crate::entry::AuditEntry;
14use crate::redaction::{CompiledRedaction, RedactionConfig};
15
16/// Captures audit context across a single proxied request lifecycle.
17///
18/// # Usage
19///
20/// ```ignore
21/// let compiled = redaction_config.compile();
22/// let mut capture = AuditCapture::begin_compiled(Arc::new(compiled));
23/// capture.set_agent_id("agent-1");
24/// capture.set_tool_called("/tools/call");
25/// // … proxy the request …
26/// let entry = capture.finalize(Some(200));
27/// audit_sink.write(&entry).await?;
28/// ```
29pub struct AuditCapture {
30    start: Instant,
31    entry: AuditEntry,
32    compiled_redaction: Arc<CompiledRedaction>,
33}
34
35impl AuditCapture {
36    /// Begin a new audit capture with a fresh request ID.
37    ///
38    /// Compiles redaction patterns on each call. For hot-path usage, prefer
39    /// [`begin_compiled`](Self::begin_compiled) with a pre-compiled config.
40    pub fn begin(redaction_config: RedactionConfig) -> Self {
41        Self {
42            start: Instant::now(),
43            entry: AuditEntry::new(Uuid::new_v4()),
44            compiled_redaction: Arc::new(redaction_config.compile()),
45        }
46    }
47
48    /// Begin a new audit capture with pre-compiled redaction patterns.
49    pub fn begin_compiled(compiled: Arc<CompiledRedaction>) -> Self {
50        Self {
51            start: Instant::now(),
52            entry: AuditEntry::new(Uuid::new_v4()),
53            compiled_redaction: compiled,
54        }
55    }
56
57    /// Begin a new audit capture with a caller-supplied request ID.
58    pub fn begin_with_id(request_id: Uuid, redaction_config: RedactionConfig) -> Self {
59        Self {
60            start: Instant::now(),
61            entry: AuditEntry::new(request_id),
62            compiled_redaction: Arc::new(redaction_config.compile()),
63        }
64    }
65
66    /// Begin a new audit capture with a caller-supplied request ID and
67    /// pre-compiled redaction patterns.
68    pub fn begin_with_id_compiled(request_id: Uuid, compiled: Arc<CompiledRedaction>) -> Self {
69        Self {
70            start: Instant::now(),
71            entry: AuditEntry::new(request_id),
72            compiled_redaction: compiled,
73        }
74    }
75
76    pub fn set_agent_id(&mut self, agent_id: impl Into<String>) {
77        self.entry.agent_id = agent_id.into();
78    }
79
80    pub fn set_delegation_chain(&mut self, chain: impl Into<String>) {
81        self.entry.delegation_chain = chain.into();
82    }
83
84    pub fn set_task_session_id(&mut self, session_id: impl Into<String>) {
85        self.entry.task_session_id = session_id.into();
86    }
87
88    pub fn set_tool_called(&mut self, tool: impl Into<String>) {
89        self.entry.tool_called = tool.into();
90    }
91
92    pub fn set_arguments(&mut self, args: serde_json::Value) {
93        self.entry.arguments = args;
94    }
95
96    pub fn set_authorization_decision(&mut self, decision: impl Into<String>) {
97        self.entry.authorization_decision = decision.into();
98    }
99
100    pub fn set_policy_matched(&mut self, policy: impl Into<String>) {
101        self.entry.policy_matched = Some(policy.into());
102    }
103
104    pub fn set_anomaly_flags(&mut self, flags: Vec<String>) {
105        self.entry.anomaly_flags = flags;
106    }
107
108    pub fn set_failure_category(&mut self, category: impl Into<String>) {
109        self.entry.failure_category = Some(category.into());
110    }
111
112    pub fn add_inspection_findings(&mut self, findings: Vec<String>) {
113        self.entry.inspection_findings = findings;
114    }
115
116    /// Finalize the capture: compute latency, apply redaction, return the entry.
117    pub fn finalize(mut self, upstream_status: Option<u16>) -> AuditEntry {
118        self.entry.latency_ms = self.start.elapsed().as_millis() as u64;
119        self.entry.upstream_status = upstream_status;
120        self.entry.arguments = self.compiled_redaction.redact(&self.entry.arguments);
121        self.entry
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use serde_json::json;
129
130    #[test]
131    fn captures_latency() {
132        let capture = AuditCapture::begin(RedactionConfig::default());
133        // Simulate some work.
134        std::thread::sleep(std::time::Duration::from_millis(5));
135        let entry = capture.finalize(Some(200));
136
137        assert!(entry.latency_ms >= 5, "latency should be at least 5ms");
138        assert_eq!(entry.upstream_status, Some(200));
139    }
140
141    #[test]
142    fn redacts_arguments_on_finalize() {
143        let mut capture = AuditCapture::begin(RedactionConfig::default());
144        capture.set_arguments(json!({
145            "path": "/etc/hosts",
146            "api_key": "sk-secret-123"
147        }));
148        capture.set_tool_called("read_file");
149
150        let entry = capture.finalize(Some(200));
151
152        assert_eq!(entry.arguments["path"], "/etc/hosts");
153        assert_eq!(entry.arguments["api_key"], "[REDACTED]");
154        assert_eq!(entry.tool_called, "read_file");
155    }
156
157    #[test]
158    fn sets_all_fields() {
159        let id = Uuid::new_v4();
160        let mut capture = AuditCapture::begin_with_id(id, RedactionConfig { patterns: vec![] });
161        capture.set_agent_id("agent-42");
162        capture.set_delegation_chain("human>agent-42");
163        capture.set_task_session_id("session-abc");
164        capture.set_tool_called("write_file");
165        capture.set_authorization_decision("allow");
166        capture.set_policy_matched("policy-write");
167        capture.set_anomaly_flags(vec!["high_frequency".into()]);
168        capture.set_arguments(json!({"content": "hello"}));
169
170        let entry = capture.finalize(Some(201));
171
172        assert_eq!(entry.request_id, id);
173        assert_eq!(entry.agent_id, "agent-42");
174        assert_eq!(entry.delegation_chain, "human>agent-42");
175        assert_eq!(entry.task_session_id, "session-abc");
176        assert_eq!(entry.tool_called, "write_file");
177        assert_eq!(entry.authorization_decision, "allow");
178        assert_eq!(entry.policy_matched, Some("policy-write".into()));
179        assert_eq!(entry.anomaly_flags, vec!["high_frequency"]);
180        assert_eq!(entry.upstream_status, Some(201));
181        assert_eq!(entry.arguments["content"], "hello");
182    }
183}