Skip to main content

mlx_native/
tq_oracle.rs

1//! ADR-007 Path C F-0.1: CPU F32 oracle for `flash_attn_vec_tq_hb` decode.
2//!
3//! Mirrors the math in `src/shaders/flash_attn_vec_tq_hb.metal::flash_attn_vec_tq_hb_impl`
4//! exactly, in pure F32, with deterministic serial reduction. Used by the
5//! Path C F-0.2 layer-by-layer divergence audit and by codec roundtrip
6//! correctness gates.
7//!
8//! Inputs match the kernel call signature 1:1 (see
9//! `crate::ops::flash_attn_vec_tq_hb::FlashAttnVecTqHbParams` and
10//! `flash_attn_vec_tq_hb`). `Q` is consumed as F32 (the kernel converts to
11//! `half` in shared memory; the oracle does NOT mirror that precision loss
12//! since it serves as the ground truth — the F16-Q precision gap is one of
13//! the divergence sources F-0.2 measures).
14//!
15//! No GPU dispatch, no Metal, no panics. Pure-Rust, single-threaded, branchless
16//! at the inner loop where possible. Bit-for-bit deterministic across runs at
17//! fixed inputs.
18//!
19//! ## Buffer layouts (must match the kernel exactly)
20//!
21//! - `q`:        `[num_heads, head_dim]` F32 (one query per head, decode).
22//! - `k_packed`: `[num_kv_heads, kv_capacity, head_dim]` U8 (1 byte per element).
23//! - `k_norms`:
24//!     - D=256: `[num_kv_heads, kv_capacity]` F32 (1 norm per position).
25//!     - D=512: `[num_kv_heads, kv_capacity, 2]` F32 (2 per-block norms per
26//!       position, block 0 = coords 0..255, block 1 = coords 256..511).
27//! - `v_packed`, `v_norms`: same as K.
28//! - `output`:   `[num_heads, head_dim]` F32 written by the oracle.
29//!
30//! ## Codec dequant formula (mirrors kernel lines 170-202)
31//!
32//! For position `kv_pos`, head `kv_head`, coordinate `d`:
33//! - D=256: `value = codebook[byte_idx] * (norm * inv_sqrt(DK))`
34//! - D=512: `value = codebook[byte_idx] * (norm[block_idx] / scale_factor_d512)`
35//!     where `block_idx = d / 256` (0 for coords 0..255, 1 for 256..511).
36//!
37//! ## Mask semantics (mirrors kernel lines 295-318)
38//!
39//! - `mask_type == 2 && sliding_window > 0 && kv_seq_len > sliding_window`:
40//!   set `window_start_logical = kv_seq_len - sliding_window`. Otherwise 0.
41//! - For each k_pos in `0..kv_seq_len`:
42//!     - `logical_idx = (k_pos - ring_start + kv_capacity) % kv_capacity`
43//!     - Position is valid iff `logical_idx >= window_start_logical &&
44//!       logical_idx < kv_seq_len`. Otherwise masked (-65504.0).
45
46use crate::error::{MlxError, Result};
47use crate::turboquant::hb_centroid;
48
49/// Parameters for the HB TQ flash attention decode oracle.
50///
51/// Field-for-field mirror of `crate::ops::flash_attn_vec_tq_hb::FlashAttnVecTqHbParams`
52/// — kept independent so the oracle has zero dependency on Metal types.
53#[derive(Debug, Clone, Copy)]
54pub struct TqHbOracleParams {
55    pub num_heads: u32,
56    pub num_kv_heads: u32,
57    pub head_dim: u32,
58    pub kv_seq_len: u32,
59    pub kv_capacity: u32,
60    pub scale: f32,
61    pub mask_type: u32,
62    pub sliding_window: u32,
63    /// Note: present in kernel params but never read in the kernel body.
64    /// Oracle mirrors the kernel by NOT applying softcap. Tracked as F-0
65    /// finding: contractual drift vs `flash_attn_vec.metal` where softcap is
66    /// also documented but unimplemented.
67    pub softcap: f32,
68    pub ring_start: u32,
69    /// Only used when `head_dim == 512`. For D=256 set to any value.
70    pub scale_factor_d512: f32,
71    /// Codebook bit-width: 5, 6, or 8.
72    pub codebook_bits: u32,
73}
74
75fn validate(params: &TqHbOracleParams, q_len: usize, k_packed_len: usize, k_norms_len: usize, v_packed_len: usize, v_norms_len: usize, output_len: usize) -> Result<()> {
76    if params.head_dim != 256 && params.head_dim != 512 {
77        return Err(MlxError::InvalidArgument(format!(
78            "tq_oracle: head_dim must be 256 or 512, got {}",
79            params.head_dim
80        )));
81    }
82    if params.num_heads == 0 || params.num_kv_heads == 0 {
83        return Err(MlxError::InvalidArgument(
84            "tq_oracle: num_heads and num_kv_heads must be > 0".into(),
85        ));
86    }
87    if params.num_heads % params.num_kv_heads != 0 {
88        return Err(MlxError::InvalidArgument(format!(
89            "tq_oracle: num_heads ({}) % num_kv_heads ({}) != 0",
90            params.num_heads, params.num_kv_heads
91        )));
92    }
93    if params.kv_seq_len == 0 {
94        return Err(MlxError::InvalidArgument(
95            "tq_oracle: kv_seq_len must be > 0".into(),
96        ));
97    }
98    if params.kv_capacity < params.kv_seq_len {
99        return Err(MlxError::InvalidArgument(format!(
100            "tq_oracle: kv_capacity ({}) < kv_seq_len ({})",
101            params.kv_capacity, params.kv_seq_len
102        )));
103    }
104    if !matches!(params.codebook_bits, 5 | 6 | 8) {
105        return Err(MlxError::InvalidArgument(format!(
106            "tq_oracle: codebook_bits must be 5, 6, or 8, got {}",
107            params.codebook_bits
108        )));
109    }
110    let dk = params.head_dim as usize;
111    let nh = params.num_heads as usize;
112    let nkv = params.num_kv_heads as usize;
113    let cap = params.kv_capacity as usize;
114    let norms_per_pos = if dk == 512 { 2 } else { 1 };
115
116    let need_q = nh * dk;
117    let need_packed = nkv * cap * dk;
118    let need_norms = nkv * cap * norms_per_pos;
119    let need_output = nh * dk;
120
121    if q_len < need_q {
122        return Err(MlxError::InvalidArgument(format!(
123            "tq_oracle: q has {q_len} < {need_q} required"
124        )));
125    }
126    if k_packed_len < need_packed {
127        return Err(MlxError::InvalidArgument(format!(
128            "tq_oracle: k_packed has {k_packed_len} < {need_packed} required"
129        )));
130    }
131    if v_packed_len < need_packed {
132        return Err(MlxError::InvalidArgument(format!(
133            "tq_oracle: v_packed has {v_packed_len} < {need_packed} required"
134        )));
135    }
136    if k_norms_len < need_norms {
137        return Err(MlxError::InvalidArgument(format!(
138            "tq_oracle: k_norms has {k_norms_len} < {need_norms} required"
139        )));
140    }
141    if v_norms_len < need_norms {
142        return Err(MlxError::InvalidArgument(format!(
143            "tq_oracle: v_norms has {v_norms_len} < {need_norms} required"
144        )));
145    }
146    if output_len < need_output {
147        return Err(MlxError::InvalidArgument(format!(
148            "tq_oracle: output has {output_len} < {need_output} required"
149        )));
150    }
151    Ok(())
152}
153
154/// CPU F32 oracle for `flash_attn_vec_tq_hb` decode.
155///
156/// Computes `output = softmax(Q @ K^T * scale + mask) @ V` where K and V are
157/// dequantized from the byte-packed HB codec on-the-fly per the kernel formula.
158///
159/// Caller is responsible for applying FWHT to `q` BEFORE calling this oracle
160/// (mirrors kernel contract, see `flash_attn_vec_tq_hb.rs:128-130`). Inverse
161/// FWHT of `output` is also the caller's responsibility.
162///
163/// Determinism: bit-identical across runs at fixed inputs (serial reduction,
164/// no parallelism, no NaN sources at validated inputs).
165pub fn flash_attn_vec_tq_hb_oracle(
166    q: &[f32],
167    k_packed: &[u8],
168    k_norms: &[f32],
169    v_packed: &[u8],
170    v_norms: &[f32],
171    output: &mut [f32],
172    params: &TqHbOracleParams,
173) -> Result<()> {
174    validate(
175        params,
176        q.len(),
177        k_packed.len(),
178        k_norms.len(),
179        v_packed.len(),
180        v_norms.len(),
181        output.len(),
182    )?;
183
184    let dk = params.head_dim as usize;
185    let nh = params.num_heads as usize;
186    let nkv = params.num_kv_heads as usize;
187    let kv_seq_len = params.kv_seq_len as usize;
188    let kv_capacity = params.kv_capacity as usize;
189    let ring_start = params.ring_start as usize;
190    let cbits = params.codebook_bits;
191    let heads_per_kv = nh / nkv;
192
193    // Mirror the kernel's window_start_logical computation (lines 295-298).
194    let window_start_logical: usize = if params.mask_type == 2
195        && params.sliding_window > 0
196        && (kv_seq_len as u32) > params.sliding_window
197    {
198        kv_seq_len - params.sliding_window as usize
199    } else {
200        0
201    };
202
203    let is_d512 = dk == 512;
204    let inv_sqrt_dk: f32 = 1.0_f32 / (dk as f32).sqrt();
205    // For D=256, V dequant scales by inv_sqrt_dv where DV=DK in our code path.
206    // Mirror line 418 of the kernel: `const float inv_sqrt_dv = rsqrt(float(DV));`
207    let inv_sqrt_dv: f32 = inv_sqrt_dk; // DV == DK in flash_attn_vec_tq_hb
208    let sf_d512: f32 = params.scale_factor_d512;
209
210    // Pre-compute mask: `mask[kv_pos] = 0.0 if valid, -65504.0 if invalid`.
211    // Same predicate as the kernel (lines 308-318).
212    let neg_inf_proxy: f32 = -65504.0_f32;
213    let mut mask_vec: Vec<f32> = vec![0.0_f32; kv_seq_len];
214    for kv_pos in 0..kv_seq_len {
215        // logical_idx = (kv_pos - ring_start + kv_capacity) % kv_capacity
216        // (signed-safe mod via wrapping arithmetic on u64)
217        let logical_idx = ((kv_pos as i64 - ring_start as i64).rem_euclid(kv_capacity as i64))
218            as usize;
219        let valid = logical_idx >= window_start_logical && logical_idx < kv_seq_len;
220        mask_vec[kv_pos] = if valid { 0.0_f32 } else { neg_inf_proxy };
221    }
222
223    // Per-head SDPA loop. Order: q_head h → kv_head h/heads_per_kv → kv_pos →
224    // accumulate scores[kv_pos] → online softmax → output.
225    for h in 0..nh {
226        let kv_head = h / heads_per_kv;
227        let q_offset = h * dk;
228        let q_row: &[f32] = &q[q_offset..q_offset + dk];
229
230        // Pass 1: Q @ K^T * scale + mask = scores[kv_pos]
231        let mut scores: Vec<f32> = vec![neg_inf_proxy; kv_seq_len];
232        for kv_pos in 0..kv_seq_len {
233            if mask_vec[kv_pos] <= neg_inf_proxy {
234                // Already -65504; leave it.
235                continue;
236            }
237            let k_packed_offset = (kv_head * kv_capacity + kv_pos) * dk;
238            let k_packed_row: &[u8] = &k_packed[k_packed_offset..k_packed_offset + dk];
239
240            let mut dot: f32 = 0.0_f32;
241            if is_d512 {
242                // D=512: per-block norms.
243                let knorm_offset = (kv_head * kv_capacity + kv_pos) * 2;
244                let n0 = k_norms[knorm_offset];
245                let n1 = k_norms[knorm_offset + 1];
246                let sn0 = n0 / sf_d512;
247                let sn1 = n1 / sf_d512;
248                // Block 0: coords 0..256
249                for d in 0..256 {
250                    let centroid = hb_centroid(k_packed_row[d], cbits);
251                    dot += q_row[d] * centroid * sn0;
252                }
253                // Block 1: coords 256..512
254                for d in 256..dk {
255                    let centroid = hb_centroid(k_packed_row[d], cbits);
256                    dot += q_row[d] * centroid * sn1;
257                }
258            } else {
259                // D=256: single norm per position.
260                let n = k_norms[kv_head * kv_capacity + kv_pos];
261                let sn = n * inv_sqrt_dk;
262                for d in 0..dk {
263                    let centroid = hb_centroid(k_packed_row[d], cbits);
264                    dot += q_row[d] * centroid * sn;
265                }
266            }
267            scores[kv_pos] = dot * params.scale + mask_vec[kv_pos];
268        }
269
270        // Pass 2: stable softmax (max-subtraction).
271        // Mirror the kernel's online softmax outcome but use the equivalent
272        // batch form for clarity. Both are deterministic in F32 for our serial
273        // reduction.
274        let mut m: f32 = f32::NEG_INFINITY;
275        for &s in scores.iter() {
276            if s > m {
277                m = s;
278            }
279        }
280        // If every position is masked (m == neg_inf_proxy or worse), the kernel
281        // writes 0.0 to output (inv_S = 0 path, line 475).
282        let all_masked = m <= neg_inf_proxy;
283
284        let mut sum: f32 = 0.0_f32;
285        let mut weights: Vec<f32> = vec![0.0_f32; kv_seq_len];
286        if !all_masked {
287            for (i, &s) in scores.iter().enumerate() {
288                let w = (s - m).exp();
289                weights[i] = w;
290                sum += w;
291            }
292        }
293        let inv_sum: f32 = if sum > 0.0_f32 { 1.0_f32 / sum } else { 0.0_f32 };
294
295        // Pass 3: accumulate weighted V into output[h, :].
296        let out_offset = h * dk;
297        for d in 0..dk {
298            output[out_offset + d] = 0.0_f32;
299        }
300
301        if !all_masked {
302            for kv_pos in 0..kv_seq_len {
303                let w = weights[kv_pos];
304                if w == 0.0_f32 {
305                    continue;
306                }
307                let v_packed_offset = (kv_head * kv_capacity + kv_pos) * dk;
308                let v_packed_row: &[u8] = &v_packed[v_packed_offset..v_packed_offset + dk];
309
310                if is_d512 {
311                    let vnorm_offset = (kv_head * kv_capacity + kv_pos) * 2;
312                    let vn0 = v_norms[vnorm_offset];
313                    let vn1 = v_norms[vnorm_offset + 1];
314                    let sn0 = vn0 / sf_d512;
315                    let sn1 = vn1 / sf_d512;
316                    for d in 0..256 {
317                        let centroid = hb_centroid(v_packed_row[d], cbits);
318                        output[out_offset + d] += centroid * sn0 * w;
319                    }
320                    for d in 256..dk {
321                        let centroid = hb_centroid(v_packed_row[d], cbits);
322                        output[out_offset + d] += centroid * sn1 * w;
323                    }
324                } else {
325                    let vn = v_norms[kv_head * kv_capacity + kv_pos];
326                    let sn = vn * inv_sqrt_dv;
327                    for d in 0..dk {
328                        let centroid = hb_centroid(v_packed_row[d], cbits);
329                        output[out_offset + d] += centroid * sn * w;
330                    }
331                }
332            }
333
334            // Final divide by total softmax denominator (NWG=1 inv_S path).
335            for d in 0..dk {
336                output[out_offset + d] *= inv_sum;
337            }
338        }
339    }
340
341    Ok(())
342}
343
344#[cfg(test)]
345#[allow(clippy::expect_used, clippy::unwrap_used)]
346mod tests {
347    use super::*;
348    use crate::turboquant::{
349        hb_nearest_centroid, CODEBOOK_HB_5BIT, CODEBOOK_HB_6BIT, CODEBOOK_HB_8BIT,
350    };
351
352    /// Helper: encode a single F32 row (head_dim=256) via the CPU encode path.
353    /// FWHT + L2 norm + per-element nearest-centroid lookup. Mirrors what the
354    /// `hadamard_quantize_kv_hb_d256` Metal kernel produces.
355    fn encode_row_d256(x: &[f32], bits: u32) -> (Vec<u8>, f32) {
356        let mut rotated = x.to_vec();
357        crate::turboquant::fwht_inplace(&mut rotated).expect("fwht");
358        let norm_sq: f32 = rotated.iter().map(|&v| v * v).sum();
359        let norm = norm_sq.sqrt();
360        if norm < 1e-30 {
361            return (vec![0u8; x.len()], 0.0);
362        }
363        let scale = (x.len() as f32).sqrt() / norm;
364        let mut packed = Vec::with_capacity(x.len());
365        for &v in rotated.iter() {
366            let scaled = v * scale;
367            packed.push(hb_nearest_centroid(scaled, bits));
368        }
369        (packed, norm)
370    }
371
372    fn deterministic_gaussian(seed: u64, n: usize) -> Vec<f32> {
373        // Box-Muller from a seeded LCG. Deterministic, no dependencies.
374        let mut state = seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
375        let next_u32 = |s: &mut u64| -> u32 {
376            *s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
377            (*s >> 32) as u32
378        };
379        let next_f32 = |s: &mut u64| -> f32 {
380            let bits = next_u32(s);
381            // Open-interval (0, 1).
382            ((bits as f64 + 0.5) / (u32::MAX as f64 + 1.0)) as f32
383        };
384        let mut out = Vec::with_capacity(n);
385        while out.len() < n {
386            let u1 = next_f32(&mut state).max(1e-7).min(1.0 - 1e-7);
387            let u2 = next_f32(&mut state);
388            let r = (-2.0_f32 * u1.ln()).sqrt();
389            let theta = 2.0_f32 * std::f32::consts::PI * u2;
390            out.push(r * theta.cos());
391            if out.len() < n {
392                out.push(r * theta.sin());
393            }
394        }
395        out
396    }
397
398    #[test]
399    fn codebooks_match_metal_shader_constants() {
400        // Cross-check load-bearing values from the metal shader.
401        // 5-bit endpoints (line 53 / 60 of flash_attn_vec_tq_hb.metal).
402        assert!((CODEBOOK_HB_5BIT[0] - (-3.2606790)).abs() < 1e-6);
403        assert!((CODEBOOK_HB_5BIT[31] - 3.2606790).abs() < 1e-6);
404        // 6-bit endpoints.
405        assert!((CODEBOOK_HB_6BIT[0] - (-3.6996161)).abs() < 1e-6);
406        assert!((CODEBOOK_HB_6BIT[63] - 3.6996161).abs() < 1e-6);
407        // 8-bit endpoints (line 92 / 155).
408        assert!((CODEBOOK_HB_8BIT[0] - (-5.0652659)).abs() < 1e-6);
409        assert!((CODEBOOK_HB_8BIT[255] - 5.0652659).abs() < 1e-6);
410        // 8-bit symmetry (declared 3.41e-10 in the shader comment, line 88).
411        for i in 0..128 {
412            let sum = CODEBOOK_HB_8BIT[i] + CODEBOOK_HB_8BIT[255 - i];
413            assert!(sum.abs() < 1e-5, "8-bit asymmetry at i={i}: {sum}");
414        }
415    }
416
417    #[test]
418    fn hb_centroid_lookup_matches_index() {
419        // Spot-check a few indices vs the codebook arrays.
420        for &idx in &[0u8, 1u8, 16u8, 31u8] {
421            let v = hb_centroid(idx, 5);
422            assert!((v - CODEBOOK_HB_5BIT[(idx & 0x1F) as usize]).abs() < 1e-7);
423        }
424        for &idx in &[0u8, 1u8, 32u8, 63u8] {
425            let v = hb_centroid(idx, 6);
426            assert!((v - CODEBOOK_HB_6BIT[(idx & 0x3F) as usize]).abs() < 1e-7);
427        }
428        for idx in 0u8..=255u8 {
429            let v = hb_centroid(idx, 8);
430            assert!((v - CODEBOOK_HB_8BIT[idx as usize]).abs() < 1e-7);
431        }
432    }
433
434    #[test]
435    fn hb_centroid_unsupported_bits_returns_zero() {
436        // No-panic guarantee for invalid bits.
437        assert_eq!(hb_centroid(0, 4), 0.0);
438        assert_eq!(hb_centroid(255, 7), 0.0);
439        assert_eq!(hb_nearest_centroid(0.0, 4), 0u8);
440    }
441
442    #[test]
443    fn nearest_centroid_finds_closest() {
444        // Index 128 is the centroid closest to zero on the 8-bit codebook
445        // (positive side of the symmetric pair).
446        // CODEBOOK_HB_8BIT[127] = -0.0135717, [128] = +0.0135717.
447        // For value 0.005, the nearest is index 128 (positive side, dist 0.008572).
448        // Wait: dist to [127] = abs(0.005 - (-0.0135717)) = 0.0185717
449        //       dist to [128] = abs(0.005 - 0.0135717)   = 0.0085717
450        // So nearest is 128.
451        assert_eq!(hb_nearest_centroid(0.005, 8), 128);
452        // Value 5.5 saturates to the high endpoint (index 255, value 5.0652659).
453        assert_eq!(hb_nearest_centroid(5.5, 8), 255);
454        // Value -5.5 saturates to the low endpoint (index 0).
455        assert_eq!(hb_nearest_centroid(-5.5, 8), 0);
456    }
457
458    /// Sanity: oracle on a single-position cache with a known unit vector
459    /// equals the manually computed attention.
460    #[test]
461    fn oracle_single_position_uniform_v_matches_manual() {
462        let head_dim = 256u32;
463        let num_heads = 1u32;
464        let num_kv_heads = 1u32;
465        let kv_capacity = 4u32;
466        let kv_seq_len = 1u32;
467        let bits = 8u32;
468
469        // Encode a known K row (random gaussian, deterministic).
470        let k_row = deterministic_gaussian(0xC25EED, head_dim as usize);
471        let v_row = deterministic_gaussian(0xC25EED ^ 0xDEADBEEF, head_dim as usize);
472
473        let (k_packed_row, k_norm) = encode_row_d256(&k_row, bits);
474        let (v_packed_row, v_norm) = encode_row_d256(&v_row, bits);
475
476        // Build the cache buffers (positions 1..3 zeroed).
477        let mut k_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
478        let mut k_norms = vec![0.0f32; (num_kv_heads * kv_capacity) as usize];
479        let mut v_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
480        let mut v_norms = vec![0.0f32; (num_kv_heads * kv_capacity) as usize];
481
482        for d in 0..head_dim as usize {
483            k_packed[d] = k_packed_row[d];
484            v_packed[d] = v_packed_row[d];
485        }
486        k_norms[0] = k_norm;
487        v_norms[0] = v_norm;
488
489        // Q is a chosen unit vector (post-FWHT — caller's responsibility).
490        let mut q = vec![0.0_f32; (num_heads * head_dim) as usize];
491        for d in 0..head_dim as usize {
492            q[d] = 1.0_f32 / (head_dim as f32).sqrt();
493        }
494
495        let params = TqHbOracleParams {
496            num_heads,
497            num_kv_heads,
498            head_dim,
499            kv_seq_len,
500            kv_capacity,
501            scale: 1.0_f32 / (head_dim as f32).sqrt(),
502            mask_type: 0,
503            sliding_window: 0,
504            softcap: 0.0,
505            ring_start: 0,
506            scale_factor_d512: 1.0,
507            codebook_bits: bits,
508        };
509
510        let mut output = vec![0.0_f32; (num_heads * head_dim) as usize];
511        flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut output, &params).expect("oracle ok");
512
513        // With a single valid kv_pos, softmax weight = 1.0, so output = V_dequant_row.
514        // Compute expected V_dequant manually: centroid * v_norm * inv_sqrt(DK).
515        let inv_sqrt_dk = 1.0_f32 / (head_dim as f32).sqrt();
516        for d in 0..head_dim as usize {
517            let expected = hb_centroid(v_packed_row[d], bits) * v_norm * inv_sqrt_dk;
518            let actual = output[d];
519            let diff = (actual - expected).abs();
520            assert!(
521                diff < 1e-5,
522                "oracle output mismatch at d={d}: expected={expected}, actual={actual}, diff={diff}"
523            );
524        }
525    }
526
527    /// Determinism: same inputs produce bit-identical outputs across runs.
528    #[test]
529    fn oracle_is_bit_deterministic() {
530        let head_dim = 256u32;
531        let num_heads = 4u32;
532        let num_kv_heads = 2u32;
533        let kv_capacity = 16u32;
534        let kv_seq_len = 8u32;
535
536        let k_packed: Vec<u8> = (0..(num_kv_heads * kv_capacity * head_dim))
537            .map(|i| (i.wrapping_mul(31) ^ 0xA5) as u8)
538            .collect();
539        let v_packed: Vec<u8> = (0..(num_kv_heads * kv_capacity * head_dim))
540            .map(|i| (i.wrapping_mul(37) ^ 0x5A) as u8)
541            .collect();
542        let k_norms: Vec<f32> = (0..(num_kv_heads * kv_capacity))
543            .map(|i| 1.0 + (i as f32) * 0.01)
544            .collect();
545        let v_norms: Vec<f32> = (0..(num_kv_heads * kv_capacity))
546            .map(|i| 1.0 + (i as f32) * 0.02)
547            .collect();
548        let q: Vec<f32> = deterministic_gaussian(0xBEEF, (num_heads * head_dim) as usize);
549
550        let params = TqHbOracleParams {
551            num_heads,
552            num_kv_heads,
553            head_dim,
554            kv_seq_len,
555            kv_capacity,
556            scale: 0.0625,
557            mask_type: 0,
558            sliding_window: 0,
559            softcap: 0.0,
560            ring_start: 0,
561            scale_factor_d512: 1.0,
562            codebook_bits: 8,
563        };
564
565        let mut out_a = vec![0.0_f32; (num_heads * head_dim) as usize];
566        let mut out_b = vec![0.0_f32; (num_heads * head_dim) as usize];
567        flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut out_a, &params).expect("a");
568        flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut out_b, &params).expect("b");
569
570        for i in 0..out_a.len() {
571            assert_eq!(out_a[i].to_bits(), out_b[i].to_bits(),
572                "non-deterministic at i={i}: a={}, b={}", out_a[i], out_b[i]);
573        }
574    }
575
576    /// Sliding-window mask: with mask_type=2 and a small window, only the
577    /// most-recent `sliding_window` positions contribute.
578    #[test]
579    fn oracle_sliding_window_masks_old_positions() {
580        let head_dim = 256u32;
581        let num_heads = 1u32;
582        let num_kv_heads = 1u32;
583        let kv_capacity = 32u32;
584        let kv_seq_len = 16u32;
585        let sliding_window = 4u32;
586        let bits = 8u32;
587
588        // Build cache where positions 0..15 each store a distinguishable V
589        // (we'll set v_norm to identify which position dominated).
590        let k_row = deterministic_gaussian(0xCAFE, head_dim as usize);
591        let v_row = deterministic_gaussian(0xBABE, head_dim as usize);
592        let (k_packed_row, k_norm) = encode_row_d256(&k_row, bits);
593        let (v_packed_row, v_norm) = encode_row_d256(&v_row, bits);
594
595        let mut k_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
596        let mut k_norms = vec![0.0f32; (num_kv_heads * kv_capacity) as usize];
597        let mut v_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
598        let mut v_norms = vec![0.0f32; (num_kv_heads * kv_capacity) as usize];
599        for kv_pos in 0..kv_seq_len as usize {
600            let off = kv_pos * head_dim as usize;
601            for d in 0..head_dim as usize {
602                k_packed[off + d] = k_packed_row[d];
603                v_packed[off + d] = v_packed_row[d];
604            }
605            // Make v_norm per-position so the output reveals contributions.
606            v_norms[kv_pos] = v_norm * (1.0 + kv_pos as f32);
607            k_norms[kv_pos] = k_norm;
608        }
609
610        let mut q = vec![1.0_f32 / (head_dim as f32).sqrt(); (num_heads * head_dim) as usize];
611        // Re-FWHT q so it correlates with the encoded K (caller responsibility).
612        crate::turboquant::fwht_inplace(&mut q[..head_dim as usize]).expect("fwht");
613
614        let params = TqHbOracleParams {
615            num_heads,
616            num_kv_heads,
617            head_dim,
618            kv_seq_len,
619            kv_capacity,
620            scale: 1.0_f32 / (head_dim as f32).sqrt(),
621            mask_type: 2,
622            sliding_window,
623            softcap: 0.0,
624            ring_start: 0,
625            scale_factor_d512: 1.0,
626            codebook_bits: bits,
627        };
628
629        let mut out_windowed = vec![0.0_f32; (num_heads * head_dim) as usize];
630        flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut out_windowed, &params).expect("ok");
631
632        // Now disable masking and confirm output differs (sanity that masking
633        // was actually applied).
634        let params_no_mask = TqHbOracleParams { mask_type: 0, sliding_window: 0, ..params };
635        let mut out_full = vec![0.0_f32; (num_heads * head_dim) as usize];
636        flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut out_full, &params_no_mask).expect("ok");
637
638        // The two outputs must differ when sliding_window < kv_seq_len.
639        let mut max_diff = 0.0_f32;
640        for i in 0..out_windowed.len() {
641            max_diff = max_diff.max((out_windowed[i] - out_full[i]).abs());
642        }
643        assert!(max_diff > 1e-3, "sliding window had no effect: max_diff={max_diff}");
644    }
645
646    /// All-masked: when every position is masked out, output should be all zeros.
647    /// Mirrors the kernel's `inv_S = 0` path (line 475).
648    #[test]
649    fn oracle_all_masked_returns_zeros() {
650        let head_dim = 256u32;
651        let num_heads = 1u32;
652        let num_kv_heads = 1u32;
653        let kv_capacity = 4u32;
654        let kv_seq_len = 1u32;
655
656        let k_packed = vec![128u8; (num_kv_heads * kv_capacity * head_dim) as usize];
657        let v_packed = vec![128u8; (num_kv_heads * kv_capacity * head_dim) as usize];
658        let k_norms = vec![1.0f32; (num_kv_heads * kv_capacity) as usize];
659        let v_norms = vec![1.0f32; (num_kv_heads * kv_capacity) as usize];
660        let q = vec![0.5_f32; (num_heads * head_dim) as usize];
661
662        // ring_start outside the valid range so logical_idx is always >= kv_seq_len → masked.
663        let params = TqHbOracleParams {
664            num_heads,
665            num_kv_heads,
666            head_dim,
667            kv_seq_len,
668            kv_capacity,
669            scale: 1.0,
670            mask_type: 2,
671            sliding_window: kv_seq_len, // non-zero so window predicate engages
672            softcap: 0.0,
673            // logical_idx = (0 - 2 + 4) % 4 = 2, which >= kv_seq_len (1) → masked.
674            ring_start: 2,
675            scale_factor_d512: 1.0,
676            codebook_bits: 8,
677        };
678
679        let mut output = vec![1.0_f32; (num_heads * head_dim) as usize]; // pre-fill to detect zeroing
680        flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut output, &params).expect("ok");
681        for &v in output.iter() {
682            assert_eq!(v.to_bits(), 0u32, "expected 0.0 in all-masked output, got {v}");
683        }
684    }
685
686    /// D=512 path: per-block norms. Sanity check that the oracle splits coords
687    /// 0..256 / 256..512 with separate norm scales.
688    #[test]
689    fn oracle_d512_per_block_norms() {
690        let head_dim = 512u32;
691        let num_heads = 1u32;
692        let num_kv_heads = 1u32;
693        let kv_capacity = 4u32;
694        let kv_seq_len = 1u32;
695        let bits = 8u32;
696        let sf_d512: f32 = 16.0; // sqrt(256), matches AmesianX convention; orthogonal to oracle math.
697
698        // Encode a 512-vec as two 256-blocks.
699        let k_row = deterministic_gaussian(0x01234567, head_dim as usize);
700        let mut k_b0 = k_row[0..256].to_vec();
701        let mut k_b1 = k_row[256..512].to_vec();
702        crate::turboquant::fwht_inplace(&mut k_b0).expect("fwht");
703        crate::turboquant::fwht_inplace(&mut k_b1).expect("fwht");
704        let n0 = k_b0.iter().map(|&v| v * v).sum::<f32>().sqrt();
705        let n1 = k_b1.iter().map(|&v| v * v).sum::<f32>().sqrt();
706        // Encode each block as if it's an independent unit vector at sf_d512=16
707        // ⇒ scaled coord = rotated * 16 / norm.
708        let mut k_packed_row = vec![0u8; head_dim as usize];
709        for d in 0..256 {
710            let s = k_b0[d] * sf_d512 / n0;
711            k_packed_row[d] = hb_nearest_centroid(s, bits);
712        }
713        for d in 0..256 {
714            let s = k_b1[d] * sf_d512 / n1;
715            k_packed_row[256 + d] = hb_nearest_centroid(s, bits);
716        }
717
718        let mut k_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
719        let mut k_norms = vec![0.0f32; (num_kv_heads * kv_capacity * 2) as usize];
720        // Position 0 with our encoded row + per-block norms.
721        for d in 0..head_dim as usize {
722            k_packed[d] = k_packed_row[d];
723        }
724        k_norms[0] = n0;
725        k_norms[1] = n1;
726
727        let v_packed = k_packed.clone();
728        let v_norms = k_norms.clone();
729        let q = vec![1.0_f32 / (head_dim as f32).sqrt(); (num_heads * head_dim) as usize];
730
731        let params = TqHbOracleParams {
732            num_heads,
733            num_kv_heads,
734            head_dim,
735            kv_seq_len,
736            kv_capacity,
737            scale: 1.0 / (head_dim as f32).sqrt(),
738            mask_type: 0,
739            sliding_window: 0,
740            softcap: 0.0,
741            ring_start: 0,
742            scale_factor_d512: sf_d512,
743            codebook_bits: bits,
744        };
745
746        let mut out = vec![0.0f32; (num_heads * head_dim) as usize];
747        flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut out, &params).expect("ok");
748
749        // Single position → output[d] = V_dequant_row[d].
750        // Block 0: centroid * (n0 / sf_d512). Block 1: centroid * (n1 / sf_d512).
751        for d in 0..256 {
752            let expected = hb_centroid(k_packed_row[d], bits) * (n0 / sf_d512);
753            assert!((out[d] - expected).abs() < 1e-5,
754                "d512 block0 mismatch d={d}: expected={expected}, actual={}", out[d]);
755        }
756        for d in 256..head_dim as usize {
757            let expected = hb_centroid(k_packed_row[d], bits) * (n1 / sf_d512);
758            assert!((out[d] - expected).abs() < 1e-5,
759                "d512 block1 mismatch d={d}: expected={expected}, actual={}", out[d]);
760        }
761    }
762
763    /// GQA: with num_heads=8, num_kv_heads=2, heads_per_kv=4. Heads 0..3 share
764    /// kv_head=0, heads 4..7 share kv_head=1.
765    #[test]
766    fn oracle_gqa_routes_heads_to_correct_kv_head() {
767        let head_dim = 256u32;
768        let num_heads = 8u32;
769        let num_kv_heads = 2u32;
770        let kv_capacity = 4u32;
771        let kv_seq_len = 1u32;
772        let bits = 8u32;
773
774        // Two distinguishable K/V rows: one per kv_head. We'll set v_norm so
775        // the V row magnitude reveals which kv_head was consulted.
776        let k_row = deterministic_gaussian(0x111, head_dim as usize);
777        let v_row = deterministic_gaussian(0x222, head_dim as usize);
778        let (k_packed_row, k_norm) = encode_row_d256(&k_row, bits);
779        let (v_packed_row, v_norm) = encode_row_d256(&v_row, bits);
780
781        let mut k_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
782        let mut k_norms = vec![0.0f32; (num_kv_heads * kv_capacity) as usize];
783        let mut v_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
784        let mut v_norms = vec![0.0f32; (num_kv_heads * kv_capacity) as usize];
785
786        // kv_head 0 at position 0: v_norm = 1.0 * v_norm
787        for d in 0..head_dim as usize {
788            k_packed[d] = k_packed_row[d];
789            v_packed[d] = v_packed_row[d];
790        }
791        k_norms[0] = k_norm;
792        v_norms[0] = v_norm;
793
794        // kv_head 1 at position 0: v_norm = 10.0 * v_norm (distinguishable scale)
795        let kv1_off = (kv_capacity * head_dim) as usize;
796        for d in 0..head_dim as usize {
797            k_packed[kv1_off + d] = k_packed_row[d];
798            v_packed[kv1_off + d] = v_packed_row[d];
799        }
800        k_norms[(kv_capacity) as usize] = k_norm;
801        v_norms[(kv_capacity) as usize] = 10.0 * v_norm;
802
803        let q = vec![1.0_f32 / (head_dim as f32).sqrt(); (num_heads * head_dim) as usize];
804        let params = TqHbOracleParams {
805            num_heads,
806            num_kv_heads,
807            head_dim,
808            kv_seq_len,
809            kv_capacity,
810            scale: 1.0 / (head_dim as f32).sqrt(),
811            mask_type: 0,
812            sliding_window: 0,
813            softcap: 0.0,
814            ring_start: 0,
815            scale_factor_d512: 1.0,
816            codebook_bits: bits,
817        };
818
819        let mut out = vec![0.0f32; (num_heads * head_dim) as usize];
820        flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut out, &params).expect("ok");
821
822        // Heads 0..3 use kv_head 0 (v_norm baseline).
823        // Heads 4..7 use kv_head 1 (v_norm 10× baseline).
824        // For the first non-zero output dim, ratio of head4/head0 should be ≈ 10.
825        let inv_sqrt_dk = 1.0_f32 / (head_dim as f32).sqrt();
826        let expected_h0 = hb_centroid(v_packed_row[0], bits) * v_norm * inv_sqrt_dk;
827        let expected_h4 = hb_centroid(v_packed_row[0], bits) * (10.0 * v_norm) * inv_sqrt_dk;
828
829        let h0_d0 = out[(0 * head_dim) as usize];
830        let h4_d0 = out[(4 * head_dim) as usize];
831
832        assert!((h0_d0 - expected_h0).abs() < 1e-4,
833            "h0 mismatch: expected={expected_h0}, actual={h0_d0}");
834        assert!((h4_d0 - expected_h4).abs() < 1e-3,
835            "h4 mismatch: expected={expected_h4}, actual={h4_d0}");
836    }
837}