1use atd_protocol::ToolDefinition;
20
21pub trait Middleware: Send + Sync {
25 fn name(&self) -> &'static str;
26
27 fn on_result(&self, tool_id: &str, tool_def: &ToolDefinition, result: &mut serde_json::Value);
28
29 fn on_error(
42 &self,
43 tool_id: &str,
44 tool_def: &ToolDefinition,
45 message: &mut String,
46 details: &mut Option<serde_json::Value>,
47 ) {
48 let _ = (tool_id, tool_def, message, details);
49 }
50}
51
52fn walk_strings(value: &mut serde_json::Value, f: &mut impl FnMut(&mut String)) {
55 match value {
56 serde_json::Value::String(s) => f(s),
57 serde_json::Value::Array(arr) => {
58 for v in arr.iter_mut() {
59 walk_strings(v, f);
60 }
61 }
62 serde_json::Value::Object(obj) => {
63 for (_k, v) in obj.iter_mut() {
64 walk_strings(v, f);
65 }
66 }
67 _ => {}
68 }
69}
70
71pub struct RedactPathsMiddleware {
77 patterns: Vec<(regex::Regex, String)>,
78}
79
80impl RedactPathsMiddleware {
81 pub fn new(patterns: Vec<(regex::Regex, String)>) -> Self {
82 Self { patterns }
83 }
84
85 pub fn with_home_default() -> Self {
89 let patterns = match std::env::var("HOME") {
90 Ok(home) if !home.is_empty() => {
91 let escaped = regex::escape(&home);
93 match regex::Regex::new(&escaped) {
94 Ok(re) => vec![(re, "<redacted:home>".to_string())],
95 Err(_) => vec![],
96 }
97 }
98 _ => vec![],
99 };
100 Self { patterns }
101 }
102}
103
104impl Middleware for RedactPathsMiddleware {
105 fn name(&self) -> &'static str {
106 "redact_paths"
107 }
108
109 fn on_result(
110 &self,
111 _tool_id: &str,
112 _tool_def: &ToolDefinition,
113 result: &mut serde_json::Value,
114 ) {
115 if self.patterns.is_empty() {
116 return;
117 }
118 let patterns = &self.patterns;
119 walk_strings(result, &mut |s| {
120 for (re, rep) in patterns {
121 *s = re.replace_all(s, rep.as_str()).into_owned();
122 }
123 });
124 }
125
126 fn on_error(
131 &self,
132 _tool_id: &str,
133 _tool_def: &ToolDefinition,
134 message: &mut String,
135 details: &mut Option<serde_json::Value>,
136 ) {
137 if self.patterns.is_empty() {
138 return;
139 }
140 for (re, rep) in &self.patterns {
141 *message = re.replace_all(message, rep.as_str()).into_owned();
142 }
143 if let Some(d) = details {
144 let patterns = &self.patterns;
145 walk_strings(d, &mut |s| {
146 for (re, rep) in patterns {
147 *s = re.replace_all(s, rep.as_str()).into_owned();
148 }
149 });
150 }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use atd_protocol::{
158 BindingProtocol, SafetyLevel, ToolBinding, ToolCapability, ToolResources, ToolSafety,
159 ToolTrust, ToolVisibility, TrustLevel,
160 };
161
162 fn tool_def() -> ToolDefinition {
163 ToolDefinition {
164 id: "test:mw".into(),
165 name: "mw".into(),
166 description: "middleware test fixture".into(),
167 version: "0.0.0".into(),
168 capability: ToolCapability {
169 domain: "test".into(),
170 actions: vec![],
171 tags: vec![],
172 intent_examples: vec![],
173 },
174 input_schema: serde_json::json!({}),
175 output_schema: serde_json::json!({}),
176 bindings: vec![ToolBinding {
177 protocol: BindingProtocol::Cli,
178 config: serde_json::json!({}),
179 }],
180 safety: ToolSafety {
181 level: SafetyLevel::Read,
182 dry_run: false,
183 side_effects: vec![],
184 data_sensitivity: None,
185 },
186 resources: ToolResources {
187 timeout_ms: 1000,
188 max_concurrent: 1,
189 rate_limit_per_min: None,
190 estimated_tokens: None,
191 },
192 trust: ToolTrust {
193 publisher: "test".into(),
194 trust_level: TrustLevel::L0Unverified,
195 signature: None,
196 },
197 visibility: ToolVisibility::Read,
198 required_capabilities: vec![],
199 tier: None,
200 errors: vec![],
201 }
202 }
203
204 fn mw_with(pattern: &str, rep: &str) -> RedactPathsMiddleware {
205 let re = regex::Regex::new(pattern).unwrap();
206 RedactPathsMiddleware::new(vec![(re, rep.to_string())])
207 }
208
209 #[test]
210 fn redacts_pattern_in_top_level_string() {
211 let mw = mw_with(r"/home/[^/]+", "<redacted>");
212 let def = tool_def();
213 let mut v = serde_json::json!({"path": "/home/alice/x.txt"});
214 mw.on_result("test:mw", &def, &mut v);
215 assert_eq!(v["path"], "<redacted>/x.txt");
216 }
217
218 #[test]
219 fn redacts_in_nested_object() {
220 let mw = mw_with(r"secret", "***");
221 let def = tool_def();
222 let mut v = serde_json::json!({
223 "outer": {"inner": "this is a secret value"}
224 });
225 mw.on_result("t", &def, &mut v);
226 assert_eq!(v["outer"]["inner"], "this is a *** value");
227 }
228
229 #[test]
230 fn redacts_in_array_elements() {
231 let mw = mw_with(r"password=\w+", "password=<redacted>");
232 let def = tool_def();
233 let mut v = serde_json::json!({
234 "entries": ["password=hunter2", "normal line", "password=correct horse"]
235 });
236 mw.on_result("t", &def, &mut v);
237 let arr = v["entries"].as_array().unwrap();
238 assert_eq!(arr[0], "password=<redacted>");
239 assert_eq!(arr[1], "normal line");
240 assert_eq!(arr[2], "password=<redacted> horse");
241 }
242
243 #[test]
244 fn leaves_non_string_leaves_untouched() {
245 let mw = mw_with(r"\d+", "N");
246 let def = tool_def();
247 let mut v = serde_json::json!({
248 "num": 42,
249 "bool": true,
250 "null": null,
251 "str_with_num": "port 42"
252 });
253 mw.on_result("t", &def, &mut v);
254 assert_eq!(v["num"], 42);
255 assert_eq!(v["bool"], true);
256 assert_eq!(v["null"], serde_json::Value::Null);
257 assert_eq!(v["str_with_num"], "port N");
258 }
259
260 #[test]
261 fn applies_multiple_patterns_in_order() {
262 let p1 = (regex::Regex::new(r"aaa").unwrap(), "bbb".to_string());
263 let p2 = (regex::Regex::new(r"bbb").unwrap(), "ccc".to_string());
264 let mw = RedactPathsMiddleware::new(vec![p1, p2]);
266 let def = tool_def();
267 let mut v = serde_json::json!({"x": "aaa"});
268 mw.on_result("t", &def, &mut v);
269 assert_eq!(v["x"], "ccc");
270 }
271
272 #[test]
273 fn name_is_stable() {
274 let mw = RedactPathsMiddleware::new(vec![]);
275 assert_eq!(mw.name(), "redact_paths");
276 }
277
278 #[test]
281 fn on_error_redacts_message_and_details() {
282 let mw = mw_with(r"SECRET\w*", "<redacted>");
283 let def = tool_def();
284 let mut message = "leak SECRET123 in error".to_string();
285 let mut details = Some(serde_json::json!({"ctx": "also SECRET456 here"}));
286 mw.on_error("t", &def, &mut message, &mut details);
287 assert_eq!(message, "leak <redacted> in error");
288 assert_eq!(details.unwrap()["ctx"], "also <redacted> here");
289 }
290
291 #[test]
292 fn on_error_handles_none_details() {
293 let mw = mw_with(r"SECRET", "<redacted>");
294 let def = tool_def();
295 let mut message = "SECRET leaked".to_string();
296 let mut details = None;
297 mw.on_error("t", &def, &mut message, &mut details);
298 assert_eq!(message, "<redacted> leaked");
299 assert!(details.is_none());
300 }
301
302 #[test]
303 fn default_on_error_is_noop() {
304 struct Noop;
307 impl Middleware for Noop {
308 fn name(&self) -> &'static str {
309 "noop"
310 }
311 fn on_result(&self, _: &str, _: &ToolDefinition, _: &mut serde_json::Value) {}
312 }
313 let def = tool_def();
314 let mut message = "untouched SECRET".to_string();
315 let mut details = Some(serde_json::json!({"k": "untouched"}));
316 Noop.on_error("t", &def, &mut message, &mut details);
317 assert_eq!(message, "untouched SECRET");
318 assert_eq!(details.unwrap()["k"], "untouched");
319 }
320
321 #[test]
322 fn empty_middleware_is_a_noop() {
323 let mw = RedactPathsMiddleware::new(vec![]);
324 let def = tool_def();
325 let mut v = serde_json::json!({"x": "unchanged"});
326 mw.on_result("t", &def, &mut v);
327 assert_eq!(v["x"], "unchanged");
328 }
329
330 #[test]
331 fn with_home_default_handles_home_path_or_is_noop_when_unset() {
332 let prev = std::env::var_os("HOME");
336 unsafe {
338 std::env::set_var("HOME", "/tmp/fakehome-sp12");
339 }
340 let mw = RedactPathsMiddleware::with_home_default();
341 let def = tool_def();
342 let mut v = serde_json::json!({"p": "/tmp/fakehome-sp12/secret"});
343 mw.on_result("t", &def, &mut v);
344 assert_eq!(v["p"], "<redacted:home>/secret");
345
346 unsafe {
348 std::env::remove_var("HOME");
349 }
350 let mw2 = RedactPathsMiddleware::with_home_default();
351 let mut v2 = serde_json::json!({"p": "/tmp/anything"});
352 mw2.on_result("t", &def, &mut v2);
353 assert_eq!(v2["p"], "/tmp/anything");
355
356 if let Some(h) = prev {
358 unsafe {
359 std::env::set_var("HOME", h);
360 }
361 }
362 }
363}