#[burn_tensor_testgen::testgen(flip)]
mod tests {
use super::*;
use burn_tensor::{Device, Int, Shape, Tensor, TensorData};
#[test]
fn flip_int() {
let device = Default::default();
let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);
let flipped = tensor.clone().flip([0, 2]);
let expected = TensorData::from([
[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]],
[[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]],
]);
flipped.into_data().assert_eq(&expected, false);
let flipped = tensor.clone().flip([]);
assert_eq!(tensor.into_data(), flipped.into_data());
}
#[test]
fn flip_float() {
let device = Default::default();
let tensor = TestTensorInt::<1>::arange(0..24, &device)
.reshape([2, 3, 4])
.float();
let flipped = tensor.clone().flip([0, 2]);
let expected = TensorData::from([
[
[15., 14., 13., 12.],
[19., 18., 17., 16.],
[23., 22., 21., 20.],
],
[[3., 2., 1., 0.], [7., 6., 5., 4.], [11., 10., 9., 8.]],
]);
flipped.into_data().assert_eq(&expected, false);
let flipped = tensor.clone().flip([]);
tensor.into_data().assert_eq(&flipped.into_data(), true);
}
#[test]
fn flip_bool() {
let device = Default::default();
let tensor = TestTensorInt::<1>::arange(0..24, &device)
.reshape([2, 3, 4])
.greater_elem(10);
let flipped = tensor.clone().flip([0, 2]);
let data_expected = TensorData::from([
[
[true, true, true, true],
[true, true, true, true],
[true, true, true, true],
],
[
[false, false, false, false],
[false, false, false, false],
[true, false, false, false],
],
]);
flipped.into_data().assert_eq(&data_expected, true);
let flipped = tensor.clone().flip([]);
tensor.into_data().assert_eq(&flipped.into_data(), true);
}
#[test]
#[should_panic]
fn flip_duplicated_axes() {
let device = Default::default();
let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);
let _ = tensor.clone().flip([0, 0, 1]);
}
#[test]
#[should_panic]
fn flip_out_of_bound_axis() {
let device = Default::default();
let tensor = TestTensorInt::<1>::arange(0..24, &device).reshape([2, 3, 4]);
let _ = tensor.clone().flip([3, 0, 1]);
}
}