trueno-gpu 0.4.17

Pure Rust PTX generation for NVIDIA CUDA - no LLVM, no nvcc
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
//! Activation Function Kernels
//!
//! GPU kernels for activation functions used in transformer FFN blocks.
//!
//! - `ReluKernel`: Rectified Linear Unit
//! - `SiluKernel`: Sigmoid Linear Unit (SiLU/Swish)
//! - `GeluKernel`: Gaussian Error Linear Unit
//! - `ElementwiseMulKernel`: Element-wise multiplication
//! - `ScaleKernel`: Scalar multiplication

#![allow(clippy::similar_names)]

use crate::kernels::Kernel;
use crate::ptx::builder::{PtxArithmetic, PtxComparison, PtxControl};
use crate::ptx::{PtxKernel, PtxReg, PtxType};

/// ReLU Activation Kernel: output = max(0, x)
///
/// Rectified Linear Unit activation function.
/// ReLU(x) = max(0, x)
///
/// # Issue #88: Forward kernel for training pipelines
#[derive(Debug, Clone)]
pub struct ReluKernel {
    /// Number of elements
    pub n: u32,
}

impl ReluKernel {
    /// Create a new ReLU activation kernel
    #[must_use]
    pub const fn new(n: u32) -> Self {
        Self { n }
    }
}

impl Kernel for ReluKernel {
    fn name(&self) -> &str {
        "relu"
    }

    fn build_ptx(&self) -> PtxKernel {
        PtxKernel::new("relu")
            .param(PtxType::U64, "input_ptr")
            .param(PtxType::U64, "output_ptr")
            .param(PtxType::U32, "n")
            .build(|ctx| {
                // Global thread ID
                let tid = ctx.special_reg(PtxReg::TidX);
                let ctaid = ctx.special_reg(PtxReg::CtaIdX);
                let ntid = ctx.special_reg(PtxReg::NtidX);
                let gid = ctx.mad_lo_u32(ctaid, ntid, tid);

                // Load parameters
                let n = ctx.load_param_u32("n");
                let input_ptr = ctx.load_param_u64("input_ptr");
                let output_ptr = ctx.load_param_u64("output_ptr");

                // Bounds check
                let in_bounds = ctx.setp_lt_u32(gid, n);
                ctx.branch_if_not(in_bounds, "exit");

                // Calculate address
                let four = ctx.mov_u32_imm(4);
                let offset = ctx.mul_wide_u32_reg(gid, four);
                let in_addr = ctx.add_u64(input_ptr, offset);
                let out_addr = ctx.add_u64(output_ptr, offset);

                // Load x
                let x = ctx.ld_global_f32(in_addr);

                // Compute ReLU: max(0, x)
                let zero = ctx.mov_f32_imm(0.0);
                let result = ctx.max_f32(x, zero);

                // Store
                ctx.st_global_f32(out_addr, result);

                ctx.label("exit");
                ctx.ret();
            })
    }
}

/// SiLU (Swish) Activation Kernel: output = x * sigmoid(x)
///
/// Sigmoid Linear Unit activation function used in LLaMA/TinyLlama FFN.
/// SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))
///
/// # PAR-023: Used in GPU-resident FFN block
#[derive(Debug, Clone)]
pub struct SiluKernel {
    /// Number of elements
    pub n: u32,
}

impl SiluKernel {
    /// Create a new SiLU activation kernel
    #[must_use]
    pub const fn new(n: u32) -> Self {
        Self { n }
    }
}

impl Kernel for SiluKernel {
    fn name(&self) -> &str {
        "silu"
    }

