zenu-layer 0.1.1

A simple neural network layer library.
Documentation
use std::collections::HashMap;

use zenu_autograd::{
    creator::{ones::ones, zeros::zeros},
    nn::batch_norm::{batch_norm_2d, BatchNorm2dAutoGradConfig},
    Variable,
};
use zenu_matrix::{device::Device, dim::DimTrait, num::Num};

use crate::{Module, Parameters};

pub struct BatchNorm2d<T: Num, D: Device> {
    config: BatchNorm2dAutoGradConfig<T>,
    momentum: f64,
    pub scale: Variable<T, D>,
    pub bias: Variable<T, D>,
    pub mean: Variable<T, D>,
    pub variance: Variable<T, D>,
}

impl<T: Num, D: Device> Module<T, D> for BatchNorm2d<T, D> {
    type Input = Variable<T, D>;
    type Output = Variable<T, D>;
    fn call(&self, input: Variable<T, D>) -> Variable<T, D> {
        if input.get_shape() != self.config.get_shape() {
            self.config.update_shape(input.get_shape().slice());
        }
        batch_norm_2d(
            input,
            self.scale.clone(),
            self.bias.clone(),
            self.mean.clone(),
            self.variance.clone(),
            self.momentum,
            self.config.clone(),
        )
    }
}

impl<T: Num, D: Device> Parameters<T, D> for BatchNorm2d<T, D> {
    fn weights(&self) -> HashMap<String, Variable<T, D>> {
        let mut weights = HashMap::new();
        weights.insert("batch_norm_2d.scale".to_string(), self.scale.clone());
        weights
    }

    fn biases(&self) -> HashMap<String, Variable<T, D>> {
        let mut biases = HashMap::new();
        biases.insert("batch_norm_2d.bias".to_string(), self.bias.clone());
        biases
    }

    fn parameters(&self) -> HashMap<String, Variable<T, D>> {
        let mut parameters = HashMap::new();
        for (key, value) in &self.weights() {
            parameters.insert(key.clone(), value.clone());
        }
        for (key, value) in &self.biases() {
            parameters.insert(key.clone(), value.clone());
        }
        parameters.insert("batch_norm_2d.mean".to_string(), self.mean.clone());
        parameters.insert("batch_norm_2d.variance".to_string(), self.variance.clone());
        parameters
    }
}

impl<T: Num, D: Device> BatchNorm2d<T, D> {
    #[must_use]
    pub fn new(channels: usize, momentum: f64) -> Self {
        let scale = ones([channels]);
        let bias = zeros([channels]);
        let mean = zeros([channels]);
        let variance = ones([channels]);

        scale.set_is_train(true);
        bias.set_is_train(true);

        scale.set_name("batch_norm_2d.scale");
        bias.set_name("batch_norm_2d.bias");
        mean.set_name("batch_norm_2d.mean");
        variance.set_name("batch_norm_2d.variance");

        let config = BatchNorm2dAutoGradConfig::default();
        Self {
            config,
            momentum,
            scale,
            bias,
            mean,
            variance,
        }
    }

    pub fn to<Dout: Device>(self) -> BatchNorm2d<T, Dout> {
        BatchNorm2d {
            config: self.config,
            momentum: self.momentum,
            scale: self.scale.to(),
            bias: self.bias.to(),
            mean: self.mean.to(),
            variance: self.variance.to(),
        }
    }
}