use crate::error::{RealizarError, Result};
use crate::gguf::{GGUFConfig, GGUFTransformer};
use crate::quantize::{fused_q4k_tiled_matvec, QK_K};
use rayon::prelude::*;
use trueno::{Matrix, Vector};
const TILE_SIZE: usize = 64;
#[derive(Debug, Clone, Copy)]
pub struct ThreadConfig {
pub n_threads_batch: usize,
pub n_threads_decode: usize,
}
impl ThreadConfig {
pub fn auto() -> Self {
let num_cpus = rayon::current_num_threads();
Self {
n_threads_batch: num_cpus,
n_threads_decode: (num_cpus / 2).max(1),
}
}
pub fn new(n_threads_batch: usize, n_threads_decode: usize) -> Self {
Self {
n_threads_batch: n_threads_batch.max(1),
n_threads_decode: n_threads_decode.max(1),
}
}
pub fn threads_for(&self, is_prefill: bool) -> usize {
if is_prefill {
self.n_threads_batch
} else {
self.n_threads_decode
}
}
}
impl Default for ThreadConfig {
fn default() -> Self {
Self::auto()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InferenceMode {
Prefill,
Decode,
}
pub fn configure_thread_pool(num_threads: usize) -> Result<()> {
rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build_global()
.map_err(|e| {
RealizarError::InvalidConfiguration(format!("Failed to configure thread pool: {e}"))
})
}
#[derive(Clone)]
pub struct Q4KWeight {
pub data: Vec<u8>,
pub in_dim: usize,
pub out_dim: usize,
}
impl Q4KWeight {
pub fn new(data: Vec<u8>, in_dim: usize, out_dim: usize) -> Result<Self> {
let super_blocks_per_row = in_dim.div_ceil(QK_K);
let bytes_per_row = super_blocks_per_row * 144;
let expected_size = out_dim * bytes_per_row;
if data.len() < expected_size {
return Err(RealizarError::InvalidShape {
reason: format!(
"Q4K weight data too small: got {} bytes, expected {} for {}x{}",
data.len(),
expected_size,
out_dim,
in_dim
),
});
}
Ok(Self {
data,
in_dim,
out_dim,
})
}
pub fn matvec(&self, input: &[f32]) -> Result<Vec<f32>> {
if input.len() != self.in_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Input length {} doesn't match weight in_dim {}",
input.len(),
self.in_dim
),
});
}
fused_q4k_tiled_matvec(&self.data, input, self.in_dim, self.out_dim, None)
}
#[must_use]
pub fn memory_bytes(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn f32_equivalent_bytes(&self) -> usize {
self.in_dim * self.out_dim * 4
}
#[must_use]
pub fn compression_ratio(&self) -> f32 {
self.f32_equivalent_bytes() as f32 / self.memory_bytes() as f32
}
}
const PARALLEL_THRESHOLD: usize = 256;
pub fn simd_matmul(input: &[f32], weight: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
let seq_len = input.len() / in_dim;
if seq_len == 1 {
if out_dim >= PARALLEL_THRESHOLD {
return parallel_matmul_single(input, weight, in_dim, out_dim);
}
let input_vec = Vector::from_slice(input);
let mut output = Vec::with_capacity(out_dim);
for o in 0..out_dim {
let weight_row = &weight[o * in_dim..(o + 1) * in_dim];
let weight_vec = Vector::from_slice(weight_row);
let dot = input_vec.dot(&weight_vec).unwrap_or(0.0);
output.push(dot);
}
return output;
}
if seq_len * out_dim >= PARALLEL_THRESHOLD * 4 {
return tiled_matmul(input, weight, seq_len, in_dim, out_dim);
}
let input_matrix =
Matrix::from_vec(seq_len, in_dim, input.to_vec()).expect("Valid input matrix");
let weight_matrix =
Matrix::from_vec(out_dim, in_dim, weight.to_vec()).expect("Valid weight matrix");
let weight_t = weight_matrix.transpose();
let result = input_matrix
.matmul(&weight_t)
.expect("Matrix multiplication should succeed");
result.as_slice().to_vec()
}
fn parallel_matmul_single(
input: &[f32],
weight: &[f32],
in_dim: usize,
out_dim: usize,
) -> Vec<f32> {
(0..out_dim)
.into_par_iter()
.map(|o| {
let weight_row = &weight[o * in_dim..(o + 1) * in_dim];
let input_vec = Vector::from_slice(input);
let weight_vec = Vector::from_slice(weight_row);
input_vec.dot(&weight_vec).unwrap_or(0.0)
})
.collect()
}
fn tiled_matmul(
input: &[f32], weight: &[f32], m: usize, k: usize, n: usize, ) -> Vec<f32> {
let mut output = vec![0.0f32; m * n];
output
.par_chunks_mut(TILE_SIZE * n)
.enumerate()
.for_each(|(tile_i, out_chunk)| {
let i_start = tile_i * TILE_SIZE;
let i_end = (i_start + TILE_SIZE).min(m);
let rows_in_tile = i_end - i_start;
for (local_i, out_row) in out_chunk.chunks_mut(n).take(rows_in_tile).enumerate() {
let i = i_start + local_i;
let input_row = &input[i * k..(i + 1) * k];
let input_vec = Vector::from_slice(input_row);
for j_tile in (0..n).step_by(TILE_SIZE) {
let j_end = (j_tile + TILE_SIZE).min(n);
for j in j_tile..j_end {
let weight_row = &weight[j * k..(j + 1) * k];
let weight_vec = Vector::from_slice(weight_row);
out_row[j] = input_vec.dot(&weight_vec).unwrap_or(0.0);
}
}
}
});
output
}
pub fn simd_dot(a: &[f32], b: &[f32]) -> f32 {
let va = Vector::from_slice(a);
let vb = Vector::from_slice(b);
va.dot(&vb).unwrap_or(0.0)
}
pub fn simd_add(a: &mut [f32], b: &[f32]) {
let va = Vector::from_slice(a);
let vb = Vector::from_slice(b);
if let Ok(result) = va.add(&vb) {
a.copy_from_slice(result.as_slice());
}
}
pub fn simd_mul(a: &mut [f32], b: &[f32]) {
let va = Vector::from_slice(a);
let vb = Vector::from_slice(b);
if let Ok(result) = va.mul(&vb) {
a.copy_from_slice(result.as_slice());
}
}
pub fn simd_silu(data: &mut [f32]) {
for x in data.iter_mut() {
let x_val = *x;
let sigmoid = 1.0 / (1.0 + (-x_val).exp());
*x = x_val * sigmoid;
}
}
pub fn simd_gelu(data: &mut [f32]) {
const SQRT_2_OVER_PI: f32 = 0.797_884_6;
const C: f32 = 0.044_715;
for x in data.iter_mut() {
let x_val = *x;
let inner = SQRT_2_OVER_PI * (x_val + C * x_val * x_val * x_val);
*x = 0.5 * x_val * (1.0 + inner.tanh());
}
}
pub fn simd_softmax(data: &mut [f32]) {
if data.is_empty() {
return;
}
let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for x in data.iter_mut() {
*x = (*x - max_val).exp();
sum += *x;
}
let inv_sum = 1.0 / sum;
for x in data.iter_mut() {
*x *= inv_sum;
}
}
#[derive(Debug, Clone)]
pub struct KVCache {
pub num_layers: usize,
hidden_dim: usize,
max_seq_len: usize,
seq_len: usize,
k_cache: Vec<f32>,
v_cache: Vec<f32>,
}
impl KVCache {
pub fn new(num_layers: usize, hidden_dim: usize, max_seq_len: usize) -> Self {
let cache_size = num_layers * max_seq_len * hidden_dim;
Self {
num_layers,
hidden_dim,
max_seq_len,
seq_len: 0,
k_cache: vec![0.0; cache_size],
v_cache: vec![0.0; cache_size],
}
}
pub fn store(&mut self, layer: usize, k: &[f32], v: &[f32]) {
if self.seq_len >= self.max_seq_len {
return; }
let offset = (layer * self.max_seq_len + self.seq_len) * self.hidden_dim;
let k_len = k.len().min(self.hidden_dim);
let v_len = v.len().min(self.hidden_dim);
self.k_cache[offset..offset + k_len].copy_from_slice(&k[..k_len]);
self.v_cache[offset..offset + v_len].copy_from_slice(&v[..v_len]);
}
pub fn advance(&mut self) {
if self.seq_len < self.max_seq_len {
self.seq_len += 1;
}
}
pub fn get_k(&self, layer: usize) -> &[f32] {
let start = layer * self.max_seq_len * self.hidden_dim;
let end = start + self.seq_len * self.hidden_dim;
&self.k_cache[start..end]
}
pub fn get_v(&self, layer: usize) -> &[f32] {
let start = layer * self.max_seq_len * self.hidden_dim;
let end = start + self.seq_len * self.hidden_dim;
&self.v_cache[start..end]
}
pub fn len(&self) -> usize {
self.seq_len
}
pub fn is_empty(&self) -> bool {
self.seq_len == 0
}
pub fn reset(&mut self) {
self.seq_len = 0;
}
}
pub fn attention_with_cache(
q: &[f32], k_cache: &[f32], v_cache: &[f32], num_heads: usize,
head_dim: usize,
) -> Vec<f32> {
let seq_len = k_cache.len() / (num_heads * head_dim);
if seq_len == 0 {
return q.to_vec(); }
let hidden_dim = num_heads * head_dim;
let mut output = vec![0.0f32; hidden_dim];
let scale = 1.0 / (head_dim as f32).sqrt();
output
.par_chunks_mut(head_dim)
.enumerate()
.for_each(|(h, out_head)| {
let q_head = &q[h * head_dim..(h + 1) * head_dim];
let mut scores = Vec::with_capacity(seq_len);
for pos in 0..seq_len {
let k_start = pos * hidden_dim + h * head_dim;
let k_head = &k_cache[k_start..k_start + head_dim];
let score = simd_dot(q_head, k_head) * scale;
scores.push(score);
}
let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for s in &mut scores {
*s = (*s - max_score).exp();
sum += *s;
}
let inv_sum = 1.0 / sum;
for s in &mut scores {
*s *= inv_sum;
}
for (pos, &score) in scores.iter().enumerate() {
let v_start = pos * hidden_dim + h * head_dim;
let v_head = &v_cache[v_start..v_start + head_dim];
for (i, &v) in v_head.iter().enumerate() {
out_head[i] += score * v;
}
}
});
output
}
#[derive(Debug, Clone)]
pub struct OptimizedKVCache {
pub num_layers: usize,
pub hidden_dim: usize,
pub max_seq_len: usize,
seq_len: usize,
k_cache: Vec<f32>,
v_cache: Vec<f32>,
}
impl OptimizedKVCache {
pub fn new(num_layers: usize, hidden_dim: usize, max_seq_len: usize) -> Self {
let cache_size = num_layers * max_seq_len * hidden_dim;
Self {
num_layers,
hidden_dim,
max_seq_len,
seq_len: 0,
k_cache: vec![0.0; cache_size],
v_cache: vec![0.0; cache_size],
}
}
pub fn store(&mut self, layer: usize, k: &[f32], v: &[f32]) {
if self.seq_len >= self.max_seq_len {
return; }
let k_offset = (layer * self.max_seq_len + self.seq_len) * self.hidden_dim;
let k_len = k.len().min(self.hidden_dim);
self.k_cache[k_offset..k_offset + k_len].copy_from_slice(&k[..k_len]);
let v_base = layer * self.hidden_dim * self.max_seq_len;
let v_len = v.len().min(self.hidden_dim);
for (dim, &val) in v.iter().enumerate().take(v_len) {
let v_offset = v_base + dim * self.max_seq_len + self.seq_len;
self.v_cache[v_offset] = val;
}
}
pub fn advance(&mut self) {
if self.seq_len < self.max_seq_len {
self.seq_len += 1;
}
}
pub fn get_k(&self, layer: usize) -> &[f32] {
let start = layer * self.max_seq_len * self.hidden_dim;
let end = start + self.seq_len * self.hidden_dim;
&self.k_cache[start..end]
}
pub fn get_v_transposed(&self, layer: usize) -> Vec<f32> {
let v_base = layer * self.hidden_dim * self.max_seq_len;
let mut result = Vec::with_capacity(self.hidden_dim * self.seq_len);
for dim in 0..self.hidden_dim {
let dim_start = v_base + dim * self.max_seq_len;
result.extend_from_slice(&self.v_cache[dim_start..dim_start + self.seq_len]);
}
result
}
pub fn get_v_raw(&self, layer: usize) -> &[f32] {
let start = layer * self.hidden_dim * self.max_seq_len;
let end = start + self.hidden_dim * self.max_seq_len;
&self.v_cache[start..end]
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn len(&self) -> usize {
self.seq_len
}
pub fn is_empty(&self) -> bool {
self.seq_len == 0
}
pub fn reset(&mut self) {
self.seq_len = 0;
}
}
pub fn attention_with_transposed_v(
q: &[f32], k_cache: &[f32], v_transposed: &[f32], num_heads: usize,
head_dim: usize,
seq_len: usize,
) -> Vec<f32> {
if seq_len == 0 {
return q.to_vec(); }
let hidden_dim = num_heads * head_dim;
let mut output = vec![0.0f32; hidden_dim];
let scale = 1.0 / (head_dim as f32).sqrt();
output
.par_chunks_mut(head_dim)
.enumerate()
.for_each(|(h, out_head)| {
let q_head = &q[h * head_dim..(h + 1) * head_dim];
let mut scores = Vec::with_capacity(seq_len);
for pos in 0..seq_len {
let k_start = pos * hidden_dim + h * head_dim;
let k_head = &k_cache[k_start..k_start + head_dim];
let score = simd_dot(q_head, k_head) * scale;
scores.push(score);
}
let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for s in &mut scores {
*s = (*s - max_score).exp();
sum += *s;
}
let inv_sum = 1.0 / sum;
for s in &mut scores {
*s *= inv_sum;
}
for (i, out_val) in out_head.iter_mut().enumerate() {
let dim_idx = h * head_dim + i;
let v_row_start = dim_idx * seq_len;
let v_row = &v_transposed[v_row_start..v_row_start + seq_len];
let mut acc = 0.0f32;
for (pos, &score) in scores.iter().enumerate() {
acc += score * v_row[pos];
}
*out_val = acc;
}
});
output
}
pub fn simd_layer_norm(input: &[f32], weight: &[f32], bias: Option<&[f32]>, eps: f32) -> Vec<f32> {
let hidden_dim = weight.len();
let seq_len = input.len() / hidden_dim;
let mut output = Vec::with_capacity(input.len());
for i in 0..seq_len {
let start = i * hidden_dim;
let end = start + hidden_dim;
let x = &input[start..end];
let x_vec = Vector::from_slice(x);
let mean = x_vec.sum().unwrap_or(0.0) / hidden_dim as f32;
let var: f32 = x.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / hidden_dim as f32;
let inv_std = (var + eps).sqrt().recip();
for j in 0..hidden_dim {
let normalized = (x[j] - mean) * inv_std;
let scaled = normalized * weight[j];
let out = if let Some(b) = bias {
scaled + b[j]
} else {
scaled
};
output.push(out);
}
}
output
}
pub fn simd_rms_norm(input: &[f32], weight: &[f32], eps: f32) -> Vec<f32> {
let hidden_dim = weight.len();
let seq_len = input.len() / hidden_dim;
let mut output = Vec::with_capacity(input.len());
for i in 0..seq_len {
let start = i * hidden_dim;
let end = start + hidden_dim;
let x = &input[start..end];
let x_vec = Vector::from_slice(x);
let sum_sq = x_vec.dot(&x_vec).unwrap_or(0.0);
let rms = (sum_sq / hidden_dim as f32 + eps).sqrt();
let inv_rms = rms.recip();
for j in 0..hidden_dim {
let normalized = x[j] * inv_rms;
output.push(normalized * weight[j]);
}
}
output
}
pub fn apply_rope(x: &mut [f32], hidden_dim: usize, num_heads: usize, position: usize, theta: f32) {
let head_dim = hidden_dim / num_heads;
let half_dim = head_dim / 2;
for h in 0..num_heads {
let head_start = h * head_dim;
for i in 0..half_dim {
let freq = 1.0 / theta.powf(2.0 * i as f32 / head_dim as f32);
let angle = position as f32 * freq;
let cos_val = angle.cos();
let sin_val = angle.sin();
let idx1 = head_start + i;
let idx2 = head_start + i + half_dim;
if idx2 < x.len() {
let x1 = x[idx1];
let x2 = x[idx2];
x[idx1] = x1 * cos_val - x2 * sin_val;
x[idx2] = x1 * sin_val + x2 * cos_val;
}
}
}
}
pub struct TruenoInferenceEngine {
pub config: GGUFConfig,
token_embedding: Vec<f32>,
layers: Vec<TruenoTransformerLayer>,
output_norm_weight: Vec<f32>,
output_norm_bias: Option<Vec<f32>>,
lm_head_weight: Vec<f32>,
lm_head_bias: Option<Vec<f32>>,
use_rms_norm: bool,
}
struct TruenoTransformerLayer {
attn_norm_weight: Vec<f32>,
attn_norm_bias: Option<Vec<f32>>,
qkv_weight: Vec<f32>,
qkv_bias: Option<Vec<f32>>,
attn_output_weight: Vec<f32>,
attn_output_bias: Option<Vec<f32>>,
ffn_gate_weight: Option<Vec<f32>>,
ffn_gate_bias: Option<Vec<f32>>,
ffn_up_weight: Vec<f32>,
ffn_up_bias: Option<Vec<f32>>,
ffn_down_weight: Vec<f32>,
ffn_down_bias: Option<Vec<f32>>,
ffn_norm_weight: Option<Vec<f32>>,
ffn_norm_bias: Option<Vec<f32>>,
}
impl TruenoInferenceEngine {
pub fn from_gguf_transformer(transformer: GGUFTransformer) -> Self {
let use_rms_norm = transformer
.config
.architecture
.to_lowercase()
.contains("llama")
|| transformer
.config
.architecture
.to_lowercase()
.contains("mistral")
|| transformer
.config
.architecture
.to_lowercase()
.contains("qwen");
let layers = transformer
.layers
.into_iter()
.map(|l| TruenoTransformerLayer {
attn_norm_weight: l.attn_norm_weight,
attn_norm_bias: l.attn_norm_bias,
qkv_weight: l.qkv_weight,
qkv_bias: l.qkv_bias,
attn_output_weight: l.attn_output_weight,
attn_output_bias: l.attn_output_bias,
ffn_gate_weight: l.ffn_gate_weight,
ffn_gate_bias: l.ffn_gate_bias,
ffn_up_weight: l.ffn_up_weight,
ffn_up_bias: l.ffn_up_bias,
ffn_down_weight: l.ffn_down_weight,
ffn_down_bias: l.ffn_down_bias,
ffn_norm_weight: l.ffn_norm_weight,
ffn_norm_bias: l.ffn_norm_bias,
})
.collect();
Self {
config: transformer.config,
token_embedding: transformer.token_embedding,
layers,
output_norm_weight: transformer.output_norm_weight,
output_norm_bias: transformer.output_norm_bias,
lm_head_weight: transformer.lm_head_weight,
lm_head_bias: transformer.lm_head_bias,
use_rms_norm,
}
}
fn embed(&self, token_ids: &[u32]) -> Vec<f32> {
let hidden_dim = self.config.hidden_dim;
let mut embeddings = Vec::with_capacity(token_ids.len() * hidden_dim);
for &token_id in token_ids {
let start = (token_id as usize) * hidden_dim;
let end = start + hidden_dim;
if end <= self.token_embedding.len() {
embeddings.extend_from_slice(&self.token_embedding[start..end]);
} else {
embeddings.extend(std::iter::repeat(0.0).take(hidden_dim));
}
}
embeddings
}
pub fn forward(&self, token_ids: &[u32]) -> Result<Vec<f32>> {
let hidden_dim = self.config.hidden_dim;
let intermediate_dim = self.config.intermediate_dim;
let num_heads = self.config.num_heads;
let mut hidden = self.embed(token_ids);
let seq_len = token_ids.len();
for layer in &self.layers {
let normed = if self.use_rms_norm {
simd_rms_norm(&hidden, &layer.attn_norm_weight, self.config.eps)
} else {
simd_layer_norm(
&hidden,
&layer.attn_norm_weight,
layer.attn_norm_bias.as_deref(),
self.config.eps,
)
};
let qkv_dim = layer.qkv_weight.len() / hidden_dim;
let mut qkv = simd_matmul(&normed, &layer.qkv_weight, hidden_dim, qkv_dim);
if let Some(ref bias) = layer.qkv_bias {
simd_add(&mut qkv, bias);
}
let kv_dim = (qkv_dim - hidden_dim) / 2;
let num_kv_heads = self.config.num_kv_heads;
let head_dim = hidden_dim / num_heads;
let mut k_cache = Vec::with_capacity(seq_len * hidden_dim);
let mut v_cache = Vec::with_capacity(seq_len * hidden_dim);
let mut attn_out = Vec::with_capacity(seq_len * hidden_dim);
for s in 0..seq_len {
let qkv_start = s * qkv_dim;
let mut q = qkv[qkv_start..qkv_start + hidden_dim].to_vec();
let k_raw = &qkv[qkv_start + hidden_dim..qkv_start + hidden_dim + kv_dim];
let v_raw = &qkv[qkv_start + hidden_dim + kv_dim..qkv_start + qkv_dim];
apply_rope(&mut q, hidden_dim, num_heads, s, self.config.rope_theta);
let mut k = k_raw.to_vec();
apply_rope(&mut k, kv_dim, num_kv_heads, s, self.config.rope_theta);
let (k_expanded, v_expanded): (Vec<f32>, Vec<f32>) = if num_kv_heads < num_heads {
let group_size = num_heads / num_kv_heads;
let expand = |raw: &[f32]| -> Vec<f32> {
(0..num_heads)
.flat_map(|h| {
let kv_head = h / group_size;
let start = kv_head * head_dim;
raw[start..start + head_dim].iter().copied()
})
.collect()
};
(expand(&k), expand(v_raw))
} else {
(k, v_raw.to_vec())
};
k_cache.extend_from_slice(&k_expanded);
v_cache.extend_from_slice(&v_expanded);
let attn_output = attention_with_cache(&q, &k_cache, &v_cache, num_heads, head_dim);
attn_out.extend_from_slice(&attn_output);
}
let mut attn_output =
simd_matmul(&attn_out, &layer.attn_output_weight, hidden_dim, hidden_dim);
if let Some(ref bias) = layer.attn_output_bias {
simd_add(&mut attn_output, bias);
}
simd_add(&mut hidden, &attn_output);
let ffn_input = if let Some(ref norm_weight) = layer.ffn_norm_weight {
if self.use_rms_norm {
simd_rms_norm(&hidden, norm_weight, self.config.eps)
} else {
simd_layer_norm(
&hidden,
norm_weight,
layer.ffn_norm_bias.as_deref(),
self.config.eps,
)
}
} else {
hidden.clone()
};
let ffn_output = if let Some(ref gate_weight) = layer.ffn_gate_weight {
let mut gate = simd_matmul(&ffn_input, gate_weight, hidden_dim, intermediate_dim);
if let Some(ref bias) = layer.ffn_gate_bias {
simd_add(&mut gate, bias);
}
simd_silu(&mut gate);
let mut up = simd_matmul(
&ffn_input,
&layer.ffn_up_weight,
hidden_dim,
intermediate_dim,
);
if let Some(ref bias) = layer.ffn_up_bias {
simd_add(&mut up, bias);
}
simd_mul(&mut gate, &up);
let mut output =
simd_matmul(&gate, &layer.ffn_down_weight, intermediate_dim, hidden_dim);
if let Some(ref bias) = layer.ffn_down_bias {
simd_add(&mut output, bias);
}
output
} else {
let mut ffn_hidden = simd_matmul(
&ffn_input,
&layer.ffn_up_weight,
hidden_dim,
intermediate_dim,
);
if let Some(ref bias) = layer.ffn_up_bias {
simd_add(&mut ffn_hidden, bias);
}
simd_gelu(&mut ffn_hidden);
let mut output = simd_matmul(
&ffn_hidden,
&layer.ffn_down_weight,
intermediate_dim,
hidden_dim,
);
if let Some(ref bias) = layer.ffn_down_bias {
simd_add(&mut output, bias);
}
output
};
simd_add(&mut hidden, &ffn_output);
}
let normed = if self.use_rms_norm {
simd_rms_norm(&hidden, &self.output_norm_weight, self.config.eps)
} else {
simd_layer_norm(
&hidden,
&self.output_norm_weight,
self.output_norm_bias.as_deref(),
self.config.eps,
)
};
let last_hidden_start = (seq_len - 1) * hidden_dim;
let last_hidden = &normed[last_hidden_start..last_hidden_start + hidden_dim];
let mut logits = simd_matmul(
last_hidden,
&self.lm_head_weight,
hidden_dim,
self.config.vocab_size,
);
if let Some(ref bias) = self.lm_head_bias {
simd_add(&mut logits, bias);
}
Ok(logits)
}
pub fn predict_next(&self, token_ids: &[u32]) -> Result<u32> {
let logits = self.forward(token_ids)?;
let (max_idx, _) = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.ok_or_else(|| RealizarError::InvalidShape {
reason: "Empty logits".to_string(),
})?;
Ok(max_idx as u32)
}
pub fn generate(
&self,
prompt: &[u32],
max_tokens: usize,
eos_token_id: Option<u32>,
) -> Result<Vec<u32>> {
if prompt.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Prompt cannot be empty".to_string(),
});
}
let mut tokens = prompt.to_vec();
for _ in 0..max_tokens {
let next_token = self.predict_next(&tokens)?;
if let Some(eos) = eos_token_id {
if next_token == eos {
break;
}
}
tokens.push(next_token);
}
Ok(tokens)
}
pub fn forward_one_token(
&self,
token_id: u32,
position: usize,
kv_cache: &mut KVCache,
) -> Result<Vec<f32>> {
let hidden_dim = self.config.hidden_dim;
let intermediate_dim = self.config.intermediate_dim;
let num_heads = self.config.num_heads;
let head_dim = hidden_dim / num_heads;
let start = (token_id as usize) * hidden_dim;
let end = start + hidden_dim;
let mut hidden = if end <= self.token_embedding.len() {
self.token_embedding[start..end].to_vec()
} else {
vec![0.0; hidden_dim]
};
for (layer_idx, layer) in self.layers.iter().enumerate() {
let normed = if self.use_rms_norm {
simd_rms_norm(&hidden, &layer.attn_norm_weight, self.config.eps)
} else {
simd_layer_norm(
&hidden,
&layer.attn_norm_weight,
layer.attn_norm_bias.as_deref(),
self.config.eps,
)
};
let qkv_dim = layer.qkv_weight.len() / hidden_dim;
let mut qkv = simd_matmul(&normed, &layer.qkv_weight, hidden_dim, qkv_dim);
if let Some(ref bias) = layer.qkv_bias {
simd_add(&mut qkv, bias);
}
let kv_dim = (qkv_dim - hidden_dim) / 2;
let num_kv_heads = self.config.num_kv_heads;
let mut q = qkv[0..hidden_dim].to_vec();
let k_raw = &qkv[hidden_dim..hidden_dim + kv_dim];
let v_raw = &qkv[hidden_dim + kv_dim..qkv_dim];
apply_rope(
&mut q,
hidden_dim,
num_heads,
position,
self.config.rope_theta,
);
let mut k = k_raw.to_vec();
apply_rope(
&mut k,
kv_dim,
num_kv_heads,
position,
self.config.rope_theta,
);
let (k_expanded, v_expanded): (Vec<f32>, Vec<f32>) = if num_kv_heads < num_heads {
let group_size = num_heads / num_kv_heads;
let expand = |raw: &[f32]| -> Vec<f32> {
(0..num_heads)
.flat_map(|h| {
let kv_head = h / group_size;
let start = kv_head * head_dim;
raw[start..start + head_dim].iter().copied()
})
.collect()
};
(expand(&k), expand(v_raw))
} else {
(k, v_raw.to_vec())
};
let k = k_expanded;
let v = v_expanded;
kv_cache.store(layer_idx, &k, &v);
let cached_keys = kv_cache.get_k(layer_idx);
let cached_values = kv_cache.get_v(layer_idx);
let attn_out =
attention_with_cache(&q, cached_keys, cached_values, num_heads, head_dim);
let mut attn_output =
simd_matmul(&attn_out, &layer.attn_output_weight, hidden_dim, hidden_dim);
if let Some(ref bias) = layer.attn_output_bias {
simd_add(&mut attn_output, bias);
}
simd_add(&mut hidden, &attn_output);
let ffn_input = if let Some(ref norm_weight) = layer.ffn_norm_weight {
if self.use_rms_norm {
simd_rms_norm(&hidden, norm_weight, self.config.eps)
} else {
simd_layer_norm(
&hidden,
norm_weight,
layer.ffn_norm_bias.as_deref(),
self.config.eps,
)
}
} else {
hidden.clone()
};
let ffn_output = if let Some(ref gate_weight) = layer.ffn_gate_weight {
let mut gate = simd_matmul(&ffn_input, gate_weight, hidden_dim, intermediate_dim);
if let Some(ref bias) = layer.ffn_gate_bias {
simd_add(&mut gate, bias);
}
simd_silu(&mut gate);
let mut up = simd_matmul(
&ffn_input,
&layer.ffn_up_weight,
hidden_dim,
intermediate_dim,
);
if let Some(ref bias) = layer.ffn_up_bias {
simd_add(&mut up, bias);
}
simd_mul(&mut gate, &up);
let mut output =
simd_matmul(&gate, &layer.ffn_down_weight, intermediate_dim, hidden_dim);
if let Some(ref bias) = layer.ffn_down_bias {
simd_add(&mut output, bias);
}
output
} else {
let mut ffn_hidden = simd_matmul(
&ffn_input,
&layer.ffn_up_weight,
hidden_dim,
intermediate_dim,
);
if let Some(ref bias) = layer.ffn_up_bias {
simd_add(&mut ffn_hidden, bias);
}
simd_gelu(&mut ffn_hidden);
let mut output = simd_matmul(
&ffn_hidden,
&layer.ffn_down_weight,
intermediate_dim,
hidden_dim,
);
if let Some(ref bias) = layer.ffn_down_bias {
simd_add(&mut output, bias);
}
output
};
simd_add(&mut hidden, &ffn_output);
}
kv_cache.advance();
let normed = if self.use_rms_norm {
simd_rms_norm(&hidden, &self.output_norm_weight, self.config.eps)
} else {
simd_layer_norm(
&hidden,
&self.output_norm_weight,
self.output_norm_bias.as_deref(),
self.config.eps,
)
};
let mut logits = simd_matmul(
&normed,
&self.lm_head_weight,
hidden_dim,
self.config.vocab_size,
);
if let Some(ref bias) = self.lm_head_bias {
simd_add(&mut logits, bias);
}
Ok(logits)
}
pub fn generate_with_cache(
&self,
prompt: &[u32],
max_tokens: usize,
eos_token_id: Option<u32>,
) -> Result<Vec<u32>> {
if prompt.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Prompt cannot be empty".to_string(),
});
}
let hidden_dim = self.config.hidden_dim;
let num_layers = self.layers.len();
let max_seq_len = prompt.len() + max_tokens;
let mut kv_cache = KVCache::new(num_layers, hidden_dim, max_seq_len);
for (pos, &token_id) in prompt.iter().enumerate() {
let _logits = self.forward_one_token(token_id, pos, &mut kv_cache)?;
}
let mut tokens = prompt.to_vec();
let last_token = *prompt.last().unwrap();
kv_cache.reset();
for (pos, &token_id) in prompt.iter().enumerate() {
let _logits = self.forward_one_token(token_id, pos, &mut kv_cache)?;
}
for i in 0..max_tokens {
let position = prompt.len() + i;
let current_token = if i == 0 {
last_token
} else {
*tokens.last().unwrap()
};
if i > 0 {
let logits = self.forward_one_token(current_token, position - 1, &mut kv_cache)?;
let next_token = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(idx, _)| idx as u32);
if let Some(eos) = eos_token_id {
if next_token == eos {
break;
}
}
tokens.push(next_token);
} else {
let logits =
self.forward_one_token(*tokens.last().unwrap(), position, &mut kv_cache)?;
let next_token = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(idx, _)| idx as u32);
if let Some(eos) = eos_token_id {
if next_token == eos {
break;
}
}
tokens.push(next_token);
}
}
Ok(tokens)
}
pub fn generate_with_sampling(
&self,
prompt: &[u32],
max_tokens: usize,
temperature: f32,
repetition_penalty: f32,
eos_token_id: Option<u32>,
) -> Result<Vec<u32>> {
if prompt.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Prompt cannot be empty".to_string(),
});
}
let hidden_dim = self.config.hidden_dim;
let num_layers = self.layers.len();
let max_seq_len = prompt.len() + max_tokens;
let mut kv_cache = KVCache::new(num_layers, hidden_dim, max_seq_len);
for (pos, &token_id) in prompt.iter().enumerate() {
let _logits = self.forward_one_token(token_id, pos, &mut kv_cache)?;
}
let mut tokens = prompt.to_vec();
let mut rng_state: u64 = 12345;
let next_rng = |state: &mut u64| -> f32 {
*state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
(*state as f32) / (u64::MAX as f32)
};
for i in 0..max_tokens {
let position = prompt.len() + i;
let current_token = *tokens.last().unwrap();
let mut logits = self.forward_one_token(current_token, position, &mut kv_cache)?;
if repetition_penalty > 1.0 {
let window_size = 64.min(tokens.len());
let recent_tokens = &tokens[tokens.len() - window_size..];
for &token_id in recent_tokens {
let idx = token_id as usize;
if idx < logits.len() {
let logit = logits[idx];
logits[idx] = if logit > 0.0 {
logit / repetition_penalty
} else {
logit * repetition_penalty
};
}
}
}
let next_token = if temperature <= 0.0 || temperature < 0.01 {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(idx, _)| idx as u32)
} else {
let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
let max_logit = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = scaled.iter().map(|&x| (x - max_logit).exp()).sum();
let probs: Vec<f32> = scaled
.iter()
.map(|&x| (x - max_logit).exp() / exp_sum)
.collect();
let top_p = 0.9;
let mut indexed: Vec<(usize, f32)> =
probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0;
let mut cutoff_idx = indexed.len();
for (i, (_, p)) in indexed.iter().enumerate() {
cumsum += p;
if cumsum >= top_p {
cutoff_idx = i + 1;
break;
}
}
let truncated = &indexed[..cutoff_idx];
let norm: f32 = truncated.iter().map(|(_, p)| p).sum();
let r = next_rng(&mut rng_state) * norm;
let mut acc = 0.0;
let mut chosen = truncated[0].0;
for &(idx, p) in truncated {
acc += p;
if acc >= r {
chosen = idx;
break;
}
}
chosen as u32
};
if let Some(eos) = eos_token_id {
if next_token == eos {
break;
}
}
tokens.push(next_token);
}
Ok(tokens)
}
}
pub struct QuantizedTransformerLayer {
pub attn_norm_weight: Vec<f32>,
pub attn_norm_bias: Option<Vec<f32>>,
pub qkv_weight: Q4KWeight,
pub qkv_bias: Option<Vec<f32>>,
pub attn_output_weight: Q4KWeight,
pub attn_output_bias: Option<Vec<f32>>,
pub ffn_gate_weight: Option<Q4KWeight>,
pub ffn_gate_bias: Option<Vec<f32>>,
pub ffn_up_weight: Q4KWeight,
pub ffn_up_bias: Option<Vec<f32>>,
pub ffn_down_weight: Q4KWeight,
pub ffn_down_bias: Option<Vec<f32>>,
pub ffn_norm_weight: Option<Vec<f32>>,
pub ffn_norm_bias: Option<Vec<f32>>,
}
pub struct QuantizedInferenceEngine {
pub config: GGUFConfig,
token_embedding: Vec<f32>,
layers: Vec<QuantizedTransformerLayer>,
output_norm_weight: Vec<f32>,
output_norm_bias: Option<Vec<f32>>,
lm_head_weight: Q4KWeight,
lm_head_bias: Option<Vec<f32>>,
}
impl QuantizedInferenceEngine {
pub fn new(
config: GGUFConfig,
token_embedding: Vec<f32>,
layers: Vec<QuantizedTransformerLayer>,
output_norm_weight: Vec<f32>,
output_norm_bias: Option<Vec<f32>>,
lm_head_weight: Q4KWeight,
lm_head_bias: Option<Vec<f32>>,
) -> Self {
Self {
config,
token_embedding,
layers,
output_norm_weight,
output_norm_bias,
lm_head_weight,
lm_head_bias,
}
}
#[must_use]
pub fn memory_stats(&self) -> QuantizedMemoryStats {
let mut quantized_bytes = 0usize;
let mut f32_equivalent_bytes = 0usize;
quantized_bytes += self.lm_head_weight.memory_bytes();
f32_equivalent_bytes += self.lm_head_weight.f32_equivalent_bytes();
for layer in &self.layers {
quantized_bytes += layer.qkv_weight.memory_bytes();
f32_equivalent_bytes += layer.qkv_weight.f32_equivalent_bytes();
quantized_bytes += layer.attn_output_weight.memory_bytes();
f32_equivalent_bytes += layer.attn_output_weight.f32_equivalent_bytes();
quantized_bytes += layer.ffn_up_weight.memory_bytes();
f32_equivalent_bytes += layer.ffn_up_weight.f32_equivalent_bytes();
quantized_bytes += layer.ffn_down_weight.memory_bytes();
f32_equivalent_bytes += layer.ffn_down_weight.f32_equivalent_bytes();
if let Some(ref gate) = layer.ffn_gate_weight {
quantized_bytes += gate.memory_bytes();
f32_equivalent_bytes += gate.f32_equivalent_bytes();
}
}
let embedding_bytes = self.token_embedding.len() * 4;
QuantizedMemoryStats {
quantized_weight_bytes: quantized_bytes,
f32_equivalent_bytes,
embedding_bytes,
compression_ratio: f32_equivalent_bytes as f32 / quantized_bytes as f32,
}
}
fn embed(&self, token_ids: &[u32]) -> Vec<f32> {
let hidden_dim = self.config.hidden_dim;
let mut embeddings = Vec::with_capacity(token_ids.len() * hidden_dim);
for &token_id in token_ids {
let start = (token_id as usize) * hidden_dim;
let end = start + hidden_dim;
if end <= self.token_embedding.len() {
embeddings.extend_from_slice(&self.token_embedding[start..end]);
} else {
embeddings.extend(std::iter::repeat(0.0).take(hidden_dim));
}
}
embeddings
}
pub fn forward(&self, token_ids: &[u32]) -> Result<Vec<f32>> {
let hidden_dim = self.config.hidden_dim;
let num_heads = self.config.num_heads;
let mut hidden = self.embed(token_ids);
let seq_len = token_ids.len();
for layer in &self.layers {
let normed = simd_layer_norm(
&hidden,
&layer.attn_norm_weight,
layer.attn_norm_bias.as_deref(),
self.config.eps,
);
let qkv_dim = layer.qkv_weight.out_dim;
let mut qkv = if seq_len == 1 {
layer.qkv_weight.matvec(&normed)?
} else {
let mut batch_qkv = Vec::with_capacity(seq_len * qkv_dim);
for s in 0..seq_len {
let pos_input = &normed[s * hidden_dim..(s + 1) * hidden_dim];
let pos_qkv = layer.qkv_weight.matvec(pos_input)?;
batch_qkv.extend(pos_qkv);
}
batch_qkv
};
if let Some(ref bias) = layer.qkv_bias {
for s in 0..seq_len {
let offset = s * qkv_dim;
for (i, b) in bias.iter().enumerate() {
qkv[offset + i] += b;
}
}
}
let kv_dim = (qkv_dim - hidden_dim) / 2;
let num_kv_heads = self.config.num_kv_heads;
let head_dim = hidden_dim / num_heads;
let mut k_cache = Vec::with_capacity(seq_len * hidden_dim);
let mut v_cache = Vec::with_capacity(seq_len * hidden_dim);
let mut attn_out = Vec::with_capacity(seq_len * hidden_dim);
for s in 0..seq_len {
let qkv_start = s * qkv_dim;
let mut q = qkv[qkv_start..qkv_start + hidden_dim].to_vec();
let k_raw = &qkv[qkv_start + hidden_dim..qkv_start + hidden_dim + kv_dim];
let v_raw = &qkv[qkv_start + hidden_dim + kv_dim..qkv_start + qkv_dim];
apply_rope(&mut q, hidden_dim, num_heads, s, self.config.rope_theta);
let mut k = k_raw.to_vec();
apply_rope(&mut k, kv_dim, num_kv_heads, s, self.config.rope_theta);
let (k_expanded, v_expanded): (Vec<f32>, Vec<f32>) = if num_kv_heads < num_heads {
let group_size = num_heads / num_kv_heads;
let expand = |raw: &[f32]| -> Vec<f32> {
(0..num_heads)
.flat_map(|h| {
let kv_head = h / group_size;
let start = kv_head * head_dim;
raw[start..start + head_dim].iter().copied()
})
.collect()
};
(expand(&k), expand(v_raw))
} else {
(k, v_raw.to_vec())
};
k_cache.extend_from_slice(&k_expanded);
v_cache.extend_from_slice(&v_expanded);
let attn_output = attention_with_cache(&q, &k_cache, &v_cache, num_heads, head_dim);
attn_out.extend_from_slice(&attn_output);
}
let mut attn_output = if seq_len == 1 {
layer.attn_output_weight.matvec(&attn_out)?
} else {
let mut batch_out = Vec::with_capacity(seq_len * hidden_dim);
for s in 0..seq_len {
let pos_input = &attn_out[s * hidden_dim..(s + 1) * hidden_dim];
let pos_out = layer.attn_output_weight.matvec(pos_input)?;
batch_out.extend(pos_out);
}
batch_out
};
if let Some(ref bias) = layer.attn_output_bias {
for s in 0..seq_len {
let offset = s * hidden_dim;
for (i, b) in bias.iter().enumerate() {
attn_output[offset + i] += b;
}
}
}
simd_add(&mut hidden, &attn_output);
let ffn_input = if let Some(ref norm_weight) = layer.ffn_norm_weight {
simd_layer_norm(
&hidden,
norm_weight,
layer.ffn_norm_bias.as_deref(),
self.config.eps,
)
} else {
hidden.clone()
};
let intermediate_dim = self.config.intermediate_dim;
let ffn_output = if let Some(ref gate_weight) = layer.ffn_gate_weight {
let mut gate = if seq_len == 1 {
gate_weight.matvec(&ffn_input)?
} else {
let mut batch = Vec::with_capacity(seq_len * intermediate_dim);
for s in 0..seq_len {
let pos = &ffn_input[s * hidden_dim..(s + 1) * hidden_dim];
batch.extend(gate_weight.matvec(pos)?);
}
batch
};
if let Some(ref bias) = layer.ffn_gate_bias {
for s in 0..seq_len {
let offset = s * intermediate_dim;
for (i, b) in bias.iter().enumerate() {
gate[offset + i] += b;
}
}
}
simd_silu(&mut gate);
let mut up = if seq_len == 1 {
layer.ffn_up_weight.matvec(&ffn_input)?
} else {
let mut batch = Vec::with_capacity(seq_len * intermediate_dim);
for s in 0..seq_len {
let pos = &ffn_input[s * hidden_dim..(s + 1) * hidden_dim];
batch.extend(layer.ffn_up_weight.matvec(pos)?);
}
batch
};
if let Some(ref bias) = layer.ffn_up_bias {
for s in 0..seq_len {
let offset = s * intermediate_dim;
for (i, b) in bias.iter().enumerate() {
up[offset + i] += b;
}
}
}
simd_mul(&mut gate, &up);
let mut output = if seq_len == 1 {
layer.ffn_down_weight.matvec(&gate)?
} else {
let mut batch = Vec::with_capacity(seq_len * hidden_dim);
for s in 0..seq_len {
let pos = &gate[s * intermediate_dim..(s + 1) * intermediate_dim];
batch.extend(layer.ffn_down_weight.matvec(pos)?);
}
batch
};
if let Some(ref bias) = layer.ffn_down_bias {
for s in 0..seq_len {
let offset = s * hidden_dim;
for (i, b) in bias.iter().enumerate() {
output[offset + i] += b;
}
}
}
output
} else {
let mut ffn_hidden = if seq_len == 1 {
layer.ffn_up_weight.matvec(&ffn_input)?
} else {
let mut batch = Vec::with_capacity(seq_len * intermediate_dim);
for s in 0..seq_len {
let pos = &ffn_input[s * hidden_dim..(s + 1) * hidden_dim];
batch.extend(layer.ffn_up_weight.matvec(pos)?);
}
batch
};
if let Some(ref bias) = layer.ffn_up_bias {
for s in 0..seq_len {
let offset = s * intermediate_dim;
for (i, b) in bias.iter().enumerate() {
ffn_hidden[offset + i] += b;
}
}
}
simd_gelu(&mut ffn_hidden);
let mut output = if seq_len == 1 {
layer.ffn_down_weight.matvec(&ffn_hidden)?
} else {
let mut batch = Vec::with_capacity(seq_len * hidden_dim);
for s in 0..seq_len {
let pos = &ffn_hidden[s * intermediate_dim..(s + 1) * intermediate_dim];
batch.extend(layer.ffn_down_weight.matvec(pos)?);
}
batch
};
if let Some(ref bias) = layer.ffn_down_bias {
for s in 0..seq_len {
let offset = s * hidden_dim;
for (i, b) in bias.iter().enumerate() {
output[offset + i] += b;
}
}
}
output
};
simd_add(&mut hidden, &ffn_output);
}
let normed = simd_layer_norm(
&hidden,
&self.output_norm_weight,
self.output_norm_bias.as_deref(),
self.config.eps,
);
let last_hidden_start = (seq_len - 1) * hidden_dim;
let last_hidden = &normed[last_hidden_start..last_hidden_start + hidden_dim];
let mut logits = self.lm_head_weight.matvec(last_hidden)?;
if let Some(ref bias) = self.lm_head_bias {
simd_add(&mut logits, bias);
}
Ok(logits)
}
pub fn forward_one_token(
&self,
token_id: u32,
position: usize,
kv_cache: &mut KVCache,
) -> Result<Vec<f32>> {
let hidden_dim = self.config.hidden_dim;
let num_heads = self.config.num_heads;
let head_dim = hidden_dim / num_heads;
let start = (token_id as usize) * hidden_dim;
let end = start + hidden_dim;
let mut hidden = if end <= self.token_embedding.len() {
self.token_embedding[start..end].to_vec()
} else {
vec![0.0; hidden_dim]
};
for (layer_idx, layer) in self.layers.iter().enumerate() {
let normed = simd_layer_norm(
&hidden,
&layer.attn_norm_weight,
layer.attn_norm_bias.as_deref(),
self.config.eps,
);
let mut qkv = layer.qkv_weight.matvec(&normed)?;
if let Some(ref bias) = layer.qkv_bias {
simd_add(&mut qkv, bias);
}
let qkv_dim = layer.qkv_weight.out_dim;
let kv_dim = (qkv_dim - hidden_dim) / 2;
let num_kv_heads = self.config.num_kv_heads;
let mut q = qkv[0..hidden_dim].to_vec();
let k_raw = &qkv[hidden_dim..hidden_dim + kv_dim];
let v_raw = &qkv[hidden_dim + kv_dim..qkv_dim];
apply_rope(
&mut q,
hidden_dim,
num_heads,
position,
self.config.rope_theta,
);
let mut k = k_raw.to_vec();
apply_rope(
&mut k,
kv_dim,
num_kv_heads,
position,
self.config.rope_theta,
);
let (k_expanded, v_expanded): (Vec<f32>, Vec<f32>) = if num_kv_heads < num_heads {
let group_size = num_heads / num_kv_heads;
let expand = |raw: &[f32]| -> Vec<f32> {
(0..num_heads)
.flat_map(|h| {
let kv_head = h / group_size;
let start = kv_head * head_dim;
raw[start..start + head_dim].iter().copied()
})
.collect()
};
(expand(&k), expand(v_raw))
} else {
(k, v_raw.to_vec())
};
kv_cache.store(layer_idx, &k_expanded, &v_expanded);
let cached_keys = kv_cache.get_k(layer_idx);
let cached_values = kv_cache.get_v(layer_idx);
let attn_out =
attention_with_cache(&q, cached_keys, cached_values, num_heads, head_dim);
let mut attn_output = layer.attn_output_weight.matvec(&attn_out)?;
if let Some(ref bias) = layer.attn_output_bias {
simd_add(&mut attn_output, bias);
}
simd_add(&mut hidden, &attn_output);
let ffn_input = if let Some(ref norm_weight) = layer.ffn_norm_weight {
simd_layer_norm(
&hidden,
norm_weight,
layer.ffn_norm_bias.as_deref(),
self.config.eps,
)
} else {
hidden.clone()
};
let ffn_output = if let Some(ref gate_weight) = layer.ffn_gate_weight {
let mut gate = gate_weight.matvec(&ffn_input)?;
if let Some(ref bias) = layer.ffn_gate_bias {
simd_add(&mut gate, bias);
}
simd_silu(&mut gate);
let mut up = layer.ffn_up_weight.matvec(&ffn_input)?;
if let Some(ref bias) = layer.ffn_up_bias {
simd_add(&mut up, bias);
}
simd_mul(&mut gate, &up);
let mut output = layer.ffn_down_weight.matvec(&gate)?;
if let Some(ref bias) = layer.ffn_down_bias {
simd_add(&mut output, bias);
}
output
} else {
let mut ffn_hidden = layer.ffn_up_weight.matvec(&ffn_input)?;
if let Some(ref bias) = layer.ffn_up_bias {
simd_add(&mut ffn_hidden, bias);
}
simd_gelu(&mut ffn_hidden);
let mut output = layer.ffn_down_weight.matvec(&ffn_hidden)?;
if let Some(ref bias) = layer.ffn_down_bias {
simd_add(&mut output, bias);
}
output
};
simd_add(&mut hidden, &ffn_output);
}
kv_cache.advance();
let normed = simd_layer_norm(
&hidden,
&self.output_norm_weight,
self.output_norm_bias.as_deref(),
self.config.eps,
);
let mut logits = self.lm_head_weight.matvec(&normed)?;
if let Some(ref bias) = self.lm_head_bias {
simd_add(&mut logits, bias);
}
Ok(logits)
}
pub fn generate_with_cache(
&self,
prompt: &[u32],
max_tokens: usize,
eos_token_id: Option<u32>,
) -> Result<Vec<u32>> {
if prompt.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Prompt cannot be empty".to_string(),
});
}
let hidden_dim = self.config.hidden_dim;
let num_layers = self.layers.len();
let max_seq_len = prompt.len() + max_tokens;
let mut kv_cache = KVCache::new(num_layers, hidden_dim, max_seq_len);
for (pos, &token_id) in prompt.iter().enumerate() {
let _logits = self.forward_one_token(token_id, pos, &mut kv_cache)?;
}
let mut tokens = prompt.to_vec();
for i in 0..max_tokens {
let position = prompt.len() + i;
let current_token = *tokens.last().unwrap();
let logits = self.forward_one_token(current_token, position, &mut kv_cache)?;
let next_token = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(idx, _)| idx as u32);
if let Some(eos) = eos_token_id {
if next_token == eos {
break;
}
}
tokens.push(next_token);
}
Ok(tokens)
}
}
#[derive(Debug, Clone)]
pub struct QuantizedMemoryStats {
pub quantized_weight_bytes: usize,
pub f32_equivalent_bytes: usize,
pub embedding_bytes: usize,
pub compression_ratio: f32,
}
impl std::fmt::Display for QuantizedMemoryStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let quantized_mb = self.quantized_weight_bytes as f64 / (1024.0 * 1024.0);
let f32_mb = self.f32_equivalent_bytes as f64 / (1024.0 * 1024.0);
let embed_mb = self.embedding_bytes as f64 / (1024.0 * 1024.0);
write!(
f,
"Quantized weights: {:.1} MB (vs {:.1} MB f32, {:.1}x compression)\nEmbeddings: {:.1} MB",
quantized_mb, f32_mb, self.compression_ratio, embed_mb
)
}
}
#[cfg(test)]
mod attention_tests {
use super::*;
#[test]
fn test_attention_causal_flow() {
let num_heads = 2;
let head_dim = 2;
let hidden_dim = num_heads * head_dim;
let seq_len = 3;
let mut k_cache = Vec::new();
let mut v_cache = Vec::new();
let mut outputs = Vec::new();
for s in 0..seq_len {
let q: Vec<f32> = (0..hidden_dim)
.map(|i| (s * hidden_dim + i) as f32 * 0.1)
.collect();
let k: Vec<f32> = (0..hidden_dim)
.map(|i| (s * hidden_dim + i) as f32 * 0.2)
.collect();
let v: Vec<f32> = (0..hidden_dim).map(|_| (s + 1) as f32).collect();
k_cache.extend_from_slice(&k);
v_cache.extend_from_slice(&v);
let attn_output = attention_with_cache(&q, &k_cache, &v_cache, num_heads, head_dim);
outputs.push(attn_output);
}
assert_eq!(outputs[0].len(), hidden_dim);
for val in &outputs[0] {
assert!((val - 1.0).abs() < 1e-4, "Pos 0 should be 1.0, got {}", val);
}
for s in 1..seq_len {
for val in &outputs[s] {
assert!(val.is_finite() && *val >= 0.9 && *val <= (seq_len + 1) as f32);
}
}
}
#[test]
fn test_gqa_expansion() {
let num_heads = 4;
let num_kv_heads = 2;
let head_dim = 2;
let group_size = num_heads / num_kv_heads;
let k_raw: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let k_expanded: Vec<f32> = (0..num_heads)
.flat_map(|h| {
let kv_head = h / group_size;
let start = kv_head * head_dim;
k_raw[start..start + head_dim].iter().copied()
})
.collect();
assert_eq!(k_expanded, vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0]);
}
#[test]
fn test_attention_single_position() {
let output = attention_with_cache(
&[1.0, 2.0, 3.0, 4.0], &[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0], 2, 2, );
assert_eq!(output.len(), 4);
for (i, &expected) in [5.0, 6.0, 7.0, 8.0].iter().enumerate() {
assert!((output[i] - expected).abs() < 1e-4);
}
}
#[test]
fn test_embedding_lookup_basic() {
let embedding = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
let hidden_dim = 4;
let token_id = 1u32;
let start = (token_id as usize) * hidden_dim;
let end = start + hidden_dim;
let result = &embedding[start..end];
assert_eq!(result, &[5.0, 6.0, 7.0, 8.0]);
}
#[test]
fn test_lm_head_projection() {
let hidden_dim = 4;
let vocab_size = 3;
let hidden = vec![1.0, 0.0, 0.0, 0.0];
let lm_head = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ];
let logits: Vec<f32> = (0..vocab_size)
.map(|v| {
let row_start = v * hidden_dim;
(0..hidden_dim)
.map(|i| lm_head[row_start + i] * hidden[i])
.sum()
})
.collect();
assert_eq!(logits, vec![1.0, 0.0, 0.0]);
let max_idx = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap();
assert_eq!(max_idx, 0);
}
#[test]
fn test_repetition_penalty_reduces_repeats() {
let logits = vec![1.0, 2.0, 3.0, 4.0]; let recent_tokens = vec![3u32]; let penalty = 1.5;
let mut penalized = logits.clone();
for &token_id in &recent_tokens {
let idx = token_id as usize;
let logit = penalized[idx];
penalized[idx] = if logit > 0.0 {
logit / penalty
} else {
logit * penalty
};
}
assert!((penalized[3] - 2.667_f32).abs() < 0.01);
assert_eq!(penalized[0], 1.0);
assert_eq!(penalized[1], 2.0);
assert_eq!(penalized[2], 3.0);
let max_idx = penalized
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap();
assert_eq!(max_idx, 2, "After penalty, token 2 should be chosen");
}
#[test]
fn test_simd_rms_norm_correctness() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let weight = vec![1.0, 1.0, 1.0, 1.0]; let eps = 1e-5;
let result = simd_rms_norm(&input, &weight, eps);
let sum_sq: f32 = input.iter().map(|x| x * x).sum();
let rms = (sum_sq / 4.0 + eps).sqrt();
let expected: Vec<f32> = input.iter().map(|&x| x / rms).collect();
for (i, (&r, &e)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(r - e).abs() < 1e-5,
"RMSNorm mismatch at {}: got {} expected {}",
i,
r,
e
);
}
}
#[test]
fn test_simd_rms_norm_with_gamma() {
let input = vec![2.0, 2.0, 2.0, 2.0];
let weight = vec![0.5, 1.0, 1.5, 2.0]; let eps = 1e-5;
let result = simd_rms_norm(&input, &weight, eps);
let sum_sq: f32 = input.iter().map(|x| x * x).sum();
let rms = (sum_sq / 4.0 + eps).sqrt(); let expected: Vec<f32> = input
.iter()
.zip(weight.iter())
.map(|(&x, &w)| x / rms * w)
.collect();
for (i, (&r, &e)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(r - e).abs() < 1e-5,
"RMSNorm with gamma mismatch at {}: got {} expected {}",
i,
r,
e
);
}
}
}
#[cfg(all(test, feature = "heavy-tests"))]
mod tests {
use super::*;
#[test]
fn test_simd_matmul_single_token() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let weight = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ];
let output = simd_matmul(&input, &weight, 4, 2);
assert_eq!(output.len(), 2);
assert!((output[0] - 1.0).abs() < 1e-5);
assert!((output[1] - 2.0).abs() < 1e-5);
}
#[test]
fn test_simd_matmul_batch() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let weight = vec![
1.0, 0.0, 0.0, 1.0, 1.0, 1.0, ];
let output = simd_matmul(&input, &weight, 2, 3);
assert_eq!(output.len(), 6);
assert!((output[0] - 1.0).abs() < 1e-5);
assert!((output[1] - 2.0).abs() < 1e-5);
assert!((output[2] - 3.0).abs() < 1e-5);
}
#[test]
fn test_simd_dot() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![1.0, 1.0, 1.0, 1.0];
let result = simd_dot(&a, &b);
assert!((result - 10.0).abs() < 1e-5);
}
#[test]
fn test_simd_add() {
let mut a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![10.0, 20.0, 30.0, 40.0];
simd_add(&mut a, &b);
assert!((a[0] - 11.0).abs() < 1e-5);
assert!((a[3] - 44.0).abs() < 1e-5);
}
#[test]
fn test_simd_gelu() {
let mut data = vec![0.0, 1.0, -1.0];
simd_gelu(&mut data);
assert!((data[0]).abs() < 1e-5);
assert!((data[1] - 0.841).abs() < 0.01);
assert!((data[2] + 0.159).abs() < 0.01);
}
#[test]
fn test_simd_softmax() {
let mut data = vec![1.0, 2.0, 3.0];
simd_softmax(&mut data);
let sum: f32 = data.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
assert!(data[2] > data[1]);
assert!(data[1] > data[0]);
}
#[test]
fn test_simd_softmax_empty() {
let mut data: Vec<f32> = vec![];
simd_softmax(&mut data);
assert!(data.is_empty());
}
#[test]
fn test_simd_layer_norm() {
let input = vec![1.0, 2.0, 3.0, 4.0]; let weight = vec![1.0, 1.0, 1.0, 1.0];
let bias = Some(vec![0.0, 0.0, 0.0, 0.0]);
let output = simd_layer_norm(&input, &weight, bias.as_deref(), 1e-5);
assert_eq!(output.len(), 4);
let mean: f32 = output.iter().sum::<f32>() / 4.0;
assert!(mean.abs() < 1e-5);
}
#[test]
fn test_apply_rope() {
let mut x = vec![1.0, 0.0, 0.0, 1.0]; apply_rope(&mut x, 4, 1, 0, 10000.0);
assert!((x[0] - 1.0).abs() < 0.01);
assert!((x[3] - 1.0).abs() < 0.01);
}
#[test]
fn test_apply_rope_position() {
let mut x1 = vec![1.0, 0.0, 0.0, 1.0];
let mut x2 = vec![1.0, 0.0, 0.0, 1.0];
apply_rope(&mut x1, 4, 1, 0, 10000.0);
apply_rope(&mut x2, 4, 1, 10, 10000.0);
assert!((x1[0] - x2[0]).abs() > 0.001 || (x1[2] - x2[2]).abs() > 0.001);
}
#[test]
fn test_parallel_matmul_large_output() {
let in_dim = 16;
let out_dim = 512;
let input: Vec<f32> = (0..in_dim).map(|i| i as f32 * 0.1).collect();
let weight: Vec<f32> = (0..out_dim * in_dim)
.map(|i| ((i % 7) as f32 - 3.0) * 0.01)
.collect();
let output = simd_matmul(&input, &weight, in_dim, out_dim);
assert_eq!(output.len(), out_dim);
let expected_0: f32 = (0..in_dim).map(|i| input[i] * weight[i]).sum();
assert!((output[0] - expected_0).abs() < 1e-4);
}
#[test]
fn test_parallel_matmul_correctness() {
let in_dim = 8;
let out_dim = 300; let input: Vec<f32> = (0..in_dim).map(|i| (i + 1) as f32).collect();
let weight: Vec<f32> = (0..out_dim * in_dim)
.map(|i| if i % in_dim == 0 { 1.0 } else { 0.0 })
.collect();
let output = simd_matmul(&input, &weight, in_dim, out_dim);
for (i, &val) in output.iter().enumerate() {
assert!(
(val - 1.0).abs() < 1e-5,
"Output[{}] = {} expected 1.0",
i,
val
);
}
}
#[test]
fn test_tiled_matmul_batch() {
let seq_len = 4;
let in_dim = 8;
let out_dim = 256;
let input: Vec<f32> = (0..seq_len * in_dim)
.map(|i| (i % in_dim + 1) as f32)
.collect();
let weight: Vec<f32> = (0..out_dim * in_dim)
.map(|i| if i % in_dim == 0 { 1.0 } else { 0.0 })
.collect();
let output = simd_matmul(&input, &weight, in_dim, out_dim);
assert_eq!(output.len(), seq_len * out_dim);
for t in 0..seq_len {
assert!(
(output[t * out_dim] - 1.0).abs() < 1e-5,
"Token {} output[0] = {} expected 1.0",
t,
output[t * out_dim]
);
}
}
#[test]
fn test_tiled_matmul_large_batch() {
let seq_len = 128;
let in_dim = 16;
let out_dim = 32;
let input: Vec<f32> = vec![1.0; seq_len * in_dim];
let weight: Vec<f32> = vec![0.0625; out_dim * in_dim];
let output = simd_matmul(&input, &weight, in_dim, out_dim);
assert_eq!(output.len(), seq_len * out_dim);
for val in &output {
assert!((*val - 1.0).abs() < 1e-4, "Expected 1.0, got {}", val);
}
}
#[test]
fn test_kv_cache_new() {
let cache = KVCache::new(4, 64, 128);
assert_eq!(cache.num_layers, 4);
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_kv_cache_store_and_retrieve() {
let mut cache = KVCache::new(2, 4, 8);
let k0 = vec![1.0, 2.0, 3.0, 4.0];
let v0 = vec![5.0, 6.0, 7.0, 8.0];
cache.store(0, &k0, &v0);
cache.advance();
assert_eq!(cache.len(), 1);
assert!(!cache.is_empty());
let k = cache.get_k(0);
let v = cache.get_v(0);
assert_eq!(k.len(), 4);
assert_eq!(v.len(), 4);
assert!((k[0] - 1.0).abs() < 1e-5);
assert!((v[3] - 8.0).abs() < 1e-5);
}
#[test]
fn test_kv_cache_multiple_layers() {
let mut cache = KVCache::new(3, 2, 4);
cache.store(0, &[1.0, 2.0], &[3.0, 4.0]);
cache.store(1, &[5.0, 6.0], &[7.0, 8.0]);
cache.store(2, &[9.0, 10.0], &[11.0, 12.0]);
cache.advance();
assert!((cache.get_k(0)[0] - 1.0).abs() < 1e-5);
assert!((cache.get_k(1)[0] - 5.0).abs() < 1e-5);
assert!((cache.get_k(2)[0] - 9.0).abs() < 1e-5);
assert!((cache.get_v(0)[1] - 4.0).abs() < 1e-5);
assert!((cache.get_v(1)[1] - 8.0).abs() < 1e-5);
assert!((cache.get_v(2)[1] - 12.0).abs() < 1e-5);
}
#[test]
fn test_kv_cache_multiple_positions() {
let mut cache = KVCache::new(1, 2, 8);
cache.store(0, &[1.0, 1.0], &[2.0, 2.0]);
cache.advance();
cache.store(0, &[3.0, 3.0], &[4.0, 4.0]);
cache.advance();
cache.store(0, &[5.0, 5.0], &[6.0, 6.0]);
cache.advance();
assert_eq!(cache.len(), 3);
let k = cache.get_k(0);
assert_eq!(k.len(), 6); assert!((k[0] - 1.0).abs() < 1e-5); assert!((k[2] - 3.0).abs() < 1e-5); assert!((k[4] - 5.0).abs() < 1e-5); }
#[test]
fn test_kv_cache_reset() {
let mut cache = KVCache::new(2, 4, 16);
cache.store(0, &[1.0; 4], &[2.0; 4]);
cache.advance();
cache.store(0, &[3.0; 4], &[4.0; 4]);
cache.advance();
assert_eq!(cache.len(), 2);
cache.reset();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_kv_cache_full() {
let mut cache = KVCache::new(1, 2, 2);
cache.store(0, &[1.0, 1.0], &[1.0, 1.0]);
cache.advance();
cache.store(0, &[2.0, 2.0], &[2.0, 2.0]);
cache.advance();
assert_eq!(cache.len(), 2);
cache.store(0, &[3.0, 3.0], &[3.0, 3.0]);
cache.advance();
assert_eq!(cache.len(), 2);
let k = cache.get_k(0);
assert!((k[2] - 2.0).abs() < 1e-5);
}
#[test]
fn test_attention_with_cache_empty() {
let q = vec![1.0, 2.0, 3.0, 4.0];
let k_cache: Vec<f32> = vec![];
let v_cache: Vec<f32> = vec![];
let output = attention_with_cache(&q, &k_cache, &v_cache, 2, 2);
assert_eq!(output.len(), 4);
assert!((output[0] - 1.0).abs() < 1e-5);
assert!((output[3] - 4.0).abs() < 1e-5);
}
#[test]
fn test_attention_with_cache_single_position() {
let num_heads = 2;
let head_dim = 4;
let hidden_dim = num_heads * head_dim;
let q: Vec<f32> = vec![1.0; hidden_dim];
let k_cache: Vec<f32> = vec![1.0; hidden_dim]; let v_cache: Vec<f32> = (0..hidden_dim).map(|i| i as f32).collect();
let output = attention_with_cache(&q, &k_cache, &v_cache, num_heads, head_dim);
assert_eq!(output.len(), hidden_dim);
for i in 0..hidden_dim {
assert!(
(output[i] - v_cache[i]).abs() < 1e-4,
"output[{}] = {} expected {}",
i,
output[i],
v_cache[i]
);
}
}
#[test]
fn test_attention_with_cache_multiple_positions() {
let num_heads = 1;
let head_dim = 4;
let hidden_dim = num_heads * head_dim;
let q = vec![0.0, 0.0, 1.0, 1.0];
let k_cache = vec![
1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, ];
let v_cache = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ];
let output = attention_with_cache(&q, &k_cache, &v_cache, num_heads, head_dim);
assert_eq!(output.len(), hidden_dim);
assert!(output[1] > output[0], "Should attend more to position 1");
}
#[test]
fn test_attention_with_cache_multi_head() {
let num_heads = 4;
let head_dim = 2;
let hidden_dim = num_heads * head_dim;
let q: Vec<f32> = (0..hidden_dim).map(|i| (i % 3) as f32).collect();
let k_cache: Vec<f32> = vec![1.0; hidden_dim * 3]; let v_cache: Vec<f32> = (0..hidden_dim * 3).map(|i| (i % 5) as f32 * 0.1).collect();
let output = attention_with_cache(&q, &k_cache, &v_cache, num_heads, head_dim);
assert_eq!(output.len(), hidden_dim);
for val in &output {
assert!(val.is_finite(), "Output should be finite");
}
}
#[test]
fn test_attention_scale_factor() {
let num_heads = 1;
let head_dim = 64; let hidden_dim = num_heads * head_dim;
let q: Vec<f32> = vec![10.0; hidden_dim];
let k_cache: Vec<f32> = vec![10.0; hidden_dim];
let v_cache: Vec<f32> = vec![1.0; hidden_dim];
let output = attention_with_cache(&q, &k_cache, &v_cache, num_heads, head_dim);
for val in &output {
assert!(val.is_finite());
assert!((*val - 1.0).abs() < 1e-4);
}
}
#[test]
fn test_simd_mul() {
let mut a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![2.0, 3.0, 4.0, 5.0];
simd_mul(&mut a, &b);
assert!((a[0] - 2.0).abs() < 1e-5);
assert!((a[1] - 6.0).abs() < 1e-5);
assert!((a[2] - 12.0).abs() < 1e-5);
assert!((a[3] - 20.0).abs() < 1e-5);
}
#[test]
fn test_simd_silu() {
let mut data = vec![0.0, 1.0, -1.0, 2.0];
simd_silu(&mut data);
assert!(data[0].abs() < 1e-5);
assert!((data[1] - 0.731).abs() < 0.01);
assert!((data[2] + 0.269).abs() < 0.01);
assert!((data[3] - 1.762).abs() < 0.01);
}
#[test]
fn test_kv_cache_attention_flow() {
let num_layers = 2;
let hidden_dim = 8;
let max_seq = 4;
let num_heads = 2;
let head_dim = 4;
let mut cache = KVCache::new(num_layers, hidden_dim, max_seq);
for pos in 0..3 {
for layer in 0..num_layers {
let k: Vec<f32> = (0..hidden_dim)
.map(|i| (pos + layer + i) as f32 * 0.1)
.collect();
let v: Vec<f32> = (0..hidden_dim)
.map(|i| (pos + layer + i) as f32 * 0.2)
.collect();
cache.store(layer, &k, &v);
}
cache.advance();
}
assert_eq!(cache.len(), 3);
for layer in 0..num_layers {
let k = cache.get_k(layer);
let v = cache.get_v(layer);
assert_eq!(k.len(), 24);
assert_eq!(v.len(), 24);
}
let q: Vec<f32> = vec![1.0; hidden_dim];
let output = attention_with_cache(&q, cache.get_k(0), cache.get_v(0), num_heads, head_dim);
assert_eq!(output.len(), hidden_dim);
for val in &output {
assert!(val.is_finite());
}
}
#[test]
fn test_kv_cache_incremental_attention() {
let num_heads = 2;
let head_dim = 4;
let hidden_dim = num_heads * head_dim;
let mut cache = KVCache::new(1, hidden_dim, 10);
let k0 = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0];
let v0 = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
cache.store(0, &k0, &v0);
cache.advance();
let q1 = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0];
let out1 = attention_with_cache(&q1, cache.get_k(0), cache.get_v(0), num_heads, head_dim);
let k1 = vec![0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
let v1 = vec![2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0];
cache.store(0, &k1, &v1);
cache.advance();
let q2 = vec![1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0];
let out2 = attention_with_cache(&q2, cache.get_k(0), cache.get_v(0), num_heads, head_dim);
assert_eq!(out2.len(), hidden_dim);
for val in out1.iter().chain(out2.iter()) {
assert!(val.is_finite());
}
}
#[test]
fn test_attention_softmax_stability() {
let num_heads = 1;
let head_dim = 4;
let q = vec![100.0, 100.0, 100.0, 100.0]; let k_cache = vec![
100.0, 100.0, 100.0, 100.0, -100.0, -100.0, -100.0, -100.0, ];
let v_cache = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let output = attention_with_cache(&q, &k_cache, &v_cache, num_heads, head_dim);
for (i, val) in output.iter().enumerate() {
assert!(val.is_finite(), "Output[{}] = {} should be finite", i, val);
}
assert!(output[0] < 2.0, "Should be close to V0 (1.0)");
}
#[test]
fn test_causal_self_attention_flow() {
let num_heads = 2;
let head_dim = 2;
let hidden_dim = num_heads * head_dim;
let seq_len = 3;
let mut k_cache = Vec::new();
let mut v_cache = Vec::new();
let mut outputs = Vec::new();
for s in 0..seq_len {
let q: Vec<f32> = (0..hidden_dim)
.map(|i| (s * hidden_dim + i) as f32 * 0.1)
.collect();
let k: Vec<f32> = (0..hidden_dim)
.map(|i| (s * hidden_dim + i) as f32 * 0.2)
.collect();
let v: Vec<f32> = (0..hidden_dim).map(|_| (s + 1) as f32).collect();
k_cache.extend_from_slice(&k);
v_cache.extend_from_slice(&v);
let attn_output = attention_with_cache(&q, &k_cache, &v_cache, num_heads, head_dim);
outputs.push(attn_output);
}
assert_eq!(outputs[0].len(), hidden_dim);
for val in &outputs[0] {
assert!(
(val - 1.0).abs() < 1e-4,
"Pos 0 should output V[0]=1.0, got {}",
val
);
}
for val in &outputs[1] {
assert!(
*val >= 0.9 && *val <= 2.1,
"Pos 1 should be between V[0] and V[1], got {}",
val
);
}
for val in &outputs[2] {
assert!(
*val >= 0.9 && *val <= 3.1,
"Pos 2 should be in range of V values, got {}",
val
);
}
}
#[test]
fn test_gqa_kv_expansion() {
let num_heads = 4;
let num_kv_heads = 2;
let head_dim = 2;
let hidden_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let group_size = num_heads / num_kv_heads;
let k_raw: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
assert_eq!(k_raw.len(), kv_dim);
let k_expanded: Vec<f32> = (0..num_heads)
.flat_map(|h| {
let kv_head = h / group_size;
let start = kv_head * head_dim;
k_raw[start..start + head_dim].iter().copied()
})
.collect();
assert_eq!(k_expanded.len(), hidden_dim);
assert_eq!(k_expanded, vec![1.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 4.0]);
}
#[test]
fn test_kv_cache_layer_isolation() {
let mut cache = KVCache::new(3, 4, 8);
cache.store(0, &[1.0, 1.0, 1.0, 1.0], &[1.0, 1.0, 1.0, 1.0]);
cache.store(1, &[2.0, 2.0, 2.0, 2.0], &[2.0, 2.0, 2.0, 2.0]);
cache.store(2, &[3.0, 3.0, 3.0, 3.0], &[3.0, 3.0, 3.0, 3.0]);
cache.advance();
let k0 = cache.get_k(0);
let k1 = cache.get_k(1);
let k2 = cache.get_k(2);
assert!((k0[0] - 1.0).abs() < 1e-5);
assert!((k1[0] - 2.0).abs() < 1e-5);
assert!((k2[0] - 3.0).abs() < 1e-5);
}
fn create_q4k_test_data(num_super_blocks: usize) -> Vec<u8> {
const SUPER_BLOCK_BYTES: usize = 144;
let mut data = Vec::with_capacity(num_super_blocks * SUPER_BLOCK_BYTES);
for _ in 0..num_super_blocks {
data.extend_from_slice(&[0x00, 0x3C]); data.extend_from_slice(&[0x00, 0x00]);
data.extend_from_slice(&[0u8; 12]);
for _ in 0..128 {
data.push(0x77); }
}
data
}
#[test]
fn test_q4k_weight_creation() {
let data = create_q4k_test_data(1);
let weight = Q4KWeight::new(data, 256, 1).unwrap();
assert_eq!(weight.in_dim, 256);
assert_eq!(weight.out_dim, 1);
assert_eq!(weight.memory_bytes(), 144);
assert_eq!(weight.f32_equivalent_bytes(), 1024); assert!(weight.compression_ratio() > 7.0); }
#[test]
fn test_q4k_weight_invalid_size() {
let data = vec![0u8; 100]; let result = Q4KWeight::new(data, 256, 1);
assert!(result.is_err());
}
#[test]
fn test_q4k_weight_matvec_dimension_mismatch() {
let data = create_q4k_test_data(1);
let weight = Q4KWeight::new(data, 256, 1).unwrap();
let input = vec![1.0f32; 128]; let result = weight.matvec(&input);
assert!(result.is_err());
}
#[test]
fn test_q4k_weight_matvec_correct_output_size() {
let data = create_q4k_test_data(2); let weight = Q4KWeight::new(data, 512, 1).unwrap();
let input = vec![1.0f32; 512];
let output = weight.matvec(&input).unwrap();
assert_eq!(output.len(), 1); }
#[test]
fn test_quantized_memory_stats_display() {
let stats = QuantizedMemoryStats {
quantized_weight_bytes: 1024 * 1024, f32_equivalent_bytes: 8 * 1024 * 1024, embedding_bytes: 512 * 1024, compression_ratio: 8.0,
};
let display = format!("{}", stats);
assert!(display.contains("1.0 MB"));
assert!(display.contains("8.0 MB"));
assert!(display.contains("8.0x"));
}
#[test]
fn test_quantized_inference_engine_memory_stats() {
let config = GGUFConfig {
architecture: "test".to_string(),
vocab_size: 100,
hidden_dim: 256,
intermediate_dim: 512,
num_heads: 4,
num_kv_heads: 4,
num_layers: 1,
context_length: 2048,
eps: 1e-5,
rope_theta: 10000.0,
};
let qkv_data = create_q4k_test_data(768); let attn_out_data = create_q4k_test_data(256); let ffn_up_data = create_q4k_test_data(512); let ffn_down_data = create_q4k_test_data(256 * 2); let lm_head_data = create_q4k_test_data(100);
let layer = QuantizedTransformerLayer {
attn_norm_weight: vec![1.0; 256],
attn_norm_bias: None,
qkv_weight: Q4KWeight::new(qkv_data, 256, 768).unwrap(),
qkv_bias: None,
attn_output_weight: Q4KWeight::new(attn_out_data, 256, 256).unwrap(),
attn_output_bias: None,
ffn_gate_weight: None,
ffn_gate_bias: None,
ffn_up_weight: Q4KWeight::new(ffn_up_data, 256, 512).unwrap(),
ffn_up_bias: None,
ffn_down_weight: Q4KWeight::new(ffn_down_data, 512, 256).unwrap(),
ffn_down_bias: None,
ffn_norm_weight: None,
ffn_norm_bias: None,
};
let token_embedding = vec![0.0f32; 100 * 256];
let engine = QuantizedInferenceEngine::new(
config,
token_embedding,
vec![layer],
vec![1.0; 256],
None,
Q4KWeight::new(lm_head_data, 256, 100).unwrap(),
None,
);
let stats = engine.memory_stats();
assert!(stats.quantized_weight_bytes > 0);
assert!(stats.f32_equivalent_bytes > 0);
assert!(stats.embedding_bytes > 0);
assert!(stats.compression_ratio > 1.0);
assert!(stats.compression_ratio > 5.0);
assert!(stats.compression_ratio < 10.0);
}
#[test]
fn test_quantized_engine_embed() {
let config = GGUFConfig {
architecture: "test".to_string(),
vocab_size: 10,
hidden_dim: 256,
intermediate_dim: 512,
num_heads: 4,
num_kv_heads: 4,
num_layers: 0, context_length: 2048,
eps: 1e-5,
rope_theta: 10000.0,
};
let mut token_embedding = vec![0.0f32; 10 * 256];
for token_id in 0..10 {
for j in 0..256 {
token_embedding[token_id * 256 + j] = (token_id + 1) as f32;
}
}
let lm_head_data = create_q4k_test_data(10);
let engine = QuantizedInferenceEngine::new(
config,
token_embedding,
vec![],
vec![1.0; 256],
None,
Q4KWeight::new(lm_head_data, 256, 10).unwrap(),
None,
);
assert_eq!(engine.config.vocab_size, 10);
}
#[test]
fn test_q4k_weight_compression_ratio() {
let data = create_q4k_test_data(4); let weight = Q4KWeight::new(data, 1024, 1).unwrap();
let ratio = weight.compression_ratio();
assert!(ratio > 7.0, "Compression ratio {} should be > 7.0", ratio);
assert!(ratio < 8.0, "Compression ratio {} should be < 8.0", ratio);
}
#[test]
fn test_quantized_transformer_layer_fields() {
let qkv_data = create_q4k_test_data(3); let attn_out_data = create_q4k_test_data(1); let gate_data = create_q4k_test_data(2); let up_data = create_q4k_test_data(2); let down_data = create_q4k_test_data(4);
let layer = QuantizedTransformerLayer {
attn_norm_weight: vec![1.0; 256],
attn_norm_bias: Some(vec![0.0; 256]),
qkv_weight: Q4KWeight::new(qkv_data, 256, 3).unwrap(),
qkv_bias: Some(vec![0.0; 768]),
attn_output_weight: Q4KWeight::new(attn_out_data, 256, 1).unwrap(),
attn_output_bias: Some(vec![0.0; 256]),
ffn_gate_weight: Some(Q4KWeight::new(gate_data, 256, 2).unwrap()),
ffn_gate_bias: Some(vec![0.0; 512]),
ffn_up_weight: Q4KWeight::new(up_data, 256, 2).unwrap(),
ffn_up_bias: Some(vec![0.0; 512]),
ffn_down_weight: Q4KWeight::new(down_data, 512, 2).unwrap(),
ffn_down_bias: Some(vec![0.0; 256]),
ffn_norm_weight: Some(vec![1.0; 256]),
ffn_norm_bias: Some(vec![0.0; 256]),
};
assert_eq!(layer.attn_norm_weight.len(), 256);
assert!(layer.attn_norm_bias.is_some());
assert!(layer.ffn_gate_weight.is_some());
assert!(layer.ffn_norm_weight.is_some());
}
#[test]
fn test_optimized_kv_cache_creation() {
let cache = OptimizedKVCache::new(2, 64, 512);
assert_eq!(cache.num_layers, 2);
assert_eq!(cache.hidden_dim, 64);
assert_eq!(cache.max_seq_len, 512);
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_optimized_kv_cache_store_and_get() {
let mut cache = OptimizedKVCache::new(1, 4, 8);
let k1 = vec![1.0, 2.0, 3.0, 4.0];
let v1 = vec![0.1, 0.2, 0.3, 0.4];
cache.store(0, &k1, &v1);
cache.advance();
assert_eq!(cache.len(), 1);
let k_cached = cache.get_k(0);
assert_eq!(k_cached.len(), 4);
assert!((k_cached[0] - 1.0).abs() < 1e-6);
assert!((k_cached[3] - 4.0).abs() < 1e-6);
let v_cached = cache.get_v_transposed(0);
assert_eq!(v_cached.len(), 4);
assert!((v_cached[0] - 0.1).abs() < 1e-6);
}
#[test]
fn test_optimized_kv_cache_v_transpose_layout() {
let mut cache = OptimizedKVCache::new(1, 4, 8);
cache.store(0, &[1.0; 4], &[1.0, 2.0, 3.0, 4.0]); cache.advance();
cache.store(0, &[1.0; 4], &[5.0, 6.0, 7.0, 8.0]); cache.advance();
cache.store(0, &[1.0; 4], &[9.0, 10.0, 11.0, 12.0]); cache.advance();
let v_t = cache.get_v_transposed(0);
assert_eq!(v_t.len(), 12);
assert!((v_t[0] - 1.0).abs() < 1e-6); assert!((v_t[1] - 5.0).abs() < 1e-6); assert!((v_t[2] - 9.0).abs() < 1e-6); assert!((v_t[3] - 2.0).abs() < 1e-6); assert!((v_t[4] - 6.0).abs() < 1e-6); }
#[test]
fn test_attention_with_transposed_v_correctness() {
let hidden_dim = 8;
let num_heads = 2;
let head_dim = 4;
let seq_len = 3;
let q: Vec<f32> = (0..hidden_dim).map(|i| i as f32 * 0.1).collect();
let k_cache: Vec<f32> = (0..seq_len * hidden_dim)
.map(|i| (i % 7) as f32 * 0.1)
.collect();
let v_cache: Vec<f32> = (0..seq_len * hidden_dim)
.map(|i| (i % 5) as f32 * 0.1)
.collect();
let original = attention_with_cache(&q, &k_cache, &v_cache, num_heads, head_dim);
let mut v_transposed = vec![0.0f32; hidden_dim * seq_len];
for pos in 0..seq_len {
for dim in 0..hidden_dim {
v_transposed[dim * seq_len + pos] = v_cache[pos * hidden_dim + dim];
}
}
let optimized =
attention_with_transposed_v(&q, &k_cache, &v_transposed, num_heads, head_dim, seq_len);
assert_eq!(original.len(), optimized.len());
for i in 0..original.len() {
assert!(
(original[i] - optimized[i]).abs() < 1e-5,
"Mismatch at {}: {} vs {}",
i,
original[i],
optimized[i]
);
}
}
#[test]
fn test_attention_with_transposed_v_empty_cache() {
let q = vec![1.0, 2.0, 3.0, 4.0];
let k_cache: Vec<f32> = vec![];
let v_transposed: Vec<f32> = vec![];
let result = attention_with_transposed_v(&q, &k_cache, &v_transposed, 2, 2, 0);
assert_eq!(result, q);
}
#[test]
fn test_optimized_kv_cache_reset() {
let mut cache = OptimizedKVCache::new(1, 4, 8);
cache.store(0, &[1.0; 4], &[1.0; 4]);
cache.advance();
assert_eq!(cache.len(), 1);
cache.reset();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
}
#[test]
fn test_optimized_kv_cache_multiple_layers() {
let mut cache = OptimizedKVCache::new(3, 4, 8);
cache.store(0, &[1.0; 4], &[1.0; 4]);
cache.store(1, &[2.0; 4], &[2.0; 4]);
cache.store(2, &[3.0; 4], &[3.0; 4]);
cache.advance();
let k0 = cache.get_k(0);
let k1 = cache.get_k(1);
let k2 = cache.get_k(2);
assert!((k0[0] - 1.0).abs() < 1e-6);
assert!((k1[0] - 2.0).abs() < 1e-6);
assert!((k2[0] - 3.0).abs() < 1e-6);
}
#[test]
fn test_thread_config_auto() {
let config = ThreadConfig::auto();
let num_cpus = rayon::current_num_threads();
assert_eq!(config.n_threads_batch, num_cpus);
assert_eq!(config.n_threads_decode, (num_cpus / 2).max(1));
}
#[test]
fn test_thread_config_new() {
let config = ThreadConfig::new(8, 4);
assert_eq!(config.n_threads_batch, 8);
assert_eq!(config.n_threads_decode, 4);
}
#[test]
fn test_thread_config_new_clamps_to_one() {
let config = ThreadConfig::new(0, 0);
assert_eq!(config.n_threads_batch, 1);
assert_eq!(config.n_threads_decode, 1);
}
#[test]
fn test_thread_config_threads_for() {
let config = ThreadConfig::new(8, 2);
assert_eq!(config.threads_for(true), 8);
assert_eq!(config.threads_for(false), 2);
}
#[test]
fn test_thread_config_default() {
let config = ThreadConfig::default();
let auto = ThreadConfig::auto();
assert_eq!(config.n_threads_batch, auto.n_threads_batch);
assert_eq!(config.n_threads_decode, auto.n_threads_decode);
}
#[test]
fn test_inference_mode_equality() {
assert_eq!(InferenceMode::Prefill, InferenceMode::Prefill);
assert_eq!(InferenceMode::Decode, InferenceMode::Decode);
assert_ne!(InferenceMode::Prefill, InferenceMode::Decode);
}
#[test]
fn test_inference_mode_debug() {
let prefill = InferenceMode::Prefill;
let decode = InferenceMode::Decode;
assert!(format!("{:?}", prefill).contains("Prefill"));
assert!(format!("{:?}", decode).contains("Decode"));
}
#[test]
fn test_inference_mode_clone() {
let original = InferenceMode::Prefill;
let cloned = original;
assert_eq!(original, cloned);
}
#[test]
fn test_thread_config_with_inference_mode() {
let config = ThreadConfig::new(16, 4);
let mode = InferenceMode::Prefill;
let threads = config.threads_for(mode == InferenceMode::Prefill);
assert_eq!(threads, 16);
let mode = InferenceMode::Decode;
let threads = config.threads_for(mode == InferenceMode::Prefill);
assert_eq!(threads, 4);
}
#[test]
fn test_simd_operations_edge_cases_additional() {
let mut large_neg = vec![-1000.0, -1000.0, -1000.0];
simd_softmax(&mut large_neg);
let sum: f32 = large_neg.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn test_optimized_kv_cache_v_transposed_multi() {
let mut cache = OptimizedKVCache::new(1, 4, 8);
cache.store(0, &[0.0; 4], &[1.0, 2.0, 3.0, 4.0]);
cache.advance();
cache.store(0, &[0.0; 4], &[5.0, 6.0, 7.0, 8.0]);
cache.advance();
let v_t = cache.get_v_transposed(0);
assert_eq!(v_t.len(), 4 * 2);
assert!((v_t[0] - 1.0).abs() < 1e-5); assert!((v_t[1] - 5.0).abs() < 1e-5); }
#[test]
fn test_configure_thread_pool_call() {
let result = configure_thread_pool(4);
let _ = result;
}
#[test]
fn test_q4k_weight_multiple_output_rows() {
let data = create_q4k_test_data(4); let weight = Q4KWeight::new(data, 256, 4).unwrap();
assert_eq!(weight.out_dim, 4);
let input = vec![1.0f32; 256];
let output = weight.matvec(&input).unwrap();
assert_eq!(output.len(), 4);
}
}