Skip to main content

ferrum_kernels/backend/
cpu.rs

1//! CPU backend using Accelerate (macOS) / portable fallback (Linux).
2//! Context = () — all ops execute immediately, no batching needed.
3
4use half::f16;
5
6use super::{AttnConfig, Backend};
7use ferrum_types::{FerrumError, Result};
8
9// ── Q4_K_M block layout ────────────────────────────────────────────────
10//
11// Mirrors GGML / candle's `BlockQ4K`. Used by `load_q4_k` to dequant raw
12// GGUF block bytes to fp32 row-major weights on CPU.
13
14const Q4_K_QK: usize = 256;
15const Q4_K_SCALE_SIZE: usize = 12;
16const Q4_K_BLOCK_BYTES: usize = 4 + Q4_K_SCALE_SIZE + Q4_K_QK / 2; // 144
17
18/// Bit-unpacker matching candle's `quantized::utils::get_scale_min_k4`.
19fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
20    if j < 4 {
21        (q[j] & 63, q[j + 4] & 63)
22    } else {
23        let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
24        let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
25        (d, m)
26    }
27}
28
29/// Port of candle's CPU `BlockQ4K::to_float`. Bit-identical output for
30/// identical input — the test in `q4_k.rs` verifies our Metal kernel
31/// also matches.
32fn dequant_q4_k_cpu(bytes: &[u8], n_blocks: usize) -> Vec<f32> {
33    debug_assert_eq!(bytes.len(), n_blocks * Q4_K_BLOCK_BYTES);
34    let mut out = Vec::with_capacity(n_blocks * Q4_K_QK);
35    for b in 0..n_blocks {
36        let off = b * Q4_K_BLOCK_BYTES;
37        let d = f16::from_le_bytes([bytes[off], bytes[off + 1]]).to_f32();
38        let dmin = f16::from_le_bytes([bytes[off + 2], bytes[off + 3]]).to_f32();
39        let scales = &bytes[off + 4..off + 4 + Q4_K_SCALE_SIZE];
40        let qs = &bytes[off + 4 + Q4_K_SCALE_SIZE..off + Q4_K_BLOCK_BYTES];
41
42        let mut is = 0usize;
43        for j in (0..Q4_K_QK).step_by(64) {
44            let q_chunk = &qs[j / 2..j / 2 + 32];
45            let (sc1, mn1) = get_scale_min_k4(is, scales);
46            let d1 = d * sc1 as f32;
47            let m1 = dmin * mn1 as f32;
48            let (sc2, mn2) = get_scale_min_k4(is + 1, scales);
49            let d2 = d * sc2 as f32;
50            let m2 = dmin * mn2 as f32;
51            for q in q_chunk {
52                out.push(d1 * (q & 0xF) as f32 - m1);
53            }
54            for q in q_chunk {
55                out.push(d2 * (q >> 4) as f32 - m2);
56            }
57            is += 2;
58        }
59    }
60    out
61}
62
63pub struct CpuBackend;
64
65#[cfg(target_os = "macos")]
66unsafe extern "C" {
67    unsafe fn cblas_sgemm(
68        order: i32,
69        transa: i32,
70        transb: i32,
71        m: i32,
72        n: i32,
73        k: i32,
74        alpha: f32,
75        a: *const f32,
76        lda: i32,
77        b: *const f32,
78        ldb: i32,
79        beta: f32,
80        c: *mut f32,
81        ldc: i32,
82    );
83    fn vDSP_dotpr(
84        a: *const f32,
85        a_stride: i32,
86        b: *const f32,
87        b_stride: i32,
88        result: *mut f32,
89        n: u64,
90    );
91}
92
93/// CPU-side GPTQ store — dequantized f32 weights in row-major [n, k] layout.
94/// Trades memory for simplicity: repack once at load, then run normal GEMM.
95pub struct CpuGptqStore {
96    pub weight_f32: Vec<f32>, // [n, k] row-major
97    pub k: usize,
98    pub n: usize,
99}
100
101/// CPU-side container for any GGUF k-quant flavour. Each variant holds
102/// the dense fp32 weights post-eager-dequant — CPU isn't the bench
103/// target so we don't pay the complexity of on-the-fly dequant here;
104/// the variant tag exists so `gemm_quant` can route consistently.
105///
106/// New k-quant types (Q5_K / Q6_K / Q8_0) become new variants — no
107/// trait churn, just a new arm in `load_quant` and `gemm_quant`.
108pub enum CpuQuantStore {
109    Q4K {
110        weights: Vec<f32>, // [n_rows, n_cols] row-major
111        n_rows: usize,
112        n_cols: usize,
113    },
114}
115
116impl Backend for CpuBackend {
117    type Buffer = Vec<f32>;
118    type Context = ();
119    // type GptqStore: removed in Phase C step 4e. CpuGptqStore is now
120    // a private (crate-internal) detail of CpuMarlinExpertStack.
121
122    type Timer = crate::backend::timer::CpuTimer;
123    fn make_timer() -> Self::Timer {
124        crate::backend::timer::CpuTimer::new()
125    }
126
127    fn new_context() -> Self::Context {}
128    fn sync(_ctx: &mut Self::Context) {}
129
130    /// Phase D step 2+3: typed alloc. CPU Buffer is Vec<f32> — bytes
131    /// are dtype-erased, so we size the underlying Vec to hold `n`
132    /// elements of `dtype` (bit-cast at read/write time).
133    fn alloc_typed(dtype: crate::backend::Dtype, n: usize) -> Self::Buffer {
134        // f32 storage; for i8 we round up to 4-byte boundary so the
135        // Vec<f32> length covers all i8 elements.
136        let bytes = n * dtype.bytes_per_elem();
137        let f32_len = bytes.div_ceil(4);
138        vec![0.0f32; f32_len]
139    }
140
141    /// Phase D step 2+3: typed upload. Bit-cast host data into f32
142    /// words (CPU buffer is dtype-erased Vec<f32>, see alloc_typed).
143    fn from_slice_typed<T: crate::backend::HostDtype>(data: &[T]) -> Self::Buffer {
144        let bytes = data.len() * std::mem::size_of::<T>();
145        let f32_len = bytes.div_ceil(4);
146        let mut out = vec![0.0f32; f32_len];
147        unsafe {
148            std::ptr::copy_nonoverlapping(
149                data.as_ptr() as *const u8,
150                out.as_mut_ptr() as *mut u8,
151                bytes,
152            );
153        }
154        out
155    }
156
157    /// Phase D step 2+3: typed in-place write. Bit-cast bytes into
158    /// the dtype-erased f32 storage.
159    fn write_typed<T: crate::backend::HostDtype>(
160        _ctx: &mut Self::Context,
161        dst: &mut Self::Buffer,
162        data: &[T],
163    ) {
164        let bytes = data.len() * std::mem::size_of::<T>();
165        debug_assert!(
166            bytes <= dst.len() * 4,
167            "CpuBackend::write_typed: src bytes {} > dst bytes {}",
168            bytes,
169            dst.len() * 4
170        );
171        unsafe {
172            std::ptr::copy_nonoverlapping(
173                data.as_ptr() as *const u8,
174                dst.as_mut_ptr() as *mut u8,
175                bytes,
176            );
177        }
178    }
179
180    fn fused_silu_mul_split_strided(
181        _ctx: &mut Self::Context,
182        gate_up: &Self::Buffer,
183        in_row_offset: usize,
184        out: &mut Self::Buffer,
185        out_row_offset: usize,
186        tokens: usize,
187        intermediate: usize,
188    ) {
189        let in_per_row = 2 * intermediate;
190        let in_start = in_row_offset * in_per_row;
191        let out_start = out_row_offset * intermediate;
192        for r in 0..tokens {
193            for c in 0..intermediate {
194                let g = gate_up[in_start + r * in_per_row + c];
195                let u = gate_up[in_start + r * in_per_row + intermediate + c];
196                let silu = g / (1.0 + (-g).exp());
197                out[out_start + r * intermediate + c] = silu * u;
198            }
199        }
200    }
201
202    fn gemm(
203        _ctx: &mut Self::Context,
204        a: &Self::Buffer,
205        b: &Self::Buffer,
206        out: &mut Self::Buffer,
207        m: usize,
208        n: usize,
209        k: usize,
210    ) {
211        assert!(
212            a.len() >= m * k,
213            "gemm: a too small len={} m={m} k={k}",
214            a.len()
215        );
216        assert!(
217            b.len() >= n * k,
218            "gemm: b too small len={} n={n} k={k}",
219            b.len()
220        );
221        assert!(
222            out.len() >= m * n,
223            "gemm: out too small len={} m={m} n={n}",
224            out.len()
225        );
226        #[cfg(target_os = "macos")]
227        unsafe {
228            cblas_sgemm(
229                101,
230                111,
231                112,
232                m as i32,
233                n as i32,
234                k as i32,
235                1.0,
236                a.as_ptr(),
237                k as i32,
238                b.as_ptr(),
239                k as i32,
240                0.0,
241                out.as_mut_ptr(),
242                n as i32,
243            );
244        }
245        #[cfg(not(target_os = "macos"))]
246        {
247            for i in 0..m {
248                for j in 0..n {
249                    let mut sum = 0.0f64;
250                    for p in 0..k {
251                        sum += a[i * k + p] as f64 * b[j * k + p] as f64;
252                    }
253                    out[i * n + j] = sum as f32;
254                }
255            }
256        }
257    }
258
259    fn rms_norm(
260        _ctx: &mut Self::Context,
261        x: &Self::Buffer,
262        w: &Self::Buffer,
263        eps: f32,
264        out: &mut Self::Buffer,
265        tokens: usize,
266        dim: usize,
267    ) {
268        for t in 0..tokens {
269            let row = &x[t * dim..(t + 1) * dim];
270            let o = &mut out[t * dim..(t + 1) * dim];
271            let sum_sq = dot_product(row, row);
272            let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
273            for i in 0..dim {
274                o[i] = row[i] * inv * w[i];
275            }
276        }
277    }
278
279    fn fused_add_rms_norm(
280        _ctx: &mut Self::Context,
281        residual: &mut Self::Buffer,
282        x: &Self::Buffer,
283        w: &Self::Buffer,
284        eps: f32,
285        out: &mut Self::Buffer,
286        tokens: usize,
287        dim: usize,
288    ) {
289        for t in 0..tokens {
290            let off = t * dim;
291            for i in 0..dim {
292                residual[off + i] += x[off + i];
293            }
294            let row = &residual[off..off + dim];
295            let o = &mut out[off..off + dim];
296            let sum_sq = dot_product(row, row);
297            let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
298            for i in 0..dim {
299                o[i] = row[i] * inv * w[i];
300            }
301        }
302    }
303
304    fn flash_attention(
305        _ctx: &mut Self::Context,
306        q: &Self::Buffer,
307        k: &Self::Buffer,
308        v: &Self::Buffer,
309        out: &mut Self::Buffer,
310        batch: usize,
311        q_len: usize,
312        kv_len: usize,
313        pos_offset: usize,
314        cfg: &AttnConfig,
315    ) {
316        cpu_attention(
317            q, k, v, out, batch, q_len, kv_len, cfg.causal, pos_offset, cfg,
318        );
319    }
320
321    fn copy_slice(
322        _ctx: &mut Self::Context,
323        src: &Self::Buffer,
324        src_offset: usize,
325        dst: &mut Self::Buffer,
326        dst_offset: usize,
327        len: usize,
328    ) {
329        dst[dst_offset..dst_offset + len].copy_from_slice(&src[src_offset..src_offset + len]);
330    }
331
332    fn embedding_lookup(
333        _ctx: &mut Self::Context,
334        table: &Self::Buffer,
335        ids: &[u32],
336        out: &mut Self::Buffer,
337        dim: usize,
338    ) {
339        for (i, &id) in ids.iter().enumerate() {
340            let src = id as usize * dim;
341            out[i * dim..(i + 1) * dim].copy_from_slice(&table[src..src + dim]);
342        }
343    }
344
345    fn split_qkv(
346        _ctx: &mut Self::Context,
347        qkv: &Self::Buffer,
348        q: &mut Self::Buffer,
349        k: &mut Self::Buffer,
350        v: &mut Self::Buffer,
351        tokens: usize,
352        q_dim: usize,
353        kv_dim: usize,
354    ) {
355        let qkv_dim = q_dim + 2 * kv_dim;
356        for t in 0..tokens {
357            let base = t * qkv_dim;
358            q[t * q_dim..(t + 1) * q_dim].copy_from_slice(&qkv[base..base + q_dim]);
359            k[t * kv_dim..(t + 1) * kv_dim]
360                .copy_from_slice(&qkv[base + q_dim..base + q_dim + kv_dim]);
361            v[t * kv_dim..(t + 1) * kv_dim]
362                .copy_from_slice(&qkv[base + q_dim + kv_dim..base + qkv_dim]);
363        }
364    }
365
366    fn fused_silu_mul_split(
367        _ctx: &mut Self::Context,
368        gate_up: &Self::Buffer,
369        out: &mut Self::Buffer,
370        tokens: usize,
371        im: usize,
372    ) {
373        for t in 0..tokens {
374            for i in 0..im {
375                let g = gate_up[t * 2 * im + i];
376                let u = gate_up[t * 2 * im + im + i];
377                out[t * im + i] = (g / (1.0 + (-g).exp())) * u;
378            }
379        }
380    }
381
382    fn qk_norm_rope(
383        _ctx: &mut Self::Context,
384        input: &Self::Buffer,
385        norm_w: &Self::Buffer,
386        cos: &Self::Buffer,
387        sin: &Self::Buffer,
388        output: &mut Self::Buffer,
389        tokens: usize,
390        heads: usize,
391        head_dim: usize,
392        pos_offset: usize,
393        eps: f32,
394        mode: i32,
395    ) {
396        let half = head_dim / 2;
397        let cos_len = cos.len();
398        let sin_len = sin.len();
399        debug_assert_eq!(cos_len, sin_len);
400
401        for t in 0..tokens {
402            let pos = pos_offset + t;
403            for h in 0..heads {
404                // input row: [t, h, :]  stride = heads * head_dim
405                let src_off = (t * heads + h) * head_dim;
406                // output row: [h, t, :]  stride = tokens * head_dim
407                let dst_off = (h * tokens + t) * head_dim;
408
409                // Mode 0: plain transpose.
410                if mode == 0 {
411                    for i in 0..head_dim {
412                        output[dst_off + i] = input[src_off + i];
413                    }
414                    continue;
415                }
416
417                // Optional RMS norm (mode 1 only).
418                let scale = if mode == 1 {
419                    let mut sum_sq = 0.0f32;
420                    for i in 0..head_dim {
421                        sum_sq += input[src_off + i] * input[src_off + i];
422                    }
423                    1.0f32 / (sum_sq / head_dim as f32 + eps).sqrt()
424                } else {
425                    1.0
426                };
427
428                if mode == 3 {
429                    // GGUF LLaMA / llama.cpp interleaved RoPE layout.
430                    for i in 0..half {
431                        let j = 2 * i;
432                        let x0 = input[src_off + j];
433                        let x1 = input[src_off + j + 1];
434                        let c = cos[pos * half + i];
435                        let s = sin[pos * half + i];
436                        output[dst_off + j] = x0 * c - x1 * s;
437                        output[dst_off + j + 1] = x1 * c + x0 * s;
438                    }
439                } else {
440                    // Apply (norm?) + half-split RoPE to head-major output.
441                    for i in 0..half {
442                        let (x0_raw, x1_raw) = (input[src_off + i], input[src_off + i + half]);
443                        let (x0, x1) = if mode == 1 {
444                            (
445                                x0_raw * scale * norm_w[i],
446                                x1_raw * scale * norm_w[i + half],
447                            )
448                        } else {
449                            (x0_raw, x1_raw)
450                        };
451                        let c = cos[pos * half + i];
452                        let s = sin[pos * half + i];
453                        output[dst_off + i] = x0 * c - x1 * s;
454                        output[dst_off + i + half] = x1 * c + x0 * s;
455                    }
456                }
457            }
458        }
459    }
460
461    fn kv_cache_append_head_major(
462        _ctx: &mut Self::Context,
463        cache_k: &mut Self::Buffer,
464        cache_v: &mut Self::Buffer,
465        cache_len: usize,
466        cache_capacity: usize,
467        new_k_head_major: &Self::Buffer,
468        new_v_head_major: &Self::Buffer,
469        new_tokens: usize,
470        nkv: usize,
471        hd: usize,
472    ) {
473        debug_assert!(cache_len + new_tokens <= cache_capacity);
474        debug_assert_eq!(cache_k.len(), nkv * cache_capacity * hd);
475        debug_assert_eq!(cache_v.len(), nkv * cache_capacity * hd);
476        // The source buffers may be sized for `max_tokens` (the prefill-
477        // sized scratch) while only the first `nkv * new_tokens * hd`
478        // entries are valid for this call. Allow >= so reusing scratch
479        // across prefill and decode doesn't trip the assert.
480        debug_assert!(new_k_head_major.len() >= nkv * new_tokens * hd);
481        debug_assert!(new_v_head_major.len() >= nkv * new_tokens * hd);
482
483        for h in 0..nkv {
484            let dst_base = h * cache_capacity * hd + cache_len * hd;
485            let src_base = h * new_tokens * hd;
486            cache_k[dst_base..dst_base + new_tokens * hd]
487                .copy_from_slice(&new_k_head_major[src_base..src_base + new_tokens * hd]);
488            cache_v[dst_base..dst_base + new_tokens * hd]
489                .copy_from_slice(&new_v_head_major[src_base..src_base + new_tokens * hd]);
490        }
491    }
492
493    fn transpose_head_to_token(
494        _ctx: &mut Self::Context,
495        src: &Self::Buffer,
496        dst: &mut Self::Buffer,
497        tokens: usize,
498        heads: usize,
499        dim: usize,
500    ) {
501        for h in 0..heads {
502            for t in 0..tokens {
503                let s = (h * tokens + t) * dim;
504                let d = (t * heads + h) * dim;
505                dst[d..d + dim].copy_from_slice(&src[s..s + dim]);
506            }
507        }
508    }
509
510    fn add_inplace(
511        _ctx: &mut Self::Context,
512        residual: &mut Self::Buffer,
513        x: &Self::Buffer,
514        len: usize,
515    ) {
516        for i in 0..len {
517            residual[i] += x[i];
518        }
519    }
520
521    fn scaled_add_inplace(
522        _ctx: &mut Self::Context,
523        dst: &mut Self::Buffer,
524        src: &Self::Buffer,
525        scale: f32,
526        len: usize,
527    ) {
528        for i in 0..len {
529            dst[i] += scale * src[i];
530        }
531    }
532
533    fn add_bias(
534        _ctx: &mut Self::Context,
535        data: &mut Self::Buffer,
536        bias: &Self::Buffer,
537        rows: usize,
538        cols: usize,
539    ) {
540        debug_assert_eq!(bias.len(), cols);
541        for r in 0..rows {
542            let off = r * cols;
543            for c in 0..cols {
544                data[off + c] += bias[c];
545            }
546        }
547    }
548
549    fn layer_norm(
550        _ctx: &mut Self::Context,
551        x: &Self::Buffer,
552        gamma: &Self::Buffer,
553        beta: &Self::Buffer,
554        eps: f32,
555        out: &mut Self::Buffer,
556        tokens: usize,
557        dim: usize,
558    ) {
559        debug_assert_eq!(gamma.len(), dim);
560        debug_assert_eq!(beta.len(), dim);
561        for t in 0..tokens {
562            let off = t * dim;
563            // Compute mean + variance over `dim` in f64 for stability.
564            let mut mean = 0.0f64;
565            for i in 0..dim {
566                mean += x[off + i] as f64;
567            }
568            mean /= dim as f64;
569            let mut var = 0.0f64;
570            for i in 0..dim {
571                let d = x[off + i] as f64 - mean;
572                var += d * d;
573            }
574            var /= dim as f64;
575            let inv = 1.0f32 / ((var as f32) + eps).sqrt();
576            let mean_f32 = mean as f32;
577            for i in 0..dim {
578                out[off + i] = (x[off + i] - mean_f32) * inv * gamma[i] + beta[i];
579            }
580        }
581    }
582
583    fn gelu(_ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize) {
584        // Exact GELU: 0.5 * x * (1 + erf(x / sqrt(2))).
585        // Uses f64 for erf accuracy (matches torch.nn.functional.gelu default).
586        for i in 0..len {
587            let xi = x[i];
588            out[i] = 0.5 * xi * (1.0 + libm_erf(xi / std::f32::consts::SQRT_2));
589        }
590    }
591
592    fn alloc(len: usize) -> Self::Buffer {
593        vec![0.0f32; len]
594    }
595    fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32> {
596        buf[..len].to_vec()
597    }
598    fn from_slice(data: &[f32]) -> Self::Buffer {
599        data.to_vec()
600    }
601}
602
603// ── Helpers ──────────────────────────────────────────────────────────────
604
605fn dot_product(a: &[f32], b: &[f32]) -> f32 {
606    #[cfg(target_os = "macos")]
607    {
608        let mut result = 0.0f32;
609        unsafe {
610            vDSP_dotpr(a.as_ptr(), 1, b.as_ptr(), 1, &mut result, a.len() as u64);
611        }
612        result
613    }
614    #[cfg(not(target_os = "macos"))]
615    {
616        a.iter().zip(b).map(|(x, y)| x * y).sum()
617    }
618}
619
620#[allow(dead_code)]
621fn apply_rope_impl(
622    data: &mut [f32],
623    tokens: usize,
624    heads: usize,
625    head_dim: usize,
626    half: usize,
627    cos: &[f32],
628    sin: &[f32],
629    positions: &[u32],
630) {
631    for t in 0..tokens {
632        let pos = positions[t] as usize;
633        for h in 0..heads {
634            let base = t * heads * head_dim + h * head_dim;
635            for i in 0..half {
636                let c = cos[pos * half + i];
637                let s = sin[pos * half + i];
638                let x0 = data[base + i];
639                let x1 = data[base + half + i];
640                data[base + i] = x0 * c - x1 * s;
641                data[base + half + i] = x1 * c + x0 * s;
642            }
643        }
644    }
645}
646
647fn cpu_attention(
648    q: &[f32],
649    k: &[f32],
650    v: &[f32],
651    out: &mut [f32],
652    batch: usize,
653    q_len: usize,
654    kv_len: usize,
655    causal: bool,
656    pos_offset: usize,
657    cfg: &AttnConfig,
658) {
659    let nh = cfg.num_heads;
660    let nkv = cfg.num_kv_heads;
661    let d = cfg.head_dim;
662    let n_rep = nh / nkv;
663    let scale = cfg.scale;
664    // Per-head KV stride: 0 (the default) means contiguous (legacy
665    // `kv_cache_append` path reallocates each layer). A non-zero value means
666    // the cache is pre-allocated to `kv_seq_stride` rows per head but only
667    // the first `kv_len` are valid — we skip the rest via `attend_len`.
668    let kv_stride = if cfg.kv_seq_stride > 0 {
669        cfg.kv_seq_stride
670    } else {
671        kv_len
672    };
673
674    for b in 0..batch {
675        for h in 0..nh {
676            let kv_h = h / n_rep;
677            let q_off = (b * nh + h) * q_len * d;
678            let k_off = (b * nkv + kv_h) * kv_stride * d;
679            let v_off = (b * nkv + kv_h) * kv_stride * d;
680            let o_off = (b * nh + h) * q_len * d;
681
682            for qi in 0..q_len {
683                let attend_end = if causal {
684                    (pos_offset + qi + 1).min(kv_len)
685                } else {
686                    kv_len
687                };
688                let attend_start = if causal && cfg.sliding_window > 0 {
689                    attend_end.saturating_sub(cfg.sliding_window)
690                } else {
691                    0
692                };
693                let mut max_score = f32::NEG_INFINITY;
694                let mut sum_exp = 0.0f32;
695                let mut acc = vec![0.0f32; d];
696
697                for ki in attend_start..attend_end {
698                    let mut dot = 0.0f32;
699                    for di in 0..d {
700                        dot += q[q_off + qi * d + di] * k[k_off + ki * d + di];
701                    }
702                    let score = dot * scale;
703                    if score > max_score {
704                        let correction = (max_score - score).exp();
705                        for di in 0..d {
706                            acc[di] *= correction;
707                        }
708                        sum_exp *= correction;
709                        max_score = score;
710                    }
711                    let w = (score - max_score).exp();
712                    sum_exp += w;
713                    for di in 0..d {
714                        acc[di] += w * v[v_off + ki * d + di];
715                    }
716                }
717
718                if sum_exp > 0.0 {
719                    let inv = 1.0 / sum_exp;
720                    for di in 0..d {
721                        out[o_off + qi * d + di] = acc[di] * inv;
722                    }
723                }
724            }
725        }
726    }
727}
728
729/// Minimal error-function approximation (Abramowitz & Stegun 7.1.26),
730/// max error ~1.5e-7 which is comfortably below f32 round-off noise.
731fn libm_erf(x: f32) -> f32 {
732    let sign = if x < 0.0 { -1.0 } else { 1.0 };
733    let x = x.abs();
734    let t = 1.0 / (1.0 + 0.3275911 * x);
735    let y = 1.0
736        - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_72) * t
737            + 0.254_829_6)
738            * t
739            * (-x * x).exp();
740    sign * y
741}
742
743// CPU has no graph-capture analogue; inherit BackendGraph defaults.
744impl crate::backend::BackendGraph for CpuBackend {}
745
746// CPU has no multi-rank collectives; inherit BackendCollective defaults.
747impl crate::backend::BackendCollective for CpuBackend {}
748
749/// Dequant raw GPTQ tensors → row-major `[n, k]` f32. Shared between
750/// the per-tensor `load_gptq` and the MoE `load_gptq_stacked` impls.
751fn cpu_dequant_gptq(
752    qweight: &[i32],
753    scales: &[f32],
754    qzeros: &[i32],
755    bits: u32,
756    group_size: usize,
757    k: usize,
758    n: usize,
759) -> Result<Vec<f32>> {
760    if bits != 4 {
761        return Err(FerrumError::unsupported(format!(
762            "CPU GPTQ: only bits=4 supported (got {bits})"
763        )));
764    }
765    let mut w = vec![0.0f32; n * k];
766    let packed_rows = k / 8;
767    for pr in 0..packed_rows {
768        for col in 0..n {
769            let packed = qweight[pr * n + col] as u32;
770            for bi in 0..8 {
771                let ki = pr * 8 + bi;
772                let q = ((packed >> (bi * 4)) & 0xF) as i32;
773                let grp = ki / group_size;
774                let scale = scales[grp * n + col];
775                let z_packed = qzeros[grp * (n / 8) + (col / 8)] as u32;
776                let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
777                let val = (q - zero) as f32 * scale;
778                w[col * k + ki] = val;
779            }
780        }
781    }
782    Ok(w)
783}
784
785impl crate::backend::BackendQuantMarlin for CpuBackend {
786    fn load_gptq(
787        qweight: &[i32],
788        scales: &[f32],
789        qzeros: &[i32],
790        _g_idx: Option<&[i32]>,
791        bias_host: Option<&[f32]>,
792        bits: u32,
793        group_size: usize,
794        k: usize,
795        n: usize,
796    ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
797        let w = cpu_dequant_gptq(qweight, scales, qzeros, bits, group_size, k, n)?;
798        // Phase 3e/2: dequantized weights become a CpuGptqLinear that
799        // owns the (out_features, in_features) f32 matrix and runs
800        // through the existing Self::gemm CPU path.
801        Ok(Box::new(crate::quant_linear::cpu_dequant::CpuGptqLinear {
802            weight_f32: w,
803            bias: bias_host.map(|b| b.to_vec()),
804            in_features: k,
805            out_features: n,
806        }))
807    }
808    fn load_gptq_stacked(
809        qweights: &[&[i32]],
810        scales: &[&[f32]],
811        qzeros: &[&[i32]],
812        _g_idx: Option<&[i32]>,
813        bits: u32,
814        group_size: usize,
815        k: usize,
816        n_per_expert: usize,
817    ) -> Result<std::sync::Arc<dyn crate::MarlinExpertStack<Self>>> {
818        // Phase 3e/2 addition: dequant each expert independently, concat
819        // along N (rows in [n, k] layout). Used by MoE parity tests.
820        let num_experts = qweights.len();
821        if scales.len() != num_experts || qzeros.len() != num_experts {
822            return Err(FerrumError::model(format!(
823                "load_gptq_stacked: input slice lengths disagree (qw {num_experts}, sc {}, qz {})",
824                scales.len(),
825                qzeros.len()
826            )));
827        }
828        let total_n = num_experts * n_per_expert;
829        let mut all_w = Vec::with_capacity(total_n * k);
830        for ((qw_e, sc_e), qz_e) in qweights.iter().zip(scales.iter()).zip(qzeros.iter()) {
831            let w_e = cpu_dequant_gptq(qw_e, sc_e, qz_e, bits, group_size, k, n_per_expert)?;
832            all_w.extend_from_slice(&w_e);
833        }
834        let store = std::sync::Arc::new(CpuGptqStore {
835            weight_f32: all_w,
836            k,
837            n: total_n,
838        });
839        Ok(std::sync::Arc::new(
840            crate::quant_linear::cpu_marlin_stack::CpuMarlinExpertStack::new(
841                store,
842                num_experts,
843                n_per_expert,
844                k,
845            ),
846        ))
847    }
848    // Phase C step 4b: make_stacked_expert_linear inlined into
849    // CpuMarlinExpertStack::make_expert_linear.
850    // Phase C step 4e: make_marlin_expert_stack subsumed by load_gptq_stacked.
851    // gemm_gptq_with_offset_strided body moved to free function
852    // cpu_gemm_gptq_with_offset_strided below — called by
853    // CpuMarlinExpertStack::gemm_phase_batched.
854}
855
856/// Free-function form of the deleted
857/// `BackendQuantMarlin::gemm_gptq_with_offset_strided` (Phase C step 4e).
858/// Single caller: `CpuMarlinExpertStack::gemm_phase_batched`.
859#[allow(clippy::too_many_arguments)]
860pub(crate) fn cpu_gemm_gptq_with_offset_strided(
861    _ctx: &mut <CpuBackend as Backend>::Context,
862    input: &<CpuBackend as Backend>::Buffer,
863    in_row_offset: usize,
864    weight: &CpuGptqStore,
865    expert_offset: usize,
866    expert_n: usize,
867    output: &mut <CpuBackend as Backend>::Buffer,
868    out_row_offset: usize,
869    m: usize,
870    k: usize,
871) -> Result<()> {
872    if expert_offset + expert_n > weight.n {
873        return Err(FerrumError::model(format!(
874            "cpu_gemm_gptq_with_offset_strided OOB: offset {expert_offset} + n {expert_n} > stacked_n {}",
875            weight.n
876        )));
877    }
878    if k != weight.k {
879        return Err(FerrumError::model(format!(
880            "cpu_gemm_gptq_with_offset_strided k mismatch: arg {k} vs weight.k {}",
881            weight.k
882        )));
883    }
884    let in_start = in_row_offset * k;
885    let in_end = (in_row_offset + m) * k;
886    let out_start = out_row_offset * expert_n;
887    let out_end = (out_row_offset + m) * expert_n;
888    let row_start = expert_offset * k;
889    let row_end = (expert_offset + expert_n) * k;
890    let weight_slice = weight.weight_f32[row_start..row_end].to_vec();
891    let in_slice = input[in_start..in_end].to_vec();
892    let mut out_slice = vec![0.0f32; m * expert_n];
893    let mut ctx_local = ();
894    CpuBackend::gemm(
895        &mut ctx_local,
896        &in_slice,
897        &weight_slice,
898        &mut out_slice,
899        m,
900        expert_n,
901        k,
902    );
903    output[out_start..out_end].copy_from_slice(&out_slice);
904    Ok(())
905}
906
907impl crate::backend::BackendQuantGguf for CpuBackend {
908    fn load_quant(
909        kind: super::GgufQuantType,
910        bytes: &[u8],
911        n_rows: usize,
912        n_cols: usize,
913    ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
914        use super::GgufQuantType;
915        let store = match kind {
916            GgufQuantType::Q4K => {
917                let total_elems = n_rows * n_cols;
918                if total_elems % Q4_K_QK != 0 {
919                    return Err(FerrumError::model(format!(
920                        "load_quant Q4K: elements {total_elems} not a multiple of {Q4_K_QK}"
921                    )));
922                }
923                let n_blocks = total_elems / Q4_K_QK;
924                let expected = n_blocks * Q4_K_BLOCK_BYTES;
925                if bytes.len() != expected {
926                    return Err(FerrumError::model(format!(
927                        "load_quant Q4K: bytes {} != expected {} ({n_blocks} × {Q4_K_BLOCK_BYTES})",
928                        bytes.len(),
929                        expected
930                    )));
931                }
932                CpuQuantStore::Q4K {
933                    weights: dequant_q4_k_cpu(bytes, n_blocks),
934                    n_rows,
935                    n_cols,
936                }
937            }
938            other => {
939                return Err(FerrumError::unsupported(format!(
940                    "CPU load_quant: {other:?} not yet implemented"
941                )));
942            }
943        };
944        // Phase 3e/3: dispatch via CpuGgufLinear::forward instead of
945        // a trait method.
946        Ok(Box::new(crate::quant_linear::cpu_gguf::CpuGgufLinear {
947            store,
948            in_features: n_cols,
949            out_features: n_rows,
950        }))
951    }
952}
953
954// CPU has no paged-KV path; inherit unsupported defaults.
955impl crate::backend::BackendPagedKv for CpuBackend {}
956
957// CPU has no native MoE dispatch; inherit unsupported defaults.
958impl crate::backend::BackendMoeFused for CpuBackend {}
959
960// CPU: existing KV cache path treats fp16 buffer as f32 internally; mark as KvFp16 for compatibility.
961impl crate::backend::BackendKvDtype<crate::backend::KvFp16> for CpuBackend {
962    type KvBuffer = <Self as crate::backend::Backend>::Buffer;
963    type KvScales = ();
964}