Skip to main content

ferrotorch_nn/
module.rs

1use std::collections::HashMap;
2
3use ferrotorch_core::{Device, FerrotorchError, FerrotorchResult, Float, Tensor};
4
5use crate::buffer::Buffer;
6use crate::hooks::{BackwardHook, ForwardHook, ForwardPreHook, HookHandle, HookedModule};
7use crate::parameter::Parameter;
8
9/// A map from parameter names to tensors, used for serialization.
10pub type StateDict<T> = HashMap<String, Tensor<T>>;
11
12/// Reduction mode for loss functions.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Reduction {
15    /// Return the mean of all losses.
16    Mean,
17    /// Return the sum of all losses.
18    Sum,
19    /// Return the unreduced loss tensor.
20    None,
21}
22
23/// The trait that all neural network layers implement.
24///
25/// Requires `Send + Sync` to match `Tensor<T>`'s thread-safety guarantees.
26pub trait Module<T: Float>: Send + Sync {
27    /// Forward pass. Takes input tensor, returns output tensor.
28    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>;
29
30    /// Iterate over all learnable parameters.
31    fn parameters(&self) -> Vec<&Parameter<T>>;
32
33    /// Iterate over all learnable parameters mutably.
34    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>;
35
36    /// Named parameters for state dict serialization.
37    ///
38    /// Keys use dot-separated paths for nested modules
39    /// (e.g., `"layer1.weight"`, `"layer1.bias"`).
40    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>;
41
42    /// Set training mode. Affects dropout, batchnorm, etc.
43    fn train(&mut self);
44
45    /// Set evaluation mode.
46    fn eval(&mut self);
47
48    /// Whether the module is in training mode.
49    fn is_training(&self) -> bool;
50
51    /// Move all parameters and buffers to a device.
52    ///
53    /// Default implementation iterates `parameters_mut()` and `buffers_mut()`
54    /// and transfers each.
55    fn to_device(&mut self, device: Device) -> FerrotorchResult<()> {
56        for param in self.parameters_mut() {
57            *param = param.to(device)?;
58        }
59        for buffer in self.buffers_mut() {
60            *buffer = buffer.to(device)?;
61        }
62        Ok(())
63    }
64
65    /// Export parameters and buffers as a state dict (torch parity).
66    ///
67    /// Buffers are included alongside parameters since both are persistent
68    /// module state. Keys are dot-separated paths.
69    fn state_dict(&self) -> StateDict<T> {
70        let mut out: StateDict<T> = self
71            .named_parameters()
72            .into_iter()
73            .map(|(name, param)| (name, param.tensor().clone()))
74            .collect();
75        for (name, buffer) in self.named_buffers() {
76            out.insert(name, buffer.tensor().clone());
77        }
78        out
79    }
80
81    // -----------------------------------------------------------------
82    // Buffers — non-trainable persistent state. (#583)
83    // -----------------------------------------------------------------
84
85    /// Iterate over all non-trainable buffers (e.g. running mean / variance
86    /// in BatchNorm). Default returns empty — concrete modules with buffers
87    /// override.
88    fn buffers(&self) -> Vec<&Buffer<T>> {
89        Vec::new()
90    }
91
92    /// Mutable iteration over all buffers. Default returns empty.
93    fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>> {
94        Vec::new()
95    }
96
97    /// Named buffers (dot-separated paths for nested modules). Default
98    /// returns empty.
99    fn named_buffers(&self) -> Vec<(String, &Buffer<T>)> {
100        Vec::new()
101    }
102
103    /// Downcast hook for type-erased buffer-loader dispatch. (#984)
104    ///
105    /// Returns `Some(&self as &dyn Any)` for concrete module types whose
106    /// non-`Buffer<T>` persistent state needs to be applied from a state
107    /// dict (currently `BatchNorm1d` / `BatchNorm2d` / `BatchNorm3d`'s
108    /// running mean / variance / `num_batches_tracked` — see Phase 2
109    /// of the value-parity pipeline in `ferrotorch-vision/tests`).
110    ///
111    /// The default returns `None`, so existing modules are unaffected:
112    /// type-erased callers walking `named_modules()` will simply skip
113    /// modules that do not opt in. Implementors MUST return
114    /// `Some(self)`; returning `Some` for an unrelated `Any` would
115    /// violate the contract.
116    ///
117    /// Why a downcast hook instead of a wider trait surface (e.g. a
118    /// dedicated `set_buffer_value(&self, &str, &Tensor<T>)` method on
119    /// `Module`)? Buffers carrying torch-shaped state (running mean /
120    /// variance, `num_batches_tracked: usize`) currently live outside
121    /// the [`Buffer<T>`] abstraction (BN keeps `Mutex<Vec<f64>>` for
122    /// numerical stability and the integer counter has no `Buffer`
123    /// at all), so a single typed setter on `Module` would force a
124    /// premature unification that #984 explicitly defers. The downcast
125    /// hook keeps `Module` free of BN-specific shape and lets concrete
126    /// modules expose their own typed setters at full precision.
127    fn as_any(&self) -> Option<&dyn std::any::Any> {
128        None
129    }
130
131    // -----------------------------------------------------------------
132    // Submodule iteration. (#583)
133    // -----------------------------------------------------------------
134
135    /// Direct child modules. Default returns empty (leaf module).
136    fn children(&self) -> Vec<&dyn Module<T>> {
137        Vec::new()
138    }
139
140    /// Direct child modules with their attribute names. Default returns
141    /// empty.
142    fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
143        Vec::new()
144    }
145
146    /// All modules in this subtree, depth-first (self first, then each
147    /// child's descendants in order).
148    ///
149    /// Requires `Self: Sized` so we can coerce `self` to `&dyn Module<T>`.
150    /// Trait-object callers can use [`Module::descendants_dyn`] (which yields
151    /// descendants only) and prepend their own reference.
152    fn modules(&self) -> Vec<&dyn Module<T>>
153    where
154        Self: Sized,
155    {
156        let mut out: Vec<&dyn Module<T>> = vec![self];
157        out.extend(self.descendants_dyn());
158        out
159    }
160
161    /// All strict descendants of `self` in depth-first order. Object-safe.
162    fn descendants_dyn(&self) -> Vec<&dyn Module<T>> {
163        let mut out: Vec<&dyn Module<T>> = Vec::new();
164        for child in self.children() {
165            out.push(child);
166            out.extend(child.descendants_dyn());
167        }
168        out
169    }
170
171    /// All modules in this subtree with dot-separated path names. The root
172    /// is named `""`; children paths are joined with `.`.
173    fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>
174    where
175        Self: Sized,
176    {
177        let mut out: Vec<(String, &dyn Module<T>)> = vec![(String::new(), self)];
178        out.extend(self.named_descendants_dyn());
179        out
180    }
181
182    /// Strict descendants with dot-paths. Object-safe.
183    fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)> {
184        let mut out: Vec<(String, &dyn Module<T>)> = Vec::new();
185        for (name, child) in self.named_children() {
186            out.push((name.clone(), child));
187            for (sub_name, sub_module) in child.named_descendants_dyn() {
188                let full = if sub_name.is_empty() {
189                    name.clone()
190                } else if name.is_empty() {
191                    // Transparent wrapper: parent exposes child under "" so
192                    // the child's own naming (e.g. `("backbone", inner)`)
193                    // becomes the canonical path. Without this branch the
194                    // walker would produce `".backbone.X"` (leading dot),
195                    // mismatching state-dict keys like `"backbone.X"`.
196                    // See #1142 for the DeepLabV3 BN-buffer routing bug
197                    // that this branch fixes.
198                    sub_name
199                } else {
200                    format!("{name}.{sub_name}")
201                };
202                out.push((full, sub_module));
203            }
204        }
205        out
206    }
207
208    // -----------------------------------------------------------------
209    // Helpers. (#583)
210    // -----------------------------------------------------------------
211
212    // -----------------------------------------------------------------
213    // Hooks (#606)
214    //
215    // These consume `self` and return a [`HookedModule<Self, T>`] with the
216    // requested hook already registered. Mirrors `torch.nn.Module
217    // .register_*_hook(...)` ergonomically — callers no longer need to
218    // wrap manually with `HookedModule::new(..)` first. Gated on
219    // `Self: Sized` so the trait stays dyn-compatible.
220    //
221    // Named with the `with_*` prefix (rather than `register_*` directly) to
222    // avoid clashing with `HookedModule`'s own inherent `register_*` methods,
223    // which take `&self` and append a hook to an already-wrapped instance.
224    // The two surfaces compose: `Linear::new(..)?.with_forward_hook(h1).0`
225    // is a `HookedModule` that can `.register_forward_hook(h2)` again.
226    // -----------------------------------------------------------------
227
228    /// Wrap this module in a [`HookedModule`] and register a forward hook.
229    /// Returns the wrapper paired with a [`HookHandle`] that can be used to
230    /// remove the hook later. The wrapper implements `Module<T>` itself, so
231    /// it slots into any place the original module did. Mirrors
232    /// `torch.nn.Module.register_forward_hook`.
233    fn with_forward_hook(self, hook: ForwardHook<T>) -> (HookedModule<Self, T>, HookHandle)
234    where
235        Self: Sized,
236    {
237        let wrapped = HookedModule::new(self);
238        let handle = wrapped.register_forward_hook(hook);
239        (wrapped, handle)
240    }
241
242    /// Wrap this module in a [`HookedModule`] and register a forward
243    /// pre-hook. See [`Self::with_forward_hook`]. Mirrors
244    /// `torch.nn.Module.register_forward_pre_hook`.
245    fn with_forward_pre_hook(self, hook: ForwardPreHook<T>) -> (HookedModule<Self, T>, HookHandle)
246    where
247        Self: Sized,
248    {
249        let wrapped = HookedModule::new(self);
250        let handle = wrapped.register_forward_pre_hook(hook);
251        (wrapped, handle)
252    }
253
254    /// Wrap this module in a [`HookedModule`] and register a backward hook.
255    /// See [`Self::with_forward_hook`]. Mirrors
256    /// `torch.nn.Module.register_backward_hook`.
257    fn with_backward_hook(self, hook: BackwardHook<T>) -> (HookedModule<Self, T>, HookHandle)
258    where
259        Self: Sized,
260    {
261        let wrapped = HookedModule::new(self);
262        let handle = wrapped.register_backward_hook(hook);
263        (wrapped, handle)
264    }
265
266    /// Set the gradient of every parameter to `None`.
267    ///
268    /// Equivalent to calling `tensor.zero_grad()` on each parameter's
269    /// underlying tensor. Mirrors `torch.nn.Module.zero_grad`.
270    fn zero_grad(&self) -> FerrotorchResult<()> {
271        for param in self.parameters() {
272            param.tensor().zero_grad()?;
273        }
274        Ok(())
275    }
276
277    /// Toggle `requires_grad` on every parameter (freeze / unfreeze the
278    /// module). Mirrors `torch.nn.Module.requires_grad_`.
279    fn requires_grad_(&mut self, requires_grad: bool) {
280        for param in self.parameters_mut() {
281            param.set_requires_grad(requires_grad);
282        }
283    }
284
285    /// Apply a function to every parameter in this module. Mirrors
286    /// `torch.nn.Module.apply` for the parameter case (true `apply` recurses
287    /// over all submodules; the recursive form requires `&mut dyn Module`
288    /// which conflicts with this trait's `&mut self` borrow).
289    ///
290    /// Takes `&mut dyn FnMut(...)` (rather than a generic closure) so the
291    /// trait stays dyn-compatible — `Box<dyn Module<T>>` is a common usage
292    /// pattern.
293    fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>)) {
294        for param in self.parameters_mut() {
295            f(param);
296        }
297    }
298
299    /// Load parameters from a state dict.
300    ///
301    /// When `strict` is `true` (default), unexpected keys are an error.
302    /// When `false`, unexpected keys are silently ignored and missing
303    /// keys leave existing parameter values unchanged.
304    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
305        // Known keys: union of parameter and buffer paths.
306        let mut known_keys: std::collections::HashSet<String> = self
307            .named_parameters()
308            .iter()
309            .map(|(k, _)| k.clone())
310            .collect();
311        for (k, _) in self.named_buffers() {
312            known_keys.insert(k);
313        }
314
315        if strict {
316            for key in state.keys() {
317                if !known_keys.contains(key) {
318                    return Err(FerrotorchError::InvalidArgument {
319                        message: format!("unexpected key in state_dict: \"{key}\""),
320                    });
321                }
322            }
323        }
324
325        // We need mutable access to parameters. Use named_parameters to get
326        // the mapping, then parameters_mut to actually update.
327        // This two-pass approach avoids borrowing issues.
328        let param_names: Vec<String> = self
329            .named_parameters()
330            .into_iter()
331            .map(|(name, _)| name)
332            .collect();
333
334        let params_mut = self.parameters_mut();
335
336        for (name, param) in param_names.iter().zip(params_mut) {
337            if let Some(tensor) = state.get(name) {
338                if param.shape() != tensor.shape() {
339                    return Err(FerrotorchError::ShapeMismatch {
340                        message: format!(
341                            "state_dict shape mismatch for \"{name}\": expected {:?}, got {:?}",
342                            param.shape(),
343                            tensor.shape()
344                        ),
345                    });
346                }
347                // Replace the parameter data with the loaded tensor.
348                *param = Parameter::new(tensor.clone());
349            } else if strict {
350                return Err(FerrotorchError::InvalidArgument {
351                    message: format!("missing key in state_dict: \"{name}\""),
352                });
353            }
354        }
355
356        // Same dance for buffers.
357        let buffer_names: Vec<String> = self
358            .named_buffers()
359            .into_iter()
360            .map(|(name, _)| name)
361            .collect();
362        let buffers_mut = self.buffers_mut();
363        for (name, buf) in buffer_names.iter().zip(buffers_mut) {
364            if let Some(tensor) = state.get(name) {
365                if buf.shape() != tensor.shape() {
366                    return Err(FerrotorchError::ShapeMismatch {
367                        message: format!(
368                            "state_dict shape mismatch for buffer \"{name}\": expected {:?}, got {:?}",
369                            buf.shape(),
370                            tensor.shape()
371                        ),
372                    });
373                }
374                *buf = Buffer::new(tensor.clone());
375            } else if strict {
376                return Err(FerrotorchError::InvalidArgument {
377                    message: format!("missing buffer key in state_dict: \"{name}\""),
378                });
379            }
380        }
381
382        Ok(())
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    /// A minimal test module with one parameter.
390    struct SimpleModule<T: Float> {
391        weight: Parameter<T>,
392        training: bool,
393    }
394
395    impl<T: Float> SimpleModule<T> {
396        fn new(size: usize) -> FerrotorchResult<Self> {
397            Ok(Self {
398                weight: Parameter::zeros(&[size])?,
399                training: true,
400            })
401        }
402    }
403
404    impl<T: Float> Module<T> for SimpleModule<T> {
405        fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
406            // Just return input for testing.
407            Ok(input.clone())
408        }
409
410        fn parameters(&self) -> Vec<&Parameter<T>> {
411            vec![&self.weight]
412        }
413
414        fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
415            vec![&mut self.weight]
416        }
417
418        fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
419            vec![("weight".to_string(), &self.weight)]
420        }
421
422        fn train(&mut self) {
423            self.training = true;
424        }
425
426        fn eval(&mut self) {
427            self.training = false;
428        }
429
430        fn is_training(&self) -> bool {
431            self.training
432        }
433    }
434
435    #[test]
436    fn test_module_parameters() {
437        let m = SimpleModule::<f32>::new(5).unwrap();
438        assert_eq!(m.parameters().len(), 1);
439        assert_eq!(m.parameters()[0].shape(), &[5]);
440    }
441
442    #[test]
443    fn test_module_named_parameters() {
444        let m = SimpleModule::<f32>::new(3).unwrap();
445        let named = m.named_parameters();
446        assert_eq!(named.len(), 1);
447        assert_eq!(named[0].0, "weight");
448    }
449
450    #[test]
451    fn test_module_train_eval() {
452        let mut m = SimpleModule::<f32>::new(2).unwrap();
453        assert!(m.is_training());
454        m.eval();
455        assert!(!m.is_training());
456        m.train();
457        assert!(m.is_training());
458    }
459
460    #[test]
461    fn test_module_state_dict_roundtrip() {
462        let m = SimpleModule::<f32>::new(4).unwrap();
463        let sd = m.state_dict();
464        assert!(sd.contains_key("weight"));
465        assert_eq!(sd["weight"].shape(), &[4]);
466
467        let mut m2 = SimpleModule::<f32>::new(4).unwrap();
468        m2.load_state_dict(&sd, true).unwrap();
469    }
470
471    #[test]
472    fn test_module_state_dict_strict_extra_key() {
473        let mut m = SimpleModule::<f32>::new(3).unwrap();
474        let mut sd = HashMap::new();
475        sd.insert(
476            "weight".to_string(),
477            ferrotorch_core::zeros::<f32>(&[3]).unwrap(),
478        );
479        sd.insert(
480            "extra".to_string(),
481            ferrotorch_core::zeros::<f32>(&[1]).unwrap(),
482        );
483
484        assert!(m.load_state_dict(&sd, true).is_err());
485        assert!(m.load_state_dict(&sd, false).is_ok());
486    }
487
488    #[test]
489    fn test_module_state_dict_shape_mismatch() {
490        let mut m = SimpleModule::<f32>::new(3).unwrap();
491        let mut sd = HashMap::new();
492        sd.insert(
493            "weight".to_string(),
494            ferrotorch_core::zeros::<f32>(&[5]).unwrap(),
495        );
496
497        assert!(m.load_state_dict(&sd, true).is_err());
498    }
499
500    #[test]
501    fn test_module_is_send_sync() {
502        fn assert_send_sync<T: Send + Sync>() {}
503        assert_send_sync::<SimpleModule<f32>>();
504    }
505
506    #[test]
507    fn test_reduction_enum() {
508        assert_eq!(Reduction::Mean, Reduction::Mean);
509        assert_ne!(Reduction::Mean, Reduction::Sum);
510    }
511
512    #[test]
513    fn test_to_device_cpu_preserves_weights() {
514        let mut m = SimpleModule::<f32>::new(4).unwrap();
515        m.to_device(ferrotorch_core::Device::Cpu).unwrap();
516        assert_eq!(m.parameters().len(), 1);
517        assert_eq!(m.parameters()[0].shape(), &[4]);
518    }
519
520    #[test]
521    fn test_to_device_cuda_without_backend() {
522        let mut m = SimpleModule::<f32>::new(3).unwrap();
523        let result = m.to_device(ferrotorch_core::Device::Cuda(0));
524        assert!(result.is_err());
525    }
526
527    // -----------------------------------------------------------------------
528    // Module trait additions: buffers / children / zero_grad / requires_grad_ /
529    // apply_to_parameters / modules iteration. (#583)
530    // -----------------------------------------------------------------------
531
532    /// A module with one parameter, one buffer, and a child.
533    struct ParentModule<T: Float> {
534        weight: Parameter<T>,
535        running_mean: Buffer<T>,
536        child: SimpleModule<T>,
537    }
538
539    impl<T: Float> ParentModule<T> {
540        fn new() -> FerrotorchResult<Self> {
541            Ok(Self {
542                weight: Parameter::ones(&[2, 2])?,
543                running_mean: Buffer::zeros(&[2])?,
544                child: SimpleModule::new(3)?,
545            })
546        }
547    }
548
549    impl<T: Float> Module<T> for ParentModule<T> {
550        fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
551            self.child.forward(input)
552        }
553
554        fn parameters(&self) -> Vec<&Parameter<T>> {
555            // self.weight + child.parameters()
556            let mut out: Vec<&Parameter<T>> = vec![&self.weight];
557            out.extend(self.child.parameters());
558            out
559        }
560
561        fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
562            let mut out: Vec<&mut Parameter<T>> = vec![&mut self.weight];
563            out.extend(self.child.parameters_mut());
564            out
565        }
566
567        fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
568            let mut out: Vec<(String, &Parameter<T>)> = vec![("weight".to_string(), &self.weight)];
569            for (n, p) in self.child.named_parameters() {
570                out.push((format!("child.{n}"), p));
571            }
572            out
573        }
574
575        fn buffers(&self) -> Vec<&Buffer<T>> {
576            vec![&self.running_mean]
577        }
578
579        fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>> {
580            vec![&mut self.running_mean]
581        }
582
583        fn named_buffers(&self) -> Vec<(String, &Buffer<T>)> {
584            vec![("running_mean".to_string(), &self.running_mean)]
585        }
586
587        fn children(&self) -> Vec<&dyn Module<T>> {
588            vec![&self.child]
589        }
590
591        fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
592            vec![("child".to_string(), &self.child)]
593        }
594
595        fn train(&mut self) {
596            self.child.train();
597        }
598
599        fn eval(&mut self) {
600            self.child.eval();
601        }
602
603        fn is_training(&self) -> bool {
604            self.child.is_training()
605        }
606    }
607
608    #[test]
609    fn module_buffers_default_is_empty() {
610        // SimpleModule doesn't override buffers() — default impl returns empty.
611        let m = SimpleModule::<f32>::new(3).unwrap();
612        assert!(m.buffers().is_empty());
613        assert!(m.named_buffers().is_empty());
614    }
615
616    #[test]
617    fn module_buffers_listed_for_overriding_module() {
618        let m = ParentModule::<f32>::new().unwrap();
619        assert_eq!(m.buffers().len(), 1);
620        assert_eq!(m.buffers()[0].shape(), &[2]);
621        let nb = m.named_buffers();
622        assert_eq!(nb.len(), 1);
623        assert_eq!(nb[0].0, "running_mean");
624    }
625
626    #[test]
627    fn module_children_listed_for_parent() {
628        let m = ParentModule::<f32>::new().unwrap();
629        assert_eq!(m.children().len(), 1);
630        assert_eq!(m.named_children().len(), 1);
631        assert_eq!(m.named_children()[0].0, "child");
632    }
633
634    #[test]
635    fn module_named_modules_includes_self_and_descendants() {
636        let m = ParentModule::<f32>::new().unwrap();
637        let nm = m.named_modules();
638        // Root + 1 child = 2 entries.
639        assert_eq!(nm.len(), 2);
640        assert_eq!(nm[0].0, "");
641        assert_eq!(nm[1].0, "child");
642    }
643
644    #[test]
645    fn module_modules_includes_self_and_descendants() {
646        let m = ParentModule::<f32>::new().unwrap();
647        let mods = m.modules();
648        assert_eq!(mods.len(), 2);
649    }
650
651    #[test]
652    fn module_zero_grad_succeeds() {
653        // No grads yet on a fresh module — zero_grad should still succeed.
654        let m = SimpleModule::<f32>::new(3).unwrap();
655        m.zero_grad().unwrap();
656    }
657
658    #[test]
659    fn module_requires_grad_toggles_all_parameters() {
660        let mut m = ParentModule::<f32>::new().unwrap();
661        for p in m.parameters() {
662            assert!(p.requires_grad());
663        }
664        m.requires_grad_(false);
665        for p in m.parameters() {
666            assert!(!p.requires_grad());
667        }
668        m.requires_grad_(true);
669        for p in m.parameters() {
670            assert!(p.requires_grad());
671        }
672    }
673
674    #[test]
675    fn module_apply_to_parameters_visits_all() {
676        let mut m = ParentModule::<f32>::new().unwrap();
677        let n_params = m.parameters().len();
678        let mut count = 0;
679        m.apply_to_parameters(&mut |_p| count += 1);
680        assert_eq!(count, n_params);
681    }
682
683    #[test]
684    fn module_state_dict_includes_buffers() {
685        let m = ParentModule::<f32>::new().unwrap();
686        let sd = m.state_dict();
687        assert!(sd.contains_key("weight"));
688        assert!(sd.contains_key("running_mean"));
689        assert!(sd.contains_key("child.weight"));
690        assert_eq!(sd.len(), 3);
691    }
692
693    #[test]
694    fn module_load_state_dict_with_buffer() {
695        let mut m = ParentModule::<f32>::new().unwrap();
696        let mut sd: StateDict<f32> = HashMap::new();
697        sd.insert(
698            "weight".into(),
699            ferrotorch_core::ones::<f32>(&[2, 2]).unwrap(),
700        );
701        sd.insert(
702            "running_mean".into(),
703            ferrotorch_core::from_slice::<f32>(&[7.0, 9.0], &[2]).unwrap(),
704        );
705        sd.insert(
706            "child.weight".into(),
707            ferrotorch_core::zeros::<f32>(&[3]).unwrap(),
708        );
709        m.load_state_dict(&sd, true).unwrap();
710        assert_eq!(m.buffers()[0].data().unwrap(), &[7.0, 9.0]);
711    }
712
713    #[test]
714    fn module_descendants_dyn_excludes_self() {
715        let m = ParentModule::<f32>::new().unwrap();
716        let d = m.descendants_dyn();
717        assert_eq!(d.len(), 1);
718    }
719
720    #[test]
721    fn module_named_descendants_dyn_paths() {
722        let m = ParentModule::<f32>::new().unwrap();
723        let nd = m.named_descendants_dyn();
724        assert_eq!(nd.len(), 1);
725        assert_eq!(nd[0].0, "child");
726    }
727
728    /// #1142 regression lock: a transparent wrapper module that exposes
729    /// its inner child at path `""` must NOT prepend a leading `.` to
730    /// the child's own descendant paths.
731    ///
732    /// The DeepLabV3 model uses this idiom — `DeepLabV3::named_children`
733    /// returns `("", &backbone)` and `ResNet50Dilated::named_children`
734    /// returns `("backbone", &inner)`. Pre-#1142 the walker produced
735    /// `".backbone"`, mismatching state-dict keys like `"backbone.bn1.X"`
736    /// and silently failing every BN-buffer load on DeepLabV3's backbone.
737    #[test]
738    fn module_named_descendants_dyn_empty_parent_no_leading_dot() {
739        /// Wraps a `ParentModule` at the empty path. The descendant walker
740        /// must reach `ParentModule`'s `("child", _)` entry as the bare
741        /// path `"child"`, not `".child"`.
742        struct TransparentWrapper<T: Float> {
743            inner: ParentModule<T>,
744        }
745        impl<T: Float> Module<T> for TransparentWrapper<T> {
746            fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
747                self.inner.forward(input)
748            }
749            fn parameters(&self) -> Vec<&Parameter<T>> {
750                self.inner.parameters()
751            }
752            fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
753                self.inner.parameters_mut()
754            }
755            fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
756                self.inner.named_parameters()
757            }
758            fn children(&self) -> Vec<&dyn Module<T>> {
759                vec![&self.inner]
760            }
761            fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
762                vec![(String::new(), &self.inner)]
763            }
764            fn train(&mut self) {
765                self.inner.train();
766            }
767            fn eval(&mut self) {
768                self.inner.eval();
769            }
770            fn is_training(&self) -> bool {
771                self.inner.is_training()
772            }
773        }
774        let m = TransparentWrapper::<f32> {
775            inner: ParentModule::new().unwrap(),
776        };
777        let nd: Vec<String> = m.named_descendants_dyn().into_iter().map(|(n, _)| n).collect();
778        // 2 entries: ("" -> inner) and ("child" -> grandchild).
779        assert_eq!(nd, vec![String::new(), "child".to_string()]);
780        for p in &nd {
781            assert!(
782                !p.starts_with('.'),
783                "transparent-wrapper descendant path '{p}' starts with '.'; \
784                 the empty-parent branch in named_descendants_dyn has regressed",
785            );
786        }
787    }
788
789    // -------------------------------------------------------------------
790    // Hook-registration trait methods (#606)
791    // -------------------------------------------------------------------
792
793    #[test]
794    fn with_forward_hook_wraps_and_fires() {
795        use std::sync::atomic::{AtomicUsize, Ordering};
796        let m = SimpleModule::<f32>::new(2).unwrap();
797        let counter = std::sync::Arc::new(AtomicUsize::new(0));
798        let counter_for_hook = std::sync::Arc::clone(&counter);
799
800        let (wrapped, _handle) = m.with_forward_hook(Box::new(move |_input, _output| {
801            counter_for_hook.fetch_add(1, Ordering::SeqCst);
802        }));
803
804        let input = ferrotorch_core::Tensor::from_storage(
805            ferrotorch_core::TensorStorage::cpu(vec![1.0_f32, 2.0]),
806            vec![2],
807            false,
808        )
809        .unwrap();
810        let _ = wrapped.forward(&input).unwrap();
811        assert_eq!(counter.load(Ordering::SeqCst), 1);
812    }
813
814    #[test]
815    fn with_forward_pre_hook_wraps_and_fires() {
816        use std::sync::atomic::{AtomicUsize, Ordering};
817        let m = SimpleModule::<f32>::new(2).unwrap();
818        let counter = std::sync::Arc::new(AtomicUsize::new(0));
819        let counter_for_hook = std::sync::Arc::clone(&counter);
820
821        let (wrapped, _handle) = m.with_forward_pre_hook(Box::new(move |input| {
822            counter_for_hook.fetch_add(1, Ordering::SeqCst);
823            Ok(input.clone())
824        }));
825
826        let input = ferrotorch_core::Tensor::from_storage(
827            ferrotorch_core::TensorStorage::cpu(vec![1.0_f32, 2.0]),
828            vec![2],
829            false,
830        )
831        .unwrap();
832        let _ = wrapped.forward(&input).unwrap();
833        assert_eq!(counter.load(Ordering::SeqCst), 1);
834    }
835
836    #[test]
837    fn with_backward_hook_returns_handle() {
838        // backward hook fires only on the backward pass; just verify the
839        // wrapping API resolves and returns a usable HookedModule + handle.
840        let m = SimpleModule::<f32>::new(2).unwrap();
841        let (wrapped, handle) = m.with_backward_hook(Box::new(|_gi, _go| {}));
842        // Wrapper still implements Module<T> trait — slot it into a forward
843        // call to confirm it round-trips.
844        let input = ferrotorch_core::Tensor::from_storage(
845            ferrotorch_core::TensorStorage::cpu(vec![3.0_f32]),
846            vec![1],
847            false,
848        )
849        .unwrap();
850        let _ = wrapped.forward(&input).unwrap();
851        // Handle is droppable; explicit remove is also fine.
852        handle.remove();
853    }
854}