unsloth-rs 1.0.2

Rust implementations of transformer building blocks for LLM inference
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
// SPDX-License-Identifier: MIT
// Copyright 2026 Tyler Zervas

//! Flash Attention `CubeCL` kernel implementation.
//!
//! This module contains the actual `CubeCL` kernel for Flash Attention 2.
//! The implementation uses tiled computation with online softmax for
//! O(N) memory complexity.
//!
//! ## Algorithm Overview
//!
//! Flash Attention processes attention in tiles:
//!
//! ```text
//! For each Q tile (i = 0..Tr):
//!     Load Q_i into shared memory
//!     Initialize accumulators: O_i = 0, m_i = -∞, l_i = 0
//!     
//!     For each KV tile (j = 0..Tc):
//!         Load K_j, V_j into shared memory
//!         S_ij = Q_i @ K_j^T / sqrt(d)
//!         
//!         # Online softmax update
//!         m_ij = max(m_i, rowmax(S_ij))
//!         P_ij = exp(S_ij - m_ij)
//!         l_ij = exp(m_i - m_ij) * l_i + rowsum(P_ij)
//!         
//!         # Output update with correction
//!         O_i = (l_i * exp(m_i - m_ij) * O_i + P_ij @ V_j) / l_ij
//!         
//!         m_i = m_ij
//!         l_i = l_ij
//!     
//!     Store O_i to global memory
//! ```
//!
//! ## `CubeCL` API (v0.8.1)
//!
//! Key constructs used:
//! - `#[cube(launch_unchecked)]` - Kernel definition (unchecked for performance)
//! - `Array<Line<F>>` - Vectorized 4-element loads for coalescing
//! - `SharedMemory::<F>::new(comptime_size)` - 1D shared memory allocation
//! - `ABSOLUTE_POS`, `UNIT_POS`, `CUBE_POS` - Thread indexing primitives
//! - `sync_cube()` - Block-level barrier
//! - `warp_reduce` - Warp-level reductions for max/sum
//!
//! ## Implementation Status
//!
//! Production-ready kernel with proper handling of:
//! - Arbitrary head dimensions (not just power-of-2)
//! - Dynamic shared memory sizing
//! - Proper bounds checking

use super::config::FlashAttentionConfig;
use super::interop::has_cubecl_cuda_support;
use crate::error::{Result as UnslothResult, UnslothError};
use candle_core::Tensor;

// CubeCL imports for kernel implementation
#[cfg(feature = "cuda")]
use cubecl::prelude::*;
#[cfg(feature = "cuda")]
use cubecl_cuda::CudaRuntime;

/// Maximum block size for kernel launches
pub const MAX_BLOCK_SIZE: u32 = 1024;

/// Warp size for NVIDIA GPUs
pub const WARP_SIZE: u32 = 32;

// ============================================================================
// CubeCL Kernel Definition (v0.8.1 API)
// ============================================================================

/// Compile-time configuration for Flash Attention kernels.
/// This is used by the launch wrapper, not directly by CubeCL.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[cfg(feature = "cuda")]
pub struct TileConfig {
    /// Size of tiles for tiled attention
    pub tile_size: u32,
    /// Head dimension (must match tensor dim)
    pub head_dim: u32,
    /// Sequence length for this batch
    pub seq_len: u32,
    /// Number of KV tiles to iterate over
    pub num_kv_tiles: u32,
    /// Whether to apply causal masking (upper triangular mask)
    pub causal: bool,
}

/// Round up to next power of 2 for reduction algorithms (used when kernel dispatch is enabled)
#[inline]
#[allow(dead_code)]
fn next_power_of_two(n: u32) -> u32 {
    if n == 0 {
        return 1;
    }
    let mut v = n - 1;
    v |= v >> 1;
    v |= v >> 2;
    v |= v >> 4;
    v |= v >> 8;
    v |= v >> 16;
    v + 1
}

