Skip to main content

atd_runtime/
middleware.rs

1//! Result-middleware pipeline.
2//!
3//! A `Middleware` is invoked after a tool returns, with a mutable reference
4//! to the egress value. SP-12 shipped one built-in (`RedactPathsMiddleware`)
5//! to demonstrate the shape; the v3 brief's full suite (pii_redact,
6//! source_device_tag, compress, audit_log, rate_shape) is deferred.
7//!
8//! Two hooks (SP-observability-completeness-v1 Axis A):
9//! - `on_result` — the SUCCESS path, and the `ExecutionFailed` exit (whose
10//!   wire shape is a `ToolResultResponse { success: false, result }`, i.e.
11//!   a result Value).
12//! - `on_error` — the `Response::Error` path (`InvalidArgs` /
13//!   `InternalError`), whose wire shape is a bare `message: String` +
14//!   optional `details`. Default no-op; security-sensitive middleware
15//!   override it. Before this SP, error paths bypassed middleware entirely
16//!   (SP-12 §8 Q4) — that let a tool's failure text (an arg echo, a panic
17//!   message naming a patient) reach the LLM unredacted, a real PHI leak.
18
19use atd_protocol::ToolDefinition;
20
21/// A result-rewriting hook. Must be deterministic and side-effect-free
22/// beyond the `result` / error mutation + any internal audit sinks the impl
23/// owns.
24pub trait Middleware: Send + Sync {
25    fn name(&self) -> &'static str;
26
27    fn on_result(&self, tool_id: &str, tool_def: &ToolDefinition, result: &mut serde_json::Value);
28
29    /// SP-observability-completeness-v1 Axis A. Egress redaction for the
30    /// FAILURE wire shape `Response::Error { message, details }` — the
31    /// `InvalidArgs` / `InternalError` dispatch exits. Default is a no-op,
32    /// preserving pre-SP behaviour for middleware that only rewrite success
33    /// results. **Security-sensitive middleware (PHI / PII redaction) MUST
34    /// override this** — a tool's failure text reaches the LLM verbatim and
35    /// may carry PHI (an arg echo, a panic message naming a patient).
36    /// `details` is the optional structured error payload; redact both.
37    ///
38    /// Note: the `ExecutionFailed` exit returns a `ToolResultResponse`
39    /// whose `result` is a Value and so runs through `on_result`, not this
40    /// hook. This hook is only for the bare-`message` `Response::Error`.
41    fn on_error(
42        &self,
43        tool_id: &str,
44        tool_def: &ToolDefinition,
45        message: &mut String,
46        details: &mut Option<serde_json::Value>,
47    ) {
48        let _ = (tool_id, tool_def, message, details);
49    }
50}
51
52/// Walk a JSON value, applying `f` to every string leaf (including strings
53/// inside arrays and nested objects). Non-string leaves are untouched.
54fn walk_strings(value: &mut serde_json::Value, f: &mut impl FnMut(&mut String)) {
55    match value {
56        serde_json::Value::String(s) => f(s),
57        serde_json::Value::Array(arr) => {
58            for v in arr.iter_mut() {
59                walk_strings(v, f);
60            }
61        }
62        serde_json::Value::Object(obj) => {
63            for (_k, v) in obj.iter_mut() {
64                walk_strings(v, f);
65            }
66        }
67        _ => {}
68    }
69}
70
71/// Redact absolute filesystem paths from tool output. Applies each
72/// `(pattern, replacement)` pair in order to every string leaf in the
73/// result. Default construction via `with_home_default()` redacts
74/// `$HOME/...` paths — a low-effort demonstration of the pattern, not a
75/// comprehensive PII scrubber.
76pub struct RedactPathsMiddleware {
77    patterns: Vec<(regex::Regex, String)>,
78}
79
80impl RedactPathsMiddleware {
81    pub fn new(patterns: Vec<(regex::Regex, String)>) -> Self {
82        Self { patterns }
83    }
84
85    /// Redact the current user's home directory. If `$HOME` is unset (rare
86    /// on CI, possible in containers), returns a middleware with an empty
87    /// pattern set rather than panicking — it becomes a no-op.
88    pub fn with_home_default() -> Self {
89        let patterns = match std::env::var("HOME") {
90            Ok(home) if !home.is_empty() => {
91                // Escape regex metacharacters in the path before compiling.
92                let escaped = regex::escape(&home);
93                match regex::Regex::new(&escaped) {
94                    Ok(re) => vec![(re, "<redacted:home>".to_string())],
95                    Err(_) => vec![],
96                }
97            }
98            _ => vec![],
99        };
100        Self { patterns }
101    }
102}
103
104impl Middleware for RedactPathsMiddleware {
105    fn name(&self) -> &'static str {
106        "redact_paths"
107    }
108
109    fn on_result(
110        &self,
111        _tool_id: &str,
112        _tool_def: &ToolDefinition,
113        result: &mut serde_json::Value,
114    ) {
115        if self.patterns.is_empty() {
116            return;
117        }
118        let patterns = &self.patterns;
119        walk_strings(result, &mut |s| {
120            for (re, rep) in patterns {
121                *s = re.replace_all(s, rep.as_str()).into_owned();
122            }
123        });
124    }
125
126    /// SP-observability-completeness-v1 Axis A — the `$HOME`/path scrub is
127    /// as relevant to error text as to success results (a failure message
128    /// can echo an absolute path). Apply the same patterns to the bare
129    /// `message` string and walk the optional `details` value.
130    fn on_error(
131        &self,
132        _tool_id: &str,
133        _tool_def: &ToolDefinition,
134        message: &mut String,
135        details: &mut Option<serde_json::Value>,
136    ) {
137        if self.patterns.is_empty() {
138            return;
139        }
140        for (re, rep) in &self.patterns {
141            *message = re.replace_all(message, rep.as_str()).into_owned();
142        }
143        if let Some(d) = details {
144            let patterns = &self.patterns;
145            walk_strings(d, &mut |s| {
146                for (re, rep) in patterns {
147                    *s = re.replace_all(s, rep.as_str()).into_owned();
148                }
149            });
150        }
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use atd_protocol::{
158        BindingProtocol, SafetyLevel, ToolBinding, ToolCapability, ToolResources, ToolSafety,
159        ToolTrust, ToolVisibility, TrustLevel,
160    };
161
162    fn tool_def() -> ToolDefinition {
163        ToolDefinition {
164            id: "test:mw".into(),
165            name: "mw".into(),
166            description: "middleware test fixture".into(),
167            version: "0.0.0".into(),
168            capability: ToolCapability {
169                domain: "test".into(),
170                actions: vec![],
171                tags: vec![],
172                intent_examples: vec![],
173            },
174            input_schema: serde_json::json!({}),
175            output_schema: serde_json::json!({}),
176            bindings: vec![ToolBinding {
177                protocol: BindingProtocol::Cli,
178                config: serde_json::json!({}),
179            }],
180            safety: ToolSafety {
181                level: SafetyLevel::Read,
182                dry_run: false,
183                side_effects: vec![],
184                data_sensitivity: None,
185            },
186            resources: ToolResources {
187                timeout_ms: 1000,
188                max_concurrent: 1,
189                rate_limit_per_min: None,
190                estimated_tokens: None,
191            },
192            trust: ToolTrust {
193                publisher: "test".into(),
194                trust_level: TrustLevel::L0Unverified,
195                signature: None,
196            },
197            visibility: ToolVisibility::Read,
198            required_capabilities: vec![],
199            tier: None,
200            errors: vec![],
201        }
202    }
203
204    fn mw_with(pattern: &str, rep: &str) -> RedactPathsMiddleware {
205        let re = regex::Regex::new(pattern).unwrap();
206        RedactPathsMiddleware::new(vec![(re, rep.to_string())])
207    }
208
209    #[test]
210    fn redacts_pattern_in_top_level_string() {
211        let mw = mw_with(r"/home/[^/]+", "<redacted>");
212        let def = tool_def();
213        let mut v = serde_json::json!({"path": "/home/alice/x.txt"});
214        mw.on_result("test:mw", &def, &mut v);
215        assert_eq!(v["path"], "<redacted>/x.txt");
216    }
217
218    #[test]
219    fn redacts_in_nested_object() {
220        let mw = mw_with(r"secret", "***");
221        let def = tool_def();
222        let mut v = serde_json::json!({
223            "outer": {"inner": "this is a secret value"}
224        });
225        mw.on_result("t", &def, &mut v);
226        assert_eq!(v["outer"]["inner"], "this is a *** value");
227    }
228
229    #[test]
230    fn redacts_in_array_elements() {
231        let mw = mw_with(r"password=\w+", "password=<redacted>");
232        let def = tool_def();
233        let mut v = serde_json::json!({
234            "entries": ["password=hunter2", "normal line", "password=correct horse"]
235        });
236        mw.on_result("t", &def, &mut v);
237        let arr = v["entries"].as_array().unwrap();
238        assert_eq!(arr[0], "password=<redacted>");
239        assert_eq!(arr[1], "normal line");
240        assert_eq!(arr[2], "password=<redacted> horse");
241    }
242
243    #[test]
244    fn leaves_non_string_leaves_untouched() {
245        let mw = mw_with(r"\d+", "N");
246        let def = tool_def();
247        let mut v = serde_json::json!({
248            "num": 42,
249            "bool": true,
250            "null": null,
251            "str_with_num": "port 42"
252        });
253        mw.on_result("t", &def, &mut v);
254        assert_eq!(v["num"], 42);
255        assert_eq!(v["bool"], true);
256        assert_eq!(v["null"], serde_json::Value::Null);
257        assert_eq!(v["str_with_num"], "port N");
258    }
259
260    #[test]
261    fn applies_multiple_patterns_in_order() {
262        let p1 = (regex::Regex::new(r"aaa").unwrap(), "bbb".to_string());
263        let p2 = (regex::Regex::new(r"bbb").unwrap(), "ccc".to_string());
264        // First 'aaa' -> 'bbb', then 'bbb' -> 'ccc'. End state: 'ccc'.
265        let mw = RedactPathsMiddleware::new(vec![p1, p2]);
266        let def = tool_def();
267        let mut v = serde_json::json!({"x": "aaa"});
268        mw.on_result("t", &def, &mut v);
269        assert_eq!(v["x"], "ccc");
270    }
271
272    #[test]
273    fn name_is_stable() {
274        let mw = RedactPathsMiddleware::new(vec![]);
275        assert_eq!(mw.name(), "redact_paths");
276    }
277
278    // ---- SP-observability-completeness-v1 Axis A: on_error ----
279
280    #[test]
281    fn on_error_redacts_message_and_details() {
282        let mw = mw_with(r"SECRET\w*", "<redacted>");
283        let def = tool_def();
284        let mut message = "leak SECRET123 in error".to_string();
285        let mut details = Some(serde_json::json!({"ctx": "also SECRET456 here"}));
286        mw.on_error("t", &def, &mut message, &mut details);
287        assert_eq!(message, "leak <redacted> in error");
288        assert_eq!(details.unwrap()["ctx"], "also <redacted> here");
289    }
290
291    #[test]
292    fn on_error_handles_none_details() {
293        let mw = mw_with(r"SECRET", "<redacted>");
294        let def = tool_def();
295        let mut message = "SECRET leaked".to_string();
296        let mut details = None;
297        mw.on_error("t", &def, &mut message, &mut details);
298        assert_eq!(message, "<redacted> leaked");
299        assert!(details.is_none());
300    }
301
302    #[test]
303    fn default_on_error_is_noop() {
304        // A middleware that does NOT override on_error leaves error text
305        // alone — the additive default preserves pre-SP behaviour.
306        struct Noop;
307        impl Middleware for Noop {
308            fn name(&self) -> &'static str {
309                "noop"
310            }
311            fn on_result(&self, _: &str, _: &ToolDefinition, _: &mut serde_json::Value) {}
312        }
313        let def = tool_def();
314        let mut message = "untouched SECRET".to_string();
315        let mut details = Some(serde_json::json!({"k": "untouched"}));
316        Noop.on_error("t", &def, &mut message, &mut details);
317        assert_eq!(message, "untouched SECRET");
318        assert_eq!(details.unwrap()["k"], "untouched");
319    }
320
321    #[test]
322    fn empty_middleware_is_a_noop() {
323        let mw = RedactPathsMiddleware::new(vec![]);
324        let def = tool_def();
325        let mut v = serde_json::json!({"x": "unchanged"});
326        mw.on_result("t", &def, &mut v);
327        assert_eq!(v["x"], "unchanged");
328    }
329
330    #[test]
331    fn with_home_default_handles_home_path_or_is_noop_when_unset() {
332        // SAFETY-ish: we mutate HOME just for this test; other tests do not
333        // rely on it. If HOME was unset to begin with, the middleware must
334        // still be constructed without panic and act as a no-op.
335        let prev = std::env::var_os("HOME");
336        // Case A: HOME set.
337        unsafe {
338            std::env::set_var("HOME", "/tmp/fakehome-sp12");
339        }
340        let mw = RedactPathsMiddleware::with_home_default();
341        let def = tool_def();
342        let mut v = serde_json::json!({"p": "/tmp/fakehome-sp12/secret"});
343        mw.on_result("t", &def, &mut v);
344        assert_eq!(v["p"], "<redacted:home>/secret");
345
346        // Case B: HOME unset.
347        unsafe {
348            std::env::remove_var("HOME");
349        }
350        let mw2 = RedactPathsMiddleware::with_home_default();
351        let mut v2 = serde_json::json!({"p": "/tmp/anything"});
352        mw2.on_result("t", &def, &mut v2);
353        // No-op.
354        assert_eq!(v2["p"], "/tmp/anything");
355
356        // Restore HOME.
357        if let Some(h) = prev {
358            unsafe {
359                std::env::set_var("HOME", h);
360            }
361        }
362    }
363}