tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Graph-builder DSL: `Tokitai::graph(|g| { ... })`.
//!
//! This is a thin closure-based sugar over the imperative
//! `SemanticGraph` API. The DSL compiles down to the same
//! `SemanticGraph` (and the same planner/executor pipeline); it
//! just reads more like a single expression.
//!
//! The DSL is intentionally minimal:
//! - `g.input(meta)` adds a typed input and returns a `TensorHandle`
//! - `g.input_tensor(t)` adds a typed input and pre-populates the
//!   store with the supplied tensor
//! - `g.op(op, &[h1, h2, ...])` adds an op and returns its first
//!   output as a `TensorHandle`
//! - the closure returns a `Result<TensorHandle>` (the final output)
//!
//! `Tokitai::graph(...)` builds the graph, plans it, executes it
//! on the default `CpuScalarBackend`, and returns a `CompiledGraph`
//! that exposes the output tensor to the caller.

use crate::backend::cpu::CpuScalarBackend;
use crate::backend::{Backend, TensorStore};
use crate::ir::SemanticGraph;
use crate::object::{ObjectMeta, Tensor};
use crate::op::Operator;
use crate::planner::ExecutionPlan;
use crate::planner::HeuristicPlanner;
use crate::{Error, Result};

/// Opaque handle to a value inside a `GraphBuilder`. Wraps the
/// node id of the corresponding value in the underlying
/// `SemanticGraph`.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TensorHandle {
    id: usize,
}

impl TensorHandle {
    /// Underlying value id. The user typically does not need to
    /// call this; the DSL uses it internally.
    pub fn id(&self) -> usize {
        self.id
    }
}

/// In-progress graph being built by a `Tokitai::graph(...)`
/// closure. Owns the underlying `SemanticGraph` and a
/// `TensorStore<i64>` that accumulates input data.
pub struct GraphBuilder {
    graph: SemanticGraph,
    store: TensorStore<i64>,
}

impl GraphBuilder {
    /// Add a new input with the supplied metadata. The caller is
    /// responsible for inserting the actual tensor data into the
    /// store before executing the graph.
    pub fn input(&mut self, meta: ObjectMeta) -> Result<TensorHandle> {
        let id = self.graph.add_input(meta);
        Ok(TensorHandle { id })
    }

    /// Add a new input and pre-populate the store with the
    /// supplied tensor. This is the ergonomic single-call
    /// equivalent of `add_input` + `store.insert`.
    pub fn input_tensor(&mut self, t: Tensor<i64>) -> Result<TensorHandle> {
        let id = self.graph.add_input(t.meta.clone());
        self.store.insert(id, t);
        Ok(TensorHandle { id })
    }

    /// Add an op that takes the given input handles and returns
    /// the first output as a handle. Returns an error if the op
    /// does not produce exactly one output or if any input handle
    /// is unknown.
    pub fn op<O: Operator>(&mut self, op: O, inputs: &[TensorHandle]) -> Result<TensorHandle> {
        let ids: Vec<usize> = inputs.iter().map(|h| h.id).collect();
        let outputs = self.graph.add_op(op, &ids)?;
        if outputs.len() != 1 {
            return Err(Error::ir(format!(
                "DSL `op` requires exactly one output, got {}",
                outputs.len()
            )));
        }
        Ok(TensorHandle { id: outputs[0] })
    }

    /// Borrow the underlying `SemanticGraph`. Useful for callers
    /// that want to compose the DSL-built graph with the rest of
    /// the public API.
    pub fn graph(&self) -> &SemanticGraph {
        &self.graph
    }

    /// Borrow the underlying `TensorStore<i64>`. Callers can use
    /// this to inspect or populate the store before
    /// `Tokitai::graph` executes.
    pub fn store(&self) -> &TensorStore<i64> {
        &self.store
    }

    /// Mutable borrow of the underlying `TensorStore<i64>`. This
    /// is the escape hatch for populating inputs that were
    /// declared with `input` (rather than `input_tensor`).
    pub fn store_mut(&mut self) -> &mut TensorStore<i64> {
        &mut self.store
    }

    // ---- P436 DSL sugar: per-op builder methods on GraphBuilder ----
    //
    // These wrap the generic `g.op(...)` call with a typed builder
    // for each high-traffic op family. The closure-based DSL is the
    // canonical user entry point; these methods make a small graph
    // read as a sequence of method calls instead of a flat
    // `g.op(AddOp, &[a, b])` chain.
    //
    // Each method returns a single `TensorHandle` (the op's first
    // output) and propagates the underlying `g.op` error.

