use std::sync::OnceLock;
use syntect::highlighting::ThemeSet;
use syntect::html::highlighted_html_for_string;
use syntect::parsing::SyntaxSet;
static SYNTAX_SET: OnceLock<SyntaxSet> = OnceLock::new();
static THEME_SET: OnceLock<ThemeSet> = OnceLock::new();
fn syntax_set() -> &'static SyntaxSet {
SYNTAX_SET.get_or_init(SyntaxSet::load_defaults_newlines)
}
fn theme_set() -> &'static ThemeSet {
THEME_SET.get_or_init(ThemeSet::load_defaults)
}
pub fn highlight_code(code: &str, lang: &str) -> Option<String> {
if lang.is_empty() {
return None;
}
let ss = syntax_set();
let ts = theme_set();
let syntax = ss.find_syntax_by_token(lang)?;
let theme = ts
.themes
.get("InspiredGitHub")
.or_else(|| ts.themes.values().next())?;
highlighted_html_for_string(code, ss, syntax, theme).ok()
}
pub fn highlight_code_dark(code: &str, lang: &str) -> Option<String> {
if lang.is_empty() {
return None;
}
let ss = syntax_set();
let ts = theme_set();
let syntax = ss.find_syntax_by_token(lang)?;
let theme = ts
.themes
.get("base16-ocean.dark")
.or_else(|| ts.themes.values().next())?;
highlighted_html_for_string(code, ss, syntax, theme).ok()
}
pub fn apply_syntax_highlighting(html: &str, dark: bool) -> String {
let mut result = String::with_capacity(html.len() + 512);
let mut remaining = html;
loop {
let pre_start = match remaining.find("<pre><code") {
Some(i) => i,
None => {
result.push_str(remaining);
break;
}
};
result.push_str(&remaining[..pre_start]);
remaining = &remaining[pre_start..];
let code_tag_start = 5; let tag_end = match remaining[code_tag_start..].find('>') {
Some(i) => code_tag_start + i,
None => {
result.push_str(remaining);
break;
}
};
let open_tag = &remaining[..=tag_end];
let lang = extract_language(open_tag);
let close = "</code></pre>";
let code_start = tag_end + 1;
let close_pos = match remaining[code_start..].find(close) {
Some(i) => code_start + i,
None => {
result.push_str(remaining);
break;
}
};
let raw_code = &remaining[code_start..close_pos];
let decoded = decode_html_entities(raw_code);
let highlighted = if let Some(ref l) = lang {
if dark {
highlight_code_dark(&decoded, l)
} else {
highlight_code(&decoded, l)
}
} else {
None
};
match highlighted {
Some(hl) => {
result.push_str(&hl);
}
None => {
result.push_str(&remaining[..close_pos + close.len()]);
}
}
remaining = &remaining[close_pos + close.len()..];
}
result
}
fn extract_language(tag: &str) -> Option<String> {
let prefix = "language-";
let pos = tag.find(prefix)?;
let after = &tag[pos + prefix.len()..];
let end = after
.find(|c: char| c == '"' || c == '\'' || c.is_whitespace())
.unwrap_or(after.len());
let lang = &after[..end];
if lang.is_empty() {
None
} else {
Some(lang.to_string())
}
}
fn decode_html_entities(s: &str) -> String {
s.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
.replace(""", "\"")
}