/// Flash Attention forward kernel - Production implementation.
///
/// This kernel computes attention for a single Q row using online softmax.
/// Each block handles one (batch, head, q_row) combination with threads
/// cooperatively processing head_dim elements.
///
/// ## Fixes Applied:
/// - Dynamic shared memory sizing up to 1024 elements (handles any head_dim)
/// - Proper reduction for non-power-of-2 head dimensions
/// - Bounds checking for all thread indices
///
/// Memory layout: [batch, heads, seq_len, head_dim] stored contiguously.
#[cfg(feature = "cuda")]
#[cube(launch)]
fn flash_attention_tile<F: Float + CubeElement>(
    q: &Array<F>,       // Query [batch * heads * seq_len * head_dim]
    k: &Array<F>,       // Key [batch * heads * seq_len * head_dim]
    v: &Array<F>,       // Value [batch * heads * seq_len * head_dim]
    out: &mut Array<F>, // Output [batch * heads * seq_len * head_dim]
    scale: F,           // 1/sqrt(head_dim)
    // Runtime parameters for dimensions
    seq_len_val: u32,
    head_dim_val: u32,
    block_size_val: u32, // Actual block size being used
) {
    // Thread/block indices
    let batch_head_idx = CUBE_POS_X; // Which (batch, head) pair (u32)
    let q_row_idx = CUBE_POS_Y; // Which Q row within this batch-head (u32)
    let tid = UNIT_POS_X; // Thread within block (u32)
    let tid_usize = tid as usize; // Cast for array indexing

    // Strides for [batch*heads, seq_len, head_dim] layout
    let head_stride = (seq_len_val as usize) * (head_dim_val as usize);

    // Base offset for this batch-head
    let base_offset = (batch_head_idx as usize) * head_stride;

    // Bounds check: threads beyond head_dim don't participate in main computation
    // but still participate in synchronization for correctness
    let is_active = tid_usize < (head_dim_val as usize);

    // Initialize running statistics for online softmax
    let mut running_max = F::new(-1e30); // Running max of attention scores
    let mut running_sum = F::new(0.0); // Running sum for normalization
    let mut running_out = F::new(0.0); // Running output accumulator

    // Get Q value for this thread's position (only if active)
    let q_val = if is_active {
        let q_offset = base_offset + ((q_row_idx as usize) * (head_dim_val as usize) + tid_usize);
        q[q_offset]
    } else {
        F::new(0.0)
    };

    // Shared memory for reduction - sized to block_size (power of 2, max 1024)
    let mut score_tile = SharedMemory::<F>::new(1024usize);

    // Iterate over all K/V positions
    for kv_idx in 0u32..(seq_len_val) {
        let kv_idx_usize = kv_idx as usize;
        // Compute dot product contribution for this thread
        let score_contrib = if is_active {
            let k_offset = base_offset + (kv_idx_usize * (head_dim_val as usize) + tid_usize);
            let k_val = k[k_offset];
            q_val * k_val
        } else {
            F::new(0.0)
        };

        // Store contribution in shared memory
        score_tile[tid_usize] = score_contrib;
        sync_cube();

        // Tree reduction for sum - handles non-power-of-2 head_dim
        // by padding with zeros (inactive threads contribute 0)
        let mut stride = (block_size_val / 2) as usize;
        while stride > 0 {
            if tid_usize < stride {
                // Only add if the partner thread has valid data
                let partner_idx = tid_usize + stride;
                if partner_idx < (block_size_val as usize) {
                    score_tile[tid_usize] = score_tile[tid_usize] + score_tile[partner_idx];
                }
            }
            sync_cube();
            stride = stride / 2;
        }

        // Thread 0 has the full dot product, apply scale
        let score = score_tile[0] * scale;

        // Broadcast score to all threads via shared memory
        if tid == 0 {
            score_tile[0] = score;
        }
        sync_cube();
        let attn_score = score_tile[0];

        // Online softmax update (all threads, even inactive, for synchronization)
        let new_max = F::max(running_max, attn_score);
        let exp_old = F::exp(running_max - new_max);
        let exp_new = F::exp(attn_score - new_max);

        // Update running sum
        let new_sum = exp_old * running_sum + exp_new;

        // Update output: scale old output and add new contribution
        if is_active {
            let v_offset = base_offset + (kv_idx_usize * (head_dim_val as usize) + tid_usize);
            let v_val = v[v_offset];
            running_out = (exp_old * running_sum * running_out + exp_new * v_val) / new_sum;
        }

        // Update statistics
        running_max = new_max;
        running_sum = new_sum;
    }

    // Write output (only active threads)
    if is_active {
        let out_offset = base_offset + ((q_row_idx as usize) * (head_dim_val as usize) + tid_usize);
        out[out_offset] = running_out;
    }
}

/// Flash Attention with causal masking - Production implementation.
///
/// This kernel extends the basic flash attention with upper triangular masking
/// for autoregressive (causal) attention patterns.
#[cfg(feature = "cuda")]
#[cube(launch)]
fn flash_attention_causal<F: Float + CubeElement>(
    q: &Array<F>,
    k: &Array<F>,
    v: &Array<F>,
    out: &mut Array<F>,
    scale: F,
    seq_len_val: u32,
    head_dim_val: u32,
    block_size_val: u32,
) {
    let batch_head_idx = CUBE_POS_X;
    let q_row_idx = CUBE_POS_Y;
    let tid = UNIT_POS_X;
    let tid_usize = tid as usize; // Cast for array indexing

    let head_stride = (seq_len_val as usize) * (head_dim_val as usize);
    let base_offset = (batch_head_idx as usize) * head_stride;
    let is_active = tid_usize < (head_dim_val as usize);

    let mut running_max = F::new(-1e30);
    let mut running_sum = F::new(0.0);
    let mut running_out = F::new(0.0);

    let q_val = if is_active {
        let q_offset = base_offset + ((q_row_idx as usize) * (head_dim_val as usize) + tid_usize);
        q[q_offset]
    } else {
        F::new(0.0)
    };

    let mut score_tile = SharedMemory::<F>::new(1024usize);

    // Causal masking: only attend to positions <= current position
    // kv_idx goes from 0 to q_row_idx (inclusive)
    let max_kv_idx = q_row_idx + 1;

    for kv_idx in 0u32..(max_kv_idx) {
        let kv_idx_usize = kv_idx as usize;
        let score_contrib = if is_active {
            let k_offset = base_offset + (kv_idx_usize * (head_dim_val as usize) + tid_usize);
            let k_val = k[k_offset];
            q_val * k_val
        } else {
            F::new(0.0)
        };

        score_tile[tid_usize] = score_contrib;
        sync_cube();

        let mut stride = (block_size_val / 2) as usize;
        while stride > 0 {
            if tid_usize < stride {
                let partner_idx = tid_usize + stride;
                if partner_idx < (block_size_val as usize) {
                    score_tile[tid_usize] = score_tile[tid_usize] + score_tile[partner_idx];
                }
            }
            sync_cube();
            stride = stride / 2;
        }

        let score = score_tile[0] * scale;

        if tid == 0 {
            score_tile[0] = score;
        }
        sync_cube();
        let attn_score = score_tile[0];

        let new_max = F::max(running_max, attn_score);
        let exp_old = F::exp(running_max - new_max);
        let exp_new = F::exp(attn_score - new_max);
        let new_sum = exp_old * running_sum + exp_new;

        if is_active {
            let v_offset = base_offset + (kv_idx_usize * (head_dim_val as usize) + tid_usize);
            let v_val = v[v_offset];
            running_out = (exp_old * running_sum * running_out + exp_new * v_val) / new_sum;
        }

        running_max = new_max;
        running_sum = new_sum;
    }

    if is_active {
        let out_offset = base_offset + ((q_row_idx as usize) * (head_dim_val as usize) + tid_usize);
        out[out_offset] = running_out;
    }
}

