tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! HTTP sidecar for the trained 0.7B MoE model (gated on `model-server`).
//!
//! Axum + tokio server that exposes the MoE model over HTTP.
//! Loads `arch.json` + `checkpoint.tkp1` at startup and serves
//! `POST /infer` requests. The feature flag `model-server` is
//! separate from `rocm-hip` so the server can be built without
//! the HIP kernels (CPU-only inference).
//!
#![cfg(feature = "model-server")]

//! HTTP sidecar for the trained 0.7B MoE quality-decision model.
//! Step 2 of the tokitai-search integration plan: an axum server
//! that wraps a single `ModelSession` (loaded at startup) and
//! exposes:
//!
//!   - `POST /infer` — body `{"features": [f32; 96]}` or batch;
//!     returns `{"logits": [[f32; 20], ...], "router_weights":
//!     [[f32; 4], ...]}`.
//!   - `GET /healthz` — readiness probe returning the loaded arch
//!     + checkpoint paths so the tokitai-search integration can
//!     verify the sidecar is up and pointing at the right artifacts.
//!
//! The bind address is read from `TOKITAI_MODEL_BIND_ADDR`
//! (default `127.0.0.1:9100`). Only loopback by default — the
//! tokitai-search service is expected to run on the same host and
//! talk to the sidecar via the loopback interface.
//!
//! Concurrency: the axum server is multi-threaded; `ModelSession`
//! is wrapped in `Arc<Mutex<…>>` so concurrent `/infer` calls
//! serialize through the mutex. The model itself is not
//! thread-safe to mutate, but the only mutation is the forward
//! pass, which is bounded by the per-request critical section.

use std::net::SocketAddr;
use std::path::PathBuf;
use std::process::ExitCode;
use std::sync::{Arc, Mutex};

use axum::Json;
use axum::Router;
use axum::extract::State;
use axum::http::StatusCode;
use axum::routing::{get, post};
use serde::Deserialize;
use serde_json::{Value, json};
use tokitai_operator::infer::{INFER_IN_DIM, INFER_N_EXPERTS, INFER_OUT_DIM, ModelSession};

#[derive(Clone)]
struct AppState {
    session: Arc<Mutex<ModelSession>>,
}

#[derive(Debug, Deserialize)]
struct InferRequest {
    /// A flat 96-dim vector (treated as batch=1) or a list of
    /// 96-dim vectors.
    features: Value,
}

#[tokio::main(flavor = "multi_thread")]
async fn main() -> ExitCode {
    let args: Vec<String> = std::env::args().skip(1).collect();
    let cfg = match parse_args(&args) {
        Ok(c) => c,
        Err(e) => {
            eprintln!("model_server: {e}");
            eprintln!();
            print_usage();
            return ExitCode::from(2);
        }
    };

    eprintln!(
        "model_server: loading model from arch={} checkpoint={}",
        cfg.arch_path.display(),
        cfg.checkpoint_path.display(),
    );
    let session = match ModelSession::load(&cfg.arch_path, &cfg.checkpoint_path) {
        Ok(s) => s,
        Err(e) => {
            eprintln!("model_server: failed to load model: {e}");
            return ExitCode::FAILURE;
        }
    };
    eprintln!(
        "model_server: model loaded (in_dim={INFER_IN_DIM} out_dim={INFER_OUT_DIM} n_experts={INFER_N_EXPERTS})"
    );

    let bind_addr =
        std::env::var("TOKITAI_MODEL_BIND_ADDR").unwrap_or_else(|_| "127.0.0.1:9100".to_string());
    let bind_addr: SocketAddr = match bind_addr.parse() {
        Ok(a) => a,
        Err(e) => {
            eprintln!(
                "model_server: TOKITAI_MODEL_BIND_ADDR={bind_addr:?} is not a valid socket addr: {e}"
            );
            return ExitCode::FAILURE;
        }
    };

    let state = AppState {
        session: Arc::new(Mutex::new(session)),
    };
    let app = Router::new()
        .route("/infer", post(infer_handler))
        .route("/healthz", get(healthz_handler))
        .with_state(state);

    let listener = match tokio::net::TcpListener::bind(bind_addr).await {
        Ok(l) => l,
        Err(e) => {
            eprintln!("model_server: bind {bind_addr}: {e}");
            return ExitCode::FAILURE;
        }
    };
    eprintln!("model_server: listening on http://{bind_addr}");

    if let Err(e) = axum::serve(listener, app).await {
        eprintln!("model_server: serve error: {e}");
        return ExitCode::FAILURE;
    }
    ExitCode::SUCCESS
}

#[derive(Debug)]
struct ServerConfig {
    arch_path: PathBuf,
    checkpoint_path: PathBuf,
}

fn parse_args(args: &[String]) -> Result<ServerConfig, String> {
    let mut arch_path: Option<PathBuf> = None;
    let mut checkpoint_path: Option<PathBuf> = None;
    let mut i = 0;
    while i < args.len() {
        match args[i].as_str() {
            "--arch" => {
                i += 1;
                let v = args
                    .get(i)
                    .ok_or_else(|| "--arch requires a value".to_string())?;
                arch_path = Some(PathBuf::from(v));
            }
            "--checkpoint" => {
                i += 1;
                let v = args
                    .get(i)
                    .ok_or_else(|| "--checkpoint requires a value".to_string())?;
                checkpoint_path = Some(PathBuf::from(v));
            }
            "-h" | "--help" => {
                print_usage();
                std::process::exit(0);
            }
            other => {
                return Err(format!("unknown arg: {other}"));
            }
        }
        i += 1;
    }
    let arch_path = arch_path.ok_or_else(|| "missing --arch <path>".to_string())?;
    let checkpoint_path =
        checkpoint_path.ok_or_else(|| "missing --checkpoint <path>".to_string())?;
    Ok(ServerConfig {
        arch_path,
        checkpoint_path,
    })
}

