#[cfg(feature = "cuda")]
pub struct AprQ4kRequest {
pub prompt_ids: Vec<u32>,
pub max_tokens: usize,
pub temperature: f32,
pub eos_ids: Vec<u32>,
pub response_tx: tokio::sync::oneshot::Sender<Result<AprQ4kResponse, String>>,
}
#[cfg(feature = "cuda")]
#[derive(Debug)]
pub struct AprQ4kResponse {
pub output_tokens: Vec<u32>,
pub tokens_generated: usize,
pub generation_time_ms: f64,
pub tokens_per_second: f64,
}
#[cfg(feature = "cuda")]
pub fn spawn_apr_q4k_inference_thread(
model_path: &str,
) -> Result<tokio::sync::mpsc::Sender<AprQ4kRequest>, String> {
use crate::apr::AprV2Model;
use crate::cuda::CudaExecutor;
use crate::gpu::adapters::apr_q4k::{
parse_apr_q4k_config, upload_apr_q4k_weights, AprQ4KConfig,
};
use std::path::Path;
let model_path_owned = model_path.to_string();
let path = Path::new(&model_path_owned);
let model = AprV2Model::load(path).map_err(|e| format!("Failed to load APR: {e}"))?;
let config =
parse_apr_q4k_config(&model).map_err(|e| format!("Failed to parse config: {e}"))?;
println!(
" Q4K GPU: {} layers, hidden={}, heads={}/{}, vocab={}",
config.num_layers,
config.hidden_dim,
config.num_heads,
config.num_kv_heads,
config.vocab_size
);
if let Some(ne) = config.num_experts {
println!(
" MoE: {} experts, top-{}, intermediate={}",
ne,
config.num_experts_per_tok.unwrap_or(0),
config.moe_intermediate_size.unwrap_or(0)
);
}
let mut executor = CudaExecutor::new(0).map_err(|e| format!("CUDA init failed: {e}"))?;
let upload_result = upload_apr_q4k_weights(&model, &mut executor)
.map_err(|e| format!("Weight upload failed: {e}"))?;
println!(
" Uploaded {} tensors ({} Q4K, {} F32) — {:.1} MB VRAM",
upload_result.num_tensors,
upload_result.num_q4k_tensors,
upload_result.num_f32_tensors,
upload_result.total_bytes as f64 / (1024.0 * 1024.0)
);
let embed_name = model
.find_tensor_name(&[
"model.embed_tokens.weight",
"embed_tokens.weight",
"transformer.wte.weight",
"embeddings.word_embeddings.weight",
"tok_embeddings.weight",
"token_embd.weight",
])
.map_err(|e| format!("Missing embedding: {e}"))?;
let embedding_weight = model
.get_tensor_f32(&embed_name)
.map_err(|e| format!("Missing embedding: {e}"))?;
let norm_name = model
.find_tensor_name(&[
"model.norm.weight",
"norm.weight",
"transformer.ln_f.weight",
"output_norm.weight",
])
.map_err(|e| format!("Missing output norm: {e}"))?;
let output_norm_weight = model
.get_tensor_f32(&norm_name)
.map_err(|e| format!("Missing output norm: {e}"))?;
let mut layer_norm_weights: Vec<(Vec<f32>, Vec<f32>, Option<Vec<f32>>, Option<Vec<f32>>)> =
Vec::with_capacity(config.num_layers);
for layer_idx in 0..config.num_layers {
let attn_norm_name = model
.find_tensor_name(&[
&format!("model.layers.{layer_idx}.input_layernorm.weight"),
&format!("layers.{layer_idx}.input_layernorm.weight"),
&format!("blk.{layer_idx}.attn_norm.weight"),
])
.map_err(|e| format!("Missing attn norm layer {layer_idx}: {e}"))?;
let attn_norm = model
.get_tensor_f32(&attn_norm_name)
.map_err(|e| format!("Missing attn norm layer {layer_idx}: {e}"))?;
let ffn_norm_name = model
.find_tensor_name(&[
&format!("model.layers.{layer_idx}.post_attention_layernorm.weight"),
&format!("layers.{layer_idx}.post_attention_layernorm.weight"),
&format!("blk.{layer_idx}.ffn_norm.weight"),
])
.map_err(|e| format!("Missing FFN norm layer {layer_idx}: {e}"))?;
let ffn_norm = model
.get_tensor_f32(&ffn_norm_name)
.map_err(|e| format!("Missing FFN norm layer {layer_idx}: {e}"))?;
let q_norm = model
.get_tensor_f32(&format!("model.layers.{layer_idx}.self_attn.q_norm.weight"))
.ok();
let k_norm = model
.get_tensor_f32(&format!("model.layers.{layer_idx}.self_attn.k_norm.weight"))
.ok();
layer_norm_weights.push((attn_norm, ffn_norm, q_norm, k_norm));
}
let mut layer_qkv_biases: Vec<(Option<Vec<f32>>, Option<Vec<f32>>, Option<Vec<f32>>)> =
Vec::with_capacity(config.num_layers);
for layer_idx in 0..config.num_layers {
let q_bias = model
.get_tensor_f32(&format!("model.layers.{layer_idx}.self_attn.q_proj.bias"))
.ok();
let k_bias = model
.get_tensor_f32(&format!("model.layers.{layer_idx}.self_attn.k_proj.bias"))
.ok();
let v_bias = model
.get_tensor_f32(&format!("model.layers.{layer_idx}.self_attn.v_proj.bias"))
.ok();
layer_qkv_biases.push((q_bias, k_bias, v_bias));
}
let _ = model.release_cpu_pages();
let tokenizer = AprV2Model::load_tokenizer(path);
println!(" Q4K GPU inference thread: ready");
let (tx, mut rx) = tokio::sync::mpsc::channel::<AprQ4kRequest>(64);
std::thread::spawn(move || {
executor
.make_context_current()
.expect("Q4K inference thread: failed to set CUDA context");
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Q4K inference thread: failed to create tokio runtime");
rt.block_on(async move {
while let Some(req) = rx.recv().await {
let result = generate_q4k(
&mut executor,
&config,
&embedding_weight,
&output_norm_weight,
&layer_norm_weights,
&layer_qkv_biases,
&req.prompt_ids,
req.max_tokens,
req.temperature,
&req.eos_ids,
);
let _ = req.response_tx.send(result);
}
eprintln!("[Q4K] Inference thread shutting down (channel closed)");
});
});
Ok(tx)
}
#[cfg(feature = "cuda")]
fn generate_q4k(
executor: &mut crate::cuda::CudaExecutor,
config: &crate::gpu::adapters::apr_q4k::AprQ4KConfig,
embedding_weight: &[f32],
output_norm_weight: &[f32],
layer_norm_weights: &[(Vec<f32>, Vec<f32>, Option<Vec<f32>>, Option<Vec<f32>>)],
layer_qkv_biases: &[(Option<Vec<f32>>, Option<Vec<f32>>, Option<Vec<f32>>)],
prompt_ids: &[u32],
max_tokens: usize,
temperature: f32,
eos_ids: &[u32],
) -> Result<AprQ4kResponse, String> {
use crate::cli::inference::{argmax, sample_with_temperature};
use crate::gpu::adapters::apr_q4k::forward_token_apr_q4k;
use std::time::Instant;
let mut kv_cache_k: Vec<Vec<f32>> = vec![Vec::new(); config.num_layers];
let mut kv_cache_v: Vec<Vec<f32>> = vec![Vec::new(); config.num_layers];
let gen_start = Instant::now();
let mut last_logits = Vec::new();
for (pos, &token_id) in prompt_ids.iter().enumerate() {
last_logits = forward_token_apr_q4k(
executor,
config,
embedding_weight,
output_norm_weight,
layer_norm_weights,
layer_qkv_biases,
&mut kv_cache_k,
&mut kv_cache_v,
token_id,
pos,
)
.map_err(|e| format!("Prefill failed at pos {pos}: {e}"))?;
}
let mut next_token = if temperature <= 0.01 {
argmax(&last_logits)
} else {
sample_with_temperature(&last_logits, temperature, 40)
};
let mut output_tokens = vec![next_token];
for step in 0..max_tokens.saturating_sub(1) {
if eos_ids.contains(&next_token) {
break;
}
let position = prompt_ids.len() + step;
let logits = forward_token_apr_q4k(
executor,
config,
embedding_weight,
output_norm_weight,
layer_norm_weights,
layer_qkv_biases,
&mut kv_cache_k,
&mut kv_cache_v,
next_token,
position,
)
.map_err(|e| format!("Decode failed at step {step}: {e}"))?;
next_token = if temperature <= 0.01 {
argmax(&logits)
} else {
sample_with_temperature(&logits, temperature, 40)
};
output_tokens.push(next_token);
}
let gen_time = gen_start.elapsed();
let tokens_generated = output_tokens.len();
let tokens_per_second = if gen_time.as_secs_f64() > 0.0 {
tokens_generated as f64 / gen_time.as_secs_f64()
} else {
0.0
};
Ok(AprQ4kResponse {
output_tokens,
tokens_generated,
generation_time_ms: gen_time.as_secs_f64() * 1000.0,
tokens_per_second,
})
}