pub mod cache;
pub mod cost_model;
pub mod obligation;
pub mod plan;
pub mod tuning;
use crate::Result;
use crate::backend::BackendCapabilities;
use crate::domain::{Condition, ContractId};
use crate::domain::{Domain, PadicDomain};
use crate::ir::RewriteJustification;
use crate::ir::SemanticGraph;
use crate::object::Dim;
use crate::object::Representation;
use crate::object::sheaf::Cover;
use crate::op::LayerBehavior;
use crate::op::{
DeviceRequirement, LayoutRequirement, LocalityRequirement, LoweringCapability,
LoweringDomainRequirement, LoweringRule, OperatorRegistry, PrecisionRequirement,
UnsupportedBackendReason, ValuationRequirement,
};
use crate::theory::CertifiedTheoryRequirement;
pub use cache::PlanCacheKey;
pub use cost_model::{
BackendCostModel, CostModelWeights, CostReason, PlanCandidate, PlanSelection,
};
pub use obligation::{DischargeStatus, Obligation, ObligationSeverity, ObligationSource};
pub use plan::{
EvidenceLog, ExecutionPlan, Fallback, FallbackCost, PadicPlanningResource, PlanCost,
PlanResourceProfile, PlanStep, PlanStepKind, PrecisionCost, PrecisionSchedule, ProofKind,
ProofObject, ProofStatus, SemanticCost, ValuationCost, VerificationCost,
};
pub use tuning::{BenchmarkRecord, TuningRecord, TuningTraceSummary};
#[derive(Debug, Clone)]
pub struct HeuristicPlanner {
pub backend: BackendCapabilities,
pub optimization_policy: OptimizationPolicy,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OptimizationPolicy {
pub enable_pointwise_fusion: bool,
pub enable_padic_matmul_valuation_skip: bool,
}
impl Default for OptimizationPolicy {
fn default() -> Self {
Self {
enable_pointwise_fusion: true,
enable_padic_matmul_valuation_skip: true,
}
}
}
impl HeuristicPlanner {
pub fn new(backend: BackendCapabilities) -> Self {
Self {
backend,
optimization_policy: OptimizationPolicy::default(),
}
}
pub fn with_optimization_policy(
backend: BackendCapabilities,
optimization_policy: OptimizationPolicy,
) -> Self {
Self {
backend,
optimization_policy,
}
}
pub fn with_tuning_record(backend: BackendCapabilities, tuning_record: &TuningRecord) -> Self {
let optimization_policy = tuning_record.optimization_policy();
Self {
backend,
optimization_policy,
}
}
pub fn plan(&self, graph: &SemanticGraph) -> Result<ExecutionPlan> {
let mut plan = ExecutionPlan::new(self.backend.name.clone());
plan.cache_key = Some(PlanCacheKey::from_graph(graph, &self.backend));
for node in graph.nodes() {
let obligations = node.required_contracts.iter().map(|contract| Obligation {
source: ObligationSource::Operator(node.op_name.clone()),
condition: format!("requires {:?}", contract.claim),
severity: ObligationSeverity::Required,
status: DischargeStatus::Unresolved,
});
plan.obligations.extend(obligations);
plan.steps.push(PlanStep {
node_id: node.id,
op_name: node.op_name.clone(),
backend: self.backend.name.clone(),
domain: node
.outputs
.first()
.map(|meta| meta.domain.0.clone())
.unwrap_or_else(|| "unknown".to_string()),
representation: node
.outputs
.first()
.map(|meta| meta.representation.id().0)
.unwrap_or_else(|| "unknown".to_string()),
lowering_rule_id: None,
kind: PlanStepKind::Single,
});
plan.evidence_log.entries.push(format!(
"node {} ({}) planned without semantic rewrite",
node.id, node.op_name
));
if node.op_name == "mul"
&& node
.outputs
.first()
.is_some_and(|meta| meta.domain.0.starts_with("Q_"))
{
plan.evidence_log.entries.push(format!(
"node {} (mul) is valuation-filterable for p-adic domain {}",
node.id, node.outputs[0].domain.0
));
}
if node.layer_behavior.contains(&LayerBehavior::CoverLocal) {
plan.evidence_log.entries.push(format!(
"node {} ({}) is cover-local and eligible for local-to-global planning",
node.id, node.op_name
));
}
attach_shape_proof_obligations(&mut plan, graph, node);
attach_padic_matmul_valuation_cost(&mut plan, graph, node);
}
let fusion_groups = pointwise_fusion_groups(graph);
if self.optimization_policy.enable_pointwise_fusion {
for group in fusion_groups {
if group.len() < 2 {
continue;
}
let producer = &graph.nodes()[group[0]];
let rule = "pointwise_metadata_preserving_fusion".to_string();
plan.evidence_log.entries.push(format!(
"fusion candidate: nodes {:?} justified by pointwise layer behavior and matching metadata",
group
));
let rewrite = RewriteJustification {
source: ContractId(200),
rule: rule.clone(),
discharged_conditions: vec![
Condition::new("all fused nodes are pointwise"),
Condition::new("each fused node consumes the previous fused output"),
Condition::new(
"domain, shape, and representation match across fused outputs",
),
],
certified_requirements: Vec::new(),
};
attach_rewrite_soundness_proof(&mut plan, &rewrite);
plan.rewrites.push(rewrite);
plan.steps.push(PlanStep {
node_id: producer.id,
op_name: format!(
"fused:{}",
group
.iter()
.map(|id| graph.nodes()[*id].op_name.as_str())
.collect::<Vec<_>>()
.join("+")
),
backend: self.backend.name.clone(),
domain: producer
.outputs
.first()
.map(|meta| meta.domain.0.clone())
.unwrap_or_else(|| "unknown".to_string()),
representation: producer
.outputs
.first()
.map(|meta| meta.representation.id().0)
.unwrap_or_else(|| "unknown".to_string()),
lowering_rule_id: None,
kind: PlanStepKind::Fused {
node_ids: group,
rule,
},
});
}
} else if fusion_groups.iter().any(|group| group.len() >= 2) {
plan.evidence_log
.entries
.push("optimization policy disabled pointwise fusion lowering".to_string());
}
Ok(plan)
}
pub fn plan_with_tuning_record(
&self,
graph: &SemanticGraph,
tuning_record: &TuningRecord,
) -> Result<ExecutionPlan> {
let planner = Self::with_optimization_policy(
self.backend.clone(),
tuning_record.optimization_policy(),
);
let mut plan = planner.plan(graph)?;
plan.evidence_log
.entries
.push(tuning_record.selection_summary());
if let Some(strategy) = tuning_record.fastest_strategy() {
plan.evidence_log.entries.push(format!(
"benchmark-backed strategy {} selected for backend {}",
strategy, tuning_record.backend
));
}
Ok(plan)
}
pub fn plan_with_registry_and_tuning_record(
&self,
graph: &SemanticGraph,
registry: &OperatorRegistry,
tuning_record: &TuningRecord,
) -> Result<ExecutionPlan> {
let planner = Self::with_optimization_policy(
self.backend.clone(),
tuning_record.optimization_policy(),
);
let mut plan = planner.plan_with_registry(graph, registry)?;
plan.evidence_log
.entries
.push(tuning_record.selection_summary());
if let Some(strategy) = tuning_record.fastest_strategy() {
plan.evidence_log.entries.push(format!(
"benchmark-backed strategy {} selected for backend {}",
strategy, tuning_record.backend
));
}
Ok(plan)
}
pub fn plan_with_registry(
&self,
graph: &SemanticGraph,
registry: &OperatorRegistry,
) -> Result<ExecutionPlan> {
let mut plan = self.plan(graph)?;
for node in graph.nodes() {
let representation = node
.outputs
.first()
.map(|meta| meta.representation.id().0)
.unwrap_or_else(|| "unknown".to_string());
let domain = node
.outputs
.first()
.map(|meta| meta.domain.0.as_str())
.unwrap_or("unknown");
if registry.operator(&node.op_name).is_none() {
plan.obligations.push(Obligation {
source: ObligationSource::Planner("operator_registry".to_string()),
condition: format!(
"operator {} must be registered before backend lowering",
node.op_name
),
severity: ObligationSeverity::Required,
status: DischargeStatus::Unresolved,
});
continue;
}
match registry.lowering_for_domain(
&node.op_name,
&self.backend.name,
&representation,
domain,
) {
Some(rule) => {
let unsupported = unsupported_lowering_reasons(
plan.resources.padic.as_ref(),
&self.backend,
rule,
domain,
&representation,
);
if unsupported.is_empty() {
if let Some(step) = plan.steps.iter_mut().find(|step| {
step.node_id == node.id && step.kind == PlanStepKind::Single
}) {
step.lowering_rule_id = Some(rule.id.0.clone());
}
plan.evidence_log.entries.push(format!(
"node {} ({}) lowered by registry rule {} for backend {}",
node.id, node.op_name, rule.id.0, self.backend.name
));
attach_lowering_metadata(&mut plan, rule);
} else {
attach_unsupported_lowering(
&mut plan,
&self.backend.name,
&format!("node {} ({})", node.id, node.op_name),
rule,
unsupported,
);
}
}
None => {
plan.obligations.push(Obligation {
source: ObligationSource::Backend(self.backend.name.clone()),
condition: format!(
"no lowering for operator {} on backend {} with representation {}",
node.op_name, self.backend.name, representation
),
severity: ObligationSeverity::Required,
status: DischargeStatus::Unresolved,
});
}
}
}
let mut fused_metadata = Vec::new();
let mut fused_missing_obligations = Vec::new();
for step in &mut plan.steps {
if !matches!(step.kind, PlanStepKind::Fused { .. }) {
continue;
}
let domain = graph
.nodes()
.get(step.node_id)
.and_then(|node| node.outputs.first())
.map(|meta| meta.domain.0.as_str())
.unwrap_or("unknown");
match registry.lowering_for_domain(
&step.op_name,
&self.backend.name,
&step.representation,
domain,
) {
Some(rule) => {
let unsupported = unsupported_lowering_reasons(
plan.resources.padic.as_ref(),
&self.backend,
rule,
domain,
&step.representation,
);
if unsupported.is_empty() {
step.lowering_rule_id = Some(rule.id.0.clone());
fused_metadata.push((
rule.clone(),
format!(
"fused step {} ({}) lowered by registry rule {} for backend {}",
step.node_id, step.op_name, rule.id.0, self.backend.name
),
));
} else {
fused_missing_obligations.push(Obligation {
source: ObligationSource::Backend(self.backend.name.clone()),
condition: format!(
"lowering rule {} rejected for fused operator {}: {}",
rule.id.0,
step.op_name,
unsupported
.iter()
.map(UnsupportedBackendReason::message)
.collect::<Vec<_>>()
.join("; ")
),
severity: ObligationSeverity::Required,
status: DischargeStatus::Unresolved,
});
}
}
None => {
fused_missing_obligations.push(Obligation {
source: ObligationSource::Backend(self.backend.name.clone()),
condition: format!(
"no lowering for fused operator {} on backend {} with representation {}",
step.op_name, self.backend.name, step.representation
),
severity: ObligationSeverity::Required,
status: DischargeStatus::Unresolved,
});
}
}
}
for obligation in fused_missing_obligations {
plan.obligations.push(obligation);
}
for (rule, evidence) in fused_metadata {
plan.evidence_log.entries.push(evidence);
attach_lowering_metadata(&mut plan, &rule);
}
if self.optimization_policy.enable_padic_matmul_valuation_skip {
attach_padic_matmul_valuation_skip_steps(&mut plan, graph, registry, &self.backend);
} else {
plan.evidence_log.entries.push(
"optimization policy disabled p-adic matmul valuation-skip lowering".to_string(),
);
}
refresh_cache_key_with_lowerings(&mut plan, graph, &self.backend, registry);
Ok(plan)
}
pub fn plan_padic_sum_products_skip(
&self,
domain: &PadicDomain,
lhs_id: usize,
rhs_id: usize,
output_id: usize,
) -> ExecutionPlan {
let mut plan = ExecutionPlan::new(self.backend.name.clone());
let valuation_cost = estimate_padic_sum_products_valuation_cost(domain, 4);
attach_padic_planning_resource(&mut plan, &self.backend, domain, &valuation_cost);
refresh_specialized_padic_valuation_skip_cache_key(
&mut plan,
&self.backend,
domain,
lhs_id,
rhs_id,
output_id,
None,
);
plan.cost.precision_loss = Some(0);
plan.cost.semantic.precision = Some(PrecisionCost {
input_digits: domain.meta.precision,
output_digits: domain.meta.precision,
precision_loss: 0,
});
plan.cost.semantic.valuation = Some(valuation_cost.clone());
plan.cost.semantic.verification = Some(VerificationCost {
required_checks: 2,
estimated_overhead_ns: None,
});
plan.steps.push(PlanStep {
node_id: 0,
op_name: "padic_sum_products_valuation_skip".to_string(),
backend: self.backend.name.clone(),
domain: domain.id().0,
representation: Representation::PadicScalar.id().0,
lowering_rule_id: None,
kind: PlanStepKind::PadicValuationSkip {
lhs_id,
rhs_id,
output_id,
prime: domain.meta.prime,
precision: domain.meta.precision,
},
});
let rewrite = RewriteJustification {
source: ContractId(300),
rule: "padic_sum_products_valuation_skip".to_string(),
discharged_conditions: vec![
Condition::new("terms with valuation >= precision vanish modulo p^N"),
Condition::new("all values share the same p-adic domain"),
],
certified_requirements: vec![CertifiedTheoryRequirement::padic_valuation_skip()],
};
attach_rewrite_soundness_proof(&mut plan, &rewrite);
plan.rewrites.push(rewrite);
plan.evidence_log.entries.push(format!(
"p-adic valuation skip planned for sum-products over p={} precision={}",
domain.meta.prime, domain.meta.precision
));
if let Some(resource) = &plan.resources.padic {
plan.evidence_log.entries.push(format!(
"p-adic planner resource: prime={}, precision={}, cutoff={}, equality_digits={}, backend_capability={}, fallback={}",
resource.prime,
resource.precision,
resource.valuation_cutoff,
resource.equality_digits,
resource.backend_capability,
resource
.fallback_reason
.as_deref()
.unwrap_or("none")
));
}
plan.evidence_log
.entries
.push(format!("cost rationale: {}", valuation_cost.rationale));
plan.proof_objects.push(
ProofObject::new(
"proof:padic_sum_products_valuation_skip:0",
ProofKind::ValuationCutoff,
format!(
"terms with valuation >= precision {} vanish modulo {}^{}",
domain.meta.precision, domain.meta.prime, domain.meta.precision
),
ProofStatus::RuntimeChecked,
)
.with_obligation("all inputs share the same fixed-precision p-adic domain")
.with_obligation("runtime valuation checks use the active precision cutoff")
.with_evidence(format!(
"valuation cutoff={}, estimated_terms={}, estimated_skipped_terms={}",
valuation_cost.cutoff,
valuation_cost.estimated_terms,
valuation_cost.estimated_skipped_terms
)),
);
plan
}
pub fn plan_padic_sum_products_skip_with_registry(
&self,
domain: &PadicDomain,
lhs_id: usize,
rhs_id: usize,
output_id: usize,
registry: &OperatorRegistry,
) -> ExecutionPlan {
let mut plan = self.plan_padic_sum_products_skip(domain, lhs_id, rhs_id, output_id);
let Some(step) = plan.steps.first() else {
return plan;
};
let op_name = step.op_name.clone();
let representation = step.representation.clone();
let step_domain = step.domain.clone();
let node_id = step.node_id;
match registry.lowering_for_domain(
&op_name,
&self.backend.name,
&representation,
&step_domain,
) {
Some(rule) => {
let unsupported = unsupported_lowering_reasons(
plan.resources.padic.as_ref(),
&self.backend,
rule,
&step_domain,
&representation,
);
if unsupported.is_empty() {
if let Some(step) = plan.steps.first_mut() {
step.lowering_rule_id = Some(rule.id.0.clone());
}
plan.evidence_log.entries.push(format!(
"specialized step {} ({}) lowered by registry rule {} for backend {}",
node_id, op_name, rule.id.0, self.backend.name
));
attach_lowering_metadata(&mut plan, rule);
refresh_specialized_padic_valuation_skip_cache_key(
&mut plan,
&self.backend,
domain,
lhs_id,
rhs_id,
output_id,
Some((rule, "runtime_valuation_vector")),
);
} else {
attach_unsupported_lowering(
&mut plan,
&self.backend.name,
&format!("specialized step {} ({})", node_id, op_name),
rule,
unsupported,
);
}
}
None => {
plan.obligations.push(Obligation {
source: ObligationSource::Backend(self.backend.name.clone()),
condition: format!(
"no lowering for specialized operator {} on backend {} with representation {}",
op_name, self.backend.name, representation
),
severity: ObligationSeverity::Required,
status: DischargeStatus::Unresolved,
});
}
}
plan
}
pub fn plan_cover_glue_check(&self, cover: &Cover) -> ExecutionPlan {
let mut plan = ExecutionPlan::new(self.backend.name.clone());
plan.steps.push(PlanStep {
node_id: 0,
op_name: "cover_glue_check".to_string(),
backend: self.backend.name.clone(),
domain: format!("cover:{}", cover.target.0),
representation: Representation::CoverIndexedSection {
device: crate::object::DeviceId::cpu(),
}
.id()
.0,
lowering_rule_id: None,
kind: PlanStepKind::CoverGlueCheck {
target: cover.target.0.clone(),
opens: cover.opens.iter().map(|open| open.0.clone()).collect(),
},
});
let rewrite = RewriteJustification {
source: ContractId(400),
rule: "cover_local_to_global_glue".to_string(),
discharged_conditions: vec![
Condition::new("operator is cover-local"),
Condition::new("sections agree on declared overlaps"),
Condition::new("cover is valid in finite site"),
],
certified_requirements: vec![CertifiedTheoryRequirement::finite_sheaf_gluing()],
};
attach_rewrite_soundness_proof(&mut plan, &rewrite);
plan.rewrites.push(rewrite);
plan.evidence_log.entries.push(format!(
"cover glue check planned for target {:?} with {} opens",
cover.target,
cover.opens.len()
));
plan
}
pub fn plan_cover_glue_check_with_registry(
&self,
cover: &Cover,
registry: &OperatorRegistry,
) -> ExecutionPlan {
let mut plan = self.plan_cover_glue_check(cover);
let Some(step) = plan.steps.first() else {
return plan;
};
let op_name = step.op_name.clone();
let representation = step.representation.clone();
let domain = step.domain.clone();
match registry.lowering_for_domain(&op_name, &self.backend.name, &representation, &domain) {
Some(rule) => {
let unsupported = unsupported_lowering_reasons(
plan.resources.padic.as_ref(),
&self.backend,
rule,
&domain,
&representation,
);
if unsupported.is_empty() {
if let Some(step) = plan.steps.first_mut() {
step.lowering_rule_id = Some(rule.id.0.clone());
}
plan.evidence_log.entries.push(format!(
"cover glue check lowered by registry rule {} for backend {}",
rule.id.0, self.backend.name
));
attach_lowering_metadata(&mut plan, rule);
} else {
attach_unsupported_lowering(
&mut plan,
&self.backend.name,
"cover glue check",
rule,
unsupported,
);
}
}
None => {
plan.obligations.push(Obligation {
source: ObligationSource::Backend(self.backend.name.clone()),
condition: format!(
"no lowering for cover glue check on backend {} with representation {}",
self.backend.name, representation
),
severity: ObligationSeverity::Required,
status: DischargeStatus::Unresolved,
});
}
}
plan
}
}
fn attach_padic_planning_resource(
plan: &mut ExecutionPlan,
backend: &BackendCapabilities,
domain: &PadicDomain,
valuation_cost: &ValuationCost,
) {
let supports_fixed_precision = backend
.supported_domains
.iter()
.any(|domain| domain == "padic:fixed_precision");
let fallback_reason = if supports_fixed_precision {
None
} else {
Some(format!(
"backend {} does not advertise padic:fixed_precision capability",
backend.name
))
};
let backend_capability = if supports_fixed_precision {
"padic:fixed_precision".to_string()
} else {
"missing:padic:fixed_precision".to_string()
};
plan.resources.padic = Some(PadicPlanningResource {
prime: domain.meta.prime,
precision: domain.meta.precision,
valuation_cutoff: valuation_cost.cutoff,
equality_digits: domain.meta.precision,
backend_capability: backend_capability.clone(),
fallback_reason: fallback_reason.clone(),
});
plan.evidence_log.entries.push(format!(
"p-adic resource requirement: prime={}, precision={}, valuation_cutoff={}, equality_digits={}, backend_capability={}",
domain.meta.prime,
domain.meta.precision,
valuation_cost.cutoff,
domain.meta.precision,
backend_capability
));
if let Some(reason) = fallback_reason {
plan.fallback = Some(Fallback {
backend: "cpu_scalar".to_string(),
reason: reason.clone(),
});
plan.cost.semantic.fallback = Some(FallbackCost {
penalty_ns: None,
reason,
});
plan.obligations.push(Obligation {
source: ObligationSource::Backend(backend.name.clone()),
condition: "p-adic valuation-skip lowering requires padic:fixed_precision capability"
.to_string(),
severity: ObligationSeverity::Required,
status: DischargeStatus::Unresolved,
});
} else {
plan.obligations.push(Obligation {
source: ObligationSource::Planner("padic_resource".to_string()),
condition: format!(
"backend {} advertises padic:fixed_precision for valuation cutoff {} and precision-bounded equality at {} digits",
backend.name, valuation_cost.cutoff, domain.meta.precision
),
severity: ObligationSeverity::AuditOnly,
status: DischargeStatus::Discharged(
"p-adic planner resource is backend-visible".to_string(),
),
});
}
}
fn refresh_cache_key_with_lowerings(
plan: &mut ExecutionPlan,
graph: &SemanticGraph,
backend: &BackendCapabilities,
registry: &OperatorRegistry,
) {
let lowering_rule_ids = plan
.steps
.iter()
.filter_map(|step| step.lowering_rule_id.clone())
.collect::<Vec<_>>();
let mut lowering_contract_fingerprints = lowering_rule_ids
.iter()
.map(|id| {
registry
.lowering_by_id(id)
.map(LoweringRule::cache_fingerprint)
.unwrap_or_else(|| format!("missing:{id}"))
})
.collect::<Vec<_>>();
lowering_contract_fingerprints.extend(plan.steps.iter().filter_map(|step| {
if let PlanStepKind::PadicMatmulValuationSkip {
certificate_policy,
valuation_bucket_fingerprint,
..
} = &step.kind
{
Some(format!(
"padic-matmul-certificate:node={}:policy={}:bucket={}",
step.node_id, certificate_policy, valuation_bucket_fingerprint
))
} else {
None
}
}));
plan.cache_key = Some(PlanCacheKey::from_graph_with_lowering_contracts(
graph,
backend,
lowering_rule_ids,
lowering_contract_fingerprints,
));
}
fn refresh_specialized_padic_valuation_skip_cache_key(
plan: &mut ExecutionPlan,
backend: &BackendCapabilities,
domain: &PadicDomain,
lhs_id: usize,
rhs_id: usize,
output_id: usize,
lowering: Option<(&LoweringRule, &str)>,
) {
let lowering_rule_ids = lowering
.map(|(rule, _)| vec![rule.id.0.clone()])
.unwrap_or_default();
let lowering_contract_fingerprints = lowering
.map(|(rule, input_policy)| {
vec![
rule.cache_fingerprint(),
format!(
"padic-valuation-vector:prime={}:precision={}:lhs={}:rhs={}:out={}:input_policy={}",
domain.meta.prime, domain.meta.precision, lhs_id, rhs_id, output_id, input_policy
),
]
})
.unwrap_or_default();
plan.cache_key = Some(PlanCacheKey::from_specialized_padic_valuation_skip(
backend,
domain.id().0,
lhs_id,
rhs_id,
output_id,
lowering_rule_ids,
lowering_contract_fingerprints,
));
}
fn unsupported_lowering_reasons(
padic_resource: Option<&PadicPlanningResource>,
backend: &BackendCapabilities,
rule: &LoweringRule,
domain: &str,
representation: &str,
) -> Vec<UnsupportedBackendReason> {
let mut reasons = Vec::new();
for capability in &rule.capabilities {
reasons.extend(unsupported_capability_reasons(
padic_resource,
backend,
capability,
domain,
representation,
));
}
reasons
}
fn unsupported_capability_reasons(
padic_resource: Option<&PadicPlanningResource>,
backend: &BackendCapabilities,
capability: &LoweringCapability,
domain: &str,
representation: &str,
) -> Vec<UnsupportedBackendReason> {
let mut reasons = Vec::new();
match &capability.domain {
LoweringDomainRequirement::Any => {}
LoweringDomainRequirement::BackendDomain(required) => {
if !backend
.supported_domains
.iter()
.any(|domain| domain == required)
{
reasons.push(UnsupportedBackendReason::MissingBackendDomain {
required: required.clone(),
backend: backend.name.clone(),
});
}
}
}
match &capability.layout {
LayoutRequirement::Any => {}
LayoutRequirement::Representation(required) => {
if representation != required {
reasons.push(UnsupportedBackendReason::UnsupportedRepresentation {
required: required.clone(),
actual: representation.to_string(),
});
}
}
}
match capability.precision {
PrecisionRequirement::Any | PrecisionRequirement::Exact => {}
PrecisionRequirement::PadicFixedPrecision {
min_digits,
equality_digits_match_precision,
} => match parse_padic_domain_id(domain) {
Some((_prime, precision)) => {
if precision < min_digits {
reasons.push(UnsupportedBackendReason::PadicPrecisionTooLow {
required_min_digits: min_digits,
actual_digits: precision,
});
}
if equality_digits_match_precision
&& padic_resource.is_some_and(|resource| {
resource.precision != precision || resource.equality_digits != precision
})
{
reasons.push(UnsupportedBackendReason::MissingPadicResource);
}
}
None => reasons.push(UnsupportedBackendReason::PadicDomainParseFailed {
domain: domain.to_string(),
}),
},
}
match capability.valuation {
ValuationRequirement::None => {}
ValuationRequirement::PadicCutoffAtPrecision => {
match (parse_padic_domain_id(domain), padic_resource) {
(Some((_prime, precision)), Some(resource))
if resource.valuation_cutoff == precision => {}
(Some((_prime, precision)), Some(resource)) => {
reasons.push(UnsupportedBackendReason::MissingPadicValuationCutoff {
expected_cutoff: precision,
actual_cutoff: resource.valuation_cutoff,
});
}
(Some(_), None) => {}
(None, _) => {}
}
}
}
match capability.locality {
LocalityRequirement::Global => {}
LocalityRequirement::FiniteSite { .. } => {
if !backend
.supported_domains
.iter()
.any(|domain| domain == "sheaf:finite_site")
{
reasons.push(UnsupportedBackendReason::MissingSheafFiniteSiteCapability {
backend: backend.name.clone(),
});
}
}
}
match &capability.device {
DeviceRequirement::Any => {}
DeviceRequirement::RocmHip {
required_gfx,
device_capability_fingerprint,
kernel_source_fingerprint,
compiler_fingerprint,
cpu_oracle_required,
transfer_obligations,
} => {
require_backend_domain(
backend,
"rocm:hip",
"ROCm/HIP runtime availability",
&mut reasons,
);
require_backend_domain(
backend,
&format!("gfx:{required_gfx}"),
"ROCm/HIP gfx target",
&mut reasons,
);
require_backend_domain(
backend,
&format!("device_capability:{device_capability_fingerprint}"),
"ROCm/HIP device capability fingerprint",
&mut reasons,
);
if kernel_source_fingerprint.is_empty() {
reasons.push(UnsupportedBackendReason::MissingRocmHipCapability {
required: "HIP kernel source fingerprint".to_string(),
backend: backend.name.clone(),
});
} else {
require_backend_domain(
backend,
&format!("hip_kernel_source:{kernel_source_fingerprint}"),
"HIP kernel source fingerprint",
&mut reasons,
);
}
if compiler_fingerprint.is_empty() {
reasons.push(UnsupportedBackendReason::MissingRocmHipCapability {
required: "hipcc compiler fingerprint".to_string(),
backend: backend.name.clone(),
});
} else {
require_backend_domain(
backend,
&format!("hip_compiler:{compiler_fingerprint}"),
"hipcc compiler fingerprint",
&mut reasons,
);
}
if *cpu_oracle_required {
require_backend_domain(
backend,
"cpu_oracle:required",
"CPU oracle comparison",
&mut reasons,
);
}
for obligation in transfer_obligations {
if !backend
.semantic_degradations
.iter()
.any(|item| item == &format!("transfer_obligation:{obligation}"))
{
reasons.push(UnsupportedBackendReason::MissingRocmHipCapability {
required: format!("transfer obligation {obligation}"),
backend: backend.name.clone(),
});
}
}
}
}
reasons
}
fn require_backend_domain(
backend: &BackendCapabilities,
required_domain: &str,
label: &str,
reasons: &mut Vec<UnsupportedBackendReason>,
) {
if !backend
.supported_domains
.iter()
.any(|domain| domain == required_domain)
{
reasons.push(UnsupportedBackendReason::MissingRocmHipCapability {
required: format!("{label} ({required_domain})"),
backend: backend.name.clone(),
});
}
}
fn attach_unsupported_lowering(
plan: &mut ExecutionPlan,
backend: &str,
context: &str,
rule: &LoweringRule,
reasons: Vec<UnsupportedBackendReason>,
) {
let reason = reasons
.iter()
.map(UnsupportedBackendReason::message)
.collect::<Vec<_>>()
.join("; ");
plan.evidence_log.entries.push(format!(
"{context} rejected lowering rule {} by mathematical capability contract: {reason}",
rule.id.0
));
plan.fallback = Some(Fallback {
backend: "cpu_scalar".to_string(),
reason: reason.clone(),
});
plan.cost.semantic.fallback = Some(FallbackCost {
penalty_ns: None,
reason: reason.clone(),
});
plan.obligations.push(Obligation {
source: ObligationSource::Backend(backend.to_string()),
condition: format!(
"lowering rule {} is unsupported for {context} by mathematical capability contract: {reason}",
rule.id.0
),
severity: ObligationSeverity::Required,
status: DischargeStatus::Unresolved,
});
}
fn attach_padic_matmul_valuation_skip_steps(
plan: &mut ExecutionPlan,
graph: &SemanticGraph,
registry: &OperatorRegistry,
backend: &BackendCapabilities,
) {
for node in graph.nodes() {
if node.op_name != "matmul" || node.inputs.len() != 2 {
continue;
}
let Some(output) = node.outputs.first() else {
continue;
};
if !output.domain.0.starts_with("Q_") {
continue;
}
let representation = output.representation.id().0;
let Some(rule) = registry.lowering_by_id("cpu.matmul.padic_valuation_stratified") else {
continue;
};
if !rule.supports_representation(&representation)
|| !rule.supports_domain(&output.domain.0)
|| rule.op_name != "matmul"
|| rule.backend != plan.backend
{
continue;
}
let Some((prime, precision)) = parse_padic_domain_id(&output.domain.0) else {
continue;
};
let unsupported = unsupported_lowering_reasons(
plan.resources.padic.as_ref(),
backend,
rule,
&output.domain.0,
&representation,
);
if !unsupported.is_empty() {
attach_unsupported_lowering(
plan,
&rule.backend,
&format!("valuation-stratified matmul node {}", node.id),
rule,
unsupported,
);
continue;
}
plan.steps.push(PlanStep {
node_id: node.id,
op_name: "matmul".to_string(),
backend: plan.backend.clone(),
domain: output.domain.0.clone(),
representation,
lowering_rule_id: Some(rule.id.0.clone()),
kind: PlanStepKind::PadicMatmulValuationSkip {
lhs_id: node.inputs[0],
rhs_id: node.inputs[1],
output_id: node.output_ids[0],
prime,
precision,
certificate_policy: "per_output_min_skipped_valuation_and_dense_oracle".to_string(),
valuation_bucket_fingerprint: format!(
"runtime-required:node={}:lhs={}:rhs={}:precision={}",
node.id, node.inputs[0], node.inputs[1], precision
),
},
});
plan.evidence_log.entries.push(format!(
"p-adic matmul node {} planned with valuation-stratified certificate-aware lowering {}",
node.id, rule.id.0
));
plan.evidence_log.entries.push(format!(
"p-adic matmul certificate policy: node={}, policy=per_output_min_skipped_valuation_and_dense_oracle, valuation_bucket_fingerprint=runtime-required:node={}:lhs={}:rhs={}:precision={}",
node.id, node.id, node.inputs[0], node.inputs[1], precision
));
attach_lowering_metadata(plan, rule);
plan.proof_objects.push(
ProofObject::new(
format!("proof:padic_matmul_valuation_skip:{}", node.id),
ProofKind::ValuationCutoff,
format!(
"p-adic matmul node {} may skip products whose operand valuations sum to at least precision {}",
node.id, precision
),
ProofStatus::RuntimeChecked,
)
.with_obligation("all matrix entries share the same fixed-precision p-adic domain")
.with_obligation("each skipped product is checked against the active precision cutoff")
.with_obligation("per-output certificates record skipped/evaluated counts and precision margin")
.with_evidence(format!("lowering_rule={}", rule.id.0))
.with_evidence(format!("prime={}, precision={}", prime, precision))
.with_evidence("certificate_policy=per_output_min_skipped_valuation_and_dense_oracle")
.with_evidence(format!(
"valuation_bucket_fingerprint=runtime-required:node={}:lhs={}:rhs={}:precision={}",
node.id, node.inputs[0], node.inputs[1], precision
)),
);
}
}
fn attach_padic_matmul_valuation_cost(
plan: &mut ExecutionPlan,
graph: &SemanticGraph,
node: &crate::ir::SemanticNode,
) {
if node.op_name != "matmul" || node.inputs.len() != 2 {
return;
}
let Some(output) = node.outputs.first() else {
return;
};
if !output.domain.0.starts_with("Q_") {
return;
}
let (Some(lhs), Some(rhs)) = (graph.value(node.inputs[0]), graph.value(node.inputs[1])) else {
return;
};
let (Some(m), Some(k), Some(n)) = (
static_dim(lhs.shape.dims.first()),
static_dim(lhs.shape.dims.get(1)),
static_dim(rhs.shape.dims.get(1)),
) else {
plan.evidence_log.entries.push(format!(
"p-adic matmul node {} has non-static dimensions; valuation cost requires runtime shape",
node.id
));
return;
};
let estimated_terms = m * n * k;
let precision = parse_padic_precision(&output.domain.0).unwrap_or(0);
let valuation = ValuationCost {
cutoff: precision,
estimated_terms,
estimated_skipped_terms: 0,
estimated_evaluated_terms: estimated_terms,
certificate_policy: Some("per_output_min_skipped_valuation_and_dense_oracle".to_string()),
valuation_bucket_fingerprint_policy: Some(
"runtime_lhs_rhs_bucket_fingerprints_required".to_string(),
),
precision_margin_floor: Some(0),
rationale: format!(
"p-adic matmul has {estimated_terms} multiplicative terms; terms with lhs valuation plus rhs valuation >= precision cutoff {precision} can be skipped after runtime valuation checks"
),
};
plan.cost.semantic.valuation = Some(valuation.clone());
plan.cost.semantic.verification = Some(VerificationCost {
required_checks: 2,
estimated_overhead_ns: None,
});
plan.evidence_log.entries.push(format!(
"p-adic matmul valuation cost: node={}, terms={}, cutoff={}, evaluated_upper_bound={}",
node.id, valuation.estimated_terms, valuation.cutoff, valuation.estimated_evaluated_terms
));
plan.obligations.push(Obligation {
source: ObligationSource::Planner("valuation_cost".to_string()),
condition: format!(
"p-adic matmul node {} valuation pruning requires runtime checks against cutoff {} over {} terms",
node.id, valuation.cutoff, valuation.estimated_terms
),
severity: ObligationSeverity::AuditOnly,
status: DischargeStatus::Discharged("recorded as conservative valuation upper bound".to_string()),
});
plan.proof_objects.push(
ProofObject::new(
format!("proof:padic_matmul_valuation_bound:{}", node.id),
ProofKind::ValuationCutoff,
format!(
"p-adic matmul node {} has {} candidate products and cutoff {}",
node.id, valuation.estimated_terms, valuation.cutoff
),
ProofStatus::Pending,
)
.with_obligation("runtime executor must check operand valuation sums before pruning")
.with_evidence("conservative planner estimate records zero static skips")
.with_evidence(valuation.rationale),
);
}
fn static_dim(dim: Option<&Dim>) -> Option<usize> {
match dim {
Some(Dim::Static(value)) => Some(*value),
_ => None,
}
}
fn parse_padic_precision(domain: &str) -> Option<u32> {
domain
.rsplit_once('[')
.and_then(|(_, suffix)| suffix.strip_suffix(']'))
.and_then(|digits| digits.parse().ok())
}
fn parse_padic_domain_id(domain: &str) -> Option<(u64, u32)> {
let body = domain.strip_prefix("Q_")?;
let (prime, precision) = body.split_once('[')?;
let precision = precision.strip_suffix(']')?;
Some((prime.parse().ok()?, precision.parse().ok()?))
}
fn attach_shape_proof_obligations(
plan: &mut ExecutionPlan,
graph: &SemanticGraph,
node: &crate::ir::SemanticNode,
) {
if node.op_name != "matmul" || node.inputs.len() != 2 {
return;
}
let (Some(lhs), Some(rhs), Some(output)) = (
graph.value(node.inputs[0]),
graph.value(node.inputs[1]),
node.outputs.first(),
) else {
return;
};
let (Some(lhs_inner), Some(rhs_inner)) = (lhs.shape.dims.get(1), rhs.shape.dims.first()) else {
return;
};
if !lhs_inner.proves_equal(rhs_inner) {
return;
}
let proof = format!(
"matmul node {} shape proof: lhs inner {} equals rhs inner {}; output shape {}",
node.id,
format_dim(lhs_inner),
format_dim(rhs_inner),
format_shape(&output.shape)
);
plan.evidence_log.entries.push(proof.clone());
plan.obligations.push(Obligation {
source: ObligationSource::Planner("shape_proof".to_string()),
condition: proof,
severity: ObligationSeverity::AuditOnly,
status: DischargeStatus::Discharged("proven by shape inference".to_string()),
});
plan.proof_objects.push(
ProofObject::new(
format!("proof:matmul_shape:{}", node.id),
ProofKind::ShapeEquality,
format!(
"matmul node {} has provably equal inner dimensions {} and {}",
node.id,
format_dim(lhs_inner),
format_dim(rhs_inner)
),
ProofStatus::Proven,
)
.with_obligation("lhs and rhs are rank-2 tensors")
.with_obligation("inner dimensions are provably equal before lowering")
.with_evidence(format!("output_shape={}", format_shape(&output.shape))),
);
}
fn format_shape(shape: &crate::object::Shape) -> String {
format!(
"[{}]",
shape
.dims
.iter()
.map(format_dim)
.collect::<Vec<_>>()
.join(",")
)
}
fn format_dim(dim: &Dim) -> String {
match dim {
Dim::Static(value) => value.to_string(),
Dim::Symbolic(name) => format!("symbolic:{name}"),
Dim::Bounded { name, min, max } => format!("bounded:{name}:{min}:{max}"),
Dim::DataDependent(name) => format!("data:{name}"),
}
}
fn attach_lowering_metadata(plan: &mut ExecutionPlan, rule: &LoweringRule) {
let mut proof = ProofObject::new(
format!("proof:lowering:{}", rule.id.0),
ProofKind::LoweringSemanticPreservation,
format!(
"lowering {} preserves semantic operator {} on backend {}",
rule.id.0, rule.op_name, rule.backend
),
ProofStatus::Proven,
);
for evidence in &rule.required_evidence {
plan.evidence_log.entries.push(format!(
"lowering evidence: rule={}, kind={:?}, evidence={}",
rule.id.0, evidence.kind, evidence.description
));
proof = proof.with_evidence(format!("{:?}: {}", evidence.kind, evidence.description));
}
for obligation in &rule.obligations {
plan.obligations.push(Obligation {
source: ObligationSource::Backend(rule.id.0.clone()),
condition: format!("{} ({})", obligation.condition, obligation.rationale),
severity: ObligationSeverity::AuditOnly,
status: DischargeStatus::Discharged("declared by lowering rule".to_string()),
});
proof = proof.with_obligation(format!(
"{} ({})",
obligation.condition, obligation.rationale
));
}
for capability in &rule.capabilities {
plan.evidence_log.entries.push(format!(
"lowering capability: rule={}, capability={:?}",
rule.id.0, capability
));
proof = proof.with_evidence(format!("capability={capability:?}"));
for obligation in &capability.proof_obligations {
plan.obligations.push(Obligation {
source: ObligationSource::Backend(rule.id.0.clone()),
condition: obligation.clone(),
severity: ObligationSeverity::AuditOnly,
status: DischargeStatus::Discharged(
"declared by lowering capability contract".to_string(),
),
});
proof = proof.with_obligation(obligation.clone());
}
}
for requirement in &rule.theory_requirements {
plan.evidence_log.entries.push(format!(
"certified lowering theory requirement: rule={}, theory={}, law={}, theorem={}, required_evidence={}",
rule.id.0,
requirement.theory.as_str(),
requirement.law_id,
requirement.theorem_id.as_deref().unwrap_or("none"),
requirement.required_evidence.join("|")
));
plan.obligations.push(Obligation {
source: ObligationSource::Backend(rule.id.0.clone()),
condition: format!(
"theory law {}:{} must be evidenced before lowering",
requirement.theory.as_str(),
requirement.law_id
),
severity: ObligationSeverity::AuditOnly,
status: DischargeStatus::Discharged(
"declared by certified lowering contract".to_string(),
),
});
proof = proof.with_evidence(format!(
"theory_law={}:{} theorem_binding={}",
requirement.theory.as_str(),
requirement.law_id,
requirement.theorem_id.as_deref().unwrap_or("none")
));
}
plan.proof_objects.push(proof);
}
fn attach_rewrite_soundness_proof(plan: &mut ExecutionPlan, rewrite: &RewriteJustification) {
let mut proof = ProofObject::new(
format!("proof:rewrite:{}", rewrite.rule),
ProofKind::RewriteSoundness,
format!(
"rewrite {} is sound under {} discharged conditions",
rewrite.rule,
rewrite.discharged_conditions.len()
),
ProofStatus::Proven,
)
.with_evidence(format!("source_contract={}", rewrite.source.0))
.with_evidence(format!(
"discharged_conditions={}",
rewrite.discharged_conditions.len()
));
for condition in &rewrite.discharged_conditions {
proof = proof.with_obligation(condition.description.clone());
}
for requirement in &rewrite.certified_requirements {
plan.evidence_log.entries.push(format!(
"certified rewrite theory requirement: rewrite={}, theory={}, law={}, theorem={}, required_evidence={}",
rewrite.rule,
requirement.theory.as_str(),
requirement.law_id,
requirement.theorem_id.as_deref().unwrap_or("none"),
requirement.required_evidence.join("|")
));
proof = proof.with_evidence(format!(
"theory_law={}:{} theorem_binding={}",
requirement.theory.as_str(),
requirement.law_id,
requirement.theorem_id.as_deref().unwrap_or("none")
));
}
plan.proof_objects.push(proof);
}
fn estimate_padic_sum_products_valuation_cost(
domain: &PadicDomain,
estimated_terms: usize,
) -> ValuationCost {
let cutoff = domain.meta.precision;
let estimated_skipped_terms = estimated_terms / 2;
let estimated_evaluated_terms = estimated_terms - estimated_skipped_terms;
ValuationCost {
cutoff,
estimated_terms,
estimated_skipped_terms,
estimated_evaluated_terms,
certificate_policy: Some("scalar_valuation_cutoff_runtime_checked".to_string()),
valuation_bucket_fingerprint_policy: None,
precision_margin_floor: Some(0),
rationale: format!(
"skip terms whose lhs valuation plus rhs valuation is >= precision cutoff {cutoff}; estimated {estimated_skipped_terms}/{estimated_terms} terms vanish modulo p^{cutoff}"
),
}
}
fn can_fuse_pointwise(
producer: &crate::ir::SemanticNode,
consumer: &crate::ir::SemanticNode,
) -> bool {
producer.layer_behavior.contains(&LayerBehavior::Pointwise)
&& consumer.layer_behavior.contains(&LayerBehavior::Pointwise)
&& producer.output_ids.len() == 1
&& consumer.inputs.contains(&producer.output_ids[0])
&& producer
.outputs
.first()
.zip(consumer.outputs.first())
.is_some_and(|(producer_output, consumer_output)| {
producer_output.domain == consumer_output.domain
&& producer_output.shape == consumer_output.shape
&& producer_output.representation == consumer_output.representation
})
}
fn pointwise_fusion_groups(graph: &SemanticGraph) -> Vec<Vec<usize>> {
let mut groups = Vec::new();
let mut index = 0;
while index < graph.nodes().len() {
let mut group = vec![index];
let mut current = index;
while let Some(output_id) = graph
.output_ids_of_node(current)
.and_then(|ids| ids.first())
.copied()
{
let Some(consumer) = graph.single_consumer_of_value(output_id) else {
break;
};
let producer_node = &graph.nodes()[current];
let consumer_node = &graph.nodes()[consumer];
if !can_fuse_pointwise(producer_node, consumer_node) {
break;
}
group.push(consumer);
current = consumer;
}
if group.len() > 1 {
index = *group.last().expect("nonempty group") + 1;
groups.push(group);
} else {
index += 1;
}
}
groups
}