#[burn_tensor_testgen::testgen(flatten)]
mod tests {
use super::*;
use burn_tensor::{Shape, Tensor, TensorData};
#[test]
fn should_flatten_to_1d() {
let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());
let flattened_tensor: Tensor<TestBackend, 1> = tensor.flatten(0, 3);
let expected_shape = Shape::new([120]);
assert_eq!(flattened_tensor.shape(), expected_shape);
}
#[test]
fn should_flatten_middle() {
let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());
let flattened_tensor: Tensor<TestBackend, 3> = tensor.flatten(1, 2);
let expected_shape = Shape::new([2, 12, 5]);
assert_eq!(flattened_tensor.shape(), expected_shape);
}
#[test]
fn should_flatten_begin() {
let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());
let flattened_tensor: Tensor<TestBackend, 2> = tensor.flatten(0, 2);
let expected_shape = Shape::new([24, 5]);
assert_eq!(flattened_tensor.shape(), expected_shape);
}
#[test]
fn should_flatten_end() {
let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());
let flattened_tensor: Tensor<TestBackend, 2> = tensor.flatten(1, 3);
let expected_shape = Shape::new([2, 60]);
assert_eq!(flattened_tensor.shape(), expected_shape);
}
#[test]
#[should_panic]
fn should_flatten_panic() {
let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());
let flattened_tensor: Tensor<TestBackend, 2> = tensor.flatten(2, 0);
}
#[test]
#[should_panic]
fn not_enough_destination_dimension() {
let tensor = TestTensor::<3>::ones(Shape::new([1, 5, 15]), &Default::default());
let flattened_tensor: Tensor<TestBackend, 1> = tensor.flatten(1, 2);
let expected_shape = Shape::new([75]);
assert_eq!(flattened_tensor.shape(), expected_shape);
}
}