#[burn_tensor_testgen::testgen(roll)]
mod tests {
use super::*;
use burn_tensor::{Int, Slice, Tensor, TensorData, as_type, might_panic, s};
#[ignore = "0 size resources are not yet supported"]
#[cfg(feature = "std")]
#[test]
fn test_roll_empty() {
let device = Default::default();
let input = TestTensorInt::<2>::zeros([12, 0], &device);
let result = input.clone().roll(&[1, 2], &[0, 1]);
assert_eq!(result.shape().dims, &[12, 0]);
}
#[test]
fn test_roll() {
let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);
input
.clone()
.roll(&[0, 0], &[0, 1])
.to_data()
.assert_eq(&input.clone().to_data(), false);
input
.clone()
.roll(&[1, -1], &[0, 1])
.to_data()
.assert_eq(&TensorData::from([[5, 3, 4], [2, 0, 1]]), false);
input
.clone()
.roll(&[-1, 1], &[1, 0])
.to_data()
.assert_eq(&TensorData::from([[5, 3, 4], [2, 0, 1]]), false);
input
.clone()
.roll(&[2 * 32 + 1, 3 * (-400) - 1], &[0, 1])
.to_data()
.assert_eq(&TensorData::from([[5, 3, 4], [2, 0, 1]]), false);
}
#[should_panic]
#[test]
fn test_roll_dim_too_big() {
let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);
let _d = input.roll(&[1], &[2]);
}
#[should_panic]
#[test]
fn test_roll_dim_too_small() {
let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);
let _d = input.roll(&[1], &[-3]);
}
#[should_panic]
#[test]
fn test_roll_shift_size_mismatch() {
let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);
let _d = input.roll(&[1, 2], &[0]);
}
#[test]
fn test_roll_dim() {
let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);
input
.clone()
.roll_dim(1, 0)
.to_data()
.assert_eq(&TensorData::from([[3, 4, 5], [0, 1, 2]]), false);
input
.clone()
.roll_dim(-1, 1)
.to_data()
.assert_eq(&TensorData::from([[2, 0, 1], [5, 3, 4]]), false);
}
#[should_panic]
#[test]
fn test_roll_dim_dim_too_big() {
let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);
let _d = input.roll_dim(1, 2);
}
#[should_panic]
#[test]
fn test_roll_dim_dim_too_small() {
let input = TestTensorInt::<2>::from([[0, 1, 2], [3, 4, 5]]);
let _d = input.roll_dim(1, -3);
}
}