Skip to main content

ferrum_kernels/backend/
cpu.rs

1//! CPU backend using Accelerate (macOS) / portable fallback (Linux).
2//! Context = () — all ops execute immediately, no batching needed.
3
4use half::f16;
5
6use super::{AttnConfig, Backend};
7use ferrum_types::{FerrumError, Result};
8
9// ── Q4_K_M block layout ────────────────────────────────────────────────
10//
11// Mirrors GGML / candle's `BlockQ4K`. Used by `load_q4_k` to dequant raw
12// GGUF block bytes to fp32 row-major weights on CPU.
13
14const Q4_K_QK: usize = 256;
15const Q4_K_SCALE_SIZE: usize = 12;
16const Q4_K_BLOCK_BYTES: usize = 4 + Q4_K_SCALE_SIZE + Q4_K_QK / 2; // 144
17
18/// Bit-unpacker matching candle's `quantized::utils::get_scale_min_k4`.
19fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
20    if j < 4 {
21        (q[j] & 63, q[j + 4] & 63)
22    } else {
23        let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
24        let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
25        (d, m)
26    }
27}
28
29/// Port of candle's CPU `BlockQ4K::to_float`. Bit-identical output for
30/// identical input — the test in `q4_k.rs` verifies our Metal kernel
31/// also matches.
32fn dequant_q4_k_cpu(bytes: &[u8], n_blocks: usize) -> Vec<f32> {
33    debug_assert_eq!(bytes.len(), n_blocks * Q4_K_BLOCK_BYTES);
34    let mut out = Vec::with_capacity(n_blocks * Q4_K_QK);
35    for b in 0..n_blocks {
36        let off = b * Q4_K_BLOCK_BYTES;
37        let d = f16::from_le_bytes([bytes[off], bytes[off + 1]]).to_f32();
38        let dmin = f16::from_le_bytes([bytes[off + 2], bytes[off + 3]]).to_f32();
39        let scales = &bytes[off + 4..off + 4 + Q4_K_SCALE_SIZE];
40        let qs = &bytes[off + 4 + Q4_K_SCALE_SIZE..off + Q4_K_BLOCK_BYTES];
41
42        let mut is = 0usize;
43        for j in (0..Q4_K_QK).step_by(64) {
44            let q_chunk = &qs[j / 2..j / 2 + 32];
45            let (sc1, mn1) = get_scale_min_k4(is, scales);
46            let d1 = d * sc1 as f32;
47            let m1 = dmin * mn1 as f32;
48            let (sc2, mn2) = get_scale_min_k4(is + 1, scales);
49            let d2 = d * sc2 as f32;
50            let m2 = dmin * mn2 as f32;
51            for q in q_chunk {
52                out.push(d1 * (q & 0xF) as f32 - m1);
53            }
54            for q in q_chunk {
55                out.push(d2 * (q >> 4) as f32 - m2);
56            }
57            is += 2;
58        }
59    }
60    out
61}
62
63pub struct CpuBackend;
64
65#[cfg(target_os = "macos")]
66unsafe extern "C" {
67    unsafe fn cblas_sgemm(
68        order: i32,
69        transa: i32,
70        transb: i32,
71        m: i32,
72        n: i32,
73        k: i32,
74        alpha: f32,
75        a: *const f32,
76        lda: i32,
77        b: *const f32,
78        ldb: i32,
79        beta: f32,
80        c: *mut f32,
81        ldc: i32,
82    );
83    fn vDSP_dotpr(
84        a: *const f32,
85        a_stride: i32,
86        b: *const f32,
87        b_stride: i32,
88        result: *mut f32,
89        n: u64,
90    );
91}
92
93/// CPU-side GPTQ store — dequantized f32 weights in row-major [n, k] layout.
94/// Trades memory for simplicity: repack once at load, then run normal GEMM.
95pub struct CpuGptqStore {
96    pub weight_f32: Vec<f32>, // [n, k] row-major
97    pub k: usize,
98    pub n: usize,
99}
100
101/// CPU-side container for any GGUF k-quant flavour. Each variant holds
102/// the dense fp32 weights post-eager-dequant — CPU isn't the bench
103/// target so we don't pay the complexity of on-the-fly dequant here;
104/// the variant tag exists so `gemm_quant` can route consistently.
105///
106/// New k-quant types (Q5_K / Q6_K / Q8_0) become new variants — no
107/// trait churn, just a new arm in `load_quant` and `gemm_quant`.
108pub enum CpuQuantStore {
109    Q4K {
110        weights: Vec<f32>, // [n_rows, n_cols] row-major
111        n_rows: usize,
112        n_cols: usize,
113    },
114}
115
116impl Backend for CpuBackend {
117    type Buffer = Vec<f32>;
118    type Context = ();
119    // type GptqStore: removed in Phase C step 4e. CpuGptqStore is now
120    // a private (crate-internal) detail of CpuMarlinExpertStack.
121
122    type Timer = crate::backend::timer::CpuTimer;
123    fn make_timer() -> Self::Timer {
124        crate::backend::timer::CpuTimer::new()
125    }
126
127    fn new_context() -> Self::Context {}
128    fn sync(_ctx: &mut Self::Context) {}
129
130    /// Phase D step 2+3: typed alloc. CPU Buffer is Vec<f32> — bytes
131    /// are dtype-erased, so we size the underlying Vec to hold `n`
132    /// elements of `dtype` (bit-cast at read/write time).
133    fn alloc_typed(dtype: crate::backend::Dtype, n: usize) -> Self::Buffer {
134        // f32 storage; for i8 we round up to 4-byte boundary so the
135        // Vec<f32> length covers all i8 elements.
136        let bytes = n * dtype.bytes_per_elem();
137        let f32_len = bytes.div_ceil(4);
138        vec![0.0f32; f32_len]
139    }
140
141    /// Phase D step 2+3: typed upload. Bit-cast host data into f32
142    /// words (CPU buffer is dtype-erased Vec<f32>, see alloc_typed).
143    fn from_slice_typed<T: crate::backend::HostDtype>(data: &[T]) -> Self::Buffer {
144        let bytes = data.len() * std::mem::size_of::<T>();
145        let f32_len = bytes.div_ceil(4);
146        let mut out = vec![0.0f32; f32_len];
147        unsafe {
148            std::ptr::copy_nonoverlapping(
149                data.as_ptr() as *const u8,
150                out.as_mut_ptr() as *mut u8,
151                bytes,
152            );
153        }
154        out
155    }
156
157    /// Phase D step 2+3: typed in-place write. Bit-cast bytes into
158    /// the dtype-erased f32 storage.
159    fn write_typed<T: crate::backend::HostDtype>(
160        _ctx: &mut Self::Context,
161        dst: &mut Self::Buffer,
162        data: &[T],
163    ) {
164        let bytes = data.len() * std::mem::size_of::<T>();
165        debug_assert!(
166            bytes <= dst.len() * 4,
167            "CpuBackend::write_typed: src bytes {} > dst bytes {}",
168            bytes,
169            dst.len() * 4
170        );
171        unsafe {
172            std::ptr::copy_nonoverlapping(
173                data.as_ptr() as *const u8,
174                dst.as_mut_ptr() as *mut u8,
175                bytes,
176            );
177        }
178    }
179
180    fn fused_silu_mul_split_strided(
181        _ctx: &mut Self::Context,
182        gate_up: &Self::Buffer,
183        in_row_offset: usize,
184        out: &mut Self::Buffer,
185        out_row_offset: usize,
186        tokens: usize,
187        intermediate: usize,
188    ) {
189        let in_per_row = 2 * intermediate;
190        let in_start = in_row_offset * in_per_row;
191        let out_start = out_row_offset * intermediate;
192        for r in 0..tokens {
193            for c in 0..intermediate {
194                let g = gate_up[in_start + r * in_per_row + c];
195                let u = gate_up[in_start + r * in_per_row + intermediate + c];
196                let silu = g / (1.0 + (-g).exp());
197                out[out_start + r * intermediate + c] = silu * u;
198            }
199        }
200    }
201
202    fn gemm(
203        _ctx: &mut Self::Context,
204        a: &Self::Buffer,
205        b: &Self::Buffer,
206        out: &mut Self::Buffer,
207        m: usize,
208        n: usize,
209        k: usize,
210    ) {
211        assert!(
212            a.len() >= m * k,
213            "gemm: a too small len={} m={m} k={k}",
214            a.len()
215        );
216        assert!(
217            b.len() >= n * k,
218            "gemm: b too small len={} n={n} k={k}",
219            b.len()
220        );
221        assert!(
222            out.len() >= m * n,
223            "gemm: out too small len={} m={m} n={n}",
224            out.len()
225        );
226        #[cfg(target_os = "macos")]
227        unsafe {
228            cblas_sgemm(
229                101,
230                111,
231                112,
232                m as i32,
233                n as i32,
234                k as i32,
235                1.0,
236                a.as_ptr(),
237                k as i32,
238                b.as_ptr(),
239                k as i32,
240                0.0,
241                out.as_mut_ptr(),
242                n as i32,
243            );
244        }
245        #[cfg(not(target_os = "macos"))]
246        {
247            for i in 0..m {
248                for j in 0..n {
249                    let mut sum = 0.0f64;
250                    for p in 0..k {
251                        sum += a[i * k + p] as f64 * b[j * k + p] as f64;
252                    }
253                    out[i * n + j] = sum as f32;
254                }
255            }
256        }
257    }
258
259    fn rms_norm(
260        _ctx: &mut Self::Context,
261        x: &Self::Buffer,
262        w: &Self::Buffer,
263        eps: f32,
264        out: &mut Self::Buffer,
265        tokens: usize,
266        dim: usize,
267    ) {
268        for t in 0..tokens {
269            let row = &x[t * dim..(t + 1) * dim];
270            let o = &mut out[t * dim..(t + 1) * dim];
271            let sum_sq = dot_product(row, row);
272            let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
273            for i in 0..dim {
274                o[i] = row[i] * inv * w[i];
275            }
276        }
277    }
278
279    fn fused_add_rms_norm(
280        _ctx: &mut Self::Context,
281        residual: &mut Self::Buffer,
282        x: &Self::Buffer,
283        w: &Self::Buffer,
284        eps: f32,
285        out: &mut Self::Buffer,
286        tokens: usize,
287        dim: usize,
288    ) {
289        for t in 0..tokens {
290            let off = t * dim;
291            for i in 0..dim {
292                residual[off + i] += x[off + i];
293            }
294            let row = &residual[off..off + dim];
295            let o = &mut out[off..off + dim];
296            let sum_sq = dot_product(row, row);
297            let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
298            for i in 0..dim {
299                o[i] = row[i] * inv * w[i];
300            }
301        }
302    }
303
304    fn flash_attention(
305        _ctx: &mut Self::Context,
306        q: &Self::Buffer,
307        k: &Self::Buffer,
308        v: &Self::Buffer,
309        out: &mut Self::Buffer,
310        batch: usize,
311        q_len: usize,
312        kv_len: usize,
313        pos_offset: usize,
314        cfg: &AttnConfig,
315    ) {
316        cpu_attention(
317            q, k, v, out, batch, q_len, kv_len, cfg.causal, pos_offset, cfg,
318        );
319    }
320
321    fn copy_slice(
322        _ctx: &mut Self::Context,
323        src: &Self::Buffer,
324        src_offset: usize,
325        dst: &mut Self::Buffer,
326        dst_offset: usize,
327        len: usize,
328    ) {
329        dst[dst_offset..dst_offset + len].copy_from_slice(&src[src_offset..src_offset + len]);
330    }
331
332    fn embedding_lookup(
333        _ctx: &mut Self::Context,
334        table: &Self::Buffer,
335        ids: &[u32],
336        out: &mut Self::Buffer,
337        dim: usize,
338    ) {
339        for (i, &id) in ids.iter().enumerate() {
340            let src = id as usize * dim;
341            out[i * dim..(i + 1) * dim].copy_from_slice(&table[src..src + dim]);
342        }
343    }
344
345    fn split_qkv(
346        _ctx: &mut Self::Context,
347        qkv: &Self::Buffer,
348        q: &mut Self::Buffer,
349        k: &mut Self::Buffer,
350        v: &mut Self::Buffer,
351        tokens: usize,
352        q_dim: usize,
353        kv_dim: usize,
354    ) {
355        let qkv_dim = q_dim + 2 * kv_dim;
356        for t in 0..tokens {
357            let base = t * qkv_dim;
358            q[t * q_dim..(t + 1) * q_dim].copy_from_slice(&qkv[base..base + q_dim]);
359            k[t * kv_dim..(t + 1) * kv_dim]
360                .copy_from_slice(&qkv[base + q_dim..base + q_dim + kv_dim]);
361            v[t * kv_dim..(t + 1) * kv_dim]
362                .copy_from_slice(&qkv[base + q_dim + kv_dim..base + qkv_dim]);
363        }
364    }
365
366    fn fused_silu_mul_split(
367        _ctx: &mut Self::Context,
368        gate_up: &Self::Buffer,
369        out: &mut Self::Buffer,
370        tokens: usize,
371        im: usize,
372    ) {
373        for t in 0..tokens {
374            for i in 0..im {
375                let g = gate_up[t * 2 * im + i];
376                let u = gate_up[t * 2 * im + im + i];
377                out[t * im + i] = (g / (1.0 + (-g).exp())) * u;
378            }
379        }
380    }
381
382    fn qk_norm_rope(
383        _ctx: &mut Self::Context,
384        input: &Self::Buffer,
385        norm_w: &Self::Buffer,
386        cos: &Self::Buffer,
387        sin: &Self::Buffer,
388        output: &mut Self::Buffer,
389        tokens: usize,
390        heads: usize,
391        head_dim: usize,
392        pos_offset: usize,
393        eps: f32,
394        mode: i32,
395    ) {
396        let half = head_dim / 2;
397        let cos_len = cos.len();
398        let sin_len = sin.len();
399        debug_assert_eq!(cos_len, sin_len);
400
401        for t in 0..tokens {
402            let pos = pos_offset + t;
403            for h in 0..heads {
404                // input row: [t, h, :]  stride = heads * head_dim
405                let src_off = (t * heads + h) * head_dim;
406                // output row: [h, t, :]  stride = tokens * head_dim
407                let dst_off = (h * tokens + t) * head_dim;
408
409                // Mode 0: plain transpose.
410                if mode == 0 {
411                    for i in 0..head_dim {
412                        output[dst_off + i] = input[src_off + i];
413                    }
414                    continue;
415                }
416
417                // Optional RMS norm (mode 1 only).
418                let scale = if mode == 1 {
419                    let mut sum_sq = 0.0f32;
420                    for i in 0..head_dim {
421                        sum_sq += input[src_off + i] * input[src_off + i];
422                    }
423                    1.0f32 / (sum_sq / head_dim as f32 + eps).sqrt()
424                } else {
425                    1.0
426                };
427
428                // Apply (norm?) + RoPE to halves, write to head-major output.
429                for i in 0..half {
430                    let (x0_raw, x1_raw) = (input[src_off + i], input[src_off + i + half]);
431                    let (x0, x1) = if mode == 1 {
432                        (
433                            x0_raw * scale * norm_w[i],
434                            x1_raw * scale * norm_w[i + half],
435                        )
436                    } else {
437                        (x0_raw, x1_raw)
438                    };
439                    let c = cos[pos * half + i];
440                    let s = sin[pos * half + i];
441                    output[dst_off + i] = x0 * c - x1 * s;
442                    output[dst_off + i + half] = x1 * c + x0 * s;
443                }
444            }
445        }
446    }
447
448    fn kv_cache_append_head_major(
449        _ctx: &mut Self::Context,
450        cache_k: &mut Self::Buffer,
451        cache_v: &mut Self::Buffer,
452        cache_len: usize,
453        cache_capacity: usize,
454        new_k_head_major: &Self::Buffer,
455        new_v_head_major: &Self::Buffer,
456        new_tokens: usize,
457        nkv: usize,
458        hd: usize,
459    ) {
460        debug_assert!(cache_len + new_tokens <= cache_capacity);
461        debug_assert_eq!(cache_k.len(), nkv * cache_capacity * hd);
462        debug_assert_eq!(cache_v.len(), nkv * cache_capacity * hd);
463        // The source buffers may be sized for `max_tokens` (the prefill-
464        // sized scratch) while only the first `nkv * new_tokens * hd`
465        // entries are valid for this call. Allow >= so reusing scratch
466        // across prefill and decode doesn't trip the assert.
467        debug_assert!(new_k_head_major.len() >= nkv * new_tokens * hd);
468        debug_assert!(new_v_head_major.len() >= nkv * new_tokens * hd);
469
470        for h in 0..nkv {
471            let dst_base = h * cache_capacity * hd + cache_len * hd;
472            let src_base = h * new_tokens * hd;
473            cache_k[dst_base..dst_base + new_tokens * hd]
474                .copy_from_slice(&new_k_head_major[src_base..src_base + new_tokens * hd]);
475            cache_v[dst_base..dst_base + new_tokens * hd]
476                .copy_from_slice(&new_v_head_major[src_base..src_base + new_tokens * hd]);
477        }
478    }
479
480    fn transpose_head_to_token(
481        _ctx: &mut Self::Context,
482        src: &Self::Buffer,
483        dst: &mut Self::Buffer,
484        tokens: usize,
485        heads: usize,
486        dim: usize,
487    ) {
488        for h in 0..heads {
489            for t in 0..tokens {
490                let s = (h * tokens + t) * dim;
491                let d = (t * heads + h) * dim;
492                dst[d..d + dim].copy_from_slice(&src[s..s + dim]);
493            }
494        }
495    }
496
497    fn add_inplace(
498        _ctx: &mut Self::Context,
499        residual: &mut Self::Buffer,
500        x: &Self::Buffer,
501        len: usize,
502    ) {
503        for i in 0..len {
504            residual[i] += x[i];
505        }
506    }
507
508    fn scaled_add_inplace(
509        _ctx: &mut Self::Context,
510        dst: &mut Self::Buffer,
511        src: &Self::Buffer,
512        scale: f32,
513        len: usize,
514    ) {
515        for i in 0..len {
516            dst[i] += scale * src[i];
517        }
518    }
519
520    fn add_bias(
521        _ctx: &mut Self::Context,
522        data: &mut Self::Buffer,
523        bias: &Self::Buffer,
524        rows: usize,
525        cols: usize,
526    ) {
527        debug_assert_eq!(bias.len(), cols);
528        for r in 0..rows {
529            let off = r * cols;
530            for c in 0..cols {
531                data[off + c] += bias[c];
532            }
533        }
534    }
535
536    fn layer_norm(
537        _ctx: &mut Self::Context,
538        x: &Self::Buffer,
539        gamma: &Self::Buffer,
540        beta: &Self::Buffer,
541        eps: f32,
542        out: &mut Self::Buffer,
543        tokens: usize,
544        dim: usize,
545    ) {
546        debug_assert_eq!(gamma.len(), dim);
547        debug_assert_eq!(beta.len(), dim);
548        for t in 0..tokens {
549            let off = t * dim;
550            // Compute mean + variance over `dim` in f64 for stability.
551            let mut mean = 0.0f64;
552            for i in 0..dim {
553                mean += x[off + i] as f64;
554            }
555            mean /= dim as f64;
556            let mut var = 0.0f64;
557            for i in 0..dim {
558                let d = x[off + i] as f64 - mean;
559                var += d * d;
560            }
561            var /= dim as f64;
562            let inv = 1.0f32 / ((var as f32) + eps).sqrt();
563            let mean_f32 = mean as f32;
564            for i in 0..dim {
565                out[off + i] = (x[off + i] - mean_f32) * inv * gamma[i] + beta[i];
566            }
567        }
568    }
569
570    fn gelu(_ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize) {
571        // Exact GELU: 0.5 * x * (1 + erf(x / sqrt(2))).
572        // Uses f64 for erf accuracy (matches torch.nn.functional.gelu default).
573        for i in 0..len {
574            let xi = x[i];
575            out[i] = 0.5 * xi * (1.0 + libm_erf(xi / std::f32::consts::SQRT_2));
576        }
577    }
578
579    fn alloc(len: usize) -> Self::Buffer {
580        vec![0.0f32; len]
581    }
582    fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32> {
583        buf[..len].to_vec()
584    }
585    fn from_slice(data: &[f32]) -> Self::Buffer {
586        data.to_vec()
587    }
588}
589
590// ── Helpers ──────────────────────────────────────────────────────────────
591
592fn dot_product(a: &[f32], b: &[f32]) -> f32 {
593    #[cfg(target_os = "macos")]
594    {
595        let mut result = 0.0f32;
596        unsafe {
597            vDSP_dotpr(a.as_ptr(), 1, b.as_ptr(), 1, &mut result, a.len() as u64);
598        }
599        result
600    }
601    #[cfg(not(target_os = "macos"))]
602    {
603        a.iter().zip(b).map(|(x, y)| x * y).sum()
604    }
605}
606
607#[allow(dead_code)]
608fn apply_rope_impl(
609    data: &mut [f32],
610    tokens: usize,
611    heads: usize,
612    head_dim: usize,
613    half: usize,
614    cos: &[f32],
615    sin: &[f32],
616    positions: &[u32],
617) {
618    for t in 0..tokens {
619        let pos = positions[t] as usize;
620        for h in 0..heads {
621            let base = t * heads * head_dim + h * head_dim;
622            for i in 0..half {
623                let c = cos[pos * half + i];
624                let s = sin[pos * half + i];
625                let x0 = data[base + i];
626                let x1 = data[base + half + i];
627                data[base + i] = x0 * c - x1 * s;
628                data[base + half + i] = x1 * c + x0 * s;
629            }
630        }
631    }
632}
633
634fn cpu_attention(
635    q: &[f32],
636    k: &[f32],
637    v: &[f32],
638    out: &mut [f32],
639    batch: usize,
640    q_len: usize,
641    kv_len: usize,
642    causal: bool,
643    pos_offset: usize,
644    cfg: &AttnConfig,
645) {
646    let nh = cfg.num_heads;
647    let nkv = cfg.num_kv_heads;
648    let d = cfg.head_dim;
649    let n_rep = nh / nkv;
650    let scale = cfg.scale;
651    // Per-head KV stride: 0 (the default) means contiguous (legacy
652    // `kv_cache_append` path reallocates each layer). A non-zero value means
653    // the cache is pre-allocated to `kv_seq_stride` rows per head but only
654    // the first `kv_len` are valid — we skip the rest via `attend_len`.
655    let kv_stride = if cfg.kv_seq_stride > 0 {
656        cfg.kv_seq_stride
657    } else {
658        kv_len
659    };
660
661    for b in 0..batch {
662        for h in 0..nh {
663            let kv_h = h / n_rep;
664            let q_off = (b * nh + h) * q_len * d;
665            let k_off = (b * nkv + kv_h) * kv_stride * d;
666            let v_off = (b * nkv + kv_h) * kv_stride * d;
667            let o_off = (b * nh + h) * q_len * d;
668
669            for qi in 0..q_len {
670                let attend_end = if causal {
671                    (pos_offset + qi + 1).min(kv_len)
672                } else {
673                    kv_len
674                };
675                let attend_start = if causal && cfg.sliding_window > 0 {
676                    attend_end.saturating_sub(cfg.sliding_window)
677                } else {
678                    0
679                };
680                let mut max_score = f32::NEG_INFINITY;
681                let mut sum_exp = 0.0f32;
682                let mut acc = vec![0.0f32; d];
683
684                for ki in attend_start..attend_end {
685                    let mut dot = 0.0f32;
686                    for di in 0..d {
687                        dot += q[q_off + qi * d + di] * k[k_off + ki * d + di];
688                    }
689                    let score = dot * scale;
690                    if score > max_score {
691                        let correction = (max_score - score).exp();
692                        for di in 0..d {
693                            acc[di] *= correction;
694                        }
695                        sum_exp *= correction;
696                        max_score = score;
697                    }
698                    let w = (score - max_score).exp();
699                    sum_exp += w;
700                    for di in 0..d {
701                        acc[di] += w * v[v_off + ki * d + di];
702                    }
703                }
704
705                if sum_exp > 0.0 {
706                    let inv = 1.0 / sum_exp;
707                    for di in 0..d {
708                        out[o_off + qi * d + di] = acc[di] * inv;
709                    }
710                }
711            }
712        }
713    }
714}
715
716/// Minimal error-function approximation (Abramowitz & Stegun 7.1.26),
717/// max error ~1.5e-7 which is comfortably below f32 round-off noise.
718fn libm_erf(x: f32) -> f32 {
719    let sign = if x < 0.0 { -1.0 } else { 1.0 };
720    let x = x.abs();
721    let t = 1.0 / (1.0 + 0.3275911 * x);
722    let y = 1.0
723        - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_72) * t
724            + 0.254_829_6)
725            * t
726            * (-x * x).exp();
727    sign * y
728}
729
730// CPU has no graph-capture analogue; inherit BackendGraph defaults.
731impl crate::backend::BackendGraph for CpuBackend {}
732
733// CPU has no multi-rank collectives; inherit BackendCollective defaults.
734impl crate::backend::BackendCollective for CpuBackend {}
735
736/// Dequant raw GPTQ tensors → row-major `[n, k]` f32. Shared between
737/// the per-tensor `load_gptq` and the MoE `load_gptq_stacked` impls.
738fn cpu_dequant_gptq(
739    qweight: &[i32],
740    scales: &[f32],
741    qzeros: &[i32],
742    bits: u32,
743    group_size: usize,
744    k: usize,
745    n: usize,
746) -> Result<Vec<f32>> {
747    if bits != 4 {
748        return Err(FerrumError::unsupported(format!(
749            "CPU GPTQ: only bits=4 supported (got {bits})"
750        )));
751    }
752    let mut w = vec![0.0f32; n * k];
753    let packed_rows = k / 8;
754    for pr in 0..packed_rows {
755        for col in 0..n {
756            let packed = qweight[pr * n + col] as u32;
757            for bi in 0..8 {
758                let ki = pr * 8 + bi;
759                let q = ((packed >> (bi * 4)) & 0xF) as i32;
760                let grp = ki / group_size;
761                let scale = scales[grp * n + col];
762                let z_packed = qzeros[grp * (n / 8) + (col / 8)] as u32;
763                let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
764                let val = (q - zero) as f32 * scale;
765                w[col * k + ki] = val;
766            }
767        }
768    }
769    Ok(w)
770}
771
772impl crate::backend::BackendQuantMarlin for CpuBackend {
773    fn load_gptq(
774        qweight: &[i32],
775        scales: &[f32],
776        qzeros: &[i32],
777        _g_idx: Option<&[i32]>,
778        bias_host: Option<&[f32]>,
779        bits: u32,
780        group_size: usize,
781        k: usize,
782        n: usize,
783    ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
784        let w = cpu_dequant_gptq(qweight, scales, qzeros, bits, group_size, k, n)?;
785        // Phase 3e/2: dequantized weights become a CpuGptqLinear that
786        // owns the (out_features, in_features) f32 matrix and runs
787        // through the existing Self::gemm CPU path.
788        Ok(Box::new(crate::quant_linear::cpu_dequant::CpuGptqLinear {
789            weight_f32: w,
790            bias: bias_host.map(|b| b.to_vec()),
791            in_features: k,
792            out_features: n,
793        }))
794    }
795    fn load_gptq_stacked(
796        qweights: &[&[i32]],
797        scales: &[&[f32]],
798        qzeros: &[&[i32]],
799        _g_idx: Option<&[i32]>,
800        bits: u32,
801        group_size: usize,
802        k: usize,
803        n_per_expert: usize,
804    ) -> Result<std::sync::Arc<dyn crate::MarlinExpertStack<Self>>> {
805        // Phase 3e/2 addition: dequant each expert independently, concat
806        // along N (rows in [n, k] layout). Used by MoE parity tests.
807        let num_experts = qweights.len();
808        if scales.len() != num_experts || qzeros.len() != num_experts {
809            return Err(FerrumError::model(format!(
810                "load_gptq_stacked: input slice lengths disagree (qw {num_experts}, sc {}, qz {})",
811                scales.len(),
812                qzeros.len()
813            )));
814        }
815        let total_n = num_experts * n_per_expert;
816        let mut all_w = Vec::with_capacity(total_n * k);
817        for ((qw_e, sc_e), qz_e) in qweights.iter().zip(scales.iter()).zip(qzeros.iter()) {
818            let w_e = cpu_dequant_gptq(qw_e, sc_e, qz_e, bits, group_size, k, n_per_expert)?;
819            all_w.extend_from_slice(&w_e);
820        }
821        let store = std::sync::Arc::new(CpuGptqStore {
822            weight_f32: all_w,
823            k,
824            n: total_n,
825        });
826        Ok(std::sync::Arc::new(
827            crate::quant_linear::cpu_marlin_stack::CpuMarlinExpertStack::new(
828                store,
829                num_experts,
830                n_per_expert,
831                k,
832            ),
833        ))
834    }
835    // Phase C step 4b: make_stacked_expert_linear inlined into
836    // CpuMarlinExpertStack::make_expert_linear.
837    // Phase C step 4e: make_marlin_expert_stack subsumed by load_gptq_stacked.
838    // gemm_gptq_with_offset_strided body moved to free function
839    // cpu_gemm_gptq_with_offset_strided below — called by
840    // CpuMarlinExpertStack::gemm_phase_batched.
841}
842
843/// Free-function form of the deleted
844/// `BackendQuantMarlin::gemm_gptq_with_offset_strided` (Phase C step 4e).
845/// Single caller: `CpuMarlinExpertStack::gemm_phase_batched`.
846#[allow(clippy::too_many_arguments)]
847pub(crate) fn cpu_gemm_gptq_with_offset_strided(
848    _ctx: &mut <CpuBackend as Backend>::Context,
849    input: &<CpuBackend as Backend>::Buffer,
850    in_row_offset: usize,
851    weight: &CpuGptqStore,
852    expert_offset: usize,
853    expert_n: usize,
854    output: &mut <CpuBackend as Backend>::Buffer,
855    out_row_offset: usize,
856    m: usize,
857    k: usize,
858) -> Result<()> {
859    if expert_offset + expert_n > weight.n {
860        return Err(FerrumError::model(format!(
861            "cpu_gemm_gptq_with_offset_strided OOB: offset {expert_offset} + n {expert_n} > stacked_n {}",
862            weight.n
863        )));
864    }
865    if k != weight.k {
866        return Err(FerrumError::model(format!(
867            "cpu_gemm_gptq_with_offset_strided k mismatch: arg {k} vs weight.k {}",
868            weight.k
869        )));
870    }
871    let in_start = in_row_offset * k;
872    let in_end = (in_row_offset + m) * k;
873    let out_start = out_row_offset * expert_n;
874    let out_end = (out_row_offset + m) * expert_n;
875    let row_start = expert_offset * k;
876    let row_end = (expert_offset + expert_n) * k;
877    let weight_slice = weight.weight_f32[row_start..row_end].to_vec();
878    let in_slice = input[in_start..in_end].to_vec();
879    let mut out_slice = vec![0.0f32; m * expert_n];
880    let mut ctx_local = ();
881    CpuBackend::gemm(
882        &mut ctx_local,
883        &in_slice,
884        &weight_slice,
885        &mut out_slice,
886        m,
887        expert_n,
888        k,
889    );
890    output[out_start..out_end].copy_from_slice(&out_slice);
891    Ok(())
892}
893
894impl crate::backend::BackendQuantGguf for CpuBackend {
895    fn load_quant(
896        kind: super::GgufQuantType,
897        bytes: &[u8],
898        n_rows: usize,
899        n_cols: usize,
900    ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
901        use super::GgufQuantType;
902        let store = match kind {
903            GgufQuantType::Q4K => {
904                let total_elems = n_rows * n_cols;
905                if total_elems % Q4_K_QK != 0 {
906                    return Err(FerrumError::model(format!(
907                        "load_quant Q4K: elements {total_elems} not a multiple of {Q4_K_QK}"
908                    )));
909                }
910                let n_blocks = total_elems / Q4_K_QK;
911                let expected = n_blocks * Q4_K_BLOCK_BYTES;
912                if bytes.len() != expected {
913                    return Err(FerrumError::model(format!(
914                        "load_quant Q4K: bytes {} != expected {} ({n_blocks} × {Q4_K_BLOCK_BYTES})",
915                        bytes.len(),
916                        expected
917                    )));
918                }
919                CpuQuantStore::Q4K {
920                    weights: dequant_q4_k_cpu(bytes, n_blocks),
921                    n_rows,
922                    n_cols,
923                }
924            }
925            other => {
926                return Err(FerrumError::unsupported(format!(
927                    "CPU load_quant: {other:?} not yet implemented"
928                )));
929            }
930        };
931        // Phase 3e/3: dispatch via CpuGgufLinear::forward instead of
932        // a trait method.
933        Ok(Box::new(crate::quant_linear::cpu_gguf::CpuGgufLinear {
934            store,
935            in_features: n_cols,
936            out_features: n_rows,
937        }))
938    }
939}
940
941// CPU has no paged-KV path; inherit unsupported defaults.
942impl crate::backend::BackendPagedKv for CpuBackend {}
943
944// CPU has no native MoE dispatch; inherit unsupported defaults.
945impl crate::backend::BackendMoeFused for CpuBackend {}
946
947// CPU: existing KV cache path treats fp16 buffer as f32 internally; mark as KvFp16 for compatibility.
948impl crate::backend::BackendKvDtype<crate::backend::KvFp16> for CpuBackend {
949    type KvBuffer = <Self as crate::backend::Backend>::Buffer;
950    type KvScales = ();
951}