burn-ndarray 0.21.0-pre.3

Ndarray backend for the Burn framework
Documentation
use crate::{
    SharedArray, element::FloatNdArrayElement, iter_range_par, run_par, sharing::UnsafeSharedRef,
};
use burn_backend::ElementConversion;
use ndarray::Array4;

#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float;

pub(crate) fn adaptive_avg_pool2d<E: FloatNdArrayElement>(
    x: SharedArray<E>,
    output_size: [usize; 2],
) -> SharedArray<E> {
    let [batch_size, channels, input_height, input_width] = x.shape().try_into().unwrap();

    let mut output = Array4::from_elem(
        (batch_size, channels, output_size[0], output_size[1]),
        0.elem(),
    );
    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);

    run_par!(|| {
        iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
            let b = k / channels;
            let c = k % channels;

            let output = unsafe_shared_out.get();
            for h in 0..output_size[0] {
                for w in 0..output_size[1] {
                    let ih_start = start_index(h, output_size[0], input_height);
                    let ih_end = end_index(h, output_size[0], input_height);
                    let iw_start = start_index(w, output_size[1], input_width);
                    let iw_end = end_index(w, output_size[1], input_width);

                    let mut sum_val: E = 0.elem();

                    for ih in ih_start..ih_end {
                        for iw in iw_start..iw_end {
                            sum_val += x[[b, c, ih, iw]];
                        }
                    }

                    let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem();
                    output[[b, c, h, w]] = sum_val / count.elem();
                }
            }
        })
    });

    output.into_dyn().into_shared()
}

pub(crate) fn adaptive_avg_pool2d_backward<E: FloatNdArrayElement>(
    x: SharedArray<E>,
    grad: SharedArray<E>,
) -> SharedArray<E> {
    let [_, _, input_height, input_width] = x.shape().try_into().unwrap();
    let [batch_size, channels, output_height, output_width] = grad.shape().try_into().unwrap();

    let mut output_grad =
        Array4::from_elem((batch_size, channels, input_height, input_width), 0.elem());
    let unsafe_shared_out = UnsafeSharedRef::new(&mut output_grad);

    run_par!(|| {
        iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
            let b = k / channels;
            let c = k % channels;

            let output_grad = unsafe_shared_out.get();
            for oh in 0..output_height {
                for ow in 0..output_width {
                    let ih_start = start_index(oh, output_height, input_height);
                    let ih_end = end_index(oh, output_height, input_height);

                    let iw_start = start_index(ow, output_width, input_width);
                    let iw_end = end_index(ow, output_width, input_width);

                    let count: E = (((ih_end - ih_start) * (iw_end - iw_start)) as i32).elem();

                    for ih in ih_start..ih_end {
                        for iw in iw_start..iw_end {
                            output_grad[[b, c, ih, iw]] += grad[[b, c, oh, ow]] / count.elem();
                        }
                    }
                }
            }
        })
    });

    output_grad.into_dyn().into_shared()
}

fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
    ((output_size_index as f32 * input_size as f32) / output_size as f32).floor() as usize
}

fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize {
    let index =
        (((output_size_index + 1) as f32 * input_size as f32) / output_size as f32).ceil() as usize;

    usize::min(index, input_size)
}