use crate::engine::InferenceEngine;
use crate::sampling::SamplingParams;
use oxibonsai_core::config::Qwen3Config;
#[derive(serde::Deserialize, Debug)]
struct WasmInferenceRequest {
hidden_size: usize,
num_layers: usize,
num_attention_heads: usize,
num_kv_heads: usize,
intermediate_size: usize,
vocab_size: usize,
max_context_length: usize,
rms_norm_eps: f32,
rope_theta: f32,
head_dim: usize,
prompt_tokens: Vec<u32>,
max_tokens: usize,
#[serde(default = "default_temperature")]
temperature: f32,
#[serde(default = "default_top_k")]
top_k: usize,
#[serde(default = "default_top_p")]
top_p: f32,
#[serde(default = "default_seed")]
seed: u64,
}
fn default_temperature() -> f32 {
0.7
}
fn default_top_k() -> usize {
40
}
fn default_top_p() -> f32 {
0.9
}
fn default_seed() -> u64 {
42
}
#[derive(serde::Serialize, Debug)]
struct WasmInferenceResponse {
tokens: Vec<u32>,
error: Option<String>,
}
impl WasmInferenceResponse {
fn success(tokens: Vec<u32>) -> Self {
Self {
tokens,
error: None,
}
}
fn error(msg: impl Into<String>) -> Self {
Self {
tokens: vec![],
error: Some(msg.into()),
}
}
}
pub fn generate_json(request_json: &str) -> String {
let response = match run_inference(request_json) {
Ok(tokens) => WasmInferenceResponse::success(tokens),
Err(e) => WasmInferenceResponse::error(e),
};
match serde_json::to_string(&response) {
Ok(s) => s,
Err(e) => format!(r#"{{"tokens":[],"error":"failed to serialize response: {e}"}}"#),
}
}
fn run_inference(request_json: &str) -> Result<Vec<u32>, String> {
let req: WasmInferenceRequest =
serde_json::from_str(request_json).map_err(|e| format!("invalid request JSON: {e}"))?;
let config = Qwen3Config {
hidden_size: req.hidden_size,
num_layers: req.num_layers,
num_attention_heads: req.num_attention_heads,
num_kv_heads: req.num_kv_heads,
intermediate_size: req.intermediate_size,
vocab_size: req.vocab_size,
max_context_length: req.max_context_length,
rms_norm_eps: req.rms_norm_eps,
rope_freq_base: req.rope_theta,
head_dim: req.head_dim,
architecture: "qwen3".to_string(),
model_name: "bonsai".to_string(),
};
let sampling = SamplingParams {
temperature: req.temperature,
top_k: req.top_k,
top_p: req.top_p,
repetition_penalty: 1.1,
};
let mut engine = InferenceEngine::new(config, sampling, req.seed);
engine
.generate(&req.prompt_tokens, req.max_tokens)
.map_err(|e| format!("inference error: {e}"))
}
#[cfg(test)]
mod tests {
use super::*;
fn tiny_config_json(prompt_tokens: &[u32], max_tokens: usize) -> String {
let tokens_json = serde_json::to_string(prompt_tokens).expect("serialize tokens");
format!(
r#"{{
"hidden_size": 256,
"num_layers": 2,
"num_attention_heads": 4,
"num_kv_heads": 2,
"intermediate_size": 512,
"vocab_size": 1024,
"max_context_length": 128,
"rms_norm_eps": 1e-6,
"rope_theta": 10000.0,
"head_dim": 64,
"prompt_tokens": {tokens_json},
"max_tokens": {max_tokens}
}}"#
)
}
#[test]
fn generate_json_empty_prompt_returns_empty_tokens() {
let req = tiny_config_json(&[], 5);
let resp_str = generate_json(&req);
let resp: serde_json::Value = serde_json::from_str(&resp_str).expect("valid JSON response");
assert!(resp["error"].is_null(), "expected no error, got: {resp}");
let tokens = resp["tokens"].as_array().expect("tokens array");
assert!(tokens.is_empty(), "empty prompt should yield no tokens");
}
#[test]
fn generate_json_invalid_json_returns_error() {
let resp_str = generate_json("this is not json");
let resp: serde_json::Value =
serde_json::from_str(&resp_str).expect("response should be valid JSON");
assert!(
!resp["error"].is_null(),
"invalid input should produce an error"
);
}
#[test]
fn generate_json_missing_required_field_returns_error() {
let req = r#"{"num_layers": 2, "prompt_tokens": [1], "max_tokens": 1}"#;
let resp_str = generate_json(req);
let resp: serde_json::Value =
serde_json::from_str(&resp_str).expect("response should be valid JSON");
assert!(
!resp["error"].is_null(),
"missing fields should produce an error"
);
}
#[test]
fn response_serialization_success() {
let r = WasmInferenceResponse::success(vec![1, 2, 3]);
let s = serde_json::to_string(&r).expect("serialize");
assert!(s.contains("\"tokens\":[1,2,3]"));
assert!(s.contains("\"error\":null"));
}
#[test]
fn response_serialization_error() {
let r = WasmInferenceResponse::error("something went wrong");
let s = serde_json::to_string(&r).expect("serialize");
assert!(s.contains("\"tokens\":[]"));
assert!(s.contains("\"error\":\"something went wrong\""));
}
}