rustyml 0.11.0

A high-performance machine learning & deep learning library in pure Rust, offering ML algorithms and neural network support
Documentation
use ndarray::{Array1, Array2, Array3, Array4, Array5, ArrayD};

/// Container for different types of neural network layer weights
///
/// This enum serves as a polymorphic container for the weights of various
/// neural network layer types. Each variant corresponds to a specific layer
/// type and contains the appropriate weight structure for that layer.
///
/// # Variants
///
/// - `Dense` - Contains weights for dense (fully connected) layers
/// - `SimpleRNN` - Contains weights for simple recurrent neural network layers
/// - `LSTM` - Contains weights for long short-term memory layers
/// - `Conv1D` - Contains weights for 1D convolutional layers
/// - `Conv2D` - Contains weights for 2D convolutional layers
/// - `Conv3D` - Contains weights for 3D convolutional layers
/// - `BatchNormalization` - Contains weights for batch normalization layers
/// - `LayerNormalizationLayer` - Contains weights for layer normalization layers
/// - `InstanceNormalizationLayer` - Contains weights for instance normalization layers
/// - `GroupNormalizationLayer` - Contains weights for group normalization layers
/// - `Empty` - Represents a layer with no trainable parameters
pub enum LayerWeight<'a> {
    Dense(DenseLayerWeight<'a>),
    SimpleRNN(SimpleRNNLayerWeight<'a>),
    LSTM(LSTMLayerWeight<'a>),
    GRU(GRULayerWeight<'a>),
    Conv1D(Conv1DLayerWeight<'a>),
    Conv2D(Conv2DLayerWeight<'a>),
    SeparableConv2DLayer(SeparableConv2DLayerWeight<'a>),
    DepthwiseConv2DLayer(DepthwiseConv2DLayerWeight<'a>),
    Conv3D(Conv3DLayerWeight<'a>),
    BatchNormalization(BatchNormalizationLayerWeight<'a>),
    LayerNormalizationLayer(LayerNormalizationLayerWeight<'a>),
    InstanceNormalizationLayer(InstanceNormalizationLayerWeight<'a>),
    GroupNormalizationLayer(GroupNormalizationLayerWeight<'a>),
    Empty,
}

/// Weights for a dense (fully connected) neural network layer
///
/// # Fields
///
/// - `weight` - Weight matrix with shape (input_features, output_features)
/// - `bias` - Bias vector with shape (1, output_features)
pub struct DenseLayerWeight<'a> {
    pub weight: &'a Array2<f32>,
    pub bias: &'a Array2<f32>,
}

/// Weights for a simple recurrent neural network layer
///
/// # Fields
///
/// - `kernel` - Weight matrix for input features
/// - `recurrent_kernel` - Weight matrix for recurrent connections
/// - `bias` - Bias vector
pub struct SimpleRNNLayerWeight<'a> {
    pub kernel: &'a Array2<f32>,
    pub recurrent_kernel: &'a Array2<f32>,
    pub bias: &'a Array2<f32>,
}

/// Weights for a single gate in an LSTM layer
///
/// # Fields
///
/// - `kernel` - Weight matrix for input features
/// - `recurrent_kernel` - Weight matrix for recurrent connections
/// - `bias` - Bias vector for the gate
pub struct LSTMGateWeight<'a> {
    pub kernel: &'a Array2<f32>,
    pub recurrent_kernel: &'a Array2<f32>,
    pub bias: &'a Array2<f32>,
}

/// Weights for a Long Short-Term Memory (LSTM) layer
///
/// Contains weights for the four gates that control information flow in an LSTM cell:
/// input gate, forget gate, cell gate, and output gate.
///
/// # Fields
///
/// - `input` - Weights for the input gate, which controls what new information to store
/// - `forget` - Weights for the forget gate, which controls what information to discard
/// - `cell` - Weights for the cell gate, which proposes new cell state values
/// - `output` - Weights for the output gate, which controls what to output
pub struct LSTMLayerWeight<'a> {
    pub input: LSTMGateWeight<'a>,
    pub forget: LSTMGateWeight<'a>,
    pub cell: LSTMGateWeight<'a>,
    pub output: LSTMGateWeight<'a>,
}

/// Weights for a single gate in a GRU layer
///
/// # Fields
///
/// - `kernel` - Weight matrix for input features
/// - `recurrent_kernel` - Weight matrix for recurrent connections
/// - `bias` - Bias vector for the gate
pub struct GRUGateWeight<'a> {
    pub kernel: &'a Array2<f32>,
    pub recurrent_kernel: &'a Array2<f32>,
    pub bias: &'a Array2<f32>,
}

/// Weights for a Gated Recurrent Unit (GRU) layer
///
/// Contains weights for the three gates that control information flow in a GRU cell:
/// reset gate, update gate, and candidate gate.
///
/// # Fields
///
/// - `reset` - Weights for the reset gate, which controls what information to forget
/// - `update` - Weights for the update gate, which controls how much to update the hidden state
/// - `candidate` - Weights for the candidate gate, which proposes new hidden state values
pub struct GRULayerWeight<'a> {
    pub reset: GRUGateWeight<'a>,
    pub update: GRUGateWeight<'a>,
    pub candidate: GRUGateWeight<'a>,
}

/// Weights for a 1D convolutional layer
///
/// # Fields
///
/// - `weight` - 3D convolution kernel with shape (output_channels, input_channels, kernel_size)
/// - `bias` - Bias vector with shape (1, output_channels)
pub struct Conv1DLayerWeight<'a> {
    pub weight: &'a Array3<f32>,
    pub bias: &'a Array2<f32>,
}

/// Weights for a 2D convolutional layer
///
/// # Fields
///
/// - `weight` - 4D convolution kernel with shape (output_channels, input_channels, kernel_height, kernel_width)
/// - `bias` - Bias vector with shape (1, output_channels)
pub struct Conv2DLayerWeight<'a> {
    pub weight: &'a Array4<f32>,
    pub bias: &'a Array2<f32>,
}

/// Weights for a 3D convolutional layer
///
/// # Fields
///
/// - `weight` - 5D convolution kernel with shape (output_channels, input_channels, kernel_depth, kernel_height, kernel_width)
/// - `bias` - Bias vector with shape (1, output_channels)
pub struct Conv3DLayerWeight<'a> {
    pub weight: &'a Array5<f32>,
    pub bias: &'a Array2<f32>,
}

/// Weights for a 2D separable convolutional layer
///
/// # Fields
///
/// - `depthwise_weight` - 4D weight tensor for depthwise convolution filters with shape (depth_multiplier, input_channels, kernel_height, kernel_width)
/// - `pointwise_weight` - 4D weight tensor for pointwise (1x1) convolution filters with shape (output_filters, input_channels * depth_multiplier, 1, 1)
/// - `bias` - Bias vector with shape (1, output_filters)
pub struct SeparableConv2DLayerWeight<'a> {
    pub depthwise_weight: &'a Array4<f32>,
    pub pointwise_weight: &'a Array4<f32>,
    pub bias: &'a Array2<f32>,
}

/// Weights for a 2D depthwise convolutional layer
///
/// # Fields
///
/// - `weight` - 4D weight tensor for depthwise filters with shape (depth_multiplier, input_channels, kernel_height, kernel_width)
/// - `bias` - Bias vector with shape (one bias per input channel)
pub struct DepthwiseConv2DLayerWeight<'a> {
    pub weight: &'a Array4<f32>,
    pub bias: &'a Array1<f32>,
}

/// Weights for a batch normalization layer
///
/// # Fields
///
/// - `gamma` - Scale parameter (learned during training) that controls the variance of normalized values
/// - `beta` - Shift parameter (learned during training) that controls the mean of normalized values
/// - `running_mean` - Exponentially weighted moving average of batch means (updated during training, used during inference)
/// - `running_var` - Exponentially weighted moving average of batch variances (updated during training, used during inference)
pub struct BatchNormalizationLayerWeight<'a> {
    pub gamma: &'a ArrayD<f32>,
    pub beta: &'a ArrayD<f32>,
    pub running_mean: &'a ArrayD<f32>,
    pub running_var: &'a ArrayD<f32>,
}

/// Weights for a layer normalization layer
///
/// # Fields
///
/// - `gamma` - Scale parameter (learned during training) that controls the variance of normalized values
/// - `beta` - Shift parameter (learned during training) that controls the mean of normalized values
pub struct LayerNormalizationLayerWeight<'a> {
    pub gamma: &'a ArrayD<f32>,
    pub beta: &'a ArrayD<f32>,
}

/// Weights for an instance normalization layer
///
/// # Fields
///
/// - `gamma` - Scale parameter (learned during training) that controls the variance of normalized values
/// - `beta` - Shift parameter (learned during training) that controls the mean of normalized values
pub struct InstanceNormalizationLayerWeight<'a> {
    pub gamma: &'a ArrayD<f32>,
    pub beta: &'a ArrayD<f32>,
}

/// Weights for a group normalization layer
///
/// # Fields
///
/// - `gamma` - Scale parameter (learned during training) that controls the variance of normalized values
/// - `beta` - Shift parameter (learned during training) that controls the mean of normalized values
pub struct GroupNormalizationLayerWeight<'a> {
    pub gamma: &'a ArrayD<f32>,
    pub beta: &'a ArrayD<f32>,
}