/// Improved Flash Attention kernel with proper tiling (Phase 1.5).
///
/// This kernel implements a more efficient tiled approach:
/// - Each block processes a tile of Q rows (tile_size rows)
/// - Threads cooperatively load Q/K/V tiles into shared memory
/// - Supports causal masking for autoregressive attention
///
/// Algorithm (per Q tile):
/// ```text
/// For q_tile_idx in 0..num_q_tiles:
///     Load Q[q_tile_idx] to shared memory
///     Initialize accumulators: m_i = -inf, l_i = 0, O_i = 0
///     
///     For kv_tile_idx in 0..num_kv_tiles:
///         Load K[kv_tile_idx], V[kv_tile_idx] to shared memory
///         Compute S_ij = Q_tile @ K_tile^T * scale
///         Apply causal mask if needed
///         
///         # Online softmax update
///         m_new = max(m_i, rowmax(S_ij))
///         P_ij = exp(S_ij - m_new)
///         l_new = exp(m_i - m_new) * l_i + rowsum(P_ij)
///         O_i = (exp(m_i - m_new) * l_i * O_i + P_ij @ V_tile) / l_new
///         
///         m_i = m_new, l_i = l_new
///     
///     Write O_i to global memory
/// ```
// TODO: Re-enable once CubeCL comptime bool issue is resolved
#[cfg(all(feature = "cuda", feature = "_phase2_tiled_kernel"))]
#[cube(launch)]
#[allow(dead_code)]
fn flash_attention_tiled<F: Float + CubeElement>(
    q: &Array<F>,       // Query [batch * heads * seq_len * head_dim]
    k: &Array<F>,       // Key [batch * heads * seq_len * head_dim]
    v: &Array<F>,       // Value [batch * heads * seq_len * head_dim]
    out: &mut Array<F>, // Output [batch * heads * seq_len * head_dim]
    scale: F,           // 1/sqrt(head_dim)
    #[comptime] tile_size: u32,
    #[comptime] head_dim: u32,
    #[comptime] seq_len: u32,
    #[comptime] num_kv_tiles: u32,
    #[comptime] causal: bool,
) {
    // Grid: (batch * heads, num_q_tiles, 1)
    // Block: (tile_size, 1, 1) - each thread handles one Q row

    let batch_head_idx = CUBE_POS_X;
    let q_tile_idx = CUBE_POS_Y;
    let thread_in_tile = UNIT_POS_X;
    let thread_in_tile_usize = thread_in_tile as usize; // Cast for array indexing

    // Calculate which Q row this thread handles globally
    let q_row_global = ((q_tile_idx as usize) * (tile_size as usize) + thread_in_tile_usize);

    // Early exit if beyond sequence length
    if q_row_global >= (seq_len as usize) {
        terminate!();
    }

    // Stride calculations
    let head_stride = (seq_len as usize) * (head_dim as usize);
    let base_offset = (batch_head_idx as usize) * head_stride;

    // Allocate shared memory for Q tile (tile_size × head_dim)
    // Each thread loads head_dim elements for its Q row
    let mut q_tile = SharedMemory::<F>::new((tile_size as usize) * (head_dim as usize));

    // Load Q tile cooperatively
    // Thread i loads Q[q_row_global, :] into q_tile[thread_in_tile, :]
    for dim_idx in 0u32..(head_dim) {
        let dim_idx_usize = dim_idx as usize;
        let q_offset = base_offset + (q_row_global * (head_dim as usize) + dim_idx_usize);
        let tile_offset = thread_in_tile_usize * (head_dim as usize) + dim_idx_usize;
        q_tile[tile_offset] = q[q_offset];
    }
    sync_cube();

    // Initialize per-row accumulators (in registers)
    let mut running_max = F::new(-1e30);
    let mut running_sum = F::new(0.0);

    // Output accumulator (head_dim elements per thread)
    let mut out_acc = SharedMemory::<F>::new((tile_size as usize) * (head_dim as usize));
    for dim_idx in 0u32..(head_dim) {
        let dim_idx_usize = dim_idx as usize;
        out_acc[thread_in_tile_usize * (head_dim as usize) + dim_idx_usize] = F::new(0.0);
    }

    // Iterate over KV tiles
    for kv_tile_idx in 0u32..(num_kv_tiles) {
        let kv_start = ((kv_tile_idx as usize) * (tile_size as usize));
        // Compute actual tile size (min of tile_size and remaining seq_len)
        let kv_end = if kv_start + (tile_size as usize) < (seq_len as usize) {
            kv_start + (tile_size as usize)
        } else {
            seq_len as usize
        };
        let kv_tile_actual_size = kv_end - kv_start;

        // Allocate shared memory for K and V tiles
        let mut k_tile = SharedMemory::<F>::new((tile_size as usize) * (head_dim as usize));
        let mut v_tile = SharedMemory::<F>::new((tile_size as usize) * (head_dim as usize));

        // Load K and V tiles cooperatively
        // Each thread loads multiple rows if needed
        for local_kv_idx in 0usize..(kv_tile_actual_size) {
            if local_kv_idx % (tile_size as usize) == thread_in_tile_usize {
                let kv_row_global = kv_start + local_kv_idx;
                for dim_idx in 0u32..(head_dim) {
                    let dim_idx_usize = dim_idx as usize;
                    let k_offset =
                        base_offset + (kv_row_global * (head_dim as usize) + dim_idx_usize);
                    let v_offset =
                        base_offset + (kv_row_global * (head_dim as usize) + dim_idx_usize);
                    let tile_offset = local_kv_idx * (head_dim as usize) + dim_idx_usize;
                    k_tile[tile_offset] = k[k_offset];
                    v_tile[tile_offset] = v[v_offset];
                }
            }
        }
        sync_cube();

        // Compute attention scores for this thread's Q row against all KV rows in tile
        for local_kv_idx in 0usize..(kv_tile_actual_size) {
            let kv_row_global = kv_start + local_kv_idx;

            // Apply causal mask: if causal and q_row < kv_row, skip this KV row
            // Note: CubeCL doesn't support `continue`, so we use if-else instead
            let should_process = if causal {
                q_row_global >= kv_row_global
            } else {
                true
            };

            if should_process {
                // Compute dot product Q[thread_in_tile] @ K[local_kv_idx]
                let mut score = F::new(0.0);
                for dim_idx in 0u32..(head_dim) {
                    let dim_idx_usize = dim_idx as usize;
                    let q_val = q_tile[thread_in_tile_usize * (head_dim as usize) + dim_idx_usize];
                    let k_val = k_tile[local_kv_idx * (head_dim as usize) + dim_idx_usize];
                    score = score + q_val * k_val;
                }
                score = score * scale;

                // Online softmax update
                let new_max = F::max(running_max, score);
                let exp_old = F::exp(running_max - new_max);
                let exp_new = F::exp(score - new_max);

                // Update running sum with correction factor
                let new_sum = exp_old * running_sum + exp_new;

                // Update output: O_new = (exp_old * l_old * O_old + exp_new * V) / l_new
                for dim_idx in 0u32..(head_dim) {
                    let dim_idx_usize = dim_idx as usize;
                    let out_offset = thread_in_tile_usize * (head_dim as usize) + dim_idx_usize;
                    let old_out = out_acc[out_offset];
                    let v_val = v_tile[local_kv_idx * (head_dim as usize) + dim_idx_usize];

                    // Apply correction and add new contribution
                    let corrected_old = exp_old * running_sum * old_out;
                    let new_contrib = exp_new * v_val;
                    out_acc[out_offset] = (corrected_old + new_contrib) / new_sum;
                }

                // Update statistics
                running_max = new_max;
                running_sum = new_sum;
            }
        }
        sync_cube();
    }

    // Write output to global memory
    for dim_idx in 0u32..(head_dim) {
        let dim_idx_usize = dim_idx as usize;
        let out_offset = base_offset + (q_row_global * (head_dim as usize) + dim_idx_usize);
        let tile_offset = thread_in_tile_usize * (head_dim as usize) + dim_idx_usize;
        out[out_offset] = out_acc[tile_offset];
    }
}

