use tokitai_operator::Error;
use tokitai_operator::backend::TensorStore;
use tokitai_operator::domain::PadicDomain;
use tokitai_operator::facade::{FacadeGraphBuilder, Tokitai};
use tokitai_operator::object::sheaf::{Cover, FiniteSite, OpenId};
use tokitai_operator::object::{ObjectMeta, Representation, Shape, Tensor};
use tokitai_operator::op::arithmetic::DivOp;
#[test]
fn error_domain_variant_is_triggered_by_non_prime_padic_domain() {
let result = PadicDomain::new(1, 3);
match result {
Err(Error::Domain(msg)) => assert!(msg.contains("p-adic prime must be at least 2")),
other => panic!("expected Error::Domain, got {other:?}"),
}
}
#[test]
fn error_shape_variant_is_triggered_by_shape_rank_mismatch() {
let s = Shape::from(vec![2, 3]);
let err = s.ensure_rank(3).expect_err("rank mismatch should fail");
assert!(matches!(err, Error::Shape(_)));
}
#[test]
fn error_shape_variant_ensure_same_is_variant_asserted() {
let a = Shape::from(vec![2, 3]);
let b = Shape::from(vec![3, 2]);
let err = a.ensure_same(&b).expect_err("shape mismatch should fail");
assert!(matches!(err, Error::Shape(ref m) if m.contains("shape mismatch")));
}
#[test]
fn error_operator_variant_is_triggered_by_softmax_wrong_input_count() {
let api = Tokitai::new();
let mut builder = FacadeGraphBuilder::new();
let domain = tokitai_operator::domain::DomainId::new("integer");
let meta = ObjectMeta::tensor(domain, Shape::from(vec![4]), Representation::dense_cpu());
let a = builder.input(meta.clone());
let b = builder.input(meta);
let softmax = api.softmax();
let err = builder
.add_op(softmax, &[a, b])
.expect_err("softmax on 2 inputs should fail");
assert!(matches!(err, Error::Operator(_)));
}
#[test]
fn error_ir_variant_constructor_round_trips() {
let err: Error = Error::ir("ir violation example");
assert!(matches!(err, Error::Ir(ref m) if m == "ir violation example"));
}
#[test]
fn error_domain_variant_is_triggered_by_div_by_zero() {
let api = Tokitai::new();
let mut builder = FacadeGraphBuilder::new();
let domain = tokitai_operator::domain::DomainId::new("integer");
let meta = ObjectMeta::tensor(
domain.clone(),
Shape::from(vec![2]),
Representation::dense_cpu(),
);
let lhs = builder.input(meta.clone());
let rhs = builder.input(meta);
let out = builder.add_op(DivOp, &[lhs, rhs]).unwrap()[0];
let graph = builder.build();
let plan = api.plan(&graph).unwrap();
let mut inputs = TensorStore::new();
inputs.insert(
lhs,
Tensor::dense_cpu(domain.clone(), Shape::from(vec![2]), vec![10, 20]),
);
inputs.insert(
rhs,
Tensor::dense_cpu(domain, Shape::from(vec![2]), vec![0, 0]),
);
let err = api
.execute_i64(&graph, &plan, inputs, &[out])
.expect_err("div by zero should fail");
assert!(matches!(err, Error::Domain(ref m) if m.contains("div by zero")));
}
#[test]
fn error_verification_variant_is_triggered_by_cover_target_not_in_site() {
let site = FiniteSite::new(vec![OpenId("a".into())], vec![]);
let cover = Cover::new("not_in_site", vec!["a"]);
let err = site
.validate_cover(&cover)
.expect_err("cover target not in site should fail");
assert!(matches!(err, Error::Verification(_)));
}
#[test]
fn error_verification_variant_is_triggered_by_empty_cover_glue() {
let site = FiniteSite::new(vec![OpenId("a".into())], vec![]);
let cover = Cover::new("a", Vec::<&str>::new());
let table: tokitai_operator::object::sheaf::SectionTable<i64> =
tokitai_operator::object::sheaf::SectionTable::new();
let err = table
.glue_from_cover(&site, &cover)
.expect_err("empty cover glue should fail");
assert!(matches!(err, Error::Verification(_)));
}
#[test]
fn p441_error_shape_variant_is_triggered_by_softmax_rank0_input() {
use tokitai_operator::op::SoftmaxOp;
let api = Tokitai::new();
let mut builder = FacadeGraphBuilder::new();
let domain = tokitai_operator::domain::DomainId::new("integer");
let meta = ObjectMeta::tensor(domain, Shape::from(vec![]), Representation::dense_cpu());
let x = builder.input(meta);
let out = builder.add_op(SoftmaxOp, &[x]).unwrap()[0];
let graph = builder.build();
let plan = api.plan(&graph).unwrap();
let mut inputs = TensorStore::new();
inputs.insert(
x,
Tensor::dense_cpu(
tokitai_operator::domain::DomainId::new("integer"),
Shape::from(vec![]),
vec![42],
),
);
let err = api
.execute_i64(&graph, &plan, inputs, &[out])
.expect_err("softmax on rank-0 should fail");
assert!(matches!(err, Error::Shape(ref m) if m.contains("softmax: rank-0 input")));
}