Skip to main content

Crate ferrotorch_nn

Crate ferrotorch_nn 

Source
Expand description

§ferrotorch-nn — crate root

Declares every per-module file, re-exports the canonical public surface (layer types, Module trait, Parameter, Buffer, container types, gradient-clipping helpers, the Module derive macro), and provides the prelude module that mirrors from torch import nn ergonomics.

§REQ status (per .design/ferrotorch-nn/lib.md)

REQStatusEvidence
REQ-1SHIPPEDCrate-wide #![warn(clippy::all, clippy::pedantic)] + #![deny(rust_2018_idioms)] baseline at the top of lib.rs; cargo clippy -p ferrotorch-nn --lib -- -D warnings enforces on every build.
REQ-2SHIPPED31 pub mod declarations cover every per-layer file; cargo check -p ferrotorch-nn fails if any module file is missing.
REQ-3SHIPPEDFlat pub use re-exports surface every layer + utility name at crate root, mirroring torch/nn/__init__.py:11-50; consumed by ferrotorch-optim/src/optimizer.rs (line 5) use ferrotorch_nn::Parameter and every model crate.
REQ-4SHIPPEDpub use ferrotorch_nn_derive::Module republishes the derive macro under the trait’s name (separate namespaces); consumed by every #[derive(Module)] site in downstream layer code.
REQ-5SHIPPEDpub mod prelude collects core abstractions + standard layers + canonical losses + gradient-clipping helpers; consumed by downstream training scripts writing use ferrotorch_nn::prelude::*.
REQ-6SHIPPED#[allow(unused_extern_crates)] extern crate self as ferrotorch_nn; enables the derive macro’s ::ferrotorch_nn::Module hygienic path; consumed implicitly by every #[derive(Module)] macro expansion inside this crate.

Re-exports§