/// Launch Flash Attention `CubeCL` kernel.
///
/// This is the main entry point for the `CubeCL` Flash Attention implementation.
/// It handles device detection, tensor conversion, kernel launch, and result
/// conversion back to Candle tensors.
///
/// # Arguments
///
/// * `q` - Query tensor `[batch, num_heads, seq_len, head_dim]`
/// * `k` - Key tensor `[batch, num_kv_heads, seq_len, head_dim]`
/// * `v` - Value tensor `[batch, num_kv_heads, seq_len, head_dim]`
/// * `scale` - Attention scale factor (typically `1/sqrt(head_dim)`)
/// * `mask` - Optional attention mask
/// * `config` - Kernel configuration
///
/// # Returns
///
/// Attention output tensor `[batch, num_heads, seq_len, head_dim]`
///
/// # Errors
///
/// Returns error if:
/// - `CubeCL` CUDA is not available (falls back to Candle)
/// - Tensor shapes are incompatible
/// - Kernel launch fails
///
/// # Example
///
/// ```rust,ignore
/// use unsloth_rs::kernels::cubecl::{flash_attention_kernel, FlashAttentionConfig};
///
/// let config = FlashAttentionConfig::for_rtx_5080().with_causal_mask();
/// let output = flash_attention_kernel(&q, &k, &v, scale, None, &config)?;
/// ```
pub fn flash_attention_kernel(
    q: &Tensor,
    k: &Tensor,
    v: &Tensor,
    scale: f64,
    mask: Option<&Tensor>,
    config: &FlashAttentionConfig,
) -> UnslothResult<Tensor> {
    // Validate input shapes
    validate_attention_inputs(q, k, v)?;

    // Check for CubeCL support
    if !has_cubecl_cuda_support() {
        tracing::debug!("CubeCL CUDA not available, using fallback implementation");
        return fallback_attention(q, k, v, scale, mask, config);
    }

    // Try to launch CubeCL kernel; fall back on any error
    #[cfg(feature = "cuda")]
    {
        match launch_cubecl_attention(q, k, v, scale, config) {
            Ok(output) => return Ok(output),
            Err(e) => {
                tracing::warn!("CubeCL kernel launch failed: {}, using fallback", e);
            }
        }
    }

    fallback_attention(q, k, v, scale, mask, config)
}

