use crate::backend::BackendCapabilities;
use crate::domain::ContractSet;
use crate::ir::SemanticGraph;
use crate::object::{Dim, ObjectMeta};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PlanCacheKey {
pub graph_digest: String,
pub semantic: String,
pub shape_class: String,
pub domains: Vec<String>,
pub representations: Vec<String>,
pub contracts: Vec<String>,
pub backend: String,
pub backend_fingerprint: String,
pub backend_capabilities: Vec<String>,
pub mathematical_constraint_fingerprint: String,
pub capability_version: String,
pub lowering_rule_ids: Vec<String>,
pub lowering_contract_fingerprints: Vec<String>,
}
impl PlanCacheKey {
pub fn from_graph(graph: &SemanticGraph, backend: &BackendCapabilities) -> Self {
Self::from_graph_with_lowerings(graph, backend, Vec::new())
}
pub fn from_graph_with_lowerings(
graph: &SemanticGraph,
backend: &BackendCapabilities,
lowering_rule_ids: Vec<String>,
) -> Self {
let lowering_contract_fingerprints = lowering_rule_ids
.iter()
.map(|id| format!("legacy:{id}"))
.collect::<Vec<_>>();
Self::from_graph_with_lowering_contracts(
graph,
backend,
lowering_rule_ids,
lowering_contract_fingerprints,
)
}
pub fn from_graph_with_lowering_contracts(
graph: &SemanticGraph,
backend: &BackendCapabilities,
mut lowering_rule_ids: Vec<String>,
mut lowering_contract_fingerprints: Vec<String>,
) -> Self {
let mut domains = Vec::new();
let mut representations = Vec::new();
let mut shape_parts = Vec::new();
let mut contracts = Vec::new();
let mut semantic_parts = Vec::new();
for node in graph.nodes() {
semantic_parts.push(format!(
"{}({:?})->{:?}",
node.op_name, node.inputs, node.output_ids
));
for output in &node.outputs {
collect_meta(output, &mut domains, &mut representations, &mut shape_parts);
}
collect_contracts(&node.required_contracts, &mut contracts);
collect_contracts(&node.provided_contracts, &mut contracts);
}
domains.sort();
domains.dedup();
representations.sort();
representations.dedup();
contracts.sort();
contracts.dedup();
let mut backend_capabilities = vec![
format!("exact={}", backend.exact),
format!("deterministic={}", backend.deterministic),
];
backend_capabilities.extend(
backend
.supported_representations
.iter()
.map(|item| format!("repr={item}")),
);
backend_capabilities.extend(
backend
.supported_domains
.iter()
.map(|item| format!("domain={item}")),
);
backend_capabilities.extend(
backend
.semantic_degradations
.iter()
.map(|item| format!("degradation={item}")),
);
backend_capabilities.sort();
lowering_rule_ids.sort();
lowering_rule_ids.dedup();
lowering_contract_fingerprints.sort();
lowering_contract_fingerprints.dedup();
let semantic = semantic_parts.join("|");
let shape_class = shape_parts.join("|");
let graph_digest = stable_fingerprint(&[
"semantic".to_string(),
semantic.clone(),
shape_class.clone(),
]);
let backend_fingerprint = stable_fingerprint(
&std::iter::once(format!("backend={}", backend.name))
.chain(backend_capabilities.iter().cloned())
.collect::<Vec<_>>(),
);
let mathematical_constraint_fingerprint = stable_fingerprint(
&domains
.iter()
.map(|domain| format!("domain={domain}"))
.chain(
contracts
.iter()
.map(|contract| format!("contract={contract}")),
)
.collect::<Vec<_>>(),
);
let capability_version = stable_fingerprint(&backend_capabilities);
Self {
graph_digest,
semantic,
shape_class,
domains,
representations,
contracts,
backend: backend.name.clone(),
backend_fingerprint,
backend_capabilities,
mathematical_constraint_fingerprint,
capability_version,
lowering_rule_ids,
lowering_contract_fingerprints,
}
}
pub fn from_specialized_padic_valuation_skip(
backend: &BackendCapabilities,
domain: impl Into<String>,
lhs_id: usize,
rhs_id: usize,
output_id: usize,
mut lowering_rule_ids: Vec<String>,
mut lowering_contract_fingerprints: Vec<String>,
) -> Self {
let domain = domain.into();
let semantic =
format!("padic_sum_products_valuation_skip(lhs={lhs_id},rhs={rhs_id})->{output_id}");
let shape_class = "specialized:sum_products".to_string();
let mut backend_capabilities = vec![
format!("exact={}", backend.exact),
format!("deterministic={}", backend.deterministic),
];
backend_capabilities.extend(
backend
.supported_representations
.iter()
.map(|item| format!("repr={item}")),
);
backend_capabilities.extend(
backend
.supported_domains
.iter()
.map(|item| format!("domain={item}")),
);
backend_capabilities.extend(
backend
.semantic_degradations
.iter()
.map(|item| format!("degradation={item}")),
);
backend_capabilities.sort();
lowering_rule_ids.sort();
lowering_rule_ids.dedup();
lowering_contract_fingerprints.sort();
lowering_contract_fingerprints.dedup();
let graph_digest = stable_fingerprint(&[
"semantic".to_string(),
semantic.clone(),
shape_class.clone(),
]);
let backend_fingerprint = stable_fingerprint(
&std::iter::once(format!("backend={}", backend.name))
.chain(backend_capabilities.iter().cloned())
.collect::<Vec<_>>(),
);
let mathematical_constraint_fingerprint = stable_fingerprint(&[format!("domain={domain}")]);
let capability_version = stable_fingerprint(&backend_capabilities);
Self {
graph_digest,
semantic,
shape_class,
domains: vec![domain],
representations: vec!["padic_scalar".to_string()],
contracts: Vec::new(),
backend: backend.name.clone(),
backend_fingerprint,
backend_capabilities,
mathematical_constraint_fingerprint,
capability_version,
lowering_rule_ids,
lowering_contract_fingerprints,
}
}
}
fn stable_fingerprint(parts: &[String]) -> String {
let mut hash = 0xcbf29ce484222325u64;
for part in parts {
for byte in part.as_bytes() {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(0x100000001b3);
}
hash ^= 0xff;
hash = hash.wrapping_mul(0x100000001b3);
}
format!("cachefp-{hash:016x}")
}
fn collect_meta(
meta: &ObjectMeta,
domains: &mut Vec<String>,
representations: &mut Vec<String>,
shape_parts: &mut Vec<String>,
) {
domains.push(meta.domain.0.clone());
representations.push(meta.representation.id().0);
shape_parts.push(
meta.shape
.dims
.iter()
.map(dim_class)
.collect::<Vec<_>>()
.join(","),
);
}
fn collect_contracts(contracts: &ContractSet, output: &mut Vec<String>) {
output.extend(
contracts
.iter()
.map(|contract| format!("{:?}:{:?}", contract.scope, contract.claim)),
);
}
fn dim_class(dim: &Dim) -> String {
match dim {
Dim::Static(value) => format!("static:{value}"),
Dim::Symbolic(name) => format!("symbolic:{name}"),
Dim::Bounded { name, min, max } => format!("bounded:{name}:{min}:{max}"),
Dim::DataDependent(name) => format!("data:{name}"),
}
}