tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Index-manipulation operators.
//!
//! These ops let users reach into a tensor by integer indices.
//! They mirror the PyTorch family of `gather`, `scatter`,
//! `index_select`, `index_add`, and the boolean `nonzero` reduction.
//!
//! `Gather`, `Scatter`, `IndexSelect`, and `IndexAdd` all take an
//! `axis` argument as the final input (a 1-element i32 tensor).
//! `Nonzero` is axis-less and returns a 2-D i64 tensor whose first
//! dim is the number of nonzero elements and whose second dim is
//! the input rank.
//!
//! Conventions:
//! - `Gather(input, indices, axis)`: for each position in the
//!   output, the value at `axis` is replaced by `indices[i, ..., j, ...]`
//!   and the result is read from `input`. Output shape equals
//!   `indices.shape`.
//! - `Scatter(input, indices, axis)`: for each position in the
//!   output, write `input[i, ..., j, ...]` to
//!   `output[indices[i, ..., j, ...], ...]`. Output shape equals
//!   `input.shape`. Positions that are not written remain 0.
//! - `IndexSelect(input, indices, axis)`: select a slice along
//!   `axis`. Output shape is `input.shape` with the indexed dim
//!   replaced by `indices.shape` (i.e. `indices.len()`).
//! - `IndexAdd(input, indices, source, axis)`: start from `input`,
//!   then add `source[i, ...]` to `input[indices[i, ...], ...]`
//!   along `axis`. Output shape equals `input.shape`.
//! - `Nonzero(input)`: 2-D output `[N, rank]` where N is the count
//!   of nonzero elements. The N depends on the runtime data so the
//!   op returns a `DataDependent` shape.
//!
//! All five ops preserve the i64 element type; the only arithmetic
//! beyond data movement is `IndexAdd`'s addition. They are
//! therefore declared `Deterministic + Exact` (no floating-point
//! approximation is involved).

use crate::domain::{Claim, Contract, ContractId, ContractSet, Evidence, Scope};
use crate::object::{Dim, ObjectKind, ObjectMeta, Shape};
use crate::{Error, Result};

use super::{LayerBehavior, OpInput, OpOutput, OpSignature, Operator};

// ---------------------------------------------------------------------------
// Shared helpers
// ---------------------------------------------------------------------------

fn deterministic_exact_contracts(scope: Scope) -> ContractSet {
    ContractSet::from_iter([
        Contract::new(
            ContractId(50),
            Claim::Deterministic,
            scope.clone(),
            Evidence::Axiom,
        ),
        Contract::new(ContractId(51), Claim::Exact, scope, Evidence::Axiom),
    ])
}

fn tensor_input_check(op: &str, inputs: &[ObjectMeta], expected: usize) -> Result<()> {
    if inputs.len() != expected {
        return Err(Error::operator(format!(
            "{op} expects {expected} input(s), got {}",
            inputs.len()
        )));
    }
    for (i, m) in inputs.iter().enumerate() {
        if m.object_kind != ObjectKind::Tensor {
            return Err(Error::operator(format!(
                "{op} only supports tensor inputs, input {i} is {:?}",
                m.object_kind
            )));
        }
    }
    Ok(())
}

// ---------------------------------------------------------------------------
// Gather
// ---------------------------------------------------------------------------

/// Gather values from `input` along `axis` using integer `indices`.
/// Output shape equals `indices.shape`; for each output position
/// `(i_0, i_1, ..., i_{N-1})` the source coord on `axis` is
/// `indices[i_0, ..., i_{N-1}]` and the other coords pass through.
#[derive(Debug, Clone, Copy, Default)]
pub struct GatherOp;

impl Operator for GatherOp {
    fn name(&self) -> &'static str {
        "gather"
    }

    fn signature(&self) -> OpSignature {
        OpSignature {
            inputs: vec![
                OpInput {
                    name: "input".to_string(),
                },
                OpInput {
                    name: "indices".to_string(),
                },
                OpInput {
                    name: "axis".to_string(),
                },
            ],
            outputs: vec![OpOutput {
                name: "out".to_string(),
            }],
        }
    }

    fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
        tensor_input_check(self.name(), inputs, 3)?;
        // Output shape equals indices.shape. We don't know the axis
        // value at plan time, so we can't refine the dim-count
        // assertion; the CPU lowering will validate the axis at
        // execution time.
        let mut output = inputs[1].clone();
        output.domain = inputs[0].domain.clone();
        Ok(vec![output])
    }

    fn required_contracts(&self) -> ContractSet {
        ContractSet::new()
    }
    fn provided_contracts(&self) -> ContractSet {
        deterministic_exact_contracts(Scope::Operator(self.name().to_string()))
    }
    fn layer_behavior(&self) -> Vec<LayerBehavior> {
        vec![LayerBehavior::Pointwise, LayerBehavior::Global]
    }
}

// ---------------------------------------------------------------------------
// Scatter
// ---------------------------------------------------------------------------

/// Scatter `input` into a fresh zero buffer along `axis` using
/// `indices`. The output shape equals `input.shape`. Positions that
/// are not addressed by any index remain 0.
#[derive(Debug, Clone, Copy, Default)]
pub struct ScatterOp;

