1use std::collections::HashMap;
2
3use polyfont_core::{Position, Range, TokenInfo};
4use thiserror::Error;
5use tracing::warn;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum OffsetEncoding {
9 Utf8,
10 Utf16,
11}
12
13#[derive(Debug, Error)]
14pub enum ParseError {
15 #[error("unsupported language: {0}")]
16 UnsupportedLanguage(String),
17 #[error("tree-sitter parsing failed for language '{language}': {message}")]
18 ParseFailed { language: String, message: String },
19}
20
21pub trait LanguageSupport: Send + Sync {
22 fn language_name(&self) -> &str;
23 fn language_id(&self) -> &str;
24 fn parse(
25 &self,
26 text: &str,
27 offset_encoding: OffsetEncoding,
28 ) -> Result<Vec<TokenInfo>, ParseError>;
29}
30
31pub fn byte_offset_to_position(
32 text: &str,
33 byte_offset: usize,
34 offset_encoding: OffsetEncoding,
35) -> Position {
36 let bytes = text.as_bytes();
37 let offset = byte_offset.min(bytes.len());
38
39 let mut line: u32 = 0;
40 let mut line_start: usize = 0;
41
42 for (i, &byte) in bytes.iter().enumerate() {
43 if byte == b'\n' {
44 line += 1;
45 line_start = i + 1;
46 }
47 if i == offset {
48 break;
49 }
50 }
51
52 let line_text = if offset >= line_start && offset <= bytes.len() {
53 &text[line_start..offset]
54 } else {
55 ""
56 };
57
58 let column = match offset_encoding {
59 OffsetEncoding::Utf8 => line_text.len() as u32,
60 OffsetEncoding::Utf16 => line_text.encode_utf16().count() as u32,
61 };
62
63 Position { line, column }
64}
65
66#[allow(dead_code)]
67fn byte_offset_to_position_safe(
68 text: &str,
69 byte_offset: usize,
70 offset_encoding: OffsetEncoding,
71) -> Position {
72 if byte_offset <= text.len() {
73 byte_offset_to_position(text, byte_offset, offset_encoding)
74 } else {
75 byte_offset_to_position(text, text.len(), offset_encoding)
76 }
77}
78
79pub fn scope_from_highlights(highlight_names: &[&str]) -> String {
80 highlight_names.join(".")
81}
82
83#[allow(dead_code)]
84struct HighlightParser {
85 config: tree_sitter_highlight::HighlightConfiguration,
86 language_name: String,
87 language_id: String,
88}
89
90#[allow(dead_code)]
91impl HighlightParser {
92 fn new(
93 language: tree_sitter::Language,
94 language_name: &str,
95 language_id: &str,
96 highlights_query: &str,
97 injections_query: &str,
98 locals_query: &str,
99 ) -> Result<Self, ParseError> {
100 let config = tree_sitter_highlight::HighlightConfiguration::new(
101 language,
102 language_name,
103 highlights_query,
104 injections_query,
105 locals_query,
106 )
107 .map_err(|e| ParseError::ParseFailed {
108 language: language_name.to_owned(),
109 message: e.to_string(),
110 })?;
111
112 Ok(Self {
113 config,
114 language_name: language_name.to_owned(),
115 language_id: language_id.to_owned(),
116 })
117 }
118
119 fn parse_impl(
120 &self,
121 text: &str,
122 offset_encoding: OffsetEncoding,
123 ) -> Result<Vec<TokenInfo>, ParseError> {
124 let mut highlighter = tree_sitter_highlight::Highlighter::new();
125 let source = text.as_bytes();
126
127 let events: Vec<tree_sitter_highlight::HighlightEvent> =
128 match highlighter.highlight(&self.config, source, None, |_| None) {
129 Ok(iter) => iter.filter_map(|e| e.ok()).collect(),
130 Err(e) => {
131 warn!(
132 language = %self.language_name,
133 error = %e,
134 "highlighting failed, returning empty tokens"
135 );
136 return Ok(Vec::new());
137 }
138 };
139
140 let mut tokens = Vec::new();
141 let mut scope_stack: Vec<String> = Vec::new();
142 let mut byte_start_stack: Vec<usize> = Vec::new();
143 let mut current_source_start: usize = 0;
144 let mut current_source_end: usize = 0;
145
146 for event in events {
147 match event {
148 tree_sitter_highlight::HighlightEvent::Source { start, end } => {
149 current_source_start = start;
150 current_source_end = end;
151 }
152 tree_sitter_highlight::HighlightEvent::HighlightStart(capture) => {
153 let capture_idx = capture.0;
154 let name = self
155 .config
156 .names()
157 .get(capture_idx)
158 .copied()
159 .unwrap_or("unknown");
160 scope_stack.push(name.to_owned());
161 byte_start_stack.push(current_source_start);
162 }
163 tree_sitter_highlight::HighlightEvent::HighlightEnd => {
164 let scope = scope_stack.pop().unwrap_or_default();
165 let byte_start = byte_start_stack.pop().unwrap_or(current_source_start);
166 let byte_end = current_source_end;
167
168 if byte_end > byte_start {
169 let safe_start = byte_start.min(text.len());
170 let safe_end = byte_end.min(text.len());
171 if safe_end > safe_start
172 && text.is_char_boundary(safe_start)
173 && text.is_char_boundary(safe_end)
174 {
175 let token_text = text[safe_start..safe_end].to_owned();
176 let start_pos =
177 byte_offset_to_position_safe(text, safe_start, offset_encoding);
178 let end_pos =
179 byte_offset_to_position_safe(text, safe_end, offset_encoding);
180
181 tokens.push(TokenInfo {
182 text: token_text,
183 range: Range {
184 start: start_pos,
185 end: end_pos,
186 },
187 scope,
188 modifiers: Vec::new(),
189 });
190 }
191 }
192 }
193 }
194 }
195
196 Ok(tokens)
197 }
198}
199
200impl LanguageSupport for HighlightParser {
201 fn language_name(&self) -> &str {
202 &self.language_name
203 }
204
205 fn language_id(&self) -> &str {
206 &self.language_id
207 }
208
209 fn parse(
210 &self,
211 text: &str,
212 offset_encoding: OffsetEncoding,
213 ) -> Result<Vec<TokenInfo>, ParseError> {
214 self.parse_impl(text, offset_encoding)
215 }
216}
217
218macro_rules! register_language {
219 ($languages:expr, $id:expr, $name:expr, $feature_gate:expr, $lang_fn:expr, $hq:expr, $iq:expr, $lq:expr) => {
220 #[cfg(feature = $feature_gate)]
221 {
222 match HighlightParser::new($lang_fn, $name, $id, $hq, $iq, $lq) {
223 Ok(parser) => {
224 $languages.insert($id.to_owned(), Box::new(parser));
225 }
226 Err(e) => {
227 warn!(
228 language = $name,
229 error = %e,
230 "failed to create highlighter for language, skipping"
231 );
232 }
233 }
234 }
235 };
236}
237
238pub struct TokenParser {
239 languages: HashMap<String, Box<dyn LanguageSupport>>,
240}
241
242impl TokenParser {
243 pub fn new() -> Self {
244 #[allow(unused_mut)]
245 let mut languages: HashMap<String, Box<dyn LanguageSupport>> = HashMap::new();
246
247 register_language!(
248 languages,
249 "rust",
250 "Rust",
251 "rust",
252 tree_sitter_rust::language(),
253 tree_sitter_rust::HIGHLIGHTS_QUERY,
254 tree_sitter_rust::INJECTIONS_QUERY,
255 ""
256 );
257
258 register_language!(
259 languages,
260 "typescript",
261 "TypeScript",
262 "typescript",
263 tree_sitter_typescript::language_typescript(),
264 tree_sitter_typescript::HIGHLIGHTS_QUERY,
265 tree_sitter_typescript::INJECTIONS_QUERY,
266 tree_sitter_typescript::LOCALS_QUERY
267 );
268
269 register_language!(
270 languages,
271 "javascript",
272 "JavaScript",
273 "javascript",
274 tree_sitter_typescript::language_typescript(),
275 tree_sitter_typescript::HIGHLIGHTS_QUERY,
276 tree_sitter_typescript::INJECTIONS_QUERY,
277 tree_sitter_typescript::LOCALS_QUERY
278 );
279
280 register_language!(
281 languages,
282 "python",
283 "Python",
284 "python",
285 tree_sitter_python::language(),
286 tree_sitter_python::HIGHLIGHTS_QUERY,
287 "",
288 tree_sitter_python::LOCALS_QUERY
289 );
290
291 register_language!(
292 languages,
293 "go",
294 "Go",
295 "go",
296 tree_sitter_go::language(),
297 tree_sitter_go::HIGHLIGHTS_QUERY,
298 "",
299 ""
300 );
301
302 register_language!(
303 languages,
304 "c",
305 "C",
306 "c",
307 tree_sitter_c::language(),
308 tree_sitter_c::HIGHLIGHTS_QUERY,
309 "",
310 ""
311 );
312
313 register_language!(
314 languages,
315 "cpp",
316 "C++",
317 "cpp",
318 tree_sitter_cpp::language(),
319 tree_sitter_cpp::HIGHLIGHTS_QUERY,
320 "",
321 ""
322 );
323
324 register_language!(
325 languages,
326 "json",
327 "JSON",
328 "json",
329 tree_sitter_json::language(),
330 tree_sitter_json::HIGHLIGHTS_QUERY,
331 "",
332 ""
333 );
334
335 register_language!(
336 languages,
337 "toml",
338 "TOML",
339 "toml",
340 tree_sitter_toml::language(),
341 tree_sitter_toml::HIGHLIGHTS_QUERY,
342 "",
343 ""
344 );
345
346 register_language!(
347 languages,
348 "lua",
349 "Lua",
350 "lua",
351 tree_sitter_lua::language(),
352 tree_sitter_lua::HIGHLIGHTS_QUERY,
353 "",
354 ""
355 );
356
357 Self { languages }
358 }
359
360 pub fn supported_languages(&self) -> Vec<&str> {
361 let mut langs: Vec<&str> = self.languages.keys().map(String::as_str).collect();
362 langs.sort();
363 langs
364 }
365
366 pub fn parse_tokens(
367 &self,
368 text: &str,
369 language_id: &str,
370 offset_encoding: OffsetEncoding,
371 ) -> Result<Vec<TokenInfo>, ParseError> {
372 let support = self
373 .languages
374 .get(language_id)
375 .ok_or_else(|| ParseError::UnsupportedLanguage(language_id.to_owned()))?;
376 support.parse(text, offset_encoding)
377 }
378}
379
380impl Default for TokenParser {
381 fn default() -> Self {
382 Self::new()
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn test_byte_offset_to_position_single_line() {
392 let text = "hello world";
393 let pos = byte_offset_to_position(text, 5, OffsetEncoding::Utf8);
394 assert_eq!(pos.line, 0);
395 assert_eq!(pos.column, 5);
396 }
397
398 #[test]
399 fn test_byte_offset_to_position_multiline() {
400 let text = "line1\nline2\nline3";
401 let pos = byte_offset_to_position(text, 6, OffsetEncoding::Utf8);
402 assert_eq!(pos.line, 1);
403 assert_eq!(pos.column, 0);
404
405 let pos = byte_offset_to_position(text, 12, OffsetEncoding::Utf8);
406 assert_eq!(pos.line, 2);
407 assert_eq!(pos.column, 0);
408 }
409
410 #[test]
411 fn test_byte_offset_to_position_utf16() {
412 let text = "hello\nworld";
413 let pos = byte_offset_to_position(text, 6, OffsetEncoding::Utf16);
414 assert_eq!(pos.line, 1);
415 assert_eq!(pos.column, 0);
416 }
417
418 #[test]
419 fn test_byte_offset_to_position_end_of_text() {
420 let text = "abc";
421 let pos = byte_offset_to_position(text, 3, OffsetEncoding::Utf8);
422 assert_eq!(pos.line, 0);
423 assert_eq!(pos.column, 3);
424 }
425
426 #[test]
427 fn test_byte_offset_to_position_empty_text() {
428 let pos = byte_offset_to_position("", 0, OffsetEncoding::Utf8);
429 assert_eq!(pos.line, 0);
430 assert_eq!(pos.column, 0);
431 }
432
433 #[test]
434 fn test_byte_offset_to_position_safe_clamps() {
435 let text = "abc";
436 let pos = byte_offset_to_position_safe(text, 100, OffsetEncoding::Utf8);
437 assert_eq!(pos.line, 0);
438 assert_eq!(pos.column, 3);
439 }
440
441 #[test]
442 fn test_scope_from_highlights_single() {
443 let scope = scope_from_highlights(&["keyword"]);
444 assert_eq!(scope, "keyword");
445 }
446
447 #[test]
448 fn test_scope_from_highlights_multiple() {
449 let scope = scope_from_highlights(&["keyword", "control"]);
450 assert_eq!(scope, "keyword.control");
451 }
452
453 #[test]
454 fn test_scope_from_highlights_empty() {
455 let scope = scope_from_highlights(&[]);
456 assert_eq!(scope, "");
457 }
458
459 #[test]
460 fn test_scope_from_highlights_three_levels() {
461 let scope = scope_from_highlights(&["entity", "name", "function"]);
462 assert_eq!(scope, "entity.name.function");
463 }
464
465 #[test]
466 fn test_token_parser_no_features_by_default() {
467 let parser = TokenParser::new();
468 assert!(parser.supported_languages().is_empty());
469 }
470
471 #[test]
472 fn test_token_parser_unsupported_language() {
473 let parser = TokenParser::new();
474 let result = parser.parse_tokens("fn main() {}", "rust", OffsetEncoding::Utf8);
475 assert!(result.is_err());
476 match result.unwrap_err() {
477 ParseError::UnsupportedLanguage(lang) => assert_eq!(lang, "rust"),
478 other => panic!("expected UnsupportedLanguage, got {other}"),
479 }
480 }
481
482 #[test]
483 fn test_token_parser_empty_input() {
484 let parser = TokenParser::new();
485 let result = parser.parse_tokens("", "rust", OffsetEncoding::Utf8);
486 assert!(result.is_err());
487 }
488
489 #[test]
490 fn test_token_parser_unknown_language_ids() {
491 let parser = TokenParser::new();
492 for id in &["brainfuck", "cobol", "fortran", "haskell", "zig"] {
493 let result = parser.parse_tokens("", id, OffsetEncoding::Utf8);
494 assert!(result.is_err());
495 }
496 }
497}