use crate::Tensor;
use ndarray::array;
crate::codegen_tests! {
fn test_triu_basic(config) {
let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
let mut result = x.triu(0).unwrap();
result.realize_with(&config).unwrap();
assert_eq!(result.as_vec::<f32>().unwrap(), [1.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 0.0, 9.0]);
}
fn test_triu_diagonal_positive(config) {
let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
let mut result = x.triu(1).unwrap();
result.realize_with(&config).unwrap();
assert_eq!(result.as_vec::<f32>().unwrap(), [0.0, 2.0, 3.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0]);
}
fn test_triu_diagonal_negative(config) {
let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
let mut result = x.triu(-1).unwrap();
result.realize_with(&config).unwrap();
assert_eq!(result.as_vec::<f32>().unwrap(), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 8.0, 9.0]);
}
fn test_tril_basic(config) {
let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
let mut result = x.tril(0).unwrap();
result.realize_with(&config).unwrap();
assert_eq!(result.as_vec::<f32>().unwrap(), [1.0, 0.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 9.0]);
}
fn test_tril_diagonal_negative(config) {
let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]);
let mut result = x.tril(-1).unwrap();
result.realize_with(&config).unwrap();
assert_eq!(result.as_vec::<f32>().unwrap(), [0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 7.0, 8.0, 0.0]);
}
fn test_triu_non_square(config) {
let x = Tensor::from_ndarray(&array![[1.0f32, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]);
let mut result = x.triu(0).unwrap();
result.realize_with(&config).unwrap();
assert_eq!(result.as_vec::<f32>().unwrap(), [1.0, 2.0, 3.0, 4.0, 0.0, 6.0, 7.0, 8.0]);
}
}