Skip to main content

ferrum_kernels/backend/
cpu.rs

1//! CPU backend using Accelerate (macOS) / portable fallback (Linux).
2//! Context = () — all ops execute immediately, no batching needed.
3
4use super::{AttnConfig, Backend};
5use ferrum_types::{FerrumError, Result};
6
7pub struct CpuBackend;
8
9#[cfg(target_os = "macos")]
10extern "C" {
11    fn cblas_sgemm(
12        order: i32,
13        transa: i32,
14        transb: i32,
15        m: i32,
16        n: i32,
17        k: i32,
18        alpha: f32,
19        a: *const f32,
20        lda: i32,
21        b: *const f32,
22        ldb: i32,
23        beta: f32,
24        c: *mut f32,
25        ldc: i32,
26    );
27    fn vDSP_dotpr(
28        a: *const f32,
29        a_stride: i32,
30        b: *const f32,
31        b_stride: i32,
32        result: *mut f32,
33        n: u64,
34    );
35}
36
37/// CPU-side GPTQ store — dequantized f32 weights in row-major [n, k] layout.
38/// Trades memory for simplicity: repack once at load, then run normal GEMM.
39pub struct CpuGptqStore {
40    pub weight_f32: Vec<f32>, // [n, k] row-major
41    pub k: usize,
42    pub n: usize,
43}
44
45impl Backend for CpuBackend {
46    type Buffer = Vec<f32>;
47    type Context = ();
48    type GptqStore = CpuGptqStore;
49
50    fn new_context() -> Self::Context {}
51    fn sync(_ctx: &mut Self::Context) {}
52
53    fn load_gptq(
54        qweight: &[i32],
55        scales: &[f32],
56        qzeros: &[i32],
57        _g_idx: Option<&[i32]>,
58        bits: u32,
59        group_size: usize,
60        k: usize,
61        n: usize,
62    ) -> Result<Self::GptqStore> {
63        if bits != 4 {
64            return Err(FerrumError::unsupported(format!(
65                "CPU GPTQ: only bits=4 supported (got {bits})"
66            )));
67        }
68        let num_groups = k / group_size;
69        // Unpack GPTQ [K/8, N] i32 → int4 values, dequantize per-group:
70        //   w_f16 = (q - zero) * scale
71        // Write to [n, k] row-major (matches DenseLinear convention).
72        let mut w = vec![0.0f32; n * k];
73        let packed_rows = k / 8;
74        for pr in 0..packed_rows {
75            for col in 0..n {
76                let packed = qweight[pr * n + col] as u32;
77                for bi in 0..8 {
78                    let ki = pr * 8 + bi;
79                    let q = ((packed >> (bi * 4)) & 0xF) as i32;
80                    let grp = ki / group_size;
81                    let scale = scales[grp * n + col];
82                    // qzeros [num_groups, N/8] i32 packs 8 zero-values per int32
83                    let z_packed = qzeros[grp * (n / 8) + (col / 8)] as u32;
84                    let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
85                    let val = (q - zero) as f32 * scale;
86                    w[col * k + ki] = val;
87                }
88            }
89        }
90        let _ = num_groups; // informational only
91        Ok(CpuGptqStore {
92            weight_f32: w,
93            k,
94            n,
95        })
96    }
97
98    fn gemm_gptq(
99        ctx: &mut Self::Context,
100        a: &Self::Buffer,
101        weight: &Self::GptqStore,
102        out: &mut Self::Buffer,
103        m: usize,
104    ) -> Result<()> {
105        // Just run normal GEMM with dequantized weights.
106        // out[m, n] = a[m, k] @ w[n, k]^T — same contract as B::gemm.
107        Self::gemm(ctx, a, &weight.weight_f32, out, m, weight.n, weight.k);
108        Ok(())
109    }
110
111    fn gemm(
112        _ctx: &mut Self::Context,
113        a: &Self::Buffer,
114        b: &Self::Buffer,
115        out: &mut Self::Buffer,
116        m: usize,
117        n: usize,
118        k: usize,
119    ) {
120        assert!(
121            a.len() >= m * k,
122            "gemm: a too small len={} m={m} k={k}",
123            a.len()
124        );
125        assert!(
126            b.len() >= n * k,
127            "gemm: b too small len={} n={n} k={k}",
128            b.len()
129        );
130        assert!(
131            out.len() >= m * n,
132            "gemm: out too small len={} m={m} n={n}",
133            out.len()
134        );
135        #[cfg(target_os = "macos")]
136        unsafe {
137            cblas_sgemm(
138                101,
139                111,
140                112,
141                m as i32,
142                n as i32,
143                k as i32,
144                1.0,
145                a.as_ptr(),
146                k as i32,
147                b.as_ptr(),
148                k as i32,
149                0.0,
150                out.as_mut_ptr(),
151                n as i32,
152            );
153        }
154        #[cfg(not(target_os = "macos"))]
155        {
156            for i in 0..m {
157                for j in 0..n {
158                    let mut sum = 0.0f64;
159                    for p in 0..k {
160                        sum += a[i * k + p] as f64 * b[j * k + p] as f64;
161                    }
162                    out[i * n + j] = sum as f32;
163                }
164            }
165        }
166    }
167
168    fn rms_norm(
169        _ctx: &mut Self::Context,
170        x: &Self::Buffer,
171        w: &Self::Buffer,
172        eps: f32,
173        out: &mut Self::Buffer,
174        tokens: usize,
175        dim: usize,
176    ) {
177        for t in 0..tokens {
178            let row = &x[t * dim..(t + 1) * dim];
179            let o = &mut out[t * dim..(t + 1) * dim];
180            let sum_sq = dot_product(row, row);
181            let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
182            for i in 0..dim {
183                o[i] = row[i] * inv * w[i];
184            }
185        }
186    }
187
188    fn fused_add_rms_norm(
189        _ctx: &mut Self::Context,
190        residual: &mut Self::Buffer,
191        x: &Self::Buffer,
192        w: &Self::Buffer,
193        eps: f32,
194        out: &mut Self::Buffer,
195        tokens: usize,
196        dim: usize,
197    ) {
198        for t in 0..tokens {
199            let off = t * dim;
200            for i in 0..dim {
201                residual[off + i] += x[off + i];
202            }
203            let row = &residual[off..off + dim];
204            let o = &mut out[off..off + dim];
205            let sum_sq = dot_product(row, row);
206            let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
207            for i in 0..dim {
208                o[i] = row[i] * inv * w[i];
209            }
210        }
211    }
212
213    fn flash_attention(
214        _ctx: &mut Self::Context,
215        q: &Self::Buffer,
216        k: &Self::Buffer,
217        v: &Self::Buffer,
218        out: &mut Self::Buffer,
219        batch: usize,
220        q_len: usize,
221        kv_len: usize,
222        pos_offset: usize,
223        cfg: &AttnConfig,
224    ) {
225        cpu_attention(
226            q, k, v, out, batch, q_len, kv_len, cfg.causal, pos_offset, cfg,
227        );
228    }
229
230    fn copy_slice(
231        _ctx: &mut Self::Context,
232        src: &Self::Buffer,
233        src_offset: usize,
234        dst: &mut Self::Buffer,
235        dst_offset: usize,
236        len: usize,
237    ) {
238        dst[dst_offset..dst_offset + len].copy_from_slice(&src[src_offset..src_offset + len]);
239    }
240
241    fn embedding_lookup(
242        _ctx: &mut Self::Context,
243        table: &Self::Buffer,
244        ids: &[u32],
245        out: &mut Self::Buffer,
246        dim: usize,
247    ) {
248        for (i, &id) in ids.iter().enumerate() {
249            let src = id as usize * dim;
250            out[i * dim..(i + 1) * dim].copy_from_slice(&table[src..src + dim]);
251        }
252    }
253
254    fn split_qkv(
255        _ctx: &mut Self::Context,
256        qkv: &Self::Buffer,
257        q: &mut Self::Buffer,
258        k: &mut Self::Buffer,
259        v: &mut Self::Buffer,
260        tokens: usize,
261        q_dim: usize,
262        kv_dim: usize,
263    ) {
264        let qkv_dim = q_dim + 2 * kv_dim;
265        for t in 0..tokens {
266            let base = t * qkv_dim;
267            q[t * q_dim..(t + 1) * q_dim].copy_from_slice(&qkv[base..base + q_dim]);
268            k[t * kv_dim..(t + 1) * kv_dim]
269                .copy_from_slice(&qkv[base + q_dim..base + q_dim + kv_dim]);
270            v[t * kv_dim..(t + 1) * kv_dim]
271                .copy_from_slice(&qkv[base + q_dim + kv_dim..base + qkv_dim]);
272        }
273    }
274
275    fn fused_silu_mul_split(
276        _ctx: &mut Self::Context,
277        gate_up: &Self::Buffer,
278        out: &mut Self::Buffer,
279        tokens: usize,
280        im: usize,
281    ) {
282        for t in 0..tokens {
283            for i in 0..im {
284                let g = gate_up[t * 2 * im + i];
285                let u = gate_up[t * 2 * im + im + i];
286                out[t * im + i] = (g / (1.0 + (-g).exp())) * u;
287            }
288        }
289    }
290
291    fn qk_norm_rope(
292        _ctx: &mut Self::Context,
293        input: &Self::Buffer,
294        norm_w: &Self::Buffer,
295        cos: &Self::Buffer,
296        sin: &Self::Buffer,
297        output: &mut Self::Buffer,
298        tokens: usize,
299        heads: usize,
300        head_dim: usize,
301        pos_offset: usize,
302        eps: f32,
303        mode: i32,
304    ) {
305        let half = head_dim / 2;
306        let cos_len = cos.len();
307        let sin_len = sin.len();
308        debug_assert_eq!(cos_len, sin_len);
309
310        for t in 0..tokens {
311            let pos = pos_offset + t;
312            for h in 0..heads {
313                // input row: [t, h, :]  stride = heads * head_dim
314                let src_off = (t * heads + h) * head_dim;
315                // output row: [h, t, :]  stride = tokens * head_dim
316                let dst_off = (h * tokens + t) * head_dim;
317
318                // Mode 0: plain transpose.
319                if mode == 0 {
320                    for i in 0..head_dim {
321                        output[dst_off + i] = input[src_off + i];
322                    }
323                    continue;
324                }
325
326                // Optional RMS norm (mode 1 only).
327                let scale = if mode == 1 {
328                    let mut sum_sq = 0.0f32;
329                    for i in 0..head_dim {
330                        sum_sq += input[src_off + i] * input[src_off + i];
331                    }
332                    1.0f32 / (sum_sq / head_dim as f32 + eps).sqrt()
333                } else {
334                    1.0
335                };
336
337                // Apply (norm?) + RoPE to halves, write to head-major output.
338                for i in 0..half {
339                    let (x0_raw, x1_raw) = (input[src_off + i], input[src_off + i + half]);
340                    let (x0, x1) = if mode == 1 {
341                        (
342                            x0_raw * scale * norm_w[i],
343                            x1_raw * scale * norm_w[i + half],
344                        )
345                    } else {
346                        (x0_raw, x1_raw)
347                    };
348                    let c = cos[pos * half + i];
349                    let s = sin[pos * half + i];
350                    output[dst_off + i] = x0 * c - x1 * s;
351                    output[dst_off + i + half] = x1 * c + x0 * s;
352                }
353            }
354        }
355    }
356
357    fn kv_cache_append_head_major(
358        _ctx: &mut Self::Context,
359        cache_k: &mut Self::Buffer,
360        cache_v: &mut Self::Buffer,
361        cache_len: usize,
362        cache_capacity: usize,
363        new_k_head_major: &Self::Buffer,
364        new_v_head_major: &Self::Buffer,
365        new_tokens: usize,
366        nkv: usize,
367        hd: usize,
368    ) {
369        debug_assert!(cache_len + new_tokens <= cache_capacity);
370        debug_assert_eq!(cache_k.len(), nkv * cache_capacity * hd);
371        debug_assert_eq!(cache_v.len(), nkv * cache_capacity * hd);
372        debug_assert_eq!(new_k_head_major.len(), nkv * new_tokens * hd);
373        debug_assert_eq!(new_v_head_major.len(), nkv * new_tokens * hd);
374
375        for h in 0..nkv {
376            let dst_base = h * cache_capacity * hd + cache_len * hd;
377            let src_base = h * new_tokens * hd;
378            cache_k[dst_base..dst_base + new_tokens * hd]
379                .copy_from_slice(&new_k_head_major[src_base..src_base + new_tokens * hd]);
380            cache_v[dst_base..dst_base + new_tokens * hd]
381                .copy_from_slice(&new_v_head_major[src_base..src_base + new_tokens * hd]);
382        }
383    }
384
385    fn transpose_head_to_token(
386        _ctx: &mut Self::Context,
387        src: &Self::Buffer,
388        dst: &mut Self::Buffer,
389        tokens: usize,
390        heads: usize,
391        dim: usize,
392    ) {
393        for h in 0..heads {
394            for t in 0..tokens {
395                let s = (h * tokens + t) * dim;
396                let d = (t * heads + h) * dim;
397                dst[d..d + dim].copy_from_slice(&src[s..s + dim]);
398            }
399        }
400    }
401
402    fn add_inplace(
403        _ctx: &mut Self::Context,
404        residual: &mut Self::Buffer,
405        x: &Self::Buffer,
406        len: usize,
407    ) {
408        for i in 0..len {
409            residual[i] += x[i];
410        }
411    }
412
413    fn add_bias(
414        _ctx: &mut Self::Context,
415        data: &mut Self::Buffer,
416        bias: &Self::Buffer,
417        rows: usize,
418        cols: usize,
419    ) {
420        debug_assert_eq!(bias.len(), cols);
421        for r in 0..rows {
422            let off = r * cols;
423            for c in 0..cols {
424                data[off + c] += bias[c];
425            }
426        }
427    }
428
429    fn layer_norm(
430        _ctx: &mut Self::Context,
431        x: &Self::Buffer,
432        gamma: &Self::Buffer,
433        beta: &Self::Buffer,
434        eps: f32,
435        out: &mut Self::Buffer,
436        tokens: usize,
437        dim: usize,
438    ) {
439        debug_assert_eq!(gamma.len(), dim);
440        debug_assert_eq!(beta.len(), dim);
441        for t in 0..tokens {
442            let off = t * dim;
443            // Compute mean + variance over `dim` in f64 for stability.
444            let mut mean = 0.0f64;
445            for i in 0..dim {
446                mean += x[off + i] as f64;
447            }
448            mean /= dim as f64;
449            let mut var = 0.0f64;
450            for i in 0..dim {
451                let d = x[off + i] as f64 - mean;
452                var += d * d;
453            }
454            var /= dim as f64;
455            let inv = 1.0f32 / ((var as f32) + eps).sqrt();
456            let mean_f32 = mean as f32;
457            for i in 0..dim {
458                out[off + i] = (x[off + i] - mean_f32) * inv * gamma[i] + beta[i];
459            }
460        }
461    }
462
463    fn gelu(_ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize) {
464        // Exact GELU: 0.5 * x * (1 + erf(x / sqrt(2))).
465        // Uses f64 for erf accuracy (matches torch.nn.functional.gelu default).
466        for i in 0..len {
467            let xi = x[i];
468            out[i] = 0.5 * xi * (1.0 + libm_erf(xi / std::f32::consts::SQRT_2));
469        }
470    }
471
472    fn alloc(len: usize) -> Self::Buffer {
473        vec![0.0f32; len]
474    }
475    fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32> {
476        buf[..len].to_vec()
477    }
478    fn from_slice(data: &[f32]) -> Self::Buffer {
479        data.to_vec()
480    }
481}
482
483// ── Helpers ──────────────────────────────────────────────────────────────
484
485fn dot_product(a: &[f32], b: &[f32]) -> f32 {
486    #[cfg(target_os = "macos")]
487    {
488        let mut result = 0.0f32;
489        unsafe {
490            vDSP_dotpr(a.as_ptr(), 1, b.as_ptr(), 1, &mut result, a.len() as u64);
491        }
492        result
493    }
494    #[cfg(not(target_os = "macos"))]
495    {
496        a.iter().zip(b).map(|(x, y)| x * y).sum()
497    }
498}
499
500fn apply_rope_impl(
501    data: &mut [f32],
502    tokens: usize,
503    heads: usize,
504    head_dim: usize,
505    half: usize,
506    cos: &[f32],
507    sin: &[f32],
508    positions: &[u32],
509) {
510    for t in 0..tokens {
511        let pos = positions[t] as usize;
512        for h in 0..heads {
513            let base = t * heads * head_dim + h * head_dim;
514            for i in 0..half {
515                let c = cos[pos * half + i];
516                let s = sin[pos * half + i];
517                let x0 = data[base + i];
518                let x1 = data[base + half + i];
519                data[base + i] = x0 * c - x1 * s;
520                data[base + half + i] = x1 * c + x0 * s;
521            }
522        }
523    }
524}
525
526fn cpu_attention(
527    q: &[f32],
528    k: &[f32],
529    v: &[f32],
530    out: &mut [f32],
531    batch: usize,
532    q_len: usize,
533    kv_len: usize,
534    causal: bool,
535    pos_offset: usize,
536    cfg: &AttnConfig,
537) {
538    let nh = cfg.num_heads;
539    let nkv = cfg.num_kv_heads;
540    let d = cfg.head_dim;
541    let n_rep = nh / nkv;
542    let scale = cfg.scale;
543    // Per-head KV stride: 0 (the default) means contiguous (legacy
544    // `kv_cache_append` path reallocates each layer). A non-zero value means
545    // the cache is pre-allocated to `kv_seq_stride` rows per head but only
546    // the first `kv_len` are valid — we skip the rest via `attend_len`.
547    let kv_stride = if cfg.kv_seq_stride > 0 {
548        cfg.kv_seq_stride
549    } else {
550        kv_len
551    };
552
553    for b in 0..batch {
554        for h in 0..nh {
555            let kv_h = h / n_rep;
556            let q_off = (b * nh + h) * q_len * d;
557            let k_off = (b * nkv + kv_h) * kv_stride * d;
558            let v_off = (b * nkv + kv_h) * kv_stride * d;
559            let o_off = (b * nh + h) * q_len * d;
560
561            for qi in 0..q_len {
562                let attend_end = if causal {
563                    (pos_offset + qi + 1).min(kv_len)
564                } else {
565                    kv_len
566                };
567                let attend_start = if causal && cfg.sliding_window > 0 {
568                    attend_end.saturating_sub(cfg.sliding_window)
569                } else {
570                    0
571                };
572                let mut max_score = f32::NEG_INFINITY;
573                let mut sum_exp = 0.0f32;
574                let mut acc = vec![0.0f32; d];
575
576                for ki in attend_start..attend_end {
577                    let mut dot = 0.0f32;
578                    for di in 0..d {
579                        dot += q[q_off + qi * d + di] * k[k_off + ki * d + di];
580                    }
581                    let score = dot * scale;
582                    if score > max_score {
583                        let correction = (max_score - score).exp();
584                        for di in 0..d {
585                            acc[di] *= correction;
586                        }
587                        sum_exp *= correction;
588                        max_score = score;
589                    }
590                    let w = (score - max_score).exp();
591                    sum_exp += w;
592                    for di in 0..d {
593                        acc[di] += w * v[v_off + ki * d + di];
594                    }
595                }
596
597                if sum_exp > 0.0 {
598                    let inv = 1.0 / sum_exp;
599                    for di in 0..d {
600                        out[o_off + qi * d + di] = acc[di] * inv;
601                    }
602                }
603            }
604        }
605    }
606}
607
608/// Minimal error-function approximation (Abramowitz & Stegun 7.1.26),
609/// max error ~1.5e-7 which is comfortably below f32 round-off noise.
610fn libm_erf(x: f32) -> f32 {
611    let sign = if x < 0.0 { -1.0 } else { 1.0 };
612    let x = x.abs();
613    let t = 1.0 / (1.0 + 0.3275911 * x);
614    let y = 1.0
615        - (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736) * t
616            + 0.254829592)
617            * t
618            * (-x * x).exp();
619    sign * y
620}