Skip to main content

runmat_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use std::fs;
4use std::path::{Path, PathBuf};
5use std::sync::{Mutex, OnceLock};
6use syn::parse::{Parse, ParseStream};
7use syn::{
8    parse_macro_input, AttributeArgs, Expr, FnArg, ItemConst, ItemFn, Lit, LitStr, Meta,
9    MetaNameValue, NestedMeta, Pat,
10};
11
12static WASM_REGISTRY_PATH: OnceLock<Option<PathBuf>> = OnceLock::new();
13static WASM_REGISTRY_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
14static WASM_REGISTRY_INIT: OnceLock<()> = OnceLock::new();
15
16/// Attribute used to mark functions as implementing a runtime builtin.
17///
18/// Example:
19/// ```rust,ignore
20/// use runmat_macros::runtime_builtin;
21///
22/// #[runtime_builtin(name = "plot")]
23/// pub fn plot_line(xs: &[f64]) {
24///     /* implementation */
25/// }
26/// ```
27///
28/// This registers the function with the `runmat-builtins` inventory
29/// so the runtime can discover it at start-up.
30#[proc_macro_attribute]
31pub fn runtime_builtin(args: TokenStream, input: TokenStream) -> TokenStream {
32    // Parse attribute arguments as `name = "..."`
33    let args = parse_macro_input!(args as AttributeArgs);
34    let mut name_lit: Option<Lit> = None;
35    let mut category_lit: Option<Lit> = None;
36    let mut summary_lit: Option<Lit> = None;
37    let mut keywords_lit: Option<Lit> = None;
38    let mut errors_lit: Option<Lit> = None;
39    let mut related_lit: Option<Lit> = None;
40    let mut introduced_lit: Option<Lit> = None;
41    let mut status_lit: Option<Lit> = None;
42    let mut examples_lit: Option<Lit> = None;
43    let mut accel_values: Vec<String> = Vec::new();
44    let mut builtin_path_lit: Option<LitStr> = None;
45    let mut type_resolver_path: Option<syn::Path> = None;
46    let mut type_resolver_ctx_path: Option<syn::Path> = None;
47    let mut sink_flag = false;
48    let mut suppress_auto_output_flag = false;
49    for arg in args {
50        match arg {
51            NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => {
52                if path.is_ident("name") {
53                    name_lit = Some(lit);
54                } else if path.is_ident("category") {
55                    category_lit = Some(lit);
56                } else if path.is_ident("summary") {
57                    summary_lit = Some(lit);
58                } else if path.is_ident("keywords") {
59                    keywords_lit = Some(lit);
60                } else if path.is_ident("errors") {
61                    errors_lit = Some(lit);
62                } else if path.is_ident("related") {
63                    related_lit = Some(lit);
64                } else if path.is_ident("introduced") {
65                    introduced_lit = Some(lit);
66                } else if path.is_ident("status") {
67                    status_lit = Some(lit);
68                } else if path.is_ident("examples") {
69                    examples_lit = Some(lit);
70                } else if path.is_ident("accel") {
71                    if let Lit::Str(ls) = lit {
72                        accel_values.extend(
73                            ls.value()
74                                .split(|c: char| c == ',' || c == '|' || c.is_ascii_whitespace())
75                                .filter(|s| !s.is_empty())
76                                .map(|s| s.to_ascii_lowercase()),
77                        );
78                    }
79                } else if path.is_ident("sink") {
80                    if let Lit::Bool(lb) = lit {
81                        sink_flag = lb.value;
82                    }
83                } else if path.is_ident("suppress_auto_output") {
84                    if let Lit::Bool(lb) = lit {
85                        suppress_auto_output_flag = lb.value;
86                    }
87                } else if path.is_ident("builtin_path") {
88                    if let Lit::Str(ls) = lit {
89                        builtin_path_lit = Some(ls);
90                    } else {
91                        panic!("builtin_path must be a string literal");
92                    }
93                } else if path.is_ident("type_resolver") {
94                    if let Lit::Str(ls) = lit {
95                        let parsed: syn::Path = ls.parse().expect("type_resolver must be a path");
96                        type_resolver_path = Some(parsed);
97                    } else {
98                        panic!("type_resolver must be a string literal path");
99                    }
100                } else if path.is_ident("type_resolver_ctx") {
101                    if let Lit::Str(ls) = lit {
102                        let parsed: syn::Path =
103                            ls.parse().expect("type_resolver_ctx must be a path");
104                        type_resolver_ctx_path = Some(parsed);
105                    } else {
106                        panic!("type_resolver_ctx must be a string literal path");
107                    }
108                } else {
109                    // Gracefully ignore unknown parameters for better IDE experience
110                }
111            }
112            NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("type_resolver") => {
113                if list.nested.len() != 1 {
114                    panic!("type_resolver expects exactly one path argument");
115                }
116                let nested = list.nested.first().unwrap();
117                if let NestedMeta::Meta(Meta::Path(path)) = nested {
118                    type_resolver_path = Some(path.clone());
119                } else {
120                    panic!("type_resolver expects a path argument");
121                }
122            }
123            NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("type_resolver_ctx") => {
124                if list.nested.len() != 1 {
125                    panic!("type_resolver_ctx expects exactly one path argument");
126                }
127                let nested = list.nested.first().unwrap();
128                if let NestedMeta::Meta(Meta::Path(path)) = nested {
129                    type_resolver_ctx_path = Some(path.clone());
130                } else {
131                    panic!("type_resolver_ctx expects a path argument");
132                }
133            }
134            _ => {}
135        }
136    }
137    let name_lit = name_lit.expect("expected `name = \"...\"` argument");
138    let name_str = if let Lit::Str(ref s) = name_lit {
139        s.value()
140    } else {
141        panic!("name must be a string literal");
142    };
143
144    let func: ItemFn = parse_macro_input!(input as ItemFn);
145    let ident = &func.sig.ident;
146    let is_async = func.sig.asyncness.is_some();
147
148    // Extract param idents and types
149    let mut param_idents = Vec::new();
150    let mut param_types = Vec::new();
151    for arg in &func.sig.inputs {
152        match arg {
153            FnArg::Typed(pt) => {
154                // pattern must be ident
155                if let Pat::Ident(pi) = pt.pat.as_ref() {
156                    param_idents.push(pi.ident.clone());
157                } else {
158                    panic!("parameters must be simple identifiers");
159                }
160                param_types.push((*pt.ty).clone());
161            }
162            _ => panic!("self parameter not allowed"),
163        }
164    }
165    let param_len = param_idents.len();
166
167    // Infer parameter types for BuiltinFunction
168    let inferred_param_types: Vec<proc_macro2::TokenStream> =
169        param_types.iter().map(infer_builtin_type).collect();
170
171    // Infer return type for BuiltinFunction
172    let inferred_return_type = match &func.sig.output {
173        syn::ReturnType::Default => quote! { runmat_builtins::Type::Void },
174        syn::ReturnType::Type(_, ty) => infer_builtin_type(ty),
175    };
176
177    // Detect if last parameter is variadic Vec<Value>
178    let is_last_variadic = param_types
179        .last()
180        .map(|ty| {
181            // crude detection: type path starts with Vec and inner type is runmat_builtins::Value or Value
182            if let syn::Type::Path(tp) = ty {
183                if tp
184                    .path
185                    .segments
186                    .last()
187                    .map(|s| s.ident == "Vec")
188                    .unwrap_or(false)
189                {
190                    if let syn::PathArguments::AngleBracketed(ab) =
191                        &tp.path.segments.last().unwrap().arguments
192                    {
193                        if let Some(syn::GenericArgument::Type(syn::Type::Path(inner))) =
194                            ab.args.first()
195                        {
196                            return inner
197                                .path
198                                .segments
199                                .last()
200                                .map(|s| s.ident == "Value")
201                                .unwrap_or(false);
202                        }
203                    }
204                }
205            }
206            false
207        })
208        .unwrap_or(false);
209
210    // Generate wrapper ident
211    let wrapper_ident = format_ident!("__rt_wrap_{}", ident);
212
213    let conv_stmts: Vec<proc_macro2::TokenStream> = if is_last_variadic && param_len > 0 {
214        let mut stmts = Vec::new();
215        // Convert fixed params (all but last)
216        for (i, (ident, ty)) in param_idents
217            .iter()
218            .zip(param_types.iter())
219            .enumerate()
220            .take(param_len - 1)
221        {
222            stmts.push(quote! { let #ident : #ty = std::convert::TryInto::try_into(&args[#i])?; });
223        }
224        // Collect the rest into Vec<Value>
225        let last_ident = &param_idents[param_len - 1];
226        stmts.push(quote! {
227            let #last_ident : Vec<runmat_builtins::Value> = {
228                let mut v = Vec::new();
229                for j in (#param_len-1)..args.len() {
230                    let item : runmat_builtins::Value = std::convert::TryInto::try_into(&args[j])?;
231                    v.push(item);
232                }
233                v
234            };
235        });
236        stmts
237    } else {
238        param_idents
239            .iter()
240            .zip(param_types.iter())
241            .enumerate()
242            .map(|(i, (ident, ty))| {
243                quote! { let #ident : #ty = std::convert::TryInto::try_into(&args[#i])?; }
244            })
245            .collect()
246    };
247
248    let call_expr = if is_async {
249        quote! { #ident(#(#param_idents),*).await? }
250    } else {
251        quote! { #ident(#(#param_idents),*)? }
252    };
253
254    let wrapper = quote! {
255        fn #wrapper_ident(args: &[runmat_builtins::Value]) -> runmat_builtins::BuiltinFuture {
256            #![allow(unused_variables)]
257            let args = args.to_vec();
258            Box::pin(async move {
259                if #is_last_variadic {
260                    if args.len() < #param_len - 1 {
261                        return Err(std::convert::From::from(format!(
262                            "expected at least {} args, got {}",
263                            #param_len - 1,
264                            args.len()
265                        )));
266                    }
267                } else if args.len() != #param_len {
268                    return Err(std::convert::From::from(format!(
269                        "expected {} args, got {}",
270                        #param_len,
271                        args.len()
272                    )));
273                }
274                #(#conv_stmts)*
275                let value = #call_expr;
276                Ok(runmat_builtins::Value::from(value))
277            })
278        }
279    };
280
281    // Prepare tokens for defaults and options
282    let default_category = syn::LitStr::new("general", proc_macro2::Span::call_site());
283    let default_summary =
284        syn::LitStr::new("Runtime builtin function", proc_macro2::Span::call_site());
285
286    let category_tok: proc_macro2::TokenStream = match &category_lit {
287        Some(syn::Lit::Str(ls)) => quote! { #ls },
288        _ => quote! { #default_category },
289    };
290    let summary_tok: proc_macro2::TokenStream = match &summary_lit {
291        Some(syn::Lit::Str(ls)) => quote! { #ls },
292        _ => quote! { #default_summary },
293    };
294
295    fn opt_tok(lit: &Option<syn::Lit>) -> proc_macro2::TokenStream {
296        if let Some(syn::Lit::Str(ls)) = lit {
297            quote! { Some(#ls) }
298        } else {
299            quote! { None }
300        }
301    }
302    let category_opt_tok = opt_tok(&category_lit);
303    let summary_opt_tok = opt_tok(&summary_lit);
304    let keywords_opt_tok = opt_tok(&keywords_lit);
305    let errors_opt_tok = opt_tok(&errors_lit);
306    let related_opt_tok = opt_tok(&related_lit);
307    let introduced_opt_tok = opt_tok(&introduced_lit);
308    let status_opt_tok = opt_tok(&status_lit);
309    let examples_opt_tok = opt_tok(&examples_lit);
310
311    let accel_tokens: Vec<proc_macro2::TokenStream> = accel_values
312        .iter()
313        .map(|mode| match mode.as_str() {
314            "unary" => quote! { runmat_builtins::AccelTag::Unary },
315            "elementwise" => quote! { runmat_builtins::AccelTag::Elementwise },
316            "reduction" => quote! { runmat_builtins::AccelTag::Reduction },
317            "matmul" => quote! { runmat_builtins::AccelTag::MatMul },
318            "transpose" => quote! { runmat_builtins::AccelTag::Transpose },
319            "array_construct" => quote! { runmat_builtins::AccelTag::ArrayConstruct },
320            _ => quote! {},
321        })
322        .filter(|ts| !ts.is_empty())
323        .collect();
324    let accel_slice = if accel_tokens.is_empty() {
325        quote! { &[] as &[runmat_builtins::AccelTag] }
326    } else {
327        quote! { &[#(#accel_tokens),*] }
328    };
329    let type_resolver_expr = if let Some(path) = type_resolver_ctx_path.as_ref() {
330        quote! { Some(runmat_builtins::type_resolver_kind_ctx(#path)) }
331    } else if let Some(path) = type_resolver_path.as_ref() {
332        quote! { Some(runmat_builtins::type_resolver_kind_ctx(#path)) }
333    } else {
334        quote! { None }
335    };
336    let sink_bool = sink_flag;
337    let suppress_auto_output_bool = suppress_auto_output_flag;
338
339    let builtin_expr = quote! {
340        runmat_builtins::BuiltinFunction::new(
341            #name_str,
342            #summary_tok,
343            #category_tok,
344            "",
345            "",
346            vec![#(#inferred_param_types),*],
347            #inferred_return_type,
348            #type_resolver_expr,
349            #wrapper_ident,
350            #accel_slice,
351            #sink_bool,
352            #suppress_auto_output_bool,
353        )
354    };
355
356    let doc_expr = quote! {
357        runmat_builtins::BuiltinDoc {
358            name: #name_str,
359            category: #category_opt_tok,
360            summary: #summary_opt_tok,
361            keywords: #keywords_opt_tok,
362            errors: #errors_opt_tok,
363            related: #related_opt_tok,
364            introduced: #introduced_opt_tok,
365            status: #status_opt_tok,
366            examples: #examples_opt_tok,
367        }
368    };
369
370    let builtin_path_lit =
371        builtin_path_lit.expect("runtime_builtin requires `builtin_path = \"...\"`");
372    let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
373        .expect("runtime_builtin `builtin_path` must be a valid path");
374    let helper_ident = format_ident!("__runmat_wasm_register_builtin_{}", ident);
375    let builtin_expr_helper = builtin_expr.clone();
376    let doc_expr_helper = doc_expr.clone();
377    let wasm_helper = quote! {
378        #[cfg(target_arch = "wasm32")]
379        #[allow(non_snake_case)]
380        pub(crate) fn #helper_ident() {
381            runmat_builtins::wasm_registry::submit_builtin_function(#builtin_expr_helper);
382            runmat_builtins::wasm_registry::submit_builtin_doc(#doc_expr_helper);
383        }
384    };
385    let register_native = quote! {
386        #[cfg(not(target_arch = "wasm32"))]
387        runmat_builtins::inventory::submit! { #builtin_expr }
388        #[cfg(not(target_arch = "wasm32"))]
389        runmat_builtins::inventory::submit! { #doc_expr }
390    };
391    append_wasm_block(quote! {
392        #builtin_path::#helper_ident();
393    });
394
395    TokenStream::from(quote! {
396        #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
397        #func
398        #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
399        #wrapper
400        #wasm_helper
401        #register_native
402    })
403}
404
405/// Attribute used to declare a runtime constant.
406///
407/// Example:
408/// ```rust,ignore
409/// use runmat_macros::runtime_constant;
410/// use runmat_builtins::Value;
411///
412/// #[runtime_constant(name = "pi", value = std::f64::consts::PI)]
413/// const PI_CONSTANT: ();
414/// ```
415///
416/// This registers the constant with the `runmat-builtins` inventory
417/// so the runtime can discover it at start-up.
418#[proc_macro_attribute]
419pub fn runtime_constant(args: TokenStream, input: TokenStream) -> TokenStream {
420    let args = parse_macro_input!(args as AttributeArgs);
421    let mut name_lit: Option<Lit> = None;
422    let mut value_expr: Option<Expr> = None;
423    let mut builtin_path_lit: Option<LitStr> = None;
424
425    for arg in args {
426        match arg {
427            NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => {
428                if path.is_ident("name") {
429                    name_lit = Some(lit);
430                } else if path.is_ident("builtin_path") {
431                    if let Lit::Str(ls) = lit {
432                        builtin_path_lit = Some(ls);
433                    } else {
434                        panic!("builtin_path must be a string literal");
435                    }
436                } else {
437                    panic!("Unknown attribute parameter: {}", quote!(#path));
438                }
439            }
440            NestedMeta::Meta(Meta::Path(path)) if path.is_ident("value") => {
441                panic!("value parameter requires assignment: value = expression");
442            }
443            NestedMeta::Lit(lit) => {
444                // This handles the case where value is provided as a literal
445                value_expr = Some(syn::parse_quote!(#lit));
446            }
447            _ => panic!("Invalid attribute syntax"),
448        }
449    }
450
451    let name = match name_lit {
452        Some(Lit::Str(s)) => s.value(),
453        _ => panic!("name parameter must be a string literal"),
454    };
455
456    let value = value_expr.unwrap_or_else(|| {
457        panic!("value parameter is required");
458    });
459
460    let builtin_path_lit =
461        builtin_path_lit.expect("runtime_constant requires `builtin_path = \"...\"` argument");
462    let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
463        .expect("runtime_constant `builtin_path` must be a valid path");
464    let item = parse_macro_input!(input as syn::Item);
465
466    let constant_expr = quote! {
467        runmat_builtins::Constant {
468            name: #name,
469            value: #value,
470        }
471    };
472
473    let helper_ident = helper_ident_from_name("__runmat_wasm_register_const_", &name);
474    let constant_expr_helper = constant_expr.clone();
475    let wasm_helper = quote! {
476        #[cfg(target_arch = "wasm32")]
477        #[allow(non_snake_case)]
478        pub(crate) fn #helper_ident() {
479            runmat_builtins::wasm_registry::submit_constant(#constant_expr_helper);
480        }
481    };
482    let register_native = quote! {
483        #[cfg(not(target_arch = "wasm32"))]
484        #[allow(non_upper_case_globals)]
485        runmat_builtins::inventory::submit! { #constant_expr }
486    };
487    append_wasm_block(quote! {
488        #builtin_path::#helper_ident();
489    });
490
491    TokenStream::from(quote! {
492        #item
493        #wasm_helper
494        #register_native
495    })
496}
497
498struct RegisterConstantArgs {
499    name: LitStr,
500    value: Expr,
501    builtin_path: LitStr,
502}
503
504impl syn::parse::Parse for RegisterConstantArgs {
505    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
506        let name: LitStr = input.parse()?;
507        input.parse::<syn::Token![,]>()?;
508        let value: Expr = input.parse()?;
509        input.parse::<syn::Token![,]>()?;
510        let builtin_path: LitStr = input.parse()?;
511        if input.peek(syn::Token![,]) {
512            input.parse::<syn::Token![,]>()?;
513        }
514        Ok(RegisterConstantArgs {
515            name,
516            value,
517            builtin_path,
518        })
519    }
520}
521
522#[proc_macro]
523pub fn register_constant(input: TokenStream) -> TokenStream {
524    let RegisterConstantArgs {
525        name,
526        value,
527        builtin_path,
528    } = parse_macro_input!(input as RegisterConstantArgs);
529    let constant_expr = quote! {
530        runmat_builtins::Constant {
531            name: #name,
532            value: #value,
533        }
534    };
535    let helper_ident = helper_ident_from_name("__runmat_wasm_register_const_", &name.value());
536    let builtin_path: syn::Path = syn::parse_str(&builtin_path.value())
537        .expect("register_constant `builtin_path` must be a valid path");
538    let constant_expr_helper = constant_expr.clone();
539    let wasm_helper = quote! {
540        #[cfg(target_arch = "wasm32")]
541        #[allow(non_snake_case)]
542        pub(crate) fn #helper_ident() {
543            runmat_builtins::wasm_registry::submit_constant(#constant_expr_helper);
544        }
545    };
546    append_wasm_block(quote! {
547        #builtin_path::#helper_ident();
548    });
549    TokenStream::from(quote! {
550        #wasm_helper
551        #[cfg(not(target_arch = "wasm32"))]
552        runmat_builtins::inventory::submit! { #constant_expr }
553    })
554}
555
556struct RegisterSpecAttrArgs {
557    spec_expr: Option<Expr>,
558    builtin_path: Option<LitStr>,
559}
560
561impl Parse for RegisterSpecAttrArgs {
562    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
563        let mut spec_expr = None;
564        let mut builtin_path = None;
565        while !input.is_empty() {
566            let ident: syn::Ident = input.parse()?;
567            input.parse::<syn::Token![=]>()?;
568            if ident == "spec" {
569                spec_expr = Some(input.parse()?);
570            } else if ident == "builtin_path" {
571                let lit: LitStr = input.parse()?;
572                builtin_path = Some(lit);
573            } else {
574                return Err(syn::Error::new(ident.span(), "unknown attribute argument"));
575            }
576            if input.peek(syn::Token![,]) {
577                input.parse::<syn::Token![,]>()?;
578            }
579        }
580        Ok(Self {
581            spec_expr,
582            builtin_path,
583        })
584    }
585}
586
587#[proc_macro_attribute]
588pub fn register_gpu_spec(attr: TokenStream, item: TokenStream) -> TokenStream {
589    let args = parse_macro_input!(attr as RegisterSpecAttrArgs);
590    let RegisterSpecAttrArgs {
591        spec_expr,
592        builtin_path,
593    } = args;
594    let item_const = parse_macro_input!(item as ItemConst);
595    let spec_tokens = spec_expr.map(|expr| quote! { #expr }).unwrap_or_else(|| {
596        let ident = &item_const.ident;
597        quote! { #ident }
598    });
599    let spec_for_native = spec_tokens.clone();
600    let builtin_path_lit =
601        builtin_path.expect("register_gpu_spec requires `builtin_path = \"...\"` argument");
602    let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
603        .expect("register_gpu_spec `builtin_path` must be a valid path");
604    let helper_ident = format_ident!(
605        "__runmat_wasm_register_gpu_spec_{}",
606        item_const.ident.to_string()
607    );
608    let spec_tokens_helper = spec_tokens.clone();
609    let wasm_helper = quote! {
610        #[cfg(target_arch = "wasm32")]
611        #[allow(non_snake_case)]
612        pub(crate) fn #helper_ident() {
613            crate::builtins::common::spec::wasm_registry::submit_gpu_spec(&#spec_tokens_helper);
614        }
615    };
616    append_wasm_block(quote! {
617        #builtin_path::#helper_ident();
618    });
619    let expanded = quote! {
620        #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
621        #item_const
622        #wasm_helper
623        #[cfg(not(target_arch = "wasm32"))]
624        inventory::submit! {
625            crate::builtins::common::spec::GpuSpecInventory { spec: &#spec_for_native }
626        }
627    };
628    expanded.into()
629}
630
631#[proc_macro_attribute]
632pub fn register_fusion_spec(attr: TokenStream, item: TokenStream) -> TokenStream {
633    let args = parse_macro_input!(attr as RegisterSpecAttrArgs);
634    let RegisterSpecAttrArgs {
635        spec_expr,
636        builtin_path,
637    } = args;
638    let item_const = parse_macro_input!(item as ItemConst);
639    let spec_tokens = spec_expr.map(|expr| quote! { #expr }).unwrap_or_else(|| {
640        let ident = &item_const.ident;
641        quote! { #ident }
642    });
643    let spec_for_native = spec_tokens.clone();
644    let builtin_path_lit =
645        builtin_path.expect("register_fusion_spec requires `builtin_path = \"...\"` argument");
646    let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
647        .expect("register_fusion_spec `builtin_path` must be a valid path");
648    let helper_ident = format_ident!(
649        "__runmat_wasm_register_fusion_spec_{}",
650        item_const.ident.to_string()
651    );
652    let spec_tokens_helper = spec_tokens.clone();
653    let wasm_helper = quote! {
654        #[cfg(target_arch = "wasm32")]
655        #[allow(non_snake_case)]
656        pub(crate) fn #helper_ident() {
657            crate::builtins::common::spec::wasm_registry::submit_fusion_spec(&#spec_tokens_helper);
658        }
659    };
660    append_wasm_block(quote! {
661        #builtin_path::#helper_ident();
662    });
663    let expanded = quote! {
664        #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
665        #item_const
666        #wasm_helper
667        #[cfg(not(target_arch = "wasm32"))]
668        inventory::submit! {
669            crate::builtins::common::spec::FusionSpecInventory { spec: &#spec_for_native }
670        }
671    };
672    expanded.into()
673}
674
675fn append_wasm_block(block: proc_macro2::TokenStream) {
676    if !should_generate_wasm_registry() {
677        return;
678    }
679    let path = match wasm_registry_path() {
680        Some(p) => p,
681        None => return,
682    };
683    let _guard = wasm_registry_lock().lock().unwrap();
684    initialize_registry_file(path);
685    let mut contents = fs::read_to_string(path).expect("failed to read wasm registry file");
686    let insertion = format!("    {}\n", block);
687    if let Some(pos) = contents.rfind('}') {
688        contents.insert_str(pos, &insertion);
689    } else {
690        contents.push_str(&insertion);
691        contents.push_str("}\n");
692    }
693    fs::write(path, contents).expect("failed to update wasm registry file");
694}
695
696fn wasm_registry_path() -> Option<&'static PathBuf> {
697    WASM_REGISTRY_PATH
698        .get_or_init(workspace_registry_path)
699        .as_ref()
700}
701
702fn wasm_registry_lock() -> &'static Mutex<()> {
703    WASM_REGISTRY_LOCK.get_or_init(|| Mutex::new(()))
704}
705
706fn initialize_registry_file(path: &Path) {
707    WASM_REGISTRY_INIT.get_or_init(|| {
708        if let Some(parent) = path.parent() {
709            let _ = fs::create_dir_all(parent);
710        }
711        const HEADER: &str = "pub fn register_all() {\n}\n";
712        fs::write(path, HEADER).expect("failed to create wasm registry file");
713    });
714}
715
716fn should_generate_wasm_registry() -> bool {
717    // Only write to the registry file when explicitly requested.  Writing on every build
718    // causes an infinite rebuild loop: proc-macros modify the file → cargo detects the
719    // change → build.rs re-runs → runmat-runtime recompiles → proc-macros run again →
720    // repeat.  Callers (e.g. test-wasm-headless.sh) set RUNMAT_GENERATE_WASM_REGISTRY=1
721    // for the dedicated `cargo check` regeneration step, then unset it before wasm-pack
722    // test so the pre-generated file is used as-is without triggering rebuilds.
723    matches!(
724        std::env::var("RUNMAT_GENERATE_WASM_REGISTRY"),
725        Ok(ref value) if value == "1"
726    )
727}
728
729fn workspace_registry_path() -> Option<PathBuf> {
730    let mut dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").ok()?);
731    loop {
732        if dir.join("Cargo.lock").exists() {
733            return Some(dir.join("target").join("runmat_wasm_registry.rs"));
734        }
735        if !dir.pop() {
736            return None;
737        }
738    }
739}
740
741fn helper_ident_from_name(prefix: &str, name: &str) -> proc_macro2::Ident {
742    let mut sanitized = String::new();
743    for ch in name.chars() {
744        if ch.is_ascii_alphanumeric() || ch == '_' {
745            sanitized.push(ch);
746        } else {
747            sanitized.push('_');
748        }
749    }
750    format_ident!("{}{}", prefix, sanitized)
751}
752
753/// Smart type inference from Rust types to our enhanced Type enum
754fn infer_builtin_type(ty: &syn::Type) -> proc_macro2::TokenStream {
755    use syn::Type;
756
757    match ty {
758        // Basic primitive types
759        Type::Path(type_path) => {
760            if let Some(ident) = type_path.path.get_ident() {
761                match ident.to_string().as_str() {
762                    "i32" | "i64" | "isize" => quote! { runmat_builtins::Type::Int },
763                    "f32" | "f64" => quote! { runmat_builtins::Type::Num },
764                    "bool" => quote! { runmat_builtins::Type::Bool },
765                    "String" => quote! { runmat_builtins::Type::String },
766                    _ => infer_complex_type(type_path),
767                }
768            } else {
769                infer_complex_type(type_path)
770            }
771        }
772
773        // Reference types like &str, &Value, &Matrix
774        Type::Reference(type_ref) => match type_ref.elem.as_ref() {
775            Type::Path(type_path) => {
776                if let Some(ident) = type_path.path.get_ident() {
777                    match ident.to_string().as_str() {
778                        "str" => quote! { runmat_builtins::Type::String },
779                        _ => infer_builtin_type(&type_ref.elem),
780                    }
781                } else {
782                    infer_builtin_type(&type_ref.elem)
783                }
784            }
785            _ => infer_builtin_type(&type_ref.elem),
786        },
787
788        // Slice types like &[Value], &[f64]
789        Type::Slice(type_slice) => {
790            let element_type = infer_builtin_type(&type_slice.elem);
791            quote! { runmat_builtins::Type::Cell {
792                element_type: Some(Box::new(#element_type)),
793                length: None
794            } }
795        }
796
797        // Array types like [f64; N]
798        Type::Array(type_array) => {
799            let element_type = infer_builtin_type(&type_array.elem);
800            // Try to extract length if it's a literal
801            if let syn::Expr::Lit(expr_lit) = &type_array.len {
802                if let syn::Lit::Int(lit_int) = &expr_lit.lit {
803                    if let Ok(length) = lit_int.base10_parse::<usize>() {
804                        return quote! { runmat_builtins::Type::Cell {
805                            element_type: Some(Box::new(#element_type)),
806                            length: Some(#length)
807                        } };
808                    }
809                }
810            }
811            // Fallback to unknown length
812            quote! { runmat_builtins::Type::Cell {
813                element_type: Some(Box::new(#element_type)),
814                length: None
815            } }
816        }
817
818        // Generic or complex types
819        _ => quote! { runmat_builtins::Type::Unknown },
820    }
821}
822
823/// Infer types for complex path types like Result<T, E>, Option<T>, Matrix, Value
824fn infer_complex_type(type_path: &syn::TypePath) -> proc_macro2::TokenStream {
825    let path_str = quote! { #type_path }.to_string();
826
827    // Handle common patterns
828    if path_str.contains("Matrix") || path_str.contains("Tensor") {
829        quote! { runmat_builtins::Type::tensor() }
830    } else if path_str.contains("Value") {
831        quote! { runmat_builtins::Type::Unknown } // Value can be anything
832    } else if path_str.starts_with("Result") {
833        // Extract the Ok type from Result<T, E>
834        if let syn::PathArguments::AngleBracketed(angle_bracketed) =
835            &type_path.path.segments.last().unwrap().arguments
836        {
837            if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
838                return infer_builtin_type(ty);
839            }
840        }
841        quote! { runmat_builtins::Type::Unknown }
842    } else if path_str.starts_with("Option") {
843        // Extract the Some type from Option<T>
844        if let syn::PathArguments::AngleBracketed(angle_bracketed) =
845            &type_path.path.segments.last().unwrap().arguments
846        {
847            if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
848                return infer_builtin_type(ty);
849            }
850        }
851        quote! { runmat_builtins::Type::Unknown }
852    } else if path_str.starts_with("Vec") {
853        // Extract element type from Vec<T>
854        if let syn::PathArguments::AngleBracketed(angle_bracketed) =
855            &type_path.path.segments.last().unwrap().arguments
856        {
857            if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
858                let element_type = infer_builtin_type(ty);
859                return quote! { runmat_builtins::Type::Cell {
860                    element_type: Some(Box::new(#element_type)),
861                    length: None
862                } };
863            }
864        }
865        quote! { runmat_builtins::Type::cell() }
866    } else {
867        // Unknown type
868        quote! { runmat_builtins::Type::Unknown }
869    }
870}