tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Per-backend plan cache.
//!
//! `PlanCache` is the per-(graph-fingerprint, backend) cache that
//! memoizes `ExecutionPlan`s. The cache key is the `PlanCacheKey`
//! (a hash of the graph's structural fingerprint + the backend
//! capabilities). The cache is invalidated on graph or
//! capability change.
//!
use crate::backend::BackendCapabilities;
use crate::domain::ContractSet;
use crate::ir::SemanticGraph;
use crate::object::{Dim, ObjectMeta};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PlanCacheKey {
    pub graph_digest: String,
    pub semantic: String,
    pub shape_class: String,
    pub domains: Vec<String>,
    pub representations: Vec<String>,
    pub contracts: Vec<String>,
    pub backend: String,
    pub backend_fingerprint: String,
    pub backend_capabilities: Vec<String>,
    pub mathematical_constraint_fingerprint: String,
    pub capability_version: String,
    pub lowering_rule_ids: Vec<String>,
    pub lowering_contract_fingerprints: Vec<String>,
}

impl PlanCacheKey {
    pub fn from_graph(graph: &SemanticGraph, backend: &BackendCapabilities) -> Self {
        Self::from_graph_with_lowerings(graph, backend, Vec::new())
    }

    pub fn from_graph_with_lowerings(
        graph: &SemanticGraph,
        backend: &BackendCapabilities,
        lowering_rule_ids: Vec<String>,
    ) -> Self {
        let lowering_contract_fingerprints = lowering_rule_ids
            .iter()
            .map(|id| format!("legacy:{id}"))
            .collect::<Vec<_>>();
        Self::from_graph_with_lowering_contracts(
            graph,
            backend,
            lowering_rule_ids,
            lowering_contract_fingerprints,
        )
    }

    pub fn from_graph_with_lowering_contracts(
        graph: &SemanticGraph,
        backend: &BackendCapabilities,
        mut lowering_rule_ids: Vec<String>,
        mut lowering_contract_fingerprints: Vec<String>,
    ) -> Self {
        let mut domains = Vec::new();
        let mut representations = Vec::new();
        let mut shape_parts = Vec::new();
        let mut contracts = Vec::new();
        let mut semantic_parts = Vec::new();

        for node in graph.nodes() {
            semantic_parts.push(format!(
                "{}({:?})->{:?}",
                node.op_name, node.inputs, node.output_ids
            ));
            for output in &node.outputs {
                collect_meta(output, &mut domains, &mut representations, &mut shape_parts);
            }
            collect_contracts(&node.required_contracts, &mut contracts);
            collect_contracts(&node.provided_contracts, &mut contracts);
        }

        domains.sort();
        domains.dedup();
        representations.sort();
        representations.dedup();
        contracts.sort();
        contracts.dedup();

        let mut backend_capabilities = vec![
            format!("exact={}", backend.exact),
            format!("deterministic={}", backend.deterministic),
        ];
        backend_capabilities.extend(
            backend
                .supported_representations
                .iter()
                .map(|item| format!("repr={item}")),
        );
        backend_capabilities.extend(
            backend
                .supported_domains
                .iter()
                .map(|item| format!("domain={item}")),
        );
        backend_capabilities.extend(
            backend
                .semantic_degradations
                .iter()
                .map(|item| format!("degradation={item}")),
        );
        backend_capabilities.sort();
        lowering_rule_ids.sort();
        lowering_rule_ids.dedup();
        lowering_contract_fingerprints.sort();
        lowering_contract_fingerprints.dedup();

        let semantic = semantic_parts.join("|");
        let shape_class = shape_parts.join("|");
        let graph_digest = stable_fingerprint(&[
            "semantic".to_string(),
            semantic.clone(),
            shape_class.clone(),
        ]);
        let backend_fingerprint = stable_fingerprint(
            &std::iter::once(format!("backend={}", backend.name))
                .chain(backend_capabilities.iter().cloned())
                .collect::<Vec<_>>(),
        );
        let mathematical_constraint_fingerprint = stable_fingerprint(
            &domains
                .iter()
                .map(|domain| format!("domain={domain}"))
                .chain(
                    contracts
                        .iter()
                        .map(|contract| format!("contract={contract}")),
                )
                .collect::<Vec<_>>(),
        );
        let capability_version = stable_fingerprint(&backend_capabilities);

        Self {
            graph_digest,
            semantic,
            shape_class,
            domains,
            representations,
            contracts,
            backend: backend.name.clone(),
            backend_fingerprint,
            backend_capabilities,
            mathematical_constraint_fingerprint,
            capability_version,
            lowering_rule_ids,
            lowering_contract_fingerprints,
        }
    }

