#![cfg(feature = "rocm-hip")]
use std::env;
use std::io::{Read, Write};
use std::path::PathBuf;
use std::process::ExitCode;
use tokitai_operator::infer::ModelSession;
fn main() -> ExitCode {
let args: Vec<String> = env::args().skip(1).collect();
let parsed = match parse_args(&args) {
Ok(c) => c,
Err(e) => {
eprintln!("infer_quality_moe: {e}");
eprintln!();
print_usage();
return ExitCode::from(2);
}
};
match run_infer(&parsed) {
Ok(out) => {
let stdout = std::io::stdout();
let mut h = stdout.lock();
if let Err(e) = writeln!(h, "{}", out) {
eprintln!("infer_quality_moe: write stdout: {e}");
return ExitCode::FAILURE;
}
ExitCode::SUCCESS
}
Err(e) => {
eprintln!("infer_quality_moe: failed: {e}");
ExitCode::FAILURE
}
}
}
#[derive(Debug)]
struct InferConfig {
arch_path: PathBuf,
checkpoint_path: PathBuf,
}
fn parse_args(args: &[String]) -> Result<InferConfig, String> {
let mut arch_path: Option<PathBuf> = None;
let mut checkpoint_path: Option<PathBuf> = None;
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--arch" => {
i += 1;
let v = args
.get(i)
.ok_or_else(|| "--arch requires a value".to_string())?;
arch_path = Some(PathBuf::from(v));
}
"--checkpoint" => {
i += 1;
let v = args
.get(i)
.ok_or_else(|| "--checkpoint requires a value".to_string())?;
checkpoint_path = Some(PathBuf::from(v));
}
"-h" | "--help" => {
print_usage();
std::process::exit(0);
}
other => {
return Err(format!("unknown arg: {other}"));
}
}
i += 1;
}
let arch_path = arch_path.ok_or_else(|| "missing --arch <path>".to_string())?;
let checkpoint_path =
checkpoint_path.ok_or_else(|| "missing --checkpoint <path>".to_string())?;
Ok(InferConfig {
arch_path,
checkpoint_path,
})
}
fn print_usage() {
eprintln!(
"Usage: infer_quality_moe --arch <arch.json> --checkpoint <checkpoint.tkp1>\n\
\n\
Reads a single JSON object from stdin of the form:\n \
{{\"features\": [f32; 96]}}\n \
or for a batch:\n \
{{\"features\": [[f32; 96], [f32; 96], ...]}}\n\
\n\
Writes a single JSON object to stdout:\n \
{{\"logits\": [[f32; 20], ...], \"router_weights\": [[f32; 4], ...]}}\n"
);
}
fn run_infer(cfg: &InferConfig) -> Result<String, String> {
let mut session = ModelSession::load(&cfg.arch_path, &cfg.checkpoint_path)?;
let mut stdin_buf = String::new();
std::io::stdin()
.read_to_string(&mut stdin_buf)
.map_err(|e| format!("read stdin: {e}"))?;
let batch = parse_features(&stdin_buf)?;
let out = session.forward(batch)?;
let logits_str = tensor_to_json(&out.logits);
let router_str = tensor_to_json(&out.router_weights);
Ok(format!(
"{{\"logits\":{logits_str},\"router_weights\":{router_str}}}"
))
}
fn parse_features(raw: &str) -> Result<Vec<Vec<f32>>, String> {
use tokitai_operator::infer::INFER_IN_DIM as IN_DIM;
let v: serde_json::Value =
serde_json::from_str(raw).map_err(|e| format!("stdin is not valid JSON: {e}"))?;
let features = v
.get("features")
.ok_or_else(|| "stdin JSON missing `features` key".to_string())?;
if let Some(arr) = features.as_array() {
if arr.is_empty() {
return Err("features is empty".to_string());
}
if arr[0].is_number() {
let row: Vec<f32> = arr
.iter()
.map(|x| {
if !x.is_number() {
Err("flat features array must be all numbers".to_string())
} else {
Ok(x.as_f64().unwrap_or(0.0) as f32)
}
})
.collect::<Result<Vec<_>, _>>()?;
if row.len() != IN_DIM {
return Err(format!(
"flat features array has length {}, expected {IN_DIM}",
row.len()
));
}
return Ok(vec![row]);
}
let mut batch = Vec::with_capacity(arr.len());
for (i, row) in arr.iter().enumerate() {
let row_arr = row
.as_array()
.ok_or_else(|| format!("features[{i}] must be an array"))?;
if !row_arr.iter().all(|x| x.is_number()) {
return Err(format!("features[{i}] must be all numbers"));
}
let v: Vec<f32> = row_arr
.iter()
.map(|x| x.as_f64().unwrap_or(0.0) as f32)
.collect();
if v.len() != IN_DIM {
return Err(format!(
"features[{i}] has length {}, expected {IN_DIM}",
v.len()
));
}
batch.push(v);
}
return Ok(batch);
}
Err("features must be an array".to_string())
}
fn tensor_to_json(t: &tokitai_operator::object::Tensor<f32>) -> String {
use tokitai_operator::object::Dim;
let dims = &t.meta.shape.dims;
let (rows, cols) = match dims.as_slice() {
[Dim::Static(r), Dim::Static(c)] => (*r, *c),
_ => {
let mut s = String::from("[");
for (i, v) in t.data.iter().enumerate() {
if i > 0 {
s.push(',');
}
s.push_str(&format_f32(*v));
}
s.push(']');
return s;
}
};
let mut s = String::from("[");
for r in 0..rows {
if r > 0 {
s.push(',');
}
s.push('[');
for c in 0..cols {
if c > 0 {
s.push(',');
}
s.push_str(&format_f32(t.data[r * cols + c]));
}
s.push(']');
}
s.push(']');
s
}
fn format_f32(v: f32) -> String {
if v.is_nan() {
"null".to_string()
} else if v.is_infinite() {
if v > 0.0 {
"1e999".to_string()
} else {
"-1e999".to_string()
}
} else {
format!("{:.6}", v)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokitai_operator::infer::INFER_IN_DIM as IN_DIM;
#[test]
fn format_f32_handles_special_values() {
assert_eq!(format_f32(0.0), "0.000000");
assert_eq!(format_f32(-1.5), "-1.500000");
assert_eq!(format_f32(f32::NAN), "null");
assert_eq!(format_f32(f32::INFINITY), "1e999");
assert_eq!(format_f32(f32::NEG_INFINITY), "-1e999");
}
#[test]
fn parse_features_accepts_flat_array() {
let row: Vec<f32> = vec![0.0; IN_DIM];
let raw = serde_json::json!({ "features": row }).to_string();
let batch = parse_features(&raw).expect("flat array should parse");
assert_eq!(batch.len(), 1);
assert_eq!(batch[0].len(), IN_DIM);
}
#[test]
fn parse_features_accepts_batch() {
let rows: Vec<Vec<f32>> = vec![vec![0.0; IN_DIM], vec![1.0; IN_DIM]];
let raw = serde_json::json!({ "features": rows }).to_string();
let batch = parse_features(&raw).expect("batch should parse");
assert_eq!(batch.len(), 2);
assert_eq!(batch[0].len(), IN_DIM);
assert_eq!(batch[1].len(), IN_DIM);
}
#[test]
fn parse_features_rejects_missing_key() {
let raw = "{\"wrong_key\": []}";
assert!(parse_features(raw).is_err());
}
#[test]
fn parse_features_rejects_wrong_row_width() {
let raw = serde_json::json!({ "features": [[1.0_f32, 2.0]] }).to_string();
assert!(parse_features(&raw).is_err());
}
#[test]
fn parse_features_rejects_empty_batch() {
let raw = "{\"features\": []}";
assert!(parse_features(raw).is_err());
}
}