use proc_macro::TokenStream;
use quote::quote;
use syn::{FnArg, ItemFn, parse_macro_input};
#[proc_macro_attribute]
pub fn script(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemFn);
let vis = &input.vis;
let sig = &input.sig;
let ident = &sig.ident;
let block = &input.block;
let inputs = &sig.inputs;
let output = &sig.output;
let mut arg_idents: Vec<proc_macro2::TokenStream> = Vec::new();
for arg in inputs {
if let FnArg::Typed(pat_ty) = arg {
let pat = &pat_ty.pat;
arg_idents.push(quote! { #pat .clone() });
}
}
let body_fn_inputs = inputs.clone();
let arg_names: Vec<&syn::Pat> = inputs
.iter()
.filter_map(|a| match a {
FnArg::Typed(pt) => Some(&*pt.pat),
_ => None,
})
.collect();
let arg_unpacks: Vec<proc_macro2::TokenStream> = arg_names
.iter()
.enumerate()
.map(|(i, name)| quote! { let #name = inputs[#i].clone(); })
.collect();
let scalar_ty = match output {
syn::ReturnType::Type(_, ty) => extract_tensor_param(ty).unwrap_or_else(|| quote! { f32 }),
_ => quote! { f32 },
};
let user_return_ty: proc_macro2::TokenStream = match output {
syn::ReturnType::Type(_, ty) => quote! { #ty },
_ => quote! { ::ferrotorch_core::FerrotorchResult<::ferrotorch_core::Tensor<#scalar_ty>> },
};
let expanded = quote! {
#vis fn #ident ( #body_fn_inputs ) -> ::ferrotorch_core::FerrotorchResult<
::ferrotorch_jit::TracedModule<#scalar_ty>
> {
let __script_inputs: ::std::vec::Vec<::ferrotorch_core::Tensor<#scalar_ty>> =
vec![ #( #arg_idents ),* ];
let __script_inputs_for_trace: ::std::vec::Vec<::ferrotorch_core::Tensor<#scalar_ty>> =
__script_inputs
.iter()
.map(|t| t.clone().requires_grad_(true))
.collect();
let __graph = ::ferrotorch_jit::trace(
|inputs: &[::ferrotorch_core::Tensor<#scalar_ty>]|
-> ::ferrotorch_core::FerrotorchResult<::ferrotorch_core::Tensor<#scalar_ty>>
{
#( #arg_unpacks )*
let __script_result: #user_return_ty = (|| #block)();
__script_result
},
&__script_inputs_for_trace,
)?;
Ok(::ferrotorch_jit::TracedModule::<#scalar_ty>::new(__graph))
}
};
expanded.into()
}
fn extract_tensor_param(ty: &syn::Type) -> Option<proc_macro2::TokenStream> {
let path = if let syn::Type::Path(p) = ty {
&p.path
} else {
return None;
};
let last = path.segments.last()?;
let ident_str = last.ident.to_string();
let args = match &last.arguments {
syn::PathArguments::AngleBracketed(a) => a,
_ => return None,
};
if ident_str == "Tensor" {
if let Some(syn::GenericArgument::Type(t)) = args.args.first() {
let ts = quote! { #t };
return Some(ts);
}
}
if ident_str == "FerrotorchResult" || ident_str == "Result" {
if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
return extract_tensor_param(inner);
}
}
None
}