tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
use std::collections::{BTreeMap, HashMap};

use tokitai_operator::backend::BackendCapabilities;
use tokitai_operator::domain::{Domain, DomainId, PadicDomain};
use tokitai_operator::ir::SemanticGraph;
use tokitai_operator::object::{ObjectMeta, Representation, Shape};
use tokitai_operator::op::{
    AddOp, BackendLoweringRegistry, LoweringRule, MatmulOp, OperatorRegistry, ReduceOp,
};
use tokitai_operator::planner::{HeuristicPlanner, PlanCacheKey, PlanStepKind};

#[test]
fn p186_backend_lowering_registry_can_be_used_without_semantic_operator_registry() {
    let mut lowerings = BackendLoweringRegistry::new();
    lowerings
        .register_lowering_unchecked(LoweringRule::new(
            "gpu.add.experimental",
            "add",
            "gpu_scaffold",
            vec![Representation::dense_cpu().id().0],
        ))
        .unwrap();

    let rule = lowerings
        .lowering_for("add", "gpu_scaffold", &Representation::dense_cpu().id().0)
        .unwrap();
    assert_eq!(rule.id.0, "gpu.add.experimental");
    assert!(OperatorRegistry::from_parts(BTreeMap::new(), lowerings).is_err());
}

#[test]
fn p186_operator_registry_still_guards_unknown_semantic_operators() {
    let mut registry = OperatorRegistry::new();
    let err = registry
        .register_lowering(LoweringRule::new(
            "cpu.unknown",
            "unknown_op",
            "cpu_scalar",
            vec![Representation::dense_cpu().id().0],
        ))
        .unwrap_err()
        .to_string();

    assert!(err.contains("unknown operator unknown_op"));
}

#[test]
fn p186_plan_cache_key_records_selected_lowering_rule_ids() {
    let graph = dense_add_reduce_graph();
    let registry = OperatorRegistry::cpu_scalar_builtins().unwrap();
    let plain_plan = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
        .plan(&graph)
        .unwrap();
    let lowered_plan = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
        .plan_with_registry(&graph, &registry)
        .unwrap();

    let plain_key = plain_plan.cache_key.unwrap();
    let lowered_key = lowered_plan.cache_key.unwrap();
    assert!(plain_key.lowering_rule_ids.is_empty());
    assert!(
        lowered_key
            .lowering_rule_ids
            .contains(&"cpu.add.dense".to_string())
    );
    assert!(
        lowered_key
            .lowering_rule_ids
            .contains(&"cpu.reduce.dense".to_string())
    );
    assert_ne!(plain_key, lowered_key);
}

#[test]
fn p218_plan_cache_key_records_padic_matmul_certificate_identity() {
    let graph = padic_matmul_graph(2, 3, 2);
    let registry = OperatorRegistry::cpu_scalar_builtins().unwrap();
    let plan = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
        .plan_with_registry(&graph, &registry)
        .unwrap();
    let key = plan.cache_key.as_ref().unwrap();
    assert!(
        key.lowering_rule_ids
            .contains(&"cpu.matmul.padic_valuation_stratified".to_string())
    );
    assert!(
        key.lowering_contract_fingerprints
            .iter()
            .any(|fingerprint| {
                fingerprint.contains("padic-matmul-certificate")
                    && fingerprint.contains("per_output_min_skipped_valuation_and_dense_oracle")
                    && fingerprint.contains("runtime-required")
            })
    );
    let step = plan
        .steps
        .iter()
        .find(|step| matches!(step.kind, PlanStepKind::PadicMatmulValuationSkip { .. }))
        .unwrap();
    assert_eq!(
        step.lowering_rule_id.as_deref(),
        Some("cpu.matmul.padic_valuation_stratified")
    );
}

#[test]
fn p250_specialized_padic_valuation_skip_cache_key_records_lowering_identity() {
    let q5 = PadicDomain::new(5, 3).unwrap();
    let registry = OperatorRegistry::cpu_scalar_builtins().unwrap();
    let plan = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
        .plan_padic_sum_products_skip_with_registry(&q5, 7, 8, 9, &registry);
    let key = plan.cache_key.as_ref().unwrap();

    assert!(
        key.lowering_rule_ids
            .contains(&"cpu.padic_sum_products_valuation_skip".to_string())
    );
    assert!(
        key.lowering_contract_fingerprints
            .iter()
            .any(|fingerprint| fingerprint.starts_with("loweringfp-"))
    );
    assert!(
        key.lowering_contract_fingerprints
            .iter()
            .any(|fingerprint| {
                fingerprint.contains("padic-valuation-vector")
                    && fingerprint.contains("prime=5")
                    && fingerprint.contains("precision=3")
                    && fingerprint.contains("lhs=7")
                    && fingerprint.contains("rhs=8")
                    && fingerprint.contains("out=9")
            })
    );
    assert!(key.semantic.contains("padic_sum_products_valuation_skip"));
    assert_eq!(key.domains, vec!["Q_5[3]".to_string()]);
}

