ferrotorch_jit_script/lib.rs
1//! `#[ferrotorch_jit_script::script]` — declarative graph construction. (#625)
2//!
3//! Annotate a Rust function with `#[script]` and the macro rewrites the body
4//! to build an `IrGraph` instead of running eagerly. The wrapped function
5//! returns a closure with the same signature that, when called, executes
6//! the captured graph via `ferrotorch_jit::trace`.
7//!
8//! # What's supported in the body
9//!
10//! The script body is restricted to a small recognized subset:
11//!
12//! - `let x = …;` bindings (no shadowing across statements is required)
13//! - Function calls on tensors: `add(&a, &b)`, `mul(&a, &b)`, `sum(&t)`,
14//! `mean(&t)`, `relu(&t)`, `sigmoid(&t)`, `tanh(&t)`, `mm(&a, &b)` ...
15//! (anything in `ferrotorch_core::grad_fns::{arithmetic, reduction,
16//! activation, linalg}` that takes `&Tensor` args and returns `Tensor`)
17//! - A trailing expression as the function's return value
18//!
19//! Anything else (control flow, struct fields, non-tensor values) is left
20//! untouched — the macro just emits the rewritten body as-is and `trace`
21//! captures it at the autograd level.
22//!
23//! # Why this is a shim over `trace`
24//!
25//! The macro doesn't reimplement the IR builder. Instead, it ensures the
26//! function's inputs are leaf tensors with `requires_grad=true` (so the
27//! autograd graph is built), then calls the existing `trace` to capture
28//! the IR. That keeps op coverage in lockstep with trace and avoids a
29//! second source of truth for op-name → IR mapping.
30
31use proc_macro::TokenStream;
32use quote::quote;
33use syn::{FnArg, ItemFn, parse_macro_input};
34
35/// Apply `#[script]` to a function `fn(args...) -> Tensor<T>` to compile
36/// it into a `ferrotorch_jit::TracedModule<T>`-returning function.
37///
38/// Example:
39///
40/// ```ignore
41/// use ferrotorch_jit_script::script;
42/// use ferrotorch_core::Tensor;
43/// use ferrotorch_core::grad_fns::arithmetic::{mul, add};
44/// use ferrotorch_core::grad_fns::reduction::sum;
45///
46/// #[script]
47/// fn weighted_sum(a: Tensor<f32>, w: Tensor<f32>) -> Tensor<f32> {
48/// let prod = mul(&a, &w)?;
49/// sum(&prod)
50/// }
51///
52/// // `weighted_sum(...)` now returns `FerrotorchResult<TracedModule<f32>>`
53/// // built by tracing the body once with the supplied tensors.
54/// ```
55#[proc_macro_attribute]
56pub fn script(_attr: TokenStream, item: TokenStream) -> TokenStream {
57 let input = parse_macro_input!(item as ItemFn);
58 let vis = &input.vis;
59 let sig = &input.sig;
60 let ident = &sig.ident;
61 let block = &input.block;
62 let inputs = &sig.inputs;
63 let output = &sig.output;
64
65 // Collect simple `name: Tensor<T>` arg idents so we can pass them
66 // to `trace` as the example-input slice.
67 let mut arg_idents: Vec<proc_macro2::TokenStream> = Vec::new();
68 for arg in inputs {
69 if let FnArg::Typed(pat_ty) = arg {
70 let pat = &pat_ty.pat;
71 arg_idents.push(quote! { #pat .clone() });
72 }
73 }
74
75 let body_fn_inputs = inputs.clone();
76
77 // Generated wrapper:
78 // - Defines an inner closure with the user's body (so all the let
79 // bindings and ops just compile as normal Rust).
80 // - Calls trace(closure, &[args...]) to capture the IR graph.
81 // - Wraps the graph in a TracedModule and returns it.
82 //
83 // The closure takes `&[Tensor<T>]` (matching trace's signature) and
84 // unpacks them by index into the same arg names the user wrote.
85 let arg_names: Vec<&syn::Pat> = inputs
86 .iter()
87 .filter_map(|a| match a {
88 FnArg::Typed(pt) => Some(&*pt.pat),
89 _ => None,
90 })
91 .collect();
92 let arg_unpacks: Vec<proc_macro2::TokenStream> = arg_names
93 .iter()
94 .enumerate()
95 .map(|(i, name)| quote! { let #name = inputs[#i].clone(); })
96 .collect();
97
98 // Determine T from the return type Tensor<T>: walk the syn Type.
99 // For simplicity we pin T = f32 in the generated code unless the
100 // return is Tensor<f64>. Most callers use f32; f64 falls back to
101 // explicit annotation.
102 let scalar_ty = match output {
103 syn::ReturnType::Type(_, ty) => extract_tensor_param(ty).unwrap_or_else(|| quote! { f32 }),
104 _ => quote! { f32 },
105 };
106
107 // Capture the user's return type tokens so the expansion mentions
108 // any names they imported (e.g. `FerrotorchResult`) — otherwise
109 // those imports look unused after macro expansion.
110 let user_return_ty: proc_macro2::TokenStream = match output {
111 syn::ReturnType::Type(_, ty) => quote! { #ty },
112 _ => quote! { ::ferrotorch_core::FerrotorchResult<::ferrotorch_core::Tensor<#scalar_ty>> },
113 };
114
115 let expanded = quote! {
116 #vis fn #ident ( #body_fn_inputs ) -> ::ferrotorch_core::FerrotorchResult<
117 ::ferrotorch_jit::TracedModule<#scalar_ty>
118 > {
119 let __script_inputs: ::std::vec::Vec<::ferrotorch_core::Tensor<#scalar_ty>> =
120 vec![ #( #arg_idents ),* ];
121 let __script_inputs_for_trace: ::std::vec::Vec<::ferrotorch_core::Tensor<#scalar_ty>> =
122 __script_inputs
123 .iter()
124 .map(|t| t.clone().requires_grad_(true))
125 .collect();
126 let __graph = ::ferrotorch_jit::trace(
127 |inputs: &[::ferrotorch_core::Tensor<#scalar_ty>]|
128 -> ::ferrotorch_core::FerrotorchResult<::ferrotorch_core::Tensor<#scalar_ty>>
129 {
130 #( #arg_unpacks )*
131 let __script_result: #user_return_ty = (|| #block)();
132 __script_result
133 },
134 &__script_inputs_for_trace,
135 )?;
136 Ok(::ferrotorch_jit::TracedModule::<#scalar_ty>::new(__graph))
137 }
138 };
139 expanded.into()
140}
141
142/// Extract the `T` from `Tensor<T>` (or `FerrotorchResult<Tensor<T>>`,
143/// or `Result<Tensor<T>, _>`). Returns `None` if the shape doesn't match.
144fn extract_tensor_param(ty: &syn::Type) -> Option<proc_macro2::TokenStream> {
145 let path = if let syn::Type::Path(p) = ty {
146 &p.path
147 } else {
148 return None;
149 };
150 let last = path.segments.last()?;
151 let ident_str = last.ident.to_string();
152 let args = match &last.arguments {
153 syn::PathArguments::AngleBracketed(a) => a,
154 _ => return None,
155 };
156 if ident_str == "Tensor" {
157 // First generic arg is the scalar type.
158 if let Some(syn::GenericArgument::Type(t)) = args.args.first() {
159 let ts = quote! { #t };
160 return Some(ts);
161 }
162 }
163 if ident_str == "FerrotorchResult" || ident_str == "Result" {
164 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
165 return extract_tensor_param(inner);
166 }
167 }
168 None
169}