aprender-compute 0.32.0

High-performance SIMD compute library with GPU support, LLM inference engine, and GGUF model loading (was: trueno)
Documentation
//! Shared helpers for Pixel FKR tests

use std::collections::HashMap;
use std::sync::OnceLock;

// Tolerance constants (SPEC Section 3.5.1)
// Reserved for cross-backend comparison tests (simd-pixel-fkr, wgpu-pixel-fkr)
#[allow(dead_code)]
pub const SCALAR_TOLERANCE: f32 = 0.0; // Exact match for baseline
pub const SIMD_TOLERANCE: f32 = 1e-6; // +-1 ULP for SIMD rounding
#[allow(dead_code)]
pub const GPU_TOLERANCE: f32 = 1e-5; // +-2 ULP for GPU FP variance

// Global storage for golden baselines (generated by scalar tests)
// Reserved for cross-test baseline sharing in multi-suite runs
#[allow(dead_code)]
static GOLDEN_BASELINES: OnceLock<HashMap<String, Vec<f32>>> = OnceLock::new();

#[allow(dead_code)]
pub fn get_golden_baselines() -> &'static HashMap<String, Vec<f32>> {
    GOLDEN_BASELINES.get_or_init(HashMap::new)
}

// ============================================================================
// REALIZER CORE OPERATIONS (Scalar implementations as ground truth)
// ============================================================================

/// RMS Norm (LLaMA normalization) - scalar implementation
pub fn scalar_rmsnorm(x: &[f32], weight: &[f32], eps: f32) -> Vec<f32> {
    let n = x.len();
    let sum_sq: f32 = x.iter().map(|v| v * v).sum();
    let rms = (sum_sq / n as f32 + eps).sqrt();

    x.iter().zip(weight.iter()).map(|(xi, wi)| (xi / rms) * wi).collect()
}

/// SiLU activation (LLaMA FFN) - scalar implementation
pub fn scalar_silu(x: &[f32]) -> Vec<f32> {
    x.iter().map(|xi| xi * (1.0 / (1.0 + (-xi).exp()))).collect()
}

/// Softmax - scalar implementation with numerical stability
pub fn scalar_softmax(x: &[f32]) -> Vec<f32> {
    let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
    let exp_vals: Vec<f32> = x.iter().map(|xi| (xi - max_val).exp()).collect();
    let sum: f32 = exp_vals.iter().sum();
    exp_vals.iter().map(|e| e / sum).collect()
}

/// RoPE (Rotary Position Embedding) - scalar implementation
pub fn scalar_rope(x: &[f32], freqs_cos: &[f32], freqs_sin: &[f32]) -> Vec<f32> {
    let n = x.len();
    let half = n / 2;
    let mut out = vec![0.0f32; n];

    for i in 0..half {
        let x0 = x[i];
        let x1 = x[i + half];
        let cos = freqs_cos[i];
        let sin = freqs_sin[i];

        out[i] = x0 * cos - x1 * sin;
        out[i + half] = x0 * sin + x1 * cos;
    }

    out
}

/// Causal mask application - scalar implementation
pub fn scalar_causal_mask(scores: &[f32], seq_len: usize) -> Vec<f32> {
    let mut out = scores.to_vec();
    for i in 0..seq_len {
        for j in (i + 1)..seq_len {
            out[i * seq_len + j] = f32::NEG_INFINITY;
        }
    }
    out
}

/// Q4_K block dequantization (simplified) - scalar implementation
pub fn scalar_q4k_dequant(quantized: &[u8], scale: f32, zero_point: f32) -> Vec<f32> {
    quantized
        .iter()
        .flat_map(|byte| {
            let low = (byte & 0x0F) as f32;
            let high = ((byte >> 4) & 0x0F) as f32;
            vec![(low - zero_point) * scale, (high - zero_point) * scale]
        })
        .collect()
}

/// Compute RoPE frequencies
pub fn compute_rope_freqs(dim: usize, base: f32) -> (Vec<f32>, Vec<f32>) {
    let half = dim / 2;
    let mut freqs_cos = vec![0.0f32; half];
    let mut freqs_sin = vec![0.0f32; half];

    for i in 0..half {
        let freq = 1.0 / base.powf(2.0 * i as f32 / dim as f32);
        freqs_cos[i] = freq.cos();
        freqs_sin[i] = freq.sin();
    }

    (freqs_cos, freqs_sin)
}

/// Simple deterministic RNG for test data
pub struct SimpleRng {
    state: u64,
}

impl SimpleRng {
    pub fn new(seed: u64) -> Self {
        Self { state: seed }
    }

    pub fn next_f32(&mut self) -> f32 {
        // xorshift64
        self.state ^= self.state << 13;
        self.state ^= self.state >> 7;
        self.state ^= self.state << 17;
        (self.state as f32 / u64::MAX as f32) * 2.0 - 1.0 // Range [-1, 1]
    }

    pub fn gen_vec(&mut self, n: usize) -> Vec<f32> {
        (0..n).map(|_| self.next_f32()).collect()
    }
}

/// Compare two float vectors with tolerance
pub fn vectors_match(a: &[f32], b: &[f32], tolerance: f32, name: &str) -> bool {
    if a.len() != b.len() {
        eprintln!("{name}: length mismatch {} vs {}", a.len(), b.len());
        return false;
    }

    let mut max_diff = 0.0f32;
    let mut max_diff_idx = 0;

    for (i, (ai, bi)) in a.iter().zip(b.iter()).enumerate() {
        let diff = (ai - bi).abs();
        if diff > max_diff {
            max_diff = diff;
            max_diff_idx = i;
        }
    }

    if max_diff > tolerance {
        eprintln!(
            "{name}: max diff {max_diff:.2e} at index {max_diff_idx} exceeds tolerance {tolerance:.2e}"
        );
        return false;
    }

    println!("{name}: PASS (max_diff={max_diff:.2e})");
    true
}