use compact_str::CompactString;
use crossterm::style::Color;
use crate::ui::theme;
#[derive(Debug, Clone)]
pub struct Span {
pub text: CompactString,
pub color: Color,
}
pub fn highlight_code(code: &str, lang: &str) -> Vec<Vec<Span>> {
let lang_norm = normalize_lang(lang);
let rules = rules_for(lang_norm);
let mut out: Vec<Vec<Span>> = Vec::new();
let mut in_block_comment = false;
for raw in code.split('\n') {
let line = raw.trim_end_matches('\r');
if let Some(rules) = rules.as_ref() {
let (spans, still_in_block) = tokenize_line(line, rules, in_block_comment);
in_block_comment = still_in_block;
out.push(spans);
} else {
out.push(vec![Span {
text: CompactString::new(line),
color: theme::tool(),
}]);
}
}
out
}
pub fn supports(lang: &str) -> bool {
rules_for(normalize_lang(lang)).is_some()
}
fn normalize_lang(lang: &str) -> &str {
let head = lang.split([',', ' ']).next().unwrap_or("");
match head {
"ts" | "tsx" | "typescript" | "typescriptreact" => "typescript",
"js" | "jsx" | "javascript" | "javascriptreact" | "mjs" | "cjs" => "javascript",
"py" | "python" | "py3" | "python3" => "python",
"sh" | "bash" | "shell" | "zsh" => "bash",
"clj" | "cljs" | "cljc" | "clojure" | "edn" => "clojure",
"go" | "golang" => "go",
"rb" | "ruby" => "ruby",
"rs" | "rust" => "rust",
"java" => "java",
"c" | "h" => "c",
"cpp" | "cc" | "cxx" | "hpp" | "hh" | "hxx" | "c++" => "cpp",
"json" | "jsonc" => "json",
"yaml" | "yml" => "yaml",
"toml" => "toml",
"sql" => "sql",
"md" | "markdown" => "markdown",
other => other,
}
}
struct Rules {
keywords: &'static [&'static str],
types: &'static [&'static str],
line_comment: Option<&'static str>,
block_comment: Option<(&'static str, &'static str)>,
string_delims: &'static [char],
hash_directive: bool,
}
fn rules_for(lang: &str) -> Option<&'static Rules> {
match lang {
"typescript" | "javascript" => Some(&JS_RULES),
"python" => Some(&PY_RULES),
"bash" => Some(&BASH_RULES),
"clojure" => Some(&CLJ_RULES),
"go" => Some(&GO_RULES),
"ruby" => Some(&RUBY_RULES),
"rust" => Some(&RUST_RULES),
"java" => Some(&JAVA_RULES),
"c" => Some(&C_RULES),
"cpp" => Some(&CPP_RULES),
"json" => Some(&JSON_RULES),
"yaml" => Some(&YAML_RULES),
"toml" => Some(&TOML_RULES),
"sql" => Some(&SQL_RULES),
_ => None,
}
}
fn tokenize_line(line: &str, rules: &Rules, mut in_block: bool) -> (Vec<Span>, bool) {
let mut spans: Vec<Span> = Vec::new();
let bytes = line.as_bytes();
let mut i = 0usize;
if in_block && let Some((_open, close)) = rules.block_comment {
let close_b = close.as_bytes();
if let Some(pos) = find_subseq(&bytes[i..], close_b) {
let end = i + pos + close_b.len();
spans.push(Span {
text: CompactString::new(&line[i..end]),
color: theme::dim(),
});
i = end;
in_block = false;
let _ = in_block; } else {
spans.push(Span {
text: CompactString::new(&line[i..]),
color: theme::dim(),
});
return (spans, true);
}
}
while i < bytes.len() {
let ch = bytes[i] as char;
if let Some(marker) = rules.line_comment {
let mb = marker.as_bytes();
if bytes[i..].starts_with(mb) {
spans.push(Span {
text: CompactString::new(&line[i..]),
color: theme::dim(),
});
return (spans, false);
}
}
if let Some((open, _close)) = rules.block_comment {
let ob = open.as_bytes();
if bytes[i..].starts_with(ob) {
let close = rules.block_comment.unwrap().1;
let close_b = close.as_bytes();
if let Some(pos) = find_subseq(&bytes[i + ob.len()..], close_b) {
let end = i + ob.len() + pos + close_b.len();
spans.push(Span {
text: CompactString::new(&line[i..end]),
color: theme::dim(),
});
i = end;
continue;
} else {
spans.push(Span {
text: CompactString::new(&line[i..]),
color: theme::dim(),
});
return (spans, true);
}
}
}
if rules.string_delims.contains(&ch) {
let delim = ch;
let start = i;
i += 1;
while i < bytes.len() {
let c = bytes[i] as char;
if c == '\\' && i + 1 < bytes.len() {
i += 2;
continue;
}
if c == delim {
i += 1;
break;
}
let step = utf8_char_len(bytes[i]);
i += step.max(1);
}
spans.push(Span {
text: CompactString::new(&line[start..i]),
color: theme::accent(),
});
continue;
}
if ch.is_ascii_digit() {
let start = i;
if ch == '0' && i + 1 < bytes.len() {
let next = bytes[i + 1] as char;
if matches!(next, 'x' | 'X' | 'b' | 'B' | 'o' | 'O') {
i += 2;
while i < bytes.len() && (bytes[i] as char).is_ascii_alphanumeric() {
i += 1;
}
spans.push(Span {
text: CompactString::new(&line[start..i]),
color: theme::warn(),
});
continue;
}
}
while i < bytes.len()
&& ((bytes[i] as char).is_ascii_digit() || bytes[i] == b'.' || bytes[i] == b'_')
{
i += 1;
}
while i < bytes.len() && ((bytes[i] as char).is_ascii_alphanumeric()) {
i += 1;
}
spans.push(Span {
text: CompactString::new(&line[start..i]),
color: theme::warn(),
});
continue;
}
if is_ident_start(ch, rules) {
let start = i;
while i < bytes.len() {
let Some(next) = line[i..].chars().next() else {
break;
};
if !is_ident_cont(next, rules) {
break;
}
i += next.len_utf8();
}
let word = &line[start..i];
let color = if rules.keywords.contains(&word) {
theme::user()
} else if rules.types.contains(&word) || looks_like_type(word, rules) {
theme::header()
} else if i < bytes.len() && bytes[i] == b'(' {
theme::tool()
} else if rules.hash_directive && start > 0 && bytes[start - 1] == b'#' {
theme::user()
} else {
theme::agent()
};
spans.push(Span {
text: CompactString::new(word),
color,
});
continue;
}
let start = i;
while i < bytes.len() {
let c = bytes[i] as char;
if is_ident_start(c, rules) || c.is_ascii_digit() || rules.string_delims.contains(&c) {
break;
}
if let Some(marker) = rules.line_comment
&& bytes[i..].starts_with(marker.as_bytes())
{
break;
}
if let Some((open, _)) = rules.block_comment
&& bytes[i..].starts_with(open.as_bytes())
{
break;
}
i += utf8_char_len(bytes[i]).max(1);
}
if i == start {
i += 1;
}
spans.push(Span {
text: CompactString::new(&line[start..i.min(bytes.len())]),
color: theme::agent(),
});
}
(spans, false)
}
fn is_ident_start(c: char, _rules: &Rules) -> bool {
c.is_ascii_alphabetic() || c == '_' || c == '$'
}
fn is_ident_cont(c: char, rules: &Rules) -> bool {
c.is_ascii_alphanumeric()
|| c == '_'
|| c == '$'
|| (!c.is_ascii() && !c.is_control() && !c.is_whitespace())
|| (rules.string_delims.is_empty() && c == '-')
}
fn looks_like_type(word: &str, _rules: &Rules) -> bool {
let bytes = word.as_bytes();
bytes.len() >= 3
&& bytes[0].is_ascii_uppercase()
&& bytes[1..].iter().any(|b| b.is_ascii_lowercase())
}
fn utf8_char_len(first_byte: u8) -> usize {
if first_byte < 0xC0 {
1
} else if first_byte < 0xE0 {
2
} else if first_byte < 0xF0 {
3
} else {
4
}
}
fn find_subseq(haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() || haystack.len() < needle.len() {
return None;
}
haystack.windows(needle.len()).position(|w| w == needle)
}
static JS_RULES: Rules = Rules {
keywords: &[
"abstract",
"as",
"async",
"await",
"break",
"case",
"catch",
"class",
"const",
"continue",
"debugger",
"default",
"delete",
"do",
"else",
"enum",
"export",
"extends",
"false",
"finally",
"for",
"from",
"function",
"get",
"if",
"implements",
"import",
"in",
"instanceof",
"interface",
"is",
"let",
"new",
"null",
"of",
"package",
"private",
"protected",
"public",
"readonly",
"return",
"satisfies",
"set",
"static",
"super",
"switch",
"this",
"throw",
"true",
"try",
"type",
"typeof",
"undefined",
"var",
"void",
"while",
"with",
"yield",
],
types: &[
"string", "number", "boolean", "object", "any", "unknown", "never", "bigint", "symbol",
"Promise", "Array", "Map", "Set", "Date", "RegExp", "Error",
],
line_comment: Some("//"),
block_comment: Some(("/*", "*/")),
string_delims: &['"', '\'', '`'],
hash_directive: false,
};
static PY_RULES: Rules = Rules {
keywords: &[
"False", "None", "True", "and", "as", "assert", "async", "await", "break", "class",
"continue", "def", "del", "elif", "else", "except", "finally", "for", "from", "global",
"if", "import", "in", "is", "lambda", "match", "nonlocal", "not", "or", "pass", "raise",
"return", "try", "while", "with", "yield",
],
types: &[
"int", "float", "str", "bool", "list", "tuple", "dict", "set", "bytes", "None",
],
line_comment: Some("#"),
block_comment: None,
string_delims: &['"', '\''],
hash_directive: false,
};
static BASH_RULES: Rules = Rules {
keywords: &[
"if", "then", "else", "elif", "fi", "case", "esac", "for", "select", "while", "until",
"do", "done", "in", "function", "return", "break", "continue", "exit", "export", "local",
"readonly", "declare", "typeset", "unset", "alias", "trap", "source", "eval", "exec",
],
types: &[],
line_comment: Some("#"),
block_comment: None,
string_delims: &['"', '\''],
hash_directive: false,
};
static CLJ_RULES: Rules = Rules {
keywords: &[
"def",
"defn",
"defn-",
"defmacro",
"defmulti",
"defmethod",
"defprotocol",
"defrecord",
"deftype",
"defstruct",
"deflinked-type",
"definterface",
"defonce",
"defproject",
"fn",
"let",
"letfn",
"do",
"quote",
"var",
"if",
"if-not",
"if-let",
"if-some",
"when",
"when-not",
"when-let",
"when-some",
"cond",
"condp",
"case",
"loop",
"recur",
"try",
"catch",
"finally",
"throw",
"and",
"or",
"not",
"nil",
"true",
"false",
"ns",
"require",
"import",
"use",
"in-ns",
],
types: &[],
line_comment: Some(";"),
block_comment: None,
string_delims: &['"'],
hash_directive: false,
};
static GO_RULES: Rules = Rules {
keywords: &[
"break",
"case",
"chan",
"const",
"continue",
"default",
"defer",
"else",
"fallthrough",
"for",
"func",
"go",
"goto",
"if",
"import",
"interface",
"map",
"package",
"range",
"return",
"select",
"struct",
"switch",
"type",
"var",
"nil",
"true",
"false",
"iota",
],
types: &[
"bool",
"byte",
"complex64",
"complex128",
"error",
"float32",
"float64",
"int",
"int8",
"int16",
"int32",
"int64",
"rune",
"string",
"uint",
"uint8",
"uint16",
"uint32",
"uint64",
"uintptr",
],
line_comment: Some("//"),
block_comment: Some(("/*", "*/")),
string_delims: &['"', '\'', '`'],
hash_directive: false,
};
static RUBY_RULES: Rules = Rules {
keywords: &[
"BEGIN",
"END",
"alias",
"and",
"begin",
"break",
"case",
"class",
"def",
"defined?",
"do",
"else",
"elsif",
"end",
"ensure",
"false",
"for",
"if",
"in",
"module",
"next",
"nil",
"not",
"or",
"redo",
"rescue",
"retry",
"return",
"self",
"super",
"then",
"true",
"undef",
"unless",
"until",
"when",
"while",
"yield",
"require",
"require_relative",
"include",
"extend",
"attr_accessor",
"attr_reader",
"attr_writer",
],
types: &[],
line_comment: Some("#"),
block_comment: None,
string_delims: &['"', '\''],
hash_directive: false,
};
static RUST_RULES: Rules = Rules {
keywords: &[
"as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
"extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
"mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
"true", "type", "union", "unsafe", "use", "where", "while", "yield",
],
types: &[
"bool", "char", "f32", "f64", "i8", "i16", "i32", "i64", "i128", "isize", "u8", "u16",
"u32", "u64", "u128", "usize", "str", "String", "Vec", "Option", "Result", "Box", "Rc",
"Ok", "Err", "Some", "None", "Arc", "RefCell", "Cell", "HashMap", "HashSet", "BTreeMap",
"BTreeSet",
],
line_comment: Some("//"),
block_comment: Some(("/*", "*/")),
string_delims: &['"', '\''],
hash_directive: false,
};
static JAVA_RULES: Rules = Rules {
keywords: &[
"abstract",
"assert",
"boolean",
"break",
"byte",
"case",
"catch",
"char",
"class",
"const",
"continue",
"default",
"do",
"double",
"else",
"enum",
"extends",
"final",
"finally",
"float",
"for",
"goto",
"if",
"implements",
"import",
"instanceof",
"int",
"interface",
"long",
"native",
"new",
"null",
"package",
"private",
"protected",
"public",
"return",
"short",
"static",
"strictfp",
"super",
"switch",
"synchronized",
"this",
"throw",
"throws",
"transient",
"true",
"false",
"try",
"void",
"volatile",
"while",
"yield",
"record",
"sealed",
"permits",
"non-sealed",
],
types: &[
"String", "Object", "Integer", "Long", "Double", "Boolean", "List", "Map", "Set",
],
line_comment: Some("//"),
block_comment: Some(("/*", "*/")),
string_delims: &['"', '\''],
hash_directive: false,
};
static C_RULES: Rules = Rules {
keywords: &[
"auto",
"break",
"case",
"char",
"const",
"continue",
"default",
"do",
"double",
"else",
"enum",
"extern",
"float",
"for",
"goto",
"if",
"int",
"long",
"register",
"return",
"short",
"signed",
"sizeof",
"static",
"struct",
"switch",
"typedef",
"union",
"unsigned",
"void",
"volatile",
"while",
"inline",
"restrict",
"_Bool",
"_Complex",
"_Imaginary",
"include",
"define",
"ifdef",
"ifndef",
"endif",
"pragma",
"error",
"undef",
"elif",
],
types: &[
"size_t",
"ssize_t",
"ptrdiff_t",
"intptr_t",
"uintptr_t",
"int8_t",
"int16_t",
"int32_t",
"int64_t",
"uint8_t",
"uint16_t",
"uint32_t",
"uint64_t",
"FILE",
"NULL",
],
line_comment: Some("//"),
block_comment: Some(("/*", "*/")),
string_delims: &['"', '\''],
hash_directive: true,
};
static CPP_RULES: Rules = Rules {
keywords: &[
"alignas",
"alignof",
"and",
"and_eq",
"asm",
"auto",
"bitand",
"bitor",
"bool",
"break",
"case",
"catch",
"char",
"char16_t",
"char32_t",
"class",
"compl",
"const",
"constexpr",
"const_cast",
"continue",
"decltype",
"default",
"delete",
"do",
"double",
"dynamic_cast",
"else",
"enum",
"explicit",
"export",
"extern",
"false",
"float",
"for",
"friend",
"goto",
"if",
"inline",
"int",
"long",
"mutable",
"namespace",
"new",
"noexcept",
"not",
"not_eq",
"nullptr",
"operator",
"or",
"or_eq",
"private",
"protected",
"public",
"register",
"reinterpret_cast",
"return",
"short",
"signed",
"sizeof",
"static",
"static_assert",
"static_cast",
"struct",
"switch",
"template",
"this",
"thread_local",
"throw",
"true",
"try",
"typedef",
"typeid",
"typename",
"union",
"unsigned",
"using",
"virtual",
"void",
"volatile",
"wchar_t",
"while",
"xor",
"xor_eq",
"concept",
"requires",
"co_await",
"co_return",
"co_yield",
],
types: &[
"size_t",
"ssize_t",
"ptrdiff_t",
"string",
"vector",
"map",
"set",
"unordered_map",
"unordered_set",
"shared_ptr",
"unique_ptr",
"weak_ptr",
"optional",
"variant",
],
line_comment: Some("//"),
block_comment: Some(("/*", "*/")),
string_delims: &['"', '\''],
hash_directive: true,
};
static JSON_RULES: Rules = Rules {
keywords: &["true", "false", "null"],
types: &[],
line_comment: None,
block_comment: None,
string_delims: &['"'],
hash_directive: false,
};
static YAML_RULES: Rules = Rules {
keywords: &["true", "false", "null", "yes", "no", "on", "off"],
types: &[],
line_comment: Some("#"),
block_comment: None,
string_delims: &['"', '\''],
hash_directive: false,
};
static TOML_RULES: Rules = Rules {
keywords: &["true", "false"],
types: &[],
line_comment: Some("#"),
block_comment: None,
string_delims: &['"', '\''],
hash_directive: false,
};
static SQL_RULES: Rules = Rules {
keywords: &[
"SELECT",
"FROM",
"WHERE",
"GROUP",
"BY",
"ORDER",
"HAVING",
"JOIN",
"LEFT",
"RIGHT",
"INNER",
"OUTER",
"FULL",
"ON",
"AS",
"AND",
"OR",
"NOT",
"NULL",
"IS",
"IN",
"LIKE",
"BETWEEN",
"INSERT",
"UPDATE",
"DELETE",
"INTO",
"VALUES",
"SET",
"CREATE",
"TABLE",
"DROP",
"ALTER",
"INDEX",
"VIEW",
"PRIMARY",
"KEY",
"FOREIGN",
"REFERENCES",
"DEFAULT",
"UNIQUE",
"CHECK",
"CONSTRAINT",
"CASCADE",
"TRUE",
"FALSE",
"LIMIT",
"OFFSET",
"WITH",
"UNION",
"ALL",
"DISTINCT",
"CASE",
"WHEN",
"THEN",
"ELSE",
"END",
"BEGIN",
"COMMIT",
"ROLLBACK",
"TRANSACTION",
"select",
"from",
"where",
"group",
"by",
"order",
"having",
"join",
"left",
"right",
"inner",
"outer",
"full",
"on",
"as",
"and",
"or",
"not",
"null",
"is",
"in",
"like",
"between",
"insert",
"update",
"delete",
"into",
"values",
"set",
"create",
"table",
"drop",
"alter",
"index",
"view",
"primary",
"key",
"foreign",
"references",
"default",
"unique",
"check",
"constraint",
"cascade",
"true",
"false",
"limit",
"offset",
"with",
"union",
"all",
"distinct",
"case",
"when",
"then",
"else",
"end",
"begin",
"commit",
"rollback",
"transaction",
],
types: &[
"INT",
"INTEGER",
"VARCHAR",
"TEXT",
"BOOLEAN",
"DATE",
"TIMESTAMP",
"FLOAT",
"DOUBLE",
],
line_comment: Some("--"),
block_comment: Some(("/*", "*/")),
string_delims: &['\''],
hash_directive: false,
};
#[cfg(test)]
mod tests {
use super::*;
fn render(code: &str, lang: &str) -> Vec<Vec<Span>> {
highlight_code(code, lang)
}
#[test]
fn rust_basic_keywords_colored() {
let lines = render("fn main() {}", "rust");
assert_eq!(lines.len(), 1);
let row = &lines[0];
let fn_span = row.iter().find(|s| s.text == "fn").expect("fn span");
assert_eq!(fn_span.color, theme::user());
}
#[test]
fn strings_get_accent_color() {
let lines = render(r#"let s = "hello";"#, "rust");
let row = &lines[0];
let str_span = row.iter().find(|s| s.text == "\"hello\"").expect("string");
assert_eq!(str_span.color, theme::accent());
}
#[test]
fn line_comments_dim() {
let lines = render("let x = 1; // comment", "rust");
let row = &lines[0];
let com = row
.iter()
.find(|s| s.text.contains("comment"))
.expect("comment");
assert_eq!(com.color, theme::dim());
}
#[test]
fn block_comment_spans_lines() {
let lines = render("a\n/* multi\nline */\nb", "rust");
let line2 = &lines[1];
let line3 = &lines[2];
assert!(line2.iter().all(|s| s.color == theme::dim()));
assert!(line3.iter().any(|s| s.color == theme::dim()));
}
#[test]
fn unknown_lang_falls_back_to_uniform_color() {
let lines = render("nonsense ::: gibberish", "fortran");
assert_eq!(lines.len(), 1);
assert_eq!(lines[0].len(), 1);
assert_eq!(lines[0][0].color, theme::tool());
}
#[test]
fn number_literals_colored() {
for n in &["42", "3.14", "0xDEADBEEF", "0b1010", "1_000_000"] {
let lines = render(n, "rust");
assert_eq!(lines[0][0].color, theme::warn(), "for {n}");
}
}
#[test]
fn capitalized_words_get_type_color() {
let lines = render("let v: Vec<String> = Vec::new();", "rust");
let row = &lines[0];
assert!(
row.iter()
.any(|s| s.text == "Vec" && s.color == theme::header())
);
}
#[test]
fn python_def_keyword_colored() {
let lines = render("def hello():\n pass", "python");
let def = lines[0].iter().find(|s| s.text == "def").expect("def");
assert_eq!(def.color, theme::user());
}
#[test]
fn sql_keywords_case_variants_both_match() {
let lines = render("SELECT * FROM t WHERE x = 1", "sql");
let row = &lines[0];
assert!(
row.iter()
.any(|s| s.text == "SELECT" && s.color == theme::user())
);
let lines = render("select * from t where x = 1", "sql");
let row = &lines[0];
assert!(
row.iter()
.any(|s| s.text == "select" && s.color == theme::user())
);
}
#[test]
fn lang_aliases_normalize() {
assert_eq!(normalize_lang("ts"), "typescript");
assert_eq!(normalize_lang("tsx"), "typescript");
assert_eq!(normalize_lang("py"), "python");
assert_eq!(normalize_lang("c++"), "cpp");
assert_eq!(normalize_lang("rust,no_run"), "rust");
assert_eq!(normalize_lang("rust ignore"), "rust");
}
}