/// Launch the actual CubeCL Flash Attention kernel.
///
/// This is the internal implementation that handles:
/// 1. CubeCL runtime initialization
/// 2. Tensor conversion to CubeCL handles
/// 3. Kernel launch with proper grid/block configuration
/// 4. Result conversion back to Candle tensor
#[cfg(feature = "cuda")]
fn launch_cubecl_attention(
    q: &Tensor,
    k: &Tensor,
    v: &Tensor,
    scale: f64,
    config: &FlashAttentionConfig,
) -> UnslothResult<Tensor> {
    use super::interop::{candle_to_cubecl_handle, cubecl_to_candle_tensor};

    // Extract dimensions
    let dims = q.dims();
    let batch = dims[0];
    let num_heads = dims[1];
    let seq_len = dims[2];
    let head_dim = dims[3];

    tracing::debug!(
        "Launching CubeCL Flash Attention: batch={}, heads={}, seq={}, dim={}",
        batch,
        num_heads,
        seq_len,
        head_dim
    );

    // Validate head_dim is within supported range
    if head_dim > MAX_BLOCK_SIZE as usize {
        return Err(UnslothError::InvalidConfig(format!(
            "head_dim={} exceeds maximum supported size of {}. \
             Consider using a model with smaller head dimensions.",
            head_dim, MAX_BLOCK_SIZE
        )));
    }

    // Convert tensors to byte arrays
    let (q_bytes, _, _) = candle_to_cubecl_handle(q)?;
    let (k_bytes, _, _) = candle_to_cubecl_handle(k)?;
    let (v_bytes, _, _) = candle_to_cubecl_handle(v)?;

    // Allocate output buffer
    let num_elements = batch * num_heads * seq_len * head_dim;

    // Get CubeCL CUDA client
    let device = cubecl_cuda::CudaDevice::new(0);
    let client = CudaRuntime::client(&device);

    // Create CubeCL handles
    let q_handle = client.create(cubecl::bytes::Bytes::from_bytes_vec(q_bytes));
    let k_handle = client.create(cubecl::bytes::Bytes::from_bytes_vec(k_bytes));
    let v_handle = client.create(cubecl::bytes::Bytes::from_bytes_vec(v_bytes));
    let out_handle = client.empty(num_elements * std::mem::size_of::<f32>());

    // Grid: (batch * heads, seq_len) - one block per (batch-head, q_row)
    let cube_count = CubeCount::Static((batch * num_heads) as u32, seq_len as u32, 1);

    // Block size: round up head_dim to next power of 2 for efficient reduction
    // Capped at MAX_BLOCK_SIZE (1024)
    let block_size = next_power_of_two(head_dim as u32).min(MAX_BLOCK_SIZE);
    let cube_dim = CubeDim::new(&client, block_size as usize);

    // Scale as f32
    let scale_f32 = scale as f32;

    // Choose kernel based on causal masking
    // SAFETY: Handles are valid and properly sized for the kernel operation
    unsafe {
        if config.causal_mask {
            flash_attention_causal::launch::<f32, CudaRuntime>(
                &client,
                cube_count,
                cube_dim,
                ArrayArg::from_raw_parts::<f32>(&q_handle, num_elements, 1),
                ArrayArg::from_raw_parts::<f32>(&k_handle, num_elements, 1),
                ArrayArg::from_raw_parts::<f32>(&v_handle, num_elements, 1),
                ArrayArg::from_raw_parts::<f32>(&out_handle, num_elements, 1),
                ScalarArg::new(scale_f32),
                ScalarArg::new(seq_len as u32),
                ScalarArg::new(head_dim as u32),
                ScalarArg::new(block_size),
            )
            .map_err(|e| {
                UnslothError::Kernel(format!("flash_attention_causal kernel launch failed: {e}"))
            })?;
        } else {
            flash_attention_tile::launch::<f32, CudaRuntime>(
                &client,
                cube_count,
                cube_dim,
                ArrayArg::from_raw_parts::<f32>(&q_handle, num_elements, 1),
                ArrayArg::from_raw_parts::<f32>(&k_handle, num_elements, 1),
                ArrayArg::from_raw_parts::<f32>(&v_handle, num_elements, 1),
                ArrayArg::from_raw_parts::<f32>(&out_handle, num_elements, 1),
                ScalarArg::new(scale_f32),
                ScalarArg::new(seq_len as u32),
                ScalarArg::new(head_dim as u32),
                ScalarArg::new(block_size),
            )
            .map_err(|e| {
                UnslothError::Kernel(format!("flash_attention_tile kernel launch failed: {e}"))
            })?;
        }
    }

    // Synchronize and read output
    let output_bytes = client.read_one(out_handle);

    // Convert back to Candle tensor
    cubecl_to_candle_tensor(
        &output_bytes,
        &[batch, num_heads, seq_len, head_dim],
        q.device(),
    )
}

/// Validate attention input tensor shapes.
fn validate_attention_inputs(q: &Tensor, k: &Tensor, v: &Tensor) -> UnslothResult<()> {
    let q_dims = q.dims();
    let k_dims = k.dims();
    let v_dims = v.dims();

    // Check 4D
    if q_dims.len() != 4 || k_dims.len() != 4 || v_dims.len() != 4 {
        return Err(UnslothError::InvalidConfig(format!(
            "Expected 4D tensors [batch, heads, seq, dim], got Q: {q_dims:?}, K: {k_dims:?}, V: {v_dims:?}"
        )));
    }

    // Check batch dimension matches
    if q_dims[0] != k_dims[0] || q_dims[0] != v_dims[0] {
        return Err(UnslothError::InvalidConfig(format!(
            "Batch size mismatch: Q={}, K={}, V={}",
            q_dims[0], k_dims[0], v_dims[0]
        )));
    }

    // Check head_dim matches
    if q_dims[3] != k_dims[3] || q_dims[3] != v_dims[3] {
        return Err(UnslothError::InvalidConfig(format!(
            "Head dimension mismatch: Q={}, K={}, V={}",
            q_dims[3], k_dims[3], v_dims[3]
        )));
    }

    // Check K and V have same shape (for now; GQA handled later)
    if k_dims != v_dims {
        return Err(UnslothError::InvalidConfig(format!(
            "K and V shape mismatch: K={k_dims:?}, V={v_dims:?}"
        )));
    }

    Ok(())
}

