Skip to main content

ferrotorch_jit_script/
lib.rs

1//! `#[ferrotorch_jit_script::script]` — declarative graph construction. (#625)
2//!
3//! Annotate a Rust function with `#[script]` and the macro rewrites the body
4//! to build an `IrGraph` instead of running eagerly. The wrapped function
5//! returns a closure with the same signature that, when called, executes
6//! the captured graph via `ferrotorch_jit::trace`.
7//!
8//! # What's supported in the body
9//!
10//! The script body is restricted to a small recognized subset:
11//!
12//! - `let x = …;` bindings (no shadowing across statements is required)
13//! - Function calls on tensors: `add(&a, &b)`, `mul(&a, &b)`, `sum(&t)`,
14//!   `mean(&t)`, `relu(&t)`, `sigmoid(&t)`, `tanh(&t)`, `mm(&a, &b)` ...
15//!   (anything in `ferrotorch_core::grad_fns::{arithmetic, reduction,
16//!   activation, linalg}` that takes `&Tensor` args and returns `Tensor`)
17//! - A trailing expression as the function's return value
18//!
19//! Anything else (control flow, struct fields, non-tensor values) is left
20//! untouched — the macro just emits the rewritten body as-is and `trace`
21//! captures it at the autograd level.
22//!
23//! # Why this is a shim over `trace`
24//!
25//! The macro doesn't reimplement the IR builder. Instead, it ensures the
26//! function's inputs are leaf tensors with `requires_grad=true` (so the
27//! autograd graph is built), then calls the existing `trace` to capture
28//! the IR. That keeps op coverage in lockstep with trace and avoids a
29//! second source of truth for op-name → IR mapping.
30
31#![warn(clippy::all, clippy::pedantic)]
32#![deny(rust_2018_idioms, missing_debug_implementations)]
33#![allow(missing_docs)] // tracked workspace-wide in the rustdoc pass
34
35use proc_macro::TokenStream;
36use proc_macro2::TokenStream as TokenStream2;
37use quote::quote;
38use syn::{FnArg, ItemFn};
39
40/// Apply `#[script]` to a function `fn(args...) -> Tensor<T>` to compile
41/// it into a `ferrotorch_jit::TracedModule<T>`-returning function.
42///
43/// The example below is marked `ignore` because this is a `proc-macro`
44/// crate: it cannot itself import `ferrotorch_jit::TracedModule` or
45/// `ferrotorch_core::Tensor` at doctest-compile time (proc-macro crates
46/// can only export proc-macro items and pull procedural-macro deps; they
47/// cannot depend on consumer crates). The example is exercised
48/// end-to-end by the integration tests in
49/// `ferrotorch-jit-script/tests/script_macro.rs`.
50///
51/// ```ignore
52/// use ferrotorch_jit_script::script;
53/// use ferrotorch_core::Tensor;
54/// use ferrotorch_core::grad_fns::arithmetic::{mul, add};
55/// use ferrotorch_core::grad_fns::reduction::sum;
56///
57/// #[script]
58/// fn weighted_sum(a: Tensor<f32>, w: Tensor<f32>) -> Tensor<f32> {
59///     let prod = mul(&a, &w)?;
60///     sum(&prod)
61/// }
62///
63/// // `weighted_sum(...)` now returns `FerrotorchResult<TracedModule<f32>>`
64/// // built by tracing the body once with the supplied tensors.
65/// ```
66///
67/// # Errors
68///
69/// Emits a `compile_error!` (via `syn::Error`) if the annotated function's
70/// return type isn't one of the recognized shapes:
71///
72/// - `Tensor<T>`
73/// - `FerrotorchResult<Tensor<T>>`
74/// - `Result<Tensor<T>, _>`
75///
76/// Previously, an unrecognized return type silently fell back to
77/// `TracedModule<f32>`, producing a wrong-dtype wrapper for e.g.
78/// `Tensor<f64>` callers. The macro now refuses the input instead.
79#[proc_macro_attribute]
80pub fn script(attr: TokenStream, item: TokenStream) -> TokenStream {
81    match script_impl(attr, item) {
82        Ok(ts) => ts.into(),
83        Err(err) => err.to_compile_error().into(),
84    }
85}
86
87fn script_impl(_attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream2> {
88    let input: ItemFn = syn::parse(item)?;
89    let vis = &input.vis;
90    let sig = &input.sig;
91    let ident = &sig.ident;
92    let block = &input.block;
93    let inputs = &sig.inputs;
94    let output = &sig.output;
95
96    // Walk the function's typed args once. `arg_clones` is the
97    // example-input slice handed to `trace`; `arg_unpacks` are the
98    // index-based bindings inside the closure body. Both projections
99    // share the same filter (skip `Self` receivers) and ordering, so a
100    // single pass keeps them in lockstep.
101    let typed_args: Vec<&syn::PatType> = inputs
102        .iter()
103        .filter_map(|a| match a {
104            FnArg::Typed(pt) => Some(pt),
105            FnArg::Receiver(_) => None,
106        })
107        .collect();
108    let arg_clones: Vec<TokenStream2> = typed_args
109        .iter()
110        .map(|pt| {
111            let pat = &pt.pat;
112            quote! { #pat .clone() }
113        })
114        .collect();
115    let arg_unpacks: Vec<TokenStream2> = typed_args
116        .iter()
117        .enumerate()
118        .map(|(i, pt)| {
119            let pat = &pt.pat;
120            quote! { let #pat = inputs[#i].clone(); }
121        })
122        .collect();
123
124    // Determine T from the return type. Unrecognized return shapes are
125    // a hard error: silently defaulting to f32 (the historic behaviour)
126    // produced a `TracedModule<f32>` wrapper for callers that returned
127    // e.g. `Tensor<f64>`, with no diagnostic. Failing here surfaces the
128    // mistake at macro-expansion time as a clean `compile_error!`.
129    let scalar_ty = match output {
130        syn::ReturnType::Type(_, ty) => extract_tensor_param(ty.as_ref()).ok_or_else(|| {
131            syn::Error::new_spanned(
132                ty,
133                "ferrotorch-jit-script: function must return Tensor<T>, \
134                 FerrotorchResult<Tensor<T>>, or Result<Tensor<T>, _>",
135            )
136        })?,
137        syn::ReturnType::Default => {
138            return Err(syn::Error::new_spanned(
139                sig,
140                "ferrotorch-jit-script: function must declare a return type \
141                 of Tensor<T>, FerrotorchResult<Tensor<T>>, or Result<Tensor<T>, _>",
142            ));
143        }
144    };
145
146    // Capture the user's return type tokens so the expansion mentions
147    // any names they imported (e.g. `FerrotorchResult`) — otherwise
148    // those imports look unused after macro expansion.
149    let user_return_ty: TokenStream2 = match output {
150        syn::ReturnType::Type(_, ty) => quote! { #ty },
151        syn::ReturnType::Default => {
152            quote! { ::ferrotorch_core::FerrotorchResult<::ferrotorch_core::Tensor<#scalar_ty>> }
153        }
154    };
155
156    // Generated wrapper:
157    // - Builds the example-input slice from the caller's args.
158    // - Defines an inner closure with the user's body (so all the let
159    //   bindings and ops just compile as normal Rust).
160    // - Calls trace(closure, &[args...]) to capture the IR graph.
161    // - Wraps the graph in a TracedModule and returns it.
162    let expanded = quote! {
163        #vis fn #ident ( #inputs ) -> ::ferrotorch_core::FerrotorchResult<
164            ::ferrotorch_jit::TracedModule<#scalar_ty>
165        > {
166            let __script_inputs: ::std::vec::Vec<::ferrotorch_core::Tensor<#scalar_ty>> =
167                vec![ #( #arg_clones ),* ];
168            let __script_inputs_for_trace: ::std::vec::Vec<::ferrotorch_core::Tensor<#scalar_ty>> =
169                __script_inputs
170                    .iter()
171                    .map(|t| t.clone().requires_grad_(true))
172                    .collect();
173            let __graph = ::ferrotorch_jit::trace(
174                |inputs: &[::ferrotorch_core::Tensor<#scalar_ty>]|
175                    -> ::ferrotorch_core::FerrotorchResult<::ferrotorch_core::Tensor<#scalar_ty>>
176                {
177                    #( #arg_unpacks )*
178                    let __script_result: #user_return_ty = (|| #block)();
179                    __script_result
180                },
181                &__script_inputs_for_trace,
182            )?;
183            Ok(::ferrotorch_jit::TracedModule::<#scalar_ty>::new(__graph))
184        }
185    };
186    Ok(expanded)
187}
188
189/// Recognized return-type entry point for the `#[script]` macro.
190///
191/// Returns the dtype `TokenStream` extracted from the annotated function's
192/// return type. The recognized shapes are:
193///
194/// - `Tensor<T>` — yields `T`
195/// - `FerrotorchResult<Tensor<T>>` — yields `T` (recurses through the
196///   `FerrotorchResult` wrapper)
197/// - `Result<Tensor<T>, _>` — yields `T` (recurses through the `Result`
198///   wrapper, ignoring the error type)
199///
200/// Returns `None` when the type doesn't match any of those shapes; the
201/// caller turns that into a `compile_error!` diagnostic.
202///
203/// Recursion through `Result` / `FerrotorchResult` wrappers is bounded
204/// (see `extract_tensor_param_inner`'s depth cap) so a malformed input
205/// like `Result<Result<Result<...>, _>, _>` cannot blow the macro's
206/// stack — it returns `None` once the cap is reached, falling through
207/// to the same `compile_error!` path as any other unrecognized shape.
208fn extract_tensor_param(ty: &syn::Type) -> Option<TokenStream2> {
209    extract_tensor_param_inner(ty, 0)
210}
211
212/// Maximum nesting depth for the recognized `Result<...>` /
213/// `FerrotorchResult<...>` wrappers around `Tensor<T>`.
214///
215/// `Result<Result<FerrotorchResult<Tensor<T>>, _>, _>` is depth 3, which
216/// is already pathological; anything deeper is malformed. The cap exists
217/// to make the recursion total — a hand-crafted, syntactically valid but
218/// nonsense input (an arbitrarily nested `Result<...>`) cannot blow the
219/// macro's stack during expansion.
220const MAX_RETURN_TYPE_DEPTH: u8 = 4;
221
222fn extract_tensor_param_inner(ty: &syn::Type, depth: u8) -> Option<TokenStream2> {
223    if depth > MAX_RETURN_TYPE_DEPTH {
224        return None;
225    }
226    let syn::Type::Path(p) = ty else {
227        return None;
228    };
229    let path = &p.path;
230    let last = path.segments.last()?;
231    let ident_str = last.ident.to_string();
232    let syn::PathArguments::AngleBracketed(args) = &last.arguments else {
233        return None;
234    };
235    if ident_str == "Tensor" {
236        // First generic arg is the scalar type.
237        if let Some(syn::GenericArgument::Type(t)) = args.args.first() {
238            let ts = quote! { #t };
239            return Some(ts);
240        }
241    }
242    if ident_str == "FerrotorchResult" || ident_str == "Result" {
243        if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
244            return extract_tensor_param_inner(inner, depth + 1);
245        }
246    }
247    None
248}