use std::collections::{BTreeMap, HashMap};
use tokitai_operator::backend::BackendCapabilities;
use tokitai_operator::domain::{Domain, DomainId, PadicDomain};
use tokitai_operator::ir::SemanticGraph;
use tokitai_operator::object::{ObjectMeta, Representation, Shape};
use tokitai_operator::op::{
AddOp, BackendLoweringRegistry, LoweringRule, MatmulOp, OperatorRegistry, ReduceOp,
};
use tokitai_operator::planner::{HeuristicPlanner, PlanCacheKey, PlanStepKind};
#[test]
fn p186_backend_lowering_registry_can_be_used_without_semantic_operator_registry() {
let mut lowerings = BackendLoweringRegistry::new();
lowerings
.register_lowering_unchecked(LoweringRule::new(
"gpu.add.experimental",
"add",
"gpu_scaffold",
vec![Representation::dense_cpu().id().0],
))
.unwrap();
let rule = lowerings
.lowering_for("add", "gpu_scaffold", &Representation::dense_cpu().id().0)
.unwrap();
assert_eq!(rule.id.0, "gpu.add.experimental");
assert!(OperatorRegistry::from_parts(BTreeMap::new(), lowerings).is_err());
}
#[test]
fn p186_operator_registry_still_guards_unknown_semantic_operators() {
let mut registry = OperatorRegistry::new();
let err = registry
.register_lowering(LoweringRule::new(
"cpu.unknown",
"unknown_op",
"cpu_scalar",
vec![Representation::dense_cpu().id().0],
))
.unwrap_err()
.to_string();
assert!(err.contains("unknown operator unknown_op"));
}
#[test]
fn p186_plan_cache_key_records_selected_lowering_rule_ids() {
let graph = dense_add_reduce_graph();
let registry = OperatorRegistry::cpu_scalar_builtins().unwrap();
let plain_plan = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
.plan(&graph)
.unwrap();
let lowered_plan = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
.plan_with_registry(&graph, ®istry)
.unwrap();
let plain_key = plain_plan.cache_key.unwrap();
let lowered_key = lowered_plan.cache_key.unwrap();
assert!(plain_key.lowering_rule_ids.is_empty());
assert!(
lowered_key
.lowering_rule_ids
.contains(&"cpu.add.dense".to_string())
);
assert!(
lowered_key
.lowering_rule_ids
.contains(&"cpu.reduce.dense".to_string())
);
assert_ne!(plain_key, lowered_key);
}
#[test]
fn p218_plan_cache_key_records_padic_matmul_certificate_identity() {
let graph = padic_matmul_graph(2, 3, 2);
let registry = OperatorRegistry::cpu_scalar_builtins().unwrap();
let plan = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
.plan_with_registry(&graph, ®istry)
.unwrap();
let key = plan.cache_key.as_ref().unwrap();
assert!(
key.lowering_rule_ids
.contains(&"cpu.matmul.padic_valuation_stratified".to_string())
);
assert!(
key.lowering_contract_fingerprints
.iter()
.any(|fingerprint| {
fingerprint.contains("padic-matmul-certificate")
&& fingerprint.contains("per_output_min_skipped_valuation_and_dense_oracle")
&& fingerprint.contains("runtime-required")
})
);
let step = plan
.steps
.iter()
.find(|step| matches!(step.kind, PlanStepKind::PadicMatmulValuationSkip { .. }))
.unwrap();
assert_eq!(
step.lowering_rule_id.as_deref(),
Some("cpu.matmul.padic_valuation_stratified")
);
}
#[test]
fn p250_specialized_padic_valuation_skip_cache_key_records_lowering_identity() {
let q5 = PadicDomain::new(5, 3).unwrap();
let registry = OperatorRegistry::cpu_scalar_builtins().unwrap();
let plan = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
.plan_padic_sum_products_skip_with_registry(&q5, 7, 8, 9, ®istry);
let key = plan.cache_key.as_ref().unwrap();
assert!(
key.lowering_rule_ids
.contains(&"cpu.padic_sum_products_valuation_skip".to_string())
);
assert!(
key.lowering_contract_fingerprints
.iter()
.any(|fingerprint| fingerprint.starts_with("loweringfp-"))
);
assert!(
key.lowering_contract_fingerprints
.iter()
.any(|fingerprint| {
fingerprint.contains("padic-valuation-vector")
&& fingerprint.contains("prime=5")
&& fingerprint.contains("precision=3")
&& fingerprint.contains("lhs=7")
&& fingerprint.contains("rhs=8")
&& fingerprint.contains("out=9")
})
);
assert!(key.semantic.contains("padic_sum_products_valuation_skip"));
assert_eq!(key.domains, vec!["Q_5[3]".to_string()]);
}
#[test]
fn p218_padic_matmul_admission_rejects_unsupported_precision() {
let q5 = PadicDomain::new(5, 0).unwrap();
let meta = ObjectMeta::tensor(
q5.id(),
Shape::from(vec![1, 1]),
Representation::dense_cpu(),
);
let mut graph = SemanticGraph::new();
let lhs = graph.add_input(meta.clone());
let rhs = graph.add_input(meta);
graph.add_op(MatmulOp, &[lhs, rhs]).unwrap();
let registry = OperatorRegistry::cpu_scalar_builtins().unwrap();
let plan = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
.plan_with_registry(&graph, ®istry)
.unwrap();
assert!(
!plan
.steps
.iter()
.any(|step| matches!(step.kind, PlanStepKind::PadicMatmulValuationSkip { .. }))
);
assert!(plan.obligations.iter().any(|obligation| {
obligation
.condition
.contains("p-adic precision 0 is below required minimum 1")
&& obligation
.condition
.contains("cpu.matmul.padic_valuation_stratified")
}));
}
#[test]
fn p186_plan_cache_key_invalidates_on_backend_capability_change() {
let graph = dense_add_reduce_graph();
let mut degraded_backend = BackendCapabilities::cpu_scalar();
degraded_backend
.semantic_degradations
.push("test:changed-capability-version".to_string());
let base_key = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
.plan(&graph)
.unwrap()
.cache_key
.unwrap();
let changed_key = HeuristicPlanner::new(degraded_backend)
.plan(&graph)
.unwrap()
.cache_key
.unwrap();
assert_eq!(base_key.graph_digest, changed_key.graph_digest);
assert_eq!(
base_key.mathematical_constraint_fingerprint,
changed_key.mathematical_constraint_fingerprint
);
assert_ne!(
base_key.backend_fingerprint,
changed_key.backend_fingerprint
);
assert_ne!(base_key.capability_version, changed_key.capability_version);
}
#[test]
fn p186_plan_cache_key_invalidates_on_mathematical_constraint_change() {
let integer_key = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
.plan(&dense_add_reduce_graph())
.unwrap()
.cache_key
.unwrap();
let padic_key = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
.plan(&padic_add_graph())
.unwrap()
.cache_key
.unwrap();
assert_ne!(integer_key.domains, padic_key.domains);
assert_ne!(
integer_key.mathematical_constraint_fingerprint,
padic_key.mathematical_constraint_fingerprint
);
}
#[test]
fn p186_plan_cache_key_supports_explicit_cache_hit_and_miss_paths() {
let key = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
.plan(&dense_add_reduce_graph())
.unwrap()
.cache_key
.unwrap();
let miss_key = PlanCacheKey::from_graph_with_lowerings(
&dense_add_reduce_graph(),
&BackendCapabilities::cpu_scalar(),
vec!["synthetic.lowering.rule".to_string()],
);
let mut cache = HashMap::new();
cache.insert(key.clone(), "cached-plan");
assert_eq!(cache.get(&key), Some(&"cached-plan"));
assert_eq!(cache.get(&miss_key), None);
}
fn dense_add_reduce_graph() -> SemanticGraph {
let meta = ObjectMeta::tensor(
DomainId::new("integer"),
Shape::from(vec![4]),
Representation::dense_cpu(),
);
let mut graph = SemanticGraph::new();
let lhs = graph.add_input(meta.clone());
let rhs = graph.add_input(meta);
let add = graph.add_op(AddOp, &[lhs, rhs]).unwrap()[0];
graph.add_op(ReduceOp, &[add]).unwrap();
graph
}
fn padic_add_graph() -> SemanticGraph {
let q5 = PadicDomain::new(5, 3).unwrap();
let meta = ObjectMeta::tensor(q5.id(), Shape::from(vec![4]), Representation::dense_cpu());
let mut graph = SemanticGraph::new();
let lhs = graph.add_input(meta.clone());
let rhs = graph.add_input(meta);
graph.add_op(AddOp, &[lhs, rhs]).unwrap();
graph
}
fn padic_matmul_graph(m: usize, k: usize, n: usize) -> SemanticGraph {
let q5 = PadicDomain::new(5, 3).unwrap();
let lhs_meta = ObjectMeta::tensor(
q5.id(),
Shape::from(vec![m, k]),
Representation::dense_cpu(),
);
let rhs_meta = ObjectMeta::tensor(
q5.id(),
Shape::from(vec![k, n]),
Representation::dense_cpu(),
);
let mut graph = SemanticGraph::new();
let lhs = graph.add_input(lhs_meta);
let rhs = graph.add_input(rhs_meta);
graph.add_op(MatmulOp, &[lhs, rhs]).unwrap();
graph
}