#![allow(unused_imports)]
#![allow(unused_variables)]
#![allow(dead_code)]
#[cfg(feature = "wgpu")]
use axum::response::IntoResponse;
use crate::error::{CliError, Result};
use colored::Colorize;
use std::fmt::Write;
use std::path::Path;
use std::sync::Arc;
use std::time::Instant;
#[cfg(feature = "inference")]
use super::safetensors::{load_safetensors_tokenizer, SafeTensorsTokenizerInfo};
use super::types::ServerConfig;
#[cfg(feature = "wgpu")]
struct WgpuInferenceState {
fwd: std::sync::Mutex<trueno::backends::gpu::WgslForwardPass>,
token_embedding: Vec<f32>,
output_norm_weight: Vec<f32>,
lm_head_f32: Vec<f32>,
vocab: Vec<String>,
token_to_id: std::collections::HashMap<String, u32>,
bpe_tokenizer: Option<realizar::apr::BpeTokenizer>,
num_layers: usize,
vocab_size: usize,
hidden_dim: usize,
}
#[cfg(feature = "wgpu")]
fn wgpu_detokenize_one(id: u32, vocab: &[String]) -> String {
let token = match vocab.get(id as usize) {
Some(t) => t,
None => return String::new(),
};
if token.starts_with("<|") && token.ends_with("|>") {
return String::new();
}
if token.starts_with("<0x") && token.ends_with('>') && token.len() == 6 {
if let Ok(b) = u8::from_str_radix(&token[3..5], 16) {
return String::from_utf8_lossy(&[b]).into_owned();
}
}
let mut bytes = Vec::new();
for c in token.chars() {
let cp = c as u32;
let byte = if (0x21..=0x7E).contains(&cp)
|| (0xA1..=0xAC).contains(&cp)
|| (0xAE..=0xFF).contains(&cp)
{
cp as u8
} else if (0x0100..=0x0143).contains(&cp) {
let off = cp - 0x0100;
match off {
0..=32 => off as u8,
33 => 0x7F,
34..=66 => (0x80 + (off - 34)) as u8,
67 => 0xAD,
_ => {
bytes.push(b'?');
continue;
}
}
} else {
let mut buf = [0u8; 4];
let s = c.encode_utf8(&mut buf);
bytes.extend_from_slice(s.as_bytes());
continue;
};
bytes.push(byte);
}
String::from_utf8_lossy(&bytes).into_owned()
}
#[cfg(feature = "wgpu")]
#[provable_contracts_macros::contract("streaming-tpot-v1", equation = "tpot_definition")]
async fn wgpu_chat_completion(
state: Arc<WgpuInferenceState>,
axum::Json(body): axum::Json<serde_json::Value>,
) -> axum::response::Response {
let max_tokens = body["max_tokens"].as_u64().unwrap_or(64).min(4096) as usize;
let stream = body["stream"].as_bool().unwrap_or(false);
let messages = body["messages"].as_array();
let prompt = messages
.and_then(|m| m.last())
.and_then(|m| m["content"].as_str())
.unwrap_or("Hello");
let chat_text = format!(
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
prompt
);
let prompt_ids: Vec<u32> = if let Some(ref tok) = state.bpe_tokenizer {
tok.encode(&chat_text)
} else {
tokenize_greedy(&chat_text, &state.token_to_id, state.vocab_size)
};
let id = format!(
"chatcmpl-wgpu-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis()
);
if stream {
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(32);
let id_clone = id.clone();
let vocab = state.vocab.clone();
let prompt_len = prompt_ids.len();
tokio::task::spawn_blocking(move || {
let gen_start = std::time::Instant::now();
let fwd = state.fwd.lock().expect("lock(");
let mut kv_caches: Vec<(Vec<f32>, Vec<f32>)> = Vec::new();
let mut last_logits = Vec::new();
for (pos, &token_id) in prompt_ids.iter().enumerate() {
match fwd.forward_model(
token_id,
pos,
state.num_layers,
&state.token_embedding,
&state.output_norm_weight,
&state.lm_head_f32,
state.vocab_size,
1e-6,
&mut kv_caches,
) {
Ok(l) => last_logits = l,
Err(_) => return,
}
}
let mut completion_tokens = 0u32;
for step in 0..max_tokens {
let next_token = last_logits
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as u32)
.unwrap_or(0);
if next_token == 151645 || next_token == 0 {
break;
}
let text = wgpu_detokenize_one(next_token, &vocab);
let chunk = serde_json::json!({
"id": id_clone, "object": "chat.completion.chunk", "model": "qwen-wgpu",
"choices": [{"index": 0, "delta": {"content": text}, "finish_reason": serde_json::Value::Null}]
});
completion_tokens += 1;
if tx.blocking_send(chunk.to_string()).is_err() {
break;
}
let position = prompt_ids.len() + step;
match fwd.forward_model(
next_token,
position,
state.num_layers,
&state.token_embedding,
&state.output_norm_weight,
&state.lm_head_f32,
state.vocab_size,
1e-6,
&mut kv_caches,
) {
Ok(l) => last_logits = l,
Err(_) => break,
}
}
let elapsed = gen_start.elapsed();
let tok_s = if elapsed.as_secs_f64() > 0.0 {
completion_tokens as f64 / elapsed.as_secs_f64()
} else {
0.0
};
let done = serde_json::json!({
"id": id_clone, "object": "chat.completion.chunk", "model": "qwen-wgpu",
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": prompt_len, "completion_tokens": completion_tokens,
"total_tokens": prompt_len as u32 + completion_tokens},
"x_wgpu_tok_s": tok_s,
});
let _ = tx.blocking_send(done.to_string());
let _ = tx.blocking_send("[DONE]".to_string());
});
let stream = async_stream::stream! {
while let Some(data) = rx.recv().await {
yield Ok::<_, std::convert::Infallible>(
axum::response::sse::Event::default().data(data)
);
}
};
axum::response::sse::Sse::new(stream).into_response()
} else {
let gen_start = std::time::Instant::now();
let mut output_ids: Vec<u32> = Vec::new();
let fwd = state.fwd.lock().expect("lock(");
let mut kv_caches: Vec<(Vec<f32>, Vec<f32>)> = Vec::new();
let mut last_logits = Vec::new();
for (pos, &token_id) in prompt_ids.iter().enumerate() {
match fwd.forward_model(
token_id,
pos,
state.num_layers,
&state.token_embedding,
&state.output_norm_weight,
&state.lm_head_f32,
state.vocab_size,
1e-6,
&mut kv_caches,
) {
Ok(logits) => last_logits = logits,
Err(e) => {
return axum::Json(serde_json::json!({"error": format!("{e}")})).into_response()
}
}
}
for step in 0..max_tokens {
let next_token = last_logits
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as u32)
.unwrap_or(0);
if next_token == 151645 || next_token == 0 {
break;
}
output_ids.push(next_token);
let position = prompt_ids.len() + step;
match fwd.forward_model(
next_token,
position,
state.num_layers,
&state.token_embedding,
&state.output_norm_weight,
&state.lm_head_f32,
state.vocab_size,
1e-6,
&mut kv_caches,
) {
Ok(logits) => last_logits = logits,
Err(_) => break,
}
}
drop(fwd);
let elapsed = gen_start.elapsed();
let output_text: String = output_ids
.iter()
.map(|&id| wgpu_detokenize_one(id, &state.vocab))
.collect();
let tok_s = if elapsed.as_secs_f64() > 0.0 {
output_ids.len() as f64 / elapsed.as_secs_f64()
} else {
0.0
};
axum::Json(serde_json::json!({
"id": id, "object": "chat.completion", "model": "qwen-wgpu",
"choices": [{"index": 0, "message": {"role": "assistant", "content": output_text},
"finish_reason": if output_ids.len() >= max_tokens { "length" } else { "stop" }}],
"usage": {"prompt_tokens": prompt_ids.len(), "completion_tokens": output_ids.len(),
"total_tokens": prompt_ids.len() + output_ids.len()},
"x_wgpu_latency_ms": elapsed.as_secs_f64() * 1000.0, "x_wgpu_tok_s": tok_s,
}))
.into_response()
}
}
#[cfg(feature = "wgpu")]
fn tokenize_greedy(
text: &str,
token_to_id: &std::collections::HashMap<String, u32>,
vocab_size: usize,
) -> Vec<u32> {
let mut ids = Vec::new();
let bytes = text.as_bytes();
let mut pos = 0;
while pos < bytes.len() {
let mut best_len = 0;
let mut best_id = 0u32; let max_len = (bytes.len() - pos).min(32);
for len in (1..=max_len).rev() {
if let Ok(s) = std::str::from_utf8(&bytes[pos..pos + len]) {
if let Some(&id) = token_to_id.get(s) {
best_len = len;
best_id = id;
break;
}
}
}
if best_len == 0 {
best_id = (bytes[pos] as u32).min(vocab_size as u32 - 1);
best_len = 1;
}
ids.push(best_id);
pos += best_len;
}
ids
}
#[cfg(feature = "inference")]
pub(crate) fn start_realizar_server(model_path: &Path, config: &ServerConfig) -> Result<()> {
use realizar::format::{detect_format, ModelFormat};
use std::io::Read;
if config.backend.as_deref() == Some("wgpu") {
println!();
println!("{}", "Backend: WGPU (Vulkan/Metal/WebGPU)".cyan());
println!(
"{}",
"PMAT-333: Loading model for WGPU inference...".dimmed()
);
use realizar::gguf::{MappedGGUFModel, OwnedQuantizedModel};
let mapped = MappedGGUFModel::from_path(model_path)
.map_err(|e| CliError::ModelLoadFailed(format!("GGUF load: {e}")))?;
let quantized = OwnedQuantizedModel::from_mapped(&mapped)
.map_err(|e| CliError::ModelLoadFailed(format!("Quantized model: {e}")))?;
let num_layers = quantized.layers().len();
println!(
"{}",
format!(
"Model: {} layers loaded for WGPU dequantization",
num_layers,
)
.green()
);
let dequant_start = std::time::Instant::now();
let weights = realizar::gpu::adapters::wgpu_adapter::dequant_model_weights(&quantized)
.map_err(|e| CliError::ModelLoadFailed(format!("Dequant: {e}")))?;
let total_mb: f64 = weights
.iter()
.map(|(_, d, _, _)| d.len() * 4)
.sum::<usize>() as f64
/ 1e6;
println!(
"{}",
format!(
"Dequantized {} weights ({:.0} MB) in {:.1}s",
weights.len(),
total_mb,
dequant_start.elapsed().as_secs_f64(),
)
.dimmed()
);
#[cfg(feature = "wgpu")]
{
println!("{}", "Initializing WGPU device...".dimmed());
let gpu_dev = trueno::backends::gpu::GpuDevice::new()
.map_err(|e| CliError::ModelLoadFailed(format!("WGPU init: {e}")))?;
println!("{}", "WGPU device ready (Vulkan/Metal)".green());
let hidden_dim = weights
.iter()
.find(|(n, _, _, _)| n.ends_with(".q_proj"))
.map(|(_, _, _, cols)| *cols)
.unwrap_or(1536);
let intermediate_dim = weights
.iter()
.find(|(n, _, _, _)| n.ends_with(".gate_proj"))
.map(|(_, _, rows, _)| *rows)
.unwrap_or(8960);
let q_dim = weights
.iter()
.find(|(n, _, _, _)| n.ends_with(".q_proj"))
.map(|(_, _, rows, _)| *rows)
.unwrap_or(hidden_dim);
let kv_dim = weights
.iter()
.find(|(n, _, _, _)| n.ends_with(".k_proj"))
.map(|(_, _, rows, _)| *rows)
.unwrap_or(256);
let head_dim = 128; let num_heads = q_dim / head_dim;
let num_kv_heads = kv_dim / head_dim;
let mut fwd = trueno::backends::gpu::WgslForwardPass::new(
gpu_dev.device.clone(),
gpu_dev.queue.clone(),
hidden_dim,
num_heads,
num_kv_heads,
head_dim,
intermediate_dim,
);
let upload_start = std::time::Instant::now();
let use_q4k = std::env::var("WGPU_Q4K").is_ok();
if use_q4k {
let q4k_raw = realizar::gpu::adapters::wgpu_adapter::raw_q4k_weights(&quantized);
let q4k_names: std::collections::HashSet<String> =
q4k_raw.iter().map(|(n, _, _, _)| n.clone()).collect();
for (name, raw_data, _rows, _cols) in &q4k_raw {
fwd.upload_q4k_weight(name, raw_data);
}
for (name, data, _rows, _cols) in &weights {
if q4k_names.contains(name.as_str()) {
continue;
}
fwd.upload_weight(name, data);
}
let q4k_mb: f64 =
q4k_raw.iter().map(|(_, d, _, _)| d.len()).sum::<usize>() as f64 / 1e6;
println!(
"{}",
format!(
"Q4K mode: {} Q4K ({:.0} MB) — 10× VRAM savings",
q4k_raw.len(),
q4k_mb
)
.cyan()
);
} else {
for (name, data, _rows, _cols) in &weights {
fwd.upload_weight(name, data);
}
}
fwd.init_kv_cache(num_layers);
println!(
"{}",
format!(
"Uploaded {} weights to GPU ({:.1} MB VRAM) in {:.1}ms",
weights.len(),
fwd.total_vram_bytes() as f64 / 1e6,
upload_start.elapsed().as_secs_f64() * 1000.0,
)
.green()
);
let token_embedding = quantized.token_embedding().to_vec();
let output_norm_weight = quantized.output_norm_weight().to_vec();
let vocab_size = token_embedding.len() / hidden_dim;
let lm_head_f32 = weights
.iter()
.find(|(n, _, _, _)| n == "lm_head")
.map(|(_, d, _, _)| d.clone())
.unwrap_or_else(|| token_embedding.clone());
println!(
"{}",
format!(
"WGPU inference ready: {} layers, vocab={}, hidden={}",
num_layers, vocab_size, hidden_dim,
)
.green()
);
let test_token = 9707u32; let test_start = std::time::Instant::now();
let mut test_kv: Vec<(Vec<f32>, Vec<f32>)> = Vec::new();
match fwd.forward_model(
test_token,
0,
num_layers,
&token_embedding,
&output_norm_weight,
&lm_head_f32,
vocab_size,
1e-6,
&mut test_kv,
) {
Ok(logits) => {
let argmax = logits
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).expect("1"))
.map(|(i, _)| i)
.unwrap_or(0);
let elapsed = test_start.elapsed();
println!(
"{}",
format!(
"WGPU test: token {} → logits[{}] max at idx {} ({:.1}ms)",
test_token,
logits.len(),
argmax,
elapsed.as_secs_f64() * 1000.0,
)
.cyan()
);
}
Err(e) => {
println!("{}", format!("WGPU forward failed: {e}").red());
}
}
println!("{}", "Starting WGPU inference server...".cyan());
let vocab: Vec<String> = mapped.model.vocabulary().unwrap_or_else(|| {
eprintln!("Warning: No vocabulary in GGUF, using placeholder");
let mut v: Vec<String> = (0..vocab_size).map(|i| format!("token{i}")).collect();
if !v.is_empty() {
v[0] = "<unk>".to_string();
}
v
});
let merges = mapped.model.merge_rules().unwrap_or_default();
println!(
"{}",
format!(
"Vocab: {} tokens, {} merge rules from GGUF",
vocab.len(),
merges.len()
)
.dimmed()
);
let bpe_tokenizer = if !merges.is_empty() {
match {
let mut t2id = std::collections::HashMap::new();
for (i, tok) in vocab.iter().enumerate() {
t2id.insert(tok.clone(), i as u32);
}
let special: std::collections::HashMap<String, u32> = vocab
.iter()
.enumerate()
.filter(|(_, t)| t.starts_with("<|") && t.ends_with("|>"))
.map(|(i, t)| (t.clone(), i as u32))
.collect();
Ok::<_, String>(realizar::apr::BpeTokenizer {
token_to_id: t2id,
id_to_token: vocab.clone(),
merge_rules: merges,
bos_id: None,
eos_id: Some(151645), special_tokens: special,
})
} {
Ok(tok) => {
println!("{}", "BPE tokenizer created with merge rules".green());
Some(tok)
}
Err(e) => {
eprintln!("BPE tokenizer failed: {e} — falling back to greedy");
None
}
}
} else {
println!("{}", "No merge rules — using greedy tokenizer".yellow());
None
};
let token_to_id: std::collections::HashMap<String, u32> = vocab
.iter()
.enumerate()
.map(|(i, t)| (t.clone(), i as u32))
.collect();
use std::sync::{Arc, Mutex};
let wgpu_state = Arc::new(WgpuInferenceState {
fwd: Mutex::new(fwd),
token_embedding,
output_norm_weight,
lm_head_f32,
vocab,
token_to_id,
bpe_tokenizer,
num_layers,
vocab_size,
hidden_dim,
});
use axum::{
routing::{get, post},
Json, Router,
};
let state_health = wgpu_state.clone();
let state_chat = wgpu_state.clone();
let app = Router::new()
.route(
"/health",
get(move || async move {
Json(serde_json::json!({
"status": "healthy",
"compute_mode": "wgpu",
"version": env!("CARGO_PKG_VERSION"),
}))
}),
)
.route(
"/v1/chat/completions",
post(move |body: Json<serde_json::Value>| {
let state = state_chat.clone();
async move { wgpu_chat_completion(state, body).await }
}),
);
let bind_addr = config.bind_addr();
let runtime = tokio::runtime::Runtime::new()
.map_err(|e| CliError::InferenceFailed(format!("Runtime: {e}")))?;
runtime.block_on(async move {
let listener = tokio::net::TcpListener::bind(&bind_addr)
.await
.map_err(|e| CliError::InferenceFailed(format!("Bind: {e}")))?;
println!(
"{}",
format!("WGPU inference server listening on http://{}", bind_addr)
.green()
.bold()
);
println!(" POST /v1/chat/completions - Chat completions (WGPU)");
println!(" GET /health - Health check");
axum::serve(listener, app)
.await
.map_err(|e| CliError::InferenceFailed(format!("Serve: {e}")))?;
Ok::<(), CliError>(())
})?;
return Ok(());
}
#[cfg(not(feature = "wgpu"))]
{
println!(
"{}",
"WGPU feature not enabled. Build with --features wgpu".yellow()
);
}
}
let path_str = model_path.to_string_lossy();
if path_str.ends_with(".safetensors.index.json") {
println!();
println!("Detected format: Sharded SafeTensors (index.json)");
return start_safetensors_server_with_fallback(model_path, config);
}
let mut file = std::fs::File::open(model_path)?;
let mut magic = [0u8; 8];
let bytes_read = file.read(&mut magic)?;
if bytes_read < 8 {
return Err(CliError::InvalidFormat(
"File too small for format detection".to_string(),
));
}
let format = detect_format(&magic)
.map_err(|e| CliError::InvalidFormat(format!("Format detection failed: {e}")))?;
println!();
println!("Detected format: {}", format);
match format {
ModelFormat::Apr => {
println!("{}", "Starting APR model server...".cyan());
start_apr_server(model_path, config)
}
ModelFormat::Gguf => {
println!("{}", "Starting GGUF inference server...".cyan());
start_gguf_server(model_path, config)
}
ModelFormat::SafeTensors => start_safetensors_server_with_fallback(model_path, config),
}
}
#[cfg(feature = "inference")]
fn start_safetensors_server_with_fallback(model_path: &Path, config: &ServerConfig) -> Result<()> {
#[cfg(feature = "cuda")]
{
let use_gpu = config.gpu && !config.no_gpu;
if use_gpu {
println!(
"{}",
"Starting SafeTensors GPU server (fused Q4K)...".cyan()
);
match start_safetensors_server_gpu(model_path, config) {
Ok(()) => return Ok(()),
Err(e) => {
println!(
"{}",
format!("GPU init failed, falling back to CPU: {e}").yellow()
);
}
}
}
}
#[cfg(feature = "inference")]
{
match start_safetensors_server_cpu_quantized(model_path, config) {
Ok(()) => return Ok(()),
Err(e) => {
println!(
"{}",
format!("Q4K conversion failed, falling back to F32: {e}").yellow()
);
}
}
}
println!("{}", "Starting SafeTensors inspection server...".cyan());
super::safetensors::start_safetensors_server(model_path, config)
}
#[cfg(feature = "inference")]
#[derive(Clone)]
struct AprServerState {
transformer: Option<Arc<std::sync::Mutex<realizar::apr_transformer::AprTransformer>>>,
model_type: String,
architecture: String,
is_transformer: bool,
tokenizer: Option<SafeTensorsTokenizerInfo>,
embedded_tokenizer: Option<realizar::apr::BpeTokenizer>,
model_name: String,
}
#[cfg(feature = "inference")]
struct AprInferenceOutput {
text: String,
tokens_generated: usize,
gen_duration: std::time::Duration,
input_token_count: usize,
}
#[cfg(feature = "inference")]
fn run_apr_cpu_inference(
state: &AprServerState,
prompt: &str,
max_tokens: usize,
temperature: f32,
) -> std::result::Result<AprInferenceOutput, String> {
let transformer = state
.transformer
.as_ref()
.ok_or("Transformer not loaded, inference not supported")?;
let input_tokens: Vec<u32> = if let Some(ref tok) = state.embedded_tokenizer {
tok.encode(prompt)
} else if let Some(ref tok) = state.tokenizer {
tok.tokenizer.encode(prompt)
} else {
prompt.chars().map(|c| c as u32).collect()
};
let input_token_count = input_tokens.len();
let gen_config = realizar::apr_transformer::GenerateConfig {
max_tokens,
temperature,
top_p: 0.9,
top_k: 0,
repetition_penalty: 1.0,
trace: false,
stop_tokens: vec![],
};
let gen_start = Instant::now();
let output_tokens = {
let t = transformer.lock().map_err(|_| {
"Transformer state corrupted (lock poisoned). Please restart the server.".to_string()
})?;
t.generate_with_cache(&input_tokens, &gen_config)
.map_err(|e| format!("Generate failed: {e}"))?
};
let gen_duration = gen_start.elapsed();
let new_tokens = if output_tokens.len() > input_tokens.len() {
&output_tokens[input_tokens.len()..]
} else {
&output_tokens[..]
};
let text = if let Some(ref tok) = state.embedded_tokenizer {
tok.decode(new_tokens)
} else if let Some(ref tok) = state.tokenizer {
tok.tokenizer.decode(new_tokens).unwrap_or_default()
} else {
new_tokens
.iter()
.filter_map(|&t| char::from_u32(t))
.collect()
};
Ok(AprInferenceOutput {
text,
tokens_generated: new_tokens.len(),
gen_duration,
input_token_count,
})
}
#[cfg(feature = "inference")]
fn load_apr_model_state(model_path: &Path, config: &ServerConfig) -> Result<AprServerState> {
use realizar::apr::AprModel;
println!("{}", "Loading APR v2 model...".dimmed());
let model = AprModel::load(model_path)
.map_err(|e| CliError::ModelLoadFailed(format!("Failed to load APR v2 model: {e}")))?;
let model_type = model
.metadata()
.model_type
.clone()
.unwrap_or_else(|| "unknown".to_string());
let architecture = model
.metadata()
.architecture
.clone()
.unwrap_or_else(|| "unknown".to_string());
let tensor_count = model.tensor_count();
let param_count = model.estimated_parameters();
let is_transformer = model.metadata().is_transformer();
println!(
"{}",
format!(
"Loaded {} model (arch: {}, {} tensors, ~{} params)",
model_type, architecture, tensor_count, param_count
)
.green()
);
if is_transformer {
println!(
"{}",
"Transformer model detected - inference enabled".cyan()
);
}
let embedded_tokenizer = model.load_embedded_bpe_tokenizer();
if embedded_tokenizer.is_some() {
println!(
"{}",
"Embedded BPE tokenizer loaded from APR metadata".green()
);
}
let bpe_tokenizer = if embedded_tokenizer.is_none() {
if let Some(tokenizer_path) =
realizar::safetensors::find_sibling_file(model_path, "tokenizer.json")
{
println!(
"{}",
format!("BPE tokenizer loaded from {}", tokenizer_path.display()).green()
);
load_safetensors_tokenizer(&tokenizer_path)
} else {
println!(
"{}",
"No tokenizer found - using character-level fallback".yellow()
);
None
}
} else {
None
};
println!("{}", "Using CPU inference".dimmed());
let transformer = if is_transformer {
match realizar::apr_transformer::AprTransformer::from_apr_file(model_path) {
Ok(t) => {
println!(
"{}",
format!(
"Transformer ready: {} layers, hidden_dim={}",
t.config.num_layers, t.config.hidden_dim
)
.cyan()
);
Some(Arc::new(std::sync::Mutex::new(t)))
}
Err(e) => {
println!(
"{}",
format!("Transformer load failed: {e} - inference disabled").yellow()
);
None
}
}
} else {
None
};
let model_name = model_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("apr")
.to_string();
Ok(AprServerState {
transformer,
model_type,
architecture,
is_transformer,
tokenizer: bpe_tokenizer,
embedded_tokenizer,
model_name,
})
}
#[cfg(feature = "inference")]
#[derive(Clone, serde::Serialize)]
struct AprHealthResponse {
status: String,
model_type: String,
architecture: String,
inference_enabled: bool,
compute_mode: String,
}
#[cfg(feature = "inference")]
#[derive(serde::Deserialize)]
struct AprCompletionRequest {
prompt: String,
#[serde(default = "default_max_tokens_apr")]
max_tokens: usize,
#[serde(default)]
temperature: Option<f32>,
}
#[cfg(feature = "inference")]
fn default_max_tokens_apr() -> usize {
32
}
#[cfg(feature = "inference")]
#[derive(serde::Serialize)]
struct AprCompletionResponse {
text: String,
tokens_generated: usize,
latency_ms: u64,
tok_per_sec: f64,
}
fn start_apr_server(model_path: &Path, config: &ServerConfig) -> Result<()> {
#[cfg(feature = "cuda")]
{
let use_gpu = config.gpu && !config.no_gpu;
if use_gpu {
match start_apr_server_gpu(model_path, config) {
Ok(()) => return Ok(()),
Err(e) => {
println!(
"{}",
format!("GPU init failed, falling back to CPU: {e}").yellow()
);
}
}
}
}
#[cfg(feature = "inference")]
{
match try_apr_quantized_cpu(model_path, config) {
Ok(()) => return Ok(()),
Err(e) => {
println!(
"{}",
format!("Q4K path unavailable ({e}), using AprTransformer fallback").yellow()
);
}
}
}
let state = load_apr_model_state(model_path, config)?;
let is_transformer = state.is_transformer;
let runtime = tokio::runtime::Runtime::new()
.map_err(|e| CliError::InferenceFailed(format!("Failed to create runtime: {e}")))?;
let bind_addr = config.bind_addr();
runtime.block_on(async move {
let app = build_apr_cpu_router(state);
let listener = tokio::net::TcpListener::bind(&bind_addr)
.await
.map_err(|e| CliError::InferenceFailed(format!("Failed to bind: {e}")))?;
print_apr_cpu_banner(&bind_addr, is_transformer);
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.map_err(|e| CliError::InferenceFailed(format!("Server error: {e}")))?;
println!();
println!("{}", "Server stopped".yellow());
Ok(())
})
}
#[cfg(feature = "inference")]
fn try_apr_quantized_cpu(model_path: &Path, config: &ServerConfig) -> Result<()> {
use realizar::apr::MappedAprModel;
use realizar::gguf::OwnedQuantizedModel;
println!("{}", "Loading APR model (fused Q4K kernels)...".dimmed());
let mapped = MappedAprModel::from_path(model_path)
.map_err(|e| CliError::InferenceFailed(format!("Failed to map APR: {e}")))?;
println!(
"{}",
format!(
"APR loaded: {} tensors, {} metadata entries",
mapped.tensors.len(),
mapped.metadata.extra.len()
)
.dimmed()
);
let quantized = OwnedQuantizedModel::from_apr(&mapped)
.map_err(|e| CliError::InferenceFailed(format!("Failed to create quantized model: {e}")))?;
println!(
"{}",
format!(
"Model ready: {} layers, vocab_size={}, hidden_dim={}",
quantized.layers().len(),
quantized.config().vocab_size,
quantized.config().hidden_dim
)
.green()
);
let vocab = mapped
.metadata
.get_embedded_vocabulary()
.unwrap_or_else(|| {
let vocab_size = mapped.metadata.vocab_size.unwrap_or(32000);
eprintln!("Warning: No embedded vocabulary in APR, using placeholder tokens");
let mut v: Vec<String> = (0..vocab_size).map(|i| format!("token{i}")).collect();
if !v.is_empty() {
v[0] = "<unk>".to_string();
}
v
});
println!("{}", "Q4K CPU inference ready".green());
run_cpu_server(quantized, vocab, config)
}
#[cfg(feature = "inference")]
#[allow(clippy::disallowed_methods)] fn build_apr_cpu_router(state: AprServerState) -> axum::Router {
use axum::{
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use std::sync::Mutex;
let state_for_health = state.clone();
let state_for_completions = Arc::new(Mutex::new(state.clone()));
let state_for_chat = Arc::new(Mutex::new(state));
Router::new()
.route(
"/health",
get(move || {
let s = state_for_health.clone();
async move {
Json(AprHealthResponse {
status: "healthy".to_string(),
model_type: s.model_type.clone(),
architecture: s.architecture.clone(),
inference_enabled: s.is_transformer,
compute_mode: "cpu".to_string(),
})
}
}),
)
.route(
"/v1/completions",
post(move |Json(req): Json<AprCompletionRequest>| {
let state = state_for_completions.clone();
async move { handle_apr_cpu_completion(&state, &req).await }
}),
)
.route(
"/v1/chat/completions",
post(
move |headers: axum::http::HeaderMap, Json(req): Json<serde_json::Value>| {
let state = state_for_chat.clone();
async move { handle_apr_cpu_chat_completion(&state, &headers, &req).await }
},
),
)
.route(
"/",
get(|| async {
"APR v2 Inference Server - POST /v1/completions, /v1/chat/completions"
}),
)
.fallback(|| async {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "not_found",
"message": "Route not found. Available: /health, /v1/completions, /v1/chat/completions"
})),
)
})
}
include!("handler_apr_cpu_completion.rs");
include!("handler_gpu_completion.rs");
include!("server.rs");