1use atd_protocol::ToolDefinition;
13
14pub trait Middleware: Send + Sync {
17 fn name(&self) -> &'static str;
18
19 fn on_result(&self, tool_id: &str, tool_def: &ToolDefinition, result: &mut serde_json::Value);
20}
21
22fn walk_strings(value: &mut serde_json::Value, f: &mut impl FnMut(&mut String)) {
25 match value {
26 serde_json::Value::String(s) => f(s),
27 serde_json::Value::Array(arr) => {
28 for v in arr.iter_mut() {
29 walk_strings(v, f);
30 }
31 }
32 serde_json::Value::Object(obj) => {
33 for (_k, v) in obj.iter_mut() {
34 walk_strings(v, f);
35 }
36 }
37 _ => {}
38 }
39}
40
41pub struct RedactPathsMiddleware {
47 patterns: Vec<(regex::Regex, String)>,
48}
49
50impl RedactPathsMiddleware {
51 pub fn new(patterns: Vec<(regex::Regex, String)>) -> Self {
52 Self { patterns }
53 }
54
55 pub fn with_home_default() -> Self {
59 let patterns = match std::env::var("HOME") {
60 Ok(home) if !home.is_empty() => {
61 let escaped = regex::escape(&home);
63 match regex::Regex::new(&escaped) {
64 Ok(re) => vec![(re, "<redacted:home>".to_string())],
65 Err(_) => vec![],
66 }
67 }
68 _ => vec![],
69 };
70 Self { patterns }
71 }
72}
73
74impl Middleware for RedactPathsMiddleware {
75 fn name(&self) -> &'static str {
76 "redact_paths"
77 }
78
79 fn on_result(
80 &self,
81 _tool_id: &str,
82 _tool_def: &ToolDefinition,
83 result: &mut serde_json::Value,
84 ) {
85 if self.patterns.is_empty() {
86 return;
87 }
88 let patterns = &self.patterns;
89 walk_strings(result, &mut |s| {
90 for (re, rep) in patterns {
91 *s = re.replace_all(s, rep.as_str()).into_owned();
92 }
93 });
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use atd_protocol::{
101 BindingProtocol, SafetyLevel, ToolBinding, ToolCapability, ToolResources, ToolSafety,
102 ToolTrust, ToolVisibility, TrustLevel,
103 };
104
105 fn tool_def() -> ToolDefinition {
106 ToolDefinition {
107 id: "test:mw".into(),
108 name: "mw".into(),
109 description: "middleware test fixture".into(),
110 version: "0.0.0".into(),
111 capability: ToolCapability {
112 domain: "test".into(),
113 actions: vec![],
114 tags: vec![],
115 intent_examples: vec![],
116 },
117 input_schema: serde_json::json!({}),
118 output_schema: serde_json::json!({}),
119 bindings: vec![ToolBinding {
120 protocol: BindingProtocol::Cli,
121 config: serde_json::json!({}),
122 }],
123 safety: ToolSafety {
124 level: SafetyLevel::Read,
125 dry_run: false,
126 side_effects: vec![],
127 data_sensitivity: None,
128 },
129 resources: ToolResources {
130 timeout_ms: 1000,
131 max_concurrent: 1,
132 rate_limit_per_min: None,
133 estimated_tokens: None,
134 },
135 trust: ToolTrust {
136 publisher: "test".into(),
137 trust_level: TrustLevel::L0Unverified,
138 signature: None,
139 },
140 visibility: ToolVisibility::Read,
141 required_capabilities: vec![],
142 tier: None,
143 errors: vec![],
144 }
145 }
146
147 fn mw_with(pattern: &str, rep: &str) -> RedactPathsMiddleware {
148 let re = regex::Regex::new(pattern).unwrap();
149 RedactPathsMiddleware::new(vec![(re, rep.to_string())])
150 }
151
152 #[test]
153 fn redacts_pattern_in_top_level_string() {
154 let mw = mw_with(r"/home/[^/]+", "<redacted>");
155 let def = tool_def();
156 let mut v = serde_json::json!({"path": "/home/alice/x.txt"});
157 mw.on_result("test:mw", &def, &mut v);
158 assert_eq!(v["path"], "<redacted>/x.txt");
159 }
160
161 #[test]
162 fn redacts_in_nested_object() {
163 let mw = mw_with(r"secret", "***");
164 let def = tool_def();
165 let mut v = serde_json::json!({
166 "outer": {"inner": "this is a secret value"}
167 });
168 mw.on_result("t", &def, &mut v);
169 assert_eq!(v["outer"]["inner"], "this is a *** value");
170 }
171
172 #[test]
173 fn redacts_in_array_elements() {
174 let mw = mw_with(r"password=\w+", "password=<redacted>");
175 let def = tool_def();
176 let mut v = serde_json::json!({
177 "entries": ["password=hunter2", "normal line", "password=correct horse"]
178 });
179 mw.on_result("t", &def, &mut v);
180 let arr = v["entries"].as_array().unwrap();
181 assert_eq!(arr[0], "password=<redacted>");
182 assert_eq!(arr[1], "normal line");
183 assert_eq!(arr[2], "password=<redacted> horse");
184 }
185
186 #[test]
187 fn leaves_non_string_leaves_untouched() {
188 let mw = mw_with(r"\d+", "N");
189 let def = tool_def();
190 let mut v = serde_json::json!({
191 "num": 42,
192 "bool": true,
193 "null": null,
194 "str_with_num": "port 42"
195 });
196 mw.on_result("t", &def, &mut v);
197 assert_eq!(v["num"], 42);
198 assert_eq!(v["bool"], true);
199 assert_eq!(v["null"], serde_json::Value::Null);
200 assert_eq!(v["str_with_num"], "port N");
201 }
202
203 #[test]
204 fn applies_multiple_patterns_in_order() {
205 let p1 = (regex::Regex::new(r"aaa").unwrap(), "bbb".to_string());
206 let p2 = (regex::Regex::new(r"bbb").unwrap(), "ccc".to_string());
207 let mw = RedactPathsMiddleware::new(vec![p1, p2]);
209 let def = tool_def();
210 let mut v = serde_json::json!({"x": "aaa"});
211 mw.on_result("t", &def, &mut v);
212 assert_eq!(v["x"], "ccc");
213 }
214
215 #[test]
216 fn name_is_stable() {
217 let mw = RedactPathsMiddleware::new(vec![]);
218 assert_eq!(mw.name(), "redact_paths");
219 }
220
221 #[test]
222 fn empty_middleware_is_a_noop() {
223 let mw = RedactPathsMiddleware::new(vec![]);
224 let def = tool_def();
225 let mut v = serde_json::json!({"x": "unchanged"});
226 mw.on_result("t", &def, &mut v);
227 assert_eq!(v["x"], "unchanged");
228 }
229
230 #[test]
231 fn with_home_default_handles_home_path_or_is_noop_when_unset() {
232 let prev = std::env::var_os("HOME");
236 unsafe {
238 std::env::set_var("HOME", "/tmp/fakehome-sp12");
239 }
240 let mw = RedactPathsMiddleware::with_home_default();
241 let def = tool_def();
242 let mut v = serde_json::json!({"p": "/tmp/fakehome-sp12/secret"});
243 mw.on_result("t", &def, &mut v);
244 assert_eq!(v["p"], "<redacted:home>/secret");
245
246 unsafe {
248 std::env::remove_var("HOME");
249 }
250 let mw2 = RedactPathsMiddleware::with_home_default();
251 let mut v2 = serde_json::json!({"p": "/tmp/anything"});
252 mw2.on_result("t", &def, &mut v2);
253 assert_eq!(v2["p"], "/tmp/anything");
255
256 if let Some(h) = prev {
258 unsafe {
259 std::env::set_var("HOME", h);
260 }
261 }
262 }
263}