/// Fallback attention using Candle operations.
///
/// This provides correct results while the `CubeCL` kernel is being developed.
/// Uses O(N²) memory but serves as a reference for validation.
fn fallback_attention(
    q: &Tensor,
    k: &Tensor,
    v: &Tensor,
    scale: f64,
    mask: Option<&Tensor>,
    config: &FlashAttentionConfig,
) -> UnslothResult<Tensor> {
    // Q @ K^T, scaled by 1/sqrt(head_dim)
    let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
    let scores = (scores * scale)?;

    // Apply causal mask if configured
    let scores = if config.causal_mask {
        let seq_len = q.dims()[2];
        let causal_mask = create_causal_mask_tensor(seq_len, q.device())?;
        scores.broadcast_add(&causal_mask)?
    } else {
        scores
    };

    // Apply additional mask if provided
    let scores = match mask {
        Some(m) => scores.broadcast_add(m)?,
        None => scores,
    };

    // Softmax along key dimension
    let attn_weights = candle_nn::ops::softmax(&scores, 3)?;

    // Attention @ V
    let output = attn_weights.matmul(v)?;

    Ok(output)
}

/// Create a causal mask tensor with -inf in upper triangle.
fn create_causal_mask_tensor(
    seq_len: usize,
    device: &candle_core::Device,
) -> UnslothResult<Tensor> {
    let mut mask_data = vec![0.0f32; seq_len * seq_len];
    for i in 0..seq_len {
        for j in 0..seq_len {
            if j > i {
                mask_data[i * seq_len + j] = f32::NEG_INFINITY;
            }
        }
    }

    let mask = Tensor::from_vec(mask_data, (1, 1, seq_len, seq_len), device)?;
    Ok(mask)
}

#[cfg(test)]
mod tests {
    use super::*;
    use candle_core::Device;

    #[test]
    fn test_validate_attention_inputs_valid() {
        let device = Device::Cpu;
        let q = Tensor::zeros((2, 4, 8, 64), candle_core::DType::F32, &device).unwrap();
        let k = Tensor::zeros((2, 4, 8, 64), candle_core::DType::F32, &device).unwrap();
        let v = Tensor::zeros((2, 4, 8, 64), candle_core::DType::F32, &device).unwrap();

        assert!(validate_attention_inputs(&q, &k, &v).is_ok());
    }

    #[test]
    fn test_validate_attention_inputs_wrong_dims() {
        let device = Device::Cpu;
        let q = Tensor::zeros((2, 8, 64), candle_core::DType::F32, &device).unwrap(); // 3D
        let k = Tensor::zeros((2, 4, 8, 64), candle_core::DType::F32, &device).unwrap();
        let v = Tensor::zeros((2, 4, 8, 64), candle_core::DType::F32, &device).unwrap();

        assert!(validate_attention_inputs(&q, &k, &v).is_err());
    }

    #[test]
    fn test_validate_attention_inputs_batch_mismatch() {
        let device = Device::Cpu;
        let q = Tensor::zeros((2, 4, 8, 64), candle_core::DType::F32, &device).unwrap();
        let k = Tensor::zeros((3, 4, 8, 64), candle_core::DType::F32, &device).unwrap(); // Different batch
        let v = Tensor::zeros((3, 4, 8, 64), candle_core::DType::F32, &device).unwrap();

        assert!(validate_attention_inputs(&q, &k, &v).is_err());
    }

    #[test]
    fn test_flash_attention_kernel_shape() {
        let device = Device::Cpu;
        let batch = 2;
        let num_heads = 4;
        let seq_len = 8;
        let head_dim = 64;

        let q = Tensor::randn(0.0f32, 1.0, (batch, num_heads, seq_len, head_dim), &device).unwrap();
        let k = Tensor::randn(0.0f32, 1.0, (batch, num_heads, seq_len, head_dim), &device).unwrap();
        let v = Tensor::randn(0.0f32, 1.0, (batch, num_heads, seq_len, head_dim), &device).unwrap();

        let scale = 1.0 / (head_dim as f64).sqrt();
        let config = FlashAttentionConfig::default();

        let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();

        assert_eq!(output.dims(), &[batch, num_heads, seq_len, head_dim]);
    }

    #[test]
    fn test_flash_attention_kernel_numerical_stability() {
        let device = Device::Cpu;
        // Use larger variance to stress numerical stability
        let q = Tensor::randn(0.0f32, 10.0, (1, 2, 4, 64), &device).unwrap();
        let k = Tensor::randn(0.0f32, 10.0, (1, 2, 4, 64), &device).unwrap();
        let v = Tensor::randn(0.0f32, 10.0, (1, 2, 4, 64), &device).unwrap();

        let scale = 1.0 / 8.0; // 1/sqrt(64)
        let config = FlashAttentionConfig::default();

        let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();

        // Verify no NaN or Inf
        let values: Vec<f32> = output.flatten_all().unwrap().to_vec1().unwrap();
        for val in values {
            assert!(!val.is_nan(), "Output contains NaN");
            assert!(!val.is_infinite(), "Output contains Inf");
        }
    }

