use super::device::{ComputeBuffer, ComputeDevice};
use crate::expr::codegen::Dialect;
use crate::expr::node::ExprId;
use crate::expr::trace;
pub struct CpuBuffer {
data: Vec<f32>,
}
impl ComputeBuffer for CpuBuffer {
fn len(&self) -> usize {
self.data.len()
}
fn to_vec(&self) -> Vec<f32> {
self.data.clone()
}
}
pub struct CpuDevice;
impl CpuDevice {
pub fn new() -> Self {
CpuDevice
}
}
impl Default for CpuDevice {
fn default() -> Self {
Self::new()
}
}
impl ComputeDevice for CpuDevice {
type Buffer = CpuBuffer;
fn dialect(&self) -> Dialect {
Dialect::C
}
fn upload(&self, data: &[f32]) -> CpuBuffer {
CpuBuffer {
data: data.to_vec(),
}
}
fn upload_u32(&self, data: &[u32]) -> CpuBuffer {
CpuBuffer {
data: data.iter().map(|&x| f32::from_bits(x)).collect(),
}
}
fn alloc(&self, len: usize) -> CpuBuffer {
CpuBuffer {
data: vec![0.0; len],
}
}
fn download(&self, buf: &CpuBuffer) -> Vec<f32> {
buf.data.clone()
}
fn elementwise(
&self,
inputs: &[&CpuBuffer],
numel: usize,
f: &dyn Fn(&[ExprId]) -> ExprId,
) -> CpuBuffer {
let n_inputs = inputs.len();
let (graph, output) = trace(|| {
let vars: Vec<ExprId> = (0..n_inputs as u16).map(ExprId::var).collect();
f(&vars)
});
let compiled = graph.compile(output);
let mut result = vec![0.0f32; numel];
let mut args = vec![0.0f64; n_inputs];
for i in 0..numel {
for (j, input) in inputs.iter().enumerate() {
args[j] = input.data[i] as f64;
}
result[i] = compiled(&args) as f32;
}
CpuBuffer { data: result }
}
fn matmul(&self, a: &CpuBuffer, b: &CpuBuffer, m: usize, k: usize, n: usize) -> CpuBuffer {
let mut c = vec![0.0f32; m * n];
matmul_impl(&a.data, &b.data, &mut c, m, k, n);
CpuBuffer { data: c }
}
fn softmax(&self, data: &CpuBuffer, n_rows: usize, row_len: usize) -> CpuBuffer {
let mut out = data.data.clone();
for row in 0..n_rows {
let start = row * row_len;
let end = start + row_len;
let slice = &mut out[start..end];
let max = slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
for v in slice.iter_mut() {
*v = (*v - max).exp();
}
let sum: f32 = slice.iter().sum();
let inv = 1.0 / sum;
for v in slice.iter_mut() {
*v *= inv;
}
}
CpuBuffer { data: out }
}
fn rms_norm(
&self,
data: &CpuBuffer,
weight: &CpuBuffer,
n_groups: usize,
dim: usize,
eps: f32,
) -> CpuBuffer {
let mut out = vec![0.0f32; n_groups * dim];
for g in 0..n_groups {
let start = g * dim;
let slice = &data.data[start..start + dim];
let sq_sum: f32 = slice.iter().map(|x| x * x).sum();
let rms = (sq_sum / dim as f32 + eps).sqrt();
let inv_rms = 1.0 / rms;
for d in 0..dim {
out[start + d] = slice[d] * inv_rms * weight.data[d];
}
}
CpuBuffer { data: out }
}
fn embedding(
&self,
weight: &CpuBuffer,
ids: &CpuBuffer,
seq_len: usize,
dim: usize,
) -> CpuBuffer {
let mut out = vec![0.0f32; seq_len * dim];
for i in 0..seq_len {
let id = ids.data[i].to_bits() as usize;
let src_start = id * dim;
let dst_start = i * dim;
out[dst_start..dst_start + dim]
.copy_from_slice(&weight.data[src_start..src_start + dim]);
}
CpuBuffer { data: out }
}
fn reduce_sum(&self, data: &CpuBuffer, shape: &[usize], axis: usize) -> CpuBuffer {
let ndim = shape.len();
assert!(axis < ndim);
let mut out_shape = shape.to_vec();
out_shape.remove(axis);
let out_len: usize = out_shape.iter().product();
if out_len == 0 {
return CpuBuffer { data: vec![] };
}
let mut result = vec![0.0f32; out_len];
let outer: usize = shape[..axis].iter().product();
let axis_len = shape[axis];
let inner: usize = shape[axis + 1..].iter().product();
for o in 0..outer {
for a in 0..axis_len {
for i in 0..inner {
let src_idx = o * axis_len * inner + a * inner + i;
let dst_idx = o * inner + i;
result[dst_idx] += data.data[src_idx];
}
}
}
CpuBuffer { data: result }
}
fn causal_attention(
&self,
q: &CpuBuffer,
k: &CpuBuffer,
v: &CpuBuffer,
seq_len: usize,
n_heads: usize,
n_kv_heads: usize,
head_dim: usize,
) -> CpuBuffer {
let total_dim = n_heads * head_dim;
let kv_dim = n_kv_heads * head_dim;
let heads_per_kv = n_heads / n_kv_heads;
let mut out = vec![0.0f32; seq_len * total_dim];
let scale = 1.0 / (head_dim as f32).sqrt();
for h in 0..n_heads {
let kv_h = h / heads_per_kv;
let q_off = h * head_dim;
let kv_off = kv_h * head_dim;
for i in 0..seq_len {
let mut running_max = f32::NEG_INFINITY;
let mut running_sum = 0.0f32;
let mut accum = vec![0.0f32; head_dim];
for j in 0..=i {
let mut dot = 0.0f32;
for d in 0..head_dim {
dot += q.data[i * total_dim + q_off + d]
* k.data[j * kv_dim + kv_off + d];
}
let score = dot * scale;
let new_max = running_max.max(score);
let exp_score = (score - new_max).exp();
let rescale = (running_max - new_max).exp();
running_sum = running_sum * rescale + exp_score;
for d in 0..head_dim {
accum[d] = accum[d] * rescale
+ exp_score * v.data[j * kv_dim + kv_off + d];
}
running_max = new_max;
}
let inv = 1.0 / running_sum;
for d in 0..head_dim {
out[i * total_dim + q_off + d] = accum[d] * inv;
}
}
}
CpuBuffer { data: out }
}
fn kv_attention(
&self,
q: &CpuBuffer,
k_cache: &CpuBuffer,
v_cache: &CpuBuffer,
cache_start: usize,
q_len: usize,
n_heads: usize,
n_kv_heads: usize,
head_dim: usize,
) -> CpuBuffer {
let total_dim = n_heads * head_dim;
let kv_dim = n_kv_heads * head_dim;
let heads_per_kv = n_heads / n_kv_heads;
let mut out = vec![0.0f32; q_len * total_dim];
let scale = 1.0 / (head_dim as f32).sqrt();
for qi in 0..q_len {
let attend_len = cache_start + qi + 1;
for h in 0..n_heads {
let kv_h = h / heads_per_kv;
let q_off = qi * total_dim + h * head_dim;
let kv_off = kv_h * head_dim;
let mut running_max = f32::NEG_INFINITY;
let mut running_sum = 0.0f32;
let mut accum = vec![0.0f32; head_dim];
for j in 0..attend_len {
let mut dot = 0.0f32;
for d in 0..head_dim {
dot += q.data[q_off + d] * k_cache.data[j * kv_dim + kv_off + d];
}
let score = dot * scale;
let new_max = running_max.max(score);
let exp_score = (score - new_max).exp();
let rescale = (running_max - new_max).exp();
running_sum = running_sum * rescale + exp_score;
for d in 0..head_dim {
accum[d] = accum[d] * rescale
+ exp_score * v_cache.data[j * kv_dim + kv_off + d];
}
running_max = new_max;
}
let out_off = qi * total_dim + h * head_dim;
let inv = 1.0 / running_sum;
for d in 0..head_dim {
out[out_off + d] = accum[d] * inv;
}
}
}
CpuBuffer { data: out }
}
fn transpose_2d(&self, buf: &CpuBuffer, rows: usize, cols: usize) -> CpuBuffer {
assert_eq!(buf.data.len(), rows * cols);
let mut out = vec![0.0f32; rows * cols];
for r in 0..rows {
for c in 0..cols {
out[c * rows + r] = buf.data[r * cols + c];
}
}
CpuBuffer { data: out }
}
fn softmax_backward(
&self,
softmax_out: &CpuBuffer,
grad_output: &CpuBuffer,
n_rows: usize,
row_len: usize,
) -> CpuBuffer {
let mut grad_input = vec![0.0f32; n_rows * row_len];
for row in 0..n_rows {
let base = row * row_len;
let mut dot = 0.0f32;
for j in 0..row_len {
dot += softmax_out.data[base + j] * grad_output.data[base + j];
}
for j in 0..row_len {
grad_input[base + j] =
softmax_out.data[base + j] * (grad_output.data[base + j] - dot);
}
}
CpuBuffer { data: grad_input }
}
fn rms_norm_backward(
&self,
input: &CpuBuffer,
weight: &CpuBuffer,
grad_output: &CpuBuffer,
n_groups: usize,
dim: usize,
eps: f32,
) -> (CpuBuffer, CpuBuffer) {
let mut grad_input = vec![0.0f32; n_groups * dim];
let mut grad_weight = vec![0.0f32; dim];
for g in 0..n_groups {
let base = g * dim;
let x = &input.data[base..base + dim];
let sq_sum: f32 = x.iter().map(|v| v * v).sum();
let rms_sq = sq_sum / dim as f32 + eps;
let inv_rms = 1.0 / rms_sq.sqrt();
for d in 0..dim {
grad_weight[d] += grad_output.data[base + d] * x[d] * inv_rms;
}
let mut sum_xwg = 0.0f32;
for d in 0..dim {
sum_xwg += x[d] * weight.data[d] * grad_output.data[base + d];
}
for d in 0..dim {
grad_input[base + d] = weight.data[d] * inv_rms * grad_output.data[base + d]
- x[d] * inv_rms * inv_rms * inv_rms / dim as f32 * sum_xwg;
}
}
(CpuBuffer { data: grad_input }, CpuBuffer { data: grad_weight })
}
fn embedding_backward(
&self,
grad_output: &CpuBuffer,
ids: &CpuBuffer,
vocab_size: usize,
seq_len: usize,
dim: usize,
) -> CpuBuffer {
let mut grad_weight = vec![0.0f32; vocab_size * dim];
for i in 0..seq_len {
let id = ids.data[i].to_bits() as usize;
for d in 0..dim {
grad_weight[id * dim + d] += grad_output.data[i * dim + d];
}
}
CpuBuffer { data: grad_weight }
}
fn causal_attention_backward(
&self,
grad_output: &CpuBuffer,
q: &CpuBuffer,
k: &CpuBuffer,
v: &CpuBuffer,
seq_len: usize,
n_heads: usize,
n_kv_heads: usize,
head_dim: usize,
) -> (CpuBuffer, CpuBuffer, CpuBuffer) {
let total_dim = n_heads * head_dim;
let kv_dim = n_kv_heads * head_dim;
let heads_per_kv = n_heads / n_kv_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let mut grad_q = vec![0.0f32; seq_len * total_dim];
let mut grad_k = vec![0.0f32; seq_len * kv_dim];
let mut grad_v = vec![0.0f32; seq_len * kv_dim];
for h in 0..n_heads {
let kv_h = h / heads_per_kv;
let q_off = h * head_dim;
let kv_off = kv_h * head_dim;
let mut scores = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
if j > i {
scores[i * seq_len + j] = f32::NEG_INFINITY;
} else {
let mut dot = 0.0f32;
for d in 0..head_dim {
dot += q.data[i * total_dim + q_off + d]
* k.data[j * kv_dim + kv_off + d];
}
scores[i * seq_len + j] = dot * scale;
}
}
}
let mut probs = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
let row = &scores[i * seq_len..(i + 1) * seq_len];
let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for j in 0..seq_len {
let e = (row[j] - max_val).exp();
probs[i * seq_len + j] = e;
sum += e;
}
let inv = 1.0 / sum;
for j in 0..seq_len {
probs[i * seq_len + j] *= inv;
}
}
for j in 0..seq_len {
for d in 0..head_dim {
let mut sum = 0.0f32;
for i in 0..seq_len {
sum += probs[i * seq_len + j]
* grad_output.data[i * total_dim + q_off + d];
}
grad_v[j * kv_dim + kv_off + d] += sum;
}
}
let mut grad_p = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
let mut dot = 0.0f32;
for d in 0..head_dim {
dot += grad_output.data[i * total_dim + q_off + d]
* v.data[j * kv_dim + kv_off + d];
}
grad_p[i * seq_len + j] = dot;
}
}
let mut grad_s = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
let base = i * seq_len;
let mut dot = 0.0f32;
for j in 0..seq_len {
dot += probs[base + j] * grad_p[base + j];
}
for j in 0..seq_len {
grad_s[base + j] = probs[base + j] * (grad_p[base + j] - dot);
}
}
for i in 0..seq_len {
for d in 0..head_dim {
let mut sum = 0.0f32;
for j in 0..seq_len {
sum += grad_s[i * seq_len + j] * k.data[j * kv_dim + kv_off + d];
}
grad_q[i * total_dim + q_off + d] = sum * scale;
}
}
for j in 0..seq_len {
for d in 0..head_dim {
let mut sum = 0.0f32;
for i in 0..seq_len {
sum += grad_s[i * seq_len + j]
* q.data[i * total_dim + q_off + d];
}
grad_k[j * kv_dim + kv_off + d] += sum * scale;
}
}
}
(
CpuBuffer { data: grad_q },
CpuBuffer { data: grad_k },
CpuBuffer { data: grad_v },
)
}
fn cross_entropy_forward_backward(
&self,
logits: &CpuBuffer,
targets: &CpuBuffer,
n_positions: usize,
vocab_size: usize,
pad_id: u32,
) -> (f32, CpuBuffer) {
let mut grad = vec![0.0f32; n_positions * vocab_size];
let mut total_loss = 0.0f64;
let mut count = 0usize;
for pos in 0..n_positions {
let target = targets.data[pos].to_bits();
if target == pad_id {
continue;
}
count += 1;
let base = pos * vocab_size;
let row = &logits.data[base..base + vocab_size];
let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f64;
for j in 0..vocab_size {
sum += ((row[j] - max_val) as f64).exp();
}
let log_sum = sum.ln();
let log_prob = (row[target as usize] - max_val) as f64 - log_sum;
total_loss -= log_prob;
for j in 0..vocab_size {
let sm = (((row[j] - max_val) as f64).exp() / sum) as f32;
grad[base + j] = sm;
}
grad[base + target as usize] -= 1.0;
}
if count > 0 {
let inv_count = 1.0 / count as f32;
for g in grad.iter_mut() {
*g *= inv_count;
}
total_loss /= count as f64;
}
(total_loss as f32, CpuBuffer { data: grad })
}
fn sync(&self) {
}
fn copy_buffer(&self, src: &CpuBuffer) -> CpuBuffer {
CpuBuffer { data: src.data.clone() }
}
fn bias_add(&self, matrix: &CpuBuffer, bias: &CpuBuffer, numel: usize, dim: usize) -> CpuBuffer {
let mut out = matrix.data.clone();
for i in 0..numel {
out[i] += bias.data[i % dim];
}
CpuBuffer { data: out }
}
fn add_assign(&self, dst: &mut CpuBuffer, src: &CpuBuffer) {
assert_eq!(dst.data.len(), src.data.len());
for (d, s) in dst.data.iter_mut().zip(src.data.iter()) {
*d += *s;
}
}
fn zero_buffer(&self, buf: &mut CpuBuffer) {
for v in buf.data.iter_mut() {
*v = 0.0;
}
}
fn adamw_step(
&self,
param: &mut CpuBuffer,
grad: &CpuBuffer,
m: &mut CpuBuffer,
v: &mut CpuBuffer,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
step_t: usize,
) {
let n = param.data.len();
let beta1_pow = beta1.powi(step_t as i32);
let beta2_pow = beta2.powi(step_t as i32);
for i in 0..n {
let g = grad.data[i];
param.data[i] -= lr * weight_decay * param.data[i];
m.data[i] = beta1 * m.data[i] + (1.0 - beta1) * g;
v.data[i] = beta2 * v.data[i] + (1.0 - beta2) * g * g;
let m_hat = m.data[i] / (1.0 - beta1_pow);
let v_hat = v.data[i] / (1.0 - beta2_pow);
param.data[i] -= lr * m_hat / (v_hat.sqrt() + eps);
}
}
}
fn matmul_impl(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
#[cfg(target_os = "macos")]
{
cblas_matmul(a, b, c, m, k, n);
}
#[cfg(not(target_os = "macos"))]
{
naive_matmul(a, b, c, m, k, n);
}
}
#[cfg(target_os = "macos")]
fn cblas_matmul(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
extern "C" {
fn cblas_sgemm(
order: i32,
transa: i32,
transb: i32,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: *const f32,
lda: i32,
b: *const f32,
ldb: i32,
beta: f32,
c: *mut f32,
ldc: i32,
);
}
const ROW_MAJOR: i32 = 101;
const NO_TRANS: i32 = 111;
unsafe {
cblas_sgemm(
ROW_MAJOR,
NO_TRANS,
NO_TRANS,
m as i32,
n as i32,
k as i32,
1.0,
a.as_ptr(),
k as i32,
b.as_ptr(),
n as i32,
0.0,
c.as_mut_ptr(),
n as i32,
);
}
}
#[cfg(not(target_os = "macos"))]
fn naive_matmul(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for p in 0..k {
sum += a[i * k + p] * b[p * n + j];
}
c[i * n + j] = sum;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_upload_download() {
let dev = CpuDevice::new();
let data = vec![1.0, 2.0, 3.0, 4.0];
let buf = dev.upload(&data);
assert_eq!(dev.download(&buf), data);
}
#[test]
fn cpu_matmul_identity() {
let dev = CpuDevice::new();
let a = dev.upload(&[1.0, 0.0, 0.0, 1.0]);
let b = dev.upload(&[1.0, 2.0, 3.0, 4.0]);
let c = dev.matmul(&a, &b, 2, 2, 2);
let result = dev.download(&c);
assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn cpu_matmul_basic() {
let dev = CpuDevice::new();
let a = dev.upload(&[1.0, 2.0]);
let b = dev.upload(&[3.0, 4.0]);
let c = dev.matmul(&a, &b, 1, 2, 1);
let result = dev.download(&c);
assert!((result[0] - 11.0).abs() < 1e-5);
}
#[test]
fn cpu_matmul_rectangular() {
let dev = CpuDevice::new();
let a = dev.upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = dev.upload(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
let c = dev.matmul(&a, &b, 2, 3, 2);
let result = dev.download(&c);
assert!((result[0] - 58.0).abs() < 1e-4);
assert!((result[1] - 64.0).abs() < 1e-4);
assert!((result[2] - 139.0).abs() < 1e-4);
assert!((result[3] - 154.0).abs() < 1e-4);
}
#[test]
fn cpu_softmax() {
let dev = CpuDevice::new();
let data = dev.upload(&[1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
let result = dev.softmax(&data, 2, 3);
let out = dev.download(&result);
let sum0: f32 = out[0..3].iter().sum();
let sum1: f32 = out[3..6].iter().sum();
assert!((sum0 - 1.0).abs() < 1e-5);
assert!((sum1 - 1.0).abs() < 1e-5);
assert!((out[0] - out[3]).abs() < 1e-6);
}
#[test]
fn cpu_rms_norm() {
let dev = CpuDevice::new();
let data = dev.upload(&[1.0, 2.0, 3.0, 4.0]);
let weight = dev.upload(&[1.0, 1.0]);
let result = dev.rms_norm(&data, &weight, 2, 2, 1e-5);
let out = dev.download(&result);
let rms0 = (2.5f32 + 1e-5).sqrt();
assert!((out[0] - 1.0 / rms0).abs() < 1e-5);
assert!((out[1] - 2.0 / rms0).abs() < 1e-5);
}
#[test]
fn cpu_embedding() {
let dev = CpuDevice::new();
let weight = dev.upload(&[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]);
let ids = dev.upload_u32(&[2, 0, 1]);
let result = dev.embedding(&weight, &ids, 3, 2);
let out = dev.download(&result);
assert!((out[0] - 0.5).abs() < 1e-6);
assert!((out[1] - 0.6).abs() < 1e-6);
assert!((out[2] - 0.1).abs() < 1e-6);
assert!((out[3] - 0.2).abs() < 1e-6);
assert!((out[4] - 0.3).abs() < 1e-6);
assert!((out[5] - 0.4).abs() < 1e-6);
}
#[test]
fn cpu_reduce_sum() {
let dev = CpuDevice::new();
let data = dev.upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let result = dev.reduce_sum(&data, &[2, 3], 1);
let out = dev.download(&result);
assert_eq!(out.len(), 2);
assert!((out[0] - 6.0).abs() < 1e-5); assert!((out[1] - 15.0).abs() < 1e-5); }
#[test]
fn cpu_reduce_sum_axis0() {
let dev = CpuDevice::new();
let data = dev.upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let result = dev.reduce_sum(&data, &[2, 3], 0);
let out = dev.download(&result);
assert_eq!(out.len(), 3);
assert!((out[0] - 5.0).abs() < 1e-5); assert!((out[1] - 7.0).abs() < 1e-5); assert!((out[2] - 9.0).abs() < 1e-5); }
#[test]
fn cpu_elementwise_add() {
let dev = CpuDevice::new();
let a = dev.upload(&[1.0, 2.0, 3.0]);
let b = dev.upload(&[4.0, 5.0, 6.0]);
let c = dev.elementwise(&[&a, &b], 3, &|vars| vars[0] + vars[1]);
let out = dev.download(&c);
assert!((out[0] - 5.0).abs() < 1e-5);
assert!((out[1] - 7.0).abs() < 1e-5);
assert!((out[2] - 9.0).abs() < 1e-5);
}
#[test]
fn cpu_elementwise_fused() {
use crate::Scalar;
let dev = CpuDevice::new();
let a = dev.upload(&[1.0, 4.0, 9.0]);
let c = dev.elementwise(&[&a], 3, &|vars| {
vars[0].sqrt() + ExprId::from_f64(1.0)
});
let out = dev.download(&c);
assert!((out[0] - 2.0).abs() < 1e-4);
assert!((out[1] - 3.0).abs() < 1e-4);
assert!((out[2] - 4.0).abs() < 1e-4);
}
#[test]
fn cpu_causal_attention() {
let dev = CpuDevice::new();
let q = dev.upload(&[1.0, 0.0, 0.0, 1.0]);
let k = dev.upload(&[1.0, 0.0, 0.0, 1.0]);
let v = dev.upload(&[1.0, 2.0, 3.0, 4.0]);
let result = dev.causal_attention(&q, &k, &v, 2, 1, 1, 2);
let out = dev.download(&result);
assert_eq!(out.len(), 4);
assert!((out[0] - 1.0).abs() < 1e-4);
assert!((out[1] - 2.0).abs() < 1e-4);
}
#[test]
fn cpu_kv_attention() {
let dev = CpuDevice::new();
let q = dev.upload(&[1.0, 0.0]);
let k = dev.upload(&[1.0, 0.0, 0.0, 1.0]);
let v = dev.upload(&[1.0, 2.0, 3.0, 4.0]);
let result = dev.kv_attention(&q, &k, &v, 1, 1, 1, 1, 2);
let out = dev.download(&result);
assert_eq!(out.len(), 2);
assert!(out[0] < 2.5); }
#[test]
fn cpu_kv_attention_batched_causal() {
let dev = CpuDevice::new();
let q = dev.upload(&[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
let k = dev.upload(&[1.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
let v = dev.upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let result = dev.kv_attention(&q, &k, &v, 0, 3, 1, 1, 2);
let out = dev.download(&result);
assert_eq!(out.len(), 6);
assert!((out[0] - 1.0).abs() < 1e-5);
assert!((out[1] - 2.0).abs() < 1e-5);
assert!(out[2] > 1.5); assert!(out[3] > 2.5);
assert_eq!(out.len(), 6);
}
#[test]
fn cpu_transpose_2d() {
let dev = CpuDevice::new();
let buf = dev.upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let result = dev.transpose_2d(&buf, 2, 3);
let out = dev.download(&result);
assert_eq!(out, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn cpu_softmax_backward() {
let dev = CpuDevice::new();
let sm = dev.upload(&[0.2, 0.3, 0.5]);
let grad = dev.upload(&[1.0, 0.0, 0.0]);
let result = dev.softmax_backward(&sm, &grad, 1, 3);
let out = dev.download(&result);
assert!((out[0] - 0.16).abs() < 1e-6);
assert!((out[1] - (-0.06)).abs() < 1e-6);
assert!((out[2] - (-0.1)).abs() < 1e-6);
}
#[test]
fn cpu_embedding_backward() {
let dev = CpuDevice::new();
let grad_out = dev.upload(&[1.0, 2.0, 3.0, 0.1, 0.2, 0.3]);
let ids = dev.upload_u32(&[1, 3]);
let result = dev.embedding_backward(&grad_out, &ids, 4, 2, 3);
let out = dev.download(&result);
assert_eq!(out.len(), 12); assert!((out[3] - 1.0).abs() < 1e-6);
assert!((out[4] - 2.0).abs() < 1e-6);
assert!((out[5] - 3.0).abs() < 1e-6);
assert!((out[9] - 0.1).abs() < 1e-6);
assert!((out[10] - 0.2).abs() < 1e-6);
assert!((out[11] - 0.3).abs() < 1e-6);
assert!((out[0]).abs() < 1e-6);
assert!((out[6]).abs() < 1e-6);
}
#[test]
fn cpu_cross_entropy_forward_backward() {
let dev = CpuDevice::new();
let logits = dev.upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let targets = dev.upload_u32(&[2, 0]);
let (loss, grad) = dev.cross_entropy_forward_backward(&logits, &targets, 2, 3, 99);
let g = dev.download(&grad);
assert!(loss.is_finite());
assert!(loss > 0.0);
assert_eq!(g.len(), 6);
let row0_sum: f32 = g[0..3].iter().sum();
let row1_sum: f32 = g[3..6].iter().sum();
assert!(row0_sum.abs() < 1e-5);
assert!(row1_sum.abs() < 1e-5);
}
#[test]
fn cpu_cross_entropy_with_padding() {
let dev = CpuDevice::new();
let logits = dev.upload(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let targets = dev.upload_u32(&[2, 0]); let (loss, grad) = dev.cross_entropy_forward_backward(&logits, &targets, 2, 3, 0);
let g = dev.download(&grad);
assert!(loss > 0.0);
assert!((g[3]).abs() < 1e-6);
assert!((g[4]).abs() < 1e-6);
assert!((g[5]).abs() < 1e-6);
}
#[test]
fn cpu_rms_norm_backward() {
let dev = CpuDevice::new();
let input = dev.upload(&[1.0, 2.0, 3.0, 4.0]);
let weight = dev.upload(&[1.0, 1.0]);
let grad_out = dev.upload(&[1.0, 0.0, 0.0, 1.0]);
let (grad_input, grad_weight) =
dev.rms_norm_backward(&input, &weight, &grad_out, 2, 2, 1e-5);
let gi = dev.download(&grad_input);
let gw = dev.download(&grad_weight);
assert_eq!(gi.len(), 4);
assert_eq!(gw.len(), 2);
for v in &gi {
assert!(v.is_finite());
}
for v in &gw {
assert!(v.is_finite());
}
}
#[test]
fn cpu_kv_attention_batched_matches_sequential() {
let dev = CpuDevice::new();
let n_heads = 2;
let n_kv_heads = 1;
let head_dim = 4;
let q_len = 4;
let total_dim = n_heads * head_dim;
let kv_dim = n_kv_heads * head_dim;
let q_data: Vec<f32> = (0..q_len * total_dim).map(|i| ((i * 7 + 3) % 13) as f32 / 13.0).collect();
let kv_data: Vec<f32> = (0..q_len * kv_dim).map(|i| ((i * 11 + 5) % 17) as f32 / 17.0).collect();
let v_data: Vec<f32> = (0..q_len * kv_dim).map(|i| ((i * 13 + 7) % 19) as f32 / 19.0).collect();
let q = dev.upload(&q_data);
let k = dev.upload(&kv_data);
let v = dev.upload(&v_data);
let batched = dev.download(&dev.kv_attention(&q, &k, &v, 0, q_len, n_heads, n_kv_heads, head_dim));
let mut sequential = Vec::new();
for qi in 0..q_len {
let q_slice = dev.upload(&q_data[qi * total_dim..(qi + 1) * total_dim]);
let k_slice = dev.upload(&kv_data[..((qi + 1) * kv_dim)]);
let v_slice = dev.upload(&v_data[..((qi + 1) * kv_dim)]);
let out = dev.download(&dev.kv_attention(&q_slice, &k_slice, &v_slice, qi, 1, n_heads, n_kv_heads, head_dim));
sequential.extend(out);
}
assert_eq!(batched.len(), sequential.len());
for i in 0..batched.len() {
assert!(
(batched[i] - sequential[i]).abs() < 1e-5,
"mismatch at {i}: batched={} sequential={}",
batched[i], sequential[i]
);
}
}
#[test]
fn adamw_step_reduces_loss_proxy() {
let dev = CpuDevice::new();
let mut param = dev.upload(&[1.0, 2.0, 3.0]);
let grad = dev.upload(&[0.1, 0.2, 0.3]);
let mut m = dev.upload(&[0.0, 0.0, 0.0]);
let mut v = dev.upload(&[0.0, 0.0, 0.0]);
dev.adamw_step(&mut param, &grad, &mut m, &mut v, 0.001, 0.9, 0.999, 1e-8, 0.01, 1);
let p = param.to_vec();
assert!(p[0] < 1.0);
assert!(p[1] < 2.0);
assert!(p[2] < 3.0);
}
#[test]
fn add_assign_works() {
let dev = CpuDevice::new();
let mut a = dev.upload(&[1.0, 2.0, 3.0]);
let b = dev.upload(&[0.5, 1.0, 1.5]);
dev.add_assign(&mut a, &b);
assert_eq!(a.to_vec(), vec![1.5, 3.0, 4.5]);
}
#[test]
fn zero_buffer_works() {
let dev = CpuDevice::new();
let mut a = dev.upload(&[1.0, 2.0, 3.0]);
dev.zero_buffer(&mut a);
assert_eq!(a.to_vec(), vec![0.0, 0.0, 0.0]);
}
}