tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Standalone inference CLI for the trained 0.7B MoE model (gated on `rocm-hip`).
//!
//! Loads `arch.json` + `checkpoint.tkp1` and runs a single
//! forward pass on a user-supplied input. Exits 0 on success,
//! prints the per-class probabilities.
//!
#![cfg(feature = "rocm-hip")]

//! Standalone inference CLI for the trained 0.7B MoE quality-decision
//! model. Step 1 of the tokitai-search integration plan: a single-shot
//! binary that loads `arch.json` + `checkpoint.tkp1`, reads one input
//! vector (96-dim, JSON on stdin), runs a forward pass, and prints
//! the 20-dim logits as JSON to stdout. The HTTP sidecar (Step 2:
//! `src/bin/model_server.rs`) is a thin axum wrapper around the
//! same `ModelSession::forward` path.
//!
//! Input protocol (stdin):
//!
//!     {"features": [f32; 96]}
//!
//! or, for a batch:
//!
//!     {"features": [[f32; 96], [f32; 96], ...]}
//!
//! Output protocol (stdout, single line):
//!
//!     {"logits": [[f32; 20], ...], "router_weights": [[f32; 4], ...]}
//!
//! Both `features` and `logits` are float-friendly (not milli-units),
//! matching the raw tensor layout used by the training runner. The
//! quality-decision application (in tokitai-search) is responsible
//! for the milli-unit <-> float conversion when constructing the
//! input and interpreting the output.

use std::env;
use std::io::{Read, Write};
use std::path::PathBuf;
use std::process::ExitCode;

use tokitai_operator::infer::ModelSession;

fn main() -> ExitCode {
    let args: Vec<String> = env::args().skip(1).collect();
    let parsed = match parse_args(&args) {
        Ok(c) => c,
        Err(e) => {
            eprintln!("infer_quality_moe: {e}");
            eprintln!();
            print_usage();
            return ExitCode::from(2);
        }
    };
    match run_infer(&parsed) {
        Ok(out) => {
            // Single-line JSON output: the HTTP sidecar parses this
            // verbatim, so keeping it on one line avoids an extra
            // newline-strip step in the client.
            let stdout = std::io::stdout();
            let mut h = stdout.lock();
            if let Err(e) = writeln!(h, "{}", out) {
                eprintln!("infer_quality_moe: write stdout: {e}");
                return ExitCode::FAILURE;
            }
            ExitCode::SUCCESS
        }
        Err(e) => {
            eprintln!("infer_quality_moe: failed: {e}");
            ExitCode::FAILURE
        }
    }
}

#[derive(Debug)]
struct InferConfig {
    arch_path: PathBuf,
    checkpoint_path: PathBuf,
}

fn parse_args(args: &[String]) -> Result<InferConfig, String> {
    let mut arch_path: Option<PathBuf> = None;
    let mut checkpoint_path: Option<PathBuf> = None;
    let mut i = 0;
    while i < args.len() {
        match args[i].as_str() {
            "--arch" => {
                i += 1;
                let v = args
                    .get(i)
                    .ok_or_else(|| "--arch requires a value".to_string())?;
                arch_path = Some(PathBuf::from(v));
            }
            "--checkpoint" => {
                i += 1;
                let v = args
                    .get(i)
                    .ok_or_else(|| "--checkpoint requires a value".to_string())?;
                checkpoint_path = Some(PathBuf::from(v));
            }
            "-h" | "--help" => {
                print_usage();
                std::process::exit(0);
            }
            other => {
                return Err(format!("unknown arg: {other}"));
            }
        }
        i += 1;
    }
    let arch_path = arch_path.ok_or_else(|| "missing --arch <path>".to_string())?;
    let checkpoint_path =
        checkpoint_path.ok_or_else(|| "missing --checkpoint <path>".to_string())?;
    Ok(InferConfig {
        arch_path,
        checkpoint_path,
    })
}

fn print_usage() {
    eprintln!(
        "Usage: infer_quality_moe --arch <arch.json> --checkpoint <checkpoint.tkp1>\n\
         \n\
         Reads a single JSON object from stdin of the form:\n  \
             {{\"features\": [f32; 96]}}\n  \
         or for a batch:\n  \
             {{\"features\": [[f32; 96], [f32; 96], ...]}}\n\
         \n\
         Writes a single JSON object to stdout:\n  \
             {{\"logits\": [[f32; 20], ...], \"router_weights\": [[f32; 4], ...]}}\n"
    );
}

fn run_infer(cfg: &InferConfig) -> Result<String, String> {
    // 1. Build the session. This loads arch + checkpoint and
    //    restores the trained weights into the freshly-built model.
    let mut session = ModelSession::load(&cfg.arch_path, &cfg.checkpoint_path)?;

    // 2. Read features from stdin.
    let mut stdin_buf = String::new();
    std::io::stdin()
        .read_to_string(&mut stdin_buf)
        .map_err(|e| format!("read stdin: {e}"))?;
    let batch = parse_features(&stdin_buf)?;

    // 3. Forward.
    let out = session.forward(batch)?;

    // 4. Serialize the output as a single-line JSON object so the
    //    HTTP sidecar (Step 2) and the tokitai-search e2e test can
    //    parse it verbatim.
    let logits_str = tensor_to_json(&out.logits);
    let router_str = tensor_to_json(&out.router_weights);
    Ok(format!(
        "{{\"logits\":{logits_str},\"router_weights\":{router_str}}}"
    ))
}