    fn build_ptx(&self) -> PtxKernel {
        PtxKernel::new("silu")
            .param(PtxType::U64, "input_ptr")
            .param(PtxType::U64, "output_ptr")
            .param(PtxType::U32, "n")
            .build(|ctx| {
                // Global thread ID
                let tid = ctx.special_reg(PtxReg::TidX);
                let ctaid = ctx.special_reg(PtxReg::CtaIdX);
                let ntid = ctx.special_reg(PtxReg::NtidX);
                let gid = ctx.mad_lo_u32(ctaid, ntid, tid);

                // Load parameters
                let n = ctx.load_param_u32("n");
                let input_ptr = ctx.load_param_u64("input_ptr");
                let output_ptr = ctx.load_param_u64("output_ptr");

                // Bounds check
                let in_bounds = ctx.setp_lt_u32(gid, n);
                ctx.branch_if_not(in_bounds, "exit");

                // Calculate address
                let four = ctx.mov_u32_imm(4);
                let offset = ctx.mul_wide_u32_reg(gid, four);
                let in_addr = ctx.add_u64(input_ptr, offset);
                let out_addr = ctx.add_u64(output_ptr, offset);

                // Load x
                let x = ctx.ld_global_f32(in_addr);

                // Compute SiLU: x * sigmoid(x) = x / (1 + exp(-x))
                // Step 1: neg_x = -x (0 - x)
                let zero = ctx.mov_f32_imm(0.0);
                let neg_x = ctx.sub_f32(zero, x);
                // Step 2: exp_neg_x = exp(-x) using ex2 (base-2 exp)
                // exp(x) = 2^(x * log2(e)) where log2(e) ≈ 1.4426950408889634
                let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
                let scaled = ctx.mul_f32(neg_x, log2_e);
                let exp_neg_x = ctx.ex2_f32(scaled);
                // Step 3: denom = 1 + exp(-x)
                let one = ctx.mov_f32_imm(1.0);
                let denom = ctx.add_f32(one, exp_neg_x);
                // Step 4: sigmoid = 1 / denom (using division)
                let sigmoid = ctx.div_f32(one, denom);
                // Step 5: result = x * sigmoid
                let result = ctx.mul_f32(x, sigmoid);

                // Store
                ctx.st_global_f32(out_addr, result);

                ctx.label("exit");
                ctx.ret();
            })
    }
}

/// GELU Activation Kernel (approximate): output ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
///
/// Gaussian Error Linear Unit activation function used in GPT/BERT models.
///
/// # PAR-023: Used in GPU-resident FFN block for models using GELU
#[derive(Debug, Clone)]
pub struct GeluKernel {
    /// Number of elements
    pub n: u32,
}

impl GeluKernel {
    /// Create a new GELU activation kernel
    #[must_use]
    pub const fn new(n: u32) -> Self {
        Self { n }
    }
}

impl Kernel for GeluKernel {
    fn name(&self) -> &str {
        "gelu"
    }

    fn build_ptx(&self) -> PtxKernel {
        PtxKernel::new("gelu")
            .param(PtxType::U64, "input_ptr")
            .param(PtxType::U64, "output_ptr")
            .param(PtxType::U32, "n")
            .build(|ctx| {
                // Global thread ID
                let tid = ctx.special_reg(PtxReg::TidX);
                let ctaid = ctx.special_reg(PtxReg::CtaIdX);
                let ntid = ctx.special_reg(PtxReg::NtidX);
                let gid = ctx.mad_lo_u32(ctaid, ntid, tid);

                // Load parameters
                let n = ctx.load_param_u32("n");
                let input_ptr = ctx.load_param_u64("input_ptr");
                let output_ptr = ctx.load_param_u64("output_ptr");

                // Bounds check
                let in_bounds = ctx.setp_lt_u32(gid, n);
                ctx.branch_if_not(in_bounds, "exit");

                // Calculate address
                let four = ctx.mov_u32_imm(4);
                let offset = ctx.mul_wide_u32_reg(gid, four);
                let in_addr = ctx.add_u64(input_ptr, offset);
                let out_addr = ctx.add_u64(output_ptr, offset);

                // Load x
                let x = ctx.ld_global_f32(in_addr);

                // GELU approximation:
                // 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
                // sqrt(2/π) ≈ 0.7978845608
                let sqrt_2_pi = ctx.mov_f32_imm(0.797_884_6);
                let c = ctx.mov_f32_imm(0.044_715);
                let half = ctx.mov_f32_imm(0.5);
                let one = ctx.mov_f32_imm(1.0);

                // x³
                let x2 = ctx.mul_f32(x, x);
                let x3 = ctx.mul_f32(x2, x);

                // 0.044715 * x³
                let cx3 = ctx.mul_f32(c, x3);

                // x + 0.044715 * x³
                let inner = ctx.add_f32(x, cx3);

                // sqrt(2/π) * (x + 0.044715 * x³)
                let scaled = ctx.mul_f32(sqrt_2_pi, inner);

                // tanh approximation using (exp(2x) - 1) / (exp(2x) + 1)
                // For better precision, use: tanh(x) = 2*sigmoid(2x) - 1
                let two = ctx.mov_f32_imm(2.0);
                let zero = ctx.mov_f32_imm(0.0);
                let two_x = ctx.mul_f32(two, scaled);
                let neg_two_x = ctx.sub_f32(zero, two_x);
                let log2_e = ctx.mov_f32_imm(std::f32::consts::LOG2_E);
                let scaled_exp = ctx.mul_f32(neg_two_x, log2_e);
                let exp_neg = ctx.ex2_f32(scaled_exp);
                let denom = ctx.add_f32(one, exp_neg);
                let sigmoid = ctx.div_f32(one, denom);
                // tanh = 2*sigmoid - 1
                let two_sigmoid = ctx.mul_f32(two, sigmoid);
                let tanh = ctx.sub_f32(two_sigmoid, one);

                // 1 + tanh(...)
                let one_plus_tanh = ctx.add_f32(one, tanh);

                // 0.5 * x
                let half_x = ctx.mul_f32(half, x);

                // result = 0.5 * x * (1 + tanh(...))
                let result = ctx.mul_f32(half_x, one_plus_tanh);

                // Store
                ctx.st_global_f32(out_addr, result);

                ctx.label("exit");
                ctx.ret();
            })
    }
}

