#![cfg(feature = "model-server")]
use std::net::SocketAddr;
use std::path::PathBuf;
use std::process::ExitCode;
use std::sync::{Arc, Mutex};
use axum::Json;
use axum::Router;
use axum::extract::State;
use axum::http::StatusCode;
use axum::routing::{get, post};
use serde::Deserialize;
use serde_json::{Value, json};
use tokitai_operator::infer::{INFER_IN_DIM, INFER_N_EXPERTS, INFER_OUT_DIM, ModelSession};
#[derive(Clone)]
struct AppState {
session: Arc<Mutex<ModelSession>>,
}
#[derive(Debug, Deserialize)]
struct InferRequest {
features: Value,
}
#[tokio::main(flavor = "multi_thread")]
async fn main() -> ExitCode {
let args: Vec<String> = std::env::args().skip(1).collect();
let cfg = match parse_args(&args) {
Ok(c) => c,
Err(e) => {
eprintln!("model_server: {e}");
eprintln!();
print_usage();
return ExitCode::from(2);
}
};
eprintln!(
"model_server: loading model from arch={} checkpoint={}",
cfg.arch_path.display(),
cfg.checkpoint_path.display(),
);
let session = match ModelSession::load(&cfg.arch_path, &cfg.checkpoint_path) {
Ok(s) => s,
Err(e) => {
eprintln!("model_server: failed to load model: {e}");
return ExitCode::FAILURE;
}
};
eprintln!(
"model_server: model loaded (in_dim={INFER_IN_DIM} out_dim={INFER_OUT_DIM} n_experts={INFER_N_EXPERTS})"
);
let bind_addr =
std::env::var("TOKITAI_MODEL_BIND_ADDR").unwrap_or_else(|_| "127.0.0.1:9100".to_string());
let bind_addr: SocketAddr = match bind_addr.parse() {
Ok(a) => a,
Err(e) => {
eprintln!(
"model_server: TOKITAI_MODEL_BIND_ADDR={bind_addr:?} is not a valid socket addr: {e}"
);
return ExitCode::FAILURE;
}
};
let state = AppState {
session: Arc::new(Mutex::new(session)),
};
let app = Router::new()
.route("/infer", post(infer_handler))
.route("/healthz", get(healthz_handler))
.with_state(state);
let listener = match tokio::net::TcpListener::bind(bind_addr).await {
Ok(l) => l,
Err(e) => {
eprintln!("model_server: bind {bind_addr}: {e}");
return ExitCode::FAILURE;
}
};
eprintln!("model_server: listening on http://{bind_addr}");
if let Err(e) = axum::serve(listener, app).await {
eprintln!("model_server: serve error: {e}");
return ExitCode::FAILURE;
}
ExitCode::SUCCESS
}
#[derive(Debug)]
struct ServerConfig {
arch_path: PathBuf,
checkpoint_path: PathBuf,
}
fn parse_args(args: &[String]) -> Result<ServerConfig, String> {
let mut arch_path: Option<PathBuf> = None;
let mut checkpoint_path: Option<PathBuf> = None;
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--arch" => {
i += 1;
let v = args
.get(i)
.ok_or_else(|| "--arch requires a value".to_string())?;
arch_path = Some(PathBuf::from(v));
}
"--checkpoint" => {
i += 1;
let v = args
.get(i)
.ok_or_else(|| "--checkpoint requires a value".to_string())?;
checkpoint_path = Some(PathBuf::from(v));
}
"-h" | "--help" => {
print_usage();
std::process::exit(0);
}
other => {
return Err(format!("unknown arg: {other}"));
}
}
i += 1;
}
let arch_path = arch_path.ok_or_else(|| "missing --arch <path>".to_string())?;
let checkpoint_path =
checkpoint_path.ok_or_else(|| "missing --checkpoint <path>".to_string())?;
Ok(ServerConfig {
arch_path,
checkpoint_path,
})
}
fn print_usage() {
eprintln!(
"Usage: model_server --arch <arch.json> --checkpoint <checkpoint.tkp1>\n\
\n\
Environment:\n \
TOKITAI_MODEL_BIND_ADDR bind address (default 127.0.0.1:9100)\n\
\n\
Endpoints:\n \
POST /infer body: {{\"features\": [f32; 96]}} or batch\n \
GET /healthz returns {{\"status\": \"ok\", \"arch\": ..., \"checkpoint\": ...}}\n"
);
}
async fn infer_handler(
State(state): State<AppState>,
Json(req): Json<InferRequest>,
) -> Result<Json<Value>, (StatusCode, String)> {
let batch = parse_features_value(&req.features).map_err(invalid_request)?;
let mut session = state.session.lock().map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("session lock poisoned: {e}"),
)
})?;
let out = session.forward(batch).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("forward failed: {e}"),
)
})?;
Ok(Json(json!({
"logits": tensor_to_json(&out.logits),
"router_weights": tensor_to_json(&out.router_weights),
})))
}
async fn healthz_handler(State(state): State<AppState>) -> Json<Value> {
let session = state
.session
.lock()
.expect("session lock poisoned (unrecoverable; restart server)");
Json(json!({
"status": "ok",
"arch": session.arch_path().display().to_string(),
"checkpoint": session.checkpoint_path().display().to_string(),
"in_dim": INFER_IN_DIM,
"out_dim": INFER_OUT_DIM,
"n_experts": INFER_N_EXPERTS,
}))
}
fn invalid_request(msg: String) -> (StatusCode, String) {
(StatusCode::BAD_REQUEST, msg)
}
fn parse_features_value(features: &Value) -> Result<Vec<Vec<f32>>, String> {
let arr = features
.as_array()
.ok_or_else(|| "features must be an array".to_string())?;
if arr.is_empty() {
return Err("features is empty".to_string());
}
if arr[0].is_number() {
let row: Vec<f32> = arr
.iter()
.map(|x| {
if !x.is_number() {
Err("flat features array must be all numbers".to_string())
} else {
Ok(x.as_f64().unwrap_or(0.0) as f32)
}
})
.collect::<Result<Vec<_>, _>>()?;
if row.len() != INFER_IN_DIM {
return Err(format!(
"flat features array has length {}, expected {INFER_IN_DIM}",
row.len()
));
}
return Ok(vec![row]);
}
let mut batch = Vec::with_capacity(arr.len());
for (i, row) in arr.iter().enumerate() {
let row_arr = row
.as_array()
.ok_or_else(|| format!("features[{i}] must be an array"))?;
if !row_arr.iter().all(|x| x.is_number()) {
return Err(format!("features[{i}] must be all numbers"));
}
let v: Vec<f32> = row_arr
.iter()
.map(|x| x.as_f64().unwrap_or(0.0) as f32)
.collect();
if v.len() != INFER_IN_DIM {
return Err(format!(
"features[{i}] has length {}, expected {INFER_IN_DIM}",
v.len()
));
}
batch.push(v);
}
Ok(batch)
}
fn tensor_to_json(t: &tokitai_operator::object::Tensor<f32>) -> Value {
use tokitai_operator::object::Dim;
let dims = &t.meta.shape.dims;
let (rows, cols) = match dims.as_slice() {
[Dim::Static(r), Dim::Static(c)] => (*r, *c),
_ => {
return Value::Array(t.data.iter().map(|v| json!(f64_or_null(*v))).collect());
}
};
let mut outer: Vec<Value> = Vec::with_capacity(rows);
for r in 0..rows {
let row: Vec<Value> = (0..cols)
.map(|c| json!(f64_or_null(t.data[r * cols + c])))
.collect();
outer.push(Value::Array(row));
}
Value::Array(outer)
}
fn f64_or_null(v: f32) -> Option<f64> {
if v.is_nan() {
None
} else if v.is_infinite() {
if v > 0.0 {
Some(1e99_f64)
} else {
Some(-1e99_f64)
}
} else {
Some(v as f64)
}
}