1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use std::fs;
4use std::path::{Path, PathBuf};
5use std::sync::{Mutex, OnceLock};
6use syn::parse::{Parse, ParseStream};
7use syn::{
8 parse_macro_input, AttributeArgs, Expr, FnArg, ItemConst, ItemFn, Lit, LitStr, Meta,
9 MetaNameValue, NestedMeta, Pat,
10};
11
12static WASM_REGISTRY_PATH: OnceLock<Option<PathBuf>> = OnceLock::new();
13static WASM_REGISTRY_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
14static WASM_REGISTRY_INIT: OnceLock<()> = OnceLock::new();
15
16#[proc_macro_attribute]
31pub fn runtime_builtin(args: TokenStream, input: TokenStream) -> TokenStream {
32 let args = parse_macro_input!(args as AttributeArgs);
34 let mut name_lit: Option<Lit> = None;
35 let mut category_lit: Option<Lit> = None;
36 let mut summary_lit: Option<Lit> = None;
37 let mut keywords_lit: Option<Lit> = None;
38 let mut errors_lit: Option<Lit> = None;
39 let mut related_lit: Option<Lit> = None;
40 let mut introduced_lit: Option<Lit> = None;
41 let mut status_lit: Option<Lit> = None;
42 let mut examples_lit: Option<Lit> = None;
43 let mut accel_values: Vec<String> = Vec::new();
44 let mut builtin_path_lit: Option<LitStr> = None;
45 let mut type_resolver_path: Option<syn::Path> = None;
46 let mut type_resolver_ctx_path: Option<syn::Path> = None;
47 let mut sink_flag = false;
48 let mut suppress_auto_output_flag = false;
49 for arg in args {
50 match arg {
51 NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => {
52 if path.is_ident("name") {
53 name_lit = Some(lit);
54 } else if path.is_ident("category") {
55 category_lit = Some(lit);
56 } else if path.is_ident("summary") {
57 summary_lit = Some(lit);
58 } else if path.is_ident("keywords") {
59 keywords_lit = Some(lit);
60 } else if path.is_ident("errors") {
61 errors_lit = Some(lit);
62 } else if path.is_ident("related") {
63 related_lit = Some(lit);
64 } else if path.is_ident("introduced") {
65 introduced_lit = Some(lit);
66 } else if path.is_ident("status") {
67 status_lit = Some(lit);
68 } else if path.is_ident("examples") {
69 examples_lit = Some(lit);
70 } else if path.is_ident("accel") {
71 if let Lit::Str(ls) = lit {
72 accel_values.extend(
73 ls.value()
74 .split(|c: char| c == ',' || c == '|' || c.is_ascii_whitespace())
75 .filter(|s| !s.is_empty())
76 .map(|s| s.to_ascii_lowercase()),
77 );
78 }
79 } else if path.is_ident("sink") {
80 if let Lit::Bool(lb) = lit {
81 sink_flag = lb.value;
82 }
83 } else if path.is_ident("suppress_auto_output") {
84 if let Lit::Bool(lb) = lit {
85 suppress_auto_output_flag = lb.value;
86 }
87 } else if path.is_ident("builtin_path") {
88 if let Lit::Str(ls) = lit {
89 builtin_path_lit = Some(ls);
90 } else {
91 panic!("builtin_path must be a string literal");
92 }
93 } else if path.is_ident("type_resolver") {
94 if let Lit::Str(ls) = lit {
95 let parsed: syn::Path = ls.parse().expect("type_resolver must be a path");
96 type_resolver_path = Some(parsed);
97 } else {
98 panic!("type_resolver must be a string literal path");
99 }
100 } else if path.is_ident("type_resolver_ctx") {
101 if let Lit::Str(ls) = lit {
102 let parsed: syn::Path =
103 ls.parse().expect("type_resolver_ctx must be a path");
104 type_resolver_ctx_path = Some(parsed);
105 } else {
106 panic!("type_resolver_ctx must be a string literal path");
107 }
108 } else {
109 }
111 }
112 NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("type_resolver") => {
113 if list.nested.len() != 1 {
114 panic!("type_resolver expects exactly one path argument");
115 }
116 let nested = list.nested.first().unwrap();
117 if let NestedMeta::Meta(Meta::Path(path)) = nested {
118 type_resolver_path = Some(path.clone());
119 } else {
120 panic!("type_resolver expects a path argument");
121 }
122 }
123 NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("type_resolver_ctx") => {
124 if list.nested.len() != 1 {
125 panic!("type_resolver_ctx expects exactly one path argument");
126 }
127 let nested = list.nested.first().unwrap();
128 if let NestedMeta::Meta(Meta::Path(path)) = nested {
129 type_resolver_ctx_path = Some(path.clone());
130 } else {
131 panic!("type_resolver_ctx expects a path argument");
132 }
133 }
134 _ => {}
135 }
136 }
137 let name_lit = name_lit.expect("expected `name = \"...\"` argument");
138 let name_str = if let Lit::Str(ref s) = name_lit {
139 s.value()
140 } else {
141 panic!("name must be a string literal");
142 };
143
144 let func: ItemFn = parse_macro_input!(input as ItemFn);
145 let ident = &func.sig.ident;
146 let is_async = func.sig.asyncness.is_some();
147
148 let mut param_idents = Vec::new();
150 let mut param_types = Vec::new();
151 for arg in &func.sig.inputs {
152 match arg {
153 FnArg::Typed(pt) => {
154 if let Pat::Ident(pi) = pt.pat.as_ref() {
156 param_idents.push(pi.ident.clone());
157 } else {
158 panic!("parameters must be simple identifiers");
159 }
160 param_types.push((*pt.ty).clone());
161 }
162 _ => panic!("self parameter not allowed"),
163 }
164 }
165 let param_len = param_idents.len();
166
167 let inferred_param_types: Vec<proc_macro2::TokenStream> =
169 param_types.iter().map(infer_builtin_type).collect();
170
171 let inferred_return_type = match &func.sig.output {
173 syn::ReturnType::Default => quote! { runmat_builtins::Type::Void },
174 syn::ReturnType::Type(_, ty) => infer_builtin_type(ty),
175 };
176
177 let is_last_variadic = param_types
179 .last()
180 .map(|ty| {
181 if let syn::Type::Path(tp) = ty {
183 if tp
184 .path
185 .segments
186 .last()
187 .map(|s| s.ident == "Vec")
188 .unwrap_or(false)
189 {
190 if let syn::PathArguments::AngleBracketed(ab) =
191 &tp.path.segments.last().unwrap().arguments
192 {
193 if let Some(syn::GenericArgument::Type(syn::Type::Path(inner))) =
194 ab.args.first()
195 {
196 return inner
197 .path
198 .segments
199 .last()
200 .map(|s| s.ident == "Value")
201 .unwrap_or(false);
202 }
203 }
204 }
205 }
206 false
207 })
208 .unwrap_or(false);
209
210 let wrapper_ident = format_ident!("__rt_wrap_{}", ident);
212
213 let conv_stmts: Vec<proc_macro2::TokenStream> = if is_last_variadic && param_len > 0 {
214 let mut stmts = Vec::new();
215 for (i, (ident, ty)) in param_idents
217 .iter()
218 .zip(param_types.iter())
219 .enumerate()
220 .take(param_len - 1)
221 {
222 stmts.push(quote! { let #ident : #ty = std::convert::TryInto::try_into(&args[#i])?; });
223 }
224 let last_ident = ¶m_idents[param_len - 1];
226 stmts.push(quote! {
227 let #last_ident : Vec<runmat_builtins::Value> = {
228 let mut v = Vec::new();
229 for j in (#param_len-1)..args.len() {
230 let item : runmat_builtins::Value = std::convert::TryInto::try_into(&args[j])?;
231 v.push(item);
232 }
233 v
234 };
235 });
236 stmts
237 } else {
238 param_idents
239 .iter()
240 .zip(param_types.iter())
241 .enumerate()
242 .map(|(i, (ident, ty))| {
243 quote! { let #ident : #ty = std::convert::TryInto::try_into(&args[#i])?; }
244 })
245 .collect()
246 };
247
248 let call_expr = if is_async {
249 quote! { #ident(#(#param_idents),*).await? }
250 } else {
251 quote! { #ident(#(#param_idents),*)? }
252 };
253
254 let wrapper = quote! {
255 fn #wrapper_ident(args: &[runmat_builtins::Value]) -> runmat_builtins::BuiltinFuture {
256 #![allow(unused_variables)]
257 let args = args.to_vec();
258 Box::pin(async move {
259 if #is_last_variadic {
260 if args.len() < #param_len - 1 {
261 return Err(std::convert::From::from(format!(
262 "expected at least {} args, got {}",
263 #param_len - 1,
264 args.len()
265 )));
266 }
267 } else if args.len() != #param_len {
268 return Err(std::convert::From::from(format!(
269 "expected {} args, got {}",
270 #param_len,
271 args.len()
272 )));
273 }
274 #(#conv_stmts)*
275 let value = #call_expr;
276 Ok(runmat_builtins::Value::from(value))
277 })
278 }
279 };
280
281 let default_category = syn::LitStr::new("general", proc_macro2::Span::call_site());
283 let default_summary =
284 syn::LitStr::new("Runtime builtin function", proc_macro2::Span::call_site());
285
286 let category_tok: proc_macro2::TokenStream = match &category_lit {
287 Some(syn::Lit::Str(ls)) => quote! { #ls },
288 _ => quote! { #default_category },
289 };
290 let summary_tok: proc_macro2::TokenStream = match &summary_lit {
291 Some(syn::Lit::Str(ls)) => quote! { #ls },
292 _ => quote! { #default_summary },
293 };
294
295 fn opt_tok(lit: &Option<syn::Lit>) -> proc_macro2::TokenStream {
296 if let Some(syn::Lit::Str(ls)) = lit {
297 quote! { Some(#ls) }
298 } else {
299 quote! { None }
300 }
301 }
302 let category_opt_tok = opt_tok(&category_lit);
303 let summary_opt_tok = opt_tok(&summary_lit);
304 let keywords_opt_tok = opt_tok(&keywords_lit);
305 let errors_opt_tok = opt_tok(&errors_lit);
306 let related_opt_tok = opt_tok(&related_lit);
307 let introduced_opt_tok = opt_tok(&introduced_lit);
308 let status_opt_tok = opt_tok(&status_lit);
309 let examples_opt_tok = opt_tok(&examples_lit);
310
311 let accel_tokens: Vec<proc_macro2::TokenStream> = accel_values
312 .iter()
313 .map(|mode| match mode.as_str() {
314 "unary" => quote! { runmat_builtins::AccelTag::Unary },
315 "elementwise" => quote! { runmat_builtins::AccelTag::Elementwise },
316 "reduction" => quote! { runmat_builtins::AccelTag::Reduction },
317 "matmul" => quote! { runmat_builtins::AccelTag::MatMul },
318 "transpose" => quote! { runmat_builtins::AccelTag::Transpose },
319 "array_construct" => quote! { runmat_builtins::AccelTag::ArrayConstruct },
320 _ => quote! {},
321 })
322 .filter(|ts| !ts.is_empty())
323 .collect();
324 let accel_slice = if accel_tokens.is_empty() {
325 quote! { &[] as &[runmat_builtins::AccelTag] }
326 } else {
327 quote! { &[#(#accel_tokens),*] }
328 };
329 let type_resolver_expr = if let Some(path) = type_resolver_ctx_path.as_ref() {
330 quote! { Some(runmat_builtins::type_resolver_kind_ctx(#path)) }
331 } else if let Some(path) = type_resolver_path.as_ref() {
332 quote! { Some(runmat_builtins::type_resolver_kind_ctx(#path)) }
333 } else {
334 quote! { None }
335 };
336 let sink_bool = sink_flag;
337 let suppress_auto_output_bool = suppress_auto_output_flag;
338
339 let builtin_expr = quote! {
340 runmat_builtins::BuiltinFunction::new(
341 #name_str,
342 #summary_tok,
343 #category_tok,
344 "",
345 "",
346 vec![#(#inferred_param_types),*],
347 #inferred_return_type,
348 #type_resolver_expr,
349 #wrapper_ident,
350 #accel_slice,
351 #sink_bool,
352 #suppress_auto_output_bool,
353 )
354 };
355
356 let doc_expr = quote! {
357 runmat_builtins::BuiltinDoc {
358 name: #name_str,
359 category: #category_opt_tok,
360 summary: #summary_opt_tok,
361 keywords: #keywords_opt_tok,
362 errors: #errors_opt_tok,
363 related: #related_opt_tok,
364 introduced: #introduced_opt_tok,
365 status: #status_opt_tok,
366 examples: #examples_opt_tok,
367 }
368 };
369
370 let builtin_path_lit =
371 builtin_path_lit.expect("runtime_builtin requires `builtin_path = \"...\"`");
372 let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
373 .expect("runtime_builtin `builtin_path` must be a valid path");
374 let helper_ident = format_ident!("__runmat_wasm_register_builtin_{}", ident);
375 let builtin_expr_helper = builtin_expr.clone();
376 let doc_expr_helper = doc_expr.clone();
377 let wasm_helper = quote! {
378 #[cfg(target_arch = "wasm32")]
379 #[allow(non_snake_case)]
380 pub(crate) fn #helper_ident() {
381 runmat_builtins::wasm_registry::submit_builtin_function(#builtin_expr_helper);
382 runmat_builtins::wasm_registry::submit_builtin_doc(#doc_expr_helper);
383 }
384 };
385 let register_native = quote! {
386 #[cfg(not(target_arch = "wasm32"))]
387 runmat_builtins::inventory::submit! { #builtin_expr }
388 #[cfg(not(target_arch = "wasm32"))]
389 runmat_builtins::inventory::submit! { #doc_expr }
390 };
391 append_wasm_block(quote! {
392 #builtin_path::#helper_ident();
393 });
394
395 TokenStream::from(quote! {
396 #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
397 #func
398 #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
399 #wrapper
400 #wasm_helper
401 #register_native
402 })
403}
404
405#[proc_macro_attribute]
419pub fn runtime_constant(args: TokenStream, input: TokenStream) -> TokenStream {
420 let args = parse_macro_input!(args as AttributeArgs);
421 let mut name_lit: Option<Lit> = None;
422 let mut value_expr: Option<Expr> = None;
423 let mut builtin_path_lit: Option<LitStr> = None;
424
425 for arg in args {
426 match arg {
427 NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => {
428 if path.is_ident("name") {
429 name_lit = Some(lit);
430 } else if path.is_ident("builtin_path") {
431 if let Lit::Str(ls) = lit {
432 builtin_path_lit = Some(ls);
433 } else {
434 panic!("builtin_path must be a string literal");
435 }
436 } else {
437 panic!("Unknown attribute parameter: {}", quote!(#path));
438 }
439 }
440 NestedMeta::Meta(Meta::Path(path)) if path.is_ident("value") => {
441 panic!("value parameter requires assignment: value = expression");
442 }
443 NestedMeta::Lit(lit) => {
444 value_expr = Some(syn::parse_quote!(#lit));
446 }
447 _ => panic!("Invalid attribute syntax"),
448 }
449 }
450
451 let name = match name_lit {
452 Some(Lit::Str(s)) => s.value(),
453 _ => panic!("name parameter must be a string literal"),
454 };
455
456 let value = value_expr.unwrap_or_else(|| {
457 panic!("value parameter is required");
458 });
459
460 let builtin_path_lit =
461 builtin_path_lit.expect("runtime_constant requires `builtin_path = \"...\"` argument");
462 let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
463 .expect("runtime_constant `builtin_path` must be a valid path");
464 let item = parse_macro_input!(input as syn::Item);
465
466 let constant_expr = quote! {
467 runmat_builtins::Constant {
468 name: #name,
469 value: #value,
470 }
471 };
472
473 let helper_ident = helper_ident_from_name("__runmat_wasm_register_const_", &name);
474 let constant_expr_helper = constant_expr.clone();
475 let wasm_helper = quote! {
476 #[cfg(target_arch = "wasm32")]
477 #[allow(non_snake_case)]
478 pub(crate) fn #helper_ident() {
479 runmat_builtins::wasm_registry::submit_constant(#constant_expr_helper);
480 }
481 };
482 let register_native = quote! {
483 #[cfg(not(target_arch = "wasm32"))]
484 #[allow(non_upper_case_globals)]
485 runmat_builtins::inventory::submit! { #constant_expr }
486 };
487 append_wasm_block(quote! {
488 #builtin_path::#helper_ident();
489 });
490
491 TokenStream::from(quote! {
492 #item
493 #wasm_helper
494 #register_native
495 })
496}
497
498struct RegisterConstantArgs {
499 name: LitStr,
500 value: Expr,
501 builtin_path: LitStr,
502}
503
504impl syn::parse::Parse for RegisterConstantArgs {
505 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
506 let name: LitStr = input.parse()?;
507 input.parse::<syn::Token![,]>()?;
508 let value: Expr = input.parse()?;
509 input.parse::<syn::Token![,]>()?;
510 let builtin_path: LitStr = input.parse()?;
511 if input.peek(syn::Token![,]) {
512 input.parse::<syn::Token![,]>()?;
513 }
514 Ok(RegisterConstantArgs {
515 name,
516 value,
517 builtin_path,
518 })
519 }
520}
521
522#[proc_macro]
523pub fn register_constant(input: TokenStream) -> TokenStream {
524 let RegisterConstantArgs {
525 name,
526 value,
527 builtin_path,
528 } = parse_macro_input!(input as RegisterConstantArgs);
529 let constant_expr = quote! {
530 runmat_builtins::Constant {
531 name: #name,
532 value: #value,
533 }
534 };
535 let helper_ident = helper_ident_from_name("__runmat_wasm_register_const_", &name.value());
536 let builtin_path: syn::Path = syn::parse_str(&builtin_path.value())
537 .expect("register_constant `builtin_path` must be a valid path");
538 let constant_expr_helper = constant_expr.clone();
539 let wasm_helper = quote! {
540 #[cfg(target_arch = "wasm32")]
541 #[allow(non_snake_case)]
542 pub(crate) fn #helper_ident() {
543 runmat_builtins::wasm_registry::submit_constant(#constant_expr_helper);
544 }
545 };
546 append_wasm_block(quote! {
547 #builtin_path::#helper_ident();
548 });
549 TokenStream::from(quote! {
550 #wasm_helper
551 #[cfg(not(target_arch = "wasm32"))]
552 runmat_builtins::inventory::submit! { #constant_expr }
553 })
554}
555
556struct RegisterSpecAttrArgs {
557 spec_expr: Option<Expr>,
558 builtin_path: Option<LitStr>,
559}
560
561impl Parse for RegisterSpecAttrArgs {
562 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
563 let mut spec_expr = None;
564 let mut builtin_path = None;
565 while !input.is_empty() {
566 let ident: syn::Ident = input.parse()?;
567 input.parse::<syn::Token![=]>()?;
568 if ident == "spec" {
569 spec_expr = Some(input.parse()?);
570 } else if ident == "builtin_path" {
571 let lit: LitStr = input.parse()?;
572 builtin_path = Some(lit);
573 } else {
574 return Err(syn::Error::new(ident.span(), "unknown attribute argument"));
575 }
576 if input.peek(syn::Token![,]) {
577 input.parse::<syn::Token![,]>()?;
578 }
579 }
580 Ok(Self {
581 spec_expr,
582 builtin_path,
583 })
584 }
585}
586
587#[proc_macro_attribute]
588pub fn register_gpu_spec(attr: TokenStream, item: TokenStream) -> TokenStream {
589 let args = parse_macro_input!(attr as RegisterSpecAttrArgs);
590 let RegisterSpecAttrArgs {
591 spec_expr,
592 builtin_path,
593 } = args;
594 let item_const = parse_macro_input!(item as ItemConst);
595 let spec_tokens = spec_expr.map(|expr| quote! { #expr }).unwrap_or_else(|| {
596 let ident = &item_const.ident;
597 quote! { #ident }
598 });
599 let spec_for_native = spec_tokens.clone();
600 let builtin_path_lit =
601 builtin_path.expect("register_gpu_spec requires `builtin_path = \"...\"` argument");
602 let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
603 .expect("register_gpu_spec `builtin_path` must be a valid path");
604 let helper_ident = format_ident!(
605 "__runmat_wasm_register_gpu_spec_{}",
606 item_const.ident.to_string()
607 );
608 let spec_tokens_helper = spec_tokens.clone();
609 let wasm_helper = quote! {
610 #[cfg(target_arch = "wasm32")]
611 #[allow(non_snake_case)]
612 pub(crate) fn #helper_ident() {
613 crate::builtins::common::spec::wasm_registry::submit_gpu_spec(&#spec_tokens_helper);
614 }
615 };
616 append_wasm_block(quote! {
617 #builtin_path::#helper_ident();
618 });
619 let expanded = quote! {
620 #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
621 #item_const
622 #wasm_helper
623 #[cfg(not(target_arch = "wasm32"))]
624 inventory::submit! {
625 crate::builtins::common::spec::GpuSpecInventory { spec: &#spec_for_native }
626 }
627 };
628 expanded.into()
629}
630
631#[proc_macro_attribute]
632pub fn register_fusion_spec(attr: TokenStream, item: TokenStream) -> TokenStream {
633 let args = parse_macro_input!(attr as RegisterSpecAttrArgs);
634 let RegisterSpecAttrArgs {
635 spec_expr,
636 builtin_path,
637 } = args;
638 let item_const = parse_macro_input!(item as ItemConst);
639 let spec_tokens = spec_expr.map(|expr| quote! { #expr }).unwrap_or_else(|| {
640 let ident = &item_const.ident;
641 quote! { #ident }
642 });
643 let spec_for_native = spec_tokens.clone();
644 let builtin_path_lit =
645 builtin_path.expect("register_fusion_spec requires `builtin_path = \"...\"` argument");
646 let builtin_path: syn::Path = syn::parse_str(&builtin_path_lit.value())
647 .expect("register_fusion_spec `builtin_path` must be a valid path");
648 let helper_ident = format_ident!(
649 "__runmat_wasm_register_fusion_spec_{}",
650 item_const.ident.to_string()
651 );
652 let spec_tokens_helper = spec_tokens.clone();
653 let wasm_helper = quote! {
654 #[cfg(target_arch = "wasm32")]
655 #[allow(non_snake_case)]
656 pub(crate) fn #helper_ident() {
657 crate::builtins::common::spec::wasm_registry::submit_fusion_spec(&#spec_tokens_helper);
658 }
659 };
660 append_wasm_block(quote! {
661 #builtin_path::#helper_ident();
662 });
663 let expanded = quote! {
664 #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
665 #item_const
666 #wasm_helper
667 #[cfg(not(target_arch = "wasm32"))]
668 inventory::submit! {
669 crate::builtins::common::spec::FusionSpecInventory { spec: &#spec_for_native }
670 }
671 };
672 expanded.into()
673}
674
675fn append_wasm_block(block: proc_macro2::TokenStream) {
676 if !should_generate_wasm_registry() {
677 return;
678 }
679 let path = match wasm_registry_path() {
680 Some(p) => p,
681 None => return,
682 };
683 let _guard = wasm_registry_lock().lock().unwrap();
684 initialize_registry_file(path);
685 let mut contents = fs::read_to_string(path).expect("failed to read wasm registry file");
686 let insertion = format!(" {}\n", block);
687 if let Some(pos) = contents.rfind('}') {
688 contents.insert_str(pos, &insertion);
689 } else {
690 contents.push_str(&insertion);
691 contents.push_str("}\n");
692 }
693 fs::write(path, contents).expect("failed to update wasm registry file");
694}
695
696fn wasm_registry_path() -> Option<&'static PathBuf> {
697 WASM_REGISTRY_PATH
698 .get_or_init(workspace_registry_path)
699 .as_ref()
700}
701
702fn wasm_registry_lock() -> &'static Mutex<()> {
703 WASM_REGISTRY_LOCK.get_or_init(|| Mutex::new(()))
704}
705
706fn initialize_registry_file(path: &Path) {
707 WASM_REGISTRY_INIT.get_or_init(|| {
708 if let Some(parent) = path.parent() {
709 let _ = fs::create_dir_all(parent);
710 }
711 const HEADER: &str = "pub fn register_all() {\n}\n";
712 fs::write(path, HEADER).expect("failed to create wasm registry file");
713 });
714}
715
716fn should_generate_wasm_registry() -> bool {
717 matches!(
724 std::env::var("RUNMAT_GENERATE_WASM_REGISTRY"),
725 Ok(ref value) if value == "1"
726 )
727}
728
729fn workspace_registry_path() -> Option<PathBuf> {
730 let mut dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").ok()?);
731 loop {
732 if dir.join("Cargo.lock").exists() {
733 return Some(dir.join("target").join("runmat_wasm_registry.rs"));
734 }
735 if !dir.pop() {
736 return None;
737 }
738 }
739}
740
741fn helper_ident_from_name(prefix: &str, name: &str) -> proc_macro2::Ident {
742 let mut sanitized = String::new();
743 for ch in name.chars() {
744 if ch.is_ascii_alphanumeric() || ch == '_' {
745 sanitized.push(ch);
746 } else {
747 sanitized.push('_');
748 }
749 }
750 format_ident!("{}{}", prefix, sanitized)
751}
752
753fn infer_builtin_type(ty: &syn::Type) -> proc_macro2::TokenStream {
755 use syn::Type;
756
757 match ty {
758 Type::Path(type_path) => {
760 if let Some(ident) = type_path.path.get_ident() {
761 match ident.to_string().as_str() {
762 "i32" | "i64" | "isize" => quote! { runmat_builtins::Type::Int },
763 "f32" | "f64" => quote! { runmat_builtins::Type::Num },
764 "bool" => quote! { runmat_builtins::Type::Bool },
765 "String" => quote! { runmat_builtins::Type::String },
766 _ => infer_complex_type(type_path),
767 }
768 } else {
769 infer_complex_type(type_path)
770 }
771 }
772
773 Type::Reference(type_ref) => match type_ref.elem.as_ref() {
775 Type::Path(type_path) => {
776 if let Some(ident) = type_path.path.get_ident() {
777 match ident.to_string().as_str() {
778 "str" => quote! { runmat_builtins::Type::String },
779 _ => infer_builtin_type(&type_ref.elem),
780 }
781 } else {
782 infer_builtin_type(&type_ref.elem)
783 }
784 }
785 _ => infer_builtin_type(&type_ref.elem),
786 },
787
788 Type::Slice(type_slice) => {
790 let element_type = infer_builtin_type(&type_slice.elem);
791 quote! { runmat_builtins::Type::Cell {
792 element_type: Some(Box::new(#element_type)),
793 length: None
794 } }
795 }
796
797 Type::Array(type_array) => {
799 let element_type = infer_builtin_type(&type_array.elem);
800 if let syn::Expr::Lit(expr_lit) = &type_array.len {
802 if let syn::Lit::Int(lit_int) = &expr_lit.lit {
803 if let Ok(length) = lit_int.base10_parse::<usize>() {
804 return quote! { runmat_builtins::Type::Cell {
805 element_type: Some(Box::new(#element_type)),
806 length: Some(#length)
807 } };
808 }
809 }
810 }
811 quote! { runmat_builtins::Type::Cell {
813 element_type: Some(Box::new(#element_type)),
814 length: None
815 } }
816 }
817
818 _ => quote! { runmat_builtins::Type::Unknown },
820 }
821}
822
823fn infer_complex_type(type_path: &syn::TypePath) -> proc_macro2::TokenStream {
825 let path_str = quote! { #type_path }.to_string();
826
827 if path_str.contains("Matrix") || path_str.contains("Tensor") {
829 quote! { runmat_builtins::Type::tensor() }
830 } else if path_str.contains("Value") {
831 quote! { runmat_builtins::Type::Unknown } } else if path_str.starts_with("Result") {
833 if let syn::PathArguments::AngleBracketed(angle_bracketed) =
835 &type_path.path.segments.last().unwrap().arguments
836 {
837 if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
838 return infer_builtin_type(ty);
839 }
840 }
841 quote! { runmat_builtins::Type::Unknown }
842 } else if path_str.starts_with("Option") {
843 if let syn::PathArguments::AngleBracketed(angle_bracketed) =
845 &type_path.path.segments.last().unwrap().arguments
846 {
847 if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
848 return infer_builtin_type(ty);
849 }
850 }
851 quote! { runmat_builtins::Type::Unknown }
852 } else if path_str.starts_with("Vec") {
853 if let syn::PathArguments::AngleBracketed(angle_bracketed) =
855 &type_path.path.segments.last().unwrap().arguments
856 {
857 if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
858 let element_type = infer_builtin_type(ty);
859 return quote! { runmat_builtins::Type::Cell {
860 element_type: Some(Box::new(#element_type)),
861 length: None
862 } };
863 }
864 }
865 quote! { runmat_builtins::Type::cell() }
866 } else {
867 quote! { runmat_builtins::Type::Unknown }
869 }
870}