Skip to main content

oxibonsai_model/layers/
flash_decode.rs

1//! Flash Decoding: parallelized decode-phase attention.
2//!
3//! During inference decoding, we have Q=[1, h, d] and K/V=[S, h, d] where S can be large.
4//! Flash Decoding splits the KV sequence into tiles and computes partial softmax in parallel,
5//! then combines using the log-sum-exp trick.
6//!
7//! References: Dao et al. 2023 — "FlashDecoding++"
8
9use rayon::prelude::*;
10
11// ─── FlashDecodeConfig ────────────────────────────────────────────────────────
12
13/// Configuration for flash decoding.
14#[derive(Debug, Clone)]
15pub struct FlashDecodeConfig {
16    /// Number of tiles to split the KV sequence into.
17    pub num_tiles: usize,
18    /// Scale factor for attention scores: 1/sqrt(head_dim).
19    pub scale: f32,
20}
21
22impl FlashDecodeConfig {
23    /// Create a config with default num_tiles=4 and scale=1/sqrt(head_dim).
24    pub fn new(head_dim: usize) -> Self {
25        let scale = if head_dim > 0 {
26            1.0_f32 / (head_dim as f32).sqrt()
27        } else {
28            1.0_f32
29        };
30        Self {
31            num_tiles: 4,
32            scale,
33        }
34    }
35
36    /// Set the number of tiles.
37    #[must_use]
38    pub fn with_num_tiles(mut self, n: usize) -> Self {
39        self.num_tiles = n;
40        self
41    }
42}
43
44// ─── flash_decode_tile ───────────────────────────────────────────────────────
45
46/// Compute partial attention output for a single tile.
47///
48/// Returns `(output_tile, max_score, log_sum_exp)`.
49///
50/// The log-sum-exp trick:
51/// ```text
52/// m  = max(scores)
53/// sum = Σ exp(score_i - m)
54/// lse = m + ln(sum)
55/// output = Σ (exp(score_i - m) / sum) * v_i
56/// ```
57fn flash_decode_tile(
58    query: &[f32],
59    keys_tile: &[f32],
60    values_tile: &[f32],
61    tile_len: usize,
62    head_dim: usize,
63    scale: f32,
64) -> (Vec<f32>, f32, f32) {
65    // Compute dot-product scores
66    let mut scores: Vec<f32> = (0..tile_len)
67        .map(|t| {
68            let k_start = t * head_dim;
69            let k_vec = &keys_tile[k_start..k_start + head_dim];
70            query
71                .iter()
72                .zip(k_vec.iter())
73                .map(|(q, k)| q * k)
74                .sum::<f32>()
75                * scale
76        })
77        .collect();
78
79    // Find max for numerical stability (log-sum-exp trick)
80    let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
81
82    if !max_score.is_finite() {
83        // All -inf: return zero output
84        return (
85            vec![0.0_f32; head_dim],
86            f32::NEG_INFINITY,
87            f32::NEG_INFINITY,
88        );
89    }
90
91    // Shift scores and compute softmax weights
92    for s in scores.iter_mut() {
93        *s = (*s - max_score).exp();
94    }
95    let sum: f32 = scores.iter().sum();
96    let log_sum_exp = max_score + sum.ln();
97
98    // Weighted sum of values
99    let mut output = vec![0.0_f32; head_dim];
100    for (t, &w) in scores.iter().enumerate() {
101        let v_start = t * head_dim;
102        let v_vec = &values_tile[v_start..v_start + head_dim];
103        for d in 0..head_dim {
104            output[d] += w * v_vec[d];
105        }
106    }
107
108    // Normalize by sum
109    if sum > 0.0 {
110        for o in output.iter_mut() {
111            *o /= sum;
112        }
113    }
114
115    (output, max_score, log_sum_exp)
116}
117
118// ─── combine_tile_outputs ────────────────────────────────────────────────────
119
120/// Combine tile outputs using log-sum-exp reduction.
121///
122/// Each tile has a partial output, max score, and log-sum-exp value.
123/// The final output is the weighted combination where each tile is weighted
124/// by `exp(lse_i - global_lse)`.
125fn combine_tile_outputs(
126    tile_outputs: &[Vec<f32>],
127    tile_max_scores: &[f32],
128    tile_lse: &[f32],
129    head_dim: usize,
130) -> Vec<f32> {
131    debug_assert_eq!(tile_outputs.len(), tile_lse.len());
132    debug_assert_eq!(tile_outputs.len(), tile_max_scores.len());
133
134    if tile_outputs.is_empty() {
135        return vec![0.0_f32; head_dim];
136    }
137    if tile_outputs.len() == 1 {
138        return tile_outputs[0].clone();
139    }
140
141    // Filter out tiles with -inf lse (empty or all-masked tiles)
142    let valid: Vec<usize> = (0..tile_lse.len())
143        .filter(|&i| tile_lse[i].is_finite())
144        .collect();
145
146    if valid.is_empty() {
147        return vec![0.0_f32; head_dim];
148    }
149    if valid.len() == 1 {
150        return tile_outputs[valid[0]].clone();
151    }
152
153    // Global log-sum-exp across tile LSEs
154    let global_lse_max = valid
155        .iter()
156        .map(|&i| tile_lse[i])
157        .fold(f32::NEG_INFINITY, f32::max);
158
159    let global_sum: f32 = valid
160        .iter()
161        .map(|&i| (tile_lse[i] - global_lse_max).exp())
162        .sum();
163    let global_lse = global_lse_max + global_sum.ln();
164
165    // Combine: output = Σ_tile exp(lse_tile - global_lse) * output_tile
166    let mut combined = vec![0.0_f32; head_dim];
167    for &i in &valid {
168        let weight = (tile_lse[i] - global_lse).exp();
169        for d in 0..head_dim {
170            combined[d] += weight * tile_outputs[i][d];
171        }
172    }
173
174    combined
175}
176
177// ─── flash_decode_single_head ─────────────────────────────────────────────────
178
179/// Compute attention for a single query token against full KV cache.
180///
181/// - `query`:   shape `[head_dim]`
182/// - `keys`:    shape `[seq_len * head_dim]` (row-major: token is outer dim)
183/// - `values`:  shape `[seq_len * head_dim]`
184/// - Returns:   shape `[head_dim]`
185pub fn flash_decode_single_head(
186    query: &[f32],
187    keys: &[f32],
188    values: &[f32],
189    seq_len: usize,
190    head_dim: usize,
191    config: &FlashDecodeConfig,
192) -> Result<Vec<f32>, FlashDecodeError> {
193    if seq_len == 0 {
194        return Err(FlashDecodeError::EmptyKv);
195    }
196    if query.len() != head_dim {
197        return Err(FlashDecodeError::DimMismatch {
198            q_dim: query.len(),
199            k_dim: head_dim,
200        });
201    }
202    if keys.len() != seq_len * head_dim {
203        return Err(FlashDecodeError::DimMismatch {
204            q_dim: query.len(),
205            k_dim: keys.len() / seq_len.max(1),
206        });
207    }
208    if values.len() != seq_len * head_dim {
209        return Err(FlashDecodeError::DimMismatch {
210            q_dim: query.len(),
211            k_dim: values.len() / seq_len.max(1),
212        });
213    }
214
215    // Clamp num_tiles to seq_len
216    let num_tiles = config.num_tiles.min(seq_len).max(1);
217
218    let tile_size_base = seq_len / num_tiles;
219    let remainder = seq_len % num_tiles;
220
221    let mut tile_outputs: Vec<Vec<f32>> = Vec::with_capacity(num_tiles);
222    let mut tile_max_scores: Vec<f32> = Vec::with_capacity(num_tiles);
223    let mut tile_lse: Vec<f32> = Vec::with_capacity(num_tiles);
224
225    let mut offset = 0usize;
226    for tile_idx in 0..num_tiles {
227        // Distribute remainder tokens among the first `remainder` tiles
228        let tile_len = tile_size_base + if tile_idx < remainder { 1 } else { 0 };
229        if tile_len == 0 {
230            break;
231        }
232
233        let k_start = offset * head_dim;
234        let k_end = k_start + tile_len * head_dim;
235        let v_start = offset * head_dim;
236        let v_end = v_start + tile_len * head_dim;
237
238        let (out, max_s, lse) = flash_decode_tile(
239            query,
240            &keys[k_start..k_end],
241            &values[v_start..v_end],
242            tile_len,
243            head_dim,
244            config.scale,
245        );
246        tile_outputs.push(out);
247        tile_max_scores.push(max_s);
248        tile_lse.push(lse);
249
250        offset += tile_len;
251    }
252
253    Ok(combine_tile_outputs(
254        &tile_outputs,
255        &tile_max_scores,
256        &tile_lse,
257        head_dim,
258    ))
259}
260
261// ─── flash_decode_multi_head ─────────────────────────────────────────────────
262
263/// Multi-head flash decode: compute attention across all heads in parallel (via rayon).
264///
265/// - `queries`: shape `[num_heads * head_dim]`
266/// - `keys`:    shape `[seq_len * num_heads * head_dim]` (token-major)
267/// - `values`:  same shape as `keys`
268///
269/// Returns flattened `[num_heads * head_dim]`.
270pub fn flash_decode_multi_head(
271    queries: &[f32],
272    keys: &[f32],
273    values: &[f32],
274    num_heads: usize,
275    seq_len: usize,
276    head_dim: usize,
277    config: &FlashDecodeConfig,
278) -> Result<Vec<f32>, FlashDecodeError> {
279    if seq_len == 0 {
280        return Err(FlashDecodeError::EmptyKv);
281    }
282    if queries.len() != num_heads * head_dim {
283        return Err(FlashDecodeError::DimMismatch {
284            q_dim: queries.len(),
285            k_dim: head_dim,
286        });
287    }
288
289    // Re-index keys/values from [seq_len, num_heads, head_dim] to per-head
290    // [seq_len, head_dim] slices for each head.
291    // We build per-head key and value buffers.
292    let per_head_keys: Vec<Vec<f32>> = (0..num_heads)
293        .map(|h| {
294            let mut buf = vec![0.0_f32; seq_len * head_dim];
295            for t in 0..seq_len {
296                let src_start = t * num_heads * head_dim + h * head_dim;
297                let dst_start = t * head_dim;
298                buf[dst_start..dst_start + head_dim]
299                    .copy_from_slice(&keys[src_start..src_start + head_dim]);
300            }
301            buf
302        })
303        .collect();
304
305    let per_head_values: Vec<Vec<f32>> = (0..num_heads)
306        .map(|h| {
307            let mut buf = vec![0.0_f32; seq_len * head_dim];
308            for t in 0..seq_len {
309                let src_start = t * num_heads * head_dim + h * head_dim;
310                let dst_start = t * head_dim;
311                buf[dst_start..dst_start + head_dim]
312                    .copy_from_slice(&values[src_start..src_start + head_dim]);
313            }
314            buf
315        })
316        .collect();
317
318    // Process each head in parallel using rayon
319    let results: Vec<Result<Vec<f32>, FlashDecodeError>> = (0..num_heads)
320        .into_par_iter()
321        .map(|h| {
322            let q_start = h * head_dim;
323            let q_vec = &queries[q_start..q_start + head_dim];
324            flash_decode_single_head(
325                q_vec,
326                &per_head_keys[h],
327                &per_head_values[h],
328                seq_len,
329                head_dim,
330                config,
331            )
332        })
333        .collect();
334
335    // Flatten results
336    let mut output = vec![0.0_f32; num_heads * head_dim];
337    for (h, res) in results.into_iter().enumerate() {
338        let head_out = res?;
339        let start = h * head_dim;
340        output[start..start + head_dim].copy_from_slice(&head_out);
341    }
342
343    Ok(output)
344}
345
346// ─── flash_vs_naive_error ─────────────────────────────────────────────────────
347
348/// Compare flash decode vs naive attention (for testing).
349///
350/// Returns the mean absolute error between the two implementations.
351pub fn flash_vs_naive_error(
352    query: &[f32],
353    keys: &[f32],
354    values: &[f32],
355    seq_len: usize,
356    head_dim: usize,
357) -> Result<f32, FlashDecodeError> {
358    if seq_len == 0 {
359        return Err(FlashDecodeError::EmptyKv);
360    }
361    if query.len() != head_dim {
362        return Err(FlashDecodeError::DimMismatch {
363            q_dim: query.len(),
364            k_dim: head_dim,
365        });
366    }
367
368    // Flash decode output
369    let config = FlashDecodeConfig::new(head_dim);
370    let flash_out = flash_decode_single_head(query, keys, values, seq_len, head_dim, &config)?;
371
372    // Naive attention output
373    let scale = config.scale;
374    let mut scores: Vec<f32> = (0..seq_len)
375        .map(|t| {
376            let k_start = t * head_dim;
377            let k_vec = &keys[k_start..k_start + head_dim];
378            query
379                .iter()
380                .zip(k_vec.iter())
381                .map(|(q, k)| q * k)
382                .sum::<f32>()
383                * scale
384        })
385        .collect();
386
387    // Softmax
388    let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
389    for s in scores.iter_mut() {
390        *s = (*s - max_s).exp();
391    }
392    let sum: f32 = scores.iter().sum();
393    if sum > 0.0 {
394        for s in scores.iter_mut() {
395            *s /= sum;
396        }
397    }
398
399    // Weighted sum of values
400    let mut naive_out = vec![0.0_f32; head_dim];
401    for (t, &w) in scores.iter().enumerate() {
402        let v_start = t * head_dim;
403        for d in 0..head_dim {
404            naive_out[d] += w * values[v_start + d];
405        }
406    }
407
408    // Mean absolute error
409    let mae = flash_out
410        .iter()
411        .zip(naive_out.iter())
412        .map(|(a, b)| (a - b).abs())
413        .sum::<f32>()
414        / head_dim as f32;
415
416    Ok(mae)
417}
418
419// ─── FlashDecodeError ────────────────────────────────────────────────────────
420
421/// Errors from flash decode operations.
422#[derive(Debug, thiserror::Error)]
423pub enum FlashDecodeError {
424    #[error("empty KV sequence")]
425    EmptyKv,
426
427    #[error("dimension mismatch: query has {q_dim}, keys have {k_dim}")]
428    DimMismatch { q_dim: usize, k_dim: usize },
429
430    #[error("num_tiles ({0}) exceeds seq_len ({1})")]
431    TooManyTiles(usize, usize),
432
433    #[error("invalid config: {0}")]
434    InvalidConfig(String),
435}
436
437// ─── Tests ───────────────────────────────────────────────────────────────────
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    fn make_deterministic_data(seq_len: usize, head_dim: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
444        let query: Vec<f32> = (0..head_dim).map(|i| 0.1 * i as f32).collect();
445        let keys: Vec<f32> = (0..seq_len * head_dim)
446            .map(|i| 0.05 * i as f32 + 0.01)
447            .collect();
448        let values: Vec<f32> = (0..seq_len * head_dim)
449            .map(|i| 0.02 * i as f32 + 0.1)
450            .collect();
451        (query, keys, values)
452    }
453
454    #[test]
455    fn flash_decode_config_default() {
456        let head_dim = 64usize;
457        let cfg = FlashDecodeConfig::new(head_dim);
458        let expected_scale = 1.0_f32 / (head_dim as f32).sqrt();
459        assert!(
460            (cfg.scale - expected_scale).abs() < 1e-6,
461            "scale mismatch: {} vs {}",
462            cfg.scale,
463            expected_scale
464        );
465        assert_eq!(cfg.num_tiles, 4);
466    }
467
468    #[test]
469    fn flash_decode_single_head_matches_naive() {
470        let head_dim = 16;
471        let seq_len = 32;
472        let (q, k, v) = make_deterministic_data(seq_len, head_dim);
473        let mae = flash_vs_naive_error(&q, &k, &v, seq_len, head_dim)
474            .expect("flash_vs_naive_error failed");
475        assert!(
476            mae < 1e-5,
477            "MAE between flash and naive exceeds threshold: {mae}"
478        );
479    }
480
481    #[test]
482    fn flash_decode_empty_kv_error() {
483        let head_dim = 8;
484        let config = FlashDecodeConfig::new(head_dim);
485        let q = vec![0.1f32; head_dim];
486        let result = flash_decode_single_head(&q, &[], &[], 0, head_dim, &config);
487        assert!(
488            matches!(result, Err(FlashDecodeError::EmptyKv)),
489            "expected EmptyKv, got {result:?}"
490        );
491    }
492
493    #[test]
494    fn flash_decode_dim_mismatch_error() {
495        let head_dim = 8;
496        let config = FlashDecodeConfig::new(head_dim);
497        // query has wrong length
498        let q = vec![0.1f32; head_dim + 2];
499        let k = vec![0.1f32; head_dim];
500        let v = vec![0.1f32; head_dim];
501        let result = flash_decode_single_head(&q, &k, &v, 1, head_dim, &config);
502        assert!(
503            matches!(result, Err(FlashDecodeError::DimMismatch { .. })),
504            "expected DimMismatch, got {result:?}"
505        );
506    }
507
508    #[test]
509    fn flash_decode_single_token() {
510        // seq_len=1: output should equal value[0] (since softmax of single element = 1.0)
511        let head_dim = 4;
512        let config = FlashDecodeConfig::new(head_dim);
513        let q = vec![1.0f32, 0.0, 0.0, 0.0];
514        let k = vec![0.5f32, 0.5, 0.5, 0.5]; // single key
515        let v = vec![3.0f32, 1.0, 2.0, 4.0]; // single value
516
517        let out = flash_decode_single_head(&q, &k, &v, 1, head_dim, &config)
518            .expect("flash_decode_single_head failed");
519
520        for (i, (&o, &expected)) in out.iter().zip(v.iter()).enumerate() {
521            assert!(
522                (o - expected).abs() < 1e-5,
523                "output[{i}] = {o}, expected {expected}"
524            );
525        }
526    }
527
528    #[test]
529    fn flash_decode_uniform_keys() {
530        // When all keys are identical and uniform queries, output = average of values
531        // Actually: uniform attention weights → output = mean of values per dimension
532        let head_dim = 4;
533        let seq_len = 4;
534        let config = FlashDecodeConfig::new(head_dim);
535        let q = vec![0.1f32; head_dim];
536        let k = vec![0.1f32; seq_len * head_dim]; // identical keys
537
538        // Values: row t has all elements = (t+1) as f32
539        let v: Vec<f32> = (0..seq_len)
540            .flat_map(|t| vec![(t + 1) as f32; head_dim])
541            .collect();
542
543        let out = flash_decode_single_head(&q, &k, &v, seq_len, head_dim, &config)
544            .expect("flash_decode_single_head failed");
545
546        // With uniform keys, all attention weights equal 1/seq_len
547        // Expected output per dim = mean of [1, 2, 3, 4] = 2.5
548        let expected = 2.5_f32;
549        for (i, &o) in out.iter().enumerate() {
550            assert!(
551                (o - expected).abs() < 1e-4,
552                "output[{i}] = {o}, expected {expected}"
553            );
554        }
555    }
556
557    #[test]
558    fn flash_decode_tile_count_1() {
559        let head_dim = 8;
560        let seq_len = 16;
561        let config = FlashDecodeConfig::new(head_dim).with_num_tiles(1);
562        let (q, k, v) = make_deterministic_data(seq_len, head_dim);
563        let result = flash_decode_single_head(&q, &k, &v, seq_len, head_dim, &config);
564        assert!(result.is_ok(), "num_tiles=1 should be valid: {result:?}");
565    }
566
567    #[test]
568    fn flash_decode_tile_count_many() {
569        let head_dim = 8;
570        let seq_len = 16;
571        let config = FlashDecodeConfig::new(head_dim).with_num_tiles(8);
572        let (q, k, v) = make_deterministic_data(seq_len, head_dim);
573        let result = flash_decode_single_head(&q, &k, &v, seq_len, head_dim, &config);
574        assert!(
575            result.is_ok(),
576            "num_tiles=8 with seq_len=16 failed: {result:?}"
577        );
578    }
579
580    #[test]
581    fn flash_vs_naive_error_small() {
582        let head_dim = 32;
583        let seq_len = 64;
584        let (q, k, v) = make_deterministic_data(seq_len, head_dim);
585        let mae = flash_vs_naive_error(&q, &k, &v, seq_len, head_dim)
586            .expect("flash_vs_naive_error failed");
587        assert!(mae < 1e-4, "MAE too large: {mae}");
588    }
589
590    #[test]
591    fn flash_decode_multi_head_shape() {
592        let num_heads = 4;
593        let head_dim = 8;
594        let seq_len = 16;
595        let config = FlashDecodeConfig::new(head_dim);
596
597        let queries = vec![0.1f32; num_heads * head_dim];
598        let keys = vec![0.05f32; seq_len * num_heads * head_dim];
599        let values = vec![0.2f32; seq_len * num_heads * head_dim];
600
601        let out = flash_decode_multi_head(
602            &queries, &keys, &values, num_heads, seq_len, head_dim, &config,
603        )
604        .expect("multi_head flash decode failed");
605
606        assert_eq!(
607            out.len(),
608            num_heads * head_dim,
609            "output shape mismatch: {} vs {}",
610            out.len(),
611            num_heads * head_dim
612        );
613    }
614
615    #[test]
616    fn flash_decode_multi_head_matches_naive_per_head() {
617        let num_heads = 2;
618        let head_dim = 8;
619        let seq_len = 16;
620        let config = FlashDecodeConfig::new(head_dim);
621
622        // Deterministic data
623        let queries: Vec<f32> = (0..num_heads * head_dim).map(|i| 0.1 * i as f32).collect();
624        let keys: Vec<f32> = (0..seq_len * num_heads * head_dim)
625            .map(|i| 0.05 * (i % 17) as f32 + 0.01)
626            .collect();
627        let values: Vec<f32> = (0..seq_len * num_heads * head_dim)
628            .map(|i| 0.02 * (i % 13) as f32 + 0.1)
629            .collect();
630
631        let flash_out = flash_decode_multi_head(
632            &queries, &keys, &values, num_heads, seq_len, head_dim, &config,
633        )
634        .expect("multi_head flash decode failed");
635
636        // Check each head individually against naive attention
637        for h in 0..num_heads {
638            let q_vec = &queries[h * head_dim..(h + 1) * head_dim];
639
640            // Extract per-head K/V
641            let mut k_head = vec![0.0f32; seq_len * head_dim];
642            let mut v_head = vec![0.0f32; seq_len * head_dim];
643            for t in 0..seq_len {
644                let src_k = t * num_heads * head_dim + h * head_dim;
645                let src_v = t * num_heads * head_dim + h * head_dim;
646                let dst = t * head_dim;
647                k_head[dst..dst + head_dim].copy_from_slice(&keys[src_k..src_k + head_dim]);
648                v_head[dst..dst + head_dim].copy_from_slice(&values[src_v..src_v + head_dim]);
649            }
650
651            let naive_config = FlashDecodeConfig::new(head_dim).with_num_tiles(1);
652            let naive_out =
653                flash_decode_single_head(q_vec, &k_head, &v_head, seq_len, head_dim, &naive_config)
654                    .expect("naive single head failed");
655
656            let head_flash = &flash_out[h * head_dim..(h + 1) * head_dim];
657            let mae: f32 = head_flash
658                .iter()
659                .zip(naive_out.iter())
660                .map(|(a, b)| (a - b).abs())
661                .sum::<f32>()
662                / head_dim as f32;
663            assert!(
664                mae < 1e-4,
665                "head {h}: MAE between multi_head flash and single-head naive = {mae}"
666            );
667        }
668    }
669
670    #[test]
671    fn combine_tiles_single_tile() {
672        let head_dim = 4;
673        let tile_out = vec![1.0f32, 2.0, 3.0, 4.0];
674        let combined = combine_tile_outputs(
675            std::slice::from_ref(&tile_out),
676            &[0.5_f32],
677            &[1.0_f32],
678            head_dim,
679        );
680        // Single tile: output == tile output
681        for (i, (&c, &t)) in combined.iter().zip(tile_out.iter()).enumerate() {
682            assert!((c - t).abs() < 1e-5, "combined[{i}] = {c}, expected {t}");
683        }
684    }
685
686    #[test]
687    fn flash_decode_long_sequence() {
688        let head_dim = 16;
689        let seq_len = 128;
690        let config = FlashDecodeConfig::new(head_dim).with_num_tiles(8);
691        let (q, k, v) = make_deterministic_data(seq_len, head_dim);
692        let result = flash_decode_single_head(&q, &k, &v, seq_len, head_dim, &config);
693        assert!(
694            result.is_ok(),
695            "long sequence (seq_len=128) failed: {result:?}"
696        );
697        let out = result.expect("already checked");
698        assert_eq!(out.len(), head_dim);
699        for (i, &o) in out.iter().enumerate() {
700            assert!(o.is_finite(), "output[{i}] = {o} is not finite");
701        }
702    }
703}