Skip to main content

dioxus_code_macro/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::env;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use macro_string::MacroString;
9use proc_macro::TokenStream;
10use proc_macro_crate::{FoundCrate, crate_name};
11use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
12use quote::{format_ident, quote, quote_spanned};
13use syn::parse::{Parse, ParseStream};
14use syn::spanned::Spanned;
15use syn::{Expr, LitStr, Token, parse_macro_input};
16
17/// Compile-time syntax highlighting.
18///
19/// Reads a source file relative to the consumer's `CARGO_MANIFEST_DIR`, parses
20/// it with [`arborium`], and expands to the resulting span tree. Pass the path
21/// as a string literal, `concat!(...)`, or `env!(...)`. Pass
22/// [`CodeOptions::builder`] with [`CodeOptions::with_language`] to name the
23/// language explicitly; otherwise it is inferred from the file extension.
24///
25/// [`CodeOptions::builder`]: https://docs.rs/dioxus-code/latest/dioxus_code/struct.CodeOptions.html#method.builder
26/// [`CodeOptions::with_language`]: https://docs.rs/dioxus-code/latest/dioxus_code/struct.CodeOptions.html#method.with_language
27#[proc_macro]
28pub fn code(input: TokenStream) -> TokenStream {
29    let input = parse_macro_input!(input as CodeInput);
30
31    match expand_code(input) {
32        Ok(tokens) => tokens.into(),
33        Err(error) => error.to_compile_error().into(),
34    }
35}
36
37struct CodeInput {
38    path: String,
39    options: Option<Expr>,
40}
41
42impl Parse for CodeInput {
43    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
44        let MacroString(path) = input.parse()?;
45        let mut options = None;
46
47        if input.peek(Token![,]) {
48            input.parse::<Token![,]>()?;
49            if !input.is_empty() {
50                let expr: Expr = input.parse()?;
51                if input.peek(Token![,]) {
52                    input.parse::<Token![,]>()?;
53                }
54                if !input.is_empty() {
55                    return Err(input.error("unexpected tokens after code macro options"));
56                }
57                options = Some(expr);
58            }
59        }
60
61        Ok(Self { path, options })
62    }
63}
64
65fn try_extract_language(expr: &Expr) -> Option<String> {
66    match expr {
67        Expr::Group(group) => try_extract_language(&group.expr),
68        Expr::Paren(paren) => try_extract_language(&paren.expr),
69        Expr::MethodCall(method) => {
70            if method.method == "with_language"
71                && method.args.len() == 1
72                && let Some(slug) = try_parse_language_arg(method.args.first().unwrap())
73            {
74                return Some(slug);
75            }
76            try_extract_language(&method.receiver)
77        }
78        _ => None,
79    }
80}
81
82fn try_parse_language_arg(expr: &Expr) -> Option<String> {
83    match expr {
84        Expr::Group(group) => try_parse_language_arg(&group.expr),
85        Expr::Paren(paren) => try_parse_language_arg(&paren.expr),
86        Expr::Call(call) if is_some_call(call) && call.args.len() == 1 => {
87            try_parse_language_arg(call.args.first().unwrap())
88        }
89        Expr::Path(path) if is_none_path(path) => None,
90        Expr::Path(path) => language_slug_from_path(path).map(str::to_string),
91        _ => None,
92    }
93}
94
95fn is_some_call(call: &syn::ExprCall) -> bool {
96    let Expr::Path(path) = call.func.as_ref() else {
97        return false;
98    };
99    path.path
100        .segments
101        .last()
102        .is_some_and(|segment| segment.ident == "Some")
103}
104
105fn is_none_path(path: &syn::ExprPath) -> bool {
106    path.path
107        .segments
108        .last()
109        .is_some_and(|segment| segment.ident == "None")
110}
111
112const LANGUAGE_VARIANTS: &[(&str, &str)] = &[
113    ("Rust", "rust"),
114    ("Ada", "ada"),
115    ("Agda", "agda"),
116    ("Asciidoc", "asciidoc"),
117    ("Asm", "asm"),
118    ("Awk", "awk"),
119    ("Bash", "bash"),
120    ("Batch", "batch"),
121    ("C", "c"),
122    ("CSharp", "c-sharp"),
123    ("Caddy", "caddy"),
124    ("Capnp", "capnp"),
125    ("Cedar", "cedar"),
126    ("CedarSchema", "cedarschema"),
127    ("Clojure", "clojure"),
128    ("CMake", "cmake"),
129    ("Cobol", "cobol"),
130    ("CommonLisp", "commonlisp"),
131    ("Cpp", "cpp"),
132    ("Css", "css"),
133    ("D", "d"),
134    ("Dart", "dart"),
135    ("DeviceTree", "devicetree"),
136    ("Diff", "diff"),
137    ("Dockerfile", "dockerfile"),
138    ("Dot", "dot"),
139    ("Elisp", "elisp"),
140    ("Elixir", "elixir"),
141    ("Elm", "elm"),
142    ("Erlang", "erlang"),
143    ("Fish", "fish"),
144    ("FSharp", "fsharp"),
145    ("Gleam", "gleam"),
146    ("Glsl", "glsl"),
147    ("Go", "go"),
148    ("GraphQL", "graphql"),
149    ("Groovy", "groovy"),
150    ("Haskell", "haskell"),
151    ("Hcl", "hcl"),
152    ("Hlsl", "hlsl"),
153    ("Html", "html"),
154    ("Idris", "idris"),
155    ("Ini", "ini"),
156    ("Java", "java"),
157    ("JavaScript", "javascript"),
158    ("Jinja2", "jinja2"),
159    ("Jq", "jq"),
160    ("Json", "json"),
161    ("Julia", "julia"),
162    ("Kotlin", "kotlin"),
163    ("Lean", "lean"),
164    ("Lua", "lua"),
165    ("Markdown", "markdown"),
166    ("Matlab", "matlab"),
167    ("Meson", "meson"),
168    ("Nginx", "nginx"),
169    ("Ninja", "ninja"),
170    ("Nix", "nix"),
171    ("ObjectiveC", "objc"),
172    ("OCaml", "ocaml"),
173    ("Perl", "perl"),
174    ("Php", "php"),
175    ("PostScript", "postscript"),
176    ("PowerShell", "powershell"),
177    ("Prolog", "prolog"),
178    ("Python", "python"),
179    ("Query", "query"),
180    ("R", "r"),
181    ("Rego", "rego"),
182    ("Rescript", "rescript"),
183    ("Ron", "ron"),
184    ("Ruby", "ruby"),
185    ("Scala", "scala"),
186    ("Scheme", "scheme"),
187    ("Scss", "scss"),
188    ("Solidity", "solidity"),
189    ("Sparql", "sparql"),
190    ("Sql", "sql"),
191    ("SshConfig", "ssh-config"),
192    ("Starlark", "starlark"),
193    ("Styx", "styx"),
194    ("Svelte", "svelte"),
195    ("Swift", "swift"),
196    ("Textproto", "textproto"),
197    ("Thrift", "thrift"),
198    ("TlaPlus", "tlaplus"),
199    ("Toml", "toml"),
200    ("Tsx", "tsx"),
201    ("TypeScript", "typescript"),
202    ("Typst", "typst"),
203    ("Uiua", "uiua"),
204    ("VisualBasic", "vb"),
205    ("Verilog", "verilog"),
206    ("Vhdl", "vhdl"),
207    ("Vim", "vim"),
208    ("Vue", "vue"),
209    ("Wit", "wit"),
210    ("X86Asm", "x86asm"),
211    ("Xml", "xml"),
212    ("Yaml", "yaml"),
213    ("Yuri", "yuri"),
214    ("Zig", "zig"),
215    ("Zsh", "zsh"),
216];
217
218fn language_slug_from_path(path: &syn::ExprPath) -> Option<&'static str> {
219    let variant = path.path.segments.last()?.ident.to_string();
220    LANGUAGE_VARIANTS
221        .iter()
222        .find(|(name, _)| *name == variant)
223        .map(|(_, slug)| *slug)
224}
225
226fn language_variant_for_slug(slug: &str) -> Option<&'static str> {
227    LANGUAGE_VARIANTS
228        .iter()
229        .find(|(_, s)| *s == slug)
230        .map(|(name, _)| *name)
231}
232
233fn expand_code(input: CodeInput) -> syn::Result<TokenStream2> {
234    let manifest_dir = env::var("CARGO_MANIFEST_DIR")
235        .map_err(|error| syn::Error::new(Span::call_site(), error.to_string()))?;
236    let manifest_dir = PathBuf::from(manifest_dir);
237    let macro_path = input.path;
238    let absolute_path = resolve_manifest_path(&manifest_dir, &macro_path);
239    let crate_path = dioxus_code_crate_path()?;
240
241    let options_check = input.options.as_ref().map(|expr| {
242        quote_spanned! { expr.span() =>
243            const _: fn() = || {
244                let _: #crate_path::CodeOptions = #expr;
245            };
246        }
247    });
248
249    let source = fs::read_to_string(&absolute_path).map_err(|error| {
250        syn::Error::new(
251            Span::call_site(),
252            format!("failed to read `{}`: {error}", absolute_path.display()),
253        )
254    })?;
255
256    let Some(language) = input
257        .options
258        .as_ref()
259        .and_then(try_extract_language)
260        .or_else(|| arborium::detect_language(&macro_path).map(str::to_string))
261    else {
262        let message = format!(
263            "could not detect language for `{macro_path}`; pass `CodeOptions::builder().with_language(Language::Rust)`"
264        );
265        return Ok(quote! {{
266            #options_check
267            compile_error!(#message);
268        }});
269    };
270
271    let mut highlighter = arborium::Highlighter::new();
272    let spans = highlighter
273        .highlight_spans(&language, &source)
274        .map_err(|error| syn::Error::new(Span::call_site(), error.to_string()))?;
275
276    let Some(variant) = language_variant_for_slug(&language) else {
277        let message = format!("language `{language}` has no `Language` variant");
278        return Ok(quote! {{
279            #options_check
280            compile_error!(#message);
281        }});
282    };
283    let variant_ident = Ident::new(variant, Span::call_site());
284    let absolute_lit = LitStr::new(&absolute_path.to_string_lossy(), Span::call_site());
285    let spans = normalize_spans(spans).into_iter().map(|span| {
286        let start = span.start;
287        let end = span.end;
288        let tag = LitStr::new(span.tag, Span::call_site());
289
290        quote! {
291            #crate_path::advanced::HighlightSpan::new(#start..#end, #tag)
292        }
293    });
294
295    Ok(quote! {{
296        #options_check
297        const SOURCE: &str = include_str!(#absolute_lit);
298        static SPANS: &[#crate_path::advanced::HighlightSpan] = &[#(#spans),*];
299        #crate_path::advanced::HighlightedSource::from_static_parts(
300            SOURCE,
301            #crate_path::Language::#variant_ident,
302            SPANS,
303        )
304    }})
305}
306
307struct NormalizedSpan {
308    start: u32,
309    end: u32,
310    tag: &'static str,
311}
312
313struct RawSpan {
314    start: u32,
315    end: u32,
316    tag: Option<&'static str>,
317    pattern_index: u32,
318}
319
320fn normalize_spans(spans: Vec<arborium::advanced::Span>) -> Vec<NormalizedSpan> {
321    use std::collections::HashMap;
322
323    let mut deduped: HashMap<(u32, u32), RawSpan> = HashMap::new();
324    for span in spans {
325        let span = RawSpan {
326            start: span.start,
327            end: span.end,
328            tag: arborium_theme::tag_for_capture(&span.capture),
329            pattern_index: span.pattern_index,
330        };
331        let key = (span.start, span.end);
332
333        if let Some(existing) = deduped.get(&key) {
334            let should_replace = match (span.tag.is_some(), existing.tag.is_some()) {
335                (true, false) => true,
336                (false, true) => false,
337                _ => span.pattern_index >= existing.pattern_index,
338            };
339            if should_replace {
340                deduped.insert(key, span);
341            }
342        } else {
343            deduped.insert(key, span);
344        }
345    }
346
347    let mut spans: Vec<_> = deduped
348        .into_values()
349        .filter_map(|span| {
350            Some(NormalizedSpan {
351                start: span.start,
352                end: span.end,
353                tag: span.tag?,
354            })
355        })
356        .collect();
357
358    spans.sort_by_key(|span| (span.start, span.end));
359
360    let mut coalesced: Vec<NormalizedSpan> = Vec::with_capacity(spans.len());
361    for span in spans {
362        if let Some(last) = coalesced.last_mut()
363            && span.tag == last.tag
364            && span.start <= last.end
365        {
366            last.end = last.end.max(span.end);
367            continue;
368        }
369        coalesced.push(span);
370    }
371
372    coalesced
373}
374
375fn dioxus_code_crate_path() -> syn::Result<TokenStream2> {
376    match crate_name("dioxus-code") {
377        Ok(FoundCrate::Itself) => Ok(quote!(::dioxus_code)),
378        Ok(FoundCrate::Name(name)) => {
379            let ident = format_ident!("{}", name);
380            Ok(quote!(::#ident))
381        }
382        Err(error) => Err(syn::Error::new(Span::call_site(), error.to_string())),
383    }
384}
385
386fn resolve_manifest_path(manifest_dir: &Path, path: &str) -> PathBuf {
387    let path_buf = PathBuf::from(path);
388    if path_buf.is_absolute() && (path_buf.exists() || path_buf.starts_with(manifest_dir)) {
389        return path_buf;
390    }
391
392    if let Some(stripped) = path.strip_prefix('/') {
393        manifest_dir.join(stripped)
394    } else {
395        manifest_dir.join(path)
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    fn language(expr: &str) -> Option<String> {
404        let expr = syn::parse_str::<Expr>(expr).unwrap();
405        try_extract_language(&expr)
406    }
407
408    #[test]
409    fn extracts_language_variant_options() {
410        assert_eq!(
411            language("CodeOptions::builder().with_language(Language::Rust)").as_deref(),
412            Some("rust"),
413        );
414        assert_eq!(
415            language("CodeOptions::builder().with_language(Some(Language::Rust))").as_deref(),
416            Some("rust"),
417        );
418    }
419
420    #[test]
421    fn extracts_none_language_option() {
422        assert_eq!(
423            language("CodeOptions::builder().with_language(None)").as_deref(),
424            None,
425        );
426    }
427
428    #[test]
429    fn unknown_method_chains_fall_back_silently() {
430        assert_eq!(language("CodeOptions::builder()").as_deref(), None);
431        assert_eq!(
432            language("CodeOptions::builder().with_themes(Language::Rust)").as_deref(),
433            None,
434        );
435    }
436}