burn-ndarray 0.18.0

Ndarray backend for the Burn framework
Documentation
use crate::{element::FloatNdArrayElement, sharing::UnsafeSharedRef, tensor::NdArrayTensor};
use burn_common::{iter_range_par, run_par};

use burn_tensor::{ElementConversion, TensorMetadata};
use ndarray::Array4;

pub(crate) fn avg_pool2d<E: FloatNdArrayElement>(
    x: NdArrayTensor<E>,
    kernel_size: [usize; 2],
    stride: [usize; 2],
    padding: [usize; 2],
    count_include_pad: bool,
) -> NdArrayTensor<E> {
    let [kernel_height, kernel_width] = kernel_size;
    let [padding_height, padding_width] = padding;
    let [stride_height, stride_width] = stride;
    let [batch_size, channels, x_height, x_width] = x.shape().dims();

    let out_height = ((x_height + 2 * padding_height - kernel_height) / stride_height) + 1;
    let out_width = ((x_width + 2 * padding_width - kernel_width) / stride_width) + 1;

    let x = x.array;

    let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), 0.elem());
    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);

    run_par!(|| {
        iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
            let b = k / channels;
            let c = k % channels;

            let output = unsafe_shared_out.get();

            for oh in 0..out_height {
                for ow in 0..out_width {
                    let mut sum_val: E = 0.elem();
                    let mut count: E = 0.elem();

                    for kh in 0..kernel_height {
                        for kw in 0..kernel_width {
                            let ih = oh * stride_height + kh;
                            let iw = ow * stride_width + kw;

                            if ih >= x_height + padding_height
                                || iw >= x_width + padding_width
                                || ih < padding_height
                                || iw < padding_width
                            {
                                continue;
                            }

                            let ih = ih - padding_height;
                            let iw = iw - padding_width;

                            count += 1.elem();
                            sum_val += x[[b, c, ih, iw]];
                        }
                    }

                    if count_include_pad {
                        count = ((kernel_height * kernel_width) as i32).elem();
                    }

                    output[[b, c, oh, ow]] = sum_val / count;
                }
            }
        })
    });

    NdArrayTensor::new(output.into_dyn().into_shared())
}

pub(crate) fn avg_pool2d_backward<E: FloatNdArrayElement>(
    x: NdArrayTensor<E>,
    grad: NdArrayTensor<E>,
    kernel_size: [usize; 2],
    stride: [usize; 2],
    padding: [usize; 2],
    count_include_pad: bool,
) -> NdArrayTensor<E> {
    let [kernel_height, kernel_width] = kernel_size;
    let [stride_height, stride_width] = stride;
    let [padding_height, padding_width] = padding;
    let [batch_size, channels, x_height, x_width] = x.shape().dims();
    let [_batch_size, _channels, out_height, out_width] = grad.shape().dims();

    let grad = grad.array;

    let mut output_grad = Array4::from_elem((batch_size, channels, x_height, x_width), 0.elem());
    let unsafe_shared_grad = UnsafeSharedRef::new(&mut output_grad);

    run_par!(|| {
        iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
            let b = k / channels;
            let c = k % channels;

            let output_grad = unsafe_shared_grad.get();

            for oh in 0..out_height {
                for ow in 0..out_width {
                    let ih_start = oh * stride_height;
                    let iw_start = ow * stride_width;

                    let ih_end = ih_start + kernel_height;
                    let iw_end = iw_start + kernel_width;

                    let ih_start = usize::max(ih_start, padding_height);
                    let iw_start = usize::max(iw_start, padding_width);

                    let ih_end = usize::min(ih_end, x_height + padding_height);
                    let iw_end = usize::min(iw_end, x_width + padding_width);

                    let count = match count_include_pad {
                        true => kernel_width * kernel_height,
                        false => (ih_end - ih_start) * (iw_end - iw_start),
                    };

                    for ih in ih_start..ih_end {
                        for iw in iw_start..iw_end {
                            let ih = ih - padding_height;
                            let iw = iw - padding_width;

                            output_grad[[b, c, ih, iw]] +=
                                grad[[b, c, oh, ow]] / (count as i32).elem();
                        }
                    }
                }
            }
        })
    });

    NdArrayTensor::new(output_grad.into_dyn().into_shared())
}