Skip to main content

aft/bash_rewrite/
rules.rs

1use serde_json::{json, Value};
2
3const REGEX_SIZE_LIMIT: usize = 10 * 1024 * 1024;
4
5use crate::bash_rewrite::footer::add_footer;
6use crate::bash_rewrite::parser::parse;
7use crate::bash_rewrite::RewriteRule;
8use crate::context::AppContext;
9use crate::protocol::{RawRequest, Response};
10
11pub struct GrepRule;
12pub struct RgRule;
13pub struct FindRule;
14pub struct CatRule;
15pub struct CatAppendRule;
16pub struct SedRule;
17pub struct LsRule;
18
19impl RewriteRule for GrepRule {
20    fn name(&self) -> &'static str {
21        "grep"
22    }
23
24    fn matches(&self, command: &str) -> bool {
25        grep_request(command, "grep").is_some()
26    }
27
28    fn rewrite(
29        &self,
30        command: &str,
31        session_id: Option<&str>,
32        ctx: &AppContext,
33    ) -> Result<Response, String> {
34        let params = grep_request(command, "grep").ok_or("not a grep rewrite")?;
35        try_call_and_footer(crate::commands::grep::handle_grep(&request("grep", params, session_id), ctx), "grep")
36    }
37}
38
39impl RewriteRule for RgRule {
40    fn name(&self) -> &'static str {
41        "rg"
42    }
43
44    fn matches(&self, command: &str) -> bool {
45        grep_request(command, "rg").is_some()
46    }
47
48    fn rewrite(
49        &self,
50        command: &str,
51        session_id: Option<&str>,
52        ctx: &AppContext,
53    ) -> Result<Response, String> {
54        let params = grep_request(command, "rg").ok_or("not an rg rewrite")?;
55        try_call_and_footer(crate::commands::grep::handle_grep(&request("grep", params, session_id), ctx), "grep")
56    }
57}
58
59impl RewriteRule for FindRule {
60    fn name(&self) -> &'static str {
61        "find"
62    }
63
64    fn matches(&self, command: &str) -> bool {
65        find_request(command).is_some()
66    }
67
68    fn rewrite(
69        &self,
70        command: &str,
71        session_id: Option<&str>,
72        ctx: &AppContext,
73    ) -> Result<Response, String> {
74        let params = find_request(command).ok_or("not a find rewrite")?;
75        try_call_and_footer(crate::commands::glob::handle_glob(&request("glob", params, session_id), ctx), "glob")
76    }
77}
78
79impl RewriteRule for CatRule {
80    fn name(&self) -> &'static str {
81        "cat"
82    }
83
84    fn matches(&self, command: &str) -> bool {
85        cat_read_request(command).is_some()
86    }
87
88    fn rewrite(
89        &self,
90        command: &str,
91        session_id: Option<&str>,
92        ctx: &AppContext,
93    ) -> Result<Response, String> {
94        let params = cat_read_request(command).ok_or("not a cat rewrite")?;
95        try_call_and_footer(crate::commands::read::handle_read(&request("read", params, session_id), ctx), "read")
96    }
97}
98
99impl RewriteRule for CatAppendRule {
100    fn name(&self) -> &'static str {
101        "cat_append"
102    }
103
104    fn matches(&self, command: &str) -> bool {
105        append_request(command).is_some()
106    }
107
108    fn rewrite(
109        &self,
110        command: &str,
111        session_id: Option<&str>,
112        ctx: &AppContext,
113    ) -> Result<Response, String> {
114        let params = append_request(command).ok_or("not an append rewrite")?;
115        try_call_and_footer(crate::commands::edit_match::handle_edit_match(
116            &request("edit_match", params, session_id),
117            ctx,
118        ), "edit")
119    }
120}
121
122impl RewriteRule for SedRule {
123    fn name(&self) -> &'static str {
124        "sed"
125    }
126
127    fn matches(&self, command: &str) -> bool {
128        sed_request(command).is_some()
129    }
130
131    fn rewrite(
132        &self,
133        command: &str,
134        session_id: Option<&str>,
135        ctx: &AppContext,
136    ) -> Result<Response, String> {
137        let params = sed_request(command).ok_or("not a sed rewrite")?;
138        try_call_and_footer(crate::commands::read::handle_read(&request("read", params, session_id), ctx), "read")
139    }
140}
141
142impl RewriteRule for LsRule {
143    fn name(&self) -> &'static str {
144        "ls"
145    }
146
147    fn matches(&self, command: &str) -> bool {
148        ls_request(command).is_some()
149    }
150
151    fn rewrite(
152        &self,
153        command: &str,
154        session_id: Option<&str>,
155        ctx: &AppContext,
156    ) -> Result<Response, String> {
157        let params = ls_request(command).ok_or("not an ls rewrite")?;
158        try_call_and_footer(crate::commands::read::handle_read(&request("read", params, session_id), ctx), "read")
159    }
160}
161
162fn request(command: &str, params: Value, session_id: Option<&str>) -> RawRequest {
163    RawRequest {
164        id: "bash_rewrite".to_string(),
165        command: command.to_string(),
166        lsp_hints: None,
167        session_id: session_id.map(str::to_string),
168        params,
169    }
170}
171
172/// Run an underlying tool through the rewrite path. If the tool returned
173/// `success: false`, propagate as Err so dispatch falls through to actual bash
174/// — the agent's intent was bash, the rewrite is a transparent optimization.
175/// Returning a wrapped error response would surprise the agent (e.g. read's
176/// `outside project root` rejecting a sed that bash would have allowed).
177fn try_call_and_footer(response: Response, replacement_tool: &str) -> Result<Response, String> {
178    if !response.success {
179        let message = response
180            .data
181            .get("message")
182            .and_then(Value::as_str)
183            .or_else(|| response.data.get("code").and_then(Value::as_str))
184            .unwrap_or("error");
185        return Err(format!("{} declined: {}", replacement_tool, message));
186    }
187    Ok(call_and_footer(response, replacement_tool))
188}
189
190fn call_and_footer(mut response: Response, replacement_tool: &str) -> Response {
191    let output = response_output(&response.data);
192    let output = add_footer(&output, replacement_tool);
193
194    if let Some(object) = response.data.as_object_mut() {
195        object.insert("output".to_string(), Value::String(output.clone()));
196
197        for key in ["text", "content", "message"] {
198            if object.get(key).is_some_and(Value::is_string) {
199                object.insert(key.to_string(), Value::String(output.clone()));
200                break;
201            }
202        }
203    } else {
204        response.data = json!({ "output": output });
205    }
206
207    response
208}
209
210fn response_output(data: &Value) -> String {
211    if let Some(output) = data.get("output").and_then(Value::as_str) {
212        return output.to_string();
213    }
214    if let Some(text) = data.get("text").and_then(Value::as_str) {
215        return text.to_string();
216    }
217    if let Some(content) = data.get("content").and_then(Value::as_str) {
218        return content.to_string();
219    }
220    if let Some(message) = data.get("message").and_then(Value::as_str) {
221        return message.to_string();
222    }
223    if let Some(entries) = data.get("entries").and_then(Value::as_array) {
224        return entries
225            .iter()
226            .filter_map(Value::as_str)
227            .collect::<Vec<_>>()
228            .join("\n");
229    }
230    serde_json::to_string_pretty(data).unwrap_or_else(|_| data.to_string())
231}
232
233fn grep_request(command: &str, binary: &str) -> Option<Value> {
234    let parsed = parse(command)?;
235    if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != binary {
236        return None;
237    }
238
239    let mut case_sensitive = true;
240    let mut word_match = false;
241    let mut index = 1;
242
243    while let Some(arg) = parsed.args.get(index) {
244        if !arg.starts_with('-') || arg == "-" {
245            break;
246        }
247        for flag in arg[1..].chars() {
248            match flag {
249                'n' | 'r' => {}
250                'i' => case_sensitive = false,
251                'w' => word_match = true,
252                _ => return None,
253            }
254        }
255        index += 1;
256    }
257
258    let pattern = parsed.args.get(index)?.clone();
259    let path = parsed.args.get(index + 1).cloned();
260    if parsed.args.len() > index + 2 {
261        return None;
262    }
263
264    let pattern = if word_match {
265        format!(r"\b(?:{})\b", pattern)
266    } else {
267        pattern
268    };
269
270    if regex::RegexBuilder::new(&pattern)
271        .size_limit(REGEX_SIZE_LIMIT)
272        .build()
273        .is_err()
274    {
275        return None;
276    }
277
278    let mut params = json!({
279        "pattern": pattern,
280        "case_sensitive": case_sensitive,
281        "max_results": 100,
282    });
283    if let Some(path) = path {
284        params["path"] = json!(path);
285    }
286    Some(params)
287}
288
289fn find_request(command: &str) -> Option<Value> {
290    let parsed = parse(command)?;
291    if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != "find" {
292        return None;
293    }
294    if parsed.args.len() != 4 && parsed.args.len() != 6 {
295        return None;
296    }
297
298    let path = parsed.args.get(1)?.clone();
299    let mut name = None;
300    let mut saw_type_file = false;
301    let mut index = 2;
302
303    while index < parsed.args.len() {
304        match parsed.args[index].as_str() {
305            "-name" if name.is_none() && index + 1 < parsed.args.len() => {
306                name = Some(parsed.args[index + 1].clone());
307                index += 2;
308            }
309            "-type" if !saw_type_file && index + 1 < parsed.args.len() => {
310                if parsed.args[index + 1] != "f" {
311                    return None;
312                }
313                saw_type_file = true;
314                index += 2;
315            }
316            _ => return None,
317        }
318    }
319
320    let name = name?;
321    let pattern = if path == "." {
322        format!("**/{name}")
323    } else {
324        format!("{}/**/{name}", path.trim_end_matches('/'))
325    };
326
327    Some(json!({ "pattern": pattern }))
328}
329
330fn cat_read_request(command: &str) -> Option<Value> {
331    let parsed = parse(command)?;
332    if parsed.appends_to.is_some() || parsed.heredoc.is_some() {
333        return None;
334    }
335    if parsed.args.len() != 2 || parsed.args.first()? != "cat" {
336        return None;
337    }
338    Some(json!({ "file": parsed.args[1] }))
339}
340
341fn append_request(command: &str) -> Option<Value> {
342    let parsed = parse(command)?;
343    let file = parsed.appends_to.clone()?;
344
345    let append_content = if parsed.args == ["cat"] {
346        parsed.heredoc?
347    } else if parsed.heredoc.is_none()
348        && parsed.args.first().is_some_and(|arg| arg == "echo")
349        && parsed.args.len() >= 2
350        && !parsed.args[1].starts_with('-')
351    {
352        format!("{}\n", parsed.args[1..].join(" "))
353    } else {
354        return None;
355    };
356
357    Some(json!({
358        "op": "append",
359        "file": file,
360        "append_content": append_content,
361        "create_dirs": true,
362    }))
363}
364
365fn sed_request(command: &str) -> Option<Value> {
366    let parsed = parse(command)?;
367    if parsed.appends_to.is_some() || parsed.heredoc.is_some() {
368        return None;
369    }
370    if parsed.args.len() != 4 || parsed.args.first()? != "sed" || parsed.args[1] != "-n" {
371        return None;
372    }
373
374    let range = parsed.args[2].strip_suffix('p')?;
375    let (start, end) = range.split_once(',')?;
376    let start_line = start.parse::<u32>().ok()?;
377    let end_line = end.parse::<u32>().ok()?;
378    if start_line == 0 || end_line < start_line {
379        return None;
380    }
381
382    Some(json!({
383        "file": parsed.args[3],
384        "start_line": start_line,
385        "end_line": end_line,
386    }))
387}
388
389fn ls_request(command: &str) -> Option<Value> {
390    let parsed = parse(command)?;
391    if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != "ls" {
392        return None;
393    }
394
395    let mut path = None;
396    for arg in parsed.args.iter().skip(1) {
397        if let Some(flags) = arg.strip_prefix('-') {
398            if flags.is_empty() || !flags.chars().all(|flag| matches!(flag, 'l' | 'R' | 'a')) {
399                return None;
400            }
401        } else if path.is_none() {
402            path = Some(arg.clone());
403        } else {
404            return None;
405        }
406    }
407
408    Some(json!({ "file": path.unwrap_or_else(|| ".".to_string()) }))
409}