#![allow(dead_code)]
pub mod transformer;
use crate::attention::AttentionParams;
#[cfg(target_os = "macos")]
extern "C" {
fn cblas_sgemm(
order: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: *const f32,
lda: i32,
b: *const f32,
ldb: i32,
beta: f32,
c: *mut f32,
ldc: i32,
);
}
#[cfg(target_os = "macos")]
fn gemm_at_bt(a: &[f32], b: &[f32], c: &mut [f32], m: usize, n: usize, k: usize) {
unsafe {
cblas_sgemm(
101,
111,
112, m as i32,
n as i32,
k as i32,
1.0,
a.as_ptr(),
k as i32,
b.as_ptr(),
k as i32,
0.0,
c.as_mut_ptr(),
n as i32,
);
}
}
#[cfg(not(target_os = "macos"))]
fn gemm_at_bt(a: &[f32], b: &[f32], c: &mut [f32], m: usize, n: usize, k: usize) {
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f64;
for p in 0..k {
sum += a[i * k + p] as f64 * b[j * k + p] as f64;
}
c[i * n + j] = sum as f32;
}
}
}
#[cfg(target_os = "macos")]
fn gemm_tb_nt(a: &[f32], b: &[f32], c: &mut [f32], m: usize, n: usize, k: usize) {
unsafe {
cblas_sgemm(
102,
112,
111, m as i32,
n as i32,
k as i32,
1.0,
a.as_ptr(),
k as i32, b.as_ptr(),
k as i32,
0.0,
c.as_mut_ptr(),
m as i32,
);
}
}
#[cfg(not(target_os = "macos"))]
fn gemm_tb_nt(a: &[f32], b: &[f32], c: &mut [f32], m: usize, n: usize, k: usize) {
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f64;
for p in 0..k {
sum += a[p * m + i] as f64 * b[p * n + j] as f64;
}
c[i * n + j] = sum as f32;
}
}
}
fn gemm_nt_nt_colmajor(_a: &[f32], _b: &[f32], _c: &mut [f32], _m: usize, _n: usize, _k: usize) {
}
#[cfg(target_os = "macos")]
pub fn gemm_at_b(a: &[f32], b: &[f32], c: &mut [f32], m: usize, n: usize, k: usize) {
unsafe {
cblas_sgemm(
101,
111,
111, m as i32,
n as i32,
k as i32,
1.0,
a.as_ptr(),
k as i32,
b.as_ptr(),
n as i32,
0.0,
c.as_mut_ptr(),
n as i32,
);
}
}
#[cfg(not(target_os = "macos"))]
pub fn gemm_at_b(a: &[f32], b: &[f32], c: &mut [f32], m: usize, n: usize, k: usize) {
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f64;
for p in 0..k {
sum += a[i * k + p] as f64 * b[p * n + j] as f64;
}
c[i * n + j] = sum as f32;
}
}
}
pub fn softmax_inplace(
scores: &mut [f32],
m: usize,
n: usize,
causal: bool,
pos_offset: usize,
sliding_window: usize,
) {
for i in 0..m {
let row = &mut scores[i * n..(i + 1) * n];
let attend_end = if causal {
(pos_offset + i + 1).min(n)
} else {
n
};
let attend_start = if causal && sliding_window > 0 {
attend_end.saturating_sub(sliding_window)
} else {
0
};
for j in 0..attend_start {
row[j] = f32::NEG_INFINITY;
}
for j in attend_end..n {
row[j] = f32::NEG_INFINITY;
}
let mut max_val = f32::NEG_INFINITY;
for j in attend_start..attend_end {
if row[j] > max_val {
max_val = row[j];
}
}
let mut sum = 0.0f32;
for j in 0..n {
let e = (row[j] - max_val).exp();
row[j] = e;
sum += e;
}
let inv = 1.0f32 / sum;
for j in 0..n {
row[j] *= inv;
}
}
}
pub fn fused_attention(q: &[f32], k: &[f32], v: &[f32], out: &mut [f32], p: &AttentionParams) {
let nh = p.num_heads;
let nkv = p.num_kv_heads;
let n_rep = nh / nkv;
let sq = p.q_len;
let sk = p.kv_len;
let d = p.head_dim;
let scale = 1.0f32 / (d as f32).sqrt();
for b in 0..p.batch {
for h in 0..nh {
let kv_h = h / n_rep;
let q_off = (b * nh + h) * sq * d;
let k_off = (b * nkv + kv_h) * sk * d;
let v_off = (b * nkv + kv_h) * sk * d;
let o_off = (b * nh + h) * sq * d;
let q_slice = &q[q_off..q_off + sq * d];
let k_slice = &k[k_off..k_off + sk * d];
let v_slice = &v[v_off..v_off + sk * d];
let mut scores = vec![0.0f32; sq * sk];
gemm_at_bt(q_slice, k_slice, &mut scores, sq, sk, d);
for s in scores.iter_mut() {
*s *= scale;
}
softmax_inplace(
&mut scores,
sq,
sk,
p.causal,
p.pos_offset,
p.sliding_window,
);
let o_slice = &mut out[o_off..o_off + sq * d];
gemm_at_b(&scores, v_slice, o_slice, sq, d, sk);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cpu_attention_causal() {
let q = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
let k = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut out = vec![0.0f32; 8];
let params = AttentionParams {
batch: 1,
num_heads: 1,
num_kv_heads: 1,
q_len: 2,
kv_len: 2,
head_dim: 4,
causal: true,
pos_offset: 0,
sliding_window: 0,
};
fused_attention(&q, &k, &v, &mut out, ¶ms);
assert!((out[0] - 1.0).abs() < 1e-4, "out[0]={}", out[0]);
assert!((out[1] - 2.0).abs() < 1e-4);
let e0 = (-0.5f32).exp();
let w0 = e0 / (e0 + 1.0);
let w1 = 1.0 / (e0 + 1.0);
let expected = w0 * 1.0 + w1 * 5.0;
assert!(
(out[4] - expected).abs() < 1e-3,
"out[4]={} expected={}",
out[4],
expected
);
}
#[test]
fn test_cpu_attention_gqa() {
let q = vec![1.0, 0.0, 0.0, 1.0]; let k = vec![1.0, 0.0]; let v = vec![3.0, 7.0]; let mut out = vec![0.0f32; 4];
let params = AttentionParams {
batch: 1,
num_heads: 2,
num_kv_heads: 1,
q_len: 1,
kv_len: 1,
head_dim: 2,
causal: false,
pos_offset: 0,
sliding_window: 0,
};
fused_attention(&q, &k, &v, &mut out, ¶ms);
assert!((out[0] - 3.0).abs() < 1e-4);
assert!((out[1] - 7.0).abs() < 1e-4);
assert!((out[2] - 3.0).abs() < 1e-4);
assert!((out[3] - 7.0).abs() < 1e-4);
}
#[test]
fn test_cpu_attention_sliding_window_equals_self() {
let q = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.5, 0.5];
let k = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.5, 0.5];
let v = vec![0.0, 10.0, 1.0, 11.0, 2.0, 12.0, 3.0, 13.0];
let mut out = vec![0.0f32; 8];
let params = AttentionParams {
batch: 1,
num_heads: 1,
num_kv_heads: 1,
q_len: 4,
kv_len: 4,
head_dim: 2,
causal: true,
pos_offset: 0,
sliding_window: 1,
};
fused_attention(&q, &k, &v, &mut out, ¶ms);
for i in 0..4 {
assert!(
(out[i * 2] - v[i * 2]).abs() < 1e-4,
"row {i}: expected {}, got {}",
v[i * 2],
out[i * 2]
);
assert!(
(out[i * 2 + 1] - v[i * 2 + 1]).abs() < 1e-4,
"row {i}: expected {}, got {}",
v[i * 2 + 1],
out[i * 2 + 1]
);
}
}
}