1use serde_json::{json, Value};
2
3const REGEX_SIZE_LIMIT: usize = 10 * 1024 * 1024;
4
5use crate::bash_rewrite::footer::{add_footer, add_grep_footer};
6use crate::bash_rewrite::parser::parse;
7use crate::bash_rewrite::RewriteRule;
8use crate::context::AppContext;
9use crate::protocol::{RawRequest, Response};
10
11pub struct GrepRule;
12pub struct RgRule;
13pub struct FindRule;
14pub struct CatRule;
15pub struct CatAppendRule;
16pub struct SedRule;
17pub struct LsRule;
18
19impl RewriteRule for GrepRule {
20 fn name(&self) -> &'static str {
21 "grep"
22 }
23
24 fn matches(&self, command: &str) -> bool {
25 grep_request(command, "grep").is_some()
26 }
27
28 fn rewrite(
29 &self,
30 command: &str,
31 session_id: Option<&str>,
32 ctx: &AppContext,
33 ) -> Result<Response, String> {
34 let params = grep_request(command, "grep").ok_or("not a grep rewrite")?;
35 try_call_and_grep_footer(
36 crate::commands::grep::handle_grep(&request("grep", params, session_id), ctx),
37 ctx,
38 )
39 }
40}
41
42impl RewriteRule for RgRule {
43 fn name(&self) -> &'static str {
44 "rg"
45 }
46
47 fn matches(&self, command: &str) -> bool {
48 grep_request(command, "rg").is_some()
49 }
50
51 fn rewrite(
52 &self,
53 command: &str,
54 session_id: Option<&str>,
55 ctx: &AppContext,
56 ) -> Result<Response, String> {
57 let params = grep_request(command, "rg").ok_or("not an rg rewrite")?;
58 try_call_and_grep_footer(
59 crate::commands::grep::handle_grep(&request("grep", params, session_id), ctx),
60 ctx,
61 )
62 }
63}
64
65impl RewriteRule for FindRule {
66 fn name(&self) -> &'static str {
67 "find"
68 }
69
70 fn matches(&self, command: &str) -> bool {
71 find_request(command).is_some()
72 }
73
74 fn rewrite(
75 &self,
76 command: &str,
77 session_id: Option<&str>,
78 ctx: &AppContext,
79 ) -> Result<Response, String> {
80 let params = find_request(command).ok_or("not a find rewrite")?;
81 try_call_and_footer(
82 crate::commands::glob::handle_glob(&request("glob", params, session_id), ctx),
83 "glob",
84 )
85 }
86}
87
88impl RewriteRule for CatRule {
89 fn name(&self) -> &'static str {
90 "cat"
91 }
92
93 fn matches(&self, command: &str) -> bool {
94 cat_read_request(command).is_some()
95 }
96
97 fn rewrite(
98 &self,
99 command: &str,
100 session_id: Option<&str>,
101 ctx: &AppContext,
102 ) -> Result<Response, String> {
103 let params = cat_read_request(command).ok_or("not a cat rewrite")?;
104 try_call_and_footer(
105 crate::commands::read::handle_read(&request("read", params, session_id), ctx),
106 "read",
107 )
108 }
109}
110
111impl RewriteRule for CatAppendRule {
112 fn name(&self) -> &'static str {
113 "cat_append"
114 }
115
116 fn matches(&self, command: &str) -> bool {
117 append_request(command).is_some()
118 }
119
120 fn rewrite(
121 &self,
122 command: &str,
123 session_id: Option<&str>,
124 ctx: &AppContext,
125 ) -> Result<Response, String> {
126 let params = append_request(command).ok_or("not an append rewrite")?;
127 try_call_and_footer(
128 crate::commands::edit_match::handle_edit_match(
129 &request("edit_match", params, session_id),
130 ctx,
131 ),
132 "edit",
133 )
134 }
135}
136
137impl RewriteRule for SedRule {
138 fn name(&self) -> &'static str {
139 "sed"
140 }
141
142 fn matches(&self, command: &str) -> bool {
143 sed_request(command).is_some()
144 }
145
146 fn rewrite(
147 &self,
148 command: &str,
149 session_id: Option<&str>,
150 ctx: &AppContext,
151 ) -> Result<Response, String> {
152 let params = sed_request(command).ok_or("not a sed rewrite")?;
153 try_call_and_footer(
154 crate::commands::read::handle_read(&request("read", params, session_id), ctx),
155 "read",
156 )
157 }
158}
159
160impl RewriteRule for LsRule {
161 fn name(&self) -> &'static str {
162 "ls"
163 }
164
165 fn matches(&self, command: &str) -> bool {
166 ls_request(command).is_some()
167 }
168
169 fn rewrite(
170 &self,
171 command: &str,
172 session_id: Option<&str>,
173 ctx: &AppContext,
174 ) -> Result<Response, String> {
175 let params = ls_request(command).ok_or("not an ls rewrite")?;
176 try_call_and_footer(
177 crate::commands::read::handle_read(&request("read", params, session_id), ctx),
178 "read",
179 )
180 }
181}
182
183fn request(command: &str, params: Value, session_id: Option<&str>) -> RawRequest {
184 RawRequest {
185 id: "bash_rewrite".to_string(),
186 command: command.to_string(),
187 lsp_hints: None,
188 session_id: session_id.map(str::to_string),
189 params,
190 }
191}
192
193fn try_call_and_footer(response: Response, replacement_tool: &str) -> Result<Response, String> {
199 if let Some(err) = declined_error(&response, replacement_tool) {
200 return Err(err);
201 }
202 Ok(call_and_footer(response, replacement_tool))
203}
204
205fn try_call_and_grep_footer(response: Response, ctx: &AppContext) -> Result<Response, String> {
208 if let Some(err) = declined_error(&response, "grep") {
209 return Err(err);
210 }
211 let output = response_output(&response.data);
212 let footered = add_grep_footer(&output, ctx.config().aft_search_registered);
213 Ok(apply_footer(response, footered))
214}
215
216fn declined_error(response: &Response, replacement_tool: &str) -> Option<String> {
217 if response.success {
218 return None;
219 }
220 let message = response
221 .data
222 .get("message")
223 .and_then(Value::as_str)
224 .or_else(|| response.data.get("code").and_then(Value::as_str))
225 .unwrap_or("error");
226 Some(format!("{replacement_tool} declined: {message}"))
227}
228
229fn call_and_footer(response: Response, replacement_tool: &str) -> Response {
230 let output = response_output(&response.data);
231 let footered = add_footer(&output, replacement_tool);
232 apply_footer(response, footered)
233}
234
235fn apply_footer(mut response: Response, output: String) -> Response {
236 if let Some(object) = response.data.as_object_mut() {
237 object.insert("output".to_string(), Value::String(output.clone()));
238
239 for key in ["text", "content", "message"] {
240 if object.get(key).is_some_and(Value::is_string) {
241 object.insert(key.to_string(), Value::String(output.clone()));
242 break;
243 }
244 }
245 } else {
246 response.data = json!({ "output": output });
247 }
248
249 response
250}
251
252fn response_output(data: &Value) -> String {
253 if let Some(output) = data.get("output").and_then(Value::as_str) {
254 return output.to_string();
255 }
256 if let Some(text) = data.get("text").and_then(Value::as_str) {
257 return text.to_string();
258 }
259 if let Some(content) = data.get("content").and_then(Value::as_str) {
260 return content.to_string();
261 }
262 if let Some(message) = data.get("message").and_then(Value::as_str) {
263 return message.to_string();
264 }
265 if let Some(entries) = data.get("entries").and_then(Value::as_array) {
266 return entries
267 .iter()
268 .filter_map(Value::as_str)
269 .collect::<Vec<_>>()
270 .join("\n");
271 }
272 serde_json::to_string_pretty(data).unwrap_or_else(|_| data.to_string())
273}
274
275fn grep_request(command: &str, binary: &str) -> Option<Value> {
276 let parsed = parse(command)?;
277 if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != binary {
278 return None;
279 }
280
281 let mut case_sensitive = true;
282 let mut word_match = false;
283 let mut index = 1;
284
285 while let Some(arg) = parsed.args.get(index) {
286 if !arg.starts_with('-') || arg == "-" {
287 break;
288 }
289 for flag in arg[1..].chars() {
290 match flag {
291 'n' | 'r' => {}
292 'i' => case_sensitive = false,
293 'w' => word_match = true,
294 _ => return None,
295 }
296 }
297 index += 1;
298 }
299
300 let pattern = parsed.args.get(index)?.clone();
301 let path = parsed.args.get(index + 1).cloned();
302 if parsed.args.len() > index + 2 {
303 return None;
304 }
305
306 let pattern = if word_match {
307 format!(r"\b(?:{})\b", pattern)
308 } else {
309 pattern
310 };
311
312 if regex::RegexBuilder::new(&pattern)
313 .size_limit(REGEX_SIZE_LIMIT)
314 .build()
315 .is_err()
316 {
317 return None;
318 }
319
320 let mut params = json!({
321 "pattern": pattern,
322 "case_sensitive": case_sensitive,
323 "max_results": 100,
324 });
325 if let Some(path) = path {
326 params["path"] = json!(path);
327 }
328 Some(params)
329}
330
331fn find_request(command: &str) -> Option<Value> {
332 let parsed = parse(command)?;
333 if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != "find" {
334 return None;
335 }
336 if parsed.args.len() != 4 && parsed.args.len() != 6 {
337 return None;
338 }
339
340 let path = parsed.args.get(1)?.clone();
341 let mut name = None;
342 let mut saw_type_file = false;
343 let mut index = 2;
344
345 while index < parsed.args.len() {
346 match parsed.args[index].as_str() {
347 "-name" if name.is_none() && index + 1 < parsed.args.len() => {
348 name = Some(parsed.args[index + 1].clone());
349 index += 2;
350 }
351 "-type" if !saw_type_file && index + 1 < parsed.args.len() => {
352 if parsed.args[index + 1] != "f" {
353 return None;
354 }
355 saw_type_file = true;
356 index += 2;
357 }
358 _ => return None,
359 }
360 }
361
362 let name = name?;
363 let pattern = format!("**/{name}");
364 if path == "." {
365 Some(json!({ "pattern": pattern }))
366 } else {
367 let trimmed = path.trim_end_matches('/');
368 if trimmed.is_empty() {
369 None
375 } else {
376 Some(json!({ "path": trimmed, "pattern": pattern }))
377 }
378 }
379}
380
381fn cat_read_request(command: &str) -> Option<Value> {
382 let parsed = parse(command)?;
383 if parsed.appends_to.is_some() || parsed.heredoc.is_some() {
384 return None;
385 }
386 if parsed.args.len() != 2 || parsed.args.first()? != "cat" {
387 return None;
388 }
389 Some(json!({ "file": parsed.args[1] }))
390}
391
392fn append_request(command: &str) -> Option<Value> {
393 let parsed = parse(command)?;
394 let file = parsed.appends_to.clone()?;
395
396 let append_content = if parsed.args == ["cat"] {
397 parsed.heredoc?
398 } else if parsed.heredoc.is_none()
399 && parsed.args.first().is_some_and(|arg| arg == "echo")
400 && parsed.args.len() >= 2
401 && !parsed.args[1].starts_with('-')
402 {
403 format!("{}\n", parsed.args[1..].join(" "))
404 } else {
405 return None;
406 };
407
408 Some(json!({
409 "op": "append",
410 "file": file,
411 "append_content": append_content,
412 "create_dirs": true,
413 }))
414}
415
416fn sed_request(command: &str) -> Option<Value> {
417 let parsed = parse(command)?;
418 if parsed.appends_to.is_some() || parsed.heredoc.is_some() {
419 return None;
420 }
421 if parsed.args.len() != 4 || parsed.args.first()? != "sed" || parsed.args[1] != "-n" {
422 return None;
423 }
424
425 let range = parsed.args[2].strip_suffix('p')?;
426 let (start, end) = range.split_once(',')?;
427 let start_line = start.parse::<u32>().ok()?;
428 let end_line = end.parse::<u32>().ok()?;
429 if start_line == 0 || end_line < start_line {
430 return None;
431 }
432
433 Some(json!({
434 "file": parsed.args[3],
435 "start_line": start_line,
436 "end_line": end_line,
437 }))
438}
439
440fn ls_request(command: &str) -> Option<Value> {
441 let parsed = parse(command)?;
442 if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != "ls" {
443 return None;
444 }
445
446 let mut path = None;
447 for arg in parsed.args.iter().skip(1) {
448 if let Some(flags) = arg.strip_prefix('-') {
449 if flags.is_empty() {
450 return None;
451 }
452 for flag in flags.chars() {
453 match flag {
454 'R' | 'a' => {}
461 _ => return None,
467 }
468 }
469 } else if path.is_none() {
470 path = Some(arg.clone());
471 } else {
472 return None;
473 }
474 }
475
476 let target = path.clone().unwrap_or_else(|| ".".to_string());
482 if let Ok(metadata) = std::fs::metadata(&target) {
483 if !metadata.is_dir() {
484 return None;
485 }
486 }
487 else if path.is_some() {
491 return None;
492 }
493
494 Some(json!({ "file": target }))
495}
496
497#[cfg(test)]
498mod tests {
499 use serde_json::json;
500
501 use super::find_request;
502
503 #[test]
504 fn find_absolute_path_uses_glob_path_arg() {
505 assert_eq!(
506 find_request(r#"find /tmp/foo -name "*.ts" -type f"#),
507 Some(json!({ "path": "/tmp/foo", "pattern": "**/*.ts" }))
508 );
509 }
510
511 #[test]
512 fn find_dot_keeps_project_root_relative_pattern() {
513 assert_eq!(
514 find_request(r#"find . -name "*.ts" -type f"#),
515 Some(json!({ "pattern": "**/*.ts" }))
516 );
517 }
518
519 #[test]
520 fn find_relative_path_uses_glob_path_arg() {
521 assert_eq!(
522 find_request(r#"find ./src -name "*.go""#),
523 Some(json!({ "path": "./src", "pattern": "**/*.go" }))
524 );
525 }
526
527 #[test]
528 fn find_trims_trailing_slash_from_path_arg() {
529 assert_eq!(
530 find_request(r#"find /tmp/foo/ -name "*.ts""#),
531 Some(json!({ "path": "/tmp/foo", "pattern": "**/*.ts" }))
532 );
533 }
534
535 #[test]
536 fn find_filesystem_root_is_not_rewritten() {
537 assert_eq!(find_request(r#"find / -name "*.rs""#), None);
540 assert_eq!(find_request(r#"find // -name "*.rs""#), None);
541 }
542}