ferrotorch_jit_script/lib.rs
1//! `#[ferrotorch_jit_script::script]` — declarative graph construction. (#625)
2//!
3//! Annotate a Rust function with `#[script]` and the macro rewrites the body
4//! to build an `IrGraph` instead of running eagerly. The wrapped function
5//! returns a closure with the same signature that, when called, executes
6//! the captured graph via `ferrotorch_jit::trace`.
7//!
8//! # What's supported in the body
9//!
10//! The script body is restricted to a small recognized subset:
11//!
12//! - `let x = …;` bindings (no shadowing across statements is required)
13//! - Function calls on tensors: `add(&a, &b)`, `mul(&a, &b)`, `sum(&t)`,
14//! `mean(&t)`, `relu(&t)`, `sigmoid(&t)`, `tanh(&t)`, `mm(&a, &b)` ...
15//! (anything in `ferrotorch_core::grad_fns::{arithmetic, reduction,
16//! activation, linalg}` that takes `&Tensor` args and returns `Tensor`)
17//! - A trailing expression as the function's return value
18//!
19//! Anything else (control flow, struct fields, non-tensor values) is left
20//! untouched — the macro just emits the rewritten body as-is and `trace`
21//! captures it at the autograd level.
22//!
23//! # Why this is a shim over `trace`
24//!
25//! The macro doesn't reimplement the IR builder. Instead, it ensures the
26//! function's inputs are leaf tensors with `requires_grad=true` (so the
27//! autograd graph is built), then calls the existing `trace` to capture
28//! the IR. That keeps op coverage in lockstep with trace and avoids a
29//! second source of truth for op-name → IR mapping.
30
31#![warn(clippy::all, clippy::pedantic)]
32#![deny(rust_2018_idioms, missing_debug_implementations)]
33#![allow(missing_docs)] // tracked workspace-wide in the rustdoc pass
34
35use proc_macro::TokenStream;
36use proc_macro2::TokenStream as TokenStream2;
37use quote::quote;
38use syn::{FnArg, ItemFn};
39
40/// Apply `#[script]` to a function `fn(args...) -> Tensor<T>` to compile
41/// it into a `ferrotorch_jit::TracedModule<T>`-returning function.
42///
43/// The example below is marked `ignore` because this is a `proc-macro`
44/// crate: it cannot itself import `ferrotorch_jit::TracedModule` or
45/// `ferrotorch_core::Tensor` at doctest-compile time (proc-macro crates
46/// can only export proc-macro items and pull procedural-macro deps; they
47/// cannot depend on consumer crates). The example is exercised
48/// end-to-end by the integration tests in
49/// `ferrotorch-jit-script/tests/script_macro.rs`.
50///
51/// ```ignore
52/// use ferrotorch_jit_script::script;
53/// use ferrotorch_core::Tensor;
54/// use ferrotorch_core::grad_fns::arithmetic::{mul, add};
55/// use ferrotorch_core::grad_fns::reduction::sum;
56///
57/// #[script]
58/// fn weighted_sum(a: Tensor<f32>, w: Tensor<f32>) -> Tensor<f32> {
59/// let prod = mul(&a, &w)?;
60/// sum(&prod)
61/// }
62///
63/// // `weighted_sum(...)` now returns `FerrotorchResult<TracedModule<f32>>`
64/// // built by tracing the body once with the supplied tensors.
65/// ```
66///
67/// # Errors
68///
69/// Emits a `compile_error!` (via `syn::Error`) if the annotated function's
70/// return type isn't one of the recognized shapes:
71///
72/// - `Tensor<T>`
73/// - `FerrotorchResult<Tensor<T>>`
74/// - `Result<Tensor<T>, _>`
75///
76/// Previously, an unrecognized return type silently fell back to
77/// `TracedModule<f32>`, producing a wrong-dtype wrapper for e.g.
78/// `Tensor<f64>` callers. The macro now refuses the input instead.
79#[proc_macro_attribute]
80pub fn script(attr: TokenStream, item: TokenStream) -> TokenStream {
81 match script_impl(attr, item) {
82 Ok(ts) => ts.into(),
83 Err(err) => err.to_compile_error().into(),
84 }
85}
86
87fn script_impl(_attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream2> {
88 let input: ItemFn = syn::parse(item)?;
89 let vis = &input.vis;
90 let sig = &input.sig;
91 let ident = &sig.ident;
92 let block = &input.block;
93 let inputs = &sig.inputs;
94 let output = &sig.output;
95
96 // Walk the function's typed args once. `arg_clones` is the
97 // example-input slice handed to `trace`; `arg_unpacks` are the
98 // index-based bindings inside the closure body. Both projections
99 // share the same filter (skip `Self` receivers) and ordering, so a
100 // single pass keeps them in lockstep.
101 let typed_args: Vec<&syn::PatType> = inputs
102 .iter()
103 .filter_map(|a| match a {
104 FnArg::Typed(pt) => Some(pt),
105 FnArg::Receiver(_) => None,
106 })
107 .collect();
108 let arg_clones: Vec<TokenStream2> = typed_args
109 .iter()
110 .map(|pt| {
111 let pat = &pt.pat;
112 quote! { #pat .clone() }
113 })
114 .collect();
115 let arg_unpacks: Vec<TokenStream2> = typed_args
116 .iter()
117 .enumerate()
118 .map(|(i, pt)| {
119 let pat = &pt.pat;
120 quote! { let #pat = inputs[#i].clone(); }
121 })
122 .collect();
123
124 // Determine T from the return type. Unrecognized return shapes are
125 // a hard error: silently defaulting to f32 (the historic behaviour)
126 // produced a `TracedModule<f32>` wrapper for callers that returned
127 // e.g. `Tensor<f64>`, with no diagnostic. Failing here surfaces the
128 // mistake at macro-expansion time as a clean `compile_error!`.
129 let scalar_ty = match output {
130 syn::ReturnType::Type(_, ty) => extract_tensor_param(ty.as_ref()).ok_or_else(|| {
131 syn::Error::new_spanned(
132 ty,
133 "ferrotorch-jit-script: function must return Tensor<T>, \
134 FerrotorchResult<Tensor<T>>, or Result<Tensor<T>, _>",
135 )
136 })?,
137 syn::ReturnType::Default => {
138 return Err(syn::Error::new_spanned(
139 sig,
140 "ferrotorch-jit-script: function must declare a return type \
141 of Tensor<T>, FerrotorchResult<Tensor<T>>, or Result<Tensor<T>, _>",
142 ));
143 }
144 };
145
146 // Capture the user's return type tokens so the expansion mentions
147 // any names they imported (e.g. `FerrotorchResult`) — otherwise
148 // those imports look unused after macro expansion.
149 let user_return_ty: TokenStream2 = match output {
150 syn::ReturnType::Type(_, ty) => quote! { #ty },
151 syn::ReturnType::Default => {
152 quote! { ::ferrotorch_core::FerrotorchResult<::ferrotorch_core::Tensor<#scalar_ty>> }
153 }
154 };
155
156 // Generated wrapper:
157 // - Builds the example-input slice from the caller's args.
158 // - Defines an inner closure with the user's body (so all the let
159 // bindings and ops just compile as normal Rust).
160 // - Calls trace(closure, &[args...]) to capture the IR graph.
161 // - Wraps the graph in a TracedModule and returns it.
162 let expanded = quote! {
163 #vis fn #ident ( #inputs ) -> ::ferrotorch_core::FerrotorchResult<
164 ::ferrotorch_jit::TracedModule<#scalar_ty>
165 > {
166 let __script_inputs: ::std::vec::Vec<::ferrotorch_core::Tensor<#scalar_ty>> =
167 vec![ #( #arg_clones ),* ];
168 let __script_inputs_for_trace: ::std::vec::Vec<::ferrotorch_core::Tensor<#scalar_ty>> =
169 __script_inputs
170 .iter()
171 .map(|t| t.clone().requires_grad_(true))
172 .collect();
173 let __graph = ::ferrotorch_jit::trace(
174 |inputs: &[::ferrotorch_core::Tensor<#scalar_ty>]|
175 -> ::ferrotorch_core::FerrotorchResult<::ferrotorch_core::Tensor<#scalar_ty>>
176 {
177 #( #arg_unpacks )*
178 let __script_result: #user_return_ty = (|| #block)();
179 __script_result
180 },
181 &__script_inputs_for_trace,
182 )?;
183 Ok(::ferrotorch_jit::TracedModule::<#scalar_ty>::new(__graph))
184 }
185 };
186 Ok(expanded)
187}
188
189/// Recognized return-type entry point for the `#[script]` macro.
190///
191/// Returns the dtype `TokenStream` extracted from the annotated function's
192/// return type. The recognized shapes are:
193///
194/// - `Tensor<T>` — yields `T`
195/// - `FerrotorchResult<Tensor<T>>` — yields `T` (recurses through the
196/// `FerrotorchResult` wrapper)
197/// - `Result<Tensor<T>, _>` — yields `T` (recurses through the `Result`
198/// wrapper, ignoring the error type)
199///
200/// Returns `None` when the type doesn't match any of those shapes; the
201/// caller turns that into a `compile_error!` diagnostic.
202///
203/// Recursion through `Result` / `FerrotorchResult` wrappers is bounded
204/// (see `extract_tensor_param_inner`'s depth cap) so a malformed input
205/// like `Result<Result<Result<...>, _>, _>` cannot blow the macro's
206/// stack — it returns `None` once the cap is reached, falling through
207/// to the same `compile_error!` path as any other unrecognized shape.
208fn extract_tensor_param(ty: &syn::Type) -> Option<TokenStream2> {
209 extract_tensor_param_inner(ty, 0)
210}
211
212/// Maximum nesting depth for the recognized `Result<...>` /
213/// `FerrotorchResult<...>` wrappers around `Tensor<T>`.
214///
215/// `Result<Result<FerrotorchResult<Tensor<T>>, _>, _>` is depth 3, which
216/// is already pathological; anything deeper is malformed. The cap exists
217/// to make the recursion total — a hand-crafted, syntactically valid but
218/// nonsense input (an arbitrarily nested `Result<...>`) cannot blow the
219/// macro's stack during expansion.
220const MAX_RETURN_TYPE_DEPTH: u8 = 4;
221
222fn extract_tensor_param_inner(ty: &syn::Type, depth: u8) -> Option<TokenStream2> {
223 if depth > MAX_RETURN_TYPE_DEPTH {
224 return None;
225 }
226 let syn::Type::Path(p) = ty else {
227 return None;
228 };
229 let path = &p.path;
230 let last = path.segments.last()?;
231 let ident_str = last.ident.to_string();
232 let syn::PathArguments::AngleBracketed(args) = &last.arguments else {
233 return None;
234 };
235 if ident_str == "Tensor" {
236 // First generic arg is the scalar type.
237 if let Some(syn::GenericArgument::Type(t)) = args.args.first() {
238 let ts = quote! { #t };
239 return Some(ts);
240 }
241 }
242 if ident_str == "FerrotorchResult" || ident_str == "Result" {
243 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
244 return extract_tensor_param_inner(inner, depth + 1);
245 }
246 }
247 None
248}