pub use activation::CELU;
pub use activation::ELU;
pub use activation::GELU;
pub use activation::GLU;
pub use activation::HardSigmoid;
pub use activation::HardSwish;
pub use activation::Hardshrink;
pub use activation::Hardtanh;
pub use activation::LeakyReLU;
pub use activation::LogSigmoid;
pub use activation::LogSoftmax;
pub use activation::Mish;
pub use activation::PReLU;
pub use activation::RReLU;
pub use activation::ReLU;
pub use activation::ReLU6;
pub use activation::SELU;
pub use activation::SiLU;
pub use activation::Sigmoid;
pub use activation::Softmax;
pub use activation::Softmax2d;
pub use activation::Softmin;
pub use activation::Softplus;
pub use activation::Softshrink;
pub use activation::Softsign;
pub use activation::Tanh;
pub use activation::Tanhshrink;
pub use activation::Threshold;
pub use attention::MultiheadAttention;
pub use attention::repeat_kv;
pub use attention::reshape_to_heads;
pub use attention::transpose_heads_to_2d;
pub use container::ModuleDict;
pub use container::ModuleList;
pub use container::Sequential;
pub use conv::Conv1d;
pub use conv::Conv2d;
pub use conv::Conv3d;
pub use conv::ConvTranspose1d;
pub use conv::ConvTranspose2d;
pub use conv::ConvTranspose3d;
pub use conv::StringPadding;
pub use dropout::AlphaDropout;
pub use dropout::Dropout;
pub use dropout::Dropout1d;
pub use dropout::Dropout2d;
pub use dropout::Dropout3d;
pub use dropout::FeatureAlphaDropout;
pub use embedding::Embedding;
pub use embedding::EmbeddingBag;
pub use embedding::EmbeddingBagMode;
pub use flash_attention::flash_attention;
pub use flash_attention::standard_attention;
pub use flex_attention::BlockMask;
pub use flex_attention::alibi_score_mod;
pub use flex_attention::causal_score_mod;
pub use flex_attention::flex_attention;
pub use flex_attention::relative_position_bias_score_mod;
pub use hooks::BackwardHook;
pub use hooks::ForwardHook;
pub use hooks::ForwardPreHook;
pub use hooks::HookHandle;
pub use hooks::HookedModule;
pub use identity::ChannelShuffle;
pub use identity::CosineSimilarity;
pub use identity::Flatten;
pub use identity::Identity;
pub use identity::PairwiseDistance;
pub use identity::Unflatten;
pub use init::FanMode;
pub use init::NonLinearity;
pub use lazy_conv::LazyConv1d;
pub use lazy_conv::LazyConv2d;
pub use lazy_conv::LazyConv3d;
pub use lazy_linear::LazyLinear;
pub use linear::Bilinear;
pub use linear::Linear;
pub use lora::LoRALinear;
pub use loss::BCELoss;
pub use loss::BCEWithLogitsLoss;
pub use loss::CTCLoss;
pub use loss::CosineEmbeddingLoss;
pub use loss::CrossEntropyLoss;
pub use loss::GaussianNLLLoss;
pub use loss::HingeEmbeddingLoss;
pub use loss::HuberLoss;
pub use loss::KLDivLoss;
pub use loss::L1Loss;
pub use loss::MSELoss;
pub use loss::MarginRankingLoss;
pub use loss::MultiLabelSoftMarginLoss;
pub use loss::MultiMarginLoss;
pub use loss::NLLLoss;
pub use loss::PoissonNLLLoss;
pub use loss::SmoothL1Loss;
pub use loss::TripletMarginLoss;
pub use module::Module;
pub use module::Reduction;
pub use module::StateDict;
pub use buffer::Buffer;
pub use norm::BatchNorm1d;
pub use norm::BatchNorm2d;
pub use norm::BatchNorm3d;
pub use norm::GroupNorm;
pub use norm::InstanceNorm1d;
pub use norm::InstanceNorm2d;
pub use norm::InstanceNorm3d;
pub use norm::LayerNorm;
pub use norm::LocalResponseNorm;
pub use norm::RMSNorm;
pub use padding::CircularPad1d;
pub use padding::CircularPad2d;
pub use padding::CircularPad3d;
pub use padding::ConstantPad1d;
pub use padding::ConstantPad2d;
pub use padding::ConstantPad3d;
pub use padding::PaddingMode;
pub use padding::ReflectionPad1d;
pub use padding::ReflectionPad2d;
pub use padding::ReflectionPad3d;
pub use padding::ReplicationPad1d;
pub use padding::ReplicationPad2d;
pub use padding::ReplicationPad3d;
pub use padding::ZeroPad1d;
pub use padding::ZeroPad2d;
pub use padding::ZeroPad3d;
pub use paged_attention::KVPage;
pub use paged_attention::PagePool;
pub use paged_attention::PagedAttentionManager;
pub use paged_attention::PagedKVCache;
pub use parameter::Parameter;
pub use parameter_container::ParameterDict;
pub use parameter_container::ParameterList;
pub use pooling::AdaptiveAvgPool1d;
pub use pooling::AdaptiveAvgPool2d;
pub use pooling::AdaptiveAvgPool3d;
pub use pooling::AdaptiveMaxPool1d;
pub use pooling::AdaptiveMaxPool2d;
pub use pooling::AdaptiveMaxPool3d;
pub use pooling::AvgPool1d;
pub use pooling::AvgPool2d;
pub use pooling::AvgPool3d;
pub use pooling::FractionalMaxPool2d;
pub use pooling::LPPool1d;
pub use pooling::LPPool2d;
pub use pooling::MaxPool1d;
pub use pooling::MaxPool2d;
pub use pooling::MaxPool3d;
pub use pooling::MaxUnpool2d;
pub use pooling::adaptive_avg_pool1d;
pub use pooling::adaptive_avg_pool2d;
pub use pooling::adaptive_avg_pool3d;
pub use pooling::adaptive_max_pool1d;
pub use pooling::adaptive_max_pool2d;
pub use pooling::adaptive_max_pool3d;
pub use pooling::avg_pool1d;
pub use pooling::avg_pool2d;
pub use pooling::avg_pool3d;
pub use pooling::lp_pool1d;
pub use pooling::lp_pool2d;
pub use pooling::max_pool1d;
pub use pooling::max_pool2d;
pub use pooling::max_pool3d;
pub use pooling::max_unpool2d;
pub use qat::ObserverType;
pub use qat::QatConfig;
pub use qat::QuantizedModel;
pub use qat::prepare_qat;
pub use rnn::GRU;
pub use rnn::GRUCell;
pub use rnn::LSTM;
pub use rnn::LSTMCell;
pub use rnn::RNN;
pub use rnn::RNNCell;
pub use rnn::RNNNonlinearity;
pub use rnn_utils::PackedSequence;
pub use rnn_utils::pack_padded_sequence;
pub use rnn_utils::pad_packed_sequence;
pub use rnn_utils::pad_sequence;
pub use se::SqueezeExcitation;
pub use transformer::KVCache;
pub use transformer::RoPEConvention;
pub use transformer::RoPEScaling;
pub use transformer::RotaryPositionEmbedding;
pub use transformer::SwiGLU;
pub use transformer::Transformer;
pub use transformer::TransformerDecoder;
pub use transformer::TransformerDecoderLayer;
pub use transformer::TransformerEncoder;
pub use transformer::TransformerEncoderLayer;
pub use upsample::Fold;
pub use upsample::GridSampleMode;
pub use upsample::GridSamplePaddingMode;
pub use upsample::InterpolateMode;
pub use upsample::PixelShuffle;
pub use upsample::PixelUnshuffle;
pub use upsample::Unfold;
pub use upsample::Upsample;
pub use upsample::affine_grid;
pub use upsample::fold;
pub use upsample::grid_sample;
pub use upsample::interpolate;
pub use upsample::pixel_shuffle;
pub use upsample::pixel_unshuffle;
pub use upsample::unfold;
pub use utils::clip_grad_norm_;
pub use utils::clip_grad_value_;

