#![warn(clippy::all, clippy::pedantic)]
#![deny(rust_2018_idioms)]
#![allow(missing_docs, missing_debug_implementations)]
#![allow(
// The crate is laid out so submodule names (`module::Module`,
// `parameter::Parameter`, `loss::MSELoss`) match the public type they
// export; renaming would force ergonomic breakage.
clippy::module_name_repetitions,
// # Errors / # Panics sections are added as part of focused passes
// (this audit's Finding #5 covers the high-leverage NotImplementedOnCuda
// sites in loss.rs); a blanket sweep is tracked separately.
clippy::missing_errors_doc,
clippy::missing_panics_doc,
// NN code casts pervasively between `usize` (shape, indices) and
// floating-point (norms, scales) and between `f32`/`f64` (mixed
// precision). The explicit cast is more readable than a `cast()` call
// through num-traits in arithmetic-heavy kernels.
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::cast_precision_loss,
clippy::cast_lossless,
// `#[must_use]` on every getter is churn for marginal value; callers
clippy::must_use_candidate,
clippy::return_self_not_must_use,
clippy::doc_markdown,
clippy::items_after_statements,
clippy::too_many_lines,
clippy::needless_pass_by_value,
clippy::option_if_let_else,
clippy::trivially_copy_pass_by_ref,
clippy::unreadable_literal,
clippy::single_match_else,
clippy::similar_names,
clippy::many_single_char_names,
clippy::manual_let_else,
clippy::too_many_arguments,
clippy::redundant_closure_for_method_calls,
clippy::needless_range_loop,
clippy::float_cmp,
clippy::uninlined_format_args,
clippy::implicit_clone,
clippy::manual_midpoint,
clippy::option_map_unit_fn,
clippy::map_unwrap_or,
clippy::ptr_as_ptr,
clippy::no_effect_underscore_binding,
clippy::iter_without_into_iter,
clippy::missing_fields_in_debug,
clippy::single_match,
clippy::if_not_else,
clippy::unsafe_derive_deserialize,
clippy::explicit_iter_loop,
clippy::redundant_else,
)]
#[allow(unused_extern_crates)]
extern crate self as ferrotorch_nn;
pub mod activation;
pub mod attention;
pub mod buffer;
pub mod container;
pub mod conv;
pub mod dropout;
pub mod embedding;
pub mod flash_attention;
pub mod flex_attention;
pub mod functional;
pub mod hooks;
pub mod identity;
pub mod init;
pub mod lazy_conv;
pub mod lazy_conv_transpose;
pub mod lazy_linear;
pub mod lazy_norm;
pub mod linear;
pub mod lora;
pub mod loss;
pub mod module;
pub mod norm;
pub mod padding;
pub mod paged_attention;
pub mod parameter;
pub mod parameter_container;
pub mod pooling;
pub mod qat;
pub mod rnn;
pub mod rnn_utils;
pub mod transformer;
pub mod upsample;
pub mod utils;
pub use activation::{
CELU, ELU, GELU, GLU, GeluApproximate, HardSigmoid, HardSwish, Hardshrink, Hardtanh, LeakyReLU,
LogSigmoid, LogSoftmax, Mish, PReLU, RReLU, ReLU, ReLU6, SELU, SiLU, Sigmoid, Softmax,
Softmax2d, Softmin, Softplus, Softshrink, Softsign, Tanh, Tanhshrink, Threshold,
};
pub use attention::{MultiheadAttention, repeat_kv, reshape_to_heads, transpose_heads_to_2d};
pub use container::{ModuleDict, ModuleList, Sequential};
pub use conv::{Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d};
pub use dropout::{AlphaDropout, Dropout, Dropout1d, Dropout2d, Dropout3d};
pub use embedding::{Embedding, EmbeddingBag, EmbeddingBagMode};
pub use flash_attention::{flash_attention, standard_attention};
pub use flex_attention::{
BlockMask, alibi_score_mod, causal_score_mod, flex_attention, relative_position_bias_score_mod,
};
pub use hooks::{BackwardHook, ForwardHook, ForwardPreHook, HookHandle, HookedModule};
pub use identity::{
ChannelShuffle, CosineSimilarity, Flatten, Identity, PairwiseDistance, Unflatten,
};
pub use init::NonLinearity;
pub use lazy_conv::{LazyConv1d, LazyConv2d, LazyConv3d};
pub use lazy_linear::LazyLinear;
pub use linear::Linear;
pub use lora::LoRALinear;
pub use loss::{
BCELoss, BCEWithLogitsLoss, CTCLoss, CosineEmbeddingLoss, CrossEntropyLoss, GaussianNLLLoss,
HingeEmbeddingLoss, HuberLoss, KLDivLoss, L1Loss, MSELoss, MarginRankingLoss,
MultiLabelSoftMarginLoss, MultiMarginLoss, NLLLoss, PoissonNLLLoss, SmoothL1Loss,
TripletMarginLoss,
};
pub use module::{Module, Reduction, StateDict};
pub use buffer::Buffer;
pub use ferrotorch_nn_derive::Module;
pub use norm::{
BatchNorm1d, BatchNorm2d, BatchNorm3d, GroupNorm, InstanceNorm1d, InstanceNorm2d,
InstanceNorm3d, LayerNorm, LocalResponseNorm, RMSNorm,
};
pub use padding::{
CircularPad1d, CircularPad2d, CircularPad3d, ConstantPad1d, ConstantPad2d, ConstantPad3d,
PaddingMode, ReflectionPad1d, ReflectionPad2d, ReflectionPad3d, ReplicationPad1d,
ReplicationPad2d, ReplicationPad3d, ZeroPad1d, ZeroPad2d, ZeroPad3d,
};
pub use paged_attention::{KVPage, PagePool, PagedAttentionManager, PagedKVCache};
pub use parameter::Parameter;
pub use parameter_container::{ParameterDict, ParameterList};
pub use pooling::{
AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d, AdaptiveMaxPool1d, AdaptiveMaxPool2d,
AdaptiveMaxPool3d, AvgPool1d, AvgPool2d, AvgPool3d, FractionalMaxPool2d, LPPool1d, LPPool2d,
MaxPool1d, MaxPool2d, MaxPool3d, MaxUnpool2d, adaptive_avg_pool1d, adaptive_avg_pool2d,
adaptive_avg_pool3d, adaptive_max_pool1d, adaptive_max_pool2d, adaptive_max_pool3d, avg_pool1d,
avg_pool2d, avg_pool3d, lp_pool1d, lp_pool2d, max_pool1d, max_pool2d, max_pool3d, max_unpool2d,
};
pub use qat::{ObserverType, QatConfig, QatModel, QuantizedModel, prepare_qat};
pub use rnn::{GRU, GRUCell, LSTM, LSTMCell, RNN, RNNCell, RNNNonlinearity};
pub use rnn_utils::{PackedSequence, pack_padded_sequence, pad_packed_sequence};
pub use transformer::{
KVCache, RoPEConvention, RoPEScaling, RotaryPositionEmbedding, SwiGLU, Transformer,
TransformerDecoder, TransformerDecoderLayer, TransformerEncoder, TransformerEncoderLayer,
};
pub use upsample::{
Fold, GridSampleMode, GridSamplePaddingMode, InterpolateMode, PixelShuffle, PixelUnshuffle,
Unfold, Upsample, affine_grid, fold, grid_sample, interpolate, pixel_shuffle, pixel_unshuffle,
unfold,
};
pub use utils::{clip_grad_norm_, clip_grad_value_};
pub mod prelude {
pub use crate::buffer::Buffer;
pub use crate::module::{Module, Reduction, StateDict};
pub use crate::parameter::Parameter;
pub use ferrotorch_nn_derive::Module as DeriveModule;
pub use crate::activation::{GELU, ReLU, Sigmoid, Softmax, Tanh};
pub use crate::container::{ModuleDict, ModuleList, Sequential};
pub use crate::conv::{Conv1d, Conv2d, Conv3d};
pub use crate::dropout::Dropout;
pub use crate::embedding::Embedding;
pub use crate::linear::Linear;
pub use crate::norm::{BatchNorm1d, BatchNorm2d, GroupNorm, LayerNorm, RMSNorm};
pub use crate::pooling::{AdaptiveAvgPool2d, MaxPool2d};
pub use crate::loss::{BCELoss, BCEWithLogitsLoss, CrossEntropyLoss, L1Loss, MSELoss, NLLLoss};
pub use crate::utils::{clip_grad_norm_, clip_grad_value_};
}