use alloc::vec::Vec;
use burn_std::Shape;
use crate::{FlexTensor, Layout};
#[inline]
fn calculate_windows(dim_size: usize, window_size: usize, step: usize) -> usize {
assert!(step > 0, "step must be positive");
if dim_size + step < window_size {
0
} else {
(dim_size + step - window_size) / step
}
}
pub fn unfold(tensor: FlexTensor, dim: usize, size: usize, step: usize) -> FlexTensor {
let input_layout = tensor.layout();
let shape = input_layout.shape();
let input_strides = input_layout.strides();
let start_offset = input_layout.start_offset();
let ndims = shape.num_dims();
let dtype = tensor.dtype();
assert!(
dim < ndims,
"dim {} out of bounds for {} dimensions",
dim,
ndims
);
assert!(size > 0, "window size must be positive");
assert!(step > 0, "step must be positive");
assert!(
shape[dim] >= size,
"dimension {} has size {} which is smaller than window size {}",
dim,
shape[dim],
size
);
let dim_size = shape[dim];
let windows = calculate_windows(dim_size, size, step);
let mut output_dims: Vec<usize> = Vec::with_capacity(ndims + 1);
for (d, &s) in shape.iter().enumerate() {
if d == dim {
output_dims.push(windows);
} else {
output_dims.push(s);
}
}
output_dims.push(size);
let mut output_strides: Vec<isize> = Vec::with_capacity(ndims + 1);
for (d, &s) in input_strides.iter().enumerate() {
if d == dim {
output_strides.push(s * step as isize);
} else {
output_strides.push(s);
}
}
output_strides.push(input_strides[dim]);
let output_shape = Shape::from(output_dims);
let output_layout = Layout::new(output_shape, output_strides, start_offset);
FlexTensor::from_arc(tensor.data_arc(), output_layout, dtype)
}
pub fn unfold_f32(tensor: FlexTensor, dim: usize, size: usize, step: usize) -> FlexTensor {
unfold(tensor, dim, size, step)
}
pub fn unfold_f64(tensor: FlexTensor, dim: usize, size: usize, step: usize) -> FlexTensor {
unfold(tensor, dim, size, step)
}
pub fn unfold_bool(tensor: FlexTensor, dim: usize, size: usize, step: usize) -> FlexTensor {
unfold(tensor, dim, size, step)
}
pub fn unfold_int(tensor: FlexTensor, dim: usize, size: usize, step: usize) -> FlexTensor {
unfold(tensor, dim, size, step)
}
#[cfg(test)]
mod tests {
use super::*;
use burn_backend::TensorData;
#[test]
fn test_unfold_1d() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], [5]));
let result = unfold_f32(tensor, 0, 3, 1);
assert_eq!(result.layout().shape().to_vec(), vec![3, 3]);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_unfold_1d_step2() {
let tensor =
FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [6]));
let result = unfold_f32(tensor, 0, 3, 2);
assert_eq!(result.layout().shape().to_vec(), vec![2, 3]);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data, vec![1.0, 2.0, 3.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_unfold_2d_dim1() {
let tensor = FlexTensor::from_data(TensorData::new(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
[2, 4],
));
let result = unfold_f32(tensor, 1, 2, 1);
assert_eq!(result.layout().shape().to_vec(), vec![2, 3, 2]);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(
data,
vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 5.0, 6.0, 6.0, 7.0, 7.0, 8.0]
);
}
}