Modules§

activation
Activation function wrapper modules.
attention
Multi-head attention layer.
buffer
Buffer<T> — non-trainable persistent module state. (#583)
container
Container modules: Sequential, ModuleList, and ModuleDict.
conv
Convolution layers: 1-D, 2-D, 3-D and their transposed variants.
dropout
Dropout regularization layers.
embedding
Embedding layer: a lookup table of fixed-size vectors.
flash_attention
Memory-efficient FlashAttention (CPU tiled version).
flex_attention
Flexible attention with composable score modification.
functional
Stateless functional API for common neural network operations.
hooks
Forward/backward hooks for Module instances.
identity
Identity and Flatten modules + small shape/distance modules.
init
Weight initialization functions.
lazy_conv
Lazy variants of Conv1d, Conv2d, and Conv3d.
lazy_conv_transpose
Lazy variants of [ConvTranspose{1,2,3}d]. (#622)
lazy_linear
Lazy variants of Linear and convolution layers.
lazy_norm
Lazy normalization modules. (#622)
linear
Fully connected (dense) linear layer: y = input @ weight^T + bias.
lora
Low-Rank Adaptation (LoRA) for parameter-efficient fine-tuning.
loss
Loss functions for training neural networks.
module
Module<T> trait, Reduction, and StateDict<T> — the Rust analog of PyTorch’s torch.nn.Module base class.
norm
Normalization layers: LayerNorm, GroupNorm, RMSNorm, BatchNorm1d/2d/3d, InstanceNorm1d/2d/3d, LocalResponseNorm.
padding
Padding layers: constant, reflection, replication, and zero padding in 1-D, 2-D, 3-D.
paged_attention
PagedAttention — efficient KV cache management for LLM serving.
parameter
Parameter<T> — a trainable tensor wrapper, the Rust analog of torch.nn.Parameter.
parameter_container
Parameter containers: ParameterList and ParameterDict.
pooling
Pooling layers: MaxPool1d/2d/3d, AvgPool1d/2d/3d, AdaptiveAvgPool1d/2d/3d, AdaptiveMaxPool1d/2d/3d, FractionalMaxPool2d, LPPool1d/2d, MaxUnpool2d.
prelude
Glob-import-friendly re-exports of the most commonly used items.
qat
Quantization-aware training (QAT) for ferrotorch nn modules.
rnn
Recurrent neural network modules.
rnn_utils
Utilities for packing and unpacking variable-length sequences.
se
Squeeze-and-Excitation (SE) block — Hu et al. 2018, Squeeze-and-Excitation Networks.
transformer
LLM-critical transformer building blocks.
upsample
Upsample, interpolation, and vision ops.
utils
Gradient clipping utilities.

Structs§

QatModel
Wraps a collection of named weight tensors for quantization-aware training.

Enums§

GeluApproximate
Selects the GELU approximation method.

Derive Macros§

Module
Derive the Module<T> trait for a struct.