use crate::neural_network::Tensor;
use crate::neural_network::layer::convolution_layer::PaddingType;
use ndarray::{Array2, Array3, ArrayD, s};
use rayon::iter::{
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
};
pub(super) fn calculate_output_shape_1d_pooling(
input_shape: &[usize],
pool_size: usize,
stride: usize,
) -> Vec<usize> {
let batch_size = input_shape[0];
let channels = input_shape[1];
let length = input_shape[2];
let output_length = (length - pool_size) / stride + 1;
vec![batch_size, channels, output_length]
}
pub(super) fn calculate_output_shape_2d_pooling(
input_shape: &[usize],
pool_size: (usize, usize),
strides: (usize, usize),
) -> Vec<usize> {
let batch_size = input_shape[0];
let channels = input_shape[1];
let input_height = input_shape[2];
let input_width = input_shape[3];
let output_height = (input_height - pool_size.0) / strides.0 + 1;
let output_width = (input_width - pool_size.1) / strides.1 + 1;
vec![batch_size, channels, output_height, output_width]
}
pub(super) fn calculate_output_shape_3d_pooling(
input_shape: &[usize],
pool_size: (usize, usize, usize),
strides: (usize, usize, usize),
) -> Vec<usize> {
let batch_size = input_shape[0];
let channels = input_shape[1];
let input_depth = input_shape[2];
let input_height = input_shape[3];
let input_width = input_shape[4];
let output_depth = (input_depth - pool_size.0) / strides.0 + 1;
let output_height = (input_height - pool_size.1) / strides.1 + 1;
let output_width = (input_width - pool_size.2) / strides.2 + 1;
vec![
batch_size,
channels,
output_depth,
output_height,
output_width,
]
}
pub(super) fn update_adam_conv(
params: &mut [f32],
grads: &[f32],
m: &mut [f32],
v: &mut [f32],
lr: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
bias_correction1: f32,
bias_correction2: f32,
) {
params
.par_iter_mut()
.zip(grads.par_iter())
.zip(m.par_iter_mut())
.zip(v.par_iter_mut())
.for_each(|(((param, &grad), m_val), v_val)| {
*m_val = beta1 * *m_val + (1.0 - beta1) * grad;
*v_val = beta2 * *v_val + (1.0 - beta2) * grad * grad;
let m_corrected = *m_val / bias_correction1;
let v_corrected = *v_val / bias_correction2;
*param -= lr * m_corrected / (v_corrected.sqrt() + epsilon);
});
}
pub(super) fn merge_results(
output_shape: Vec<usize>,
results: Vec<(usize, Array3<f32>)>,
filters: usize,
) -> ArrayD<f32> {
let mut output: ArrayD<f32> = ArrayD::zeros(output_shape.clone());
for (b, batch_output) in results {
for f in 0..filters {
for i in 0..output_shape[2] {
for j in 0..output_shape[3] {
output[[b, f, i, j]] = batch_output[[f, i, j]];
}
}
}
}
output
}
pub(super) fn compute_row_gradient_sum(
gradient: &Tensor,
input: &Tensor,
b: usize,
f: usize,
c: usize,
i: usize,
i_pos: usize,
w: usize,
grad_shape: &[usize],
input_shape: &[usize],
stride_1: usize,
) -> f32 {
let mut sum = 0.0;
for j in 0..grad_shape[3] {
let j_pos = j * stride_1 + w;
if j_pos < input_shape[3] {
sum += gradient[[b, f, i, j]] * input[[b, c, i_pos, j_pos]];
}
}
sum
}
pub(super) fn update_rmsprop(
params: &mut [f32],
grads: &[f32],
cache: &mut [f32],
rho: f32,
epsilon: f32,
lr: f32,
) {
params
.par_iter_mut()
.zip(grads.par_iter())
.zip(cache.par_iter_mut())
.for_each(|((param, &grad), cache_val)| {
*cache_val = rho * *cache_val + (1.0 - rho) * grad * grad;
*param -= lr * grad / (cache_val.sqrt() + epsilon);
});
}
pub(super) fn calculate_output_height_and_weight(
padding_type: PaddingType,
input_height: usize,
input_width: usize,
kernel_size: (usize, usize),
strides: (usize, usize),
) -> (usize, usize) {
let (output_height, output_width) = match padding_type {
PaddingType::Valid => {
let out_height = (input_height - kernel_size.0) / strides.0 + 1;
let out_width = (input_width - kernel_size.1) / strides.1 + 1;
(out_height, out_width)
}
PaddingType::Same => {
let out_height = (input_height as f32 / strides.0 as f32).ceil() as usize;
let out_width = (input_width as f32 / strides.1 as f32).ceil() as usize;
(out_height, out_width)
}
};
(output_height, output_width)
}
pub(super) fn pad_tensor_2d(input: &Array2<f32>, pad_h: usize, pad_w: usize) -> Array2<f32> {
let (input_height, input_width) = input.dim();
let pad_top = pad_h / 2;
let pad_left = pad_w / 2;
let output_height = input_height + pad_h;
let output_width = input_width + pad_w;
let mut output = Array2::zeros((output_height, output_width));
output
.slice_mut(s![
pad_top..pad_top + input_height,
pad_left..pad_left + input_width
])
.assign(input);
output
}
pub(super) fn calculate_output_shape_2d(
input_shape: &[usize],
kernel_size: (usize, usize),
strides: (usize, usize),
padding: &PaddingType,
) -> Vec<usize> {
assert!(
input_shape.len() >= 4,
"Input shape must have at least 4 dimensions"
);
let batch_size = input_shape[0];
let channels = input_shape[1];
let input_height = input_shape[2];
let input_width = input_shape[3];
let (output_height, output_width) = calculate_output_height_and_weight(
*padding,
input_height,
input_width,
kernel_size,
strides,
);
vec![batch_size, channels, output_height, output_width]
}