#![allow(clippy::similar_names)]
#[cfg(feature = "cuda")]
use crate::apr::AprV2Model;
#[cfg(feature = "cuda")]
use crate::cuda::CudaExecutor;
use crate::error::{RealizarError, Result};
#[cfg(feature = "cuda")]
const Q4K_TYPE: u32 = 12;
#[cfg(feature = "cuda")]
const Q6K_TYPE: u32 = 14;
#[cfg(feature = "cuda")]
const F32_TYPE: u32 = 0;
#[derive(Debug, Clone)]
pub struct AprQ4KConfig {
pub hidden_dim: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub num_layers: usize,
pub intermediate_dim: usize,
pub vocab_size: usize,
pub eps: f32,
pub rope_theta: f64,
pub num_experts: Option<usize>,
pub num_experts_per_tok: Option<usize>,
pub moe_intermediate_size: Option<usize>,
}
#[derive(Debug)]
pub struct UploadResult {
pub total_bytes: usize,
pub num_tensors: usize,
pub num_q4k_tensors: usize,
pub num_f32_tensors: usize,
}
#[cfg(feature = "cuda")]
const fn align_up(offset: usize, align: usize) -> usize {
(offset + align - 1) & !(align - 1)
}
#[cfg(feature = "cuda")]
fn normalize_tensor_name(name: &str) -> String {
if let Some(rest) = name.strip_prefix("blk.") {
if let Some(dot_pos) = rest.find('.') {
let layer_num = &rest[..dot_pos];
let suffix = &rest[dot_pos + 1..];
let mapped = match suffix {
"attn_q.weight" => format!("model.layers.{layer_num}.self_attn.q_proj.weight"),
"attn_k.weight" => format!("model.layers.{layer_num}.self_attn.k_proj.weight"),
"attn_v.weight" => format!("model.layers.{layer_num}.self_attn.v_proj.weight"),
"attn_output.weight" => {
format!("model.layers.{layer_num}.self_attn.o_proj.weight")
},
"attn_norm.weight" => {
format!("model.layers.{layer_num}.input_layernorm.weight")
},
"ffn_norm.weight" => {
format!("model.layers.{layer_num}.post_attention_layernorm.weight")
},
"ffn_gate.weight" => format!("model.layers.{layer_num}.mlp.gate_proj.weight"),
"ffn_up.weight" => format!("model.layers.{layer_num}.mlp.up_proj.weight"),
"ffn_down.weight" => format!("model.layers.{layer_num}.mlp.down_proj.weight"),
"attn_q.bias" => format!("model.layers.{layer_num}.self_attn.q_proj.bias"),
"attn_k.bias" => format!("model.layers.{layer_num}.self_attn.k_proj.bias"),
"attn_v.bias" => format!("model.layers.{layer_num}.self_attn.v_proj.bias"),
"attn_q_norm.weight" => {
format!("model.layers.{layer_num}.self_attn.q_norm.weight")
},
"attn_k_norm.weight" => {
format!("model.layers.{layer_num}.self_attn.k_norm.weight")
},
_ => return name.to_string(),
};
return mapped;
}
}
match name {
"token_embd.weight" => "model.embed_tokens.weight".to_string(),
"output_norm.weight" => "model.norm.weight".to_string(),
"output.weight" => "lm_head.weight".to_string(),
_ => name.to_string(),
}
}
#[cfg(feature = "cuda")]
fn norm_alias(hf_name: &str) -> Option<String> {
if hf_name == "model.norm.weight" || hf_name == "norm.weight" || hf_name == "output_norm.weight"
{
return Some("apr.output_norm".to_string());
}
if let Some(rest) = hf_name.strip_prefix("model.layers.") {
if let Some(dot_pos) = rest.find('.') {
let layer_num = &rest[..dot_pos];
let suffix = &rest[dot_pos + 1..];
return match suffix {
"input_layernorm.weight" => Some(format!("apr.layer_{layer_num}.attn_norm")),
"post_attention_layernorm.weight" => {
Some(format!("apr.layer_{layer_num}.ffn_norm"))
},
_ => None,
};
}
}
None
}
#[cfg(feature = "cuda")]
pub fn upload_apr_q4k_weights(
model: &AprV2Model,
executor: &mut CudaExecutor,
) -> Result<UploadResult> {
let raw_names: Vec<String> = model.tensor_names().into_iter().map(String::from).collect();
let tensor_names: Vec<(String, String)> = raw_names
.iter()
.map(|n| (n.clone(), normalize_tensor_name(n)))
.collect();
let mut pool_size = 0usize;
let mut quantized_entries: Vec<(String, String, u32, usize)> = Vec::new();
for (raw_name, norm_name) in &tensor_names {
let entry = model
.get_tensor(raw_name)
.ok_or_else(|| RealizarError::FormatError {
reason: format!("Tensor disappeared: {raw_name}"),
})?;
let bytes = model.get_tensor_bytes(raw_name)?;
let dtype = entry.dtype.as_str();
let qtype = match dtype {
"Q4_K" | "q4_k" => Some(Q4K_TYPE),
"Q6_K" | "q6_k" => Some(Q6K_TYPE),
other => crate::apr::dequant::dtype_to_ggml_qtype(other),
};
if let Some(qt) = qtype {
let offset = pool_size;
quantized_entries.push((raw_name.clone(), norm_name.clone(), qt, offset));
pool_size = align_up(pool_size + bytes.len(), 256);
}
}
let num_q4k = quantized_entries.len();
println!(
" ALB-098: Pool allocator — {} quantized tensors, {:.1} GB in 1 cuMemAlloc",
num_q4k,
pool_size as f64 / 1e9
);
let mut total_bytes = 0usize;
if pool_size > 0 {
executor
.allocate_quantized_weight_pool(pool_size)
.map_err(|e| RealizarError::GpuError {
reason: format!(
"Failed to allocate {:.1} GB weight pool: {e}",
pool_size as f64 / 1e9
),
})?;
for (raw_name, norm_name, qtype, offset) in &quantized_entries {
let bytes = model.get_tensor_bytes(raw_name)?;
let uploaded = executor
.load_quantized_weights_pooled(norm_name, bytes, *qtype, *offset)
.map_err(|e| RealizarError::GpuError {
reason: format!("Failed to upload {norm_name} to pool: {e}"),
})?;
total_bytes += uploaded;
}
}
let mut num_f32 = 0usize;
for (raw_name, norm_name) in &tensor_names {
let entry = model
.get_tensor(raw_name)
.ok_or_else(|| RealizarError::FormatError {
reason: format!("Tensor disappeared: {raw_name}"),
})?;
let dtype = entry.dtype.as_str();
match dtype {
"Q4_K" | "q4_k" | "Q6_K" | "q6_k" => continue,
other if crate::apr::dequant::dtype_to_ggml_qtype(other).is_some() => continue,
_ => {},
}
let bytes = model.get_tensor_bytes(raw_name)?;
match dtype {
"F32" | "f32" => {
let floats: Vec<f32> = bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
if norm_name.contains("norm") {
let uploaded =
executor
.cache_rmsnorm_gamma(norm_name, &floats)
.map_err(|e| RealizarError::GpuError {
reason: format!("Failed to cache norm {norm_name}: {e}"),
})?;
total_bytes += uploaded;
if let Some(alias) = norm_alias(norm_name) {
let _ = executor.cache_rmsnorm_gamma(&alias, &floats);
}
} else {
let uploaded = executor.load_weights(norm_name, &floats).map_err(|e| {
RealizarError::GpuError {
reason: format!("Failed to upload F32 {norm_name}: {e}"),
}
})?;
total_bytes += uploaded;
}
num_f32 += 1;
},
"F16" | "f16" => {
let floats: Vec<f32> = bytes
.chunks_exact(2)
.map(|c| {
let bits = u16::from_le_bytes([c[0], c[1]]);
half::f16::from_bits(bits).to_f32()
})
.collect();
if norm_name.contains("norm") {
let uploaded =
executor
.cache_rmsnorm_gamma(norm_name, &floats)
.map_err(|e| RealizarError::GpuError {
reason: format!("Failed to cache F16 norm {norm_name}: {e}"),
})?;
total_bytes += uploaded;
if let Some(alias) = norm_alias(norm_name) {
let _ = executor.cache_rmsnorm_gamma(&alias, &floats);
}
} else {
let uploaded = executor.load_weights(norm_name, &floats).map_err(|e| {
RealizarError::GpuError {
reason: format!("Failed to upload F16 {norm_name}: {e}"),
}
})?;
total_bytes += uploaded;
}
num_f32 += 1;
},
other => {
eprintln!("[ALB-095] Skipping unsupported dtype {other} for {raw_name}");
},
}
}
Ok(UploadResult {
total_bytes,
num_tensors: num_q4k + num_f32,
num_q4k_tensors: num_q4k,
num_f32_tensors: num_f32,
})
}
#[cfg(feature = "cuda")]
pub fn parse_apr_q4k_config(model: &AprV2Model) -> Result<AprQ4KConfig> {
let meta = model.metadata();
let hidden_dim = meta.hidden_size.ok_or_else(|| RealizarError::FormatError {
reason: "APR metadata missing hidden_size".to_string(),
})?;
let num_heads = meta.num_heads.ok_or_else(|| RealizarError::FormatError {
reason: "APR metadata missing num_heads".to_string(),
})?;
let num_kv_heads = meta.num_kv_heads.unwrap_or(num_heads);
let num_layers = meta.num_layers.ok_or_else(|| RealizarError::FormatError {
reason: "APR metadata missing num_layers".to_string(),
})?;
let vocab_size = meta.vocab_size.ok_or_else(|| RealizarError::FormatError {
reason: "APR metadata missing vocab_size".to_string(),
})?;
let intermediate_dim = meta.intermediate_size.unwrap_or(hidden_dim * 4);
let head_dim = meta
.extra
.get("head_dim")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
.unwrap_or_else(|| {
let q_tensor = model
.get_tensor("model.layers.0.self_attn.q_proj.weight")
.or_else(|| model.get_tensor("blk.0.attn_q.weight"));
if let Some(t) = q_tensor {
let num_blocks = (t.size / 18) as usize;
let total_elements = num_blocks * 32;
let q_out_dim = total_elements / hidden_dim;
let inferred = q_out_dim / num_heads;
if inferred != hidden_dim / num_heads {
eprintln!(
"[ALB-095] Inferred head_dim={} from q_proj weight (vs default {})",
inferred,
hidden_dim / num_heads
);
}
inferred
} else {
hidden_dim / num_heads
}
});
let eps = meta.rms_norm_eps.unwrap_or(1e-6);
let rope_theta = meta.rope_theta.unwrap_or(10000.0);
let mut num_experts = meta
.extra
.get("num_experts")
.and_then(|v| v.as_u64())
.map(|v| v as usize);
let mut num_experts_per_tok = meta
.extra
.get("num_experts_per_tok")
.and_then(|v| v.as_u64())
.map(|v| v as usize);
let mut moe_intermediate_size = meta
.extra
.get("moe_intermediate_size")
.and_then(|v| v.as_u64())
.map(|v| v as usize);
if num_experts.is_none() {
let (inferred_experts, inferred_k, inferred_moe_inter) =
infer_moe_config_from_tensors(model, hidden_dim);
if let Some(n) = inferred_experts {
num_experts = Some(n);
num_experts_per_tok = num_experts_per_tok.or(inferred_k);
moe_intermediate_size = moe_intermediate_size.or(inferred_moe_inter);
}
}
Ok(AprQ4KConfig {
hidden_dim,
num_heads,
num_kv_heads,
head_dim,
num_layers,
intermediate_dim,
vocab_size,
eps,
rope_theta: rope_theta as f64,
num_experts,
num_experts_per_tok,
moe_intermediate_size,
})
}
#[cfg(feature = "cuda")]
fn infer_moe_config_from_tensors(
model: &AprV2Model,
hidden_dim: usize,
) -> (Option<usize>, Option<usize>, Option<usize>) {
let names = model.tensor_names();
let mut max_expert: Option<usize> = None;
for name in &names {
if let Some(rest) = name.strip_prefix("model.layers.0.mlp.experts.") {
if let Some(dot_pos) = rest.find('.') {
if let Ok(expert_id) = rest[..dot_pos].parse::<usize>() {
max_expert = Some(max_expert.map_or(expert_id, |m: usize| m.max(expert_id)));
}
}
}
}
let num_experts = max_expert.map(|m| m + 1);
if num_experts.is_none() {
return (None, None, None);
}
let num_experts_per_tok = Some(8);
let moe_intermediate: Option<usize> = model
.get_tensor("model.layers.0.mlp.experts.0.gate_proj.weight")
.map(|t| {
let num_blocks = (t.size / 18) as usize;
let total_elements = num_blocks * 32;
total_elements / hidden_dim
});
if let (Some(n), Some(k)) = (num_experts, num_experts_per_tok) {
let inter = moe_intermediate.unwrap_or(0);
eprintln!(
"[ALB-095] Inferred MoE config from tensor names: {} experts, top-{}, intermediate={}",
n, k, inter
);
}
(num_experts, num_experts_per_tok, moe_intermediate)
}
#[cfg(feature = "cuda")]
pub fn q4k_gemv(
executor: &mut CudaExecutor,
cache_key: &str,
input: &[f32],
n: usize,
k: usize,
) -> Result<Vec<f32>> {
let mut output = vec![0.0f32; n];
executor
.q4k_gemv_cached(cache_key, input, &mut output, n as u32, k as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("Q4K GEMV failed for {cache_key}: {e}"),
})?;
Ok(output)
}
#[cfg(feature = "cuda")]
fn q4k_gemv_reuse_input(
executor: &mut CudaExecutor,
cache_key: &str,
n: usize,
k: usize,
) -> Result<Vec<f32>> {
let mut output = vec![0.0f32; n];
let out_ptr = executor
.ensure_gemv_output_buffer(n)
.map_err(|e| RealizarError::GpuError {
reason: format!("Output buffer: {e}"),
})?;
executor
.q4k_gemv_launch_to_ptr(cache_key, out_ptr, n as u32, k as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("Q4K GEMV launch for {cache_key}: {e}"),
})?;
executor
.sync_stream()
.map_err(|e| RealizarError::GpuError {
reason: format!("Sync: {e}"),
})?;
executor
.download_gemv_output(0, &mut output)
.map_err(|e| RealizarError::GpuError {
reason: format!("Download: {e}"),
})?;
Ok(output)
}
#[cfg(feature = "cuda")]
fn q4k_batch_qkv(
executor: &mut CudaExecutor,
normed: &[f32],
q_key: &str,
k_key: &str,
v_key: &str,
q_dim: usize,
kv_dim: usize,
hidden_dim: usize,
) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>)> {
executor
.q4k_upload_to_input_buffer(normed, hidden_dim as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("Upload: {e}"),
})?;
let ptr_a = executor
.ensure_gemv_output_buffer(q_dim)
.map_err(|e| RealizarError::GpuError {
reason: format!("Output A: {e}"),
})?;
let ptr_b =
executor
.ensure_gemv_output_buffer_b(kv_dim)
.map_err(|e| RealizarError::GpuError {
reason: format!("Output B: {e}"),
})?;
let ptr_c =
executor
.ensure_gemv_output_buffer_c(kv_dim)
.map_err(|e| RealizarError::GpuError {
reason: format!("Output C: {e}"),
})?;
executor
.q4k_gemv_launch_to_ptr(q_key, ptr_a, q_dim as u32, hidden_dim as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("Q launch: {e}"),
})?;
executor
.q4k_gemv_launch_to_ptr(k_key, ptr_b, kv_dim as u32, hidden_dim as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("K launch: {e}"),
})?;
executor
.q4k_gemv_launch_to_ptr(v_key, ptr_c, kv_dim as u32, hidden_dim as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("V launch: {e}"),
})?;
executor
.sync_stream()
.map_err(|e| RealizarError::GpuError {
reason: format!("Sync: {e}"),
})?;
let mut q = vec![0.0f32; q_dim];
let mut k = vec![0.0f32; kv_dim];
let mut v = vec![0.0f32; kv_dim];
executor
.download_gemv_output(0, &mut q)
.map_err(|e| RealizarError::GpuError {
reason: format!("Q download: {e}"),
})?;
executor
.download_gemv_output(1, &mut k)
.map_err(|e| RealizarError::GpuError {
reason: format!("K download: {e}"),
})?;
executor
.download_gemv_output(2, &mut v)
.map_err(|e| RealizarError::GpuError {
reason: format!("V download: {e}"),
})?;
Ok((q, k, v))
}
#[cfg(feature = "cuda")]
fn q4k_batch_gate_up(
executor: &mut CudaExecutor,
gate_key: &str,
up_key: &str,
intermediate_dim: usize,
hidden_dim: usize,
) -> Result<(Vec<f32>, Vec<f32>)> {
let ptr_a = executor
.ensure_gemv_output_buffer(intermediate_dim)
.map_err(|e| RealizarError::GpuError {
reason: format!("Output A: {e}"),
})?;
let ptr_b = executor
.ensure_gemv_output_buffer_b(intermediate_dim)
.map_err(|e| RealizarError::GpuError {
reason: format!("Output B: {e}"),
})?;
executor
.q4k_gemv_launch_to_ptr(gate_key, ptr_a, intermediate_dim as u32, hidden_dim as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("Gate launch: {e}"),
})?;
executor
.q4k_gemv_launch_to_ptr(up_key, ptr_b, intermediate_dim as u32, hidden_dim as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("Up launch: {e}"),
})?;
executor
.sync_stream()
.map_err(|e| RealizarError::GpuError {
reason: format!("Sync: {e}"),
})?;
let mut gate_out = vec![0.0f32; intermediate_dim];
let mut up_out = vec![0.0f32; intermediate_dim];
executor
.download_gemv_output(0, &mut gate_out)
.map_err(|e| RealizarError::GpuError {
reason: format!("Gate download: {e}"),
})?;
executor
.download_gemv_output(1, &mut up_out)
.map_err(|e| RealizarError::GpuError {
reason: format!("Up download: {e}"),
})?;
Ok((gate_out, up_out))
}
#[cfg(feature = "cuda")]
pub fn f32_matmul(weight: &[f32], input: &[f32], out_dim: usize, in_dim: usize) -> Vec<f32> {
let mut output = vec![0.0f32; out_dim];
for i in 0..out_dim {
let offset = i * in_dim;
let mut sum = 0.0f32;
for j in 0..in_dim {
sum += weight[offset + j] * input[j];
}
output[i] = sum;
}
output
}
#[cfg(feature = "cuda")]
fn rms_norm(input: &[f32], weight: &[f32], eps: f32) -> Vec<f32> {
let n = input.len();
let mut sum_sq = 0.0f32;
for &v in input {
sum_sq += v * v;
}
let rms = (sum_sq / n as f32 + eps).sqrt();
let inv_rms = 1.0 / rms;
let mut output = vec![0.0f32; n];
for i in 0..n {
output[i] = input[i] * inv_rms * weight[i];
}
output
}
#[cfg(feature = "cuda")]
fn apply_rope_neox(
data: &mut [f32],
num_heads: usize,
head_dim: usize,
theta: f64,
position: usize,
) {
for h in 0..num_heads {
let offset = h * head_dim;
let half = head_dim / 2;
for i in 0..half {
let freq = 1.0 / theta.powf(2.0 * i as f64 / head_dim as f64);
let angle = position as f64 * freq;
let cos_val = angle.cos() as f32;
let sin_val = angle.sin() as f32;
let x0 = data[offset + i];
let x1 = data[offset + half + i];
data[offset + i] = x0 * cos_val - x1 * sin_val;
data[offset + half + i] = x0 * sin_val + x1 * cos_val;
}
}
}
#[cfg(feature = "cuda")]
fn gqa_attention(
q: &[f32],
full_k: &[f32],
full_v: &[f32],
kv_len: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Vec<f32> {
let q_per_kv = num_heads / num_kv_heads;
let kv_dim = num_kv_heads * head_dim;
let scale = 1.0 / (head_dim as f32).sqrt();
let mut output = vec![0.0f32; num_heads * head_dim];
for h in 0..num_heads {
let kv_h = h / q_per_kv;
let q_offset = h * head_dim;
let q_head = &q[q_offset..q_offset + head_dim];
let mut scores = vec![0.0f32; kv_len];
for pos in 0..kv_len {
let k_offset = pos * kv_dim + kv_h * head_dim;
let mut dot = 0.0f32;
for d in 0..head_dim {
dot += q_head[d] * full_k[k_offset + d];
}
scores[pos] = dot * scale;
}
let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut exp_sum = 0.0f32;
for s in &mut scores {
*s = (*s - max_score).exp();
exp_sum += *s;
}
for s in &mut scores {
*s /= exp_sum;
}
let out_offset = h * head_dim;
for pos in 0..kv_len {
let v_offset = pos * kv_dim + kv_h * head_dim;
let w = scores[pos];
for d in 0..head_dim {
output[out_offset + d] += w * full_v[v_offset + d];
}
}
}
output
}
#[cfg(feature = "cuda")]
#[inline]
fn silu(x: f32) -> f32 {
x / (1.0 + (-x).exp())
}
#[cfg(feature = "cuda")]
fn moe_ffn_forward(
executor: &mut CudaExecutor,
hidden_state: &[f32],
layer_idx: usize,
hidden_dim: usize,
num_experts: usize,
num_experts_per_tok: usize,
moe_intermediate: usize,
) -> Result<Vec<f32>> {
let gate_key = format!("model.layers.{layer_idx}.mlp.gate.weight");
let logits = q4k_gemv(executor, &gate_key, hidden_state, num_experts, hidden_dim)?;
let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut probs = vec![0.0f32; num_experts];
let mut exp_sum = 0.0f32;
for (i, &l) in logits.iter().enumerate() {
probs[i] = (l - max_logit).exp();
exp_sum += probs[i];
}
for p in &mut probs {
*p /= exp_sum;
}
let mut indexed: Vec<(usize, f32)> = probs.into_iter().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_k: Vec<(usize, f32)> = indexed.into_iter().take(num_experts_per_tok).collect();
let weight_sum: f32 = top_k.iter().map(|(_, w)| w).sum();
let weights: Vec<f32> = if weight_sum > 0.0 {
top_k.iter().map(|(_, w)| w / weight_sum).collect()
} else {
vec![1.0 / num_experts_per_tok as f32; num_experts_per_tok]
};
let mut routed_out = vec![0.0f32; hidden_dim];
executor
.q4k_upload_to_input_buffer(hidden_state, hidden_dim as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("Upload: {e}"),
})?;
for (idx, &(expert_id, _)) in top_k.iter().enumerate() {
let gate_key = format!("model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight");
let up_key = format!("model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight");
let down_key = format!("model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight");
let (gate_out, up_out) =
q4k_batch_gate_up(executor, &gate_key, &up_key, moe_intermediate, hidden_dim)?;
let mut act = vec![0.0f32; moe_intermediate];
for i in 0..moe_intermediate {
act[i] = silu(gate_out[i]) * up_out[i];
}
let down_out = q4k_gemv(executor, &down_key, &act, hidden_dim, moe_intermediate)?;
if idx + 1 < top_k.len() {
executor
.q4k_upload_to_input_buffer(hidden_state, hidden_dim as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("Re-upload: {e}"),
})?;
}
let w = weights[idx];
for i in 0..hidden_dim {
routed_out[i] += w * down_out[i];
}
}
let shared_gate_key = format!("model.layers.{layer_idx}.mlp.shared_expert.gate_proj.weight");
if executor.has_quantized_weights(&shared_gate_key) {
let shared_up_key = format!("model.layers.{layer_idx}.mlp.shared_expert.up_proj.weight");
let shared_down_key =
format!("model.layers.{layer_idx}.mlp.shared_expert.down_proj.weight");
executor
.q4k_upload_to_input_buffer(hidden_state, hidden_dim as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("Upload: {e}"),
})?;
let (gate_out, up_out) = q4k_batch_gate_up(
executor,
&shared_gate_key,
&shared_up_key,
moe_intermediate,
hidden_dim,
)?;
let mut act = vec![0.0f32; moe_intermediate];
for i in 0..moe_intermediate {
act[i] = silu(gate_out[i]) * up_out[i];
}
let shared_out = q4k_gemv(
executor,
&shared_down_key,
&act,
hidden_dim,
moe_intermediate,
)?;
let gate_weight_key = format!("model.layers.{layer_idx}.mlp.shared_expert_gate.weight");
if executor.has_quantized_weights(&gate_weight_key) {
let gate_logit = q4k_gemv(executor, &gate_weight_key, hidden_state, 1, hidden_dim)?;
let gate_scale = 1.0 / (1.0 + (-gate_logit[0]).exp()); for i in 0..hidden_dim {
routed_out[i] += gate_scale * shared_out[i];
}
} else {
for i in 0..hidden_dim {
routed_out[i] += shared_out[i];
}
}
}
Ok(routed_out)
}
#[cfg(feature = "cuda")]
fn dense_ffn_forward(
executor: &mut CudaExecutor,
hidden_state: &[f32],
layer_idx: usize,
hidden_dim: usize,
intermediate_dim: usize,
) -> Result<Vec<f32>> {
let gate_key = format!("model.layers.{layer_idx}.mlp.gate_proj.weight");
let up_key = format!("model.layers.{layer_idx}.mlp.up_proj.weight");
let down_key = format!("model.layers.{layer_idx}.mlp.down_proj.weight");
executor
.q4k_upload_to_input_buffer(hidden_state, hidden_dim as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("Upload: {e}"),
})?;
let (gate_out, up_out) =
q4k_batch_gate_up(executor, &gate_key, &up_key, intermediate_dim, hidden_dim)?;
let mut act = vec![0.0f32; intermediate_dim];
for i in 0..intermediate_dim {
act[i] = silu(gate_out[i]) * up_out[i];
}
q4k_gemv(executor, &down_key, &act, hidden_dim, intermediate_dim)
}
#[cfg(feature = "cuda")]
fn per_head_rms_norm(
data: &mut [f32],
num_heads: usize,
head_dim: usize,
weight: &[f32],
eps: f32,
) {
for h in 0..num_heads {
let offset = h * head_dim;
let head = &data[offset..offset + head_dim];
let mut sum_sq = 0.0f32;
for &v in head {
sum_sq += v * v;
}
let rms = (sum_sq / head_dim as f32 + eps).sqrt();
let inv_rms = 1.0 / rms;
for i in 0..head_dim {
data[offset + i] *= inv_rms * weight[i];
}
}
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn forward_token_apr_q4k(
executor: &mut CudaExecutor,
config: &AprQ4KConfig,
embedding_weight: &[f32],
output_norm_weight: &[f32],
layer_norm_weights: &[(Vec<f32>, Vec<f32>, Option<Vec<f32>>, Option<Vec<f32>>)],
layer_qkv_biases: &[(Option<Vec<f32>>, Option<Vec<f32>>, Option<Vec<f32>>)],
kv_cache_k: &mut Vec<Vec<f32>>, kv_cache_v: &mut Vec<Vec<f32>>,
token_id: u32,
position: usize,
) -> Result<Vec<f32>> {
contract_pre_prefill_phase!(embedding_weight);
let hidden_dim = config.hidden_dim;
let num_heads = config.num_heads;
let num_kv_heads = config.num_kv_heads;
let head_dim = config.head_dim;
let kv_dim = num_kv_heads * head_dim;
let q_dim = num_heads * head_dim;
let embed_offset = token_id as usize * hidden_dim;
let mut hidden: Vec<f32> = embedding_weight[embed_offset..embed_offset + hidden_dim].to_vec();
for layer_idx in 0..config.num_layers {
let normed = rms_norm(&hidden, &layer_norm_weights[layer_idx].0, config.eps);
let q_key = format!("model.layers.{layer_idx}.self_attn.q_proj.weight");
let k_key = format!("model.layers.{layer_idx}.self_attn.k_proj.weight");
let v_key = format!("model.layers.{layer_idx}.self_attn.v_proj.weight");
let (mut q, mut k, mut v) = q4k_batch_qkv(
executor, &normed, &q_key, &k_key, &v_key, q_dim, kv_dim, hidden_dim,
)?;
if let Some((ref q_bias, ref k_bias, ref v_bias)) = layer_qkv_biases.get(layer_idx) {
if let Some(qb) = q_bias {
for (qi, bi) in q.iter_mut().zip(qb.iter()) {
*qi += *bi;
}
}
if let Some(kb) = k_bias {
for (ki, bi) in k.iter_mut().zip(kb.iter()) {
*ki += *bi;
}
}
if let Some(vb) = v_bias {
for (vi, bi) in v.iter_mut().zip(vb.iter()) {
*vi += *bi;
}
}
}
if let Some(ref q_norm_w) = layer_norm_weights[layer_idx].2 {
per_head_rms_norm(&mut q, num_heads, head_dim, q_norm_w, config.eps);
}
if let Some(ref k_norm_w) = layer_norm_weights[layer_idx].3 {
per_head_rms_norm(&mut k, num_kv_heads, head_dim, k_norm_w, config.eps);
}
apply_rope_neox(&mut q, num_heads, head_dim, config.rope_theta, position);
apply_rope_neox(&mut k, num_kv_heads, head_dim, config.rope_theta, position);
let kv_len = kv_cache_k[layer_idx].len() / kv_dim + 1;
let mut full_k = kv_cache_k[layer_idx].clone();
full_k.extend_from_slice(&k);
let mut full_v = kv_cache_v[layer_idx].clone();
full_v.extend_from_slice(&v);
let attn_out = gqa_attention(
&q,
&full_k,
&full_v,
kv_len,
num_heads,
num_kv_heads,
head_dim,
);
kv_cache_k[layer_idx].extend_from_slice(&k);
kv_cache_v[layer_idx].extend_from_slice(&v);
let o_key = format!("model.layers.{layer_idx}.self_attn.o_proj.weight");
executor
.q4k_upload_to_input_buffer(&attn_out, q_dim as u32)
.map_err(|e| RealizarError::GpuError {
reason: format!("Upload: {e}"),
})?;
let attn_proj = q4k_gemv_reuse_input(executor, &o_key, hidden_dim, q_dim)?;
for i in 0..hidden_dim {
hidden[i] += attn_proj[i];
}
let ffn_normed = rms_norm(&hidden, &layer_norm_weights[layer_idx].1, config.eps);
let ffn_out = if let (Some(num_experts), Some(k_experts), Some(moe_inter)) = (
config.num_experts,
config.num_experts_per_tok,
config.moe_intermediate_size,
) {
moe_ffn_forward(
executor,
&ffn_normed,
layer_idx,
hidden_dim,
num_experts,
k_experts,
moe_inter,
)?
} else {
dense_ffn_forward(
executor,
&ffn_normed,
layer_idx,
hidden_dim,
config.intermediate_dim,
)?
};
for i in 0..hidden_dim {
hidden[i] += ffn_out[i];
}
}
let final_normed = rms_norm(&hidden, output_norm_weight, config.eps);
let lm_head_key = "model.embed_tokens.weight"; let lm_head_output_key = "lm_head.weight";
if executor.has_quantized_weights(lm_head_output_key) {
q4k_gemv(
executor,
lm_head_output_key,
&final_normed,
config.vocab_size,
hidden_dim,
)
} else if executor.has_quantized_weights(lm_head_key) {
q4k_gemv(
executor,
lm_head_key,
&final_normed,
config.vocab_size,
hidden_dim,
)
} else {
Ok(f32_matmul(
embedding_weight,
&final_normed,
config.vocab_size,
hidden_dim,
))
}
}