use tokitai_operator::domain::DomainId;
use tokitai_operator::object::{Dim, Shape, Tensor};
fn int_domain() -> DomainId {
DomainId::new("integer")
}
#[test]
fn try_from_vec_succeeds_when_length_matches_shape_product() {
let t: Tensor<i64> = Tensor::try_from_vec(
int_domain(),
Shape::from(vec![2, 3]),
vec![1, 2, 3, 4, 5, 6],
)
.expect("2*3=6");
assert_eq!(t.data, vec![1, 2, 3, 4, 5, 6]);
assert_eq!(t.rank(), 2);
}
#[test]
fn try_from_vec_fails_when_length_too_small() {
let result: Result<Tensor<i64>, _> =
Tensor::try_from_vec(int_domain(), Shape::from(vec![3]), vec![1, 2]);
let err = result.expect_err("length 2 != 3");
let msg = format!("{err}");
assert!(
msg.contains("length 2 does not match") || msg.contains("3"),
"got: {msg}"
);
}
#[test]
fn try_from_vec_fails_when_length_too_large() {
let result: Result<Tensor<i64>, _> =
Tensor::try_from_vec(int_domain(), Shape::from(vec![2]), vec![1, 2, 3]);
let err = result.expect_err("length 3 != 2");
let msg = format!("{err}");
assert!(
msg.contains("length 3 does not match") || msg.contains("2"),
"got: {msg}"
);
}
#[test]
fn try_from_vec_fails_for_non_static_dim() {
let result: Result<Tensor<i64>, _> = Tensor::try_from_vec(
int_domain(),
Shape::new(vec![Dim::Symbolic("N".to_string())]),
vec![1, 2, 3],
);
let err = result.expect_err("Symbolic dim must fail");
let msg = format!("{err}");
assert!(
msg.contains("all-static") || msg.contains("Symbolic"),
"expected all-static-shape error, got: {msg}"
);
}