use super::MAGIC;
use std::fs;
use std::path::Path;
pub(crate) fn rms_norm(x: &[f32], weight: &[f32], eps: f32) -> Vec<f32> {
#[cfg(feature = "gpu")]
{
let hidden_dim = weight.len();
let zero_bias = vec![0.0f32; hidden_dim];
crate::gpu::layer_norm_static(x, weight, &zero_bias, hidden_dim, eps)
}
#[cfg(not(feature = "gpu"))]
{
let hidden_dim = weight.len();
let n_tokens = x.len() / hidden_dim;
let mut output = vec![0.0f32; x.len()];
for t in 0..n_tokens {
let offset = t * hidden_dim;
let slice = &x[offset..offset + hidden_dim];
let ss: f32 = slice.iter().map(|v| v * v).sum::<f32>() / hidden_dim as f32;
let rms = (ss + eps).sqrt();
for i in 0..hidden_dim {
output[offset + i] = slice[i] / rms * weight[i];
}
}
output
}
}
pub(crate) fn matmul(
x: &[f32],
w: &[f32],
seq_len: usize,
in_dim: usize,
out_dim: usize,
) -> Vec<f32> {
#[cfg(feature = "gpu")]
{
crate::gpu::cpu_matmul_transpose_b(x, w, seq_len, in_dim, out_dim)
}
#[cfg(not(feature = "gpu"))]
{
let mut output = vec![0.0f32; seq_len * out_dim];
for m in 0..seq_len {
for n in 0..out_dim {
let mut sum = 0.0f32;
for k in 0..in_dim {
sum += x[m * in_dim + k] * w[n * in_dim + k];
}
output[m * out_dim + n] = sum;
}
}
output
}
}
#[cfg(feature = "cuda")]
pub(crate) fn transpose_matrix(m: &[f32], rows: usize, cols: usize) -> Vec<f32> {
crate::contract_gate::transpose_f32(m, rows, cols)
}
#[inline]
pub fn simd_dot(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { simd_dot_avx2(a, b) };
}
}
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn simd_dot_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::{
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_loadu_ps,
_mm256_setzero_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_movehl_ps, _mm_shuffle_ps,
};
let n = a.len().min(b.len());
let chunks = n / 8;
unsafe {
let mut sum = _mm256_setzero_ps();
for i in 0..chunks {
let av = _mm256_loadu_ps(a.as_ptr().add(i * 8));
let bv = _mm256_loadu_ps(b.as_ptr().add(i * 8));
sum = _mm256_fmadd_ps(av, bv, sum);
}
let hi = _mm256_extractf128_ps(sum, 1);
let lo = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
for i in (chunks * 8)..n {
result += a.get(i).copied().unwrap_or(0.0) * b.get(i).copied().unwrap_or(0.0);
}
result
}
}
#[inline]
fn compute_attention_score(
q: &[f32],
k: &[f32],
q_offset: usize,
k_offset: usize,
head_dim: usize,
scale: f32,
) -> f32 {
let mut score = 0.0;
for d in 0..head_dim {
let q_val = q.get(q_offset + d).copied().unwrap_or(0.0);
let k_val = k.get(k_offset + d).copied().unwrap_or(0.0);
score += q_val * k_val;
}
score * scale
}
#[inline]
fn softmax_causal(scores: &mut [f32], s: usize) {
let max_score = scores[..=s]
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0;
for score in &mut scores[..=s] {
*score = (*score - max_score).exp();
sum += *score;
}
for score in &mut scores[..=s] {
*score /= sum;
}
}
#[inline]
fn weighted_value_sum(v: &[f32], scores: &[f32], v_base: usize, d: usize, s: usize) -> f32 {
let mut val = 0.0;
for t in 0..=s {
let v_val = v.get(v_base * t + d).copied().unwrap_or(0.0);
val += scores[t] * v_val;
}
val
}
pub(crate) fn simple_attention(
q: &[f32],
k: &[f32],
v: &[f32],
seq_len: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Vec<f32> {
let hidden_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let heads_per_kv = num_heads / num_kv_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let mut output = vec![0.0; seq_len * hidden_dim];
for s in 0..seq_len {
for h in 0..num_heads {
let kv_h = h / heads_per_kv;
let q_base = s * hidden_dim + h * head_dim;
let k_base = kv_h * head_dim;
let v_base = kv_dim;
let mut scores = vec![0.0; seq_len];
for t in 0..=s {
scores[t] =
compute_attention_score(q, k, q_base, t * kv_dim + k_base, head_dim, scale);
}
softmax_causal(&mut scores, s);
for d in 0..head_dim {
let val = weighted_value_sum(v, &scores, v_base, kv_h * head_dim + d, s);
output[s * hidden_dim + h * head_dim + d] = val;
}
}
}
output
}
pub(crate) fn apply_rope_norm(
x: &mut [f32],
num_heads: usize,
head_dim: usize,
position: usize,
theta: f32,
rope_type: u32,
) {
let half_dim = head_dim / 2;
for h in 0..num_heads {
let head_offset = h * head_dim;
for i in 0..half_dim {
let freq = 1.0 / theta.powf(2.0 * i as f32 / head_dim as f32);
let angle = position as f32 * freq;
let cos_val = angle.cos();
let sin_val = angle.sin();
if rope_type == 2 {
let idx0 = head_offset + i;
let idx1 = head_offset + half_dim + i;
if idx1 < x.len() {
let x0 = x[idx0];
let x1 = x[idx1];
x[idx0] = x0 * cos_val - x1 * sin_val;
x[idx1] = x0 * sin_val + x1 * cos_val;
}
} else {
let idx0 = head_offset + 2 * i;
let idx1 = head_offset + 2 * i + 1;
if idx1 < x.len() {
let x0 = x[idx0];
let x1 = x[idx1];
x[idx0] = x0 * cos_val - x1 * sin_val;
x[idx1] = x0 * sin_val + x1 * cos_val;
}
}
}
}
}
pub fn is_apr_file<P: AsRef<Path>>(path: P) -> bool {
fs::read(path.as_ref()).is_ok_and(|data| data.len() >= 4 && data[0..4] == MAGIC)
}
fn format_from_extension(path: &Path) -> Option<&'static str> {
let ext = path.extension()?.to_string_lossy().to_lowercase();
match ext.as_str() {
"apr" => Some("apr"),
"gguf" => Some("gguf"),
"safetensors" => Some("safetensors"),
_ => None,
}
}
fn format_from_magic(path: &Path) -> &'static str {
let Ok(data) = fs::read(path) else {
return "unknown";
};
if data.len() < 4 {
return "unknown";
}
if data[0..4] == MAGIC {
return "apr";
}
if data[0..4] == [0x47, 0x47, 0x55, 0x46] {
return "gguf";
}
if data[0] == b'{' {
return "safetensors";
}
"unknown"
}
pub fn detect_format<P: AsRef<Path>>(path: P) -> &'static str {
let path = path.as_ref();
format_from_extension(path).unwrap_or_else(|| format_from_magic(path))
}
include!("helpers_tests.rs");