use scirs2_core::ndarray::{Array2, Array4, ArrayView4, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign, Zero};
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
#[allow(dead_code)]
pub fn im2col<F>(
input: &ArrayView4<F>,
kernelsize: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + Zero + ScalarOperand,
{
let (batchsize, channels, height, width) = input.dim();
let (kernel_h, kernel_w) = kernelsize;
let (stride_h, stride_w) = stride;
let (padding_h, padding_w) = padding;
let (dilation_h, dilation_w) = dilation;
let output_h = ((height + 2 * padding_h - dilation_h * (kernel_h - 1) - 1) / stride_h) + 1;
let output_w = ((width + 2 * padding_w - dilation_w * (kernel_w - 1) - 1) / stride_w) + 1;
if output_h == 0 || output_w == 0 {
return Err(LinalgError::ShapeError(format!(
"Invalid output dimensions: ({output_h}, {output_w})"
)));
}
let mut cols = Array2::<F>::zeros((
kernel_h * kernel_w * channels,
output_h * output_w * batchsize,
));
for batch_idx in 0..batchsize {
for channel_idx in 0..channels {
for kernel_row in 0..kernel_h {
for kernel_col in 0..kernel_w {
let input_row_offset = kernel_row * dilation_h;
let input_col_offset = kernel_col * dilation_w;
let cols_idx =
channel_idx * kernel_h * kernel_w + kernel_row * kernel_w + kernel_col;
for output_row in 0..output_h {
for output_col in 0..output_w {
let input_row = output_row * stride_h + input_row_offset;
let input_col = output_col * stride_w + input_col_offset;
let cols_pos = batch_idx * output_h * output_w
+ output_row * output_w
+ output_col;
if input_row < padding_h
|| input_row >= height + padding_h
|| input_col < padding_w
|| input_col >= width + padding_w
{
cols[[cols_idx, cols_pos]] = F::zero();
} else {
let input_val = input[[
batch_idx,
channel_idx,
input_row - padding_h,
input_col - padding_w,
]];
cols[[cols_idx, cols_pos]] = input_val;
}
}
}
}
}
}
}
Ok(cols)
}
#[allow(dead_code)]
pub fn col2im<F>(
cols: &scirs2_core::ndarray::ArrayView2<F>,
outputshape: (usize, usize, usize, usize),
kernelsize: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
) -> LinalgResult<Array4<F>>
where
F: Float + NumAssign + Sum + Zero + ScalarOperand,
{
let (batchsize, channels, height, width) = outputshape;
let (kernel_h, kernel_w) = kernelsize;
let (stride_h, stride_w) = stride;
let (padding_h, padding_w) = padding;
let (dilation_h, dilation_w) = dilation;
let output_h = ((height + 2 * padding_h - dilation_h * (kernel_h - 1) - 1) / stride_h) + 1;
let output_w = ((width + 2 * padding_w - dilation_w * (kernel_w - 1) - 1) / stride_w) + 1;
if output_h == 0 || output_w == 0 {
return Err(LinalgError::ShapeError(format!(
"Invalid output dimensions: ({output_h}, {output_w})"
)));
}
if cols.shape()[0] != kernel_h * kernel_w * channels
|| cols.shape()[1] != output_h * output_w * batchsize
{
return Err(LinalgError::ShapeError(format!(
"Invalid cols shape: expected ({}, {}), got ({}, {})",
kernel_h * kernel_w * channels,
output_h * output_w * batchsize,
cols.shape()[0],
cols.shape()[1]
)));
}
let mut output = Array4::<F>::zeros((batchsize, channels, height, width));
let mut counts = Array4::<usize>::zeros((batchsize, channels, height, width));
for batch_idx in 0..batchsize {
for channel_idx in 0..channels {
for kernel_row in 0..kernel_h {
for kernel_col in 0..kernel_w {
let input_row_offset = kernel_row * dilation_h;
let input_col_offset = kernel_col * dilation_w;
let cols_idx =
channel_idx * kernel_h * kernel_w + kernel_row * kernel_w + kernel_col;
for output_row in 0..output_h {
for output_col in 0..output_w {
let input_row = output_row * stride_h + input_row_offset;
let input_col = output_col * stride_w + input_col_offset;
let cols_pos = batch_idx * output_h * output_w
+ output_row * output_w
+ output_col;
if input_row >= padding_h
&& input_row < height + padding_h
&& input_col >= padding_w
&& input_col < width + padding_w
{
let output_row_idx = input_row - padding_h;
let output_col_idx = input_col - padding_w;
output[[batch_idx, channel_idx, output_row_idx, output_col_idx]] +=
cols[[cols_idx, cols_pos]];
counts[[batch_idx, channel_idx, output_row_idx, output_col_idx]] +=
1;
}
}
}
}
}
}
}
for batch_idx in 0..batchsize {
for channel_idx in 0..channels {
for h in 0..height {
for w in 0..width {
let count = counts[[batch_idx, channel_idx, h, w]];
if count > 0 {
output[[batch_idx, channel_idx, h, w]] /=
F::from(count).expect("Operation failed");
}
}
}
}
}
Ok(output)
}
#[allow(dead_code)]
pub fn max_pool2d<F>(
input: &ArrayView4<F>,
poolsize: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
) -> LinalgResult<(Array4<F>, Array4<usize>)>
where
F: Float + NumAssign + Sum + Zero + ScalarOperand,
{
let (batchsize, channels, height, width) = input.dim();
let (pool_h, pool_w) = poolsize;
let (stride_h, stride_w) = stride;
let (padding_h, padding_w) = padding;
let output_h = ((height + 2 * padding_h - pool_h) / stride_h) + 1;
let output_w = ((width + 2 * padding_w - pool_w) / stride_w) + 1;
if output_h == 0 || output_w == 0 {
return Err(LinalgError::ShapeError(format!(
"Invalid output dimensions: ({output_h}, {output_w})"
)));
}
let mut output = Array4::<F>::zeros((batchsize, channels, output_h, output_w));
let mut indices = Array4::<usize>::zeros((batchsize, channels, output_h, output_w));
for batch_idx in 0..batchsize {
for channel_idx in 0..channels {
for output_row in 0..output_h {
for output_col in 0..output_w {
let start_h = output_row * stride_h;
let start_w = output_col * stride_w;
let mut max_val = F::neg_infinity();
let mut max_idx = 0;
for pool_row in 0..pool_h {
for pool_col in 0..pool_w {
let input_row = start_h + pool_row;
let input_col = start_w + pool_col;
if input_row >= padding_h
&& input_row < height + padding_h
&& input_col >= padding_w
&& input_col < width + padding_w
{
let input_row_idx = input_row - padding_h;
let input_col_idx = input_col - padding_w;
let val =
input[[batch_idx, channel_idx, input_row_idx, input_col_idx]];
if val > max_val {
max_val = val;
max_idx = input_row_idx * width + input_col_idx;
}
}
}
}
output[[batch_idx, channel_idx, output_row, output_col]] = max_val;
indices[[batch_idx, channel_idx, output_row, output_col]] = max_idx;
}
}
}
}
Ok((output, indices))
}
#[allow(dead_code)]
pub fn max_pool2d_backward<F>(
grad_output: &ArrayView4<F>,
indices: &scirs2_core::ndarray::ArrayView4<usize>,
inputshape: (usize, usize, usize, usize),
) -> LinalgResult<Array4<F>>
where
F: Float + NumAssign + Sum + Zero + ScalarOperand,
{
let (batchsize, channels, height, width) = inputshape;
let (out_batch, out_channels_, out_height, out_width) = grad_output.dim();
let (idx_batch, idx_channels, idx_height, idx_width) = indices.dim();
if out_batch != idx_batch
|| out_channels_ != idx_channels
|| out_height != idx_height
|| out_width != idx_width
{
return Err(LinalgError::ShapeError(format!(
"Shape mismatch between grad_output ({out_batch}, {out_channels_}, {out_height}, {out_width}) and indices ({idx_batch}, {idx_channels}, {idx_height}, {idx_width})"
)));
}
let mut grad_input = Array4::<F>::zeros((batchsize, channels, height, width));
for batch_idx in 0..out_batch {
for channel_idx in 0..out_channels_ {
for h in 0..out_height {
for w in 0..out_width {
let index = indices[[batch_idx, channel_idx, h, w]];
let input_h = index / width;
let input_w = index % width;
if input_h < height && input_w < width {
grad_input[[batch_idx, channel_idx, input_h, input_w]] +=
grad_output[[batch_idx, channel_idx, h, w]];
}
}
}
}
}
Ok(grad_input)
}
#[allow(dead_code)]
pub fn compute_conv_indices(
inputshape: (usize, usize, usize, usize),
kernelshape: (usize, usize, usize, usize),
stride: (usize, usize),
padding: (usize, usize),
) -> LinalgResult<scirs2_core::ndarray::Array1<usize>> {
let (batchsize, _in_channels, height, width) = inputshape;
let (out_channels_, in_channels, kernel_h, kernel_w) = kernelshape;
let (stride_h, stride_w) = stride;
let (padding_h, padding_w) = padding;
let output_h = ((height + 2 * padding_h - kernel_h) / stride_h) + 1;
let output_w = ((width + 2 * padding_w - kernel_w) / stride_w) + 1;
if output_h == 0 || output_w == 0 {
return Err(LinalgError::ShapeError(format!(
"Invalid output dimensions: ({output_h}, {output_w})"
)));
}
let total_elements =
batchsize * out_channels_ * output_h * output_w * in_channels * kernel_h * kernel_w;
let mut indices = scirs2_core::ndarray::Array1::<usize>::zeros(total_elements * 5);
let mut idx = 0;
for b in 0..batchsize {
for oc in 0..out_channels_ {
for oh in 0..output_h {
for ow in 0..output_w {
for ic in 0..in_channels {
for kh in 0..kernel_h {
for kw in 0..kernel_w {
let ih = oh * stride_h + kh;
let iw = ow * stride_w + kw;
if ih >= padding_h
&& ih < height + padding_h
&& iw >= padding_w
&& iw < width + padding_w
{
let real_ih = ih - padding_h;
let real_iw = iw - padding_w;
let out_idx = b * out_channels_ * output_h * output_w
+ oc * output_h * output_w
+ oh * output_w
+ ow;
let in_idx = b * in_channels * height * width
+ ic * height * width
+ real_ih * width
+ real_iw;
let kernel_idx = oc * in_channels * kernel_h * kernel_w
+ ic * kernel_h * kernel_w
+ kh * kernel_w
+ kw;
indices[idx] = out_idx;
indices[idx + 1] = in_idx;
indices[idx + 2] = kernel_idx;
indices[idx + 3] = oh * output_w + ow;
indices[idx + 4] = oc;
idx += 5;
}
}
}
}
}
}
}
}
let indices = indices.slice(scirs2_core::ndarray::s![0..idx]).to_owned();
Ok(indices)
}
#[allow(dead_code)]
pub fn conv2d_im2col<F>(
input: &ArrayView4<F>,
kernel: &ArrayView4<F>,
bias: Option<scirs2_core::ndarray::ArrayView1<F>>,
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
) -> LinalgResult<Array4<F>>
where
F: Float + NumAssign + Sum + Zero + ScalarOperand,
{
let (batchsize, in_channels, height, width) = input.dim();
let (out_channels_, k_in_channels, kernel_h, kernel_w) = kernel.dim();
if in_channels != k_in_channels {
return Err(LinalgError::ShapeError(format!(
"Input channels ({in_channels}) must match kernel in_channels ({k_in_channels})"
)));
}
if let Some(b) = bias {
if b.len() != out_channels_ {
return Err(LinalgError::ShapeError(format!(
"Bias length ({}) must match out_channels_ ({})",
b.len(),
out_channels_
)));
}
}
let (stride_h, stride_w) = stride;
let (padding_h, padding_w) = padding;
let (dilation_h, dilation_w) = dilation;
let output_h = ((height + 2 * padding_h - dilation_h * (kernel_h - 1) - 1) / stride_h) + 1;
let output_w = ((width + 2 * padding_w - dilation_w * (kernel_w - 1) - 1) / stride_w) + 1;
if output_h == 0 || output_w == 0 {
return Err(LinalgError::ShapeError(format!(
"Invalid output dimensions: ({output_h}, {output_w})"
)));
}
let cols = im2col(input, (kernel_h, kernel_w), stride, padding, dilation)?;
let flat_kernel = (*kernel)
.into_shape_with_order((out_channels_, in_channels * kernel_h * kernel_w))
.map_err(|e| LinalgError::ShapeError(e.to_string()))?;
let output_2d = flat_kernel.dot(&cols);
let mut output = output_2d
.into_shape_with_order((out_channels_, batchsize, output_h, output_w))
.map_err(|e| LinalgError::ShapeError(e.to_string()))?;
output = output.permuted_axes([1, 0, 2, 3]);
if let Some(b) = bias {
for batch_idx in 0..batchsize {
for oc in 0..out_channels_ {
for h in 0..output_h {
for w in 0..output_w {
output[[batch_idx, oc, h, w]] += b[oc];
}
}
}
}
}
Ok(output)
}
#[allow(dead_code)]
pub fn conv2d_backward_input<F>(
grad_output: &ArrayView4<F>,
kernel: &ArrayView4<F>,
inputshape: (usize, usize, usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
) -> LinalgResult<Array4<F>>
where
F: Float + NumAssign + Sum + Zero + ScalarOperand,
{
let (batchsize, out_channels, _output_h, _output_w) = grad_output.dim();
let (k_out_channels, in_channels, kernel_h, kernel_w) = kernel.dim();
let (i_batchsize, i_in_channels, _height, _width) = inputshape;
if batchsize != i_batchsize {
return Err(LinalgError::ShapeError(format!(
"Batch size mismatch: grad_output ({batchsize}) vs inputshape ({i_batchsize})"
)));
}
if out_channels != k_out_channels {
return Err(LinalgError::ShapeError(format!(
"Output channels mismatch: grad_output ({out_channels}) vs kernel ({k_out_channels})"
)));
}
if in_channels != i_in_channels {
return Err(LinalgError::ShapeError(format!(
"Input channels mismatch: kernel ({in_channels}) vs inputshape ({i_in_channels})"
)));
}
let mut kernel_transposed = Array4::<F>::zeros((in_channels, out_channels, kernel_h, kernel_w));
for oc in 0..out_channels {
for ic in 0..in_channels {
for kh in 0..kernel_h {
for kw in 0..kernel_w {
kernel_transposed[[ic, oc, kernel_h - 1 - kh, kernel_w - 1 - kw]] =
kernel[[oc, ic, kh, kw]];
}
}
}
}
let (_stride_h, _stride_w) = stride;
let (padding_h, padding_w) = padding;
let (_dilation_h, _dilation_w) = dilation;
let pad_h = kernel_h - 1 - padding_h;
let pad_w = kernel_w - 1 - padding_w;
conv2d_im2col(
grad_output,
&kernel_transposed.view(),
None,
dilation, (pad_h, pad_w), stride, )
}
#[allow(dead_code)]
pub fn conv2d_backward_kernel<F>(
input: &ArrayView4<F>,
grad_output: &ArrayView4<F>,
kernelshape: (usize, usize, usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
) -> LinalgResult<Array4<F>>
where
F: Float + NumAssign + Sum + Zero + ScalarOperand,
{
let (batchsize, in_channels, _height, _width) = input.dim();
let (go_batchsize, out_channels_, output_h, output_w) = grad_output.dim();
let (k_out_channels, k_in_channels, kernel_h, kernel_w) = kernelshape;
if batchsize != go_batchsize {
return Err(LinalgError::ShapeError(format!(
"Batch size mismatch: input ({batchsize}) vs grad_output ({go_batchsize})"
)));
}
if out_channels_ != k_out_channels {
return Err(LinalgError::ShapeError(format!(
"Output channels mismatch: grad_output ({out_channels_}) vs kernelshape ({k_out_channels})"
)));
}
if in_channels != k_in_channels {
return Err(LinalgError::ShapeError(format!(
"Input channels mismatch: input ({in_channels}) vs kernelshape ({k_in_channels})"
)));
}
let cols = im2col(input, (kernel_h, kernel_w), stride, padding, dilation)?;
let grad_output_reshaped = (*grad_output)
.into_shape_with_order((batchsize * out_channels_, output_h * output_w))
.map_err(|e| LinalgError::ShapeError(e.to_string()))?;
let grad_kernel_flat = grad_output_reshaped.dot(&cols.t());
let grad_kernel = grad_kernel_flat
.into_shape_with_order((out_channels_, in_channels, kernel_h, kernel_w))
.map_err(|e| LinalgError::ShapeError(e.to_string()))?;
Ok(grad_kernel)
}
#[allow(dead_code)]
pub fn conv2d_backward_bias<F>(
grad_output: &ArrayView4<F>,
) -> LinalgResult<scirs2_core::ndarray::Array1<F>>
where
F: Float + NumAssign + Sum + Zero,
{
let (batchsize, out_channels_, output_h, output_w) = grad_output.dim();
let mut grad_bias = scirs2_core::ndarray::Array1::<F>::zeros(out_channels_);
for batch_idx in 0..batchsize {
for oc in 0..out_channels_ {
for h in 0..output_h {
for w in 0..output_w {
grad_bias[oc] += grad_output[[batch_idx, oc, h, w]];
}
}
}
}
Ok(grad_bias)
}
#[allow(dead_code)]
pub fn conv_transpose2d<F>(
input: &ArrayView4<F>,
kernel: &ArrayView4<F>,
bias: Option<scirs2_core::ndarray::ArrayView1<F>>,
stride: (usize, usize),
padding: (usize, usize),
output_padding: (usize, usize),
dilation: (usize, usize),
) -> LinalgResult<Array4<F>>
where
F: Float + NumAssign + Sum + Zero + ScalarOperand,
{
let (batchsize, in_channels, height, width) = input.dim();
let (k_in_channels, out_channels_, kernel_h, kernel_w) = kernel.dim();
if in_channels != k_in_channels {
return Err(LinalgError::ShapeError(format!(
"Input channels mismatch: input ({in_channels}) vs kernel ({k_in_channels})"
)));
}
if let Some(b) = bias {
if b.len() != out_channels_ {
return Err(LinalgError::ShapeError(format!(
"Bias length ({}) must match out_channels_ ({})",
b.len(),
out_channels_
)));
}
}
let (stride_h, stride_w) = stride;
let (padding_h, padding_w) = padding;
let (output_padding_h, output_padding_w) = output_padding;
let (dilation_h, dilation_w) = dilation;
let output_h = (height - 1) * stride_h - 2 * padding_h
+ dilation_h * (kernel_h - 1)
+ output_padding_h
+ 1;
let output_w =
(width - 1) * stride_w - 2 * padding_w + dilation_w * (kernel_w - 1) + output_padding_w + 1;
let mut output = Array4::<F>::zeros((batchsize, out_channels_, output_h, output_w));
for b in 0..batchsize {
for oc in 0..out_channels_ {
for ic in 0..in_channels {
for h in 0..height {
for w in 0..width {
let input_val = input[[b, ic, h, w]];
for kh in 0..kernel_h {
for kw in 0..kernel_w {
let out_h = h as isize * stride_h as isize
+ kh as isize * dilation_h as isize
- padding_h as isize;
let out_w = w as isize * stride_w as isize
+ kw as isize * dilation_w as isize
- padding_w as isize;
if out_h >= 0 && out_w >= 0 {
let out_h = out_h as usize;
let out_w = out_w as usize;
if out_h < output_h && out_w < output_w {
output[[b, oc, out_h, out_w]] +=
input_val * kernel[[ic, oc, kh, kw]];
}
}
}
}
}
}
}
if let Some(b_val) = bias.map(|b| b[oc]) {
for h in 0..output_h {
for w in 0..output_w {
output[[b, oc, h, w]] += b_val;
}
}
}
}
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::{Array1, Array4};
#[test]
fn test_im2col_basic() {
let mut input = Array4::<f32>::zeros((1, 1, 3, 3));
for h in 0..3 {
for w in 0..3 {
input[[0, 0, h, w]] = (h * 3 + w) as f32;
}
}
let cols = im2col(&input.view(), (2, 2), (1, 1), (0, 0), (1, 1)).expect("Operation failed");
assert_eq!(cols.shape(), &[4, 4]);
assert_eq!(cols[[0, 0]], 0.0);
assert_eq!(cols[[1, 0]], 1.0);
assert_eq!(cols[[2, 0]], 3.0);
assert_eq!(cols[[3, 0]], 4.0);
assert_eq!(cols[[0, 1]], 1.0);
assert_eq!(cols[[1, 1]], 2.0);
assert_eq!(cols[[2, 1]], 4.0);
assert_eq!(cols[[3, 1]], 5.0);
}
#[test]
fn test_im2col_with_padding() {
let mut input = Array4::<f32>::zeros((1, 1, 2, 2));
input[[0, 0, 0, 0]] = 0.0;
input[[0, 0, 0, 1]] = 1.0;
input[[0, 0, 1, 0]] = 2.0;
input[[0, 0, 1, 1]] = 3.0;
let cols = im2col(&input.view(), (3, 3), (1, 1), (1, 1), (1, 1)).expect("Operation failed");
assert_eq!(cols.shape(), &[9, 4]);
assert_eq!(cols[[0, 0]], 0.0); assert_eq!(cols[[2, 0]], 0.0); assert_eq!(cols[[6, 0]], 0.0); assert_eq!(cols[[8, 0]], 3.0);
assert_eq!(cols[[4, 0]], 0.0); }
#[test]
fn test_col2im_basic() {
let mut input = Array4::<f32>::zeros((1, 1, 3, 3));
for h in 0..3 {
for w in 0..3 {
input[[0, 0, h, w]] = (h * 3 + w) as f32;
}
}
let cols = im2col(&input.view(), (2, 2), (1, 1), (0, 0), (1, 1)).expect("Operation failed");
let output = col2im(&cols.view(), (1, 1, 3, 3), (2, 2), (1, 1), (0, 0), (1, 1))
.expect("Operation failed");
assert_eq!(output.shape(), input.shape());
assert_relative_eq!(output[[0, 0, 0, 0]], input[[0, 0, 0, 0]], epsilon = 1e-5);
assert_relative_eq!(output[[0, 0, 0, 1]], input[[0, 0, 0, 1]], epsilon = 1e-5);
assert_relative_eq!(output[[0, 0, 0, 2]], input[[0, 0, 0, 2]], epsilon = 1e-5);
assert_relative_eq!(output[[0, 0, 1, 0]], input[[0, 0, 1, 0]], epsilon = 1e-5);
assert_relative_eq!(output[[0, 0, 1, 1]], input[[0, 0, 1, 1]], epsilon = 1e-5);
assert_relative_eq!(output[[0, 0, 1, 2]], input[[0, 0, 1, 2]], epsilon = 1e-5);
assert_relative_eq!(output[[0, 0, 2, 0]], input[[0, 0, 2, 0]], epsilon = 1e-5);
assert_relative_eq!(output[[0, 0, 2, 1]], input[[0, 0, 2, 1]], epsilon = 1e-5);
assert_relative_eq!(output[[0, 0, 2, 2]], input[[0, 0, 2, 2]], epsilon = 1e-5);
}
#[test]
fn test_max_pool2d() {
let mut input = Array4::<f32>::zeros((1, 1, 4, 4));
for h in 0..4 {
for w in 0..4 {
input[[0, 0, h, w]] = (h * 4 + w) as f32;
}
}
let (output, indices) =
max_pool2d(&input.view(), (2, 2), (2, 2), (0, 0)).expect("Operation failed");
assert_eq!(output.shape(), &[1, 1, 2, 2]);
assert_eq!(output[[0, 0, 0, 0]], 5.0); assert_eq!(output[[0, 0, 0, 1]], 7.0); assert_eq!(output[[0, 0, 1, 0]], 13.0); assert_eq!(output[[0, 0, 1, 1]], 15.0);
assert_eq!(indices[[0, 0, 0, 0]], 5); assert_eq!(indices[[0, 0, 0, 1]], 7); assert_eq!(indices[[0, 0, 1, 0]], 13); assert_eq!(indices[[0, 0, 1, 1]], 15); }
#[test]
fn test_max_pool2d_backward() {
let mut input = Array4::<f32>::zeros((1, 1, 4, 4));
for h in 0..4 {
for w in 0..4 {
input[[0, 0, h, w]] = (h * 4 + w) as f32;
}
}
let (_output, indices) =
max_pool2d(&input.view(), (2, 2), (2, 2), (0, 0)).expect("Operation failed");
let grad_output = Array4::<f32>::ones((1, 1, 2, 2));
let grad_input = max_pool2d_backward(&grad_output.view(), &indices.view(), (1, 1, 4, 4))
.expect("Operation failed");
assert_eq!(grad_input.shape(), input.shape());
for h in 0..4 {
for w in 0..4 {
let pos = h * 4 + w;
let expected = if pos == 5 || pos == 7 || pos == 13 || pos == 15 {
1.0
} else {
0.0
};
assert_eq!(grad_input[[0, 0, h, w]], expected);
}
}
}
#[test]
fn test_conv2d_im2col_basic() {
let mut input = Array4::<f32>::zeros((1, 1, 3, 3));
for h in 0..3 {
for w in 0..3 {
input[[0, 0, h, w]] = (h * 3 + w) as f32;
}
}
let mut kernel = Array4::<f32>::zeros((1, 1, 2, 2));
kernel[[0, 0, 0, 0]] = 1.0;
kernel[[0, 0, 0, 1]] = 0.0;
kernel[[0, 0, 1, 0]] = 0.0;
kernel[[0, 0, 1, 1]] = 0.0;
let output = conv2d_im2col(&input.view(), &kernel.view(), None, (1, 1), (0, 0), (1, 1))
.expect("Operation failed");
assert_eq!(output.shape(), &[1, 1, 2, 2]);
assert_eq!(output[[0, 0, 0, 0]], 0.0);
assert_eq!(output[[0, 0, 0, 1]], 1.0);
assert_eq!(output[[0, 0, 1, 0]], 3.0);
assert_eq!(output[[0, 0, 1, 1]], 4.0);
}
#[test]
fn test_conv2d_im2col_with_bias() {
let mut input = Array4::<f32>::zeros((1, 1, 3, 3));
for h in 0..3 {
for w in 0..3 {
input[[0, 0, h, w]] = (h * 3 + w) as f32;
}
}
let mut kernel = Array4::<f32>::zeros((1, 1, 2, 2));
kernel[[0, 0, 0, 0]] = 1.0;
kernel[[0, 0, 0, 1]] = 0.0;
kernel[[0, 0, 1, 0]] = 0.0;
kernel[[0, 0, 1, 1]] = 0.0;
let bias = Array1::<f32>::from_elem(1, 10.0);
let output = conv2d_im2col(
&input.view(),
&kernel.view(),
Some(bias.view()),
(1, 1),
(0, 0),
(1, 1),
)
.expect("Operation failed");
assert_eq!(output.shape(), &[1, 1, 2, 2]);
assert_eq!(output[[0, 0, 0, 0]], 10.0);
assert_eq!(output[[0, 0, 0, 1]], 11.0);
assert_eq!(output[[0, 0, 1, 0]], 13.0);
assert_eq!(output[[0, 0, 1, 1]], 14.0);
}
#[test]
fn test_conv2d_backward_input() {
let input = Array4::<f32>::zeros((1, 1, 3, 3));
let mut kernel = Array4::<f32>::zeros((1, 1, 2, 2));
kernel[[0, 0, 0, 0]] = 1.0;
kernel[[0, 0, 0, 1]] = 2.0;
kernel[[0, 0, 1, 0]] = 3.0;
kernel[[0, 0, 1, 1]] = 4.0;
let _output = conv2d_im2col(&input.view(), &kernel.view(), None, (1, 1), (0, 0), (1, 1))
.expect("Operation failed");
let grad_output = Array4::<f32>::ones((1, 1, 2, 2));
let grad_input = conv2d_backward_input(
&grad_output.view(),
&kernel.view(),
(1, 1, 3, 3),
(1, 1),
(0, 0),
(1, 1),
)
.expect("Operation failed");
assert_eq!(grad_input.shape(), input.shape());
assert_eq!(grad_input[[0, 0, 0, 0]], 1.0);
assert_eq!(grad_input[[0, 0, 0, 1]], 3.0);
assert_eq!(grad_input[[0, 0, 1, 0]], 4.0);
assert_eq!(grad_input[[0, 0, 1, 1]], 10.0);
}
#[test]
fn test_conv2d_backward_kernel() {
let input = Array4::<f32>::ones((1, 1, 3, 3));
let grad_output = Array4::<f32>::ones((1, 1, 2, 2));
let grad_kernel = conv2d_backward_kernel(
&input.view(),
&grad_output.view(),
(1, 1, 2, 2),
(1, 1),
(0, 0),
(1, 1),
)
.expect("Operation failed");
assert_eq!(grad_kernel.shape(), &[1, 1, 2, 2]);
assert_eq!(grad_kernel[[0, 0, 0, 0]], 4.0); assert_eq!(grad_kernel[[0, 0, 0, 1]], 4.0);
assert_eq!(grad_kernel[[0, 0, 1, 0]], 4.0);
assert_eq!(grad_kernel[[0, 0, 1, 1]], 4.0);
}
#[test]
fn test_conv2d_backward_bias() {
let mut grad_output = Array4::<f32>::zeros((2, 3, 2, 2));
for b in 0..2 {
for c in 0..3 {
for h in 0..2 {
for w in 0..2 {
grad_output[[b, c, h, w]] = 1.0;
}
}
}
}
let grad_bias = conv2d_backward_bias(&grad_output.view()).expect("Operation failed");
assert_eq!(grad_bias.shape(), &[3]);
assert_eq!(grad_bias[0], 8.0); assert_eq!(grad_bias[1], 8.0);
assert_eq!(grad_bias[2], 8.0);
}
#[test]
fn test_conv_transpose2d() {
let input = Array4::<f32>::ones((1, 1, 2, 2));
let mut kernel = Array4::<f32>::zeros((1, 1, 3, 3));
kernel[[0, 0, 0, 0]] = 1.0;
let output = conv_transpose2d(
&input.view(),
&kernel.view(),
None,
(2, 2), (1, 1), (0, 0), (1, 1), )
.expect("Operation failed");
assert_eq!(output.shape(), &[1, 1, 3, 3]);
assert_eq!(output[[0, 0, 0, 0]], 0.0); assert_eq!(output[[0, 0, 0, 1]], 0.0); assert_eq!(output[[0, 0, 0, 2]], 0.0); assert_eq!(output[[0, 0, 1, 0]], 0.0); assert_eq!(output[[0, 0, 1, 1]], 1.0); assert_eq!(output[[0, 0, 1, 2]], 0.0); assert_eq!(output[[0, 0, 2, 0]], 0.0); assert_eq!(output[[0, 0, 2, 1]], 0.0); assert_eq!(output[[0, 0, 2, 2]], 0.0); }
}