Skip to main content

atd_runtime/
middleware.rs

1//! Result-middleware pipeline.
2//!
3//! A `Middleware` is invoked **on success** after a tool returns, with a
4//! mutable reference to the result value. SP-12 ships one built-in
5//! (`RedactPathsMiddleware`) to demonstrate the shape; the v3 brief's full
6//! suite (pii_redact, source_device_tag, compress, audit_log, rate_shape)
7//! is deferred.
8//!
9//! Error paths bypass middleware in SP-12 โ€” spec ยง8 Q4. A future SP can
10//! add an `on_error` hook once a real consumer exists.
11
12use atd_protocol::ToolDefinition;
13
14/// A result-rewriting hook. Must be deterministic and side-effect-free
15/// beyond the `result` mutation + any internal audit sinks the impl owns.
16pub trait Middleware: Send + Sync {
17    fn name(&self) -> &'static str;
18
19    fn on_result(&self, tool_id: &str, tool_def: &ToolDefinition, result: &mut serde_json::Value);
20}
21
22/// Walk a JSON value, applying `f` to every string leaf (including strings
23/// inside arrays and nested objects). Non-string leaves are untouched.
24fn walk_strings(value: &mut serde_json::Value, f: &mut impl FnMut(&mut String)) {
25    match value {
26        serde_json::Value::String(s) => f(s),
27        serde_json::Value::Array(arr) => {
28            for v in arr.iter_mut() {
29                walk_strings(v, f);
30            }
31        }
32        serde_json::Value::Object(obj) => {
33            for (_k, v) in obj.iter_mut() {
34                walk_strings(v, f);
35            }
36        }
37        _ => {}
38    }
39}
40
41/// Redact absolute filesystem paths from tool output. Applies each
42/// `(pattern, replacement)` pair in order to every string leaf in the
43/// result. Default construction via `with_home_default()` redacts
44/// `$HOME/...` paths โ€” a low-effort demonstration of the pattern, not a
45/// comprehensive PII scrubber.
46pub struct RedactPathsMiddleware {
47    patterns: Vec<(regex::Regex, String)>,
48}
49
50impl RedactPathsMiddleware {
51    pub fn new(patterns: Vec<(regex::Regex, String)>) -> Self {
52        Self { patterns }
53    }
54
55    /// Redact the current user's home directory. If `$HOME` is unset (rare
56    /// on CI, possible in containers), returns a middleware with an empty
57    /// pattern set rather than panicking โ€” it becomes a no-op.
58    pub fn with_home_default() -> Self {
59        let patterns = match std::env::var("HOME") {
60            Ok(home) if !home.is_empty() => {
61                // Escape regex metacharacters in the path before compiling.
62                let escaped = regex::escape(&home);
63                match regex::Regex::new(&escaped) {
64                    Ok(re) => vec![(re, "<redacted:home>".to_string())],
65                    Err(_) => vec![],
66                }
67            }
68            _ => vec![],
69        };
70        Self { patterns }
71    }
72}
73
74impl Middleware for RedactPathsMiddleware {
75    fn name(&self) -> &'static str {
76        "redact_paths"
77    }
78
79    fn on_result(
80        &self,
81        _tool_id: &str,
82        _tool_def: &ToolDefinition,
83        result: &mut serde_json::Value,
84    ) {
85        if self.patterns.is_empty() {
86            return;
87        }
88        let patterns = &self.patterns;
89        walk_strings(result, &mut |s| {
90            for (re, rep) in patterns {
91                *s = re.replace_all(s, rep.as_str()).into_owned();
92            }
93        });
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use atd_protocol::{
101        BindingProtocol, SafetyLevel, ToolBinding, ToolCapability, ToolResources, ToolSafety,
102        ToolTrust, ToolVisibility, TrustLevel,
103    };
104
105    fn tool_def() -> ToolDefinition {
106        ToolDefinition {
107            id: "test:mw".into(),
108            name: "mw".into(),
109            description: "middleware test fixture".into(),
110            version: "0.0.0".into(),
111            capability: ToolCapability {
112                domain: "test".into(),
113                actions: vec![],
114                tags: vec![],
115                intent_examples: vec![],
116            },
117            input_schema: serde_json::json!({}),
118            output_schema: serde_json::json!({}),
119            bindings: vec![ToolBinding {
120                protocol: BindingProtocol::Cli,
121                config: serde_json::json!({}),
122            }],
123            safety: ToolSafety {
124                level: SafetyLevel::Read,
125                dry_run: false,
126                side_effects: vec![],
127                data_sensitivity: None,
128            },
129            resources: ToolResources {
130                timeout_ms: 1000,
131                max_concurrent: 1,
132                rate_limit_per_min: None,
133                estimated_tokens: None,
134            },
135            trust: ToolTrust {
136                publisher: "test".into(),
137                trust_level: TrustLevel::L0Unverified,
138                signature: None,
139            },
140            visibility: ToolVisibility::Read,
141            required_capabilities: vec![],
142            tier: None,
143            errors: vec![],
144        }
145    }
146
147    fn mw_with(pattern: &str, rep: &str) -> RedactPathsMiddleware {
148        let re = regex::Regex::new(pattern).unwrap();
149        RedactPathsMiddleware::new(vec![(re, rep.to_string())])
150    }
151
152    #[test]
153    fn redacts_pattern_in_top_level_string() {
154        let mw = mw_with(r"/home/[^/]+", "<redacted>");
155        let def = tool_def();
156        let mut v = serde_json::json!({"path": "/home/alice/x.txt"});
157        mw.on_result("test:mw", &def, &mut v);
158        assert_eq!(v["path"], "<redacted>/x.txt");
159    }
160
161    #[test]
162    fn redacts_in_nested_object() {
163        let mw = mw_with(r"secret", "***");
164        let def = tool_def();
165        let mut v = serde_json::json!({
166            "outer": {"inner": "this is a secret value"}
167        });
168        mw.on_result("t", &def, &mut v);
169        assert_eq!(v["outer"]["inner"], "this is a *** value");
170    }
171
172    #[test]
173    fn redacts_in_array_elements() {
174        let mw = mw_with(r"password=\w+", "password=<redacted>");
175        let def = tool_def();
176        let mut v = serde_json::json!({
177            "entries": ["password=hunter2", "normal line", "password=correct horse"]
178        });
179        mw.on_result("t", &def, &mut v);
180        let arr = v["entries"].as_array().unwrap();
181        assert_eq!(arr[0], "password=<redacted>");
182        assert_eq!(arr[1], "normal line");
183        assert_eq!(arr[2], "password=<redacted> horse");
184    }
185
186    #[test]
187    fn leaves_non_string_leaves_untouched() {
188        let mw = mw_with(r"\d+", "N");
189        let def = tool_def();
190        let mut v = serde_json::json!({
191            "num": 42,
192            "bool": true,
193            "null": null,
194            "str_with_num": "port 42"
195        });
196        mw.on_result("t", &def, &mut v);
197        assert_eq!(v["num"], 42);
198        assert_eq!(v["bool"], true);
199        assert_eq!(v["null"], serde_json::Value::Null);
200        assert_eq!(v["str_with_num"], "port N");
201    }
202
203    #[test]
204    fn applies_multiple_patterns_in_order() {
205        let p1 = (regex::Regex::new(r"aaa").unwrap(), "bbb".to_string());
206        let p2 = (regex::Regex::new(r"bbb").unwrap(), "ccc".to_string());
207        // First 'aaa' -> 'bbb', then 'bbb' -> 'ccc'. End state: 'ccc'.
208        let mw = RedactPathsMiddleware::new(vec![p1, p2]);
209        let def = tool_def();
210        let mut v = serde_json::json!({"x": "aaa"});
211        mw.on_result("t", &def, &mut v);
212        assert_eq!(v["x"], "ccc");
213    }
214
215    #[test]
216    fn name_is_stable() {
217        let mw = RedactPathsMiddleware::new(vec![]);
218        assert_eq!(mw.name(), "redact_paths");
219    }
220
221    #[test]
222    fn empty_middleware_is_a_noop() {
223        let mw = RedactPathsMiddleware::new(vec![]);
224        let def = tool_def();
225        let mut v = serde_json::json!({"x": "unchanged"});
226        mw.on_result("t", &def, &mut v);
227        assert_eq!(v["x"], "unchanged");
228    }
229
230    #[test]
231    fn with_home_default_handles_home_path_or_is_noop_when_unset() {
232        // SAFETY-ish: we mutate HOME just for this test; other tests do not
233        // rely on it. If HOME was unset to begin with, the middleware must
234        // still be constructed without panic and act as a no-op.
235        let prev = std::env::var_os("HOME");
236        // Case A: HOME set.
237        unsafe {
238            std::env::set_var("HOME", "/tmp/fakehome-sp12");
239        }
240        let mw = RedactPathsMiddleware::with_home_default();
241        let def = tool_def();
242        let mut v = serde_json::json!({"p": "/tmp/fakehome-sp12/secret"});
243        mw.on_result("t", &def, &mut v);
244        assert_eq!(v["p"], "<redacted:home>/secret");
245
246        // Case B: HOME unset.
247        unsafe {
248            std::env::remove_var("HOME");
249        }
250        let mw2 = RedactPathsMiddleware::with_home_default();
251        let mut v2 = serde_json::json!({"p": "/tmp/anything"});
252        mw2.on_result("t", &def, &mut v2);
253        // No-op.
254        assert_eq!(v2["p"], "/tmp/anything");
255
256        // Restore HOME.
257        if let Some(h) = prev {
258            unsafe {
259                std::env::set_var("HOME", h);
260            }
261        }
262    }
263}