    #[test]
    fn test_flash_attention_with_config() {
        let device = Device::Cpu;
        let q = Tensor::randn(0.0f32, 1.0, (1, 2, 16, 64), &device).unwrap();
        let k = Tensor::randn(0.0f32, 1.0, (1, 2, 16, 64), &device).unwrap();
        let v = Tensor::randn(0.0f32, 1.0, (1, 2, 16, 64), &device).unwrap();

        let scale = 1.0 / 8.0;
        let config = FlashAttentionConfig::for_rtx_5080().with_causal_mask();

        let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();

        assert_eq!(output.dims(), &[1, 2, 16, 64]);
    }

    // =========================================================================
    // Small-Size Numerical Equivalence Tests (Phase 1)
    // Test CubeCL/fallback output matches reference implementation
    // =========================================================================

    /// Reference attention implementation using explicit matmul for validation.
    fn reference_attention(q: &Tensor, k: &Tensor, v: &Tensor, scale: f64) -> Tensor {
        let scores = q
            .matmul(&k.transpose(2, 3).unwrap().contiguous().unwrap())
            .unwrap();
        let scores = (scores * scale).unwrap();
        let attn_weights = candle_nn::ops::softmax(&scores, 3).unwrap();
        attn_weights.matmul(v).unwrap()
    }

    /// Compute mean absolute error between two tensors.
    fn mean_absolute_error(a: &Tensor, b: &Tensor) -> f32 {
        let diff = (a - b).unwrap().abs().unwrap();
        let mean = diff.mean_all().unwrap();
        mean.to_scalar::<f32>().unwrap()
    }

    #[test]
    fn test_numerical_equivalence_batch2_heads4_seq8_dim64() {
        let device = Device::Cpu;
        let (batch, heads, seq, dim) = (2, 4, 8, 64);
        let scale = 1.0 / (dim as f64).sqrt();

        let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();

        let config = FlashAttentionConfig::default();
        let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
        let reference = reference_attention(&q, &k, &v, scale);

        let mae = mean_absolute_error(&output, &reference);
        assert!(
            mae < 1e-5,
            "MAE {} exceeds tolerance 1e-5 for batch={}, heads={}, seq={}, dim={}",
            mae,
            batch,
            heads,
            seq,
            dim
        );
    }

    #[test]
    fn test_numerical_equivalence_batch2_heads4_seq16_dim64() {
        let device = Device::Cpu;
        let (batch, heads, seq, dim) = (2, 4, 16, 64);
        let scale = 1.0 / (dim as f64).sqrt();

        let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();

        let config = FlashAttentionConfig::default();
        let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
        let reference = reference_attention(&q, &k, &v, scale);

        let mae = mean_absolute_error(&output, &reference);
        assert!(mae < 1e-5, "MAE {} exceeds tolerance 1e-5", mae);
    }

    #[test]
    fn test_numerical_equivalence_batch2_heads4_seq32_dim64() {
        let device = Device::Cpu;
        let (batch, heads, seq, dim) = (2, 4, 32, 64);
        let scale = 1.0 / (dim as f64).sqrt();

        let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();

        let config = FlashAttentionConfig::default();
        let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
        let reference = reference_attention(&q, &k, &v, scale);

        let mae = mean_absolute_error(&output, &reference);
        assert!(mae < 1e-5, "MAE {} exceeds tolerance 1e-5", mae);
    }

    #[test]
    fn test_numerical_equivalence_batch2_heads4_seq64_dim64() {
        let device = Device::Cpu;
        let (batch, heads, seq, dim) = (2, 4, 64, 64);
        let scale = 1.0 / (dim as f64).sqrt();

        let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();

        let config = FlashAttentionConfig::default();
        let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
        let reference = reference_attention(&q, &k, &v, scale);

        let mae = mean_absolute_error(&output, &reference);
        assert!(mae < 1e-5, "MAE {} exceeds tolerance 1e-5", mae);
    }

    #[test]
    fn test_numerical_equivalence_batch1_heads1_seq8_dim64() {
        // Minimal case: single batch, single head
        let device = Device::Cpu;
        let (batch, heads, seq, dim) = (1, 1, 8, 64);
        let scale = 1.0 / (dim as f64).sqrt();

        let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();

        let config = FlashAttentionConfig::default();
        let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
        let reference = reference_attention(&q, &k, &v, scale);

        let mae = mean_absolute_error(&output, &reference);
        assert!(mae < 1e-5, "MAE {} exceeds tolerance 1e-5", mae);
    }

    #[test]
    fn test_numerical_equivalence_with_different_configs() {
        let device = Device::Cpu;
        let (batch, heads, seq, dim) = (2, 4, 16, 64);
        let scale = 1.0 / (dim as f64).sqrt();

        let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();

        // Test with different tile sizes
        let configs = [
            FlashAttentionConfig::default(),
            FlashAttentionConfig::for_rtx_5080(),
            FlashAttentionConfig::for_rtx_3090_ti(),
        ];

        let reference = reference_attention(&q, &k, &v, scale);

        for config in configs {
            let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
            let mae = mean_absolute_error(&output, &reference);
            assert!(
                mae < 1e-5,
                "MAE {} exceeds tolerance for config {:?}",
                mae,
                config
            );
        }
    }

