Skip to main content

rgx/ui/
syntax_highlight.rs

1use ratatui::{
2    style::{Color, Style},
3    text::Span,
4};
5use regex_syntax::ast::{parse::Parser, Ast, LiteralKind};
6
7use super::theme;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum SyntaxCategory {
11    Literal,
12    Group,
13    Quantifier,
14    CharClass,
15    Anchor,
16    Escape,
17    Alternation,
18}
19
20#[derive(Debug, Clone)]
21pub struct SyntaxToken {
22    pub start: usize,
23    pub end: usize,
24    pub category: SyntaxCategory,
25}
26
27pub fn highlight(pattern: &str) -> Vec<SyntaxToken> {
28    if pattern.is_empty() {
29        return vec![];
30    }
31
32    let ast = match Parser::new().parse(pattern) {
33        Ok(ast) => ast,
34        Err(_) => return vec![],
35    };
36
37    let mut tokens = Vec::new();
38    collect_tokens(&ast, &mut tokens);
39    tokens.sort_by_key(|t| t.start);
40    tokens
41}
42
43fn collect_tokens(ast: &Ast, tokens: &mut Vec<SyntaxToken>) {
44    match ast {
45        Ast::Empty(_) => {}
46        Ast::Literal(lit) => {
47            let start = lit.span.start.offset;
48            let end = lit.span.end.offset;
49            let category = match lit.kind {
50                LiteralKind::Verbatim => SyntaxCategory::Literal,
51                _ => SyntaxCategory::Escape,
52            };
53            tokens.push(SyntaxToken {
54                start,
55                end,
56                category,
57            });
58        }
59        Ast::Dot(span) => {
60            tokens.push(SyntaxToken {
61                start: span.start.offset,
62                end: span.end.offset,
63                category: SyntaxCategory::Anchor,
64            });
65        }
66        Ast::Assertion(a) => {
67            tokens.push(SyntaxToken {
68                start: a.span.start.offset,
69                end: a.span.end.offset,
70                category: SyntaxCategory::Anchor,
71            });
72        }
73        Ast::ClassPerl(c) => {
74            tokens.push(SyntaxToken {
75                start: c.span.start.offset,
76                end: c.span.end.offset,
77                category: SyntaxCategory::CharClass,
78            });
79        }
80        Ast::ClassUnicode(c) => {
81            tokens.push(SyntaxToken {
82                start: c.span.start.offset,
83                end: c.span.end.offset,
84                category: SyntaxCategory::CharClass,
85            });
86        }
87        Ast::ClassBracketed(c) => {
88            tokens.push(SyntaxToken {
89                start: c.span.start.offset,
90                end: c.span.end.offset,
91                category: SyntaxCategory::CharClass,
92            });
93        }
94        Ast::Repetition(rep) => {
95            collect_tokens(&rep.ast, tokens);
96            tokens.push(SyntaxToken {
97                start: rep.op.span.start.offset,
98                end: rep.op.span.end.offset,
99                category: SyntaxCategory::Quantifier,
100            });
101        }
102        Ast::Group(group) => {
103            // Token for the opening delimiter (everything up to the inner AST)
104            let group_start = group.span.start.offset;
105            let inner_start = group.ast.span().start.offset;
106            let inner_end = group.ast.span().end.offset;
107            let group_end = group.span.end.offset;
108
109            if inner_start > group_start {
110                tokens.push(SyntaxToken {
111                    start: group_start,
112                    end: inner_start,
113                    category: SyntaxCategory::Group,
114                });
115            }
116
117            collect_tokens(&group.ast, tokens);
118
119            // Token for the closing delimiter
120            if group_end > inner_end {
121                tokens.push(SyntaxToken {
122                    start: inner_end,
123                    end: group_end,
124                    category: SyntaxCategory::Group,
125                });
126            }
127        }
128        Ast::Alternation(alt) => {
129            // Visit children and derive `|` positions from gaps between siblings
130            for (i, child) in alt.asts.iter().enumerate() {
131                collect_tokens(child, tokens);
132                if i + 1 < alt.asts.len() {
133                    let pipe_start = child.span().end.offset;
134                    let next_start = alt.asts[i + 1].span().start.offset;
135                    // The `|` should be between the end of this child and start of next
136                    if next_start > pipe_start {
137                        tokens.push(SyntaxToken {
138                            start: pipe_start,
139                            end: pipe_start + 1,
140                            category: SyntaxCategory::Alternation,
141                        });
142                    }
143                }
144            }
145        }
146        Ast::Concat(concat) => {
147            for child in &concat.asts {
148                collect_tokens(child, tokens);
149            }
150        }
151        Ast::Flags(flags) => {
152            tokens.push(SyntaxToken {
153                start: flags.span.start.offset,
154                end: flags.span.end.offset,
155                category: SyntaxCategory::Group,
156            });
157        }
158    }
159}
160
161pub fn category_color(cat: SyntaxCategory) -> Color {
162    match cat {
163        SyntaxCategory::Literal => theme::TEXT,
164        SyntaxCategory::Group => theme::BLUE,
165        SyntaxCategory::Quantifier => theme::MAUVE,
166        SyntaxCategory::CharClass => theme::GREEN,
167        SyntaxCategory::Anchor => theme::TEAL,
168        SyntaxCategory::Escape => theme::PEACH,
169        SyntaxCategory::Alternation => theme::YELLOW,
170    }
171}
172
173pub fn build_highlighted_spans<'a>(pattern: &'a str, tokens: &[SyntaxToken]) -> Vec<Span<'a>> {
174    let mut spans = Vec::new();
175    let mut pos = 0;
176
177    for token in tokens {
178        // Skip overlapping or out-of-order tokens
179        if token.start < pos {
180            continue;
181        }
182        // Gap before this token → plain text
183        if token.start > pos {
184            spans.push(Span::styled(
185                &pattern[pos..token.start],
186                Style::default().fg(theme::TEXT),
187            ));
188        }
189        let end = token.end.min(pattern.len());
190        if end > token.start {
191            spans.push(Span::styled(
192                &pattern[token.start..end],
193                Style::default().fg(category_color(token.category)),
194            ));
195        }
196        pos = end;
197    }
198
199    // Remaining text after last token
200    if pos < pattern.len() {
201        spans.push(Span::styled(
202            &pattern[pos..],
203            Style::default().fg(theme::TEXT),
204        ));
205    }
206
207    spans
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn test_empty_pattern() {
216        assert!(highlight("").is_empty());
217    }
218
219    #[test]
220    fn test_invalid_pattern_returns_empty() {
221        assert!(highlight("(unclosed").is_empty());
222    }
223
224    #[test]
225    fn test_literal_only() {
226        let tokens = highlight("hello");
227        assert!(!tokens.is_empty());
228        for t in &tokens {
229            assert_eq!(t.category, SyntaxCategory::Literal);
230        }
231    }
232
233    #[test]
234    fn test_perl_class() {
235        let tokens = highlight(r"\d+");
236        let categories: Vec<_> = tokens.iter().map(|t| t.category).collect();
237        assert!(categories.contains(&SyntaxCategory::CharClass));
238        assert!(categories.contains(&SyntaxCategory::Quantifier));
239    }
240
241    #[test]
242    fn test_group_and_quantifier() {
243        let tokens = highlight(r"(\w+)");
244        let categories: Vec<_> = tokens.iter().map(|t| t.category).collect();
245        assert!(categories.contains(&SyntaxCategory::Group));
246        assert!(categories.contains(&SyntaxCategory::CharClass));
247        assert!(categories.contains(&SyntaxCategory::Quantifier));
248    }
249
250    #[test]
251    fn test_alternation() {
252        let tokens = highlight("foo|bar");
253        let categories: Vec<_> = tokens.iter().map(|t| t.category).collect();
254        assert!(categories.contains(&SyntaxCategory::Alternation));
255        assert!(categories.contains(&SyntaxCategory::Literal));
256    }
257
258    #[test]
259    fn test_anchors() {
260        let tokens = highlight(r"^hello$");
261        let categories: Vec<_> = tokens.iter().map(|t| t.category).collect();
262        assert!(categories.contains(&SyntaxCategory::Anchor));
263        assert!(categories.contains(&SyntaxCategory::Literal));
264    }
265
266    #[test]
267    fn test_escape_sequences() {
268        let tokens = highlight(r"\n\t");
269        for t in &tokens {
270            assert_eq!(t.category, SyntaxCategory::Escape);
271        }
272    }
273
274    #[test]
275    fn test_dot() {
276        let tokens = highlight("a.b");
277        let categories: Vec<_> = tokens.iter().map(|t| t.category).collect();
278        assert!(categories.contains(&SyntaxCategory::Anchor));
279    }
280
281    #[test]
282    fn test_bracketed_class() {
283        let tokens = highlight("[a-z]+");
284        let categories: Vec<_> = tokens.iter().map(|t| t.category).collect();
285        assert!(categories.contains(&SyntaxCategory::CharClass));
286        assert!(categories.contains(&SyntaxCategory::Quantifier));
287    }
288
289    #[test]
290    fn test_build_highlighted_spans_covers_full_pattern() {
291        let pattern = r"(\w+)@(\w+)";
292        let tokens = highlight(pattern);
293        let spans = build_highlighted_spans(pattern, &tokens);
294        let reconstructed: String = spans.iter().map(|s| s.content.as_ref()).collect();
295        assert_eq!(reconstructed, pattern);
296    }
297
298    #[test]
299    fn test_lazy_quantifier() {
300        let tokens = highlight(r"\d+?");
301        let quant: Vec<_> = tokens
302            .iter()
303            .filter(|t| t.category == SyntaxCategory::Quantifier)
304            .collect();
305        assert_eq!(quant.len(), 1);
306        // Should cover both `+` and `?`
307        assert_eq!(quant[0].end - quant[0].start, 2);
308    }
309
310    #[test]
311    fn test_named_group() {
312        let tokens = highlight(r"(?P<name>\w+)");
313        let groups: Vec<_> = tokens
314            .iter()
315            .filter(|t| t.category == SyntaxCategory::Group)
316            .collect();
317        // Opening `(?P<name>` and closing `)`
318        assert_eq!(groups.len(), 2);
319    }
320}