use burn::{
prelude::{
Backend,
Tensor,
},
tensor::Numeric,
};
use crate::utility::zspace::expect_point_bounds_check;
#[inline]
pub fn conv2d_kernel_midpoint_filter<B: Backend, K>(
shape: [usize; 2],
kernel: [usize; 2],
device: &B::Device,
) -> Tensor<B, 2, K>
where
K: Numeric<B>,
{
let region = [
(kernel[0] / 2)..shape[0] - ((kernel[0] - 1) / 2),
(kernel[1] / 2)..shape[1] - ((kernel[1] - 1) / 2),
];
Tensor::zeros(shape, device).slice_fill(region, 1)
}
#[cfg(test)]
mod tests {
use burn::{
backend::NdArray,
prelude::{
Device,
TensorData,
},
};
use super::*;
#[test]
fn test_conv2d_kernel_midpoint_filter() {
type B = NdArray;
let device = Default::default();
let shape = [7, 9];
let kernel_shape = [2, 3];
let mask: Tensor<B, 2> = conv2d_kernel_midpoint_filter(shape, kernel_shape, &device);
mask.to_data().assert_eq(
&TensorData::from([
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
]),
false,
);
}
}