1use crate::error::{MlxError, Result};
47use crate::turboquant::hb_centroid;
48
49#[derive(Debug, Clone, Copy)]
54pub struct TqHbOracleParams {
55 pub num_heads: u32,
56 pub num_kv_heads: u32,
57 pub head_dim: u32,
58 pub kv_seq_len: u32,
59 pub kv_capacity: u32,
60 pub scale: f32,
61 pub mask_type: u32,
62 pub sliding_window: u32,
63 pub softcap: f32,
68 pub ring_start: u32,
69 pub scale_factor_d512: f32,
71 pub codebook_bits: u32,
73}
74
75fn validate(params: &TqHbOracleParams, q_len: usize, k_packed_len: usize, k_norms_len: usize, v_packed_len: usize, v_norms_len: usize, output_len: usize) -> Result<()> {
76 if params.head_dim != 256 && params.head_dim != 512 {
77 return Err(MlxError::InvalidArgument(format!(
78 "tq_oracle: head_dim must be 256 or 512, got {}",
79 params.head_dim
80 )));
81 }
82 if params.num_heads == 0 || params.num_kv_heads == 0 {
83 return Err(MlxError::InvalidArgument(
84 "tq_oracle: num_heads and num_kv_heads must be > 0".into(),
85 ));
86 }
87 if params.num_heads % params.num_kv_heads != 0 {
88 return Err(MlxError::InvalidArgument(format!(
89 "tq_oracle: num_heads ({}) % num_kv_heads ({}) != 0",
90 params.num_heads, params.num_kv_heads
91 )));
92 }
93 if params.kv_seq_len == 0 {
94 return Err(MlxError::InvalidArgument(
95 "tq_oracle: kv_seq_len must be > 0".into(),
96 ));
97 }
98 if params.kv_capacity < params.kv_seq_len {
99 return Err(MlxError::InvalidArgument(format!(
100 "tq_oracle: kv_capacity ({}) < kv_seq_len ({})",
101 params.kv_capacity, params.kv_seq_len
102 )));
103 }
104 if !matches!(params.codebook_bits, 5 | 6 | 8) {
105 return Err(MlxError::InvalidArgument(format!(
106 "tq_oracle: codebook_bits must be 5, 6, or 8, got {}",
107 params.codebook_bits
108 )));
109 }
110 let dk = params.head_dim as usize;
111 let nh = params.num_heads as usize;
112 let nkv = params.num_kv_heads as usize;
113 let cap = params.kv_capacity as usize;
114 let norms_per_pos = if dk == 512 { 2 } else { 1 };
115
116 let need_q = nh * dk;
117 let need_packed = nkv * cap * dk;
118 let need_norms = nkv * cap * norms_per_pos;
119 let need_output = nh * dk;
120
121 if q_len < need_q {
122 return Err(MlxError::InvalidArgument(format!(
123 "tq_oracle: q has {q_len} < {need_q} required"
124 )));
125 }
126 if k_packed_len < need_packed {
127 return Err(MlxError::InvalidArgument(format!(
128 "tq_oracle: k_packed has {k_packed_len} < {need_packed} required"
129 )));
130 }
131 if v_packed_len < need_packed {
132 return Err(MlxError::InvalidArgument(format!(
133 "tq_oracle: v_packed has {v_packed_len} < {need_packed} required"
134 )));
135 }
136 if k_norms_len < need_norms {
137 return Err(MlxError::InvalidArgument(format!(
138 "tq_oracle: k_norms has {k_norms_len} < {need_norms} required"
139 )));
140 }
141 if v_norms_len < need_norms {
142 return Err(MlxError::InvalidArgument(format!(
143 "tq_oracle: v_norms has {v_norms_len} < {need_norms} required"
144 )));
145 }
146 if output_len < need_output {
147 return Err(MlxError::InvalidArgument(format!(
148 "tq_oracle: output has {output_len} < {need_output} required"
149 )));
150 }
151 Ok(())
152}
153
154pub fn flash_attn_vec_tq_hb_oracle(
166 q: &[f32],
167 k_packed: &[u8],
168 k_norms: &[f32],
169 v_packed: &[u8],
170 v_norms: &[f32],
171 output: &mut [f32],
172 params: &TqHbOracleParams,
173) -> Result<()> {
174 validate(
175 params,
176 q.len(),
177 k_packed.len(),
178 k_norms.len(),
179 v_packed.len(),
180 v_norms.len(),
181 output.len(),
182 )?;
183
184 let dk = params.head_dim as usize;
185 let nh = params.num_heads as usize;
186 let nkv = params.num_kv_heads as usize;
187 let kv_seq_len = params.kv_seq_len as usize;
188 let kv_capacity = params.kv_capacity as usize;
189 let ring_start = params.ring_start as usize;
190 let cbits = params.codebook_bits;
191 let heads_per_kv = nh / nkv;
192
193 let window_start_logical: usize = if params.mask_type == 2
195 && params.sliding_window > 0
196 && (kv_seq_len as u32) > params.sliding_window
197 {
198 kv_seq_len - params.sliding_window as usize
199 } else {
200 0
201 };
202
203 let is_d512 = dk == 512;
204 let inv_sqrt_dk: f32 = 1.0_f32 / (dk as f32).sqrt();
205 let inv_sqrt_dv: f32 = inv_sqrt_dk; let sf_d512: f32 = params.scale_factor_d512;
209
210 let neg_inf_proxy: f32 = -65504.0_f32;
213 let mut mask_vec: Vec<f32> = vec![0.0_f32; kv_seq_len];
214 for kv_pos in 0..kv_seq_len {
215 let logical_idx = ((kv_pos as i64 - ring_start as i64).rem_euclid(kv_capacity as i64))
218 as usize;
219 let valid = logical_idx >= window_start_logical && logical_idx < kv_seq_len;
220 mask_vec[kv_pos] = if valid { 0.0_f32 } else { neg_inf_proxy };
221 }
222
223 for h in 0..nh {
226 let kv_head = h / heads_per_kv;
227 let q_offset = h * dk;
228 let q_row: &[f32] = &q[q_offset..q_offset + dk];
229
230 let mut scores: Vec<f32> = vec![neg_inf_proxy; kv_seq_len];
232 for kv_pos in 0..kv_seq_len {
233 if mask_vec[kv_pos] <= neg_inf_proxy {
234 continue;
236 }
237 let k_packed_offset = (kv_head * kv_capacity + kv_pos) * dk;
238 let k_packed_row: &[u8] = &k_packed[k_packed_offset..k_packed_offset + dk];
239
240 let mut dot: f32 = 0.0_f32;
241 if is_d512 {
242 let knorm_offset = (kv_head * kv_capacity + kv_pos) * 2;
244 let n0 = k_norms[knorm_offset];
245 let n1 = k_norms[knorm_offset + 1];
246 let sn0 = n0 / sf_d512;
247 let sn1 = n1 / sf_d512;
248 for d in 0..256 {
250 let centroid = hb_centroid(k_packed_row[d], cbits);
251 dot += q_row[d] * centroid * sn0;
252 }
253 for d in 256..dk {
255 let centroid = hb_centroid(k_packed_row[d], cbits);
256 dot += q_row[d] * centroid * sn1;
257 }
258 } else {
259 let n = k_norms[kv_head * kv_capacity + kv_pos];
261 let sn = n * inv_sqrt_dk;
262 for d in 0..dk {
263 let centroid = hb_centroid(k_packed_row[d], cbits);
264 dot += q_row[d] * centroid * sn;
265 }
266 }
267 scores[kv_pos] = dot * params.scale + mask_vec[kv_pos];
268 }
269
270 let mut m: f32 = f32::NEG_INFINITY;
275 for &s in scores.iter() {
276 if s > m {
277 m = s;
278 }
279 }
280 let all_masked = m <= neg_inf_proxy;
283
284 let mut sum: f32 = 0.0_f32;
285 let mut weights: Vec<f32> = vec![0.0_f32; kv_seq_len];
286 if !all_masked {
287 for (i, &s) in scores.iter().enumerate() {
288 let w = (s - m).exp();
289 weights[i] = w;
290 sum += w;
291 }
292 }
293 let inv_sum: f32 = if sum > 0.0_f32 { 1.0_f32 / sum } else { 0.0_f32 };
294
295 let out_offset = h * dk;
297 for d in 0..dk {
298 output[out_offset + d] = 0.0_f32;
299 }
300
301 if !all_masked {
302 for kv_pos in 0..kv_seq_len {
303 let w = weights[kv_pos];
304 if w == 0.0_f32 {
305 continue;
306 }
307 let v_packed_offset = (kv_head * kv_capacity + kv_pos) * dk;
308 let v_packed_row: &[u8] = &v_packed[v_packed_offset..v_packed_offset + dk];
309
310 if is_d512 {
311 let vnorm_offset = (kv_head * kv_capacity + kv_pos) * 2;
312 let vn0 = v_norms[vnorm_offset];
313 let vn1 = v_norms[vnorm_offset + 1];
314 let sn0 = vn0 / sf_d512;
315 let sn1 = vn1 / sf_d512;
316 for d in 0..256 {
317 let centroid = hb_centroid(v_packed_row[d], cbits);
318 output[out_offset + d] += centroid * sn0 * w;
319 }
320 for d in 256..dk {
321 let centroid = hb_centroid(v_packed_row[d], cbits);
322 output[out_offset + d] += centroid * sn1 * w;
323 }
324 } else {
325 let vn = v_norms[kv_head * kv_capacity + kv_pos];
326 let sn = vn * inv_sqrt_dv;
327 for d in 0..dk {
328 let centroid = hb_centroid(v_packed_row[d], cbits);
329 output[out_offset + d] += centroid * sn * w;
330 }
331 }
332 }
333
334 for d in 0..dk {
336 output[out_offset + d] *= inv_sum;
337 }
338 }
339 }
340
341 Ok(())
342}
343
344#[cfg(test)]
345#[allow(clippy::expect_used, clippy::unwrap_used)]
346mod tests {
347 use super::*;
348 use crate::turboquant::{
349 hb_nearest_centroid, CODEBOOK_HB_5BIT, CODEBOOK_HB_6BIT, CODEBOOK_HB_8BIT,
350 };
351
352 fn encode_row_d256(x: &[f32], bits: u32) -> (Vec<u8>, f32) {
356 let mut rotated = x.to_vec();
357 crate::turboquant::fwht_inplace(&mut rotated).expect("fwht");
358 let norm_sq: f32 = rotated.iter().map(|&v| v * v).sum();
359 let norm = norm_sq.sqrt();
360 if norm < 1e-30 {
361 return (vec![0u8; x.len()], 0.0);
362 }
363 let scale = (x.len() as f32).sqrt() / norm;
364 let mut packed = Vec::with_capacity(x.len());
365 for &v in rotated.iter() {
366 let scaled = v * scale;
367 packed.push(hb_nearest_centroid(scaled, bits));
368 }
369 (packed, norm)
370 }
371
372 fn deterministic_gaussian(seed: u64, n: usize) -> Vec<f32> {
373 let mut state = seed.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
375 let next_u32 = |s: &mut u64| -> u32 {
376 *s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
377 (*s >> 32) as u32
378 };
379 let next_f32 = |s: &mut u64| -> f32 {
380 let bits = next_u32(s);
381 ((bits as f64 + 0.5) / (u32::MAX as f64 + 1.0)) as f32
383 };
384 let mut out = Vec::with_capacity(n);
385 while out.len() < n {
386 let u1 = next_f32(&mut state).max(1e-7).min(1.0 - 1e-7);
387 let u2 = next_f32(&mut state);
388 let r = (-2.0_f32 * u1.ln()).sqrt();
389 let theta = 2.0_f32 * std::f32::consts::PI * u2;
390 out.push(r * theta.cos());
391 if out.len() < n {
392 out.push(r * theta.sin());
393 }
394 }
395 out
396 }
397
398 #[test]
399 fn codebooks_match_metal_shader_constants() {
400 assert!((CODEBOOK_HB_5BIT[0] - (-3.2606790)).abs() < 1e-6);
403 assert!((CODEBOOK_HB_5BIT[31] - 3.2606790).abs() < 1e-6);
404 assert!((CODEBOOK_HB_6BIT[0] - (-3.6996161)).abs() < 1e-6);
406 assert!((CODEBOOK_HB_6BIT[63] - 3.6996161).abs() < 1e-6);
407 assert!((CODEBOOK_HB_8BIT[0] - (-5.0652659)).abs() < 1e-6);
409 assert!((CODEBOOK_HB_8BIT[255] - 5.0652659).abs() < 1e-6);
410 for i in 0..128 {
412 let sum = CODEBOOK_HB_8BIT[i] + CODEBOOK_HB_8BIT[255 - i];
413 assert!(sum.abs() < 1e-5, "8-bit asymmetry at i={i}: {sum}");
414 }
415 }
416
417 #[test]
418 fn hb_centroid_lookup_matches_index() {
419 for &idx in &[0u8, 1u8, 16u8, 31u8] {
421 let v = hb_centroid(idx, 5);
422 assert!((v - CODEBOOK_HB_5BIT[(idx & 0x1F) as usize]).abs() < 1e-7);
423 }
424 for &idx in &[0u8, 1u8, 32u8, 63u8] {
425 let v = hb_centroid(idx, 6);
426 assert!((v - CODEBOOK_HB_6BIT[(idx & 0x3F) as usize]).abs() < 1e-7);
427 }
428 for idx in 0u8..=255u8 {
429 let v = hb_centroid(idx, 8);
430 assert!((v - CODEBOOK_HB_8BIT[idx as usize]).abs() < 1e-7);
431 }
432 }
433
434 #[test]
435 fn hb_centroid_unsupported_bits_returns_zero() {
436 assert_eq!(hb_centroid(0, 4), 0.0);
438 assert_eq!(hb_centroid(255, 7), 0.0);
439 assert_eq!(hb_nearest_centroid(0.0, 4), 0u8);
440 }
441
442 #[test]
443 fn nearest_centroid_finds_closest() {
444 assert_eq!(hb_nearest_centroid(0.005, 8), 128);
452 assert_eq!(hb_nearest_centroid(5.5, 8), 255);
454 assert_eq!(hb_nearest_centroid(-5.5, 8), 0);
456 }
457
458 #[test]
461 fn oracle_single_position_uniform_v_matches_manual() {
462 let head_dim = 256u32;
463 let num_heads = 1u32;
464 let num_kv_heads = 1u32;
465 let kv_capacity = 4u32;
466 let kv_seq_len = 1u32;
467 let bits = 8u32;
468
469 let k_row = deterministic_gaussian(0xC25EED, head_dim as usize);
471 let v_row = deterministic_gaussian(0xC25EED ^ 0xDEADBEEF, head_dim as usize);
472
473 let (k_packed_row, k_norm) = encode_row_d256(&k_row, bits);
474 let (v_packed_row, v_norm) = encode_row_d256(&v_row, bits);
475
476 let mut k_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
478 let mut k_norms = vec![0.0f32; (num_kv_heads * kv_capacity) as usize];
479 let mut v_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
480 let mut v_norms = vec![0.0f32; (num_kv_heads * kv_capacity) as usize];
481
482 for d in 0..head_dim as usize {
483 k_packed[d] = k_packed_row[d];
484 v_packed[d] = v_packed_row[d];
485 }
486 k_norms[0] = k_norm;
487 v_norms[0] = v_norm;
488
489 let mut q = vec![0.0_f32; (num_heads * head_dim) as usize];
491 for d in 0..head_dim as usize {
492 q[d] = 1.0_f32 / (head_dim as f32).sqrt();
493 }
494
495 let params = TqHbOracleParams {
496 num_heads,
497 num_kv_heads,
498 head_dim,
499 kv_seq_len,
500 kv_capacity,
501 scale: 1.0_f32 / (head_dim as f32).sqrt(),
502 mask_type: 0,
503 sliding_window: 0,
504 softcap: 0.0,
505 ring_start: 0,
506 scale_factor_d512: 1.0,
507 codebook_bits: bits,
508 };
509
510 let mut output = vec![0.0_f32; (num_heads * head_dim) as usize];
511 flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut output, ¶ms).expect("oracle ok");
512
513 let inv_sqrt_dk = 1.0_f32 / (head_dim as f32).sqrt();
516 for d in 0..head_dim as usize {
517 let expected = hb_centroid(v_packed_row[d], bits) * v_norm * inv_sqrt_dk;
518 let actual = output[d];
519 let diff = (actual - expected).abs();
520 assert!(
521 diff < 1e-5,
522 "oracle output mismatch at d={d}: expected={expected}, actual={actual}, diff={diff}"
523 );
524 }
525 }
526
527 #[test]
529 fn oracle_is_bit_deterministic() {
530 let head_dim = 256u32;
531 let num_heads = 4u32;
532 let num_kv_heads = 2u32;
533 let kv_capacity = 16u32;
534 let kv_seq_len = 8u32;
535
536 let k_packed: Vec<u8> = (0..(num_kv_heads * kv_capacity * head_dim))
537 .map(|i| (i.wrapping_mul(31) ^ 0xA5) as u8)
538 .collect();
539 let v_packed: Vec<u8> = (0..(num_kv_heads * kv_capacity * head_dim))
540 .map(|i| (i.wrapping_mul(37) ^ 0x5A) as u8)
541 .collect();
542 let k_norms: Vec<f32> = (0..(num_kv_heads * kv_capacity))
543 .map(|i| 1.0 + (i as f32) * 0.01)
544 .collect();
545 let v_norms: Vec<f32> = (0..(num_kv_heads * kv_capacity))
546 .map(|i| 1.0 + (i as f32) * 0.02)
547 .collect();
548 let q: Vec<f32> = deterministic_gaussian(0xBEEF, (num_heads * head_dim) as usize);
549
550 let params = TqHbOracleParams {
551 num_heads,
552 num_kv_heads,
553 head_dim,
554 kv_seq_len,
555 kv_capacity,
556 scale: 0.0625,
557 mask_type: 0,
558 sliding_window: 0,
559 softcap: 0.0,
560 ring_start: 0,
561 scale_factor_d512: 1.0,
562 codebook_bits: 8,
563 };
564
565 let mut out_a = vec![0.0_f32; (num_heads * head_dim) as usize];
566 let mut out_b = vec![0.0_f32; (num_heads * head_dim) as usize];
567 flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut out_a, ¶ms).expect("a");
568 flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut out_b, ¶ms).expect("b");
569
570 for i in 0..out_a.len() {
571 assert_eq!(out_a[i].to_bits(), out_b[i].to_bits(),
572 "non-deterministic at i={i}: a={}, b={}", out_a[i], out_b[i]);
573 }
574 }
575
576 #[test]
579 fn oracle_sliding_window_masks_old_positions() {
580 let head_dim = 256u32;
581 let num_heads = 1u32;
582 let num_kv_heads = 1u32;
583 let kv_capacity = 32u32;
584 let kv_seq_len = 16u32;
585 let sliding_window = 4u32;
586 let bits = 8u32;
587
588 let k_row = deterministic_gaussian(0xCAFE, head_dim as usize);
591 let v_row = deterministic_gaussian(0xBABE, head_dim as usize);
592 let (k_packed_row, k_norm) = encode_row_d256(&k_row, bits);
593 let (v_packed_row, v_norm) = encode_row_d256(&v_row, bits);
594
595 let mut k_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
596 let mut k_norms = vec![0.0f32; (num_kv_heads * kv_capacity) as usize];
597 let mut v_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
598 let mut v_norms = vec![0.0f32; (num_kv_heads * kv_capacity) as usize];
599 for kv_pos in 0..kv_seq_len as usize {
600 let off = kv_pos * head_dim as usize;
601 for d in 0..head_dim as usize {
602 k_packed[off + d] = k_packed_row[d];
603 v_packed[off + d] = v_packed_row[d];
604 }
605 v_norms[kv_pos] = v_norm * (1.0 + kv_pos as f32);
607 k_norms[kv_pos] = k_norm;
608 }
609
610 let mut q = vec![1.0_f32 / (head_dim as f32).sqrt(); (num_heads * head_dim) as usize];
611 crate::turboquant::fwht_inplace(&mut q[..head_dim as usize]).expect("fwht");
613
614 let params = TqHbOracleParams {
615 num_heads,
616 num_kv_heads,
617 head_dim,
618 kv_seq_len,
619 kv_capacity,
620 scale: 1.0_f32 / (head_dim as f32).sqrt(),
621 mask_type: 2,
622 sliding_window,
623 softcap: 0.0,
624 ring_start: 0,
625 scale_factor_d512: 1.0,
626 codebook_bits: bits,
627 };
628
629 let mut out_windowed = vec![0.0_f32; (num_heads * head_dim) as usize];
630 flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut out_windowed, ¶ms).expect("ok");
631
632 let params_no_mask = TqHbOracleParams { mask_type: 0, sliding_window: 0, ..params };
635 let mut out_full = vec![0.0_f32; (num_heads * head_dim) as usize];
636 flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut out_full, ¶ms_no_mask).expect("ok");
637
638 let mut max_diff = 0.0_f32;
640 for i in 0..out_windowed.len() {
641 max_diff = max_diff.max((out_windowed[i] - out_full[i]).abs());
642 }
643 assert!(max_diff > 1e-3, "sliding window had no effect: max_diff={max_diff}");
644 }
645
646 #[test]
649 fn oracle_all_masked_returns_zeros() {
650 let head_dim = 256u32;
651 let num_heads = 1u32;
652 let num_kv_heads = 1u32;
653 let kv_capacity = 4u32;
654 let kv_seq_len = 1u32;
655
656 let k_packed = vec![128u8; (num_kv_heads * kv_capacity * head_dim) as usize];
657 let v_packed = vec![128u8; (num_kv_heads * kv_capacity * head_dim) as usize];
658 let k_norms = vec![1.0f32; (num_kv_heads * kv_capacity) as usize];
659 let v_norms = vec![1.0f32; (num_kv_heads * kv_capacity) as usize];
660 let q = vec![0.5_f32; (num_heads * head_dim) as usize];
661
662 let params = TqHbOracleParams {
664 num_heads,
665 num_kv_heads,
666 head_dim,
667 kv_seq_len,
668 kv_capacity,
669 scale: 1.0,
670 mask_type: 2,
671 sliding_window: kv_seq_len, softcap: 0.0,
673 ring_start: 2,
675 scale_factor_d512: 1.0,
676 codebook_bits: 8,
677 };
678
679 let mut output = vec![1.0_f32; (num_heads * head_dim) as usize]; flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut output, ¶ms).expect("ok");
681 for &v in output.iter() {
682 assert_eq!(v.to_bits(), 0u32, "expected 0.0 in all-masked output, got {v}");
683 }
684 }
685
686 #[test]
689 fn oracle_d512_per_block_norms() {
690 let head_dim = 512u32;
691 let num_heads = 1u32;
692 let num_kv_heads = 1u32;
693 let kv_capacity = 4u32;
694 let kv_seq_len = 1u32;
695 let bits = 8u32;
696 let sf_d512: f32 = 16.0; let k_row = deterministic_gaussian(0x01234567, head_dim as usize);
700 let mut k_b0 = k_row[0..256].to_vec();
701 let mut k_b1 = k_row[256..512].to_vec();
702 crate::turboquant::fwht_inplace(&mut k_b0).expect("fwht");
703 crate::turboquant::fwht_inplace(&mut k_b1).expect("fwht");
704 let n0 = k_b0.iter().map(|&v| v * v).sum::<f32>().sqrt();
705 let n1 = k_b1.iter().map(|&v| v * v).sum::<f32>().sqrt();
706 let mut k_packed_row = vec![0u8; head_dim as usize];
709 for d in 0..256 {
710 let s = k_b0[d] * sf_d512 / n0;
711 k_packed_row[d] = hb_nearest_centroid(s, bits);
712 }
713 for d in 0..256 {
714 let s = k_b1[d] * sf_d512 / n1;
715 k_packed_row[256 + d] = hb_nearest_centroid(s, bits);
716 }
717
718 let mut k_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
719 let mut k_norms = vec![0.0f32; (num_kv_heads * kv_capacity * 2) as usize];
720 for d in 0..head_dim as usize {
722 k_packed[d] = k_packed_row[d];
723 }
724 k_norms[0] = n0;
725 k_norms[1] = n1;
726
727 let v_packed = k_packed.clone();
728 let v_norms = k_norms.clone();
729 let q = vec![1.0_f32 / (head_dim as f32).sqrt(); (num_heads * head_dim) as usize];
730
731 let params = TqHbOracleParams {
732 num_heads,
733 num_kv_heads,
734 head_dim,
735 kv_seq_len,
736 kv_capacity,
737 scale: 1.0 / (head_dim as f32).sqrt(),
738 mask_type: 0,
739 sliding_window: 0,
740 softcap: 0.0,
741 ring_start: 0,
742 scale_factor_d512: sf_d512,
743 codebook_bits: bits,
744 };
745
746 let mut out = vec![0.0f32; (num_heads * head_dim) as usize];
747 flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut out, ¶ms).expect("ok");
748
749 for d in 0..256 {
752 let expected = hb_centroid(k_packed_row[d], bits) * (n0 / sf_d512);
753 assert!((out[d] - expected).abs() < 1e-5,
754 "d512 block0 mismatch d={d}: expected={expected}, actual={}", out[d]);
755 }
756 for d in 256..head_dim as usize {
757 let expected = hb_centroid(k_packed_row[d], bits) * (n1 / sf_d512);
758 assert!((out[d] - expected).abs() < 1e-5,
759 "d512 block1 mismatch d={d}: expected={expected}, actual={}", out[d]);
760 }
761 }
762
763 #[test]
766 fn oracle_gqa_routes_heads_to_correct_kv_head() {
767 let head_dim = 256u32;
768 let num_heads = 8u32;
769 let num_kv_heads = 2u32;
770 let kv_capacity = 4u32;
771 let kv_seq_len = 1u32;
772 let bits = 8u32;
773
774 let k_row = deterministic_gaussian(0x111, head_dim as usize);
777 let v_row = deterministic_gaussian(0x222, head_dim as usize);
778 let (k_packed_row, k_norm) = encode_row_d256(&k_row, bits);
779 let (v_packed_row, v_norm) = encode_row_d256(&v_row, bits);
780
781 let mut k_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
782 let mut k_norms = vec![0.0f32; (num_kv_heads * kv_capacity) as usize];
783 let mut v_packed = vec![0u8; (num_kv_heads * kv_capacity * head_dim) as usize];
784 let mut v_norms = vec![0.0f32; (num_kv_heads * kv_capacity) as usize];
785
786 for d in 0..head_dim as usize {
788 k_packed[d] = k_packed_row[d];
789 v_packed[d] = v_packed_row[d];
790 }
791 k_norms[0] = k_norm;
792 v_norms[0] = v_norm;
793
794 let kv1_off = (kv_capacity * head_dim) as usize;
796 for d in 0..head_dim as usize {
797 k_packed[kv1_off + d] = k_packed_row[d];
798 v_packed[kv1_off + d] = v_packed_row[d];
799 }
800 k_norms[(kv_capacity) as usize] = k_norm;
801 v_norms[(kv_capacity) as usize] = 10.0 * v_norm;
802
803 let q = vec![1.0_f32 / (head_dim as f32).sqrt(); (num_heads * head_dim) as usize];
804 let params = TqHbOracleParams {
805 num_heads,
806 num_kv_heads,
807 head_dim,
808 kv_seq_len,
809 kv_capacity,
810 scale: 1.0 / (head_dim as f32).sqrt(),
811 mask_type: 0,
812 sliding_window: 0,
813 softcap: 0.0,
814 ring_start: 0,
815 scale_factor_d512: 1.0,
816 codebook_bits: bits,
817 };
818
819 let mut out = vec![0.0f32; (num_heads * head_dim) as usize];
820 flash_attn_vec_tq_hb_oracle(&q, &k_packed, &k_norms, &v_packed, &v_norms, &mut out, ¶ms).expect("ok");
821
822 let inv_sqrt_dk = 1.0_f32 / (head_dim as f32).sqrt();
826 let expected_h0 = hb_centroid(v_packed_row[0], bits) * v_norm * inv_sqrt_dk;
827 let expected_h4 = hb_centroid(v_packed_row[0], bits) * (10.0 * v_norm) * inv_sqrt_dk;
828
829 let h0_d0 = out[(0 * head_dim) as usize];
830 let h4_d0 = out[(4 * head_dim) as usize];
831
832 assert!((h0_d0 - expected_h0).abs() < 1e-4,
833 "h0 mismatch: expected={expected_h0}, actual={h0_d0}");
834 assert!((h4_d0 - expected_h4).abs() < 1e-3,
835 "h4 mismatch: expected={expected_h4}, actual={h4_d0}");
836 }
837}