Skip to main content

runmat_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use std::fs;
4use std::path::{Path, PathBuf};
5use std::sync::{Mutex, OnceLock};
6use syn::parse::{Parse, ParseStream};
7use syn::{
8    parse_macro_input, AttributeArgs, Expr, FnArg, ItemConst, ItemFn, Lit, LitStr, Meta,
9    MetaNameValue, NestedMeta, Pat,
10};
11
12static WASM_REGISTRY_PATH: OnceLock<Option<PathBuf>> = OnceLock::new();
13static WASM_REGISTRY_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
14static WASM_REGISTRY_INIT: OnceLock<()> = OnceLock::new();
15
16/// Attribute used to mark functions as implementing a runtime builtin.
17///
18/// Example:
19/// ```rust,ignore
20/// use runmat_macros::runtime_builtin;
21///
22/// #[runtime_builtin(name = "plot")]
23/// pub fn plot_line(xs: &[f64]) {
24///     /* implementation */
25/// }
26/// ```
27///
28/// This registers the function with the `runmat-builtins` inventory
29/// so the runtime can discover it at start-up.
30#[proc_macro_attribute]
31pub fn runtime_builtin(args: TokenStream, input: TokenStream) -> TokenStream {
32    // Parse attribute arguments as `name = "..."`
33    let args = parse_macro_input!(args as AttributeArgs);
34    let mut name_lit: Option<Lit> = None;
35    let mut category_lit: Option<Lit> = None;
36    let mut summary_lit: Option<Lit> = None;
37    let mut keywords_lit: Option<Lit> = None;
38    let mut errors_lit: Option<Lit> = None;
39    let mut related_lit: Option<Lit> = None;
40    let mut introduced_lit: Option<Lit> = None;
41    let mut status_lit: Option<Lit> = None;
42    let mut examples_lit: Option<Lit> = None;
43    let mut accel_values: Vec<String> = Vec::new();
44    let mut builtin_path_lit: Option<LitStr> = None;
45    let mut type_resolver_path: Option<syn::Path> = None;
46    let mut type_resolver_ctx_path: Option<syn::Path> = None;
47    let mut sink_flag = false;
48    let mut suppress_auto_output_flag = false;
49    for arg in args {
50        match arg {
51            NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => {
52                if path.is_ident("name") {
53                    name_lit = Some(lit);
54                } else if path.is_ident("category") {
55                    category_lit = Some(lit);
56                } else if path.is_ident("summary") {
57                    summary_lit = Some(lit);
58                } else if path.is_ident("keywords") {
59                    keywords_lit = Some(lit);
60                } else if path.is_ident("errors") {
61                    errors_lit = Some(lit);
62                } else if path.is_ident("related") {
63                    related_lit = Some(lit);
64                } else if path.is_ident("introduced") {
65                    introduced_lit = Some(lit);
66                } else if path.is_ident("status") {
67                    status_lit = Some(lit);
68                } else if path.is_ident("examples") {
69                    examples_lit = Some(lit);
70                } else if path.is_ident("accel") {
71                    if let Lit::Str(ls) = lit {
72                        accel_values.extend(
73                            ls.value()
74                                .split(|c: char| c == ',' || c == '|' || c.is_ascii_whitespace())
75                                .filter(|s| !s.is_empty())
76                                .map(|s| s.to_ascii_lowercase()),
77                        );
78                    }
79                } else if path.is_ident("sink") {
80                    if let Lit::Bool(lb) = lit {
81                        sink_flag = lb.value;
82                    }
83                } else if path.is_ident("suppress_auto_output") {
84                    if let Lit::Bool(lb) = lit {
85                        suppress_auto_output_flag = lb.value;
86                    }
87                } else if path.is_ident("builtin_path") {
88                    if let Lit::Str(ls) = lit {
89                        builtin_path_lit = Some(ls);
90                    } else {
91                        panic!("builtin_path must be a string literal");
92                    }
93                } else if path.is_ident("type_resolver") {
94                    if let Lit::Str(ls) = lit {
95                        let parsed: syn::Path = ls.parse().expect("type_resolver must be a path");
96                        type_resolver_path = Some(parsed);
97                    } else {
98                        panic!("type_resolver must be a string literal path");
99                    }
100                } else if path.is_ident("type_resolver_ctx") {
101                    if let Lit::Str(ls) = lit {
102                        let parsed: syn::Path =
103                            ls.parse().expect("type_resolver_ctx must be a path");
104                        type_resolver_ctx_path = Some(parsed);
105                    } else {
106                        panic!("type_resolver_ctx must be a string literal path");
107                    }
108                } else {
109                    // Gracefully ignore unknown parameters for better IDE experience
110                }
111            }
112            NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("type_resolver") => {
113                if list.nested.len() != 1 {
114                    panic!("type_resolver expects exactly one path argument");
115                }
116                let nested = list.nested.first().unwrap();
117                if let NestedMeta::Meta(Meta::Path(path)) = nested {
118                    type_resolver_path = Some(path.clone());
119                } else {
120                    panic!("type_resolver expects a path argument");
121                }
122            }
123            NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("type_resolver_ctx") => {
124                if list.nested.len() != 1 {
125                    panic!("type_resolver_ctx expects exactly one path argument");
126                }
127                let nested = list.nested.first().unwrap();
128                if let NestedMeta::Meta(Meta::Path(path)) = nested {
129                    type_resolver_ctx_path = Some(path.clone());
130                } else {
131                    panic!("type_resolver_ctx expects a path argument");
132                }
133            }
134            _ => {}
135        }
136    }
137    let name_lit = name_lit.expect("expected `name = \"...\"` argument");
138    let name_str = if let Lit::Str(ref s) = name_lit {
139        s.value()
140    } else {
141        panic!("name must be a string literal");
142    };
143
144    let func: ItemFn = parse_macro_input!(input as ItemFn);
145    let ident = &func.sig.ident;
146    let is_async = func.sig.asyncness.is_some();
147
148    // Extract param idents and types
149    let mut param_idents = Vec::new();
150    let mut param_types = Vec::new();
151    for arg in &func.sig.inputs {
152        match arg {
153            FnArg::Typed(pt) => {
154                // pattern must be ident
155                if let Pat::Ident(pi) = pt.pat.as_ref() {
156                    param_idents.push(pi.ident.clone());
157                } else {
158                    panic!("parameters must be simple identifiers");
159                }
160                param_types.push((*pt.ty).clone());
161            }
162            _ => panic!("self parameter not allowed"),
163        }
164    }
165    let param_len = param_idents.len();
166
167    // Infer parameter types for BuiltinFunction
168    let inferred_param_types: Vec<proc_macro2::TokenStream> =
169        param_types.iter().map(infer_builtin_type).collect();
170
171    // Infer return type for BuiltinFunction
172    let inferred_return_type = match &func.sig.output {
173        syn::ReturnType::Default => quote! { runmat_builtins::Type::Void },
174        syn::ReturnType::Type(_, ty) => infer_builtin_type(ty),
175    };
176
177    // Detect if last parameter is variadic Vec<Value>
178    let is_last_variadic = param_types
179        .last()
180        .map(|ty| {
181            // crude detection: type path starts with Vec and inner type is runmat_builtins::Value or Value
182            if let syn::Type::Path(tp) = ty {
183                if tp
184                    .path
185                    .segments
186                    .last()
187                    .map(|s| s.ident == "Vec")
188                    .unwrap_or(false)
189                {
190                    if let syn::PathArguments::AngleBracketed(ab) =
191                        &tp.path.segments.last().unwrap().arguments
192                    {
193                        if let Some(syn::GenericArgument::Type(syn::Type::Path(inner))) =
194                            ab.args.first()
195                        {
196                            return inner
197                                .path
198                                .segments
199                                .last()
200                                .map(|s| s.ident == "Value")
201                                .unwrap_or(false);
202                        }
203                    }
204                }
205            }
206            false
207        })
208        .unwrap_or(false);
209
210    // Generate wrapper ident
211    let wrapper_ident = format_ident!("__rt_wrap_{}", ident);
212
213    let conv_stmts: Vec<proc_macro2::TokenStream> = if is_last_variadic && param_len > 0 {
214        let mut stmts = Vec::new();
215        // Convert fixed params (all but last)
216        for (i, (ident, ty)) in param_idents
217            .iter()
218            .zip(param_types.iter())
219            .enumerate()
220            .take(param_len - 1)
221        {
222            stmts.push(quote! { let #ident : #ty = std::convert::TryInto::try_into(&args[#i])?; });
223        }
224        // Collect the rest into Vec<Value>
225        let last_ident = &param_idents[param_len - 1];
226        stmts.push(quote! {
227            let #last_ident : Vec<runmat_builtins::Value> = {
228                let mut v = Vec::new();
229                for j in (#param_len-1)..args.len() {
230                    let item : runmat_builtins::Value = std::convert::TryInto::try_into(&args[j])?;
231                    v.push(item);
232                }
233                v
234            };
235        });
236        stmts
237    } else {
238        param_idents
239            .iter()
240            .zip(param_types.iter())
241            .enumerate()
242            .map(|(i, (ident, ty))| {
243                quote! { let #ident : #ty = std::convert::TryInto::try_into(&args[#i])?; }
244            })
245            .collect()
246    };
247
248    let call_expr = if is_async {
249        quote! { #ident(#(#param_idents),*).await? }
250    } else {
251        quote! { #ident(#(#param_idents),*)? }
252    };
253
254    let wrapper = quote! {
255        fn #wrapper_ident(args: &[runmat_builtins::Value]) -> runmat_builtins::BuiltinFuture {
256            #![allow(unused_variables)]
257            let args = args.to_vec();
258            Box::pin(async move {
259                if #is_last_variadic {
260                    if args.len() < #param_len - 1 {
261                        return Err(std::convert::From::from(format!(
262                            "expected at least {} args, got {}",
263                            #param_len - 1,
264                            args.len()
265                        )));
266                    }
267                } else if args.len() != #param_len {
268                    return Err(std::convert::From::from(format!(
269                        "expected {} args, got {}",
270                        #param_len,
271                        args.len()
272                    )));
273                }
274                #(#conv_stmts)*
275                let value = #call_expr;
276                Ok(runmat_builtins::Value::from(value))
277            })
278        }
279    };
280
281    // Prepare tokens for defaults and options
282    let default_category = syn::LitStr::new("general", proc_macro2::Span::call_site());
283    let default_summary =
284        syn::LitStr::new("Runtime builtin function", proc_macro2::Span::call_site());
285
286    let category_tok: proc_macro2::TokenStream = match &category_lit {
287        Some(syn::Lit::Str(ls)) => quote! { #ls },
288        _ => quote! { #default_category },
289    };
290    let summary_tok: proc_macro2::TokenStream = match &summary_lit {
291        Some(syn::Lit::Str(ls)) => quote! { #ls },
292        _ => quote! { #default_summary },
293    };
294
295    fn opt_tok(lit: &Option<syn::Lit>) -> proc_macro2::TokenStream {
296        if let Some(syn::Lit::Str(ls)) = lit {
297            quote! { Some(#ls) }
298        } else {
299            quote! { None }
300        }
301    }
302    let category_opt_tok = opt_tok(&category_lit);
303    let summary_opt_tok = opt_tok(&summary_lit);
304    let keywords_opt_tok = opt_tok(&keywords_lit);
305    let errors_opt_tok = opt_tok(&errors_lit);
306    let related_opt_tok = opt_tok(&related_lit);
307    let introduced_opt_tok = opt_tok(&introduced_lit);
308    let status_opt_tok = opt_tok(&status_lit);
309    let examples_opt_tok = opt_tok(&examples_lit);
310
311    let accel_tokens: Vec<proc_macro2::TokenStream> = accel_values
312        .iter()
313        .map(|mode| match mode.as_str() {
314            "unary" => quote! { runmat_builtins::AccelTag::Unary },
315            "binary" => quote! { runmat_builtins::AccelTag::Elementwise },
316            "elementwise" => quote! { runmat_builtins::AccelTag::Elementwise },
317            "reduction" => quote! { runmat_builtins::AccelTag::Reduction },
318            "matmul" => quote! { runmat_builtins::AccelTag::MatMul },
319            "transpose" => quote! { runmat_builtins::AccelTag::Transpose },
320            "array_construct" => quote! { runmat_builtins::AccelTag::ArrayConstruct },
321            _ => quote! {},
322        })
323        .filter(|ts| !ts.is_empty())
324        .collect();
325    let accel_slice = if accel_tokens.is_empty() {
326        quote! { &[] as &[runmat_builtins::AccelTag] }
327    } else {
328        quote! { &[#(#accel_tokens),*] }
329    };
330    let type_resolver_expr = if let Some(path) = type_resolver_ctx_path.as_ref() {
331        quote! { Some(runmat_builtins::type_resolver_kind_ctx(#path)) }
332    } else if let Some(path) = type_resolver_path.as_ref() {
333        quote! { Some(runmat_builtins::type_resolver_kind_ctx(#path)) }
334    } else {
335        quote! { None }
336    };
337    let sink_bool = sink_flag;
338    let suppress_auto_output_bool = suppress_auto_output_flag;
339
340    let builtin_expr = quote! {
341        runmat_builtins::BuiltinFunction::new(
342            #name_str,
343            #summary_tok,
344            #category_tok,
345            "",
346            "",
347            vec![#(#inferred_param_types),*],
348            #inferred_return_type,
349            #type_resolver_expr,
350            #wrapper_ident,
351            #accel_slice,
352            #sink_bool,
353            #suppress_auto_output_bool,
354        )
355    };
356
357    let doc_expr = quote! {
358        runmat_builtins::BuiltinDoc {
359            name: #name_str,
360            category: #category_opt_tok,
361            summary: #summary_opt_tok,
362            keywords: #keywords_opt_tok,
363            errors: #errors_opt_tok,
364            related: #related_opt_tok,
365            introduced: #introduced_opt_tok,
366            status: #status_opt_tok,
367            examples: #examples_opt_tok,
368        }
369    };
370
371    let builtin_path_lit =
372        builtin_path_lit.expect("runtime_builtin requires `builtin_path = \"...\"`");
373    let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
374        .expect("runtime_builtin `builtin_path` must be a valid path");
375    let helper_ident = format_ident!("__runmat_wasm_register_builtin_{}", ident);
376    let builtin_expr_helper = builtin_expr.clone();
377    let doc_expr_helper = doc_expr.clone();
378    let wasm_helper = quote! {
379        #[cfg(target_arch = "wasm32")]
380        #[allow(non_snake_case)]
381        pub(crate) fn #helper_ident() {
382            runmat_builtins::wasm_registry::submit_builtin_function(#builtin_expr_helper);
383            runmat_builtins::wasm_registry::submit_builtin_doc(#doc_expr_helper);
384        }
385    };
386    let register_native = quote! {
387        #[cfg(not(target_arch = "wasm32"))]
388        runmat_builtins::inventory::submit! { #builtin_expr }
389        #[cfg(not(target_arch = "wasm32"))]
390        runmat_builtins::inventory::submit! { #doc_expr }
391    };
392    append_wasm_block(quote! {
393        #builtin_path::#helper_ident();
394    });
395
396    TokenStream::from(quote! {
397        #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
398        #func
399        #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
400        #wrapper
401        #wasm_helper
402        #register_native
403    })
404}
405
406/// Attribute used to declare a runtime constant.
407///
408/// Example:
409/// ```rust,ignore
410/// use runmat_macros::runtime_constant;
411/// use runmat_builtins::Value;
412///
413/// #[runtime_constant(name = "pi", value = std::f64::consts::PI)]
414/// const PI_CONSTANT: ();
415/// ```
416///
417/// This registers the constant with the `runmat-builtins` inventory
418/// so the runtime can discover it at start-up.
419#[proc_macro_attribute]
420pub fn runtime_constant(args: TokenStream, input: TokenStream) -> TokenStream {
421    let args = parse_macro_input!(args as AttributeArgs);
422    let mut name_lit: Option<Lit> = None;
423    let mut value_expr: Option<Expr> = None;
424    let mut builtin_path_lit: Option<LitStr> = None;
425
426    for arg in args {
427        match arg {
428            NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => {
429                if path.is_ident("name") {
430                    name_lit = Some(lit);
431                } else if path.is_ident("builtin_path") {
432                    if let Lit::Str(ls) = lit {
433                        builtin_path_lit = Some(ls);
434                    } else {
435                        panic!("builtin_path must be a string literal");
436                    }
437                } else {
438                    panic!("Unknown attribute parameter: {}", quote!(#path));
439                }
440            }
441            NestedMeta::Meta(Meta::Path(path)) if path.is_ident("value") => {
442                panic!("value parameter requires assignment: value = expression");
443            }
444            NestedMeta::Lit(lit) => {
445                // This handles the case where value is provided as a literal
446                value_expr = Some(syn::parse_quote!(#lit));
447            }
448            _ => panic!("Invalid attribute syntax"),
449        }
450    }
451
452    let name = match name_lit {
453        Some(Lit::Str(s)) => s.value(),
454        _ => panic!("name parameter must be a string literal"),
455    };
456
457    let value = value_expr.unwrap_or_else(|| {
458        panic!("value parameter is required");
459    });
460
461    let builtin_path_lit =
462        builtin_path_lit.expect("runtime_constant requires `builtin_path = \"...\"` argument");
463    let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
464        .expect("runtime_constant `builtin_path` must be a valid path");
465    let item = parse_macro_input!(input as syn::Item);
466
467    let constant_expr = quote! {
468        runmat_builtins::Constant {
469            name: #name,
470            value: #value,
471        }
472    };
473
474    let helper_ident = helper_ident_from_name("__runmat_wasm_register_const_", &name);
475    let constant_expr_helper = constant_expr.clone();
476    let wasm_helper = quote! {
477        #[cfg(target_arch = "wasm32")]
478        #[allow(non_snake_case)]
479        pub(crate) fn #helper_ident() {
480            runmat_builtins::wasm_registry::submit_constant(#constant_expr_helper);
481        }
482    };
483    let register_native = quote! {
484        #[cfg(not(target_arch = "wasm32"))]
485        #[allow(non_upper_case_globals)]
486        runmat_builtins::inventory::submit! { #constant_expr }
487    };
488    append_wasm_block(quote! {
489        #builtin_path::#helper_ident();
490    });
491
492    TokenStream::from(quote! {
493        #item
494        #wasm_helper
495        #register_native
496    })
497}
498
499struct RegisterConstantArgs {
500    name: LitStr,
501    value: Expr,
502    builtin_path: LitStr,
503}
504
505impl syn::parse::Parse for RegisterConstantArgs {
506    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
507        let name: LitStr = input.parse()?;
508        input.parse::<syn::Token![,]>()?;
509        let value: Expr = input.parse()?;
510        input.parse::<syn::Token![,]>()?;
511        let builtin_path: LitStr = input.parse()?;
512        if input.peek(syn::Token![,]) {
513            input.parse::<syn::Token![,]>()?;
514        }
515        Ok(RegisterConstantArgs {
516            name,
517            value,
518            builtin_path,
519        })
520    }
521}
522
523#[proc_macro]
524pub fn register_constant(input: TokenStream) -> TokenStream {
525    let RegisterConstantArgs {
526        name,
527        value,
528        builtin_path,
529    } = parse_macro_input!(input as RegisterConstantArgs);
530    let constant_expr = quote! {
531        runmat_builtins::Constant {
532            name: #name,
533            value: #value,
534        }
535    };
536    let helper_ident = helper_ident_from_name("__runmat_wasm_register_const_", &name.value());
537    let builtin_path: syn::Path = syn::parse_str(&builtin_path.value())
538        .expect("register_constant `builtin_path` must be a valid path");
539    let constant_expr_helper = constant_expr.clone();
540    let wasm_helper = quote! {
541        #[cfg(target_arch = "wasm32")]
542        #[allow(non_snake_case)]
543        pub(crate) fn #helper_ident() {
544            runmat_builtins::wasm_registry::submit_constant(#constant_expr_helper);
545        }
546    };
547    append_wasm_block(quote! {
548        #builtin_path::#helper_ident();
549    });
550    TokenStream::from(quote! {
551        #wasm_helper
552        #[cfg(not(target_arch = "wasm32"))]
553        runmat_builtins::inventory::submit! { #constant_expr }
554    })
555}
556
557struct RegisterSpecAttrArgs {
558    spec_expr: Option<Expr>,
559    builtin_path: Option<LitStr>,
560}
561
562impl Parse for RegisterSpecAttrArgs {
563    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
564        let mut spec_expr = None;
565        let mut builtin_path = None;
566        while !input.is_empty() {
567            let ident: syn::Ident = input.parse()?;
568            input.parse::<syn::Token![=]>()?;
569            if ident == "spec" {
570                spec_expr = Some(input.parse()?);
571            } else if ident == "builtin_path" {
572                let lit: LitStr = input.parse()?;
573                builtin_path = Some(lit);
574            } else {
575                return Err(syn::Error::new(ident.span(), "unknown attribute argument"));
576            }
577            if input.peek(syn::Token![,]) {
578                input.parse::<syn::Token![,]>()?;
579            }
580        }
581        Ok(Self {
582            spec_expr,
583            builtin_path,
584        })
585    }
586}
587
588#[proc_macro_attribute]
589pub fn register_gpu_spec(attr: TokenStream, item: TokenStream) -> TokenStream {
590    let args = parse_macro_input!(attr as RegisterSpecAttrArgs);
591    let RegisterSpecAttrArgs {
592        spec_expr,
593        builtin_path,
594    } = args;
595    let item_const = parse_macro_input!(item as ItemConst);
596    let spec_tokens = spec_expr.map(|expr| quote! { #expr }).unwrap_or_else(|| {
597        let ident = &item_const.ident;
598        quote! { #ident }
599    });
600    let spec_for_native = spec_tokens.clone();
601    let builtin_path_lit =
602        builtin_path.expect("register_gpu_spec requires `builtin_path = \"...\"` argument");
603    let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
604        .expect("register_gpu_spec `builtin_path` must be a valid path");
605    let helper_ident = format_ident!(
606        "__runmat_wasm_register_gpu_spec_{}",
607        item_const.ident.to_string()
608    );
609    let spec_tokens_helper = spec_tokens.clone();
610    let wasm_helper = quote! {
611        #[cfg(target_arch = "wasm32")]
612        #[allow(non_snake_case)]
613        pub(crate) fn #helper_ident() {
614            crate::builtins::common::spec::wasm_registry::submit_gpu_spec(&#spec_tokens_helper);
615        }
616    };
617    append_wasm_block(quote! {
618        #builtin_path::#helper_ident();
619    });
620    let expanded = quote! {
621        #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
622        #item_const
623        #wasm_helper
624        #[cfg(not(target_arch = "wasm32"))]
625        inventory::submit! {
626            crate::builtins::common::spec::GpuSpecInventory { spec: &#spec_for_native }
627        }
628    };
629    expanded.into()
630}
631
632#[proc_macro_attribute]
633pub fn register_fusion_spec(attr: TokenStream, item: TokenStream) -> TokenStream {
634    let args = parse_macro_input!(attr as RegisterSpecAttrArgs);
635    let RegisterSpecAttrArgs {
636        spec_expr,
637        builtin_path,
638    } = args;
639    let item_const = parse_macro_input!(item as ItemConst);
640    let spec_tokens = spec_expr.map(|expr| quote! { #expr }).unwrap_or_else(|| {
641        let ident = &item_const.ident;
642        quote! { #ident }
643    });
644    let spec_for_native = spec_tokens.clone();
645    let builtin_path_lit =
646        builtin_path.expect("register_fusion_spec requires `builtin_path = \"...\"` argument");
647    let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
648        .expect("register_fusion_spec `builtin_path` must be a valid path");
649    let helper_ident = format_ident!(
650        "__runmat_wasm_register_fusion_spec_{}",
651        item_const.ident.to_string()
652    );
653    let spec_tokens_helper = spec_tokens.clone();
654    let wasm_helper = quote! {
655        #[cfg(target_arch = "wasm32")]
656        #[allow(non_snake_case)]
657        pub(crate) fn #helper_ident() {
658            crate::builtins::common::spec::wasm_registry::submit_fusion_spec(&#spec_tokens_helper);
659        }
660    };
661    append_wasm_block(quote! {
662        #builtin_path::#helper_ident();
663    });
664    let expanded = quote! {
665        #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
666        #item_const
667        #wasm_helper
668        #[cfg(not(target_arch = "wasm32"))]
669        inventory::submit! {
670            crate::builtins::common::spec::FusionSpecInventory { spec: &#spec_for_native }
671        }
672    };
673    expanded.into()
674}
675
676fn append_wasm_block(block: proc_macro2::TokenStream) {
677    if !should_generate_wasm_registry() {
678        return;
679    }
680    let path = match wasm_registry_path() {
681        Some(p) => p,
682        None => return,
683    };
684    let _guard = wasm_registry_lock().lock().unwrap();
685    initialize_registry_file(path);
686    let mut contents = fs::read_to_string(path).expect("failed to read wasm registry file");
687    let insertion = format!("    {}\n", block);
688    if let Some(pos) = contents.rfind('}') {
689        contents.insert_str(pos, &insertion);
690    } else {
691        contents.push_str(&insertion);
692        contents.push_str("}\n");
693    }
694    fs::write(path, contents).expect("failed to update wasm registry file");
695}
696
697fn wasm_registry_path() -> Option<&'static PathBuf> {
698    WASM_REGISTRY_PATH
699        .get_or_init(workspace_registry_path)
700        .as_ref()
701}
702
703fn wasm_registry_lock() -> &'static Mutex<()> {
704    WASM_REGISTRY_LOCK.get_or_init(|| Mutex::new(()))
705}
706
707fn initialize_registry_file(path: &Path) {
708    WASM_REGISTRY_INIT.get_or_init(|| {
709        if let Some(parent) = path.parent() {
710            let _ = fs::create_dir_all(parent);
711        }
712        const HEADER: &str = "pub fn register_all() {\n}\n";
713        fs::write(path, HEADER).expect("failed to create wasm registry file");
714    });
715}
716
717fn should_generate_wasm_registry() -> bool {
718    // Only write to the registry file when explicitly requested.  Writing on every build
719    // causes an infinite rebuild loop: proc-macros modify the file → cargo detects the
720    // change → build.rs re-runs → runmat-runtime recompiles → proc-macros run again →
721    // repeat.  Callers (e.g. test-wasm-headless.sh) set RUNMAT_GENERATE_WASM_REGISTRY=1
722    // for the dedicated `cargo check` regeneration step, then unset it before wasm-pack
723    // test so the pre-generated file is used as-is without triggering rebuilds.
724    matches!(
725        std::env::var("RUNMAT_GENERATE_WASM_REGISTRY"),
726        Ok(ref value) if value == "1"
727    )
728}
729
730fn workspace_registry_path() -> Option<PathBuf> {
731    let mut dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").ok()?);
732    loop {
733        if dir.join("Cargo.lock").exists() {
734            return Some(dir.join("target").join("runmat_wasm_registry.rs"));
735        }
736        if !dir.pop() {
737            return None;
738        }
739    }
740}
741
742fn helper_ident_from_name(prefix: &str, name: &str) -> proc_macro2::Ident {
743    let mut sanitized = String::new();
744    for ch in name.chars() {
745        if ch.is_ascii_alphanumeric() || ch == '_' {
746            sanitized.push(ch);
747        } else {
748            sanitized.push('_');
749        }
750    }
751    format_ident!("{}{}", prefix, sanitized)
752}
753
754/// Smart type inference from Rust types to our enhanced Type enum
755fn infer_builtin_type(ty: &syn::Type) -> proc_macro2::TokenStream {
756    use syn::Type;
757
758    match ty {
759        // Basic primitive types
760        Type::Path(type_path) => {
761            if let Some(ident) = type_path.path.get_ident() {
762                match ident.to_string().as_str() {
763                    "i32" | "i64" | "isize" => quote! { runmat_builtins::Type::Int },
764                    "f32" | "f64" => quote! { runmat_builtins::Type::Num },
765                    "bool" => quote! { runmat_builtins::Type::Bool },
766                    "String" => quote! { runmat_builtins::Type::String },
767                    _ => infer_complex_type(type_path),
768                }
769            } else {
770                infer_complex_type(type_path)
771            }
772        }
773
774        // Reference types like &str, &Value, &Matrix
775        Type::Reference(type_ref) => match type_ref.elem.as_ref() {
776            Type::Path(type_path) => {
777                if let Some(ident) = type_path.path.get_ident() {
778                    match ident.to_string().as_str() {
779                        "str" => quote! { runmat_builtins::Type::String },
780                        _ => infer_builtin_type(&type_ref.elem),
781                    }
782                } else {
783                    infer_builtin_type(&type_ref.elem)
784                }
785            }
786            _ => infer_builtin_type(&type_ref.elem),
787        },
788
789        // Slice types like &[Value], &[f64]
790        Type::Slice(type_slice) => {
791            let element_type = infer_builtin_type(&type_slice.elem);
792            quote! { runmat_builtins::Type::Cell {
793                element_type: Some(Box::new(#element_type)),
794                length: None
795            } }
796        }
797
798        // Array types like [f64; N]
799        Type::Array(type_array) => {
800            let element_type = infer_builtin_type(&type_array.elem);
801            // Try to extract length if it's a literal
802            if let syn::Expr::Lit(expr_lit) = &type_array.len {
803                if let syn::Lit::Int(lit_int) = &expr_lit.lit {
804                    if let Ok(length) = lit_int.base10_parse::<usize>() {
805                        return quote! { runmat_builtins::Type::Cell {
806                            element_type: Some(Box::new(#element_type)),
807                            length: Some(#length)
808                        } };
809                    }
810                }
811            }
812            // Fallback to unknown length
813            quote! { runmat_builtins::Type::Cell {
814                element_type: Some(Box::new(#element_type)),
815                length: None
816            } }
817        }
818
819        // Generic or complex types
820        _ => quote! { runmat_builtins::Type::Unknown },
821    }
822}
823
824/// Infer types for complex path types like Result<T, E>, Option<T>, Matrix, Value
825fn infer_complex_type(type_path: &syn::TypePath) -> proc_macro2::TokenStream {
826    let path_str = quote! { #type_path }.to_string();
827
828    // Handle common patterns
829    if path_str.contains("Matrix") || path_str.contains("Tensor") {
830        quote! { runmat_builtins::Type::tensor() }
831    } else if path_str.contains("Value") {
832        quote! { runmat_builtins::Type::Unknown } // Value can be anything
833    } else if path_str.starts_with("Result") {
834        // Extract the Ok type from Result<T, E>
835        if let syn::PathArguments::AngleBracketed(angle_bracketed) =
836            &type_path.path.segments.last().unwrap().arguments
837        {
838            if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
839                return infer_builtin_type(ty);
840            }
841        }
842        quote! { runmat_builtins::Type::Unknown }
843    } else if path_str.starts_with("Option") {
844        // Extract the Some type from Option<T>
845        if let syn::PathArguments::AngleBracketed(angle_bracketed) =
846            &type_path.path.segments.last().unwrap().arguments
847        {
848            if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
849                return infer_builtin_type(ty);
850            }
851        }
852        quote! { runmat_builtins::Type::Unknown }
853    } else if path_str.starts_with("Vec") {
854        // Extract element type from Vec<T>
855        if let syn::PathArguments::AngleBracketed(angle_bracketed) =
856            &type_path.path.segments.last().unwrap().arguments
857        {
858            if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
859                let element_type = infer_builtin_type(ty);
860                return quote! { runmat_builtins::Type::Cell {
861                    element_type: Some(Box::new(#element_type)),
862                    length: None
863                } };
864            }
865        }
866        quote! { runmat_builtins::Type::cell() }
867    } else {
868        // Unknown type
869        quote! { runmat_builtins::Type::Unknown }
870    }
871}