1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 parse_macro_input, AttributeArgs, Expr, FnArg, ItemFn, Lit, Meta, MetaNameValue, NestedMeta,
5 Pat,
6};
7
8#[proc_macro_attribute]
23pub fn runtime_builtin(args: TokenStream, input: TokenStream) -> TokenStream {
24 let args = parse_macro_input!(args as AttributeArgs);
26 let mut name_lit: Option<Lit> = None;
27 let mut category_lit: Option<Lit> = None;
28 let mut summary_lit: Option<Lit> = None;
29 let mut keywords_lit: Option<Lit> = None;
30 let mut errors_lit: Option<Lit> = None;
31 let mut related_lit: Option<Lit> = None;
32 let mut introduced_lit: Option<Lit> = None;
33 let mut status_lit: Option<Lit> = None;
34 let mut examples_lit: Option<Lit> = None;
35 let mut accel_values: Vec<String> = Vec::new();
36 let mut sink_flag = false;
37 for arg in args {
38 if let NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) = arg {
39 if path.is_ident("name") {
40 name_lit = Some(lit);
41 } else if path.is_ident("category") {
42 category_lit = Some(lit);
43 } else if path.is_ident("summary") {
44 summary_lit = Some(lit);
45 } else if path.is_ident("keywords") {
46 keywords_lit = Some(lit);
47 } else if path.is_ident("errors") {
48 errors_lit = Some(lit);
49 } else if path.is_ident("related") {
50 related_lit = Some(lit);
51 } else if path.is_ident("introduced") {
52 introduced_lit = Some(lit);
53 } else if path.is_ident("status") {
54 status_lit = Some(lit);
55 } else if path.is_ident("examples") {
56 examples_lit = Some(lit);
57 } else if path.is_ident("accel") {
58 if let Lit::Str(ls) = lit {
59 accel_values.extend(
60 ls.value()
61 .split(|c: char| c == ',' || c == '|' || c.is_ascii_whitespace())
62 .filter(|s| !s.is_empty())
63 .map(|s| s.to_ascii_lowercase()),
64 );
65 }
66 } else if path.is_ident("sink") {
67 if let Lit::Bool(lb) = lit {
68 sink_flag = lb.value;
69 }
70 } else {
71 }
73 }
74 }
75 let name_lit = name_lit.expect("expected `name = \"...\"` argument");
76 let name_str = if let Lit::Str(ref s) = name_lit {
77 s.value()
78 } else {
79 panic!("name must be a string literal");
80 };
81
82 let func: ItemFn = parse_macro_input!(input as ItemFn);
83 let ident = &func.sig.ident;
84
85 let mut param_idents = Vec::new();
87 let mut param_types = Vec::new();
88 for arg in &func.sig.inputs {
89 match arg {
90 FnArg::Typed(pt) => {
91 if let Pat::Ident(pi) = pt.pat.as_ref() {
93 param_idents.push(pi.ident.clone());
94 } else {
95 panic!("parameters must be simple identifiers");
96 }
97 param_types.push((*pt.ty).clone());
98 }
99 _ => panic!("self parameter not allowed"),
100 }
101 }
102 let param_len = param_idents.len();
103
104 let inferred_param_types: Vec<proc_macro2::TokenStream> =
106 param_types.iter().map(infer_builtin_type).collect();
107
108 let inferred_return_type = match &func.sig.output {
110 syn::ReturnType::Default => quote! { runmat_builtins::Type::Void },
111 syn::ReturnType::Type(_, ty) => infer_builtin_type(ty),
112 };
113
114 let is_last_variadic = param_types
116 .last()
117 .map(|ty| {
118 if let syn::Type::Path(tp) = ty {
120 if tp
121 .path
122 .segments
123 .last()
124 .map(|s| s.ident == "Vec")
125 .unwrap_or(false)
126 {
127 if let syn::PathArguments::AngleBracketed(ab) =
128 &tp.path.segments.last().unwrap().arguments
129 {
130 if let Some(syn::GenericArgument::Type(syn::Type::Path(inner))) =
131 ab.args.first()
132 {
133 return inner
134 .path
135 .segments
136 .last()
137 .map(|s| s.ident == "Value")
138 .unwrap_or(false);
139 }
140 }
141 }
142 }
143 false
144 })
145 .unwrap_or(false);
146
147 let wrapper_ident = format_ident!("__rt_wrap_{}", ident);
149
150 let conv_stmts: Vec<proc_macro2::TokenStream> = if is_last_variadic && param_len > 0 {
151 let mut stmts = Vec::new();
152 for (i, (ident, ty)) in param_idents
154 .iter()
155 .zip(param_types.iter())
156 .enumerate()
157 .take(param_len - 1)
158 {
159 stmts.push(quote! { let #ident : #ty = std::convert::TryInto::try_into(&args[#i])?; });
160 }
161 let last_ident = ¶m_idents[param_len - 1];
163 stmts.push(quote! {
164 let #last_ident : Vec<runmat_builtins::Value> = {
165 let mut v = Vec::new();
166 for j in (#param_len-1)..args.len() {
167 let item : runmat_builtins::Value = std::convert::TryInto::try_into(&args[j])?;
168 v.push(item);
169 }
170 v
171 };
172 });
173 stmts
174 } else {
175 param_idents
176 .iter()
177 .zip(param_types.iter())
178 .enumerate()
179 .map(|(i, (ident, ty))| {
180 quote! { let #ident : #ty = std::convert::TryInto::try_into(&args[#i])?; }
181 })
182 .collect()
183 };
184
185 let wrapper = quote! {
186 fn #wrapper_ident(args: &[runmat_builtins::Value]) -> Result<runmat_builtins::Value, String> {
187 #![allow(unused_variables)]
188 if #is_last_variadic {
189 if args.len() < #param_len - 1 { return Err(format!("expected at least {} args, got {}", #param_len - 1, args.len())); }
190 } else {
191 if args.len() != #param_len { return Err(format!("expected {} args, got {}", #param_len, args.len())); }
192 }
193 #(#conv_stmts)*
194 let res = #ident(#(#param_idents),*)?;
195 Ok(runmat_builtins::Value::from(res))
196 }
197 };
198
199 let default_category = syn::LitStr::new("general", proc_macro2::Span::call_site());
201 let default_summary =
202 syn::LitStr::new("Runtime builtin function", proc_macro2::Span::call_site());
203
204 let category_tok: proc_macro2::TokenStream = match &category_lit {
205 Some(syn::Lit::Str(ls)) => quote! { #ls },
206 _ => quote! { #default_category },
207 };
208 let summary_tok: proc_macro2::TokenStream = match &summary_lit {
209 Some(syn::Lit::Str(ls)) => quote! { #ls },
210 _ => quote! { #default_summary },
211 };
212
213 fn opt_tok(lit: &Option<syn::Lit>) -> proc_macro2::TokenStream {
214 if let Some(syn::Lit::Str(ls)) = lit {
215 quote! { Some(#ls) }
216 } else {
217 quote! { None }
218 }
219 }
220 let category_opt_tok = opt_tok(&category_lit);
221 let summary_opt_tok = opt_tok(&summary_lit);
222 let keywords_opt_tok = opt_tok(&keywords_lit);
223 let errors_opt_tok = opt_tok(&errors_lit);
224 let related_opt_tok = opt_tok(&related_lit);
225 let introduced_opt_tok = opt_tok(&introduced_lit);
226 let status_opt_tok = opt_tok(&status_lit);
227 let examples_opt_tok = opt_tok(&examples_lit);
228
229 let accel_tokens: Vec<proc_macro2::TokenStream> = accel_values
230 .iter()
231 .map(|mode| match mode.as_str() {
232 "unary" => quote! { runmat_builtins::AccelTag::Unary },
233 "elementwise" => quote! { runmat_builtins::AccelTag::Elementwise },
234 "reduction" => quote! { runmat_builtins::AccelTag::Reduction },
235 "matmul" => quote! { runmat_builtins::AccelTag::MatMul },
236 "transpose" => quote! { runmat_builtins::AccelTag::Transpose },
237 "array_construct" => quote! { runmat_builtins::AccelTag::ArrayConstruct },
238 _ => quote! {},
239 })
240 .filter(|ts| !ts.is_empty())
241 .collect();
242 let accel_slice = if accel_tokens.is_empty() {
243 quote! { &[] as &[runmat_builtins::AccelTag] }
244 } else {
245 quote! { &[#(#accel_tokens),*] }
246 };
247 let sink_bool = sink_flag;
248
249 let register = quote! {
250 runmat_builtins::inventory::submit! {
251 runmat_builtins::BuiltinFunction::new(
252 #name_str,
253 #summary_tok,
254 #category_tok,
255 "",
256 "",
257 vec![#(#inferred_param_types),*],
258 #inferred_return_type,
259 #wrapper_ident,
260 #accel_slice,
261 #sink_bool,
262 )
263 }
264 runmat_builtins::inventory::submit! {
265 runmat_builtins::BuiltinDoc {
266 name: #name_str,
267 category: #category_opt_tok,
268 summary: #summary_opt_tok,
269 keywords: #keywords_opt_tok,
270 errors: #errors_opt_tok,
271 related: #related_opt_tok,
272 introduced: #introduced_opt_tok,
273 status: #status_opt_tok,
274 examples: #examples_opt_tok,
275 }
276 }
277 };
278
279 TokenStream::from(quote! {
280 #func
281 #wrapper
282 #register
283 })
284}
285
286#[proc_macro_attribute]
300pub fn runtime_constant(args: TokenStream, input: TokenStream) -> TokenStream {
301 let args = parse_macro_input!(args as AttributeArgs);
302 let mut name_lit: Option<Lit> = None;
303 let mut value_expr: Option<Expr> = None;
304
305 for arg in args {
306 match arg {
307 NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => {
308 if path.is_ident("name") {
309 name_lit = Some(lit);
310 } else {
311 panic!("Unknown attribute parameter: {}", quote!(#path));
312 }
313 }
314 NestedMeta::Meta(Meta::Path(path)) if path.is_ident("value") => {
315 panic!("value parameter requires assignment: value = expression");
316 }
317 NestedMeta::Lit(lit) => {
318 value_expr = Some(syn::parse_quote!(#lit));
320 }
321 _ => panic!("Invalid attribute syntax"),
322 }
323 }
324
325 let name = match name_lit {
326 Some(Lit::Str(s)) => s.value(),
327 _ => panic!("name parameter must be a string literal"),
328 };
329
330 let value = value_expr.unwrap_or_else(|| {
331 panic!("value parameter is required");
332 });
333
334 let item = parse_macro_input!(input as syn::Item);
335
336 let register = {
337 quote! {
338 #[allow(non_upper_case_globals)]
339 runmat_builtins::inventory::submit! {
340 runmat_builtins::Constant {
341 name: #name,
342 value: runmat_builtins::Value::Num(#value),
343 }
344 }
345 }
346 };
347
348 TokenStream::from(quote! {
349 #item
350 #register
351 })
352}
353
354fn infer_builtin_type(ty: &syn::Type) -> proc_macro2::TokenStream {
356 use syn::Type;
357
358 match ty {
359 Type::Path(type_path) => {
361 if let Some(ident) = type_path.path.get_ident() {
362 match ident.to_string().as_str() {
363 "i32" | "i64" | "isize" => quote! { runmat_builtins::Type::Int },
364 "f32" | "f64" => quote! { runmat_builtins::Type::Num },
365 "bool" => quote! { runmat_builtins::Type::Bool },
366 "String" => quote! { runmat_builtins::Type::String },
367 _ => infer_complex_type(type_path),
368 }
369 } else {
370 infer_complex_type(type_path)
371 }
372 }
373
374 Type::Reference(type_ref) => match type_ref.elem.as_ref() {
376 Type::Path(type_path) => {
377 if let Some(ident) = type_path.path.get_ident() {
378 match ident.to_string().as_str() {
379 "str" => quote! { runmat_builtins::Type::String },
380 _ => infer_builtin_type(&type_ref.elem),
381 }
382 } else {
383 infer_builtin_type(&type_ref.elem)
384 }
385 }
386 _ => infer_builtin_type(&type_ref.elem),
387 },
388
389 Type::Slice(type_slice) => {
391 let element_type = infer_builtin_type(&type_slice.elem);
392 quote! { runmat_builtins::Type::Cell {
393 element_type: Some(Box::new(#element_type)),
394 length: None
395 } }
396 }
397
398 Type::Array(type_array) => {
400 let element_type = infer_builtin_type(&type_array.elem);
401 if let syn::Expr::Lit(expr_lit) = &type_array.len {
403 if let syn::Lit::Int(lit_int) = &expr_lit.lit {
404 if let Ok(length) = lit_int.base10_parse::<usize>() {
405 return quote! { runmat_builtins::Type::Cell {
406 element_type: Some(Box::new(#element_type)),
407 length: Some(#length)
408 } };
409 }
410 }
411 }
412 quote! { runmat_builtins::Type::Cell {
414 element_type: Some(Box::new(#element_type)),
415 length: None
416 } }
417 }
418
419 _ => quote! { runmat_builtins::Type::Unknown },
421 }
422}
423
424fn infer_complex_type(type_path: &syn::TypePath) -> proc_macro2::TokenStream {
426 let path_str = quote! { #type_path }.to_string();
427
428 if path_str.contains("Matrix") || path_str.contains("Tensor") {
430 quote! { runmat_builtins::Type::tensor() }
431 } else if path_str.contains("Value") {
432 quote! { runmat_builtins::Type::Unknown } } else if path_str.starts_with("Result") {
434 if let syn::PathArguments::AngleBracketed(angle_bracketed) =
436 &type_path.path.segments.last().unwrap().arguments
437 {
438 if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
439 return infer_builtin_type(ty);
440 }
441 }
442 quote! { runmat_builtins::Type::Unknown }
443 } else if path_str.starts_with("Option") {
444 if let syn::PathArguments::AngleBracketed(angle_bracketed) =
446 &type_path.path.segments.last().unwrap().arguments
447 {
448 if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
449 return infer_builtin_type(ty);
450 }
451 }
452 quote! { runmat_builtins::Type::Unknown }
453 } else if path_str.starts_with("Vec") {
454 if let syn::PathArguments::AngleBracketed(angle_bracketed) =
456 &type_path.path.segments.last().unwrap().arguments
457 {
458 if let Some(syn::GenericArgument::Type(ty)) = angle_bracketed.args.first() {
459 let element_type = infer_builtin_type(ty);
460 return quote! { runmat_builtins::Type::Cell {
461 element_type: Some(Box::new(#element_type)),
462 length: None
463 } };
464 }
465 }
466 quote! { runmat_builtins::Type::cell() }
467 } else {
468 quote! { runmat_builtins::Type::Unknown }
470 }
471}