ferrotorch_nn_derive/lib.rs
1//! Derive macro for the `Module<T>` trait in `ferrotorch-nn`.
2//!
3//! Generates the boilerplate methods (`parameters`, `parameters_mut`,
4//! `named_parameters`, `train`, `eval`, `is_training`) so the user only
5//! needs to write `forward()`.
6//!
7//! # Field attributes
8//!
9//! | Attribute | Meaning |
10//! |-----------------|--------------------------------------------------------|
11//! | `#[param]` | This field is a `Parameter<T>` — registered directly. |
12//! | `#[submodule]` | This field implements `Module<T>` — recurse into it. |
13//! | `#[skip]` | Ignore this field entirely. |
14//! | *(none)* | Ignored (same as `#[skip]`), except for `training: bool` which is managed automatically. |
15//!
16//! The struct **must** contain a `training: bool` field. The derive will
17//! generate `train()`, `eval()`, and `is_training()` using it, and will
18//! propagate train/eval to all `#[submodule]` fields.
19//!
20//! # Example
21//!
22//! The example below is marked `ignore` because this is a `proc-macro` crate:
23//! it cannot itself import `ferrotorch_nn::Module` or `ferrotorch_core::Tensor`
24//! at doctest-compile time (proc-macro crates can only export proc-macro items
25//! and pull procedural-macro deps; they cannot depend on consumer crates).
26//! The example is exercised end-to-end by the integration tests in
27//! `ferrotorch-nn/tests/derive_module.rs`.
28//!
29//! ```ignore
30//! use ferrotorch_nn::{Module, Parameter, Linear};
31//! use ferrotorch_nn_derive::Module;
32//!
33//! #[derive(Module)]
34//! struct MyModel<T: Float> {
35//! #[param] weight: Parameter<T>,
36//! #[param] bias: Parameter<T>,
37//! #[submodule] layer1: Linear<T>,
38//! #[submodule] layer2: Linear<T>,
39//! #[skip] hidden_size: usize,
40//! training: bool,
41//! }
42//! ```
43
44#![warn(clippy::all, clippy::pedantic)]
45#![deny(rust_2018_idioms, missing_debug_implementations)]
46#![allow(missing_docs)] // tracked workspace-wide in the rustdoc pass
47
48use proc_macro::TokenStream;
49use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
50use quote::quote;
51use syn::{Data, DeriveInput, Fields, GenericParam, Generics, TypeParam, parse_macro_input};
52
53/// Derive the `Module<T>` trait for a struct.
54///
55/// See the [crate-level documentation](crate) for attribute usage.
56#[proc_macro_derive(Module, attributes(param, submodule, skip))]
57pub fn derive_module(input: TokenStream) -> TokenStream {
58 let input = parse_macro_input!(input as DeriveInput);
59 match derive_module_impl(input) {
60 Ok(tokens) => tokens.into(),
61 Err(err) => err.to_compile_error().into(),
62 }
63}
64
65// ---------------------------------------------------------------------------
66// Internal implementation
67// ---------------------------------------------------------------------------
68
69/// Classification of a struct field for code generation.
70#[derive(Debug)]
71enum FieldKind {
72 /// `#[param]` — a `Parameter<T>` field.
73 Param,
74 /// `#[submodule]` — a field that implements `Module<T>`.
75 Submodule,
76 /// The `training: bool` field managed by the derive.
77 Training,
78 /// `#[skip]` or unannotated — ignored.
79 Skip,
80}
81
82#[derive(Debug)]
83struct ClassifiedField {
84 ident: Ident,
85 kind: FieldKind,
86}
87
88// `derive_module_impl` exceeds clippy's 100-line threshold because it is the
89// single code-generation entry point: classify fields, validate, find float
90// param, build six method bodies, then assemble one `quote!` block.
91// Splitting purely to satisfy the lint would scatter the `quote!` template
92// across helpers that share a dozen captured locals — net readability loss.
93#[allow(clippy::too_many_lines)]
94// `parse_macro_input!` (in `derive_module`) yields an owned `DeriveInput`,
95// which is the standard proc-macro shape; taking by reference here would
96// force callers to bind a temporary first. This is a proc-macro convention,
97// not a hot-path concern.
98#[allow(clippy::needless_pass_by_value)]
99fn derive_module_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
100 let name = &input.ident;
101 let generics = &input.generics;
102
103 // --- Extract fields (named structs only) --------------------------------
104
105 let fields = match &input.data {
106 Data::Struct(data) => match &data.fields {
107 Fields::Named(fields) => &fields.named,
108 _ => {
109 return Err(syn::Error::new_spanned(
110 name,
111 "#[derive(Module)] only supports structs with named fields",
112 ));
113 }
114 },
115 _ => {
116 return Err(syn::Error::new_spanned(
117 name,
118 "#[derive(Module)] only supports structs",
119 ));
120 }
121 };
122
123 // --- Classify each field ------------------------------------------------
124
125 let mut classified: Vec<ClassifiedField> = Vec::new();
126 let mut has_training = false;
127
128 for field in fields {
129 // We've already matched `Fields::Named(...)` above, so every field
130 // here has an ident. Defending in depth against a future refactor
131 // that broadens the match: surface a `compile_error!` rather than
132 // an `unwrap`-ICE if this invariant is ever violated.
133 let ident = field
134 .ident
135 .as_ref()
136 .ok_or_else(|| {
137 syn::Error::new_spanned(
138 field,
139 "ferrotorch-nn-derive: expected named field (this is a bug — \
140 please report at https://github.com/ferrotorch/ferrotorch/issues)",
141 )
142 })?
143 .clone();
144
145 let has_param = field.attrs.iter().any(|a| a.path().is_ident("param"));
146 let has_submodule = field.attrs.iter().any(|a| a.path().is_ident("submodule"));
147 let has_skip = field.attrs.iter().any(|a| a.path().is_ident("skip"));
148
149 // Validate: at most one of #[param], #[submodule], #[skip].
150 let attr_count = u8::from(has_param) + u8::from(has_submodule) + u8::from(has_skip);
151 if attr_count > 1 {
152 return Err(syn::Error::new_spanned(
153 field,
154 "field cannot have more than one of #[param], #[submodule], #[skip]",
155 ));
156 }
157
158 let kind = if has_param {
159 FieldKind::Param
160 } else if has_submodule {
161 FieldKind::Submodule
162 } else if has_skip {
163 FieldKind::Skip
164 } else if ident == "training" {
165 has_training = true;
166 FieldKind::Training
167 } else {
168 // Unannotated and not `training` — skip by default.
169 FieldKind::Skip
170 };
171
172 classified.push(ClassifiedField { ident, kind });
173 }
174
175 if !has_training {
176 return Err(syn::Error::new(
177 Span::call_site(),
178 "#[derive(Module)] requires a `training: bool` field",
179 ));
180 }
181
182 // --- Find the Float type parameter --------------------------------------
183 // We look for a type parameter that has a `Float` bound.
184 // If none is found, we fall back to the first type parameter.
185
186 let float_param = find_float_param(generics)?;
187
188 // --- Generate method bodies ---------------------------------------------
189
190 let params: Vec<&ClassifiedField> = classified
191 .iter()
192 .filter(|f| matches!(f.kind, FieldKind::Param))
193 .collect();
194 let submodules: Vec<&ClassifiedField> = classified
195 .iter()
196 .filter(|f| matches!(f.kind, FieldKind::Submodule))
197 .collect();
198
199 // parameters(&self) -> Vec<&Parameter<T>>
200 let parameters_body = {
201 let param_pushes = params.iter().map(|f| {
202 let id = &f.ident;
203 quote! { params.push(&self.#id); }
204 });
205 let submod_extends = submodules.iter().map(|f| {
206 let id = &f.ident;
207 quote! { params.extend(self.#id.parameters()); }
208 });
209 quote! {
210 let mut params = ::std::vec::Vec::new();
211 #(#param_pushes)*
212 #(#submod_extends)*
213 params
214 }
215 };
216
217 // parameters_mut(&mut self) -> Vec<&mut Parameter<T>>
218 let parameters_mut_body = {
219 let param_pushes = params.iter().map(|f| {
220 let id = &f.ident;
221 quote! { params.push(&mut self.#id); }
222 });
223 let submod_extends = submodules.iter().map(|f| {
224 let id = &f.ident;
225 quote! { params.extend(self.#id.parameters_mut()); }
226 });
227 quote! {
228 let mut params = ::std::vec::Vec::new();
229 #(#param_pushes)*
230 #(#submod_extends)*
231 params
232 }
233 };
234
235 // named_parameters(&self) -> Vec<(String, &Parameter<T>)>
236 let named_parameters_body = {
237 let param_pushes = params.iter().map(|f| {
238 let id = &f.ident;
239 let name_str = id.to_string();
240 quote! { params.push((#name_str.to_string(), &self.#id)); }
241 });
242 let submod_extends = submodules.iter().map(|f| {
243 let id = &f.ident;
244 let prefix = id.to_string();
245 quote! {
246 for (name, p) in self.#id.named_parameters() {
247 params.push((::std::format!("{}.{}", #prefix, name), p));
248 }
249 }
250 });
251 quote! {
252 let mut params = ::std::vec::Vec::new();
253 #(#param_pushes)*
254 #(#submod_extends)*
255 params
256 }
257 };
258
259 // train(&mut self)
260 let train_body = {
261 let submod_trains = submodules.iter().map(|f| {
262 let id = &f.ident;
263 quote! { self.#id.train(); }
264 });
265 quote! {
266 self.training = true;
267 #(#submod_trains)*
268 }
269 };
270
271 // eval(&mut self)
272 let eval_body = {
273 let submod_evals = submodules.iter().map(|f| {
274 let id = &f.ident;
275 quote! { self.#id.eval(); }
276 });
277 quote! {
278 self.training = false;
279 #(#submod_evals)*
280 }
281 };
282
283 // --- Assemble the impl block --------------------------------------------
284
285 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
286
287 let expanded = quote! {
288 impl #impl_generics ::ferrotorch_nn::Module<#float_param> for #name #ty_generics #where_clause {
289 /// Delegates to the inherent `forward()` method that the user must
290 /// define on this struct. Forgetting to define it produces a
291 /// compile-time error instead of a runtime panic.
292 fn forward(&self, input: &::ferrotorch_core::Tensor<#float_param>) -> ::ferrotorch_core::FerrotorchResult<::ferrotorch_core::Tensor<#float_param>> {
293 self.forward(input)
294 }
295
296 fn parameters(&self) -> ::std::vec::Vec<&::ferrotorch_nn::Parameter<#float_param>> {
297 #parameters_body
298 }
299
300 fn parameters_mut(&mut self) -> ::std::vec::Vec<&mut ::ferrotorch_nn::Parameter<#float_param>> {
301 #parameters_mut_body
302 }
303
304 fn named_parameters(&self) -> ::std::vec::Vec<(::std::string::String, &::ferrotorch_nn::Parameter<#float_param>)> {
305 #named_parameters_body
306 }
307
308 fn train(&mut self) {
309 #train_body
310 }
311
312 fn eval(&mut self) {
313 #eval_body
314 }
315
316 fn is_training(&self) -> bool {
317 self.training
318 }
319 }
320 };
321
322 Ok(expanded)
323}
324
325/// Collect the idents of every type parameter declared on `generics`.
326fn type_param_idents(generics: &Generics) -> Vec<&Ident> {
327 generics
328 .params
329 .iter()
330 .filter_map(|p| match p {
331 GenericParam::Type(TypeParam { ident, .. }) => Some(ident),
332 _ => None,
333 })
334 .collect()
335}
336
337/// True if `path` is a single-segment path with no qualifier or generic args
338/// — i.e. a plain type-parameter reference like `T`, not `Self::Item`,
339/// `<Self as Foo>::T`, or `Vec<T>`.
340fn path_is_plain_type_param(path: &syn::Path) -> bool {
341 path.segments.len() == 1 && matches!(path.segments[0].arguments, syn::PathArguments::None)
342}
343
344/// Find the type parameter with a `Float` bound, or fall back to the first
345/// type parameter. Returns an error if the struct has no type parameters.
346fn find_float_param(generics: &Generics) -> syn::Result<Ident> {
347 let declared = type_param_idents(generics);
348
349 // First pass: look for a parameter declaration with an explicit `Float`
350 // bound (e.g. `T: Float`).
351 for param in &generics.params {
352 if let GenericParam::Type(TypeParam { ident, bounds, .. }) = param {
353 for bound in bounds {
354 if let syn::TypeParamBound::Trait(tb) = bound {
355 if tb
356 .path
357 .segments
358 .last()
359 .is_some_and(|seg| seg.ident == "Float")
360 {
361 return Ok(ident.clone());
362 }
363 }
364 }
365 }
366 }
367
368 // Second pass: where-clause predicates (e.g. `where T: Float`).
369 //
370 // We only accept predicates whose bounded type is a plain single-segment
371 // path that names one of the struct's declared type parameters. Anything
372 // else — `Self::Item: Float`, `<Self as Foo>::T: Float`, `Vec<T>: Float`,
373 // or a path naming an undeclared identifier — is not a generic-parameter
374 // bound and must not be picked as the float type. (Earlier versions of
375 // this function used `path.segments.first()`, which silently returned
376 // `Self` for `Self::Item: Float` — the wrong qualifier rather than the
377 // intended type parameter.)
378 if let Some(where_clause) = &generics.where_clause {
379 for predicate in &where_clause.predicates {
380 if let syn::WherePredicate::Type(pt) = predicate {
381 let bounds_float = pt.bounds.iter().any(|bound| {
382 matches!(
383 bound,
384 syn::TypeParamBound::Trait(tb)
385 if tb.path.segments.last().is_some_and(|seg| seg.ident == "Float")
386 )
387 });
388 if !bounds_float {
389 continue;
390 }
391 let syn::Type::Path(tp) = &pt.bounded_ty else {
392 continue;
393 };
394 if tp.qself.is_some() || !path_is_plain_type_param(&tp.path) {
395 continue;
396 }
397 let candidate = &tp.path.segments[0].ident;
398 if declared.contains(&candidate) {
399 return Ok(candidate.clone());
400 }
401 }
402 }
403 }
404
405 // Fallback: use the first type parameter.
406 if let Some(first) = declared.first() {
407 return Ok((*first).clone());
408 }
409
410 Err(syn::Error::new(
411 Span::call_site(),
412 "#[derive(Module)] requires at least one type parameter (e.g., `T: Float`)",
413 ))
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use syn::parse_quote;
420
421 fn float_ident_for(input: &syn::DeriveInput) -> String {
422 find_float_param(&input.generics).unwrap().to_string()
423 }
424
425 #[test]
426 fn picks_inline_bound_param() {
427 // `struct S<T: Float> { ... }` — first pass.
428 let di: syn::DeriveInput = parse_quote! {
429 struct S<T: Float> { x: T }
430 };
431 assert_eq!(float_ident_for(&di), "T");
432 }
433
434 #[test]
435 fn picks_where_clause_param() {
436 // `where T: Float` — second pass.
437 let di: syn::DeriveInput = parse_quote! {
438 struct S<T> where T: Float { x: T }
439 };
440 assert_eq!(float_ident_for(&di), "T");
441 }
442
443 // Regression test for the audit finding: previously, a where-clause with
444 // a multi-segment bounded type like `Self::Item: Float` would match and
445 // return the *first* segment (`Self`) — which is not a generic parameter
446 // at all. The current implementation skips such predicates entirely and
447 // falls back to the first declared type parameter.
448 #[test]
449 fn ignores_associated_type_in_where_clause() {
450 // `where Self::Item: Float` — must NOT be picked. The fallback (first
451 // declared type param, `T`) is the correct answer.
452 let di: syn::DeriveInput = parse_quote! {
453 struct S<T> where Self::Item: Float { x: T }
454 };
455 assert_eq!(float_ident_for(&di), "T");
456 }
457
458 #[test]
459 fn ignores_qself_path_in_where_clause() {
460 // `where <Self as Foo>::T: Float` — must NOT be picked.
461 let di: syn::DeriveInput = parse_quote! {
462 struct S<U> where <Self as Foo>::T: Float { x: U }
463 };
464 assert_eq!(float_ident_for(&di), "U");
465 }
466
467 #[test]
468 fn fallback_when_no_float_bound() {
469 // No `Float` bound anywhere — return the first type parameter.
470 let di: syn::DeriveInput = parse_quote! {
471 struct S<T: Clone> { x: T }
472 };
473 assert_eq!(float_ident_for(&di), "T");
474 }
475
476 #[test]
477 fn errors_when_no_type_params() {
478 let di: syn::DeriveInput = parse_quote! {
479 struct S { x: u32 }
480 };
481 assert!(find_float_param(&di.generics).is_err());
482 }
483}