use crate::backend::hardware::{ComputeHardware, DeviceCapabilities, HardwareTarget};
use crate::backend::{
Backend, BackendCapabilities, Executable, GraphExecutor, ObjectRef, SpecializedPlanExecutor,
TensorStore,
};
use crate::domain::{
Domain, Padic, PadicDomain, PadicMatrixMetadata, PadicOutputCertificate, ValuationSkipReport,
};
use crate::ir::SemanticGraph;
use crate::object::{Dim, ObjectMeta, Tensor};
use crate::op::OperatorRegistry;
use crate::planner::{ExecutionPlan, PlanStepKind};
use crate::{Error, Result};
use std::collections::BTreeSet;
#[derive(Debug, Clone, Copy, Default)]
pub struct CpuScalarBackend;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PadicMatmulValuationSkipReport {
pub skipped_terms: usize,
pub evaluated_terms: usize,
pub lhs_metadata: Vec<PadicMatrixMetadata>,
pub rhs_metadata: Vec<PadicMatrixMetadata>,
pub output_certificates: Vec<PadicOutputCertificate>,
pub dense_oracle_matches: bool,
}
impl PadicMatmulValuationSkipReport {
fn merge(&mut self, other: Self) {
self.skipped_terms += other.skipped_terms;
self.evaluated_terms += other.evaluated_terms;
self.lhs_metadata.extend(other.lhs_metadata);
self.rhs_metadata.extend(other.rhs_metadata);
self.output_certificates.extend(other.output_certificates);
self.dense_oracle_matches &= other.dense_oracle_matches;
}
}
impl Backend for CpuScalarBackend {
fn name(&self) -> &'static str {
"cpu_scalar"
}
fn capabilities(&self) -> BackendCapabilities {
BackendCapabilities::cpu_scalar()
}
fn compile(&self, plan: &ExecutionPlan) -> Result<Executable> {
Ok(Executable {
backend: plan.backend.clone(),
})
}
fn execute(&self, _executable: &Executable, _args: &[ObjectRef]) -> Result<()> {
Ok(())
}
}
impl ComputeHardware for CpuScalarBackend {
fn target(&self) -> HardwareTarget {
HardwareTarget::cpu_scalar()
}
fn device_capabilities(&self) -> DeviceCapabilities {
DeviceCapabilities::cpu_scalar()
}
}
impl CpuScalarBackend {
pub fn execute_i64(
&self,
graph: &SemanticGraph,
plan: &ExecutionPlan,
store: &mut TensorStore<i64>,
) -> Result<()> {
if plan.backend != self.name() {
return Err(Error::backend(format!(
"plan targets backend {}, but executor is {}",
plan.backend,
self.name()
)));
}
for node in graph.nodes() {
match node.op_name.as_str() {
"add" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_binary_i64(&lhs, &rhs, |a, b| a + b)?;
store.insert(node.output_ids[0], output);
}
"mul" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_binary_i64(&lhs, &rhs, |a, b| a * b)?;
store.insert(node.output_ids[0], output);
}
"matmul" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_matmul_i64(&lhs, &rhs, node.outputs[0].clone())?;
store.insert(node.output_ids[0], output);
}
"map" => {
let input = store.get(node.inputs[0])?.clone();
store.insert(node.output_ids[0], input);
}
"reduce" => {
let input = store.get(node.inputs[0])?;
let sum = input.data.iter().copied().sum::<i64>();
let output = Tensor {
meta: node.outputs[0].clone(),
data: vec![sum],
};
store.insert(node.output_ids[0], output);
}
"fma" => {
let a = store.get(node.inputs[0])?.clone();
let b = store.get(node.inputs[1])?.clone();
let c = store.get(node.inputs[2])?.clone();
let output = dense_ternary_i64(&a, &b, &c, |x, y, z| x * y + z)?;
store.insert(node.output_ids[0], output);
}
"clamp" => {
let input = store.get(node.inputs[0])?.clone();
let lo_tensor = store.get(node.inputs[1])?;
let hi_tensor = store.get(node.inputs[2])?;
if lo_tensor.data.is_empty() || hi_tensor.data.is_empty() {
return Err(Error::backend(format!(
"clamp expects lo and hi to be scalar tensors with at least one element (got lo_len={}, hi_len={}, input shape={:?})",
lo_tensor.data.len(),
hi_tensor.data.len(),
input.meta.shape.dims,
)));
}
let lo = lo_tensor.data[0];
let hi = hi_tensor.data[0];
let output = clamp_i64(&input, lo, hi)?;
store.insert(node.output_ids[0], output);
}
"neg" => {
let input = store.get(node.inputs[0])?.clone();
let output = neg_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"abs" => {
let input = store.get(node.inputs[0])?.clone();
let output = abs_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"square" => {
let input = store.get(node.inputs[0])?.clone();
let output = square_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"mul_by_two" => {
let input = store.get(node.inputs[0])?.clone();
let output = mul_by_two_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"sub" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_sub_i64(&lhs, &rhs)?;
store.insert(node.output_ids[0], output);
}
"div" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_div_i64(&lhs, &rhs)?;
store.insert(node.output_ids[0], output);
}
"scalar_add" => {
let input = store.get(node.inputs[0])?.clone();
let scalar_t = store.get(node.inputs[1])?;
if scalar_t.data.is_empty() {
return Err(Error::backend(
"scalar_add scalar input must be a 1-element tensor",
));
}
let output = scalar_add_i64(&input, scalar_t.data[0])?;
store.insert(node.output_ids[0], output);
}
"scalar_mul" => {
let input = store.get(node.inputs[0])?.clone();
let scalar_t = store.get(node.inputs[1])?;
if scalar_t.data.is_empty() {
return Err(Error::backend(
"scalar_mul scalar input must be a 1-element tensor",
));
}
let output = scalar_mul_i64(&input, scalar_t.data[0])?;
store.insert(node.output_ids[0], output);
}
"pow" => {
let input = store.get(node.inputs[0])?.clone();
let exp_t = store.get(node.inputs[1])?;
if exp_t.data.is_empty() {
return Err(Error::backend("pow exp input must be a 1-element tensor"));
}
let output = pow_i64(&input, exp_t.data[0] as i32)?;
store.insert(node.output_ids[0], output);
}
"sqrt" => {
let input = store.get(node.inputs[0])?.clone();
let output = sqrt_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"exp2" => {
let input = store.get(node.inputs[0])?.clone();
let output = exp2_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"log2" => {
let input = store.get(node.inputs[0])?.clone();
let output = log2_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"reshape" => {
let input = store.get(node.inputs[0])?.clone();
let target = store.get(node.inputs[1])?.clone();
let target_dims: Vec<usize> = target
.meta
.shape
.dims
.iter()
.enumerate()
.map(|(idx, d)| {
d.value().ok_or_else(|| {
Error::backend(format!(
"reshape target shape contains non-static dim at index {idx} (input shape={:?}, target shape={:?})",
input.meta.shape.dims, target.meta.shape.dims
))
})
})
.collect::<Result<Vec<_>>>()?;
let output = shape_reshape_i64(&input, &target_dims)?;
store.insert(node.output_ids[0], output);
}
"transpose" => {
let input = store.get(node.inputs[0])?.clone();
let axes = store.get(node.inputs[1])?.clone();
if axes.data.len() != 2 {
return Err(Error::backend(format!(
"transpose expects a 2-element axes tensor, got {}",
axes.data.len()
)));
}
let output = shape_transpose_i64(&input, axes.data[0], axes.data[1])?;
store.insert(node.output_ids[0], output);
}
"permute" => {
let input = store.get(node.inputs[0])?.clone();
let perm = store.get(node.inputs[1])?.clone();
let permutation: Vec<i64> = perm.data.clone();
let output = shape_permute_i64(&input, &permutation)?;
store.insert(node.output_ids[0], output);
}
"slice" => {
let input = store.get(node.inputs[0])?.clone();
let bounds = store.get(node.inputs[1])?.clone();
if bounds.data.len() != 3 {
return Err(Error::backend(format!(
"slice expects a 3-element bounds tensor [axis, start, end], got {}",
bounds.data.len()
)));
}
let axis = bounds.data[0];
let start = bounds.data[1] as usize;
let end = bounds.data[2] as usize;
let output = shape_slice_i64(&input, axis, start, end)?;
store.insert(node.output_ids[0], output);
}
"concat" => {
if node.inputs.is_empty() {
return Err(Error::shape("concat: no input tensors"));
}
let last = node.inputs[node.inputs.len() - 1];
let axis_t = store.get(last)?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(
"concat: axis tensor must have at least 1 element",
));
}
let axis = axis_t.data[0];
let mut tensors = Vec::with_capacity(node.inputs.len() - 1);
for &inp in &node.inputs[..node.inputs.len() - 1] {
tensors.push(store.get(inp)?.clone());
}
let output = shape_concat_i64(&tensors, axis)?;
store.insert(node.output_ids[0], output);
}
"broadcast" => {
let input = store.get(node.inputs[0])?.clone();
let target = store.get(node.inputs[1])?.clone();
let target_dims: Vec<usize> = target
.meta
.shape
.dims
.iter()
.map(|d| {
d.value().ok_or_else(|| {
Error::shape("broadcast target shape contains non-static dim")
})
})
.collect::<Result<Vec<_>>>()?;
let output = shape_broadcast_i64(&input, &target_dims)?;
store.insert(node.output_ids[0], output);
}
"flatten" => {
let input = store.get(node.inputs[0])?.clone();
let output = shape_flatten_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"squeeze" => {
let input = store.get(node.inputs[0])?.clone();
let output = shape_squeeze_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"unsqueeze" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(
"unsqueeze: axis tensor must have at least 1 element",
));
}
let axis = axis_t.data[0];
let output = shape_unsqueeze_i64(&input, axis)?;
store.insert(node.output_ids[0], output);
}
"relu" => {
let input = store.get(node.inputs[0])?.clone();
let output = nn_relu_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"sigmoid" => {
let input = store.get(node.inputs[0])?.clone();
let output = nn_sigmoid_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"tanh" => {
let input = store.get(node.inputs[0])?.clone();
let output = nn_tanh_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"gelu" => {
let input = store.get(node.inputs[0])?.clone();
let output = nn_gelu_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"softmax" => {
let input = store.get(node.inputs[0])?.clone();
let output = nn_softmax_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"layer_norm" => {
if node.inputs.len() != 3 {
return Err(Error::backend(
"layer_norm: expects 3 inputs (input, gamma, beta)",
));
}
let input = store.get(node.inputs[0])?.clone();
let gamma = store.get(node.inputs[1])?.clone();
let beta = store.get(node.inputs[2])?.clone();
let output = nn_layer_norm_i64(&input, &gamma, &beta)?;
store.insert(node.output_ids[0], output);
}
"gather" => {
if node.inputs.len() != 3 {
return Err(Error::backend(
"gather: expects 3 inputs (input, indices, axis)",
));
}
let input = store.get(node.inputs[0])?.clone();
let indices = store.get(node.inputs[1])?.clone();
let axis_t = store.get(node.inputs[2])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"gather: axis tensor must be non-empty (input shape={:?}, indices shape={:?})",
input.meta.shape.dims, indices.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = index_gather_i64(&input, &indices, axis)?;
store.insert(node.output_ids[0], output);
}
"scatter" => {
if node.inputs.len() != 3 {
return Err(Error::backend(
"scatter: expects 3 inputs (input, indices, axis)",
));
}
let input = store.get(node.inputs[0])?.clone();
let indices = store.get(node.inputs[1])?.clone();
let axis_t = store.get(node.inputs[2])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"scatter: axis tensor must be non-empty (input shape={:?}, indices shape={:?})",
input.meta.shape.dims, indices.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = index_scatter_i64(&input, &indices, axis)?;
store.insert(node.output_ids[0], output);
}
"index_select" => {
if node.inputs.len() != 3 {
return Err(Error::backend(
"index_select: expects 3 inputs (input, indices, axis)",
));
}
let input = store.get(node.inputs[0])?.clone();
let indices = store.get(node.inputs[1])?.clone();
let axis_t = store.get(node.inputs[2])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"index_select: axis tensor must be non-empty (input shape={:?}, indices shape={:?})",
input.meta.shape.dims, indices.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = index_select_i64(&input, &indices, axis)?;
store.insert(node.output_ids[0], output);
}
"index_add" => {
if node.inputs.len() != 4 {
return Err(Error::backend(
"index_add: expects 4 inputs (input, indices, source, axis)",
));
}
let input = store.get(node.inputs[0])?.clone();
let indices = store.get(node.inputs[1])?.clone();
let source = store.get(node.inputs[2])?.clone();
let axis_t = store.get(node.inputs[3])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"index_add: axis tensor must be non-empty (input shape={:?}, indices shape={:?}, source shape={:?})",
input.meta.shape.dims, indices.meta.shape.dims, source.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = index_add_i64(&input, &indices, &source, axis)?;
store.insert(node.output_ids[0], output);
}
"nonzero" => {
let input = store.get(node.inputs[0])?.clone();
let output = index_nonzero_i64(&input)?;
store.insert(node.output_ids[0], output);
}
"sum" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"sum: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::Sum)?;
store.insert(node.output_ids[0], output);
}
"mean" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"mean: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::Mean)?;
store.insert(node.output_ids[0], output);
}
"max" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"max: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::Max)?;
store.insert(node.output_ids[0], output);
}
"min" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"min: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::Min)?;
store.insert(node.output_ids[0], output);
}
"argmax" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"argmax: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::ArgMax)?;
store.insert(node.output_ids[0], output);
}
"argmin" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"argmin: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::ArgMin)?;
store.insert(node.output_ids[0], output);
}
"prod" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"prod: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::Prod)?;
store.insert(node.output_ids[0], output);
}
"any" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"any: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::Any)?;
store.insert(node.output_ids[0], output);
}
"all" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"all: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::All)?;
store.insert(node.output_ids[0], output);
}
other => {
return Err(Error::backend(format!(
"cpu scalar reference backend does not support op {other}"
)));
}
}
}
Ok(())
}
pub fn execute_i64_plan(
&self,
graph: &SemanticGraph,
plan: &ExecutionPlan,
store: &mut TensorStore<i64>,
) -> Result<()> {
if plan.backend != self.name() {
return Err(Error::backend(format!(
"plan targets backend {}, but executor is {}",
plan.backend,
self.name()
)));
}
let fused_nodes = fused_node_ids(plan);
for step in &plan.steps {
match &step.kind {
PlanStepKind::Single => {
if fused_nodes.contains(&step.node_id) {
continue;
}
let node = graph
.nodes()
.get(step.node_id)
.ok_or_else(|| Error::backend(format!("unknown node {}", step.node_id)))?;
if node
.output_ids
.iter()
.all(|output_id| store.contains(*output_id))
{
continue;
}
execute_i64_node(node, store)?;
}
PlanStepKind::Fused { node_ids, rule } => {
execute_i64_fused(graph, node_ids, rule, store)?;
}
PlanStepKind::PadicValuationSkip { .. } => {
return Err(Error::backend(
"p-adic valuation skip step cannot be executed by i64 plan executor",
));
}
PlanStepKind::PadicMatmulValuationSkip { .. } => {
return Err(Error::backend(
"p-adic matmul valuation skip step cannot be executed by i64 plan executor",
));
}
PlanStepKind::CoverGlueCheck { .. } => {
return Err(Error::backend(
"cover glue check step cannot be executed by i64 plan executor",
));
}
}
}
Ok(())
}
pub fn execute_i64_plan_with_registry(
&self,
graph: &SemanticGraph,
plan: &ExecutionPlan,
registry: &OperatorRegistry,
store: &mut TensorStore<i64>,
) -> Result<()> {
self.ensure_step_lowerings(plan, registry, |kind| {
matches!(kind, PlanStepKind::Single | PlanStepKind::Fused { .. })
})?;
self.execute_i64_plan(graph, plan, store)
}
fn ensure_step_lowerings(
&self,
plan: &ExecutionPlan,
registry: &OperatorRegistry,
required: impl Fn(&PlanStepKind) -> bool,
) -> Result<()> {
for step in &plan.steps {
if !required(&step.kind) {
continue;
}
let Some(lowering_rule_id) = &step.lowering_rule_id else {
return Err(Error::backend(format!(
"step {} ({}) has no registry lowering rule",
step.node_id, step.op_name
)));
};
let lowering = registry.lowering_by_id(lowering_rule_id).ok_or_else(|| {
Error::backend(format!(
"missing registry lowering rule {} for step {} ({})",
lowering_rule_id, step.node_id, step.op_name
))
})?;
if lowering.op_name != step.op_name
|| lowering.backend != self.name()
|| !lowering.supports_representation(&step.representation)
|| !lowering.supports_domain(&step.domain)
{
return Err(Error::backend(format!(
"lowering rule mismatch for step {} ({}): rule={}, backend={}, domain={}, representation={}",
step.node_id,
step.op_name,
lowering_rule_id,
self.name(),
step.domain,
step.representation
)));
}
}
Ok(())
}
pub fn execute_padic(
&self,
graph: &SemanticGraph,
plan: &ExecutionPlan,
store: &mut TensorStore<Padic>,
) -> Result<()> {
if plan.backend != self.name() {
return Err(Error::backend(format!(
"plan targets backend {}, but executor is {}",
plan.backend,
self.name()
)));
}
for node in graph.nodes() {
match node.op_name.as_str() {
"add" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_binary_padic(&lhs, &rhs, |domain, a, b| domain.add(a, b))?;
store.insert(node.output_ids[0], output);
}
"mul" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_binary_padic(&lhs, &rhs, |domain, a, b| domain.mul(a, b))?;
store.insert(node.output_ids[0], output);
}
"matmul" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_matmul_padic(&lhs, &rhs, node.outputs[0].clone())?;
store.insert(node.output_ids[0], output);
}
"fma" => {
let a = store.get(node.inputs[0])?.clone();
let b = store.get(node.inputs[1])?.clone();
let c = store.get(node.inputs[2])?.clone();
let output = dense_ternary_padic(&a, &b, &c, |domain, x, y, z| {
let prod = domain.mul(x, y)?;
domain.add(&prod, z)
})?;
store.insert(node.output_ids[0], output);
}
"p_pad_fma" => {
let a = store.get(node.inputs[0])?.clone();
let b = store.get(node.inputs[1])?.clone();
let c = store.get(node.inputs[2])?.clone();
let output = dense_matmul_fma_padic(&a, &b, &c, node.outputs[0].clone())?;
store.insert(node.output_ids[0], output);
}
"p_dot" => {
let a = store.get(node.inputs[0])?.clone();
let b = store.get(node.inputs[1])?.clone();
let output = dense_dot_padic(&a, &b, node.outputs[0].clone())?;
store.insert(node.output_ids[0], output);
}
other => {
return Err(Error::backend(format!(
"cpu scalar p-adic backend does not support op {other}"
)));
}
}
}
Ok(())
}
pub fn execute_padic_plan_with_registry(
&self,
graph: &SemanticGraph,
plan: &ExecutionPlan,
registry: &OperatorRegistry,
store: &mut TensorStore<Padic>,
) -> Result<()> {
self.ensure_step_lowerings(plan, registry, |kind| matches!(kind, PlanStepKind::Single))?;
self.execute_padic(graph, plan, store)
}
pub fn execute_padic_matmul_with_valuation_skip(
&self,
graph: &SemanticGraph,
plan: &ExecutionPlan,
store: &mut TensorStore<Padic>,
) -> Result<PadicMatmulValuationSkipReport> {
if plan.backend != self.name() {
return Err(Error::backend(format!(
"plan targets backend {}, but executor is {}",
plan.backend,
self.name()
)));
}
let mut report = PadicMatmulValuationSkipReport {
skipped_terms: 0,
evaluated_terms: 0,
lhs_metadata: Vec::new(),
rhs_metadata: Vec::new(),
output_certificates: Vec::new(),
dense_oracle_matches: true,
};
let optimized_nodes = padic_matmul_valuation_skip_node_ids(plan);
for step in &plan.steps {
match &step.kind {
PlanStepKind::Single => {
if optimized_nodes.contains(&step.node_id) {
continue;
}
let node = graph
.nodes()
.get(step.node_id)
.ok_or_else(|| Error::backend(format!("unknown node {}", step.node_id)))?;
if node.op_name != "matmul" {
execute_padic_node(node, store)?;
continue;
}
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let (output, node_report) = dense_matmul_padic_with_valuation_skip(
&lhs,
&rhs,
node.outputs[0].clone(),
)?;
report.merge(node_report);
store.insert(node.output_ids[0], output);
}
PlanStepKind::PadicMatmulValuationSkip {
lhs_id,
rhs_id,
output_id,
..
} => {
let lhs = store.get(*lhs_id)?.clone();
let rhs = store.get(*rhs_id)?.clone();
let node = graph
.nodes()
.get(step.node_id)
.ok_or_else(|| Error::backend(format!("unknown node {}", step.node_id)))?;
let (output, node_report) = dense_matmul_padic_with_valuation_skip(
&lhs,
&rhs,
node.outputs[0].clone(),
)?;
report.merge(node_report);
store.insert(*output_id, output);
}
PlanStepKind::Fused { .. } => {
return Err(Error::backend(
"fused p-adic valuation-skip matmul execution is not supported",
));
}
PlanStepKind::PadicValuationSkip { .. } => {
return Err(Error::backend(
"specialized p-adic sum-products skip step cannot be mixed with graph matmul skip executor",
));
}
PlanStepKind::CoverGlueCheck { .. } => {
return Err(Error::backend(
"cover glue check step cannot be executed by p-adic matmul skip executor",
));
}
}
}
Ok(report)
}
pub fn execute_padic_matmul_with_valuation_skip_and_registry(
&self,
graph: &SemanticGraph,
plan: &ExecutionPlan,
registry: &OperatorRegistry,
store: &mut TensorStore<Padic>,
) -> Result<PadicMatmulValuationSkipReport> {
self.ensure_step_lowerings(plan, registry, |kind| {
matches!(
kind,
PlanStepKind::Single | PlanStepKind::PadicMatmulValuationSkip { .. }
)
})?;
self.execute_padic_matmul_with_valuation_skip(graph, plan, store)
}
pub fn execute_padic_sum_products_plan(
&self,
plan: &ExecutionPlan,
store: &mut TensorStore<Padic>,
) -> Result<ValuationSkipReport> {
if plan.backend != self.name() {
return Err(Error::backend(format!(
"plan targets backend {}, but executor is {}",
plan.backend,
self.name()
)));
}
let step = plan
.steps
.iter()
.find(|step| matches!(step.kind, PlanStepKind::PadicValuationSkip { .. }))
.ok_or_else(|| Error::backend("missing p-adic valuation skip step"))?;
let PlanStepKind::PadicValuationSkip {
lhs_id,
rhs_id,
output_id,
prime,
precision,
} = step.kind
else {
unreachable!("step kind checked above")
};
let lhs = store.get(lhs_id)?.clone();
let rhs = store.get(rhs_id)?.clone();
let domain = PadicDomain::new(prime, precision)?;
let report = domain.sum_products_with_valuation_skip(&lhs.data, &rhs.data)?;
store.insert(
output_id,
Tensor::dense_cpu(
domain.id(),
crate::object::Shape::scalar(),
vec![report.result.clone()],
),
);
Ok(report)
}
pub fn execute_padic_sum_products_plan_with_registry(
&self,
plan: &ExecutionPlan,
registry: &OperatorRegistry,
store: &mut TensorStore<Padic>,
) -> Result<ValuationSkipReport> {
self.ensure_step_lowerings(plan, registry, |kind| {
matches!(kind, PlanStepKind::PadicValuationSkip { .. })
})?;
self.execute_padic_sum_products_plan(plan, store)
}
}
impl GraphExecutor<i64> for CpuScalarBackend {
fn execute_graph(
&self,
graph: &SemanticGraph,
plan: &ExecutionPlan,
store: &mut TensorStore<i64>,
) -> Result<()> {
self.execute_i64_plan(graph, plan, store)
}
}
impl GraphExecutor<Padic> for CpuScalarBackend {
fn execute_graph(
&self,
graph: &SemanticGraph,
plan: &ExecutionPlan,
store: &mut TensorStore<Padic>,
) -> Result<()> {
self.execute_padic(graph, plan, store)
}
}
impl SpecializedPlanExecutor<Padic, ValuationSkipReport> for CpuScalarBackend {
fn execute_specialized_plan(
&self,
plan: &ExecutionPlan,
store: &mut TensorStore<Padic>,
) -> Result<ValuationSkipReport> {
self.execute_padic_sum_products_plan(plan, store)
}
}
fn fused_node_ids(plan: &ExecutionPlan) -> BTreeSet<usize> {
let mut ids = BTreeSet::new();
for step in &plan.steps {
if let PlanStepKind::Fused { node_ids, .. } = &step.kind {
ids.extend(node_ids.iter().copied());
}
}
ids
}
fn padic_matmul_valuation_skip_node_ids(plan: &ExecutionPlan) -> BTreeSet<usize> {
let mut ids = BTreeSet::new();
for step in &plan.steps {
if matches!(step.kind, PlanStepKind::PadicMatmulValuationSkip { .. }) {
ids.insert(step.node_id);
}
}
ids
}
fn execute_i64_node(node: &crate::ir::SemanticNode, store: &mut TensorStore<i64>) -> Result<()> {
match node.op_name.as_str() {
"add" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_binary_i64(&lhs, &rhs, |a, b| a + b)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"mul" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_binary_i64(&lhs, &rhs, |a, b| a * b)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"matmul" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_matmul_i64(&lhs, &rhs, node.outputs[0].clone())?;
store.insert(node.output_ids[0], output);
Ok(())
}
"map" => {
let input = store.get(node.inputs[0])?.clone();
store.insert(node.output_ids[0], input);
Ok(())
}
"reduce" => {
let input = store.get(node.inputs[0])?;
let sum = input.data.iter().copied().sum::<i64>();
let output = Tensor {
meta: node.outputs[0].clone(),
data: vec![sum],
};
store.insert(node.output_ids[0], output);
Ok(())
}
"fma" => {
let a = store.get(node.inputs[0])?.clone();
let b = store.get(node.inputs[1])?.clone();
let c = store.get(node.inputs[2])?.clone();
let output = dense_ternary_i64(&a, &b, &c, |x, y, z| x * y + z)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"clamp" => {
let input = store.get(node.inputs[0])?.clone();
let lo_tensor = store.get(node.inputs[1])?.clone();
let hi_tensor = store.get(node.inputs[2])?.clone();
if lo_tensor.data.is_empty() || hi_tensor.data.is_empty() {
return Err(Error::backend(format!(
"clamp expects lo and hi to be scalar tensors with at least one element (got lo_len={}, hi_len={}, input shape={:?})",
lo_tensor.data.len(),
hi_tensor.data.len(),
input.meta.shape.dims,
)));
}
let lo = lo_tensor.data[0];
let hi = hi_tensor.data[0];
let output = clamp_i64(&input, lo, hi)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"neg" => {
let input = store.get(node.inputs[0])?.clone();
let output = neg_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"abs" => {
let input = store.get(node.inputs[0])?.clone();
let output = abs_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"square" => {
let input = store.get(node.inputs[0])?.clone();
let output = square_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"mul_by_two" => {
let input = store.get(node.inputs[0])?.clone();
let output = mul_by_two_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"sub" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_sub_i64(&lhs, &rhs)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"div" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_div_i64(&lhs, &rhs)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"scalar_add" => {
let input = store.get(node.inputs[0])?.clone();
let scalar_t = store.get(node.inputs[1])?.clone();
if scalar_t.data.is_empty() {
return Err(Error::backend(
"scalar_add scalar input must be a 1-element tensor",
));
}
let output = scalar_add_i64(&input, scalar_t.data[0])?;
store.insert(node.output_ids[0], output);
Ok(())
}
"scalar_mul" => {
let input = store.get(node.inputs[0])?.clone();
let scalar_t = store.get(node.inputs[1])?.clone();
if scalar_t.data.is_empty() {
return Err(Error::backend(
"scalar_mul scalar input must be a 1-element tensor",
));
}
let output = scalar_mul_i64(&input, scalar_t.data[0])?;
store.insert(node.output_ids[0], output);
Ok(())
}
"pow" => {
let input = store.get(node.inputs[0])?.clone();
let exp_t = store.get(node.inputs[1])?.clone();
if exp_t.data.is_empty() {
return Err(Error::backend("pow exp input must be a 1-element tensor"));
}
let output = pow_i64(&input, exp_t.data[0] as i32)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"sqrt" => {
let input = store.get(node.inputs[0])?.clone();
let output = sqrt_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"exp2" => {
let input = store.get(node.inputs[0])?.clone();
let output = exp2_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"log2" => {
let input = store.get(node.inputs[0])?.clone();
let output = log2_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"reshape" => {
let input = store.get(node.inputs[0])?.clone();
let target = store.get(node.inputs[1])?.clone();
let target_dims: Vec<usize> = target
.meta
.shape
.dims
.iter()
.enumerate()
.map(|(idx, d)| {
d.value().ok_or_else(|| {
Error::backend(format!(
"reshape target shape contains non-static dim at index {idx} (input shape={:?}, target shape={:?})",
input.meta.shape.dims, target.meta.shape.dims
))
})
})
.collect::<Result<Vec<_>>>()?;
let output = shape_reshape_i64(&input, &target_dims)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"transpose" => {
let input = store.get(node.inputs[0])?.clone();
let axes = store.get(node.inputs[1])?.clone();
if axes.data.len() != 2 {
return Err(Error::backend(format!(
"transpose expects a 2-element axes tensor, got {}",
axes.data.len()
)));
}
let output = shape_transpose_i64(&input, axes.data[0], axes.data[1])?;
store.insert(node.output_ids[0], output);
Ok(())
}
"permute" => {
let input = store.get(node.inputs[0])?.clone();
let perm = store.get(node.inputs[1])?.clone();
let permutation: Vec<i64> = perm.data.clone();
let output = shape_permute_i64(&input, &permutation)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"slice" => {
let input = store.get(node.inputs[0])?.clone();
let bounds = store.get(node.inputs[1])?.clone();
if bounds.data.len() != 3 {
return Err(Error::backend(format!(
"slice expects a 3-element bounds tensor [axis, start, end], got {}",
bounds.data.len()
)));
}
let axis = bounds.data[0];
let start = bounds.data[1] as usize;
let end = bounds.data[2] as usize;
let output = shape_slice_i64(&input, axis, start, end)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"concat" => {
if node.inputs.is_empty() {
return Err(Error::shape("concat: no input tensors"));
}
let last = node.inputs[node.inputs.len() - 1];
let axis_t = store.get(last)?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(
"concat: axis tensor must have at least 1 element",
));
}
let axis = axis_t.data[0];
let mut tensors = Vec::with_capacity(node.inputs.len() - 1);
for &inp in &node.inputs[..node.inputs.len() - 1] {
tensors.push(store.get(inp)?.clone());
}
let output = shape_concat_i64(&tensors, axis)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"broadcast" => {
let input = store.get(node.inputs[0])?.clone();
let target = store.get(node.inputs[1])?.clone();
let target_dims: Vec<usize> = target
.meta
.shape
.dims
.iter()
.map(|d| {
d.value().ok_or_else(|| {
Error::shape("broadcast target shape contains non-static dim")
})
})
.collect::<Result<Vec<_>>>()?;
let output = shape_broadcast_i64(&input, &target_dims)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"flatten" => {
let input = store.get(node.inputs[0])?.clone();
let output = shape_flatten_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"squeeze" => {
let input = store.get(node.inputs[0])?.clone();
let output = shape_squeeze_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"unsqueeze" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(
"unsqueeze: axis tensor must have at least 1 element",
));
}
let axis = axis_t.data[0];
let output = shape_unsqueeze_i64(&input, axis)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"relu" => {
let input = store.get(node.inputs[0])?.clone();
let output = nn_relu_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"sigmoid" => {
let input = store.get(node.inputs[0])?.clone();
let output = nn_sigmoid_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"tanh" => {
let input = store.get(node.inputs[0])?.clone();
let output = nn_tanh_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"gelu" => {
let input = store.get(node.inputs[0])?.clone();
let output = nn_gelu_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"softmax" => {
let input = store.get(node.inputs[0])?.clone();
let output = nn_softmax_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"layer_norm" => {
if node.inputs.len() != 3 {
return Err(Error::backend(
"layer_norm: expects 3 inputs (input, gamma, beta)",
));
}
let input = store.get(node.inputs[0])?.clone();
let gamma = store.get(node.inputs[1])?.clone();
let beta = store.get(node.inputs[2])?.clone();
let output = nn_layer_norm_i64(&input, &gamma, &beta)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"gather" => {
if node.inputs.len() != 3 {
return Err(Error::backend(
"gather: expects 3 inputs (input, indices, axis)",
));
}
let input = store.get(node.inputs[0])?.clone();
let indices = store.get(node.inputs[1])?.clone();
let axis_t = store.get(node.inputs[2])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"gather: axis tensor must be non-empty (input shape={:?}, indices shape={:?})",
input.meta.shape.dims, indices.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = index_gather_i64(&input, &indices, axis)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"scatter" => {
if node.inputs.len() != 3 {
return Err(Error::backend(
"scatter: expects 3 inputs (input, indices, axis)",
));
}
let input = store.get(node.inputs[0])?.clone();
let indices = store.get(node.inputs[1])?.clone();
let axis_t = store.get(node.inputs[2])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"scatter: axis tensor must be non-empty (input shape={:?}, indices shape={:?})",
input.meta.shape.dims, indices.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = index_scatter_i64(&input, &indices, axis)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"index_select" => {
if node.inputs.len() != 3 {
return Err(Error::backend(
"index_select: expects 3 inputs (input, indices, axis)",
));
}
let input = store.get(node.inputs[0])?.clone();
let indices = store.get(node.inputs[1])?.clone();
let axis_t = store.get(node.inputs[2])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"index_select: axis tensor must be non-empty (input shape={:?}, indices shape={:?})",
input.meta.shape.dims, indices.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = index_select_i64(&input, &indices, axis)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"index_add" => {
if node.inputs.len() != 4 {
return Err(Error::backend(
"index_add: expects 4 inputs (input, indices, source, axis)",
));
}
let input = store.get(node.inputs[0])?.clone();
let indices = store.get(node.inputs[1])?.clone();
let source = store.get(node.inputs[2])?.clone();
let axis_t = store.get(node.inputs[3])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"index_add: axis tensor must be non-empty (input shape={:?}, indices shape={:?}, source shape={:?})",
input.meta.shape.dims, indices.meta.shape.dims, source.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = index_add_i64(&input, &indices, &source, axis)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"nonzero" => {
let input = store.get(node.inputs[0])?.clone();
let output = index_nonzero_i64(&input)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"sum" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"sum: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::Sum)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"mean" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"mean: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::Mean)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"max" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"max: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::Max)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"min" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"min: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::Min)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"argmax" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"argmax: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::ArgMax)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"argmin" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"argmin: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::ArgMin)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"prod" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"prod: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::Prod)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"any" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"any: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::Any)?;
store.insert(node.output_ids[0], output);
Ok(())
}
"all" => {
let input = store.get(node.inputs[0])?.clone();
let axis_t = store.get(node.inputs[1])?.clone();
if axis_t.data.is_empty() {
return Err(Error::backend(format!(
"all: axis tensor must be non-empty (input shape={:?})",
input.meta.shape.dims
)));
}
let axis = axis_t.data[0];
let output = reduce_i64(&input, axis, ReductionKind::All)?;
store.insert(node.output_ids[0], output);
Ok(())
}
other => Err(Error::backend(format!(
"cpu scalar reference backend does not support op {other}"
))),
}
}
fn execute_padic_node(
node: &crate::ir::SemanticNode,
store: &mut TensorStore<Padic>,
) -> Result<()> {
match node.op_name.as_str() {
"add" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_binary_padic(&lhs, &rhs, |domain, a, b| domain.add(a, b))?;
store.insert(node.output_ids[0], output);
Ok(())
}
"mul" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_binary_padic(&lhs, &rhs, |domain, a, b| domain.mul(a, b))?;
store.insert(node.output_ids[0], output);
Ok(())
}
"matmul" => {
let lhs = store.get(node.inputs[0])?.clone();
let rhs = store.get(node.inputs[1])?.clone();
let output = dense_matmul_padic(&lhs, &rhs, node.outputs[0].clone())?;
store.insert(node.output_ids[0], output);
Ok(())
}
"fma" => {
let a = store.get(node.inputs[0])?.clone();
let b = store.get(node.inputs[1])?.clone();
let c = store.get(node.inputs[2])?.clone();
let output = dense_ternary_padic(&a, &b, &c, |domain, x, y, z| {
let prod = domain.mul(x, y)?;
domain.add(&prod, z)
})?;
store.insert(node.output_ids[0], output);
Ok(())
}
other => Err(Error::backend(format!(
"cpu scalar p-adic backend does not support op {other}"
))),
}
}
fn execute_i64_fused(
graph: &SemanticGraph,
node_ids: &[usize],
rule: &str,
store: &mut TensorStore<i64>,
) -> Result<()> {
if rule != "pointwise_metadata_preserving_fusion" || node_ids.len() < 2 {
return Err(Error::backend(format!(
"unsupported fused rule {rule} for nodes {node_ids:?}"
)));
}
let mut current: Option<Tensor<i64>> = None;
let mut current_output_id: Option<usize> = None;
for node_id in node_ids {
let node = graph
.nodes()
.get(*node_id)
.ok_or_else(|| Error::backend(format!("unknown node {node_id}")))?;
current = Some(match node.op_name.as_str() {
"add" => {
let lhs = fused_input(node.inputs[0], current_output_id, current.as_ref(), store)?;
let rhs = fused_input(node.inputs[1], current_output_id, current.as_ref(), store)?;
dense_binary_i64(&lhs, &rhs, |a, b| a + b)?
}
"mul" => {
let lhs = fused_input(node.inputs[0], current_output_id, current.as_ref(), store)?;
let rhs = fused_input(node.inputs[1], current_output_id, current.as_ref(), store)?;
dense_binary_i64(&lhs, &rhs, |a, b| a * b)?
}
"map" => fused_input(node.inputs[0], current_output_id, current.as_ref(), store)?,
other => {
return Err(Error::backend(format!(
"unsupported fused pointwise op {other}"
)));
}
});
current_output_id = node.output_ids.first().copied();
}
let last_node_id = *node_ids
.last()
.ok_or_else(|| Error::backend("empty fused group"))?;
let last_node = graph
.nodes()
.get(last_node_id)
.ok_or_else(|| Error::backend(format!("unknown node {last_node_id}")))?;
let mut output = current.ok_or_else(|| Error::backend("fused group produced no output"))?;
output.meta = last_node.outputs[0].clone();
store.insert(last_node.output_ids[0], output);
Ok(())
}
fn fused_input(
value_id: usize,
current_output_id: Option<usize>,
current: Option<&Tensor<i64>>,
store: &TensorStore<i64>,
) -> Result<Tensor<i64>> {
if Some(value_id) == current_output_id {
if let Some(value) = current {
Ok(value.clone())
} else {
Err(Error::backend("missing current fused tensor"))
}
} else if let Some(value) = store.get_optional(value_id) {
Ok(value.clone())
} else {
Err(Error::backend(format!(
"missing fused input value {value_id}"
)))
}
}
fn dense_binary_i64(
lhs: &Tensor<i64>,
rhs: &Tensor<i64>,
op: impl Fn(i64, i64) -> i64,
) -> Result<Tensor<i64>> {
lhs.meta.shape.ensure_same(&rhs.meta.shape)?;
if lhs.data.len() != rhs.data.len() {
return Err(Error::shape(format!(
"tensor data length mismatch: left={}, right={}",
lhs.data.len(),
rhs.data.len()
)));
}
let data = lhs
.data
.iter()
.copied()
.zip(rhs.data.iter().copied())
.map(|(a, b)| op(a, b))
.collect();
Ok(Tensor {
meta: lhs.meta.clone(),
data,
})
}
fn dense_matmul_i64(
lhs: &Tensor<i64>,
rhs: &Tensor<i64>,
output_meta: crate::object::ObjectMeta,
) -> Result<Tensor<i64>> {
let [m, k] = static_matrix_dims(&lhs.meta.shape, "left")?;
let [rhs_k, n] = static_matrix_dims(&rhs.meta.shape, "right")?;
if k != rhs_k {
return Err(Error::shape(format!(
"matmul inner dimension mismatch: left={k}, right={rhs_k}"
)));
}
if lhs.data.len() != m * k {
return Err(Error::backend(format!(
"left matmul data length mismatch: expected {}, got {}",
m * k,
lhs.data.len()
)));
}
if rhs.data.len() != rhs_k * n {
return Err(Error::backend(format!(
"right matmul data length mismatch: expected {}, got {}",
rhs_k * n,
rhs.data.len()
)));
}
let mut data = vec![0; m * n];
for row in 0..m {
for col in 0..n {
let mut acc = 0;
for inner in 0..k {
acc += lhs.data[row * k + inner] * rhs.data[inner * n + col];
}
data[row * n + col] = acc;
}
}
Ok(Tensor {
meta: output_meta,
data,
})
}
fn static_matrix_dims(shape: &crate::object::Shape, side: &str) -> Result<[usize; 2]> {
if shape.dims.len() != 2 {
return Err(Error::backend(format!(
"{side} matmul input must be rank-2, got rank {}",
shape.dims.len()
)));
}
let dim = |index: usize| match &shape.dims[index] {
crate::object::Dim::Static(value) => Ok(*value),
other => Err(Error::backend(format!(
"{side} matmul input requires static dimensions, got {other:?}"
))),
};
Ok([dim(0)?, dim(1)?])
}
fn dense_binary_padic(
lhs: &Tensor<Padic>,
rhs: &Tensor<Padic>,
op: impl Fn(&PadicDomain, &Padic, &Padic) -> Result<Padic>,
) -> Result<Tensor<Padic>> {
lhs.meta.shape.ensure_same(&rhs.meta.shape)?;
if lhs.data.len() != rhs.data.len() {
return Err(Error::shape(format!(
"tensor data length mismatch: left={}, right={}",
lhs.data.len(),
rhs.data.len()
)));
}
let first = lhs
.data
.first()
.or_else(|| rhs.data.first())
.ok_or_else(|| Error::backend("cannot infer p-adic domain from empty tensors"))?;
let domain = PadicDomain {
meta: first.meta.clone(),
};
let data = lhs
.data
.iter()
.zip(rhs.data.iter())
.map(|(a, b)| op(&domain, a, b))
.collect::<Result<Vec<_>>>()?;
Ok(Tensor {
meta: lhs.meta.clone(),
data,
})
}
fn mul_by_two_i64(input: &Tensor<i64>) -> Result<Tensor<i64>> {
let data: Vec<i64> = input.data.iter().map(|v| v.wrapping_mul(2)).collect();
Ok(Tensor {
meta: input.meta.clone(),
data,
})
}
fn square_i64(input: &Tensor<i64>) -> Result<Tensor<i64>> {
let data: Vec<i64> = input.data.iter().map(|v| v.wrapping_mul(*v)).collect();
Ok(Tensor {
meta: input.meta.clone(),
data,
})
}
fn abs_i64(input: &Tensor<i64>) -> Result<Tensor<i64>> {
let data: Vec<i64> = input.data.iter().map(|v| v.saturating_abs()).collect();
Ok(Tensor {
meta: input.meta.clone(),
data,
})
}
fn neg_i64(input: &Tensor<i64>) -> Result<Tensor<i64>> {
let data: Vec<i64> = input.data.iter().map(|v| v.wrapping_neg()).collect();
Ok(Tensor {
meta: input.meta.clone(),
data,
})
}
fn clamp_i64(input: &Tensor<i64>, lo: i64, hi: i64) -> Result<Tensor<i64>> {
if lo > hi {
return Err(Error::backend(format!("clamp lo={lo} must be <= hi={hi}")));
}
let data: Vec<i64> = input
.data
.iter()
.copied()
.map(|v| v.clamp(lo, hi))
.collect();
Ok(Tensor {
meta: input.meta.clone(),
data,
})
}
fn dense_ternary_i64(
a: &Tensor<i64>,
b: &Tensor<i64>,
c: &Tensor<i64>,
op: impl Fn(i64, i64, i64) -> i64,
) -> Result<Tensor<i64>> {
a.meta.shape.ensure_same(&b.meta.shape)?;
b.meta.shape.ensure_same(&c.meta.shape)?;
if a.data.len() != b.data.len() || b.data.len() != c.data.len() {
return Err(Error::shape(format!(
"tensor data length mismatch: a={}, b={}, c={}",
a.data.len(),
b.data.len(),
c.data.len()
)));
}
let data = a
.data
.iter()
.copied()
.zip(b.data.iter().copied())
.zip(c.data.iter().copied())
.map(|((x, y), z)| op(x, y, z))
.collect();
Ok(Tensor {
meta: a.meta.clone(),
data,
})
}
fn dense_ternary_padic(
a: &Tensor<Padic>,
b: &Tensor<Padic>,
c: &Tensor<Padic>,
op: impl Fn(&PadicDomain, &Padic, &Padic, &Padic) -> Result<Padic>,
) -> Result<Tensor<Padic>> {
a.meta.shape.ensure_same(&b.meta.shape)?;
b.meta.shape.ensure_same(&c.meta.shape)?;
if a.data.len() != b.data.len() || b.data.len() != c.data.len() {
return Err(Error::shape(format!(
"tensor data length mismatch: a={}, b={}, c={}",
a.data.len(),
b.data.len(),
c.data.len()
)));
}
let first = a
.data
.first()
.or_else(|| b.data.first())
.or_else(|| c.data.first())
.ok_or_else(|| Error::backend("cannot infer p-adic domain from empty tensors"))?;
let domain = PadicDomain {
meta: first.meta.clone(),
};
let data = a
.data
.iter()
.zip(b.data.iter())
.zip(c.data.iter())
.map(|((x, y), z)| op(&domain, x, y, z))
.collect::<Result<Vec<_>>>()?;
Ok(Tensor {
meta: a.meta.clone(),
data,
})
}
fn dense_dot_padic(
a: &Tensor<Padic>,
b: &Tensor<Padic>,
output_meta: crate::object::ObjectMeta,
) -> Result<Tensor<Padic>> {
a.meta.shape.ensure_same(&b.meta.shape)?;
if a.data.len() != b.data.len() {
return Err(Error::backend(format!(
"p_dot data length mismatch: a={}, b={}",
a.data.len(),
b.data.len()
)));
}
let first = a
.data
.first()
.or_else(|| b.data.first())
.ok_or_else(|| Error::backend("cannot infer p-adic domain from empty vectors"))?;
let domain = PadicDomain {
meta: first.meta.clone(),
};
let modulus = domain.modulus();
let mut acc: u128 = 0;
for (av, bv) in a.data.iter().zip(b.data.iter()) {
acc = (acc + av.residue * bv.residue) % modulus;
}
Ok(Tensor {
meta: output_meta,
data: vec![domain.element(acc)],
})
}
fn dense_matmul_fma_padic(
a: &Tensor<Padic>,
b: &Tensor<Padic>,
c: &Tensor<Padic>,
output_meta: crate::object::ObjectMeta,
) -> Result<Tensor<Padic>> {
let [m, k] = static_matrix_dims(&a.meta.shape, "left")?;
let [b_k, n] = static_matrix_dims(&b.meta.shape, "right")?;
if k != b_k {
return Err(Error::shape(format!(
"p_pad_fma inner dimension mismatch: a_k={k}, b_k={b_k}"
)));
}
if c.data.len() != m * n {
return Err(Error::backend(format!(
"p_pad_fma bias data length mismatch: expected {}, got {}",
m * n,
c.data.len()
)));
}
if a.data.len() != m * k {
return Err(Error::backend(format!(
"p_pad_fma left data length mismatch: expected {}, got {}",
m * k,
a.data.len()
)));
}
if b.data.len() != b_k * n {
return Err(Error::backend(format!(
"p_pad_fma right data length mismatch: expected {}, got {}",
b_k * n,
b.data.len()
)));
}
let first = a
.data
.first()
.or_else(|| b.data.first())
.or_else(|| c.data.first())
.ok_or_else(|| Error::backend("cannot infer p-adic domain from empty matrices"))?;
let domain = PadicDomain {
meta: first.meta.clone(),
};
let modulus = domain.modulus();
let mut out: Vec<Padic> = Vec::with_capacity(m * n);
for i in 0..m {
for j in 0..n {
let mut acc: u128 = 0;
for kk in 0..k {
let av = a.data[i * k + kk].residue;
let bv = b.data[kk * n + j].residue;
acc = (acc + av * bv) % modulus;
}
let cv = c.data[i * n + j].residue;
acc = (acc + cv) % modulus;
out.push(domain.element(acc));
}
}
Ok(Tensor {
meta: output_meta,
data: out,
})
}
fn dense_matmul_padic(
lhs: &Tensor<Padic>,
rhs: &Tensor<Padic>,
output_meta: crate::object::ObjectMeta,
) -> Result<Tensor<Padic>> {
let [m, k] = static_matrix_dims(&lhs.meta.shape, "left")?;
let [rhs_k, n] = static_matrix_dims(&rhs.meta.shape, "right")?;
if k != rhs_k {
return Err(Error::shape(format!(
"matmul inner dimension mismatch: left={k}, right={rhs_k}"
)));
}
if lhs.data.len() != m * k {
return Err(Error::backend(format!(
"left matmul data length mismatch: expected {}, got {}",
m * k,
lhs.data.len()
)));
}
if rhs.data.len() != rhs_k * n {
return Err(Error::backend(format!(
"right matmul data length mismatch: expected {}, got {}",
rhs_k * n,
rhs.data.len()
)));
}
let first = lhs
.data
.first()
.or_else(|| rhs.data.first())
.ok_or_else(|| Error::backend("cannot infer p-adic domain from empty matrices"))?;
let domain = PadicDomain {
meta: first.meta.clone(),
};
let mut data = Vec::with_capacity(m * n);
for row in 0..m {
for col in 0..n {
let mut acc = domain.element(0);
for inner in 0..k {
let product = domain.mul(&lhs.data[row * k + inner], &rhs.data[inner * n + col])?;
acc = domain.add(&acc, &product)?;
}
data.push(acc);
}
}
Ok(Tensor {
meta: output_meta,
data,
})
}
fn dense_matmul_padic_with_valuation_skip(
lhs: &Tensor<Padic>,
rhs: &Tensor<Padic>,
output_meta: crate::object::ObjectMeta,
) -> Result<(Tensor<Padic>, PadicMatmulValuationSkipReport)> {
let [m, k] = static_matrix_dims(&lhs.meta.shape, "left")?;
let [rhs_k, n] = static_matrix_dims(&rhs.meta.shape, "right")?;
if k != rhs_k {
return Err(Error::shape(format!(
"matmul inner dimension mismatch: left={k}, right={rhs_k}"
)));
}
if lhs.data.len() != m * k {
return Err(Error::backend(format!(
"left matmul data length mismatch: expected {}, got {}",
m * k,
lhs.data.len()
)));
}
if rhs.data.len() != rhs_k * n {
return Err(Error::backend(format!(
"right matmul data length mismatch: expected {}, got {}",
rhs_k * n,
rhs.data.len()
)));
}
let first = lhs
.data
.first()
.or_else(|| rhs.data.first())
.ok_or_else(|| Error::backend("cannot infer p-adic domain from empty matrices"))?;
let domain = PadicDomain {
meta: first.meta.clone(),
};
let lhs_matrix = domain.matrix(m, k, lhs.data.clone())?;
let rhs_matrix = domain.matrix(rhs_k, n, rhs.data.clone())?;
let certified = domain.certified_valuation_sparse_matrix_mul(&lhs_matrix, &rhs_matrix)?;
let report = PadicMatmulValuationSkipReport {
skipped_terms: certified.skipped_products,
evaluated_terms: certified.evaluated_products,
lhs_metadata: vec![certified.lhs_metadata],
rhs_metadata: vec![certified.rhs_metadata],
output_certificates: certified.output_certificates,
dense_oracle_matches: certified.dense_oracle_matches,
};
Ok((
Tensor {
meta: output_meta,
data: certified.output.data,
},
report,
))
}
pub fn dense_sub_i64(lhs: &Tensor<i64>, rhs: &Tensor<i64>) -> Result<Tensor<i64>> {
dense_binary_i64(lhs, rhs, |a, b| a.wrapping_sub(b))
}
pub fn dense_div_i64(lhs: &Tensor<i64>, rhs: &Tensor<i64>) -> Result<Tensor<i64>> {
if let Some((idx, _)) = rhs.data.iter().enumerate().find(|(_, v)| **v == 0) {
return Err(Error::domain(format!(
"div by zero: rhs.data[{idx}] = 0 (lhs shape={:?}, rhs shape={:?})",
lhs.meta.shape.dims, rhs.meta.shape.dims
)));
}
dense_binary_i64(lhs, rhs, |a, b| a.wrapping_div(b))
}
pub fn scalar_add_i64(lhs: &Tensor<i64>, scalar: i64) -> Result<Tensor<i64>> {
let mut data = Vec::with_capacity(lhs.data.len());
for v in &lhs.data {
data.push(v.wrapping_add(scalar));
}
Ok(Tensor {
meta: lhs.meta.clone(),
data,
})
}
pub fn scalar_mul_i64(lhs: &Tensor<i64>, scalar: i64) -> Result<Tensor<i64>> {
let mut data = Vec::with_capacity(lhs.data.len());
for v in &lhs.data {
data.push(v.wrapping_mul(scalar));
}
Ok(Tensor {
meta: lhs.meta.clone(),
data,
})
}
pub fn pow_i64(lhs: &Tensor<i64>, exp: i32) -> Result<Tensor<i64>> {
if exp < 0 {
return Err(Error::backend(format!(
"pow negative exponent {exp} not supported"
)));
}
let mut data = Vec::with_capacity(lhs.data.len());
for v in &lhs.data {
let mut acc: i64 = 1;
for _ in 0..exp {
acc = acc.wrapping_mul(*v);
}
data.push(acc);
}
Ok(Tensor {
meta: lhs.meta.clone(),
data,
})
}
pub fn sqrt_i64(lhs: &Tensor<i64>) -> Result<Tensor<i64>> {
if lhs.data.iter().any(|v| *v < 0) {
return Err(Error::backend("sqrt of negative value"));
}
let mut data = Vec::with_capacity(lhs.data.len());
for v in &lhs.data {
let mut r: i64 = 0;
let mut lo: i64 = 0;
let mut hi: i64 = 1 << 32;
while lo <= hi {
let mid = (lo + hi) / 2;
if mid * mid <= *v {
r = mid;
lo = mid + 1;
} else {
hi = mid - 1;
}
}
data.push(r);
}
Ok(Tensor {
meta: lhs.meta.clone(),
data,
})
}
pub fn exp2_i64(lhs: &Tensor<i64>) -> Result<Tensor<i64>> {
let mut data = Vec::with_capacity(lhs.data.len());
for v in &lhs.data {
if *v < 0 {
return Err(Error::backend("exp2 negative exponent"));
}
if *v >= 63 {
data.push(i64::MAX);
} else {
data.push(1_i64 << v);
}
}
Ok(Tensor {
meta: lhs.meta.clone(),
data,
})
}
pub fn log2_i64(lhs: &Tensor<i64>) -> Result<Tensor<i64>> {
let mut data = Vec::with_capacity(lhs.data.len());
for v in &lhs.data {
if *v <= 1 {
data.push(0);
} else {
data.push(63 - (*v as u64).leading_zeros() as i64);
}
}
Ok(Tensor {
meta: lhs.meta.clone(),
data,
})
}
pub fn shape_reshape_i64(input: &Tensor<i64>, new_dims: &[usize]) -> Result<Tensor<i64>> {
let expected: usize = new_dims.iter().product();
if expected != input.data.len() {
return Err(Error::backend(format!(
"reshape: target dims product {expected} does not match data length {}",
input.data.len()
)));
}
let new_shape = crate::object::Shape::new(
new_dims
.iter()
.map(|&n| Dim::Static(n))
.collect::<Vec<Dim>>(),
);
Ok(Tensor {
meta: ObjectMeta {
shape: new_shape,
..input.meta.clone()
},
data: input.data.clone(),
})
}
pub fn shape_transpose_i64(input: &Tensor<i64>, axis0: i64, axis1: i64) -> Result<Tensor<i64>> {
let rank = input.meta.shape.rank() as i64;
let a0 = if axis0 < 0 { rank + axis0 } else { axis0 } as usize;
let a1 = if axis1 < 0 { rank + axis1 } else { axis1 } as usize;
if a0 >= rank as usize || a1 >= rank as usize {
return Err(Error::backend(format!(
"transpose axis out of range: rank={rank}, got axis0={axis0}, axis1={axis1}"
)));
}
let dims: Vec<usize> = input
.meta
.shape
.dims
.iter()
.map(|d| {
d.value()
.ok_or_else(|| Error::shape("transpose requires all-static input dims"))
})
.collect::<Result<Vec<_>>>()?;
if a0 == a1 {
return Ok(input.clone());
}
let mut strides = vec![1usize; dims.len()];
for i in (0..dims.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * dims[i + 1];
}
let mut out_dims = dims.clone();
out_dims.swap(a0, a1);
let mut out_strides = vec![1usize; out_dims.len()];
for i in (0..out_dims.len().saturating_sub(1)).rev() {
out_strides[i] = out_strides[i + 1] * out_dims[i + 1];
}
let inv_perm: Vec<usize> = (0..dims.len())
.map(|d| {
if d == a0 {
a1
} else if d == a1 {
a0
} else {
d
}
})
.collect();
let total: usize = out_dims.iter().product();
let mut out = vec![0i64; total];
for (out_idx, item) in out.iter_mut().enumerate() {
let mut tmp = out_idx;
let mut in_idx = 0usize;
for d in 0..out_dims.len() {
let coord = tmp / out_strides[d];
tmp %= out_strides[d];
in_idx += coord * strides[inv_perm[d]];
}
*item = input.data[in_idx];
}
let new_shape = crate::object::Shape::new(
out_dims
.iter()
.map(|&n| Dim::Static(n))
.collect::<Vec<Dim>>(),
);
Ok(Tensor {
meta: ObjectMeta {
shape: new_shape,
..input.meta.clone()
},
data: out,
})
}
pub fn shape_permute_i64(input: &Tensor<i64>, permutation: &[i64]) -> Result<Tensor<i64>> {
let rank = input.meta.shape.rank();
if permutation.len() != rank {
return Err(Error::backend(format!(
"permute: permutation length {} does not match input rank {rank}",
permutation.len()
)));
}
let perm: Vec<usize> = permutation
.iter()
.map(|&p| {
let r = input.meta.shape.rank() as i64;
let q = if p < 0 { r + p } else { p };
if q < 0 || q >= r {
Err(Error::backend(format!(
"permute axis {p} out of range for rank {r}"
)))
} else {
Ok(q as usize)
}
})
.collect::<Result<Vec<_>>>()?;
let mut sorted = perm.clone();
sorted.sort();
for (i, &p) in sorted.iter().enumerate() {
if p != i {
return Err(Error::backend(format!(
"permute: permutation {permutation:?} is not a valid permutation of 0..{rank}"
)));
}
}
let dims: Vec<usize> = input
.meta
.shape
.dims
.iter()
.map(|d| {
d.value()
.ok_or_else(|| Error::shape("permute requires all-static input dims"))
})
.collect::<Result<Vec<_>>>()?;
let mut strides = vec![1usize; dims.len()];
for i in (0..dims.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * dims[i + 1];
}
let mut out_dims = vec![0usize; rank];
for (out_axis, &src_axis) in perm.iter().enumerate() {
out_dims[out_axis] = dims[src_axis];
}
let mut out_strides = vec![1usize; out_dims.len()];
for i in (0..out_dims.len().saturating_sub(1)).rev() {
out_strides[i] = out_strides[i + 1] * out_dims[i + 1];
}
let total: usize = out_dims.iter().product();
let mut out = vec![0i64; total];
for (out_idx, item) in out.iter_mut().enumerate() {
let mut tmp = out_idx;
let mut in_idx = 0usize;
for d in 0..out_dims.len() {
let coord = tmp / out_strides[d];
tmp %= out_strides[d];
in_idx += coord * strides[perm[d]];
}
*item = input.data[in_idx];
}
let new_shape = crate::object::Shape::new(
out_dims
.iter()
.map(|&n| Dim::Static(n))
.collect::<Vec<Dim>>(),
);
Ok(Tensor {
meta: ObjectMeta {
shape: new_shape,
..input.meta.clone()
},
data: out,
})
}
pub fn shape_slice_i64(
input: &Tensor<i64>,
axis: i64,
start: usize,
end: usize,
) -> Result<Tensor<i64>> {
let rank = input.meta.shape.rank() as i64;
let ax = if axis < 0 { rank + axis } else { axis } as usize;
let dims: Vec<usize> = input
.meta
.shape
.dims
.iter()
.map(|d| {
d.value()
.ok_or_else(|| Error::shape("slice requires all-static input dims"))
})
.collect::<Result<Vec<_>>>()?;
if ax >= rank as usize {
return Err(Error::backend(format!(
"slice: axis {axis} out of range for rank {rank}"
)));
}
let dim_size = dims[ax];
let resolved_end = if end == usize::MAX { dim_size } else { end };
if start > resolved_end {
return Err(Error::backend(format!(
"slice: start={start} > end={resolved_end}"
)));
}
if resolved_end > dim_size {
return Err(Error::backend(format!(
"slice: end={resolved_end} > dim_size={dim_size}"
)));
}
let mut out_dims = dims.clone();
out_dims[ax] = resolved_end - start;
let mut strides = vec![1usize; dims.len()];
for i in (0..dims.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * dims[i + 1];
}
let total: usize = out_dims.iter().product();
let mut out = vec![0i64; total];
for (out_idx, item) in out.iter_mut().enumerate() {
let mut tmp = out_idx;
let mut in_idx = 0usize;
for d in 0..out_dims.len() {
let coord = tmp / strides[d];
tmp %= strides[d];
in_idx += (coord + if d == ax { start } else { 0 }) * strides[d];
}
*item = input.data[in_idx];
}
let new_shape = crate::object::Shape::new(
out_dims
.iter()
.map(|&n| Dim::Static(n))
.collect::<Vec<Dim>>(),
);
Ok(Tensor {
meta: ObjectMeta {
shape: new_shape,
..input.meta.clone()
},
data: out,
})
}
pub fn shape_concat_i64(inputs: &[Tensor<i64>], axis: i64) -> Result<Tensor<i64>> {
if inputs.is_empty() {
return Err(Error::shape("concat: empty input list"));
}
let rank = inputs[0].meta.shape.rank() as i64;
let ax = if axis < 0 { rank + axis } else { axis } as usize;
if ax >= rank as usize {
return Err(Error::backend(format!(
"concat: axis {axis} out of range for rank {rank}"
)));
}
let base_dims: Vec<usize> = inputs[0]
.meta
.shape
.dims
.iter()
.map(|d| {
d.value()
.ok_or_else(|| Error::shape("concat requires all-static input dims"))
})
.collect::<Result<Vec<_>>>()?;
let mut out_dims = base_dims.clone();
out_dims[ax] = 0;
for t in inputs {
let d: Vec<usize> = t
.meta
.shape
.dims
.iter()
.map(|dd| {
dd.value()
.ok_or_else(|| Error::shape("concat requires all-static input dims"))
})
.collect::<Result<Vec<_>>>()?;
if d.len() != base_dims.len() {
return Err(Error::shape("concat: rank mismatch between inputs"));
}
for (i, (&a, &b)) in d.iter().zip(base_dims.iter()).enumerate() {
if i != ax && a != b {
return Err(Error::backend(format!(
"concat: non-concat axis {i} has mismatched size {a} vs {b}"
)));
}
}
out_dims[ax] += d[ax];
}
let total: usize = out_dims.iter().product();
let mut out = vec![0i64; total];
let mut strides = vec![1usize; out_dims.len()];
for i in (0..out_dims.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * out_dims[i + 1];
}
let mut offset = 0usize;
for t in inputs {
let in_dims: Vec<usize> = t
.meta
.shape
.dims
.iter()
.map(|dd| {
dd.value()
.ok_or_else(|| Error::shape("concat requires all-static input dims"))
})
.collect::<Result<Vec<_>>>()?;
let mut in_strides = vec![1usize; in_dims.len()];
for i in (0..in_dims.len().saturating_sub(1)).rev() {
in_strides[i] = in_strides[i + 1] * in_dims[i + 1];
}
let in_total: usize = in_dims.iter().product();
for in_idx in 0..in_total {
let mut tmp = in_idx;
let mut coords = vec![0usize; in_dims.len()];
for d in 0..in_dims.len() {
coords[d] = tmp / in_strides[d];
tmp %= in_strides[d];
}
let out_idx: usize = coords
.iter()
.enumerate()
.map(|(d, &c)| (c + if d == ax { offset } else { 0 }) * strides[d])
.sum();
out[out_idx] = t.data[in_idx];
}
offset += in_dims[ax];
}
let new_shape = crate::object::Shape::new(
out_dims
.iter()
.map(|&n| Dim::Static(n))
.collect::<Vec<Dim>>(),
);
Ok(Tensor {
meta: ObjectMeta {
shape: new_shape,
..inputs[0].meta.clone()
},
data: out,
})
}
pub fn shape_broadcast_i64(input: &Tensor<i64>, target_dims: &[usize]) -> Result<Tensor<i64>> {
let in_dims: Vec<usize> = input
.meta
.shape
.dims
.iter()
.map(|d| {
d.value()
.ok_or_else(|| Error::shape("broadcast requires all-static input dims"))
})
.collect::<Result<Vec<_>>>()?;
if in_dims.len() > target_dims.len() {
return Err(Error::backend(format!(
"broadcast: input rank {} > target rank {}",
in_dims.len(),
target_dims.len()
)));
}
let pad = target_dims.len() - in_dims.len();
let mut aligned = vec![1usize; pad];
aligned.extend(in_dims.iter().copied());
for (i, (&a, &b)) in aligned.iter().zip(target_dims.iter()).enumerate() {
if a != b && a != 1 {
return Err(Error::backend(format!(
"broadcast: dim {i} has incompatible size {a} vs target {b}"
)));
}
}
let out_dims: Vec<usize> = target_dims.to_vec();
let total: usize = out_dims.iter().product();
let mut out = vec![0i64; total];
let mut in_strides = vec![1usize; aligned.len()];
for i in (0..aligned.len().saturating_sub(1)).rev() {
in_strides[i] = in_strides[i + 1] * aligned[i + 1];
}
let mut out_strides = vec![1usize; out_dims.len()];
for i in (0..out_dims.len().saturating_sub(1)).rev() {
out_strides[i] = out_strides[i + 1] * out_dims[i + 1];
}
for (out_idx, item) in out.iter_mut().enumerate() {
let mut tmp = out_idx;
let mut in_idx = 0usize;
for d in 0..out_dims.len() {
let coord = tmp / out_strides[d];
tmp %= out_strides[d];
let src_coord = if aligned[d] == 1 { 0 } else { coord };
in_idx += src_coord * in_strides[d];
}
*item = input.data[in_idx];
}
let new_shape = crate::object::Shape::new(
out_dims
.iter()
.map(|&n| Dim::Static(n))
.collect::<Vec<Dim>>(),
);
Ok(Tensor {
meta: ObjectMeta {
shape: new_shape,
..input.meta.clone()
},
data: out,
})
}
pub fn shape_flatten_i64(input: &Tensor<i64>) -> Result<Tensor<i64>> {
let total = input.data.len();
let new_shape =
crate::object::Shape::new(vec![Dim::Static(if total == 0 { 0 } else { total })]);
Ok(Tensor {
meta: ObjectMeta {
shape: new_shape,
..input.meta.clone()
},
data: input.data.clone(),
})
}
pub fn shape_squeeze_i64(input: &Tensor<i64>) -> Result<Tensor<i64>> {
let dims: Vec<Dim> = input
.meta
.shape
.dims
.iter()
.filter(|d| !matches!(d, Dim::Static(1)))
.cloned()
.collect();
let new_shape = crate::object::Shape::new(dims);
Ok(Tensor {
meta: ObjectMeta {
shape: new_shape,
..input.meta.clone()
},
data: input.data.clone(),
})
}
pub fn shape_unsqueeze_i64(input: &Tensor<i64>, axis: i64) -> Result<Tensor<i64>> {
let rank = input.meta.shape.rank() as i64;
let ax = if axis < 0 { rank + 1 + axis } else { axis } as usize;
if ax > rank as usize {
return Err(Error::backend(format!(
"unsqueeze: axis {axis} out of range for rank {rank}"
)));
}
let mut new_dims: Vec<Dim> = input.meta.shape.dims.clone();
new_dims.insert(ax, Dim::Static(1));
let new_shape = crate::object::Shape::new(new_dims);
Ok(Tensor {
meta: ObjectMeta {
shape: new_shape,
..input.meta.clone()
},
data: input.data.clone(),
})
}
const NN_APPROX_SCALE: i64 = 1_000_000;
pub fn nn_relu_i64(input: &Tensor<i64>) -> Result<Tensor<i64>> {
let data = input
.data
.iter()
.map(|&v| if v > 0 { v } else { 0 })
.collect();
Ok(Tensor {
meta: input.meta.clone(),
data,
})
}
pub fn nn_sigmoid_i64(input: &Tensor<i64>) -> Result<Tensor<i64>> {
let mut data = Vec::with_capacity(input.data.len());
for &v in &input.data {
if v <= -8 {
data.push(0);
continue;
}
if v >= 8 {
data.push(NN_APPROX_SCALE);
continue;
}
let v_i = v;
let abs_v = v_i.unsigned_abs() as i64;
let out: i64 = if abs_v <= 4 {
let v_abs = abs_v;
let v_sq = v_abs * v_abs;
let v_cu = v_sq * v_abs;
let half = 32i64; let term1 = 16 * v_i; let term2 = -4 * v_abs * v_i; let term3 = v_i * v_cu / 4; let total =
half * NN_APPROX_SCALE / 64 + (term1 + term2 + term3) * NN_APPROX_SCALE / 64;
total
} else {
let slope = if v_i > 0 { 1 } else { -1 };
let step = (abs_v - 4) as i64;
let interp = (NN_APPROX_SCALE * 6 / 10) + step * (NN_APPROX_SCALE * 4 / 10) / 4;
if slope > 0 {
interp
} else {
NN_APPROX_SCALE - interp
}
};
data.push(out.clamp(0, NN_APPROX_SCALE));
}
Ok(Tensor {
meta: input.meta.clone(),
data,
})
}
pub fn nn_tanh_i64(input: &Tensor<i64>) -> Result<Tensor<i64>> {
let mut data = Vec::with_capacity(input.data.len());
for &v in &input.data {
if v >= 8 {
data.push(NN_APPROX_SCALE);
continue;
}
if v <= -8 {
data.push(-NN_APPROX_SCALE);
continue;
}
let v_i = v;
let v_sq = v_i * v_i;
let num = v_i * (27 + v_sq);
let den = 27 + 9 * v_sq;
let out = num * NN_APPROX_SCALE / den;
data.push(out.clamp(-NN_APPROX_SCALE, NN_APPROX_SCALE));
}
Ok(Tensor {
meta: input.meta.clone(),
data,
})
}
pub fn nn_gelu_i64(input: &Tensor<i64>) -> Result<Tensor<i64>> {
const COEFF_A: i64 = 797; const COEFF_B: i64 = 45; let mut data = Vec::with_capacity(input.data.len());
for &v in &input.data {
let v_i = v;
let v_sq = v_i * v_i;
let v_cu = v_sq * v_i;
let cubic = v_i + COEFF_B * v_cu / 1000;
let inner_arg = COEFF_A * cubic / 1000;
let tanh_arg = inner_arg;
let t_out: i64 = if tanh_arg >= 8 {
NN_APPROX_SCALE
} else if tanh_arg <= -8 {
-NN_APPROX_SCALE
} else {
let ts = tanh_arg;
let ts_sq = ts * ts;
let num = ts * (27 + ts_sq);
let den = 27 + 9 * ts_sq;
num * NN_APPROX_SCALE / den
};
let t_out = t_out.clamp(-NN_APPROX_SCALE, NN_APPROX_SCALE);
let numerator = v_i * (NN_APPROX_SCALE + t_out);
let out = numerator / 2;
data.push(out);
}
Ok(Tensor {
meta: input.meta.clone(),
data,
})
}
pub fn nn_softmax_i64(input: &Tensor<i64>) -> Result<Tensor<i64>> {
if input.data.is_empty() {
return Ok(input.clone());
}
let rank = input.meta.shape.rank();
if rank == 0 {
return Err(Error::shape("softmax: rank-0 input"));
}
let last_dim: usize = input
.meta
.shape
.dims
.last()
.and_then(|d| d.value())
.ok_or_else(|| Error::shape("softmax requires static last dim"))?;
let outer: usize = input.data.len() / last_dim;
if outer * last_dim != input.data.len() {
return Err(Error::backend(
"softmax: data length is not a multiple of last dim",
));
}
let exp2_fixed: [i64; 33] = {
let mut t = [0i64; 33];
t[0] = 1_000_000;
for k in 1..33 {
t[k] = t[k - 1] / 2;
}
t
};
let mut out = Vec::with_capacity(input.data.len());
for row in 0..outer {
let start = row * last_dim;
let row_slice = &input.data[start..start + last_dim];
let max_v = *row_slice.iter().max().unwrap();
let mut sum: i64 = 0;
let mut exps: Vec<i64> = Vec::with_capacity(last_dim);
for &v in row_slice {
let arg = (max_v - v) as usize; let exp = if arg >= 32 { 0 } else { exp2_fixed[arg] };
exps.push(exp);
sum = sum.saturating_add(exp);
}
for exp in exps {
let normalized = if sum == 0 {
NN_APPROX_SCALE / last_dim as i64
} else {
exp * NN_APPROX_SCALE / sum
};
out.push(normalized);
}
}
Ok(Tensor {
meta: input.meta.clone(),
data: out,
})
}
pub fn nn_layer_norm_i64(
input: &Tensor<i64>,
gamma: &Tensor<i64>,
beta: &Tensor<i64>,
) -> Result<Tensor<i64>> {
if input.data.is_empty() {
return Ok(input.clone());
}
let rank = input.meta.shape.rank();
if rank == 0 {
return Err(Error::shape("layer_norm: rank-0 input"));
}
let last_dim: usize = input
.meta
.shape
.dims
.last()
.and_then(|d| d.value())
.ok_or_else(|| Error::shape("layer_norm requires static last dim"))?;
if gamma.data.len() != last_dim {
return Err(Error::backend(format!(
"layer_norm: gamma length {} != last dim {last_dim}",
gamma.data.len()
)));
}
if beta.data.len() != last_dim {
return Err(Error::backend(format!(
"layer_norm: beta length {} != last dim {last_dim}",
beta.data.len()
)));
}
let outer: usize = input.data.len() / last_dim;
let mut out = Vec::with_capacity(input.data.len());
for row in 0..outer {
let start = row * last_dim;
let row_slice = &input.data[start..start + last_dim];
let sum: i64 = row_slice.iter().sum();
let mean = sum / last_dim as i64;
let var_sum: i64 = row_slice
.iter()
.map(|&v| {
let d = v - mean;
d * d
})
.sum();
let var = var_sum / last_dim as i64;
let stddev = integer_sqrt(var);
for (i, &v) in row_slice.iter().enumerate() {
let centered = v - mean;
let g = gamma.data[i];
let b = beta.data[i];
let normalized = if stddev == 0 {
b
} else {
centered * g / stddev + b
};
out.push(normalized);
}
}
Ok(Tensor {
meta: input.meta.clone(),
data: out,
})
}
fn integer_sqrt(v: i64) -> i64 {
if v <= 0 {
return 0;
}
let mut r: i64 = 0;
let mut lo: i64 = 0;
let mut hi: i64 = 1 << 32;
while lo <= hi {
let mid = (lo + hi) / 2;
if mid * mid <= v {
r = mid;
lo = mid + 1;
} else {
hi = mid - 1;
}
}
r
}
fn resolve_axis(axis: i64, rank: i64) -> Result<usize> {
let a = if axis < 0 { rank + axis } else { axis };
if a < 0 || a >= rank {
return Err(Error::backend(format!(
"index op axis {axis} out of range for rank {rank}"
)));
}
Ok(a as usize)
}
pub fn index_gather_i64(
input: &Tensor<i64>,
indices: &Tensor<i64>,
axis: i64,
) -> Result<Tensor<i64>> {
let rank = input.meta.shape.rank() as i64;
let ax = resolve_axis(axis, rank)?;
let in_dims: Vec<usize> = input
.meta
.shape
.dims
.iter()
.map(|d| {
d.value()
.ok_or_else(|| Error::backend("gather requires static input dims"))
})
.collect::<Result<Vec<_>>>()?;
let in_strides = row_major_strides(&in_dims);
let axis_size = in_dims[ax];
let out_dims: Vec<usize> = indices
.meta
.shape
.dims
.iter()
.map(|d| {
d.value()
.ok_or_else(|| Error::backend("gather requires static indices dims"))
})
.collect::<Result<Vec<_>>>()?;
let out_strides = row_major_strides(&out_dims);
let total: usize = out_dims.iter().product();
if total == 0 {
return Ok(Tensor {
meta: input.meta.clone(),
data: vec![],
});
}
let mut out = vec![0i64; total];
for (out_idx, item) in out.iter_mut().enumerate() {
let mut tmp = out_idx;
let mut in_idx = 0usize;
for d in 0..out_dims.len() {
let coord = tmp / out_strides[d];
tmp %= out_strides[d];
if d == ax {
let raw = indices.data[out_idx];
let resolved = if raw < 0 { axis_size as i64 + raw } else { raw };
if resolved < 0 || resolved as usize >= axis_size {
return Err(Error::backend(format!(
"gather: index {raw} out of range for axis size {axis_size}"
)));
}
in_idx += (resolved as usize) * in_strides[d];
} else {
in_idx += coord * in_strides[d];
}
}
*item = input.data[in_idx];
}
let out_shape = crate::object::Shape::new(
out_dims
.iter()
.map(|&n| Dim::Static(n))
.collect::<Vec<Dim>>(),
);
Ok(Tensor {
meta: ObjectMeta {
shape: out_shape,
..input.meta.clone()
},
data: out,
})
}
pub fn index_scatter_i64(
input: &Tensor<i64>,
indices: &Tensor<i64>,
axis: i64,
) -> Result<Tensor<i64>> {
let rank = input.meta.shape.rank() as i64;
let ax = resolve_axis(axis, rank)?;
let in_dims: Vec<usize> = input
.meta
.shape
.dims
.iter()
.map(|d| {
d.value()
.ok_or_else(|| Error::backend("scatter requires static input dims"))
})
.collect::<Result<Vec<_>>>()?;
let in_strides = row_major_strides(&in_dims);
let axis_size = in_dims[ax];
if indices.data.is_empty() {
return Ok(input.clone());
}
let total: usize = in_dims.iter().product();
if indices.data.len() != total {
return Err(Error::backend(format!(
"scatter: indices length {} != input data length {total}",
indices.data.len()
)));
}
let mut out = vec![0i64; total];
for i in 0..total {
let raw = indices.data[i];
let resolved = if raw < 0 { axis_size as i64 + raw } else { raw };
if resolved < 0 || resolved as usize >= axis_size {
return Err(Error::backend(format!(
"scatter: index {raw} out of range for axis size {axis_size}"
)));
}
let mut tmp = i;
let mut dest = 0usize;
for d in 0..in_dims.len() {
let coord = tmp / in_strides[d];
tmp %= in_strides[d];
let c = if d == ax { resolved as usize } else { coord };
dest += c * in_strides[d];
}
out[dest] = input.data[i];
}
Ok(Tensor {
meta: input.meta.clone(),
data: out,
})
}
pub fn index_select_i64(
input: &Tensor<i64>,
indices: &Tensor<i64>,
axis: i64,
) -> Result<Tensor<i64>> {
let rank = input.meta.shape.rank() as i64;
let ax = resolve_axis(axis, rank)?;
let in_dims: Vec<usize> = input
.meta
.shape
.dims
.iter()
.map(|d| {
d.value()
.ok_or_else(|| Error::backend("index_select requires static input dims"))
})
.collect::<Result<Vec<_>>>()?;
let axis_size = in_dims[ax];
let in_strides = row_major_strides(&in_dims);
let mut out_dims = in_dims.clone();
out_dims[ax] = indices.data.len();
let out_strides = row_major_strides(&out_dims);
let total: usize = out_dims.iter().product();
if total == 0 {
let out_shape = crate::object::Shape::new(
out_dims
.iter()
.map(|&n| Dim::Static(n))
.collect::<Vec<Dim>>(),
);
return Ok(Tensor {
meta: ObjectMeta {
shape: out_shape,
..input.meta.clone()
},
data: vec![],
});
}
let mut out = vec![0i64; total];
for (out_idx, item) in out.iter_mut().enumerate() {
let mut tmp = out_idx;
let mut coords = vec![0usize; out_dims.len()];
for d in 0..out_dims.len() {
coords[d] = tmp / out_strides[d];
tmp %= out_strides[d];
}
let raw = indices.data[coords[ax]];
let resolved = if raw < 0 { axis_size as i64 + raw } else { raw };
if resolved < 0 || resolved as usize >= axis_size {
return Err(Error::backend(format!(
"index_select: index {raw} out of range for axis size {axis_size}"
)));
}
let mut in_idx = 0usize;
for d in 0..in_dims.len() {
let c = if d == ax {
resolved as usize
} else {
coords[d]
};
in_idx += c * in_strides[d];
}
*item = input.data[in_idx];
}
let out_shape = crate::object::Shape::new(
out_dims
.iter()
.map(|&n| Dim::Static(n))
.collect::<Vec<Dim>>(),
);
Ok(Tensor {
meta: ObjectMeta {
shape: out_shape,
..input.meta.clone()
},
data: out,
})
}
pub fn index_add_i64(
input: &Tensor<i64>,
indices: &Tensor<i64>,
source: &Tensor<i64>,
axis: i64,
) -> Result<Tensor<i64>> {
let rank = input.meta.shape.rank() as i64;
let ax = resolve_axis(axis, rank)?;
let in_dims: Vec<usize> = input
.meta
.shape
.dims
.iter()
.map(|d| {
d.value()
.ok_or_else(|| Error::backend("index_add requires static input dims"))
})
.collect::<Result<Vec<_>>>()?;
let in_strides = row_major_strides(&in_dims);
let axis_size = in_dims[ax];
if indices.data.len() != source.data.len() {
return Err(Error::backend(format!(
"index_add: indices length {} != source length {}",
indices.data.len(),
source.data.len()
)));
}
let mut out = input.data.clone();
for i in 0..indices.data.len() {
let raw = indices.data[i];
let resolved = if raw < 0 { axis_size as i64 + raw } else { raw };
if resolved < 0 || resolved as usize >= axis_size {
return Err(Error::backend(format!(
"index_add: index {raw} out of range for axis size {axis_size}"
)));
}
let src_dims: Vec<usize> = source
.meta
.shape
.dims
.iter()
.map(|d| {
d.value()
.ok_or_else(|| Error::backend("index_add requires static source dims"))
})
.collect::<Result<Vec<_>>>()?;
let src_strides = row_major_strides(&src_dims);
let src_total: usize = src_dims.iter().product();
if i >= src_total {
return Err(Error::backend("index_add: index out of source range"));
}
let mut tmp = i;
let mut coords = vec![0usize; src_dims.len()];
for d in 0..src_dims.len() {
coords[d] = tmp / src_strides[d];
tmp %= src_strides[d];
}
let mut dest = 0usize;
for d in 0..in_dims.len() {
let c = if d == ax {
resolved as usize
} else {
coords[d]
};
dest += c * in_strides[d];
}
out[dest] = out[dest].wrapping_add(source.data[i]);
}
Ok(Tensor {
meta: input.meta.clone(),
data: out,
})
}
pub fn index_nonzero_i64(input: &Tensor<i64>) -> Result<Tensor<i64>> {
let rank = input.meta.shape.rank();
let in_dims: Vec<usize> = input
.meta
.shape
.dims
.iter()
.map(|d| {
d.value()
.ok_or_else(|| Error::backend("nonzero requires static input dims"))
})
.collect::<Result<Vec<_>>>()?;
let strides = row_major_strides(&in_dims);
let total: usize = in_dims.iter().product();
let mut out: Vec<i64> = Vec::new();
for i in 0..total {
if input.data[i] != 0 {
let mut tmp = i;
for d in 0..rank {
let coord = (tmp / strides[d]) as i64;
tmp %= strides[d];
out.push(coord);
}
}
}
let n = out.len() / if rank == 0 { 1 } else { rank };
let out_shape = crate::object::Shape::new(vec![
Dim::Static(n),
Dim::Static(if rank == 0 { 1 } else { rank }),
]);
Ok(Tensor {
meta: ObjectMeta {
shape: out_shape,
..input.meta.clone()
},
data: out,
})
}
fn row_major_strides(dims: &[usize]) -> Vec<usize> {
let mut strides = vec![1usize; dims.len()];
for i in (0..dims.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * dims[i + 1];
}
strides
}
use crate::op::reductions::ReductionKind;
pub fn reduce_i64(input: &Tensor<i64>, axis: i64, kind: ReductionKind) -> Result<Tensor<i64>> {
let rank = input.meta.shape.rank() as i64;
if rank == 0 {
return Err(Error::shape("reduce: rank-0 input"));
}
let ax = if axis < 0 { rank + axis } else { axis } as usize;
if ax >= rank as usize {
return Err(Error::backend(format!(
"reduce: axis {axis} out of range for rank {rank}"
)));
}
let in_dims: Vec<usize> = input
.meta
.shape
.dims
.iter()
.map(|d| {
d.value()
.ok_or_else(|| Error::shape("reduce requires static input dims"))
})
.collect::<Result<Vec<_>>>()?;
let axis_size = in_dims[ax];
if axis_size == 0 {
return Err(Error::shape("reduce: empty axis"));
}
let outer: usize = in_dims.iter().take(ax).product();
let inner: usize = in_dims.iter().skip(ax + 1).product();
let out_total = outer * inner;
let mut out = Vec::with_capacity(out_total);
let data = &input.data;
for o in 0..outer {
for i in 0..inner {
let base = o * axis_size * inner + i;
let step = inner;
let mut slice = (0..axis_size).map(|k| data[base + k * step]);
let reduced = match kind {
ReductionKind::Sum => slice.sum::<i64>(),
ReductionKind::Mean => {
let sum: i64 = slice.sum();
let n = axis_size as i64;
if sum >= 0 {
sum / n
} else {
-(((-sum) + n - 1) / n)
}
}
ReductionKind::Max => slice.max().unwrap_or(0),
ReductionKind::Min => slice.min().unwrap_or(0),
ReductionKind::ArgMax => {
let mut best_v = i64::MIN;
let mut best_k: i64 = 0;
for (k, v) in slice.enumerate() {
if v > best_v {
best_v = v;
best_k = k as i64;
}
}
best_k
}
ReductionKind::ArgMin => {
let mut best_v = i64::MAX;
let mut best_k: i64 = 0;
for (k, v) in slice.enumerate() {
if v < best_v {
best_v = v;
best_k = k as i64;
}
}
best_k
}
ReductionKind::Prod => {
let mut acc: i64 = 1;
for v in slice {
acc = acc.wrapping_mul(v);
}
acc
}
ReductionKind::Any => {
if slice.any(|v| v != 0) {
1
} else {
0
}
}
ReductionKind::All => {
if slice.all(|v| v != 0) {
1
} else {
0
}
}
};
out.push(reduced);
}
}
let out_dims: Vec<Dim> = in_dims
.iter()
.enumerate()
.filter(|(i, _)| *i != ax)
.map(|(_, &d)| Dim::Static(d))
.collect();
let out_shape = crate::object::Shape::new(out_dims);
Ok(Tensor {
meta: ObjectMeta {
shape: out_shape,
..input.meta.clone()
},
data: out,
})
}