Skip to main content

Module

Trait Module 

Source
pub trait Module<T: Float>: Send + Sync {
Show 26 methods // Required methods fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>; fn parameters(&self) -> Vec<&Parameter<T>>; fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>; fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>; fn train(&mut self); fn eval(&mut self); fn is_training(&self) -> bool; // Provided methods fn to_device(&mut self, device: Device) -> FerrotorchResult<()> { ... } fn state_dict(&self) -> StateDict<T> { ... } fn buffers(&self) -> Vec<&Buffer<T>> { ... } fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>> { ... } fn named_buffers(&self) -> Vec<(String, &Buffer<T>)> { ... } fn as_any(&self) -> Option<&dyn Any> { ... } fn children(&self) -> Vec<&dyn Module<T>> { ... } fn named_children(&self) -> Vec<(String, &dyn Module<T>)> { ... } fn modules(&self) -> Vec<&dyn Module<T>> where Self: Sized { ... } fn descendants_dyn(&self) -> Vec<&dyn Module<T>> { ... } fn named_modules(&self) -> Vec<(String, &dyn Module<T>)> where Self: Sized { ... } fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)> { ... } fn with_forward_hook( self, hook: ForwardHook<T>, ) -> (HookedModule<Self, T>, HookHandle) where Self: Sized { ... } fn with_forward_pre_hook( self, hook: ForwardPreHook<T>, ) -> (HookedModule<Self, T>, HookHandle) where Self: Sized { ... } fn with_backward_hook( self, hook: BackwardHook<T>, ) -> (HookedModule<Self, T>, HookHandle) where Self: Sized { ... } fn zero_grad(&self) -> FerrotorchResult<()> { ... } fn requires_grad_(&mut self, requires_grad: bool) { ... } fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>)) { ... } fn load_state_dict( &mut self, state: &StateDict<T>, strict: bool, ) -> FerrotorchResult<()> { ... }
}
Expand description

The trait that all neural network layers implement.

Requires Send + Sync to match Tensor<T>’s thread-safety guarantees.

Required Methods§

Source

fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>>

Forward pass. Takes input tensor, returns output tensor.

Source

fn parameters(&self) -> Vec<&Parameter<T>>

Iterate over all learnable parameters.

Source

fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>>

Iterate over all learnable parameters mutably.

Source

fn named_parameters(&self) -> Vec<(String, &Parameter<T>)>

Named parameters for state dict serialization.

Keys use dot-separated paths for nested modules (e.g., "layer1.weight", "layer1.bias").

Source

fn train(&mut self)

Set training mode. Affects dropout, batchnorm, etc.

Source

fn eval(&mut self)

Set evaluation mode.

Source

fn is_training(&self) -> bool

Whether the module is in training mode.

Provided Methods§

Source

fn to_device(&mut self, device: Device) -> FerrotorchResult<()>

Move all parameters and buffers to a device.

Default implementation iterates parameters_mut() and buffers_mut() and transfers each.

Source

fn state_dict(&self) -> StateDict<T>

Export parameters and buffers as a state dict (torch parity).

Buffers are included alongside parameters since both are persistent module state. Keys are dot-separated paths.

Source

fn buffers(&self) -> Vec<&Buffer<T>>

Iterate over all non-trainable buffers (e.g. running mean / variance in BatchNorm). Default returns empty — concrete modules with buffers override.

Source

fn buffers_mut(&mut self) -> Vec<&mut Buffer<T>>

Mutable iteration over all buffers. Default returns empty.

Source

fn named_buffers(&self) -> Vec<(String, &Buffer<T>)>

Named buffers (dot-separated paths for nested modules). Default returns empty.

Source

fn as_any(&self) -> Option<&dyn Any>

