Skip to main content

rgx/ui/
syntax_highlight.rs

1use ratatui::{
2    style::{Color, Style},
3    text::Span,
4};
5use regex_syntax::ast::{Ast, LiteralKind};
6
7use super::theme;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum SyntaxCategory {
11    Literal,
12    Group,
13    Quantifier,
14    CharClass,
15    Anchor,
16    Escape,
17    Alternation,
18}
19
20#[derive(Debug, Clone)]
21pub struct SyntaxToken {
22    pub start: usize,
23    pub end: usize,
24    pub category: SyntaxCategory,
25}
26
27pub fn highlight(pattern: &str) -> Vec<SyntaxToken> {
28    if pattern.is_empty() {
29        return vec![];
30    }
31
32    let Some(ast) = crate::explain::parse_ast(pattern) else {
33        return vec![];
34    };
35
36    let mut tokens = Vec::new();
37    collect_tokens(&ast, &mut tokens);
38    tokens.sort_by_key(|t| t.start);
39    tokens
40}
41
42fn collect_tokens(ast: &Ast, tokens: &mut Vec<SyntaxToken>) {
43    match ast {
44        Ast::Empty(_) => {}
45        Ast::Literal(lit) => {
46            let start = lit.span.start.offset;
47            let end = lit.span.end.offset;
48            let category = match lit.kind {
49                LiteralKind::Verbatim => SyntaxCategory::Literal,
50                _ => SyntaxCategory::Escape,
51            };
52            tokens.push(SyntaxToken {
53                start,
54                end,
55                category,
56            });
57        }
58        Ast::Dot(span) => {
59            tokens.push(SyntaxToken {
60                start: span.start.offset,
61                end: span.end.offset,
62                category: SyntaxCategory::Anchor,
63            });
64        }
65        Ast::Assertion(a) => {
66            tokens.push(SyntaxToken {
67                start: a.span.start.offset,
68                end: a.span.end.offset,
69                category: SyntaxCategory::Anchor,
70            });
71        }
72        Ast::ClassPerl(c) => {
73            tokens.push(SyntaxToken {
74                start: c.span.start.offset,
75                end: c.span.end.offset,
76                category: SyntaxCategory::CharClass,
77            });
78        }
79        Ast::ClassUnicode(c) => {
80            tokens.push(SyntaxToken {
81                start: c.span.start.offset,
82                end: c.span.end.offset,
83                category: SyntaxCategory::CharClass,
84            });
85        }
86        Ast::ClassBracketed(c) => {
87            tokens.push(SyntaxToken {
88                start: c.span.start.offset,
89                end: c.span.end.offset,
90                category: SyntaxCategory::CharClass,
91            });
92        }
93        Ast::Repetition(rep) => {
94            collect_tokens(&rep.ast, tokens);
95            tokens.push(SyntaxToken {
96                start: rep.op.span.start.offset,
97                end: rep.op.span.end.offset,
98                category: SyntaxCategory::Quantifier,
99            });
100        }
101        Ast::Group(group) => {
102            // Token for the opening delimiter (everything up to the inner AST)
103            let group_start = group.span.start.offset;
104            let inner_start = group.ast.span().start.offset;
105            let inner_end = group.ast.span().end.offset;
106            let group_end = group.span.end.offset;
107
108            if inner_start > group_start {
109                tokens.push(SyntaxToken {
110                    start: group_start,
111                    end: inner_start,
112                    category: SyntaxCategory::Group,
113                });
114            }
115
116            collect_tokens(&group.ast, tokens);
117
118            // Token for the closing delimiter
119            if group_end > inner_end {
120                tokens.push(SyntaxToken {
121                    start: inner_end,
122                    end: group_end,
123                    category: SyntaxCategory::Group,
124                });
125            }
126        }
127        Ast::Alternation(alt) => {
128            // Visit children and derive `|` positions from gaps between siblings
129            for (i, child) in alt.asts.iter().enumerate() {
130                collect_tokens(child, tokens);
131                if i + 1 < alt.asts.len() {
132                    let pipe_start = child.span().end.offset;
133                    let next_start = alt.asts[i + 1].span().start.offset;
134                    // The `|` should be between the end of this child and start of next
135                    if next_start > pipe_start {
136                        tokens.push(SyntaxToken {
137                            start: pipe_start,
138                            end: pipe_start + 1,
139                            category: SyntaxCategory::Alternation,
140                        });
141                    }
142                }
143            }
144        }
145        Ast::Concat(concat) => {
146            for child in &concat.asts {
147                collect_tokens(child, tokens);
148            }
149        }
150        Ast::Flags(flags) => {
151            tokens.push(SyntaxToken {
152                start: flags.span.start.offset,
153                end: flags.span.end.offset,
154                category: SyntaxCategory::Group,
155            });
156        }
157    }
158}
159
160pub fn category_color(cat: SyntaxCategory) -> Color {
161    match cat {
162        SyntaxCategory::Literal => theme::TEXT,
163        SyntaxCategory::Group => theme::BLUE,
164        SyntaxCategory::Quantifier => theme::MAUVE,
165        SyntaxCategory::CharClass => theme::GREEN,
166        SyntaxCategory::Anchor => theme::TEAL,
167        SyntaxCategory::Escape => theme::PEACH,
168        SyntaxCategory::Alternation => theme::YELLOW,
169    }
170}
171
172pub fn build_highlighted_spans<'a>(pattern: &'a str, tokens: &[SyntaxToken]) -> Vec<Span<'a>> {
173    let mut spans = Vec::new();
174    let mut pos = 0;
175
176    for token in tokens {
177        // Skip overlapping or out-of-order tokens
178        if token.start < pos {
179            continue;
180        }
181        // Gap before this token → plain text
182        if token.start > pos {
183            spans.push(Span::styled(
184                &pattern[pos..token.start],
185                Style::default().fg(theme::TEXT),
186            ));
187        }
188        let end = token.end.min(pattern.len());
189        if end > token.start {
190            spans.push(Span::styled(
191                &pattern[token.start..end],
192                Style::default().fg(category_color(token.category)),
193            ));
194        }
195        pos = end;
196    }
197
198    // Remaining text after last token
199    if pos < pattern.len() {
200        spans.push(Span::styled(
201            &pattern[pos..],
202            Style::default().fg(theme::TEXT),
203        ));
204    }
205
206    spans
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_empty_pattern() {
215        assert!(highlight("").is_empty());
216    }
217
218    #[test]
219    fn test_invalid_pattern_returns_empty() {
220        assert!(highlight("(unclosed").is_empty());
221    }
222
223    #[test]
224    fn test_literal_only() {
225        let tokens = highlight("hello");
226        assert!(!tokens.is_empty());
227        for t in &tokens {
228            assert_eq!(t.category, SyntaxCategory::Literal);
229        }
230    }
231
232    #[test]
233    fn test_perl_class() {
234        let tokens = highlight(r"\d+");
235        let categories: Vec<_> = tokens.iter().map(|t| t.category).collect();
236        assert!(categories.contains(&SyntaxCategory::CharClass));
237        assert!(categories.contains(&SyntaxCategory::Quantifier));
238    }
239
240    #[test]
241    fn test_group_and_quantifier() {
242        let tokens = highlight(r"(\w+)");
243        let categories: Vec<_> = tokens.iter().map(|t| t.category).collect();
244        assert!(categories.contains(&SyntaxCategory::Group));
245        assert!(categories.contains(&SyntaxCategory::CharClass));
246        assert!(categories.contains(&SyntaxCategory::Quantifier));
247    }
248
249    #[test]
250    fn test_alternation() {
251        let tokens = highlight("foo|bar");
252        let categories: Vec<_> = tokens.iter().map(|t| t.category).collect();
253        assert!(categories.contains(&SyntaxCategory::Alternation));
254        assert!(categories.contains(&SyntaxCategory::Literal));
255    }
256
257    #[test]
258    fn test_anchors() {
259        let tokens = highlight(r"^hello$");
260        let categories: Vec<_> = tokens.iter().map(|t| t.category).collect();
261        assert!(categories.contains(&SyntaxCategory::Anchor));
262        assert!(categories.contains(&SyntaxCategory::Literal));
263    }
264
265    #[test]
266    fn test_escape_sequences() {
267        let tokens = highlight(r"\n\t");
268        for t in &tokens {
269            assert_eq!(t.category, SyntaxCategory::Escape);
270        }
271    }
272
273    #[test]
274    fn test_dot() {
275        let tokens = highlight("a.b");
276        let categories: Vec<_> = tokens.iter().map(|t| t.category).collect();
277        assert!(categories.contains(&SyntaxCategory::Anchor));
278    }
279
280    #[test]
281    fn test_bracketed_class() {
282        let tokens = highlight("[a-z]+");
283        let categories: Vec<_> = tokens.iter().map(|t| t.category).collect();
284        assert!(categories.contains(&SyntaxCategory::CharClass));
285        assert!(categories.contains(&SyntaxCategory::Quantifier));
286    }
287
288    #[test]
289    fn test_build_highlighted_spans_covers_full_pattern() {
290        let pattern = r"(\w+)@(\w+)";
291        let tokens = highlight(pattern);
292        let spans = build_highlighted_spans(pattern, &tokens);
293        let reconstructed: String = spans.iter().map(|s| s.content.as_ref()).collect();
294        assert_eq!(reconstructed, pattern);
295    }
296
297    #[test]
298    fn test_lazy_quantifier() {
299        let tokens = highlight(r"\d+?");
300        let quant: Vec<_> = tokens
301            .iter()
302            .filter(|t| t.category == SyntaxCategory::Quantifier)
303            .collect();
304        assert_eq!(quant.len(), 1);
305        // Should cover both `+` and `?`
306        assert_eq!(quant[0].end - quant[0].start, 2);
307    }
308
309    #[test]
310    fn test_named_group() {
311        let tokens = highlight(r"(?P<name>\w+)");
312        let groups: Vec<_> = tokens
313            .iter()
314            .filter(|t| t.category == SyntaxCategory::Group)
315            .collect();
316        // Opening `(?P<name>` and closing `)`
317        assert_eq!(groups.len(), 2);
318    }
319}