use scirs2_core::ndarray::{Array1, Array2, Array4, ArrayView1, ArrayView2, ArrayView4};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::{NeuralError, Result};
pub fn conv2d<F>(
input: &ArrayView4<F>,
weight: &ArrayView4<F>,
bias: Option<ArrayView1<F>>,
stride: usize,
padding: usize,
) -> Result<Array4<F>>
where
F: Float + Debug,
{
let batch_size = input.shape()[0];
let in_channels = input.shape()[1];
let in_height = input.shape()[2];
let in_width = input.shape()[3];
let out_channels = weight.shape()[0];
let weight_in_channels = weight.shape()[1];
let kernel_height = weight.shape()[2];
let kernel_width = weight.shape()[3];
if in_channels != weight_in_channels {
return Err(NeuralError::ShapeMismatch(
format!("Input and weight channel mismatch in conv2d: input has {} channels, weight expects {} channels",
in_channels, weight_in_channels)
));
}
if let Some(b) = bias {
if b.shape()[0] != out_channels {
return Err(NeuralError::ShapeMismatch(format!(
"Bias shape mismatch in conv2d: bias has {} channels, expected {}",
b.shape()[0],
out_channels
)));
}
}
let out_height = (in_height + 2 * padding - kernel_height) / stride + 1;
let out_width = (in_width + 2 * padding - kernel_width) / stride + 1;
let mut output = Array4::<F>::zeros((batch_size, out_channels, out_height, out_width));
let mut input_padded = Array4::<F>::zeros((
batch_size,
in_channels,
in_height + 2 * padding,
in_width + 2 * padding,
));
for b in 0..batch_size {
for c in 0..in_channels {
for h in 0..in_height {
for w in 0..in_width {
input_padded[[b, c, h + padding, w + padding]] = input[[b, c, h, w]];
}
}
}
}
for b in 0..batch_size {
for oc in 0..out_channels {
for oh in 0..out_height {
for ow in 0..out_width {
let h_start = oh * stride;
let w_start = ow * stride;
let mut sum = F::zero();
for ic in 0..in_channels {
for kh in 0..kernel_height {
for kw in 0..kernel_width {
sum = sum
+ input_padded[[b, ic, h_start + kh, w_start + kw]]
* weight[[oc, ic, kh, kw]];
}
}
}
if let Some(b) = bias {
sum = sum + b[oc];
}
output[[b, oc, oh, ow]] = sum;
}
}
}
}
Ok(output)
}
pub fn max_pool2d<F>(
input: &ArrayView4<F>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> Result<(Array4<F>, Array4<usize>)>
where
F: Float + Debug,
{
let batch_size = input.shape()[0];
let channels = input.shape()[1];
let in_height = input.shape()[2];
let in_width = input.shape()[3];
let out_height = (in_height + 2 * padding - kernel_size) / stride + 1;
let out_width = (in_width + 2 * padding - kernel_size) / stride + 1;
let mut output = Array4::<F>::zeros((batch_size, channels, out_height, out_width));
let mut indices = Array4::<usize>::zeros((batch_size, channels, out_height, out_width));
let mut input_padded = Array4::<F>::zeros((
batch_size,
channels,
in_height + 2 * padding,
in_width + 2 * padding,
));
let neg_inf = F::min_value();
input_padded.fill(neg_inf);
for b in 0..batch_size {
for c in 0..channels {
for h in 0..in_height {
for w in 0..in_width {
input_padded[[b, c, h + padding, w + padding]] = input[[b, c, h, w]];
}
}
}
}
for b in 0..batch_size {
for c in 0..channels {
for oh in 0..out_height {
for ow in 0..out_width {
let h_start = oh * stride;
let w_start = ow * stride;
let mut max_val = neg_inf;
let mut max_idx = 0;
for kh in 0..kernel_size {
for kw in 0..kernel_size {
let h_idx = h_start + kh;
let w_idx = w_start + kw;
let val = input_padded[[b, c, h_idx, w_idx]];
if val > max_val {
max_val = val;
max_idx = kh * kernel_size + kw;
}
}
}
output[[b, c, oh, ow]] = max_val;
indices[[b, c, oh, ow]] = max_idx;
}
}
}
}
Ok((output, indices))
}
pub fn im2col<F>(
input: &ArrayView4<F>,
kernel_height: usize,
kernel_width: usize,
stride: usize,
padding: usize,
) -> Result<Array2<F>>
where
F: Float + Debug,
{
let batch_size = input.shape()[0];
let channels = input.shape()[1];
let in_height = input.shape()[2];
let in_width = input.shape()[3];
let out_height = (in_height + 2 * padding - kernel_height) / stride + 1;
let out_width = (in_width + 2 * padding - kernel_width) / stride + 1;
let mut input_padded = Array4::<F>::zeros((
batch_size,
channels,
in_height + 2 * padding,
in_width + 2 * padding,
));
for b in 0..batch_size {
for c in 0..channels {
for h in 0..in_height {
for w in 0..in_width {
input_padded[[b, c, h + padding, w + padding]] = input[[b, c, h, w]];
}
}
}
}
let col_height = channels * kernel_height * kernel_width;
let col_width = batch_size * out_height * out_width;
let mut cols = Array2::<F>::zeros((col_height, col_width));
for b in 0..batch_size {
for oh in 0..out_height {
for ow in 0..out_width {
let col_idx = b * (out_height * out_width) + oh * out_width + ow;
let h_start = oh * stride;
let w_start = ow * stride;
let mut row_idx = 0;
for c in 0..channels {
for kh in 0..kernel_height {
for kw in 0..kernel_width {
cols[[row_idx, col_idx]] =
input_padded[[b, c, h_start + kh, w_start + kw]];
row_idx += 1;
}
}
}
}
}
}
Ok(cols)
}
#[allow(clippy::too_many_arguments)]
pub fn col2im<F>(
cols: &ArrayView2<F>,
batch_size: usize,
channels: usize,
height: usize,
width: usize,
kernel_height: usize,
kernel_width: usize,
_stride: usize,
padding: usize,
) -> Result<Array4<F>>
where
F: Float + Debug,
{
let out_height = height + 2 * padding - kernel_height + 1;
let out_width = width + 2 * padding - kernel_width + 1;
let expected_col_height = channels * kernel_height * kernel_width;
let expected_col_width = batch_size * out_height * out_width;
if cols.shape()[0] != expected_col_height || cols.shape()[1] != expected_col_width {
return Err(NeuralError::ShapeMismatch(format!(
"Column shape mismatch in col2im: expected [{}, {}], got [{}, {}]",
expected_col_height,
expected_col_width,
cols.shape()[0],
cols.shape()[1]
)));
}
let mut output_padded = Array4::<F>::zeros((
batch_size,
channels,
height + 2 * padding,
width + 2 * padding,
));
for b in 0..batch_size {
for oh in 0..out_height {
for ow in 0..out_width {
let col_idx = b * (out_height * out_width) + oh * out_width + ow;
let h_start = oh;
let w_start = ow;
let mut row_idx = 0;
for c in 0..channels {
for kh in 0..kernel_height {
for kw in 0..kernel_width {
output_padded[[b, c, h_start + kh, w_start + kw]] = output_padded
[[b, c, h_start + kh, w_start + kw]]
+ cols[[row_idx, col_idx]];
row_idx += 1;
}
}
}
}
}
}
let mut output = Array4::<F>::zeros((batch_size, channels, height, width));
for b in 0..batch_size {
for c in 0..channels {
for h in 0..height {
for w in 0..width {
output[[b, c, h, w]] = output_padded[[b, c, h + padding, w + padding]];
}
}
}
}
Ok(output)
}
pub fn adaptive_avg_pool2d<F>(
input: &ArrayView4<F>,
output_height: usize,
output_width: usize,
) -> Result<Array4<F>>
where
F: Float + Debug,
{
let batch_size = input.shape()[0];
let channels = input.shape()[1];
let in_height = input.shape()[2];
let in_width = input.shape()[3];
if output_height > in_height || output_width > in_width {
return Err(NeuralError::InvalidArgument(
"Output dimensions must be less than or equal to input dimensions in adaptive_avg_pool2d".to_string()
));
}
let mut output = Array4::<F>::zeros((batch_size, channels, output_height, output_width));
for b in 0..batch_size {
for c in 0..channels {
for oh in 0..output_height {
for ow in 0..output_width {
let h_start = (oh * in_height) / output_height;
let h_end = ((oh + 1) * in_height) / output_height;
let w_start = (ow * in_width) / output_width;
let w_end = ((ow + 1) * in_width) / output_width;
let kernel_h = h_end - h_start;
let kernel_w = w_end - w_start;
let mut sum = F::zero();
for h in h_start..h_end {
for w in w_start..w_end {
sum = sum + input[[b, c, h, w]];
}
}
output[[b, c, oh, ow]] =
sum / F::from(kernel_h * kernel_w).expect("Failed to convert to float");
}
}
}
}
Ok(output)
}
pub fn conv2d_backward<F>(
dout: &ArrayView4<F>,
input: &ArrayView4<F>,
weight: &ArrayView4<F>,
stride: usize,
padding: usize,
) -> Result<(Array4<F>, Array4<F>, Array1<F>)>
where
F: Float + Debug,
{
let batch_size = input.shape()[0];
let in_channels = input.shape()[1];
let in_height = input.shape()[2];
let in_width = input.shape()[3];
let out_channels = dout.shape()[1];
let out_height = dout.shape()[2];
let out_width = dout.shape()[3];
let kernel_height = weight.shape()[2];
let kernel_width = weight.shape()[3];
let mut d_input = Array4::<F>::zeros(input.raw_dim());
let mut d_weight = Array4::<F>::zeros(weight.raw_dim());
let mut d_bias = Array1::<F>::zeros(out_channels);
let mut input_padded = Array4::<F>::zeros((
batch_size,
in_channels,
in_height + 2 * padding,
in_width + 2 * padding,
));
for b in 0..batch_size {
for c in 0..in_channels {
for h in 0..in_height {
for w in 0..in_width {
input_padded[[b, c, h + padding, w + padding]] = input[[b, c, h, w]];
}
}
}
}
for oc in 0..out_channels {
for b in 0..batch_size {
for oh in 0..out_height {
for ow in 0..out_width {
d_bias[oc] = d_bias[oc] + dout[[b, oc, oh, ow]];
}
}
}
}
for oc in 0..out_channels {
for ic in 0..in_channels {
for kh in 0..kernel_height {
for kw in 0..kernel_width {
let mut grad = F::zero();
for b in 0..batch_size {
for oh in 0..out_height {
for ow in 0..out_width {
let h_in = oh * stride + kh;
let w_in = ow * stride + kw;
grad = grad
+ dout[[b, oc, oh, ow]] * input_padded[[b, ic, h_in, w_in]];
}
}
}
d_weight[[oc, ic, kh, kw]] = grad;
}
}
}
}
let mut d_input_padded = Array4::<F>::zeros((
batch_size,
in_channels,
in_height + 2 * padding,
in_width + 2 * padding,
));
for b in 0..batch_size {
for ic in 0..in_channels {
for h_in in 0..(in_height + 2 * padding) {
for w_in in 0..(in_width + 2 * padding) {
let mut grad = F::zero();
for oc in 0..out_channels {
for kh in 0..kernel_height {
if h_in < kh {
continue;
}
let oh = (h_in - kh) / stride;
if oh >= out_height || (h_in - kh) % stride != 0 {
continue;
}
for kw in 0..kernel_width {
if w_in < kw {
continue;
}
let ow = (w_in - kw) % stride;
if ow >= out_width || (w_in - kw) % stride != 0 {
continue;
}
grad = grad + dout[[b, oc, oh, ow]] * weight[[oc, ic, kh, kw]];
}
}
}
d_input_padded[[b, ic, h_in, w_in]] = grad;
}
}
}
}
for b in 0..batch_size {
for c in 0..in_channels {
for h in 0..in_height {
for w in 0..in_width {
d_input[[b, c, h, w]] = d_input_padded[[b, c, h + padding, w + padding]];
}
}
}
}
Ok((d_input, d_weight, d_bias))
}
pub fn max_pool2d_backward<F>(
dout: &ArrayView4<F>,
input: &ArrayView4<F>,
indices: &ArrayView4<usize>,
kernel_size: usize,
stride: usize,
padding: usize,
) -> Result<Array4<F>>
where
F: Float + Debug,
{
let batch_size = input.shape()[0];
let channels = input.shape()[1];
let in_height = input.shape()[2];
let in_width = input.shape()[3];
let out_height = dout.shape()[2];
let out_width = dout.shape()[3];
if dout.shape()[0] != batch_size || dout.shape()[1] != channels {
return Err(NeuralError::ShapeMismatch(format!(
"Gradient shape mismatch in max_pool2d_backward: dout shape {:?}, input shape {:?}",
dout.shape(),
input.shape()
)));
}
if indices.shape() != dout.shape() {
return Err(NeuralError::ShapeMismatch(format!(
"Indices shape mismatch in max_pool2d_backward: indices shape {:?}, dout shape {:?}",
indices.shape(),
dout.shape()
)));
}
let mut d_input = Array4::<F>::zeros(input.raw_dim());
for b in 0..batch_size {
for c in 0..channels {
for oh in 0..out_height {
for ow in 0..out_width {
let h_start = oh * stride;
let w_start = ow * stride;
let idx = indices[[b, c, oh, ow]];
let kh = idx / kernel_size;
let kw = idx % kernel_size;
let h_idx = h_start + kh - padding;
let w_idx = w_start + kw - padding;
if h_idx < in_height && w_idx < in_width {
d_input[[b, c, h_idx, w_idx]] =
d_input[[b, c, h_idx, w_idx]] + dout[[b, c, oh, ow]];
}
}
}
}
}
Ok(d_input)
}
pub fn adaptive_avg_pool2d_backward<F>(
dout: &ArrayView4<F>,
input: &ArrayView4<F>,
output_height: usize,
output_width: usize,
) -> Result<Array4<F>>
where
F: Float + Debug,
{
let batch_size = input.shape()[0];
let channels = input.shape()[1];
let in_height = input.shape()[2];
let in_width = input.shape()[3];
if dout.shape() != [batch_size, channels, output_height, output_width] {
return Err(NeuralError::ShapeMismatch(
format!("Gradient shape mismatch in adaptive_avg_pool2d_backward: dout shape {:?}, expected [{}, {}, {}, {}]",
dout.shape(), batch_size, channels, output_height, output_width)
));
}
let mut d_input = Array4::<F>::zeros(input.raw_dim());
for b in 0..batch_size {
for c in 0..channels {
for oh in 0..output_height {
for ow in 0..output_width {
let h_start = (oh * in_height) / output_height;
let h_end = ((oh + 1) * in_height) / output_height;
let w_start = (ow * in_width) / output_width;
let w_end = ((ow + 1) * in_width) / output_width;
let kernel_h = h_end - h_start;
let kernel_w = w_end - w_start;
let grad_factor = dout[[b, c, oh, ow]]
/ F::from(kernel_h * kernel_w).expect("Failed to convert to float");
for h in h_start..h_end {
for w in w_start..w_end {
d_input[[b, c, h, w]] = d_input[[b, c, h, w]] + grad_factor;
}
}
}
}
}
}
Ok(d_input)
}