Skip to main content

ferrotorch_nn/
module.rs

1//! `Module<T>` trait, `Reduction`, and `StateDict<T>` — the Rust analog of
2//! PyTorch's `torch.nn.Module` base class.
3//!
4//! Every neural-network layer in `ferrotorch-nn` implements `Module<T>`. The
5//! trait composes parameter iteration, buffer iteration, train/eval mode,
6//! device transfer, sub-module walks, hook registration, gradient zeroing,
7//! and state-dict load/save.
8//!
9//! ## REQ status (per `.design/ferrotorch-nn/module.md`)
10//!
11//! | REQ | Status | Evidence |
12//! |---|---|---|
13//! | REQ-1 | SHIPPED | `pub type StateDict<T> = HashMap<String, Tensor<T>>` mirrors `torch/nn/modules/module.py:1980-2060`; consumed by `pub use module::{Module, Reduction, StateDict}` at `lib.rs:218` and every layer's `state_dict()` return. |
14//! | REQ-2 | SHIPPED | `pub enum Reduction { Mean, Sum, None }` mirrors PyTorch's loss `reduction` kwarg; consumed by `ferrotorch-nn/src/loss.rs:19` `use crate::module::Reduction` and `ferrotorch-nn/src/functional.rs:1798`. |
15//! | REQ-3 | SHIPPED | `pub trait Module<T: Float>: Send + Sync` mirrors `torch/nn/modules/module.py:407-441`; consumed by every layer file (`linear.rs`, `conv.rs`, `norm.rs`, `embedding.rs`, etc.) via `impl Module<T> for ...`. |
16//! | REQ-4 | SHIPPED | `forward`, `parameters`, `parameters_mut`, `named_parameters` required methods; consumed by `ferrotorch-optim/src/optimizer.rs:5` and every optimizer (`adam.rs:17`, `adadelta.rs:20`, …) reading `.parameters()`. |
17//! | REQ-5 | SHIPPED | `train`, `eval`, `is_training` required methods; consumed by `ferrotorch-nn/src/container.rs` `Sequential::train` / `::eval` propagation. |
18//! | REQ-6 | SHIPPED | Default `to_device` iterates parameters and buffers, mirroring `torch/nn/modules/module.py:1180-1260`; consumed by downstream `model.to_device(Device::Cuda(0))` calls in model composition code. |
19//! | REQ-7 | SHIPPED | Default `state_dict` unions parameters with buffers, mirroring `module.py:1980-2060`; consumed by `ferrotorch-nn/src/hooks.rs` `HookedModule::state_dict` delegation and SafeTensors export. |
20//! | REQ-8 | SHIPPED | `load_state_dict(strict)` two-pass strict + shape-validate, mirroring `module.py:2150-2310`; consumed by SafeTensors / GGUF loaders in downstream crates. |
21//! | REQ-9 | SHIPPED | `buffers` / `buffers_mut` / `named_buffers` default-empty methods mirror `module.py:2430-2490`; consumed by `ferrotorch-nn/src/norm.rs` `BatchNorm*` overrides and `module.rs` default `load_state_dict` (the `*buf = Buffer::new(...)` site). |
22//! | REQ-10 | SHIPPED | `as_any` default returns `None` for the #984 downcast hook; consumed by `ferrotorch-nn/src/norm.rs` `BatchNorm*` overriding it so `ferrotorch-vision`'s state-dict loader can route to BN running stats. |
23//! | REQ-11 | SHIPPED | `children`, `named_children`, `modules` (`Self: Sized`), `descendants_dyn` (object-safe), `named_modules`, `named_descendants_dyn` mirror `module.py:2510-2640`; consumed by `ferrotorch-nn/src/container.rs` containers and downstream state-dict walkers. |
24//! | REQ-12 | SHIPPED | Empty-parent path branch in `named_descendants_dyn` fixes #1142 (DeepLabV3 BN-buffer routing); consumed by `ferrotorch-vision`'s DeepLabV3 backbone load path; pinned by `module_named_descendants_dyn_empty_parent_no_leading_dot`. |
25//! | REQ-13 | SHIPPED | `with_forward_hook` / `with_forward_pre_hook` / `with_backward_hook` `Self: Sized` methods return `(HookedModule<Self, T>, HookHandle)`; consumed by `ferrotorch-nn/src/hooks.rs` (`HookedModule` is instantiated by these) and downstream observability code. |
26//! | REQ-14 | SHIPPED | Default `zero_grad` walks `parameters()` mirroring `module.py:2700-2740`; consumed by `ferrotorch-optim` and `ferrotorch-train/src/grad_utils.rs` calling `model.zero_grad()` per step. |
27//! | REQ-15 | SHIPPED | Default `requires_grad_(bool)` toggles all parameters mirroring `module.py:2680-2700`; consumed by transfer-learning callsites that freeze the backbone via `backbone.requires_grad_(false)`. |
28//! | REQ-16 | SHIPPED | Default `apply_to_parameters(&mut dyn FnMut(...))` mirrors `torch.nn.Module.apply` for the parameter case; consumed by lazy-init paths in downstream lazy layers (`lazy_linear.rs`, `lazy_conv.rs`). |
29
30use std::collections::HashMap;
31
32use ferrotorch_core::{Device, FerrotorchError, FerrotorchResult, Float, Tensor};
33
34use crate::buffer::Buffer;
35use crate::hooks::{BackwardHook, ForwardHook, ForwardPreHook, HookHandle, HookedModule};
36use crate::parameter::Parameter;
37
38/// A map from parameter names to tensors, used for serialization.
39pub type StateDict<T> = HashMap<String, Tensor<T>>;
40
41/// Reduction mode for loss functions.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum Reduction {
44    /// Return the mean of all losses.
45    Mean,
46    /// Return the sum of all losses.
47    Sum,
48    /// Return the unreduced loss tensor.
49    None,
50}
51
52/// The trait that all neural network layers implement.
53///
54/// Requires `Send + Sync` to match `Tensor<T>`'s thread-safety guarantees.
55pub trait Module<T: Float>: Send + Sync {
56    /// Forward pass. Takes input tensor, returns output tensor.
57    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>;
58
59    /// Iterate over all learnable parameters.
60    fn parameters(&self) -> Vec<&Parameter<T>>;
61
62    /// Iterate over all learnable parameters mutably.
63    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>;
64
65    /// Named parameters for state dict serialization.
66    ///
67    /// Keys use dot-separated paths for nested modules
68    /// (e.g., `"layer1.weight"`, `"layer1.bias"`).
69    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>;
70
71    /// Set training mode. Affects dropout, batchnorm, etc.
72    fn train(&mut self);
73
74    /// Set evaluation mode.
75    fn eval(&mut self);
76
77    /// Whether the module is in training mode.
78    fn is_training(&self) -> bool;
79
80    /// Move all parameters and buffers to a device.
81    ///
82    /// Default implementation iterates `parameters_mut()` and `buffers_mut()`
83    /// and transfers each.
84    fn to_device(&mut self, device: Device) -> FerrotorchResult<()> {
85        for param in self.parameters_mut() {
86            *param = param.to(device)?;
87        }
88        for buffer in self.buffers_mut() {
89            *buffer = buffer.to(device)?;
90        }
91        Ok(())
92    }
93
94    /// Export parameters and buffers as a state dict (torch parity).
95    ///
96    /// Buffers are included alongside parameters since both are persistent
97    /// module state. Keys are dot-separated paths.
98    fn state_dict(&self) -> StateDict<T> {
99        let mut out: StateDict<T> = self
100            .named_parameters()
101            .into_iter()
102            .map(|(name, param)| (name, param.tensor().clone()))
103            .collect();
104        for (name, buffer) in self.named_buffers() {
105            out.insert(name, buffer.tensor().clone());
106        }
107        out
108    }
109
110    // -----------------------------------------------------------------
111    // Buffers — non-trainable persistent state. (#583)
112    // -----------------------------------------------------------------
113
114    /// Iterate over all non-trainable buffers (e.g. running mean / variance
115    /// in BatchNorm). Default returns empty — concrete modules with buffers
116    /// override.
117    fn buffers(&self) -> Vec<&Buffer<T>> {
118        Vec::new()
119    }
120
121    /// Mutable iteration over all buffers. Default returns empty.
122    fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>> {
123        Vec::new()
124    }
125
126    /// Named buffers (dot-separated paths for nested modules). Default
127    /// returns empty.
128    fn named_buffers(&self) -> Vec<(String, &Buffer<T>)> {
129        Vec::new()
130    }
131
132    /// Downcast hook for type-erased buffer-loader dispatch. (#984)
133    ///
134    /// Returns `Some(&self as &dyn Any)` for concrete module types whose
135    /// non-`Buffer<T>` persistent state needs to be applied from a state
136    /// dict (currently `BatchNorm1d` / `BatchNorm2d` / `BatchNorm3d`'s
137    /// running mean / variance / `num_batches_tracked` — see Phase 2
138    /// of the value-parity pipeline in `ferrotorch-vision/tests`).
139    ///
140    /// The default returns `None`, so existing modules are unaffected:
141    /// type-erased callers walking `named_modules()` will simply skip
142    /// modules that do not opt in. Implementors MUST return
143    /// `Some(self)`; returning `Some` for an unrelated `Any` would
144    /// violate the contract.
145    ///
146    /// Why a downcast hook instead of a wider trait surface (e.g. a
147    /// dedicated `set_buffer_value(&self, &str, &Tensor<T>)` method on
148    /// `Module`)? Buffers carrying torch-shaped state (running mean /
149    /// variance, `num_batches_tracked: usize`) currently live outside
150    /// the [`Buffer<T>`] abstraction (BN keeps `Mutex<Vec<f64>>` for
151    /// numerical stability and the integer counter has no `Buffer`
152    /// at all), so a single typed setter on `Module` would force a
153    /// premature unification that #984 explicitly defers. The downcast
154    /// hook keeps `Module` free of BN-specific shape and lets concrete
155    /// modules expose their own typed setters at full precision.
156    fn as_any(&self) -> Option<&dyn std::any::Any> {
157        None
158    }
159
160    // -----------------------------------------------------------------
161    // Submodule iteration. (#583)
162    // -----------------------------------------------------------------
163
164    /// Direct child modules. Default returns empty (leaf module).
165    fn children(&self) -> Vec<&dyn Module<T>> {
166        Vec::new()
167    }
168
169    /// Direct child modules with their attribute names. Default returns
170    /// empty.
171    fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
172        Vec::new()
173    }
174
175    /// All modules in this subtree, depth-first (self first, then each
176    /// child's descendants in order).
177    ///
178    /// Requires `Self: Sized` so we can coerce `self` to `&dyn Module<T>`.
179    /// Trait-object callers can use [`Module::descendants_dyn`] (which yields
180    /// descendants only) and prepend their own reference.
181    fn modules(&self) -> Vec<&dyn Module<T>>
182    where
183        Self: Sized,
184    {
185        let mut out: Vec<&dyn Module<T>> = vec![self];
186        out.extend(self.descendants_dyn());
187        out
188    }
189
190    /// All strict descendants of `self` in depth-first order. Object-safe.
191    fn descendants_dyn(&self) -> Vec<&dyn Module<T>> {
192        let mut out: Vec<&dyn Module<T>> = Vec::new();
193        for child in self.children() {
194            out.push(child);
195            out.extend(child.descendants_dyn());
196        }
197        out
198    }
199
200    /// All modules in this subtree with dot-separated path names. The root
201    /// is named `""`; children paths are joined with `.`.
202    fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>
203    where
204        Self: Sized,
205    {
206        let mut out: Vec<(String, &dyn Module<T>)> = vec![(String::new(), self)];
207        out.extend(self.named_descendants_dyn());
208        out
209    }
210
211    /// Strict descendants with dot-paths. Object-safe.
212    fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)> {
213        let mut out: Vec<(String, &dyn Module<T>)> = Vec::new();
214        for (name, child) in self.named_children() {
215            out.push((name.clone(), child));
216            for (sub_name, sub_module) in child.named_descendants_dyn() {
217                let full = if sub_name.is_empty() {
218                    name.clone()
219                } else if name.is_empty() {
220                    // Transparent wrapper: parent exposes child under "" so
221                    // the child's own naming (e.g. `("backbone", inner)`)
222                    // becomes the canonical path. Without this branch the
223                    // walker would produce `".backbone.X"` (leading dot),
224                    // mismatching state-dict keys like `"backbone.X"`.
225                    // See #1142 for the DeepLabV3 BN-buffer routing bug
226                    // that this branch fixes.
227                    sub_name
228                } else {
229                    format!("{name}.{sub_name}")
230                };
231                out.push((full, sub_module));
232            }
233        }
234        out
235    }
236
237    // -----------------------------------------------------------------
238    // Helpers. (#583)
239    // -----------------------------------------------------------------
240
241    // -----------------------------------------------------------------
242    // Hooks (#606)
243    //
244    // These consume `self` and return a [`HookedModule<Self, T>`] with the
245    // requested hook already registered. Mirrors `torch.nn.Module
246    // .register_*_hook(...)` ergonomically — callers no longer need to
247    // wrap manually with `HookedModule::new(..)` first. Gated on
248    // `Self: Sized` so the trait stays dyn-compatible.
249    //
250    // Named with the `with_*` prefix (rather than `register_*` directly) to
251    // avoid clashing with `HookedModule`'s own inherent `register_*` methods,
252    // which take `&self` and append a hook to an already-wrapped instance.
253    // The two surfaces compose: `Linear::new(..)?.with_forward_hook(h1).0`
254    // is a `HookedModule` that can `.register_forward_hook(h2)` again.
255    // -----------------------------------------------------------------
256
257    /// Wrap this module in a [`HookedModule`] and register a forward hook.
258    /// Returns the wrapper paired with a [`HookHandle`] that can be used to
259    /// remove the hook later. The wrapper implements `Module<T>` itself, so
260    /// it slots into any place the original module did. Mirrors
261    /// `torch.nn.Module.register_forward_hook`.
262    fn with_forward_hook(self, hook: ForwardHook<T>) -> (HookedModule<Self, T>, HookHandle)
263    where
264        Self: Sized,
265    {
266        let wrapped = HookedModule::new(self);
267        let handle = wrapped.register_forward_hook(hook);
268        (wrapped, handle)
269    }
270
271    /// Wrap this module in a [`HookedModule`] and register a forward
272    /// pre-hook. See [`Self::with_forward_hook`]. Mirrors
273    /// `torch.nn.Module.register_forward_pre_hook`.
274    fn with_forward_pre_hook(self, hook: ForwardPreHook<T>) -> (HookedModule<Self, T>, HookHandle)
275    where
276        Self: Sized,
277    {
278        let wrapped = HookedModule::new(self);
279        let handle = wrapped.register_forward_pre_hook(hook);
280        (wrapped, handle)
281    }
282
283    /// Wrap this module in a [`HookedModule`] and register a backward hook.
284    /// See [`Self::with_forward_hook`]. Mirrors
285    /// `torch.nn.Module.register_backward_hook`.
286    fn with_backward_hook(self, hook: BackwardHook<T>) -> (HookedModule<Self, T>, HookHandle)
287    where
288        Self: Sized,
289    {
290        let wrapped = HookedModule::new(self);
291        let handle = wrapped.register_backward_hook(hook);
292        (wrapped, handle)
293    }
294
295    /// Set the gradient of every parameter to `None`.
296    ///
297    /// Equivalent to calling `tensor.zero_grad()` on each parameter's
298    /// underlying tensor. Mirrors `torch.nn.Module.zero_grad`.
299    fn zero_grad(&self) -> FerrotorchResult<()> {
300        for param in self.parameters() {
301            param.tensor().zero_grad()?;
302        }
303        Ok(())
304    }
305
306    /// Toggle `requires_grad` on every parameter (freeze / unfreeze the
307    /// module). Mirrors `torch.nn.Module.requires_grad_`.
308    fn requires_grad_(&mut self, requires_grad: bool) {
309        for param in self.parameters_mut() {
310            param.set_requires_grad(requires_grad);
311        }
312    }
313
314    /// Apply a function to every parameter in this module. Mirrors
315    /// `torch.nn.Module.apply` for the parameter case (true `apply` recurses
316    /// over all submodules; the recursive form requires `&mut dyn Module`
317    /// which conflicts with this trait's `&mut self` borrow).
318    ///
319    /// Takes `&mut dyn FnMut(...)` (rather than a generic closure) so the
320    /// trait stays dyn-compatible — `Box<dyn Module<T>>` is a common usage
321    /// pattern.
322    fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>)) {
323        for param in self.parameters_mut() {
324            f(param);
325        }
326    }
327
328    /// Load parameters from a state dict.
329    ///
330    /// When `strict` is `true` (default), unexpected keys are an error.
331    /// When `false`, unexpected keys are silently ignored and missing
332    /// keys leave existing parameter values unchanged.
333    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
334        // Known keys: union of parameter and buffer paths.
335        let mut known_keys: std::collections::HashSet<String> = self
336            .named_parameters()
337            .iter()
338            .map(|(k, _)| k.clone())
339            .collect();
340        for (k, _) in self.named_buffers() {
341            known_keys.insert(k);
342        }
343
344        if strict {
345            for key in state.keys() {
346                if !known_keys.contains(key) {
347                    return Err(FerrotorchError::InvalidArgument {
348                        message: format!("unexpected key in state_dict: \"{key}\""),
349                    });
350                }
351            }
352        }
353
354        // We need mutable access to parameters. Use named_parameters to get
355        // the mapping, then parameters_mut to actually update.
356        // This two-pass approach avoids borrowing issues.
357        let param_names: Vec<String> = self
358            .named_parameters()
359            .into_iter()
360            .map(|(name, _)| name)
361            .collect();
362
363        let params_mut = self.parameters_mut();
364
365        for (name, param) in param_names.iter().zip(params_mut) {
366            if let Some(tensor) = state.get(name) {
367                if param.shape() != tensor.shape() {
368                    return Err(FerrotorchError::ShapeMismatch {
369                        message: format!(
370                            "state_dict shape mismatch for \"{name}\": expected {:?}, got {:?}",
371                            param.shape(),
372                            tensor.shape()
373                        ),
374                    });
375                }
376                // Replace the parameter data with the loaded tensor.
377                *param = Parameter::new(tensor.clone());
378            } else if strict {
379                return Err(FerrotorchError::InvalidArgument {
380                    message: format!("missing key in state_dict: \"{name}\""),
381                });
382            }
383        }
384
385        // Same dance for buffers.
386        let buffer_names: Vec<String> = self
387            .named_buffers()
388            .into_iter()
389            .map(|(name, _)| name)
390            .collect();
391        let buffers_mut = self.buffers_mut();
392        for (name, buf) in buffer_names.iter().zip(buffers_mut) {
393            if let Some(tensor) = state.get(name) {
394                if buf.shape() != tensor.shape() {
395                    return Err(FerrotorchError::ShapeMismatch {
396                        message: format!(
397                            "state_dict shape mismatch for buffer \"{name}\": expected {:?}, got {:?}",
398                            buf.shape(),
399                            tensor.shape()
400                        ),
401                    });
402                }
403                *buf = Buffer::new(tensor.clone());
404            } else if strict {
405                return Err(FerrotorchError::InvalidArgument {
406                    message: format!("missing buffer key in state_dict: \"{name}\""),
407                });
408            }
409        }
410
411        Ok(())
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    /// A minimal test module with one parameter.
419    struct SimpleModule<T: Float> {
420        weight: Parameter<T>,
421        training: bool,
422    }
423
424    impl<T: Float> SimpleModule<T> {
425        fn new(size: usize) -> FerrotorchResult<Self> {
426            Ok(Self {
427                weight: Parameter::zeros(&[size])?,
428                training: true,
429            })
430        }
431    }
432
433    impl<T: Float> Module<T> for SimpleModule<T> {
434        fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
435            // Just return input for testing.
436            Ok(input.clone())
437        }
438
439        fn parameters(&self) -> Vec<&Parameter<T>> {
440            vec![&self.weight]
441        }
442
443        fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
444            vec![&mut self.weight]
445        }
446
447        fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
448            vec![("weight".to_string(), &self.weight)]
449        }
450
451        fn train(&mut self) {
452            self.training = true;
453        }
454
455        fn eval(&mut self) {
456            self.training = false;
457        }
458
459        fn is_training(&self) -> bool {
460            self.training
461        }
462    }
463
464    #[test]
465    fn test_module_parameters() {
466        let m = SimpleModule::<f32>::new(5).unwrap();
467        assert_eq!(m.parameters().len(), 1);
468        assert_eq!(m.parameters()[0].shape(), &[5]);
469    }
470
471    #[test]
472    fn test_module_named_parameters() {
473        let m = SimpleModule::<f32>::new(3).unwrap();
474        let named = m.named_parameters();
475        assert_eq!(named.len(), 1);
476        assert_eq!(named[0].0, "weight");
477    }
478
479    #[test]
480    fn test_module_train_eval() {
481        let mut m = SimpleModule::<f32>::new(2).unwrap();
482        assert!(m.is_training());
483        m.eval();
484        assert!(!m.is_training());
485        m.train();
486        assert!(m.is_training());
487    }
488
489    #[test]
490    fn test_module_state_dict_roundtrip() {
491        let m = SimpleModule::<f32>::new(4).unwrap();
492        let sd = m.state_dict();
493        assert!(sd.contains_key("weight"));
494        assert_eq!(sd["weight"].shape(), &[4]);
495
496        let mut m2 = SimpleModule::<f32>::new(4).unwrap();
497        m2.load_state_dict(&sd, true).unwrap();
498    }
499
500    #[test]
501    fn test_module_state_dict_strict_extra_key() {
502        let mut m = SimpleModule::<f32>::new(3).unwrap();
503        let mut sd = HashMap::new();
504        sd.insert(
505            "weight".to_string(),
506            ferrotorch_core::zeros::<f32>(&[3]).unwrap(),
507        );
508        sd.insert(
509            "extra".to_string(),
510            ferrotorch_core::zeros::<f32>(&[1]).unwrap(),
511        );
512
513        assert!(m.load_state_dict(&sd, true).is_err());
514        assert!(m.load_state_dict(&sd, false).is_ok());
515    }
516
517    #[test]
518    fn test_module_state_dict_shape_mismatch() {
519        let mut m = SimpleModule::<f32>::new(3).unwrap();
520        let mut sd = HashMap::new();
521        sd.insert(
522            "weight".to_string(),
523            ferrotorch_core::zeros::<f32>(&[5]).unwrap(),
524        );
525
526        assert!(m.load_state_dict(&sd, true).is_err());
527    }
528
529    #[test]
530    fn test_module_is_send_sync() {
531        fn assert_send_sync<T: Send + Sync>() {}
532        assert_send_sync::<SimpleModule<f32>>();
533    }
534
535    #[test]
536    fn test_reduction_enum() {
537        assert_eq!(Reduction::Mean, Reduction::Mean);
538        assert_ne!(Reduction::Mean, Reduction::Sum);
539    }
540
541    #[test]
542    fn test_to_device_cpu_preserves_weights() {
543        let mut m = SimpleModule::<f32>::new(4).unwrap();
544        m.to_device(ferrotorch_core::Device::Cpu).unwrap();
545        assert_eq!(m.parameters().len(), 1);
546        assert_eq!(m.parameters()[0].shape(), &[4]);
547    }
548
549    #[test]
550    fn test_to_device_cuda_without_backend() {
551        let mut m = SimpleModule::<f32>::new(3).unwrap();
552        let result = m.to_device(ferrotorch_core::Device::Cuda(0));
553        assert!(result.is_err());
554    }
555
556    // -----------------------------------------------------------------------
557    // Module trait additions: buffers / children / zero_grad / requires_grad_ /
558    // apply_to_parameters / modules iteration. (#583)
559    // -----------------------------------------------------------------------
560
561    /// A module with one parameter, one buffer, and a child.
562    struct ParentModule<T: Float> {
563        weight: Parameter<T>,
564        running_mean: Buffer<T>,
565        child: SimpleModule<T>,
566    }
567
568    impl<T: Float> ParentModule<T> {
569        fn new() -> FerrotorchResult<Self> {
570            Ok(Self {
571                weight: Parameter::ones(&[2, 2])?,
572                running_mean: Buffer::zeros(&[2])?,
573                child: SimpleModule::new(3)?,
574            })
575        }
576    }
577
578    impl<T: Float> Module<T> for ParentModule<T> {
579        fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
580            self.child.forward(input)
581        }
582
583        fn parameters(&self) -> Vec<&Parameter<T>> {
584            // self.weight + child.parameters()
585            let mut out: Vec<&Parameter<T>> = vec![&self.weight];
586            out.extend(self.child.parameters());
587            out
588        }
589
590        fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
591            let mut out: Vec<&mut Parameter<T>> = vec![&mut self.weight];
592            out.extend(self.child.parameters_mut());
593            out
594        }
595
596        fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
597            let mut out: Vec<(String, &Parameter<T>)> = vec![("weight".to_string(), &self.weight)];
598            for (n, p) in self.child.named_parameters() {
599                out.push((format!("child.{n}"), p));
600            }
601            out
602        }
603
604        fn buffers(&self) -> Vec<&Buffer<T>> {
605            vec![&self.running_mean]
606        }
607
608        fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>> {
609            vec![&mut self.running_mean]
610        }
611
612        fn named_buffers(&self) -> Vec<(String, &Buffer<T>)> {
613            vec![("running_mean".to_string(), &self.running_mean)]
614        }
615
616        fn children(&self) -> Vec<&dyn Module<T>> {
617            vec![&self.child]
618        }
619
620        fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
621            vec![("child".to_string(), &self.child)]
622        }
623
624        fn train(&mut self) {
625            self.child.train();
626        }
627
628        fn eval(&mut self) {
629            self.child.eval();
630        }
631
632        fn is_training(&self) -> bool {
633            self.child.is_training()
634        }
635    }
636
637    #[test]
638    fn module_buffers_default_is_empty() {
639        // SimpleModule doesn't override buffers() — default impl returns empty.
640        let m = SimpleModule::<f32>::new(3).unwrap();
641        assert!(m.buffers().is_empty());
642        assert!(m.named_buffers().is_empty());
643    }
644
645    #[test]
646    fn module_buffers_listed_for_overriding_module() {
647        let m = ParentModule::<f32>::new().unwrap();
648        assert_eq!(m.buffers().len(), 1);
649        assert_eq!(m.buffers()[0].shape(), &[2]);
650        let nb = m.named_buffers();
651        assert_eq!(nb.len(), 1);
652        assert_eq!(nb[0].0, "running_mean");
653    }
654
655    #[test]
656    fn module_children_listed_for_parent() {
657        let m = ParentModule::<f32>::new().unwrap();
658        assert_eq!(m.children().len(), 1);
659        assert_eq!(m.named_children().len(), 1);
660        assert_eq!(m.named_children()[0].0, "child");
661    }
662
663    #[test]
664    fn module_named_modules_includes_self_and_descendants() {
665        let m = ParentModule::<f32>::new().unwrap();
666        let nm = m.named_modules();
667        // Root + 1 child = 2 entries.
668        assert_eq!(nm.len(), 2);
669        assert_eq!(nm[0].0, "");
670        assert_eq!(nm[1].0, "child");
671    }
672
673    #[test]
674    fn module_modules_includes_self_and_descendants() {
675        let m = ParentModule::<f32>::new().unwrap();
676        let mods = m.modules();
677        assert_eq!(mods.len(), 2);
678    }
679
680    #[test]
681    fn module_zero_grad_succeeds() {
682        // No grads yet on a fresh module — zero_grad should still succeed.
683        let m = SimpleModule::<f32>::new(3).unwrap();
684        m.zero_grad().unwrap();
685    }
686
687    #[test]
688    fn module_requires_grad_toggles_all_parameters() {
689        let mut m = ParentModule::<f32>::new().unwrap();
690        for p in m.parameters() {
691            assert!(p.requires_grad());
692        }
693        m.requires_grad_(false);
694        for p in m.parameters() {
695            assert!(!p.requires_grad());
696        }
697        m.requires_grad_(true);
698        for p in m.parameters() {
699            assert!(p.requires_grad());
700        }
701    }
702
703    #[test]
704    fn module_apply_to_parameters_visits_all() {
705        let mut m = ParentModule::<f32>::new().unwrap();
706        let n_params = m.parameters().len();
707        let mut count = 0;
708        m.apply_to_parameters(&mut |_p| count += 1);
709        assert_eq!(count, n_params);
710    }
711
712    #[test]
713    fn module_state_dict_includes_buffers() {
714        let m = ParentModule::<f32>::new().unwrap();
715        let sd = m.state_dict();
716        assert!(sd.contains_key("weight"));
717        assert!(sd.contains_key("running_mean"));
718        assert!(sd.contains_key("child.weight"));
719        assert_eq!(sd.len(), 3);
720    }
721
722    #[test]
723    fn module_load_state_dict_with_buffer() {
724        let mut m = ParentModule::<f32>::new().unwrap();
725        let mut sd: StateDict<f32> = HashMap::new();
726        sd.insert(
727            "weight".into(),
728            ferrotorch_core::ones::<f32>(&[2, 2]).unwrap(),
729        );
730        sd.insert(
731            "running_mean".into(),
732            ferrotorch_core::from_slice::<f32>(&[7.0, 9.0], &[2]).unwrap(),
733        );
734        sd.insert(
735            "child.weight".into(),
736            ferrotorch_core::zeros::<f32>(&[3]).unwrap(),
737        );
738        m.load_state_dict(&sd, true).unwrap();
739        assert_eq!(m.buffers()[0].data().unwrap(), &[7.0, 9.0]);
740    }
741
742    #[test]
743    fn module_descendants_dyn_excludes_self() {
744        let m = ParentModule::<f32>::new().unwrap();
745        let d = m.descendants_dyn();
746        assert_eq!(d.len(), 1);
747    }
748
749    #[test]
750    fn module_named_descendants_dyn_paths() {
751        let m = ParentModule::<f32>::new().unwrap();
752        let nd = m.named_descendants_dyn();
753        assert_eq!(nd.len(), 1);
754        assert_eq!(nd[0].0, "child");
755    }
756
757    /// #1142 regression lock: a transparent wrapper module that exposes
758    /// its inner child at path `""` must NOT prepend a leading `.` to
759    /// the child's own descendant paths.
760    ///
761    /// The DeepLabV3 model uses this idiom — `DeepLabV3::named_children`
762    /// returns `("", &backbone)` and `ResNet50Dilated::named_children`
763    /// returns `("backbone", &inner)`. Pre-#1142 the walker produced
764    /// `".backbone"`, mismatching state-dict keys like `"backbone.bn1.X"`
765    /// and silently failing every BN-buffer load on DeepLabV3's backbone.
766    #[test]
767    fn module_named_descendants_dyn_empty_parent_no_leading_dot() {
768        /// Wraps a `ParentModule` at the empty path. The descendant walker
769        /// must reach `ParentModule`'s `("child", _)` entry as the bare
770        /// path `"child"`, not `".child"`.
771        struct TransparentWrapper<T: Float> {
772            inner: ParentModule<T>,
773        }
774        impl<T: Float> Module<T> for TransparentWrapper<T> {
775            fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
776                self.inner.forward(input)
777            }
778            fn parameters(&self) -> Vec<&Parameter<T>> {
779                self.inner.parameters()
780            }
781            fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
782                self.inner.parameters_mut()
783            }
784            fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
785                self.inner.named_parameters()
786            }
787            fn children(&self) -> Vec<&dyn Module<T>> {
788                vec![&self.inner]
789            }
790            fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
791                vec![(String::new(), &self.inner)]
792            }
793            fn train(&mut self) {
794                self.inner.train();
795            }
796            fn eval(&mut self) {
797                self.inner.eval();
798            }
799            fn is_training(&self) -> bool {
800                self.inner.is_training()
801            }
802        }
803        let m = TransparentWrapper::<f32> {
804            inner: ParentModule::new().unwrap(),
805        };
806        let nd: Vec<String> = m
807            .named_descendants_dyn()
808            .into_iter()
809            .map(|(n, _)| n)
810            .collect();
811        // 2 entries: ("" -> inner) and ("child" -> grandchild).
812        assert_eq!(nd, vec![String::new(), "child".to_string()]);
813        for p in &nd {
814            assert!(
815                !p.starts_with('.'),
816                "transparent-wrapper descendant path '{p}' starts with '.'; \
817                 the empty-parent branch in named_descendants_dyn has regressed",
818            );
819        }
820    }
821
822    // -------------------------------------------------------------------
823    // Hook-registration trait methods (#606)
824    // -------------------------------------------------------------------
825
826    #[test]
827    fn with_forward_hook_wraps_and_fires() {
828        use std::sync::atomic::{AtomicUsize, Ordering};
829        let m = SimpleModule::<f32>::new(2).unwrap();
830        let counter = std::sync::Arc::new(AtomicUsize::new(0));
831        let counter_for_hook = std::sync::Arc::clone(&counter);
832
833        let (wrapped, _handle) = m.with_forward_hook(Box::new(move |_input, _output| {
834            counter_for_hook.fetch_add(1, Ordering::SeqCst);
835        }));
836
837        let input = ferrotorch_core::Tensor::from_storage(
838            ferrotorch_core::TensorStorage::cpu(vec![1.0_f32, 2.0]),
839            vec![2],
840            false,
841        )
842        .unwrap();
843        let _ = wrapped.forward(&input).unwrap();
844        assert_eq!(counter.load(Ordering::SeqCst), 1);
845    }
846
847    #[test]
848    fn with_forward_pre_hook_wraps_and_fires() {
849        use std::sync::atomic::{AtomicUsize, Ordering};
850        let m = SimpleModule::<f32>::new(2).unwrap();
851        let counter = std::sync::Arc::new(AtomicUsize::new(0));
852        let counter_for_hook = std::sync::Arc::clone(&counter);
853
854        let (wrapped, _handle) = m.with_forward_pre_hook(Box::new(move |input| {
855            counter_for_hook.fetch_add(1, Ordering::SeqCst);
856            Ok(input.clone())
857        }));
858
859        let input = ferrotorch_core::Tensor::from_storage(
860            ferrotorch_core::TensorStorage::cpu(vec![1.0_f32, 2.0]),
861            vec![2],
862            false,
863        )
864        .unwrap();
865        let _ = wrapped.forward(&input).unwrap();
866        assert_eq!(counter.load(Ordering::SeqCst), 1);
867    }
868
869    #[test]
870    fn with_backward_hook_returns_handle() {
871        // backward hook fires only on the backward pass; just verify the
872        // wrapping API resolves and returns a usable HookedModule + handle.
873        let m = SimpleModule::<f32>::new(2).unwrap();
874        let (wrapped, handle) = m.with_backward_hook(Box::new(|_gi, _go| {}));
875        // Wrapper still implements Module<T> trait — slot it into a forward
876        // call to confirm it round-trips.
877        let input = ferrotorch_core::Tensor::from_storage(
878            ferrotorch_core::TensorStorage::cpu(vec![3.0_f32]),
879            vec![1],
880            false,
881        )
882        .unwrap();
883        let _ = wrapped.forward(&input).unwrap();
884        // Handle is droppable; explicit remove is also fine.
885        handle.remove();
886    }
887}