Skip to main content

ferrum_kernels/backend/
cpu.rs

1//! CPU backend using Accelerate (macOS) / portable fallback (Linux).
2//! Context = () — all ops execute immediately, no batching needed.
3
4use half::f16;
5
6use super::{AttnConfig, Backend};
7use ferrum_types::{FerrumError, Result};
8
9// ── Q4_K_M block layout ────────────────────────────────────────────────
10//
11// Mirrors GGML / candle's `BlockQ4K`. Used by `load_q4_k` to dequant raw
12// GGUF block bytes to fp32 row-major weights on CPU.
13
14const Q4_K_QK: usize = 256;
15const Q4_K_SCALE_SIZE: usize = 12;
16const Q4_K_BLOCK_BYTES: usize = 4 + Q4_K_SCALE_SIZE + Q4_K_QK / 2; // 144
17
18/// Bit-unpacker matching candle's `quantized::utils::get_scale_min_k4`.
19fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
20    if j < 4 {
21        (q[j] & 63, q[j + 4] & 63)
22    } else {
23        let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
24        let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
25        (d, m)
26    }
27}
28
29/// Port of candle's CPU `BlockQ4K::to_float`. Bit-identical output for
30/// identical input — the test in `q4_k.rs` verifies our Metal kernel
31/// also matches.
32fn dequant_q4_k_cpu(bytes: &[u8], n_blocks: usize) -> Vec<f32> {
33    debug_assert_eq!(bytes.len(), n_blocks * Q4_K_BLOCK_BYTES);
34    let mut out = Vec::with_capacity(n_blocks * Q4_K_QK);
35    for b in 0..n_blocks {
36        let off = b * Q4_K_BLOCK_BYTES;
37        let d = f16::from_le_bytes([bytes[off], bytes[off + 1]]).to_f32();
38        let dmin = f16::from_le_bytes([bytes[off + 2], bytes[off + 3]]).to_f32();
39        let scales = &bytes[off + 4..off + 4 + Q4_K_SCALE_SIZE];
40        let qs = &bytes[off + 4 + Q4_K_SCALE_SIZE..off + Q4_K_BLOCK_BYTES];
41
42        let mut is = 0usize;
43        for j in (0..Q4_K_QK).step_by(64) {
44            let q_chunk = &qs[j / 2..j / 2 + 32];
45            let (sc1, mn1) = get_scale_min_k4(is, scales);
46            let d1 = d * sc1 as f32;
47            let m1 = dmin * mn1 as f32;
48            let (sc2, mn2) = get_scale_min_k4(is + 1, scales);
49            let d2 = d * sc2 as f32;
50            let m2 = dmin * mn2 as f32;
51            for q in q_chunk {
52                out.push(d1 * (q & 0xF) as f32 - m1);
53            }
54            for q in q_chunk {
55                out.push(d2 * (q >> 4) as f32 - m2);
56            }
57            is += 2;
58        }
59    }
60    out
61}
62
63pub struct CpuBackend;
64
65#[cfg(target_os = "macos")]
66unsafe extern "C" {
67    unsafe fn cblas_sgemm(
68        order: i32,
69        transa: i32,
70        transb: i32,
71        m: i32,
72        n: i32,
73        k: i32,
74        alpha: f32,
75        a: *const f32,
76        lda: i32,
77        b: *const f32,
78        ldb: i32,
79        beta: f32,
80        c: *mut f32,
81        ldc: i32,
82    );
83    fn vDSP_dotpr(
84        a: *const f32,
85        a_stride: i32,
86        b: *const f32,
87        b_stride: i32,
88        result: *mut f32,
89        n: u64,
90    );
91}
92
93/// CPU-side GPTQ store — dequantized f32 weights in row-major [n, k] layout.
94/// Trades memory for simplicity: repack once at load, then run normal GEMM.
95pub struct CpuGptqStore {
96    pub weight_f32: Vec<f32>, // [n, k] row-major
97    pub k: usize,
98    pub n: usize,
99}
100
101/// CPU-side container for any GGUF k-quant flavour. Each variant holds
102/// the dense fp32 weights post-eager-dequant — CPU isn't the bench
103/// target so we don't pay the complexity of on-the-fly dequant here;
104/// the variant tag exists so `gemm_quant` can route consistently.
105///
106/// New k-quant types (Q5_K / Q6_K / Q8_0) become new variants — no
107/// trait churn, just a new arm in `load_quant` and `gemm_quant`.
108pub enum CpuQuantStore {
109    Q4K {
110        weights: Vec<f32>, // [n_rows, n_cols] row-major
111        n_rows: usize,
112        n_cols: usize,
113    },
114}
115
116impl Backend for CpuBackend {
117    type Buffer = Vec<f32>;
118    type Context = ();
119    type GptqStore = CpuGptqStore;
120    type QuantStore = CpuQuantStore;
121
122    fn new_context() -> Self::Context {}
123    fn sync(_ctx: &mut Self::Context) {}
124
125    fn load_gptq(
126        qweight: &[i32],
127        scales: &[f32],
128        qzeros: &[i32],
129        _g_idx: Option<&[i32]>,
130        bits: u32,
131        group_size: usize,
132        k: usize,
133        n: usize,
134    ) -> Result<Self::GptqStore> {
135        if bits != 4 {
136            return Err(FerrumError::unsupported(format!(
137                "CPU GPTQ: only bits=4 supported (got {bits})"
138            )));
139        }
140        let num_groups = k / group_size;
141        // Unpack GPTQ [K/8, N] i32 → int4 values, dequantize per-group:
142        //   w_f16 = (q - zero) * scale
143        // Write to [n, k] row-major (matches DenseLinear convention).
144        let mut w = vec![0.0f32; n * k];
145        let packed_rows = k / 8;
146        for pr in 0..packed_rows {
147            for col in 0..n {
148                let packed = qweight[pr * n + col] as u32;
149                for bi in 0..8 {
150                    let ki = pr * 8 + bi;
151                    let q = ((packed >> (bi * 4)) & 0xF) as i32;
152                    let grp = ki / group_size;
153                    let scale = scales[grp * n + col];
154                    // qzeros [num_groups, N/8] i32 packs 8 zero-values per int32
155                    let z_packed = qzeros[grp * (n / 8) + (col / 8)] as u32;
156                    let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
157                    let val = (q - zero) as f32 * scale;
158                    w[col * k + ki] = val;
159                }
160            }
161        }
162        let _ = num_groups; // informational only
163        Ok(CpuGptqStore {
164            weight_f32: w,
165            k,
166            n,
167        })
168    }
169
170    fn gemm_gptq(
171        ctx: &mut Self::Context,
172        a: &Self::Buffer,
173        weight: &Self::GptqStore,
174        out: &mut Self::Buffer,
175        m: usize,
176    ) -> Result<()> {
177        // Just run normal GEMM with dequantized weights.
178        // out[m, n] = a[m, k] @ w[n, k]^T — same contract as B::gemm.
179        Self::gemm(ctx, a, &weight.weight_f32, out, m, weight.n, weight.k);
180        Ok(())
181    }
182
183    fn load_quant(
184        kind: super::GgufQuantType,
185        bytes: &[u8],
186        n_rows: usize,
187        n_cols: usize,
188    ) -> Result<Self::QuantStore> {
189        use super::GgufQuantType;
190        match kind {
191            GgufQuantType::Q4K => {
192                let total_elems = n_rows * n_cols;
193                if total_elems % Q4_K_QK != 0 {
194                    return Err(FerrumError::model(format!(
195                        "load_quant Q4K: elements {total_elems} not a multiple of {Q4_K_QK}"
196                    )));
197                }
198                let n_blocks = total_elems / Q4_K_QK;
199                let expected = n_blocks * Q4_K_BLOCK_BYTES;
200                if bytes.len() != expected {
201                    return Err(FerrumError::model(format!(
202                        "load_quant Q4K: bytes {} != expected {} ({n_blocks} × {Q4_K_BLOCK_BYTES})",
203                        bytes.len(),
204                        expected
205                    )));
206                }
207                Ok(CpuQuantStore::Q4K {
208                    weights: dequant_q4_k_cpu(bytes, n_blocks),
209                    n_rows,
210                    n_cols,
211                })
212            }
213            other => Err(FerrumError::unsupported(format!(
214                "CPU load_quant: {other:?} not yet implemented"
215            ))),
216        }
217    }
218
219    fn gemm_quant(
220        ctx: &mut Self::Context,
221        a: &Self::Buffer,
222        weight: &Self::QuantStore,
223        out: &mut Self::Buffer,
224        m: usize,
225    ) -> Result<()> {
226        match weight {
227            CpuQuantStore::Q4K {
228                weights,
229                n_rows,
230                n_cols,
231            } => {
232                Self::gemm(ctx, a, weights, out, m, *n_rows, *n_cols);
233                Ok(())
234            }
235        }
236    }
237
238    fn gemm(
239        _ctx: &mut Self::Context,
240        a: &Self::Buffer,
241        b: &Self::Buffer,
242        out: &mut Self::Buffer,
243        m: usize,
244        n: usize,
245        k: usize,
246    ) {
247        assert!(
248            a.len() >= m * k,
249            "gemm: a too small len={} m={m} k={k}",
250            a.len()
251        );
252        assert!(
253            b.len() >= n * k,
254            "gemm: b too small len={} n={n} k={k}",
255            b.len()
256        );
257        assert!(
258            out.len() >= m * n,
259            "gemm: out too small len={} m={m} n={n}",
260            out.len()
261        );
262        #[cfg(target_os = "macos")]
263        unsafe {
264            cblas_sgemm(
265                101,
266                111,
267                112,
268                m as i32,
269                n as i32,
270                k as i32,
271                1.0,
272                a.as_ptr(),
273                k as i32,
274                b.as_ptr(),
275                k as i32,
276                0.0,
277                out.as_mut_ptr(),
278                n as i32,
279            );
280        }
281        #[cfg(not(target_os = "macos"))]
282        {
283            for i in 0..m {
284                for j in 0..n {
285                    let mut sum = 0.0f64;
286                    for p in 0..k {
287                        sum += a[i * k + p] as f64 * b[j * k + p] as f64;
288                    }
289                    out[i * n + j] = sum as f32;
290                }
291            }
292        }
293    }
294
295    fn rms_norm(
296        _ctx: &mut Self::Context,
297        x: &Self::Buffer,
298        w: &Self::Buffer,
299        eps: f32,
300        out: &mut Self::Buffer,
301        tokens: usize,
302        dim: usize,
303    ) {
304        for t in 0..tokens {
305            let row = &x[t * dim..(t + 1) * dim];
306            let o = &mut out[t * dim..(t + 1) * dim];
307            let sum_sq = dot_product(row, row);
308            let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
309            for i in 0..dim {
310                o[i] = row[i] * inv * w[i];
311            }
312        }
313    }
314
315    fn fused_add_rms_norm(
316        _ctx: &mut Self::Context,
317        residual: &mut Self::Buffer,
318        x: &Self::Buffer,
319        w: &Self::Buffer,
320        eps: f32,
321        out: &mut Self::Buffer,
322        tokens: usize,
323        dim: usize,
324    ) {
325        for t in 0..tokens {
326            let off = t * dim;
327            for i in 0..dim {
328                residual[off + i] += x[off + i];
329            }
330            let row = &residual[off..off + dim];
331            let o = &mut out[off..off + dim];
332            let sum_sq = dot_product(row, row);
333            let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
334            for i in 0..dim {
335                o[i] = row[i] * inv * w[i];
336            }
337        }
338    }
339
340    fn flash_attention(
341        _ctx: &mut Self::Context,
342        q: &Self::Buffer,
343        k: &Self::Buffer,
344        v: &Self::Buffer,
345        out: &mut Self::Buffer,
346        batch: usize,
347        q_len: usize,
348        kv_len: usize,
349        pos_offset: usize,
350        cfg: &AttnConfig,
351    ) {
352        cpu_attention(
353            q, k, v, out, batch, q_len, kv_len, cfg.causal, pos_offset, cfg,
354        );
355    }
356
357    fn copy_slice(
358        _ctx: &mut Self::Context,
359        src: &Self::Buffer,
360        src_offset: usize,
361        dst: &mut Self::Buffer,
362        dst_offset: usize,
363        len: usize,
364    ) {
365        dst[dst_offset..dst_offset + len].copy_from_slice(&src[src_offset..src_offset + len]);
366    }
367
368    fn embedding_lookup(
369        _ctx: &mut Self::Context,
370        table: &Self::Buffer,
371        ids: &[u32],
372        out: &mut Self::Buffer,
373        dim: usize,
374    ) {
375        for (i, &id) in ids.iter().enumerate() {
376            let src = id as usize * dim;
377            out[i * dim..(i + 1) * dim].copy_from_slice(&table[src..src + dim]);
378        }
379    }
380
381    fn split_qkv(
382        _ctx: &mut Self::Context,
383        qkv: &Self::Buffer,
384        q: &mut Self::Buffer,
385        k: &mut Self::Buffer,
386        v: &mut Self::Buffer,
387        tokens: usize,
388        q_dim: usize,
389        kv_dim: usize,
390    ) {
391        let qkv_dim = q_dim + 2 * kv_dim;
392        for t in 0..tokens {
393            let base = t * qkv_dim;
394            q[t * q_dim..(t + 1) * q_dim].copy_from_slice(&qkv[base..base + q_dim]);
395            k[t * kv_dim..(t + 1) * kv_dim]
396                .copy_from_slice(&qkv[base + q_dim..base + q_dim + kv_dim]);
397            v[t * kv_dim..(t + 1) * kv_dim]
398                .copy_from_slice(&qkv[base + q_dim + kv_dim..base + qkv_dim]);
399        }
400    }
401
402    fn fused_silu_mul_split(
403        _ctx: &mut Self::Context,
404        gate_up: &Self::Buffer,
405        out: &mut Self::Buffer,
406        tokens: usize,
407        im: usize,
408    ) {
409        for t in 0..tokens {
410            for i in 0..im {
411                let g = gate_up[t * 2 * im + i];
412                let u = gate_up[t * 2 * im + im + i];
413                out[t * im + i] = (g / (1.0 + (-g).exp())) * u;
414            }
415        }
416    }
417
418    fn qk_norm_rope(
419        _ctx: &mut Self::Context,
420        input: &Self::Buffer,
421        norm_w: &Self::Buffer,
422        cos: &Self::Buffer,
423        sin: &Self::Buffer,
424        output: &mut Self::Buffer,
425        tokens: usize,
426        heads: usize,
427        head_dim: usize,
428        pos_offset: usize,
429        eps: f32,
430        mode: i32,
431    ) {
432        let half = head_dim / 2;
433        let cos_len = cos.len();
434        let sin_len = sin.len();
435        debug_assert_eq!(cos_len, sin_len);
436
437        for t in 0..tokens {
438            let pos = pos_offset + t;
439            for h in 0..heads {
440                // input row: [t, h, :]  stride = heads * head_dim
441                let src_off = (t * heads + h) * head_dim;
442                // output row: [h, t, :]  stride = tokens * head_dim
443                let dst_off = (h * tokens + t) * head_dim;
444
445                // Mode 0: plain transpose.
446                if mode == 0 {
447                    for i in 0..head_dim {
448                        output[dst_off + i] = input[src_off + i];
449                    }
450                    continue;
451                }
452
453                // Optional RMS norm (mode 1 only).
454                let scale = if mode == 1 {
455                    let mut sum_sq = 0.0f32;
456                    for i in 0..head_dim {
457                        sum_sq += input[src_off + i] * input[src_off + i];
458                    }
459                    1.0f32 / (sum_sq / head_dim as f32 + eps).sqrt()
460                } else {
461                    1.0
462                };
463
464                // Apply (norm?) + RoPE to halves, write to head-major output.
465                for i in 0..half {
466                    let (x0_raw, x1_raw) = (input[src_off + i], input[src_off + i + half]);
467                    let (x0, x1) = if mode == 1 {
468                        (
469                            x0_raw * scale * norm_w[i],
470                            x1_raw * scale * norm_w[i + half],
471                        )
472                    } else {
473                        (x0_raw, x1_raw)
474                    };
475                    let c = cos[pos * half + i];
476                    let s = sin[pos * half + i];
477                    output[dst_off + i] = x0 * c - x1 * s;
478                    output[dst_off + i + half] = x1 * c + x0 * s;
479                }
480            }
481        }
482    }
483
484    fn kv_cache_append_head_major(
485        _ctx: &mut Self::Context,
486        cache_k: &mut Self::Buffer,
487        cache_v: &mut Self::Buffer,
488        cache_len: usize,
489        cache_capacity: usize,
490        new_k_head_major: &Self::Buffer,
491        new_v_head_major: &Self::Buffer,
492        new_tokens: usize,
493        nkv: usize,
494        hd: usize,
495    ) {
496        debug_assert!(cache_len + new_tokens <= cache_capacity);
497        debug_assert_eq!(cache_k.len(), nkv * cache_capacity * hd);
498        debug_assert_eq!(cache_v.len(), nkv * cache_capacity * hd);
499        // The source buffers may be sized for `max_tokens` (the prefill-
500        // sized scratch) while only the first `nkv * new_tokens * hd`
501        // entries are valid for this call. Allow >= so reusing scratch
502        // across prefill and decode doesn't trip the assert.
503        debug_assert!(new_k_head_major.len() >= nkv * new_tokens * hd);
504        debug_assert!(new_v_head_major.len() >= nkv * new_tokens * hd);
505
506        for h in 0..nkv {
507            let dst_base = h * cache_capacity * hd + cache_len * hd;
508            let src_base = h * new_tokens * hd;
509            cache_k[dst_base..dst_base + new_tokens * hd]
510                .copy_from_slice(&new_k_head_major[src_base..src_base + new_tokens * hd]);
511            cache_v[dst_base..dst_base + new_tokens * hd]
512                .copy_from_slice(&new_v_head_major[src_base..src_base + new_tokens * hd]);
513        }
514    }
515
516    fn transpose_head_to_token(
517        _ctx: &mut Self::Context,
518        src: &Self::Buffer,
519        dst: &mut Self::Buffer,
520        tokens: usize,
521        heads: usize,
522        dim: usize,
523    ) {
524        for h in 0..heads {
525            for t in 0..tokens {
526                let s = (h * tokens + t) * dim;
527                let d = (t * heads + h) * dim;
528                dst[d..d + dim].copy_from_slice(&src[s..s + dim]);
529            }
530        }
531    }
532
533    fn add_inplace(
534        _ctx: &mut Self::Context,
535        residual: &mut Self::Buffer,
536        x: &Self::Buffer,
537        len: usize,
538    ) {
539        for i in 0..len {
540            residual[i] += x[i];
541        }
542    }
543
544    fn scaled_add_inplace(
545        _ctx: &mut Self::Context,
546        dst: &mut Self::Buffer,
547        src: &Self::Buffer,
548        scale: f32,
549        len: usize,
550    ) {
551        for i in 0..len {
552            dst[i] += scale * src[i];
553        }
554    }
555
556    fn add_bias(
557        _ctx: &mut Self::Context,
558        data: &mut Self::Buffer,
559        bias: &Self::Buffer,
560        rows: usize,
561        cols: usize,
562    ) {
563        debug_assert_eq!(bias.len(), cols);
564        for r in 0..rows {
565            let off = r * cols;
566            for c in 0..cols {
567                data[off + c] += bias[c];
568            }
569        }
570    }
571
572    fn layer_norm(
573        _ctx: &mut Self::Context,
574        x: &Self::Buffer,
575        gamma: &Self::Buffer,
576        beta: &Self::Buffer,
577        eps: f32,
578        out: &mut Self::Buffer,
579        tokens: usize,
580        dim: usize,
581    ) {
582        debug_assert_eq!(gamma.len(), dim);
583        debug_assert_eq!(beta.len(), dim);
584        for t in 0..tokens {
585            let off = t * dim;
586            // Compute mean + variance over `dim` in f64 for stability.
587            let mut mean = 0.0f64;
588            for i in 0..dim {
589                mean += x[off + i] as f64;
590            }
591            mean /= dim as f64;
592            let mut var = 0.0f64;
593            for i in 0..dim {
594                let d = x[off + i] as f64 - mean;
595                var += d * d;
596            }
597            var /= dim as f64;
598            let inv = 1.0f32 / ((var as f32) + eps).sqrt();
599            let mean_f32 = mean as f32;
600            for i in 0..dim {
601                out[off + i] = (x[off + i] - mean_f32) * inv * gamma[i] + beta[i];
602            }
603        }
604    }
605
606    fn gelu(_ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize) {
607        // Exact GELU: 0.5 * x * (1 + erf(x / sqrt(2))).
608        // Uses f64 for erf accuracy (matches torch.nn.functional.gelu default).
609        for i in 0..len {
610            let xi = x[i];
611            out[i] = 0.5 * xi * (1.0 + libm_erf(xi / std::f32::consts::SQRT_2));
612        }
613    }
614
615    fn alloc(len: usize) -> Self::Buffer {
616        vec![0.0f32; len]
617    }
618    fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32> {
619        buf[..len].to_vec()
620    }
621    fn from_slice(data: &[f32]) -> Self::Buffer {
622        data.to_vec()
623    }
624}
625
626// ── Helpers ──────────────────────────────────────────────────────────────
627
628fn dot_product(a: &[f32], b: &[f32]) -> f32 {
629    #[cfg(target_os = "macos")]
630    {
631        let mut result = 0.0f32;
632        unsafe {
633            vDSP_dotpr(a.as_ptr(), 1, b.as_ptr(), 1, &mut result, a.len() as u64);
634        }
635        result
636    }
637    #[cfg(not(target_os = "macos"))]
638    {
639        a.iter().zip(b).map(|(x, y)| x * y).sum()
640    }
641}
642
643#[allow(dead_code)]
644fn apply_rope_impl(
645    data: &mut [f32],
646    tokens: usize,
647    heads: usize,
648    head_dim: usize,
649    half: usize,
650    cos: &[f32],
651    sin: &[f32],
652    positions: &[u32],
653) {
654    for t in 0..tokens {
655        let pos = positions[t] as usize;
656        for h in 0..heads {
657            let base = t * heads * head_dim + h * head_dim;
658            for i in 0..half {
659                let c = cos[pos * half + i];
660                let s = sin[pos * half + i];
661                let x0 = data[base + i];
662                let x1 = data[base + half + i];
663                data[base + i] = x0 * c - x1 * s;
664                data[base + half + i] = x1 * c + x0 * s;
665            }
666        }
667    }
668}
669
670fn cpu_attention(
671    q: &[f32],
672    k: &[f32],
673    v: &[f32],
674    out: &mut [f32],
675    batch: usize,
676    q_len: usize,
677    kv_len: usize,
678    causal: bool,
679    pos_offset: usize,
680    cfg: &AttnConfig,
681) {
682    let nh = cfg.num_heads;
683    let nkv = cfg.num_kv_heads;
684    let d = cfg.head_dim;
685    let n_rep = nh / nkv;
686    let scale = cfg.scale;
687    // Per-head KV stride: 0 (the default) means contiguous (legacy
688    // `kv_cache_append` path reallocates each layer). A non-zero value means
689    // the cache is pre-allocated to `kv_seq_stride` rows per head but only
690    // the first `kv_len` are valid — we skip the rest via `attend_len`.
691    let kv_stride = if cfg.kv_seq_stride > 0 {
692        cfg.kv_seq_stride
693    } else {
694        kv_len
695    };
696
697    for b in 0..batch {
698        for h in 0..nh {
699            let kv_h = h / n_rep;
700            let q_off = (b * nh + h) * q_len * d;
701            let k_off = (b * nkv + kv_h) * kv_stride * d;
702            let v_off = (b * nkv + kv_h) * kv_stride * d;
703            let o_off = (b * nh + h) * q_len * d;
704
705            for qi in 0..q_len {
706                let attend_end = if causal {
707                    (pos_offset + qi + 1).min(kv_len)
708                } else {
709                    kv_len
710                };
711                let attend_start = if causal && cfg.sliding_window > 0 {
712                    attend_end.saturating_sub(cfg.sliding_window)
713                } else {
714                    0
715                };
716                let mut max_score = f32::NEG_INFINITY;
717                let mut sum_exp = 0.0f32;
718                let mut acc = vec![0.0f32; d];
719
720                for ki in attend_start..attend_end {
721                    let mut dot = 0.0f32;
722                    for di in 0..d {
723                        dot += q[q_off + qi * d + di] * k[k_off + ki * d + di];
724                    }
725                    let score = dot * scale;
726                    if score > max_score {
727                        let correction = (max_score - score).exp();
728                        for di in 0..d {
729                            acc[di] *= correction;
730                        }
731                        sum_exp *= correction;
732                        max_score = score;
733                    }
734                    let w = (score - max_score).exp();
735                    sum_exp += w;
736                    for di in 0..d {
737                        acc[di] += w * v[v_off + ki * d + di];
738                    }
739                }
740
741                if sum_exp > 0.0 {
742                    let inv = 1.0 / sum_exp;
743                    for di in 0..d {
744                        out[o_off + qi * d + di] = acc[di] * inv;
745                    }
746                }
747            }
748        }
749    }
750}
751
752/// Minimal error-function approximation (Abramowitz & Stegun 7.1.26),
753/// max error ~1.5e-7 which is comfortably below f32 round-off noise.
754fn libm_erf(x: f32) -> f32 {
755    let sign = if x < 0.0 { -1.0 } else { 1.0 };
756    let x = x.abs();
757    let t = 1.0 / (1.0 + 0.3275911 * x);
758    let y = 1.0
759        - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_72) * t
760            + 0.254_829_6)
761            * t
762            * (-x * x).exp();
763    sign * y
764}