#![warn(clippy::all, clippy::pedantic)]
#![deny(rust_2018_idioms, missing_debug_implementations)]
#![allow(missing_docs)]
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{FnArg, ItemFn};
#[proc_macro_attribute]
pub fn script(attr: TokenStream, item: TokenStream) -> TokenStream {
match script_impl(attr, item) {
Ok(ts) => ts.into(),
Err(err) => err.to_compile_error().into(),
}
}
fn script_impl(_attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream2> {
let input: ItemFn = syn::parse(item)?;
let vis = &input.vis;
let sig = &input.sig;
let ident = &sig.ident;
let block = &input.block;
let inputs = &sig.inputs;
let output = &sig.output;
let typed_args: Vec<&syn::PatType> = inputs
.iter()
.filter_map(|a| match a {
FnArg::Typed(pt) => Some(pt),
FnArg::Receiver(_) => None,
})
.collect();
let arg_clones: Vec<TokenStream2> = typed_args
.iter()
.map(|pt| {
let pat = &pt.pat;
quote! { #pat .clone() }
})
.collect();
let arg_unpacks: Vec<TokenStream2> = typed_args
.iter()
.enumerate()
.map(|(i, pt)| {
let pat = &pt.pat;
quote! { let #pat = inputs[#i].clone(); }
})
.collect();
let scalar_ty = match output {
syn::ReturnType::Type(_, ty) => extract_tensor_param(ty.as_ref()).ok_or_else(|| {
syn::Error::new_spanned(
ty,
"ferrotorch-jit-script: function must return Tensor<T>, \
FerrotorchResult<Tensor<T>>, or Result<Tensor<T>, _>",
)
})?,
syn::ReturnType::Default => {
return Err(syn::Error::new_spanned(
sig,
"ferrotorch-jit-script: function must declare a return type \
of Tensor<T>, FerrotorchResult<Tensor<T>>, or Result<Tensor<T>, _>",
));
}
};
let user_return_ty: TokenStream2 = match output {
syn::ReturnType::Type(_, ty) => quote! { #ty },
syn::ReturnType::Default => {
quote! { ::ferrotorch_core::FerrotorchResult<::ferrotorch_core::Tensor<#scalar_ty>> }
}
};
let expanded = quote! {
#vis fn #ident ( #inputs ) -> ::ferrotorch_core::FerrotorchResult<
::ferrotorch_jit::TracedModule<#scalar_ty>
> {
let __script_inputs: ::std::vec::Vec<::ferrotorch_core::Tensor<#scalar_ty>> =
vec![ #( #arg_clones ),* ];
let __script_inputs_for_trace: ::std::vec::Vec<::ferrotorch_core::Tensor<#scalar_ty>> =
__script_inputs
.iter()
.map(|t| t.clone().requires_grad_(true))
.collect();
let __graph = ::ferrotorch_jit::trace(
|inputs: &[::ferrotorch_core::Tensor<#scalar_ty>]|
-> ::ferrotorch_core::FerrotorchResult<::ferrotorch_core::Tensor<#scalar_ty>>
{
#( #arg_unpacks )*
let __script_result: #user_return_ty = (|| #block)();
__script_result
},
&__script_inputs_for_trace,
)?;
Ok(::ferrotorch_jit::TracedModule::<#scalar_ty>::new(__graph))
}
};
Ok(expanded)
}
fn extract_tensor_param(ty: &syn::Type) -> Option<TokenStream2> {
extract_tensor_param_inner(ty, 0)
}
const MAX_RETURN_TYPE_DEPTH: u8 = 4;
fn extract_tensor_param_inner(ty: &syn::Type, depth: u8) -> Option<TokenStream2> {
if depth > MAX_RETURN_TYPE_DEPTH {
return None;
}
let syn::Type::Path(p) = ty else {
return None;
};
let path = &p.path;
let last = path.segments.last()?;
let ident_str = last.ident.to_string();
let syn::PathArguments::AngleBracketed(args) = &last.arguments else {
return None;
};
if ident_str == "Tensor" {
if let Some(syn::GenericArgument::Type(t)) = args.args.first() {
let ts = quote! { #t };
return Some(ts);
}
}
if ident_str == "FerrotorchResult" || ident_str == "Result" {
if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
return extract_tensor_param_inner(inner, depth + 1);
}
}
None
}