/// Parse the stdin JSON. Accepts either a flat 96-dim vector
/// (treated as batch=1) or a list of 96-dim vectors. Row width is
/// checked up front so that the forward call later never sees a
/// tensor with the wrong shape.
fn parse_features(raw: &str) -> Result<Vec<Vec<f32>>, String> {
    use tokitai_operator::infer::INFER_IN_DIM as IN_DIM;
    let v: serde_json::Value =
        serde_json::from_str(raw).map_err(|e| format!("stdin is not valid JSON: {e}"))?;
    let features = v
        .get("features")
        .ok_or_else(|| "stdin JSON missing `features` key".to_string())?;
    if let Some(arr) = features.as_array() {
        if arr.is_empty() {
            return Err("features is empty".to_string());
        }
        // Distinguish flat number array from batch-of-arrays by
        // looking at the first element.
        if arr[0].is_number() {
            let row: Vec<f32> = arr
                .iter()
                .map(|x| {
                    if !x.is_number() {
                        Err("flat features array must be all numbers".to_string())
                    } else {
                        Ok(x.as_f64().unwrap_or(0.0) as f32)
                    }
                })
                .collect::<Result<Vec<_>, _>>()?;
            if row.len() != IN_DIM {
                return Err(format!(
                    "flat features array has length {}, expected {IN_DIM}",
                    row.len()
                ));
            }
            return Ok(vec![row]);
        }
        // Batch path: [[f32; 96], [f32; 96], ...]
        let mut batch = Vec::with_capacity(arr.len());
        for (i, row) in arr.iter().enumerate() {
            let row_arr = row
                .as_array()
                .ok_or_else(|| format!("features[{i}] must be an array"))?;
            if !row_arr.iter().all(|x| x.is_number()) {
                return Err(format!("features[{i}] must be all numbers"));
            }
            let v: Vec<f32> = row_arr
                .iter()
                .map(|x| x.as_f64().unwrap_or(0.0) as f32)
                .collect();
            if v.len() != IN_DIM {
                return Err(format!(
                    "features[{i}] has length {}, expected {IN_DIM}",
                    v.len()
                ));
            }
            batch.push(v);
        }
        return Ok(batch);
    }
    Err("features must be an array".to_string())
}

/// Render a 2-D `Tensor<f32>` as a JSON array of arrays. We
/// hand-roll the JSON to avoid pulling in `serde_json`'s derive
/// for a single nested type — the layout is a strict
/// `[[f32; cols]; rows]` and the rounding only has to be a
/// reasonable decimal representation.
fn tensor_to_json(t: &tokitai_operator::object::Tensor<f32>) -> String {
    use tokitai_operator::object::Dim;
    let dims = &t.meta.shape.dims;
    let (rows, cols) = match dims.as_slice() {
        [Dim::Static(r), Dim::Static(c)] => (*r, *c),
        _ => {
            // 1-D or symbolic dims: fall back to a flat array.
            let mut s = String::from("[");
            for (i, v) in t.data.iter().enumerate() {
                if i > 0 {
                    s.push(',');
                }
                s.push_str(&format_f32(*v));
            }
            s.push(']');
            return s;
        }
    };
    let mut s = String::from("[");
    for r in 0..rows {
        if r > 0 {
            s.push(',');
        }
        s.push('[');
        for c in 0..cols {
            if c > 0 {
                s.push(',');
            }
            s.push_str(&format_f32(t.data[r * cols + c]));
        }
        s.push(']');
    }
    s.push(']');
    s
}

fn format_f32(v: f32) -> String {
    if v.is_nan() {
        "null".to_string()
    } else if v.is_infinite() {
        if v > 0.0 {
            "1e999".to_string()
        } else {
            "-1e999".to_string()
        }
    } else {
        // 6 significant digits is enough for the tokitai-search
        // quality decision path; the 20-dim logits are coarse
        // enough that this does not lose information.
        format!("{:.6}", v)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use tokitai_operator::infer::INFER_IN_DIM as IN_DIM;

    #[test]
    fn format_f32_handles_special_values() {
        assert_eq!(format_f32(0.0), "0.000000");
        assert_eq!(format_f32(-1.5), "-1.500000");
        assert_eq!(format_f32(f32::NAN), "null");
        assert_eq!(format_f32(f32::INFINITY), "1e999");
        assert_eq!(format_f32(f32::NEG_INFINITY), "-1e999");
    }

    #[test]
    fn parse_features_accepts_flat_array() {
        let row: Vec<f32> = vec![0.0; IN_DIM];
        let raw = serde_json::json!({ "features": row }).to_string();
        let batch = parse_features(&raw).expect("flat array should parse");
        assert_eq!(batch.len(), 1);
        assert_eq!(batch[0].len(), IN_DIM);
    }

    #[test]
    fn parse_features_accepts_batch() {
        let rows: Vec<Vec<f32>> = vec![vec![0.0; IN_DIM], vec![1.0; IN_DIM]];
        let raw = serde_json::json!({ "features": rows }).to_string();
        let batch = parse_features(&raw).expect("batch should parse");
        assert_eq!(batch.len(), 2);
        assert_eq!(batch[0].len(), IN_DIM);
        assert_eq!(batch[1].len(), IN_DIM);
    }

    #[test]
    fn parse_features_rejects_missing_key() {
        let raw = "{\"wrong_key\": []}";
        assert!(parse_features(raw).is_err());
    }

    #[test]
    fn parse_features_rejects_wrong_row_width() {
        let raw = serde_json::json!({ "features": [[1.0_f32, 2.0]] }).to_string();
        assert!(parse_features(&raw).is_err());
    }

    #[test]
    fn parse_features_rejects_empty_batch() {
        let raw = "{\"features\": []}";
        assert!(parse_features(raw).is_err());
    }
}