1use serde_json::{json, Value};
2
3const REGEX_SIZE_LIMIT: usize = 10 * 1024 * 1024;
4
5use crate::bash_rewrite::footer::add_footer;
6use crate::bash_rewrite::parser::parse;
7use crate::bash_rewrite::RewriteRule;
8use crate::context::AppContext;
9use crate::protocol::{RawRequest, Response};
10
11pub struct GrepRule;
12pub struct RgRule;
13pub struct FindRule;
14pub struct CatRule;
15pub struct CatAppendRule;
16pub struct SedRule;
17pub struct LsRule;
18
19impl RewriteRule for GrepRule {
20 fn name(&self) -> &'static str {
21 "grep"
22 }
23
24 fn matches(&self, command: &str) -> bool {
25 grep_request(command, "grep").is_some()
26 }
27
28 fn rewrite(
29 &self,
30 command: &str,
31 session_id: Option<&str>,
32 ctx: &AppContext,
33 ) -> Result<Response, String> {
34 let params = grep_request(command, "grep").ok_or("not a grep rewrite")?;
35 try_call_and_footer(crate::commands::grep::handle_grep(&request("grep", params, session_id), ctx), "grep")
36 }
37}
38
39impl RewriteRule for RgRule {
40 fn name(&self) -> &'static str {
41 "rg"
42 }
43
44 fn matches(&self, command: &str) -> bool {
45 grep_request(command, "rg").is_some()
46 }
47
48 fn rewrite(
49 &self,
50 command: &str,
51 session_id: Option<&str>,
52 ctx: &AppContext,
53 ) -> Result<Response, String> {
54 let params = grep_request(command, "rg").ok_or("not an rg rewrite")?;
55 try_call_and_footer(crate::commands::grep::handle_grep(&request("grep", params, session_id), ctx), "grep")
56 }
57}
58
59impl RewriteRule for FindRule {
60 fn name(&self) -> &'static str {
61 "find"
62 }
63
64 fn matches(&self, command: &str) -> bool {
65 find_request(command).is_some()
66 }
67
68 fn rewrite(
69 &self,
70 command: &str,
71 session_id: Option<&str>,
72 ctx: &AppContext,
73 ) -> Result<Response, String> {
74 let params = find_request(command).ok_or("not a find rewrite")?;
75 try_call_and_footer(crate::commands::glob::handle_glob(&request("glob", params, session_id), ctx), "glob")
76 }
77}
78
79impl RewriteRule for CatRule {
80 fn name(&self) -> &'static str {
81 "cat"
82 }
83
84 fn matches(&self, command: &str) -> bool {
85 cat_read_request(command).is_some()
86 }
87
88 fn rewrite(
89 &self,
90 command: &str,
91 session_id: Option<&str>,
92 ctx: &AppContext,
93 ) -> Result<Response, String> {
94 let params = cat_read_request(command).ok_or("not a cat rewrite")?;
95 try_call_and_footer(crate::commands::read::handle_read(&request("read", params, session_id), ctx), "read")
96 }
97}
98
99impl RewriteRule for CatAppendRule {
100 fn name(&self) -> &'static str {
101 "cat_append"
102 }
103
104 fn matches(&self, command: &str) -> bool {
105 append_request(command).is_some()
106 }
107
108 fn rewrite(
109 &self,
110 command: &str,
111 session_id: Option<&str>,
112 ctx: &AppContext,
113 ) -> Result<Response, String> {
114 let params = append_request(command).ok_or("not an append rewrite")?;
115 try_call_and_footer(crate::commands::edit_match::handle_edit_match(
116 &request("edit_match", params, session_id),
117 ctx,
118 ), "edit")
119 }
120}
121
122impl RewriteRule for SedRule {
123 fn name(&self) -> &'static str {
124 "sed"
125 }
126
127 fn matches(&self, command: &str) -> bool {
128 sed_request(command).is_some()
129 }
130
131 fn rewrite(
132 &self,
133 command: &str,
134 session_id: Option<&str>,
135 ctx: &AppContext,
136 ) -> Result<Response, String> {
137 let params = sed_request(command).ok_or("not a sed rewrite")?;
138 try_call_and_footer(crate::commands::read::handle_read(&request("read", params, session_id), ctx), "read")
139 }
140}
141
142impl RewriteRule for LsRule {
143 fn name(&self) -> &'static str {
144 "ls"
145 }
146
147 fn matches(&self, command: &str) -> bool {
148 ls_request(command).is_some()
149 }
150
151 fn rewrite(
152 &self,
153 command: &str,
154 session_id: Option<&str>,
155 ctx: &AppContext,
156 ) -> Result<Response, String> {
157 let params = ls_request(command).ok_or("not an ls rewrite")?;
158 try_call_and_footer(crate::commands::read::handle_read(&request("read", params, session_id), ctx), "read")
159 }
160}
161
162fn request(command: &str, params: Value, session_id: Option<&str>) -> RawRequest {
163 RawRequest {
164 id: "bash_rewrite".to_string(),
165 command: command.to_string(),
166 lsp_hints: None,
167 session_id: session_id.map(str::to_string),
168 params,
169 }
170}
171
172fn try_call_and_footer(response: Response, replacement_tool: &str) -> Result<Response, String> {
178 if !response.success {
179 let message = response
180 .data
181 .get("message")
182 .and_then(Value::as_str)
183 .or_else(|| response.data.get("code").and_then(Value::as_str))
184 .unwrap_or("error");
185 return Err(format!("{} declined: {}", replacement_tool, message));
186 }
187 Ok(call_and_footer(response, replacement_tool))
188}
189
190fn call_and_footer(mut response: Response, replacement_tool: &str) -> Response {
191 let output = response_output(&response.data);
192 let output = add_footer(&output, replacement_tool);
193
194 if let Some(object) = response.data.as_object_mut() {
195 object.insert("output".to_string(), Value::String(output.clone()));
196
197 for key in ["text", "content", "message"] {
198 if object.get(key).is_some_and(Value::is_string) {
199 object.insert(key.to_string(), Value::String(output.clone()));
200 break;
201 }
202 }
203 } else {
204 response.data = json!({ "output": output });
205 }
206
207 response
208}
209
210fn response_output(data: &Value) -> String {
211 if let Some(output) = data.get("output").and_then(Value::as_str) {
212 return output.to_string();
213 }
214 if let Some(text) = data.get("text").and_then(Value::as_str) {
215 return text.to_string();
216 }
217 if let Some(content) = data.get("content").and_then(Value::as_str) {
218 return content.to_string();
219 }
220 if let Some(message) = data.get("message").and_then(Value::as_str) {
221 return message.to_string();
222 }
223 if let Some(entries) = data.get("entries").and_then(Value::as_array) {
224 return entries
225 .iter()
226 .filter_map(Value::as_str)
227 .collect::<Vec<_>>()
228 .join("\n");
229 }
230 serde_json::to_string_pretty(data).unwrap_or_else(|_| data.to_string())
231}
232
233fn grep_request(command: &str, binary: &str) -> Option<Value> {
234 let parsed = parse(command)?;
235 if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != binary {
236 return None;
237 }
238
239 let mut case_sensitive = true;
240 let mut word_match = false;
241 let mut index = 1;
242
243 while let Some(arg) = parsed.args.get(index) {
244 if !arg.starts_with('-') || arg == "-" {
245 break;
246 }
247 for flag in arg[1..].chars() {
248 match flag {
249 'n' | 'r' => {}
250 'i' => case_sensitive = false,
251 'w' => word_match = true,
252 _ => return None,
253 }
254 }
255 index += 1;
256 }
257
258 let pattern = parsed.args.get(index)?.clone();
259 let path = parsed.args.get(index + 1).cloned();
260 if parsed.args.len() > index + 2 {
261 return None;
262 }
263
264 let pattern = if word_match {
265 format!(r"\b(?:{})\b", pattern)
266 } else {
267 pattern
268 };
269
270 if regex::RegexBuilder::new(&pattern)
271 .size_limit(REGEX_SIZE_LIMIT)
272 .build()
273 .is_err()
274 {
275 return None;
276 }
277
278 let mut params = json!({
279 "pattern": pattern,
280 "case_sensitive": case_sensitive,
281 "max_results": 100,
282 });
283 if let Some(path) = path {
284 params["path"] = json!(path);
285 }
286 Some(params)
287}
288
289fn find_request(command: &str) -> Option<Value> {
290 let parsed = parse(command)?;
291 if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != "find" {
292 return None;
293 }
294 if parsed.args.len() != 4 && parsed.args.len() != 6 {
295 return None;
296 }
297
298 let path = parsed.args.get(1)?.clone();
299 let mut name = None;
300 let mut saw_type_file = false;
301 let mut index = 2;
302
303 while index < parsed.args.len() {
304 match parsed.args[index].as_str() {
305 "-name" if name.is_none() && index + 1 < parsed.args.len() => {
306 name = Some(parsed.args[index + 1].clone());
307 index += 2;
308 }
309 "-type" if !saw_type_file && index + 1 < parsed.args.len() => {
310 if parsed.args[index + 1] != "f" {
311 return None;
312 }
313 saw_type_file = true;
314 index += 2;
315 }
316 _ => return None,
317 }
318 }
319
320 let name = name?;
321 let pattern = if path == "." {
322 format!("**/{name}")
323 } else {
324 format!("{}/**/{name}", path.trim_end_matches('/'))
325 };
326
327 Some(json!({ "pattern": pattern }))
328}
329
330fn cat_read_request(command: &str) -> Option<Value> {
331 let parsed = parse(command)?;
332 if parsed.appends_to.is_some() || parsed.heredoc.is_some() {
333 return None;
334 }
335 if parsed.args.len() != 2 || parsed.args.first()? != "cat" {
336 return None;
337 }
338 Some(json!({ "file": parsed.args[1] }))
339}
340
341fn append_request(command: &str) -> Option<Value> {
342 let parsed = parse(command)?;
343 let file = parsed.appends_to.clone()?;
344
345 let append_content = if parsed.args == ["cat"] {
346 parsed.heredoc?
347 } else if parsed.heredoc.is_none()
348 && parsed.args.first().is_some_and(|arg| arg == "echo")
349 && parsed.args.len() >= 2
350 && !parsed.args[1].starts_with('-')
351 {
352 format!("{}\n", parsed.args[1..].join(" "))
353 } else {
354 return None;
355 };
356
357 Some(json!({
358 "op": "append",
359 "file": file,
360 "append_content": append_content,
361 "create_dirs": true,
362 }))
363}
364
365fn sed_request(command: &str) -> Option<Value> {
366 let parsed = parse(command)?;
367 if parsed.appends_to.is_some() || parsed.heredoc.is_some() {
368 return None;
369 }
370 if parsed.args.len() != 4 || parsed.args.first()? != "sed" || parsed.args[1] != "-n" {
371 return None;
372 }
373
374 let range = parsed.args[2].strip_suffix('p')?;
375 let (start, end) = range.split_once(',')?;
376 let start_line = start.parse::<u32>().ok()?;
377 let end_line = end.parse::<u32>().ok()?;
378 if start_line == 0 || end_line < start_line {
379 return None;
380 }
381
382 Some(json!({
383 "file": parsed.args[3],
384 "start_line": start_line,
385 "end_line": end_line,
386 }))
387}
388
389fn ls_request(command: &str) -> Option<Value> {
390 let parsed = parse(command)?;
391 if parsed.appends_to.is_some() || parsed.heredoc.is_some() || parsed.args.first()? != "ls" {
392 return None;
393 }
394
395 let mut path = None;
396 for arg in parsed.args.iter().skip(1) {
397 if let Some(flags) = arg.strip_prefix('-') {
398 if flags.is_empty() || !flags.chars().all(|flag| matches!(flag, 'l' | 'R' | 'a')) {
399 return None;
400 }
401 } else if path.is_none() {
402 path = Some(arg.clone());
403 } else {
404 return None;
405 }
406 }
407
408 Some(json!({ "file": path.unwrap_or_else(|| ".".to_string()) }))
409}