use std::path::Path;
use std::sync::OnceLock;
use syntect::parsing::{ParseState, ScopeStack, SyntaxSet};
use crate::diff::{FileDiff, LineKind, TokenKind};
fn syntax_set() -> &'static SyntaxSet {
static S: OnceLock<SyntaxSet> = OnceLock::new();
S.get_or_init(SyntaxSet::load_defaults_newlines)
}
pub fn highlight_files(files: &mut [FileDiff]) {
let ps = syntax_set();
for file in files {
let syntax = syntax_for_path(&file.path, ps);
let mut parse_state = ParseState::new(syntax);
let mut stack = ScopeStack::new();
for line in &mut file.lines {
match line.kind {
LineKind::HunkHeader | LineKind::FoldDown | LineKind::FoldUp => {
parse_state = ParseState::new(syntax);
stack = ScopeStack::new();
}
LineKind::Context | LineKind::Added | LineKind::Removed => {
let text = format!("{}\n", line.text);
let Ok(ops) = parse_state.parse_line(&text, ps) else {
continue;
};
let mut segments: Vec<(TokenKind, String)> = Vec::new();
let mut last = 0usize;
let mut push = |kind: TokenKind, slice: &str| {
if slice.is_empty() {
return;
}
let trimmed = slice.trim_end_matches('\n');
if trimmed.is_empty() {
return;
}
segments.push((kind, trimmed.to_string()));
};
for (pos, op) in &ops {
if *pos > last {
let kind = classify(&stack);
push(kind, &text[last..*pos]);
}
stack.apply(op).ok();
last = *pos;
}
if last < text.len() {
let kind = classify(&stack);
push(kind, &text[last..]);
}
let mut merged: Vec<(TokenKind, String)> = Vec::with_capacity(segments.len());
for (k, s) in segments {
if let Some(last) = merged.last_mut() {
if last.0 == k {
last.1.push_str(&s);
continue;
}
}
merged.push((k, s));
}
line.segments = merged;
}
}
}
}
}
fn classify(stack: &ScopeStack) -> TokenKind {
for scope in stack.as_slice().iter().rev() {
let s = format!("{}", scope);
if starts_with_part(&s, "comment") {
return TokenKind::Comment;
}
if starts_with_part(&s, "string") {
return TokenKind::String;
}
if starts_with_part(&s, "constant.numeric") {
return TokenKind::Number;
}
if starts_with_part(&s, "constant") {
return TokenKind::Constant;
}
if starts_with_part(&s, "keyword.operator") {
return TokenKind::Operator;
}
if starts_with_part(&s, "keyword") || starts_with_part(&s, "storage") {
return TokenKind::Keyword;
}
if starts_with_part(&s, "entity.name.function")
|| starts_with_part(&s, "support.function")
|| starts_with_part(&s, "variable.function")
|| starts_with_part(&s, "meta.function-call")
{
return TokenKind::Function;
}
if starts_with_part(&s, "entity.name.type")
|| starts_with_part(&s, "entity.name.class")
|| starts_with_part(&s, "support.type")
|| starts_with_part(&s, "support.class")
{
return TokenKind::Type;
}
if starts_with_part(&s, "punctuation") {
return TokenKind::Punctuation;
}
if starts_with_part(&s, "variable") {
return TokenKind::Variable;
}
}
TokenKind::Default
}
fn starts_with_part(scope: &str, prefix: &str) -> bool {
if !scope.starts_with(prefix) {
return false;
}
matches!(scope.as_bytes().get(prefix.len()), None | Some(b'.'))
}
fn syntax_for_path<'a>(path: &str, ps: &'a SyntaxSet) -> &'a syntect::parsing::SyntaxReference {
let p = Path::new(path);
if let Some(ext) = p.extension().and_then(|e| e.to_str()) {
if let Some(s) = ps.find_syntax_by_extension(ext) {
return s;
}
}
if let Some(name) = p.file_name().and_then(|n| n.to_str()) {
if let Some(s) = ps.find_syntax_by_token(name) {
return s;
}
}
ps.find_syntax_plain_text()
}