tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Sequential model container.
//!
//! `Sequential` holds an ordered list of `Box<dyn Layer>` and
//! chains `forward` / `backward` calls. The forward cache lives
//! on the model itself (via `RefCell<Option<...>>`) so the public
//! trait signature stays clean. Used by the 0.7B MoE training
//! runner (`src/bin/train_quality_moe.rs`).
//!
// Sequential model: holds an ordered list of `Box<dyn Layer>` and
// chains forward/backward. The `ForwardCache` lives on the model
// itself (via `RefCell<Option<...>>`) so the public trait signature
// can stay close to the spec: `forward(&self, inputs: &[Tensor<f32>])
// -> Vec<Tensor<f32>>` and `backward(&self, grad_outputs: &[Tensor<f32>])
// -> (Vec<Tensor<f32>>, Vec<Tensor<f32>>)`.

use std::any::Any;
use std::cell::RefCell;

use crate::object::Tensor;
use crate::{Error, Result};

use super::layer::Layer;
use super::parameter::Parameter;

/// Per-layer forward cache. Stored as `Box<dyn Any + Send>` so
/// `Sequential` can hold an arbitrary mix of layer types. The
/// corresponding `Layer::backward` implementation is responsible
/// for downcasting back to its concrete cache type.
pub struct ForwardCache {
    /// The original input slice, kept for any callers that need it.
    pub inputs: Vec<Tensor<f32>>,
    /// Per-tensor, per-layer caches. The outer vector is indexed by
    /// input tensor; the inner vector is indexed by layer.
    pub layer_caches: Vec<Vec<Box<dyn Any + Send>>>,
}

/// `Model` trait, mirroring the Phase 2.1 spec.
pub trait Model: Send {
    fn forward(&self, inputs: &[Tensor<f32>]) -> Result<Vec<Tensor<f32>>>;
    fn backward(
        &self,
        grad_outputs: &[Tensor<f32>],
    ) -> Result<(Vec<Tensor<f32>>, Vec<Tensor<f32>>)>;
    fn parameters(&self) -> Vec<&Parameter>;
    /// Mutable view over every parameter, in layer-declaration order.
    /// Used by the optimizer / smoke test to apply in-place updates.
    fn parameters_mut(&mut self) -> Vec<&mut Parameter>;
}

pub struct Sequential {
    pub layers: Vec<Box<dyn Layer>>,
    last_cache: RefCell<Option<ForwardCache>>,
}

impl Sequential {
    pub fn new(layers: Vec<Box<dyn Layer>>) -> Self {
        Self {
            layers,
            last_cache: RefCell::new(None),
        }
    }

    pub fn layer_count(&self) -> usize {
        self.layers.len()
    }

    /// Mutable view over every parameter in layer-declaration order.
    /// Used by the training driver / smoke tests to apply an
    /// in-place optimizer step.
    pub fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
        let mut out: Vec<&mut Parameter> = Vec::new();
        for layer in &mut self.layers {
            out.extend(layer.parameters_mut());
        }
        out
    }
}

impl Model for Sequential {
    fn forward(&self, inputs: &[Tensor<f32>]) -> Result<Vec<Tensor<f32>>> {
        let mut current: Vec<Tensor<f32>> = inputs.to_vec();
        let mut layer_caches: Vec<Vec<Box<dyn Any + Send>>> = (0..inputs.len())
            .map(|_| Vec::with_capacity(self.layers.len()))
            .collect();
        for layer in &self.layers {
            let mut next = Vec::with_capacity(current.len());
            for (i, inp) in current.iter().enumerate() {
                let (out, cache) = Layer::forward(layer.as_ref(), inp).map_err(|e| {
                    Error::backend(format!(
                        "Sequential::forward layer={} idx={}: {e}",
                        layer.name(),
                        i
                    ))
                })?;
                next.push(out);
                layer_caches[i].push(cache);
            }
            current = next;
        }
        *self.last_cache.borrow_mut() = Some(ForwardCache {
            inputs: inputs.to_vec(),
            layer_caches,
        });
        Ok(current)
    }