    pub fn from_specialized_padic_valuation_skip(
        backend: &BackendCapabilities,
        domain: impl Into<String>,
        lhs_id: usize,
        rhs_id: usize,
        output_id: usize,
        mut lowering_rule_ids: Vec<String>,
        mut lowering_contract_fingerprints: Vec<String>,
    ) -> Self {
        let domain = domain.into();
        let semantic =
            format!("padic_sum_products_valuation_skip(lhs={lhs_id},rhs={rhs_id})->{output_id}");
        let shape_class = "specialized:sum_products".to_string();
        let mut backend_capabilities = vec![
            format!("exact={}", backend.exact),
            format!("deterministic={}", backend.deterministic),
        ];
        backend_capabilities.extend(
            backend
                .supported_representations
                .iter()
                .map(|item| format!("repr={item}")),
        );
        backend_capabilities.extend(
            backend
                .supported_domains
                .iter()
                .map(|item| format!("domain={item}")),
        );
        backend_capabilities.extend(
            backend
                .semantic_degradations
                .iter()
                .map(|item| format!("degradation={item}")),
        );
        backend_capabilities.sort();
        lowering_rule_ids.sort();
        lowering_rule_ids.dedup();
        lowering_contract_fingerprints.sort();
        lowering_contract_fingerprints.dedup();

        let graph_digest = stable_fingerprint(&[
            "semantic".to_string(),
            semantic.clone(),
            shape_class.clone(),
        ]);
        let backend_fingerprint = stable_fingerprint(
            &std::iter::once(format!("backend={}", backend.name))
                .chain(backend_capabilities.iter().cloned())
                .collect::<Vec<_>>(),
        );
        let mathematical_constraint_fingerprint = stable_fingerprint(&[format!("domain={domain}")]);
        let capability_version = stable_fingerprint(&backend_capabilities);

        Self {
            graph_digest,
            semantic,
            shape_class,
            domains: vec![domain],
            representations: vec!["padic_scalar".to_string()],
            contracts: Vec::new(),
            backend: backend.name.clone(),
            backend_fingerprint,
            backend_capabilities,
            mathematical_constraint_fingerprint,
            capability_version,
            lowering_rule_ids,
            lowering_contract_fingerprints,
        }
    }
}

fn stable_fingerprint(parts: &[String]) -> String {
    let mut hash = 0xcbf29ce484222325u64;
    for part in parts {
        for byte in part.as_bytes() {
            hash ^= u64::from(*byte);
            hash = hash.wrapping_mul(0x100000001b3);
        }
        hash ^= 0xff;
        hash = hash.wrapping_mul(0x100000001b3);
    }
    format!("cachefp-{hash:016x}")
}

fn collect_meta(
    meta: &ObjectMeta,
    domains: &mut Vec<String>,
    representations: &mut Vec<String>,
    shape_parts: &mut Vec<String>,
) {
    domains.push(meta.domain.0.clone());
    representations.push(meta.representation.id().0);
    shape_parts.push(
        meta.shape
            .dims
            .iter()
            .map(dim_class)
            .collect::<Vec<_>>()
            .join(","),
    );
}

fn collect_contracts(contracts: &ContractSet, output: &mut Vec<String>) {
    output.extend(
        contracts
            .iter()
            .map(|contract| format!("{:?}:{:?}", contract.scope, contract.claim)),
    );
}

fn dim_class(dim: &Dim) -> String {
    match dim {
        Dim::Static(value) => format!("static:{value}"),
        Dim::Symbolic(name) => format!("symbolic:{name}"),
        Dim::Bounded { name, min, max } => format!("bounded:{name}:{min}:{max}"),
        Dim::DataDependent(name) => format!("data:{name}"),
    }
}