Skip to main content

lean_ctx/
compound_lexer.rs

1/// Splits a compound shell command into segments separated by `&&`, `||`, `;`.
2/// Pipes (`|`) are treated specially: only the left side of a pipe is eligible
3/// for rewriting (the right side consumes output format and must stay raw).
4///
5/// Respects single quotes, double quotes, backtick-quotes, and `$(...)` subshells
6/// so that operators inside quoted strings are not treated as separators.
7///
8/// Returns a `Vec<Segment>` where each entry is either a command segment or an
9/// operator token that should be emitted verbatim.
10
11#[derive(Debug, Clone, PartialEq)]
12pub enum Segment {
13    Command(String),
14    Operator(String),
15}
16
17pub fn split_compound(input: &str) -> Vec<Segment> {
18    let input = input.trim();
19    if input.is_empty() {
20        return vec![];
21    }
22
23    if contains_heredoc(input) {
24        return vec![Segment::Command(input.to_string())];
25    }
26
27    let chars: Vec<char> = input.chars().collect();
28    let mut segments: Vec<Segment> = Vec::new();
29    let mut current = String::new();
30    let mut i = 0;
31    let len = chars.len();
32
33    while i < len {
34        let ch = chars[i];
35
36        match ch {
37            '\'' => {
38                current.push(ch);
39                i += 1;
40                while i < len && chars[i] != '\'' {
41                    current.push(chars[i]);
42                    i += 1;
43                }
44                if i < len {
45                    current.push('\'');
46                    i += 1;
47                }
48            }
49            '"' => {
50                current.push(ch);
51                i += 1;
52                while i < len && chars[i] != '"' {
53                    if chars[i] == '\\' && i + 1 < len {
54                        current.push('\\');
55                        current.push(chars[i + 1]);
56                        i += 2;
57                        continue;
58                    }
59                    current.push(chars[i]);
60                    i += 1;
61                }
62                if i < len {
63                    current.push('"');
64                    i += 1;
65                }
66            }
67            '`' => {
68                current.push(ch);
69                i += 1;
70                while i < len && chars[i] != '`' {
71                    current.push(chars[i]);
72                    i += 1;
73                }
74                if i < len {
75                    current.push('`');
76                    i += 1;
77                }
78            }
79            '$' if i + 1 < len && chars[i + 1] == '(' => {
80                current.push('$');
81                current.push('(');
82                i += 2;
83                let mut depth = 1;
84                while i < len && depth > 0 {
85                    if chars[i] == '(' {
86                        depth += 1;
87                    } else if chars[i] == ')' {
88                        depth -= 1;
89                    }
90                    current.push(chars[i]);
91                    i += 1;
92                }
93            }
94            '\\' if i + 1 < len => {
95                current.push('\\');
96                current.push(chars[i + 1]);
97                i += 2;
98            }
99            '&' if i + 1 < len && chars[i + 1] == '&' => {
100                push_command(&mut segments, &current);
101                current.clear();
102                segments.push(Segment::Operator("&&".to_string()));
103                i += 2;
104            }
105            '|' if i + 1 < len && chars[i + 1] == '|' => {
106                push_command(&mut segments, &current);
107                current.clear();
108                segments.push(Segment::Operator("||".to_string()));
109                i += 2;
110            }
111            '|' => {
112                push_command(&mut segments, &current);
113                current.clear();
114                segments.push(Segment::Operator("|".to_string()));
115                let rest: String = chars[i + 1..].iter().collect::<String>();
116                let rest = rest.trim().to_string();
117                if !rest.is_empty() {
118                    segments.push(Segment::Command(rest));
119                }
120                return segments;
121            }
122            ';' => {
123                push_command(&mut segments, &current);
124                current.clear();
125                segments.push(Segment::Operator(";".to_string()));
126                i += 1;
127            }
128            _ => {
129                current.push(ch);
130                i += 1;
131            }
132        }
133    }
134
135    push_command(&mut segments, &current);
136    segments
137}
138
139fn push_command(segments: &mut Vec<Segment>, cmd: &str) {
140    let trimmed = cmd.trim();
141    if !trimmed.is_empty() {
142        segments.push(Segment::Command(trimmed.to_string()));
143    }
144}
145
146fn contains_heredoc(input: &str) -> bool {
147    input.contains("<<") || input.contains("$((")
148}
149
150/// Rewrites a compound command by applying a rewrite function to each command segment.
151/// Operators and pipe-right-hand segments are preserved unchanged.
152/// `rewrite_fn` receives a command string and returns `Some(rewritten)` if it should
153/// be rewritten, or `None` to keep the original.
154pub fn rewrite_compound<F>(input: &str, rewrite_fn: F) -> Option<String>
155where
156    F: Fn(&str) -> Option<String>,
157{
158    let segments = split_compound(input);
159    if segments.len() <= 1 {
160        return None;
161    }
162
163    let mut any_rewritten = false;
164    let mut result = String::new();
165    let mut after_pipe = false;
166
167    for seg in &segments {
168        match seg {
169            Segment::Operator(op) => {
170                if op == "|" {
171                    after_pipe = true;
172                }
173                if !result.is_empty() && !result.ends_with(' ') {
174                    result.push(' ');
175                }
176                result.push_str(op);
177                result.push(' ');
178            }
179            Segment::Command(cmd) => {
180                if after_pipe {
181                    result.push_str(cmd);
182                } else if let Some(rewritten) = rewrite_fn(cmd) {
183                    any_rewritten = true;
184                    result.push_str(&rewritten);
185                } else {
186                    result.push_str(cmd);
187                }
188            }
189        }
190    }
191
192    if any_rewritten {
193        Some(result.trim().to_string())
194    } else {
195        None
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn simple_command() {
205        let segs = split_compound("git status");
206        assert_eq!(segs, vec![Segment::Command("git status".into())]);
207    }
208
209    #[test]
210    fn and_chain() {
211        let segs = split_compound("cd src && git status && echo done");
212        assert_eq!(
213            segs,
214            vec![
215                Segment::Command("cd src".into()),
216                Segment::Operator("&&".into()),
217                Segment::Command("git status".into()),
218                Segment::Operator("&&".into()),
219                Segment::Command("echo done".into()),
220            ]
221        );
222    }
223
224    #[test]
225    fn pipe_stops_at_right() {
226        let segs = split_compound("git log --oneline | grep fix");
227        assert_eq!(
228            segs,
229            vec![
230                Segment::Command("git log --oneline".into()),
231                Segment::Operator("|".into()),
232                Segment::Command("grep fix".into()),
233            ]
234        );
235    }
236
237    #[test]
238    fn pipe_in_chain() {
239        let segs = split_compound("cd src && git log | head -5");
240        assert_eq!(
241            segs,
242            vec![
243                Segment::Command("cd src".into()),
244                Segment::Operator("&&".into()),
245                Segment::Command("git log".into()),
246                Segment::Operator("|".into()),
247                Segment::Command("head -5".into()),
248            ]
249        );
250    }
251
252    #[test]
253    fn semicolons() {
254        let segs = split_compound("git add .; git commit -m 'fix'");
255        assert_eq!(
256            segs,
257            vec![
258                Segment::Command("git add .".into()),
259                Segment::Operator(";".into()),
260                Segment::Command("git commit -m 'fix'".into()),
261            ]
262        );
263    }
264
265    #[test]
266    fn or_chain() {
267        let segs = split_compound("git pull || echo failed");
268        assert_eq!(
269            segs,
270            vec![
271                Segment::Command("git pull".into()),
272                Segment::Operator("||".into()),
273                Segment::Command("echo failed".into()),
274            ]
275        );
276    }
277
278    #[test]
279    fn quoted_ampersand_not_split() {
280        let segs = split_compound("echo 'foo && bar'");
281        assert_eq!(segs, vec![Segment::Command("echo 'foo && bar'".into())]);
282    }
283
284    #[test]
285    fn double_quoted_pipe_not_split() {
286        let segs = split_compound(r#"echo "hello | world""#);
287        assert_eq!(
288            segs,
289            vec![Segment::Command(r#"echo "hello | world""#.into())]
290        );
291    }
292
293    #[test]
294    fn heredoc_kept_whole() {
295        let segs = split_compound("cat <<EOF\nhello\nEOF && echo done");
296        assert_eq!(
297            segs,
298            vec![Segment::Command(
299                "cat <<EOF\nhello\nEOF && echo done".into()
300            )]
301        );
302    }
303
304    #[test]
305    fn subshell_not_split() {
306        let segs = split_compound("echo $(git status && echo ok)");
307        assert_eq!(
308            segs,
309            vec![Segment::Command("echo $(git status && echo ok)".into())]
310        );
311    }
312
313    #[test]
314    fn rewrite_compound_and_chain() {
315        let result = rewrite_compound("cd src && git status && echo done", |cmd| {
316            if cmd.starts_with("git ") {
317                Some(format!("rtk {cmd}"))
318            } else {
319                None
320            }
321        });
322        assert_eq!(result, Some("cd src && rtk git status && echo done".into()));
323    }
324
325    #[test]
326    fn rewrite_compound_pipe_preserves_right() {
327        let result = rewrite_compound("git log | head -5", |cmd| {
328            if cmd.starts_with("git ") {
329                Some(format!("rtk {cmd}"))
330            } else {
331                None
332            }
333        });
334        assert_eq!(result, Some("rtk git log | head -5".into()));
335    }
336
337    #[test]
338    fn rewrite_compound_no_match_returns_none() {
339        let result = rewrite_compound("cd src && echo done", |_| None);
340        assert_eq!(result, None);
341    }
342
343    #[test]
344    fn rewrite_single_command_returns_none() {
345        let result = rewrite_compound("git status", |cmd| {
346            if cmd.starts_with("git ") {
347                Some(format!("rtk {cmd}"))
348            } else {
349                None
350            }
351        });
352        assert_eq!(result, None);
353    }
354}