jiro_nn 0.8.1

Neural Networks framework with model building & data preprocessing features.
Documentation
use crate::{
    layer::LearnableLayer,
    linalg::{Matrix, MatrixTrait, Scalar},
    vision::{
        conv_initializers::ConvInitializers, conv_optimizer::ConvOptimizers, image::Image,
        image::ImageTrait,
    },
};

use crate::vision::image_layer::ImageLayer;

use super::ConvLayer;

#[derive(Debug)]
pub struct DirectConvLayer {
    pub kernels: Image,
    biases: Image,
    input: Option<Image>,
    kernels_optimizer: ConvOptimizers,
    biases_optimizer: ConvOptimizers,
}

impl DirectConvLayer {
    pub fn new(
        krows: usize,
        kcols: usize,
        in_chans: usize,
        kernels_initializer: ConvInitializers,
        biases_initializer: ConvInitializers,
        kernels_optimizer: ConvOptimizers,
        biases_optimizer: ConvOptimizers,
    ) -> Self {
        Self {
            kernels: kernels_initializer.gen_image(krows, kcols, in_chans, 1),
            biases: biases_initializer.gen_image(1, 1, in_chans, 1),
            input: None,
            kernels_optimizer,
            biases_optimizer,
        }
    }

    pub fn out_img_dims_and_channels(
        in_rows: usize,
        in_cols: usize,
        in_chans: usize,
        krows: usize,
        kcols: usize,
    ) -> (usize, usize, usize) {
        let out_rows = in_rows - krows + 1;
        let out_cols = in_cols - kcols + 1;
        (out_rows, out_cols, in_chans)
    }
}

impl ImageLayer for DirectConvLayer {
    fn forward(&mut self, input: Image) -> Image {
        let mut channels = vec![];
        for c in 0..input.channels() {
            let channel = input.get_channel_across_samples(c);

            let kernel = self.kernels.get_channel(c);

            let correlated = channel.cross_correlate(&kernel);

            if self.biases.image_dims().0 != correlated.image_dims().0 {
                self.biases =
                    self.biases
                        .tile(correlated.image_dims().0, correlated.image_dims().1, 1, 1);
            }

            let bias = self.biases.get_channel(c);

            let result_channel = correlated.component_add(&bias);
            channels.push(result_channel);
        }
        let res = Image::join_channels(channels);
        self.input = Some(input);
        res
    }

    fn backward(&mut self, epoch: usize, output_gradient: Image) -> Image {
        let input = self.input.as_ref().unwrap();

        let mut input_grad_channels = vec![];
        for i in 0..input.channels() {
            let kernel = self.kernels.get_channel(i);
            let output_grad_i = output_gradient.get_channel_across_samples(i);
            let correlated = output_grad_i.convolve_full(&kernel);
            input_grad_channels.push(correlated);
        }
        let input_grad = Image::join_channels(input_grad_channels);

        let mut kern_grad_channels = vec![];
        for i in 0..input.channels() {
            let output_grad_i = output_gradient.get_channel_across_samples(i).sum_samples();
            let input_i = input.get_channel_across_samples(i);
            let correlated = input_i.cross_correlate(&output_grad_i).sum_samples();
            kern_grad_channels.push(correlated);
        }
        let kern_grad = Image::join_channels(kern_grad_channels);

        let mut biases_grad_channels = vec![];
        for c in 0..self.biases.channels() {
            let channel = output_gradient.get_channel_across_samples(c);
            let channel = channel.sum_samples();
            biases_grad_channels.push(channel);
        }
        let biases_grad = Image::join_channels(biases_grad_channels);

        self.kernels = self
            .kernels_optimizer
            .update_parameters(epoch, &self.kernels, &kern_grad);
        self.biases = self
            .biases_optimizer
            .update_parameters(epoch, &self.biases, &biases_grad);
        input_grad
    }
}

impl LearnableLayer for DirectConvLayer {
    fn get_learnable_parameters(&self) -> Vec<Vec<Scalar>> {
        let mut params = self.kernels.flatten().get_data_col_leading();
        params.push(self.biases.flatten().get_column(0));
        params
    }

    fn set_learnable_parameters(&mut self, params_matrix: &Vec<Vec<Scalar>>) {
        let mut kernels = params_matrix.clone();
        let biases = kernels.pop().unwrap();
        self.kernels = Image::from_samples(
            &Matrix::from_column_leading_vector2(&kernels),
            self.kernels.channels(),
        );
        self.biases =
            Image::from_samples(&Matrix::from_column_vector(&biases), self.biases.channels());
    }
}

impl ConvLayer for DirectConvLayer {
    fn scale_kernels(&mut self, scale: Scalar) {
        self.kernels = self.kernels.scalar_mul(scale);
    }
}