tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! The `SemanticGraph` data structure.
//!
//! `SemanticGraph` is the imperative form of an IR. It exposes
//! `add_input(meta)`, `add_op(op, inputs)`, and the `nodes()` /
//! `values()` iterators. The closure-based DSL
//! (`src/ir/dsl.rs`) compiles down to the same `SemanticGraph`.
//!
//! The `SemanticNode` carries the op name, inputs, output ids,
//! and the per-node contract set. The planner reads the graph
//! topologically to build an `ExecutionPlan`.
//!
use crate::domain::ContractSet;
use crate::object::ObjectMeta;
use crate::op::{LayerBehavior, Operator};
use crate::{Error, Result};

#[derive(Debug, Clone, PartialEq)]
pub struct SemanticNode {
    pub id: usize,
    pub op_name: String,
    pub inputs: Vec<usize>,
    pub output_ids: Vec<usize>,
    pub outputs: Vec<ObjectMeta>,
    pub required_contracts: ContractSet,
    pub provided_contracts: ContractSet,
    pub layer_behavior: Vec<LayerBehavior>,
}

#[derive(Debug, Clone, Default, PartialEq)]
pub struct SemanticGraph {
    nodes: Vec<SemanticNode>,
    values: Vec<ObjectMeta>,
}

impl SemanticGraph {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn add_input(&mut self, meta: ObjectMeta) -> usize {
        let id = self.values.len();
        self.values.push(meta);
        id
    }

    pub fn add_op<O: Operator>(&mut self, op: O, inputs: &[usize]) -> Result<Vec<usize>> {
        let input_meta = inputs
            .iter()
            .map(|id| {
                self.values
                    .get(*id)
                    .cloned()
                    .ok_or_else(|| Error::ir(format!("unknown value id {id}")))
            })
            .collect::<Result<Vec<_>>>()?;
        let outputs = op.infer(&input_meta)?;
        let output_ids = outputs
            .iter()
            .map(|output| {
                let id = self.values.len();
                self.values.push(output.clone());
                id
            })
            .collect::<Vec<_>>();
        let node = SemanticNode {
            id: self.nodes.len(),
            op_name: op.name().to_string(),
            inputs: inputs.to_vec(),
            output_ids: output_ids.clone(),
            outputs,
            required_contracts: op.required_contracts(),
            provided_contracts: op.provided_contracts(),
            layer_behavior: op.layer_behavior(),
        };
        self.nodes.push(node);
        Ok(output_ids)
    }

    pub fn value(&self, id: usize) -> Option<&ObjectMeta> {
        self.values.get(id)
    }

    pub fn nodes(&self) -> &[SemanticNode] {
        &self.nodes
    }

    pub fn users_of_value(&self, value_id: usize) -> Vec<usize> {
        self.nodes
            .iter()
            .filter(|node| node.inputs.contains(&value_id))
            .map(|node| node.id)
            .collect()
    }

    pub fn single_consumer_of_value(&self, value_id: usize) -> Option<usize> {
        let users = self.users_of_value(value_id);
        if users.len() == 1 {
            Some(users[0])
        } else {
            None
        }
    }

    pub fn output_ids_of_node(&self, node_id: usize) -> Option<&[usize]> {
        self.nodes
            .get(node_id)
            .map(|node| node.output_ids.as_slice())
    }
}