use std::collections::HashMap;
use std::sync::OnceLock;
#[allow(dead_code)]
pub const SCALAR_TOLERANCE: f32 = 0.0; pub const SIMD_TOLERANCE: f32 = 1e-6; #[allow(dead_code)]
pub const GPU_TOLERANCE: f32 = 1e-5;
#[allow(dead_code)]
static GOLDEN_BASELINES: OnceLock<HashMap<String, Vec<f32>>> = OnceLock::new();
#[allow(dead_code)]
pub fn get_golden_baselines() -> &'static HashMap<String, Vec<f32>> {
GOLDEN_BASELINES.get_or_init(HashMap::new)
}
pub fn scalar_rmsnorm(x: &[f32], weight: &[f32], eps: f32) -> Vec<f32> {
let n = x.len();
let sum_sq: f32 = x.iter().map(|v| v * v).sum();
let rms = (sum_sq / n as f32 + eps).sqrt();
x.iter().zip(weight.iter()).map(|(xi, wi)| (xi / rms) * wi).collect()
}
pub fn scalar_silu(x: &[f32]) -> Vec<f32> {
x.iter().map(|xi| xi * (1.0 / (1.0 + (-xi).exp()))).collect()
}
pub fn scalar_softmax(x: &[f32]) -> Vec<f32> {
let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_vals: Vec<f32> = x.iter().map(|xi| (xi - max_val).exp()).collect();
let sum: f32 = exp_vals.iter().sum();
exp_vals.iter().map(|e| e / sum).collect()
}
pub fn scalar_rope(x: &[f32], freqs_cos: &[f32], freqs_sin: &[f32]) -> Vec<f32> {
let n = x.len();
let half = n / 2;
let mut out = vec![0.0f32; n];
for i in 0..half {
let x0 = x[i];
let x1 = x[i + half];
let cos = freqs_cos[i];
let sin = freqs_sin[i];
out[i] = x0 * cos - x1 * sin;
out[i + half] = x0 * sin + x1 * cos;
}
out
}
pub fn scalar_causal_mask(scores: &[f32], seq_len: usize) -> Vec<f32> {
let mut out = scores.to_vec();
for i in 0..seq_len {
for j in (i + 1)..seq_len {
out[i * seq_len + j] = f32::NEG_INFINITY;
}
}
out
}
pub fn scalar_q4k_dequant(quantized: &[u8], scale: f32, zero_point: f32) -> Vec<f32> {
quantized
.iter()
.flat_map(|byte| {
let low = (byte & 0x0F) as f32;
let high = ((byte >> 4) & 0x0F) as f32;
vec![(low - zero_point) * scale, (high - zero_point) * scale]
})
.collect()
}
pub fn compute_rope_freqs(dim: usize, base: f32) -> (Vec<f32>, Vec<f32>) {
let half = dim / 2;
let mut freqs_cos = vec![0.0f32; half];
let mut freqs_sin = vec![0.0f32; half];
for i in 0..half {
let freq = 1.0 / base.powf(2.0 * i as f32 / dim as f32);
freqs_cos[i] = freq.cos();
freqs_sin[i] = freq.sin();
}
(freqs_cos, freqs_sin)
}
pub struct SimpleRng {
state: u64,
}
impl SimpleRng {
pub fn new(seed: u64) -> Self {
Self { state: seed }
}
pub fn next_f32(&mut self) -> f32 {
self.state ^= self.state << 13;
self.state ^= self.state >> 7;
self.state ^= self.state << 17;
(self.state as f32 / u64::MAX as f32) * 2.0 - 1.0 }
pub fn gen_vec(&mut self, n: usize) -> Vec<f32> {
(0..n).map(|_| self.next_f32()).collect()
}
}
pub fn vectors_match(a: &[f32], b: &[f32], tolerance: f32, name: &str) -> bool {
if a.len() != b.len() {
eprintln!("{name}: length mismatch {} vs {}", a.len(), b.len());
return false;
}
let mut max_diff = 0.0f32;
let mut max_diff_idx = 0;
for (i, (ai, bi)) in a.iter().zip(b.iter()).enumerate() {
let diff = (ai - bi).abs();
if diff > max_diff {
max_diff = diff;
max_diff_idx = i;
}
}
if max_diff > tolerance {
eprintln!(
"{name}: max diff {max_diff:.2e} at index {max_diff_idx} exceeds tolerance {tolerance:.2e}"
);
return false;
}
println!("{name}: PASS (max_diff={max_diff:.2e})");
true
}