ferrotorch_nn/module.rs
1//! `Module<T>` trait, `Reduction`, and `StateDict<T>` — the Rust analog of
2//! PyTorch's `torch.nn.Module` base class.
3//!
4//! Every neural-network layer in `ferrotorch-nn` implements `Module<T>`. The
5//! trait composes parameter iteration, buffer iteration, train/eval mode,
6//! device transfer, sub-module walks, hook registration, gradient zeroing,
7//! and state-dict load/save.
8//!
9//! ## REQ status (per `.design/ferrotorch-nn/module.md`)
10//!
11//! | REQ | Status | Evidence |
12//! |---|---|---|
13//! | REQ-1 | SHIPPED | `pub type StateDict<T> = HashMap<String, Tensor<T>>` mirrors `torch/nn/modules/module.py:1980-2060`; consumed by `pub use module::{Module, Reduction, StateDict}` at `lib.rs:218` and every layer's `state_dict()` return. |
14//! | REQ-2 | SHIPPED | `pub enum Reduction { Mean, Sum, None }` mirrors PyTorch's loss `reduction` kwarg; consumed by `ferrotorch-nn/src/loss.rs:19` `use crate::module::Reduction` and `ferrotorch-nn/src/functional.rs:1798`. |
15//! | REQ-3 | SHIPPED | `pub trait Module<T: Float>: Send + Sync` mirrors `torch/nn/modules/module.py:407-441`; consumed by every layer file (`linear.rs`, `conv.rs`, `norm.rs`, `embedding.rs`, etc.) via `impl Module<T> for ...`. |
16//! | REQ-4 | SHIPPED | `forward`, `parameters`, `parameters_mut`, `named_parameters` required methods; consumed by `ferrotorch-optim/src/optimizer.rs:5` and every optimizer (`adam.rs:17`, `adadelta.rs:20`, …) reading `.parameters()`. |
17//! | REQ-5 | SHIPPED | `train`, `eval`, `is_training` required methods; consumed by `ferrotorch-nn/src/container.rs` `Sequential::train` / `::eval` propagation. |
18//! | REQ-6 | SHIPPED | Default `to_device` iterates parameters and buffers, mirroring `torch/nn/modules/module.py:1180-1260`; consumed by downstream `model.to_device(Device::Cuda(0))` calls in model composition code. |
19//! | REQ-7 | SHIPPED | Default `state_dict` unions parameters with buffers, mirroring `module.py:1980-2060`; consumed by `ferrotorch-nn/src/hooks.rs` `HookedModule::state_dict` delegation and SafeTensors export. |
20//! | REQ-8 | SHIPPED | `load_state_dict(strict)` two-pass strict + shape-validate, mirroring `module.py:2150-2310`; consumed by SafeTensors / GGUF loaders in downstream crates. |
21//! | REQ-9 | SHIPPED | `buffers` / `buffers_mut` / `named_buffers` default-empty methods mirror `module.py:2430-2490`; consumed by `ferrotorch-nn/src/norm.rs` `BatchNorm*` overrides and `module.rs` default `load_state_dict` (the `*buf = Buffer::new(...)` site). |
22//! | REQ-10 | SHIPPED | `as_any` default returns `None` for the #984 downcast hook; consumed by `ferrotorch-nn/src/norm.rs` `BatchNorm*` overriding it so `ferrotorch-vision`'s state-dict loader can route to BN running stats. |
23//! | REQ-11 | SHIPPED | `children`, `named_children`, `modules` (`Self: Sized`), `descendants_dyn` (object-safe), `named_modules`, `named_descendants_dyn` mirror `module.py:2510-2640`; consumed by `ferrotorch-nn/src/container.rs` containers and downstream state-dict walkers. |
24//! | REQ-12 | SHIPPED | Empty-parent path branch in `named_descendants_dyn` fixes #1142 (DeepLabV3 BN-buffer routing); consumed by `ferrotorch-vision`'s DeepLabV3 backbone load path; pinned by `module_named_descendants_dyn_empty_parent_no_leading_dot`. |
25//! | REQ-13 | SHIPPED | `with_forward_hook` / `with_forward_pre_hook` / `with_backward_hook` `Self: Sized` methods return `(HookedModule<Self, T>, HookHandle)`; consumed by `ferrotorch-nn/src/hooks.rs` (`HookedModule` is instantiated by these) and downstream observability code. |
26//! | REQ-14 | SHIPPED | Default `zero_grad` walks `parameters()` mirroring `module.py:2700-2740`; consumed by `ferrotorch-optim` and `ferrotorch-train/src/grad_utils.rs` calling `model.zero_grad()` per step. |
27//! | REQ-15 | SHIPPED | Default `requires_grad_(bool)` toggles all parameters mirroring `module.py:2680-2700`; consumed by transfer-learning callsites that freeze the backbone via `backbone.requires_grad_(false)`. |
28//! | REQ-16 | SHIPPED | Default `apply_to_parameters(&mut dyn FnMut(...))` mirrors `torch.nn.Module.apply` for the parameter case; consumed by lazy-init paths in downstream lazy layers (`lazy_linear.rs`, `lazy_conv.rs`). |
29
30use std::collections::HashMap;
31
32use ferrotorch_core::{Device, FerrotorchError, FerrotorchResult, Float, Tensor};
33
34use crate::buffer::Buffer;
35use crate::hooks::{BackwardHook, ForwardHook, ForwardPreHook, HookHandle, HookedModule};
36use crate::parameter::Parameter;
37
38/// A map from parameter names to tensors, used for serialization.
39pub type StateDict<T> = HashMap<String, Tensor<T>>;
40
41/// Reduction mode for loss functions.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum Reduction {
44 /// Return the mean of all losses.
45 Mean,
46 /// Return the sum of all losses.
47 Sum,
48 /// Return the unreduced loss tensor.
49 None,
50}
51
52/// The trait that all neural network layers implement.
53///
54/// Requires `Send + Sync` to match `Tensor<T>`'s thread-safety guarantees.
55pub trait Module<T: Float>: Send + Sync {
56 /// Forward pass. Takes input tensor, returns output tensor.
57 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>;
58
59 /// Iterate over all learnable parameters.
60 fn parameters(&self) -> Vec<&Parameter<T>>;
61
62 /// Iterate over all learnable parameters mutably.
63 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>;
64
65 /// Named parameters for state dict serialization.
66 ///
67 /// Keys use dot-separated paths for nested modules
68 /// (e.g., `"layer1.weight"`, `"layer1.bias"`).
69 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>;
70
71 /// Set training mode. Affects dropout, batchnorm, etc.
72 fn train(&mut self);
73
74 /// Set evaluation mode.
75 fn eval(&mut self);
76
77 /// Whether the module is in training mode.
78 fn is_training(&self) -> bool;
79
80 /// Move all parameters and buffers to a device.
81 ///
82 /// Default implementation iterates `parameters_mut()` and `buffers_mut()`
83 /// and transfers each.
84 fn to_device(&mut self, device: Device) -> FerrotorchResult<()> {
85 for param in self.parameters_mut() {
86 *param = param.to(device)?;
87 }
88 for buffer in self.buffers_mut() {
89 *buffer = buffer.to(device)?;
90 }
91 Ok(())
92 }
93
94 /// Export parameters and buffers as a state dict (torch parity).
95 ///
96 /// Buffers are included alongside parameters since both are persistent
97 /// module state. Keys are dot-separated paths.
98 fn state_dict(&self) -> StateDict<T> {
99 let mut out: StateDict<T> = self
100 .named_parameters()
101 .into_iter()
102 .map(|(name, param)| (name, param.tensor().clone()))
103 .collect();
104 for (name, buffer) in self.named_buffers() {
105 out.insert(name, buffer.tensor().clone());
106 }
107 out
108 }
109
110 // -----------------------------------------------------------------
111 // Buffers — non-trainable persistent state. (#583)
112 // -----------------------------------------------------------------
113
114 /// Iterate over all non-trainable buffers (e.g. running mean / variance
115 /// in BatchNorm). Default returns empty — concrete modules with buffers
116 /// override.
117 fn buffers(&self) -> Vec<&Buffer<T>> {
118 Vec::new()
119 }
120
121 /// Mutable iteration over all buffers. Default returns empty.
122 fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>> {
123 Vec::new()
124 }
125
126 /// Named buffers (dot-separated paths for nested modules). Default
127 /// returns empty.
128 fn named_buffers(&self) -> Vec<(String, &Buffer<T>)> {
129 Vec::new()
130 }
131
132 /// Downcast hook for type-erased buffer-loader dispatch. (#984)
133 ///
134 /// Returns `Some(&self as &dyn Any)` for concrete module types whose
135 /// non-`Buffer<T>` persistent state needs to be applied from a state
136 /// dict (currently `BatchNorm1d` / `BatchNorm2d` / `BatchNorm3d`'s
137 /// running mean / variance / `num_batches_tracked` — see Phase 2
138 /// of the value-parity pipeline in `ferrotorch-vision/tests`).
139 ///
140 /// The default returns `None`, so existing modules are unaffected:
141 /// type-erased callers walking `named_modules()` will simply skip
142 /// modules that do not opt in. Implementors MUST return
143 /// `Some(self)`; returning `Some` for an unrelated `Any` would
144 /// violate the contract.
145 ///
146 /// Why a downcast hook instead of a wider trait surface (e.g. a
147 /// dedicated `set_buffer_value(&self, &str, &Tensor<T>)` method on
148 /// `Module`)? Buffers carrying torch-shaped state (running mean /
149 /// variance, `num_batches_tracked: usize`) currently live outside
150 /// the [`Buffer<T>`] abstraction (BN keeps `Mutex<Vec<f64>>` for
151 /// numerical stability and the integer counter has no `Buffer`
152 /// at all), so a single typed setter on `Module` would force a
153 /// premature unification that #984 explicitly defers. The downcast
154 /// hook keeps `Module` free of BN-specific shape and lets concrete
155 /// modules expose their own typed setters at full precision.
156 fn as_any(&self) -> Option<&dyn std::any::Any> {
157 None
158 }
159
160 // -----------------------------------------------------------------
161 // Submodule iteration. (#583)
162 // -----------------------------------------------------------------
163
164 /// Direct child modules. Default returns empty (leaf module).
165 fn children(&self) -> Vec<&dyn Module<T>> {
166 Vec::new()
167 }
168
169 /// Direct child modules with their attribute names. Default returns
170 /// empty.
171 fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
172 Vec::new()
173 }
174
175 /// All modules in this subtree, depth-first (self first, then each
176 /// child's descendants in order).
177 ///
178 /// Requires `Self: Sized` so we can coerce `self` to `&dyn Module<T>`.
179 /// Trait-object callers can use [`Module::descendants_dyn`] (which yields
180 /// descendants only) and prepend their own reference.
181 fn modules(&self) -> Vec<&dyn Module<T>>
182 where
183 Self: Sized,
184 {
185 let mut out: Vec<&dyn Module<T>> = vec![self];
186 out.extend(self.descendants_dyn());
187 out
188 }
189
190 /// All strict descendants of `self` in depth-first order. Object-safe.
191 fn descendants_dyn(&self) -> Vec<&dyn Module<T>> {
192 let mut out: Vec<&dyn Module<T>> = Vec::new();
193 for child in self.children() {
194 out.push(child);
195 out.extend(child.descendants_dyn());
196 }
197 out
198 }
199
200 /// All modules in this subtree with dot-separated path names. The root
201 /// is named `""`; children paths are joined with `.`.
202 fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>
203 where
204 Self: Sized,
205 {
206 let mut out: Vec<(String, &dyn Module<T>)> = vec![(String::new(), self)];
207 out.extend(self.named_descendants_dyn());
208 out
209 }
210
211 /// Strict descendants with dot-paths. Object-safe.
212 fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)> {
213 let mut out: Vec<(String, &dyn Module<T>)> = Vec::new();
214 for (name, child) in self.named_children() {
215 out.push((name.clone(), child));
216 for (sub_name, sub_module) in child.named_descendants_dyn() {
217 let full = if sub_name.is_empty() {
218 name.clone()
219 } else if name.is_empty() {
220 // Transparent wrapper: parent exposes child under "" so
221 // the child's own naming (e.g. `("backbone", inner)`)
222 // becomes the canonical path. Without this branch the
223 // walker would produce `".backbone.X"` (leading dot),
224 // mismatching state-dict keys like `"backbone.X"`.
225 // See #1142 for the DeepLabV3 BN-buffer routing bug
226 // that this branch fixes.
227 sub_name
228 } else {
229 format!("{name}.{sub_name}")
230 };
231 out.push((full, sub_module));
232 }
233 }
234 out
235 }
236
237 // -----------------------------------------------------------------
238 // Helpers. (#583)
239 // -----------------------------------------------------------------
240
241 // -----------------------------------------------------------------
242 // Hooks (#606)
243 //
244 // These consume `self` and return a [`HookedModule<Self, T>`] with the
245 // requested hook already registered. Mirrors `torch.nn.Module
246 // .register_*_hook(...)` ergonomically — callers no longer need to
247 // wrap manually with `HookedModule::new(..)` first. Gated on
248 // `Self: Sized` so the trait stays dyn-compatible.
249 //
250 // Named with the `with_*` prefix (rather than `register_*` directly) to
251 // avoid clashing with `HookedModule`'s own inherent `register_*` methods,
252 // which take `&self` and append a hook to an already-wrapped instance.
253 // The two surfaces compose: `Linear::new(..)?.with_forward_hook(h1).0`
254 // is a `HookedModule` that can `.register_forward_hook(h2)` again.
255 // -----------------------------------------------------------------
256
257 /// Wrap this module in a [`HookedModule`] and register a forward hook.
258 /// Returns the wrapper paired with a [`HookHandle`] that can be used to
259 /// remove the hook later. The wrapper implements `Module<T>` itself, so
260 /// it slots into any place the original module did. Mirrors
261 /// `torch.nn.Module.register_forward_hook`.
262 fn with_forward_hook(self, hook: ForwardHook<T>) -> (HookedModule<Self, T>, HookHandle)
263 where
264 Self: Sized,
265 {
266 let wrapped = HookedModule::new(self);
267 let handle = wrapped.register_forward_hook(hook);
268 (wrapped, handle)
269 }
270
271 /// Wrap this module in a [`HookedModule`] and register a forward
272 /// pre-hook. See [`Self::with_forward_hook`]. Mirrors
273 /// `torch.nn.Module.register_forward_pre_hook`.
274 fn with_forward_pre_hook(self, hook: ForwardPreHook<T>) -> (HookedModule<Self, T>, HookHandle)
275 where
276 Self: Sized,
277 {
278 let wrapped = HookedModule::new(self);
279 let handle = wrapped.register_forward_pre_hook(hook);
280 (wrapped, handle)
281 }
282
283 /// Wrap this module in a [`HookedModule`] and register a backward hook.
284 /// See [`Self::with_forward_hook`]. Mirrors
285 /// `torch.nn.Module.register_backward_hook`.
286 fn with_backward_hook(self, hook: BackwardHook<T>) -> (HookedModule<Self, T>, HookHandle)
287 where
288 Self: Sized,
289 {
290 let wrapped = HookedModule::new(self);
291 let handle = wrapped.register_backward_hook(hook);
292 (wrapped, handle)
293 }
294
295 /// Set the gradient of every parameter to `None`.
296 ///
297 /// Equivalent to calling `tensor.zero_grad()` on each parameter's
298 /// underlying tensor. Mirrors `torch.nn.Module.zero_grad`.
299 fn zero_grad(&self) -> FerrotorchResult<()> {
300 for param in self.parameters() {
301 param.tensor().zero_grad()?;
302 }
303 Ok(())
304 }
305
306 /// Toggle `requires_grad` on every parameter (freeze / unfreeze the
307 /// module). Mirrors `torch.nn.Module.requires_grad_`.
308 fn requires_grad_(&mut self, requires_grad: bool) {
309 for param in self.parameters_mut() {
310 param.set_requires_grad(requires_grad);
311 }
312 }
313
314 /// Apply a function to every parameter in this module. Mirrors
315 /// `torch.nn.Module.apply` for the parameter case (true `apply` recurses
316 /// over all submodules; the recursive form requires `&mut dyn Module`
317 /// which conflicts with this trait's `&mut self` borrow).
318 ///
319 /// Takes `&mut dyn FnMut(...)` (rather than a generic closure) so the
320 /// trait stays dyn-compatible — `Box<dyn Module<T>>` is a common usage
321 /// pattern.
322 fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>)) {
323 for param in self.parameters_mut() {
324 f(param);
325 }
326 }
327
328 /// Load parameters from a state dict.
329 ///
330 /// When `strict` is `true` (default), unexpected keys are an error.
331 /// When `false`, unexpected keys are silently ignored and missing
332 /// keys leave existing parameter values unchanged.
333 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
334 // Known keys: union of parameter and buffer paths.
335 let mut known_keys: std::collections::HashSet<String> = self
336 .named_parameters()
337 .iter()
338 .map(|(k, _)| k.clone())
339 .collect();
340 for (k, _) in self.named_buffers() {
341 known_keys.insert(k);
342 }
343
344 if strict {
345 for key in state.keys() {
346 if !known_keys.contains(key) {
347 return Err(FerrotorchError::InvalidArgument {
348 message: format!("unexpected key in state_dict: \"{key}\""),
349 });
350 }
351 }
352 }
353
354 // We need mutable access to parameters. Use named_parameters to get
355 // the mapping, then parameters_mut to actually update.
356 // This two-pass approach avoids borrowing issues.
357 let param_names: Vec<String> = self
358 .named_parameters()
359 .into_iter()
360 .map(|(name, _)| name)
361 .collect();
362
363 let params_mut = self.parameters_mut();
364
365 for (name, param) in param_names.iter().zip(params_mut) {
366 if let Some(tensor) = state.get(name) {
367 if param.shape() != tensor.shape() {
368 return Err(FerrotorchError::ShapeMismatch {
369 message: format!(
370 "state_dict shape mismatch for \"{name}\": expected {:?}, got {:?}",
371 param.shape(),
372 tensor.shape()
373 ),
374 });
375 }
376 // Replace the parameter data with the loaded tensor.
377 *param = Parameter::new(tensor.clone());
378 } else if strict {
379 return Err(FerrotorchError::InvalidArgument {
380 message: format!("missing key in state_dict: \"{name}\""),
381 });
382 }
383 }
384
385 // Same dance for buffers.
386 let buffer_names: Vec<String> = self
387 .named_buffers()
388 .into_iter()
389 .map(|(name, _)| name)
390 .collect();
391 let buffers_mut = self.buffers_mut();
392 for (name, buf) in buffer_names.iter().zip(buffers_mut) {
393 if let Some(tensor) = state.get(name) {
394 if buf.shape() != tensor.shape() {
395 return Err(FerrotorchError::ShapeMismatch {
396 message: format!(
397 "state_dict shape mismatch for buffer \"{name}\": expected {:?}, got {:?}",
398 buf.shape(),
399 tensor.shape()
400 ),
401 });
402 }
403 *buf = Buffer::new(tensor.clone());
404 } else if strict {
405 return Err(FerrotorchError::InvalidArgument {
406 message: format!("missing buffer key in state_dict: \"{name}\""),
407 });
408 }
409 }
410
411 Ok(())
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 /// A minimal test module with one parameter.
419 struct SimpleModule<T: Float> {
420 weight: Parameter<T>,
421 training: bool,
422 }
423
424 impl<T: Float> SimpleModule<T> {
425 fn new(size: usize) -> FerrotorchResult<Self> {
426 Ok(Self {
427 weight: Parameter::zeros(&[size])?,
428 training: true,
429 })
430 }
431 }
432
433 impl<T: Float> Module<T> for SimpleModule<T> {
434 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
435 // Just return input for testing.
436 Ok(input.clone())
437 }
438
439 fn parameters(&self) -> Vec<&Parameter<T>> {
440 vec![&self.weight]
441 }
442
443 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
444 vec![&mut self.weight]
445 }
446
447 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
448 vec![("weight".to_string(), &self.weight)]
449 }
450
451 fn train(&mut self) {
452 self.training = true;
453 }
454
455 fn eval(&mut self) {
456 self.training = false;
457 }
458
459 fn is_training(&self) -> bool {
460 self.training
461 }
462 }
463
464 #[test]
465 fn test_module_parameters() {
466 let m = SimpleModule::<f32>::new(5).unwrap();
467 assert_eq!(m.parameters().len(), 1);
468 assert_eq!(m.parameters()[0].shape(), &[5]);
469 }
470
471 #[test]
472 fn test_module_named_parameters() {
473 let m = SimpleModule::<f32>::new(3).unwrap();
474 let named = m.named_parameters();
475 assert_eq!(named.len(), 1);
476 assert_eq!(named[0].0, "weight");
477 }
478
479 #[test]
480 fn test_module_train_eval() {
481 let mut m = SimpleModule::<f32>::new(2).unwrap();
482 assert!(m.is_training());
483 m.eval();
484 assert!(!m.is_training());
485 m.train();
486 assert!(m.is_training());
487 }
488
489 #[test]
490 fn test_module_state_dict_roundtrip() {
491 let m = SimpleModule::<f32>::new(4).unwrap();
492 let sd = m.state_dict();
493 assert!(sd.contains_key("weight"));
494 assert_eq!(sd["weight"].shape(), &[4]);
495
496 let mut m2 = SimpleModule::<f32>::new(4).unwrap();
497 m2.load_state_dict(&sd, true).unwrap();
498 }
499
500 #[test]
501 fn test_module_state_dict_strict_extra_key() {
502 let mut m = SimpleModule::<f32>::new(3).unwrap();
503 let mut sd = HashMap::new();
504 sd.insert(
505 "weight".to_string(),
506 ferrotorch_core::zeros::<f32>(&[3]).unwrap(),
507 );
508 sd.insert(
509 "extra".to_string(),
510 ferrotorch_core::zeros::<f32>(&[1]).unwrap(),
511 );
512
513 assert!(m.load_state_dict(&sd, true).is_err());
514 assert!(m.load_state_dict(&sd, false).is_ok());
515 }
516
517 #[test]
518 fn test_module_state_dict_shape_mismatch() {
519 let mut m = SimpleModule::<f32>::new(3).unwrap();
520 let mut sd = HashMap::new();
521 sd.insert(
522 "weight".to_string(),
523 ferrotorch_core::zeros::<f32>(&[5]).unwrap(),
524 );
525
526 assert!(m.load_state_dict(&sd, true).is_err());
527 }
528
529 #[test]
530 fn test_module_is_send_sync() {
531 fn assert_send_sync<T: Send + Sync>() {}
532 assert_send_sync::<SimpleModule<f32>>();
533 }
534
535 #[test]
536 fn test_reduction_enum() {
537 assert_eq!(Reduction::Mean, Reduction::Mean);
538 assert_ne!(Reduction::Mean, Reduction::Sum);
539 }
540
541 #[test]
542 fn test_to_device_cpu_preserves_weights() {
543 let mut m = SimpleModule::<f32>::new(4).unwrap();
544 m.to_device(ferrotorch_core::Device::Cpu).unwrap();
545 assert_eq!(m.parameters().len(), 1);
546 assert_eq!(m.parameters()[0].shape(), &[4]);
547 }
548
549 #[test]
550 fn test_to_device_cuda_without_backend() {
551 let mut m = SimpleModule::<f32>::new(3).unwrap();
552 let result = m.to_device(ferrotorch_core::Device::Cuda(0));
553 assert!(result.is_err());
554 }
555
556 // -----------------------------------------------------------------------
557 // Module trait additions: buffers / children / zero_grad / requires_grad_ /
558 // apply_to_parameters / modules iteration. (#583)
559 // -----------------------------------------------------------------------
560
561 /// A module with one parameter, one buffer, and a child.
562 struct ParentModule<T: Float> {
563 weight: Parameter<T>,
564 running_mean: Buffer<T>,
565 child: SimpleModule<T>,
566 }
567
568 impl<T: Float> ParentModule<T> {
569 fn new() -> FerrotorchResult<Self> {
570 Ok(Self {
571 weight: Parameter::ones(&[2, 2])?,
572 running_mean: Buffer::zeros(&[2])?,
573 child: SimpleModule::new(3)?,
574 })
575 }
576 }
577
578 impl<T: Float> Module<T> for ParentModule<T> {
579 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
580 self.child.forward(input)
581 }
582
583 fn parameters(&self) -> Vec<&Parameter<T>> {
584 // self.weight + child.parameters()
585 let mut out: Vec<&Parameter<T>> = vec![&self.weight];
586 out.extend(self.child.parameters());
587 out
588 }
589
590 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
591 let mut out: Vec<&mut Parameter<T>> = vec![&mut self.weight];
592 out.extend(self.child.parameters_mut());
593 out
594 }
595
596 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
597 let mut out: Vec<(String, &Parameter<T>)> = vec![("weight".to_string(), &self.weight)];
598 for (n, p) in self.child.named_parameters() {
599 out.push((format!("child.{n}"), p));
600 }
601 out
602 }
603
604 fn buffers(&self) -> Vec<&Buffer<T>> {
605 vec![&self.running_mean]
606 }
607
608 fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>> {
609 vec![&mut self.running_mean]
610 }
611
612 fn named_buffers(&self) -> Vec<(String, &Buffer<T>)> {
613 vec![("running_mean".to_string(), &self.running_mean)]
614 }
615
616 fn children(&self) -> Vec<&dyn Module<T>> {
617 vec![&self.child]
618 }
619
620 fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
621 vec![("child".to_string(), &self.child)]
622 }
623
624 fn train(&mut self) {
625 self.child.train();
626 }
627
628 fn eval(&mut self) {
629 self.child.eval();
630 }
631
632 fn is_training(&self) -> bool {
633 self.child.is_training()
634 }
635 }
636
637 #[test]
638 fn module_buffers_default_is_empty() {
639 // SimpleModule doesn't override buffers() — default impl returns empty.
640 let m = SimpleModule::<f32>::new(3).unwrap();
641 assert!(m.buffers().is_empty());
642 assert!(m.named_buffers().is_empty());
643 }
644
645 #[test]
646 fn module_buffers_listed_for_overriding_module() {
647 let m = ParentModule::<f32>::new().unwrap();
648 assert_eq!(m.buffers().len(), 1);
649 assert_eq!(m.buffers()[0].shape(), &[2]);
650 let nb = m.named_buffers();
651 assert_eq!(nb.len(), 1);
652 assert_eq!(nb[0].0, "running_mean");
653 }
654
655 #[test]
656 fn module_children_listed_for_parent() {
657 let m = ParentModule::<f32>::new().unwrap();
658 assert_eq!(m.children().len(), 1);
659 assert_eq!(m.named_children().len(), 1);
660 assert_eq!(m.named_children()[0].0, "child");
661 }
662
663 #[test]
664 fn module_named_modules_includes_self_and_descendants() {
665 let m = ParentModule::<f32>::new().unwrap();
666 let nm = m.named_modules();
667 // Root + 1 child = 2 entries.
668 assert_eq!(nm.len(), 2);
669 assert_eq!(nm[0].0, "");
670 assert_eq!(nm[1].0, "child");
671 }
672
673 #[test]
674 fn module_modules_includes_self_and_descendants() {
675 let m = ParentModule::<f32>::new().unwrap();
676 let mods = m.modules();
677 assert_eq!(mods.len(), 2);
678 }
679
680 #[test]
681 fn module_zero_grad_succeeds() {
682 // No grads yet on a fresh module — zero_grad should still succeed.
683 let m = SimpleModule::<f32>::new(3).unwrap();
684 m.zero_grad().unwrap();
685 }
686
687 #[test]
688 fn module_requires_grad_toggles_all_parameters() {
689 let mut m = ParentModule::<f32>::new().unwrap();
690 for p in m.parameters() {
691 assert!(p.requires_grad());
692 }
693 m.requires_grad_(false);
694 for p in m.parameters() {
695 assert!(!p.requires_grad());
696 }
697 m.requires_grad_(true);
698 for p in m.parameters() {
699 assert!(p.requires_grad());
700 }
701 }
702
703 #[test]
704 fn module_apply_to_parameters_visits_all() {
705 let mut m = ParentModule::<f32>::new().unwrap();
706 let n_params = m.parameters().len();
707 let mut count = 0;
708 m.apply_to_parameters(&mut |_p| count += 1);
709 assert_eq!(count, n_params);
710 }
711
712 #[test]
713 fn module_state_dict_includes_buffers() {
714 let m = ParentModule::<f32>::new().unwrap();
715 let sd = m.state_dict();
716 assert!(sd.contains_key("weight"));
717 assert!(sd.contains_key("running_mean"));
718 assert!(sd.contains_key("child.weight"));
719 assert_eq!(sd.len(), 3);
720 }
721
722 #[test]
723 fn module_load_state_dict_with_buffer() {
724 let mut m = ParentModule::<f32>::new().unwrap();
725 let mut sd: StateDict<f32> = HashMap::new();
726 sd.insert(
727 "weight".into(),
728 ferrotorch_core::ones::<f32>(&[2, 2]).unwrap(),
729 );
730 sd.insert(
731 "running_mean".into(),
732 ferrotorch_core::from_slice::<f32>(&[7.0, 9.0], &[2]).unwrap(),
733 );
734 sd.insert(
735 "child.weight".into(),
736 ferrotorch_core::zeros::<f32>(&[3]).unwrap(),
737 );
738 m.load_state_dict(&sd, true).unwrap();
739 assert_eq!(m.buffers()[0].data().unwrap(), &[7.0, 9.0]);
740 }
741
742 #[test]
743 fn module_descendants_dyn_excludes_self() {
744 let m = ParentModule::<f32>::new().unwrap();
745 let d = m.descendants_dyn();
746 assert_eq!(d.len(), 1);
747 }
748
749 #[test]
750 fn module_named_descendants_dyn_paths() {
751 let m = ParentModule::<f32>::new().unwrap();
752 let nd = m.named_descendants_dyn();
753 assert_eq!(nd.len(), 1);
754 assert_eq!(nd[0].0, "child");
755 }
756
757 /// #1142 regression lock: a transparent wrapper module that exposes
758 /// its inner child at path `""` must NOT prepend a leading `.` to
759 /// the child's own descendant paths.
760 ///
761 /// The DeepLabV3 model uses this idiom — `DeepLabV3::named_children`
762 /// returns `("", &backbone)` and `ResNet50Dilated::named_children`
763 /// returns `("backbone", &inner)`. Pre-#1142 the walker produced
764 /// `".backbone"`, mismatching state-dict keys like `"backbone.bn1.X"`
765 /// and silently failing every BN-buffer load on DeepLabV3's backbone.
766 #[test]
767 fn module_named_descendants_dyn_empty_parent_no_leading_dot() {
768 /// Wraps a `ParentModule` at the empty path. The descendant walker
769 /// must reach `ParentModule`'s `("child", _)` entry as the bare
770 /// path `"child"`, not `".child"`.
771 struct TransparentWrapper<T: Float> {
772 inner: ParentModule<T>,
773 }
774 impl<T: Float> Module<T> for TransparentWrapper<T> {
775 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
776 self.inner.forward(input)
777 }
778 fn parameters(&self) -> Vec<&Parameter<T>> {
779 self.inner.parameters()
780 }
781 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
782 self.inner.parameters_mut()
783 }
784 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
785 self.inner.named_parameters()
786 }
787 fn children(&self) -> Vec<&dyn Module<T>> {
788 vec![&self.inner]
789 }
790 fn named_children(&self) -> Vec<(String, &dyn Module<T>)> {
791 vec![(String::new(), &self.inner)]
792 }
793 fn train(&mut self) {
794 self.inner.train();
795 }
796 fn eval(&mut self) {
797 self.inner.eval();
798 }
799 fn is_training(&self) -> bool {
800 self.inner.is_training()
801 }
802 }
803 let m = TransparentWrapper::<f32> {
804 inner: ParentModule::new().unwrap(),
805 };
806 let nd: Vec<String> = m
807 .named_descendants_dyn()
808 .into_iter()
809 .map(|(n, _)| n)
810 .collect();
811 // 2 entries: ("" -> inner) and ("child" -> grandchild).
812 assert_eq!(nd, vec![String::new(), "child".to_string()]);
813 for p in &nd {
814 assert!(
815 !p.starts_with('.'),
816 "transparent-wrapper descendant path '{p}' starts with '.'; \
817 the empty-parent branch in named_descendants_dyn has regressed",
818 );
819 }
820 }
821
822 // -------------------------------------------------------------------
823 // Hook-registration trait methods (#606)
824 // -------------------------------------------------------------------
825
826 #[test]
827 fn with_forward_hook_wraps_and_fires() {
828 use std::sync::atomic::{AtomicUsize, Ordering};
829 let m = SimpleModule::<f32>::new(2).unwrap();
830 let counter = std::sync::Arc::new(AtomicUsize::new(0));
831 let counter_for_hook = std::sync::Arc::clone(&counter);
832
833 let (wrapped, _handle) = m.with_forward_hook(Box::new(move |_input, _output| {
834 counter_for_hook.fetch_add(1, Ordering::SeqCst);
835 }));
836
837 let input = ferrotorch_core::Tensor::from_storage(
838 ferrotorch_core::TensorStorage::cpu(vec![1.0_f32, 2.0]),
839 vec![2],
840 false,
841 )
842 .unwrap();
843 let _ = wrapped.forward(&input).unwrap();
844 assert_eq!(counter.load(Ordering::SeqCst), 1);
845 }
846
847 #[test]
848 fn with_forward_pre_hook_wraps_and_fires() {
849 use std::sync::atomic::{AtomicUsize, Ordering};
850 let m = SimpleModule::<f32>::new(2).unwrap();
851 let counter = std::sync::Arc::new(AtomicUsize::new(0));
852 let counter_for_hook = std::sync::Arc::clone(&counter);
853
854 let (wrapped, _handle) = m.with_forward_pre_hook(Box::new(move |input| {
855 counter_for_hook.fetch_add(1, Ordering::SeqCst);
856 Ok(input.clone())
857 }));
858
859 let input = ferrotorch_core::Tensor::from_storage(
860 ferrotorch_core::TensorStorage::cpu(vec![1.0_f32, 2.0]),
861 vec![2],
862 false,
863 )
864 .unwrap();
865 let _ = wrapped.forward(&input).unwrap();
866 assert_eq!(counter.load(Ordering::SeqCst), 1);
867 }
868
869 #[test]
870 fn with_backward_hook_returns_handle() {
871 // backward hook fires only on the backward pass; just verify the
872 // wrapping API resolves and returns a usable HookedModule + handle.
873 let m = SimpleModule::<f32>::new(2).unwrap();
874 let (wrapped, handle) = m.with_backward_hook(Box::new(|_gi, _go| {}));
875 // Wrapper still implements Module<T> trait — slot it into a forward
876 // call to confirm it round-trips.
877 let input = ferrotorch_core::Tensor::from_storage(
878 ferrotorch_core::TensorStorage::cpu(vec![3.0_f32]),
879 vec![1],
880 false,
881 )
882 .unwrap();
883 let _ = wrapped.forward(&input).unwrap();
884 // Handle is droppable; explicit remove is also fine.
885 handle.remove();
886 }
887}