#[test]
fn p218_padic_matmul_admission_rejects_unsupported_precision() {
    let q5 = PadicDomain::new(5, 0).unwrap();
    let meta = ObjectMeta::tensor(
        q5.id(),
        Shape::from(vec![1, 1]),
        Representation::dense_cpu(),
    );
    let mut graph = SemanticGraph::new();
    let lhs = graph.add_input(meta.clone());
    let rhs = graph.add_input(meta);
    graph.add_op(MatmulOp, &[lhs, rhs]).unwrap();

    let registry = OperatorRegistry::cpu_scalar_builtins().unwrap();
    let plan = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
        .plan_with_registry(&graph, &registry)
        .unwrap();

    assert!(
        !plan
            .steps
            .iter()
            .any(|step| matches!(step.kind, PlanStepKind::PadicMatmulValuationSkip { .. }))
    );
    assert!(plan.obligations.iter().any(|obligation| {
        obligation
            .condition
            .contains("p-adic precision 0 is below required minimum 1")
            && obligation
                .condition
                .contains("cpu.matmul.padic_valuation_stratified")
    }));
}

#[test]
fn p186_plan_cache_key_invalidates_on_backend_capability_change() {
    let graph = dense_add_reduce_graph();
    let mut degraded_backend = BackendCapabilities::cpu_scalar();
    degraded_backend
        .semantic_degradations
        .push("test:changed-capability-version".to_string());

    let base_key = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
        .plan(&graph)
        .unwrap()
        .cache_key
        .unwrap();
    let changed_key = HeuristicPlanner::new(degraded_backend)
        .plan(&graph)
        .unwrap()
        .cache_key
        .unwrap();

    assert_eq!(base_key.graph_digest, changed_key.graph_digest);
    assert_eq!(
        base_key.mathematical_constraint_fingerprint,
        changed_key.mathematical_constraint_fingerprint
    );
    assert_ne!(
        base_key.backend_fingerprint,
        changed_key.backend_fingerprint
    );
    assert_ne!(base_key.capability_version, changed_key.capability_version);
}

#[test]
fn p186_plan_cache_key_invalidates_on_mathematical_constraint_change() {
    let integer_key = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
        .plan(&dense_add_reduce_graph())
        .unwrap()
        .cache_key
        .unwrap();
    let padic_key = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
        .plan(&padic_add_graph())
        .unwrap()
        .cache_key
        .unwrap();

    assert_ne!(integer_key.domains, padic_key.domains);
    assert_ne!(
        integer_key.mathematical_constraint_fingerprint,
        padic_key.mathematical_constraint_fingerprint
    );
}

#[test]
fn p186_plan_cache_key_supports_explicit_cache_hit_and_miss_paths() {
    let key = HeuristicPlanner::new(BackendCapabilities::cpu_scalar())
        .plan(&dense_add_reduce_graph())
        .unwrap()
        .cache_key
        .unwrap();
    let miss_key = PlanCacheKey::from_graph_with_lowerings(
        &dense_add_reduce_graph(),
        &BackendCapabilities::cpu_scalar(),
        vec!["synthetic.lowering.rule".to_string()],
    );
    let mut cache = HashMap::new();
    cache.insert(key.clone(), "cached-plan");

    assert_eq!(cache.get(&key), Some(&"cached-plan"));
    assert_eq!(cache.get(&miss_key), None);
}

fn dense_add_reduce_graph() -> SemanticGraph {
    let meta = ObjectMeta::tensor(
        DomainId::new("integer"),
        Shape::from(vec![4]),
        Representation::dense_cpu(),
    );
    let mut graph = SemanticGraph::new();
    let lhs = graph.add_input(meta.clone());
    let rhs = graph.add_input(meta);
    let add = graph.add_op(AddOp, &[lhs, rhs]).unwrap()[0];
    graph.add_op(ReduceOp, &[add]).unwrap();
    graph
}

fn padic_add_graph() -> SemanticGraph {
    let q5 = PadicDomain::new(5, 3).unwrap();
    let meta = ObjectMeta::tensor(q5.id(), Shape::from(vec![4]), Representation::dense_cpu());
    let mut graph = SemanticGraph::new();
    let lhs = graph.add_input(meta.clone());
    let rhs = graph.add_input(meta);
    graph.add_op(AddOp, &[lhs, rhs]).unwrap();
    graph
}

fn padic_matmul_graph(m: usize, k: usize, n: usize) -> SemanticGraph {
    let q5 = PadicDomain::new(5, 3).unwrap();
    let lhs_meta = ObjectMeta::tensor(
        q5.id(),
        Shape::from(vec![m, k]),
        Representation::dense_cpu(),
    );
    let rhs_meta = ObjectMeta::tensor(
        q5.id(),
        Shape::from(vec![k, n]),
        Representation::dense_cpu(),
    );
    let mut graph = SemanticGraph::new();
    let lhs = graph.add_input(lhs_meta);
    let rhs = graph.add_input(rhs_meta);
    graph.add_op(MatmulOp, &[lhs, rhs]).unwrap();
    graph
}