/// Element-wise Multiply Kernel: output = input1 * input2
///
/// Used for gated activations in SwiGLU: silu(gate) * up
///
/// # PAR-023: Used in GPU-resident FFN block
#[derive(Debug, Clone)]
pub struct ElementwiseMulKernel {
    /// Number of elements
    pub n: u32,
}

impl ElementwiseMulKernel {
    /// Create a new element-wise multiply kernel
    #[must_use]
    pub const fn new(n: u32) -> Self {
        Self { n }
    }
}

impl Kernel for ElementwiseMulKernel {
    fn name(&self) -> &str {
        "elementwise_mul"
    }

    fn build_ptx(&self) -> PtxKernel {
        PtxKernel::new("elementwise_mul")
            .param(PtxType::U64, "input1_ptr")
            .param(PtxType::U64, "input2_ptr")
            .param(PtxType::U64, "output_ptr")
            .param(PtxType::U32, "n")
            .build(|ctx| {
                // Global thread ID
                let tid = ctx.special_reg(PtxReg::TidX);
                let ctaid = ctx.special_reg(PtxReg::CtaIdX);
                let ntid = ctx.special_reg(PtxReg::NtidX);
                let gid = ctx.mad_lo_u32(ctaid, ntid, tid);

                // Load parameters
                let n = ctx.load_param_u32("n");
                let input1_ptr = ctx.load_param_u64("input1_ptr");
                let input2_ptr = ctx.load_param_u64("input2_ptr");
                let output_ptr = ctx.load_param_u64("output_ptr");

                // Bounds check
                let in_bounds = ctx.setp_lt_u32(gid, n);
                ctx.branch_if_not(in_bounds, "exit");

                // Calculate address
                let four = ctx.mov_u32_imm(4);
                let offset = ctx.mul_wide_u32_reg(gid, four);
                let addr1 = ctx.add_u64(input1_ptr, offset);
                let addr2 = ctx.add_u64(input2_ptr, offset);
                let out_addr = ctx.add_u64(output_ptr, offset);

                // Load both values
                let val1 = ctx.ld_global_f32(addr1);
                let val2 = ctx.ld_global_f32(addr2);

                // Multiply
                let result = ctx.mul_f32(val1, val2);

                // Store
                ctx.st_global_f32(out_addr, result);

                ctx.label("exit");
                ctx.ret();
            })
    }
}

