Skip to main content

runmat_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use std::fs;
4use std::path::{Path, PathBuf};
5use std::sync::{Mutex, OnceLock};
6use syn::parse::{Parse, ParseStream};
7use syn::{
8    parse_macro_input, AttributeArgs, Expr, FnArg, ItemConst, ItemFn, Lit, LitStr, Meta,
9    MetaNameValue, NestedMeta, Pat,
10};
11
12static WASM_REGISTRY_PATH: OnceLock<Option<PathBuf>> = OnceLock::new();
13static WASM_REGISTRY_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
14static WASM_REGISTRY_INIT: OnceLock<()> = OnceLock::new();
15
16/// Attribute used to mark functions as implementing a runtime builtin.
17///
18/// Example:
19/// ```rust,ignore
20/// use runmat_macros::runtime_builtin;
21///
22/// #[runtime_builtin(name = "plot")]
23/// pub fn plot_line(xs: &[f64]) {
24///     /* implementation */
25/// }
26/// ```
27///
28/// This registers the function with the `runmat-builtins` inventory
29/// so the runtime can discover it at start-up.
30#[proc_macro_attribute]
31pub fn runtime_builtin(args: TokenStream, input: TokenStream) -> TokenStream {
32    // Parse attribute arguments as `name = "..."`
33    let args = parse_macro_input!(args as AttributeArgs);
34    let mut name_lit: Option<Lit> = None;
35    let mut category_lit: Option<Lit> = None;
36    let mut summary_lit: Option<Lit> = None;
37    let mut keywords_lit: Option<Lit> = None;
38    let mut errors_lit: Option<Lit> = None;
39    let mut related_lit: Option<Lit> = None;
40    let mut introduced_lit: Option<Lit> = None;
41    let mut status_lit: Option<Lit> = None;
42    let mut examples_lit: Option<Lit> = None;
43    let mut accel_values: Vec<String> = Vec::new();
44    let mut builtin_path_lit: Option<LitStr> = None;
45    let mut type_resolver_path: Option<syn::Path> = None;
46    let mut type_resolver_ctx_path: Option<syn::Path> = None;
47    let mut descriptor_path: Option<syn::Path> = None;
48    let mut sink_flag = false;
49    let mut suppress_auto_output_flag = false;
50    for arg in args {
51        match arg {
52            NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => {
53                if path.is_ident("name") {
54                    name_lit = Some(lit);
55                } else if path.is_ident("category") {
56                    category_lit = Some(lit);
57                } else if path.is_ident("summary") {
58                    summary_lit = Some(lit);
59                } else if path.is_ident("keywords") {
60                    keywords_lit = Some(lit);
61                } else if path.is_ident("errors") {
62                    errors_lit = Some(lit);
63                } else if path.is_ident("related") {
64                    related_lit = Some(lit);
65                } else if path.is_ident("introduced") {
66                    introduced_lit = Some(lit);
67                } else if path.is_ident("status") {
68                    status_lit = Some(lit);
69                } else if path.is_ident("examples") {
70                    examples_lit = Some(lit);
71                } else if path.is_ident("accel") {
72                    if let Lit::Str(ls) = lit {
73                        accel_values.extend(
74                            ls.value()
75                                .split(|c: char| c == ',' || c == '|' || c.is_ascii_whitespace())
76                                .filter(|s| !s.is_empty())
77                                .map(|s| s.to_ascii_lowercase()),
78                        );
79                    }
80                } else if path.is_ident("sink") {
81                    if let Lit::Bool(lb) = lit {
82                        sink_flag = lb.value;
83                    }
84                } else if path.is_ident("suppress_auto_output") {
85                    if let Lit::Bool(lb) = lit {
86                        suppress_auto_output_flag = lb.value;
87                    }
88                } else if path.is_ident("builtin_path") {
89                    if let Lit::Str(ls) = lit {
90                        builtin_path_lit = Some(ls);
91                    } else {
92                        panic!("builtin_path must be a string literal");
93                    }
94                } else if path.is_ident("type_resolver") {
95                    if let Lit::Str(ls) = lit {
96                        let parsed: syn::Path = ls.parse().expect("type_resolver must be a path");
97                        type_resolver_path = Some(parsed);
98                    } else {
99                        panic!("type_resolver must be a string literal path");
100                    }
101                } else if path.is_ident("type_resolver_ctx") {
102                    if let Lit::Str(ls) = lit {
103                        let parsed: syn::Path =
104                            ls.parse().expect("type_resolver_ctx must be a path");
105                        type_resolver_ctx_path = Some(parsed);
106                    } else {
107                        panic!("type_resolver_ctx must be a string literal path");
108                    }
109                } else {
110                    // Gracefully ignore unknown parameters for better IDE experience
111                }
112            }
113            NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("type_resolver") => {
114                if list.nested.len() != 1 {
115                    panic!("type_resolver expects exactly one path argument");
116                }
117                let nested = list.nested.first().unwrap();
118                if let NestedMeta::Meta(Meta::Path(path)) = nested {
119                    type_resolver_path = Some(path.clone());
120                } else {
121                    panic!("type_resolver expects a path argument");
122                }
123            }
124            NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("type_resolver_ctx") => {
125                if list.nested.len() != 1 {
126                    panic!("type_resolver_ctx expects exactly one path argument");
127                }
128                let nested = list.nested.first().unwrap();
129                if let NestedMeta::Meta(Meta::Path(path)) = nested {
130                    type_resolver_ctx_path = Some(path.clone());
131                } else {
132                    panic!("type_resolver_ctx expects a path argument");
133                }
134            }
135            NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("descriptor") => {
136                if list.nested.len() != 1 {
137                    panic!("descriptor expects exactly one path argument");
138                }
139                let nested = list.nested.first().unwrap();
140                if let NestedMeta::Meta(Meta::Path(path)) = nested {
141                    descriptor_path = Some(path.clone());
142                } else {
143                    panic!("descriptor expects a path argument");
144                }
145            }
146            _ => {}
147        }
148    }
149    let name_lit = name_lit.expect("expected `name = \"...\"` argument");
150    let name_str = if let Lit::Str(ref s) = name_lit {
151        s.value()
152    } else {
153        panic!("name must be a string literal");
154    };
155
156    let func: ItemFn = parse_macro_input!(input as ItemFn);
157    let ident = &func.sig.ident;
158    let is_async = func.sig.asyncness.is_some();
159
160    // Extract param idents and types
161    let mut param_idents = Vec::new();
162    let mut param_types = Vec::new();
163    for arg in &func.sig.inputs {
164        match arg {
165            FnArg::Typed(pt) => {
166                // pattern must be ident
167                if let Pat::Ident(pi) = pt.pat.as_ref() {
168                    param_idents.push(pi.ident.clone());
169                } else {
170                    panic!("parameters must be simple identifiers");
171                }
172                param_types.push((*pt.ty).clone());
173            }
174            _ => panic!("self parameter not allowed"),
175        }
176    }
177    let param_len = param_idents.len();
178
179    // Infer parameter types for BuiltinFunction
180    let inferred_param_types: Vec<proc_macro2::TokenStream> =
181        param_types.iter().map(infer_builtin_type).collect();
182
183    // Infer return type for BuiltinFunction
184    let inferred_return_type = match &func.sig.output {
185        syn::ReturnType::Default => quote! { runmat_builtins::Type::Void },
186        syn::ReturnType::Type(_, ty) => infer_builtin_type(ty),
187    };
188
189    // Detect if last parameter is variadic Vec<Value>
190    let is_last_variadic = param_types
191        .last()
192        .map(|ty| {
193            // crude detection: type path starts with Vec and inner type is runmat_builtins::Value or Value
194            if let syn::Type::Path(tp) = ty {
195                if tp
196                    .path
197                    .segments
198                    .last()
199                    .map(|s| s.ident == "Vec")
200                    .unwrap_or(false)
201                {
202                    if let syn::PathArguments::AngleBracketed(ab) =
203                        &tp.path.segments.last().unwrap().arguments
204                    {
205                        if let Some(syn::GenericArgument::Type(syn::Type::Path(inner))) =
206                            ab.args.first()
207                        {
208                            return inner
209                                .path
210                                .segments
211                                .last()
212                                .map(|s| s.ident == "Value")
213                                .unwrap_or(false);
214                        }
215                    }
216                }
217            }
218            false
219        })
220        .unwrap_or(false);
221
222    // Generate wrapper ident
223    let wrapper_ident = format_ident!("__rt_wrap_{}", ident);
224
225    let conv_stmts: Vec<proc_macro2::TokenStream> = if is_last_variadic && param_len > 0 {
226        let mut stmts = Vec::new();
227        // Convert fixed params (all but last)
228        for (i, (ident, ty)) in param_idents
229            .iter()
230            .zip(param_types.iter())
231            .enumerate()
232            .take(param_len - 1)
233        {
234            stmts.push(quote! { let #ident : #ty = std::convert::TryInto::try_into(&args[#i])?; });
235        }
236        // Collect the rest into Vec<Value>
237        let last_ident = &param_idents[param_len - 1];
238        stmts.push(quote! {
239            let #last_ident : Vec<runmat_builtins::Value> = {
240                let mut v = Vec::new();
241                for j in (#param_len-1)..args.len() {
242                    let item : runmat_builtins::Value = std::convert::TryInto::try_into(&args[j])?;
243                    v.push(item);
244                }
245                v
246            };
247        });
248        stmts
249    } else {
250        param_idents
251            .iter()
252            .zip(param_types.iter())
253            .enumerate()
254            .map(|(i, (ident, ty))| {
255                quote! { let #ident : #ty = std::convert::TryInto::try_into(&args[#i])?; }
256            })
257            .collect()
258    };
259
260    let call_expr = if is_async {
261        quote! { #ident(#(#param_idents),*).await? }
262    } else {
263        quote! { #ident(#(#param_idents),*)? }
264    };
265
266    let wrapper = quote! {
267        fn #wrapper_ident(args: &[runmat_builtins::Value]) -> runmat_builtins::BuiltinFuture {
268            #![allow(unused_variables)]
269            let args = args.to_vec();
270            Box::pin(async move {
271                if #is_last_variadic {
272                    if args.len() < #param_len - 1 {
273                        return Err(std::convert::From::from(format!(
274                            "expected at least {} args, got {}",
275                            #param_len - 1,
276                            args.len()
277                        )));
278                    }
279                } else if args.len() != #param_len {
280                    return Err(std::convert::From::from(format!(
281                        "expected {} args, got {}",
282                        #param_len,
283                        args.len()
284                    )));
285                }
286                #(#conv_stmts)*
287                let value = #call_expr;
288                Ok(runmat_builtins::Value::from(value))
289            })
290        }
291    };
292
293    // Prepare tokens for defaults and options
294    let default_category = syn::LitStr::new("general", proc_macro2::Span::call_site());
295    let default_summary =
296        syn::LitStr::new("Runtime builtin function", proc_macro2::Span::call_site());
297
298    let category_tok: proc_macro2::TokenStream = match &category_lit {
299        Some(syn::Lit::Str(ls)) => quote! { #ls },
300        _ => quote! { #default_category },
301    };
302    let summary_tok: proc_macro2::TokenStream = match &summary_lit {
303        Some(syn::Lit::Str(ls)) => quote! { #ls },
304        _ => quote! { #default_summary },
305    };
306
307    fn opt_tok(lit: &Option<syn::Lit>) -> proc_macro2::TokenStream {
308        if let Some(syn::Lit::Str(ls)) = lit {
309            quote! { Some(#ls) }
310        } else {
311            quote! { None }
312        }
313    }
314    let category_opt_tok = opt_tok(&category_lit);
315    let summary_opt_tok = opt_tok(&summary_lit);
316    let keywords_opt_tok = opt_tok(&keywords_lit);
317    let errors_opt_tok = opt_tok(&errors_lit);
318    let related_opt_tok = opt_tok(&related_lit);
319    let introduced_opt_tok = opt_tok(&introduced_lit);
320    let status_opt_tok = opt_tok(&status_lit);
321    let examples_opt_tok = opt_tok(&examples_lit);
322
323    let accel_tokens: Vec<proc_macro2::TokenStream> = accel_values
324        .iter()
325        .map(|mode| match mode.as_str() {
326            "unary" => quote! { runmat_builtins::AccelTag::Unary },
327            "binary" => quote! { runmat_builtins::AccelTag::Elementwise },
328            "elementwise" => quote! { runmat_builtins::AccelTag::Elementwise },
329            "reduction" => quote! { runmat_builtins::AccelTag::Reduction },
330            "matmul" => quote! { runmat_builtins::AccelTag::MatMul },
331            "transpose" => quote! { runmat_builtins::AccelTag::Transpose },
332            "array_construct" => quote! { runmat_builtins::AccelTag::ArrayConstruct },
333            _ => quote! {},
334        })
335        .filter(|ts| !ts.is_empty())
336        .collect();
337    let accel_slice = if accel_tokens.is_empty() {
338        quote! { &[] as &[runmat_builtins::AccelTag] }
339    } else {
340        quote! { &[#(#accel_tokens),*] }
341    };
342    let type_resolver_expr = if let Some(path) = type_resolver_ctx_path.as_ref() {
343        quote! { Some(runmat_builtins::type_resolver_kind_ctx(#path)) }
344    } else if let Some(path) = type_resolver_path.as_ref() {
345        quote! { Some(runmat_builtins::type_resolver_kind_ctx(#path)) }
346    } else {
347        quote! { None }
348    };
349    let sink_bool = sink_flag;
350    let suppress_auto_output_bool = suppress_auto_output_flag;
351    let descriptor_expr = if let Some(path) = descriptor_path.as_ref() {
352        quote! { Some(&#path) }
353    } else {
354        quote! { None }
355    };
356
357    let builtin_expr = quote! {
358        runmat_builtins::BuiltinFunction::new(
359            #name_str,
360            #summary_tok,
361            #category_tok,
362            "",
363            "",
364            vec![#(#inferred_param_types),*],
365            #inferred_return_type,
366            #type_resolver_expr,
367            #wrapper_ident,
368            #accel_slice,
369            #sink_bool,
370            #suppress_auto_output_bool,
371        ).with_descriptor_option(#descriptor_expr)
372    };
373
374    let doc_expr = quote! {
375        runmat_builtins::BuiltinDoc {
376            name: #name_str,
377            category: #category_opt_tok,
378            summary: #summary_opt_tok,
379            keywords: #keywords_opt_tok,
380            errors: #errors_opt_tok,
381            related: #related_opt_tok,
382            introduced: #introduced_opt_tok,
383            status: #status_opt_tok,
384            examples: #examples_opt_tok,
385        }
386    };
387
388    let builtin_path_lit =
389        builtin_path_lit.expect("runtime_builtin requires `builtin_path = \"...\"`");
390    let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
391        .expect("runtime_builtin `builtin_path` must be a valid path");
392    let helper_ident = format_ident!("__runmat_wasm_register_builtin_{}", ident);
393    let builtin_expr_helper = builtin_expr.clone();
394    let doc_expr_helper = doc_expr.clone();
395    let wasm_helper = quote! {
396        #[cfg(target_arch = "wasm32")]
397        #[allow(non_snake_case)]
398        pub(crate) fn #helper_ident() {
399            runmat_builtins::wasm_registry::submit_builtin_function(#builtin_expr_helper);
400            runmat_builtins::wasm_registry::submit_builtin_doc(#doc_expr_helper);
401        }
402    };
403    let register_native = quote! {
404        #[cfg(not(target_arch = "wasm32"))]
405        runmat_builtins::inventory::submit! { #builtin_expr }
406        #[cfg(not(target_arch = "wasm32"))]
407        runmat_builtins::inventory::submit! { #doc_expr }
408    };
409    append_wasm_block(quote! {
410        #builtin_path::#helper_ident();
411    });
412
413    TokenStream::from(quote! {
414        #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
415        #func
416        #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
417        #wrapper
418        #wasm_helper
419        #register_native
420    })
421}
422
423/// Attribute used to declare a runtime constant.
424///
425/// Example:
426/// ```rust,ignore
427/// use runmat_macros::runtime_constant;
428/// use runmat_builtins::Value;
429///
430/// #[runtime_constant(name = "pi", value = std::f64::consts::PI)]
431/// const PI_CONSTANT: ();
432/// ```
433///
434/// This registers the constant with the `runmat-builtins` inventory
435/// so the runtime can discover it at start-up.
436#[proc_macro_attribute]
437pub fn runtime_constant(args: TokenStream, input: TokenStream) -> TokenStream {
438    let args = parse_macro_input!(args as AttributeArgs);
439    let mut name_lit: Option<Lit> = None;
440    let mut value_expr: Option<Expr> = None;
441    let mut builtin_path_lit: Option<LitStr> = None;
442
443    for arg in args {
444        match arg {
445            NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => {
446                if path.is_ident("name") {
447                    name_lit = Some(lit);
448                } else if path.is_ident("builtin_path") {
449                    if let Lit::Str(ls) = lit {
450                        builtin_path_lit = Some(ls);
451                    } else {
452                        panic!("builtin_path must be a string literal");
453                    }
454                } else {
455                    panic!("Unknown attribute parameter: {}", quote!(#path));
456                }
457            }
458            NestedMeta::Meta(Meta::Path(path)) if path.is_ident("value") => {
459                panic!("value parameter requires assignment: value = expression");
460            }
461            NestedMeta::Lit(lit) => {
462                // This handles the case where value is provided as a literal
463                value_expr = Some(syn::parse_quote!(#lit));
464            }
465            _ => panic!("Invalid attribute syntax"),
466        }
467    }
468
469    let name = match name_lit {
470        Some(Lit::Str(s)) => s.value(),
471        _ => panic!("name parameter must be a string literal"),
472    };
473
474    let value = value_expr.unwrap_or_else(|| {
475        panic!("value parameter is required");
476    });
477
478    let builtin_path_lit =
479        builtin_path_lit.expect("runtime_constant requires `builtin_path = \"...\"` argument");
480    let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
481        .expect("runtime_constant `builtin_path` must be a valid path");
482    let item = parse_macro_input!(input as syn::Item);
483
484    let constant_expr = quote! {
485        runmat_builtins::Constant {
486            name: #name,
487            value: #value,
488        }
489    };
490
491    let helper_ident = helper_ident_from_name("__runmat_wasm_register_const_", &name);
492    let constant_expr_helper = constant_expr.clone();
493    let wasm_helper = quote! {
494        #[cfg(target_arch = "wasm32")]
495        #[allow(non_snake_case)]
496        pub(crate) fn #helper_ident() {
497            runmat_builtins::wasm_registry::submit_constant(#constant_expr_helper);
498        }
499    };
500    let register_native = quote! {
501        #[cfg(not(target_arch = "wasm32"))]
502        #[allow(non_upper_case_globals)]
503        runmat_builtins::inventory::submit! { #constant_expr }
504    };
505    append_wasm_block(quote! {
506        #builtin_path::#helper_ident();
507    });
508
509    TokenStream::from(quote! {
510        #item
511        #wasm_helper
512        #register_native
513    })
514}
515
516struct RegisterConstantArgs {
517    name: LitStr,
518    value: Expr,
519    builtin_path: LitStr,
520}
521
522impl syn::parse::Parse for RegisterConstantArgs {
523    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
524        let name: LitStr = input.parse()?;
525        input.parse::<syn::Token![,]>()?;
526        let value: Expr = input.parse()?;
527        input.parse::<syn::Token![,]>()?;
528        let builtin_path: LitStr = input.parse()?;
529        if input.peek(syn::Token![,]) {
530            input.parse::<syn::Token![,]>()?;
531        }
532        Ok(RegisterConstantArgs {
533            name,
534            value,
535            builtin_path,
536        })
537    }
538}
539
540#[proc_macro]
541pub fn register_constant(input: TokenStream) -> TokenStream {
542    let RegisterConstantArgs {
543        name,
544        value,
545        builtin_path,
546    } = parse_macro_input!(input as RegisterConstantArgs);
547    let constant_expr = quote! {
548        runmat_builtins::Constant {
549            name: #name,
550            value: #value,
551        }
552    };
553    let helper_ident = helper_ident_from_name("__runmat_wasm_register_const_", &name.value());
554    let builtin_path: syn::Path = syn::parse_str(&builtin_path.value())
555        .expect("register_constant `builtin_path` must be a valid path");
556    let constant_expr_helper = constant_expr.clone();
557    let wasm_helper = quote! {
558        #[cfg(target_arch = "wasm32")]
559        #[allow(non_snake_case)]
560        pub(crate) fn #helper_ident() {
561            runmat_builtins::wasm_registry::submit_constant(#constant_expr_helper);
562        }
563    };
564    append_wasm_block(quote! {
565        #builtin_path::#helper_ident();
566    });
567    TokenStream::from(quote! {
568        #wasm_helper
569        #[cfg(not(target_arch = "wasm32"))]
570        runmat_builtins::inventory::submit! { #constant_expr }
571    })
572}
573
574struct RegisterSpecAttrArgs {
575    spec_expr: Option<Expr>,
576    builtin_path: Option<LitStr>,
577}
578
579impl Parse for RegisterSpecAttrArgs {
580    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
581        let mut spec_expr = None;
582        let mut builtin_path = None;
583        while !input.is_empty() {
584            let ident: syn::Ident = input.parse()?;
585            input.parse::<syn::Token![=]>()?;
586            if ident == "spec" {
587                spec_expr = Some(input.parse()?);
588            } else if ident == "builtin_path" {
589                let lit: LitStr = input.parse()?;
590                builtin_path = Some(lit);
591            } else {
592                return Err(syn::Error::new(ident.span(), "unknown attribute argument"));
593            }
594            if input.peek(syn::Token![,]) {
595                input.parse::<syn::Token![,]>()?;
596            }
597        }
598        Ok(Self {
599            spec_expr,
600            builtin_path,
601        })
602    }
603}
604
605#[proc_macro_attribute]
606pub fn register_gpu_spec(attr: TokenStream, item: TokenStream) -> TokenStream {
607    let args = parse_macro_input!(attr as RegisterSpecAttrArgs);
608    let RegisterSpecAttrArgs {
609        spec_expr,
610        builtin_path,
611    } = args;
612    let item_const = parse_macro_input!(item as ItemConst);
613    let spec_tokens = spec_expr.map(|expr| quote! { #expr }).unwrap_or_else(|| {
614        let ident = &item_const.ident;
615        quote! { #ident }
616    });
617    let spec_for_native = spec_tokens.clone();
618    let builtin_path_lit =
619        builtin_path.expect("register_gpu_spec requires `builtin_path = \"...\"` argument");
620    let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
621        .expect("register_gpu_spec `builtin_path` must be a valid path");
622    let helper_ident = format_ident!(
623        "__runmat_wasm_register_gpu_spec_{}",
624        item_const.ident.to_string()
625    );
626    let spec_tokens_helper = spec_tokens.clone();
627    let wasm_helper = quote! {
628        #[cfg(target_arch = "wasm32")]
629        #[allow(non_snake_case)]
630        pub(crate) fn #helper_ident() {
631            crate::builtins::common::spec::wasm_registry::submit_gpu_spec(&#spec_tokens_helper);
632        }
633    };
634    append_wasm_block(quote! {
635        #builtin_path::#helper_ident();
636    });
637    let expanded = quote! {
638        #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
639        #item_const
640        #wasm_helper
641        #[cfg(not(target_arch = "wasm32"))]
642        inventory::submit! {
643            crate::builtins::common::spec::GpuSpecInventory { spec: &#spec_for_native }
644        }
645    };
646    expanded.into()
647}
648
649#[proc_macro_attribute]
650pub fn register_fusion_spec(attr: TokenStream, item: TokenStream) -> TokenStream {
651    let args = parse_macro_input!(attr as RegisterSpecAttrArgs);
652    let RegisterSpecAttrArgs {
653        spec_expr,
654        builtin_path,
655    } = args;
656    let item_const = parse_macro_input!(item as ItemConst);
657    let spec_tokens = spec_expr.map(|expr| quote! { #expr }).unwrap_or_else(|| {
658        let ident = &item_const.ident;
659        quote! { #ident }
660    });
661    let spec_for_native = spec_tokens.clone();
662    let builtin_path_lit =
663        builtin_path.expect("register_fusion_spec requires `builtin_path = \"...\"` argument");
664    let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
665        .expect("register_fusion_spec `builtin_path` must be a valid path");
666    let helper_ident = format_ident!(
667        "__runmat_wasm_register_fusion_spec_{}",
668        item_const.ident.to_string()
669    );
670    let spec_tokens_helper = spec_tokens.clone();
671    let wasm_helper = quote! {
672        #[cfg(target_arch = "wasm32")]
673        #[allow(non_snake_case)]
674        pub(crate) fn #helper_ident() {
675            crate::builtins::common::spec::wasm_registry::submit_fusion_spec(&#spec_tokens_helper);
676        }
677    };
678    append_wasm_block(quote! {
679        #builtin_path::#helper_ident();
680    });
681    let expanded = quote! {
682        #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
683        #item_const
684        #wasm_helper
685        #[cfg(not(target_arch = "wasm32"))]
686        inventory::submit! {
687            crate::builtins::common::spec::FusionSpecInventory { spec: &#spec_for_native }
688        }
689    };
690    expanded.into()
691}
692
693fn append_wasm_block(block: proc_macro2::TokenStream) {
694    if !should_generate_wasm_registry() {
695        return;
696    }
697    let path = match wasm_registry_path() {
698        Some(p) => p,
699        None => return,
700    };
701    let _guard = wasm_registry_lock().lock().unwrap();
702    initialize_registry_file(path);
703    let mut contents = fs::read_to_string(path).expect("failed to read wasm registry file");
704    let insertion = format!("    {}\n", block);
705    if let Some(pos) = contents.rfind('}') {
706        contents.insert_str(pos, &insertion);
707    } else {
708        contents.push_str(&insertion);
709        contents.push_str("}\n");
710    }
711    fs::write(path, contents).expect("failed to update wasm registry file");
712}
713
714fn wasm_registry_path() -> Option<&'static PathBuf> {
715    WASM_REGISTRY_PATH
716        .get_or_init(workspace_registry_path)
717        .as_ref()
718}
719
720fn wasm_registry_lock() -> &'static Mutex<()> {
721    WASM_REGISTRY_LOCK.get_or_init(|| Mutex::new(()))
722}
723
724fn initialize_registry_file(path: &Path) {
725    WASM_REGISTRY_INIT.get_or_init(|| {
726        if let Some(parent) = path.parent() {
727            let _ = fs::create_dir_all(parent);
728        }
729        const HEADER: &str = "pub fn register_all() {\n}\n";
730        fs::write(path, HEADER).expect("failed to create wasm registry file");
731    });
732}
733
734fn should_generate_wasm_registry() -> bool {
735    // Only write to the registry file when explicitly requested.  Writing on every build
736    // causes an infinite rebuild loop: proc-macros modify the file → cargo detects the
737    // change → build.rs re-runs → runmat-runtime recompiles → proc-macros run again →
738    // repeat.  Callers (e.g. test-wasm-headless.sh) set RUNMAT_GENERATE_WASM_REGISTRY=1
739    // for the dedicated `cargo check` regeneration step, then unset it before wasm-pack
740    // test so the pre-generated file is used as-is without triggering rebuilds.
741    matches!(
742        std::env::var("RUNMAT_GENERATE_WASM_REGISTRY"),
743        Ok(ref value) if value == "1"
744    )
745}
746
747fn workspace_registry_path() -> Option<PathBuf> {
748    let mut dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").ok()?);
749    loop {
750        if dir.join("Cargo.lock").exists() {
751            return Some(dir.join("target").join("runmat_wasm_registry.rs"));
752        }
753        if !dir.pop() {
754            return None;
755        }
756    }
757}
758
759fn helper_ident_from_name(prefix: &str, name: &str) -> proc_macro2::Ident {
760    let mut sanitized = String::new();
761    for ch in name.chars() {
762        if ch.is_ascii_alphanumeric() || ch == '_' {
763            sanitized.push(ch);
764        } else {
765            sanitized.push('_');
766        }
767    }
768    format_ident!("{}{}", prefix, sanitized)
769}
770
771/// Smart type inference from Rust types to our enhanced Type enum
772fn infer_builtin_type(ty: &syn::Type) -> proc_macro2::TokenStream {
773    use syn::Type;
774
775    match ty {
776        // Basic primitive types
777        Type::Path(type_path) => {
778            if let Some(ident) = type_path.path.get_ident() {
779                match ident.to_string().as_str() {
780                    "i32" | "i64" | "isize" => quote! { runmat_builtins::Type::Int },
781                    "f32" | "f64" => quote! { runmat_builtins::Type::Num },
782                    "bool" => quote! { runmat_builtins::Type::Bool },
783                    "String" => quote! { runmat_builtins::Type::String },
784                    _ => infer_complex_type(type_path),
785                }
786            } else {
787                infer_complex_type(type_path)
788            }
789        }
790
791        // Reference types like &str, &Value, &Matrix
792        Type::Reference(type_ref) => match type_ref.elem.as_ref() {
793            Type::Path(type_path) => {
794                if let Some(ident) = type_path.path.get_ident() {
795                    match ident.to_string().as_str() {
796                        "str" => quote! { runmat_builtins::Type::String },
797                        _ => infer_builtin_type(&type_ref.elem),
798                    }
799                } else {
800                    infer_builtin_type(&type_ref.elem)
801                }
802            }
803            _ => infer_builtin_type(&type_ref.elem),
804        },
805
806        // Slice types like &[Value], &[f64]
807        Type::Slice(type_slice) => {
808            let element_type = infer_builtin_type(&type_slice.elem);
809            quote! { runmat_builtins::Type::Cell {
810                element_type: Some(Box::new(#element_type)),
811                length: None
812            } }
813        }
814
815        // Array types like [f64; N]
816        Type::Array(type_array) => {
817            let element_type = infer_builtin_type(&type_array.elem);
818            // Try to extract length if it's a literal
819            if let syn::Expr::Lit(expr_lit) = &type_array.len {
820                if let syn::Lit::Int(lit_int) = &expr_lit.lit {
821                    if let Ok(length) = lit_int.base10_parse::<usize>() {
822                        return quote! { runmat_builtins::Type::Cell {
823                            element_type: Some(Box::new(#element_type)),
824                            length: Some(#length)
825                        } };
826                    }
827                }
828            }
829            // Fallback to unknown length
830            quote! { runmat_builtins::Type::Cell {
831                element_type: Some(Box::new(#element_type)),
832                length: None
833            } }
834        }
835
836        // Generic or complex types
837        _ => quote! { runmat_builtins::Type::Unknown },
838    }
839}
840
841/// Infer types for complex path types like Result<T, E>, Option<T>, Matrix, Value
842fn infer_complex_type(type_path: &syn::TypePath) -> proc_macro2::TokenStream {
843    let path_str = quote! { #type_path }.to_string();
844
845    // Handle common patterns
846    if path_str.contains("Matrix") || path_str.contains("Tensor") {
847        quote! { runmat_builtins::Type::tensor() }
848    } else if path_str.contains("Value") {
849        quote! { runmat_builtins::Type::Unknown } // Value can be anything
850    } else if path_str.starts_with("Result") {
851        // Extract the Ok type from Result<T, E>
852        if let syn::PathArguments::AngleBracketed(angle_bracketed) =
853            &type_path.path.segments.last().unwrap().arguments
854        {
855            if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
856                return infer_builtin_type(ty);
857            }
858        }
859        quote! { runmat_builtins::Type::Unknown }
860    } else if path_str.starts_with("Option") {
861        // Extract the Some type from Option<T>
862        if let syn::PathArguments::AngleBracketed(angle_bracketed) =
863            &type_path.path.segments.last().unwrap().arguments
864        {
865            if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
866                return infer_builtin_type(ty);
867            }
868        }
869        quote! { runmat_builtins::Type::Unknown }
870    } else if path_str.starts_with("Vec") {
871        // Extract element type from Vec<T>
872        if let syn::PathArguments::AngleBracketed(angle_bracketed) =
873            &type_path.path.segments.last().unwrap().arguments
874        {
875            if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
876                let element_type = infer_builtin_type(ty);
877                return quote! { runmat_builtins::Type::Cell {
878                    element_type: Some(Box::new(#element_type)),
879                    length: None
880                } };
881            }
882        }
883        quote! { runmat_builtins::Type::cell() }
884    } else {
885        // Unknown type
886        quote! { runmat_builtins::Type::Unknown }
887    }
888}