use crate::activations::Activation;
use crate::device::Device;
use crate::error::{NnlError, Result};
use crate::tensor::Tensor;
use serde::{Deserialize, Serialize};
use std::fmt;
pub mod conv;
pub mod dense;
pub mod dropout;
pub mod normalization;
pub mod pooling;
pub use conv::Conv2DLayer;
pub use dense::DenseLayer;
pub use dropout::DropoutLayer;
pub use normalization::{BatchNormLayer, LayerNormLayer};
pub use pooling::{AvgPool2DLayer, FlattenLayer, MaxPool2DLayer, ReshapeLayer};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum LayerConfig {
Dense {
input_size: usize,
output_size: usize,
activation: Activation,
use_bias: bool,
weight_init: WeightInit,
},
Conv2D {
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
activation: Activation,
use_bias: bool,
weight_init: WeightInit,
},
Dropout {
dropout_rate: f32,
},
BatchNorm {
num_features: usize,
eps: f32,
momentum: f32,
affine: bool,
},
LayerNorm {
normalized_shape: Vec<usize>,
eps: f32,
elementwise_affine: bool,
},
MaxPool2D {
kernel_size: (usize, usize),
stride: Option<(usize, usize)>,
padding: (usize, usize),
},
AvgPool2D {
kernel_size: (usize, usize),
stride: Option<(usize, usize)>,
padding: (usize, usize),
},
Flatten {
start_dim: usize,
end_dim: Option<usize>,
},
Reshape {
target_shape: Vec<usize>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum WeightInit {
Xavier,
XavierNormal,
He,
HeNormal,
Uniform(f32),
Normal(f32),
Zeros,
Ones,
Constant(f32),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TrainingMode {
Training,
Inference,
}
pub trait Layer: Send + Sync + std::fmt::Debug {
fn forward(&mut self, input: &Tensor, training: TrainingMode) -> Result<Tensor>;
fn backward(&mut self, grad_output: &Tensor) -> Result<Tensor>;
fn parameters(&self) -> Vec<&Tensor>;
fn parameters_mut(&mut self) -> Vec<&mut Tensor>;
fn gradients(&self) -> Vec<&Tensor>;
fn gradients_mut(&mut self) -> Vec<&mut Tensor>;
fn zero_grad(&mut self);
fn name(&self) -> &str;
fn output_shape(&self, input_shape: &[usize]) -> Result<Vec<usize>>;
fn set_training(&mut self, training: bool);
fn training(&self) -> bool;
fn num_parameters(&self) -> usize {
self.parameters().iter().map(|p| p.size()).sum()
}
fn to_device(&mut self, device: crate::device::Device) -> Result<()>;
fn clone_layer(&self) -> Result<Box<dyn Layer>>;
}
impl WeightInit {
pub fn initialize(&self, tensor: &mut Tensor, fan_in: usize, fan_out: usize) -> Result<()> {
match self {
WeightInit::Xavier => {
let bound = (6.0 / (fan_in + fan_out) as f32).sqrt();
let data = tensor.to_vec()?;
let new_data: Vec<f32> = data
.iter()
.map(|_| {
use rand::prelude::*;
let mut rng = thread_rng();
rng.gen_range(-bound..bound)
})
.collect();
tensor.copy_from_slice(&new_data)?;
Ok(())
}
WeightInit::XavierNormal => {
let std = (2.0 / (fan_in + fan_out) as f32).sqrt();
let data = tensor.to_vec()?;
let new_data: Vec<f32> = data
.iter()
.map(|_| {
use rand::prelude::*;
use rand_distr::StandardNormal;
let mut rng = thread_rng();
rng.sample::<f32, _>(StandardNormal) * std
})
.collect();
tensor.copy_from_slice(&new_data)?;
Ok(())
}
WeightInit::He => {
let bound = (6.0 / fan_in as f32).sqrt();
let data = tensor.to_vec()?;
let new_data: Vec<f32> = data
.iter()
.map(|_| {
use rand::prelude::*;
let mut rng = thread_rng();
rng.gen_range(-bound..bound)
})
.collect();
tensor.copy_from_slice(&new_data)?;
Ok(())
}
WeightInit::HeNormal => {
let std = (2.0 / fan_in as f32).sqrt();
let data = tensor.to_vec()?;
let new_data: Vec<f32> = data
.iter()
.map(|_| {
use rand::prelude::*;
use rand_distr::StandardNormal;
let mut rng = thread_rng();
rng.sample::<f32, _>(StandardNormal) * std
})
.collect();
tensor.copy_from_slice(&new_data)?;
Ok(())
}
WeightInit::Uniform(bound) => {
let data = tensor.to_vec()?;
let new_data: Vec<f32> = data
.iter()
.map(|_| {
use rand::prelude::*;
let mut rng = thread_rng();
rng.gen_range(-*bound..*bound)
})
.collect();
tensor.copy_from_slice(&new_data)?;
Ok(())
}
WeightInit::Normal(std) => {
let data = tensor.to_vec()?;
let new_data: Vec<f32> = data
.iter()
.map(|_| {
use rand::prelude::*;
use rand_distr::StandardNormal;
let mut rng = thread_rng();
rng.sample::<f32, _>(StandardNormal) * std
})
.collect();
tensor.copy_from_slice(&new_data)?;
Ok(())
}
WeightInit::Zeros => tensor.fill(0.0),
WeightInit::Ones => tensor.fill(1.0),
WeightInit::Constant(value) => tensor.fill(*value),
}
}
pub fn default_for_activation(activation: &Activation) -> Self {
match activation {
Activation::ReLU | Activation::LeakyReLU(_) => WeightInit::HeNormal,
Activation::Sigmoid | Activation::Tanh => WeightInit::XavierNormal,
_ => WeightInit::XavierNormal,
}
}
}
impl fmt::Display for WeightInit {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WeightInit::Xavier => write!(f, "Xavier"),
WeightInit::XavierNormal => write!(f, "XavierNormal"),
WeightInit::He => write!(f, "He"),
WeightInit::HeNormal => write!(f, "HeNormal"),
WeightInit::Uniform(bound) => write!(f, "Uniform(±{})", bound),
WeightInit::Normal(std) => write!(f, "Normal(σ={})", std),
WeightInit::Zeros => write!(f, "Zeros"),
WeightInit::Ones => write!(f, "Ones"),
WeightInit::Constant(value) => write!(f, "Constant({})", value),
}
}
}
impl Default for WeightInit {
fn default() -> Self {
WeightInit::XavierNormal
}
}
pub fn create_layer(config: LayerConfig, device: Device) -> Result<Box<dyn Layer>> {
match config {
LayerConfig::Dense {
input_size,
output_size,
activation,
use_bias,
weight_init,
} => Ok(Box::new(DenseLayer::new_on_device(
input_size,
output_size,
activation,
use_bias,
weight_init,
device,
)?)),
LayerConfig::Conv2D {
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
activation,
use_bias,
weight_init,
} => Ok(Box::new(Conv2DLayer::new_on_device(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
activation,
use_bias,
weight_init,
device,
)?)),
LayerConfig::Dropout { dropout_rate } => Ok(Box::new(DropoutLayer::new(dropout_rate)?)),
LayerConfig::BatchNorm {
num_features,
eps,
momentum,
affine,
} => Ok(Box::new(BatchNormLayer::new_on_device(
num_features,
eps,
momentum,
affine,
device,
)?)),
LayerConfig::LayerNorm {
normalized_shape,
eps,
elementwise_affine,
} => Ok(Box::new(LayerNormLayer::new_on_device(
normalized_shape,
eps,
elementwise_affine,
device,
)?)),
LayerConfig::MaxPool2D {
kernel_size,
stride,
padding,
} => Ok(Box::new(MaxPool2DLayer::new(kernel_size, stride, padding)?)),
LayerConfig::AvgPool2D {
kernel_size,
stride,
padding,
} => Ok(Box::new(AvgPool2DLayer::new(kernel_size, stride, padding)?)),
LayerConfig::Flatten { start_dim, end_dim } => {
Ok(Box::new(FlattenLayer::new(start_dim, end_dim)?))
}
LayerConfig::Reshape { target_shape } => {
Ok(Box::new(ReshapeLayer::new(target_shape.clone())?))
}
}
}
impl fmt::Display for LayerConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LayerConfig::Dense {
input_size,
output_size,
activation,
..
} => write!(f, "Dense({} → {}, {})", input_size, output_size, activation),
LayerConfig::Conv2D {
in_channels,
out_channels,
kernel_size,
..
} => write!(
f,
"Conv2D({} → {}, kernel={}×{})",
in_channels, out_channels, kernel_size.0, kernel_size.1
),
LayerConfig::Dropout { dropout_rate } => write!(f, "Dropout(p={})", dropout_rate),
LayerConfig::BatchNorm { num_features, .. } => {
write!(f, "BatchNorm({})", num_features)
}
LayerConfig::LayerNorm {
normalized_shape, ..
} => write!(f, "LayerNorm({:?})", normalized_shape),
LayerConfig::MaxPool2D { kernel_size, .. } => {
write!(f, "MaxPool2D({}×{})", kernel_size.0, kernel_size.1)
}
LayerConfig::AvgPool2D { kernel_size, .. } => {
write!(f, "AvgPool2D({}×{})", kernel_size.0, kernel_size.1)
}
LayerConfig::Flatten { .. } => write!(f, "Flatten"),
LayerConfig::Reshape { target_shape } => write!(f, "Reshape({:?})", target_shape),
}
}
}
impl LayerConfig {
pub fn dense_relu(input_size: usize, output_size: usize) -> Self {
LayerConfig::Dense {
input_size,
output_size,
activation: Activation::ReLU,
use_bias: true,
weight_init: WeightInit::HeNormal,
}
}
pub fn dense_sigmoid(input_size: usize, output_size: usize) -> Self {
LayerConfig::Dense {
input_size,
output_size,
activation: Activation::Sigmoid,
use_bias: true,
weight_init: WeightInit::XavierNormal,
}
}
pub fn dense_linear(input_size: usize, output_size: usize) -> Self {
LayerConfig::Dense {
input_size,
output_size,
activation: Activation::Linear,
use_bias: true,
weight_init: WeightInit::XavierNormal,
}
}
pub fn conv2d_3x3(in_channels: usize, out_channels: usize) -> Self {
LayerConfig::Conv2D {
in_channels,
out_channels,
kernel_size: (3, 3),
stride: (1, 1),
padding: (1, 1),
dilation: (1, 1),
activation: Activation::ReLU,
use_bias: true,
weight_init: WeightInit::HeNormal,
}
}
pub fn conv2d_5x5(in_channels: usize, out_channels: usize) -> Self {
LayerConfig::Conv2D {
in_channels,
out_channels,
kernel_size: (5, 5),
stride: (1, 1),
padding: (2, 2),
dilation: (1, 1),
activation: Activation::ReLU,
use_bias: true,
weight_init: WeightInit::HeNormal,
}
}
pub fn dropout(dropout_rate: f32) -> Self {
LayerConfig::Dropout { dropout_rate }
}
pub fn batch_norm(num_features: usize) -> Self {
LayerConfig::BatchNorm {
num_features,
eps: 1e-5,
momentum: 0.1,
affine: true,
}
}
pub fn layer_norm(normalized_shape: Vec<usize>) -> Self {
LayerConfig::LayerNorm {
normalized_shape,
eps: 1e-5,
elementwise_affine: true,
}
}
pub fn max_pool2d() -> Self {
LayerConfig::MaxPool2D {
kernel_size: (2, 2),
stride: None, padding: (0, 0),
}
}
pub fn flatten() -> Self {
LayerConfig::Flatten {
start_dim: 1,
end_dim: None,
}
}
}
pub mod utils {
use super::*;
pub fn conv_output_size(
input_size: usize,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
) -> usize {
(input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
}
pub fn pool_output_size(
input_size: usize,
kernel_size: usize,
stride: usize,
padding: usize,
) -> usize {
(input_size + 2 * padding - kernel_size) / stride + 1
}
pub fn same_padding(kernel_size: usize, dilation: usize) -> usize {
(dilation * (kernel_size - 1)) / 2
}
pub fn validate_shapes(expected: &[usize], actual: &[usize]) -> Result<()> {
if expected != actual {
return Err(NnlError::shape_mismatch(expected, actual));
}
Ok(())
}
pub fn count_parameters(weights_shape: &[usize], has_bias: bool, bias_size: usize) -> usize {
let weight_params: usize = weights_shape.iter().product();
let bias_params = if has_bias { bias_size } else { 0 };
weight_params + bias_params
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_weight_initialization() {
let mut tensor = Tensor::zeros(&[10, 10]).unwrap();
let init = WeightInit::Xavier;
init.initialize(&mut tensor, 10, 10).unwrap();
let data = tensor.to_vec().unwrap();
assert!(data.iter().any(|&x| x != 0.0));
}
#[test]
fn test_layer_config_display() {
let dense = LayerConfig::dense_relu(784, 128);
let display = format!("{}", dense);
assert!(display.contains("Dense"));
assert!(display.contains("784"));
assert!(display.contains("128"));
let conv = LayerConfig::conv2d_3x3(3, 64);
let display = format!("{}", conv);
assert!(display.contains("Conv2D"));
assert!(display.contains("3"));
assert!(display.contains("64"));
}
#[test]
fn test_weight_init_display() {
assert_eq!(format!("{}", WeightInit::Xavier), "Xavier");
assert_eq!(format!("{}", WeightInit::HeNormal), "HeNormal");
assert_eq!(format!("{}", WeightInit::Uniform(0.5)), "Uniform(±0.5)");
assert_eq!(format!("{}", WeightInit::Normal(0.1)), "Normal(σ=0.1)");
}
#[test]
fn test_conv_output_size_calculation() {
let output_size = utils::conv_output_size(32, 3, 1, 1, 1);
assert_eq!(output_size, 32);
let output_size = utils::conv_output_size(32, 3, 2, 1, 1);
assert_eq!(output_size, 16); }
#[test]
fn test_pool_output_size_calculation() {
let output_size = utils::pool_output_size(32, 2, 2, 0);
assert_eq!(output_size, 16);
let output_size = utils::pool_output_size(32, 3, 1, 1);
assert_eq!(output_size, 32);
}
#[test]
fn test_same_padding_calculation() {
assert_eq!(utils::same_padding(3, 1), 1);
assert_eq!(utils::same_padding(5, 1), 2);
assert_eq!(utils::same_padding(3, 2), 2);
}
#[test]
fn test_weight_init_defaults() {
let relu_init = WeightInit::default_for_activation(&Activation::ReLU);
assert_eq!(relu_init, WeightInit::HeNormal);
let sigmoid_init = WeightInit::default_for_activation(&Activation::Sigmoid);
assert_eq!(sigmoid_init, WeightInit::XavierNormal);
}
#[test]
fn test_parameter_counting() {
let weight_params = utils::count_parameters(&[784, 128], true, 128);
assert_eq!(weight_params, 784 * 128 + 128);
let conv_params = utils::count_parameters(&[64, 3, 3, 3], true, 64);
assert_eq!(conv_params, 64 * 3 * 3 * 3 + 64);
}
#[test]
fn test_shape_validation() {
let result = utils::validate_shapes(&[2, 3, 4], &[2, 3, 4]);
assert!(result.is_ok());
let result = utils::validate_shapes(&[2, 3, 4], &[2, 3, 5]);
assert!(result.is_err());
}
}