Skip to main content

aft/bash_rewrite/
rules.rs

1use serde_json::{json, Value};
2
3const REGEX_SIZE_LIMIT: usize = 10 * 1024 * 1024;
4
5use crate::bash_rewrite::footer::add_footer;
6use crate::bash_rewrite::parser::parse;
7use crate::bash_rewrite::RewriteRule;
8use crate::context::AppContext;
9use crate::protocol::{RawRequest, Response};
10
11pub struct GrepRule;
12pub struct RgRule;
13pub struct FindRule;
14pub struct CatRule;
15pub struct CatAppendRule;
16pub struct SedRule;
17pub struct LsRule;
18
19impl RewriteRule for GrepRule {
20    fn name(&self) -> &'static str {
21        "grep"
22    }
23
24    fn matches(&self, command: &str) -> bool {
25        grep_request(command, "grep").is_some()
26    }
27
28    fn rewrite(
29        &self,
30        command: &str,
31        session_id: Option<&str>,
32        ctx: &AppContext,
33    ) -> Result<Response, String> {
34        let params = grep_request(command, "grep").ok_or("not a grep rewrite")?;
35        try_call_and_footer(
36            crate::commands::grep::handle_grep(&request("grep", params, session_id), ctx),
37            command,
38            "grep",
39        )
40    }
41}
42
43impl RewriteRule for RgRule {
44    fn name(&self) -> &'static str {
45        "rg"
46    }
47
48    fn matches(&self, command: &str) -> bool {
49        grep_request(command, "rg").is_some()
50    }
51
52    fn rewrite(
53        &self,
54        command: &str,
55        session_id: Option<&str>,
56        ctx: &AppContext,
57    ) -> Result<Response, String> {
58        let params = grep_request(command, "rg").ok_or("not an rg rewrite")?;
59        try_call_and_footer(
60            crate::commands::grep::handle_grep(&request("grep", params, session_id), ctx),
61            command,
62            "grep",
63        )
64    }
65}
66
67impl RewriteRule for FindRule {
68    fn name(&self) -> &'static str {
69        "find"
70    }
71
72    fn matches(&self, command: &str) -> bool {
73        find_request(command).is_some()
74    }
75
76    fn rewrite(
77        &self,
78        command: &str,
79        session_id: Option<&str>,
80        ctx: &AppContext,
81    ) -> Result<Response, String> {
82        let params = find_request(command).ok_or("not a find rewrite")?;
83        try_call_and_footer(
84            crate::commands::glob::handle_glob(&request("glob", params, session_id), ctx),
85            command,
86            "glob",
87        )
88    }
89}
90
91impl RewriteRule for CatRule {
92    fn name(&self) -> &'static str {
93        "cat"
94    }
95
96    fn matches(&self, command: &str) -> bool {
97        cat_read_request(command).is_some()
98    }
99
100    fn rewrite(
101        &self,
102        command: &str,
103        session_id: Option<&str>,
104        ctx: &AppContext,
105    ) -> Result<Response, String> {
106        let params = cat_read_request(command).ok_or("not a cat rewrite")?;
107        try_call_and_footer(
108            crate::commands::read::handle_read(&request("read", params, session_id), ctx),
109            command,
110            "read",
111        )
112    }
113}
114
115impl RewriteRule for CatAppendRule {
116    fn name(&self) -> &'static str {
117        "cat_append"
118    }
119
120    fn matches(&self, command: &str) -> bool {
121        append_request(command).is_some()
122    }
123
124    fn rewrite(
125        &self,
126        command: &str,
127        session_id: Option<&str>,
128        ctx: &AppContext,
129    ) -> Result<Response, String> {
130        let params = append_request(command).ok_or("not an append rewrite")?;
131        try_call_and_footer(
132            crate::commands::edit_match::handle_edit_match(
133                &request("edit_match", params, session_id),
134                ctx,
135            ),
136            command,
137            "edit",
138        )
139    }
140}
141
142impl RewriteRule for SedRule {
143    fn name(&self) -> &'static str {
144        "sed"
145    }
146
147    fn matches(&self, command: &str) -> bool {
148        sed_request(command).is_some()
149    }
150
151    fn rewrite(
152        &self,
153        command: &str,
154        session_id: Option<&str>,
155        ctx: &AppContext,
156    ) -> Result<Response, String> {
157        let params = sed_request(command).ok_or("not a sed rewrite")?;
158        try_call_and_footer(
159            crate::commands::read::handle_read(&request("read", params, session_id), ctx),
160            command,
161            "read",
162        )
163    }
164}
165
166impl RewriteRule for LsRule {
167    fn name(&self) -> &'static str {
168        "ls"
169    }
170
171    fn matches(&self, command: &str) -> bool {
172        ls_request(command).is_some()
173    }
174
175    fn rewrite(
176        &self,
177        command: &str,
178        session_id: Option<&str>,
179        ctx: &AppContext,
180    ) -> Result<Response, String> {
181        let params = ls_request(command).ok_or("not an ls rewrite")?;
182        try_call_and_footer(
183            crate::commands::read::handle_read(&request("read", params, session_id), ctx),
184            command,
185            "read",
186        )
187    }
188}
189
190fn request(command: &str, params: Value, session_id: Option<&str>) -> RawRequest {
191    RawRequest {
192        id: "bash_rewrite".to_string(),
193        command: command.to_string(),
194        lsp_hints: None,
195        session_id: session_id.map(str::to_string),
196        params,
197    }
198}
199
200/// Run an underlying tool through the rewrite path. If the tool returned
201/// `success: false`, propagate as Err so dispatch falls through to actual bash
202/// — the agent's intent was bash, the rewrite is a transparent optimization.
203/// Returning a wrapped error response would surprise the agent (e.g. read's
204/// `outside project root` rejecting a sed that bash would have allowed).
205fn try_call_and_footer(
206    response: Response,
207    original_command: &str,
208    replacement_tool: &str,
209) -> Result<Response, String> {
210    if !response.success {
211        let message = response
212            .data
213            .get("message")
214            .and_then(Value::as_str)
215            .or_else(|| response.data.get("code").and_then(Value::as_str))
216            .unwrap_or("error");
217        return Err(format!("{} declined: {}", replacement_tool, message));
218    }
219    Ok(call_and_footer(
220        response,
221        original_command,
222        replacement_tool,
223    ))
224}
225
226fn call_and_footer(
227    mut response: Response,
228    original_command: &str,
229    replacement_tool: &str,
230) -> Response {
231    let output = response_output(&response.data);
232    let output = add_footer(&output, original_command, replacement_tool);
233
234    if let Some(object) = response.data.as_object_mut() {
235        object.insert("output".to_string(), Value::String(output.clone()));
236
237        for key in ["text", "content", "message"] {
238            if object.get(key).is_some_and(Value::is_string) {
239                object.insert(key.to_string(), Value::String(output.clone()));
240                break;
241            }
242        }
243    } else {
244        response.data = json!({ "output": output });
245    }
246
247    response
248}
249
250fn response_output(data: &Value) -> String {
251    if let Some(output) = data.get("output").and_then(Value::as_str) {
252        return output.to_string();
253    }
254    if let Some(text) = data.get("text").and_then(Value::as_str) {
255        return text.to_string();
256    }
257    if let Some(content) = data.get("content").and_then(Value::as_str) {
258        return content.to_string();
259    }
260    if let Some(message) = data.get("message").and_then(Value::as_str) {
261        return message.to_string();
262    }
263    if let Some(entries) = data.get("entries").and_then(Value::as_array) {
264        return entries
265            .iter()
266            .filter_map(Value::as_str)
267            .collect::<Vec<_>>()
268            .join("\n");
269    }
270    serde_json::to_string_pretty(data).unwrap_or_else(|_| data.to_string())
271}
272
273fn grep_request(command: &str, binary: &str) -> Option<Value> {
274    let parsed = parse(command)?;
275    if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != binary {
276        return None;
277    }
278
279    let mut case_sensitive = true;
280    let mut word_match = false;
281    let mut index = 1;
282
283    while let Some(arg) = parsed.args.get(index) {
284        if !arg.starts_with('-') || arg == "-" {
285            break;
286        }
287        for flag in arg[1..].chars() {
288            match flag {
289                'n' | 'r' => {}
290                'i' => case_sensitive = false,
291                'w' => word_match = true,
292                _ => return None,
293            }
294        }
295        index += 1;
296    }
297
298    let pattern = parsed.args.get(index)?.clone();
299    let path = parsed.args.get(index + 1).cloned();
300    if parsed.args.len() > index + 2 {
301        return None;
302    }
303
304    let pattern = if word_match {
305        format!(r"\b(?:{})\b", pattern)
306    } else {
307        pattern
308    };
309
310    if regex::RegexBuilder::new(&pattern)
311        .size_limit(REGEX_SIZE_LIMIT)
312        .build()
313        .is_err()
314    {
315        return None;
316    }
317
318    let mut params = json!({
319        "pattern": pattern,
320        "case_sensitive": case_sensitive,
321        "max_results": 100,
322    });
323    if let Some(path) = path {
324        params["path"] = json!(path);
325    }
326    Some(params)
327}
328
329fn find_request(command: &str) -> Option<Value> {
330    let parsed = parse(command)?;
331    if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != "find" {
332        return None;
333    }
334    if parsed.args.len() != 4 && parsed.args.len() != 6 {
335        return None;
336    }
337
338    let path = parsed.args.get(1)?.clone();
339    let mut name = None;
340    let mut saw_type_file = false;
341    let mut index = 2;
342
343    while index < parsed.args.len() {
344        match parsed.args[index].as_str() {
345            "-name" if name.is_none() && index + 1 < parsed.args.len() => {
346                name = Some(parsed.args[index + 1].clone());
347                index += 2;
348            }
349            "-type" if !saw_type_file && index + 1 < parsed.args.len() => {
350                if parsed.args[index + 1] != "f" {
351                    return None;
352                }
353                saw_type_file = true;
354                index += 2;
355            }
356            _ => return None,
357        }
358    }
359
360    let name = name?;
361    let pattern = if path == "." {
362        format!("**/{name}")
363    } else {
364        format!("{}/**/{name}", path.trim_end_matches('/'))
365    };
366
367    Some(json!({ "pattern": pattern }))
368}
369
370fn cat_read_request(command: &str) -> Option<Value> {
371    let parsed = parse(command)?;
372    if parsed.appends_to.is_some() || parsed.heredoc.is_some() {
373        return None;
374    }
375    if parsed.args.len() != 2 || parsed.args.first()? != "cat" {
376        return None;
377    }
378    Some(json!({ "file": parsed.args[1] }))
379}
380
381fn append_request(command: &str) -> Option<Value> {
382    let parsed = parse(command)?;
383    let file = parsed.appends_to.clone()?;
384
385    let append_content = if parsed.args == ["cat"] {
386        parsed.heredoc?
387    } else if parsed.heredoc.is_none()
388        && parsed.args.first().is_some_and(|arg| arg == "echo")
389        && parsed.args.len() >= 2
390        && !parsed.args[1].starts_with('-')
391    {
392        format!("{}\n", parsed.args[1..].join(" "))
393    } else {
394        return None;
395    };
396
397    Some(json!({
398        "op": "append",
399        "file": file,
400        "append_content": append_content,
401        "create_dirs": true,
402    }))
403}
404
405fn sed_request(command: &str) -> Option<Value> {
406    let parsed = parse(command)?;
407    if parsed.appends_to.is_some() || parsed.heredoc.is_some() {
408        return None;
409    }
410    if parsed.args.len() != 4 || parsed.args.first()? != "sed" || parsed.args[1] != "-n" {
411        return None;
412    }
413
414    let range = parsed.args[2].strip_suffix('p')?;
415    let (start, end) = range.split_once(',')?;
416    let start_line = start.parse::<u32>().ok()?;
417    let end_line = end.parse::<u32>().ok()?;
418    if start_line == 0 || end_line < start_line {
419        return None;
420    }
421
422    Some(json!({
423        "file": parsed.args[3],
424        "start_line": start_line,
425        "end_line": end_line,
426    }))
427}
428
429fn ls_request(command: &str) -> Option<Value> {
430    let parsed = parse(command)?;
431    if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != "ls" {
432        return None;
433    }
434
435    let mut path = None;
436    for arg in parsed.args.iter().skip(1) {
437        if let Some(flags) = arg.strip_prefix('-') {
438            if flags.is_empty() || !flags.chars().all(|flag| matches!(flag, 'l' | 'R' | 'a')) {
439                return None;
440            }
441        } else if path.is_none() {
442            path = Some(arg.clone());
443        } else {
444            return None;
445        }
446    }
447
448    Some(json!({ "file": path.unwrap_or_else(|| ".".to_string()) }))
449}