#[burn_tensor_testgen::testgen(squeeze)]
mod tests {
use super::*;
use burn_tensor::{Shape, Tensor, TensorData};
#[test]
fn should_squeeze() {
let tensor = TestTensor::<3>::ones(Shape::new([2, 1, 4]), &Default::default());
let squeezed_tensor: Tensor<TestBackend, 2> = tensor.squeeze(1);
let expected_shape = Shape::new([2, 4]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}
#[test]
fn should_squeeze_first() {
let tensor = TestTensor::<4>::ones(Shape::new([1, 3, 4, 5]), &Default::default());
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(0);
let expected_shape = Shape::new([3, 4, 5]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}
#[test]
fn should_squeeze_last() {
let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 1]), &Default::default());
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(3);
let expected_shape = Shape::new([2, 3, 4]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}
#[test]
#[should_panic]
fn should_squeeze_panic() {
let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(2);
}
#[test]
fn should_squeeze_dims_with_empty_slice() {
let tensor = TestTensor::<3>::ones(Shape::new([1, 1, 3]), &Default::default());
let squeezed_tensor: Tensor<TestBackend, 1> = tensor.squeeze_dims(&[]);
let expected_shape = Shape::new([3]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}
#[test]
fn should_squeeze_dims_with_positive_indices() {
let tensor = TestTensor::<4>::ones(Shape::new([1, 3, 1, 5]), &Default::default());
let squeezed_tensor: Tensor<TestBackend, 2> = tensor.squeeze_dims(&[0, 2]);
let expected_shape = Shape::new([3, 5]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}
#[test]
fn should_squeeze_dims_with_negative_indices() {
let tensor = TestTensor::<4>::ones(Shape::new([2, 1, 3, 1]), &Default::default());
let squeezed_tensor: Tensor<TestBackend, 2> = tensor.squeeze_dims(&[-3, -1]);
let expected_shape = Shape::new([2, 3]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}
#[test]
#[should_panic]
fn should_squeeze_dims_work_if_non_singleton() {
let tensor = TestTensor::<3>::ones(Shape::new([2, 3, 4]), &Default::default());
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze_dims(&[1]);
let expected_shape = Shape::new([2, 3, 4]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}
#[test]
#[should_panic]
fn should_squeeze_dims_panic_on_too_many_dimensions() {
let tensor = TestTensor::<3>::ones(Shape::new([1, 1, 1]), &Default::default());
let _: Tensor<TestBackend, 1> = tensor.squeeze_dims(&[0, 1, 2]);
}
#[test]
#[should_panic]
fn should_squeeze_dims_dimension_mismatch_panic() {
let tensor = TestTensor::<4>::ones(Shape::new([1, 3, 1, 5]), &Default::default());
let _: Tensor<TestBackend, 3> = tensor.squeeze_dims(&[0, 2]);
}
#[test]
fn should_unsqueeze_dim() {
let tensor = TestTensor::<3>::ones(Shape::new([2, 4, 1]), &Default::default());
let unsqueezed_tensor: Tensor<TestBackend, 4> = tensor.unsqueeze_dim(1);
let expected_shape = Shape::new([2, 1, 4, 1]);
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
}
#[test]
fn should_unsqueeze_dim_first() {
let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(0);
let expected_shape = Shape::new([1, 2, 3, 4, 5]);
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
}
#[test]
fn should_unsqueeze_dim_last() {
let tensor = TestTensor::<4>::ones(Shape::new([5, 4, 3, 2]), &Default::default());
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(4);
let expected_shape = Shape::new([5, 4, 3, 2, 1]);
assert_eq!(unsqueezed_tensor.shape(), expected_shape);
}
#[test]
#[should_panic]
fn should_unsqueeze_dim_panic() {
let tensor = TestTensor::<4>::ones(Shape::new([2, 3, 4, 5]), &Default::default());
let unsqueezed_tensor: Tensor<TestBackend, 5> = tensor.unsqueeze_dim(5);
}
#[test]
fn should_unsqueeze_dims_support_dim_inference() {
let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default());
let output_tensor = input_tensor.unsqueeze_dims::<5>(&[1, -2]);
let expected_shape = Shape::new([3, 1, 4, 1, 5]);
assert_eq!(output_tensor.shape(), expected_shape);
}
#[test]
fn should_unsqueeze_dims_handle_first_last() {
let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default());
let output_tensor = input_tensor.unsqueeze_dims::<5>(&[0, 4]);
let expected_shape = Shape::new([1, 3, 4, 5, 1]);
assert_eq!(output_tensor.shape(), expected_shape);
}
#[test]
fn should_unsqueeze_dims_work_with_single_dim() {
let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default());
let output_tensor: Tensor<TestBackend, 4> = input_tensor.unsqueeze_dims(&[1]);
let expected_shape = Shape::new([3, 1, 4, 5]);
assert_eq!(output_tensor.shape(), expected_shape);
}
#[test]
fn should_unsqueeze_dims_multiple_trailing_negatives() {
let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default());
let output_tensor: Tensor<TestBackend, 6> = input_tensor.unsqueeze_dims(&[0, -1, -1]);
let expected_shape = Shape::new([1, 3, 4, 5, 1, 1]);
assert_eq!(output_tensor.shape(), expected_shape);
}
#[test]
#[should_panic]
fn should_unsqueeze_dims_panic() {
let input_tensor = TestTensor::<3>::ones(Shape::new([3, 4, 5]), &Default::default());
let output_tensor: Tensor<TestBackend, 5> = input_tensor.unsqueeze_dims(&[0, -6]);
}
}