tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! `Shape` and `Dim`.
//!
//! `Shape` is the static-or-dynamic shape of a tensor. `Dim` is
//! the per-axis dimension: `Static(n)` (known at compile time),
//! `Dynamic(sym)` (a symbolic dim, e.g. the batch size), or
//! `RankOnly` (only the rank is known). The shape ops
//! (`Reshape`, `Transpose`, etc.) read the dims and update the
//! shape in the output meta.
//!
use crate::{Error, Result};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Dim {
    Static(usize),
    Symbolic(String),
    Bounded {
        name: String,
        min: usize,
        max: usize,
    },
    DataDependent(String),
}

impl Dim {
    /// For `Dim::Static(n)`, returns `Some(n)`. For other variants
    /// (Symbolic, Bounded, DataDependent) returns `None`.
    pub fn value(&self) -> Option<usize> {
        match self {
            Self::Static(v) => Some(*v),
            _ => None,
        }
    }

    /// For `Dim::Static(n)`, returns `n`. For other variants, returns
    /// `default`. Useful in code paths that only need a fallback size.
    pub fn value_or(&self, default: usize) -> usize {
        self.value().unwrap_or(default)
    }

    pub fn proves_equal(&self, other: &Self) -> bool {
        match (self, other) {
            (Self::Static(lhs), Self::Static(rhs)) => lhs == rhs,
            (Self::Symbolic(lhs), Self::Symbolic(rhs)) => lhs == rhs,
            (
                Self::Bounded {
                    name: lhs_name,
                    min: lhs_min,
                    max: lhs_max,
                },
                Self::Bounded {
                    name: rhs_name,
                    min: rhs_min,
                    max: rhs_max,
                },
            ) => lhs_name == rhs_name && lhs_min == rhs_min && lhs_max == rhs_max,
            (Self::Static(value), Self::Bounded { min, max, .. })
            | (Self::Bounded { min, max, .. }, Self::Static(value)) => {
                *min <= *value && *value <= *max
            }
            _ => false,
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Shape {
    pub dims: Vec<Dim>,
}

impl Shape {
    pub fn new(dims: impl Into<Vec<Dim>>) -> Self {
        Self { dims: dims.into() }
    }

    pub fn scalar() -> Self {
        Self { dims: Vec::new() }
    }

    pub fn rank(&self) -> usize {
        self.dims.len()
    }

    /// Borrow the first dimension, or `None` when the shape is a scalar
    /// (rank-0).
    pub fn first_dim(&self) -> Option<&Dim> {
        self.dims.first()
    }

    /// Borrow the last dimension, or `None` when the shape is a scalar.
    pub fn last_dim(&self) -> Option<&Dim> {
        self.dims.last()
    }

    /// Borrow the dimension at `index`, or `None` when out of range.
    pub fn dim(&self, index: usize) -> Option<&Dim> {
        self.dims.get(index)
    }

    /// Return `Ok(())` when the shape has exactly `rank` dimensions,
    /// or `Error::Shape` otherwise. Useful as a precondition guard
    /// for ops that require a specific rank (e.g. matmul requires rank-2).
    pub fn ensure_rank(&self, rank: usize) -> Result<()> {
        if self.rank() == rank {
            Ok(())
        } else {
            Err(Error::shape(format!(
                "shape {self:?} has rank {}, expected {rank}",
                self.rank()
            )))
        }
    }

    pub fn ensure_same(&self, other: &Self) -> Result<()> {
        if self == other {
            Ok(())
        } else {
            Err(Error::shape(format!(
                "shape mismatch: left={self:?}, right={other:?}"
            )))
        }
    }

    pub fn ensure_dim_proves_equal(
        &self,
        lhs_index: usize,
        other: &Self,
        rhs_index: usize,
    ) -> Result<()> {
        let lhs = self.dims.get(lhs_index).ok_or_else(|| {
            Error::shape(format!(
                "missing left dimension {lhs_index} for shape {self:?}"
            ))
        })?;
        let rhs = other.dims.get(rhs_index).ok_or_else(|| {
            Error::shape(format!(
                "missing right dimension {rhs_index} for shape {other:?}"
            ))
        })?;
        if lhs.proves_equal(rhs) {
            Ok(())
        } else {
            Err(Error::shape(format!(
                "dimension equality is not proven: left={lhs:?}, right={rhs:?}"
            )))
        }
    }
}

impl From<Vec<usize>> for Shape {
    fn from(value: Vec<usize>) -> Self {
        Self {
            dims: value.into_iter().map(Dim::Static).collect(),
        }
    }
}

impl std::fmt::Display for Dim {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Static(value) => write!(f, "{value}"),
            Self::Symbolic(name) => write!(f, "{name}"),
            Self::Bounded { name, min, max } => write!(f, "{name}∈[{min}..{max}]"),
            Self::DataDependent(name) => write!(f, "{name}?"),
        }
    }
}

impl std::fmt::Display for Shape {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if self.dims.is_empty() {
            return write!(f, "scalar");
        }
        let parts: Vec<String> = self.dims.iter().map(|d| d.to_string()).collect();
        write!(f, "[{}]", parts.join(", "))
    }
}