use super::{ConvOptions, UnfoldOptions};
use crate::tensor::FloatTensor;
use crate::{Backend, TensorData, TensorMetadata, element::ElementConversion};
use alloc::vec;
use alloc::vec::Vec;
use burn_std::Shape;
pub(crate) fn create_unfolding_weight<B: Backend>(
in_channels: usize,
kernel_size: [usize; 2],
device: &B::Device,
) -> FloatTensor<B> {
let shape = Shape::new([
in_channels * kernel_size[0] * kernel_size[1],
in_channels,
kernel_size[0],
kernel_size[1],
]);
let mut strides = [0; 4];
let mut current = 1;
shape
.dims
.iter()
.enumerate()
.rev()
.for_each(|(index, val)| {
strides[index] = current;
current *= val;
});
let num_elements = shape.num_elements();
let mut weight: Vec<B::FloatElem> = vec![0.0.elem(); num_elements];
for k in 0..in_channels {
for i in 0..kernel_size[0] {
for j in 0..kernel_size[1] {
let output_channel = k * kernel_size[0] * kernel_size[1] + i * kernel_size[1] + j;
let index =
output_channel * strides[0] + k * strides[1] + i * strides[2] + j * strides[3];
weight[index] = 1.elem();
}
}
}
B::float_from_data(TensorData::new(weight, shape), device)
}
pub(crate) fn unfold4d_using_conv2d<B: Backend>(
x: FloatTensor<B>,
kernel_size: [usize; 2],
options: UnfoldOptions,
) -> FloatTensor<B> {
let [_batch_size, in_channels, _in_height, _in_width] = x.shape().dims();
let weight = create_unfolding_weight::<B>(in_channels, kernel_size, &B::float_device(&x));
let unfolded = B::conv2d(
x,
weight,
None,
ConvOptions::new(options.stride, options.padding, options.dilation, 1),
);
let [batch_size, channels_out, out_height, out_width] = unfolded.shape().dims();
B::float_reshape(
unfolded,
Shape::new([batch_size, channels_out, out_height * out_width]),
)
}
pub fn calculate_unfold_windows(dim_size: usize, window_size: usize, step_size: usize) -> usize {
assert!(step_size > 0);
let x = dim_size + step_size;
if x < window_size {
0
} else {
(x - window_size) / step_size
}
}
pub fn calculate_unfold_shape<S: Into<Shape>>(
shape: S,
dim: usize,
size: usize,
step: usize,
) -> Shape {
let mut shape = shape.into();
let d_shape = shape[dim];
let windows = calculate_unfold_windows(d_shape, size, step);
shape[dim] = windows;
shape.push(size);
shape
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_unfold_windows() {
assert_eq!(calculate_unfold_windows(2, 5, 1), 0);
assert_eq!(calculate_unfold_windows(2, 3, 1), 0);
assert_eq!(calculate_unfold_windows(3, 3, 1), 1);
assert_eq!(calculate_unfold_windows(4, 3, 1), 2);
assert_eq!(calculate_unfold_windows(5, 3, 1), 3);
assert_eq!(calculate_unfold_windows(2, 3, 2), 0);
assert_eq!(calculate_unfold_windows(3, 3, 2), 1);
assert_eq!(calculate_unfold_windows(4, 3, 2), 1);
assert_eq!(calculate_unfold_windows(5, 3, 2), 2);
}
#[test]
fn test_calculate_unfold_shape() {
assert_eq!(
calculate_unfold_shape([2, 6, 6], 1, 3, 2),
Shape::new([2, 2, 6, 3])
);
}
}