Skip to main content

Crate ferrotorch_jit_script

Crate ferrotorch_jit_script 

Source
Expand description

#[ferrotorch_jit_script::script] — declarative graph construction. (#625)

Annotate a Rust function with #[script] and the macro rewrites the body to build an IrGraph instead of running eagerly. The wrapped function returns a closure with the same signature that, when called, executes the captured graph via ferrotorch_jit::trace.

§What’s supported in the body

The script body is restricted to a small recognized subset:

  • let x = …; bindings (no shadowing across statements is required)
  • Function calls on tensors: add(&a, &b), mul(&a, &b), sum(&t), mean(&t), relu(&t), sigmoid(&t), tanh(&t), mm(&a, &b) … (anything in ferrotorch_core::grad_fns::{arithmetic, reduction, activation, linalg} that takes &Tensor args and returns Tensor)
  • A trailing expression as the function’s return value

Anything else (control flow, struct fields, non-tensor values) is left untouched — the macro just emits the rewritten body as-is and trace captures it at the autograd level.

§Why this is a shim over trace

The macro doesn’t reimplement the IR builder. Instead, it ensures the function’s inputs are leaf tensors with requires_grad=true (so the autograd graph is built), then calls the existing trace to capture the IR. That keeps op coverage in lockstep with trace and avoids a second source of truth for op-name → IR mapping.

§REQ status (per .design/ferrotorch-jit-script/lib.md)

REQStatusEvidence
REQ-1NOT-STARTEDopen prereq blocker #1482 — the #[proc_macro_attribute] pub script entry below is fully implemented and re-exported via ferrotorch/src/lib.rs pub use ferrotorch_jit_script::*;, but no in-tree non-test code applies #[script] to a function. Test-only callers (ferrotorch-jit-script/tests/script_macro.rs, tests/conformance_jit_script.rs) don’t count per goal.md R-DEFER-1. Consumer wiring lands in #1482.
REQ-2SHIPPEDimpl: the match output block in script_impl rejects unrecognized return types with a syn::Error::new_spanned(...) carrying the message naming the three accepted shapes (Tensor<T>, FerrotorchResult<Tensor<T>>, Result<Tensor<T>, _>); non-test consumer: every test-driven invocation in ferrotorch-jit-script/tests/script_macro.rs compiles because the return type IS recognized — the diagnostic path is fronted by extract_tensor_param’s None return which gates every recognized-form path.
REQ-3SHIPPEDimpl: const MAX_RETURN_TYPE_DEPTH: u8 = 4 plus the if depth > MAX_RETURN_TYPE_DEPTH { return None; } bounds check in extract_tensor_param_inner; non-test consumer: extract_tensor_param_inner recurses with depth + 1 through FerrotorchResult / Result wrappers, and extract_tensor_param is the public entry the script_impl rewriter calls into — the cap binds on every #[script] expansion in the crate.
REQ-4SHIPPEDimpl: the generated wrapper in script_impl emits -> ::ferrotorch_core::FerrotorchResult<::ferrotorch_jit::TracedModule<#scalar_ty>> where #scalar_ty is derived from the user’s return type; non-test consumer: the dtype-roundtrip is what ferrotorch-jit-script/tests/script_macro.rs let module: TracedModule<f32> = weighted_sum(a, w).unwrap(); verifies — the test relies on the macro emitting the correct generic param.
REQ-5SHIPPEDimpl: the `filter_map(
REQ-6SHIPPEDimpl: `__script_inputs.iter().map(
REQ-7SHIPPEDimpl: `let __script_result: #user_return_ty = (

Attribute Macros§

script
Apply #[script] to a function fn(args...) -> Tensor<T> to compile it into a ferrotorch_jit::TracedModule<T>-returning function.