use trueno::Vector;
pub(super) fn apply_rope_inline(
x: &mut [f32],
num_heads: usize,
head_dim: usize,
rope_theta: f32,
position: usize,
) {
let half_dim = head_dim / 2;
let head_dim_f32 = head_dim as f32;
let pos_f32 = position as f32;
for h in 0..num_heads {
let head_start = h * head_dim;
let idx2_start = head_start + half_dim;
for i in 0..half_dim {
let freq = 1.0 / rope_theta.powf(2.0 * i as f32 / head_dim_f32);
let angle = pos_f32 * freq;
let (sin_val, cos_val) = angle.sin_cos();
let x1 = x[head_start + i];
let x2 = x[idx2_start + i];
x[head_start + i] = x1 * cos_val - x2 * sin_val;
x[idx2_start + i] = x1 * sin_val + x2 * cos_val;
}
}
}
#[allow(clippy::too_many_arguments)]
pub(super) fn gqa_multihead_attention(
q: &[f32], k: &[f32], v: &[f32], kv_len: usize,
num_heads: usize, num_kv_heads: usize, head_dim: usize,
) -> Vec<f32> {
let hidden_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let scale = 1.0 / (head_dim as f32).sqrt();
let heads_per_kv = num_heads / num_kv_heads;
let mut output = vec![0.0; hidden_dim];
for h in 0..num_heads {
let q_head = &q[h * head_dim..(h + 1) * head_dim];
let q_vec = Vector::from_slice(q_head);
let kv_head = h / heads_per_kv;
let mut scores = Vec::with_capacity(kv_len);
for pos in 0..kv_len {
let k_offset = pos * kv_dim + kv_head * head_dim;
let cached_key = &k[k_offset..k_offset + head_dim];
let k_vec = Vector::from_slice(cached_key);
let score = q_vec.dot(&k_vec).unwrap_or(0.0) * scale;
scores.push(score);
}
let scores_vec = Vector::from_slice(&scores);
let attn_weights: Vec<f32> = scores_vec.softmax().map_or_else(
|_| {
let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
exp_scores.iter().map(|&e| e / sum_exp).collect()
},
|v| v.as_slice().to_vec(),
);
for (pos, &weight) in attn_weights.iter().enumerate() {
let v_offset = pos * kv_dim + kv_head * head_dim;
let v_head = &v[v_offset..v_offset + head_dim];
for d in 0..head_dim {
output[h * head_dim + d] += weight * v_head[d];
}
}
}
output
}
#[allow(clippy::cast_precision_loss)]
pub(crate) fn layer_norm_static(
input: &[f32],
weight: &[f32],
bias: &[f32],
hidden_dim: usize,
eps: f32,
) -> Vec<f32> {
let num_rows = input.len() / hidden_dim;
let mut output = Vec::with_capacity(input.len());
for row in 0..num_rows {
let start = row * hidden_dim;
let row_data = &input[start..start + hidden_dim];
let sum_sq: f32 = row_data.iter().map(|&x| x * x).sum();
let rms = (sum_sq / hidden_dim as f32 + eps).sqrt();
for (i, &x) in row_data.iter().enumerate() {
let normalized = x / rms;
output.push(normalized * weight[i] + bias[i]);
}
}
output
}
pub(super) fn sample_topk(logits: &[f32], temperature: f32, top_k: usize) -> usize {
let scaled: Vec<f32> = logits.iter().map(|&x| x / temperature).collect();
let max_logit = scaled.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_logits: Vec<f32> = scaled.iter().map(|&x| (x - max_logit).exp()).collect();
let sum: f32 = exp_logits.iter().sum();
let probs: Vec<f32> = exp_logits.iter().map(|&x| x / sum).collect();
let mut indexed: Vec<(usize, f32)> = probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(top_k);
indexed.first().map_or(0, |&(idx, _)| idx)
}
pub(super) fn transpose_weights(weights: &[f32], rows: usize, cols: usize) -> Vec<f32> {
crate::contract_gate::transpose_f32(weights, rows, cols)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_norm_static_single_row() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let weight = vec![1.0, 1.0, 1.0, 1.0];
let bias = vec![0.0, 0.0, 0.0, 0.0];
let eps = 1e-5;
let output = layer_norm_static(&input, &weight, &bias, 4, eps);
assert_eq!(output.len(), 4);
let sum_sq: f32 = input.iter().map(|x| x * x).sum();
let rms = (sum_sq / 4.0 + eps).sqrt();
for (i, &x) in input.iter().enumerate() {
let expected = x / rms;
assert!((output[i] - expected).abs() < 1e-5);
}
}
#[test]
fn test_layer_norm_static_with_weight_bias() {
let input = vec![2.0, 2.0, 2.0, 2.0];
let weight = vec![2.0, 2.0, 2.0, 2.0];
let bias = vec![0.5, 0.5, 0.5, 0.5];
let eps = 1e-5;
let output = layer_norm_static(&input, &weight, &bias, 4, eps);
for &val in &output {
assert!((val - 2.5).abs() < 0.01);
}
}
#[test]
fn test_transpose_weights() {
let weights = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let transposed = transpose_weights(&weights, 2, 3);
assert_eq!(transposed, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_sample_topk_deterministic() {
let logits = vec![1.0, 5.0, 2.0, 0.5];
let result = sample_topk(&logits, 1.0, 1);
assert_eq!(result, 1); }
#[test]
fn test_apply_rope_inline() {
let mut x = vec![1.0, 0.0, 0.0, 1.0]; apply_rope_inline(&mut x, 1, 4, 10000.0, 0);
assert!((x[0] - 1.0).abs() < 1e-5);
assert!((x[1] - 0.0).abs() < 1e-5);
}
#[test]
fn test_gqa_multihead_attention_simple() {
let q = vec![1.0, 0.0, 0.0, 1.0]; let k = vec![1.0, 0.0, 0.0, 1.0]; let v = vec![1.0, 2.0, 3.0, 4.0];
let output = gqa_multihead_attention(&q, &k, &v, 1, 2, 2, 2);
assert_eq!(output.len(), 4);
assert!((output[0] - 1.0).abs() < 1e-5);
assert!((output[1] - 2.0).abs() < 1e-5);
}
}