runmat_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    parse_macro_input, AttributeArgs, Expr, FnArg, ItemFn, Lit, Meta, MetaNameValue, NestedMeta,
5    Pat,
6};
7
8/// Attribute used to mark functions as implementing a runtime builtin.
9///
10/// Example:
11/// ```rust,ignore
12/// use runmat_macros::runtime_builtin;
13///
14/// #[runtime_builtin(name = "plot")]
15/// pub fn plot_line(xs: &[f64]) {
16///     /* implementation */
17/// }
18/// ```
19///
20/// This registers the function with the `runmat-builtins` inventory
21/// so the runtime can discover it at start-up.
22#[proc_macro_attribute]
23pub fn runtime_builtin(args: TokenStream, input: TokenStream) -> TokenStream {
24    // Parse attribute arguments as `name = "..."`
25    let args = parse_macro_input!(args as AttributeArgs);
26    let mut name_lit: Option<Lit> = None;
27    let mut category_lit: Option<Lit> = None;
28    let mut summary_lit: Option<Lit> = None;
29    let mut keywords_lit: Option<Lit> = None;
30    let mut errors_lit: Option<Lit> = None;
31    let mut related_lit: Option<Lit> = None;
32    let mut introduced_lit: Option<Lit> = None;
33    let mut status_lit: Option<Lit> = None;
34    let mut examples_lit: Option<Lit> = None;
35    let mut accel_values: Vec<String> = Vec::new();
36    let mut sink_flag = false;
37    for arg in args {
38        if let NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) = arg {
39            if path.is_ident("name") {
40                name_lit = Some(lit);
41            } else if path.is_ident("category") {
42                category_lit = Some(lit);
43            } else if path.is_ident("summary") {
44                summary_lit = Some(lit);
45            } else if path.is_ident("keywords") {
46                keywords_lit = Some(lit);
47            } else if path.is_ident("errors") {
48                errors_lit = Some(lit);
49            } else if path.is_ident("related") {
50                related_lit = Some(lit);
51            } else if path.is_ident("introduced") {
52                introduced_lit = Some(lit);
53            } else if path.is_ident("status") {
54                status_lit = Some(lit);
55            } else if path.is_ident("examples") {
56                examples_lit = Some(lit);
57            } else if path.is_ident("accel") {
58                if let Lit::Str(ls) = lit {
59                    accel_values.extend(
60                        ls.value()
61                            .split(|c: char| c == ',' || c == '|' || c.is_ascii_whitespace())
62                            .filter(|s| !s.is_empty())
63                            .map(|s| s.to_ascii_lowercase()),
64                    );
65                }
66            } else if path.is_ident("sink") {
67                if let Lit::Bool(lb) = lit {
68                    sink_flag = lb.value;
69                }
70            } else {
71                // Gracefully ignore unknown parameters for better IDE experience
72            }
73        }
74    }
75    let name_lit = name_lit.expect("expected `name = \"...\"` argument");
76    let name_str = if let Lit::Str(ref s) = name_lit {
77        s.value()
78    } else {
79        panic!("name must be a string literal");
80    };
81
82    let func: ItemFn = parse_macro_input!(input as ItemFn);
83    let ident = &func.sig.ident;
84
85    // Extract param idents and types
86    let mut param_idents = Vec::new();
87    let mut param_types = Vec::new();
88    for arg in &func.sig.inputs {
89        match arg {
90            FnArg::Typed(pt) => {
91                // pattern must be ident
92                if let Pat::Ident(pi) = pt.pat.as_ref() {
93                    param_idents.push(pi.ident.clone());
94                } else {
95                    panic!("parameters must be simple identifiers");
96                }
97                param_types.push((*pt.ty).clone());
98            }
99            _ => panic!("self parameter not allowed"),
100        }
101    }
102    let param_len = param_idents.len();
103
104    // Infer parameter types for BuiltinFunction
105    let inferred_param_types: Vec<proc_macro2::TokenStream> =
106        param_types.iter().map(infer_builtin_type).collect();
107
108    // Infer return type for BuiltinFunction
109    let inferred_return_type = match &func.sig.output {
110        syn::ReturnType::Default => quote! { runmat_builtins::Type::Void },
111        syn::ReturnType::Type(_, ty) => infer_builtin_type(ty),
112    };
113
114    // Detect if last parameter is variadic Vec<Value>
115    let is_last_variadic = param_types
116        .last()
117        .map(|ty| {
118            // crude detection: type path starts with Vec and inner type is runmat_builtins::Value or Value
119            if let syn::Type::Path(tp) = ty {
120                if tp
121                    .path
122                    .segments
123                    .last()
124                    .map(|s| s.ident == "Vec")
125                    .unwrap_or(false)
126                {
127                    if let syn::PathArguments::AngleBracketed(ab) =
128                        &tp.path.segments.last().unwrap().arguments
129                    {
130                        if let Some(syn::GenericArgument::Type(syn::Type::Path(inner))) =
131                            ab.args.first()
132                        {
133                            return inner
134                                .path
135                                .segments
136                                .last()
137                                .map(|s| s.ident == "Value")
138                                .unwrap_or(false);
139                        }
140                    }
141                }
142            }
143            false
144        })
145        .unwrap_or(false);
146
147    // Generate wrapper ident
148    let wrapper_ident = format_ident!("__rt_wrap_{}", ident);
149
150    let conv_stmts: Vec<proc_macro2::TokenStream> = if is_last_variadic && param_len > 0 {
151        let mut stmts = Vec::new();
152        // Convert fixed params (all but last)
153        for (i, (ident, ty)) in param_idents
154            .iter()
155            .zip(param_types.iter())
156            .enumerate()
157            .take(param_len - 1)
158        {
159            stmts.push(quote! { let #ident : #ty = std::convert::TryInto::try_into(&args[#i])?; });
160        }
161        // Collect the rest into Vec<Value>
162        let last_ident = &param_idents[param_len - 1];
163        stmts.push(quote! {
164            let #last_ident : Vec<runmat_builtins::Value> = {
165                let mut v = Vec::new();
166                for j in (#param_len-1)..args.len() {
167                    let item : runmat_builtins::Value = std::convert::TryInto::try_into(&args[j])?;
168                    v.push(item);
169                }
170                v
171            };
172        });
173        stmts
174    } else {
175        param_idents
176            .iter()
177            .zip(param_types.iter())
178            .enumerate()
179            .map(|(i, (ident, ty))| {
180                quote! { let #ident : #ty = std::convert::TryInto::try_into(&args[#i])?; }
181            })
182            .collect()
183    };
184
185    let wrapper = quote! {
186        fn #wrapper_ident(args: &[runmat_builtins::Value]) -> Result<runmat_builtins::Value, String> {
187            #![allow(unused_variables)]
188            if #is_last_variadic {
189                if args.len() < #param_len - 1 { return Err(format!("expected at least {} args, got {}", #param_len - 1, args.len())); }
190            } else {
191                if args.len() != #param_len { return Err(format!("expected {} args, got {}", #param_len, args.len())); }
192            }
193            #(#conv_stmts)*
194            let res = #ident(#(#param_idents),*)?;
195            Ok(runmat_builtins::Value::from(res))
196        }
197    };
198
199    // Prepare tokens for defaults and options
200    let default_category = syn::LitStr::new("general", proc_macro2::Span::call_site());
201    let default_summary =
202        syn::LitStr::new("Runtime builtin function", proc_macro2::Span::call_site());
203
204    let category_tok: proc_macro2::TokenStream = match &category_lit {
205        Some(syn::Lit::Str(ls)) => quote! { #ls },
206        _ => quote! { #default_category },
207    };
208    let summary_tok: proc_macro2::TokenStream = match &summary_lit {
209        Some(syn::Lit::Str(ls)) => quote! { #ls },
210        _ => quote! { #default_summary },
211    };
212
213    fn opt_tok(lit: &Option<syn::Lit>) -> proc_macro2::TokenStream {
214        if let Some(syn::Lit::Str(ls)) = lit {
215            quote! { Some(#ls) }
216        } else {
217            quote! { None }
218        }
219    }
220    let category_opt_tok = opt_tok(&category_lit);
221    let summary_opt_tok = opt_tok(&summary_lit);
222    let keywords_opt_tok = opt_tok(&keywords_lit);
223    let errors_opt_tok = opt_tok(&errors_lit);
224    let related_opt_tok = opt_tok(&related_lit);
225    let introduced_opt_tok = opt_tok(&introduced_lit);
226    let status_opt_tok = opt_tok(&status_lit);
227    let examples_opt_tok = opt_tok(&examples_lit);
228
229    let accel_tokens: Vec<proc_macro2::TokenStream> = accel_values
230        .iter()
231        .map(|mode| match mode.as_str() {
232            "unary" => quote! { runmat_builtins::AccelTag::Unary },
233            "elementwise" => quote! { runmat_builtins::AccelTag::Elementwise },
234            "reduction" => quote! { runmat_builtins::AccelTag::Reduction },
235            "matmul" => quote! { runmat_builtins::AccelTag::MatMul },
236            "transpose" => quote! { runmat_builtins::AccelTag::Transpose },
237            "array_construct" => quote! { runmat_builtins::AccelTag::ArrayConstruct },
238            _ => quote! {},
239        })
240        .filter(|ts| !ts.is_empty())
241        .collect();
242    let accel_slice = if accel_tokens.is_empty() {
243        quote! { &[] as &[runmat_builtins::AccelTag] }
244    } else {
245        quote! { &[#(#accel_tokens),*] }
246    };
247    let sink_bool = sink_flag;
248
249    let register = quote! {
250        runmat_builtins::inventory::submit! {
251            runmat_builtins::BuiltinFunction::new(
252                #name_str,
253                #summary_tok,
254                #category_tok,
255                "",
256                "",
257                vec![#(#inferred_param_types),*],
258                #inferred_return_type,
259                #wrapper_ident,
260                #accel_slice,
261                #sink_bool,
262            )
263        }
264        runmat_builtins::inventory::submit! {
265            runmat_builtins::BuiltinDoc {
266                name: #name_str,
267                category: #category_opt_tok,
268                summary: #summary_opt_tok,
269                keywords: #keywords_opt_tok,
270                errors: #errors_opt_tok,
271                related: #related_opt_tok,
272                introduced: #introduced_opt_tok,
273                status: #status_opt_tok,
274                examples: #examples_opt_tok,
275            }
276        }
277    };
278
279    TokenStream::from(quote! {
280        #func
281        #wrapper
282        #register
283    })
284}
285
286/// Attribute used to declare a runtime constant.
287///
288/// Example:
289/// ```rust,ignore
290/// use runmat_macros::runtime_constant;
291/// use runmat_builtins::Value;
292///
293/// #[runtime_constant(name = "pi", value = std::f64::consts::PI)]
294/// const PI_CONSTANT: ();
295/// ```
296///
297/// This registers the constant with the `runmat-builtins` inventory
298/// so the runtime can discover it at start-up.
299#[proc_macro_attribute]
300pub fn runtime_constant(args: TokenStream, input: TokenStream) -> TokenStream {
301    let args = parse_macro_input!(args as AttributeArgs);
302    let mut name_lit: Option<Lit> = None;
303    let mut value_expr: Option<Expr> = None;
304
305    for arg in args {
306        match arg {
307            NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => {
308                if path.is_ident("name") {
309                    name_lit = Some(lit);
310                } else {
311                    panic!("Unknown attribute parameter: {}", quote!(#path));
312                }
313            }
314            NestedMeta::Meta(Meta::Path(path)) if path.is_ident("value") => {
315                panic!("value parameter requires assignment: value = expression");
316            }
317            NestedMeta::Lit(lit) => {
318                // This handles the case where value is provided as a literal
319                value_expr = Some(syn::parse_quote!(#lit));
320            }
321            _ => panic!("Invalid attribute syntax"),
322        }
323    }
324
325    let name = match name_lit {
326        Some(Lit::Str(s)) => s.value(),
327        _ => panic!("name parameter must be a string literal"),
328    };
329
330    let value = value_expr.unwrap_or_else(|| {
331        panic!("value parameter is required");
332    });
333
334    let item = parse_macro_input!(input as syn::Item);
335
336    let register = {
337        quote! {
338            #[allow(non_upper_case_globals)]
339            runmat_builtins::inventory::submit! {
340                runmat_builtins::Constant {
341                    name: #name,
342                    value: runmat_builtins::Value::Num(#value),
343                }
344            }
345        }
346    };
347
348    TokenStream::from(quote! {
349        #item
350        #register
351    })
352}
353
354/// Smart type inference from Rust types to our enhanced Type enum
355fn infer_builtin_type(ty: &syn::Type) -> proc_macro2::TokenStream {
356    use syn::Type;
357
358    match ty {
359        // Basic primitive types
360        Type::Path(type_path) => {
361            if let Some(ident) = type_path.path.get_ident() {
362                match ident.to_string().as_str() {
363                    "i32" | "i64" | "isize" => quote! { runmat_builtins::Type::Int },
364                    "f32" | "f64" => quote! { runmat_builtins::Type::Num },
365                    "bool" => quote! { runmat_builtins::Type::Bool },
366                    "String" => quote! { runmat_builtins::Type::String },
367                    _ => infer_complex_type(type_path),
368                }
369            } else {
370                infer_complex_type(type_path)
371            }
372        }
373
374        // Reference types like &str, &Value, &Matrix
375        Type::Reference(type_ref) => match type_ref.elem.as_ref() {
376            Type::Path(type_path) => {
377                if let Some(ident) = type_path.path.get_ident() {
378                    match ident.to_string().as_str() {
379                        "str" => quote! { runmat_builtins::Type::String },
380                        _ => infer_builtin_type(&type_ref.elem),
381                    }
382                } else {
383                    infer_builtin_type(&type_ref.elem)
384                }
385            }
386            _ => infer_builtin_type(&type_ref.elem),
387        },
388
389        // Slice types like &[Value], &[f64]
390        Type::Slice(type_slice) => {
391            let element_type = infer_builtin_type(&type_slice.elem);
392            quote! { runmat_builtins::Type::Cell {
393                element_type: Some(Box::new(#element_type)),
394                length: None
395            } }
396        }
397
398        // Array types like [f64; N]
399        Type::Array(type_array) => {
400            let element_type = infer_builtin_type(&type_array.elem);
401            // Try to extract length if it's a literal
402            if let syn::Expr::Lit(expr_lit) = &type_array.len {
403                if let syn::Lit::Int(lit_int) = &expr_lit.lit {
404                    if let Ok(length) = lit_int.base10_parse::<usize>() {
405                        return quote! { runmat_builtins::Type::Cell {
406                            element_type: Some(Box::new(#element_type)),
407                            length: Some(#length)
408                        } };
409                    }
410                }
411            }
412            // Fallback to unknown length
413            quote! { runmat_builtins::Type::Cell {
414                element_type: Some(Box::new(#element_type)),
415                length: None
416            } }
417        }
418
419        // Generic or complex types
420        _ => quote! { runmat_builtins::Type::Unknown },
421    }
422}
423
424/// Infer types for complex path types like Result<T, E>, Option<T>, Matrix, Value
425fn infer_complex_type(type_path: &syn::TypePath) -> proc_macro2::TokenStream {
426    let path_str = quote! { #type_path }.to_string();
427
428    // Handle common patterns
429    if path_str.contains("Matrix") || path_str.contains("Tensor") {
430        quote! { runmat_builtins::Type::tensor() }
431    } else if path_str.contains("Value") {
432        quote! { runmat_builtins::Type::Unknown } // Value can be anything
433    } else if path_str.starts_with("Result") {
434        // Extract the Ok type from Result<T, E>
435        if let syn::PathArguments::AngleBracketed(angle_bracketed) =
436            &type_path.path.segments.last().unwrap().arguments
437        {
438            if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
439                return infer_builtin_type(ty);
440            }
441        }
442        quote! { runmat_builtins::Type::Unknown }
443    } else if path_str.starts_with("Option") {
444        // Extract the Some type from Option<T>
445        if let syn::PathArguments::AngleBracketed(angle_bracketed) =
446            &type_path.path.segments.last().unwrap().arguments
447        {
448            if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
449                return infer_builtin_type(ty);
450            }
451        }
452        quote! { runmat_builtins::Type::Unknown }
453    } else if path_str.starts_with("Vec") {
454        // Extract element type from Vec<T>
455        if let syn::PathArguments::AngleBracketed(angle_bracketed) =
456            &type_path.path.segments.last().unwrap().arguments
457        {
458            if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
459                let element_type = infer_builtin_type(ty);
460                return quote! { runmat_builtins::Type::Cell {
461                    element_type: Some(Box::new(#element_type)),
462                    length: None
463                } };
464            }
465        }
466        quote! { runmat_builtins::Type::cell() }
467    } else {
468        // Unknown type
469        quote! { runmat_builtins::Type::Unknown }
470    }
471}