Skip to main content

dioxus_code_macro/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::env;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use macro_string::MacroString;
9use proc_macro::TokenStream;
10use proc_macro_crate::{FoundCrate, crate_name};
11use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
12use quote::{format_ident, quote, quote_spanned};
13use syn::parse::{Parse, ParseStream};
14use syn::spanned::Spanned;
15use syn::{Expr, LitStr, Token, parse_macro_input};
16
17/// Compile-time syntax highlighting.
18///
19/// Reads a source file relative to the consumer's `CARGO_MANIFEST_DIR`, parses
20/// it with [`arborium`], and expands to the resulting span tree. Pass the path
21/// as a string literal, `concat!(...)`, or `env!(...)`. Pass
22/// [`CodeOptions::builder`] with [`CodeOptions::with_language`] to name the
23/// language explicitly; otherwise it is inferred from the file extension.
24///
25/// To highlight inline source instead of a file, use [`code_str!`].
26///
27/// [`CodeOptions::builder`]: https://docs.rs/dioxus-code/latest/dioxus_code/struct.CodeOptions.html#method.builder
28/// [`CodeOptions::with_language`]: https://docs.rs/dioxus-code/latest/dioxus_code/struct.CodeOptions.html#method.with_language
29#[proc_macro]
30pub fn code(input: TokenStream) -> TokenStream {
31    let input = parse_macro_input!(input as CodeInput);
32
33    match expand_code(input) {
34        Ok(tokens) => tokens.into(),
35        Err(error) => error.to_compile_error().into(),
36    }
37}
38
39/// Compile-time syntax highlighting of an inline source string.
40///
41/// Parses a string literal containing source code with [`arborium`] and
42/// expands to the resulting span tree. Pass the source as a string literal,
43/// `concat!(...)`, `include_str!(...)`, or `env!(...)`. The language must be
44/// supplied via [`CodeOptions::builder`] with [`CodeOptions::with_language`]
45/// since there is no file extension to infer from.
46///
47/// To highlight a file on disk instead, use [`code!`].
48///
49/// [`CodeOptions::builder`]: https://docs.rs/dioxus-code/latest/dioxus_code/struct.CodeOptions.html#method.builder
50/// [`CodeOptions::with_language`]: https://docs.rs/dioxus-code/latest/dioxus_code/struct.CodeOptions.html#method.with_language
51#[proc_macro]
52pub fn code_str(input: TokenStream) -> TokenStream {
53    let input = parse_macro_input!(input as CodeStrInput);
54
55    match expand_code_str(input) {
56        Ok(tokens) => tokens.into(),
57        Err(error) => error.to_compile_error().into(),
58    }
59}
60
61struct CodeInput {
62    path: String,
63    options: Option<Expr>,
64}
65
66impl Parse for CodeInput {
67    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
68        let (path, options) = parse_string_and_options(input, "code macro")?;
69        Ok(Self { path, options })
70    }
71}
72
73struct CodeStrInput {
74    source: String,
75    options: Option<Expr>,
76}
77
78impl Parse for CodeStrInput {
79    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
80        let (source, options) = parse_string_and_options(input, "code_str macro")?;
81        Ok(Self { source, options })
82    }
83}
84
85fn parse_string_and_options(
86    input: ParseStream<'_>,
87    macro_label: &str,
88) -> syn::Result<(String, Option<Expr>)> {
89    let MacroString(value) = input.parse()?;
90    let mut options = None;
91
92    if input.peek(Token![,]) {
93        input.parse::<Token![,]>()?;
94        if !input.is_empty() {
95            let expr: Expr = input.parse()?;
96            if input.peek(Token![,]) {
97                input.parse::<Token![,]>()?;
98            }
99            if !input.is_empty() {
100                return Err(input.error(format!("unexpected tokens after {macro_label} options")));
101            }
102            options = Some(expr);
103        }
104    }
105
106    Ok((value, options))
107}
108
109fn try_extract_language(expr: &Expr) -> Option<String> {
110    match expr {
111        Expr::Group(group) => try_extract_language(&group.expr),
112        Expr::Paren(paren) => try_extract_language(&paren.expr),
113        Expr::MethodCall(method) => {
114            if method.method == "with_language"
115                && method.args.len() == 1
116                && let Some(slug) = try_parse_language_arg(method.args.first().unwrap())
117            {
118                return Some(slug);
119            }
120            try_extract_language(&method.receiver)
121        }
122        _ => None,
123    }
124}
125
126fn try_parse_language_arg(expr: &Expr) -> Option<String> {
127    match expr {
128        Expr::Group(group) => try_parse_language_arg(&group.expr),
129        Expr::Paren(paren) => try_parse_language_arg(&paren.expr),
130        Expr::Call(call) if is_some_call(call) && call.args.len() == 1 => {
131            try_parse_language_arg(call.args.first().unwrap())
132        }
133        Expr::Path(path) if is_none_path(path) => None,
134        Expr::Path(path) => language_slug_from_path(path).map(str::to_string),
135        _ => None,
136    }
137}
138
139fn is_some_call(call: &syn::ExprCall) -> bool {
140    let Expr::Path(path) = call.func.as_ref() else {
141        return false;
142    };
143    path.path
144        .segments
145        .last()
146        .is_some_and(|segment| segment.ident == "Some")
147}
148
149fn is_none_path(path: &syn::ExprPath) -> bool {
150    path.path
151        .segments
152        .last()
153        .is_some_and(|segment| segment.ident == "None")
154}
155
156const LANGUAGE_VARIANTS: &[(&str, &str)] = &[
157    ("Rust", "rust"),
158    ("Ada", "ada"),
159    ("Agda", "agda"),
160    ("Asciidoc", "asciidoc"),
161    ("Asm", "asm"),
162    ("Awk", "awk"),
163    ("Bash", "bash"),
164    ("Batch", "batch"),
165    ("C", "c"),
166    ("CSharp", "c-sharp"),
167    ("Caddy", "caddy"),
168    ("Capnp", "capnp"),
169    ("Cedar", "cedar"),
170    ("CedarSchema", "cedarschema"),
171    ("Clojure", "clojure"),
172    ("CMake", "cmake"),
173    ("Cobol", "cobol"),
174    ("CommonLisp", "commonlisp"),
175    ("Cpp", "cpp"),
176    ("Css", "css"),
177    ("D", "d"),
178    ("Dart", "dart"),
179    ("DeviceTree", "devicetree"),
180    ("Diff", "diff"),
181    ("Dockerfile", "dockerfile"),
182    ("Dot", "dot"),
183    ("Elisp", "elisp"),
184    ("Elixir", "elixir"),
185    ("Elm", "elm"),
186    ("Erlang", "erlang"),
187    ("Fish", "fish"),
188    ("FSharp", "fsharp"),
189    ("Gleam", "gleam"),
190    ("Glsl", "glsl"),
191    ("Go", "go"),
192    ("GraphQL", "graphql"),
193    ("Groovy", "groovy"),
194    ("Haskell", "haskell"),
195    ("Hcl", "hcl"),
196    ("Hlsl", "hlsl"),
197    ("Html", "html"),
198    ("Idris", "idris"),
199    ("Ini", "ini"),
200    ("Java", "java"),
201    ("JavaScript", "javascript"),
202    ("Jinja2", "jinja2"),
203    ("Jq", "jq"),
204    ("Json", "json"),
205    ("Julia", "julia"),
206    ("Kotlin", "kotlin"),
207    ("Lean", "lean"),
208    ("Lua", "lua"),
209    ("Markdown", "markdown"),
210    ("Matlab", "matlab"),
211    ("Meson", "meson"),
212    ("Nginx", "nginx"),
213    ("Ninja", "ninja"),
214    ("Nix", "nix"),
215    ("ObjectiveC", "objc"),
216    ("OCaml", "ocaml"),
217    ("Perl", "perl"),
218    ("Php", "php"),
219    ("PostScript", "postscript"),
220    ("PowerShell", "powershell"),
221    ("Prolog", "prolog"),
222    ("Python", "python"),
223    ("Query", "query"),
224    ("R", "r"),
225    ("Rego", "rego"),
226    ("Rescript", "rescript"),
227    ("Ron", "ron"),
228    ("Ruby", "ruby"),
229    ("Scala", "scala"),
230    ("Scheme", "scheme"),
231    ("Scss", "scss"),
232    ("Solidity", "solidity"),
233    ("Sparql", "sparql"),
234    ("Sql", "sql"),
235    ("SshConfig", "ssh-config"),
236    ("Starlark", "starlark"),
237    ("Styx", "styx"),
238    ("Svelte", "svelte"),
239    ("Swift", "swift"),
240    ("Textproto", "textproto"),
241    ("Thrift", "thrift"),
242    ("TlaPlus", "tlaplus"),
243    ("Toml", "toml"),
244    ("Tsx", "tsx"),
245    ("TypeScript", "typescript"),
246    ("Typst", "typst"),
247    ("Uiua", "uiua"),
248    ("VisualBasic", "vb"),
249    ("Verilog", "verilog"),
250    ("Vhdl", "vhdl"),
251    ("Vim", "vim"),
252    ("Vue", "vue"),
253    ("Wit", "wit"),
254    ("X86Asm", "x86asm"),
255    ("Xml", "xml"),
256    ("Yaml", "yaml"),
257    ("Yuri", "yuri"),
258    ("Zig", "zig"),
259    ("Zsh", "zsh"),
260];
261
262fn language_slug_from_path(path: &syn::ExprPath) -> Option<&'static str> {
263    let variant = path.path.segments.last()?.ident.to_string();
264    LANGUAGE_VARIANTS
265        .iter()
266        .find(|(name, _)| *name == variant)
267        .map(|(_, slug)| *slug)
268}
269
270fn language_variant_for_slug(slug: &str) -> Option<&'static str> {
271    LANGUAGE_VARIANTS
272        .iter()
273        .find(|(_, s)| *s == slug)
274        .map(|(name, _)| *name)
275}
276
277fn expand_code(input: CodeInput) -> syn::Result<TokenStream2> {
278    let manifest_dir = env::var("CARGO_MANIFEST_DIR")
279        .map_err(|error| syn::Error::new(Span::call_site(), error.to_string()))?;
280    let absolute_path = resolve_manifest_path(&PathBuf::from(manifest_dir), &input.path);
281    let source = fs::read_to_string(&absolute_path).map_err(|error| {
282        syn::Error::new(
283            Span::call_site(),
284            format!("failed to read `{}`: {error}", absolute_path.display()),
285        )
286    })?;
287
288    expand_shared(input.options, source, Some(absolute_path))
289}
290
291fn expand_code_str(input: CodeStrInput) -> syn::Result<TokenStream2> {
292    expand_shared(input.options, input.source, None)
293}
294
295fn expand_shared(
296    options: Option<Expr>,
297    source: String,
298    origin_path: Option<PathBuf>,
299) -> syn::Result<TokenStream2> {
300    let crate_path = dioxus_code_crate_path()?;
301    let options_check = options_check_tokens(&crate_path, options.as_ref());
302
303    let Some(language) = options.as_ref().and_then(try_extract_language).or_else(|| {
304        origin_path
305            .as_ref()
306            .and_then(|path| arborium::detect_language(&path.to_string_lossy()).map(str::to_string))
307    }) else {
308        let message = match origin_path.as_ref() {
309            Some(path) => format!(
310                "could not detect language for `{}`; pass `CodeOptions::builder().with_language(Language::Rust)`",
311                path.display()
312            ),
313            None => String::from(
314                "could not determine language for `code_str!`; pass `CodeOptions::builder().with_language(Language::Rust)`",
315            ),
316        };
317        return Ok(quote! {{
318            #options_check
319            compile_error!(#message);
320        }});
321    };
322
323    let mut highlighter = arborium::Highlighter::new();
324    let spans = highlighter
325        .highlight_spans(&language, &source)
326        .map_err(|error| syn::Error::new(Span::call_site(), error.to_string()))?;
327
328    let Some(variant) = language_variant_for_slug(&language) else {
329        let message = format!("language `{language}` has no `Language` variant");
330        return Ok(quote! {{
331            #options_check
332            compile_error!(#message);
333        }});
334    };
335    let variant_ident = Ident::new(variant, Span::call_site());
336
337    let source_expr = match origin_path {
338        Some(path) => {
339            let path_lit = LitStr::new(&path.to_string_lossy(), Span::call_site());
340            quote! { include_str!(#path_lit) }
341        }
342        None => {
343            let source_lit = LitStr::new(&source, Span::call_site());
344            quote! { #source_lit }
345        }
346    };
347
348    let span_tokens = normalize_spans(spans).into_iter().map(|span| {
349        let start = span.start;
350        let end = span.end;
351        let tag = LitStr::new(span.tag, Span::call_site());
352        quote! {
353            #crate_path::advanced::HighlightSpan::new(#start..#end, #tag)
354        }
355    });
356
357    Ok(quote! {{
358        #options_check
359        const SOURCE: &str = #source_expr;
360        const SPANS: &[#crate_path::advanced::HighlightSpan] = &[#(#span_tokens),*];
361        #crate_path::advanced::HighlightedSource::from_static_parts(
362            SOURCE,
363            #crate_path::Language::#variant_ident,
364            SPANS,
365        )
366    }})
367}
368
369fn options_check_tokens(crate_path: &TokenStream2, options: Option<&Expr>) -> Option<TokenStream2> {
370    options.map(|expr| {
371        quote_spanned! { expr.span() =>
372            const _: fn() = || {
373                let _: #crate_path::CodeOptions = #expr;
374            };
375        }
376    })
377}
378
379struct NormalizedSpan {
380    start: u32,
381    end: u32,
382    tag: &'static str,
383}
384
385struct RawSpan {
386    start: u32,
387    end: u32,
388    tag: Option<&'static str>,
389    pattern_index: u32,
390}
391
392fn normalize_spans(spans: Vec<arborium::advanced::Span>) -> Vec<NormalizedSpan> {
393    use std::collections::HashMap;
394
395    let mut deduped: HashMap<(u32, u32), RawSpan> = HashMap::new();
396    for span in spans {
397        let span = RawSpan {
398            start: span.start,
399            end: span.end,
400            tag: arborium_theme::tag_for_capture(&span.capture),
401            pattern_index: span.pattern_index,
402        };
403        let key = (span.start, span.end);
404
405        if let Some(existing) = deduped.get(&key) {
406            let should_replace = match (span.tag.is_some(), existing.tag.is_some()) {
407                (true, false) => true,
408                (false, true) => false,
409                _ => span.pattern_index >= existing.pattern_index,
410            };
411            if should_replace {
412                deduped.insert(key, span);
413            }
414        } else {
415            deduped.insert(key, span);
416        }
417    }
418
419    let mut spans: Vec<_> = deduped
420        .into_values()
421        .filter_map(|span| {
422            Some(NormalizedSpan {
423                start: span.start,
424                end: span.end,
425                tag: span.tag?,
426            })
427        })
428        .collect();
429
430    spans.sort_by_key(|span| (span.start, span.end));
431
432    let mut coalesced: Vec<NormalizedSpan> = Vec::with_capacity(spans.len());
433    for span in spans {
434        if let Some(last) = coalesced.last_mut()
435            && span.tag == last.tag
436            && span.start <= last.end
437        {
438            last.end = last.end.max(span.end);
439            continue;
440        }
441        coalesced.push(span);
442    }
443
444    coalesced
445}
446
447fn dioxus_code_crate_path() -> syn::Result<TokenStream2> {
448    match crate_name("dioxus-code") {
449        Ok(FoundCrate::Itself) => Ok(quote!(::dioxus_code)),
450        Ok(FoundCrate::Name(name)) => {
451            let ident = format_ident!("{}", name);
452            Ok(quote!(::#ident))
453        }
454        Err(error) => Err(syn::Error::new(Span::call_site(), error.to_string())),
455    }
456}
457
458fn resolve_manifest_path(manifest_dir: &Path, path: &str) -> PathBuf {
459    let path_buf = PathBuf::from(path);
460    if path_buf.is_absolute() && (path_buf.exists() || path_buf.starts_with(manifest_dir)) {
461        return path_buf;
462    }
463
464    if let Some(stripped) = path.strip_prefix('/') {
465        manifest_dir.join(stripped)
466    } else {
467        manifest_dir.join(path)
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    fn language(expr: &str) -> Option<String> {
476        let expr = syn::parse_str::<Expr>(expr).unwrap();
477        try_extract_language(&expr)
478    }
479
480    #[test]
481    fn extracts_language_variant_options() {
482        assert_eq!(
483            language("CodeOptions::builder().with_language(Language::Rust)").as_deref(),
484            Some("rust"),
485        );
486        assert_eq!(
487            language("CodeOptions::builder().with_language(Some(Language::Rust))").as_deref(),
488            Some("rust"),
489        );
490    }
491
492    #[test]
493    fn extracts_none_language_option() {
494        assert_eq!(
495            language("CodeOptions::builder().with_language(None)").as_deref(),
496            None,
497        );
498    }
499
500    #[test]
501    fn unknown_method_chains_fall_back_silently() {
502        assert_eq!(language("CodeOptions::builder()").as_deref(), None);
503        assert_eq!(
504            language("CodeOptions::builder().with_themes(Language::Rust)").as_deref(),
505            None,
506        );
507    }
508}