fn print_usage() {
    eprintln!(
        "Usage: model_server --arch <arch.json> --checkpoint <checkpoint.tkp1>\n\
         \n\
         Environment:\n  \
             TOKITAI_MODEL_BIND_ADDR  bind address (default 127.0.0.1:9100)\n\
         \n\
         Endpoints:\n  \
             POST /infer    body: {{\"features\": [f32; 96]}} or batch\n  \
             GET  /healthz  returns {{\"status\": \"ok\", \"arch\": ..., \"checkpoint\": ...}}\n"
    );
}

async fn infer_handler(
    State(state): State<AppState>,
    Json(req): Json<InferRequest>,
) -> Result<Json<Value>, (StatusCode, String)> {
    let batch = parse_features_value(&req.features).map_err(invalid_request)?;
    let mut session = state.session.lock().map_err(|e| {
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("session lock poisoned: {e}"),
        )
    })?;
    let out = session.forward(batch).map_err(|e| {
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("forward failed: {e}"),
        )
    })?;
    Ok(Json(json!({
        "logits": tensor_to_json(&out.logits),
        "router_weights": tensor_to_json(&out.router_weights),
    })))
}

async fn healthz_handler(State(state): State<AppState>) -> Json<Value> {
    let session = state
        .session
        .lock()
        .expect("session lock poisoned (unrecoverable; restart server)");
    Json(json!({
        "status": "ok",
        "arch": session.arch_path().display().to_string(),
        "checkpoint": session.checkpoint_path().display().to_string(),
        "in_dim": INFER_IN_DIM,
        "out_dim": INFER_OUT_DIM,
        "n_experts": INFER_N_EXPERTS,
    }))
}

fn invalid_request(msg: String) -> (StatusCode, String) {
    (StatusCode::BAD_REQUEST, msg)
}

/// Parse a `features` JSON value (flat number array or batch of
/// arrays) into `Vec<Vec<f32>>`. Row width must equal `INFER_IN_DIM`.
fn parse_features_value(features: &Value) -> Result<Vec<Vec<f32>>, String> {
    let arr = features
        .as_array()
        .ok_or_else(|| "features must be an array".to_string())?;
    if arr.is_empty() {
        return Err("features is empty".to_string());
    }
    if arr[0].is_number() {
        let row: Vec<f32> = arr
            .iter()
            .map(|x| {
                if !x.is_number() {
                    Err("flat features array must be all numbers".to_string())
                } else {
                    Ok(x.as_f64().unwrap_or(0.0) as f32)
                }
            })
            .collect::<Result<Vec<_>, _>>()?;
        if row.len() != INFER_IN_DIM {
            return Err(format!(
                "flat features array has length {}, expected {INFER_IN_DIM}",
                row.len()
            ));
        }
        return Ok(vec![row]);
    }
    let mut batch = Vec::with_capacity(arr.len());
    for (i, row) in arr.iter().enumerate() {
        let row_arr = row
            .as_array()
            .ok_or_else(|| format!("features[{i}] must be an array"))?;
        if !row_arr.iter().all(|x| x.is_number()) {
            return Err(format!("features[{i}] must be all numbers"));
        }
        let v: Vec<f32> = row_arr
            .iter()
            .map(|x| x.as_f64().unwrap_or(0.0) as f32)
            .collect();
        if v.len() != INFER_IN_DIM {
            return Err(format!(
                "features[{i}] has length {}, expected {INFER_IN_DIM}",
                v.len()
            ));
        }
        batch.push(v);
    }
    Ok(batch)
}

/// Render a 2-D `Tensor<f32>` as a JSON array of arrays. NaN and
/// +inf/-inf are emitted as JSON `null` / `1e999` / `-1e999` so
/// the on-the-wire values are always valid JSON numbers. Finite
/// values are emitted as raw f64 — the consumer (tokitai-search)
/// can deserialize them into `f32` directly without a string
/// round-trip.
fn tensor_to_json(t: &tokitai_operator::object::Tensor<f32>) -> Value {
    use tokitai_operator::object::Dim;
    let dims = &t.meta.shape.dims;
    let (rows, cols) = match dims.as_slice() {
        [Dim::Static(r), Dim::Static(c)] => (*r, *c),
        _ => {
            return Value::Array(t.data.iter().map(|v| json!(f64_or_null(*v))).collect());
        }
    };
    let mut outer: Vec<Value> = Vec::with_capacity(rows);
    for r in 0..rows {
        let row: Vec<Value> = (0..cols)
            .map(|c| json!(f64_or_null(t.data[r * cols + c])))
            .collect();
        outer.push(Value::Array(row));
    }
    Value::Array(outer)
}

/// Convert an `f32` to a JSON-friendly `f64` (or `null` for NaN,
/// or a large-magnitude sentinel for ±inf). The 20-dim logits are
/// coarse enough that the f32->f64 widening is lossless, so the
/// consumer can round-trip through `f64` without precision loss.
fn f64_or_null(v: f32) -> Option<f64> {
    if v.is_nan() {
        None
    } else if v.is_infinite() {
        if v > 0.0 {
            Some(1e99_f64)
        } else {
            Some(-1e99_f64)
        }
    } else {
        Some(v as f64)
    }
}