Downcast hook for type-erased buffer-loader dispatch. (#984)

Returns Some(&self as &dyn Any) for concrete module types whose non-Buffer<T> persistent state needs to be applied from a state dict (currently BatchNorm1d / BatchNorm2d / BatchNorm3d’s running mean / variance / num_batches_tracked — see Phase 2 of the value-parity pipeline in ferrotorch-vision/tests).

The default returns None, so existing modules are unaffected: type-erased callers walking named_modules() will simply skip modules that do not opt in. Implementors MUST return Some(self); returning Some for an unrelated Any would violate the contract.

Why a downcast hook instead of a wider trait surface (e.g. a dedicated set_buffer_value(&self, &str, &Tensor<T>) method on Module)? Buffers carrying torch-shaped state (running mean / variance, num_batches_tracked: usize) currently live outside the Buffer<T> abstraction (BN keeps Mutex<Vec<f64>> for numerical stability and the integer counter has no Buffer at all), so a single typed setter on Module would force a premature unification that #984 explicitly defers. The downcast hook keeps Module free of BN-specific shape and lets concrete modules expose their own typed setters at full precision.

Source

fn children(&self) -> Vec<&dyn Module<T>>

Direct child modules. Default returns empty (leaf module).

Source

fn named_children(&self) -> Vec<(String, &dyn Module<T>)>

Direct child modules with their attribute names. Default returns empty.

Source

fn modules(&self) -> Vec<&dyn Module<T>>
where Self: Sized,

All modules in this subtree, depth-first (self first, then each child’s descendants in order).

Requires Self: Sized so we can coerce self to &dyn Module<T>. Trait-object callers can use Module::descendants_dyn (which yields descendants only) and prepend their own reference.

Source

fn descendants_dyn(&self) -> Vec<&dyn Module<T>>

All strict descendants of self in depth-first order. Object-safe.

Source

fn named_modules(&self) -> Vec<(String, &dyn Module<T>)>
where Self: Sized,

All modules in this subtree with dot-separated path names. The root is named ""; children paths are joined with ..

Source

fn named_descendants_dyn(&self) -> Vec<(String, &dyn Module<T>)>

Strict descendants with dot-paths. Object-safe.

Source

fn with_forward_hook( self, hook: ForwardHook<T>, ) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized,

Wrap this module in a HookedModule and register a forward hook. Returns the wrapper paired with a HookHandle that can be used to remove the hook later. The wrapper implements Module<T> itself, so it slots into any place the original module did. Mirrors torch.nn.Module.register_forward_hook.

Source

fn with_forward_pre_hook( self, hook: ForwardPreHook<T>, ) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized,

Wrap this module in a HookedModule and register a forward pre-hook. See Self::with_forward_hook. Mirrors torch.nn.Module.register_forward_pre_hook.

Source

fn with_backward_hook( self, hook: BackwardHook<T>, ) -> (HookedModule<Self, T>, HookHandle)
where Self: Sized,

Wrap this module in a HookedModule and register a backward hook. See Self::with_forward_hook. Mirrors torch.nn.Module.register_backward_hook.

Source

fn zero_grad(&self) -> FerrotorchResult<()>

Set the gradient of every parameter to None.

Equivalent to calling tensor.zero_grad() on each parameter’s underlying tensor. Mirrors torch.nn.Module.zero_grad.

Source

fn requires_grad_(&mut self, requires_grad: bool)

Toggle requires_grad on every parameter (freeze / unfreeze the module). Mirrors torch.nn.Module.requires_grad_.

Source

fn apply_to_parameters(&mut self, f: &mut dyn FnMut(&mut Parameter<T>))

Apply a function to every parameter in this module. Mirrors torch.nn.Module.apply for the parameter case (true apply recurses over all submodules; the recursive form requires &mut dyn Module which conflicts with this trait’s &mut self borrow).

Takes &mut dyn FnMut(...) (rather than a generic closure) so the trait stays dyn-compatible — Box<dyn Module<T>> is a common usage pattern.

Source

fn load_state_dict( &mut self, state: &StateDict<T>, strict: bool, ) -> FerrotorchResult<()>

Load parameters from a state dict.

When strict is true (default), unexpected keys are an error. When false, unexpected keys are silently ignored and missing keys leave existing parameter values unchanged.

Dyn Compatibility§

This trait is dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety".

Implementors§

Source§

impl<M: Module<T>, T: Float> Module<T> for HookedModule<M, T>

Source§

impl<T: Float> Module<T> for AdaptiveAvgPool1d

Source§

impl<T: Float> Module<T> for AdaptiveAvgPool2d

Source§

impl<T: Float> Module<T> for AdaptiveAvgPool3d

Source§

impl<T: Float> Module<T> for AdaptiveMaxPool1d

Source§

impl<T: Float> Module<T> for AdaptiveMaxPool2d

Source§

impl<T: Float> Module<T> for AdaptiveMaxPool3d

Source§

impl<T: Float> Module<T> for AlphaDropout<T>

Source§

impl<T: Float> Module<T> for AvgPool1d

Source§

impl<T: Float> Module<T> for AvgPool2d

Source§

impl<T: Float> Module<T> for AvgPool3d

Source§

impl<T: Float> Module<T> for BatchNorm1d<T>

Source§

impl<T: Float> Module<T> for BatchNorm2d<T>

Source§

impl<T: Float> Module<T> for BatchNorm3d<T>

Source§

impl<T: Float> Module<T> for Bilinear<T>

Source§

impl<T: Float> Module<T> for CELU

Source§

impl<T: Float> Module<T> for ChannelShuffle

Source§

impl<T: Float> Module<T> for CircularPad1d<T>

Source§

impl<T: Float> Module<T> for CircularPad2d<T>

Source§

impl<T: Float> Module<T> for CircularPad3d<T>

Source§

impl<T: Float> Module<T> for ConstantPad1d<T>

Source§

impl<T: Float> Module<T> for ConstantPad2d<T>

Source§

impl<T: Float> Module<T> for ConstantPad3d<T>

Source§

impl<T: Float> Module<T> for Conv1d<T>

Source§

impl<T: Float> Module<T> for Conv2d<T>

Source§

impl<T: Float> Module<T> for Conv3d<T>

Source§

impl<T: Float> Module<T> for ConvTranspose1d<T>

Source§

impl<T: Float> Module<T> for ConvTranspose2d<T>

Source§

impl<T: Float> Module<T> for ConvTranspose3d<T>

Source§

impl<T: Float> Module<T> for Dropout1d<T>

Source§

impl<T: Float> Module<T> for Dropout2d<T>

Source§

impl<T: Float> Module<T> for Dropout3d<T>

Source§

impl<T: Float> Module<T> for Dropout<T>

Source§

impl<T: Float> Module<T> for ELU

Source§

impl<T: Float> Module<T> for Embedding<T>

Source§

impl<T: Float> Module<T> for EmbeddingBag<T>

Source§

impl<T: Float> Module<T> for FeatureAlphaDropout<T>

Source§

impl<T: Float> Module<T> for Flatten

Source§

impl<T: Float> Module<T> for Fold

Source§

impl<T: Float> Module<T> for FractionalMaxPool2d

Source§

impl<T: Float> Module<T> for GELU

Source§

impl<T: Float> Module<T> for GLU

Source§

impl<T: Float> Module<T> for GRU<T>

Source§

impl<T: Float> Module<T> for GRUCell<T>

Source§

impl<T: Float> Module<T> for GroupNorm<T>

Source§

impl<T: Float> Module<T> for HardSigmoid

Source§

impl<T: Float> Module<T> for HardSwish

Source§

impl<T: Float> Module<T> for Hardshrink

Source§

impl<T: Float> Module<T> for Hardtanh

Source§

impl<T: Float> Module<T> for Identity

Source§

impl<T: Float> Module<T> for InstanceNorm1d<T>

Source§

impl<T: Float> Module<T> for InstanceNorm2d<T>

Source§

impl<T: Float> Module<T> for InstanceNorm3d<T>

Source§

impl<T: Float> Module<T> for LPPool1d

Source§

impl<T: Float> Module<T> for LPPool2d

Source§

impl<T: Float> Module<T> for LSTM<T>

Source§

impl<T: Float> Module<T> for LSTMCell<T>

Source§

impl<T: Float> Module<T> for LayerNorm<T>

Source§

impl<T: Float> Module<T> for LazyBatchNorm1d<T>

Source§

impl<T: Float> Module<T> for LazyBatchNorm2d<T>

Source§

impl<T: Float> Module<T> for LazyBatchNorm3d<T>

Source§

impl<T: Float> Module<T> for LazyConv1d<T>

Source§

impl<T: Float> Module<T> for LazyConv2d<T>

Source§

impl<T: Float> Module<T> for LazyConv3d<T>

Source§

impl<T: Float> Module<T> for LazyConvTranspose1d<T>

Source§

impl<T: Float> Module<T> for LazyConvTranspose2d<T>

Source§

impl<T: Float> Module<T> for LazyConvTranspose3d<T>

Source§

impl<T: Float> Module<T> for LazyInstanceNorm1d<T>

Source§

impl<T: Float> Module<T> for LazyInstanceNorm2d<T>

Source§

impl<T: Float> Module<T> for LazyInstanceNorm3d<T>

Source§

impl<T: Float> Module<T> for LazyLinear<T>

Source§

impl<T: Float> Module<T> for LeakyReLU

Source§

impl<T: Float> Module<T> for Linear<T>

Source§

impl<T: Float> Module<T> for LoRALinear<T>

Source§

impl<T: Float> Module<T> for LocalResponseNorm

Source§

impl<T: Float> Module<T> for LogSigmoid

Source§

impl<T: Float> Module<T> for LogSoftmax

Source§

impl<T: Float> Module<T> for MaxPool1d

Source§

impl<T: Float> Module<T> for MaxPool2d

Source§

impl<T: Float> Module<T> for MaxPool3d

Source§

impl<T: Float> Module<T> for Mish

Source§

impl<T: Float> Module<T> for ModuleDict<T>

Source§

impl<T: Float> Module<T> for ModuleList<T>

Source§

impl<T: Float> Module<T> for MultiheadAttention<T>

Source§

impl<T: Float> Module<T> for PReLU<T>

Source§

impl<T: Float> Module<T> for PixelShuffle

Source§

impl<T: Float> Module<T> for PixelUnshuffle

Source§

impl<T: Float> Module<T> for RMSNorm<T>

Source§

impl<T: Float> Module<T> for RNN<T>

Source§

impl<T: Float> Module<T> for RNNCell<T>

Source§

impl<T: Float> Module<T> for RReLU

Source§

impl<T: Float> Module<T> for ReLU

Source§

impl<T: Float> Module<T> for ReLU6

Source§

impl<T: Float> Module<T> for ReflectionPad1d<T>

Source§

impl<T: Float> Module<T> for ReflectionPad2d<T>

Source§

impl<T: Float> Module<T> for ReflectionPad3d<T>

Source§

impl<T: Float> Module<T> for ReplicationPad1d<T>

Source§

impl<T: Float> Module<T> for ReplicationPad2d<T>

Source§

impl<T: Float> Module<T> for ReplicationPad3d<T>

Source§

impl<T: Float> Module<T> for SELU

Source§

impl<T: Float> Module<T> for Sequential<T>

Source§

impl<T: Float> Module<T> for SiLU

Source§

impl<T: Float> Module<T> for Sigmoid

Source§

impl<T: Float> Module<T> for Softmax

Source§

impl<T: Float> Module<T> for Softmax2d

Source§

impl<T: Float> Module<T> for Softmin

Source§

impl<T: Float> Module<T> for Softplus

Source§

impl<T: Float> Module<T> for Softshrink

Source§

impl<T: Float> Module<T> for Softsign

Source§

impl<T: Float> Module<T> for SqueezeExcitation<T>

Source§

impl<T: Float> Module<T> for SwiGLU<T>

Source§

impl<T: Float> Module<T> for Tanh

Source§

impl<T: Float> Module<T> for Tanhshrink

Source§

impl<T: Float> Module<T> for Threshold

Source§

impl<T: Float> Module<T> for Transformer<T>

Source§

impl<T: Float> Module<T> for TransformerDecoder<T>

Source§

impl<T: Float> Module<T> for TransformerDecoderLayer<T>

Source§

impl<T: Float> Module<T> for TransformerEncoder<T>

Source§

impl<T: Float> Module<T> for TransformerEncoderLayer<T>

Source§

impl<T: Float> Module<T> for Unflatten

Source§

impl<T: Float> Module<T> for Unfold

Source§

impl<T: Float> Module<T> for Upsample

Source§

impl<T: Float> Module<T> for ZeroPad1d<T>

Source§

impl<T: Float> Module<T> for ZeroPad2d<T>

Source§

impl<T: Float> Module<T> for ZeroPad3d<T>