use crate::backend::Backend;
use crate::ops::FloatTensor;
use crate::{ElementConversion, Shape, TensorData, TensorMetadata};
use alloc::vec;
use alloc::vec::Vec;
use super::{ConvOptions, UnfoldOptions};
pub(crate) fn create_unfolding_weight<B: Backend>(
in_channels: usize,
kernel_size: [usize; 2],
device: &B::Device,
) -> FloatTensor<B> {
let shape = Shape::new([
in_channels * kernel_size[0] * kernel_size[1],
in_channels,
kernel_size[0],
kernel_size[1],
]);
let mut strides = [0; 4];
let mut current = 1;
shape
.dims
.iter()
.enumerate()
.rev()
.for_each(|(index, val)| {
strides[index] = current;
current *= val;
});
let num_elements = shape.num_elements();
let mut weight: Vec<B::FloatElem> = vec![0.0.elem(); num_elements];
for k in 0..in_channels {
for i in 0..kernel_size[0] {
for j in 0..kernel_size[1] {
let output_channel = k * kernel_size[0] * kernel_size[1] + i * kernel_size[1] + j;
let index =
output_channel * strides[0] + k * strides[1] + i * strides[2] + j * strides[3];
weight[index] = 1.elem();
}
}
}
B::float_from_data(TensorData::new(weight, shape), device)
}
pub(crate) fn unfold4d_using_conv2d<B: Backend>(
x: FloatTensor<B>,
kernel_size: [usize; 2],
options: UnfoldOptions,
) -> FloatTensor<B> {
let [_batch_size, in_channels, _in_height, _in_width] = x.shape().dims();
let weight = create_unfolding_weight::<B>(in_channels, kernel_size, &B::float_device(&x));
let unfolded = B::conv2d(
x,
weight,
None,
ConvOptions::new(options.stride, options.padding, options.dilation, 1),
);
let [batch_size, channels_out, out_height, out_width] = unfolded.shape().dims();
B::float_reshape(
unfolded,
Shape::new([batch_size, channels_out, out_height * out_width]),
)
}