use tokitai_operator::domain::DomainId;
use tokitai_operator::object::{Shape, Tensor};
fn int_tensor(shape: Vec<usize>, data: Vec<i64>) -> Tensor<i64> {
Tensor::dense_cpu(DomainId::new("integer"), Shape::from(shape), data)
}
#[test]
fn element_equals_returns_true_for_matching_in_bounds_element() {
let a = int_tensor(vec![2, 3], vec![1, 2, 3, 4, 5, 6]);
let b = int_tensor(vec![2, 3], vec![1, 2, 3, 4, 5, 6]);
for idx in 0..6 {
assert!(
a.element_equals(&b, idx),
"in-bounds equal at idx={idx} must be true"
);
}
}
#[test]
fn element_equals_returns_false_for_in_bounds_unequal_element() {
let a = int_tensor(vec![2, 3], vec![1, 2, 3, 4, 5, 6]);
let b = int_tensor(vec![2, 3], vec![1, 2, 99, 4, 5, 6]);
assert!(!a.element_equals(&b, 2), "idx=2 differs, must be false");
assert!(a.element_equals(&b, 0));
assert!(a.element_equals(&b, 5));
}
#[test]
fn element_equals_returns_false_for_out_of_bounds_index() {
let a = int_tensor(vec![3], vec![10, 20, 30]);
let b = int_tensor(vec![3], vec![10, 20, 30]);
assert!(!a.element_equals(&b, 3), "idx=3 is past end, must be false");
assert!(!a.element_equals(&b, 100), "way past end must be false");
assert!(!a.element_equals(&b, usize::MAX), "huge idx must be false");
}
#[test]
fn element_equals_returns_false_for_shape_mismatch() {
let a = int_tensor(vec![2, 3], vec![1, 2, 3, 4, 5, 6]);
let b = int_tensor(vec![3, 2], vec![1, 2, 3, 4, 5, 6]); for idx in 0..6 {
assert!(
!a.element_equals(&b, idx),
"shape mismatch must always return false (idx={idx})"
);
}
let c = int_tensor(vec![2, 2], vec![1, 2, 3, 4]);
for idx in 0..4 {
assert!(
!a.element_equals(&c, idx),
"different shape (2x3 vs 2x2) must always return false (idx={idx})"
);
}
}