Skip to main content

ferrum_kernels/backend/
cpu.rs

1//! CPU backend using Accelerate (macOS) / portable fallback (Linux).
2//! Context = () — all ops execute immediately, no batching needed.
3
4use half::f16;
5
6use super::{AttnConfig, Backend};
7use ferrum_types::{FerrumError, Result};
8
9// ── Q4_K_M block layout ────────────────────────────────────────────────
10//
11// Mirrors GGML / candle's `BlockQ4K`. Used by `load_q4_k` to dequant raw
12// GGUF block bytes to fp32 row-major weights on CPU.
13
14const Q4_K_QK: usize = 256;
15const Q4_K_SCALE_SIZE: usize = 12;
16const Q4_K_BLOCK_BYTES: usize = 4 + Q4_K_SCALE_SIZE + Q4_K_QK / 2; // 144
17
18/// Bit-unpacker matching candle's `quantized::utils::get_scale_min_k4`.
19fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
20    if j < 4 {
21        (q[j] & 63, q[j + 4] & 63)
22    } else {
23        let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
24        let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
25        (d, m)
26    }
27}
28
29/// Port of candle's CPU `BlockQ4K::to_float`. Bit-identical output for
30/// identical input — the test in `q4_k.rs` verifies our Metal kernel
31/// also matches.
32fn dequant_q4_k_cpu(bytes: &[u8], n_blocks: usize) -> Vec<f32> {
33    debug_assert_eq!(bytes.len(), n_blocks * Q4_K_BLOCK_BYTES);
34    let mut out = Vec::with_capacity(n_blocks * Q4_K_QK);
35    for b in 0..n_blocks {
36        let off = b * Q4_K_BLOCK_BYTES;
37        let d = f16::from_le_bytes([bytes[off], bytes[off + 1]]).to_f32();
38        let dmin = f16::from_le_bytes([bytes[off + 2], bytes[off + 3]]).to_f32();
39        let scales = &bytes[off + 4..off + 4 + Q4_K_SCALE_SIZE];
40        let qs = &bytes[off + 4 + Q4_K_SCALE_SIZE..off + Q4_K_BLOCK_BYTES];
41
42        let mut is = 0usize;
43        for j in (0..Q4_K_QK).step_by(64) {
44            let q_chunk = &qs[j / 2..j / 2 + 32];
45            let (sc1, mn1) = get_scale_min_k4(is, scales);
46            let d1 = d * sc1 as f32;
47            let m1 = dmin * mn1 as f32;
48            let (sc2, mn2) = get_scale_min_k4(is + 1, scales);
49            let d2 = d * sc2 as f32;
50            let m2 = dmin * mn2 as f32;
51            for q in q_chunk {
52                out.push(d1 * (q & 0xF) as f32 - m1);
53            }
54            for q in q_chunk {
55                out.push(d2 * (q >> 4) as f32 - m2);
56            }
57            is += 2;
58        }
59    }
60    out
61}
62
63pub struct CpuBackend;
64
65#[cfg(target_os = "macos")]
66unsafe extern "C" {
67    unsafe fn cblas_sgemm(
68        order: i32,
69        transa: i32,
70        transb: i32,
71        m: i32,
72        n: i32,
73        k: i32,
74        alpha: f32,
75        a: *const f32,
76        lda: i32,
77        b: *const f32,
78        ldb: i32,
79        beta: f32,
80        c: *mut f32,
81        ldc: i32,
82    );
83    fn vDSP_dotpr(
84        a: *const f32,
85        a_stride: i32,
86        b: *const f32,
87        b_stride: i32,
88        result: *mut f32,
89        n: u64,
90    );
91}
92
93/// CPU-side GPTQ store — dequantized f32 weights in row-major [n, k] layout.
94/// Trades memory for simplicity: repack once at load, then run normal GEMM.
95pub struct CpuGptqStore {
96    pub weight_f32: Vec<f32>, // [n, k] row-major
97    pub k: usize,
98    pub n: usize,
99}
100
101/// CPU-side container for any GGUF k-quant flavour. Each variant holds
102/// the dense fp32 weights post-eager-dequant — CPU isn't the bench
103/// target so we don't pay the complexity of on-the-fly dequant here;
104/// the variant tag exists so `gemm_quant` can route consistently.
105///
106/// New k-quant types (Q5_K / Q6_K / Q8_0) become new variants — no
107/// trait churn, just a new arm in `load_quant` and `gemm_quant`.
108pub enum CpuQuantStore {
109    Q4K {
110        weights: Vec<f32>, // [n_rows, n_cols] row-major
111        n_rows: usize,
112        n_cols: usize,
113    },
114}
115
116impl Backend for CpuBackend {
117    type Buffer = Vec<f32>;
118    type Context = ();
119    // type GptqStore: removed in Phase C step 4e. CpuGptqStore is now
120    // a private (crate-internal) detail of CpuMarlinExpertStack.
121
122    type Timer = crate::backend::timer::CpuTimer;
123    fn make_timer() -> Self::Timer {
124        crate::backend::timer::CpuTimer::new()
125    }
126
127    fn new_context() -> Self::Context {}
128    fn sync(_ctx: &mut Self::Context) {}
129    fn activation_elem_size_bytes() -> usize {
130        std::mem::size_of::<f32>()
131    }
132
133    /// Phase D step 2+3: typed alloc. CPU Buffer is Vec<f32> — bytes
134    /// are dtype-erased, so we size the underlying Vec to hold `n`
135    /// elements of `dtype` (bit-cast at read/write time).
136    fn alloc_typed(dtype: crate::backend::Dtype, n: usize) -> Self::Buffer {
137        // f32 storage; for i8 we round up to 4-byte boundary so the
138        // Vec<f32> length covers all i8 elements.
139        let bytes = n * dtype.bytes_per_elem();
140        let f32_len = bytes.div_ceil(4);
141        vec![0.0f32; f32_len]
142    }
143
144    /// Phase D step 2+3: typed upload. Bit-cast host data into f32
145    /// words (CPU buffer is dtype-erased Vec<f32>, see alloc_typed).
146    fn from_slice_typed<T: crate::backend::HostDtype>(data: &[T]) -> Self::Buffer {
147        let bytes = data.len() * std::mem::size_of::<T>();
148        let f32_len = bytes.div_ceil(4);
149        let mut out = vec![0.0f32; f32_len];
150        unsafe {
151            std::ptr::copy_nonoverlapping(
152                data.as_ptr() as *const u8,
153                out.as_mut_ptr() as *mut u8,
154                bytes,
155            );
156        }
157        out
158    }
159
160    /// Phase D step 2+3: typed in-place write. Bit-cast bytes into
161    /// the dtype-erased f32 storage.
162    fn write_typed<T: crate::backend::HostDtype>(
163        _ctx: &mut Self::Context,
164        dst: &mut Self::Buffer,
165        data: &[T],
166    ) {
167        let bytes = data.len() * std::mem::size_of::<T>();
168        debug_assert!(
169            bytes <= dst.len() * 4,
170            "CpuBackend::write_typed: src bytes {} > dst bytes {}",
171            bytes,
172            dst.len() * 4
173        );
174        unsafe {
175            std::ptr::copy_nonoverlapping(
176                data.as_ptr() as *const u8,
177                dst.as_mut_ptr() as *mut u8,
178                bytes,
179            );
180        }
181    }
182
183    fn fused_silu_mul_split_strided(
184        _ctx: &mut Self::Context,
185        gate_up: &Self::Buffer,
186        in_row_offset: usize,
187        out: &mut Self::Buffer,
188        out_row_offset: usize,
189        tokens: usize,
190        intermediate: usize,
191    ) {
192        let in_per_row = 2 * intermediate;
193        let in_start = in_row_offset * in_per_row;
194        let out_start = out_row_offset * intermediate;
195        for r in 0..tokens {
196            for c in 0..intermediate {
197                let g = gate_up[in_start + r * in_per_row + c];
198                let u = gate_up[in_start + r * in_per_row + intermediate + c];
199                let silu = g / (1.0 + (-g).exp());
200                out[out_start + r * intermediate + c] = silu * u;
201            }
202        }
203    }
204
205    fn gemm(
206        _ctx: &mut Self::Context,
207        a: &Self::Buffer,
208        b: &Self::Buffer,
209        out: &mut Self::Buffer,
210        m: usize,
211        n: usize,
212        k: usize,
213    ) {
214        assert!(
215            a.len() >= m * k,
216            "gemm: a too small len={} m={m} k={k}",
217            a.len()
218        );
219        assert!(
220            b.len() >= n * k,
221            "gemm: b too small len={} n={n} k={k}",
222            b.len()
223        );
224        assert!(
225            out.len() >= m * n,
226            "gemm: out too small len={} m={m} n={n}",
227            out.len()
228        );
229        #[cfg(target_os = "macos")]
230        unsafe {
231            cblas_sgemm(
232                101,
233                111,
234                112,
235                m as i32,
236                n as i32,
237                k as i32,
238                1.0,
239                a.as_ptr(),
240                k as i32,
241                b.as_ptr(),
242                k as i32,
243                0.0,
244                out.as_mut_ptr(),
245                n as i32,
246            );
247        }
248        #[cfg(not(target_os = "macos"))]
249        {
250            for i in 0..m {
251                for j in 0..n {
252                    let mut sum = 0.0f64;
253                    for p in 0..k {
254                        sum += a[i * k + p] as f64 * b[j * k + p] as f64;
255                    }
256                    out[i * n + j] = sum as f32;
257                }
258            }
259        }
260    }
261
262    fn rms_norm(
263        _ctx: &mut Self::Context,
264        x: &Self::Buffer,
265        w: &Self::Buffer,
266        eps: f32,
267        out: &mut Self::Buffer,
268        tokens: usize,
269        dim: usize,
270    ) {
271        for t in 0..tokens {
272            let row = &x[t * dim..(t + 1) * dim];
273            let o = &mut out[t * dim..(t + 1) * dim];
274            let sum_sq = dot_product(row, row);
275            let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
276            for i in 0..dim {
277                o[i] = row[i] * inv * w[i];
278            }
279        }
280    }
281
282    fn fused_add_rms_norm(
283        _ctx: &mut Self::Context,
284        residual: &mut Self::Buffer,
285        x: &Self::Buffer,
286        w: &Self::Buffer,
287        eps: f32,
288        out: &mut Self::Buffer,
289        tokens: usize,
290        dim: usize,
291    ) {
292        for t in 0..tokens {
293            let off = t * dim;
294            for i in 0..dim {
295                residual[off + i] += x[off + i];
296            }
297            let row = &residual[off..off + dim];
298            let o = &mut out[off..off + dim];
299            let sum_sq = dot_product(row, row);
300            let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
301            for i in 0..dim {
302                o[i] = row[i] * inv * w[i];
303            }
304        }
305    }
306
307    fn flash_attention(
308        _ctx: &mut Self::Context,
309        q: &Self::Buffer,
310        k: &Self::Buffer,
311        v: &Self::Buffer,
312        out: &mut Self::Buffer,
313        batch: usize,
314        q_len: usize,
315        kv_len: usize,
316        pos_offset: usize,
317        cfg: &AttnConfig,
318    ) {
319        cpu_attention(
320            q, k, v, out, batch, q_len, kv_len, cfg.causal, pos_offset, cfg,
321        );
322    }
323
324    fn copy_slice(
325        _ctx: &mut Self::Context,
326        src: &Self::Buffer,
327        src_offset: usize,
328        dst: &mut Self::Buffer,
329        dst_offset: usize,
330        len: usize,
331    ) {
332        dst[dst_offset..dst_offset + len].copy_from_slice(&src[src_offset..src_offset + len]);
333    }
334
335    fn embedding_lookup(
336        _ctx: &mut Self::Context,
337        table: &Self::Buffer,
338        ids: &[u32],
339        out: &mut Self::Buffer,
340        dim: usize,
341    ) {
342        for (i, &id) in ids.iter().enumerate() {
343            let src = id as usize * dim;
344            out[i * dim..(i + 1) * dim].copy_from_slice(&table[src..src + dim]);
345        }
346    }
347
348    fn split_qkv(
349        _ctx: &mut Self::Context,
350        qkv: &Self::Buffer,
351        q: &mut Self::Buffer,
352        k: &mut Self::Buffer,
353        v: &mut Self::Buffer,
354        tokens: usize,
355        q_dim: usize,
356        kv_dim: usize,
357    ) {
358        let qkv_dim = q_dim + 2 * kv_dim;
359        for t in 0..tokens {
360            let base = t * qkv_dim;
361            q[t * q_dim..(t + 1) * q_dim].copy_from_slice(&qkv[base..base + q_dim]);
362            k[t * kv_dim..(t + 1) * kv_dim]
363                .copy_from_slice(&qkv[base + q_dim..base + q_dim + kv_dim]);
364            v[t * kv_dim..(t + 1) * kv_dim]
365                .copy_from_slice(&qkv[base + q_dim + kv_dim..base + qkv_dim]);
366        }
367    }
368
369    fn fused_silu_mul_split(
370        _ctx: &mut Self::Context,
371        gate_up: &Self::Buffer,
372        out: &mut Self::Buffer,
373        tokens: usize,
374        im: usize,
375    ) {
376        for t in 0..tokens {
377            for i in 0..im {
378                let g = gate_up[t * 2 * im + i];
379                let u = gate_up[t * 2 * im + im + i];
380                out[t * im + i] = (g / (1.0 + (-g).exp())) * u;
381            }
382        }
383    }
384
385    fn qk_norm_rope(
386        _ctx: &mut Self::Context,
387        input: &Self::Buffer,
388        norm_w: &Self::Buffer,
389        cos: &Self::Buffer,
390        sin: &Self::Buffer,
391        output: &mut Self::Buffer,
392        tokens: usize,
393        heads: usize,
394        head_dim: usize,
395        pos_offset: usize,
396        eps: f32,
397        mode: i32,
398    ) {
399        let half = head_dim / 2;
400        let cos_len = cos.len();
401        let sin_len = sin.len();
402        debug_assert_eq!(cos_len, sin_len);
403
404        for t in 0..tokens {
405            let pos = pos_offset + t;
406            for h in 0..heads {
407                // input row: [t, h, :]  stride = heads * head_dim
408                let src_off = (t * heads + h) * head_dim;
409                // output row: [h, t, :]  stride = tokens * head_dim
410                let dst_off = (h * tokens + t) * head_dim;
411
412                // Mode 0: plain transpose.
413                if mode == 0 {
414                    for i in 0..head_dim {
415                        output[dst_off + i] = input[src_off + i];
416                    }
417                    continue;
418                }
419
420                // Optional RMS norm (mode 1 only).
421                let scale = if mode == 1 {
422                    let mut sum_sq = 0.0f32;
423                    for i in 0..head_dim {
424                        sum_sq += input[src_off + i] * input[src_off + i];
425                    }
426                    1.0f32 / (sum_sq / head_dim as f32 + eps).sqrt()
427                } else {
428                    1.0
429                };
430
431                if mode == 3 {
432                    // GGUF LLaMA / llama.cpp interleaved RoPE layout.
433                    for i in 0..half {
434                        let j = 2 * i;
435                        let x0 = input[src_off + j];
436                        let x1 = input[src_off + j + 1];
437                        let c = cos[pos * half + i];
438                        let s = sin[pos * half + i];
439                        output[dst_off + j] = x0 * c - x1 * s;
440                        output[dst_off + j + 1] = x1 * c + x0 * s;
441                    }
442                } else {
443                    // Apply (norm?) + half-split RoPE to head-major output.
444                    for i in 0..half {
445                        let (x0_raw, x1_raw) = (input[src_off + i], input[src_off + i + half]);
446                        let (x0, x1) = if mode == 1 {
447                            (
448                                x0_raw * scale * norm_w[i],
449                                x1_raw * scale * norm_w[i + half],
450                            )
451                        } else {
452                            (x0_raw, x1_raw)
453                        };
454                        let c = cos[pos * half + i];
455                        let s = sin[pos * half + i];
456                        output[dst_off + i] = x0 * c - x1 * s;
457                        output[dst_off + i + half] = x1 * c + x0 * s;
458                    }
459                }
460            }
461        }
462    }
463
464    fn kv_cache_append_head_major(
465        _ctx: &mut Self::Context,
466        cache_k: &mut Self::Buffer,
467        cache_v: &mut Self::Buffer,
468        cache_len: usize,
469        cache_capacity: usize,
470        new_k_head_major: &Self::Buffer,
471        new_v_head_major: &Self::Buffer,
472        new_tokens: usize,
473        nkv: usize,
474        hd: usize,
475    ) {
476        debug_assert!(cache_len + new_tokens <= cache_capacity);
477        debug_assert_eq!(cache_k.len(), nkv * cache_capacity * hd);
478        debug_assert_eq!(cache_v.len(), nkv * cache_capacity * hd);
479        // The source buffers may be sized for `max_tokens` (the prefill-
480        // sized scratch) while only the first `nkv * new_tokens * hd`
481        // entries are valid for this call. Allow >= so reusing scratch
482        // across prefill and decode doesn't trip the assert.
483        debug_assert!(new_k_head_major.len() >= nkv * new_tokens * hd);
484        debug_assert!(new_v_head_major.len() >= nkv * new_tokens * hd);
485
486        for h in 0..nkv {
487            let dst_base = h * cache_capacity * hd + cache_len * hd;
488            let src_base = h * new_tokens * hd;
489            cache_k[dst_base..dst_base + new_tokens * hd]
490                .copy_from_slice(&new_k_head_major[src_base..src_base + new_tokens * hd]);
491            cache_v[dst_base..dst_base + new_tokens * hd]
492                .copy_from_slice(&new_v_head_major[src_base..src_base + new_tokens * hd]);
493        }
494    }
495
496    fn transpose_head_to_token(
497        _ctx: &mut Self::Context,
498        src: &Self::Buffer,
499        dst: &mut Self::Buffer,
500        tokens: usize,
501        heads: usize,
502        dim: usize,
503    ) {
504        for h in 0..heads {
505            for t in 0..tokens {
506                let s = (h * tokens + t) * dim;
507                let d = (t * heads + h) * dim;
508                dst[d..d + dim].copy_from_slice(&src[s..s + dim]);
509            }
510        }
511    }
512
513    fn add_inplace(
514        _ctx: &mut Self::Context,
515        residual: &mut Self::Buffer,
516        x: &Self::Buffer,
517        len: usize,
518    ) {
519        for i in 0..len {
520            residual[i] += x[i];
521        }
522    }
523
524    fn scaled_add_inplace(
525        _ctx: &mut Self::Context,
526        dst: &mut Self::Buffer,
527        src: &Self::Buffer,
528        scale: f32,
529        len: usize,
530    ) {
531        for i in 0..len {
532            dst[i] += scale * src[i];
533        }
534    }
535
536    fn add_bias(
537        _ctx: &mut Self::Context,
538        data: &mut Self::Buffer,
539        bias: &Self::Buffer,
540        rows: usize,
541        cols: usize,
542    ) {
543        debug_assert_eq!(bias.len(), cols);
544        for r in 0..rows {
545            let off = r * cols;
546            for c in 0..cols {
547                data[off + c] += bias[c];
548            }
549        }
550    }
551
552    fn layer_norm(
553        _ctx: &mut Self::Context,
554        x: &Self::Buffer,
555        gamma: &Self::Buffer,
556        beta: &Self::Buffer,
557        eps: f32,
558        out: &mut Self::Buffer,
559        tokens: usize,
560        dim: usize,
561    ) {
562        debug_assert_eq!(gamma.len(), dim);
563        debug_assert_eq!(beta.len(), dim);
564        for t in 0..tokens {
565            let off = t * dim;
566            // Compute mean + variance over `dim` in f64 for stability.
567            let mut mean = 0.0f64;
568            for i in 0..dim {
569                mean += x[off + i] as f64;
570            }
571            mean /= dim as f64;
572            let mut var = 0.0f64;
573            for i in 0..dim {
574                let d = x[off + i] as f64 - mean;
575                var += d * d;
576            }
577            var /= dim as f64;
578            let inv = 1.0f32 / ((var as f32) + eps).sqrt();
579            let mean_f32 = mean as f32;
580            for i in 0..dim {
581                out[off + i] = (x[off + i] - mean_f32) * inv * gamma[i] + beta[i];
582            }
583        }
584    }
585
586    fn gelu(_ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize) {
587        // Exact GELU: 0.5 * x * (1 + erf(x / sqrt(2))).
588        // Uses f64 for erf accuracy (matches torch.nn.functional.gelu default).
589        for i in 0..len {
590            let xi = x[i];
591            out[i] = 0.5 * xi * (1.0 + libm_erf(xi / std::f32::consts::SQRT_2));
592        }
593    }
594
595    fn alloc(len: usize) -> Self::Buffer {
596        vec![0.0f32; len]
597    }
598    fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32> {
599        buf[..len].to_vec()
600    }
601    fn from_slice(data: &[f32]) -> Self::Buffer {
602        data.to_vec()
603    }
604}
605
606// ── Helpers ──────────────────────────────────────────────────────────────
607
608fn dot_product(a: &[f32], b: &[f32]) -> f32 {
609    #[cfg(target_os = "macos")]
610    {
611        let mut result = 0.0f32;
612        unsafe {
613            vDSP_dotpr(a.as_ptr(), 1, b.as_ptr(), 1, &mut result, a.len() as u64);
614        }
615        result
616    }
617    #[cfg(not(target_os = "macos"))]
618    {
619        a.iter().zip(b).map(|(x, y)| x * y).sum()
620    }
621}
622
623#[allow(dead_code)]
624fn apply_rope_impl(
625    data: &mut [f32],
626    tokens: usize,
627    heads: usize,
628    head_dim: usize,
629    half: usize,
630    cos: &[f32],
631    sin: &[f32],
632    positions: &[u32],
633) {
634    for t in 0..tokens {
635        let pos = positions[t] as usize;
636        for h in 0..heads {
637            let base = t * heads * head_dim + h * head_dim;
638            for i in 0..half {
639                let c = cos[pos * half + i];
640                let s = sin[pos * half + i];
641                let x0 = data[base + i];
642                let x1 = data[base + half + i];
643                data[base + i] = x0 * c - x1 * s;
644                data[base + half + i] = x1 * c + x0 * s;
645            }
646        }
647    }
648}
649
650fn cpu_attention(
651    q: &[f32],
652    k: &[f32],
653    v: &[f32],
654    out: &mut [f32],
655    batch: usize,
656    q_len: usize,
657    kv_len: usize,
658    causal: bool,
659    pos_offset: usize,
660    cfg: &AttnConfig,
661) {
662    let nh = cfg.num_heads;
663    let nkv = cfg.num_kv_heads;
664    let d = cfg.head_dim;
665    let n_rep = nh / nkv;
666    let scale = cfg.scale;
667    // Per-head KV stride: 0 (the default) means contiguous (legacy
668    // `kv_cache_append` path reallocates each layer). A non-zero value means
669    // the cache is pre-allocated to `kv_seq_stride` rows per head but only
670    // the first `kv_len` are valid — we skip the rest via `attend_len`.
671    let kv_stride = if cfg.kv_seq_stride > 0 {
672        cfg.kv_seq_stride
673    } else {
674        kv_len
675    };
676
677    for b in 0..batch {
678        for h in 0..nh {
679            let kv_h = h / n_rep;
680            let q_off = (b * nh + h) * q_len * d;
681            let k_off = (b * nkv + kv_h) * kv_stride * d;
682            let v_off = (b * nkv + kv_h) * kv_stride * d;
683            let o_off = (b * nh + h) * q_len * d;
684
685            for qi in 0..q_len {
686                let attend_end = if causal {
687                    (pos_offset + qi + 1).min(kv_len)
688                } else {
689                    kv_len
690                };
691                let attend_start = if causal && cfg.sliding_window > 0 {
692                    attend_end.saturating_sub(cfg.sliding_window)
693                } else {
694                    0
695                };
696                let mut max_score = f32::NEG_INFINITY;
697                let mut sum_exp = 0.0f32;
698                let mut acc = vec![0.0f32; d];
699
700                for ki in attend_start..attend_end {
701                    let mut dot = 0.0f32;
702                    for di in 0..d {
703                        dot += q[q_off + qi * d + di] * k[k_off + ki * d + di];
704                    }
705                    let score = dot * scale;
706                    if score > max_score {
707                        let correction = (max_score - score).exp();
708                        for di in 0..d {
709                            acc[di] *= correction;
710                        }
711                        sum_exp *= correction;
712                        max_score = score;
713                    }
714                    let w = (score - max_score).exp();
715                    sum_exp += w;
716                    for di in 0..d {
717                        acc[di] += w * v[v_off + ki * d + di];
718                    }
719                }
720
721                if sum_exp > 0.0 {
722                    let inv = 1.0 / sum_exp;
723                    for di in 0..d {
724                        out[o_off + qi * d + di] = acc[di] * inv;
725                    }
726                }
727            }
728        }
729    }
730}
731
732/// Minimal error-function approximation (Abramowitz & Stegun 7.1.26),
733/// max error ~1.5e-7 which is comfortably below f32 round-off noise.
734fn libm_erf(x: f32) -> f32 {
735    let sign = if x < 0.0 { -1.0 } else { 1.0 };
736    let x = x.abs();
737    let t = 1.0 / (1.0 + 0.3275911 * x);
738    let y = 1.0
739        - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_72) * t
740            + 0.254_829_6)
741            * t
742            * (-x * x).exp();
743    sign * y
744}
745
746// CPU has no graph-capture analogue; inherit BackendGraph defaults.
747impl crate::backend::BackendGraph for CpuBackend {}
748
749// CPU has no multi-rank collectives; inherit BackendCollective defaults.
750impl crate::backend::BackendCollective for CpuBackend {}
751
752/// Dequant raw GPTQ tensors → row-major `[n, k]` f32. Shared between
753/// the per-tensor `load_gptq` and the MoE `load_gptq_stacked` impls.
754fn cpu_dequant_gptq(
755    qweight: &[i32],
756    scales: &[f32],
757    qzeros: &[i32],
758    bits: u32,
759    group_size: usize,
760    k: usize,
761    n: usize,
762) -> Result<Vec<f32>> {
763    if bits != 4 {
764        return Err(FerrumError::unsupported(format!(
765            "CPU GPTQ: only bits=4 supported (got {bits})"
766        )));
767    }
768    let mut w = vec![0.0f32; n * k];
769    let packed_rows = k / 8;
770    for pr in 0..packed_rows {
771        for col in 0..n {
772            let packed = qweight[pr * n + col] as u32;
773            for bi in 0..8 {
774                let ki = pr * 8 + bi;
775                let q = ((packed >> (bi * 4)) & 0xF) as i32;
776                let grp = ki / group_size;
777                let scale = scales[grp * n + col];
778                let z_packed = qzeros[grp * (n / 8) + (col / 8)] as u32;
779                let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
780                let val = (q - zero) as f32 * scale;
781                w[col * k + ki] = val;
782            }
783        }
784    }
785    Ok(w)
786}
787
788impl crate::backend::BackendQuantMarlin for CpuBackend {
789    fn load_gptq(
790        qweight: &[i32],
791        scales: &[f32],
792        qzeros: &[i32],
793        _g_idx: Option<&[i32]>,
794        bias_host: Option<&[f32]>,
795        bits: u32,
796        group_size: usize,
797        k: usize,
798        n: usize,
799    ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
800        let w = cpu_dequant_gptq(qweight, scales, qzeros, bits, group_size, k, n)?;
801        // Phase 3e/2: dequantized weights become a CpuGptqLinear that
802        // owns the (out_features, in_features) f32 matrix and runs
803        // through the existing Self::gemm CPU path.
804        Ok(Box::new(crate::quant_linear::cpu_dequant::CpuGptqLinear {
805            weight_f32: w,
806            bias: bias_host.map(|b| b.to_vec()),
807            in_features: k,
808            out_features: n,
809        }))
810    }
811    fn load_gptq_stacked(
812        qweights: &[&[i32]],
813        scales: &[&[f32]],
814        qzeros: &[&[i32]],
815        _g_idx: Option<&[i32]>,
816        bits: u32,
817        group_size: usize,
818        k: usize,
819        n_per_expert: usize,
820    ) -> Result<std::sync::Arc<dyn crate::MarlinExpertStack<Self>>> {
821        // Phase 3e/2 addition: dequant each expert independently, concat
822        // along N (rows in [n, k] layout). Used by MoE parity tests.
823        let num_experts = qweights.len();
824        if scales.len() != num_experts || qzeros.len() != num_experts {
825            return Err(FerrumError::model(format!(
826                "load_gptq_stacked: input slice lengths disagree (qw {num_experts}, sc {}, qz {})",
827                scales.len(),
828                qzeros.len()
829            )));
830        }
831        let total_n = num_experts * n_per_expert;
832        let mut all_w = Vec::with_capacity(total_n * k);
833        for ((qw_e, sc_e), qz_e) in qweights.iter().zip(scales.iter()).zip(qzeros.iter()) {
834            let w_e = cpu_dequant_gptq(qw_e, sc_e, qz_e, bits, group_size, k, n_per_expert)?;
835            all_w.extend_from_slice(&w_e);
836        }
837        let store = std::sync::Arc::new(CpuGptqStore {
838            weight_f32: all_w,
839            k,
840            n: total_n,
841        });
842        Ok(std::sync::Arc::new(
843            crate::quant_linear::cpu_marlin_stack::CpuMarlinExpertStack::new(
844                store,
845                num_experts,
846                n_per_expert,
847                k,
848            ),
849        ))
850    }
851    // Phase C step 4b: make_stacked_expert_linear inlined into
852    // CpuMarlinExpertStack::make_expert_linear.
853    // Phase C step 4e: make_marlin_expert_stack subsumed by load_gptq_stacked.
854    // gemm_gptq_with_offset_strided body moved to free function
855    // cpu_gemm_gptq_with_offset_strided below — called by
856    // CpuMarlinExpertStack::gemm_phase_batched.
857}
858
859/// Free-function form of the deleted
860/// `BackendQuantMarlin::gemm_gptq_with_offset_strided` (Phase C step 4e).
861/// Single caller: `CpuMarlinExpertStack::gemm_phase_batched`.
862#[allow(clippy::too_many_arguments)]
863pub(crate) fn cpu_gemm_gptq_with_offset_strided(
864    _ctx: &mut <CpuBackend as Backend>::Context,
865    input: &<CpuBackend as Backend>::Buffer,
866    in_row_offset: usize,
867    weight: &CpuGptqStore,
868    expert_offset: usize,
869    expert_n: usize,
870    output: &mut <CpuBackend as Backend>::Buffer,
871    out_row_offset: usize,
872    m: usize,
873    k: usize,
874) -> Result<()> {
875    if expert_offset + expert_n > weight.n {
876        return Err(FerrumError::model(format!(
877            "cpu_gemm_gptq_with_offset_strided OOB: offset {expert_offset} + n {expert_n} > stacked_n {}",
878            weight.n
879        )));
880    }
881    if k != weight.k {
882        return Err(FerrumError::model(format!(
883            "cpu_gemm_gptq_with_offset_strided k mismatch: arg {k} vs weight.k {}",
884            weight.k
885        )));
886    }
887    let in_start = in_row_offset * k;
888    let in_end = (in_row_offset + m) * k;
889    let out_start = out_row_offset * expert_n;
890    let out_end = (out_row_offset + m) * expert_n;
891    let row_start = expert_offset * k;
892    let row_end = (expert_offset + expert_n) * k;
893    let weight_slice = weight.weight_f32[row_start..row_end].to_vec();
894    let in_slice = input[in_start..in_end].to_vec();
895    let mut out_slice = vec![0.0f32; m * expert_n];
896    let mut ctx_local = ();
897    CpuBackend::gemm(
898        &mut ctx_local,
899        &in_slice,
900        &weight_slice,
901        &mut out_slice,
902        m,
903        expert_n,
904        k,
905    );
906    output[out_start..out_end].copy_from_slice(&out_slice);
907    Ok(())
908}
909
910impl crate::backend::BackendQuantGguf for CpuBackend {
911    fn load_quant(
912        kind: super::GgufQuantType,
913        bytes: &[u8],
914        n_rows: usize,
915        n_cols: usize,
916    ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
917        use super::GgufQuantType;
918        let store = match kind {
919            GgufQuantType::Q4K => {
920                let total_elems = n_rows * n_cols;
921                if total_elems % Q4_K_QK != 0 {
922                    return Err(FerrumError::model(format!(
923                        "load_quant Q4K: elements {total_elems} not a multiple of {Q4_K_QK}"
924                    )));
925                }
926                let n_blocks = total_elems / Q4_K_QK;
927                let expected = n_blocks * Q4_K_BLOCK_BYTES;
928                if bytes.len() != expected {
929                    return Err(FerrumError::model(format!(
930                        "load_quant Q4K: bytes {} != expected {} ({n_blocks} × {Q4_K_BLOCK_BYTES})",
931                        bytes.len(),
932                        expected
933                    )));
934                }
935                CpuQuantStore::Q4K {
936                    weights: dequant_q4_k_cpu(bytes, n_blocks),
937                    n_rows,
938                    n_cols,
939                }
940            }
941            other => {
942                return Err(FerrumError::unsupported(format!(
943                    "CPU load_quant: {other:?} not yet implemented"
944                )));
945            }
946        };
947        // Phase 3e/3: dispatch via CpuGgufLinear::forward instead of
948        // a trait method.
949        Ok(Box::new(crate::quant_linear::cpu_gguf::CpuGgufLinear {
950            store,
951            in_features: n_cols,
952            out_features: n_rows,
953        }))
954    }
955}
956
957// CPU has no paged-KV path; inherit unsupported defaults.
958impl crate::backend::BackendPagedKv for CpuBackend {}
959
960// CPU has no native MoE dispatch; inherit unsupported defaults.
961impl crate::backend::BackendMoeFused for CpuBackend {}
962
963// CPU: existing KV cache path treats fp16 buffer as f32 internally; mark as KvFp16 for compatibility.
964impl crate::backend::BackendKvDtype<crate::backend::KvFp16> for CpuBackend {
965    type KvBuffer = <Self as crate::backend::Backend>::Buffer;
966    type KvScales = ();
967}