pub mod activation;
pub mod adaptive_pool;
pub mod attention;
pub mod batchnorm;
pub mod conv1d;
pub mod conv2d;
pub mod conv3d;
pub mod conv_base;
pub mod conv_transpose;
pub mod conv_transpose_1d;
pub mod conv_transpose_3d;
pub mod conv_transpose_common;
pub mod dropout;
pub mod embedding;
pub mod gru;
pub mod gru_cell;
pub mod gru_layer;
pub mod instance_norm;
pub mod linear;
pub mod loss;
pub mod lstm;
pub mod lstm_cell;
pub mod lstm_layer;
pub mod normalization;
pub mod pool2d;
pub mod pruning;
pub mod quantization;
pub mod recurrent_common;
pub mod rnn;
pub mod safe_ops;
pub mod shared_activation;
pub mod shared_loss;
pub mod shared_normalization;
pub mod transformer;
pub mod transformer_phase6;
use crate::autograd::Variable;
use crate::serialization::core::{Loadable, Saveable, SerializationError, SerializationResult};
use num_traits::Float;
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::marker::{Send, Sync};
pub trait Module<T>: Send + Sync + Debug
where
T: Float + 'static + Send + Sync + ndarray::ScalarOperand + num_traits::FromPrimitive,
{
fn forward(&self, input: &Variable<T>) -> Variable<T>;
fn parameters(&self) -> Vec<Variable<T>>;
fn as_any(&self) -> &dyn Any;
fn train(&mut self) {
}
fn eval(&mut self) {
}
}
#[derive(Debug, Default)]
pub struct Sequential<T> {
modules: Vec<Box<dyn Module<T>>>,
}
impl<T> Sequential<T>
where
T: Float + 'static + Send + Sync + Debug + ndarray::ScalarOperand + num_traits::FromPrimitive,
{
pub fn new() -> Self {
Sequential {
modules: Vec::new(),
}
}
pub fn add_module<M: Module<T> + 'static>(&mut self, module: M) -> &mut Self {
self.modules.push(Box::new(module));
self
}
pub fn get_module(&self, index: usize) -> Option<&dyn Module<T>> {
self.modules.get(index).map(|m| &**m)
}
pub fn len(&self) -> usize {
self.modules.len()
}
pub fn is_empty(&self) -> bool {
self.modules.is_empty()
}
}
impl<T> Module<T> for Sequential<T>
where
T: Float + 'static + Send + Sync + Debug + ndarray::ScalarOperand + num_traits::FromPrimitive,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
let mut x = input.clone();
for module in &self.modules {
x = module.forward(&x);
}
x
}
fn parameters(&self) -> Vec<Variable<T>> {
let mut params = Vec::new();
for module in &self.modules {
params.extend(module.parameters());
}
params
}
fn as_any(&self) -> &dyn Any {
self
}
}
pub use activation::{geglu, glu, reglu, swiglu, ReLU, Softmax, Tanh, GELU, GLU};
pub use adaptive_pool::{AdaptiveAvgPool2d, AdaptiveMaxPool2d};
pub use attention::{CrossAttention, SelfAttention};
pub use batchnorm::{BatchNorm1d, BatchNorm2d};
pub use conv1d::Conv1d;
pub use conv2d::Conv2d;
pub use conv3d::Conv3d;
pub use conv_transpose::ConvTranspose2d;
pub use conv_transpose_1d::ConvTranspose1d;
pub use conv_transpose_3d::ConvTranspose3d;
pub use dropout::{dropout, AlphaDropout, Dropout};
pub use embedding::{Embedding, PositionalEmbedding, SinusoidalPositionalEncoding};
pub use gru::{GRUCell, GRU};
pub use instance_norm::{InstanceNorm1d, InstanceNorm2d, InstanceNorm3d};
pub use linear::Linear;
pub use loss::{
cross_entropy_loss, focal_loss, kl_div_loss, mse_loss, triplet_loss, CrossEntropyLoss,
FocalLoss, KLDivLoss, Loss, MSELoss, TripletLoss,
};
pub use lstm::{LSTMCell, LSTM};
pub use normalization::{GroupNorm, LayerNorm, RMSNorm};
pub use pool2d::{AvgPool2d, MaxPool2d};
pub use pruning::{
Pruner, PruningAwareModule, PruningMask, PruningMethod, PruningSchedule, PruningStructure,
};
pub use quantization::{
CalibrationMode, QuantizationAwareModule, QuantizationParams, QuantizationType,
QuantizedTensor, Quantizer,
};
pub use rnn::{RNNCell, RNN};
pub use transformer_phase6::{
MultiheadAttention, PositionalEncoding, Transformer, TransformerDecoderLayer,
TransformerEncoderLayer,
};