lean_ctx/
compound_lexer.rs1#[derive(Debug, Clone, PartialEq)]
12pub enum Segment {
13 Command(String),
14 Operator(String),
15}
16
17pub fn split_compound(input: &str) -> Vec<Segment> {
18 let input = input.trim();
19 if input.is_empty() {
20 return vec![];
21 }
22
23 if contains_heredoc(input) {
24 return vec![Segment::Command(input.to_string())];
25 }
26
27 let chars: Vec<char> = input.chars().collect();
28 let mut segments: Vec<Segment> = Vec::new();
29 let mut current = String::new();
30 let mut i = 0;
31 let len = chars.len();
32
33 while i < len {
34 let ch = chars[i];
35
36 match ch {
37 '\'' => {
38 current.push(ch);
39 i += 1;
40 while i < len && chars[i] != '\'' {
41 current.push(chars[i]);
42 i += 1;
43 }
44 if i < len {
45 current.push('\'');
46 i += 1;
47 }
48 }
49 '"' => {
50 current.push(ch);
51 i += 1;
52 while i < len && chars[i] != '"' {
53 if chars[i] == '\\' && i + 1 < len {
54 current.push('\\');
55 current.push(chars[i + 1]);
56 i += 2;
57 continue;
58 }
59 current.push(chars[i]);
60 i += 1;
61 }
62 if i < len {
63 current.push('"');
64 i += 1;
65 }
66 }
67 '`' => {
68 current.push(ch);
69 i += 1;
70 while i < len && chars[i] != '`' {
71 current.push(chars[i]);
72 i += 1;
73 }
74 if i < len {
75 current.push('`');
76 i += 1;
77 }
78 }
79 '$' if i + 1 < len && chars[i + 1] == '(' => {
80 current.push('$');
81 current.push('(');
82 i += 2;
83 let mut depth = 1;
84 while i < len && depth > 0 {
85 if chars[i] == '(' {
86 depth += 1;
87 } else if chars[i] == ')' {
88 depth -= 1;
89 }
90 current.push(chars[i]);
91 i += 1;
92 }
93 }
94 '\\' if i + 1 < len => {
95 current.push('\\');
96 current.push(chars[i + 1]);
97 i += 2;
98 }
99 '&' if i + 1 < len && chars[i + 1] == '&' => {
100 push_command(&mut segments, ¤t);
101 current.clear();
102 segments.push(Segment::Operator("&&".to_string()));
103 i += 2;
104 }
105 '|' if i + 1 < len && chars[i + 1] == '|' => {
106 push_command(&mut segments, ¤t);
107 current.clear();
108 segments.push(Segment::Operator("||".to_string()));
109 i += 2;
110 }
111 '|' => {
112 push_command(&mut segments, ¤t);
113 current.clear();
114 segments.push(Segment::Operator("|".to_string()));
115 let rest: String = chars[i + 1..].iter().collect::<String>();
116 let rest = rest.trim().to_string();
117 if !rest.is_empty() {
118 segments.push(Segment::Command(rest));
119 }
120 return segments;
121 }
122 ';' => {
123 push_command(&mut segments, ¤t);
124 current.clear();
125 segments.push(Segment::Operator(";".to_string()));
126 i += 1;
127 }
128 _ => {
129 current.push(ch);
130 i += 1;
131 }
132 }
133 }
134
135 push_command(&mut segments, ¤t);
136 segments
137}
138
139fn push_command(segments: &mut Vec<Segment>, cmd: &str) {
140 let trimmed = cmd.trim();
141 if !trimmed.is_empty() {
142 segments.push(Segment::Command(trimmed.to_string()));
143 }
144}
145
146fn contains_heredoc(input: &str) -> bool {
147 input.contains("<<") || input.contains("$((")
148}
149
150pub fn rewrite_compound<F>(input: &str, rewrite_fn: F) -> Option<String>
155where
156 F: Fn(&str) -> Option<String>,
157{
158 let segments = split_compound(input);
159 if segments.len() <= 1 {
160 return None;
161 }
162
163 let mut any_rewritten = false;
164 let mut result = String::new();
165 let mut after_pipe = false;
166
167 for seg in &segments {
168 match seg {
169 Segment::Operator(op) => {
170 if op == "|" {
171 after_pipe = true;
172 }
173 if !result.is_empty() && !result.ends_with(' ') {
174 result.push(' ');
175 }
176 result.push_str(op);
177 result.push(' ');
178 }
179 Segment::Command(cmd) => {
180 if after_pipe {
181 result.push_str(cmd);
182 } else if let Some(rewritten) = rewrite_fn(cmd) {
183 any_rewritten = true;
184 result.push_str(&rewritten);
185 } else {
186 result.push_str(cmd);
187 }
188 }
189 }
190 }
191
192 if any_rewritten {
193 Some(result.trim().to_string())
194 } else {
195 None
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn simple_command() {
205 let segs = split_compound("git status");
206 assert_eq!(segs, vec![Segment::Command("git status".into())]);
207 }
208
209 #[test]
210 fn and_chain() {
211 let segs = split_compound("cd src && git status && echo done");
212 assert_eq!(
213 segs,
214 vec![
215 Segment::Command("cd src".into()),
216 Segment::Operator("&&".into()),
217 Segment::Command("git status".into()),
218 Segment::Operator("&&".into()),
219 Segment::Command("echo done".into()),
220 ]
221 );
222 }
223
224 #[test]
225 fn pipe_stops_at_right() {
226 let segs = split_compound("git log --oneline | grep fix");
227 assert_eq!(
228 segs,
229 vec![
230 Segment::Command("git log --oneline".into()),
231 Segment::Operator("|".into()),
232 Segment::Command("grep fix".into()),
233 ]
234 );
235 }
236
237 #[test]
238 fn pipe_in_chain() {
239 let segs = split_compound("cd src && git log | head -5");
240 assert_eq!(
241 segs,
242 vec![
243 Segment::Command("cd src".into()),
244 Segment::Operator("&&".into()),
245 Segment::Command("git log".into()),
246 Segment::Operator("|".into()),
247 Segment::Command("head -5".into()),
248 ]
249 );
250 }
251
252 #[test]
253 fn semicolons() {
254 let segs = split_compound("git add .; git commit -m 'fix'");
255 assert_eq!(
256 segs,
257 vec![
258 Segment::Command("git add .".into()),
259 Segment::Operator(";".into()),
260 Segment::Command("git commit -m 'fix'".into()),
261 ]
262 );
263 }
264
265 #[test]
266 fn or_chain() {
267 let segs = split_compound("git pull || echo failed");
268 assert_eq!(
269 segs,
270 vec![
271 Segment::Command("git pull".into()),
272 Segment::Operator("||".into()),
273 Segment::Command("echo failed".into()),
274 ]
275 );
276 }
277
278 #[test]
279 fn quoted_ampersand_not_split() {
280 let segs = split_compound("echo 'foo && bar'");
281 assert_eq!(segs, vec![Segment::Command("echo 'foo && bar'".into())]);
282 }
283
284 #[test]
285 fn double_quoted_pipe_not_split() {
286 let segs = split_compound(r#"echo "hello | world""#);
287 assert_eq!(
288 segs,
289 vec![Segment::Command(r#"echo "hello | world""#.into())]
290 );
291 }
292
293 #[test]
294 fn heredoc_kept_whole() {
295 let segs = split_compound("cat <<EOF\nhello\nEOF && echo done");
296 assert_eq!(
297 segs,
298 vec![Segment::Command(
299 "cat <<EOF\nhello\nEOF && echo done".into()
300 )]
301 );
302 }
303
304 #[test]
305 fn subshell_not_split() {
306 let segs = split_compound("echo $(git status && echo ok)");
307 assert_eq!(
308 segs,
309 vec![Segment::Command("echo $(git status && echo ok)".into())]
310 );
311 }
312
313 #[test]
314 fn rewrite_compound_and_chain() {
315 let result = rewrite_compound("cd src && git status && echo done", |cmd| {
316 if cmd.starts_with("git ") {
317 Some(format!("rtk {cmd}"))
318 } else {
319 None
320 }
321 });
322 assert_eq!(result, Some("cd src && rtk git status && echo done".into()));
323 }
324
325 #[test]
326 fn rewrite_compound_pipe_preserves_right() {
327 let result = rewrite_compound("git log | head -5", |cmd| {
328 if cmd.starts_with("git ") {
329 Some(format!("rtk {cmd}"))
330 } else {
331 None
332 }
333 });
334 assert_eq!(result, Some("rtk git log | head -5".into()));
335 }
336
337 #[test]
338 fn rewrite_compound_no_match_returns_none() {
339 let result = rewrite_compound("cd src && echo done", |_| None);
340 assert_eq!(result, None);
341 }
342
343 #[test]
344 fn rewrite_single_command_returns_none() {
345 let result = rewrite_compound("git status", |cmd| {
346 if cmd.starts_with("git ") {
347 Some(format!("rtk {cmd}"))
348 } else {
349 None
350 }
351 });
352 assert_eq!(result, None);
353 }
354}