use tokitai_operator::backend::TensorStore;
use tokitai_operator::domain::{Padic, PadicDomain};
use tokitai_operator::facade::Tokitai;
use tokitai_operator::object::{ObjectMeta, Representation, Shape, Tensor};
use tokitai_operator::planner::PlanStepKind;
fn integer_meta(shape: Vec<usize>) -> ObjectMeta {
ObjectMeta::tensor(
tokitai_operator::domain::DomainId::new("integer"),
Shape::from(shape),
Representation::dense_cpu(),
)
}
fn padic_tensor_meta(prime: u64, precision: u8, shape: Vec<usize>) -> ObjectMeta {
ObjectMeta::tensor(
tokitai_operator::domain::DomainId::new(format!("Q_{prime}[{precision}]")),
Shape::from(shape),
Representation::dense_cpu(),
)
}
#[test]
fn fma_planner_lowers_integer_three_input_graph_to_single_step() {
let api = Tokitai::cpu_only();
let mut builder = tokitai_operator::facade::FacadeGraphBuilder::new();
let a = builder.input(integer_meta(vec![3]));
let b = builder.input(integer_meta(vec![3]));
let c = builder.input(integer_meta(vec![3]));
let out = builder
.add_op(api.fma(), &[a, b, c])
.expect("fma infer should accept 3 same-shape integer inputs")[0];
let graph = builder.build();
let plan = api.plan_public(&graph).expect("plan should succeed");
let single_fma_steps: Vec<_> = plan
.plan
.steps
.iter()
.filter(|step| step.kind == PlanStepKind::Single && step.op_name == "fma")
.collect();
assert_eq!(single_fma_steps.len(), 1);
let rule_id = single_fma_steps[0]
.lowering_rule_id
.as_ref()
.expect("fma integer step must be lowered by cpu.fma.dense");
assert_eq!(rule_id, "cpu.fma.dense");
assert!(!plan.plan.has_blocking_obligations());
assert!(plan.plan.obligations.iter().all(|obligation| {
!obligation
.condition
.contains("no lowering for operator fma")
}));
let int_domain = tokitai_operator::domain::DomainId::new("integer");
let mut store = TensorStore::<i64>::new();
store.insert(
a,
Tensor::dense_cpu(int_domain.clone(), Shape::from(vec![3]), vec![2, 3, 4]),
);
store.insert(
b,
Tensor::dense_cpu(int_domain.clone(), Shape::from(vec![3]), vec![5, 6, 7]),
);
store.insert(
c,
Tensor::dense_cpu(int_domain, Shape::from(vec![3]), vec![1, 1, 1]),
);
let (executed, _report) = api
.execute_i64(&graph, &plan.plan, store, &[out])
.expect("fma integer execution should succeed");
let result = executed.get(out).expect("fma output must be present");
assert_eq!(result.data, vec![11, 19, 29]);
}
#[test]
fn fma_integer_matches_explicit_mul_then_add_in_canonical_oracle() {
let api = Tokitai::cpu_only();
let mut builder = tokitai_operator::facade::FacadeGraphBuilder::new();
let a = builder.input(integer_meta(vec![2, 2]));
let b = builder.input(integer_meta(vec![2, 2]));
let c = builder.input(integer_meta(vec![2, 2]));
let fma_out = builder.add_op(api.fma(), &[a, b, c]).unwrap()[0];
let graph = builder.build();
let plan = api.plan_public(&graph).unwrap();
let int_domain = tokitai_operator::domain::DomainId::new("integer");
let a_data = vec![1, 2, 3, 4];
let b_data = vec![5, 6, 7, 8];
let c_data = vec![10, 20, 30, 40];
let mut store = TensorStore::<i64>::new();
store.insert(
a,
Tensor::dense_cpu(int_domain.clone(), Shape::from(vec![2, 2]), a_data.clone()),
);
store.insert(
b,
Tensor::dense_cpu(int_domain.clone(), Shape::from(vec![2, 2]), b_data.clone()),
);
store.insert(
c,
Tensor::dense_cpu(int_domain, Shape::from(vec![2, 2]), c_data.clone()),
);
let (executed, _report) = api
.execute_i64(&graph, &plan.plan, store, &[fma_out])
.unwrap();
let fma_result = executed.get(fma_out).unwrap().data.clone();
let mut oracle = Vec::with_capacity(a_data.len());
for ((av, bv), cv) in a_data.iter().zip(b_data.iter()).zip(c_data.iter()) {
oracle.push(av * bv + cv);
}
assert_eq!(fma_result, oracle);
}
#[test]
fn fma_planner_rejects_inconsistent_shapes() {
let api = Tokitai::cpu_only();
let mut builder = tokitai_operator::facade::FacadeGraphBuilder::new();
let a = builder.input(integer_meta(vec![3]));
let b = builder.input(integer_meta(vec![2]));
let c = builder.input(integer_meta(vec![3]));
let err = builder
.add_op(api.fma(), &[a, b, c])
.expect_err("fma should reject mismatched shapes");
let msg = format!("{err}");
assert!(msg.contains("shape"), "expected shape error, got: {msg}");
}
#[test]
fn fma_padic_lowering_executes_modulo_p_power() {
let q5 = PadicDomain::new(5, 3).expect("5 must be a valid prime");
let api = Tokitai::cpu_only();
let mut builder = tokitai_operator::facade::FacadeGraphBuilder::new();
let a = builder.input(padic_tensor_meta(5, 3, vec![2]));
let b = builder.input(padic_tensor_meta(5, 3, vec![2]));
let c = builder.input(padic_tensor_meta(5, 3, vec![2]));
let out = builder.add_op(api.fma(), &[a, b, c]).unwrap()[0];
let graph = builder.build();
let plan = api.plan_public(&graph).unwrap();
let fma_steps: Vec<_> = plan
.plan
.steps
.iter()
.filter(|step| step.kind == PlanStepKind::Single && step.op_name == "fma")
.collect();
assert_eq!(fma_steps.len(), 1);
assert_eq!(
fma_steps[0].lowering_rule_id.as_deref(),
Some("cpu.fma.padic_dense")
);
let a_data = vec![q5.element(2), q5.element(3)];
let b_data = vec![q5.element(4), q5.element(5)];
let c_data = vec![q5.element(1), q5.element(1)];
let mut store: TensorStore<Padic> = TensorStore::new();
store.insert(
a,
Tensor::dense_cpu(
tokitai_operator::domain::DomainId::new("Q_5[3]"),
Shape::from(vec![2]),
a_data,
),
);
store.insert(
b,
Tensor::dense_cpu(
tokitai_operator::domain::DomainId::new("Q_5[3]"),
Shape::from(vec![2]),
b_data,
),
);
store.insert(
c,
Tensor::dense_cpu(
tokitai_operator::domain::DomainId::new("Q_5[3]"),
Shape::from(vec![2]),
c_data,
),
);
tokitai_operator::backend::cpu::CpuScalarBackend
.execute_padic(&graph, &plan.plan, &mut store)
.expect("fma p-adic execution should succeed");
let result = store.get(out).expect("fma p-adic output must be present");
let expected_residues = [
q5.add(
&q5.mul(&q5.element(2), &q5.element(4)).unwrap(),
&q5.element(1),
)
.unwrap()
.residue,
q5.add(
&q5.mul(&q5.element(3), &q5.element(5)).unwrap(),
&q5.element(1),
)
.unwrap()
.residue,
];
let actual: Vec<u128> = result.data.iter().map(|p| p.residue).collect();
assert_eq!(actual, expected_residues);
}