impl Operator for ScatterOp {
    fn name(&self) -> &'static str {
        "scatter"
    }

    fn signature(&self) -> OpSignature {
        OpSignature {
            inputs: vec![
                OpInput {
                    name: "input".to_string(),
                },
                OpInput {
                    name: "indices".to_string(),
                },
                OpInput {
                    name: "axis".to_string(),
                },
            ],
            outputs: vec![OpOutput {
                name: "out".to_string(),
            }],
        }
    }

    fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
        tensor_input_check(self.name(), inputs, 3)?;
        // Output shape equals input.shape.
        Ok(vec![inputs[0].clone()])
    }

    fn required_contracts(&self) -> ContractSet {
        ContractSet::new()
    }
    fn provided_contracts(&self) -> ContractSet {
        deterministic_exact_contracts(Scope::Operator(self.name().to_string()))
    }
    fn layer_behavior(&self) -> Vec<LayerBehavior> {
        vec![LayerBehavior::Pointwise, LayerBehavior::Global]
    }
}

// ---------------------------------------------------------------------------
// IndexSelect
// ---------------------------------------------------------------------------

/// Select a slice along `axis` from `input` using `indices` (a
/// 1-D i64 tensor of arbitrary length). The output shape equals
/// `input.shape` with the indexed dim replaced by
/// `indices.shape` (i.e. `indices.len()` when rank 1).
#[derive(Debug, Clone, Copy, Default)]
pub struct IndexSelectOp;

impl Operator for IndexSelectOp {
    fn name(&self) -> &'static str {
        "index_select"
    }

    fn signature(&self) -> OpSignature {
        OpSignature {
            inputs: vec![
                OpInput {
                    name: "input".to_string(),
                },
                OpInput {
                    name: "indices".to_string(),
                },
                OpInput {
                    name: "axis".to_string(),
                },
            ],
            outputs: vec![OpOutput {
                name: "out".to_string(),
            }],
        }
    }

    fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
        tensor_input_check(self.name(), inputs, 3)?;
        // We don't know the runtime axis. We mark the indexed dim
        // as DataDependent (length comes from indices) and keep the
        // rest of the input shape.
        let input_rank = inputs[0].shape.rank();
        let mut output = inputs[0].clone();
        // Replace every dim with a DataDependent placeholder so the
        // CPU lowering can produce the right shape; downstream
        // static shape inference will be limited but execution is
        // well-defined.
        output.shape = Shape::new(
            (0..input_rank)
                .map(|_| Dim::DataDependent("index_select_axis".to_string()))
                .collect::<Vec<Dim>>(),
        );
        Ok(vec![output])
    }

    fn required_contracts(&self) -> ContractSet {
        ContractSet::new()
    }
    fn provided_contracts(&self) -> ContractSet {
        deterministic_exact_contracts(Scope::Operator(self.name().to_string()))
    }
    fn layer_behavior(&self) -> Vec<LayerBehavior> {
        vec![LayerBehavior::Pointwise, LayerBehavior::Global]
    }
}

// ---------------------------------------------------------------------------
// IndexAdd
// ---------------------------------------------------------------------------

/// Add `source` into `input` along `axis` at the positions named by
/// `indices`. `input` is the initial value (mutated in place
/// semantically), `source` provides the values to add. The output
/// shape equals `input.shape`.
#[derive(Debug, Clone, Copy, Default)]
pub struct IndexAddOp;

impl Operator for IndexAddOp {
    fn name(&self) -> &'static str {
        "index_add"
    }

    fn signature(&self) -> OpSignature {
        OpSignature {
            inputs: vec![
                OpInput {
                    name: "input".to_string(),
                },
                OpInput {
                    name: "indices".to_string(),
                },
                OpInput {
                    name: "source".to_string(),
                },
                OpInput {
                    name: "axis".to_string(),
                },
            ],
            outputs: vec![OpOutput {
                name: "out".to_string(),
            }],
        }
    }

    fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
        tensor_input_check(self.name(), inputs, 4)?;
        Ok(vec![inputs[0].clone()])
    }

    fn required_contracts(&self) -> ContractSet {
        ContractSet::new()
    }
    fn provided_contracts(&self) -> ContractSet {
        deterministic_exact_contracts(Scope::Operator(self.name().to_string()))
    }
    fn layer_behavior(&self) -> Vec<LayerBehavior> {
        vec![LayerBehavior::Pointwise, LayerBehavior::Global]
    }
}

// ---------------------------------------------------------------------------
// Nonzero
// ---------------------------------------------------------------------------

/// Return the indices of nonzero elements as a 2-D i64 tensor of
/// shape `[N, rank]`, where `N` is the number of nonzero elements.
/// The op is data-dependent: both dims are `DataDependent` because
/// the output shape depends on the input data.
#[derive(Debug, Clone, Copy, Default)]
pub struct NonzeroOp;

impl Operator for NonzeroOp {
    fn name(&self) -> &'static str {
        "nonzero"
    }

    fn signature(&self) -> OpSignature {
        OpSignature {
            inputs: vec![OpInput {
                name: "input".to_string(),
            }],
            outputs: vec![OpOutput {
                name: "out".to_string(),
            }],
        }
    }

    fn infer(&self, inputs: &[ObjectMeta]) -> Result<Vec<ObjectMeta>> {
        tensor_input_check(self.name(), inputs, 1)?;
        let mut output = inputs[0].clone();
        let rank = inputs[0].shape.rank();
        output.shape = Shape::new(vec![
            Dim::DataDependent("nonzero_count".to_string()),
            if rank == 0 {
                Dim::Static(1)
            } else {
                Dim::DataDependent("nonzero_rank".to_string())
            },
        ]);
        Ok(vec![output])
    }

    fn required_contracts(&self) -> ContractSet {
        ContractSet::new()
    }
    fn provided_contracts(&self) -> ContractSet {
        deterministic_exact_contracts(Scope::Operator(self.name().to_string()))
    }
    fn layer_behavior(&self) -> Vec<LayerBehavior> {
        vec![LayerBehavior::Pointwise, LayerBehavior::Global]
    }
}