pub trait Module<T: Float>: Send + Sync {
Show 26 methods
// Required methods
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>;
fn parameters(&self) -> Vec<&Parameter<T>>;
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>;
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>;
fn train(&mut self);
fn eval(&mut self);
fn is_training(&self) -> bool;
// Provided methods
fn to_device(&mut self, device: Device) -> FerrotorchResult<()> { ... }
fn state_dict(&self) -> StateDict<T> { ... }
fn buffers(&self) -> Vec<&Buffer<T>> { ... }
fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>> { ... }
fn named_buffers(&self) -> Vec<(String, &Buffer<T>)> { ... }
fn as_any(&self) -> Option<&dyn Any> { ... }
fn children(&self) -> Vec<&dyn Module<T>> { ... }
fn named_children(&self) -> Vec<(String, &dyn Module<T>)> { ... }
fn modules(&self) -> Vec<&dyn Module<T>>
where Self: Sized { ... }
fn descendants_dyn(&self) -> Vec<&dyn Module<T>> { ... }
fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>
where Self: Sized { ... }
fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)> { ... }
fn with_forward_hook(
self,
hook: ForwardHook<T>,
) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized { ... }
fn with_forward_pre_hook(
self,
hook: ForwardPreHook<T>,
) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized { ... }
fn with_backward_hook(
self,
hook: BackwardHook<T>,
) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized { ... }
fn zero_grad(&self) -> FerrotorchResult<()> { ... }
fn requires_grad_(&mut self, requires_grad: bool) { ... }
fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>)) { ... }
fn load_state_dict(
&mut self,
state: &StateDict<T>,
strict: bool,
) -> FerrotorchResult<()> { ... }
}Expand description
The trait that all neural network layers implement.
Requires Send + Sync to match Tensor<T>’s thread-safety guarantees.
Required Methods§
Sourcefn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>
Forward pass. Takes input tensor, returns output tensor.
Sourcefn parameters(&self) -> Vec<&Parameter<T>>
fn parameters(&self) -> Vec<&Parameter<T>>
Iterate over all learnable parameters.
Sourcefn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>
Iterate over all learnable parameters mutably.
Sourcefn named_parameters(&self) -> Vec<(String, &Parameter<T>)>
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>
Named parameters for state dict serialization.
Keys use dot-separated paths for nested modules
(e.g., "layer1.weight", "layer1.bias").
Sourcefn is_training(&self) -> bool
fn is_training(&self) -> bool
Whether the module is in training mode.
Provided Methods§
Sourcefn to_device(&mut self, device: Device) -> FerrotorchResult<()>
fn to_device(&mut self, device: Device) -> FerrotorchResult<()>
Move all parameters and buffers to a device.
Default implementation iterates parameters_mut() and buffers_mut()
and transfers each.
Sourcefn state_dict(&self) -> StateDict<T>
fn state_dict(&self) -> StateDict<T>
Export parameters and buffers as a state dict (torch parity).
Buffers are included alongside parameters since both are persistent module state. Keys are dot-separated paths.
Sourcefn buffers(&self) -> Vec<&Buffer<T>>
fn buffers(&self) -> Vec<&Buffer<T>>
Iterate over all non-trainable buffers (e.g. running mean / variance in BatchNorm). Default returns empty — concrete modules with buffers override.
Sourcefn buffers_mut(&mut self) -> Vec<&mut Buffer<T>>
fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>>
Mutable iteration over all buffers. Default returns empty.
Sourcefn named_buffers(&self) -> Vec<(String, &Buffer<T>)>
fn named_buffers(&self) -> Vec<(String, &Buffer<T>)>
Named buffers (dot-separated paths for nested modules). Default returns empty.
Sourcefn as_any(&self) -> Option<&dyn Any>
fn as_any(&self) -> Option<&dyn Any>
Downcast hook for type-erased buffer-loader dispatch. (#984)
Returns Some(&self as &dyn Any) for concrete module types whose
non-Buffer<T> persistent state needs to be applied from a state
dict (currently BatchNorm1d / BatchNorm2d / BatchNorm3d’s
running mean / variance / num_batches_tracked — see Phase 2
of the value-parity pipeline in ferrotorch-vision/tests).
The default returns None, so existing modules are unaffected:
type-erased callers walking named_modules() will simply skip
modules that do not opt in. Implementors MUST return
Some(self); returning Some for an unrelated Any would
violate the contract.
Why a downcast hook instead of a wider trait surface (e.g. a
dedicated set_buffer_value(&self, &str, &Tensor<T>) method on
Module)? Buffers carrying torch-shaped state (running mean /
variance, num_batches_tracked: usize) currently live outside
the Buffer<T> abstraction (BN keeps Mutex<Vec<f64>> for
numerical stability and the integer counter has no Buffer
at all), so a single typed setter on Module would force a
premature unification that #984 explicitly defers. The downcast
hook keeps Module free of BN-specific shape and lets concrete
modules expose their own typed setters at full precision.
Sourcefn children(&self) -> Vec<&dyn Module<T>>
fn children(&self) -> Vec<&dyn Module<T>>
Direct child modules. Default returns empty (leaf module).
Sourcefn named_children(&self) -> Vec<(String, &dyn Module<T>)>
fn named_children(&self) -> Vec<(String, &dyn Module<T>)>
Direct child modules with their attribute names. Default returns empty.
Sourcefn modules(&self) -> Vec<&dyn Module<T>>where
Self: Sized,
fn modules(&self) -> Vec<&dyn Module<T>>where
Self: Sized,
All modules in this subtree, depth-first (self first, then each child’s descendants in order).
Requires Self: Sized so we can coerce self to &dyn Module<T>.
Trait-object callers can use Module::descendants_dyn (which yields
descendants only) and prepend their own reference.
Sourcefn descendants_dyn(&self) -> Vec<&dyn Module<T>>
fn descendants_dyn(&self) -> Vec<&dyn Module<T>>
All strict descendants of self in depth-first order. Object-safe.
Sourcefn named_modules(&self) -> Vec<(String, &dyn Module<T>)>where
Self: Sized,
fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>where
Self: Sized,
All modules in this subtree with dot-separated path names. The root
is named ""; children paths are joined with ..
Sourcefn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)>
fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)>
Strict descendants with dot-paths. Object-safe.
Sourcefn with_forward_hook(
self,
hook: ForwardHook<T>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
fn with_forward_hook(
self,
hook: ForwardHook<T>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
Wrap this module in a HookedModule and register a forward hook.
Returns the wrapper paired with a HookHandle that can be used to
remove the hook later. The wrapper implements Module<T> itself, so
it slots into any place the original module did. Mirrors
torch.nn.Module.register_forward_hook.
Sourcefn with_forward_pre_hook(
self,
hook: ForwardPreHook<T>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
fn with_forward_pre_hook(
self,
hook: ForwardPreHook<T>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
Wrap this module in a HookedModule and register a forward
pre-hook. See Self::with_forward_hook. Mirrors
torch.nn.Module.register_forward_pre_hook.
Sourcefn with_backward_hook(
self,
hook: BackwardHook<T>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
fn with_backward_hook(
self,
hook: BackwardHook<T>,
) -> (HookedModule<Self, T>, HookHandle)where
Self: Sized,
Wrap this module in a HookedModule and register a backward hook.
See Self::with_forward_hook. Mirrors
torch.nn.Module.register_backward_hook.
Sourcefn zero_grad(&self) -> FerrotorchResult<()>
fn zero_grad(&self) -> FerrotorchResult<()>
Set the gradient of every parameter to None.
Equivalent to calling tensor.zero_grad() on each parameter’s
underlying tensor. Mirrors torch.nn.Module.zero_grad.
Sourcefn requires_grad_(&mut self, requires_grad: bool)
fn requires_grad_(&mut self, requires_grad: bool)
Toggle requires_grad on every parameter (freeze / unfreeze the
module). Mirrors torch.nn.Module.requires_grad_.
Sourcefn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>))
fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>))
Apply a function to every parameter in this module. Mirrors
torch.nn.Module.apply for the parameter case (true apply recurses
over all submodules; the recursive form requires &mut dyn Module
which conflicts with this trait’s &mut self borrow).
Takes &mut dyn FnMut(...) (rather than a generic closure) so the
trait stays dyn-compatible — Box<dyn Module<T>> is a common usage
pattern.
Sourcefn load_state_dict(
&mut self,
state: &StateDict<T>,
strict: bool,
) -> FerrotorchResult<()>
fn load_state_dict( &mut self, state: &StateDict<T>, strict: bool, ) -> FerrotorchResult<()>
Load parameters from a state dict.
When strict is true (default), unexpected keys are an error.
When false, unexpected keys are silently ignored and missing
keys leave existing parameter values unchanged.
Dyn Compatibility§
This trait is dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety".