Skip to main content

ferrotorch_nn_derive/
lib.rs

1//! Derive macro for the `Module<T>` trait in `ferrotorch-nn`.
2//!
3//! Generates the boilerplate methods (`parameters`, `parameters_mut`,
4//! `named_parameters`, `train`, `eval`, `is_training`) so the user only
5//! needs to write `forward()`.
6//!
7//! # Field attributes
8//!
9//! | Attribute       | Meaning                                                |
10//! |-----------------|--------------------------------------------------------|
11//! | `#[param]`      | This field is a `Parameter<T>` — registered directly.  |
12//! | `#[submodule]`  | This field implements `Module<T>` — recurse into it.   |
13//! | `#[skip]`       | Ignore this field entirely.                            |
14//! | *(none)*        | Ignored (same as `#[skip]`), except for `training: bool` which is managed automatically. |
15//!
16//! The struct **must** contain a `training: bool` field. The derive will
17//! generate `train()`, `eval()`, and `is_training()` using it, and will
18//! propagate train/eval to all `#[submodule]` fields.
19//!
20//! # Example
21//!
22//! The example below is marked `ignore` because this is a `proc-macro` crate:
23//! it cannot itself import `ferrotorch_nn::Module` or `ferrotorch_core::Tensor`
24//! at doctest-compile time (proc-macro crates can only export proc-macro items
25//! and pull procedural-macro deps; they cannot depend on consumer crates).
26//! The example is exercised end-to-end by the integration tests in
27//! `ferrotorch-nn/tests/derive_module.rs`.
28//!
29//! ```ignore
30//! use ferrotorch_nn::{Module, Parameter, Linear};
31//! use ferrotorch_nn_derive::Module;
32//!
33//! #[derive(Module)]
34//! struct MyModel<T: Float> {
35//!     #[param]     weight: Parameter<T>,
36//!     #[param]     bias: Parameter<T>,
37//!     #[submodule] layer1: Linear<T>,
38//!     #[submodule] layer2: Linear<T>,
39//!     #[skip]      hidden_size: usize,
40//!     training: bool,
41//! }
42//! ```
43
44#![warn(clippy::all, clippy::pedantic)]
45#![deny(rust_2018_idioms, missing_debug_implementations)]
46#![allow(missing_docs)] // tracked workspace-wide in the rustdoc pass
47
48use proc_macro::TokenStream;
49use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
50use quote::quote;
51use syn::{Data, DeriveInput, Fields, GenericParam, Generics, TypeParam, parse_macro_input};
52
53/// Derive the `Module<T>` trait for a struct.
54///
55/// See the [crate-level documentation](crate) for attribute usage.
56#[proc_macro_derive(Module, attributes(param, submodule, skip))]
57pub fn derive_module(input: TokenStream) -> TokenStream {
58    let input = parse_macro_input!(input as DeriveInput);
59    match derive_module_impl(input) {
60        Ok(tokens) => tokens.into(),
61        Err(err) => err.to_compile_error().into(),
62    }
63}
64
65// ---------------------------------------------------------------------------
66// Internal implementation
67// ---------------------------------------------------------------------------
68
69/// Classification of a struct field for code generation.
70#[derive(Debug)]
71enum FieldKind {
72    /// `#[param]` — a `Parameter<T>` field.
73    Param,
74    /// `#[submodule]` — a field that implements `Module<T>`.
75    Submodule,
76    /// The `training: bool` field managed by the derive.
77    Training,
78    /// `#[skip]` or unannotated — ignored.
79    Skip,
80}
81
82#[derive(Debug)]
83struct ClassifiedField {
84    ident: Ident,
85    kind: FieldKind,
86}
87
88// `derive_module_impl` exceeds clippy's 100-line threshold because it is the
89// single code-generation entry point: classify fields, validate, find float
90// param, build six method bodies, then assemble one `quote!` block.
91// Splitting purely to satisfy the lint would scatter the `quote!` template
92// across helpers that share a dozen captured locals — net readability loss.
93#[allow(clippy::too_many_lines)]
94// `parse_macro_input!` (in `derive_module`) yields an owned `DeriveInput`,
95// which is the standard proc-macro shape; taking by reference here would
96// force callers to bind a temporary first. This is a proc-macro convention,
97// not a hot-path concern.
98#[allow(clippy::needless_pass_by_value)]
99fn derive_module_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
100    let name = &input.ident;
101    let generics = &input.generics;
102
103    // --- Extract fields (named structs only) --------------------------------
104
105    let fields = match &input.data {
106        Data::Struct(data) => match &data.fields {
107            Fields::Named(fields) => &fields.named,
108            _ => {
109                return Err(syn::Error::new_spanned(
110                    name,
111                    "#[derive(Module)] only supports structs with named fields",
112                ));
113            }
114        },
115        _ => {
116            return Err(syn::Error::new_spanned(
117                name,
118                "#[derive(Module)] only supports structs",
119            ));
120        }
121    };
122
123    // --- Classify each field ------------------------------------------------
124
125    let mut classified: Vec<ClassifiedField> = Vec::new();
126    let mut has_training = false;
127
128    for field in fields {
129        // We've already matched `Fields::Named(...)` above, so every field
130        // here has an ident. Defending in depth against a future refactor
131        // that broadens the match: surface a `compile_error!` rather than
132        // an `unwrap`-ICE if this invariant is ever violated.
133        let ident = field
134            .ident
135            .as_ref()
136            .ok_or_else(|| {
137                syn::Error::new_spanned(
138                    field,
139                    "ferrotorch-nn-derive: expected named field (this is a bug — \
140                     please report at https://github.com/ferrotorch/ferrotorch/issues)",
141                )
142            })?
143            .clone();
144
145        let has_param = field.attrs.iter().any(|a| a.path().is_ident("param"));
146        let has_submodule = field.attrs.iter().any(|a| a.path().is_ident("submodule"));
147        let has_skip = field.attrs.iter().any(|a| a.path().is_ident("skip"));
148
149        // Validate: at most one of #[param], #[submodule], #[skip].
150        let attr_count = u8::from(has_param) + u8::from(has_submodule) + u8::from(has_skip);
151        if attr_count > 1 {
152            return Err(syn::Error::new_spanned(
153                field,
154                "field cannot have more than one of #[param], #[submodule], #[skip]",
155            ));
156        }
157
158        let kind = if has_param {
159            FieldKind::Param
160        } else if has_submodule {
161            FieldKind::Submodule
162        } else if has_skip {
163            FieldKind::Skip
164        } else if ident == "training" {
165            has_training = true;
166            FieldKind::Training
167        } else {
168            // Unannotated and not `training` — skip by default.
169            FieldKind::Skip
170        };
171
172        classified.push(ClassifiedField { ident, kind });
173    }
174
175    if !has_training {
176        return Err(syn::Error::new(
177            Span::call_site(),
178            "#[derive(Module)] requires a `training: bool` field",
179        ));
180    }
181
182    // --- Find the Float type parameter --------------------------------------
183    // We look for a type parameter that has a `Float` bound.
184    // If none is found, we fall back to the first type parameter.
185
186    let float_param = find_float_param(generics)?;
187
188    // --- Generate method bodies ---------------------------------------------
189
190    let params: Vec<&ClassifiedField> = classified
191        .iter()
192        .filter(|f| matches!(f.kind, FieldKind::Param))
193        .collect();
194    let submodules: Vec<&ClassifiedField> = classified
195        .iter()
196        .filter(|f| matches!(f.kind, FieldKind::Submodule))
197        .collect();
198
199    // parameters(&self) -> Vec<&Parameter<T>>
200    let parameters_body = {
201        let param_pushes = params.iter().map(|f| {
202            let id = &f.ident;
203            quote! { params.push(&self.#id); }
204        });
205        let submod_extends = submodules.iter().map(|f| {
206            let id = &f.ident;
207            quote! { params.extend(self.#id.parameters()); }
208        });
209        quote! {
210            let mut params = ::std::vec::Vec::new();
211            #(#param_pushes)*
212            #(#submod_extends)*
213            params
214        }
215    };
216
217    // parameters_mut(&mut self) -> Vec<&mut Parameter<T>>
218    let parameters_mut_body = {
219        let param_pushes = params.iter().map(|f| {
220            let id = &f.ident;
221            quote! { params.push(&mut self.#id); }
222        });
223        let submod_extends = submodules.iter().map(|f| {
224            let id = &f.ident;
225            quote! { params.extend(self.#id.parameters_mut()); }
226        });
227        quote! {
228            let mut params = ::std::vec::Vec::new();
229            #(#param_pushes)*
230            #(#submod_extends)*
231            params
232        }
233    };
234
235    // named_parameters(&self) -> Vec<(String, &Parameter<T>)>
236    let named_parameters_body = {
237        let param_pushes = params.iter().map(|f| {
238            let id = &f.ident;
239            let name_str = id.to_string();
240            quote! { params.push((#name_str.to_string(), &self.#id)); }
241        });
242        let submod_extends = submodules.iter().map(|f| {
243            let id = &f.ident;
244            let prefix = id.to_string();
245            quote! {
246                for (name, p) in self.#id.named_parameters() {
247                    params.push((::std::format!("{}.{}", #prefix, name), p));
248                }
249            }
250        });
251        quote! {
252            let mut params = ::std::vec::Vec::new();
253            #(#param_pushes)*
254            #(#submod_extends)*
255            params
256        }
257    };
258
259    // train(&mut self)
260    let train_body = {
261        let submod_trains = submodules.iter().map(|f| {
262            let id = &f.ident;
263            quote! { self.#id.train(); }
264        });
265        quote! {
266            self.training = true;
267            #(#submod_trains)*
268        }
269    };
270
271    // eval(&mut self)
272    let eval_body = {
273        let submod_evals = submodules.iter().map(|f| {
274            let id = &f.ident;
275            quote! { self.#id.eval(); }
276        });
277        quote! {
278            self.training = false;
279            #(#submod_evals)*
280        }
281    };
282
283    // --- Assemble the impl block --------------------------------------------
284
285    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
286
287    let expanded = quote! {
288        impl #impl_generics ::ferrotorch_nn::Module<#float_param> for #name #ty_generics #where_clause {
289            /// Delegates to the inherent `forward()` method that the user must
290            /// define on this struct. Forgetting to define it produces a
291            /// compile-time error instead of a runtime panic.
292            fn forward(&self, input: &::ferrotorch_core::Tensor<#float_param>) -> ::ferrotorch_core::FerrotorchResult<::ferrotorch_core::Tensor<#float_param>> {
293                self.forward(input)
294            }
295
296            fn parameters(&self) -> ::std::vec::Vec<&::ferrotorch_nn::Parameter<#float_param>> {
297                #parameters_body
298            }
299
300            fn parameters_mut(&mut self) -> ::std::vec::Vec<&mut ::ferrotorch_nn::Parameter<#float_param>> {
301                #parameters_mut_body
302            }
303
304            fn named_parameters(&self) -> ::std::vec::Vec<(::std::string::String, &::ferrotorch_nn::Parameter<#float_param>)> {
305                #named_parameters_body
306            }
307
308            fn train(&mut self) {
309                #train_body
310            }
311
312            fn eval(&mut self) {
313                #eval_body
314            }
315
316            fn is_training(&self) -> bool {
317                self.training
318            }
319        }
320    };
321
322    Ok(expanded)
323}
324
325/// Collect the idents of every type parameter declared on `generics`.
326fn type_param_idents(generics: &Generics) -> Vec<&Ident> {
327    generics
328        .params
329        .iter()
330        .filter_map(|p| match p {
331            GenericParam::Type(TypeParam { ident, .. }) => Some(ident),
332            _ => None,
333        })
334        .collect()
335}
336
337/// True if `path` is a single-segment path with no qualifier or generic args
338/// — i.e. a plain type-parameter reference like `T`, not `Self::Item`,
339/// `<Self as Foo>::T`, or `Vec<T>`.
340fn path_is_plain_type_param(path: &syn::Path) -> bool {
341    path.segments.len() == 1 && matches!(path.segments[0].arguments, syn::PathArguments::None)
342}
343
344/// Find the type parameter with a `Float` bound, or fall back to the first
345/// type parameter. Returns an error if the struct has no type parameters.
346fn find_float_param(generics: &Generics) -> syn::Result<Ident> {
347    let declared = type_param_idents(generics);
348
349    // First pass: look for a parameter declaration with an explicit `Float`
350    // bound (e.g. `T: Float`).
351    for param in &generics.params {
352        if let GenericParam::Type(TypeParam { ident, bounds, .. }) = param {
353            for bound in bounds {
354                if let syn::TypeParamBound::Trait(tb) = bound {
355                    if tb
356                        .path
357                        .segments
358                        .last()
359                        .is_some_and(|seg| seg.ident == "Float")
360                    {
361                        return Ok(ident.clone());
362                    }
363                }
364            }
365        }
366    }
367
368    // Second pass: where-clause predicates (e.g. `where T: Float`).
369    //
370    // We only accept predicates whose bounded type is a plain single-segment
371    // path that names one of the struct's declared type parameters. Anything
372    // else — `Self::Item: Float`, `<Self as Foo>::T: Float`, `Vec<T>: Float`,
373    // or a path naming an undeclared identifier — is not a generic-parameter
374    // bound and must not be picked as the float type. (Earlier versions of
375    // this function used `path.segments.first()`, which silently returned
376    // `Self` for `Self::Item: Float` — the wrong qualifier rather than the
377    // intended type parameter.)
378    if let Some(where_clause) = &generics.where_clause {
379        for predicate in &where_clause.predicates {
380            if let syn::WherePredicate::Type(pt) = predicate {
381                let bounds_float = pt.bounds.iter().any(|bound| {
382                    matches!(
383                        bound,
384                        syn::TypeParamBound::Trait(tb)
385                            if tb.path.segments.last().is_some_and(|seg| seg.ident == "Float")
386                    )
387                });
388                if !bounds_float {
389                    continue;
390                }
391                let syn::Type::Path(tp) = &pt.bounded_ty else {
392                    continue;
393                };
394                if tp.qself.is_some() || !path_is_plain_type_param(&tp.path) {
395                    continue;
396                }
397                let candidate = &tp.path.segments[0].ident;
398                if declared.contains(&candidate) {
399                    return Ok(candidate.clone());
400                }
401            }
402        }
403    }
404
405    // Fallback: use the first type parameter.
406    if let Some(first) = declared.first() {
407        return Ok((*first).clone());
408    }
409
410    Err(syn::Error::new(
411        Span::call_site(),
412        "#[derive(Module)] requires at least one type parameter (e.g., `T: Float`)",
413    ))
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use syn::parse_quote;
420
421    fn float_ident_for(input: &syn::DeriveInput) -> String {
422        find_float_param(&input.generics).unwrap().to_string()
423    }
424
425    #[test]
426    fn picks_inline_bound_param() {
427        // `struct S<T: Float> { ... }` — first pass.
428        let di: syn::DeriveInput = parse_quote! {
429            struct S<T: Float> { x: T }
430        };
431        assert_eq!(float_ident_for(&di), "T");
432    }
433
434    #[test]
435    fn picks_where_clause_param() {
436        // `where T: Float` — second pass.
437        let di: syn::DeriveInput = parse_quote! {
438            struct S<T> where T: Float { x: T }
439        };
440        assert_eq!(float_ident_for(&di), "T");
441    }
442
443    // Regression test for the audit finding: previously, a where-clause with
444    // a multi-segment bounded type like `Self::Item: Float` would match and
445    // return the *first* segment (`Self`) — which is not a generic parameter
446    // at all. The current implementation skips such predicates entirely and
447    // falls back to the first declared type parameter.
448    #[test]
449    fn ignores_associated_type_in_where_clause() {
450        // `where Self::Item: Float` — must NOT be picked. The fallback (first
451        // declared type param, `T`) is the correct answer.
452        let di: syn::DeriveInput = parse_quote! {
453            struct S<T> where Self::Item: Float { x: T }
454        };
455        assert_eq!(float_ident_for(&di), "T");
456    }
457
458    #[test]
459    fn ignores_qself_path_in_where_clause() {
460        // `where <Self as Foo>::T: Float` — must NOT be picked.
461        let di: syn::DeriveInput = parse_quote! {
462            struct S<U> where <Self as Foo>::T: Float { x: U }
463        };
464        assert_eq!(float_ident_for(&di), "U");
465    }
466
467    #[test]
468    fn fallback_when_no_float_bound() {
469        // No `Float` bound anywhere — return the first type parameter.
470        let di: syn::DeriveInput = parse_quote! {
471            struct S<T: Clone> { x: T }
472        };
473        assert_eq!(float_ident_for(&di), "T");
474    }
475
476    #[test]
477    fn errors_when_no_type_params() {
478        let di: syn::DeriveInput = parse_quote! {
479            struct S { x: u32 }
480        };
481        assert!(find_float_param(&di.generics).is_err());
482    }
483}