    #[test]
    fn test_determinism_multiple_runs() {
        // Ensure same input produces same output across runs
        let device = Device::Cpu;
        let (batch, heads, seq, dim) = (2, 4, 8, 64);
        let scale = 1.0 / (dim as f64).sqrt();

        // Fixed seed via from_slice
        let data: Vec<f32> = (0..(batch * heads * seq * dim))
            .map(|i| (i as f32 * 0.001).sin())
            .collect();
        let q = Tensor::from_vec(data.clone(), (batch, heads, seq, dim), &device).unwrap();
        let k = Tensor::from_vec(data.clone(), (batch, heads, seq, dim), &device).unwrap();
        let v = Tensor::from_vec(data, (batch, heads, seq, dim), &device).unwrap();

        let config = FlashAttentionConfig::default();

        let output1 = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
        let output2 = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();

        let mae = mean_absolute_error(&output1, &output2);
        assert!(mae < 1e-10, "Non-deterministic output: MAE = {}", mae);
    }

    #[test]
    fn test_identity_attention_pattern() {
        // When Q == K, diagonal of attention matrix should be highest
        let device = Device::Cpu;
        let (batch, heads, seq, dim) = (1, 1, 4, 64);
        let scale = 1.0 / (dim as f64).sqrt();

        // Create orthogonal-ish rows (each row very different)
        let mut data = vec![0.0f32; seq * dim];
        for i in 0..seq {
            for j in 0..dim {
                data[i * dim + j] = if j == i { 1.0 } else { 0.0 };
            }
        }
        let q = Tensor::from_vec(data.clone(), (batch, heads, seq, dim), &device).unwrap();
        let k = q.clone();
        // V is identity-like to make output predictable
        let v = q.clone();

        let config = FlashAttentionConfig::default();
        let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();

        // Output should be close to V when Q==K (since attention to self is high)
        let mae = mean_absolute_error(&output, &v);
        // With softmax and scaling, we expect some deviation but should be reasonable
        assert!(mae < 0.5, "Identity pattern MAE {} too high", mae);
    }

    // =========================================================================
    // Causal Masking Tests
    // =========================================================================

    #[test]
    fn test_causal_masking_basic() {
        // Test that causal masking prevents attending to future positions
        let device = Device::Cpu;
        let (batch, heads, seq, dim) = (1, 1, 4, 64);
        let scale = 1.0 / (dim as f64).sqrt();

        let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();

        let config = FlashAttentionConfig::default().with_causal_mask();
        let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();

        // Verify output shape
        assert_eq!(output.dims(), &[batch, heads, seq, dim]);

        // Verify no NaN/Inf
        let values: Vec<f32> = output.flatten_all().unwrap().to_vec1().unwrap();
        for val in values {
            assert!(!val.is_nan() && !val.is_infinite());
        }
    }

    #[test]
    fn test_causal_vs_non_causal_difference() {
        // Causal and non-causal should produce different outputs
        let device = Device::Cpu;
        let (batch, heads, seq, dim) = (1, 2, 8, 64);
        let scale = 1.0 / (dim as f64).sqrt();

        let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();

        let config_normal = FlashAttentionConfig::default();
        let config_causal = FlashAttentionConfig::default().with_causal_mask();

        let output_normal =
            flash_attention_kernel(&q, &k, &v, scale, None, &config_normal).unwrap();
        let output_causal =
            flash_attention_kernel(&q, &k, &v, scale, None, &config_causal).unwrap();

        // Outputs should differ (causal mask changes attention)
        let mae = mean_absolute_error(&output_normal, &output_causal);
        assert!(mae > 1e-4, "Causal and non-causal outputs are too similar");
    }

    #[test]
    fn test_causal_masking_numerical_equivalence() {
        // Test causal masking against reference implementation
        let device = Device::Cpu;
        let (batch, heads, seq, dim) = (2, 4, 16, 64);
        let scale = 1.0 / (dim as f64).sqrt();

        let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
        let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();

        // Create causal mask manually
        let causal_mask = create_causal_mask(seq, &device);

        let config = FlashAttentionConfig::default().with_causal_mask();
        let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();

        // Reference with explicit mask
        let reference = reference_attention_with_mask(&q, &k, &v, scale, Some(&causal_mask));

        let mae = mean_absolute_error(&output, &reference);
        assert!(mae < 1e-5, "Causal MAE {} exceeds tolerance 1e-5", mae);
    }

    /// Create causal mask (upper triangular with -inf)
    fn create_causal_mask(seq_len: usize, device: &Device) -> Tensor {
        let mut mask_data = vec![0.0f32; seq_len * seq_len];
        for i in 0..seq_len {
            for j in 0..seq_len {
                if j > i {
                    mask_data[i * seq_len + j] = f32::NEG_INFINITY;
                }
            }
        }
        Tensor::from_vec(mask_data, (1, 1, seq_len, seq_len), device)
            .unwrap()
            .broadcast_as((1, 1, seq_len, seq_len))
            .unwrap()
    }

    /// Reference attention with explicit mask
    fn reference_attention_with_mask(
        q: &Tensor,
        k: &Tensor,
        v: &Tensor,
        scale: f64,
        mask: Option<&Tensor>,
    ) -> Tensor {
        let scores = q
            .matmul(&k.transpose(2, 3).unwrap().contiguous().unwrap())
            .unwrap();
        let scores = (scores * scale).unwrap();

        let scores = if let Some(m) = mask {
            scores.broadcast_add(m).unwrap()
        } else {
            scores
        };

        let attn_weights = candle_nn::ops::softmax(&scores, 3).unwrap();
        attn_weights.matmul(v).unwrap()
    }
}