Skip to main content

aft/bash_rewrite/
rules.rs

1use serde_json::{json, Value};
2
3use crate::bash_rewrite::footer::add_footer;
4use crate::bash_rewrite::parser::parse;
5use crate::bash_rewrite::RewriteRule;
6use crate::context::AppContext;
7use crate::protocol::{RawRequest, Response};
8
9pub struct GrepRule;
10pub struct RgRule;
11pub struct FindRule;
12pub struct CatRule;
13pub struct CatAppendRule;
14pub struct SedRule;
15pub struct LsRule;
16
17impl RewriteRule for GrepRule {
18    fn name(&self) -> &'static str {
19        "grep"
20    }
21
22    fn matches(&self, command: &str) -> bool {
23        grep_request(command, "grep").is_some()
24    }
25
26    fn rewrite(
27        &self,
28        command: &str,
29        session_id: Option<&str>,
30        ctx: &AppContext,
31    ) -> Result<Response, String> {
32        let params = grep_request(command, "grep").ok_or("not a grep rewrite")?;
33        try_call_and_footer(
34            crate::commands::grep::handle_grep(&request("grep", params, session_id), ctx),
35            command,
36            "grep",
37        )
38    }
39}
40
41impl RewriteRule for RgRule {
42    fn name(&self) -> &'static str {
43        "rg"
44    }
45
46    fn matches(&self, command: &str) -> bool {
47        grep_request(command, "rg").is_some()
48    }
49
50    fn rewrite(
51        &self,
52        command: &str,
53        session_id: Option<&str>,
54        ctx: &AppContext,
55    ) -> Result<Response, String> {
56        let params = grep_request(command, "rg").ok_or("not an rg rewrite")?;
57        try_call_and_footer(
58            crate::commands::grep::handle_grep(&request("grep", params, session_id), ctx),
59            command,
60            "grep",
61        )
62    }
63}
64
65impl RewriteRule for FindRule {
66    fn name(&self) -> &'static str {
67        "find"
68    }
69
70    fn matches(&self, command: &str) -> bool {
71        find_request(command).is_some()
72    }
73
74    fn rewrite(
75        &self,
76        command: &str,
77        session_id: Option<&str>,
78        ctx: &AppContext,
79    ) -> Result<Response, String> {
80        let params = find_request(command).ok_or("not a find rewrite")?;
81        try_call_and_footer(
82            crate::commands::glob::handle_glob(&request("glob", params, session_id), ctx),
83            command,
84            "glob",
85        )
86    }
87}
88
89impl RewriteRule for CatRule {
90    fn name(&self) -> &'static str {
91        "cat"
92    }
93
94    fn matches(&self, command: &str) -> bool {
95        cat_read_request(command).is_some()
96    }
97
98    fn rewrite(
99        &self,
100        command: &str,
101        session_id: Option<&str>,
102        ctx: &AppContext,
103    ) -> Result<Response, String> {
104        let params = cat_read_request(command).ok_or("not a cat rewrite")?;
105        try_call_and_footer(
106            crate::commands::read::handle_read(&request("read", params, session_id), ctx),
107            command,
108            "read",
109        )
110    }
111}
112
113impl RewriteRule for CatAppendRule {
114    fn name(&self) -> &'static str {
115        "cat_append"
116    }
117
118    fn matches(&self, command: &str) -> bool {
119        append_request(command).is_some()
120    }
121
122    fn rewrite(
123        &self,
124        command: &str,
125        session_id: Option<&str>,
126        ctx: &AppContext,
127    ) -> Result<Response, String> {
128        let params = append_request(command).ok_or("not an append rewrite")?;
129        try_call_and_footer(
130            crate::commands::edit_match::handle_edit_match(
131                &request("edit_match", params, session_id),
132                ctx,
133            ),
134            command,
135            "edit",
136        )
137    }
138}
139
140impl RewriteRule for SedRule {
141    fn name(&self) -> &'static str {
142        "sed"
143    }
144
145    fn matches(&self, command: &str) -> bool {
146        sed_request(command).is_some()
147    }
148
149    fn rewrite(
150        &self,
151        command: &str,
152        session_id: Option<&str>,
153        ctx: &AppContext,
154    ) -> Result<Response, String> {
155        let params = sed_request(command).ok_or("not a sed rewrite")?;
156        try_call_and_footer(
157            crate::commands::read::handle_read(&request("read", params, session_id), ctx),
158            command,
159            "read",
160        )
161    }
162}
163
164impl RewriteRule for LsRule {
165    fn name(&self) -> &'static str {
166        "ls"
167    }
168
169    fn matches(&self, command: &str) -> bool {
170        ls_request(command).is_some()
171    }
172
173    fn rewrite(
174        &self,
175        command: &str,
176        session_id: Option<&str>,
177        ctx: &AppContext,
178    ) -> Result<Response, String> {
179        let params = ls_request(command).ok_or("not an ls rewrite")?;
180        try_call_and_footer(
181            crate::commands::read::handle_read(&request("read", params, session_id), ctx),
182            command,
183            "read",
184        )
185    }
186}
187
188fn request(command: &str, params: Value, session_id: Option<&str>) -> RawRequest {
189    RawRequest {
190        id: "bash_rewrite".to_string(),
191        command: command.to_string(),
192        lsp_hints: None,
193        session_id: session_id.map(str::to_string),
194        params,
195    }
196}
197
198/// Run an underlying tool through the rewrite path. If the tool returned
199/// `success: false`, propagate as Err so dispatch falls through to actual bash
200/// — the agent's intent was bash, the rewrite is a transparent optimization.
201/// Returning a wrapped error response would surprise the agent (e.g. read's
202/// `outside project root` rejecting a sed that bash would have allowed).
203fn try_call_and_footer(
204    response: Response,
205    original_command: &str,
206    replacement_tool: &str,
207) -> Result<Response, String> {
208    if !response.success {
209        let message = response
210            .data
211            .get("message")
212            .and_then(Value::as_str)
213            .or_else(|| response.data.get("code").and_then(Value::as_str))
214            .unwrap_or("error");
215        return Err(format!("{} declined: {}", replacement_tool, message));
216    }
217    Ok(call_and_footer(
218        response,
219        original_command,
220        replacement_tool,
221    ))
222}
223
224fn call_and_footer(
225    mut response: Response,
226    original_command: &str,
227    replacement_tool: &str,
228) -> Response {
229    let output = response_output(&response.data);
230    let output = add_footer(&output, original_command, replacement_tool);
231
232    if let Some(object) = response.data.as_object_mut() {
233        object.insert("output".to_string(), Value::String(output.clone()));
234
235        for key in ["text", "content", "message"] {
236            if object.get(key).is_some_and(Value::is_string) {
237                object.insert(key.to_string(), Value::String(output.clone()));
238                break;
239            }
240        }
241    } else {
242        response.data = json!({ "output": output });
243    }
244
245    response
246}
247
248fn response_output(data: &Value) -> String {
249    if let Some(output) = data.get("output").and_then(Value::as_str) {
250        return output.to_string();
251    }
252    if let Some(text) = data.get("text").and_then(Value::as_str) {
253        return text.to_string();
254    }
255    if let Some(content) = data.get("content").and_then(Value::as_str) {
256        return content.to_string();
257    }
258    if let Some(message) = data.get("message").and_then(Value::as_str) {
259        return message.to_string();
260    }
261    if let Some(entries) = data.get("entries").and_then(Value::as_array) {
262        return entries
263            .iter()
264            .filter_map(Value::as_str)
265            .collect::<Vec<_>>()
266            .join("\n");
267    }
268    serde_json::to_string_pretty(data).unwrap_or_else(|_| data.to_string())
269}
270
271fn grep_request(command: &str, binary: &str) -> Option<Value> {
272    let parsed = parse(command)?;
273    if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != binary {
274        return None;
275    }
276
277    let mut case_sensitive = true;
278    let mut word_match = false;
279    let mut index = 1;
280
281    while let Some(arg) = parsed.args.get(index) {
282        if !arg.starts_with('-') || arg == "-" {
283            break;
284        }
285        for flag in arg[1..].chars() {
286            match flag {
287                'n' | 'r' => {}
288                'i' => case_sensitive = false,
289                'w' => word_match = true,
290                _ => return None,
291            }
292        }
293        index += 1;
294    }
295
296    let pattern = parsed.args.get(index)?.clone();
297    let path = parsed.args.get(index + 1).cloned();
298    if parsed.args.len() > index + 2 {
299        return None;
300    }
301
302    let pattern = if word_match {
303        format!(r"\b(?:{})\b", pattern)
304    } else {
305        pattern
306    };
307
308    let mut params = json!({
309        "pattern": pattern,
310        "case_sensitive": case_sensitive,
311        "max_results": 100,
312    });
313    if let Some(path) = path {
314        params["path"] = json!(path);
315    }
316    Some(params)
317}
318
319fn find_request(command: &str) -> Option<Value> {
320    let parsed = parse(command)?;
321    if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != "find" {
322        return None;
323    }
324    if parsed.args.len() != 4 && parsed.args.len() != 6 {
325        return None;
326    }
327
328    let path = parsed.args.get(1)?.clone();
329    let mut name = None;
330    let mut saw_type_file = false;
331    let mut index = 2;
332
333    while index < parsed.args.len() {
334        match parsed.args[index].as_str() {
335            "-name" if name.is_none() && index + 1 < parsed.args.len() => {
336                name = Some(parsed.args[index + 1].clone());
337                index += 2;
338            }
339            "-type" if !saw_type_file && index + 1 < parsed.args.len() => {
340                if parsed.args[index + 1] != "f" {
341                    return None;
342                }
343                saw_type_file = true;
344                index += 2;
345            }
346            _ => return None,
347        }
348    }
349
350    let name = name?;
351    let pattern = if path == "." {
352        format!("**/{name}")
353    } else {
354        format!("{}/**/{name}", path.trim_end_matches('/'))
355    };
356
357    Some(json!({ "pattern": pattern }))
358}
359
360fn cat_read_request(command: &str) -> Option<Value> {
361    let parsed = parse(command)?;
362    if parsed.appends_to.is_some() || parsed.heredoc.is_some() {
363        return None;
364    }
365    if parsed.args.len() != 2 || parsed.args.first()? != "cat" {
366        return None;
367    }
368    Some(json!({ "file": parsed.args[1] }))
369}
370
371fn append_request(command: &str) -> Option<Value> {
372    let parsed = parse(command)?;
373    let file = parsed.appends_to.clone()?;
374
375    let append_content = if parsed.args == ["cat"] {
376        parsed.heredoc?
377    } else if parsed.heredoc.is_none()
378        && parsed.args.first().is_some_and(|arg| arg == "echo")
379        && parsed.args.len() >= 2
380        && !parsed.args[1].starts_with('-')
381    {
382        format!("{}\n", parsed.args[1..].join(" "))
383    } else {
384        return None;
385    };
386
387    Some(json!({
388        "op": "append",
389        "file": file,
390        "append_content": append_content,
391        "create_dirs": true,
392    }))
393}
394
395fn sed_request(command: &str) -> Option<Value> {
396    let parsed = parse(command)?;
397    if parsed.appends_to.is_some() || parsed.heredoc.is_some() {
398        return None;
399    }
400    if parsed.args.len() != 4 || parsed.args.first()? != "sed" || parsed.args[1] != "-n" {
401        return None;
402    }
403
404    let range = parsed.args[2].strip_suffix('p')?;
405    let (start, end) = range.split_once(',')?;
406    let start_line = start.parse::<u32>().ok()?;
407    let end_line = end.parse::<u32>().ok()?;
408    if start_line == 0 || end_line < start_line {
409        return None;
410    }
411
412    Some(json!({
413        "file": parsed.args[3],
414        "start_line": start_line,
415        "end_line": end_line,
416    }))
417}
418
419fn ls_request(command: &str) -> Option<Value> {
420    let parsed = parse(command)?;
421    if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != "ls" {
422        return None;
423    }
424
425    let mut path = None;
426    for arg in parsed.args.iter().skip(1) {
427        if let Some(flags) = arg.strip_prefix('-') {
428            if flags.is_empty() || !flags.chars().all(|flag| matches!(flag, 'l' | 'R' | 'a')) {
429                return None;
430            }
431        } else if path.is_none() {
432            path = Some(arg.clone());
433        } else {
434            return None;
435        }
436    }
437
438    Some(json!({ "file": path.unwrap_or_else(|| ".".to_string()) }))
439}