tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Reduction operators.
//!
//! These are the named reductions called out by P339. Each is a
//! thin op struct that names a `ReductionKind`; the CPU scalar
//! backend uses a single generic helper to perform the reduction
//! over the named axis (default: remove the axis, return an
//! `(n-1)`-D tensor; the planner keeps the input shape and the
//! lowering produces the right output).
//!
//! Reductions covered:
//! - `Sum`, `Mean`     — additive / affine reductions
//! - `Max`, `Min`      — total-order reductions
//! - `ArgMax`, `ArgMin`— index reductions
//! - `Prod`            — multiplicative reduction
//! - `Any`, `All`      — boolean reductions (nonzero / all-nonzero)
//!
//! `Mean` uses i64 floor division on the sum. `Prod` uses wrapping
//! i64 multiplication. `Any` and `All` treat any nonzero element as
//! true.
//!
//! All reductions are `Deterministic`; `Sum`/`Mean`/`Max`/`Min`/
//! `Prod`/`Any`/`All` preserve the integer element type and are
//! therefore `Exact`. `ArgMax`/`ArgMin` emit i64 indices, also
//! `Exact`.

use crate::domain::{Claim, Contract, ContractId, ContractSet, Evidence, Scope};
use crate::object::{Dim, ObjectKind, ObjectMeta, Shape};
use crate::{Error, Result};

use super::{LayerBehavior, OpInput, OpOutput, OpSignature, Operator};

// ---------------------------------------------------------------------------
// Shared helpers
// ---------------------------------------------------------------------------

/// The named reduction kinds. All CPU lowerings dispatch through
/// this enum so we can share a single generic helper.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReductionKind {
    Sum,
    Mean,
    Max,
    Min,
    ArgMax,
    ArgMin,
    Prod,
    Any,
    All,
}

impl ReductionKind {
    pub fn op_name(&self) -> &'static str {
        match self {
            ReductionKind::Sum => "sum",
            ReductionKind::Mean => "mean",
            ReductionKind::Max => "max",
            ReductionKind::Min => "min",
            ReductionKind::ArgMax => "argmax",
            ReductionKind::ArgMin => "argmin",
            ReductionKind::Prod => "prod",
            ReductionKind::Any => "any",
            ReductionKind::All => "all",
        }
    }
}

fn deterministic_exact_contracts(scope: Scope) -> ContractSet {
    ContractSet::from_iter([
        Contract::new(
            ContractId(60),
            Claim::Deterministic,
            scope.clone(),
            Evidence::Axiom,
        ),
        Contract::new(ContractId(61), Claim::Exact, scope, Evidence::Axiom),
    ])
}

fn binary_signature_with_axis() -> OpSignature {
    OpSignature {
        inputs: vec![
            OpInput {
                name: "input".to_string(),
            },
            OpInput {
                name: "axis".to_string(),
            },
        ],
        outputs: vec![OpOutput {
            name: "out".to_string(),
        }],
    }
}

fn check_two_tensor_inputs(op: &str, inputs: &[ObjectMeta]) -> Result<()> {
    if inputs.len() != 2 {
        return Err(Error::operator(format!(
            "{op} expects 2 inputs (input, axis), got {}",
            inputs.len()
        )));
    }
    for (i, m) in inputs.iter().enumerate() {
        if m.object_kind != ObjectKind::Tensor {
            return Err(Error::operator(format!(
                "{op} only supports tensor inputs, input {i} is {:?}",
                m.object_kind
            )));
        }
    }
    Ok(())
}

// ---------------------------------------------------------------------------
// Macro: declare a reduction op that wraps a ReductionKind.
// ---------------------------------------------------------------------------

macro_rules! reduction_op {
    ($name:ident, $kind:ident) => {
        #[doc = concat!("Named reduction: `", stringify!($kind), "` over a single axis.")]
        #[derive(Debug, Clone, Copy, Default)]
        pub struct $name;

        impl Operator for $name {
            fn name(&self) -> &'static str {
                ReductionKind::$kind.op_name()
            }

            fn signature(&self) -> OpSignature {
                binary_signature_with_axis()
            }

            fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
                check_two_tensor_inputs(self.name(), inputs)?;
                let rank = inputs[0].shape.rank();
                if rank == 0 {
                    return Err(Error::operator(format!(
                        "{}: rank-0 input cannot be reduced over an axis",
                        self.name()
                    )));
                }
                // We don't know the runtime axis, so the planning
                // step produces a DataDependent shape. The CPU
                // lowering computes the actual output shape.
                let mut output = inputs[0].clone();
                let mut dims = Vec::with_capacity(rank.saturating_sub(1));
                for d in &inputs[0].shape.dims {
                    dims.push(Dim::DataDependent(format!("{}_reduced", self.name())));
                    let _ = d; // keep borrow-checker happy
                }
                dims.clear();
                output.shape = Shape::new(dims);
                Ok(vec![output])
            }

            fn required_contracts(&self) -> ContractSet {
                ContractSet::new()
            }
            fn provided_contracts(&self) -> ContractSet {
                deterministic_exact_contracts(Scope::Operator(self.name().to_string()))
            }
            fn layer_behavior(&self) -> Vec<LayerBehavior> {
                vec![LayerBehavior::Pointwise, LayerBehavior::Global]
            }
        }
    };
}

reduction_op!(SumOp, Sum);
reduction_op!(MeanOp, Mean);
reduction_op!(MaxOp, Max);
reduction_op!(MinOp, Min);
reduction_op!(ArgMaxOp, ArgMax);
reduction_op!(ArgMinOp, ArgMin);
reduction_op!(ProdOp, Prod);
reduction_op!(AnyOp, Any);
reduction_op!(AllOp, All);