    fn backward(
        &self,
        grad_outputs: &[Tensor<f32>],
    ) -> Result<(Vec<Tensor<f32>>, Vec<Tensor<f32>>)> {
        // Take the cache out so we can do interior-mutable work without
        // holding the borrow during the loop.
        let cache = self
            .last_cache
            .borrow_mut()
            .take()
            .ok_or_else(|| Error::backend("Sequential::backward called before forward"))?;
        if cache.layer_caches.len() != grad_outputs.len() {
            return Err(Error::backend(format!(
                "Sequential::backward: {} output tensors but {} grad_output tensors",
                cache.layer_caches.len(),
                grad_outputs.len()
            )));
        }

        // Each input gets its own gradient stream. The param gradients
        // are shared across all input streams (summed across the batch
        // dimension is the caller's responsibility -- the param
        // gradient is already batched by the Linear/Embedding layers).
        let n_inputs = grad_outputs.len();
        let n_layers = self.layers.len();
        let mut per_input_grads: Vec<Tensor<f32>> = grad_outputs.to_vec();
        // Gather param-grad slots: one Vec<Tensor<f32>> per layer,
        // each layer's slot has one Tensor per parameter of that layer.
        let mut param_grads: Vec<Vec<Tensor<f32>>> = Vec::new();

        for layer_idx in (0..n_layers).rev() {
            let layer = &self.layers[layer_idx];
            let n_params = layer.parameters().len();
            let mut summed = vec![Vec::new(); n_params];
            let mut next_per_input_grads = Vec::with_capacity(n_inputs);
            for i in 0..n_inputs {
                let c = cache.layer_caches[i][layer_idx].as_ref();
                let (grad_input, pg) = Layer::backward(layer.as_ref(), &per_input_grads[i], c)
                    .map_err(|e| {
                        Error::backend(format!(
                            "Sequential::backward layer={} idx={}: {e}",
                            layer.name(),
                            i
                        ))
                    })?;
                // Accumulate parameter gradients across the batch dim.
                for (j, g) in pg.into_iter().enumerate() {
                    if summed[j].is_empty() {
                        summed[j] = g.data.clone();
                    } else {
                        for (k, v) in summed[j].iter_mut().zip(g.data.iter()) {
                            *k += *v;
                        }
                    }
                }
                next_per_input_grads.push(grad_input);
            }
            per_input_grads = next_per_input_grads;
            param_grads.push(
                summed
                    .into_iter()
                    .enumerate()
                    .map(|(j, data)| {
                        let param = layer.parameters()[j];
                        Tensor::dense_cpu(
                            param.data.meta.domain.clone(),
                            param.data.meta.shape.clone(),
                            data,
                        )
                    })
                    .collect(),
            );
        }
        param_grads.reverse();

        // Flatten param gradients to a single Vec<Tensor<f32>>, in the
        // order produced by `parameters()`.
        let parameters: Vec<&Parameter> = self.parameters();
        let expected = parameters.len();
        let flat: Vec<Tensor<f32>> = param_grads.into_iter().flatten().collect();
        if flat.len() != expected {
            return Err(Error::backend(format!(
                "Sequential::backward produced {} param gradients but model has {} parameters",
                flat.len(),
                expected
            )));
        }
        Ok((per_input_grads, flat))
    }

    fn parameters(&self) -> Vec<&Parameter> {
        let mut out: Vec<&Parameter> = Vec::new();
        for layer in &self.layers {
            out.extend(layer.parameters());
        }
        out
    }

    fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
        let mut out: Vec<&mut Parameter> = Vec::new();
        for layer in &mut self.layers {
            out.extend(layer.parameters_mut());
        }
        out
    }
}

impl Model for Box<dyn Layer> {
    fn forward(&self, inputs: &[Tensor<f32>]) -> Result<Vec<Tensor<f32>>> {
        let mut outs = Vec::with_capacity(inputs.len());
        for inp in inputs {
            let (out, _cache) = Layer::forward(self.as_ref(), inp)?;
            outs.push(out);
        }
        Ok(outs)
    }
    fn backward(
        &self,
        _grad_outputs: &[Tensor<f32>],
    ) -> Result<(Vec<Tensor<f32>>, Vec<Tensor<f32>>)> {
        // Single-layer backward isn't supported through the boxed-dyn
        // path; the caller is expected to hold their own cache. We
        // return an error so the misuse is loud.
        Err(Error::backend(
            "Box<dyn Layer> does not support Model::backward; use Sequential",
        ))
    }
    fn parameters(&self) -> Vec<&Parameter> {
        Layer::parameters(self.as_ref())
    }
    fn parameters_mut(&mut self) -> Vec<&mut Parameter> {
        Layer::parameters_mut(self.as_mut())
    }
}