shuire 0.1.1

Vim-like TUI git diff viewer
use std::path::Path;
use std::sync::OnceLock;

use syntect::parsing::{ParseState, ScopeStack, SyntaxSet};

use crate::diff::{FileDiff, LineKind, TokenKind};

fn syntax_set() -> &'static SyntaxSet {
    static S: OnceLock<SyntaxSet> = OnceLock::new();
    S.get_or_init(SyntaxSet::load_defaults_newlines)
}

pub fn highlight_files(files: &mut [FileDiff]) {
    let ps = syntax_set();

    for file in files {
        let syntax = syntax_for_path(&file.path, ps);
        let mut parse_state = ParseState::new(syntax);
        let mut stack = ScopeStack::new();

        for line in &mut file.lines {
            match line.kind {
                LineKind::HunkHeader | LineKind::FoldDown | LineKind::FoldUp => {
                    parse_state = ParseState::new(syntax);
                    stack = ScopeStack::new();
                }
                LineKind::Context | LineKind::Added | LineKind::Removed => {
                    let text = format!("{}\n", line.text);
                    let Ok(ops) = parse_state.parse_line(&text, ps) else {
                        continue;
                    };

                    let mut segments: Vec<(TokenKind, String)> = Vec::new();
                    let mut last = 0usize;
                    let mut push = |kind: TokenKind, slice: &str| {
                        if slice.is_empty() {
                            return;
                        }
                        let trimmed = slice.trim_end_matches('\n');
                        if trimmed.is_empty() {
                            return;
                        }
                        segments.push((kind, trimmed.to_string()));
                    };

                    for (pos, op) in &ops {
                        if *pos > last {
                            let kind = classify(&stack);
                            push(kind, &text[last..*pos]);
                        }
                        stack.apply(op).ok();
                        last = *pos;
                    }
                    if last < text.len() {
                        let kind = classify(&stack);
                        push(kind, &text[last..]);
                    }

                    // Merge adjacent same-kind segments to cut span count.
                    let mut merged: Vec<(TokenKind, String)> = Vec::with_capacity(segments.len());
                    for (k, s) in segments {
                        if let Some(last) = merged.last_mut() {
                            if last.0 == k {
                                last.1.push_str(&s);
                                continue;
                            }
                        }
                        merged.push((k, s));
                    }
                    line.segments = merged;
                }
            }
        }
    }
}

fn classify(stack: &ScopeStack) -> TokenKind {
    for scope in stack.as_slice().iter().rev() {
        let s = format!("{}", scope);
        if starts_with_part(&s, "comment") {
            return TokenKind::Comment;
        }
        if starts_with_part(&s, "string") {
            return TokenKind::String;
        }
        if starts_with_part(&s, "constant.numeric") {
            return TokenKind::Number;
        }
        if starts_with_part(&s, "constant") {
            return TokenKind::Constant;
        }
        if starts_with_part(&s, "keyword.operator") {
            return TokenKind::Operator;
        }
        if starts_with_part(&s, "keyword") || starts_with_part(&s, "storage") {
            return TokenKind::Keyword;
        }
        if starts_with_part(&s, "entity.name.function")
            || starts_with_part(&s, "support.function")
            || starts_with_part(&s, "variable.function")
            || starts_with_part(&s, "meta.function-call")
        {
            return TokenKind::Function;
        }
        if starts_with_part(&s, "entity.name.type")
            || starts_with_part(&s, "entity.name.class")
            || starts_with_part(&s, "support.type")
            || starts_with_part(&s, "support.class")
        {
            return TokenKind::Type;
        }
        if starts_with_part(&s, "punctuation") {
            return TokenKind::Punctuation;
        }
        if starts_with_part(&s, "variable") {
            return TokenKind::Variable;
        }
    }
    TokenKind::Default
}

/// Prefix match on dotted scope names: `"keyword"` matches `"keyword.control.rust"`
/// but not `"keywordish"`.
fn starts_with_part(scope: &str, prefix: &str) -> bool {
    if !scope.starts_with(prefix) {
        return false;
    }
    matches!(scope.as_bytes().get(prefix.len()), None | Some(b'.'))
}

fn syntax_for_path<'a>(path: &str, ps: &'a SyntaxSet) -> &'a syntect::parsing::SyntaxReference {
    let p = Path::new(path);
    if let Some(ext) = p.extension().and_then(|e| e.to_str()) {
        if let Some(s) = ps.find_syntax_by_extension(ext) {
            return s;
        }
    }
    if let Some(name) = p.file_name().and_then(|n| n.to_str()) {
        if let Some(s) = ps.find_syntax_by_token(name) {
            return s;
        }
    }
    ps.find_syntax_plain_text()
}