    /// Elementwise `lhs + rhs` (arithmetic, P335).
    pub fn add(&mut self, lhs: TensorHandle, rhs: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::AddOp, &[lhs, rhs])
    }

    /// Elementwise `lhs - rhs` (arithmetic, P335).
    pub fn sub(&mut self, lhs: TensorHandle, rhs: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::SubOp, &[lhs, rhs])
    }

    /// Elementwise `lhs * rhs` (arithmetic, P335).
    pub fn mul(&mut self, lhs: TensorHandle, rhs: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::MulOp, &[lhs, rhs])
    }

    /// Elementwise `lhs / rhs` (arithmetic, P335).
    pub fn div(&mut self, lhs: TensorHandle, rhs: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::DivOp, &[lhs, rhs])
    }

    /// Reshape to a target shape (shape, P336).
    pub fn reshape(&mut self, input: TensorHandle, target: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::ReshapeOp, &[input, target])
    }

    /// Transpose two axes (shape, P336).
    pub fn transpose(&mut self, input: TensorHandle, axes: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::TransposeOp, &[input, axes])
    }

    /// Flatten to 1-D (shape, P336).
    pub fn flatten(&mut self, input: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::FlattenOp, &[input])
    }

    /// Elementwise `max(0, x)` (NN, P337).
    pub fn relu(&mut self, input: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::ReluOp, &[input])
    }

    /// Elementwise GELU approximation (NN, P337).
    pub fn gelu(&mut self, input: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::GeluOp, &[input])
    }

    /// Softmax over the last axis (NN, P337).
    pub fn softmax(&mut self, input: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::SoftmaxOp, &[input])
    }

    /// Per-row layer normalization (NN, P337).
    pub fn layer_norm(
        &mut self,
        input: TensorHandle,
        gamma: TensorHandle,
        beta: TensorHandle,
    ) -> Result<TensorHandle> {
        self.op(crate::op::LayerNormOp, &[input, gamma, beta])
    }

    /// Gather along `axis` (index, P338).
    pub fn gather(
        &mut self,
        input: TensorHandle,
        indices: TensorHandle,
        axis: TensorHandle,
    ) -> Result<TensorHandle> {
        self.op(crate::op::GatherOp, &[input, indices, axis])
    }

    /// Scatter into a fresh zero buffer along `axis` (index, P338).
    pub fn scatter(
        &mut self,
        input: TensorHandle,
        indices: TensorHandle,
        axis: TensorHandle,
    ) -> Result<TensorHandle> {
        self.op(crate::op::ScatterOp, &[input, indices, axis])
    }

    /// Sum-reduce (reductions, P339).
    pub fn sum(&mut self, input: TensorHandle, axis: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::SumOp, &[input, axis])
    }

    /// Mean-reduce (reductions, P339).
    pub fn mean(&mut self, input: TensorHandle, axis: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::MeanOp, &[input, axis])
    }

    /// Max-reduce (reductions, P339).
    pub fn max(&mut self, input: TensorHandle, axis: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::MaxOp, &[input, axis])
    }

    /// Min-reduce (reductions, P339).
    pub fn min(&mut self, input: TensorHandle, axis: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::MinOp, &[input, axis])
    }

    /// Rank-2 matrix multiplication (arithmetic, P335). Backed by the
    /// `tokitai.matmul()` facade builder added in P434.
    pub fn matmul(&mut self, lhs: TensorHandle, rhs: TensorHandle) -> Result<TensorHandle> {
        self.op(crate::op::MatmulOp, &[lhs, rhs])
    }
}

/// Result of a successful `Tokitai::graph(...)` invocation. Owns
/// the built graph, the execution plan, the populated store, and
/// the id of the closure's final output.
pub struct CompiledGraph {
    graph: SemanticGraph,
    store: TensorStore<i64>,
    plan: ExecutionPlan,
    output: usize,
}

impl CompiledGraph {
    /// Borrow the built graph.
    pub fn graph(&self) -> &SemanticGraph {
        &self.graph
    }

    /// Borrow the execution plan.
    pub fn plan(&self) -> &ExecutionPlan {
        &self.plan
    }

    /// Borrow the populated store.
    pub fn store(&self) -> &TensorStore<i64> {
        &self.store
    }

    /// Return the final output tensor, if present in the store.
    /// (It always is for a successful DSL execution.)
    pub fn output_tensor(&self) -> Result<&Tensor<i64>> {
        self.store.get(self.output)
    }

    /// Return the underlying value id of the final output.
    pub fn output_id(&self) -> usize {
        self.output
    }
}

/// Run the closure-based DSL on the given backend. The closure
/// builds the graph (and may pre-populate the store via
/// `GraphBuilder::input_tensor` or `GraphBuilder::store_mut`).
/// On success, the result owns the built graph, the populated
/// store, the plan, and the final output handle.
pub fn run_graph<F>(backend: &CpuScalarBackend, f: F) -> Result<CompiledGraph>
where
    F: FnOnce(&mut GraphBuilder) -> Result<TensorHandle>,
{
    let mut builder = GraphBuilder {
        graph: SemanticGraph::new(),
        store: TensorStore::new(),
    };
    let out = f(&mut builder)?;
    let plan = HeuristicPlanner::new(backend.capabilities()).plan(&builder.graph)?;
    backend.execute_i64(&builder.graph, &plan, &mut builder.store)?;
    Ok(CompiledGraph {
        graph: builder.graph,
        store: builder.store,
        plan,
        output: out.id,
    })
}