Skip to main content

aft/bash_rewrite/
rules.rs

1use serde_json::{json, Value};
2
3const REGEX_SIZE_LIMIT: usize = 10 * 1024 * 1024;
4
5use crate::bash_rewrite::footer::{add_footer, add_grep_footer};
6use crate::bash_rewrite::parser::parse;
7use crate::bash_rewrite::RewriteRule;
8use crate::context::AppContext;
9use crate::protocol::{RawRequest, Response};
10
11pub struct GrepRule;
12pub struct RgRule;
13pub struct FindRule;
14pub struct CatRule;
15pub struct CatAppendRule;
16pub struct SedRule;
17pub struct LsRule;
18
19impl RewriteRule for GrepRule {
20    fn name(&self) -> &'static str {
21        "grep"
22    }
23
24    fn matches(&self, command: &str) -> bool {
25        grep_request(command, "grep").is_some()
26    }
27
28    fn rewrite(
29        &self,
30        command: &str,
31        session_id: Option<&str>,
32        ctx: &AppContext,
33    ) -> Result<Response, String> {
34        let params = grep_request(command, "grep").ok_or("not a grep rewrite")?;
35        try_call_and_grep_footer(
36            crate::commands::grep::handle_grep(&request("grep", params, session_id), ctx),
37            ctx,
38        )
39    }
40}
41
42impl RewriteRule for RgRule {
43    fn name(&self) -> &'static str {
44        "rg"
45    }
46
47    fn matches(&self, command: &str) -> bool {
48        grep_request(command, "rg").is_some()
49    }
50
51    fn rewrite(
52        &self,
53        command: &str,
54        session_id: Option<&str>,
55        ctx: &AppContext,
56    ) -> Result<Response, String> {
57        let params = grep_request(command, "rg").ok_or("not an rg rewrite")?;
58        try_call_and_grep_footer(
59            crate::commands::grep::handle_grep(&request("grep", params, session_id), ctx),
60            ctx,
61        )
62    }
63}
64
65impl RewriteRule for FindRule {
66    fn name(&self) -> &'static str {
67        "find"
68    }
69
70    fn matches(&self, command: &str) -> bool {
71        find_request(command).is_some()
72    }
73
74    fn rewrite(
75        &self,
76        command: &str,
77        session_id: Option<&str>,
78        ctx: &AppContext,
79    ) -> Result<Response, String> {
80        let params = find_request(command).ok_or("not a find rewrite")?;
81        try_call_and_footer(
82            crate::commands::glob::handle_glob(&request("glob", params, session_id), ctx),
83            "glob",
84        )
85    }
86}
87
88impl RewriteRule for CatRule {
89    fn name(&self) -> &'static str {
90        "cat"
91    }
92
93    fn matches(&self, command: &str) -> bool {
94        cat_read_request(command).is_some()
95    }
96
97    fn rewrite(
98        &self,
99        command: &str,
100        session_id: Option<&str>,
101        ctx: &AppContext,
102    ) -> Result<Response, String> {
103        let params = cat_read_request(command).ok_or("not a cat rewrite")?;
104        try_call_and_footer(
105            crate::commands::read::handle_read(&request("read", params, session_id), ctx),
106            "read",
107        )
108    }
109}
110
111impl RewriteRule for CatAppendRule {
112    fn name(&self) -> &'static str {
113        "cat_append"
114    }
115
116    fn matches(&self, command: &str) -> bool {
117        append_request(command).is_some()
118    }
119
120    fn rewrite(
121        &self,
122        command: &str,
123        session_id: Option<&str>,
124        ctx: &AppContext,
125    ) -> Result<Response, String> {
126        let params = append_request(command).ok_or("not an append rewrite")?;
127        try_call_and_footer(
128            crate::commands::edit_match::handle_edit_match(
129                &request("edit_match", params, session_id),
130                ctx,
131            ),
132            "edit",
133        )
134    }
135}
136
137impl RewriteRule for SedRule {
138    fn name(&self) -> &'static str {
139        "sed"
140    }
141
142    fn matches(&self, command: &str) -> bool {
143        sed_request(command).is_some()
144    }
145
146    fn rewrite(
147        &self,
148        command: &str,
149        session_id: Option<&str>,
150        ctx: &AppContext,
151    ) -> Result<Response, String> {
152        let params = sed_request(command).ok_or("not a sed rewrite")?;
153        try_call_and_footer(
154            crate::commands::read::handle_read(&request("read", params, session_id), ctx),
155            "read",
156        )
157    }
158}
159
160impl RewriteRule for LsRule {
161    fn name(&self) -> &'static str {
162        "ls"
163    }
164
165    fn matches(&self, command: &str) -> bool {
166        ls_request(command).is_some()
167    }
168
169    fn rewrite(
170        &self,
171        command: &str,
172        session_id: Option<&str>,
173        ctx: &AppContext,
174    ) -> Result<Response, String> {
175        let params = ls_request(command).ok_or("not an ls rewrite")?;
176        try_call_and_footer(
177            crate::commands::read::handle_read(&request("read", params, session_id), ctx),
178            "read",
179        )
180    }
181}
182
183fn request(command: &str, params: Value, session_id: Option<&str>) -> RawRequest {
184    RawRequest {
185        id: "bash_rewrite".to_string(),
186        command: command.to_string(),
187        lsp_hints: None,
188        session_id: session_id.map(str::to_string),
189        params,
190    }
191}
192
193/// Run an underlying tool through the rewrite path. If the tool returned
194/// `success: false`, propagate as Err so dispatch falls through to actual bash
195/// — the agent's intent was bash, the rewrite is a transparent optimization.
196/// Returning a wrapped error response would surprise the agent (e.g. read's
197/// `outside project root` rejecting a sed that bash would have allowed).
198fn try_call_and_footer(response: Response, replacement_tool: &str) -> Result<Response, String> {
199    if let Some(err) = declined_error(&response, replacement_tool) {
200        return Err(err);
201    }
202    Ok(call_and_footer(response, replacement_tool))
203}
204
205/// Grep/rg variant: same decline handling, but the footer is the enforced
206/// code-search redirect, steering to `aft_search` when it's registered.
207fn try_call_and_grep_footer(response: Response, ctx: &AppContext) -> Result<Response, String> {
208    if let Some(err) = declined_error(&response, "grep") {
209        return Err(err);
210    }
211    let output = response_output(&response.data);
212    let footered = add_grep_footer(&output, ctx.config().aft_search_registered);
213    Ok(apply_footer(response, footered))
214}
215
216fn declined_error(response: &Response, replacement_tool: &str) -> Option<String> {
217    if response.success {
218        return None;
219    }
220    let message = response
221        .data
222        .get("message")
223        .and_then(Value::as_str)
224        .or_else(|| response.data.get("code").and_then(Value::as_str))
225        .unwrap_or("error");
226    Some(format!("{replacement_tool} declined: {message}"))
227}
228
229fn call_and_footer(response: Response, replacement_tool: &str) -> Response {
230    let output = response_output(&response.data);
231    let footered = add_footer(&output, replacement_tool);
232    apply_footer(response, footered)
233}
234
235fn apply_footer(mut response: Response, output: String) -> Response {
236    if let Some(object) = response.data.as_object_mut() {
237        object.insert("output".to_string(), Value::String(output.clone()));
238
239        for key in ["text", "content", "message"] {
240            if object.get(key).is_some_and(Value::is_string) {
241                object.insert(key.to_string(), Value::String(output.clone()));
242                break;
243            }
244        }
245    } else {
246        response.data = json!({ "output": output });
247    }
248
249    response
250}
251
252fn response_output(data: &Value) -> String {
253    if let Some(output) = data.get("output").and_then(Value::as_str) {
254        return output.to_string();
255    }
256    if let Some(text) = data.get("text").and_then(Value::as_str) {
257        return text.to_string();
258    }
259    if let Some(content) = data.get("content").and_then(Value::as_str) {
260        return content.to_string();
261    }
262    if let Some(message) = data.get("message").and_then(Value::as_str) {
263        return message.to_string();
264    }
265    if let Some(entries) = data.get("entries").and_then(Value::as_array) {
266        return entries
267            .iter()
268            .filter_map(Value::as_str)
269            .collect::<Vec<_>>()
270            .join("\n");
271    }
272    serde_json::to_string_pretty(data).unwrap_or_else(|_| data.to_string())
273}
274
275fn grep_request(command: &str, binary: &str) -> Option<Value> {
276    let parsed = parse(command)?;
277    if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != binary {
278        return None;
279    }
280
281    let mut case_sensitive = true;
282    let mut word_match = false;
283    let mut index = 1;
284
285    while let Some(arg) = parsed.args.get(index) {
286        if !arg.starts_with('-') || arg == "-" {
287            break;
288        }
289        for flag in arg[1..].chars() {
290            match flag {
291                'n' | 'r' => {}
292                'i' => case_sensitive = false,
293                'w' => word_match = true,
294                _ => return None,
295            }
296        }
297        index += 1;
298    }
299
300    let pattern = parsed.args.get(index)?.clone();
301    let path = parsed.args.get(index + 1).cloned();
302    if parsed.args.len() > index + 2 {
303        return None;
304    }
305
306    let pattern = if word_match {
307        format!(r"\b(?:{})\b", pattern)
308    } else {
309        pattern
310    };
311
312    if regex::RegexBuilder::new(&pattern)
313        .size_limit(REGEX_SIZE_LIMIT)
314        .build()
315        .is_err()
316    {
317        return None;
318    }
319
320    let mut params = json!({
321        "pattern": pattern,
322        "case_sensitive": case_sensitive,
323        "max_results": 100,
324    });
325    if let Some(path) = path {
326        params["path"] = json!(path);
327    }
328    Some(params)
329}
330
331fn find_request(command: &str) -> Option<Value> {
332    let parsed = parse(command)?;
333    if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != "find" {
334        return None;
335    }
336    if parsed.args.len() != 4 && parsed.args.len() != 6 {
337        return None;
338    }
339
340    let path = parsed.args.get(1)?.clone();
341    let mut name = None;
342    let mut saw_type_file = false;
343    let mut index = 2;
344
345    while index < parsed.args.len() {
346        match parsed.args[index].as_str() {
347            "-name" if name.is_none() && index + 1 < parsed.args.len() => {
348                name = Some(parsed.args[index + 1].clone());
349                index += 2;
350            }
351            "-type" if !saw_type_file && index + 1 < parsed.args.len() => {
352                if parsed.args[index + 1] != "f" {
353                    return None;
354                }
355                saw_type_file = true;
356                index += 2;
357            }
358            _ => return None,
359        }
360    }
361
362    let name = name?;
363    let pattern = format!("**/{name}");
364    if path == "." {
365        Some(json!({ "pattern": pattern }))
366    } else {
367        Some(json!({ "path": path.trim_end_matches('/'), "pattern": pattern }))
368    }
369}
370
371fn cat_read_request(command: &str) -> Option<Value> {
372    let parsed = parse(command)?;
373    if parsed.appends_to.is_some() || parsed.heredoc.is_some() {
374        return None;
375    }
376    if parsed.args.len() != 2 || parsed.args.first()? != "cat" {
377        return None;
378    }
379    Some(json!({ "file": parsed.args[1] }))
380}
381
382fn append_request(command: &str) -> Option<Value> {
383    let parsed = parse(command)?;
384    let file = parsed.appends_to.clone()?;
385
386    let append_content = if parsed.args == ["cat"] {
387        parsed.heredoc?
388    } else if parsed.heredoc.is_none()
389        && parsed.args.first().is_some_and(|arg| arg == "echo")
390        && parsed.args.len() >= 2
391        && !parsed.args[1].starts_with('-')
392    {
393        format!("{}\n", parsed.args[1..].join(" "))
394    } else {
395        return None;
396    };
397
398    Some(json!({
399        "op": "append",
400        "file": file,
401        "append_content": append_content,
402        "create_dirs": true,
403    }))
404}
405
406fn sed_request(command: &str) -> Option<Value> {
407    let parsed = parse(command)?;
408    if parsed.appends_to.is_some() || parsed.heredoc.is_some() {
409        return None;
410    }
411    if parsed.args.len() != 4 || parsed.args.first()? != "sed" || parsed.args[1] != "-n" {
412        return None;
413    }
414
415    let range = parsed.args[2].strip_suffix('p')?;
416    let (start, end) = range.split_once(',')?;
417    let start_line = start.parse::<u32>().ok()?;
418    let end_line = end.parse::<u32>().ok()?;
419    if start_line == 0 || end_line < start_line {
420        return None;
421    }
422
423    Some(json!({
424        "file": parsed.args[3],
425        "start_line": start_line,
426        "end_line": end_line,
427    }))
428}
429
430fn ls_request(command: &str) -> Option<Value> {
431    let parsed = parse(command)?;
432    if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != "ls" {
433        return None;
434    }
435
436    let mut path = None;
437    for arg in parsed.args.iter().skip(1) {
438        if let Some(flags) = arg.strip_prefix('-') {
439            if flags.is_empty() {
440                return None;
441            }
442            for flag in flags.chars() {
443                match flag {
444                    // -R: recursive listing — `read` of a directory is
445                    // single-level only, but the result is still a useful
446                    // approximation of "what's in this tree".
447                    // -a: show hidden files — `read` of a directory already
448                    // includes hidden files via fs::read_dir(), so this is
449                    // a no-op compared to plain `ls`.
450                    'R' | 'a' => {}
451                    // -l: long format. Shows size, mtime, permissions, owner.
452                    // `read` returns directory entries (no metadata) or file
453                    // contents (not metadata at all). Rewriting drops the
454                    // info the user asked for, so fall through to real bash.
455                    // Reported by user dogfooding the v0.18 bash experimentals.
456                    _ => return None,
457                }
458            }
459        } else if path.is_none() {
460            path = Some(arg.clone());
461        } else {
462            return None;
463        }
464    }
465
466    // Even without -l, `ls FILE` and `read FILE` have entirely different
467    // semantics: `ls FILE` echoes the filename, `read FILE` dumps the file
468    // contents. The rewrite is only safe when the path resolves to a
469    // directory (or is missing/cwd, where `read` of cwd also makes sense).
470    // Stat the path and fall through to bash for files.
471    let target = path.clone().unwrap_or_else(|| ".".to_string());
472    if let Ok(metadata) = std::fs::metadata(&target) {
473        if !metadata.is_dir() {
474            return None;
475        }
476    }
477    // Path doesn't exist (yet)? Let bash handle the error itself — its
478    // wording is well-known to agents and we don't gain anything by
479    // rewriting a guaranteed-failing rewrite call.
480    else if path.is_some() {
481        return None;
482    }
483
484    Some(json!({ "file": target }))
485}
486
487#[cfg(test)]
488mod tests {
489    use serde_json::json;
490
491    use super::find_request;
492
493    #[test]
494    fn find_absolute_path_uses_glob_path_arg() {
495        assert_eq!(
496            find_request(r#"find /tmp/foo -name "*.ts" -type f"#),
497            Some(json!({ "path": "/tmp/foo", "pattern": "**/*.ts" }))
498        );
499    }
500
501    #[test]
502    fn find_dot_keeps_project_root_relative_pattern() {
503        assert_eq!(
504            find_request(r#"find . -name "*.ts" -type f"#),
505            Some(json!({ "pattern": "**/*.ts" }))
506        );
507    }
508
509    #[test]
510    fn find_relative_path_uses_glob_path_arg() {
511        assert_eq!(
512            find_request(r#"find ./src -name "*.go""#),
513            Some(json!({ "path": "./src", "pattern": "**/*.go" }))
514        );
515    }
516
517    #[test]
518    fn find_trims_trailing_slash_from_path_arg() {
519        assert_eq!(
520            find_request(r#"find /tmp/foo/ -name "*.ts""#),
521            Some(json!({ "path": "/tmp/foo", "pattern": "**/*.ts" }))
522        );
523    }
524}