Skip to main content

lean_ctx/
compound_lexer.rs

1/// Splits a compound shell command into segments separated by `&&`, `||`, `;`.
2/// Pipes (`|`) are treated specially: only the left side of a pipe is eligible
3/// for rewriting (the right side consumes output format and must stay raw).
4///
5/// Respects single quotes, double quotes, backtick-quotes, and `$(...)` subshells
6/// so that operators inside quoted strings are not treated as separators.
7///
8/// Returns a `Vec<Segment>` where each entry is either a command segment or an
9/// operator token that should be emitted verbatim.
10
11#[derive(Debug, Clone, PartialEq)]
12pub enum Segment {
13    Command(String),
14    Operator(String),
15}
16
17pub fn split_compound(input: &str) -> Vec<Segment> {
18    let input = input.trim();
19    if input.is_empty() {
20        return vec![];
21    }
22
23    if contains_heredoc(input) {
24        return vec![Segment::Command(input.to_string())];
25    }
26
27    let bytes = input.as_bytes();
28    let mut segments: Vec<Segment> = Vec::new();
29    let mut current = String::new();
30    let mut i = 0;
31    let len = bytes.len();
32
33    while i < len {
34        let ch = bytes[i] as char;
35
36        match ch {
37            '\'' => {
38                current.push(ch);
39                i += 1;
40                while i < len && bytes[i] != b'\'' {
41                    current.push(bytes[i] as char);
42                    i += 1;
43                }
44                if i < len {
45                    current.push('\'');
46                    i += 1;
47                }
48            }
49            '"' => {
50                current.push(ch);
51                i += 1;
52                while i < len && bytes[i] != b'"' {
53                    if bytes[i] == b'\\' && i + 1 < len {
54                        current.push('\\');
55                        current.push(bytes[i + 1] as char);
56                        i += 2;
57                        continue;
58                    }
59                    current.push(bytes[i] as char);
60                    i += 1;
61                }
62                if i < len {
63                    current.push('"');
64                    i += 1;
65                }
66            }
67            '`' => {
68                current.push(ch);
69                i += 1;
70                while i < len && bytes[i] != b'`' {
71                    current.push(bytes[i] as char);
72                    i += 1;
73                }
74                if i < len {
75                    current.push('`');
76                    i += 1;
77                }
78            }
79            '$' if i + 1 < len && bytes[i + 1] == b'(' => {
80                let start = i;
81                i += 2;
82                let mut depth = 1;
83                while i < len && depth > 0 {
84                    if bytes[i] == b'(' {
85                        depth += 1;
86                    } else if bytes[i] == b')' {
87                        depth -= 1;
88                    }
89                    i += 1;
90                }
91                current.push_str(&input[start..i]);
92            }
93            '\\' if i + 1 < len => {
94                current.push('\\');
95                current.push(bytes[i + 1] as char);
96                i += 2;
97            }
98            '&' if i + 1 < len && bytes[i + 1] == b'&' => {
99                push_command(&mut segments, &current);
100                current.clear();
101                segments.push(Segment::Operator("&&".to_string()));
102                i += 2;
103            }
104            '|' if i + 1 < len && bytes[i + 1] == b'|' => {
105                push_command(&mut segments, &current);
106                current.clear();
107                segments.push(Segment::Operator("||".to_string()));
108                i += 2;
109            }
110            '|' => {
111                // Pipe: left side is a command, right side is NOT rewritten.
112                // We emit the left command, the pipe operator, and the entire
113                // rest of the input as a single opaque command segment.
114                push_command(&mut segments, &current);
115                current.clear();
116                segments.push(Segment::Operator("|".to_string()));
117                let rest = input[i + 1..].trim().to_string();
118                if !rest.is_empty() {
119                    segments.push(Segment::Command(rest));
120                }
121                return segments;
122            }
123            ';' => {
124                push_command(&mut segments, &current);
125                current.clear();
126                segments.push(Segment::Operator(";".to_string()));
127                i += 1;
128            }
129            _ => {
130                current.push(ch);
131                i += 1;
132            }
133        }
134    }
135
136    push_command(&mut segments, &current);
137    segments
138}
139
140fn push_command(segments: &mut Vec<Segment>, cmd: &str) {
141    let trimmed = cmd.trim();
142    if !trimmed.is_empty() {
143        segments.push(Segment::Command(trimmed.to_string()));
144    }
145}
146
147fn contains_heredoc(input: &str) -> bool {
148    input.contains("<<") || input.contains("$((")
149}
150
151/// Rewrites a compound command by applying a rewrite function to each command segment.
152/// Operators and pipe-right-hand segments are preserved unchanged.
153/// `rewrite_fn` receives a command string and returns `Some(rewritten)` if it should
154/// be rewritten, or `None` to keep the original.
155pub fn rewrite_compound<F>(input: &str, rewrite_fn: F) -> Option<String>
156where
157    F: Fn(&str) -> Option<String>,
158{
159    let segments = split_compound(input);
160    if segments.len() <= 1 {
161        return None;
162    }
163
164    let mut any_rewritten = false;
165    let mut result = String::new();
166    let mut after_pipe = false;
167
168    for seg in &segments {
169        match seg {
170            Segment::Operator(op) => {
171                if op == "|" {
172                    after_pipe = true;
173                }
174                if !result.is_empty() && !result.ends_with(' ') {
175                    result.push(' ');
176                }
177                result.push_str(op);
178                result.push(' ');
179            }
180            Segment::Command(cmd) => {
181                if after_pipe {
182                    result.push_str(cmd);
183                } else if let Some(rewritten) = rewrite_fn(cmd) {
184                    any_rewritten = true;
185                    result.push_str(&rewritten);
186                } else {
187                    result.push_str(cmd);
188                }
189            }
190        }
191    }
192
193    if any_rewritten {
194        Some(result.trim().to_string())
195    } else {
196        None
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn simple_command() {
206        let segs = split_compound("git status");
207        assert_eq!(segs, vec![Segment::Command("git status".into())]);
208    }
209
210    #[test]
211    fn and_chain() {
212        let segs = split_compound("cd src && git status && echo done");
213        assert_eq!(
214            segs,
215            vec![
216                Segment::Command("cd src".into()),
217                Segment::Operator("&&".into()),
218                Segment::Command("git status".into()),
219                Segment::Operator("&&".into()),
220                Segment::Command("echo done".into()),
221            ]
222        );
223    }
224
225    #[test]
226    fn pipe_stops_at_right() {
227        let segs = split_compound("git log --oneline | grep fix");
228        assert_eq!(
229            segs,
230            vec![
231                Segment::Command("git log --oneline".into()),
232                Segment::Operator("|".into()),
233                Segment::Command("grep fix".into()),
234            ]
235        );
236    }
237
238    #[test]
239    fn pipe_in_chain() {
240        let segs = split_compound("cd src && git log | head -5");
241        assert_eq!(
242            segs,
243            vec![
244                Segment::Command("cd src".into()),
245                Segment::Operator("&&".into()),
246                Segment::Command("git log".into()),
247                Segment::Operator("|".into()),
248                Segment::Command("head -5".into()),
249            ]
250        );
251    }
252
253    #[test]
254    fn semicolons() {
255        let segs = split_compound("git add .; git commit -m 'fix'");
256        assert_eq!(
257            segs,
258            vec![
259                Segment::Command("git add .".into()),
260                Segment::Operator(";".into()),
261                Segment::Command("git commit -m 'fix'".into()),
262            ]
263        );
264    }
265
266    #[test]
267    fn or_chain() {
268        let segs = split_compound("git pull || echo failed");
269        assert_eq!(
270            segs,
271            vec![
272                Segment::Command("git pull".into()),
273                Segment::Operator("||".into()),
274                Segment::Command("echo failed".into()),
275            ]
276        );
277    }
278
279    #[test]
280    fn quoted_ampersand_not_split() {
281        let segs = split_compound("echo 'foo && bar'");
282        assert_eq!(segs, vec![Segment::Command("echo 'foo && bar'".into())]);
283    }
284
285    #[test]
286    fn double_quoted_pipe_not_split() {
287        let segs = split_compound(r#"echo "hello | world""#);
288        assert_eq!(
289            segs,
290            vec![Segment::Command(r#"echo "hello | world""#.into())]
291        );
292    }
293
294    #[test]
295    fn heredoc_kept_whole() {
296        let segs = split_compound("cat <<EOF\nhello\nEOF && echo done");
297        assert_eq!(
298            segs,
299            vec![Segment::Command(
300                "cat <<EOF\nhello\nEOF && echo done".into()
301            )]
302        );
303    }
304
305    #[test]
306    fn subshell_not_split() {
307        let segs = split_compound("echo $(git status && echo ok)");
308        assert_eq!(
309            segs,
310            vec![Segment::Command("echo $(git status && echo ok)".into())]
311        );
312    }
313
314    #[test]
315    fn rewrite_compound_and_chain() {
316        let result = rewrite_compound("cd src && git status && echo done", |cmd| {
317            if cmd.starts_with("git ") {
318                Some(format!("rtk {cmd}"))
319            } else {
320                None
321            }
322        });
323        assert_eq!(result, Some("cd src && rtk git status && echo done".into()));
324    }
325
326    #[test]
327    fn rewrite_compound_pipe_preserves_right() {
328        let result = rewrite_compound("git log | head -5", |cmd| {
329            if cmd.starts_with("git ") {
330                Some(format!("rtk {cmd}"))
331            } else {
332                None
333            }
334        });
335        assert_eq!(result, Some("rtk git log | head -5".into()));
336    }
337
338    #[test]
339    fn rewrite_compound_no_match_returns_none() {
340        let result = rewrite_compound("cd src && echo done", |_| None);
341        assert_eq!(result, None);
342    }
343
344    #[test]
345    fn rewrite_single_command_returns_none() {
346        let result = rewrite_compound("git status", |cmd| {
347            if cmd.starts_with("git ") {
348                Some(format!("rtk {cmd}"))
349            } else {
350                None
351            }
352        });
353        assert_eq!(result, None);
354    }
355}