Skip to main content

arbiter_audit/
middleware.rs

1//! Audit capture middleware: wraps a proxied request with timing and context.
2//!
3//! The proxy creates an [`AuditCapture`] at the start of each request, fills in
4//! context as it becomes available, then finalizes the capture after the upstream
5//! response. The resulting [`AuditEntry`] is written to the configured sink.
6
7use std::time::Instant;
8
9use uuid::Uuid;
10
11use std::sync::Arc;
12
13use crate::entry::AuditEntry;
14use crate::redaction::{CompiledRedaction, RedactionConfig};
15
16/// Captures audit context across a single proxied request lifecycle.
17///
18/// # Usage
19///
20/// ```ignore
21/// let compiled = redaction_config.compile();
22/// let mut capture = AuditCapture::begin_compiled(Arc::new(compiled));
23/// capture.set_agent_id("agent-1");
24/// capture.set_tool_called("/tools/call");
25/// // … proxy the request …
26/// let entry = capture.finalize(Some(200));
27/// audit_sink.write(&entry).await?;
28/// ```
29pub struct AuditCapture {
30    start: Instant,
31    entry: AuditEntry,
32    compiled_redaction: Arc<CompiledRedaction>,
33}
34
35impl AuditCapture {
36    /// Begin a new audit capture with a fresh request ID.
37    ///
38    /// Compiles redaction patterns on each call. For hot-path usage, prefer
39    /// [`begin_compiled`](Self::begin_compiled) with a pre-compiled config.
40    pub fn begin(redaction_config: RedactionConfig) -> Self {
41        Self {
42            start: Instant::now(),
43            entry: AuditEntry::new(Uuid::new_v4()),
44            compiled_redaction: Arc::new(redaction_config.compile()),
45        }
46    }
47
48    /// Begin a new audit capture with pre-compiled redaction patterns.
49    pub fn begin_compiled(compiled: Arc<CompiledRedaction>) -> Self {
50        Self {
51            start: Instant::now(),
52            entry: AuditEntry::new(Uuid::new_v4()),
53            compiled_redaction: compiled,
54        }
55    }
56
57    /// Begin a new audit capture with a caller-supplied request ID.
58    pub fn begin_with_id(request_id: Uuid, redaction_config: RedactionConfig) -> Self {
59        Self {
60            start: Instant::now(),
61            entry: AuditEntry::new(request_id),
62            compiled_redaction: Arc::new(redaction_config.compile()),
63        }
64    }
65
66    /// Begin a new audit capture with a caller-supplied request ID and
67    /// pre-compiled redaction patterns.
68    pub fn begin_with_id_compiled(request_id: Uuid, compiled: Arc<CompiledRedaction>) -> Self {
69        Self {
70            start: Instant::now(),
71            entry: AuditEntry::new(request_id),
72            compiled_redaction: compiled,
73        }
74    }
75
76    pub fn set_agent_id(&mut self, agent_id: impl Into<String>) {
77        self.entry.agent_id = agent_id.into();
78    }
79
80    pub fn set_delegation_chain(&mut self, chain: impl Into<String>) {
81        self.entry.delegation_chain = chain.into();
82    }
83
84    pub fn set_task_session_id(&mut self, session_id: impl Into<String>) {
85        self.entry.task_session_id = session_id.into();
86    }
87
88    pub fn set_tool_called(&mut self, tool: impl Into<String>) {
89        self.entry.tool_called = tool.into();
90    }
91
92    pub fn set_arguments(&mut self, args: serde_json::Value) {
93        self.entry.arguments = args;
94    }
95
96    /// Valid authorization decision values. Callers must use one of these.
97    const VALID_DECISIONS: &'static [&'static str] = &["allow", "deny", "escalate"];
98
99    pub fn set_authorization_decision(&mut self, decision: impl Into<String>) {
100        let decision = decision.into();
101        if !Self::VALID_DECISIONS.contains(&decision.as_str()) {
102            tracing::warn!(
103                decision = %decision,
104                "invalid authorization_decision value; expected one of: allow, deny, escalate"
105            );
106        }
107        self.entry.authorization_decision = decision;
108    }
109
110    pub fn set_policy_matched(&mut self, policy: impl Into<String>) {
111        self.entry.policy_matched = Some(policy.into());
112    }
113
114    pub fn set_anomaly_flags(&mut self, flags: Vec<String>) {
115        self.entry.anomaly_flags = flags;
116    }
117
118    pub fn set_failure_category(&mut self, category: impl Into<String>) {
119        self.entry.failure_category = Some(category.into());
120    }
121
122    pub fn add_inspection_findings(&mut self, findings: Vec<String>) {
123        self.entry.inspection_findings = findings;
124    }
125
126    /// Finalize the capture: compute latency, apply redaction, return the entry.
127    pub fn finalize(mut self, upstream_status: Option<u16>) -> AuditEntry {
128        self.entry.latency_ms = self.start.elapsed().as_millis() as u64;
129        self.entry.upstream_status = upstream_status;
130        self.entry.arguments = self.compiled_redaction.redact(&self.entry.arguments);
131        self.entry
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use serde_json::json;
139
140    #[test]
141    fn captures_latency() {
142        let capture = AuditCapture::begin(RedactionConfig::default());
143        // Simulate some work.
144        std::thread::sleep(std::time::Duration::from_millis(5));
145        let entry = capture.finalize(Some(200));
146
147        assert!(entry.latency_ms >= 5, "latency should be at least 5ms");
148        assert_eq!(entry.upstream_status, Some(200));
149    }
150
151    #[test]
152    fn redacts_arguments_on_finalize() {
153        let mut capture = AuditCapture::begin(RedactionConfig::default());
154        capture.set_arguments(json!({
155            "path": "/etc/hosts",
156            "api_key": "sk-secret-123"
157        }));
158        capture.set_tool_called("read_file");
159
160        let entry = capture.finalize(Some(200));
161
162        assert_eq!(entry.arguments["path"], "/etc/hosts");
163        assert_eq!(entry.arguments["api_key"], "[REDACTED]");
164        assert_eq!(entry.tool_called, "read_file");
165    }
166
167    #[test]
168    fn sets_all_fields() {
169        let id = Uuid::new_v4();
170        let mut capture = AuditCapture::begin_with_id(id, RedactionConfig { patterns: vec![] });
171        capture.set_agent_id("agent-42");
172        capture.set_delegation_chain("human>agent-42");
173        capture.set_task_session_id("session-abc");
174        capture.set_tool_called("write_file");
175        capture.set_authorization_decision("allow");
176        capture.set_policy_matched("policy-write");
177        capture.set_anomaly_flags(vec!["high_frequency".into()]);
178        capture.set_arguments(json!({"content": "hello"}));
179
180        let entry = capture.finalize(Some(201));
181
182        assert_eq!(entry.request_id, id);
183        assert_eq!(entry.agent_id, "agent-42");
184        assert_eq!(entry.delegation_chain, "human>agent-42");
185        assert_eq!(entry.task_session_id, "session-abc");
186        assert_eq!(entry.tool_called, "write_file");
187        assert_eq!(entry.authorization_decision, "allow");
188        assert_eq!(entry.policy_matched, Some("policy-write".into()));
189        assert_eq!(entry.anomaly_flags, vec!["high_frequency"]);
190        assert_eq!(entry.upstream_status, Some(201));
191        assert_eq!(entry.arguments["content"], "hello");
192    }
193}