use alloc::vec::Vec;
use burn_std::Shape;
use crate::{FlexTensor, Layout};
#[inline]
fn calculate_windows(dim_size: usize, window_size: usize, step: usize) -> usize {
assert!(step > 0, "step must be positive");
if dim_size + step < window_size {
0
} else {
(dim_size + step - window_size) / step
}
}
pub fn unfold(tensor: FlexTensor, dim: usize, size: usize, step: usize) -> FlexTensor {
let input_layout = tensor.layout();
let shape = input_layout.shape();
let input_strides = input_layout.strides();
let start_offset = input_layout.start_offset();
let ndims = shape.num_dims();
let dtype = tensor.dtype();
assert!(
dim < ndims,
"dim {} out of bounds for {} dimensions",
dim,
ndims
);
assert!(size > 0, "window size must be positive");
assert!(step > 0, "step must be positive");
assert!(
shape[dim] >= size,
"dimension {} has size {} which is smaller than window size {}",
dim,
shape[dim],
size
);
let dim_size = shape[dim];
let windows = calculate_windows(dim_size, size, step);
let mut output_dims: Vec<usize> = Vec::with_capacity(ndims + 1);
for (d, &s) in shape.iter().enumerate() {
if d == dim {
output_dims.push(windows);
} else {
output_dims.push(s);
}
}
output_dims.push(size);
let mut output_strides: Vec<isize> = Vec::with_capacity(ndims + 1);
for (d, &s) in input_strides.iter().enumerate() {
if d == dim {
output_strides.push(s * step as isize);
} else {
output_strides.push(s);
}
}
output_strides.push(input_strides[dim]);
let output_shape = Shape::from(output_dims);
let output_layout = Layout::new(output_shape, output_strides, start_offset);
FlexTensor::from_arc(tensor.data_arc(), output_layout, dtype)
}
pub fn unfold_f32(tensor: FlexTensor, dim: usize, size: usize, step: usize) -> FlexTensor {
unfold(tensor, dim, size, step)
}
pub fn unfold_f64(tensor: FlexTensor, dim: usize, size: usize, step: usize) -> FlexTensor {
unfold(tensor, dim, size, step)
}
pub fn unfold_bool(tensor: FlexTensor, dim: usize, size: usize, step: usize) -> FlexTensor {
unfold(tensor, dim, size, step)
}
pub fn unfold_int(tensor: FlexTensor, dim: usize, size: usize, step: usize) -> FlexTensor {
unfold(tensor, dim, size, step)
}