1use std::collections::HashMap;
4
5use ratatui::style::Style as RatatuiStyle;
6use ratatui::text::{Line, Span};
7use tree_sitter_highlight::{HighlightConfiguration, HighlightEvent, Highlighter as TsHighlighter};
8
9use crate::languages::Language;
10use crate::theme::Theme;
11
12#[derive(Debug)]
14pub enum HighlightError {
15 UnknownLanguage(String),
17 Highlight(String),
19 Config(String),
21}
22
23impl std::fmt::Display for HighlightError {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 HighlightError::UnknownLanguage(name) => write!(f, "Unknown language: {}", name),
27 HighlightError::Highlight(msg) => write!(f, "Highlight error: {}", msg),
28 HighlightError::Config(msg) => write!(f, "Config error: {}", msg),
29 }
30 }
31}
32
33impl std::error::Error for HighlightError {}
34
35const CAPTURE_NAMES: &[&str] = &[
38 "attribute",
39 "boolean",
40 "comment",
41 "comment.documentation",
42 "constant",
43 "constant.builtin",
44 "constructor",
45 "embedded",
46 "escape",
47 "function",
48 "function.builtin",
49 "function.call",
50 "function.macro",
51 "function.method",
52 "keyword",
53 "keyword.control",
54 "keyword.control.conditional",
55 "keyword.control.import",
56 "keyword.control.repeat",
57 "keyword.control.return",
58 "keyword.directive",
59 "keyword.function",
60 "keyword.operator",
61 "keyword.special",
62 "keyword.storage",
63 "keyword.storage.modifier",
64 "keyword.storage.type",
65 "label",
66 "namespace",
67 "number",
68 "operator",
69 "property",
70 "punctuation",
71 "punctuation.bracket",
72 "punctuation.delimiter",
73 "punctuation.special",
74 "special",
75 "string",
76 "string.escape",
77 "string.regexp",
78 "string.special",
79 "tag",
80 "type",
81 "type.builtin",
82 "variable",
83 "variable.builtin",
84 "variable.parameter",
85];
86
87struct LanguageConfig {
89 config: HighlightConfiguration,
90}
91
92pub struct Highlighter {
94 theme: Theme,
96 ts_highlighter: TsHighlighter,
98 languages: HashMap<String, LanguageConfig>,
100}
101
102impl Highlighter {
103 pub fn new(theme: Theme) -> Self {
105 Self {
106 theme,
107 ts_highlighter: TsHighlighter::new(),
108 languages: HashMap::new(),
109 }
110 }
111
112 pub fn register_language(&mut self, language: Language) -> Result<(), HighlightError> {
114 let mut config = HighlightConfiguration::new(
115 language.ts_language,
116 language.name,
117 language.highlights_query,
118 language.injections_query,
119 language.locals_query,
120 )
121 .map_err(|e| HighlightError::Config(e.to_string()))?;
122
123 config.configure(CAPTURE_NAMES);
125
126 self.languages
127 .insert(language.name.to_string(), LanguageConfig { config });
128
129 Ok(())
130 }
131
132 pub fn theme(&self) -> &Theme {
134 &self.theme
135 }
136
137 pub fn set_theme(&mut self, theme: Theme) {
139 self.theme = theme;
140 }
141
142 pub fn highlight(
151 &mut self,
152 language: &str,
153 source: &str,
154 ) -> Result<Vec<Line<'static>>, HighlightError> {
155 let lang_config = self
156 .languages
157 .get(language)
158 .ok_or_else(|| HighlightError::UnknownLanguage(language.to_string()))?;
159
160 let highlights = self
161 .ts_highlighter
162 .highlight(&lang_config.config, source.as_bytes(), None, |_| None)
163 .map_err(|e| HighlightError::Highlight(e.to_string()))?;
164
165 let mut spans: Vec<(usize, usize, RatatuiStyle)> = Vec::new();
167 let mut style_stack: Vec<RatatuiStyle> = vec![RatatuiStyle::default()];
168
169 for event in highlights {
170 match event.map_err(|e| HighlightError::Highlight(e.to_string()))? {
171 HighlightEvent::Source { start, end } => {
172 let current_style = *style_stack.last().unwrap_or(&RatatuiStyle::default());
173 spans.push((start, end, current_style));
174 }
175 HighlightEvent::HighlightStart(highlight) => {
176 let capture_name = CAPTURE_NAMES.get(highlight.0).copied().unwrap_or("text");
177 let style = self.theme.style_for(capture_name);
178 style_stack.push(style);
179 }
180 HighlightEvent::HighlightEnd => {
181 style_stack.pop();
182 }
183 }
184 }
185
186 Ok(self.spans_to_lines(source, &spans))
188 }
189
190 fn spans_to_lines(
192 &self,
193 source: &str,
194 spans: &[(usize, usize, RatatuiStyle)],
195 ) -> Vec<Line<'static>> {
196 let lines: Vec<&str> = source.lines().collect();
197 let mut result: Vec<Line<'static>> = Vec::with_capacity(lines.len());
198
199 let mut line_starts: Vec<usize> = vec![0];
201 for (i, c) in source.char_indices() {
202 if c == '\n' {
203 line_starts.push(i + 1);
204 }
205 }
206
207 if !source.ends_with('\n') && !source.is_empty() {
209 }
211
212 for (line_idx, line_text) in lines.iter().enumerate() {
213 let line_start = line_starts.get(line_idx).copied().unwrap_or(0);
214 let line_end = line_start + line_text.len();
215
216 let mut line_spans: Vec<Span<'static>> = Vec::new();
217 let mut current_pos = line_start;
218
219 for &(span_start, span_end, style) in spans {
221 if span_end <= line_start {
223 continue;
224 }
225 if span_start >= line_end {
227 break;
228 }
229
230 let clipped_start = span_start.max(line_start);
232 let clipped_end = span_end.min(line_end);
233
234 if clipped_start > current_pos {
236 let text = &source[current_pos..clipped_start];
237 line_spans.push(Span::raw(text.to_string()));
238 }
239
240 if clipped_end > clipped_start {
242 let text = &source[clipped_start..clipped_end];
243 line_spans.push(Span::styled(text.to_string(), style));
244 current_pos = clipped_end;
245 }
246 }
247
248 if current_pos < line_end {
250 let text = &source[current_pos..line_end];
251 line_spans.push(Span::raw(text.to_string()));
252 }
253
254 if line_spans.is_empty() {
256 line_spans.push(Span::raw(String::new()));
257 }
258
259 result.push(Line::from(line_spans));
260 }
261
262 if result.is_empty() {
264 result.push(Line::from(vec![Span::raw(String::new())]));
265 }
266
267 result
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::languages::{html, sql};
275 use crate::themes;
276
277 #[test]
278 fn test_highlighter_creation() {
279 let theme = themes::one_dark();
280 let highlighter = Highlighter::new(theme);
281 assert!(highlighter.languages.is_empty());
282 }
283
284 #[test]
285 fn test_register_language() {
286 let theme = themes::one_dark();
287 let mut highlighter = Highlighter::new(theme);
288 highlighter.register_language(sql()).unwrap();
289 assert!(highlighter.languages.contains_key("sql"));
290 }
291
292 #[test]
293 fn test_highlight_simple_sql() {
294 let theme = themes::one_dark();
295 let mut highlighter = Highlighter::new(theme);
296 highlighter.register_language(sql()).unwrap();
297
298 let lines = highlighter.highlight("sql", "SELECT * FROM users").unwrap();
299 assert_eq!(lines.len(), 1);
300 assert!(!lines[0].spans.is_empty());
302 }
303
304 #[test]
305 fn test_highlight_multiline_sql() {
306 let theme = themes::one_dark();
307 let mut highlighter = Highlighter::new(theme);
308 highlighter.register_language(sql()).unwrap();
309
310 let sql = "SELECT *\nFROM users\nWHERE id = 1";
311 let lines = highlighter.highlight("sql", sql).unwrap();
312 assert_eq!(lines.len(), 3);
313 }
314
315 #[test]
316 fn test_unknown_language_error() {
317 let theme = themes::one_dark();
318 let mut highlighter = Highlighter::new(theme);
319
320 let result = highlighter.highlight("unknown", "some code");
321 assert!(matches!(result, Err(HighlightError::UnknownLanguage(_))));
322 }
323
324 #[test]
325 fn test_highlight_html() {
326 let theme = themes::one_dark();
327 let mut highlighter = Highlighter::new(theme);
328 highlighter.register_language(html()).unwrap();
329
330 let html_content = "<html><head><title>Test</title></head><body><p>Hello</p></body></html>";
331 let lines = highlighter.highlight("html", html_content).unwrap();
332 assert_eq!(lines.len(), 1);
333 assert!(!lines[0].spans.is_empty());
335 }
336
337 #[test]
338 fn test_highlight_multiline_html() {
339 let theme = themes::one_dark();
340 let mut highlighter = Highlighter::new(theme);
341 highlighter.register_language(html()).unwrap();
342
343 let html_content = r#"<html>
344<head>
345 <title>Test</title>
346</head>
347<body>
348 <p>Hello</p>
349</body>
350</html>"#;
351 let lines = highlighter.highlight("html", html_content).unwrap();
352 assert_eq!(lines.len(), 8);
353 }
354}