use super::super::{cpu_matmul, cpu_matmul_transposed_simd, exceeds_gpu_buffer_limit};
use super::model::GpuModel;
use super::types::GpuModelConfig;
use crate::error::{RealizarError, Result};
pub fn generate_gpu(
model: &mut GpuModel,
prompt: &[usize],
max_tokens: usize,
) -> Result<Vec<usize>> {
let mut tokens = prompt.to_vec();
let vocab_size = model.config.vocab_size;
let logits = model.forward_gpu(&tokens)?;
let last_pos_start = (tokens.len() - 1) * vocab_size;
let last_logits = &logits[last_pos_start..last_pos_start + vocab_size];
let next_token = argmax(last_logits);
tokens.push(next_token);
if vocab_size > 8192 {
for _ in 1..max_tokens {
let next_token = forward_single_token_greedy(model, &tokens)?;
tokens.push(next_token);
}
} else {
for _ in 1..max_tokens {
let logits = forward_single_token(model, &tokens)?;
let next_token = argmax(&logits);
tokens.push(next_token);
}
}
Ok(tokens)
}
pub fn forward_single_token(model: &mut GpuModel, tokens: &[usize]) -> Result<Vec<f32>> {
let hidden_dim = model.config.hidden_dim;
let vocab_size = model.config.vocab_size;
let last_token = *tokens.last().ok_or_else(|| RealizarError::InvalidShape {
reason: "Token list empty".to_string(),
})?;
if last_token >= vocab_size {
return Err(RealizarError::InvalidShape {
reason: format!("Token {} out of bounds", last_token),
});
}
let offset = last_token * hidden_dim;
let mut hidden: Vec<f32> = model.embedding_weights[offset..offset + hidden_dim].to_vec();
for block_idx in 0..model.block_weights.len() {
hidden = forward_block_single(model, &hidden, block_idx)?;
}
hidden = GpuModel::layer_norm_static(
&hidden,
&model.final_norm_weight,
&model.final_norm_bias,
hidden_dim,
model.config.eps,
);
let lm_head_elements = hidden_dim * vocab_size;
let output = if exceeds_gpu_buffer_limit(lm_head_elements) {
cpu_matmul_transposed_simd(
&hidden,
&model.lm_head_weight_t,
&model.lm_head_bias,
hidden_dim,
vocab_size,
)
} else {
let lm_head_weight = model.lm_head_weight.clone();
let logits = model.do_matmul(&hidden, &lm_head_weight, 1, hidden_dim, vocab_size)?;
logits
.iter()
.zip(model.lm_head_bias.iter())
.map(|(&x, &b)| x + b)
.collect()
};
Ok(output)
}
pub fn forward_single_token_greedy(model: &mut GpuModel, tokens: &[usize]) -> Result<usize> {
let hidden_dim = model.config.hidden_dim;
let vocab_size = model.config.vocab_size;
let last_token = *tokens.last().ok_or_else(|| RealizarError::InvalidShape {
reason: "Token list empty".to_string(),
})?;
if last_token >= vocab_size {
return Err(RealizarError::InvalidShape {
reason: format!("Token {} out of bounds", last_token),
});
}
let offset = last_token * hidden_dim;
let mut hidden: Vec<f32> = model.embedding_weights[offset..offset + hidden_dim].to_vec();
for block_idx in 0..model.block_weights.len() {
hidden = forward_block_single(model, &hidden, block_idx)?;
}
hidden = GpuModel::layer_norm_static(
&hidden,
&model.final_norm_weight,
&model.final_norm_bias,
hidden_dim,
model.config.eps,
);
let lm_head_elements = hidden_dim * vocab_size;
if vocab_size > 8192 || exceeds_gpu_buffer_limit(lm_head_elements) {
Ok(optimized_lm_head_argmax_transposed(
&hidden,
&model.lm_head_weight_t,
&model.lm_head_bias,
hidden_dim,
vocab_size,
))
} else {
let lm_head_weight = model.lm_head_weight.clone();
let logits = model.do_matmul(&hidden, &lm_head_weight, 1, hidden_dim, vocab_size)?;
let output: Vec<f32> = logits
.iter()
.zip(model.lm_head_bias.iter())
.map(|(&x, &b)| x + b)
.collect();
Ok(argmax(&output))
}
}
#[allow(clippy::unnecessary_wraps)]
pub fn forward_block_single(
model: &mut GpuModel,
input: &[f32],
block_idx: usize,
) -> Result<Vec<f32>> {
let hidden_dim = model.config.hidden_dim;
let intermediate_dim = model.config.intermediate_dim;
let kv_dim = model.config.kv_dim();
let qkv_dim = model.config.qkv_dim();
let block = &model.block_weights[block_idx];
let normed = GpuModel::layer_norm_static(
input,
&block.attn_norm_weight,
&block.attn_norm_bias,
hidden_dim,
model.config.eps,
);
let qkv_weight = &model.block_weights[block_idx].qkv_weight;
let qkv = cpu_matmul(&normed, qkv_weight, 1, hidden_dim, qkv_dim);
let v = &qkv[hidden_dim + kv_dim..];
let num_kv_heads = model.config.num_kv_heads;
let heads_per_kv = model.config.num_heads / num_kv_heads;
let head_dim = model.config.head_dim();
let attn_out: Vec<f32> = if heads_per_kv == 1 {
v.to_vec()
} else {
let mut expanded = Vec::with_capacity(hidden_dim);
for kv_h in 0..num_kv_heads {
let v_head = &v[kv_h * head_dim..(kv_h + 1) * head_dim];
for _ in 0..heads_per_kv {
expanded.extend_from_slice(v_head);
}
}
expanded
};
let out_weight = &model.block_weights[block_idx].out_weight;
let out_bias = &model.block_weights[block_idx].out_bias;
let projected = cpu_matmul(&attn_out, out_weight, 1, hidden_dim, hidden_dim);
let residual1: Vec<f32> = input
.iter()
.zip(projected.iter())
.enumerate()
.map(|(i, (&inp, &proj))| inp + proj + out_bias[i])
.collect();
let ffn_norm_weight = &model.block_weights[block_idx].ffn_norm_weight;
let ffn_norm_bias = &model.block_weights[block_idx].ffn_norm_bias;
let ffn_normed = GpuModel::layer_norm_static(
&residual1,
ffn_norm_weight,
ffn_norm_bias,
hidden_dim,
model.config.eps,
);
let ffn_fc1_weight = &model.block_weights[block_idx].ffn_fc1_weight;
let ffn_fc1_bias = &model.block_weights[block_idx].ffn_fc1_bias;
let output: Vec<f32> = if let Some(ref moe) = model.block_weights[block_idx].moe_experts {
let moe_out = super::moe_dispatch::moe_forward_token(&ffn_normed, moe, hidden_dim);
residual1
.iter()
.zip(moe_out.iter())
.map(|(&r, &m)| r + m)
.collect()
} else {
let activated: Vec<f32> = if let Some(ref gate_weight) =
model.block_weights[block_idx].ffn_gate_weight
{
let up_out = cpu_matmul(&ffn_normed, ffn_fc1_weight, 1, hidden_dim, intermediate_dim);
let gate_out = cpu_matmul(&ffn_normed, gate_weight, 1, hidden_dim, intermediate_dim);
up_out
.iter()
.zip(gate_out.iter())
.map(|(&u, &g)| {
let silu_g = g / (1.0 + (-g).exp());
silu_g * u
})
.collect()
} else {
let fc1_out = cpu_matmul(&ffn_normed, ffn_fc1_weight, 1, hidden_dim, intermediate_dim);
fc1_out
.iter()
.enumerate()
.map(|(i, &x)| {
let x = x + ffn_fc1_bias[i];
0.5 * x
* (1.0
+ ((2.0f32 / std::f32::consts::PI).sqrt()
* (x + 0.044_715 * x.powi(3)))
.tanh())
})
.collect()
};
let ffn_fc2_weight = &model.block_weights[block_idx].ffn_fc2_weight;
let ffn_fc2_bias = &model.block_weights[block_idx].ffn_fc2_bias;
let fc2_out = cpu_matmul(&activated, ffn_fc2_weight, 1, intermediate_dim, hidden_dim);
residual1
.iter()
.zip(fc2_out.iter())
.enumerate()
.map(|(i, (&r, &fc))| r + fc + ffn_fc2_bias[i])
.collect()
};
Ok(output)
}
#[allow(clippy::items_after_statements)]
pub fn argmax(logits: &[f32]) -> usize {
if logits.len() <= 1024 {
return logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i);
}
const CHUNK_SIZE: usize = 4096;
let chunk_maxes: Vec<(usize, f32)> = logits
.chunks(CHUNK_SIZE)
.enumerate()
.map(|(chunk_idx, chunk)| {
let (local_idx, &max_val) = chunk
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.expect("chunk is non-empty by construction");
(chunk_idx * CHUNK_SIZE + local_idx, max_val)
})
.collect();
chunk_maxes
.into_iter()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(idx, _)| idx)
}
#[allow(clippy::many_single_char_names, clippy::items_after_statements)]
pub fn optimized_lm_head_argmax_transposed(
hidden: &[f32],
weight_t: &[f32], bias: &[f32],
hidden_dim: usize,
vocab_size: usize,
) -> usize {
use rayon::prelude::*;
const CHUNK_SIZE: usize = 4096;
(0..vocab_size)
.into_par_iter()
.step_by(CHUNK_SIZE)
.map(|chunk_start| {
let chunk_end = (chunk_start + CHUNK_SIZE).min(vocab_size);
let mut best_local_idx = chunk_start;
let mut best_local_val = f32::NEG_INFINITY;
for j in chunk_start..chunk_end {
let row = &weight_t[j * hidden_dim..(j + 1) * hidden_dim];
let dot: f32 = row.iter().zip(hidden.iter()).map(|(&w, &h)| w * h).sum();
let logit = dot + bias[j];
if logit > best_local_val {
best_local_val = logit;
best_local_idx = j;
}
}
(best_local_idx, best_local_val)
})
.reduce(
|| (0, f32::NEG_INFINITY),
|a, b| if a.1 > b.1 { a } else { b },
)
.0
}
fn extract_q_head(
q: &[f32],
head: usize,
seq_len: usize,
hidden_dim: usize,
head_dim: usize,
) -> Vec<f32> {
let mut q_head = Vec::with_capacity(seq_len * head_dim);
for i in 0..seq_len {
let start = i * hidden_dim + head * head_dim;
q_head.extend_from_slice(&q[start..start + head_dim]);
}
q_head
}
fn extract_kv_head(
k: &[f32],
v: &[f32],
kv_head: usize,
seq_len: usize,
kv_dim: usize,
head_dim: usize,
) -> (Vec<f32>, Vec<f32>) {
let mut k_head = Vec::with_capacity(seq_len * head_dim);
let mut v_head = Vec::with_capacity(seq_len * head_dim);
for i in 0..seq_len {
let start = i * kv_dim + kv_head * head_dim;
k_head.extend_from_slice(&k[start..start + head_dim]);
v_head.extend_from_slice(&v[start..start + head_dim]);
}
(k_head, v_head)
}
include!("attention.rs");
include!("batch_argmax_single_first.rs");