/// Scale Kernel: output = input * scale (scalar constant)
///
/// Multiplies each element by a constant scale factor.
/// Used for attention score scaling (1/sqrt(d_k)).
#[derive(Debug, Clone)]
pub struct ScaleKernel {
    /// Number of elements
    pub n: u32,
}

impl ScaleKernel {
    /// Create a new scale kernel
    #[must_use]
    pub const fn new(n: u32) -> Self {
        Self { n }
    }
}

impl Kernel for ScaleKernel {
    fn name(&self) -> &str {
        "scale"
    }

    fn build_ptx(&self) -> PtxKernel {
        PtxKernel::new("scale")
            .param(PtxType::U64, "input_ptr")
            .param(PtxType::U64, "output_ptr")
            .param(PtxType::F32, "scale")
            .param(PtxType::U32, "n")
            .build(|ctx| {
                // Global thread ID
                let tid = ctx.special_reg(PtxReg::TidX);
                let ctaid = ctx.special_reg(PtxReg::CtaIdX);
                let ntid = ctx.special_reg(PtxReg::NtidX);
                let gid = ctx.mad_lo_u32(ctaid, ntid, tid);

                // Load parameters
                let n = ctx.load_param_u32("n");
                let input_ptr = ctx.load_param_u64("input_ptr");
                let output_ptr = ctx.load_param_u64("output_ptr");
                let scale = ctx.load_param_f32("scale");

                // Bounds check
                let in_bounds = ctx.setp_lt_u32(gid, n);
                ctx.branch_if_not(in_bounds, "exit");

                // Calculate address
                let four = ctx.mov_u32_imm(4);
                let offset = ctx.mul_wide_u32_reg(gid, four);
                let in_addr = ctx.add_u64(input_ptr, offset);
                let out_addr = ctx.add_u64(output_ptr, offset);

                // Load input value
                let val = ctx.ld_global_f32(in_addr);

                // Multiply by scale
                let result = ctx.mul_f32(val, scale);

                // Store result
                ctx.st_global_f32(out_addr, result);

                ctx.label("exit");
                ctx.ret();
            })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    // ============ ReluKernel Tests ============

    #[test]
    fn test_relu_kernel_name() {
        let kernel = ReluKernel::new(2048);
        assert_eq!(kernel.name(), "relu");
    }

    #[test]
    fn test_relu_ptx_generation() {
        let kernel = ReluKernel::new(2048);
        let ptx = kernel.emit_ptx();

        // Verify entry point
        assert!(ptx.contains(".entry relu"));

        // Verify max operation for ReLU
        assert!(ptx.contains("max.f32"));
    }

    #[test]
    fn test_relu_kernel_debug() {
        let kernel = ReluKernel::new(1024);
        let debug_str = format!("{:?}", kernel);
        assert!(debug_str.contains("ReluKernel"));
        assert!(debug_str.contains("1024"));
    }

    #[test]
    fn test_relu_kernel_clone() {
        let kernel = ReluKernel::new(512);
        let cloned = kernel.clone();
        assert_eq!(cloned.n, 512);
    }

    #[test]
    fn test_relu_kernel_ptx_contains_bounds_check() {
        let kernel = ReluKernel::new(100);
        let ptx = kernel.emit_ptx();
        // Verify bounds check with setp
        assert!(ptx.contains("setp.lt.u32"));
        // Verify branch instruction
        assert!(ptx.contains("@!"));
    }

    #[test]
    fn test_relu_kernel_edge_case_n_zero() {
        let kernel = ReluKernel::new(0);
        let ptx = kernel.emit_ptx();
        // Should still generate valid PTX
        assert!(ptx.contains(".entry relu"));
    }

    #[test]
    fn test_relu_kernel_edge_case_n_one() {
        let kernel = ReluKernel::new(1);
        let ptx = kernel.emit_ptx();
        assert!(ptx.contains(".entry relu"));
        assert!(ptx.contains("max.f32"));
    }

    #[test]
    fn test_relu_kernel_large_n() {
        let kernel = ReluKernel::new(u32::MAX);
        let ptx = kernel.emit_ptx();
        assert!(ptx.contains(".entry relu"));
    }

    // ============ SiluKernel Tests ============

    #[test]
    fn test_silu_kernel_name() {
        let kernel = SiluKernel::new(2048);
        assert_eq!(kernel.name(), "silu");
    }

    #[test]
    fn test_silu_ptx_generation() {
        let kernel = SiluKernel::new(2048);
        let ptx = kernel.emit_ptx();

        // Verify entry point
        assert!(ptx.contains(".entry silu"));

        // Verify sigmoid computation (exp and division)
        assert!(ptx.contains("ex2.approx.f32"));
        assert!(ptx.contains("div.rn.f32"));

        // Verify final multiply (x * sigmoid)
        assert!(ptx.contains("mul.f32"));
    }

    #[test]
    fn test_silu_kernel_debug() {
        let kernel = SiluKernel::new(4096);
        let debug_str = format!("{:?}", kernel);
        assert!(debug_str.contains("SiluKernel"));
        assert!(debug_str.contains("4096"));
    }

    #[test]
    fn test_silu_kernel_clone() {
        let kernel = SiluKernel::new(256);
        let cloned = kernel.clone();
        assert_eq!(cloned.n, 256);
    }

    #[test]
    fn test_silu_kernel_contains_log2e_constant() {
        let kernel = SiluKernel::new(1000);
        let ptx = kernel.emit_ptx();
        // Verify we use ex2 for exp approximation
        assert!(ptx.contains("ex2.approx.f32"));
    }

    #[test]
    fn test_silu_kernel_ptx_structure() {
        let kernel = SiluKernel::new(512);
        let ptx = kernel.emit_ptx();
        // Verify parameter declarations
        assert!(ptx.contains(".param .u64 input_ptr"));
        assert!(ptx.contains(".param .u64 output_ptr"));
        assert!(ptx.contains(".param .u32 n"));
        // Verify exit label
        assert!(ptx.contains("exit:"));
    }

    // ============ GeluKernel Tests ============

    #[test]
    fn test_gelu_kernel_name() {
        let kernel = GeluKernel::new(2048);
        assert_eq!(kernel.name(), "gelu");
    }

    #[test]
    fn test_gelu_ptx_generation() {
        let kernel = GeluKernel::new(2048);
        let ptx = kernel.emit_ptx();

        // Verify entry point
        assert!(ptx.contains(".entry gelu"));

        // Verify tanh computation via sigmoid (exp)
        assert!(ptx.contains("ex2.approx.f32"));

        // Verify x^3 computation (two multiplies)
        assert!(ptx.contains("mul.f32"));
    }

    #[test]
    fn test_gelu_kernel_debug() {
        let kernel = GeluKernel::new(8192);
        let debug_str = format!("{:?}", kernel);
        assert!(debug_str.contains("GeluKernel"));
        assert!(debug_str.contains("8192"));
    }

    #[test]
    fn test_gelu_kernel_clone() {
        let kernel = GeluKernel::new(128);
        let cloned = kernel.clone();
        assert_eq!(cloned.n, 128);
    }

    #[test]
    fn test_gelu_kernel_ptx_contains_tanh_approximation() {
        let kernel = GeluKernel::new(1000);
        let ptx = kernel.emit_ptx();
        // GELU uses tanh via 2*sigmoid - 1
        assert!(ptx.contains("div.rn.f32")); // Division for sigmoid
        assert!(ptx.contains("sub.f32")); // Subtraction for tanh
    }

    #[test]
    fn test_gelu_kernel_edge_case_n_zero() {
        let kernel = GeluKernel::new(0);
        let ptx = kernel.emit_ptx();
        assert!(ptx.contains(".entry gelu"));
    }

    // ============ ElementwiseMulKernel Tests ============

    #[test]
    fn test_elementwise_mul_kernel_name() {
        let kernel = ElementwiseMulKernel::new(2048);
        assert_eq!(kernel.name(), "elementwise_mul");
    }

    #[test]
    fn test_elementwise_mul_ptx_generation() {
        let kernel = ElementwiseMulKernel::new(2048);
        let ptx = kernel.emit_ptx();

        // Verify entry point
        assert!(ptx.contains(".entry elementwise_mul"));

        // Verify two input parameters
        assert!(ptx.contains(".param .u64 input1_ptr"));
        assert!(ptx.contains(".param .u64 input2_ptr"));
        assert!(ptx.contains(".param .u64 output_ptr"));
        assert!(ptx.contains(".param .u32 n"));

        // Verify multiply operation
        assert!(ptx.contains("mul.f32"));
    }

    #[test]
    fn test_elementwise_mul_kernel_debug() {
        let kernel = ElementwiseMulKernel::new(1024);
        let debug_str = format!("{:?}", kernel);
        assert!(debug_str.contains("ElementwiseMulKernel"));
        assert!(debug_str.contains("1024"));
    }

    #[test]
    fn test_elementwise_mul_kernel_clone() {
        let kernel = ElementwiseMulKernel::new(64);
        let cloned = kernel.clone();
        assert_eq!(cloned.n, 64);
    }

    #[test]
    fn test_elementwise_mul_kernel_ptx_contains_bounds_check() {
        let kernel = ElementwiseMulKernel::new(500);
        let ptx = kernel.emit_ptx();
        // Verify bounds check
        assert!(ptx.contains("setp.lt.u32"));
    }

    #[test]
    fn test_elementwise_mul_kernel_ptx_loads_two_inputs() {
        let kernel = ElementwiseMulKernel::new(100);
        let ptx = kernel.emit_ptx();
        // Verify two global loads
        let load_count = ptx.matches("ld.global.f32").count();
        assert_eq!(load_count, 2, "Should have exactly 2 global loads");
    }

    #[test]
    fn test_elementwise_mul_kernel_edge_case_n_one() {
        let kernel = ElementwiseMulKernel::new(1);
        let ptx = kernel.emit_ptx();
        assert!(ptx.contains(".entry elementwise_mul"));
        assert!(ptx.contains("mul.f32"));
    }

    #[test]
    fn test_elementwise_mul_kernel_large_n() {
        let kernel = ElementwiseMulKernel::new(1_000_000);
        let ptx = kernel.emit_ptx();
        assert!(ptx.contains(".entry elementwise_mul"));
    }

    // ============ ScaleKernel Tests ============

    #[test]
    fn test_scale_kernel_name() {
        let kernel = ScaleKernel::new(2048);
        assert_eq!(kernel.name(), "scale");
    }

    #[test]
    fn test_scale_ptx_generation() {
        let kernel = ScaleKernel::new(2048);
        let ptx = kernel.emit_ptx();

        assert!(ptx.contains(".entry scale"));
        assert!(ptx.contains(".param .f32 scale"));
        assert!(ptx.contains("mul.f32"));
    }

    #[test]
    fn test_scale_kernel_debug() {
        let kernel = ScaleKernel::new(512);
        let debug_str = format!("{:?}", kernel);
        assert!(debug_str.contains("ScaleKernel"));
        assert!(debug_str.contains("512"));
    }

    #[test]
    fn test_scale_kernel_clone() {
        let kernel = ScaleKernel::new(32);
        let cloned = kernel.clone();
        assert_eq!(cloned.n, 32);
    }

    #[test]
    fn test_scale_kernel_ptx_structure() {
        let kernel = ScaleKernel::new(256);
        let ptx = kernel.emit_ptx();
        // Verify parameter order
        assert!(ptx.contains(".param .u64 input_ptr"));
        assert!(ptx.contains(".param .u64 output_ptr"));
        assert!(ptx.contains(".param .f32 scale"));
        assert!(ptx.contains(".param .u32 n"));
    }

    #[test]
    fn test_scale_kernel_edge_case_n_zero() {
        let kernel = ScaleKernel::new(0);
        let ptx = kernel.emit_ptx();
        assert!(ptx.contains(".entry scale"));
    }

    #[test]
    fn test_scale_kernel_ptx_uses_f32_scale_param() {
        let kernel = ScaleKernel::new(100);
        let ptx = kernel.emit_ptx();
        // Verify f32 scale parameter is loaded
        assert!(ptx.contains(".param .f32 scale"));
        // And used in multiplication
        assert!(ptx.contains("mul.f32"));
    }
}