Skip to main content

dioxus_mdx/parser/
syntax.rs

1//! Syntax highlighting for code blocks using syntect.
2//!
3//! Generates HTML with inline styles for code syntax highlighting.
4
5use std::sync::LazyLock;
6use syntect::highlighting::ThemeSet;
7use syntect::html::highlighted_html_for_string;
8use syntect::parsing::SyntaxSet;
9
10/// Lazily loaded syntax set with default syntaxes.
11static SYNTAX_SET: LazyLock<SyntaxSet> = LazyLock::new(SyntaxSet::load_defaults_newlines);
12
13/// Lazily loaded theme set with default themes.
14static THEME_SET: LazyLock<ThemeSet> = LazyLock::new(ThemeSet::load_defaults);
15
16/// Apply syntax highlighting to code.
17///
18/// Returns HTML string with inline styles for syntax highlighting.
19/// Falls back to plain code wrapped in `<code>` if highlighting fails.
20pub fn highlight_code(code: &str, language: Option<&str>) -> String {
21    let lang = language.unwrap_or("txt");
22
23    // Map common language aliases
24    let syntax_name = map_language(lang);
25
26    // Find syntax definition
27    let syntax = SYNTAX_SET
28        .find_syntax_by_extension(syntax_name)
29        .or_else(|| SYNTAX_SET.find_syntax_by_name(syntax_name))
30        .or_else(|| SYNTAX_SET.find_syntax_by_extension(lang))
31        .or_else(|| SYNTAX_SET.find_syntax_by_name(lang))
32        .unwrap_or_else(|| SYNTAX_SET.find_syntax_plain_text());
33
34    // Use a dark theme suitable for dark mode
35    // "base16-ocean.dark" is a good dark theme included in syntect
36    let theme = THEME_SET
37        .themes
38        .get("base16-ocean.dark")
39        .or_else(|| THEME_SET.themes.get("InspiredGitHub"))
40        .unwrap_or_else(|| THEME_SET.themes.values().next().unwrap());
41
42    // Generate highlighted HTML
43    match highlighted_html_for_string(code, &SYNTAX_SET, syntax, theme) {
44        Ok(html) => {
45            // The output is wrapped in <pre style="..."><code>...</code></pre>
46            // We want just the inner content since we have our own wrapper
47            // Extract the content between <pre...> and </pre>
48            if let Some(start) = html.find('>')
49                && let Some(end) = html.rfind("</pre>")
50            {
51                // Trim leading/trailing whitespace from the extracted HTML
52                // syntect adds a newline after <pre> which we don't want
53                return html[start + 1..end].trim().to_string();
54            }
55            html
56        }
57        Err(_) => {
58            // Fallback: escape HTML and return plain code
59            escape_html(code)
60        }
61    }
62}
63
64/// Map common language aliases to syntect syntax names.
65/// Returns a static string if there's a known mapping, otherwise returns the original.
66fn map_language(lang: &str) -> &str {
67    // Use case-insensitive matching via eq_ignore_ascii_case
68    // JavaScript variants
69    if lang.eq_ignore_ascii_case("js") || lang.eq_ignore_ascii_case("javascript") {
70        return "JavaScript";
71    }
72    if lang.eq_ignore_ascii_case("jsx") {
73        return "JavaScript (JSX)";
74    }
75    if lang.eq_ignore_ascii_case("ts") || lang.eq_ignore_ascii_case("typescript") {
76        return "TypeScript";
77    }
78    if lang.eq_ignore_ascii_case("tsx") {
79        return "TypeScript (TSX)";
80    }
81
82    // Shell variants
83    if lang.eq_ignore_ascii_case("sh")
84        || lang.eq_ignore_ascii_case("bash")
85        || lang.eq_ignore_ascii_case("shell")
86        || lang.eq_ignore_ascii_case("zsh")
87    {
88        return "Bash";
89    }
90
91    // Rust
92    if lang.eq_ignore_ascii_case("rs") || lang.eq_ignore_ascii_case("rust") {
93        return "Rust";
94    }
95
96    // Python
97    if lang.eq_ignore_ascii_case("py") || lang.eq_ignore_ascii_case("python") {
98        return "Python";
99    }
100
101    // Ruby
102    if lang.eq_ignore_ascii_case("rb") || lang.eq_ignore_ascii_case("ruby") {
103        return "Ruby";
104    }
105
106    // Go
107    if lang.eq_ignore_ascii_case("go") || lang.eq_ignore_ascii_case("golang") {
108        return "Go";
109    }
110
111    // JSON
112    if lang.eq_ignore_ascii_case("json") || lang.eq_ignore_ascii_case("jsonc") {
113        return "JSON";
114    }
115
116    // YAML
117    if lang.eq_ignore_ascii_case("yml") || lang.eq_ignore_ascii_case("yaml") {
118        return "YAML";
119    }
120
121    // HTML/CSS
122    if lang.eq_ignore_ascii_case("html") || lang.eq_ignore_ascii_case("htm") {
123        return "HTML";
124    }
125    if lang.eq_ignore_ascii_case("css") {
126        return "CSS";
127    }
128    if lang.eq_ignore_ascii_case("scss") {
129        return "SCSS";
130    }
131    if lang.eq_ignore_ascii_case("sass") {
132        return "Sass";
133    }
134
135    // Config files
136    if lang.eq_ignore_ascii_case("toml") {
137        return "TOML";
138    }
139    if lang.eq_ignore_ascii_case("ini") {
140        return "INI";
141    }
142    if lang.eq_ignore_ascii_case("env") {
143        return "Bourne Again Shell (bash)";
144    }
145
146    // Markdown
147    if lang.eq_ignore_ascii_case("md") || lang.eq_ignore_ascii_case("markdown") {
148        return "Markdown";
149    }
150
151    // SQL
152    if lang.eq_ignore_ascii_case("sql") {
153        return "SQL";
154    }
155
156    // C/C++
157    if lang.eq_ignore_ascii_case("c") || lang.eq_ignore_ascii_case("h") {
158        return "C";
159    }
160    if lang.eq_ignore_ascii_case("cpp")
161        || lang.eq_ignore_ascii_case("cc")
162        || lang.eq_ignore_ascii_case("cxx")
163        || lang.eq_ignore_ascii_case("hpp")
164    {
165        return "C++";
166    }
167
168    // Java
169    if lang.eq_ignore_ascii_case("java") {
170        return "Java";
171    }
172
173    // C#
174    if lang.eq_ignore_ascii_case("cs") || lang.eq_ignore_ascii_case("csharp") {
175        return "C#";
176    }
177
178    // PHP
179    if lang.eq_ignore_ascii_case("php") {
180        return "PHP";
181    }
182
183    // Swift
184    if lang.eq_ignore_ascii_case("swift") {
185        return "Swift";
186    }
187
188    // Kotlin
189    if lang.eq_ignore_ascii_case("kt") || lang.eq_ignore_ascii_case("kotlin") {
190        return "Kotlin";
191    }
192
193    // Dockerfile
194    if lang.eq_ignore_ascii_case("dockerfile") || lang.eq_ignore_ascii_case("docker") {
195        return "Dockerfile";
196    }
197
198    // Plain text
199    if lang.eq_ignore_ascii_case("txt") || lang.eq_ignore_ascii_case("text") {
200        return "Plain Text";
201    }
202
203    // Default: return the original language string
204    lang
205}
206
207/// Escape HTML special characters.
208fn escape_html(text: &str) -> String {
209    text.replace('&', "&amp;")
210        .replace('<', "&lt;")
211        .replace('>', "&gt;")
212        .replace('"', "&quot;")
213        .replace('\'', "&#39;")
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn test_highlight_rust() {
222        let code = r#"fn main() {
223    println!("Hello, world!");
224}"#;
225        let html = highlight_code(code, Some("rust"));
226        // Should contain syntax highlighting spans
227        assert!(html.contains("<span"));
228        assert!(html.contains("fn"));
229    }
230
231    #[test]
232    fn test_highlight_javascript() {
233        let code = "const x = 42;";
234        let html = highlight_code(code, Some("js"));
235        assert!(html.contains("<span"));
236    }
237
238    #[test]
239    fn test_highlight_unknown_language() {
240        let code = "some text";
241        let html = highlight_code(code, Some("unknown_lang_xyz"));
242        // Should still return something
243        assert!(!html.is_empty());
244    }
245
246    #[test]
247    fn test_highlight_no_language() {
248        let code = "plain text";
249        let html = highlight_code(code, None);
250        assert!(!html.is_empty());
251    }
252
253    #[test]
254    fn test_escape_html() {
255        assert_eq!(escape_html("<div>"), "&lt;div&gt;");
256        assert_eq!(escape_html("a & b"), "a &amp; b");
257    }
258}