use ndarray::{Array2, Array3, Array4};
use ndarray_rand::{rand::distributions::Uniform, RandomExt};
use serde::{Deserialize, Serialize};
use crate::core::NNResult;
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub(crate) enum Padding {
Valid,
Full,
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct Conv {
input_shape: [usize; 3],
output_shape: [usize; 3],
kernel_shape: [usize; 4],
kernels: Array4<f32>,
biases: Array3<f32>,
padding: Padding,
}
impl Conv {
pub fn new(
input_shape: [usize; 3],
kernel_size: usize,
nkernels: usize,
padding: Padding,
) -> Self {
let (input_depth, input_height, input_width) =
(input_shape[0], input_shape[1], input_shape[2]);
let kernel_shape = [nkernels, input_depth, kernel_size, kernel_size];
let output_shape = [
nkernels,
input_height - kernel_size + 1,
input_width - kernel_size + 1,
];
let kernels = Array4::random(kernel_shape, Uniform::new(-1.0, 1.0));
let biases = Array3::random(output_shape, Uniform::new(-1.0, 1.0));
println!("kernels:\n{}", kernels);
Self {
input_shape,
output_shape,
kernel_shape,
kernels,
biases,
padding,
}
}
fn _cross_correlation(
&self,
input: &Array2<f32>,
kernel: &Array2<f32>,
) -> NNResult<Array2<f32>> {
let mut sums = Vec::new();
match self.padding {
Padding::Valid => {
let (input_rows, input_cols) = input.dim();
let (kernel_rows, kernel_cols) = kernel.dim();
let valid_rows = input_rows - kernel_rows + 1;
let valid_cols = input_cols - kernel_cols + 1;
for i in 0..valid_rows {
for j in 0..valid_cols {
let mut sum = 0.0;
for ki in 0..kernel_rows {
for kj in 0..kernel_cols {
sum += input[[i + ki, j + kj]] * kernel[[ki, kj]];
}
}
sums.push(sum);
}
}
Ok(Array2::from_shape_vec((valid_rows, valid_cols), sums)?)
}
Padding::Full => Ok(Array2::zeros((0, 0))),
}
}
}
impl Default for Conv {
fn default() -> Self {
Self::new([1, 2, 3], 2, 4, Padding::Valid)
}
}
#[cfg(test)]
mod tests {
use ndarray::array;
use super::*;
#[test]
fn test_conv_creation() {
let conv = Conv::new([1, 2, 3], 2, 4, Padding::Valid);
assert_eq!(conv.input_shape, [1, 2, 3]);
assert_eq!(conv.output_shape, [4, 1, 2]);
assert_eq!(conv.kernel_shape, [4, 1, 2, 2]);
assert_eq!(conv.kernels.shape(), [4, 1, 2, 2]);
assert_eq!(conv.biases.shape(), [4, 1, 2]);
}
#[test]
fn test_conv_default() {
let conv_new = Conv::new([1, 2, 3], 2, 4, Padding::Valid);
let conv_default = Conv::default();
assert_eq!(conv_new.input_shape, conv_default.input_shape);
assert_eq!(conv_new.output_shape, conv_default.output_shape);
assert_eq!(conv_new.kernel_shape, conv_default.kernel_shape);
}
#[test]
fn test_cross_correlation_valid() {
let input = array![[1., 6., 2.], [5., 3., 1.], [7., 0., 4.],];
let kernel = array![[1., 2.], [-1., 0.]];
let conv = Conv::new([1, 2, 3], 2, 4, Padding::Valid);
let output = conv._cross_correlation(&input, &kernel);
assert!(output.is_ok());
assert_eq!(output.unwrap(), array![[8., 7.], [4., 5.]]);
}
}