use std::collections::HashMap;
use ratatui::style::Style as RatatuiStyle;
use ratatui::text::{Line, Span};
use tree_sitter_highlight::{HighlightConfiguration, HighlightEvent, Highlighter as TsHighlighter};
use crate::languages::Language;
use crate::theme::Theme;
#[derive(Debug)]
pub enum HighlightError {
UnknownLanguage(String),
Highlight(String),
Config(String),
}
impl std::fmt::Display for HighlightError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HighlightError::UnknownLanguage(name) => write!(f, "Unknown language: {}", name),
HighlightError::Highlight(msg) => write!(f, "Highlight error: {}", msg),
HighlightError::Config(msg) => write!(f, "Config error: {}", msg),
}
}
}
impl std::error::Error for HighlightError {}
const CAPTURE_NAMES: &[&str] = &[
"attribute",
"boolean",
"comment",
"comment.documentation",
"constant",
"constant.builtin",
"constructor",
"embedded",
"escape",
"function",
"function.builtin",
"function.call",
"function.macro",
"function.method",
"keyword",
"keyword.control",
"keyword.control.conditional",
"keyword.control.import",
"keyword.control.repeat",
"keyword.control.return",
"keyword.directive",
"keyword.function",
"keyword.operator",
"keyword.special",
"keyword.storage",
"keyword.storage.modifier",
"keyword.storage.type",
"label",
"namespace",
"number",
"operator",
"property",
"punctuation",
"punctuation.bracket",
"punctuation.delimiter",
"punctuation.special",
"special",
"string",
"string.escape",
"string.regexp",
"string.special",
"tag",
"type",
"type.builtin",
"variable",
"variable.builtin",
"variable.parameter",
];
struct LanguageConfig {
config: HighlightConfiguration,
}
pub struct Highlighter {
theme: Theme,
ts_highlighter: TsHighlighter,
languages: HashMap<String, LanguageConfig>,
}
impl Highlighter {
pub fn new(theme: Theme) -> Self {
Self {
theme,
ts_highlighter: TsHighlighter::new(),
languages: HashMap::new(),
}
}
pub fn register_language(&mut self, language: Language) -> Result<(), HighlightError> {
let mut config = HighlightConfiguration::new(
language.ts_language,
language.name,
language.highlights_query,
language.injections_query,
language.locals_query,
)
.map_err(|e| HighlightError::Config(e.to_string()))?;
config.configure(CAPTURE_NAMES);
self.languages
.insert(language.name.to_string(), LanguageConfig { config });
Ok(())
}
pub fn theme(&self) -> &Theme {
&self.theme
}
pub fn set_theme(&mut self, theme: Theme) {
self.theme = theme;
}
pub fn highlight(
&mut self,
language: &str,
source: &str,
) -> Result<Vec<Line<'static>>, HighlightError> {
let lang_config = self
.languages
.get(language)
.ok_or_else(|| HighlightError::UnknownLanguage(language.to_string()))?;
let highlights = self
.ts_highlighter
.highlight(&lang_config.config, source.as_bytes(), None, |_| None)
.map_err(|e| HighlightError::Highlight(e.to_string()))?;
let mut spans: Vec<(usize, usize, RatatuiStyle)> = Vec::new();
let mut style_stack: Vec<RatatuiStyle> = vec![RatatuiStyle::default()];
for event in highlights {
match event.map_err(|e| HighlightError::Highlight(e.to_string()))? {
HighlightEvent::Source { start, end } => {
let current_style = *style_stack.last().unwrap_or(&RatatuiStyle::default());
spans.push((start, end, current_style));
}
HighlightEvent::HighlightStart(highlight) => {
let capture_name = CAPTURE_NAMES.get(highlight.0).copied().unwrap_or("text");
let style = self.theme.style_for(capture_name);
style_stack.push(style);
}
HighlightEvent::HighlightEnd => {
style_stack.pop();
}
}
}
Ok(self.spans_to_lines(source, &spans))
}
fn spans_to_lines(
&self,
source: &str,
spans: &[(usize, usize, RatatuiStyle)],
) -> Vec<Line<'static>> {
let lines: Vec<&str> = source.lines().collect();
let mut result: Vec<Line<'static>> = Vec::with_capacity(lines.len());
let mut line_starts: Vec<usize> = vec![0];
for (i, c) in source.char_indices() {
if c == '\n' {
line_starts.push(i + 1);
}
}
if !source.ends_with('\n') && !source.is_empty() {
}
for (line_idx, line_text) in lines.iter().enumerate() {
let line_start = line_starts.get(line_idx).copied().unwrap_or(0);
let line_end = line_start + line_text.len();
let mut line_spans: Vec<Span<'static>> = Vec::new();
let mut current_pos = line_start;
for &(span_start, span_end, style) in spans {
if span_end <= line_start {
continue;
}
if span_start >= line_end {
break;
}
let clipped_start = span_start.max(line_start);
let clipped_end = span_end.min(line_end);
if clipped_start > current_pos {
let text = &source[current_pos..clipped_start];
line_spans.push(Span::raw(text.to_string()));
}
if clipped_end > clipped_start {
let text = &source[clipped_start..clipped_end];
line_spans.push(Span::styled(text.to_string(), style));
current_pos = clipped_end;
}
}
if current_pos < line_end {
let text = &source[current_pos..line_end];
line_spans.push(Span::raw(text.to_string()));
}
if line_spans.is_empty() {
line_spans.push(Span::raw(String::new()));
}
result.push(Line::from(line_spans));
}
if result.is_empty() {
result.push(Line::from(vec![Span::raw(String::new())]));
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::languages::{html, sql};
use crate::themes;
#[test]
fn test_highlighter_creation() {
let theme = themes::one_dark();
let highlighter = Highlighter::new(theme);
assert!(highlighter.languages.is_empty());
}
#[test]
fn test_register_language() {
let theme = themes::one_dark();
let mut highlighter = Highlighter::new(theme);
highlighter.register_language(sql()).unwrap();
assert!(highlighter.languages.contains_key("sql"));
}
#[test]
fn test_highlight_simple_sql() {
let theme = themes::one_dark();
let mut highlighter = Highlighter::new(theme);
highlighter.register_language(sql()).unwrap();
let lines = highlighter.highlight("sql", "SELECT * FROM users").unwrap();
assert_eq!(lines.len(), 1);
assert!(!lines[0].spans.is_empty());
}
#[test]
fn test_highlight_multiline_sql() {
let theme = themes::one_dark();
let mut highlighter = Highlighter::new(theme);
highlighter.register_language(sql()).unwrap();
let sql = "SELECT *\nFROM users\nWHERE id = 1";
let lines = highlighter.highlight("sql", sql).unwrap();
assert_eq!(lines.len(), 3);
}
#[test]
fn test_unknown_language_error() {
let theme = themes::one_dark();
let mut highlighter = Highlighter::new(theme);
let result = highlighter.highlight("unknown", "some code");
assert!(matches!(result, Err(HighlightError::UnknownLanguage(_))));
}
#[test]
fn test_highlight_html() {
let theme = themes::one_dark();
let mut highlighter = Highlighter::new(theme);
highlighter.register_language(html()).unwrap();
let html_content = "<html><head><title>Test</title></head><body><p>Hello</p></body></html>";
let lines = highlighter.highlight("html", html_content).unwrap();
assert_eq!(lines.len(), 1);
assert!(!lines[0].spans.is_empty());
}
#[test]
fn test_highlight_multiline_html() {
let theme = themes::one_dark();
let mut highlighter = Highlighter::new(theme);
highlighter.register_language(html()).unwrap();
let html_content = r#"<html>
<head>
<title>Test</title>
</head>
<body>
<p>Hello</p>
</body>
</html>"#;
let lines = highlighter.highlight("html", html_content).unwrap();
assert_eq!(lines.len(), 8);
}
}