1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3
4use std::env;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use macro_string::MacroString;
9use proc_macro::TokenStream;
10use proc_macro_crate::{FoundCrate, crate_name};
11use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
12use quote::{format_ident, quote, quote_spanned};
13use syn::parse::{Parse, ParseStream};
14use syn::spanned::Spanned;
15use syn::{Expr, LitStr, Token, parse_macro_input};
16
17#[proc_macro]
28pub fn code(input: TokenStream) -> TokenStream {
29 let input = parse_macro_input!(input as CodeInput);
30
31 match expand_code(input) {
32 Ok(tokens) => tokens.into(),
33 Err(error) => error.to_compile_error().into(),
34 }
35}
36
37struct CodeInput {
38 path: String,
39 options: Option<Expr>,
40}
41
42impl Parse for CodeInput {
43 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
44 let MacroString(path) = input.parse()?;
45 let mut options = None;
46
47 if input.peek(Token![,]) {
48 input.parse::<Token![,]>()?;
49 if !input.is_empty() {
50 let expr: Expr = input.parse()?;
51 if input.peek(Token![,]) {
52 input.parse::<Token![,]>()?;
53 }
54 if !input.is_empty() {
55 return Err(input.error("unexpected tokens after code macro options"));
56 }
57 options = Some(expr);
58 }
59 }
60
61 Ok(Self { path, options })
62 }
63}
64
65fn try_extract_language(expr: &Expr) -> Option<String> {
66 match expr {
67 Expr::Group(group) => try_extract_language(&group.expr),
68 Expr::Paren(paren) => try_extract_language(&paren.expr),
69 Expr::MethodCall(method) => {
70 if method.method == "with_language"
71 && method.args.len() == 1
72 && let Some(slug) = try_parse_language_arg(method.args.first().unwrap())
73 {
74 return Some(slug);
75 }
76 try_extract_language(&method.receiver)
77 }
78 _ => None,
79 }
80}
81
82fn try_parse_language_arg(expr: &Expr) -> Option<String> {
83 match expr {
84 Expr::Group(group) => try_parse_language_arg(&group.expr),
85 Expr::Paren(paren) => try_parse_language_arg(&paren.expr),
86 Expr::Call(call) if is_some_call(call) && call.args.len() == 1 => {
87 try_parse_language_arg(call.args.first().unwrap())
88 }
89 Expr::Path(path) if is_none_path(path) => None,
90 Expr::Path(path) => language_slug_from_path(path).map(str::to_string),
91 _ => None,
92 }
93}
94
95fn is_some_call(call: &syn::ExprCall) -> bool {
96 let Expr::Path(path) = call.func.as_ref() else {
97 return false;
98 };
99 path.path
100 .segments
101 .last()
102 .is_some_and(|segment| segment.ident == "Some")
103}
104
105fn is_none_path(path: &syn::ExprPath) -> bool {
106 path.path
107 .segments
108 .last()
109 .is_some_and(|segment| segment.ident == "None")
110}
111
112const LANGUAGE_VARIANTS: &[(&str, &str)] = &[
113 ("Rust", "rust"),
114 ("Ada", "ada"),
115 ("Agda", "agda"),
116 ("Asciidoc", "asciidoc"),
117 ("Asm", "asm"),
118 ("Awk", "awk"),
119 ("Bash", "bash"),
120 ("Batch", "batch"),
121 ("C", "c"),
122 ("CSharp", "c-sharp"),
123 ("Caddy", "caddy"),
124 ("Capnp", "capnp"),
125 ("Cedar", "cedar"),
126 ("CedarSchema", "cedarschema"),
127 ("Clojure", "clojure"),
128 ("CMake", "cmake"),
129 ("Cobol", "cobol"),
130 ("CommonLisp", "commonlisp"),
131 ("Cpp", "cpp"),
132 ("Css", "css"),
133 ("D", "d"),
134 ("Dart", "dart"),
135 ("DeviceTree", "devicetree"),
136 ("Diff", "diff"),
137 ("Dockerfile", "dockerfile"),
138 ("Dot", "dot"),
139 ("Elisp", "elisp"),
140 ("Elixir", "elixir"),
141 ("Elm", "elm"),
142 ("Erlang", "erlang"),
143 ("Fish", "fish"),
144 ("FSharp", "fsharp"),
145 ("Gleam", "gleam"),
146 ("Glsl", "glsl"),
147 ("Go", "go"),
148 ("GraphQL", "graphql"),
149 ("Groovy", "groovy"),
150 ("Haskell", "haskell"),
151 ("Hcl", "hcl"),
152 ("Hlsl", "hlsl"),
153 ("Html", "html"),
154 ("Idris", "idris"),
155 ("Ini", "ini"),
156 ("Java", "java"),
157 ("JavaScript", "javascript"),
158 ("Jinja2", "jinja2"),
159 ("Jq", "jq"),
160 ("Json", "json"),
161 ("Julia", "julia"),
162 ("Kotlin", "kotlin"),
163 ("Lean", "lean"),
164 ("Lua", "lua"),
165 ("Markdown", "markdown"),
166 ("Matlab", "matlab"),
167 ("Meson", "meson"),
168 ("Nginx", "nginx"),
169 ("Ninja", "ninja"),
170 ("Nix", "nix"),
171 ("ObjectiveC", "objc"),
172 ("OCaml", "ocaml"),
173 ("Perl", "perl"),
174 ("Php", "php"),
175 ("PostScript", "postscript"),
176 ("PowerShell", "powershell"),
177 ("Prolog", "prolog"),
178 ("Python", "python"),
179 ("Query", "query"),
180 ("R", "r"),
181 ("Rego", "rego"),
182 ("Rescript", "rescript"),
183 ("Ron", "ron"),
184 ("Ruby", "ruby"),
185 ("Scala", "scala"),
186 ("Scheme", "scheme"),
187 ("Scss", "scss"),
188 ("Solidity", "solidity"),
189 ("Sparql", "sparql"),
190 ("Sql", "sql"),
191 ("SshConfig", "ssh-config"),
192 ("Starlark", "starlark"),
193 ("Styx", "styx"),
194 ("Svelte", "svelte"),
195 ("Swift", "swift"),
196 ("Textproto", "textproto"),
197 ("Thrift", "thrift"),
198 ("TlaPlus", "tlaplus"),
199 ("Toml", "toml"),
200 ("Tsx", "tsx"),
201 ("TypeScript", "typescript"),
202 ("Typst", "typst"),
203 ("Uiua", "uiua"),
204 ("VisualBasic", "vb"),
205 ("Verilog", "verilog"),
206 ("Vhdl", "vhdl"),
207 ("Vim", "vim"),
208 ("Vue", "vue"),
209 ("Wit", "wit"),
210 ("X86Asm", "x86asm"),
211 ("Xml", "xml"),
212 ("Yaml", "yaml"),
213 ("Yuri", "yuri"),
214 ("Zig", "zig"),
215 ("Zsh", "zsh"),
216];
217
218fn language_slug_from_path(path: &syn::ExprPath) -> Option<&'static str> {
219 let variant = path.path.segments.last()?.ident.to_string();
220 LANGUAGE_VARIANTS
221 .iter()
222 .find(|(name, _)| *name == variant)
223 .map(|(_, slug)| *slug)
224}
225
226fn language_variant_for_slug(slug: &str) -> Option<&'static str> {
227 LANGUAGE_VARIANTS
228 .iter()
229 .find(|(_, s)| *s == slug)
230 .map(|(name, _)| *name)
231}
232
233fn expand_code(input: CodeInput) -> syn::Result<TokenStream2> {
234 let manifest_dir = env::var("CARGO_MANIFEST_DIR")
235 .map_err(|error| syn::Error::new(Span::call_site(), error.to_string()))?;
236 let manifest_dir = PathBuf::from(manifest_dir);
237 let macro_path = input.path;
238 let absolute_path = resolve_manifest_path(&manifest_dir, ¯o_path);
239 let crate_path = dioxus_code_crate_path()?;
240
241 let options_check = input.options.as_ref().map(|expr| {
242 quote_spanned! { expr.span() =>
243 const _: fn() = || {
244 let _: #crate_path::CodeOptions = #expr;
245 };
246 }
247 });
248
249 let source = fs::read_to_string(&absolute_path).map_err(|error| {
250 syn::Error::new(
251 Span::call_site(),
252 format!("failed to read `{}`: {error}", absolute_path.display()),
253 )
254 })?;
255
256 let Some(language) = input
257 .options
258 .as_ref()
259 .and_then(try_extract_language)
260 .or_else(|| arborium::detect_language(¯o_path).map(str::to_string))
261 else {
262 let message = format!(
263 "could not detect language for `{macro_path}`; pass `CodeOptions::builder().with_language(Language::Rust)`"
264 );
265 return Ok(quote! {{
266 #options_check
267 compile_error!(#message);
268 }});
269 };
270
271 let mut highlighter = arborium::Highlighter::new();
272 let spans = highlighter
273 .highlight_spans(&language, &source)
274 .map_err(|error| syn::Error::new(Span::call_site(), error.to_string()))?;
275
276 let Some(variant) = language_variant_for_slug(&language) else {
277 let message = format!("language `{language}` has no `Language` variant");
278 return Ok(quote! {{
279 #options_check
280 compile_error!(#message);
281 }});
282 };
283 let variant_ident = Ident::new(variant, Span::call_site());
284 let absolute_lit = LitStr::new(&absolute_path.to_string_lossy(), Span::call_site());
285 let spans = normalize_spans(spans).into_iter().map(|span| {
286 let start = span.start;
287 let end = span.end;
288 let tag = LitStr::new(span.tag, Span::call_site());
289
290 quote! {
291 #crate_path::advanced::HighlightSpan::new(#start..#end, #tag)
292 }
293 });
294
295 Ok(quote! {{
296 #options_check
297 const SOURCE: &str = include_str!(#absolute_lit);
298 static SPANS: &[#crate_path::advanced::HighlightSpan] = &[#(#spans),*];
299 #crate_path::advanced::HighlightedSource::from_static_parts(
300 SOURCE,
301 #crate_path::Language::#variant_ident,
302 SPANS,
303 )
304 }})
305}
306
307struct NormalizedSpan {
308 start: u32,
309 end: u32,
310 tag: &'static str,
311}
312
313struct RawSpan {
314 start: u32,
315 end: u32,
316 tag: Option<&'static str>,
317 pattern_index: u32,
318}
319
320fn normalize_spans(spans: Vec<arborium::advanced::Span>) -> Vec<NormalizedSpan> {
321 use std::collections::HashMap;
322
323 let mut deduped: HashMap<(u32, u32), RawSpan> = HashMap::new();
324 for span in spans {
325 let span = RawSpan {
326 start: span.start,
327 end: span.end,
328 tag: arborium_theme::tag_for_capture(&span.capture),
329 pattern_index: span.pattern_index,
330 };
331 let key = (span.start, span.end);
332
333 if let Some(existing) = deduped.get(&key) {
334 let should_replace = match (span.tag.is_some(), existing.tag.is_some()) {
335 (true, false) => true,
336 (false, true) => false,
337 _ => span.pattern_index >= existing.pattern_index,
338 };
339 if should_replace {
340 deduped.insert(key, span);
341 }
342 } else {
343 deduped.insert(key, span);
344 }
345 }
346
347 let mut spans: Vec<_> = deduped
348 .into_values()
349 .filter_map(|span| {
350 Some(NormalizedSpan {
351 start: span.start,
352 end: span.end,
353 tag: span.tag?,
354 })
355 })
356 .collect();
357
358 spans.sort_by_key(|span| (span.start, span.end));
359
360 let mut coalesced: Vec<NormalizedSpan> = Vec::with_capacity(spans.len());
361 for span in spans {
362 if let Some(last) = coalesced.last_mut()
363 && span.tag == last.tag
364 && span.start <= last.end
365 {
366 last.end = last.end.max(span.end);
367 continue;
368 }
369 coalesced.push(span);
370 }
371
372 coalesced
373}
374
375fn dioxus_code_crate_path() -> syn::Result<TokenStream2> {
376 match crate_name("dioxus-code") {
377 Ok(FoundCrate::Itself) => Ok(quote!(::dioxus_code)),
378 Ok(FoundCrate::Name(name)) => {
379 let ident = format_ident!("{}", name);
380 Ok(quote!(::#ident))
381 }
382 Err(error) => Err(syn::Error::new(Span::call_site(), error.to_string())),
383 }
384}
385
386fn resolve_manifest_path(manifest_dir: &Path, path: &str) -> PathBuf {
387 let path_buf = PathBuf::from(path);
388 if path_buf.is_absolute() && (path_buf.exists() || path_buf.starts_with(manifest_dir)) {
389 return path_buf;
390 }
391
392 if let Some(stripped) = path.strip_prefix('/') {
393 manifest_dir.join(stripped)
394 } else {
395 manifest_dir.join(path)
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 fn language(expr: &str) -> Option<String> {
404 let expr = syn::parse_str::<Expr>(expr).unwrap();
405 try_extract_language(&expr)
406 }
407
408 #[test]
409 fn extracts_language_variant_options() {
410 assert_eq!(
411 language("CodeOptions::builder().with_language(Language::Rust)").as_deref(),
412 Some("rust"),
413 );
414 assert_eq!(
415 language("CodeOptions::builder().with_language(Some(Language::Rust))").as_deref(),
416 Some("rust"),
417 );
418 }
419
420 #[test]
421 fn extracts_none_language_option() {
422 assert_eq!(
423 language("CodeOptions::builder().with_language(None)").as_deref(),
424 None,
425 );
426 }
427
428 #[test]
429 fn unknown_method_chains_fall_back_silently() {
430 assert_eq!(language("CodeOptions::builder()").as_deref(), None);
431 assert_eq!(
432 language("CodeOptions::builder().with_themes(Language::Rust)").as_deref(),
433 None,
434 );
435 }
436}