Skip to main content

ferrotorch_gpu/
kernels.rs

1//! Custom PTX CUDA kernels for elementwise GPU operations.
2//!
3//! Each operation has two code paths:
4//!
5//! 1. **PTX kernel** -- a hand-written PTX string that is loaded into the CUDA
6//!    driver at runtime via [`cudarc`]. This is the fast path and runs entirely
7//!    on the GPU.
8//! 2. **CPU fallback** -- copies data to the host, performs the operation with
9//!    standard Rust iterators, and copies the result back. Correct but slow;
10//!    used when the PTX module cannot be loaded (e.g. architecture mismatch).
11//!
12//! # Supported operations
13//!
14//! | Function | Formula |
15//! |----------|---------|
16//! | [`gpu_add`] | `out[i] = a[i] + b[i]` |
17//! | [`gpu_sub`] | `out[i] = a[i] - b[i]` |
18//! | [`gpu_mul`] | `out[i] = a[i] * b[i]` |
19//! | [`gpu_neg`] | `out[i] = -a[i]` |
20//! | [`gpu_relu`] | `out[i] = max(a[i], 0.0)` |
21
22#[cfg(feature = "cuda")]
23use cudarc::driver::LaunchConfig;
24
25use crate::buffer::CudaBuffer;
26use crate::device::GpuDevice;
27use crate::error::{GpuError, GpuResult};
28#[cfg(feature = "cuda")]
29use crate::transfer::{alloc_zeros_f32, alloc_zeros_f64, cpu_to_gpu, gpu_to_cpu};
30
31// ---------------------------------------------------------------------------
32// f32 → f64 PTX auto-conversion
33// ---------------------------------------------------------------------------
34
35/// Convert an f32 PTX kernel string to its f64 equivalent by applying
36/// mechanical substitutions. Works for "simple" kernels where the only
37/// difference between f32 and f64 is register types, load/store widths,
38/// byte offsets, and float literals.
39///
40/// Does NOT work for kernels that use `ex2.approx.f32` or `lg2.approx.f32`
41/// (transcendentals) — those need hand-written f64 implementations.
42#[cfg(feature = "cuda")]
43pub(crate) fn ptx_f32_to_f64(f32_ptx: &str, f32_kernel_name: &str, f64_kernel_name: &str) -> String {
44    f32_ptx
45        // Kernel entry point name
46        .replace(f32_kernel_name, f64_kernel_name)
47        // Register declarations
48        .replace(".reg .f32", ".reg .f64")
49        // Memory operations (must come before arithmetic to avoid double-replace)
50        .replace("ld.global.f32", "ld.global.f64")
51        .replace("st.global.f32", "st.global.f64")
52        .replace("ld.shared.f32", "ld.shared.f64")
53        .replace("st.shared.f32", "st.shared.f64")
54        .replace("ld.param.f32", "ld.param.f64")
55        .replace(".param .f32", ".param .f64")
56        // Shared memory declarations
57        .replace(".shared .align 4 .f32", ".shared .align 8 .f64")
58        // Arithmetic
59        .replace("add.f32", "add.f64")
60        .replace("sub.f32", "sub.f64")
61        .replace("mul.f32", "mul.f64")
62        .replace("div.rn.f32", "div.rn.f64")
63        .replace("div.f32", "div.f64")
64        .replace("neg.f32", "neg.f64")
65        .replace("abs.f32", "abs.f64")
66        .replace("max.f32", "max.f64")
67        .replace("min.f32", "min.f64")
68        .replace("sqrt.rn.f32", "sqrt.rn.f64")
69        .replace("sqrt.f32", "sqrt.f64")
70        .replace("fma.rn.f32", "fma.rn.f64")
71        .replace("mov.f32", "mov.f64")
72        // Comparisons
73        .replace("setp.gt.f32", "setp.gt.f64")
74        .replace("setp.ge.f32", "setp.ge.f64")
75        .replace("setp.lt.f32", "setp.lt.f64")
76        .replace("setp.le.f32", "setp.le.f64")
77        .replace("setp.eq.f32", "setp.eq.f64")
78        .replace("setp.ne.f32", "setp.ne.f64")
79        // Conversions
80        .replace("cvt.rn.f32.u32", "cvt.rn.f64.u32")
81        .replace("cvt.rn.f32.s32", "cvt.rn.f64.s32")
82        // Bit reinterpretation (for NaN/inf checks)
83        .replace("mov.b32", "mov.b64")
84        // Byte offset: 4 bytes per f32 → 8 bytes per f64
85        .replace("shl.b64 %off, %off, 2", "shl.b64 %off, %off, 3")
86        // Atomics
87        .replace("atom.global.add.f32", "atom.global.add.f64")
88        // Common float hex literals
89        .replace("0f00000000", "0d0000000000000000")     // 0.0
90        .replace("0f3F800000", "0d3FF0000000000000")     // 1.0
91        .replace("0fBF800000", "0dBFF0000000000000")     // -1.0
92        .replace("0f40000000", "0d4000000000000000")     // 2.0
93        .replace("0f3F000000", "0d3FE0000000000000")     // 0.5
94        .replace("0fFF800000", "0dFFF0000000000000")     // -inf
95        .replace("0f7F800000", "0d7FF0000000000000")     // +inf
96        .replace("0f3FB8AA3B", "0d3FF71547652B82FE")     // log2(e)
97        .replace("0f3F317218", "0d3FE62E42FEFA39EF")     // ln(2)
98}
99
100/// Helper to get or create a cached f64 PTX string from an f32 source.
101///
102/// Uses a global cache so the string transformation only happens once per
103/// kernel. The returned `&str` is valid for the lifetime of the program.
104#[cfg(feature = "cuda")]
105pub(crate) fn get_f64_ptx<'a>(
106    cache: &'a std::sync::OnceLock<String>,
107    f32_ptx: &str,
108    f32_name: &str,
109    f64_name: &str,
110) -> &'a str {
111    cache.get_or_init(|| ptx_f32_to_f64(f32_ptx, f32_name, f64_name))
112}
113
114// ---------------------------------------------------------------------------
115// PTX kernel source strings
116// ---------------------------------------------------------------------------
117
118/// PTX source for `add_kernel`: `out[i] = a[i] + b[i]`.
119#[cfg(feature = "cuda")]
120pub(crate) const ADD_PTX: &str = "\
121.version 7.0
122.target sm_52
123.address_size 64
124
125.visible .entry add_kernel(
126    .param .u64 a_ptr,
127    .param .u64 b_ptr,
128    .param .u64 out_ptr,
129    .param .u32 n
130) {
131    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
132    .reg .u64 %a, %b, %out, %off;
133    .reg .f32 %va, %vb, %vr;
134    .reg .pred %p;
135
136    ld.param.u64 %a, [a_ptr];
137    ld.param.u64 %b, [b_ptr];
138    ld.param.u64 %out, [out_ptr];
139    ld.param.u32 %n_reg, [n];
140
141    mov.u32 %bid, %ctaid.x;
142    mov.u32 %bdim, %ntid.x;
143    mov.u32 %r_tid, %tid.x;
144    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
145
146    setp.ge.u32 %p, %r_tid, %n_reg;
147    @%p bra DONE;
148
149    cvt.u64.u32 %off, %r_tid;
150    shl.b64 %off, %off, 2;
151
152    add.u64 %a, %a, %off;
153    add.u64 %b, %b, %off;
154    add.u64 %out, %out, %off;
155
156    ld.global.f32 %va, [%a];
157    ld.global.f32 %vb, [%b];
158    add.f32 %vr, %va, %vb;
159    st.global.f32 [%out], %vr;
160
161DONE:
162    ret;
163}
164";
165
166
167/// PTX source for `add_vec4_kernel`: vectorized add, 4 elements per thread.
168///
169/// Uses `ld.global.v4.f32` (128-bit load) for 4x memory throughput vs scalar.
170/// Thread i processes elements [i*4 .. i*4+3].
171#[cfg(feature = "cuda")]
172pub(crate) const ADD_VEC4_PTX: &str = "\
173.version 7.0
174.target sm_52
175.address_size 64
176
177.visible .entry add_vec4_kernel(
178    .param .u64 a_ptr,
179    .param .u64 b_ptr,
180    .param .u64 out_ptr,
181    .param .u32 n4
182) {
183    .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
184    .reg .u64 %a, %b, %out, %off;
185    .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
186    .reg .pred %p;
187
188    ld.param.u64 %a, [a_ptr];
189    ld.param.u64 %b, [b_ptr];
190    ld.param.u64 %out, [out_ptr];
191    ld.param.u32 %n4_reg, [n4];
192
193    mov.u32 %bid, %ctaid.x;
194    mov.u32 %bdim, %ntid.x;
195    mov.u32 %r_tid, %tid.x;
196    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
197
198    setp.ge.u32 %p, %r_tid, %n4_reg;
199    @%p bra DONE;
200
201    // Byte offset = tid * 16 (4 floats × 4 bytes)
202    cvt.u64.u32 %off, %r_tid;
203    shl.b64 %off, %off, 4;
204
205    add.u64 %a, %a, %off;
206    add.u64 %b, %b, %off;
207    add.u64 %out, %out, %off;
208
209    ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
210    ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
211
212    add.f32 %r0, %a0, %b0;
213    add.f32 %r1, %a1, %b1;
214    add.f32 %r2, %a2, %b2;
215    add.f32 %r3, %a3, %b3;
216
217    st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
218
219DONE:
220    ret;
221}
222";
223
224/// PTX source for `mul_vec4_kernel`: vectorized multiply, 4 elements per thread.
225#[cfg(feature = "cuda")]
226pub(crate) const MUL_VEC4_PTX: &str = "\
227.version 7.0
228.target sm_52
229.address_size 64
230
231.visible .entry mul_vec4_kernel(
232    .param .u64 a_ptr,
233    .param .u64 b_ptr,
234    .param .u64 out_ptr,
235    .param .u32 n4
236) {
237    .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
238    .reg .u64 %a, %b, %out, %off;
239    .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
240    .reg .pred %p;
241
242    ld.param.u64 %a, [a_ptr];
243    ld.param.u64 %b, [b_ptr];
244    ld.param.u64 %out, [out_ptr];
245    ld.param.u32 %n4_reg, [n4];
246
247    mov.u32 %bid, %ctaid.x;
248    mov.u32 %bdim, %ntid.x;
249    mov.u32 %r_tid, %tid.x;
250    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
251
252    setp.ge.u32 %p, %r_tid, %n4_reg;
253    @%p bra DONE;
254
255    cvt.u64.u32 %off, %r_tid;
256    shl.b64 %off, %off, 4;
257
258    add.u64 %a, %a, %off;
259    add.u64 %b, %b, %off;
260    add.u64 %out, %out, %off;
261
262    ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
263    ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
264
265    mul.f32 %r0, %a0, %b0;
266    mul.f32 %r1, %a1, %b1;
267    mul.f32 %r2, %a2, %b2;
268    mul.f32 %r3, %a3, %b3;
269
270    st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
271
272DONE:
273    ret;
274}
275";
276
277/// PTX source for `sub_kernel`: `out[i] = a[i] - b[i]`.
278#[cfg(feature = "cuda")]
279pub(crate) const SUB_PTX: &str = "\
280.version 7.0
281.target sm_52
282.address_size 64
283
284.visible .entry sub_kernel(
285    .param .u64 a_ptr,
286    .param .u64 b_ptr,
287    .param .u64 out_ptr,
288    .param .u32 n
289) {
290    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
291    .reg .u64 %a, %b, %out, %off;
292    .reg .f32 %va, %vb, %vr;
293    .reg .pred %p;
294
295    ld.param.u64 %a, [a_ptr];
296    ld.param.u64 %b, [b_ptr];
297    ld.param.u64 %out, [out_ptr];
298    ld.param.u32 %n_reg, [n];
299
300    mov.u32 %bid, %ctaid.x;
301    mov.u32 %bdim, %ntid.x;
302    mov.u32 %r_tid, %tid.x;
303    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
304
305    setp.ge.u32 %p, %r_tid, %n_reg;
306    @%p bra DONE;
307
308    cvt.u64.u32 %off, %r_tid;
309    shl.b64 %off, %off, 2;
310
311    add.u64 %a, %a, %off;
312    add.u64 %b, %b, %off;
313    add.u64 %out, %out, %off;
314
315    ld.global.f32 %va, [%a];
316    ld.global.f32 %vb, [%b];
317    sub.f32 %vr, %va, %vb;
318    st.global.f32 [%out], %vr;
319
320DONE:
321    ret;
322}
323";
324
325
326/// PTX source for `mul_kernel`: `out[i] = a[i] * b[i]`.
327#[cfg(feature = "cuda")]
328pub(crate) const MUL_PTX: &str = "\
329.version 7.0
330.target sm_52
331.address_size 64
332
333.visible .entry mul_kernel(
334    .param .u64 a_ptr,
335    .param .u64 b_ptr,
336    .param .u64 out_ptr,
337    .param .u32 n
338) {
339    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
340    .reg .u64 %a, %b, %out, %off;
341    .reg .f32 %va, %vb, %vr;
342    .reg .pred %p;
343
344    ld.param.u64 %a, [a_ptr];
345    ld.param.u64 %b, [b_ptr];
346    ld.param.u64 %out, [out_ptr];
347    ld.param.u32 %n_reg, [n];
348
349    mov.u32 %bid, %ctaid.x;
350    mov.u32 %bdim, %ntid.x;
351    mov.u32 %r_tid, %tid.x;
352    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
353
354    setp.ge.u32 %p, %r_tid, %n_reg;
355    @%p bra DONE;
356
357    cvt.u64.u32 %off, %r_tid;
358    shl.b64 %off, %off, 2;
359
360    add.u64 %a, %a, %off;
361    add.u64 %b, %b, %off;
362    add.u64 %out, %out, %off;
363
364    ld.global.f32 %va, [%a];
365    ld.global.f32 %vb, [%b];
366    mul.f32 %vr, %va, %vb;
367    st.global.f32 [%out], %vr;
368
369DONE:
370    ret;
371}
372";
373
374
375/// PTX source for `neg_kernel`: `out[i] = -a[i]`.
376#[cfg(feature = "cuda")]
377pub(crate) const NEG_PTX: &str = "\
378.version 7.0
379.target sm_52
380.address_size 64
381
382.visible .entry neg_kernel(
383    .param .u64 a_ptr,
384    .param .u64 out_ptr,
385    .param .u32 n
386) {
387    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
388    .reg .u64 %a, %out, %off;
389    .reg .f32 %va, %vr;
390    .reg .pred %p;
391
392    ld.param.u64 %a, [a_ptr];
393    ld.param.u64 %out, [out_ptr];
394    ld.param.u32 %n_reg, [n];
395
396    mov.u32 %bid, %ctaid.x;
397    mov.u32 %bdim, %ntid.x;
398    mov.u32 %r_tid, %tid.x;
399    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
400
401    setp.ge.u32 %p, %r_tid, %n_reg;
402    @%p bra DONE;
403
404    cvt.u64.u32 %off, %r_tid;
405    shl.b64 %off, %off, 2;
406
407    add.u64 %a, %a, %off;
408    add.u64 %out, %out, %off;
409
410    ld.global.f32 %va, [%a];
411    neg.f32 %vr, %va;
412    st.global.f32 [%out], %vr;
413
414DONE:
415    ret;
416}
417";
418
419
420/// PTX source for `relu_kernel`: `out[i] = max(a[i], 0.0)`.
421#[cfg(feature = "cuda")]
422pub(crate) const RELU_PTX: &str = "\
423.version 7.0
424.target sm_52
425.address_size 64
426
427.visible .entry relu_kernel(
428    .param .u64 a_ptr,
429    .param .u64 out_ptr,
430    .param .u32 n
431) {
432    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
433    .reg .u64 %a, %out, %off;
434    .reg .f32 %va, %vr, %zero;
435    .reg .pred %p;
436
437    ld.param.u64 %a, [a_ptr];
438    ld.param.u64 %out, [out_ptr];
439    ld.param.u32 %n_reg, [n];
440
441    mov.u32 %bid, %ctaid.x;
442    mov.u32 %bdim, %ntid.x;
443    mov.u32 %r_tid, %tid.x;
444    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
445
446    setp.ge.u32 %p, %r_tid, %n_reg;
447    @%p bra DONE;
448
449    cvt.u64.u32 %off, %r_tid;
450    shl.b64 %off, %off, 2;
451
452    add.u64 %a, %a, %off;
453    add.u64 %out, %out, %off;
454
455    ld.global.f32 %va, [%a];
456    mov.f32 %zero, 0f00000000;
457    max.f32 %vr, %va, %zero;
458    st.global.f32 [%out], %vr;
459
460DONE:
461    ret;
462}
463";
464
465
466/// PTX source for `scale_kernel`: `out[i] = a[i] * scalar`.
467#[cfg(feature = "cuda")]
468pub(crate) const SCALE_PTX: &str = "\
469.version 7.0
470.target sm_52
471.address_size 64
472
473.visible .entry scale_kernel(
474    .param .u64 a_ptr,
475    .param .u64 out_ptr,
476    .param .f32 scalar,
477    .param .u32 n
478) {
479    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
480    .reg .u64 %a, %out, %off;
481    .reg .f32 %va, %vr, %s;
482    .reg .pred %p;
483
484    ld.param.u64 %a, [a_ptr];
485    ld.param.u64 %out, [out_ptr];
486    ld.param.f32 %s, [scalar];
487    ld.param.u32 %n_reg, [n];
488
489    mov.u32 %bid, %ctaid.x;
490    mov.u32 %bdim, %ntid.x;
491    mov.u32 %r_tid, %tid.x;
492    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
493
494    setp.ge.u32 %p, %r_tid, %n_reg;
495    @%p bra DONE;
496
497    cvt.u64.u32 %off, %r_tid;
498    shl.b64 %off, %off, 2;
499
500    add.u64 %a, %a, %off;
501    add.u64 %out, %out, %off;
502
503    ld.global.f32 %va, [%a];
504    mul.f32 %vr, %va, %s;
505    st.global.f32 [%out], %vr;
506
507DONE:
508    ret;
509}
510";
511
512
513/// PTX for 2D matrix transpose: `out[j * M + i] = in[i * N + j]`.
514/// Thread `tid` maps to output index; computes the corresponding input index.
515#[cfg(feature = "cuda")]
516pub(crate) const TRANSPOSE_2D_PTX: &str = "\
517.version 7.0\n\
518.target sm_52\n\
519.address_size 64\n\
520\n\
521.visible .entry transpose_2d_kernel(\n\
522    .param .u64 in_ptr,\n\
523    .param .u64 out_ptr,\n\
524    .param .u32 M,\n\
525    .param .u32 N,\n\
526    .param .u32 total\n\
527) {\n\
528    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %N_reg;\n\
529    .reg .u32 %out_row, %out_col, %in_idx;\n\
530    .reg .u64 %in, %out, %off_in, %off_out;\n\
531    .reg .f32 %val;\n\
532    .reg .pred %p;\n\
533\n\
534    ld.param.u64 %in, [in_ptr];\n\
535    ld.param.u64 %out, [out_ptr];\n\
536    ld.param.u32 %M_reg, [M];\n\
537    ld.param.u32 %N_reg, [N];\n\
538    ld.param.u32 %total_reg, [total];\n\
539\n\
540    mov.u32 %bid, %ctaid.x;\n\
541    mov.u32 %bdim, %ntid.x;\n\
542    mov.u32 %r_tid, %tid.x;\n\
543    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
544\n\
545    setp.ge.u32 %p, %r_tid, %total_reg;\n\
546    @%p bra DONE;\n\
547\n\
548    // Output shape is [N, M]. tid = out_row * M + out_col.\n\
549    div.u32 %out_row, %r_tid, %M_reg;\n\
550    rem.u32 %out_col, %r_tid, %M_reg;\n\
551    // Input index: out_col * N + out_row (transposed).\n\
552    mad.lo.u32 %in_idx, %out_col, %N_reg, %out_row;\n\
553\n\
554    cvt.u64.u32 %off_in, %in_idx;\n\
555    shl.b64 %off_in, %off_in, 2;\n\
556    add.u64 %off_in, %in, %off_in;\n\
557    ld.global.f32 %val, [%off_in];\n\
558\n\
559    cvt.u64.u32 %off_out, %r_tid;\n\
560    shl.b64 %off_out, %off_out, 2;\n\
561    add.u64 %off_out, %out, %off_out;\n\
562    st.global.f32 [%off_out], %val;\n\
563\n\
564DONE:\n\
565    ret;\n\
566}\n\
567";
568
569
570// ---------------------------------------------------------------------------
571// 4D permute (0,2,1,3) PTX kernel — swap dims 1 and 2
572// ---------------------------------------------------------------------------
573// Input:  [d0, d1, d2, d3]
574// Output: [d0, d2, d1, d3]
575// Thread i computes output[i] by mapping to the transposed input index.
576
577#[cfg(feature = "cuda")]
578pub(crate) const PERMUTE_0213_PTX: &str = "\
579.version 7.0\n\
580.target sm_52\n\
581.address_size 64\n\
582\n\
583.visible .entry permute_0213_kernel(\n\
584    .param .u64 in_ptr,\n\
585    .param .u64 out_ptr,\n\
586    .param .u32 d0,\n\
587    .param .u32 d1,\n\
588    .param .u32 d2,\n\
589    .param .u32 d3,\n\
590    .param .u32 total\n\
591) {\n\
592    .reg .u32 %r_tid, %bid, %bdim, %total_reg;\n\
593    .reg .u32 %d0r, %d1r, %d2r, %d3r;\n\
594    .reg .u32 %i0, %i1, %i2, %i3, %rem, %in_idx;\n\
595    .reg .u32 %s_out2, %s_out1, %s_in1;\n\
596    .reg .u64 %in, %out, %off_in, %off_out;\n\
597    .reg .f32 %val;\n\
598    .reg .pred %p;\n\
599\n\
600    ld.param.u64 %in, [in_ptr];\n\
601    ld.param.u64 %out, [out_ptr];\n\
602    ld.param.u32 %d0r, [d0];\n\
603    ld.param.u32 %d1r, [d1];\n\
604    ld.param.u32 %d2r, [d2];\n\
605    ld.param.u32 %d3r, [d3];\n\
606    ld.param.u32 %total_reg, [total];\n\
607\n\
608    mov.u32 %bid, %ctaid.x;\n\
609    mov.u32 %bdim, %ntid.x;\n\
610    mov.u32 %r_tid, %tid.x;\n\
611    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
612\n\
613    setp.ge.u32 %p, %r_tid, %total_reg;\n\
614    @%p bra DONE;\n\
615\n\
616    // Output shape: [d0, d2, d1, d3]\n\
617    // Decompose tid into (i0, i2, i1, i3) in output layout.\n\
618    mul.lo.u32 %s_out2, %d1r, %d3r;\n\
619    mul.lo.u32 %s_out1, %s_out2, %d2r;\n\
620\n\
621    div.u32 %i0, %r_tid, %s_out1;\n\
622    rem.u32 %rem, %r_tid, %s_out1;\n\
623    div.u32 %i2, %rem, %s_out2;\n\
624    rem.u32 %rem, %rem, %s_out2;\n\
625    div.u32 %i1, %rem, %d3r;\n\
626    rem.u32 %i3, %rem, %d3r;\n\
627\n\
628    // Input index: i0 * (d1*d2*d3) + i1 * (d2*d3) + i2 * d3 + i3\n\
629    mul.lo.u32 %s_in1, %d2r, %d3r;\n\
630    mul.lo.u32 %in_idx, %i0, %d1r;\n\
631    add.u32 %in_idx, %in_idx, %i1;\n\
632    mul.lo.u32 %in_idx, %in_idx, %s_in1;\n\
633    mad.lo.u32 %in_idx, %i2, %d3r, %in_idx;\n\
634    add.u32 %in_idx, %in_idx, %i3;\n\
635\n\
636    cvt.u64.u32 %off_in, %in_idx;\n\
637    shl.b64 %off_in, %off_in, 2;\n\
638    add.u64 %off_in, %in, %off_in;\n\
639    ld.global.f32 %val, [%off_in];\n\
640\n\
641    cvt.u64.u32 %off_out, %r_tid;\n\
642    shl.b64 %off_out, %off_out, 2;\n\
643    add.u64 %off_out, %out, %off_out;\n\
644    st.global.f32 [%off_out], %val;\n\
645\n\
646DONE:\n\
647    ret;\n\
648}\n\
649";
650
651
652// ---------------------------------------------------------------------------
653// f32-to-f16 conversion PTX kernel: out_f16[i] = float2half(in_f32[i])
654// ---------------------------------------------------------------------------
655// Used by gpu_matmul_f16 to cast f32 inputs to f16 on-GPU before calling
656// cublasGemmEx. The output is stored as u16 (IEEE 754 half-precision bits).
657
658#[cfg(feature = "cuda")]
659pub(crate) const F32_TO_F16_PTX: &str = "\
660.version 7.0
661.target sm_52
662.address_size 64
663
664.visible .entry f32_to_f16_kernel(
665    .param .u64 in_ptr,
666    .param .u64 out_ptr,
667    .param .u32 n
668) {
669    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
670    .reg .u64 %in, %out, %off_in, %off_out;
671    .reg .f32 %vf;
672    .reg .b16 %vh;
673    .reg .pred %p;
674
675    ld.param.u64 %in, [in_ptr];
676    ld.param.u64 %out, [out_ptr];
677    ld.param.u32 %n_reg, [n];
678
679    mov.u32 %bid, %ctaid.x;
680    mov.u32 %bdim, %ntid.x;
681    mov.u32 %r_tid, %tid.x;
682    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
683
684    setp.ge.u32 %p, %r_tid, %n_reg;
685    @%p bra DONE;
686
687    // Compute input offset: i * 4 (f32 = 4 bytes)
688    cvt.u64.u32 %off_in, %r_tid;
689    shl.b64 %off_in, %off_in, 2;
690    add.u64 %in, %in, %off_in;
691
692    // Compute output offset: i * 2 (f16 = 2 bytes)
693    cvt.u64.u32 %off_out, %r_tid;
694    shl.b64 %off_out, %off_out, 1;
695    add.u64 %out, %out, %off_out;
696
697    // Load f32, convert to f16 (round-to-nearest-even), store as u16
698    ld.global.f32 %vf, [%in];
699    cvt.rn.f16.f32 %vh, %vf;
700    st.global.b16 [%out], %vh;
701
702DONE:
703    ret;
704}
705";
706
707/// PTX source for `f32_to_bf16_kernel`: convert f32 → bf16 (stored as u16).
708///
709/// BF16 is the top 16 bits of f32 with round-to-nearest-even. We do this
710/// with integer bit ops: add rounding bias 0x7FFF + bit 16 of the value,
711/// then shift right 16. This works on sm_52+ (no special bf16 instructions
712/// needed).
713#[cfg(feature = "cuda")]
714pub(crate) const F32_TO_BF16_PTX: &str = "\
715.version 7.0
716.target sm_52
717.address_size 64
718
719.visible .entry f32_to_bf16_kernel(
720    .param .u64 in_ptr,
721    .param .u64 out_ptr,
722    .param .u32 n
723) {
724    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
725    .reg .u64 %in, %out, %off_in, %off_out;
726    .reg .f32 %vf;
727    .reg .u32 %bits, %round, %lsb, %result;
728    .reg .pred %p;
729
730    ld.param.u64 %in, [in_ptr];
731    ld.param.u64 %out, [out_ptr];
732    ld.param.u32 %n_reg, [n];
733
734    mov.u32 %bid, %ctaid.x;
735    mov.u32 %bdim, %ntid.x;
736    mov.u32 %r_tid, %tid.x;
737    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
738
739    setp.ge.u32 %p, %r_tid, %n_reg;
740    @%p bra DONE;
741
742    cvt.u64.u32 %off_in, %r_tid;
743    shl.b64 %off_in, %off_in, 2;
744    add.u64 %in, %in, %off_in;
745
746    cvt.u64.u32 %off_out, %r_tid;
747    shl.b64 %off_out, %off_out, 1;
748    add.u64 %out, %out, %off_out;
749
750    // Load f32 as raw bits
751    ld.global.u32 %bits, [%in];
752
753    // Round-to-nearest-even: add (0x7FFF + bit[16]) then shift right 16
754    shr.u32 %lsb, %bits, 16;
755    and.b32 %lsb, %lsb, 1;
756    add.u32 %round, %bits, 0x7FFF;
757    add.u32 %round, %round, %lsb;
758    shr.u32 %result, %round, 16;
759
760    // Store as u16
761    st.global.u16 [%out], %result;
762
763DONE:
764    ret;
765}
766";
767
768// ---------------------------------------------------------------------------
769// Small matmul PTX kernel: C = A @ B, one thread per output element
770// ---------------------------------------------------------------------------
771// For small matrices where cuBLAS JIT compilation overhead > compute time.
772// Compiles once via module_cache, never JIT-recompiles for different sizes.
773
774#[cfg(feature = "cuda")]
775pub(crate) const SMALL_MATMUL_PTX: &str = "\
776.version 7.0
777.target sm_52
778.address_size 64
779
780.visible .entry small_matmul_kernel(
781    .param .u64 a_ptr,
782    .param .u64 b_ptr,
783    .param .u64 c_ptr,
784    .param .u32 M,
785    .param .u32 K,
786    .param .u32 N,
787    .param .u32 total
788) {
789    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %K_reg, %N_reg;
790    .reg .u32 %row, %col, %p, %idx;
791    .reg .u64 %a, %b, %c, %a_off, %b_off, %c_off;
792    .reg .f32 %sum, %va, %vb;
793    .reg .pred %bounds_p, %loop_p;
794
795    ld.param.u64 %a, [a_ptr];
796    ld.param.u64 %b, [b_ptr];
797    ld.param.u64 %c, [c_ptr];
798    ld.param.u32 %M_reg, [M];
799    ld.param.u32 %K_reg, [K];
800    ld.param.u32 %N_reg, [N];
801    ld.param.u32 %total_reg, [total];
802
803    mov.u32 %bid, %ctaid.x;
804    mov.u32 %bdim, %ntid.x;
805    mov.u32 %r_tid, %tid.x;
806    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
807
808    setp.ge.u32 %bounds_p, %r_tid, %total_reg;
809    @%bounds_p bra DONE;
810
811    div.u32 %row, %r_tid, %N_reg;
812    rem.u32 %col, %r_tid, %N_reg;
813
814    mov.f32 %sum, 0f00000000;
815    mov.u32 %p, 0;
816DOT:
817    setp.ge.u32 %loop_p, %p, %K_reg;
818    @%loop_p bra DOT_DONE;
819
820    mad.lo.u32 %idx, %row, %K_reg, %p;
821    cvt.u64.u32 %a_off, %idx;
822    shl.b64 %a_off, %a_off, 2;
823    add.u64 %a_off, %a, %a_off;
824    ld.global.f32 %va, [%a_off];
825
826    mad.lo.u32 %idx, %p, %N_reg, %col;
827    cvt.u64.u32 %b_off, %idx;
828    shl.b64 %b_off, %b_off, 2;
829    add.u64 %b_off, %b, %b_off;
830    ld.global.f32 %vb, [%b_off];
831
832    fma.rn.f32 %sum, %va, %vb, %sum;
833    add.u32 %p, %p, 1;
834    bra DOT;
835DOT_DONE:
836
837    cvt.u64.u32 %c_off, %r_tid;
838    shl.b64 %c_off, %c_off, 2;
839    add.u64 %c_off, %c, %c_off;
840    st.global.f32 [%c_off], %sum;
841
842DONE:
843    ret;
844}
845";
846
847// ---------------------------------------------------------------------------
848// Slice-write PTX kernel: copy [N, D] into row `pos` of [N, max_len, D]
849// ---------------------------------------------------------------------------
850
851#[cfg(feature = "cuda")]
852pub(crate) const SLICE_WRITE_PTX: &str = "\
853.version 7.0
854.target sm_52
855.address_size 64
856
857.visible .entry slice_write_kernel(
858    .param .u64 src_ptr,
859    .param .u64 dst_ptr,
860    .param .u32 n,
861    .param .u32 D,
862    .param .u32 max_len,
863    .param .u32 pos
864) {
865    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
866    .reg .u32 %batch_idx, %d_idx, %dst_row;
867    .reg .u64 %src, %dst, %src_off, %dst_off;
868    .reg .f32 %val;
869    .reg .pred %p;
870
871    ld.param.u64 %src, [src_ptr];
872    ld.param.u64 %dst, [dst_ptr];
873    ld.param.u32 %n_reg, [n];
874    ld.param.u32 %D_reg, [D];
875    ld.param.u32 %max_len_reg, [max_len];
876    ld.param.u32 %pos_reg, [pos];
877
878    mov.u32 %bid, %ctaid.x;
879    mov.u32 %bdim, %ntid.x;
880    mov.u32 %r_tid, %tid.x;
881    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
882
883    setp.ge.u32 %p, %r_tid, %n_reg;
884    @%p bra DONE;
885
886    cvt.u64.u32 %src_off, %r_tid;
887    shl.b64 %src_off, %src_off, 2;
888    add.u64 %src, %src, %src_off;
889    ld.global.f32 %val, [%src];
890
891    div.u32 %batch_idx, %r_tid, %D_reg;
892    rem.u32 %d_idx, %r_tid, %D_reg;
893    mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
894    add.u32 %dst_row, %dst_row, %pos_reg;
895    mul.lo.u32 %dst_row, %dst_row, %D_reg;
896    add.u32 %dst_row, %dst_row, %d_idx;
897    cvt.u64.u32 %dst_off, %dst_row;
898    shl.b64 %dst_off, %dst_off, 2;
899    add.u64 %dst, %dst, %dst_off;
900    st.global.f32 [%dst], %val;
901
902DONE:
903    ret;
904}
905";
906
907
908/// PTX for `slice_write_indirect_kernel`: same as `slice_write_kernel` but
909/// reads `pos` from a device pointer. This enables CUDA graph capture — the
910/// graph records the pointer address (fixed), and we update the u32 value
911/// at that address before each graph replay.
912#[cfg(feature = "cuda")]
913pub(crate) const SLICE_WRITE_INDIRECT_PTX: &str = "\
914.version 7.0
915.target sm_52
916.address_size 64
917
918.visible .entry slice_write_indirect_kernel(
919    .param .u64 src_ptr,
920    .param .u64 dst_ptr,
921    .param .u32 n,
922    .param .u32 D,
923    .param .u32 max_len,
924    .param .u64 pos_ptr
925) {
926    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
927    .reg .u32 %batch_idx, %d_idx, %dst_row;
928    .reg .u64 %src, %dst, %src_off, %dst_off, %pos_p;
929    .reg .f32 %val;
930    .reg .pred %p;
931
932    ld.param.u64 %src, [src_ptr];
933    ld.param.u64 %dst, [dst_ptr];
934    ld.param.u32 %n_reg, [n];
935    ld.param.u32 %D_reg, [D];
936    ld.param.u32 %max_len_reg, [max_len];
937    ld.param.u64 %pos_p, [pos_ptr];
938    ld.global.u32 %pos_reg, [%pos_p];
939
940    mov.u32 %bid, %ctaid.x;
941    mov.u32 %bdim, %ntid.x;
942    mov.u32 %r_tid, %tid.x;
943    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
944
945    setp.ge.u32 %p, %r_tid, %n_reg;
946    @%p bra DONE;
947
948    cvt.u64.u32 %src_off, %r_tid;
949    shl.b64 %src_off, %src_off, 2;
950    add.u64 %src, %src, %src_off;
951    ld.global.f32 %val, [%src];
952
953    div.u32 %batch_idx, %r_tid, %D_reg;
954    rem.u32 %d_idx, %r_tid, %D_reg;
955    mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
956    add.u32 %dst_row, %dst_row, %pos_reg;
957    mul.lo.u32 %dst_row, %dst_row, %D_reg;
958    add.u32 %dst_row, %dst_row, %d_idx;
959    cvt.u64.u32 %dst_off, %dst_row;
960    shl.b64 %dst_off, %dst_off, 2;
961    add.u64 %dst, %dst, %dst_off;
962    st.global.f32 [%dst], %val;
963
964DONE:
965    ret;
966}
967";
968
969/// PTX for `causal_mask_indirect_kernel`: builds an attention mask where
970/// `out[h, col] = 0.0` for `col < total_len` and `-1e9` for `col >= total_len`.
971/// `total_len` is read from a device pointer (for CUDA graph capture).
972/// Output shape: `[n_head, max_pos]` — one mask row per head (all identical).
973/// Thread `tid` maps to flat index; column = `tid % max_pos`.
974#[cfg(feature = "cuda")]
975pub(crate) const CAUSAL_MASK_INDIRECT_PTX: &str = "\
976.version 7.0
977.target sm_52
978.address_size 64
979
980.visible .entry causal_mask_indirect_kernel(
981    .param .u64 total_len_ptr,
982    .param .u64 out_ptr,
983    .param .u32 max_pos,
984    .param .u32 total
985) {
986    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %tlen, %max_pos_reg, %col;
987    .reg .u64 %out, %off, %tl_p;
988    .reg .f32 %val;
989    .reg .pred %bounds_p, %mask_p;
990
991    ld.param.u64 %tl_p, [total_len_ptr];
992    ld.param.u64 %out, [out_ptr];
993    ld.param.u32 %max_pos_reg, [max_pos];
994    ld.param.u32 %total_reg, [total];
995
996    ld.global.u32 %tlen, [%tl_p];
997
998    mov.u32 %bid, %ctaid.x;
999    mov.u32 %bdim, %ntid.x;
1000    mov.u32 %r_tid, %tid.x;
1001    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1002
1003    setp.ge.u32 %bounds_p, %r_tid, %total_reg;
1004    @%bounds_p bra DONE;
1005
1006    rem.u32 %col, %r_tid, %max_pos_reg;
1007    setp.lt.u32 %mask_p, %col, %tlen;
1008    @%mask_p bra WRITE_ZERO;
1009
1010    // 0fCE6E6B28 = -1.0e9 in IEEE 754 f32, used as a large negative mask value
1011    // to effectively zero out masked positions after softmax.
1012    mov.f32 %val, 0fCE6E6B28;
1013    bra WRITE;
1014
1015WRITE_ZERO:
1016    mov.f32 %val, 0f00000000;
1017
1018WRITE:
1019    cvt.u64.u32 %off, %r_tid;
1020    shl.b64 %off, %off, 2;
1021    add.u64 %out, %out, %off;
1022    st.global.f32 [%out], %val;
1023
1024DONE:
1025    ret;
1026}
1027";
1028
1029// ---------------------------------------------------------------------------
1030// Embedding lookup PTX kernel: output[d] = weight[token_id * D + d]
1031// ---------------------------------------------------------------------------
1032
1033#[cfg(feature = "cuda")]
1034pub(crate) const EMBED_LOOKUP_PTX: &str = "\
1035.version 7.0
1036.target sm_52
1037.address_size 64
1038
1039.visible .entry embed_lookup_kernel(
1040    .param .u64 idx_ptr,
1041    .param .u64 weight_ptr,
1042    .param .u64 out_ptr,
1043    .param .u32 D
1044) {
1045    .reg .u32 %r_tid, %bid, %bdim, %D_reg, %row, %src_idx;
1046    .reg .u64 %idx_addr, %w, %out, %off;
1047    .reg .f32 %idx_f, %val;
1048    .reg .pred %p;
1049
1050    ld.param.u64 %idx_addr, [idx_ptr];
1051    ld.param.u64 %w, [weight_ptr];
1052    ld.param.u64 %out, [out_ptr];
1053    ld.param.u32 %D_reg, [D];
1054
1055    mov.u32 %bid, %ctaid.x;
1056    mov.u32 %bdim, %ntid.x;
1057    mov.u32 %r_tid, %tid.x;
1058    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1059
1060    setp.ge.u32 %p, %r_tid, %D_reg;
1061    @%p bra DONE;
1062
1063    ld.global.f32 %idx_f, [%idx_addr];
1064    cvt.rzi.u32.f32 %row, %idx_f;
1065
1066    mad.lo.u32 %src_idx, %row, %D_reg, %r_tid;
1067    cvt.u64.u32 %off, %src_idx;
1068    shl.b64 %off, %off, 2;
1069    add.u64 %off, %w, %off;
1070    ld.global.f32 %val, [%off];
1071
1072    cvt.u64.u32 %off, %r_tid;
1073    shl.b64 %off, %off, 2;
1074    add.u64 %off, %out, %off;
1075    st.global.f32 [%off], %val;
1076
1077DONE:
1078    ret;
1079}
1080";
1081
1082
1083// ---------------------------------------------------------------------------
1084// Batch embedding lookup PTX kernel
1085// ---------------------------------------------------------------------------
1086// Given N f32 indices and a weight matrix [V, D], gather N rows into [N, D].
1087// Thread `tid` computes one element: row = tid / D, col = tid % D.
1088// out[tid] = weight[indices[row] * D + col]
1089
1090#[cfg(feature = "cuda")]
1091pub(crate) const EMBED_LOOKUP_BATCH_PTX: &str = "\
1092.version 7.0
1093.target sm_52
1094.address_size 64
1095
1096.visible .entry embed_lookup_batch_kernel(
1097    .param .u64 idx_ptr,
1098    .param .u64 weight_ptr,
1099    .param .u64 out_ptr,
1100    .param .u32 D,
1101    .param .u32 total
1102) {
1103    .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1104    .reg .u32 %row, %col, %src_idx;
1105    .reg .u64 %idx_addr, %w, %out, %off;
1106    .reg .f32 %idx_f, %val;
1107    .reg .pred %p;
1108
1109    ld.param.u64 %idx_addr, [idx_ptr];
1110    ld.param.u64 %w, [weight_ptr];
1111    ld.param.u64 %out, [out_ptr];
1112    ld.param.u32 %D_reg, [D];
1113    ld.param.u32 %total_reg, [total];
1114
1115    mov.u32 %bid, %ctaid.x;
1116    mov.u32 %bdim, %ntid.x;
1117    mov.u32 %tid, %tid.x;
1118    mad.lo.u32 %tid, %bid, %bdim, %tid;
1119
1120    setp.ge.u32 %p, %tid, %total_reg;
1121    @%p bra DONE;
1122
1123    // row = tid / D, col = tid % D
1124    div.u32 %row, %tid, %D_reg;
1125    rem.u32 %col, %tid, %D_reg;
1126
1127    // Read indices[row] (f32 -> u32)
1128    cvt.u64.u32 %off, %row;
1129    shl.b64 %off, %off, 2;
1130    add.u64 %off, %idx_addr, %off;
1131    ld.global.f32 %idx_f, [%off];
1132    cvt.rzi.u32.f32 %src_idx, %idx_f;
1133
1134    // src_idx = indices[row] * D + col
1135    mad.lo.u32 %src_idx, %src_idx, %D_reg, %col;
1136    cvt.u64.u32 %off, %src_idx;
1137    shl.b64 %off, %off, 2;
1138    add.u64 %off, %w, %off;
1139    ld.global.f32 %val, [%off];
1140
1141    // Write to out[tid]
1142    cvt.u64.u32 %off, %tid;
1143    shl.b64 %off, %off, 2;
1144    add.u64 %off, %out, %off;
1145    st.global.f32 [%off], %val;
1146
1147DONE:
1148    ret;
1149}
1150";
1151
1152
1153// ---------------------------------------------------------------------------
1154// Scatter-add rows PTX kernel (for embedding backward)
1155// ---------------------------------------------------------------------------
1156// Given grad_output [N, D] and indices [N] (f32), atomically accumulate:
1157//   grad_weight[indices[row], col] += grad_output[row * D + col]
1158// Thread `tid` handles one element: row = tid / D, col = tid % D.
1159
1160#[cfg(feature = "cuda")]
1161pub(crate) const SCATTER_ADD_ROWS_PTX: &str = "\
1162.version 7.0
1163.target sm_52
1164.address_size 64
1165
1166.visible .entry scatter_add_rows_kernel(
1167    .param .u64 grad_output_ptr,
1168    .param .u64 indices_ptr,
1169    .param .u64 grad_weight_ptr,
1170    .param .u32 D,
1171    .param .u32 total
1172) {
1173    .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1174    .reg .u32 %row, %col, %dst_idx;
1175    .reg .u64 %go, %idx_addr, %gw, %off;
1176    .reg .f32 %idx_f, %grad_val, %dummy;
1177    .reg .pred %p;
1178
1179    ld.param.u64 %go, [grad_output_ptr];
1180    ld.param.u64 %idx_addr, [indices_ptr];
1181    ld.param.u64 %gw, [grad_weight_ptr];
1182    ld.param.u32 %D_reg, [D];
1183    ld.param.u32 %total_reg, [total];
1184
1185    mov.u32 %bid, %ctaid.x;
1186    mov.u32 %bdim, %ntid.x;
1187    mov.u32 %tid, %tid.x;
1188    mad.lo.u32 %tid, %bid, %bdim, %tid;
1189
1190    setp.ge.u32 %p, %tid, %total_reg;
1191    @%p bra DONE;
1192
1193    // row = tid / D, col = tid % D
1194    div.u32 %row, %tid, %D_reg;
1195    rem.u32 %col, %tid, %D_reg;
1196
1197    // Read grad_output[tid]
1198    cvt.u64.u32 %off, %tid;
1199    shl.b64 %off, %off, 2;
1200    add.u64 %off, %go, %off;
1201    ld.global.f32 %grad_val, [%off];
1202
1203    // Read indices[row] (f32 -> u32)
1204    cvt.u64.u32 %off, %row;
1205    shl.b64 %off, %off, 2;
1206    add.u64 %off, %idx_addr, %off;
1207    ld.global.f32 %idx_f, [%off];
1208    cvt.rzi.u32.f32 %dst_idx, %idx_f;
1209
1210    // dst_idx = indices[row] * D + col
1211    mad.lo.u32 %dst_idx, %dst_idx, %D_reg, %col;
1212    cvt.u64.u32 %off, %dst_idx;
1213    shl.b64 %off, %off, 2;
1214    add.u64 %off, %gw, %off;
1215    atom.global.add.f32 %dummy, [%off], %grad_val;
1216
1217DONE:
1218    ret;
1219}
1220";
1221
1222
1223// ---------------------------------------------------------------------------
1224// Slice-read PTX kernel: read first `len` rows from [N, max_len, D]
1225// ---------------------------------------------------------------------------
1226// Thread i writes: dst[i] = src[batch_idx * max_len * D + (i % (len*D))]
1227// where batch_idx = i / (len * D)
1228
1229#[cfg(feature = "cuda")]
1230pub(crate) const SLICE_READ_PTX: &str = "\
1231.version 7.0
1232.target sm_52
1233.address_size 64
1234
1235.visible .entry slice_read_kernel(
1236    .param .u64 src_ptr,
1237    .param .u64 dst_ptr,
1238    .param .u32 total,
1239    .param .u32 D,
1240    .param .u32 len,
1241    .param .u32 max_len
1242) {
1243    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %D_reg, %len_reg, %max_len_reg;
1244    .reg .u32 %batch_idx, %within, %row, %col, %src_idx;
1245    .reg .u32 %len_d;
1246    .reg .u64 %src, %dst, %src_off, %dst_off;
1247    .reg .f32 %val;
1248    .reg .pred %p;
1249
1250    ld.param.u64 %src, [src_ptr];
1251    ld.param.u64 %dst, [dst_ptr];
1252    ld.param.u32 %total_reg, [total];
1253    ld.param.u32 %D_reg, [D];
1254    ld.param.u32 %len_reg, [len];
1255    ld.param.u32 %max_len_reg, [max_len];
1256
1257    mov.u32 %bid, %ctaid.x;
1258    mov.u32 %bdim, %ntid.x;
1259    mov.u32 %r_tid, %tid.x;
1260    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1261
1262    setp.ge.u32 %p, %r_tid, %total_reg;
1263    @%p bra DONE;
1264
1265    // dst index = r_tid
1266    // batch_idx = r_tid / (len * D)
1267    // within = r_tid % (len * D)
1268    // row = within / D
1269    // col = within % D
1270    // src_idx = batch_idx * max_len * D + row * D + col
1271    mul.lo.u32 %len_d, %len_reg, %D_reg;
1272    div.u32 %batch_idx, %r_tid, %len_d;
1273    rem.u32 %within, %r_tid, %len_d;
1274    div.u32 %row, %within, %D_reg;
1275    rem.u32 %col, %within, %D_reg;
1276
1277    mul.lo.u32 %src_idx, %batch_idx, %max_len_reg;
1278    add.u32 %src_idx, %src_idx, %row;
1279    mul.lo.u32 %src_idx, %src_idx, %D_reg;
1280    add.u32 %src_idx, %src_idx, %col;
1281
1282    cvt.u64.u32 %src_off, %src_idx;
1283    shl.b64 %src_off, %src_off, 2;
1284    add.u64 %src_off, %src, %src_off;
1285    ld.global.f32 %val, [%src_off];
1286
1287    cvt.u64.u32 %dst_off, %r_tid;
1288    shl.b64 %dst_off, %dst_off, 2;
1289    add.u64 %dst_off, %dst, %dst_off;
1290    st.global.f32 [%dst_off], %val;
1291
1292DONE:
1293    ret;
1294}
1295";
1296
1297
1298// ---------------------------------------------------------------------------
1299// GELU PTX kernel: gelu(x) = x * sigmoid(1.702 * x)
1300//
1301// Uses `.approx` PTX instructions (`ex2.approx.f32`, `rcp.approx.f32`)
1302// for performance. These have reduced precision (~2^-22 relative error)
1303// compared to the full-precision variants, which is acceptable for neural
1304// network training/inference where f32 precision is already limited.
1305// ---------------------------------------------------------------------------
1306
1307#[cfg(feature = "cuda")]
1308pub(crate) const GELU_PTX: &str = "\
1309.version 7.0
1310.target sm_52
1311.address_size 64
1312
1313.visible .entry gelu_kernel(
1314    .param .u64 in_ptr,
1315    .param .u64 out_ptr,
1316    .param .u32 n
1317) {
1318    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1319    .reg .u64 %in, %out, %off;
1320    .reg .f32 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
1321    .reg .pred %p;
1322
1323    ld.param.u64 %in, [in_ptr];
1324    ld.param.u64 %out, [out_ptr];
1325    ld.param.u32 %n_reg, [n];
1326
1327    mov.u32 %bid, %ctaid.x;
1328    mov.u32 %bdim, %ntid.x;
1329    mov.u32 %r_tid, %tid.x;
1330    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1331
1332    setp.ge.u32 %p, %r_tid, %n_reg;
1333    @%p bra DONE;
1334
1335    cvt.u64.u32 %off, %r_tid;
1336    shl.b64 %off, %off, 2;
1337    add.u64 %in, %in, %off;
1338    add.u64 %out, %out, %off;
1339
1340    ld.global.f32 %x, [%in];
1341
1342    mov.f32 %k, 0f3FDA2720;
1343    mul.f32 %neg_kx, %k, %x;
1344    neg.f32 %neg_kx, %neg_kx;
1345    mul.f32 %neg_kx, %neg_kx, 0f3FB8AA3B;
1346    ex2.approx.f32 %exp_neg, %neg_kx;
1347    mov.f32 %one, 0f3F800000;
1348    add.f32 %denom, %one, %exp_neg;
1349    rcp.approx.f32 %sig, %denom;
1350    mul.f32 %result, %x, %sig;
1351    st.global.f32 [%out], %result;
1352
1353DONE:
1354    ret;
1355}
1356";
1357
1358/// PTX source for `gelu_f64_kernel`: `out[i] = x * sigmoid(1.702 * x)` (f64).
1359/// Uses f32-downcast for transcendentals.
1360#[cfg(feature = "cuda")]
1361pub(crate) const GELU_F64_PTX: &str = "\
1362.version 7.0
1363.target sm_52
1364.address_size 64
1365
1366.visible .entry gelu_f64_kernel(
1367    .param .u64 in_ptr,
1368    .param .u64 out_ptr,
1369    .param .u32 n
1370) {
1371    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1372    .reg .u64 %in, %out, %off;
1373    .reg .f64 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
1374    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1375    .reg .s32 %e_ni;
1376    .reg .s64 %e_ni64, %e_bits;
1377    .reg .pred %p;
1378
1379    ld.param.u64 %in, [in_ptr];
1380    ld.param.u64 %out, [out_ptr];
1381    ld.param.u32 %n_reg, [n];
1382
1383    mov.u32 %bid, %ctaid.x;
1384    mov.u32 %bdim, %ntid.x;
1385    mov.u32 %r_tid, %tid.x;
1386    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1387
1388    setp.ge.u32 %p, %r_tid, %n_reg;
1389    @%p bra DONE;
1390
1391    cvt.u64.u32 %off, %r_tid;
1392    shl.b64 %off, %off, 3;
1393    add.u64 %in, %in, %off;
1394    add.u64 %out, %out, %off;
1395
1396    ld.global.f64 %x, [%in];
1397    mov.f64 %one, 0d3FF0000000000000;
1398
1399    // k = 1.702
1400    mov.f64 %k, 0d3FFB44E400000000;
1401    mul.f64 %neg_kx, %k, %x;
1402    neg.f64 %neg_kx, %neg_kx;
1403
1404    // --- exp(%neg_kx) via Cody-Waite + degree-11 Horner ---
1405    mov.f64 %e_half, 0d3FE0000000000000;
1406    fma.rn.f64 %e_nf, %neg_kx, 0d3FF71547652B82FE, %e_half;
1407    cvt.rmi.f64.f64 %e_nf, %e_nf;
1408    cvt.rni.s32.f64 %e_ni, %e_nf;
1409    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_kx;
1410    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1411    mov.f64 %e_p, 0d3E21EED8EFF8D898;
1412    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1413    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
1414    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
1415    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
1416    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
1417    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
1418    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
1419    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
1420    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
1421    fma.rn.f64 %e_p, %e_p, %e_r, %one;
1422    fma.rn.f64 %exp_neg, %e_p, %e_r, %one;
1423    cvt.s64.s32 %e_ni64, %e_ni;
1424    add.s64 %e_ni64, %e_ni64, 1023;
1425    shl.b64 %e_bits, %e_ni64, 52;
1426    mov.b64 %e_nf, %e_bits;
1427    mul.f64 %exp_neg, %exp_neg, %e_nf;
1428    // --- end exp ---
1429
1430    add.f64 %denom, %one, %exp_neg;
1431    div.rn.f64 %sig, %one, %denom;
1432    mul.f64 %result, %x, %sig;
1433    st.global.f64 [%out], %result;
1434
1435DONE:
1436    ret;
1437}
1438";
1439
1440/// PTX source for `gelu_tanh_kernel`: tanh approximation of GELU.
1441/// `out[i] = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))`
1442///
1443/// Uses `ex2.approx.f32` for exp and Horner-form tanh approximation via
1444/// `tanh(y) = (e^(2y) - 1) / (e^(2y) + 1)`.
1445#[cfg(feature = "cuda")]
1446pub(crate) const GELU_TANH_PTX: &str = "\
1447.version 7.0
1448.target sm_52
1449.address_size 64
1450
1451.visible .entry gelu_tanh_kernel(
1452    .param .u64 in_ptr,
1453    .param .u64 out_ptr,
1454    .param .u32 n
1455) {
1456    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1457    .reg .u64 %in, %out, %off;
1458    .reg .f32 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
1459    .reg .f32 %e2y_m1, %e2y_p1, %th, %one, %half, %log2e, %result;
1460    .reg .pred %p;
1461
1462    ld.param.u64 %in, [in_ptr];
1463    ld.param.u64 %out, [out_ptr];
1464    ld.param.u32 %n_reg, [n];
1465
1466    mov.u32 %bid, %ctaid.x;
1467    mov.u32 %bdim, %ntid.x;
1468    mov.u32 %r_tid, %tid.x;
1469    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1470
1471    setp.ge.u32 %p, %r_tid, %n_reg;
1472    @%p bra DONE;
1473
1474    cvt.u64.u32 %off, %r_tid;
1475    shl.b64 %off, %off, 2;
1476    add.u64 %in, %in, %off;
1477    add.u64 %out, %out, %off;
1478
1479    ld.global.f32 %x, [%in];
1480
1481    // inner = sqrt(2/π) * (x + 0.044715 * x³)
1482    // sqrt(2/π) = 0.7978845608 = 0x3F4C422A
1483    // 0.044715 = 0x3D372713
1484    mul.f32 %x3, %x, %x;
1485    mul.f32 %x3, %x3, %x;
1486    mov.f32 %c, 0f3D372713;
1487    mul.f32 %x3, %c, %x3;
1488    add.f32 %inner, %x, %x3;
1489    mov.f32 %sqrt2pi, 0f3F4C422A;
1490    mul.f32 %y, %sqrt2pi, %inner;
1491
1492    // tanh(y) = (exp(2y) - 1) / (exp(2y) + 1)
1493    // exp(2y) = 2^(2y * log2(e))
1494    mov.f32 %log2e, 0f3FB8AA3B;
1495    add.f32 %two_y, %y, %y;
1496    mul.f32 %two_y, %two_y, %log2e;
1497    ex2.approx.f32 %e2y, %two_y;
1498    mov.f32 %one, 0f3F800000;
1499    sub.f32 %e2y_m1, %e2y, %one;
1500    add.f32 %e2y_p1, %e2y, %one;
1501    rcp.approx.f32 %e2y_p1, %e2y_p1;
1502    mul.f32 %th, %e2y_m1, %e2y_p1;
1503
1504    // out = 0.5 * x * (1 + tanh)
1505    add.f32 %th, %one, %th;
1506    mov.f32 %half, 0f3F000000;
1507    mul.f32 %result, %half, %x;
1508    mul.f32 %result, %result, %th;
1509    st.global.f32 [%out], %result;
1510
1511DONE:
1512    ret;
1513}
1514";
1515
1516/// PTX source for `gelu_tanh_f64_kernel`: tanh-approx GELU (f64).
1517/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(2y) in tanh.
1518#[cfg(feature = "cuda")]
1519pub(crate) const GELU_TANH_F64_PTX: &str = "\
1520.version 7.0
1521.target sm_52
1522.address_size 64
1523
1524.visible .entry gelu_tanh_f64_kernel(
1525    .param .u64 in_ptr,
1526    .param .u64 out_ptr,
1527    .param .u32 n
1528) {
1529    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1530    .reg .u64 %in, %out, %off;
1531    .reg .f64 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
1532    .reg .f64 %e2y_m1, %e2y_p1, %th, %one, %half, %result;
1533    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1534    .reg .s32 %e_ni;
1535    .reg .s64 %e_ni64, %e_bits;
1536    .reg .pred %p;
1537
1538    ld.param.u64 %in, [in_ptr];
1539    ld.param.u64 %out, [out_ptr];
1540    ld.param.u32 %n_reg, [n];
1541
1542    mov.u32 %bid, %ctaid.x;
1543    mov.u32 %bdim, %ntid.x;
1544    mov.u32 %r_tid, %tid.x;
1545    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1546
1547    setp.ge.u32 %p, %r_tid, %n_reg;
1548    @%p bra DONE;
1549
1550    cvt.u64.u32 %off, %r_tid;
1551    shl.b64 %off, %off, 3;
1552    add.u64 %in, %in, %off;
1553    add.u64 %out, %out, %off;
1554
1555    ld.global.f64 %x, [%in];
1556    mov.f64 %one, 0d3FF0000000000000;
1557
1558    // inner = sqrt(2/pi) * (x + 0.044715 * x^3)
1559    mul.f64 %x3, %x, %x;
1560    mul.f64 %x3, %x3, %x;
1561    mov.f64 %c, 0d3FA6E4E260000000;
1562    mul.f64 %x3, %c, %x3;
1563    add.f64 %inner, %x, %x3;
1564    mov.f64 %sqrt2pi, 0d3FE9884540000000;
1565    mul.f64 %y, %sqrt2pi, %inner;
1566
1567    // tanh(y) = (exp(2y)-1)/(exp(2y)+1), exp(2y) in full f64
1568    add.f64 %two_y, %y, %y;
1569
1570    // --- exp(%two_y) via Cody-Waite + degree-11 Horner ---
1571    mov.f64 %e_half, 0d3FE0000000000000;
1572    fma.rn.f64 %e_nf, %two_y, 0d3FF71547652B82FE, %e_half;
1573    cvt.rmi.f64.f64 %e_nf, %e_nf;
1574    cvt.rni.s32.f64 %e_ni, %e_nf;
1575    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_y;
1576    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1577    mov.f64 %e_p, 0d3E21EED8EFF8D898;
1578    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1579    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
1580    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
1581    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
1582    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
1583    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
1584    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
1585    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
1586    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
1587    fma.rn.f64 %e_p, %e_p, %e_r, %one;
1588    fma.rn.f64 %e2y, %e_p, %e_r, %one;
1589    cvt.s64.s32 %e_ni64, %e_ni;
1590    add.s64 %e_ni64, %e_ni64, 1023;
1591    shl.b64 %e_bits, %e_ni64, 52;
1592    mov.b64 %e_nf, %e_bits;
1593    mul.f64 %e2y, %e2y, %e_nf;
1594    // --- end exp ---
1595
1596    sub.f64 %e2y_m1, %e2y, %one;
1597    add.f64 %e2y_p1, %e2y, %one;
1598    div.rn.f64 %th, %e2y_m1, %e2y_p1;
1599
1600    // out = 0.5 * x * (1 + tanh)
1601    add.f64 %th, %one, %th;
1602    mov.f64 %half, 0d3FE0000000000000;
1603    mul.f64 %result, %half, %x;
1604    mul.f64 %result, %result, %th;
1605    st.global.f64 [%out], %result;
1606
1607DONE:
1608    ret;
1609}
1610";
1611
1612/// PTX source for `gelu_erf_kernel`: exact GELU using erf.
1613/// `out[i] = x * 0.5 * (1 + erf(x / sqrt(2)))`
1614///
1615/// Uses Abramowitz & Stegun formula 7.1.26 for erf (|ε| < 1.5×10⁻⁷).
1616#[cfg(feature = "cuda")]
1617pub(crate) const GELU_ERF_PTX: &str = "\
1618.version 7.0
1619.target sm_52
1620.address_size 64
1621
1622.visible .entry gelu_erf_kernel(
1623    .param .u64 in_ptr,
1624    .param .u64 out_ptr,
1625    .param .u32 n
1626) {
1627    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1628    .reg .u64 %in, %out, %off;
1629    .reg .f32 %x, %z, %ax, %one, %half, %log2e;
1630    .reg .f32 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
1631    .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %result;
1632    .reg .pred %pred_ge, %pred_neg;
1633
1634    ld.param.u64 %in, [in_ptr];
1635    ld.param.u64 %out, [out_ptr];
1636    ld.param.u32 %n_reg, [n];
1637
1638    mov.u32 %bid, %ctaid.x;
1639    mov.u32 %bdim, %ntid.x;
1640    mov.u32 %r_tid, %tid.x;
1641    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1642
1643    setp.ge.u32 %pred_ge, %r_tid, %n_reg;
1644    @%pred_ge bra DONE;
1645
1646    cvt.u64.u32 %off, %r_tid;
1647    shl.b64 %off, %off, 2;
1648    add.u64 %in, %in, %off;
1649    add.u64 %out, %out, %off;
1650
1651    ld.global.f32 %x, [%in];
1652    mov.f32 %one, 0f3F800000;
1653    mov.f32 %half, 0f3F000000;
1654    mov.f32 %log2e, 0f3FB8AA3B;
1655
1656    // z = x / sqrt(2) = x * 0.70710678
1657    mov.f32 %z, 0f3F3504F3;
1658    mul.f32 %z, %x, %z;
1659
1660    // |z| for erf(|z|)
1661    abs.f32 %ax, %z;
1662
1663    // t = 1 / (1 + 0.3275911 * |z|)
1664    mov.f32 %p, 0f3EA7BA05;
1665    mul.f32 %t, %p, %ax;
1666    add.f32 %t, %one, %t;
1667    rcp.approx.f32 %t, %t;
1668
1669    // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
1670    mov.f32 %a5, 0f3E0AAAAB;
1671    mov.f32 %a4, 0fBEB3A903;
1672    mov.f32 %a3, 0f3FB506DD;
1673    mov.f32 %a2, 0fBF03C1E1;
1674    mov.f32 %a1, 0f3EA0D6BB;
1675
1676    mul.f32 %pt, %t, %a5;
1677    add.f32 %pt, %pt, %a4;
1678    mul.f32 %pt, %pt, %t;
1679    add.f32 %pt, %pt, %a3;
1680    mul.f32 %pt, %pt, %t;
1681    add.f32 %pt, %pt, %a2;
1682    mul.f32 %pt, %pt, %t;
1683    add.f32 %pt, %pt, %a1;
1684    mul.f32 %pt, %pt, %t;
1685
1686    // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
1687    mul.f32 %z2, %ax, %ax;
1688    neg.f32 %neg_z2, %z2;
1689    mul.f32 %neg_z2, %neg_z2, %log2e;
1690    ex2.approx.f32 %exp_neg_z2, %neg_z2;
1691
1692    // erf(|z|) = 1 - poly * exp(-z^2)
1693    mul.f32 %erf_val, %pt, %exp_neg_z2;
1694    sub.f32 %erf_val, %one, %erf_val;
1695
1696    // erf(-z) = -erf(z), so sign-correct
1697    setp.lt.f32 %pred_neg, %z, 0f00000000;
1698    @%pred_neg neg.f32 %erf_val, %erf_val;
1699
1700    // out = x * 0.5 * (1 + erf(x/sqrt(2)))
1701    add.f32 %erf_val, %one, %erf_val;
1702    mul.f32 %result, %half, %x;
1703    mul.f32 %result, %result, %erf_val;
1704    st.global.f32 [%out], %result;
1705
1706DONE:
1707    ret;
1708}
1709";
1710
1711/// PTX source for `gelu_erf_f64_kernel`: exact erf GELU (f64).
1712/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(-z^2).
1713#[cfg(feature = "cuda")]
1714pub(crate) const GELU_ERF_F64_PTX: &str = "\
1715.version 7.0
1716.target sm_52
1717.address_size 64
1718
1719.visible .entry gelu_erf_f64_kernel(
1720    .param .u64 in_ptr,
1721    .param .u64 out_ptr,
1722    .param .u32 n
1723) {
1724    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1725    .reg .u64 %in, %out, %off;
1726    .reg .f64 %x, %z, %ax, %one, %half;
1727    .reg .f64 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
1728    .reg .f64 %p, %a1, %a2, %a3, %a4, %a5, %result;
1729    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1730    .reg .s32 %e_ni;
1731    .reg .s64 %e_ni64, %e_bits;
1732    .reg .pred %pred_ge, %pred_neg;
1733
1734    ld.param.u64 %in, [in_ptr];
1735    ld.param.u64 %out, [out_ptr];
1736    ld.param.u32 %n_reg, [n];
1737
1738    mov.u32 %bid, %ctaid.x;
1739    mov.u32 %bdim, %ntid.x;
1740    mov.u32 %r_tid, %tid.x;
1741    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1742
1743    setp.ge.u32 %pred_ge, %r_tid, %n_reg;
1744    @%pred_ge bra DONE;
1745
1746    cvt.u64.u32 %off, %r_tid;
1747    shl.b64 %off, %off, 3;
1748    add.u64 %in, %in, %off;
1749    add.u64 %out, %out, %off;
1750
1751    ld.global.f64 %x, [%in];
1752    mov.f64 %one, 0d3FF0000000000000;
1753    mov.f64 %half, 0d3FE0000000000000;
1754
1755    // z = x / sqrt(2) = x * 0.70710678
1756    mov.f64 %z, 0d3FE6A09E60000000;
1757    mul.f64 %z, %x, %z;
1758
1759    abs.f64 %ax, %z;
1760
1761    // t = 1 / (1 + 0.3275911 * |z|)
1762    mov.f64 %p, 0d3FD4F740A0000000;
1763    mul.f64 %t, %p, %ax;
1764    add.f64 %t, %one, %t;
1765    div.rn.f64 %t, %one, %t;
1766
1767    // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
1768    mov.f64 %a5, 0d3FC1555560000000;
1769    mov.f64 %a4, 0dBFD6752060000000;
1770    mov.f64 %a3, 0d3FF6A0DBA0000000;
1771    mov.f64 %a2, 0dBFE0783C20000000;
1772    mov.f64 %a1, 0d3FD41AD760000000;
1773
1774    mul.f64 %pt, %t, %a5;
1775    add.f64 %pt, %pt, %a4;
1776    mul.f64 %pt, %pt, %t;
1777    add.f64 %pt, %pt, %a3;
1778    mul.f64 %pt, %pt, %t;
1779    add.f64 %pt, %pt, %a2;
1780    mul.f64 %pt, %pt, %t;
1781    add.f64 %pt, %pt, %a1;
1782    mul.f64 %pt, %pt, %t;
1783
1784    // exp(-z^2) in full f64
1785    mul.f64 %z2, %ax, %ax;
1786    neg.f64 %neg_z2, %z2;
1787
1788    // --- exp(%neg_z2) via Cody-Waite + degree-11 Horner ---
1789    mov.f64 %e_half, 0d3FE0000000000000;
1790    fma.rn.f64 %e_nf, %neg_z2, 0d3FF71547652B82FE, %e_half;
1791    cvt.rmi.f64.f64 %e_nf, %e_nf;
1792    cvt.rni.s32.f64 %e_ni, %e_nf;
1793    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_z2;
1794    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1795    mov.f64 %e_p, 0d3E21EED8EFF8D898;
1796    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1797    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
1798    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
1799    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
1800    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
1801    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
1802    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
1803    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
1804    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
1805    fma.rn.f64 %e_p, %e_p, %e_r, %one;
1806    fma.rn.f64 %exp_neg_z2, %e_p, %e_r, %one;
1807    cvt.s64.s32 %e_ni64, %e_ni;
1808    add.s64 %e_ni64, %e_ni64, 1023;
1809    shl.b64 %e_bits, %e_ni64, 52;
1810    mov.b64 %e_nf, %e_bits;
1811    mul.f64 %exp_neg_z2, %exp_neg_z2, %e_nf;
1812    // --- end exp ---
1813
1814    mul.f64 %erf_val, %pt, %exp_neg_z2;
1815    sub.f64 %erf_val, %one, %erf_val;
1816
1817    setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
1818    @%pred_neg neg.f64 %erf_val, %erf_val;
1819
1820    add.f64 %erf_val, %one, %erf_val;
1821    mul.f64 %result, %half, %x;
1822    mul.f64 %result, %result, %erf_val;
1823    st.global.f64 [%out], %result;
1824
1825DONE:
1826    ret;
1827}
1828";
1829
1830/// PTX source for `gelu_backward_tanh_kernel`:
1831/// Backward for tanh approximation of GELU.
1832/// Let `u = sqrt(2/π) * (x + 0.044715 * x³)`, `t = tanh(u)`.
1833/// `d/dx = 0.5 * (1 + t) + 0.5 * x * (1 - t²) * sqrt(2/π) * (1 + 3*0.044715*x²)`
1834/// `out[i] = grad[i] * d/dx`
1835#[cfg(feature = "cuda")]
1836pub(crate) const GELU_BACKWARD_TANH_PTX: &str = "\
1837.version 7.0
1838.target sm_52
1839.address_size 64
1840
1841.visible .entry gelu_backward_tanh_kernel(
1842    .param .u64 grad_ptr,
1843    .param .u64 input_ptr,
1844    .param .u64 out_ptr,
1845    .param .u32 n
1846) {
1847    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1848    .reg .u64 %grad, %input, %out, %off;
1849    .reg .f32 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
1850    .reg .f32 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half, %log2e;
1851    .reg .f32 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
1852    .reg .pred %p;
1853
1854    ld.param.u64 %grad, [grad_ptr];
1855    ld.param.u64 %input, [input_ptr];
1856    ld.param.u64 %out, [out_ptr];
1857    ld.param.u32 %n_reg, [n];
1858
1859    mov.u32 %bid, %ctaid.x;
1860    mov.u32 %bdim, %ntid.x;
1861    mov.u32 %r_tid, %tid.x;
1862    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1863
1864    setp.ge.u32 %p, %r_tid, %n_reg;
1865    @%p bra DONE;
1866
1867    cvt.u64.u32 %off, %r_tid;
1868    shl.b64 %off, %off, 2;
1869    add.u64 %grad, %grad, %off;
1870    add.u64 %input, %input, %off;
1871    add.u64 %out, %out, %off;
1872
1873    ld.global.f32 %vg, [%grad];
1874    ld.global.f32 %x, [%input];
1875
1876    mov.f32 %one, 0f3F800000;
1877    mov.f32 %half, 0f3F000000;
1878    mov.f32 %log2e, 0f3FB8AA3B;
1879    mov.f32 %sqrt2pi, 0f3F4C422A;
1880    mov.f32 %c, 0f3D372713;
1881    // 3 * 0.044715 = 0.134145 = 0x3E096B8C
1882    mov.f32 %c3, 0f3E096B8C;
1883
1884    // u = sqrt(2/π) * (x + 0.044715 * x³)
1885    mul.f32 %x2, %x, %x;
1886    mul.f32 %x3, %x2, %x;
1887    mul.f32 %x3, %c, %x3;
1888    add.f32 %inner, %x, %x3;
1889    mul.f32 %y, %sqrt2pi, %inner;
1890
1891    // tanh(y) via exp
1892    add.f32 %two_y, %y, %y;
1893    mul.f32 %two_y, %two_y, %log2e;
1894    ex2.approx.f32 %e2y, %two_y;
1895    sub.f32 %e2y_m1, %e2y, %one;
1896    add.f32 %e2y_p1, %e2y, %one;
1897    rcp.approx.f32 %e2y_p1, %e2y_p1;
1898    mul.f32 %th, %e2y_m1, %e2y_p1;
1899
1900    // d/dx = 0.5*(1+tanh) + 0.5*x*(1-tanh²)*sqrt(2/π)*(1+3*0.044715*x²)
1901    // term1 = 0.5 * (1 + th)
1902    add.f32 %term1, %one, %th;
1903    mul.f32 %term1, %half, %term1;
1904
1905    // (1 - th²)
1906    mul.f32 %th2, %th, %th;
1907    sub.f32 %one_m_th2, %one, %th2;
1908
1909    // d_inner = sqrt(2/π) * (1 + 3*0.044715*x²)
1910    mul.f32 %d_inner, %c3, %x2;
1911    add.f32 %d_inner, %one, %d_inner;
1912    mul.f32 %d_inner, %sqrt2pi, %d_inner;
1913
1914    // term2 = 0.5 * x * (1-th²) * d_inner
1915    mul.f32 %term2, %half, %x;
1916    mul.f32 %term2, %term2, %one_m_th2;
1917    mul.f32 %term2, %term2, %d_inner;
1918
1919    add.f32 %d_gelu, %term1, %term2;
1920    mul.f32 %result, %vg, %d_gelu;
1921    st.global.f32 [%out], %result;
1922
1923DONE:
1924    ret;
1925}
1926";
1927
1928/// PTX source for `gelu_backward_tanh_f64_kernel`: tanh-approx backward (f64).
1929/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(2y) in tanh.
1930#[cfg(feature = "cuda")]
1931pub(crate) const GELU_BACKWARD_TANH_F64_PTX: &str = "\
1932.version 7.0
1933.target sm_52
1934.address_size 64
1935
1936.visible .entry gelu_backward_tanh_f64_kernel(
1937    .param .u64 grad_ptr,
1938    .param .u64 input_ptr,
1939    .param .u64 out_ptr,
1940    .param .u32 n
1941) {
1942    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1943    .reg .u64 %grad, %input, %out, %off;
1944    .reg .f64 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
1945    .reg .f64 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half;
1946    .reg .f64 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
1947    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1948    .reg .s32 %e_ni;
1949    .reg .s64 %e_ni64, %e_bits;
1950    .reg .pred %p;
1951
1952    ld.param.u64 %grad, [grad_ptr];
1953    ld.param.u64 %input, [input_ptr];
1954    ld.param.u64 %out, [out_ptr];
1955    ld.param.u32 %n_reg, [n];
1956
1957    mov.u32 %bid, %ctaid.x;
1958    mov.u32 %bdim, %ntid.x;
1959    mov.u32 %r_tid, %tid.x;
1960    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1961
1962    setp.ge.u32 %p, %r_tid, %n_reg;
1963    @%p bra DONE;
1964
1965    cvt.u64.u32 %off, %r_tid;
1966    shl.b64 %off, %off, 3;
1967    add.u64 %grad, %grad, %off;
1968    add.u64 %input, %input, %off;
1969    add.u64 %out, %out, %off;
1970
1971    ld.global.f64 %vg, [%grad];
1972    ld.global.f64 %x, [%input];
1973
1974    mov.f64 %one, 0d3FF0000000000000;
1975    mov.f64 %half, 0d3FE0000000000000;
1976    mov.f64 %sqrt2pi, 0d3FE9884540000000;
1977    mov.f64 %c, 0d3FA6E4E260000000;
1978    // 3 * 0.044715 = 0.134145
1979    mov.f64 %c3, 0d3FC12D7180000000;
1980
1981    mul.f64 %x2, %x, %x;
1982    mul.f64 %x3, %x2, %x;
1983    mul.f64 %x3, %c, %x3;
1984    add.f64 %inner, %x, %x3;
1985    mul.f64 %y, %sqrt2pi, %inner;
1986
1987    // tanh(y) = (exp(2y)-1)/(exp(2y)+1) in full f64
1988    add.f64 %two_y, %y, %y;
1989
1990    // --- exp(%two_y) via Cody-Waite + degree-11 Horner ---
1991    mov.f64 %e_half, 0d3FE0000000000000;
1992    fma.rn.f64 %e_nf, %two_y, 0d3FF71547652B82FE, %e_half;
1993    cvt.rmi.f64.f64 %e_nf, %e_nf;
1994    cvt.rni.s32.f64 %e_ni, %e_nf;
1995    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_y;
1996    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1997    mov.f64 %e_p, 0d3E21EED8EFF8D898;
1998    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1999    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2000    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2001    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2002    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2003    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2004    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2005    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2006    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2007    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2008    fma.rn.f64 %e2y, %e_p, %e_r, %one;
2009    cvt.s64.s32 %e_ni64, %e_ni;
2010    add.s64 %e_ni64, %e_ni64, 1023;
2011    shl.b64 %e_bits, %e_ni64, 52;
2012    mov.b64 %e_nf, %e_bits;
2013    mul.f64 %e2y, %e2y, %e_nf;
2014    // --- end exp ---
2015
2016    sub.f64 %e2y_m1, %e2y, %one;
2017    add.f64 %e2y_p1, %e2y, %one;
2018    div.rn.f64 %th, %e2y_m1, %e2y_p1;
2019
2020    add.f64 %term1, %one, %th;
2021    mul.f64 %term1, %half, %term1;
2022
2023    mul.f64 %th2, %th, %th;
2024    sub.f64 %one_m_th2, %one, %th2;
2025
2026    mul.f64 %d_inner, %c3, %x2;
2027    add.f64 %d_inner, %one, %d_inner;
2028    mul.f64 %d_inner, %sqrt2pi, %d_inner;
2029
2030    mul.f64 %term2, %half, %x;
2031    mul.f64 %term2, %term2, %one_m_th2;
2032    mul.f64 %term2, %term2, %d_inner;
2033
2034    add.f64 %d_gelu, %term1, %term2;
2035    mul.f64 %result, %vg, %d_gelu;
2036    st.global.f64 [%out], %result;
2037
2038DONE:
2039    ret;
2040}
2041";
2042
2043// ---------------------------------------------------------------------------
2044// SiLU / ELU / Mish activation kernels (forward + backward)
2045// ---------------------------------------------------------------------------
2046
2047/// PTX source for `silu_kernel`: `out[i] = x * sigmoid(x)`.
2048/// SiLU (Sigmoid Linear Unit), also known as Swish-1.
2049#[cfg(feature = "cuda")]
2050pub(crate) const SILU_PTX: &str = "\
2051.version 7.0
2052.target sm_52
2053.address_size 64
2054
2055.visible .entry silu_kernel(
2056    .param .u64 a_ptr,
2057    .param .u64 out_ptr,
2058    .param .u32 n
2059) {
2060    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2061    .reg .u64 %a, %out, %off;
2062    .reg .f32 %x, %neg, %e, %denom, %sig, %vr, %one, %lg2e;
2063    .reg .pred %p;
2064
2065    ld.param.u64 %a, [a_ptr];
2066    ld.param.u64 %out, [out_ptr];
2067    ld.param.u32 %n_reg, [n];
2068
2069    mov.u32 %bid, %ctaid.x;
2070    mov.u32 %bdim, %ntid.x;
2071    mov.u32 %r_tid, %tid.x;
2072    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2073
2074    setp.ge.u32 %p, %r_tid, %n_reg;
2075    @%p bra DONE;
2076
2077    cvt.u64.u32 %off, %r_tid;
2078    shl.b64 %off, %off, 2;
2079
2080    add.u64 %a, %a, %off;
2081    add.u64 %out, %out, %off;
2082
2083    ld.global.f32 %x, [%a];
2084    // sigmoid(x) = 1 / (1 + exp(-x))
2085    // exp(-x) = 2^(-x * log2(e))
2086    mov.f32 %one, 0f3F800000;
2087    mov.f32 %lg2e, 0f3FB8AA3B;
2088    neg.f32 %neg, %x;
2089    mul.f32 %neg, %neg, %lg2e;
2090    ex2.approx.f32 %e, %neg;
2091    add.f32 %denom, %one, %e;
2092    rcp.approx.f32 %sig, %denom;
2093    // silu(x) = x * sigmoid(x)
2094    mul.f32 %vr, %x, %sig;
2095    st.global.f32 [%out], %vr;
2096
2097DONE:
2098    ret;
2099}
2100";
2101
2102/// PTX source for `silu_f64_kernel`: `out[i] = x * sigmoid(x)` (f64).
2103/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(-x).
2104#[cfg(feature = "cuda")]
2105pub(crate) const SILU_F64_PTX: &str = "\
2106.version 7.0
2107.target sm_52
2108.address_size 64
2109
2110.visible .entry silu_f64_kernel(
2111    .param .u64 a_ptr,
2112    .param .u64 out_ptr,
2113    .param .u32 n
2114) {
2115    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2116    .reg .u64 %a, %out, %off;
2117    .reg .f64 %x, %neg_x, %e, %denom, %sig, %vr, %one;
2118    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2119    .reg .s32 %e_ni;
2120    .reg .s64 %e_ni64, %e_bits;
2121    .reg .pred %p;
2122
2123    ld.param.u64 %a, [a_ptr];
2124    ld.param.u64 %out, [out_ptr];
2125    ld.param.u32 %n_reg, [n];
2126
2127    mov.u32 %bid, %ctaid.x;
2128    mov.u32 %bdim, %ntid.x;
2129    mov.u32 %r_tid, %tid.x;
2130    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2131
2132    setp.ge.u32 %p, %r_tid, %n_reg;
2133    @%p bra DONE;
2134
2135    cvt.u64.u32 %off, %r_tid;
2136    shl.b64 %off, %off, 3;
2137    add.u64 %a, %a, %off;
2138    add.u64 %out, %out, %off;
2139
2140    ld.global.f64 %x, [%a];
2141    mov.f64 %one, 0d3FF0000000000000;
2142    neg.f64 %neg_x, %x;
2143
2144    // --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
2145    mov.f64 %e_half, 0d3FE0000000000000;
2146    fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
2147    cvt.rmi.f64.f64 %e_nf, %e_nf;
2148    cvt.rni.s32.f64 %e_ni, %e_nf;
2149    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
2150    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2151    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2152    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2153    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2154    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2155    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2156    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2157    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2158    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2159    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2160    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2161    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2162    fma.rn.f64 %e, %e_p, %e_r, %one;
2163    cvt.s64.s32 %e_ni64, %e_ni;
2164    add.s64 %e_ni64, %e_ni64, 1023;
2165    shl.b64 %e_bits, %e_ni64, 52;
2166    mov.b64 %e_nf, %e_bits;
2167    mul.f64 %e, %e, %e_nf;
2168    // --- end exp ---
2169
2170    add.f64 %denom, %one, %e;
2171    div.rn.f64 %sig, %one, %denom;
2172    mul.f64 %vr, %x, %sig;
2173    st.global.f64 [%out], %vr;
2174
2175DONE:
2176    ret;
2177}
2178";
2179
2180/// PTX source for `silu_backward_kernel`:
2181/// `out[i] = grad[i] * (sig + x * sig * (1 - sig))` where `sig = sigmoid(input[i])`.
2182#[cfg(feature = "cuda")]
2183pub(crate) const SILU_BACKWARD_PTX: &str = "\
2184.version 7.0
2185.target sm_52
2186.address_size 64
2187
2188.visible .entry silu_backward_kernel(
2189    .param .u64 grad_ptr,
2190    .param .u64 input_ptr,
2191    .param .u64 out_ptr,
2192    .param .u32 n
2193) {
2194    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2195    .reg .u64 %grad, %input, %out, %off;
2196    .reg .f32 %vg, %x, %neg, %e, %denom, %sig, %one, %lg2e;
2197    .reg .f32 %one_m_sig, %x_sig_omsig, %deriv, %result;
2198    .reg .pred %p;
2199
2200    ld.param.u64 %grad, [grad_ptr];
2201    ld.param.u64 %input, [input_ptr];
2202    ld.param.u64 %out, [out_ptr];
2203    ld.param.u32 %n_reg, [n];
2204
2205    mov.u32 %bid, %ctaid.x;
2206    mov.u32 %bdim, %ntid.x;
2207    mov.u32 %r_tid, %tid.x;
2208    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2209
2210    setp.ge.u32 %p, %r_tid, %n_reg;
2211    @%p bra DONE;
2212
2213    cvt.u64.u32 %off, %r_tid;
2214    shl.b64 %off, %off, 2;
2215    add.u64 %grad, %grad, %off;
2216    add.u64 %input, %input, %off;
2217    add.u64 %out, %out, %off;
2218
2219    ld.global.f32 %vg, [%grad];
2220    ld.global.f32 %x, [%input];
2221
2222    // sig = sigmoid(x) = 1 / (1 + exp(-x))
2223    mov.f32 %one, 0f3F800000;
2224    mov.f32 %lg2e, 0f3FB8AA3B;
2225    neg.f32 %neg, %x;
2226    mul.f32 %neg, %neg, %lg2e;
2227    ex2.approx.f32 %e, %neg;
2228    add.f32 %denom, %one, %e;
2229    rcp.approx.f32 %sig, %denom;
2230
2231    // deriv = sig + x * sig * (1 - sig)
2232    sub.f32 %one_m_sig, %one, %sig;
2233    mul.f32 %x_sig_omsig, %x, %sig;
2234    mul.f32 %x_sig_omsig, %x_sig_omsig, %one_m_sig;
2235    add.f32 %deriv, %sig, %x_sig_omsig;
2236    mul.f32 %result, %vg, %deriv;
2237    st.global.f32 [%out], %result;
2238
2239DONE:
2240    ret;
2241}
2242";
2243
2244/// PTX source for `silu_backward_f64_kernel` (f64).
2245/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(-x).
2246#[cfg(feature = "cuda")]
2247pub(crate) const SILU_BACKWARD_F64_PTX: &str = "\
2248.version 7.0
2249.target sm_52
2250.address_size 64
2251
2252.visible .entry silu_backward_f64_kernel(
2253    .param .u64 grad_ptr,
2254    .param .u64 input_ptr,
2255    .param .u64 out_ptr,
2256    .param .u32 n
2257) {
2258    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2259    .reg .u64 %grad, %input, %out, %off;
2260    .reg .f64 %vg, %x, %neg_x, %e, %denom, %sig, %one;
2261    .reg .f64 %one_m_sig, %x_sig_omsig, %deriv, %result;
2262    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2263    .reg .s32 %e_ni;
2264    .reg .s64 %e_ni64, %e_bits;
2265    .reg .pred %p;
2266
2267    ld.param.u64 %grad, [grad_ptr];
2268    ld.param.u64 %input, [input_ptr];
2269    ld.param.u64 %out, [out_ptr];
2270    ld.param.u32 %n_reg, [n];
2271
2272    mov.u32 %bid, %ctaid.x;
2273    mov.u32 %bdim, %ntid.x;
2274    mov.u32 %r_tid, %tid.x;
2275    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2276
2277    setp.ge.u32 %p, %r_tid, %n_reg;
2278    @%p bra DONE;
2279
2280    cvt.u64.u32 %off, %r_tid;
2281    shl.b64 %off, %off, 3;
2282    add.u64 %grad, %grad, %off;
2283    add.u64 %input, %input, %off;
2284    add.u64 %out, %out, %off;
2285
2286    ld.global.f64 %vg, [%grad];
2287    ld.global.f64 %x, [%input];
2288
2289    mov.f64 %one, 0d3FF0000000000000;
2290    neg.f64 %neg_x, %x;
2291
2292    // --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
2293    mov.f64 %e_half, 0d3FE0000000000000;
2294    fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
2295    cvt.rmi.f64.f64 %e_nf, %e_nf;
2296    cvt.rni.s32.f64 %e_ni, %e_nf;
2297    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
2298    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2299    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2300    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2301    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2302    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2303    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2304    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2305    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2306    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2307    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2308    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2309    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2310    fma.rn.f64 %e, %e_p, %e_r, %one;
2311    cvt.s64.s32 %e_ni64, %e_ni;
2312    add.s64 %e_ni64, %e_ni64, 1023;
2313    shl.b64 %e_bits, %e_ni64, 52;
2314    mov.b64 %e_nf, %e_bits;
2315    mul.f64 %e, %e, %e_nf;
2316    // --- end exp ---
2317
2318    add.f64 %denom, %one, %e;
2319    div.rn.f64 %sig, %one, %denom;
2320
2321    sub.f64 %one_m_sig, %one, %sig;
2322    mul.f64 %x_sig_omsig, %x, %sig;
2323    mul.f64 %x_sig_omsig, %x_sig_omsig, %one_m_sig;
2324    add.f64 %deriv, %sig, %x_sig_omsig;
2325    mul.f64 %result, %vg, %deriv;
2326    st.global.f64 [%out], %result;
2327
2328DONE:
2329    ret;
2330}
2331";
2332
2333/// PTX source for `elu_kernel`: `out[i] = x > 0 ? x : alpha * (exp(x) - 1)`.
2334/// Takes `alpha` as an extra `.param .f32` parameter.
2335#[cfg(feature = "cuda")]
2336pub(crate) const ELU_PTX: &str = "\
2337.version 7.0
2338.target sm_52
2339.address_size 64
2340
2341.visible .entry elu_kernel(
2342    .param .u64 a_ptr,
2343    .param .u64 out_ptr,
2344    .param .u32 n,
2345    .param .f32 alpha
2346) {
2347    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2348    .reg .u64 %a, %out, %off;
2349    .reg .f32 %x, %alpha_r, %lg2e, %one, %ex, %em1, %neg_branch, %vr;
2350    .reg .pred %p, %pos;
2351
2352    ld.param.u64 %a, [a_ptr];
2353    ld.param.u64 %out, [out_ptr];
2354    ld.param.u32 %n_reg, [n];
2355    ld.param.f32 %alpha_r, [alpha];
2356
2357    mov.u32 %bid, %ctaid.x;
2358    mov.u32 %bdim, %ntid.x;
2359    mov.u32 %r_tid, %tid.x;
2360    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2361
2362    setp.ge.u32 %p, %r_tid, %n_reg;
2363    @%p bra DONE;
2364
2365    cvt.u64.u32 %off, %r_tid;
2366    shl.b64 %off, %off, 2;
2367
2368    add.u64 %a, %a, %off;
2369    add.u64 %out, %out, %off;
2370
2371    ld.global.f32 %x, [%a];
2372    mov.f32 %one, 0f3F800000;
2373    mov.f32 %lg2e, 0f3FB8AA3B;
2374
2375    // exp(x) = 2^(x * log2(e))
2376    mul.f32 %ex, %x, %lg2e;
2377    ex2.approx.f32 %ex, %ex;
2378    sub.f32 %em1, %ex, %one;
2379    mul.f32 %neg_branch, %alpha_r, %em1;
2380
2381    // x > 0 ? x : alpha*(exp(x)-1)
2382    mov.f32 %vr, 0f00000000;
2383    setp.gt.f32 %pos, %x, %vr;
2384    selp.f32 %vr, %x, %neg_branch, %pos;
2385    st.global.f32 [%out], %vr;
2386
2387DONE:
2388    ret;
2389}
2390";
2391
2392/// PTX source for `elu_f64_kernel`: `out[i] = x > 0 ? x : alpha * (exp(x) - 1)` (f64).
2393/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(x).
2394#[cfg(feature = "cuda")]
2395pub(crate) const ELU_F64_PTX: &str = "\
2396.version 7.0
2397.target sm_52
2398.address_size 64
2399
2400.visible .entry elu_f64_kernel(
2401    .param .u64 a_ptr,
2402    .param .u64 out_ptr,
2403    .param .u32 n,
2404    .param .f64 alpha
2405) {
2406    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2407    .reg .u64 %a, %out, %off;
2408    .reg .f64 %x, %alpha_r, %one, %ex, %em1, %neg_branch, %vr;
2409    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2410    .reg .s32 %e_ni;
2411    .reg .s64 %e_ni64, %e_bits;
2412    .reg .pred %p, %pos;
2413
2414    ld.param.u64 %a, [a_ptr];
2415    ld.param.u64 %out, [out_ptr];
2416    ld.param.u32 %n_reg, [n];
2417    ld.param.f64 %alpha_r, [alpha];
2418
2419    mov.u32 %bid, %ctaid.x;
2420    mov.u32 %bdim, %ntid.x;
2421    mov.u32 %r_tid, %tid.x;
2422    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2423
2424    setp.ge.u32 %p, %r_tid, %n_reg;
2425    @%p bra DONE;
2426
2427    cvt.u64.u32 %off, %r_tid;
2428    shl.b64 %off, %off, 3;
2429    add.u64 %a, %a, %off;
2430    add.u64 %out, %out, %off;
2431
2432    ld.global.f64 %x, [%a];
2433    mov.f64 %one, 0d3FF0000000000000;
2434
2435    // --- exp(%x) via Cody-Waite + degree-11 Horner ---
2436    mov.f64 %e_half, 0d3FE0000000000000;
2437    fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
2438    cvt.rmi.f64.f64 %e_nf, %e_nf;
2439    cvt.rni.s32.f64 %e_ni, %e_nf;
2440    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
2441    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2442    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2443    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2444    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2445    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2446    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2447    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2448    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2449    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2450    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2451    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2452    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2453    fma.rn.f64 %ex, %e_p, %e_r, %one;
2454    cvt.s64.s32 %e_ni64, %e_ni;
2455    add.s64 %e_ni64, %e_ni64, 1023;
2456    shl.b64 %e_bits, %e_ni64, 52;
2457    mov.b64 %e_nf, %e_bits;
2458    mul.f64 %ex, %ex, %e_nf;
2459    // --- end exp ---
2460
2461    sub.f64 %em1, %ex, %one;
2462    mul.f64 %neg_branch, %alpha_r, %em1;
2463
2464    mov.f64 %vr, 0d0000000000000000;
2465    setp.gt.f64 %pos, %x, %vr;
2466    selp.f64 %vr, %x, %neg_branch, %pos;
2467    st.global.f64 [%out], %vr;
2468
2469DONE:
2470    ret;
2471}
2472";
2473
2474/// PTX source for `elu_backward_kernel`:
2475/// `out[i] = x > 0 ? grad[i] : grad[i] * alpha * exp(x)`.
2476/// Takes `alpha` as an extra `.param .f32` parameter.
2477#[cfg(feature = "cuda")]
2478pub(crate) const ELU_BACKWARD_PTX: &str = "\
2479.version 7.0
2480.target sm_52
2481.address_size 64
2482
2483.visible .entry elu_backward_kernel(
2484    .param .u64 grad_ptr,
2485    .param .u64 input_ptr,
2486    .param .u64 out_ptr,
2487    .param .u32 n,
2488    .param .f32 alpha
2489) {
2490    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2491    .reg .u64 %grad, %input, %out, %off;
2492    .reg .f32 %vg, %x, %alpha_r, %lg2e, %ex, %neg_branch, %vr, %zero;
2493    .reg .pred %p, %pos;
2494
2495    ld.param.u64 %grad, [grad_ptr];
2496    ld.param.u64 %input, [input_ptr];
2497    ld.param.u64 %out, [out_ptr];
2498    ld.param.u32 %n_reg, [n];
2499    ld.param.f32 %alpha_r, [alpha];
2500
2501    mov.u32 %bid, %ctaid.x;
2502    mov.u32 %bdim, %ntid.x;
2503    mov.u32 %r_tid, %tid.x;
2504    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2505
2506    setp.ge.u32 %p, %r_tid, %n_reg;
2507    @%p bra DONE;
2508
2509    cvt.u64.u32 %off, %r_tid;
2510    shl.b64 %off, %off, 2;
2511    add.u64 %grad, %grad, %off;
2512    add.u64 %input, %input, %off;
2513    add.u64 %out, %out, %off;
2514
2515    ld.global.f32 %vg, [%grad];
2516    ld.global.f32 %x, [%input];
2517
2518    mov.f32 %lg2e, 0f3FB8AA3B;
2519    mov.f32 %zero, 0f00000000;
2520
2521    // exp(x) = 2^(x * log2(e))
2522    mul.f32 %ex, %x, %lg2e;
2523    ex2.approx.f32 %ex, %ex;
2524    // negative branch: grad * alpha * exp(x)
2525    mul.f32 %neg_branch, %vg, %alpha_r;
2526    mul.f32 %neg_branch, %neg_branch, %ex;
2527
2528    // x > 0 ? grad : grad * alpha * exp(x)
2529    setp.gt.f32 %pos, %x, %zero;
2530    selp.f32 %vr, %vg, %neg_branch, %pos;
2531    st.global.f32 [%out], %vr;
2532
2533DONE:
2534    ret;
2535}
2536";
2537
2538/// PTX source for `elu_backward_f64_kernel` (f64).
2539/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(x).
2540#[cfg(feature = "cuda")]
2541pub(crate) const ELU_BACKWARD_F64_PTX: &str = "\
2542.version 7.0
2543.target sm_52
2544.address_size 64
2545
2546.visible .entry elu_backward_f64_kernel(
2547    .param .u64 grad_ptr,
2548    .param .u64 input_ptr,
2549    .param .u64 out_ptr,
2550    .param .u32 n,
2551    .param .f64 alpha
2552) {
2553    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2554    .reg .u64 %grad, %input, %out, %off;
2555    .reg .f64 %vg, %x, %alpha_r, %ex, %neg_branch, %vr, %zero, %one;
2556    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2557    .reg .s32 %e_ni;
2558    .reg .s64 %e_ni64, %e_bits;
2559    .reg .pred %p, %pos;
2560
2561    ld.param.u64 %grad, [grad_ptr];
2562    ld.param.u64 %input, [input_ptr];
2563    ld.param.u64 %out, [out_ptr];
2564    ld.param.u32 %n_reg, [n];
2565    ld.param.f64 %alpha_r, [alpha];
2566
2567    mov.u32 %bid, %ctaid.x;
2568    mov.u32 %bdim, %ntid.x;
2569    mov.u32 %r_tid, %tid.x;
2570    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2571
2572    setp.ge.u32 %p, %r_tid, %n_reg;
2573    @%p bra DONE;
2574
2575    cvt.u64.u32 %off, %r_tid;
2576    shl.b64 %off, %off, 3;
2577    add.u64 %grad, %grad, %off;
2578    add.u64 %input, %input, %off;
2579    add.u64 %out, %out, %off;
2580
2581    ld.global.f64 %vg, [%grad];
2582    ld.global.f64 %x, [%input];
2583
2584    mov.f64 %zero, 0d0000000000000000;
2585    mov.f64 %one, 0d3FF0000000000000;
2586
2587    // --- exp(%x) via Cody-Waite + degree-11 Horner ---
2588    mov.f64 %e_half, 0d3FE0000000000000;
2589    fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
2590    cvt.rmi.f64.f64 %e_nf, %e_nf;
2591    cvt.rni.s32.f64 %e_ni, %e_nf;
2592    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
2593    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2594    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2595    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2596    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2597    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2598    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2599    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2600    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2601    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2602    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2603    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2604    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2605    fma.rn.f64 %ex, %e_p, %e_r, %one;
2606    cvt.s64.s32 %e_ni64, %e_ni;
2607    add.s64 %e_ni64, %e_ni64, 1023;
2608    shl.b64 %e_bits, %e_ni64, 52;
2609    mov.b64 %e_nf, %e_bits;
2610    mul.f64 %ex, %ex, %e_nf;
2611    // --- end exp ---
2612
2613    mul.f64 %neg_branch, %vg, %alpha_r;
2614    mul.f64 %neg_branch, %neg_branch, %ex;
2615
2616    setp.gt.f64 %pos, %x, %zero;
2617    selp.f64 %vr, %vg, %neg_branch, %pos;
2618    st.global.f64 [%out], %vr;
2619
2620DONE:
2621    ret;
2622}
2623";
2624
2625/// PTX source for `mish_kernel`: `out[i] = x * tanh(softplus(x))`.
2626/// softplus(x) = ln(1 + exp(x)). For stability: when x > 20, softplus ~ x.
2627/// tanh(y) = (exp(2y) - 1) / (exp(2y) + 1).
2628#[cfg(feature = "cuda")]
2629pub(crate) const MISH_PTX: &str = "\
2630.version 7.0
2631.target sm_52
2632.address_size 64
2633
2634.visible .entry mish_kernel(
2635    .param .u64 a_ptr,
2636    .param .u64 out_ptr,
2637    .param .u32 n
2638) {
2639    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2640    .reg .u64 %a, %out, %off;
2641    .reg .f32 %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
2642    .reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %th, %vr;
2643    .reg .f32 %threshold;
2644    .reg .pred %p, %large;
2645
2646    ld.param.u64 %a, [a_ptr];
2647    ld.param.u64 %out, [out_ptr];
2648    ld.param.u32 %n_reg, [n];
2649
2650    mov.u32 %bid, %ctaid.x;
2651    mov.u32 %bdim, %ntid.x;
2652    mov.u32 %r_tid, %tid.x;
2653    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2654
2655    setp.ge.u32 %p, %r_tid, %n_reg;
2656    @%p bra DONE;
2657
2658    cvt.u64.u32 %off, %r_tid;
2659    shl.b64 %off, %off, 2;
2660
2661    add.u64 %a, %a, %off;
2662    add.u64 %out, %out, %off;
2663
2664    ld.global.f32 %x, [%a];
2665    mov.f32 %one, 0f3F800000;
2666    mov.f32 %lg2e, 0f3FB8AA3B;
2667    // threshold = 20.0 = 0x41A00000
2668    mov.f32 %threshold, 0f41A00000;
2669
2670    // softplus(x) = ln(1 + exp(x))
2671    // For large x (> 20), softplus ~ x to avoid overflow
2672    setp.gt.f32 %large, %x, %threshold;
2673    @%large bra LARGE_X;
2674
2675    // exp(x) = 2^(x * log2(e))
2676    mul.f32 %ex, %x, %lg2e;
2677    ex2.approx.f32 %ex, %ex;
2678    add.f32 %ep1, %ex, %one;
2679    // ln(1+exp(x)) = log2(1+exp(x)) / log2(e)
2680    lg2.approx.f32 %lg_ep1, %ep1;
2681    // 1/log2(e) = ln(2) = 0.6931472 = 0x3F317218
2682    mul.f32 %sp, %lg_ep1, 0f3F317218;
2683
2684    // tanh(sp) = (exp(2*sp) - 1) / (exp(2*sp) + 1)
2685    add.f32 %two_sp, %sp, %sp;
2686    mul.f32 %two_sp, %two_sp, %lg2e;
2687    ex2.approx.f32 %e2sp, %two_sp;
2688    sub.f32 %e2sp_m1, %e2sp, %one;
2689    add.f32 %e2sp_p1, %e2sp, %one;
2690    rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2691    mul.f32 %th, %e2sp_m1, %e2sp_p1;
2692
2693    mul.f32 %vr, %x, %th;
2694    st.global.f32 [%out], %vr;
2695    bra DONE;
2696
2697LARGE_X:
2698    // softplus ~ x, mish ~ x * tanh(x)
2699    // tanh(x) = (exp(2x)-1)/(exp(2x)+1)
2700    add.f32 %two_sp, %x, %x;
2701    mul.f32 %two_sp, %two_sp, %lg2e;
2702    ex2.approx.f32 %e2sp, %two_sp;
2703    sub.f32 %e2sp_m1, %e2sp, %one;
2704    add.f32 %e2sp_p1, %e2sp, %one;
2705    rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2706    mul.f32 %th, %e2sp_m1, %e2sp_p1;
2707    mul.f32 %vr, %x, %th;
2708    st.global.f32 [%out], %vr;
2709
2710DONE:
2711    ret;
2712}
2713";
2714
2715/// PTX source for `mish_f64_kernel`: `out[i] = x * tanh(softplus(x))` (f64).
2716/// Full f64 precision: exp via Cody-Waite + Horner, log via argument reduction.
2717#[cfg(feature = "cuda")]
2718pub(crate) const MISH_F64_PTX: &str = "\
2719.version 7.0
2720.target sm_52
2721.address_size 64
2722
2723.visible .entry mish_f64_kernel(
2724    .param .u64 a_ptr,
2725    .param .u64 out_ptr,
2726    .param .u32 n
2727) {
2728    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2729    .reg .u64 %a, %out, %off;
2730    .reg .f64 %x, %one, %two, %ex, %ep1, %sp;
2731    .reg .f64 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %th, %vr;
2732    .reg .f64 %threshold;
2733    // exp subroutine regs
2734    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2735    .reg .s32 %e_ni;
2736    .reg .s64 %e_ni64, %e_bits;
2737    // log subroutine regs
2738    .reg .u64 %l_xbits, %l_mbits, %l_bias;
2739    .reg .s64 %l_exp64;
2740    .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
2741    .reg .pred %p, %large;
2742
2743    ld.param.u64 %a, [a_ptr];
2744    ld.param.u64 %out, [out_ptr];
2745    ld.param.u32 %n_reg, [n];
2746
2747    mov.u32 %bid, %ctaid.x;
2748    mov.u32 %bdim, %ntid.x;
2749    mov.u32 %r_tid, %tid.x;
2750    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2751
2752    setp.ge.u32 %p, %r_tid, %n_reg;
2753    @%p bra DONE;
2754
2755    cvt.u64.u32 %off, %r_tid;
2756    shl.b64 %off, %off, 3;
2757    add.u64 %a, %a, %off;
2758    add.u64 %out, %out, %off;
2759
2760    ld.global.f64 %x, [%a];
2761    mov.f64 %one, 0d3FF0000000000000;
2762    mov.f64 %two, 0d4000000000000000;
2763    mov.f64 %threshold, 0d4034000000000000;
2764
2765    setp.gt.f64 %large, %x, %threshold;
2766    @%large bra LARGE_X;
2767
2768    // === softplus: sp = ln(1 + exp(x)) ===
2769    // exp(x)
2770    mov.f64 %e_half, 0d3FE0000000000000;
2771    fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
2772    cvt.rmi.f64.f64 %e_nf, %e_nf;
2773    cvt.rni.s32.f64 %e_ni, %e_nf;
2774    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
2775    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2776    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2777    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2778    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2779    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2780    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2781    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2782    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2783    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2784    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2785    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2786    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2787    fma.rn.f64 %ex, %e_p, %e_r, %one;
2788    cvt.s64.s32 %e_ni64, %e_ni;
2789    add.s64 %e_ni64, %e_ni64, 1023;
2790    shl.b64 %e_bits, %e_ni64, 52;
2791    mov.b64 %e_nf, %e_bits;
2792    mul.f64 %ex, %ex, %e_nf;
2793
2794    // ep1 = 1 + exp(x)
2795    add.f64 %ep1, %ex, %one;
2796
2797    // ln(ep1) via argument reduction
2798    mov.b64 %l_xbits, %ep1;
2799    shr.u64 %l_exp64, %l_xbits, 52;
2800    and.b64 %l_exp64, %l_exp64, 2047;
2801    sub.s64 %l_exp64, %l_exp64, 1023;
2802    cvt.rn.f64.s64 %l_nf, %l_exp64;
2803    mov.u64 %l_bias, 0x3FF0000000000000;
2804    and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
2805    or.b64 %l_mbits, %l_mbits, %l_bias;
2806    mov.b64 %l_m, %l_mbits;
2807    sub.f64 %l_f, %l_m, %one;
2808    add.f64 %l_s, %l_m, %one;
2809    div.rn.f64 %l_f, %l_f, %l_s;
2810    mul.f64 %l_f2, %l_f, %l_f;
2811    mov.f64 %l_p, 0d3FB745D1745D1746;
2812    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
2813    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
2814    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
2815    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
2816    fma.rn.f64 %l_p, %l_p, %l_f2, %one;
2817    mul.f64 %l_p, %l_p, %l_f;
2818    add.f64 %l_p, %l_p, %l_p;
2819    mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
2820    fma.rn.f64 %sp, %l_nf, %l_ln2, %l_p;
2821
2822    // === tanh(sp) = (exp(2*sp)-1)/(exp(2*sp)+1) ===
2823    add.f64 %two_sp, %sp, %sp;
2824    fma.rn.f64 %e_nf, %two_sp, 0d3FF71547652B82FE, %e_half;
2825    cvt.rmi.f64.f64 %e_nf, %e_nf;
2826    cvt.rni.s32.f64 %e_ni, %e_nf;
2827    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
2828    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2829    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2830    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2831    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2832    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2833    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2834    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2835    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2836    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2837    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2838    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2839    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2840    fma.rn.f64 %e2sp, %e_p, %e_r, %one;
2841    cvt.s64.s32 %e_ni64, %e_ni;
2842    add.s64 %e_ni64, %e_ni64, 1023;
2843    shl.b64 %e_bits, %e_ni64, 52;
2844    mov.b64 %e_nf, %e_bits;
2845    mul.f64 %e2sp, %e2sp, %e_nf;
2846
2847    sub.f64 %e2sp_m1, %e2sp, %one;
2848    add.f64 %e2sp_p1, %e2sp, %one;
2849    div.rn.f64 %th, %e2sp_m1, %e2sp_p1;
2850
2851    mul.f64 %vr, %x, %th;
2852    st.global.f64 [%out], %vr;
2853    bra DONE;
2854
2855LARGE_X:
2856    // softplus ~ x, tanh(x) = (exp(2x)-1)/(exp(2x)+1) in f64
2857    add.f64 %two_sp, %x, %x;
2858    mov.f64 %e_half, 0d3FE0000000000000;
2859    fma.rn.f64 %e_nf, %two_sp, 0d3FF71547652B82FE, %e_half;
2860    cvt.rmi.f64.f64 %e_nf, %e_nf;
2861    cvt.rni.s32.f64 %e_ni, %e_nf;
2862    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
2863    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2864    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2865    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2866    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2867    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2868    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2869    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2870    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2871    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2872    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2873    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2874    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2875    fma.rn.f64 %e2sp, %e_p, %e_r, %one;
2876    cvt.s64.s32 %e_ni64, %e_ni;
2877    add.s64 %e_ni64, %e_ni64, 1023;
2878    shl.b64 %e_bits, %e_ni64, 52;
2879    mov.b64 %e_nf, %e_bits;
2880    mul.f64 %e2sp, %e2sp, %e_nf;
2881
2882    sub.f64 %e2sp_m1, %e2sp, %one;
2883    add.f64 %e2sp_p1, %e2sp, %one;
2884    div.rn.f64 %th, %e2sp_m1, %e2sp_p1;
2885    mul.f64 %vr, %x, %th;
2886    st.global.f64 [%out], %vr;
2887
2888DONE:
2889    ret;
2890}
2891";
2892
2893/// PTX source for `mish_backward_kernel`:
2894/// ```text
2895/// sp = ln(1 + exp(x))        // softplus
2896/// t  = tanh(sp)
2897/// sig = sigmoid(x) = 1/(1+exp(-x))
2898/// out[i] = grad[i] * (t + x * sig * (1 - t*t))
2899/// ```
2900/// For stability: when x > 20, sp ~ x, t ~ tanh(x), sig ~ 1.
2901#[cfg(feature = "cuda")]
2902pub(crate) const MISH_BACKWARD_PTX: &str = "\
2903.version 7.0
2904.target sm_52
2905.address_size 64
2906
2907.visible .entry mish_backward_kernel(
2908    .param .u64 grad_ptr,
2909    .param .u64 input_ptr,
2910    .param .u64 out_ptr,
2911    .param .u32 n
2912) {
2913    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2914    .reg .u64 %grad, %input, %out, %off;
2915    .reg .f32 %vg, %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
2916    .reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %t, %t2, %one_m_t2;
2917    .reg .f32 %neg, %en, %denom, %sig, %x_sig_omt2, %deriv, %result;
2918    .reg .f32 %threshold;
2919    .reg .pred %p, %large;
2920
2921    ld.param.u64 %grad, [grad_ptr];
2922    ld.param.u64 %input, [input_ptr];
2923    ld.param.u64 %out, [out_ptr];
2924    ld.param.u32 %n_reg, [n];
2925
2926    mov.u32 %bid, %ctaid.x;
2927    mov.u32 %bdim, %ntid.x;
2928    mov.u32 %r_tid, %tid.x;
2929    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2930
2931    setp.ge.u32 %p, %r_tid, %n_reg;
2932    @%p bra DONE;
2933
2934    cvt.u64.u32 %off, %r_tid;
2935    shl.b64 %off, %off, 2;
2936    add.u64 %grad, %grad, %off;
2937    add.u64 %input, %input, %off;
2938    add.u64 %out, %out, %off;
2939
2940    ld.global.f32 %vg, [%grad];
2941    ld.global.f32 %x, [%input];
2942
2943    mov.f32 %one, 0f3F800000;
2944    mov.f32 %lg2e, 0f3FB8AA3B;
2945    // threshold = 20.0
2946    mov.f32 %threshold, 0f41A00000;
2947
2948    setp.gt.f32 %large, %x, %threshold;
2949    @%large bra LARGE_X;
2950
2951    // --- Normal path ---
2952    // softplus: sp = ln(1 + exp(x))
2953    mul.f32 %ex, %x, %lg2e;
2954    ex2.approx.f32 %ex, %ex;
2955    add.f32 %ep1, %ex, %one;
2956    lg2.approx.f32 %lg_ep1, %ep1;
2957    // ln(2) = 0x3F317218
2958    mul.f32 %sp, %lg_ep1, 0f3F317218;
2959
2960    // t = tanh(sp) = (exp(2*sp)-1)/(exp(2*sp)+1)
2961    add.f32 %two_sp, %sp, %sp;
2962    mul.f32 %two_sp, %two_sp, %lg2e;
2963    ex2.approx.f32 %e2sp, %two_sp;
2964    sub.f32 %e2sp_m1, %e2sp, %one;
2965    add.f32 %e2sp_p1, %e2sp, %one;
2966    rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2967    mul.f32 %t, %e2sp_m1, %e2sp_p1;
2968
2969    // sig = sigmoid(x) = 1/(1+exp(-x))
2970    neg.f32 %neg, %x;
2971    mul.f32 %neg, %neg, %lg2e;
2972    ex2.approx.f32 %en, %neg;
2973    add.f32 %denom, %one, %en;
2974    rcp.approx.f32 %sig, %denom;
2975
2976    // deriv = t + x * sig * (1 - t*t)
2977    mul.f32 %t2, %t, %t;
2978    sub.f32 %one_m_t2, %one, %t2;
2979    mul.f32 %x_sig_omt2, %x, %sig;
2980    mul.f32 %x_sig_omt2, %x_sig_omt2, %one_m_t2;
2981    add.f32 %deriv, %t, %x_sig_omt2;
2982    mul.f32 %result, %vg, %deriv;
2983    st.global.f32 [%out], %result;
2984    bra DONE;
2985
2986LARGE_X:
2987    // sp ~ x, t ~ tanh(x), sig ~ 1
2988    // tanh(x) = (exp(2x)-1)/(exp(2x)+1)
2989    add.f32 %two_sp, %x, %x;
2990    mul.f32 %two_sp, %two_sp, %lg2e;
2991    ex2.approx.f32 %e2sp, %two_sp;
2992    sub.f32 %e2sp_m1, %e2sp, %one;
2993    add.f32 %e2sp_p1, %e2sp, %one;
2994    rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2995    mul.f32 %t, %e2sp_m1, %e2sp_p1;
2996
2997    // sig ~ 1, deriv ~ t + x*(1-t*t)
2998    mul.f32 %t2, %t, %t;
2999    sub.f32 %one_m_t2, %one, %t2;
3000    mul.f32 %x_sig_omt2, %x, %one_m_t2;
3001    add.f32 %deriv, %t, %x_sig_omt2;
3002    mul.f32 %result, %vg, %deriv;
3003    st.global.f32 [%out], %result;
3004
3005DONE:
3006    ret;
3007}
3008";
3009
3010/// PTX source for `mish_backward_f64_kernel` (f64).
3011/// Full f64 precision: exp via Cody-Waite + Horner, log via argument reduction.
3012#[cfg(feature = "cuda")]
3013pub(crate) const MISH_BACKWARD_F64_PTX: &str = "\
3014.version 7.0
3015.target sm_52
3016.address_size 64
3017
3018.visible .entry mish_backward_f64_kernel(
3019    .param .u64 grad_ptr,
3020    .param .u64 input_ptr,
3021    .param .u64 out_ptr,
3022    .param .u32 n
3023) {
3024    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3025    .reg .u64 %grad, %input, %out, %off;
3026    .reg .f64 %vg, %x, %one, %ex, %ep1, %sp;
3027    .reg .f64 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %t, %t2, %one_m_t2;
3028    .reg .f64 %neg_x, %en, %denom, %sig, %x_sig_omt2, %deriv, %result;
3029    .reg .f64 %threshold;
3030    // exp subroutine regs
3031    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
3032    .reg .s32 %e_ni;
3033    .reg .s64 %e_ni64, %e_bits;
3034    // log subroutine regs
3035    .reg .u64 %l_xbits, %l_mbits, %l_bias;
3036    .reg .s64 %l_exp64;
3037    .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
3038    .reg .pred %p, %large;
3039
3040    ld.param.u64 %grad, [grad_ptr];
3041    ld.param.u64 %input, [input_ptr];
3042    ld.param.u64 %out, [out_ptr];
3043    ld.param.u32 %n_reg, [n];
3044
3045    mov.u32 %bid, %ctaid.x;
3046    mov.u32 %bdim, %ntid.x;
3047    mov.u32 %r_tid, %tid.x;
3048    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3049
3050    setp.ge.u32 %p, %r_tid, %n_reg;
3051    @%p bra DONE;
3052
3053    cvt.u64.u32 %off, %r_tid;
3054    shl.b64 %off, %off, 3;
3055    add.u64 %grad, %grad, %off;
3056    add.u64 %input, %input, %off;
3057    add.u64 %out, %out, %off;
3058
3059    ld.global.f64 %vg, [%grad];
3060    ld.global.f64 %x, [%input];
3061
3062    mov.f64 %one, 0d3FF0000000000000;
3063    mov.f64 %threshold, 0d4034000000000000;
3064
3065    setp.gt.f64 %large, %x, %threshold;
3066    @%large bra LARGE_X;
3067
3068    // === softplus: sp = ln(1 + exp(x)) ===
3069    // exp(x)
3070    mov.f64 %e_half, 0d3FE0000000000000;
3071    mul.f64 %e_nf, %x, 0d3FF71547652B82FE;
3072    cvt.rni.f64.f64 %e_nf, %e_nf;
3073    cvt.rni.s32.f64 %e_ni, %e_nf;
3074    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
3075    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3076    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3077    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3078    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3079    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3080    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3081    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3082    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3083    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3084    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3085    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3086    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3087    fma.rn.f64 %ex, %e_p, %e_r, %one;
3088    cvt.s64.s32 %e_ni64, %e_ni;
3089    add.s64 %e_ni64, %e_ni64, 1023;
3090    shl.b64 %e_bits, %e_ni64, 52;
3091    mov.b64 %e_nf, %e_bits;
3092    mul.f64 %ex, %ex, %e_nf;
3093
3094    add.f64 %ep1, %ex, %one;
3095
3096    // ln(ep1) via argument reduction
3097    mov.b64 %l_xbits, %ep1;
3098    shr.u64 %l_exp64, %l_xbits, 52;
3099    and.b64 %l_exp64, %l_exp64, 2047;
3100    sub.s64 %l_exp64, %l_exp64, 1023;
3101    cvt.rn.f64.s64 %l_nf, %l_exp64;
3102    mov.u64 %l_bias, 0x3FF0000000000000;
3103    and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
3104    or.b64 %l_mbits, %l_mbits, %l_bias;
3105    mov.b64 %l_m, %l_mbits;
3106    sub.f64 %l_f, %l_m, %one;
3107    add.f64 %l_s, %l_m, %one;
3108    div.rn.f64 %l_f, %l_f, %l_s;
3109    mul.f64 %l_f2, %l_f, %l_f;
3110    mov.f64 %l_p, 0d3FB745D1745D1746;
3111    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
3112    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
3113    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
3114    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
3115    fma.rn.f64 %l_p, %l_p, %l_f2, %one;
3116    mul.f64 %l_p, %l_p, %l_f;
3117    add.f64 %l_p, %l_p, %l_p;
3118    mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
3119    fma.rn.f64 %sp, %l_nf, %l_ln2, %l_p;
3120
3121    // === tanh(sp) ===
3122    add.f64 %two_sp, %sp, %sp;
3123    mul.f64 %e_nf, %two_sp, 0d3FF71547652B82FE;
3124    cvt.rni.f64.f64 %e_nf, %e_nf;
3125    cvt.rni.s32.f64 %e_ni, %e_nf;
3126    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
3127    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3128    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3129    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3130    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3131    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3132    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3133    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3134    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3135    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3136    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3137    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3138    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3139    fma.rn.f64 %e2sp, %e_p, %e_r, %one;
3140    cvt.s64.s32 %e_ni64, %e_ni;
3141    add.s64 %e_ni64, %e_ni64, 1023;
3142    shl.b64 %e_bits, %e_ni64, 52;
3143    mov.b64 %e_nf, %e_bits;
3144    mul.f64 %e2sp, %e2sp, %e_nf;
3145
3146    sub.f64 %e2sp_m1, %e2sp, %one;
3147    add.f64 %e2sp_p1, %e2sp, %one;
3148    div.rn.f64 %t, %e2sp_m1, %e2sp_p1;
3149
3150    // === sigmoid(x) = 1/(1+exp(-x)) ===
3151    neg.f64 %neg_x, %x;
3152    mul.f64 %e_nf, %neg_x, 0d3FF71547652B82FE;
3153    cvt.rni.f64.f64 %e_nf, %e_nf;
3154    cvt.rni.s32.f64 %e_ni, %e_nf;
3155    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
3156    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3157    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3158    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3159    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3160    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3161    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3162    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3163    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3164    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3165    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3166    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3167    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3168    fma.rn.f64 %en, %e_p, %e_r, %one;
3169    cvt.s64.s32 %e_ni64, %e_ni;
3170    add.s64 %e_ni64, %e_ni64, 1023;
3171    shl.b64 %e_bits, %e_ni64, 52;
3172    mov.b64 %e_nf, %e_bits;
3173    mul.f64 %en, %en, %e_nf;
3174
3175    add.f64 %denom, %one, %en;
3176    div.rn.f64 %sig, %one, %denom;
3177
3178    // deriv = t + x * sig * (1 - t*t)
3179    mul.f64 %t2, %t, %t;
3180    sub.f64 %one_m_t2, %one, %t2;
3181    mul.f64 %x_sig_omt2, %x, %sig;
3182    mul.f64 %x_sig_omt2, %x_sig_omt2, %one_m_t2;
3183    add.f64 %deriv, %t, %x_sig_omt2;
3184    mul.f64 %result, %vg, %deriv;
3185    st.global.f64 [%out], %result;
3186    bra DONE;
3187
3188LARGE_X:
3189    // sp ~ x, tanh(x) in f64, sig ~ 1
3190    add.f64 %two_sp, %x, %x;
3191    mov.f64 %e_half, 0d3FE0000000000000;
3192    mul.f64 %e_nf, %two_sp, 0d3FF71547652B82FE;
3193    cvt.rni.f64.f64 %e_nf, %e_nf;
3194    cvt.rni.s32.f64 %e_ni, %e_nf;
3195    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
3196    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3197    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3198    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3199    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3200    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3201    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3202    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3203    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3204    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3205    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3206    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3207    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3208    fma.rn.f64 %e2sp, %e_p, %e_r, %one;
3209    cvt.s64.s32 %e_ni64, %e_ni;
3210    add.s64 %e_ni64, %e_ni64, 1023;
3211    shl.b64 %e_bits, %e_ni64, 52;
3212    mov.b64 %e_nf, %e_bits;
3213    mul.f64 %e2sp, %e2sp, %e_nf;
3214
3215    sub.f64 %e2sp_m1, %e2sp, %one;
3216    add.f64 %e2sp_p1, %e2sp, %one;
3217    div.rn.f64 %t, %e2sp_m1, %e2sp_p1;
3218
3219    // sig ~ 1, deriv ~ t + x*(1-t*t)
3220    mul.f64 %t2, %t, %t;
3221    sub.f64 %one_m_t2, %one, %t2;
3222    mul.f64 %x_sig_omt2, %x, %one_m_t2;
3223    add.f64 %deriv, %t, %x_sig_omt2;
3224    mul.f64 %result, %vg, %deriv;
3225    st.global.f64 [%out], %result;
3226
3227DONE:
3228    ret;
3229}
3230";
3231
3232/// PTX source for `clamp_kernel`: `out[i] = max(min_val, min(max_val, x[i]))`.
3233/// Takes two extra f32 params: min_val, max_val.
3234#[cfg(feature = "cuda")]
3235pub(crate) const CLAMP_PTX: &str = "\
3236.version 7.0
3237.target sm_52
3238.address_size 64
3239
3240.visible .entry clamp_kernel(
3241    .param .u64 in_ptr,
3242    .param .u64 out_ptr,
3243    .param .u32 n,
3244    .param .f32 min_val,
3245    .param .f32 max_val
3246) {
3247    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3248    .reg .u64 %in, %out, %off;
3249    .reg .f32 %x, %mn, %mx, %result;
3250    .reg .pred %p;
3251
3252    ld.param.u64 %in, [in_ptr];
3253    ld.param.u64 %out, [out_ptr];
3254    ld.param.u32 %n_reg, [n];
3255    ld.param.f32 %mn, [min_val];
3256    ld.param.f32 %mx, [max_val];
3257
3258    mov.u32 %bid, %ctaid.x;
3259    mov.u32 %bdim, %ntid.x;
3260    mov.u32 %r_tid, %tid.x;
3261    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3262
3263    setp.ge.u32 %p, %r_tid, %n_reg;
3264    @%p bra DONE;
3265
3266    cvt.u64.u32 %off, %r_tid;
3267    shl.b64 %off, %off, 2;
3268    add.u64 %in, %in, %off;
3269    add.u64 %out, %out, %off;
3270
3271    ld.global.f32 %x, [%in];
3272    max.f32 %result, %x, %mn;
3273    min.f32 %result, %result, %mx;
3274    st.global.f32 [%out], %result;
3275
3276DONE:
3277    ret;
3278}
3279";
3280
3281
3282// ---------------------------------------------------------------------------
3283// Backward activation kernels
3284// ---------------------------------------------------------------------------
3285
3286/// PTX source for `relu_backward_kernel`: `out[i] = (input[i] > 0) ? grad[i] : 0`.
3287/// Takes two inputs: grad (upstream gradient) and input (forward activation input).
3288#[cfg(feature = "cuda")]
3289pub(crate) const RELU_BACKWARD_PTX: &str = "\
3290.version 7.0
3291.target sm_52
3292.address_size 64
3293
3294.visible .entry relu_backward_kernel(
3295    .param .u64 grad_ptr,
3296    .param .u64 input_ptr,
3297    .param .u64 out_ptr,
3298    .param .u32 n
3299) {
3300    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3301    .reg .u64 %grad, %input, %out, %off;
3302    .reg .f32 %vg, %vi, %zero, %vr;
3303    .reg .pred %p, %pos;
3304
3305    ld.param.u64 %grad, [grad_ptr];
3306    ld.param.u64 %input, [input_ptr];
3307    ld.param.u64 %out, [out_ptr];
3308    ld.param.u32 %n_reg, [n];
3309
3310    mov.u32 %bid, %ctaid.x;
3311    mov.u32 %bdim, %ntid.x;
3312    mov.u32 %r_tid, %tid.x;
3313    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3314
3315    setp.ge.u32 %p, %r_tid, %n_reg;
3316    @%p bra DONE;
3317
3318    cvt.u64.u32 %off, %r_tid;
3319    shl.b64 %off, %off, 2;
3320
3321    add.u64 %grad, %grad, %off;
3322    add.u64 %input, %input, %off;
3323    add.u64 %out, %out, %off;
3324
3325    ld.global.f32 %vg, [%grad];
3326    ld.global.f32 %vi, [%input];
3327    mov.f32 %zero, 0f00000000;
3328    setp.gt.f32 %pos, %vi, %zero;
3329    selp.f32 %vr, %vg, %zero, %pos;
3330    st.global.f32 [%out], %vr;
3331
3332DONE:
3333    ret;
3334}
3335";
3336
3337
3338/// PTX source for `gelu_backward_kernel`:
3339/// `out[i] = grad[i] * (sig + 1.702 * x * sig * (1 - sig))`
3340/// where `sig = sigmoid(1.702 * x)`.
3341/// This is the exact derivative of `gelu(x) = x * sigmoid(1.702 * x)`.
3342///
3343/// Uses `.approx` PTX instructions (`ex2.approx.f32`, `rcp.approx.f32`)
3344/// for performance. These have reduced precision (~2^-22 relative error)
3345/// compared to the full-precision variants, which is acceptable for neural
3346/// network training/inference where f32 precision is already limited.
3347#[cfg(feature = "cuda")]
3348pub(crate) const GELU_BACKWARD_PTX: &str = "\
3349.version 7.0
3350.target sm_52
3351.address_size 64
3352
3353.visible .entry gelu_backward_kernel(
3354    .param .u64 grad_ptr,
3355    .param .u64 input_ptr,
3356    .param .u64 out_ptr,
3357    .param .u32 n
3358) {
3359    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3360    .reg .u64 %grad, %input, %out, %off;
3361    .reg .f32 %vg, %x, %k, %kx, %neg_kx, %log2e, %exp_neg, %one, %denom, %sig;
3362    .reg .f32 %one_minus_sig, %kx_sig_oms, %dsig, %result;
3363    .reg .pred %p;
3364
3365    ld.param.u64 %grad, [grad_ptr];
3366    ld.param.u64 %input, [input_ptr];
3367    ld.param.u64 %out, [out_ptr];
3368    ld.param.u32 %n_reg, [n];
3369
3370    mov.u32 %bid, %ctaid.x;
3371    mov.u32 %bdim, %ntid.x;
3372    mov.u32 %r_tid, %tid.x;
3373    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3374
3375    setp.ge.u32 %p, %r_tid, %n_reg;
3376    @%p bra DONE;
3377
3378    cvt.u64.u32 %off, %r_tid;
3379    shl.b64 %off, %off, 2;
3380
3381    add.u64 %grad, %grad, %off;
3382    add.u64 %input, %input, %off;
3383    add.u64 %out, %out, %off;
3384
3385    ld.global.f32 %vg, [%grad];
3386    ld.global.f32 %x, [%input];
3387
3388    // sig = sigmoid(1.702 * x)
3389    mov.f32 %k, 0f3FDA2720;
3390    mul.f32 %kx, %k, %x;
3391    neg.f32 %neg_kx, %kx;
3392    mov.f32 %log2e, 0f3FB8AA3B;
3393    mul.f32 %neg_kx, %neg_kx, %log2e;
3394    ex2.approx.f32 %exp_neg, %neg_kx;
3395    mov.f32 %one, 0f3F800000;
3396    add.f32 %denom, %one, %exp_neg;
3397    rcp.approx.f32 %sig, %denom;
3398
3399    // d/dx gelu(x) = sig + k * x * sig * (1 - sig)
3400    sub.f32 %one_minus_sig, %one, %sig;
3401    mul.f32 %kx_sig_oms, %kx, %sig;
3402    mul.f32 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
3403    add.f32 %dsig, %sig, %kx_sig_oms;
3404
3405    // out = grad * d_gelu
3406    mul.f32 %result, %vg, %dsig;
3407    st.global.f32 [%out], %result;
3408
3409DONE:
3410    ret;
3411}
3412";
3413
3414/// PTX source for `gelu_backward_f64_kernel`: sigmoid-approx backward (f64).
3415/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(-k*x).
3416#[cfg(feature = "cuda")]
3417pub(crate) const GELU_BACKWARD_F64_PTX: &str = "\
3418.version 7.0
3419.target sm_52
3420.address_size 64
3421
3422.visible .entry gelu_backward_f64_kernel(
3423    .param .u64 grad_ptr,
3424    .param .u64 input_ptr,
3425    .param .u64 out_ptr,
3426    .param .u32 n
3427) {
3428    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3429    .reg .u64 %grad, %input, %out, %off;
3430    .reg .f64 %vg, %x, %k, %kx, %neg_kx, %exp_neg, %one, %denom, %sig;
3431    .reg .f64 %one_minus_sig, %kx_sig_oms, %dsig, %result;
3432    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
3433    .reg .s32 %e_ni;
3434    .reg .s64 %e_ni64, %e_bits;
3435    .reg .pred %p;
3436
3437    ld.param.u64 %grad, [grad_ptr];
3438    ld.param.u64 %input, [input_ptr];
3439    ld.param.u64 %out, [out_ptr];
3440    ld.param.u32 %n_reg, [n];
3441
3442    mov.u32 %bid, %ctaid.x;
3443    mov.u32 %bdim, %ntid.x;
3444    mov.u32 %r_tid, %tid.x;
3445    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3446
3447    setp.ge.u32 %p, %r_tid, %n_reg;
3448    @%p bra DONE;
3449
3450    cvt.u64.u32 %off, %r_tid;
3451    shl.b64 %off, %off, 3;
3452    add.u64 %grad, %grad, %off;
3453    add.u64 %input, %input, %off;
3454    add.u64 %out, %out, %off;
3455
3456    ld.global.f64 %vg, [%grad];
3457    ld.global.f64 %x, [%input];
3458
3459    mov.f64 %one, 0d3FF0000000000000;
3460    mov.f64 %k, 0d3FFB44E400000000;
3461    mul.f64 %kx, %k, %x;
3462    neg.f64 %neg_kx, %kx;
3463
3464    // --- exp(%neg_kx) via Cody-Waite + degree-11 Horner ---
3465    mov.f64 %e_half, 0d3FE0000000000000;
3466    fma.rn.f64 %e_nf, %neg_kx, 0d3FF71547652B82FE, %e_half;
3467    cvt.rmi.f64.f64 %e_nf, %e_nf;
3468    cvt.rni.s32.f64 %e_ni, %e_nf;
3469    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_kx;
3470    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3471    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3472    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3473    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3474    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3475    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3476    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3477    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3478    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3479    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
3480    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3481    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3482    fma.rn.f64 %exp_neg, %e_p, %e_r, %one;
3483    cvt.s64.s32 %e_ni64, %e_ni;
3484    add.s64 %e_ni64, %e_ni64, 1023;
3485    shl.b64 %e_bits, %e_ni64, 52;
3486    mov.b64 %e_nf, %e_bits;
3487    mul.f64 %exp_neg, %exp_neg, %e_nf;
3488    // --- end exp ---
3489
3490    add.f64 %denom, %one, %exp_neg;
3491    div.rn.f64 %sig, %one, %denom;
3492
3493    sub.f64 %one_minus_sig, %one, %sig;
3494    mul.f64 %kx_sig_oms, %kx, %sig;
3495    mul.f64 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
3496    add.f64 %dsig, %sig, %kx_sig_oms;
3497
3498    mul.f64 %result, %vg, %dsig;
3499    st.global.f64 [%out], %result;
3500
3501DONE:
3502    ret;
3503}
3504";
3505
3506/// PTX source for `gelu_backward_erf_kernel`:
3507/// Exact GELU backward using erf: `d/dx gelu(x) = Φ(x) + x·φ(x)`
3508/// where `Φ(x) = 0.5·(1 + erf(x/√2))` and `φ(x) = exp(-x²/2) / √(2π)`.
3509///
3510/// Uses Abramowitz & Stegun formula 7.1.26 for erf (|ε| < 1.5×10⁻⁷):
3511///   `erf(x) = 1 - (a₁t + a₂t² + a₃t³ + a₄t⁴ + a₅t⁵) · exp(-x²)`
3512///   where `t = 1/(1 + 0.3275911·|x|)`
3513#[cfg(feature = "cuda")]
3514pub(crate) const GELU_BACKWARD_ERF_PTX: &str = "\
3515.version 7.0
3516.target sm_52
3517.address_size 64
3518
3519.visible .entry gelu_backward_erf_kernel(
3520    .param .u64 grad_ptr,
3521    .param .u64 input_ptr,
3522    .param .u64 out_ptr,
3523    .param .u32 n
3524) {
3525    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3526    .reg .u64 %grad, %input, %out, %off;
3527    .reg .f32 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
3528    .reg .f32 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
3529    .reg .f32 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
3530    .reg .f32 %d_gelu, %result;
3531    .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %log2e;
3532    .reg .pred %pred_ge, %pred_neg;
3533
3534    ld.param.u64 %grad, [grad_ptr];
3535    ld.param.u64 %input, [input_ptr];
3536    ld.param.u64 %out, [out_ptr];
3537    ld.param.u32 %n_reg, [n];
3538
3539    mov.u32 %bid, %ctaid.x;
3540    mov.u32 %bdim, %ntid.x;
3541    mov.u32 %r_tid, %tid.x;
3542    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3543
3544    setp.ge.u32 %pred_ge, %r_tid, %n_reg;
3545    @%pred_ge bra DONE;
3546
3547    cvt.u64.u32 %off, %r_tid;
3548    shl.b64 %off, %off, 2;
3549
3550    add.u64 %grad, %grad, %off;
3551    add.u64 %input, %input, %off;
3552    add.u64 %out, %out, %off;
3553
3554    ld.global.f32 %vg, [%grad];
3555    ld.global.f32 %x, [%input];
3556
3557    mov.f32 %one, 0f3F800000;
3558    mov.f32 %half, 0f3F000000;
3559
3560    // z = x / sqrt(2) = x * 0.70710678
3561    mov.f32 %z, 0f3F3504F3;
3562    mul.f32 %z, %x, %z;
3563
3564    // |z| for erf(|z|)
3565    abs.f32 %ax, %z;
3566
3567    // t = 1 / (1 + 0.3275911 * |z|)
3568    mov.f32 %p, 0f3EA7BA05;
3569    mul.f32 %t, %p, %ax;
3570    add.f32 %t, %one, %t;
3571    rcp.approx.f32 %t, %t;
3572
3573    // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
3574    mov.f32 %a5, 0f3E0AAAAB;
3575    mov.f32 %a4, 0fBEB3A903;
3576    mov.f32 %a3, 0f3FB506DD;
3577    mov.f32 %a2, 0fBF03C1E1;
3578    mov.f32 %a1, 0f3EA0D6BB;
3579
3580    mul.f32 %pt, %t, %a5;
3581    add.f32 %pt, %pt, %a4;
3582    mul.f32 %pt, %pt, %t;
3583    add.f32 %pt, %pt, %a3;
3584    mul.f32 %pt, %pt, %t;
3585    add.f32 %pt, %pt, %a2;
3586    mul.f32 %pt, %pt, %t;
3587    add.f32 %pt, %pt, %a1;
3588    mul.f32 %pt, %pt, %t;
3589
3590    // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
3591    mul.f32 %z2, %ax, %ax;
3592    neg.f32 %neg_z2, %z2;
3593    mov.f32 %log2e, 0f3FB8AA3B;
3594    mul.f32 %neg_z2, %neg_z2, %log2e;
3595    ex2.approx.f32 %exp_neg_z2, %neg_z2;
3596
3597    // erf(|z|) = 1 - poly * exp(-z^2)
3598    mul.f32 %erf_val, %pt, %exp_neg_z2;
3599    sub.f32 %erf_val, %one, %erf_val;
3600
3601    // erf(-z) = -erf(z), so sign-correct
3602    setp.lt.f32 %pred_neg, %z, 0f00000000;
3603    @%pred_neg neg.f32 %erf_val, %erf_val;
3604
3605    // Φ(x) = 0.5 * (1 + erf(x/sqrt(2)))
3606    add.f32 %cdf, %one, %erf_val;
3607    mul.f32 %cdf, %half, %cdf;
3608
3609    // φ(x) = exp(-x²/2) / sqrt(2π)
3610    // exp(-x²/2):
3611    mul.f32 %neg_x2h, %x, %x;
3612    mul.f32 %neg_x2h, %neg_x2h, %half;
3613    neg.f32 %neg_x2h, %neg_x2h;
3614    mul.f32 %neg_x2h, %neg_x2h, %log2e;
3615    ex2.approx.f32 %exp_neg_x2h, %neg_x2h;
3616
3617    // 1/sqrt(2π) = 0.39894228
3618    mov.f32 %inv_sqrt_2pi, 0f3ECC4220;
3619    mul.f32 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
3620
3621    // d/dx gelu(x) = Φ(x) + x * φ(x)
3622    mul.f32 %x_pdf, %x, %pdf;
3623    add.f32 %d_gelu, %cdf, %x_pdf;
3624
3625    // out = grad * d_gelu
3626    mul.f32 %result, %vg, %d_gelu;
3627    st.global.f32 [%out], %result;
3628
3629DONE:
3630    ret;
3631}
3632";
3633
3634/// PTX source for `gelu_backward_erf_f64_kernel`: exact erf backward (f64).
3635/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(-z^2) and exp(-x^2/2).
3636#[cfg(feature = "cuda")]
3637pub(crate) const GELU_BACKWARD_ERF_F64_PTX: &str = "\
3638.version 7.0
3639.target sm_52
3640.address_size 64
3641
3642.visible .entry gelu_backward_erf_f64_kernel(
3643    .param .u64 grad_ptr,
3644    .param .u64 input_ptr,
3645    .param .u64 out_ptr,
3646    .param .u32 n
3647) {
3648    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3649    .reg .u64 %grad, %input, %out, %off;
3650    .reg .f64 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
3651    .reg .f64 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
3652    .reg .f64 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
3653    .reg .f64 %d_gelu, %result;
3654    .reg .f64 %p_coef, %a1, %a2, %a3, %a4, %a5;
3655    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
3656    .reg .s32 %e_ni;
3657    .reg .s64 %e_ni64, %e_bits;
3658    .reg .pred %pred_ge, %pred_neg;
3659
3660    ld.param.u64 %grad, [grad_ptr];
3661    ld.param.u64 %input, [input_ptr];
3662    ld.param.u64 %out, [out_ptr];
3663    ld.param.u32 %n_reg, [n];
3664
3665    mov.u32 %bid, %ctaid.x;
3666    mov.u32 %bdim, %ntid.x;
3667    mov.u32 %r_tid, %tid.x;
3668    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3669
3670    setp.ge.u32 %pred_ge, %r_tid, %n_reg;
3671    @%pred_ge bra DONE;
3672
3673    cvt.u64.u32 %off, %r_tid;
3674    shl.b64 %off, %off, 3;
3675    add.u64 %grad, %grad, %off;
3676    add.u64 %input, %input, %off;
3677    add.u64 %out, %out, %off;
3678
3679    ld.global.f64 %vg, [%grad];
3680    ld.global.f64 %x, [%input];
3681
3682    mov.f64 %one, 0d3FF0000000000000;
3683    mov.f64 %half, 0d3FE0000000000000;
3684
3685    mov.f64 %z, 0d3FE6A09E60000000;
3686    mul.f64 %z, %x, %z;
3687    abs.f64 %ax, %z;
3688
3689    mov.f64 %p_coef, 0d3FD4F740A0000000;
3690    mul.f64 %t, %p_coef, %ax;
3691    add.f64 %t, %one, %t;
3692    div.rn.f64 %t, %one, %t;
3693
3694    mov.f64 %a5, 0d3FC1555560000000;
3695    mov.f64 %a4, 0dBFD6752060000000;
3696    mov.f64 %a3, 0d3FF6A0DBA0000000;
3697    mov.f64 %a2, 0dBFE0783C20000000;
3698    mov.f64 %a1, 0d3FD41AD760000000;
3699
3700    mul.f64 %pt, %t, %a5;
3701    add.f64 %pt, %pt, %a4;
3702    mul.f64 %pt, %pt, %t;
3703    add.f64 %pt, %pt, %a3;
3704    mul.f64 %pt, %pt, %t;
3705    add.f64 %pt, %pt, %a2;
3706    mul.f64 %pt, %pt, %t;
3707    add.f64 %pt, %pt, %a1;
3708    mul.f64 %pt, %pt, %t;
3709
3710    // exp(-z^2) in full f64
3711    mul.f64 %z2, %ax, %ax;
3712    neg.f64 %neg_z2, %z2;
3713
3714    // --- exp(%neg_z2) ---
3715    mov.f64 %e_half, 0d3FE0000000000000;
3716    fma.rn.f64 %e_nf, %neg_z2, 0d3FF71547652B82FE, %e_half;
3717    cvt.rmi.f64.f64 %e_nf, %e_nf;
3718    cvt.rni.s32.f64 %e_ni, %e_nf;
3719    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_z2;
3720    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3721    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3722    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3723    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3724    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3725    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3726    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3727    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3728    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3729    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
3730    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3731    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3732    fma.rn.f64 %exp_neg_z2, %e_p, %e_r, %one;
3733    cvt.s64.s32 %e_ni64, %e_ni;
3734    add.s64 %e_ni64, %e_ni64, 1023;
3735    shl.b64 %e_bits, %e_ni64, 52;
3736    mov.b64 %e_nf, %e_bits;
3737    mul.f64 %exp_neg_z2, %exp_neg_z2, %e_nf;
3738    // --- end exp ---
3739
3740    mul.f64 %erf_val, %pt, %exp_neg_z2;
3741    sub.f64 %erf_val, %one, %erf_val;
3742
3743    setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
3744    @%pred_neg neg.f64 %erf_val, %erf_val;
3745
3746    add.f64 %cdf, %one, %erf_val;
3747    mul.f64 %cdf, %half, %cdf;
3748
3749    // phi(x) = exp(-x^2/2) / sqrt(2*pi)
3750    mul.f64 %neg_x2h, %x, %x;
3751    mul.f64 %neg_x2h, %neg_x2h, %half;
3752    neg.f64 %neg_x2h, %neg_x2h;
3753
3754    // --- exp(%neg_x2h) ---
3755    fma.rn.f64 %e_nf, %neg_x2h, 0d3FF71547652B82FE, %e_half;
3756    cvt.rmi.f64.f64 %e_nf, %e_nf;
3757    cvt.rni.s32.f64 %e_ni, %e_nf;
3758    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x2h;
3759    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3760    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3761    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3762    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3763    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3764    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3765    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3766    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3767    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3768    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
3769    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3770    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3771    fma.rn.f64 %exp_neg_x2h, %e_p, %e_r, %one;
3772    cvt.s64.s32 %e_ni64, %e_ni;
3773    add.s64 %e_ni64, %e_ni64, 1023;
3774    shl.b64 %e_bits, %e_ni64, 52;
3775    mov.b64 %e_nf, %e_bits;
3776    mul.f64 %exp_neg_x2h, %exp_neg_x2h, %e_nf;
3777    // --- end exp ---
3778
3779    // 1/sqrt(2*pi) = 0.39894228
3780    mov.f64 %inv_sqrt_2pi, 0d3FD9884440000000;
3781    mul.f64 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
3782
3783    mul.f64 %x_pdf, %x, %pdf;
3784    add.f64 %d_gelu, %cdf, %x_pdf;
3785
3786    mul.f64 %result, %vg, %d_gelu;
3787    st.global.f64 [%out], %result;
3788
3789DONE:
3790    ret;
3791}
3792";
3793
3794// ---------------------------------------------------------------------------
3795// Index-select (1-D gather) PTX kernel
3796// ---------------------------------------------------------------------------
3797// Thread i: output[i] = input[indices[i]]
3798// Indices are stored as f32 on the GPU (cast to u32 via truncation).
3799
3800#[cfg(feature = "cuda")]
3801pub(crate) const INDEX_SELECT_1D_PTX: &str = "\
3802.version 7.0
3803.target sm_52
3804.address_size 64
3805
3806.visible .entry index_select_1d_kernel(
3807    .param .u64 input_ptr,
3808    .param .u64 indices_ptr,
3809    .param .u64 out_ptr,
3810    .param .u32 n_indices
3811) {
3812    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
3813    .reg .u64 %input, %indices, %out, %off, %addr;
3814    .reg .f32 %idx_f, %val;
3815    .reg .pred %p;
3816
3817    ld.param.u64 %input, [input_ptr];
3818    ld.param.u64 %indices, [indices_ptr];
3819    ld.param.u64 %out, [out_ptr];
3820    ld.param.u32 %n_reg, [n_indices];
3821
3822    mov.u32 %bid, %ctaid.x;
3823    mov.u32 %bdim, %ntid.x;
3824    mov.u32 %r_tid, %tid.x;
3825    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3826
3827    setp.ge.u32 %p, %r_tid, %n_reg;
3828    @%p bra DONE;
3829
3830    // Byte offset for thread
3831    cvt.u64.u32 %off, %r_tid;
3832    shl.b64 %off, %off, 2;
3833
3834    // Read indices[tid] (f32 -> u32)
3835    add.u64 %addr, %indices, %off;
3836    ld.global.f32 %idx_f, [%addr];
3837    cvt.rzi.u32.f32 %idx, %idx_f;
3838
3839    // Read input[idx]
3840    cvt.u64.u32 %addr, %idx;
3841    shl.b64 %addr, %addr, 2;
3842    add.u64 %addr, %input, %addr;
3843    ld.global.f32 %val, [%addr];
3844
3845    // Write output[tid]
3846    add.u64 %addr, %out, %off;
3847    st.global.f32 [%addr], %val;
3848
3849DONE:
3850    ret;
3851}
3852";
3853
3854
3855// ---------------------------------------------------------------------------
3856// Scatter-add (1-D) PTX kernel — backward of index_select
3857// ---------------------------------------------------------------------------
3858// Thread i: atomicAdd(grad_input[indices[i]], grad_output[i])
3859// The output buffer (grad_input) must be pre-zeroed.
3860// Uses atom.global.add.f32 for safe concurrent accumulation when
3861// duplicate indices map multiple threads to the same output slot.
3862
3863#[cfg(feature = "cuda")]
3864pub(crate) const SCATTER_ADD_1D_PTX: &str = "\
3865.version 7.0
3866.target sm_52
3867.address_size 64
3868
3869.visible .entry scatter_add_1d_kernel(
3870    .param .u64 grad_output_ptr,
3871    .param .u64 indices_ptr,
3872    .param .u64 grad_input_ptr,
3873    .param .u32 n_indices
3874) {
3875    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
3876    .reg .u64 %go, %indices, %gi, %off, %addr;
3877    .reg .f32 %idx_f, %grad_val, %dummy;
3878    .reg .pred %p;
3879
3880    ld.param.u64 %go, [grad_output_ptr];
3881    ld.param.u64 %indices, [indices_ptr];
3882    ld.param.u64 %gi, [grad_input_ptr];
3883    ld.param.u32 %n_reg, [n_indices];
3884
3885    mov.u32 %bid, %ctaid.x;
3886    mov.u32 %bdim, %ntid.x;
3887    mov.u32 %r_tid, %tid.x;
3888    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3889
3890    setp.ge.u32 %p, %r_tid, %n_reg;
3891    @%p bra DONE;
3892
3893    // Byte offset for thread
3894    cvt.u64.u32 %off, %r_tid;
3895    shl.b64 %off, %off, 2;
3896
3897    // Read grad_output[tid]
3898    add.u64 %addr, %go, %off;
3899    ld.global.f32 %grad_val, [%addr];
3900
3901    // Read indices[tid] (f32 -> u32)
3902    add.u64 %addr, %indices, %off;
3903    ld.global.f32 %idx_f, [%addr];
3904    cvt.rzi.u32.f32 %idx, %idx_f;
3905
3906    // Atomic add: grad_input[idx] += grad_val
3907    cvt.u64.u32 %addr, %idx;
3908    shl.b64 %addr, %addr, 2;
3909    add.u64 %addr, %gi, %addr;
3910    atom.global.add.f32 %dummy, [%addr], %grad_val;
3911
3912DONE:
3913    ret;
3914}
3915";
3916
3917
3918// ---------------------------------------------------------------------------
3919// Masked-fill PTX kernel
3920// ---------------------------------------------------------------------------
3921// Thread i: output[i] = mask[i] >= 0.5 ? fill_value : input[i]
3922// Mask is stored as f32 (1.0 = true, 0.0 = false).
3923
3924#[cfg(feature = "cuda")]
3925pub(crate) const MASKED_FILL_PTX: &str = "\
3926.version 7.0
3927.target sm_52
3928.address_size 64
3929
3930.visible .entry masked_fill_kernel(
3931    .param .u64 input_ptr,
3932    .param .u64 mask_ptr,
3933    .param .u64 out_ptr,
3934    .param .f32 fill_value,
3935    .param .u32 n
3936) {
3937    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3938    .reg .u64 %input, %mask, %out, %off;
3939    .reg .f32 %in_val, %mask_val, %fill, %result, %half;
3940    .reg .pred %p, %pmask;
3941
3942    ld.param.u64 %input, [input_ptr];
3943    ld.param.u64 %mask, [mask_ptr];
3944    ld.param.u64 %out, [out_ptr];
3945    ld.param.f32 %fill, [fill_value];
3946    ld.param.u32 %n_reg, [n];
3947
3948    mov.u32 %bid, %ctaid.x;
3949    mov.u32 %bdim, %ntid.x;
3950    mov.u32 %r_tid, %tid.x;
3951    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3952
3953    setp.ge.u32 %p, %r_tid, %n_reg;
3954    @%p bra DONE;
3955
3956    cvt.u64.u32 %off, %r_tid;
3957    shl.b64 %off, %off, 2;
3958
3959    add.u64 %input, %input, %off;
3960    add.u64 %mask, %mask, %off;
3961    add.u64 %out, %out, %off;
3962
3963    ld.global.f32 %in_val, [%input];
3964    ld.global.f32 %mask_val, [%mask];
3965    mov.f32 %half, 0f3F000000;
3966    setp.ge.f32 %pmask, %mask_val, %half;
3967    selp.f32 %result, %fill, %in_val, %pmask;
3968    st.global.f32 [%out], %result;
3969
3970DONE:
3971    ret;
3972}
3973";
3974
3975
3976// ---------------------------------------------------------------------------
3977// Masked-zero PTX kernel — backward of masked_fill
3978// ---------------------------------------------------------------------------
3979// Thread i: output[i] = mask[i] >= 0.5 ? 0.0 : grad_output[i]
3980// Zeroes gradient at positions where the forward mask was true.
3981
3982#[cfg(feature = "cuda")]
3983pub(crate) const MASKED_ZERO_PTX: &str = "\
3984.version 7.0
3985.target sm_52
3986.address_size 64
3987
3988.visible .entry masked_zero_kernel(
3989    .param .u64 grad_ptr,
3990    .param .u64 mask_ptr,
3991    .param .u64 out_ptr,
3992    .param .u32 n
3993) {
3994    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3995    .reg .u64 %grad, %mask, %out, %off;
3996    .reg .f32 %vg, %mask_val, %zero, %result, %half;
3997    .reg .pred %p, %pmask;
3998
3999    ld.param.u64 %grad, [grad_ptr];
4000    ld.param.u64 %mask, [mask_ptr];
4001    ld.param.u64 %out, [out_ptr];
4002    ld.param.u32 %n_reg, [n];
4003
4004    mov.u32 %bid, %ctaid.x;
4005    mov.u32 %bdim, %ntid.x;
4006    mov.u32 %r_tid, %tid.x;
4007    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4008
4009    setp.ge.u32 %p, %r_tid, %n_reg;
4010    @%p bra DONE;
4011
4012    cvt.u64.u32 %off, %r_tid;
4013    shl.b64 %off, %off, 2;
4014
4015    add.u64 %grad, %grad, %off;
4016    add.u64 %mask, %mask, %off;
4017    add.u64 %out, %out, %off;
4018
4019    ld.global.f32 %vg, [%grad];
4020    ld.global.f32 %mask_val, [%mask];
4021    mov.f32 %zero, 0f00000000;
4022    mov.f32 %half, 0f3F000000;
4023    setp.ge.f32 %pmask, %mask_val, %half;
4024    selp.f32 %result, %zero, %vg, %pmask;
4025    st.global.f32 [%out], %result;
4026
4027DONE:
4028    ret;
4029}
4030";
4031
4032
4033// ---------------------------------------------------------------------------
4034// Sigmoid backward PTX kernel: out[i] = grad[i] * output[i] * (1 - output[i])
4035// ---------------------------------------------------------------------------
4036
4037#[cfg(feature = "cuda")]
4038pub(crate) const SIGMOID_BACKWARD_PTX: &str = "\
4039.version 7.0
4040.target sm_52
4041.address_size 64
4042
4043.visible .entry sigmoid_backward_kernel(
4044    .param .u64 grad_ptr,
4045    .param .u64 output_ptr,
4046    .param .u64 out_ptr,
4047    .param .u32 n
4048) {
4049    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4050    .reg .u64 %grad, %output, %out, %off;
4051    .reg .f32 %vg, %vo, %one, %one_minus_o, %result;
4052    .reg .pred %p;
4053
4054    ld.param.u64 %grad, [grad_ptr];
4055    ld.param.u64 %output, [output_ptr];
4056    ld.param.u64 %out, [out_ptr];
4057    ld.param.u32 %n_reg, [n];
4058
4059    mov.u32 %bid, %ctaid.x;
4060    mov.u32 %bdim, %ntid.x;
4061    mov.u32 %r_tid, %tid.x;
4062    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4063
4064    setp.ge.u32 %p, %r_tid, %n_reg;
4065    @%p bra DONE;
4066
4067    cvt.u64.u32 %off, %r_tid;
4068    shl.b64 %off, %off, 2;
4069
4070    add.u64 %grad, %grad, %off;
4071    add.u64 %output, %output, %off;
4072    add.u64 %out, %out, %off;
4073
4074    ld.global.f32 %vg, [%grad];
4075    ld.global.f32 %vo, [%output];
4076    mov.f32 %one, 0f3F800000;
4077    sub.f32 %one_minus_o, %one, %vo;
4078    mul.f32 %result, %vo, %one_minus_o;
4079    mul.f32 %result, %vg, %result;
4080    st.global.f32 [%out], %result;
4081
4082DONE:
4083    ret;
4084}
4085";
4086
4087
4088// ---------------------------------------------------------------------------
4089// Tanh backward PTX kernel: out[i] = grad[i] * (1 - output[i]^2)
4090// ---------------------------------------------------------------------------
4091
4092#[cfg(feature = "cuda")]
4093pub(crate) const TANH_BACKWARD_PTX: &str = "\
4094.version 7.0
4095.target sm_52
4096.address_size 64
4097
4098.visible .entry tanh_backward_kernel(
4099    .param .u64 grad_ptr,
4100    .param .u64 output_ptr,
4101    .param .u64 out_ptr,
4102    .param .u32 n
4103) {
4104    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4105    .reg .u64 %grad, %output, %out, %off;
4106    .reg .f32 %vg, %vo, %one, %o_sq, %one_minus_sq, %result;
4107    .reg .pred %p;
4108
4109    ld.param.u64 %grad, [grad_ptr];
4110    ld.param.u64 %output, [output_ptr];
4111    ld.param.u64 %out, [out_ptr];
4112    ld.param.u32 %n_reg, [n];
4113
4114    mov.u32 %bid, %ctaid.x;
4115    mov.u32 %bdim, %ntid.x;
4116    mov.u32 %r_tid, %tid.x;
4117    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4118
4119    setp.ge.u32 %p, %r_tid, %n_reg;
4120    @%p bra DONE;
4121
4122    cvt.u64.u32 %off, %r_tid;
4123    shl.b64 %off, %off, 2;
4124
4125    add.u64 %grad, %grad, %off;
4126    add.u64 %output, %output, %off;
4127    add.u64 %out, %out, %off;
4128
4129    ld.global.f32 %vg, [%grad];
4130    ld.global.f32 %vo, [%output];
4131    mov.f32 %one, 0f3F800000;
4132    mul.f32 %o_sq, %vo, %vo;
4133    sub.f32 %one_minus_sq, %one, %o_sq;
4134    mul.f32 %result, %vg, %one_minus_sq;
4135    st.global.f32 [%out], %result;
4136
4137DONE:
4138    ret;
4139}
4140";
4141
4142
4143// ---------------------------------------------------------------------------
4144// Softmax backward PTX kernel (row-wise, shared-memory dot product)
4145// ---------------------------------------------------------------------------
4146// For each row of length `cols`:
4147//   dot = sum(grad[row] * output[row])
4148//   out[i] = output[i] * (grad[i] - dot)
4149// One block per row, 256 threads per block.
4150
4151#[cfg(feature = "cuda")]
4152pub(crate) const SOFTMAX_BACKWARD_PTX: &str = "\
4153.version 7.0\n\
4154.target sm_52\n\
4155.address_size 64\n\
4156\n\
4157.shared .align 4 .f32 sdata[256];\n\
4158\n\
4159.visible .entry softmax_backward_kernel(\n\
4160    .param .u64 grad_ptr,\n\
4161    .param .u64 output_ptr,\n\
4162    .param .u64 out_ptr,\n\
4163    .param .u32 rows,\n\
4164    .param .u32 cols\n\
4165) {\n\
4166    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
4167    .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
4168    .reg .f32 %vg, %vo, %dot, %other_val, %diff, %result;\n\
4169    .reg .pred %p, %loop_p, %reduce_p;\n\
4170\n\
4171    ld.param.u64 %grad, [grad_ptr];\n\
4172    ld.param.u64 %output, [output_ptr];\n\
4173    ld.param.u64 %out, [out_ptr];\n\
4174    ld.param.u32 %rows_reg, [rows];\n\
4175    ld.param.u32 %cols_reg, [cols];\n\
4176\n\
4177    mov.u32 %bid, %ctaid.x;\n\
4178    mov.u32 %bdim, %ntid.x;\n\
4179    mov.u32 %r_tid, %tid.x;\n\
4180    mov.u64 %sbase, sdata;\n\
4181\n\
4182    setp.ge.u32 %p, %bid, %rows_reg;\n\
4183    @%p bra DONE;\n\
4184\n\
4185    // row_off = bid * cols * 4 (byte offset)\n\
4186    cvt.u64.u32 %row_off, %bid;\n\
4187    cvt.u64.u32 %off, %cols_reg;\n\
4188    mul.lo.u64 %row_off, %row_off, %off;\n\
4189    shl.b64 %row_off, %row_off, 2;\n\
4190\n\
4191    // Phase 1: compute partial dot = sum(grad[j] * output[j]) for this thread's elements\n\
4192    mov.f32 %dot, 0f00000000;\n\
4193    mov.u32 %j, %r_tid;\n\
4194DOT_LOOP:\n\
4195    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4196    @%loop_p bra DOT_LOOP_DONE;\n\
4197    cvt.u64.u32 %off, %j;\n\
4198    shl.b64 %off, %off, 2;\n\
4199    add.u64 %saddr, %grad, %off;\n\
4200    add.u64 %saddr, %saddr, %row_off;\n\
4201    ld.global.f32 %vg, [%saddr];\n\
4202    add.u64 %saddr, %output, %off;\n\
4203    add.u64 %saddr, %saddr, %row_off;\n\
4204    ld.global.f32 %vo, [%saddr];\n\
4205    fma.rn.f32 %dot, %vg, %vo, %dot;\n\
4206    add.u32 %j, %j, %bdim;\n\
4207    bra DOT_LOOP;\n\
4208DOT_LOOP_DONE:\n\
4209\n\
4210    // Store partial dot into shared memory and reduce\n\
4211    cvt.u64.u32 %off, %r_tid;\n\
4212    shl.b64 %off, %off, 2;\n\
4213    add.u64 %saddr, %sbase, %off;\n\
4214    st.shared.f32 [%saddr], %dot;\n\
4215    bar.sync 0;\n\
4216\n\
4217    mov.u32 %half, %bdim;\n\
4218DOT_REDUCE:\n\
4219    shr.u32 %half, %half, 1;\n\
4220    setp.eq.u32 %reduce_p, %half, 0;\n\
4221    @%reduce_p bra DOT_REDUCE_DONE;\n\
4222    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4223    @%reduce_p bra DOT_REDUCE_SKIP;\n\
4224    add.u32 %other_tid, %r_tid, %half;\n\
4225    cvt.u64.u32 %off, %other_tid;\n\
4226    shl.b64 %off, %off, 2;\n\
4227    add.u64 %saddr, %sbase, %off;\n\
4228    ld.shared.f32 %other_val, [%saddr];\n\
4229    cvt.u64.u32 %off, %r_tid;\n\
4230    shl.b64 %off, %off, 2;\n\
4231    add.u64 %saddr, %sbase, %off;\n\
4232    ld.shared.f32 %dot, [%saddr];\n\
4233    add.f32 %dot, %dot, %other_val;\n\
4234    st.shared.f32 [%saddr], %dot;\n\
4235DOT_REDUCE_SKIP:\n\
4236    bar.sync 0;\n\
4237    bra DOT_REDUCE;\n\
4238DOT_REDUCE_DONE:\n\
4239\n\
4240    // Broadcast dot to all threads\n\
4241    ld.shared.f32 %dot, [sdata];\n\
4242    bar.sync 0;\n\
4243\n\
4244    // Phase 2: out[j] = output[j] * (grad[j] - dot)\n\
4245    mov.u32 %j, %r_tid;\n\
4246WRITE_LOOP:\n\
4247    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4248    @%loop_p bra WRITE_LOOP_DONE;\n\
4249    cvt.u64.u32 %off, %j;\n\
4250    shl.b64 %off, %off, 2;\n\
4251    add.u64 %saddr, %grad, %off;\n\
4252    add.u64 %saddr, %saddr, %row_off;\n\
4253    ld.global.f32 %vg, [%saddr];\n\
4254    add.u64 %saddr, %output, %off;\n\
4255    add.u64 %saddr, %saddr, %row_off;\n\
4256    ld.global.f32 %vo, [%saddr];\n\
4257    sub.f32 %diff, %vg, %dot;\n\
4258    mul.f32 %result, %vo, %diff;\n\
4259    add.u64 %saddr, %out, %off;\n\
4260    add.u64 %saddr, %saddr, %row_off;\n\
4261    st.global.f32 [%saddr], %result;\n\
4262    add.u32 %j, %j, %bdim;\n\
4263    bra WRITE_LOOP;\n\
4264WRITE_LOOP_DONE:\n\
4265\n\
4266DONE:\n\
4267    ret;\n\
4268}\n\
4269";
4270
4271
4272// ---------------------------------------------------------------------------
4273// LogSoftmax forward PTX kernel (row-wise, shared-memory max + log-sum-exp)
4274// ---------------------------------------------------------------------------
4275// For each row of length `cols`:
4276//   m = max(x[j])
4277//   log_sum_exp = m + log(sum(exp(x[j] - m)))
4278//   out[j] = x[j] - log_sum_exp
4279// One block per row, 256 threads per block.
4280
4281#[cfg(feature = "cuda")]
4282pub(crate) const LOG_SOFTMAX_PTX: &str = "\
4283.version 7.0\n\
4284.target sm_52\n\
4285.address_size 64\n\
4286\n\
4287.shared .align 4 .f32 sdata[256];\n\
4288\n\
4289.visible .entry log_softmax_kernel(\n\
4290    .param .u64 input_ptr,\n\
4291    .param .u64 output_ptr,\n\
4292    .param .u32 rows,\n\
4293    .param .u32 cols\n\
4294) {\n\
4295    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
4296    .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
4297    .reg .f32 %val, %max_val, %sum_val, %exp_val, %log_sum_exp, %result;\n\
4298    .reg .pred %p, %loop_p;\n\
4299    .reg .u32 %half, %other_tid;\n\
4300    .reg .f32 %other_val;\n\
4301    .reg .pred %reduce_p;\n\
4302\n\
4303    ld.param.u64 %in, [input_ptr];\n\
4304    ld.param.u64 %out, [output_ptr];\n\
4305    ld.param.u32 %rows_reg, [rows];\n\
4306    ld.param.u32 %cols_reg, [cols];\n\
4307\n\
4308    mov.u32 %bid, %ctaid.x;\n\
4309    mov.u32 %bdim, %ntid.x;\n\
4310    mov.u32 %r_tid, %tid.x;\n\
4311    mov.u64 %sbase, sdata;\n\
4312\n\
4313    setp.ge.u32 %p, %bid, %rows_reg;\n\
4314    @%p bra DONE;\n\
4315\n\
4316    // row_off = bid * cols * 4 (byte offset)\n\
4317    cvt.u64.u32 %row_off, %bid;\n\
4318    cvt.u64.u32 %off, %cols_reg;\n\
4319    mul.lo.u64 %row_off, %row_off, %off;\n\
4320    shl.b64 %row_off, %row_off, 2;\n\
4321\n\
4322    // Phase 1: find max across row (grid-stride over columns)\n\
4323    mov.f32 %max_val, 0fFF800000;\n\
4324    mov.u32 %j, %r_tid;\n\
4325FIND_MAX:\n\
4326    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4327    @%loop_p bra FIND_MAX_DONE;\n\
4328    cvt.u64.u32 %off, %j;\n\
4329    shl.b64 %off, %off, 2;\n\
4330    add.u64 %off, %in, %off;\n\
4331    add.u64 %off, %off, %row_off;\n\
4332    ld.global.f32 %val, [%off];\n\
4333    max.f32 %max_val, %max_val, %val;\n\
4334    add.u32 %j, %j, %bdim;\n\
4335    bra FIND_MAX;\n\
4336FIND_MAX_DONE:\n\
4337\n\
4338    // Shared-memory tree reduction for max\n\
4339    cvt.u64.u32 %off, %r_tid;\n\
4340    shl.b64 %off, %off, 2;\n\
4341    add.u64 %saddr, %sbase, %off;\n\
4342    st.shared.f32 [%saddr], %max_val;\n\
4343    bar.sync 0;\n\
4344\n\
4345    mov.u32 %half, %bdim;\n\
4346MAX_REDUCE:\n\
4347    shr.u32 %half, %half, 1;\n\
4348    setp.eq.u32 %reduce_p, %half, 0;\n\
4349    @%reduce_p bra MAX_REDUCE_DONE;\n\
4350    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4351    @%reduce_p bra MAX_REDUCE_SKIP;\n\
4352    add.u32 %other_tid, %r_tid, %half;\n\
4353    cvt.u64.u32 %off, %other_tid;\n\
4354    shl.b64 %off, %off, 2;\n\
4355    add.u64 %saddr, %sbase, %off;\n\
4356    ld.shared.f32 %other_val, [%saddr];\n\
4357    cvt.u64.u32 %off, %r_tid;\n\
4358    shl.b64 %off, %off, 2;\n\
4359    add.u64 %saddr, %sbase, %off;\n\
4360    ld.shared.f32 %max_val, [%saddr];\n\
4361    max.f32 %max_val, %max_val, %other_val;\n\
4362    add.u64 %saddr, %sbase, %off;\n\
4363    st.shared.f32 [%saddr], %max_val;\n\
4364MAX_REDUCE_SKIP:\n\
4365    bar.sync 0;\n\
4366    bra MAX_REDUCE;\n\
4367MAX_REDUCE_DONE:\n\
4368\n\
4369    // Broadcast max to all threads\n\
4370    ld.shared.f32 %max_val, [sdata];\n\
4371    bar.sync 0;\n\
4372\n\
4373    // Phase 2: compute partial sum of exp(x[j] - max)\n\
4374    mov.f32 %sum_val, 0f00000000;\n\
4375    mov.u32 %j, %r_tid;\n\
4376SUM_EXP:\n\
4377    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4378    @%loop_p bra SUM_EXP_DONE;\n\
4379    cvt.u64.u32 %off, %j;\n\
4380    shl.b64 %off, %off, 2;\n\
4381    add.u64 %off, %in, %off;\n\
4382    add.u64 %off, %off, %row_off;\n\
4383    ld.global.f32 %val, [%off];\n\
4384    sub.f32 %val, %val, %max_val;\n\
4385    // exp(x) = exp2(x * log2(e)), log2(e) = 0x3FB8AA3B\n\
4386    mul.f32 %val, %val, 0f3FB8AA3B;\n\
4387    ex2.approx.f32 %exp_val, %val;\n\
4388    add.f32 %sum_val, %sum_val, %exp_val;\n\
4389    add.u32 %j, %j, %bdim;\n\
4390    bra SUM_EXP;\n\
4391SUM_EXP_DONE:\n\
4392\n\
4393    // Shared-memory tree reduction for sum\n\
4394    cvt.u64.u32 %off, %r_tid;\n\
4395    shl.b64 %off, %off, 2;\n\
4396    add.u64 %saddr, %sbase, %off;\n\
4397    st.shared.f32 [%saddr], %sum_val;\n\
4398    bar.sync 0;\n\
4399\n\
4400    mov.u32 %half, %bdim;\n\
4401SUM_REDUCE:\n\
4402    shr.u32 %half, %half, 1;\n\
4403    setp.eq.u32 %reduce_p, %half, 0;\n\
4404    @%reduce_p bra SUM_REDUCE_DONE;\n\
4405    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4406    @%reduce_p bra SUM_REDUCE_SKIP;\n\
4407    add.u32 %other_tid, %r_tid, %half;\n\
4408    cvt.u64.u32 %off, %other_tid;\n\
4409    shl.b64 %off, %off, 2;\n\
4410    add.u64 %saddr, %sbase, %off;\n\
4411    ld.shared.f32 %other_val, [%saddr];\n\
4412    cvt.u64.u32 %off, %r_tid;\n\
4413    shl.b64 %off, %off, 2;\n\
4414    add.u64 %saddr, %sbase, %off;\n\
4415    ld.shared.f32 %sum_val, [%saddr];\n\
4416    add.f32 %sum_val, %sum_val, %other_val;\n\
4417    add.u64 %saddr, %sbase, %off;\n\
4418    st.shared.f32 [%saddr], %sum_val;\n\
4419SUM_REDUCE_SKIP:\n\
4420    bar.sync 0;\n\
4421    bra SUM_REDUCE;\n\
4422SUM_REDUCE_DONE:\n\
4423\n\
4424    // Broadcast sum to all threads, compute log_sum_exp = max + log(sum)\n\
4425    ld.shared.f32 %sum_val, [sdata];\n\
4426    bar.sync 0;\n\
4427    // log(x) = log2(x) / log2(e) = log2(x) * ln(2)\n\
4428    // ln(2) = 0x3F317218\n\
4429    lg2.approx.f32 %log_sum_exp, %sum_val;\n\
4430    mul.f32 %log_sum_exp, %log_sum_exp, 0f3F317218;\n\
4431    add.f32 %log_sum_exp, %max_val, %log_sum_exp;\n\
4432\n\
4433    // Phase 3: out[j] = x[j] - log_sum_exp\n\
4434    mov.u32 %j, %r_tid;\n\
4435WRITE_OUTPUT:\n\
4436    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4437    @%loop_p bra WRITE_OUTPUT_DONE;\n\
4438    cvt.u64.u32 %off, %j;\n\
4439    shl.b64 %off, %off, 2;\n\
4440    add.u64 %saddr, %in, %off;\n\
4441    add.u64 %saddr, %saddr, %row_off;\n\
4442    ld.global.f32 %val, [%saddr];\n\
4443    sub.f32 %result, %val, %log_sum_exp;\n\
4444    cvt.u64.u32 %off, %j;\n\
4445    shl.b64 %off, %off, 2;\n\
4446    add.u64 %saddr, %out, %off;\n\
4447    add.u64 %saddr, %saddr, %row_off;\n\
4448    st.global.f32 [%saddr], %result;\n\
4449    add.u32 %j, %j, %bdim;\n\
4450    bra WRITE_OUTPUT;\n\
4451WRITE_OUTPUT_DONE:\n\
4452\n\
4453DONE:\n\
4454    ret;\n\
4455}\n\
4456";
4457
4458/// PTX source for `log_softmax_f64_kernel`: row-wise log-softmax (f64).
4459#[cfg(feature = "cuda")]
4460pub(crate) const LOG_SOFTMAX_F64_PTX: &str = "\
4461.version 7.0\n\
4462.target sm_52\n\
4463.address_size 64\n\
4464\n\
4465.shared .align 8 .f64 sdata[256];\n\
4466\n\
4467.visible .entry log_softmax_f64_kernel(\n\
4468    .param .u64 input_ptr,\n\
4469    .param .u64 output_ptr,\n\
4470    .param .u32 rows,\n\
4471    .param .u32 cols\n\
4472) {\n\
4473    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
4474    .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
4475    .reg .f64 %val, %max_val, %sum_val, %exp_val, %log_sum_exp, %result;\n\
4476    .reg .pred %p, %loop_p;\n\
4477    .reg .u32 %half, %other_tid;\n\
4478    .reg .f64 %other_val;\n\
4479    .reg .pred %reduce_p;\n\
4480    .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
4481    .reg .s32 %e_ni;\n\
4482    .reg .s64 %e_ni64, %e_bits;\n\
4483    .reg .u64 %l_xbits, %l_mbits, %l_bias;\n\
4484    .reg .s64 %l_exp64;\n\
4485    .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;\n\
4486\n\
4487    ld.param.u64 %in, [input_ptr];\n\
4488    ld.param.u64 %out, [output_ptr];\n\
4489    ld.param.u32 %rows_reg, [rows];\n\
4490    ld.param.u32 %cols_reg, [cols];\n\
4491\n\
4492    mov.u32 %bid, %ctaid.x;\n\
4493    mov.u32 %bdim, %ntid.x;\n\
4494    mov.u32 %r_tid, %tid.x;\n\
4495    mov.u64 %sbase, sdata;\n\
4496\n\
4497    setp.ge.u32 %p, %bid, %rows_reg;\n\
4498    @%p bra DONE;\n\
4499\n\
4500    cvt.u64.u32 %row_off, %bid;\n\
4501    cvt.u64.u32 %off, %cols_reg;\n\
4502    mul.lo.u64 %row_off, %row_off, %off;\n\
4503    shl.b64 %row_off, %row_off, 3;\n\
4504\n\
4505    mov.f64 %max_val, 0dFFF0000000000000;\n\
4506    mov.u32 %j, %r_tid;\n\
4507FIND_MAX:\n\
4508    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4509    @%loop_p bra FIND_MAX_DONE;\n\
4510    cvt.u64.u32 %off, %j;\n\
4511    shl.b64 %off, %off, 3;\n\
4512    add.u64 %off, %in, %off;\n\
4513    add.u64 %off, %off, %row_off;\n\
4514    ld.global.f64 %val, [%off];\n\
4515    max.f64 %max_val, %max_val, %val;\n\
4516    add.u32 %j, %j, %bdim;\n\
4517    bra FIND_MAX;\n\
4518FIND_MAX_DONE:\n\
4519\n\
4520    cvt.u64.u32 %off, %r_tid;\n\
4521    shl.b64 %off, %off, 3;\n\
4522    add.u64 %saddr, %sbase, %off;\n\
4523    st.shared.f64 [%saddr], %max_val;\n\
4524    bar.sync 0;\n\
4525\n\
4526    mov.u32 %half, %bdim;\n\
4527MAX_REDUCE:\n\
4528    shr.u32 %half, %half, 1;\n\
4529    setp.eq.u32 %reduce_p, %half, 0;\n\
4530    @%reduce_p bra MAX_REDUCE_DONE;\n\
4531    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4532    @%reduce_p bra MAX_REDUCE_SKIP;\n\
4533    add.u32 %other_tid, %r_tid, %half;\n\
4534    cvt.u64.u32 %off, %other_tid;\n\
4535    shl.b64 %off, %off, 3;\n\
4536    add.u64 %saddr, %sbase, %off;\n\
4537    ld.shared.f64 %other_val, [%saddr];\n\
4538    cvt.u64.u32 %off, %r_tid;\n\
4539    shl.b64 %off, %off, 3;\n\
4540    add.u64 %saddr, %sbase, %off;\n\
4541    ld.shared.f64 %max_val, [%saddr];\n\
4542    max.f64 %max_val, %max_val, %other_val;\n\
4543    st.shared.f64 [%saddr], %max_val;\n\
4544MAX_REDUCE_SKIP:\n\
4545    bar.sync 0;\n\
4546    bra MAX_REDUCE;\n\
4547MAX_REDUCE_DONE:\n\
4548\n\
4549    ld.shared.f64 %max_val, [sdata];\n\
4550    bar.sync 0;\n\
4551\n\
4552    mov.f64 %sum_val, 0d0000000000000000;\n\
4553    mov.u32 %j, %r_tid;\n\
4554SUM_EXP:\n\
4555    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4556    @%loop_p bra SUM_EXP_DONE;\n\
4557    cvt.u64.u32 %off, %j;\n\
4558    shl.b64 %off, %off, 3;\n\
4559    add.u64 %off, %in, %off;\n\
4560    add.u64 %off, %off, %row_off;\n\
4561    ld.global.f64 %val, [%off];\n\
4562    sub.f64 %val, %val, %max_val;\n\
4563    mov.f64 %e_one, 0d3FF0000000000000;\n\
4564    mov.f64 %e_half, 0d3FE0000000000000;\n\
4565    mul.f64 %e_nf, %val, 0d3FF71547652B82FE;\n\
4566    cvt.rni.f64.f64 %e_nf, %e_nf;\n\
4567    cvt.rni.s32.f64 %e_ni, %e_nf;\n\
4568    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %val;\n\
4569    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
4570    mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
4571    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
4572    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
4573    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
4574    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
4575    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
4576    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
4577    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
4578    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
4579    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
4580    fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
4581    fma.rn.f64 %exp_val, %e_p, %e_r, %e_one;\n\
4582    cvt.s64.s32 %e_ni64, %e_ni;\n\
4583    add.s64 %e_ni64, %e_ni64, 1023;\n\
4584    shl.b64 %e_bits, %e_ni64, 52;\n\
4585    mov.b64 %e_nf, %e_bits;\n\
4586    mul.f64 %exp_val, %exp_val, %e_nf;\n\
4587    add.f64 %sum_val, %sum_val, %exp_val;\n\
4588    add.u32 %j, %j, %bdim;\n\
4589    bra SUM_EXP;\n\
4590SUM_EXP_DONE:\n\
4591\n\
4592    cvt.u64.u32 %off, %r_tid;\n\
4593    shl.b64 %off, %off, 3;\n\
4594    add.u64 %saddr, %sbase, %off;\n\
4595    st.shared.f64 [%saddr], %sum_val;\n\
4596    bar.sync 0;\n\
4597\n\
4598    mov.u32 %half, %bdim;\n\
4599SUM_REDUCE:\n\
4600    shr.u32 %half, %half, 1;\n\
4601    setp.eq.u32 %reduce_p, %half, 0;\n\
4602    @%reduce_p bra SUM_REDUCE_DONE;\n\
4603    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4604    @%reduce_p bra SUM_REDUCE_SKIP;\n\
4605    add.u32 %other_tid, %r_tid, %half;\n\
4606    cvt.u64.u32 %off, %other_tid;\n\
4607    shl.b64 %off, %off, 3;\n\
4608    add.u64 %saddr, %sbase, %off;\n\
4609    ld.shared.f64 %other_val, [%saddr];\n\
4610    cvt.u64.u32 %off, %r_tid;\n\
4611    shl.b64 %off, %off, 3;\n\
4612    add.u64 %saddr, %sbase, %off;\n\
4613    ld.shared.f64 %sum_val, [%saddr];\n\
4614    add.f64 %sum_val, %sum_val, %other_val;\n\
4615    st.shared.f64 [%saddr], %sum_val;\n\
4616SUM_REDUCE_SKIP:\n\
4617    bar.sync 0;\n\
4618    bra SUM_REDUCE;\n\
4619SUM_REDUCE_DONE:\n\
4620\n\
4621    ld.shared.f64 %sum_val, [sdata];\n\
4622    bar.sync 0;\n\
4623    mov.f64 %e_one, 0d3FF0000000000000;\n\
4624    mov.b64 %l_xbits, %sum_val;\n\
4625    shr.u64 %l_exp64, %l_xbits, 52;\n\
4626    and.b64 %l_exp64, %l_exp64, 2047;\n\
4627    sub.s64 %l_exp64, %l_exp64, 1023;\n\
4628    cvt.rn.f64.s64 %l_nf, %l_exp64;\n\
4629    mov.u64 %l_bias, 0x3FF0000000000000;\n\
4630    and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;\n\
4631    or.b64 %l_mbits, %l_mbits, %l_bias;\n\
4632    mov.b64 %l_m, %l_mbits;\n\
4633    sub.f64 %l_f, %l_m, %e_one;\n\
4634    add.f64 %l_s, %l_m, %e_one;\n\
4635    div.rn.f64 %l_f, %l_f, %l_s;\n\
4636    mul.f64 %l_f2, %l_f, %l_f;\n\
4637    mov.f64 %l_p, 0d3FB745D1745D1746;\n\
4638    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;\n\
4639    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;\n\
4640    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;\n\
4641    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;\n\
4642    fma.rn.f64 %l_p, %l_p, %l_f2, %e_one;\n\
4643    mul.f64 %l_p, %l_p, %l_f;\n\
4644    add.f64 %l_p, %l_p, %l_p;\n\
4645    mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;\n\
4646    fma.rn.f64 %log_sum_exp, %l_nf, %l_ln2, %l_p;\n\
4647    add.f64 %log_sum_exp, %max_val, %log_sum_exp;\n\
4648\n\
4649    mov.u32 %j, %r_tid;\n\
4650WRITE_OUTPUT:\n\
4651    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4652    @%loop_p bra WRITE_OUTPUT_DONE;\n\
4653    cvt.u64.u32 %off, %j;\n\
4654    shl.b64 %off, %off, 3;\n\
4655    add.u64 %saddr, %in, %off;\n\
4656    add.u64 %saddr, %saddr, %row_off;\n\
4657    ld.global.f64 %val, [%saddr];\n\
4658    sub.f64 %result, %val, %log_sum_exp;\n\
4659    cvt.u64.u32 %off, %j;\n\
4660    shl.b64 %off, %off, 3;\n\
4661    add.u64 %saddr, %out, %off;\n\
4662    add.u64 %saddr, %saddr, %row_off;\n\
4663    st.global.f64 [%saddr], %result;\n\
4664    add.u32 %j, %j, %bdim;\n\
4665    bra WRITE_OUTPUT;\n\
4666WRITE_OUTPUT_DONE:\n\
4667\n\
4668DONE:\n\
4669    ret;\n\
4670}\n\
4671";
4672
4673// ---------------------------------------------------------------------------
4674// LogSoftmax backward PTX kernel (row-wise, shared-memory sum reduction)
4675// ---------------------------------------------------------------------------
4676// For each row of length `cols`:
4677//   sum_grad = sum(grad[j])
4678//   out[j] = grad[j] - exp(output[j]) * sum_grad
4679// where output[j] is the log-softmax output, so exp(output[j]) = softmax[j].
4680// One block per row, 256 threads per block.
4681
4682#[cfg(feature = "cuda")]
4683pub(crate) const LOG_SOFTMAX_BACKWARD_PTX: &str = "\
4684.version 7.0\n\
4685.target sm_52\n\
4686.address_size 64\n\
4687\n\
4688.shared .align 4 .f32 sdata[256];\n\
4689\n\
4690.visible .entry log_softmax_backward_kernel(\n\
4691    .param .u64 grad_ptr,\n\
4692    .param .u64 output_ptr,\n\
4693    .param .u64 out_ptr,\n\
4694    .param .u32 rows,\n\
4695    .param .u32 cols\n\
4696) {\n\
4697    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
4698    .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
4699    .reg .f32 %vg, %vo, %sum_grad, %other_val, %softmax_j, %result;\n\
4700    .reg .pred %p, %loop_p, %reduce_p;\n\
4701\n\
4702    ld.param.u64 %grad, [grad_ptr];\n\
4703    ld.param.u64 %output, [output_ptr];\n\
4704    ld.param.u64 %out, [out_ptr];\n\
4705    ld.param.u32 %rows_reg, [rows];\n\
4706    ld.param.u32 %cols_reg, [cols];\n\
4707\n\
4708    mov.u32 %bid, %ctaid.x;\n\
4709    mov.u32 %bdim, %ntid.x;\n\
4710    mov.u32 %r_tid, %tid.x;\n\
4711    mov.u64 %sbase, sdata;\n\
4712\n\
4713    setp.ge.u32 %p, %bid, %rows_reg;\n\
4714    @%p bra DONE;\n\
4715\n\
4716    // row_off = bid * cols * 4 (byte offset)\n\
4717    cvt.u64.u32 %row_off, %bid;\n\
4718    cvt.u64.u32 %off, %cols_reg;\n\
4719    mul.lo.u64 %row_off, %row_off, %off;\n\
4720    shl.b64 %row_off, %row_off, 2;\n\
4721\n\
4722    // Phase 1: compute partial sum_grad = sum(grad[j]) for this thread's elements\n\
4723    mov.f32 %sum_grad, 0f00000000;\n\
4724    mov.u32 %j, %r_tid;\n\
4725SUM_LOOP:\n\
4726    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4727    @%loop_p bra SUM_LOOP_DONE;\n\
4728    cvt.u64.u32 %off, %j;\n\
4729    shl.b64 %off, %off, 2;\n\
4730    add.u64 %saddr, %grad, %off;\n\
4731    add.u64 %saddr, %saddr, %row_off;\n\
4732    ld.global.f32 %vg, [%saddr];\n\
4733    add.f32 %sum_grad, %sum_grad, %vg;\n\
4734    add.u32 %j, %j, %bdim;\n\
4735    bra SUM_LOOP;\n\
4736SUM_LOOP_DONE:\n\
4737\n\
4738    // Store partial sum into shared memory and reduce\n\
4739    cvt.u64.u32 %off, %r_tid;\n\
4740    shl.b64 %off, %off, 2;\n\
4741    add.u64 %saddr, %sbase, %off;\n\
4742    st.shared.f32 [%saddr], %sum_grad;\n\
4743    bar.sync 0;\n\
4744\n\
4745    mov.u32 %half, %bdim;\n\
4746SUM_REDUCE:\n\
4747    shr.u32 %half, %half, 1;\n\
4748    setp.eq.u32 %reduce_p, %half, 0;\n\
4749    @%reduce_p bra SUM_REDUCE_DONE;\n\
4750    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4751    @%reduce_p bra SUM_REDUCE_SKIP;\n\
4752    add.u32 %other_tid, %r_tid, %half;\n\
4753    cvt.u64.u32 %off, %other_tid;\n\
4754    shl.b64 %off, %off, 2;\n\
4755    add.u64 %saddr, %sbase, %off;\n\
4756    ld.shared.f32 %other_val, [%saddr];\n\
4757    cvt.u64.u32 %off, %r_tid;\n\
4758    shl.b64 %off, %off, 2;\n\
4759    add.u64 %saddr, %sbase, %off;\n\
4760    ld.shared.f32 %sum_grad, [%saddr];\n\
4761    add.f32 %sum_grad, %sum_grad, %other_val;\n\
4762    st.shared.f32 [%saddr], %sum_grad;\n\
4763SUM_REDUCE_SKIP:\n\
4764    bar.sync 0;\n\
4765    bra SUM_REDUCE;\n\
4766SUM_REDUCE_DONE:\n\
4767\n\
4768    // Broadcast sum_grad to all threads\n\
4769    ld.shared.f32 %sum_grad, [sdata];\n\
4770    bar.sync 0;\n\
4771\n\
4772    // Phase 2: out[j] = grad[j] - exp(output[j]) * sum_grad\n\
4773    mov.u32 %j, %r_tid;\n\
4774WRITE_LOOP:\n\
4775    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4776    @%loop_p bra WRITE_LOOP_DONE;\n\
4777    cvt.u64.u32 %off, %j;\n\
4778    shl.b64 %off, %off, 2;\n\
4779    add.u64 %saddr, %grad, %off;\n\
4780    add.u64 %saddr, %saddr, %row_off;\n\
4781    ld.global.f32 %vg, [%saddr];\n\
4782    add.u64 %saddr, %output, %off;\n\
4783    add.u64 %saddr, %saddr, %row_off;\n\
4784    ld.global.f32 %vo, [%saddr];\n\
4785    // exp(log_softmax_output) = softmax probability\n\
4786    mul.f32 %vo, %vo, 0f3FB8AA3B;\n\
4787    ex2.approx.f32 %softmax_j, %vo;\n\
4788    // out[j] = grad[j] - softmax[j] * sum_grad\n\
4789    mul.f32 %result, %softmax_j, %sum_grad;\n\
4790    sub.f32 %result, %vg, %result;\n\
4791    add.u64 %saddr, %out, %off;\n\
4792    add.u64 %saddr, %saddr, %row_off;\n\
4793    st.global.f32 [%saddr], %result;\n\
4794    add.u32 %j, %j, %bdim;\n\
4795    bra WRITE_LOOP;\n\
4796WRITE_LOOP_DONE:\n\
4797\n\
4798DONE:\n\
4799    ret;\n\
4800}\n\
4801";
4802
4803/// PTX source for `log_softmax_backward_f64_kernel` (f64).
4804#[cfg(feature = "cuda")]
4805pub(crate) const LOG_SOFTMAX_BACKWARD_F64_PTX: &str = "\
4806.version 7.0\n\
4807.target sm_52\n\
4808.address_size 64\n\
4809\n\
4810.shared .align 8 .f64 sdata[256];\n\
4811\n\
4812.visible .entry log_softmax_backward_f64_kernel(\n\
4813    .param .u64 grad_ptr,\n\
4814    .param .u64 output_ptr,\n\
4815    .param .u64 out_ptr,\n\
4816    .param .u32 rows,\n\
4817    .param .u32 cols\n\
4818) {\n\
4819    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
4820    .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
4821    .reg .f64 %vg, %vo, %sum_grad, %other_val, %softmax_j, %result;\n\
4822    .reg .pred %p, %loop_p, %reduce_p;\n\
4823    .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
4824    .reg .s32 %e_ni;\n\
4825    .reg .s64 %e_ni64, %e_bits;\n\
4826\n\
4827    ld.param.u64 %grad, [grad_ptr];\n\
4828    ld.param.u64 %output, [output_ptr];\n\
4829    ld.param.u64 %out, [out_ptr];\n\
4830    ld.param.u32 %rows_reg, [rows];\n\
4831    ld.param.u32 %cols_reg, [cols];\n\
4832\n\
4833    mov.u32 %bid, %ctaid.x;\n\
4834    mov.u32 %bdim, %ntid.x;\n\
4835    mov.u32 %r_tid, %tid.x;\n\
4836    mov.u64 %sbase, sdata;\n\
4837\n\
4838    setp.ge.u32 %p, %bid, %rows_reg;\n\
4839    @%p bra DONE;\n\
4840\n\
4841    cvt.u64.u32 %row_off, %bid;\n\
4842    cvt.u64.u32 %off, %cols_reg;\n\
4843    mul.lo.u64 %row_off, %row_off, %off;\n\
4844    shl.b64 %row_off, %row_off, 3;\n\
4845\n\
4846    mov.f64 %sum_grad, 0d0000000000000000;\n\
4847    mov.u32 %j, %r_tid;\n\
4848SUM_LOOP:\n\
4849    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4850    @%loop_p bra SUM_LOOP_DONE;\n\
4851    cvt.u64.u32 %off, %j;\n\
4852    shl.b64 %off, %off, 3;\n\
4853    add.u64 %saddr, %grad, %off;\n\
4854    add.u64 %saddr, %saddr, %row_off;\n\
4855    ld.global.f64 %vg, [%saddr];\n\
4856    add.f64 %sum_grad, %sum_grad, %vg;\n\
4857    add.u32 %j, %j, %bdim;\n\
4858    bra SUM_LOOP;\n\
4859SUM_LOOP_DONE:\n\
4860\n\
4861    cvt.u64.u32 %off, %r_tid;\n\
4862    shl.b64 %off, %off, 3;\n\
4863    add.u64 %saddr, %sbase, %off;\n\
4864    st.shared.f64 [%saddr], %sum_grad;\n\
4865    bar.sync 0;\n\
4866\n\
4867    mov.u32 %half, %bdim;\n\
4868SUM_REDUCE:\n\
4869    shr.u32 %half, %half, 1;\n\
4870    setp.eq.u32 %reduce_p, %half, 0;\n\
4871    @%reduce_p bra SUM_REDUCE_DONE;\n\
4872    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4873    @%reduce_p bra SUM_REDUCE_SKIP;\n\
4874    add.u32 %other_tid, %r_tid, %half;\n\
4875    cvt.u64.u32 %off, %other_tid;\n\
4876    shl.b64 %off, %off, 3;\n\
4877    add.u64 %saddr, %sbase, %off;\n\
4878    ld.shared.f64 %other_val, [%saddr];\n\
4879    cvt.u64.u32 %off, %r_tid;\n\
4880    shl.b64 %off, %off, 3;\n\
4881    add.u64 %saddr, %sbase, %off;\n\
4882    ld.shared.f64 %sum_grad, [%saddr];\n\
4883    add.f64 %sum_grad, %sum_grad, %other_val;\n\
4884    st.shared.f64 [%saddr], %sum_grad;\n\
4885SUM_REDUCE_SKIP:\n\
4886    bar.sync 0;\n\
4887    bra SUM_REDUCE;\n\
4888SUM_REDUCE_DONE:\n\
4889\n\
4890    ld.shared.f64 %sum_grad, [sdata];\n\
4891    bar.sync 0;\n\
4892\n\
4893    mov.u32 %j, %r_tid;\n\
4894WRITE_LOOP:\n\
4895    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4896    @%loop_p bra WRITE_LOOP_DONE;\n\
4897    cvt.u64.u32 %off, %j;\n\
4898    shl.b64 %off, %off, 3;\n\
4899    add.u64 %saddr, %grad, %off;\n\
4900    add.u64 %saddr, %saddr, %row_off;\n\
4901    ld.global.f64 %vg, [%saddr];\n\
4902    add.u64 %saddr, %output, %off;\n\
4903    add.u64 %saddr, %saddr, %row_off;\n\
4904    ld.global.f64 %vo, [%saddr];\n\
4905    // exp(log_softmax_output) — inline f64 exp\n\
4906    mov.f64 %e_one, 0d3FF0000000000000;\n\
4907    mov.f64 %e_half, 0d3FE0000000000000;\n\
4908    mul.f64 %e_nf, %vo, 0d3FF71547652B82FE;\n\
4909    cvt.rni.f64.f64 %e_nf, %e_nf;\n\
4910    cvt.rni.s32.f64 %e_ni, %e_nf;\n\
4911    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %vo;\n\
4912    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
4913    mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
4914    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
4915    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
4916    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
4917    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
4918    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
4919    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
4920    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
4921    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
4922    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
4923    fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
4924    fma.rn.f64 %softmax_j, %e_p, %e_r, %e_one;\n\
4925    cvt.s64.s32 %e_ni64, %e_ni;\n\
4926    add.s64 %e_ni64, %e_ni64, 1023;\n\
4927    shl.b64 %e_bits, %e_ni64, 52;\n\
4928    mov.b64 %e_nf, %e_bits;\n\
4929    mul.f64 %softmax_j, %softmax_j, %e_nf;\n\
4930    mul.f64 %result, %softmax_j, %sum_grad;\n\
4931    sub.f64 %result, %vg, %result;\n\
4932    add.u64 %saddr, %out, %off;\n\
4933    add.u64 %saddr, %saddr, %row_off;\n\
4934    st.global.f64 [%saddr], %result;\n\
4935    add.u32 %j, %j, %bdim;\n\
4936    bra WRITE_LOOP;\n\
4937WRITE_LOOP_DONE:\n\
4938\n\
4939DONE:\n\
4940    ret;\n\
4941}\n\
4942";
4943
4944// ---------------------------------------------------------------------------
4945// Sum-axis PTX kernel: reduce along one axis of a tensor
4946// ---------------------------------------------------------------------------
4947// Parameters: input_ptr, output_ptr, outer_size, axis_size, inner_size, total_output
4948/// PTX source for `reduce_sum_kernel`: parallel block-level sum reduction.
4949///
4950/// Each block reduces a contiguous chunk of the input array using shared
4951/// memory. Threads first accumulate a sequential sum (grid-stride loop),
4952/// store to shared memory, then do a tree reduction within the block.
4953/// Each block writes one partial sum to `output[blockIdx.x]`.
4954///
4955/// For a full reduction, launch once to get partial sums, then launch
4956/// again on the partial sums (or reduce on CPU if few blocks).
4957#[cfg(feature = "cuda")]
4958pub(crate) const REDUCE_SUM_PTX: &str = "\
4959.version 7.0
4960.target sm_52
4961.address_size 64
4962
4963// Shared memory for intra-block reduction (256 floats = 1024 bytes).
4964.shared .align 4 .f32 sdata[256];
4965
4966.visible .entry reduce_sum_kernel(
4967    .param .u64 in_ptr,
4968    .param .u64 out_ptr,
4969    .param .u32 n
4970) {
4971    .reg .u32 %tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
4972    .reg .u64 %in, %out, %off;
4973    .reg .f32 %sum, %other;
4974    .reg .pred %p, %ptid;
4975
4976    ld.param.u64 %in, [in_ptr];
4977    ld.param.u64 %out, [out_ptr];
4978    ld.param.u32 %n_reg, [n];
4979
4980    mov.u32 %tid, %tid.x;
4981    mov.u32 %bid, %ctaid.x;
4982    mov.u32 %bdim, %ntid.x;
4983    mov.u32 %gdim, %nctaid.x;
4984
4985    // Grid-stride accumulation: each thread sums multiple elements.
4986    // idx = bid * bdim + tid; stride = bdim * gdim
4987    mad.lo.u32 %idx, %bid, %bdim, %tid;
4988    mul.lo.u32 %stride, %bdim, %gdim;
4989    mov.f32 %sum, 0f00000000;
4990
4991GRID_LOOP:
4992    setp.ge.u32 %p, %idx, %n_reg;
4993    @%p bra GRID_DONE;
4994
4995    cvt.u64.u32 %off, %idx;
4996    shl.b64 %off, %off, 2;
4997    add.u64 %off, %in, %off;
4998    ld.global.f32 %other, [%off];
4999    add.f32 %sum, %sum, %other;
5000    add.u32 %idx, %idx, %stride;
5001    bra GRID_LOOP;
5002
5003GRID_DONE:
5004    // Write thread's partial sum to shared memory.
5005    cvt.u64.u32 %off, %tid;
5006    shl.b64 %off, %off, 2;
5007    st.shared.f32 [sdata + %off], %sum;
5008    bar.sync 0;
5009
5010    // Tree reduction in shared memory.
5011    mov.u32 %half, 128;
5012TREE_LOOP:
5013    setp.lt.u32 %p, %half, 1;
5014    @%p bra TREE_DONE;
5015
5016    setp.ge.u32 %ptid, %tid, %half;
5017    @%ptid bra TREE_SKIP;
5018
5019    // Load partner's value from sdata[tid + half].
5020    add.u32 %idx, %tid, %half;
5021    cvt.u64.u32 %off, %idx;
5022    shl.b64 %off, %off, 2;
5023    ld.shared.f32 %other, [sdata + %off];
5024    // Load own value.
5025    cvt.u64.u32 %off, %tid;
5026    shl.b64 %off, %off, 2;
5027    ld.shared.f32 %sum, [sdata + %off];
5028    add.f32 %sum, %sum, %other;
5029    st.shared.f32 [sdata + %off], %sum;
5030
5031TREE_SKIP:
5032    bar.sync 0;
5033    shr.u32 %half, %half, 1;
5034    bra TREE_LOOP;
5035
5036TREE_DONE:
5037    // Thread 0 writes block result.
5038    setp.ne.u32 %ptid, %tid, 0;
5039    @%ptid bra END;
5040
5041    ld.shared.f32 %sum, [sdata];
5042    cvt.u64.u32 %off, %bid;
5043    shl.b64 %off, %off, 2;
5044    add.u64 %out, %out, %off;
5045    st.global.f32 [%out], %sum;
5046
5047END:
5048    ret;
5049}
5050";
5051
5052
5053// Thread i: output[i] = sum_{k=0}^{axis_size-1} input[outer_idx * axis_size * inner_size + k * inner_size + inner_idx]
5054// where outer_idx = i / inner_size, inner_idx = i % inner_size.
5055
5056
5057#[cfg(feature = "cuda")]
5058pub(crate) const SUM_AXIS_PTX: &str = "\
5059.version 7.0
5060.target sm_52
5061.address_size 64
5062
5063.visible .entry sum_axis_kernel(
5064    .param .u64 input_ptr,
5065    .param .u64 output_ptr,
5066    .param .u32 outer_size,
5067    .param .u32 axis_size,
5068    .param .u32 inner_size,
5069    .param .u32 total_output
5070) {
5071    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %axis_sz, %inner_sz;
5072    .reg .u32 %outer_idx, %inner_idx, %k, %tmp;
5073    .reg .u64 %in, %out, %off, %addr;
5074    .reg .f32 %val, %sum;
5075    .reg .pred %p, %lp;
5076
5077    ld.param.u64 %in, [input_ptr];
5078    ld.param.u64 %out, [output_ptr];
5079    ld.param.u32 %outer_sz, [outer_size];
5080    ld.param.u32 %axis_sz, [axis_size];
5081    ld.param.u32 %inner_sz, [inner_size];
5082    ld.param.u32 %n_reg, [total_output];
5083
5084    mov.u32 %bid, %ctaid.x;
5085    mov.u32 %bdim, %ntid.x;
5086    mov.u32 %r_tid, %tid.x;
5087    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5088
5089    setp.ge.u32 %p, %r_tid, %n_reg;
5090    @%p bra DONE;
5091
5092    // outer_idx = r_tid / inner_size
5093    div.u32 %outer_idx, %r_tid, %inner_sz;
5094    // inner_idx = r_tid % inner_size
5095    rem.u32 %inner_idx, %r_tid, %inner_sz;
5096
5097    // base = outer_idx * axis_size * inner_size + inner_idx
5098    mul.lo.u32 %tmp, %outer_idx, %axis_sz;
5099    mul.lo.u32 %tmp, %tmp, %inner_sz;
5100    add.u32 %tmp, %tmp, %inner_idx;
5101
5102    mov.f32 %sum, 0f00000000;
5103    mov.u32 %k, 0;
5104SUM_LOOP:
5105    setp.ge.u32 %lp, %k, %axis_sz;
5106    @%lp bra SUM_LOOP_DONE;
5107
5108    // addr = in + (tmp + k * inner_size) * 4
5109    mul.lo.u32 %inner_idx, %k, %inner_sz;
5110    add.u32 %inner_idx, %tmp, %inner_idx;
5111    cvt.u64.u32 %off, %inner_idx;
5112    shl.b64 %off, %off, 2;
5113    add.u64 %addr, %in, %off;
5114    ld.global.f32 %val, [%addr];
5115    add.f32 %sum, %sum, %val;
5116
5117    add.u32 %k, %k, 1;
5118    bra SUM_LOOP;
5119SUM_LOOP_DONE:
5120
5121    // output[r_tid] = sum
5122    cvt.u64.u32 %off, %r_tid;
5123    shl.b64 %off, %off, 2;
5124    add.u64 %addr, %out, %off;
5125    st.global.f32 [%addr], %sum;
5126
5127DONE:
5128    ret;
5129}
5130";
5131
5132// ---------------------------------------------------------------------------
5133// Cumulative scan PTX kernels
5134//
5135// One thread per (outer_idx, inner_idx) pair. Each thread does a sequential
5136// scan along dim_size elements. Parallelism comes from outer*inner threads.
5137// ---------------------------------------------------------------------------
5138
5139/// PTX source for `cumsum_kernel`: prefix sum along an axis.
5140///
5141/// Thread i processes the scan for outer_idx = i / inner, inner_idx = i % inner.
5142/// `output[base + k*inner] = sum_{j=0}^{k} input[base + j*inner]`
5143#[cfg(feature = "cuda")]
5144pub(crate) const CUMSUM_PTX: &str = "\
5145.version 7.0
5146.target sm_52
5147.address_size 64
5148
5149.visible .entry cumsum_kernel(
5150    .param .u64 input_ptr,
5151    .param .u64 output_ptr,
5152    .param .u32 outer_size,
5153    .param .u32 dim_size,
5154    .param .u32 inner_size,
5155    .param .u32 total
5156) {
5157    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5158    .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5159    .reg .u64 %in, %out, %off, %addr;
5160    .reg .f32 %val, %acc;
5161    .reg .pred %p, %lp;
5162
5163    ld.param.u64 %in, [input_ptr];
5164    ld.param.u64 %out, [output_ptr];
5165    ld.param.u32 %outer_sz, [outer_size];
5166    ld.param.u32 %dim_sz, [dim_size];
5167    ld.param.u32 %inner_sz, [inner_size];
5168    ld.param.u32 %n_reg, [total];
5169
5170    mov.u32 %bid, %ctaid.x;
5171    mov.u32 %bdim, %ntid.x;
5172    mov.u32 %r_tid, %tid.x;
5173    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5174
5175    // total threads = outer * inner
5176    mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5177    setp.ge.u32 %p, %r_tid, %tmp;
5178    @%p bra DONE;
5179
5180    div.u32 %outer_idx, %r_tid, %inner_sz;
5181    rem.u32 %inner_idx, %r_tid, %inner_sz;
5182
5183    // base = outer_idx * dim_size * inner_size + inner_idx
5184    mul.lo.u32 %base, %outer_idx, %dim_sz;
5185    mul.lo.u32 %base, %base, %inner_sz;
5186    add.u32 %base, %base, %inner_idx;
5187
5188    mov.f32 %acc, 0f00000000;
5189    mov.u32 %k, 0;
5190SCAN_LOOP:
5191    setp.ge.u32 %lp, %k, %dim_sz;
5192    @%lp bra SCAN_DONE;
5193
5194    // idx = base + k * inner_size
5195    mul.lo.u32 %idx, %k, %inner_sz;
5196    add.u32 %idx, %base, %idx;
5197
5198    cvt.u64.u32 %off, %idx;
5199    shl.b64 %off, %off, 2;
5200    add.u64 %addr, %in, %off;
5201    ld.global.f32 %val, [%addr];
5202
5203    add.f32 %acc, %acc, %val;
5204
5205    add.u64 %addr, %out, %off;
5206    st.global.f32 [%addr], %acc;
5207
5208    add.u32 %k, %k, 1;
5209    bra SCAN_LOOP;
5210SCAN_DONE:
5211
5212DONE:
5213    ret;
5214}
5215";
5216
5217
5218/// PTX source for `cumprod_kernel`: prefix product along an axis.
5219///
5220/// Thread i processes the scan for outer_idx = i / inner, inner_idx = i % inner.
5221/// `output[base + k*inner] = prod_{j=0}^{k} input[base + j*inner]`
5222#[cfg(feature = "cuda")]
5223pub(crate) const CUMPROD_PTX: &str = "\
5224.version 7.0
5225.target sm_52
5226.address_size 64
5227
5228.visible .entry cumprod_kernel(
5229    .param .u64 input_ptr,
5230    .param .u64 output_ptr,
5231    .param .u32 outer_size,
5232    .param .u32 dim_size,
5233    .param .u32 inner_size,
5234    .param .u32 total
5235) {
5236    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5237    .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5238    .reg .u64 %in, %out, %off, %addr;
5239    .reg .f32 %val, %acc;
5240    .reg .pred %p, %lp;
5241
5242    ld.param.u64 %in, [input_ptr];
5243    ld.param.u64 %out, [output_ptr];
5244    ld.param.u32 %outer_sz, [outer_size];
5245    ld.param.u32 %dim_sz, [dim_size];
5246    ld.param.u32 %inner_sz, [inner_size];
5247    ld.param.u32 %n_reg, [total];
5248
5249    mov.u32 %bid, %ctaid.x;
5250    mov.u32 %bdim, %ntid.x;
5251    mov.u32 %r_tid, %tid.x;
5252    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5253
5254    mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5255    setp.ge.u32 %p, %r_tid, %tmp;
5256    @%p bra DONE;
5257
5258    div.u32 %outer_idx, %r_tid, %inner_sz;
5259    rem.u32 %inner_idx, %r_tid, %inner_sz;
5260
5261    mul.lo.u32 %base, %outer_idx, %dim_sz;
5262    mul.lo.u32 %base, %base, %inner_sz;
5263    add.u32 %base, %base, %inner_idx;
5264
5265    // acc = 1.0
5266    mov.f32 %acc, 0f3F800000;
5267    mov.u32 %k, 0;
5268SCAN_LOOP:
5269    setp.ge.u32 %lp, %k, %dim_sz;
5270    @%lp bra SCAN_DONE;
5271
5272    mul.lo.u32 %idx, %k, %inner_sz;
5273    add.u32 %idx, %base, %idx;
5274
5275    cvt.u64.u32 %off, %idx;
5276    shl.b64 %off, %off, 2;
5277    add.u64 %addr, %in, %off;
5278    ld.global.f32 %val, [%addr];
5279
5280    mul.f32 %acc, %acc, %val;
5281
5282    add.u64 %addr, %out, %off;
5283    st.global.f32 [%addr], %acc;
5284
5285    add.u32 %k, %k, 1;
5286    bra SCAN_LOOP;
5287SCAN_DONE:
5288
5289DONE:
5290    ret;
5291}
5292";
5293
5294
5295/// PTX source for `cummax_kernel`: running maximum along an axis.
5296///
5297/// Thread i processes the scan for outer_idx = i / inner, inner_idx = i % inner.
5298/// Outputs both values and argmax indices (as f32 for uniform buffer handling).
5299/// `values[idx] = max_{j=0}^{k} input[base + j*inner]`
5300/// `indices[idx] = argmax_{j=0}^{k} input[base + j*inner]`
5301#[cfg(feature = "cuda")]
5302pub(crate) const CUMMAX_PTX: &str = "\
5303.version 7.0
5304.target sm_52
5305.address_size 64
5306
5307.visible .entry cummax_kernel(
5308    .param .u64 input_ptr,
5309    .param .u64 output_ptr,
5310    .param .u64 indices_ptr,
5311    .param .u32 outer_size,
5312    .param .u32 dim_size,
5313    .param .u32 inner_size,
5314    .param .u32 total
5315) {
5316    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5317    .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
5318    .reg .u64 %in, %out, %ind, %off, %addr;
5319    .reg .f32 %val, %acc, %best_k_f;
5320    .reg .pred %p, %lp, %is_new_max;
5321
5322    ld.param.u64 %in, [input_ptr];
5323    ld.param.u64 %out, [output_ptr];
5324    ld.param.u64 %ind, [indices_ptr];
5325    ld.param.u32 %outer_sz, [outer_size];
5326    ld.param.u32 %dim_sz, [dim_size];
5327    ld.param.u32 %inner_sz, [inner_size];
5328    ld.param.u32 %n_reg, [total];
5329
5330    mov.u32 %bid, %ctaid.x;
5331    mov.u32 %bdim, %ntid.x;
5332    mov.u32 %r_tid, %tid.x;
5333    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5334
5335    mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5336    setp.ge.u32 %p, %r_tid, %tmp;
5337    @%p bra DONE;
5338
5339    div.u32 %outer_idx, %r_tid, %inner_sz;
5340    rem.u32 %inner_idx, %r_tid, %inner_sz;
5341
5342    mul.lo.u32 %base, %outer_idx, %dim_sz;
5343    mul.lo.u32 %base, %base, %inner_sz;
5344    add.u32 %base, %base, %inner_idx;
5345
5346    mov.b32 %acc, 0xFF800000;
5347    mov.u32 %best_k, 0;
5348    mov.u32 %k, 0;
5349SCAN_LOOP:
5350    setp.ge.u32 %lp, %k, %dim_sz;
5351    @%lp bra SCAN_DONE;
5352
5353    mul.lo.u32 %idx, %k, %inner_sz;
5354    add.u32 %idx, %base, %idx;
5355
5356    cvt.u64.u32 %off, %idx;
5357    shl.b64 %off, %off, 2;
5358    add.u64 %addr, %in, %off;
5359    ld.global.f32 %val, [%addr];
5360
5361    setp.gt.f32 %is_new_max, %val, %acc;
5362    @%is_new_max mov.u32 %best_k, %k;
5363    max.f32 %acc, %acc, %val;
5364
5365    add.u64 %addr, %out, %off;
5366    st.global.f32 [%addr], %acc;
5367
5368    cvt.rn.f32.u32 %best_k_f, %best_k;
5369    add.u64 %addr, %ind, %off;
5370    st.global.f32 [%addr], %best_k_f;
5371
5372    add.u32 %k, %k, 1;
5373    bra SCAN_LOOP;
5374SCAN_DONE:
5375
5376DONE:
5377    ret;
5378}
5379";
5380
5381
5382/// PTX source for `cummin_kernel`: running minimum along an axis.
5383///
5384/// Thread i processes the scan for outer_idx = i / inner, inner_idx = i % inner.
5385/// Outputs both values and argmin indices (as f32 for uniform buffer handling).
5386#[cfg(feature = "cuda")]
5387pub(crate) const CUMMIN_PTX: &str = "\
5388.version 7.0
5389.target sm_52
5390.address_size 64
5391
5392.visible .entry cummin_kernel(
5393    .param .u64 input_ptr,
5394    .param .u64 output_ptr,
5395    .param .u64 indices_ptr,
5396    .param .u32 outer_size,
5397    .param .u32 dim_size,
5398    .param .u32 inner_size,
5399    .param .u32 total
5400) {
5401    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5402    .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
5403    .reg .u64 %in, %out, %ind, %off, %addr;
5404    .reg .f32 %val, %acc, %best_k_f;
5405    .reg .pred %p, %lp, %is_new_min;
5406
5407    ld.param.u64 %in, [input_ptr];
5408    ld.param.u64 %out, [output_ptr];
5409    ld.param.u64 %ind, [indices_ptr];
5410    ld.param.u32 %outer_sz, [outer_size];
5411    ld.param.u32 %dim_sz, [dim_size];
5412    ld.param.u32 %inner_sz, [inner_size];
5413    ld.param.u32 %n_reg, [total];
5414
5415    mov.u32 %bid, %ctaid.x;
5416    mov.u32 %bdim, %ntid.x;
5417    mov.u32 %r_tid, %tid.x;
5418    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5419
5420    mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5421    setp.ge.u32 %p, %r_tid, %tmp;
5422    @%p bra DONE;
5423
5424    div.u32 %outer_idx, %r_tid, %inner_sz;
5425    rem.u32 %inner_idx, %r_tid, %inner_sz;
5426
5427    mul.lo.u32 %base, %outer_idx, %dim_sz;
5428    mul.lo.u32 %base, %base, %inner_sz;
5429    add.u32 %base, %base, %inner_idx;
5430
5431    mov.b32 %acc, 0x7F800000;
5432    mov.u32 %best_k, 0;
5433    mov.u32 %k, 0;
5434SCAN_LOOP:
5435    setp.ge.u32 %lp, %k, %dim_sz;
5436    @%lp bra SCAN_DONE;
5437
5438    mul.lo.u32 %idx, %k, %inner_sz;
5439    add.u32 %idx, %base, %idx;
5440
5441    cvt.u64.u32 %off, %idx;
5442    shl.b64 %off, %off, 2;
5443    add.u64 %addr, %in, %off;
5444    ld.global.f32 %val, [%addr];
5445
5446    setp.lt.f32 %is_new_min, %val, %acc;
5447    @%is_new_min mov.u32 %best_k, %k;
5448    min.f32 %acc, %acc, %val;
5449
5450    add.u64 %addr, %out, %off;
5451    st.global.f32 [%addr], %acc;
5452
5453    cvt.rn.f32.u32 %best_k_f, %best_k;
5454    add.u64 %addr, %ind, %off;
5455    st.global.f32 [%addr], %best_k_f;
5456
5457    add.u32 %k, %k, 1;
5458    bra SCAN_LOOP;
5459SCAN_DONE:
5460
5461DONE:
5462    ret;
5463}
5464";
5465
5466
5467/// PTX source for `logcumsumexp_kernel`: numerically stable log-cumulative-sum-exp.
5468///
5469/// Thread i processes the scan for outer_idx = i / inner, inner_idx = i % inner.
5470/// `acc = log(exp(acc) + exp(x))` computed as `m + log(exp(acc-m) + exp(x-m))`
5471/// where `m = max(acc, x)` for numerical stability.
5472///
5473/// Uses `ex2.approx.f32` for exp and `lg2.approx.f32` for log with
5474/// log2(e) and ln(2) conversion constants.
5475#[cfg(feature = "cuda")]
5476pub(crate) const LOGCUMSUMEXP_PTX: &str = "\
5477.version 7.0
5478.target sm_52
5479.address_size 64
5480
5481.visible .entry logcumsumexp_kernel(
5482    .param .u64 input_ptr,
5483    .param .u64 output_ptr,
5484    .param .u32 outer_size,
5485    .param .u32 dim_size,
5486    .param .u32 inner_size,
5487    .param .u32 total
5488) {
5489    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5490    .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5491    .reg .u64 %in, %out, %off, %addr;
5492    .reg .f32 %val, %acc, %m, %ea, %ev, %s, %ls, %log2e, %ln2;
5493    .reg .pred %p, %lp;
5494
5495    ld.param.u64 %in, [input_ptr];
5496    ld.param.u64 %out, [output_ptr];
5497    ld.param.u32 %outer_sz, [outer_size];
5498    ld.param.u32 %dim_sz, [dim_size];
5499    ld.param.u32 %inner_sz, [inner_size];
5500    ld.param.u32 %n_reg, [total];
5501
5502    // log2(e) = 1.4426950408...  -> 0x3FB8AA3B
5503    mov.b32 %log2e, 0x3FB8AA3B;
5504    // ln(2) = 0.6931471805... -> 0x3F317218
5505    mov.b32 %ln2, 0x3F317218;
5506
5507    mov.u32 %bid, %ctaid.x;
5508    mov.u32 %bdim, %ntid.x;
5509    mov.u32 %r_tid, %tid.x;
5510    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5511
5512    mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5513    setp.ge.u32 %p, %r_tid, %tmp;
5514    @%p bra DONE;
5515
5516    div.u32 %outer_idx, %r_tid, %inner_sz;
5517    rem.u32 %inner_idx, %r_tid, %inner_sz;
5518
5519    mul.lo.u32 %base, %outer_idx, %dim_sz;
5520    mul.lo.u32 %base, %base, %inner_sz;
5521    add.u32 %base, %base, %inner_idx;
5522
5523    // acc = -inf
5524    mov.b32 %acc, 0xFF800000;
5525    mov.u32 %k, 0;
5526SCAN_LOOP:
5527    setp.ge.u32 %lp, %k, %dim_sz;
5528    @%lp bra SCAN_DONE;
5529
5530    mul.lo.u32 %idx, %k, %inner_sz;
5531    add.u32 %idx, %base, %idx;
5532
5533    cvt.u64.u32 %off, %idx;
5534    shl.b64 %off, %off, 2;
5535    add.u64 %addr, %in, %off;
5536    ld.global.f32 %val, [%addr];
5537
5538    // Numerically stable: m = max(acc, x)
5539    max.f32 %m, %acc, %val;
5540    // exp(acc - m): (acc - m) * log2(e) -> ex2
5541    sub.f32 %ea, %acc, %m;
5542    mul.f32 %ea, %ea, %log2e;
5543    ex2.approx.f32 %ea, %ea;
5544    // exp(x - m): (x - m) * log2(e) -> ex2
5545    sub.f32 %ev, %val, %m;
5546    mul.f32 %ev, %ev, %log2e;
5547    ex2.approx.f32 %ev, %ev;
5548    // sum
5549    add.f32 %s, %ea, %ev;
5550    // log(sum) = lg2(sum) * ln(2)
5551    lg2.approx.f32 %ls, %s;
5552    mul.f32 %ls, %ls, %ln2;
5553    // acc = m + log(sum)
5554    add.f32 %acc, %m, %ls;
5555
5556    add.u64 %addr, %out, %off;
5557    st.global.f32 [%addr], %acc;
5558
5559    add.u32 %k, %k, 1;
5560    bra SCAN_LOOP;
5561SCAN_DONE:
5562
5563DONE:
5564    ret;
5565}
5566";
5567
5568/// PTX source for `logcumsumexp_f64_kernel`: numerically stable log-cumulative-sum-exp (f64).
5569#[cfg(feature = "cuda")]
5570pub(crate) const LOGCUMSUMEXP_F64_PTX: &str = "\
5571.version 7.0
5572.target sm_52
5573.address_size 64
5574
5575.visible .entry logcumsumexp_f64_kernel(
5576    .param .u64 input_ptr,
5577    .param .u64 output_ptr,
5578    .param .u32 outer_size,
5579    .param .u32 dim_size,
5580    .param .u32 inner_size,
5581    .param .u32 total
5582) {
5583    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5584    .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5585    .reg .u64 %in, %out, %off, %addr;
5586    .reg .f64 %val, %acc, %m, %ea, %ev, %s, %ls;
5587    .reg .pred %p, %lp;
5588    .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;
5589    .reg .s32 %e_ni;
5590    .reg .s64 %e_ni64, %e_bits;
5591    .reg .u64 %l_xbits, %l_mbits, %l_bias;
5592    .reg .s64 %l_exp64;
5593    .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
5594
5595    ld.param.u64 %in, [input_ptr];
5596    ld.param.u64 %out, [output_ptr];
5597    ld.param.u32 %outer_sz, [outer_size];
5598    ld.param.u32 %dim_sz, [dim_size];
5599    ld.param.u32 %inner_sz, [inner_size];
5600    ld.param.u32 %n_reg, [total];
5601
5602    mov.u32 %bid, %ctaid.x;
5603    mov.u32 %bdim, %ntid.x;
5604    mov.u32 %r_tid, %tid.x;
5605    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5606
5607    mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5608    setp.ge.u32 %p, %r_tid, %tmp;
5609    @%p bra DONE;
5610
5611    div.u32 %outer_idx, %r_tid, %inner_sz;
5612    rem.u32 %inner_idx, %r_tid, %inner_sz;
5613
5614    mul.lo.u32 %base, %outer_idx, %dim_sz;
5615    mul.lo.u32 %base, %base, %inner_sz;
5616    add.u32 %base, %base, %inner_idx;
5617
5618    // acc = -inf
5619    mov.b64 %acc, 0xFFF0000000000000;
5620    mov.u32 %k, 0;
5621SCAN_LOOP:
5622    setp.ge.u32 %lp, %k, %dim_sz;
5623    @%lp bra SCAN_DONE;
5624
5625    mul.lo.u32 %idx, %k, %inner_sz;
5626    add.u32 %idx, %base, %idx;
5627
5628    cvt.u64.u32 %off, %idx;
5629    shl.b64 %off, %off, 3;
5630    add.u64 %addr, %in, %off;
5631    ld.global.f64 %val, [%addr];
5632
5633    max.f64 %m, %acc, %val;
5634    mov.f64 %e_one, 0d3FF0000000000000;
5635    mov.f64 %e_half, 0d3FE0000000000000;
5636    // --- inline exp(acc - m) -> %ea ---
5637    sub.f64 %ea, %acc, %m;
5638    mul.f64 %e_nf, %ea, 0d3FF71547652B82FE;
5639    cvt.rni.f64.f64 %e_nf, %e_nf;
5640    cvt.rni.s32.f64 %e_ni, %e_nf;
5641    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %ea;
5642    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
5643    mov.f64 %e_p, 0d3E21EED8EFF8D898;
5644    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
5645    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
5646    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
5647    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
5648    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
5649    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
5650    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
5651    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
5652    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
5653    fma.rn.f64 %e_p, %e_p, %e_r, %e_one;
5654    fma.rn.f64 %ea, %e_p, %e_r, %e_one;
5655    cvt.s64.s32 %e_ni64, %e_ni;
5656    add.s64 %e_ni64, %e_ni64, 1023;
5657    shl.b64 %e_bits, %e_ni64, 52;
5658    mov.b64 %e_nf, %e_bits;
5659    mul.f64 %ea, %ea, %e_nf;
5660    // --- inline exp(val - m) -> %ev ---
5661    sub.f64 %ev, %val, %m;
5662    mul.f64 %e_nf, %ev, 0d3FF71547652B82FE;
5663    cvt.rni.f64.f64 %e_nf, %e_nf;
5664    cvt.rni.s32.f64 %e_ni, %e_nf;
5665    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %ev;
5666    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
5667    mov.f64 %e_p, 0d3E21EED8EFF8D898;
5668    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
5669    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
5670    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
5671    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
5672    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
5673    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
5674    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
5675    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
5676    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
5677    fma.rn.f64 %e_p, %e_p, %e_r, %e_one;
5678    fma.rn.f64 %ev, %e_p, %e_r, %e_one;
5679    cvt.s64.s32 %e_ni64, %e_ni;
5680    add.s64 %e_ni64, %e_ni64, 1023;
5681    shl.b64 %e_bits, %e_ni64, 52;
5682    mov.b64 %e_nf, %e_bits;
5683    mul.f64 %ev, %ev, %e_nf;
5684    add.f64 %s, %ea, %ev;
5685    // --- inline ln(%s) -> %ls ---
5686    mov.b64 %l_xbits, %s;
5687    shr.u64 %l_exp64, %l_xbits, 52;
5688    and.b64 %l_exp64, %l_exp64, 2047;
5689    sub.s64 %l_exp64, %l_exp64, 1023;
5690    cvt.rn.f64.s64 %l_nf, %l_exp64;
5691    mov.u64 %l_bias, 0x3FF0000000000000;
5692    and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
5693    or.b64 %l_mbits, %l_mbits, %l_bias;
5694    mov.b64 %l_m, %l_mbits;
5695    sub.f64 %l_f, %l_m, %e_one;
5696    add.f64 %l_s, %l_m, %e_one;
5697    div.rn.f64 %l_f, %l_f, %l_s;
5698    mul.f64 %l_f2, %l_f, %l_f;
5699    mov.f64 %l_p, 0d3FB745D1745D1746;
5700    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
5701    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
5702    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
5703    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
5704    fma.rn.f64 %l_p, %l_p, %l_f2, %e_one;
5705    mul.f64 %l_p, %l_p, %l_f;
5706    add.f64 %l_p, %l_p, %l_p;
5707    mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
5708    fma.rn.f64 %ls, %l_nf, %l_ln2, %l_p;
5709    add.f64 %acc, %m, %ls;
5710
5711    add.u64 %addr, %out, %off;
5712    st.global.f64 [%addr], %acc;
5713
5714    add.u32 %k, %k, 1;
5715    bra SCAN_LOOP;
5716SCAN_DONE:
5717
5718DONE:
5719    ret;
5720}
5721";
5722
5723// ---------------------------------------------------------------------------
5724// LayerNorm PTX kernel (row-wise: mean, var, normalize+affine)
5725//
5726// Uses `.approx` PTX instructions (`div.approx.f32`, `sqrt.approx.f32`,
5727// `rcp.approx.f32`) for performance. These have reduced precision (~2^-22
5728// relative error) compared to the full-precision variants, which is
5729// acceptable for neural network training/inference.
5730// ---------------------------------------------------------------------------
5731
5732#[cfg(feature = "cuda")]
5733pub(crate) const LAYERNORM_PTX: &str = "\
5734.version 7.0
5735.target sm_52
5736.address_size 64
5737
5738.shared .align 4 .f32 sdata[256];
5739
5740.visible .entry layernorm_kernel(
5741    .param .u64 in_ptr,
5742    .param .u64 out_ptr,
5743    .param .u64 w_ptr,
5744    .param .u64 b_ptr,
5745    .param .u32 rows,
5746    .param .u32 cols,
5747    .param .f32 eps
5748) {
5749    .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
5750    .reg .u64 %in, %out, %w, %b, %row_off, %off, %sbase, %saddr;
5751    .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %normed, %wv, %bv, %result, %other_val, %n_f;
5752    .reg .pred %p, %lp, %rp;
5753
5754    ld.param.u64 %in, [in_ptr];
5755    ld.param.u64 %out, [out_ptr];
5756    ld.param.u64 %w, [w_ptr];
5757    ld.param.u64 %b, [b_ptr];
5758    ld.param.u32 %rows_reg, [rows];
5759    ld.param.u32 %cols_reg, [cols];
5760    ld.param.f32 %eps_r, [eps];
5761
5762    mov.u64 %sbase, sdata;
5763
5764    mov.u32 %r_bid, %ctaid.x;
5765    mov.u32 %r_bdim, %ntid.x;
5766    mov.u32 %r_tid, %tid.x;
5767
5768    setp.ge.u32 %p, %r_bid, %rows_reg;
5769    @%p bra DONE;
5770
5771    cvt.u64.u32 %row_off, %r_bid;
5772    cvt.u64.u32 %off, %cols_reg;
5773    mul.lo.u64 %row_off, %row_off, %off;
5774    shl.b64 %row_off, %row_off, 2;
5775    cvt.rn.f32.u32 %n_f, %cols_reg;
5776
5777    mov.f32 %mean, 0f00000000;
5778    mov.u32 %j, %r_tid;
5779SM:
5780    setp.ge.u32 %lp, %j, %cols_reg;
5781    @%lp bra SMD;
5782    cvt.u64.u32 %off, %j;
5783    shl.b64 %off, %off, 2;
5784    add.u64 %off, %in, %off;
5785    add.u64 %off, %off, %row_off;
5786    ld.global.f32 %val, [%off];
5787    add.f32 %mean, %mean, %val;
5788    add.u32 %j, %j, %r_bdim;
5789    bra SM;
5790SMD:
5791    cvt.u64.u32 %off, %r_tid;
5792    shl.b64 %off, %off, 2;
5793    add.u64 %saddr, %sbase, %off;
5794    st.shared.f32 [%saddr], %mean;
5795    bar.sync 0;
5796    mov.u32 %half, %r_bdim;
5797MR:
5798    shr.u32 %half, %half, 1;
5799    setp.eq.u32 %rp, %half, 0;
5800    @%rp bra MRD;
5801    setp.ge.u32 %rp, %r_tid, %half;
5802    @%rp bra MRS;
5803    add.u32 %r_otid, %r_tid, %half;
5804    cvt.u64.u32 %off, %r_otid;
5805    shl.b64 %off, %off, 2;
5806    add.u64 %saddr, %sbase, %off;
5807    ld.shared.f32 %other_val, [%saddr];
5808    cvt.u64.u32 %off, %r_tid;
5809    shl.b64 %off, %off, 2;
5810    add.u64 %saddr, %sbase, %off;
5811    ld.shared.f32 %mean, [%saddr];
5812    add.f32 %mean, %mean, %other_val;
5813    add.u64 %saddr, %sbase, %off;
5814    st.shared.f32 [%saddr], %mean;
5815MRS:
5816    bar.sync 0;
5817    bra MR;
5818MRD:
5819    ld.shared.f32 %mean, [%sbase];
5820    div.approx.f32 %mean, %mean, %n_f;
5821    bar.sync 0;
5822
5823    mov.f32 %var, 0f00000000;
5824    mov.u32 %j, %r_tid;
5825SV:
5826    setp.ge.u32 %lp, %j, %cols_reg;
5827    @%lp bra SVD;
5828    cvt.u64.u32 %off, %j;
5829    shl.b64 %off, %off, 2;
5830    add.u64 %off, %in, %off;
5831    add.u64 %off, %off, %row_off;
5832    ld.global.f32 %val, [%off];
5833    sub.f32 %diff, %val, %mean;
5834    fma.rn.f32 %var, %diff, %diff, %var;
5835    add.u32 %j, %j, %r_bdim;
5836    bra SV;
5837SVD:
5838    cvt.u64.u32 %off, %r_tid;
5839    shl.b64 %off, %off, 2;
5840    add.u64 %saddr, %sbase, %off;
5841    st.shared.f32 [%saddr], %var;
5842    bar.sync 0;
5843    mov.u32 %half, %r_bdim;
5844VR:
5845    shr.u32 %half, %half, 1;
5846    setp.eq.u32 %rp, %half, 0;
5847    @%rp bra VRD;
5848    setp.ge.u32 %rp, %r_tid, %half;
5849    @%rp bra VRS;
5850    add.u32 %r_otid, %r_tid, %half;
5851    cvt.u64.u32 %off, %r_otid;
5852    shl.b64 %off, %off, 2;
5853    add.u64 %saddr, %sbase, %off;
5854    ld.shared.f32 %other_val, [%saddr];
5855    cvt.u64.u32 %off, %r_tid;
5856    shl.b64 %off, %off, 2;
5857    add.u64 %saddr, %sbase, %off;
5858    ld.shared.f32 %var, [%saddr];
5859    add.f32 %var, %var, %other_val;
5860    add.u64 %saddr, %sbase, %off;
5861    st.shared.f32 [%saddr], %var;
5862VRS:
5863    bar.sync 0;
5864    bra VR;
5865VRD:
5866    ld.shared.f32 %var, [%sbase];
5867    div.approx.f32 %var, %var, %n_f;
5868    add.f32 %var, %var, %eps_r;
5869    sqrt.approx.f32 %inv_std, %var;
5870    rcp.approx.f32 %inv_std, %inv_std;
5871    bar.sync 0;
5872
5873    mov.u32 %j, %r_tid;
5874NM:
5875    setp.ge.u32 %lp, %j, %cols_reg;
5876    @%lp bra NMD;
5877    cvt.u64.u32 %off, %j;
5878    shl.b64 %off, %off, 2;
5879    add.u64 %off, %in, %off;
5880    add.u64 %off, %off, %row_off;
5881    ld.global.f32 %val, [%off];
5882    sub.f32 %normed, %val, %mean;
5883    mul.f32 %normed, %normed, %inv_std;
5884    cvt.u64.u32 %off, %j;
5885    shl.b64 %off, %off, 2;
5886    add.u64 %off, %w, %off;
5887    ld.global.f32 %wv, [%off];
5888    cvt.u64.u32 %off, %j;
5889    shl.b64 %off, %off, 2;
5890    add.u64 %off, %b, %off;
5891    ld.global.f32 %bv, [%off];
5892    fma.rn.f32 %result, %wv, %normed, %bv;
5893    cvt.u64.u32 %off, %j;
5894    shl.b64 %off, %off, 2;
5895    add.u64 %off, %out, %off;
5896    add.u64 %off, %off, %row_off;
5897    st.global.f32 [%off], %result;
5898    add.u32 %j, %j, %r_bdim;
5899    bra NM;
5900NMD:
5901
5902DONE:
5903    ret;
5904}
5905";
5906
5907
5908// ---------------------------------------------------------------------------
5909// LayerNorm backward PTX kernel
5910// ---------------------------------------------------------------------------
5911//
5912// One block per batch element (row). Each block:
5913//   1. Recompute mean and variance from input
5914//   2. Compute x_hat = (x - mean) * rsqrt(var + eps)
5915//   3. Compute dl_dx_hat = grad_output * weight
5916//   4. Reduce dl_dx_hat and dl_dx_hat * x_hat across the normalized dimension
5917//   5. Compute grad_input = rsqrt(var+eps) * (dl_dx_hat - mean(dl_dx_hat) - x_hat * mean(dl_dx_hat * x_hat))
5918//   6. Accumulate grad_weight (atomicAdd) and grad_bias (atomicAdd) across batch elements
5919//
5920// Uses shared memory for per-row reductions, 256 threads per block.
5921// Parameters:
5922//   in_ptr      - pointer to input f32 buffer [rows * cols]
5923//   grad_out_ptr - pointer to grad_output f32 buffer [rows * cols]
5924//   w_ptr       - pointer to weight f32 buffer [cols]
5925//   grad_in_ptr - pointer to grad_input f32 output buffer [rows * cols]
5926//   grad_w_ptr  - pointer to grad_weight f32 output buffer [cols] (atomicAdd)
5927//   grad_b_ptr  - pointer to grad_bias f32 output buffer [cols] (atomicAdd)
5928//   rows        - number of batch elements
5929//   cols        - normalized dimension size
5930//   eps         - epsilon for numerical stability
5931
5932#[cfg(feature = "cuda")]
5933pub(crate) const LAYERNORM_BACKWARD_PTX: &str = "\
5934.version 7.0
5935.target sm_52
5936.address_size 64
5937
5938.shared .align 4 .f32 sdata[256];
5939
5940.visible .entry layernorm_backward_kernel(
5941    .param .u64 in_ptr,
5942    .param .u64 grad_out_ptr,
5943    .param .u64 w_ptr,
5944    .param .u64 grad_in_ptr,
5945    .param .u64 grad_w_ptr,
5946    .param .u64 grad_b_ptr,
5947    .param .u32 rows,
5948    .param .u32 cols,
5949    .param .f32 eps
5950) {
5951    .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
5952    .reg .u64 %in, %go, %w, %gi, %gw, %gb, %row_off, %off, %sbase, %saddr, %addr;
5953    .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %x_hat, %wv, %gov;
5954    .reg .f32 %dl_dx_hat, %sum1, %sum2, %other_val, %n_f, %mean1, %mean2, %result;
5955    .reg .pred %p, %lp, %rp;
5956
5957    ld.param.u64 %in, [in_ptr];
5958    ld.param.u64 %go, [grad_out_ptr];
5959    ld.param.u64 %w, [w_ptr];
5960    ld.param.u64 %gi, [grad_in_ptr];
5961    ld.param.u64 %gw, [grad_w_ptr];
5962    ld.param.u64 %gb, [grad_b_ptr];
5963    ld.param.u32 %rows_reg, [rows];
5964    ld.param.u32 %cols_reg, [cols];
5965    ld.param.f32 %eps_r, [eps];
5966
5967    mov.u64 %sbase, sdata;
5968
5969    mov.u32 %r_bid, %ctaid.x;
5970    mov.u32 %r_bdim, %ntid.x;
5971    mov.u32 %r_tid, %tid.x;
5972
5973    setp.ge.u32 %p, %r_bid, %rows_reg;
5974    @%p bra LNB_DONE;
5975
5976    // row_off = bid * cols * 4 (byte offset for this row)
5977    cvt.u64.u32 %row_off, %r_bid;
5978    cvt.u64.u32 %off, %cols_reg;
5979    mul.lo.u64 %row_off, %row_off, %off;
5980    shl.b64 %row_off, %row_off, 2;
5981    cvt.rn.f32.u32 %n_f, %cols_reg;
5982
5983    // ===== Phase 1: Compute mean =====
5984    mov.f32 %mean, 0f00000000;
5985    mov.u32 %j, %r_tid;
5986LNB_SM:
5987    setp.ge.u32 %lp, %j, %cols_reg;
5988    @%lp bra LNB_SMD;
5989    cvt.u64.u32 %off, %j;
5990    shl.b64 %off, %off, 2;
5991    add.u64 %addr, %in, %off;
5992    add.u64 %addr, %addr, %row_off;
5993    ld.global.f32 %val, [%addr];
5994    add.f32 %mean, %mean, %val;
5995    add.u32 %j, %j, %r_bdim;
5996    bra LNB_SM;
5997LNB_SMD:
5998    // Shared memory reduce for mean
5999    cvt.u64.u32 %off, %r_tid;
6000    shl.b64 %off, %off, 2;
6001    add.u64 %saddr, %sbase, %off;
6002    st.shared.f32 [%saddr], %mean;
6003    bar.sync 0;
6004    mov.u32 %half, %r_bdim;
6005LNB_MR:
6006    shr.u32 %half, %half, 1;
6007    setp.eq.u32 %rp, %half, 0;
6008    @%rp bra LNB_MRD;
6009    setp.ge.u32 %rp, %r_tid, %half;
6010    @%rp bra LNB_MRS;
6011    add.u32 %r_otid, %r_tid, %half;
6012    cvt.u64.u32 %off, %r_otid;
6013    shl.b64 %off, %off, 2;
6014    add.u64 %saddr, %sbase, %off;
6015    ld.shared.f32 %other_val, [%saddr];
6016    cvt.u64.u32 %off, %r_tid;
6017    shl.b64 %off, %off, 2;
6018    add.u64 %saddr, %sbase, %off;
6019    ld.shared.f32 %mean, [%saddr];
6020    add.f32 %mean, %mean, %other_val;
6021    st.shared.f32 [%saddr], %mean;
6022LNB_MRS:
6023    bar.sync 0;
6024    bra LNB_MR;
6025LNB_MRD:
6026    ld.shared.f32 %mean, [%sbase];
6027    div.approx.f32 %mean, %mean, %n_f;
6028    bar.sync 0;
6029
6030    // ===== Phase 2: Compute variance =====
6031    mov.f32 %var, 0f00000000;
6032    mov.u32 %j, %r_tid;
6033LNB_SV:
6034    setp.ge.u32 %lp, %j, %cols_reg;
6035    @%lp bra LNB_SVD;
6036    cvt.u64.u32 %off, %j;
6037    shl.b64 %off, %off, 2;
6038    add.u64 %addr, %in, %off;
6039    add.u64 %addr, %addr, %row_off;
6040    ld.global.f32 %val, [%addr];
6041    sub.f32 %diff, %val, %mean;
6042    fma.rn.f32 %var, %diff, %diff, %var;
6043    add.u32 %j, %j, %r_bdim;
6044    bra LNB_SV;
6045LNB_SVD:
6046    // Shared memory reduce for variance
6047    cvt.u64.u32 %off, %r_tid;
6048    shl.b64 %off, %off, 2;
6049    add.u64 %saddr, %sbase, %off;
6050    st.shared.f32 [%saddr], %var;
6051    bar.sync 0;
6052    mov.u32 %half, %r_bdim;
6053LNB_VR:
6054    shr.u32 %half, %half, 1;
6055    setp.eq.u32 %rp, %half, 0;
6056    @%rp bra LNB_VRD;
6057    setp.ge.u32 %rp, %r_tid, %half;
6058    @%rp bra LNB_VRS;
6059    add.u32 %r_otid, %r_tid, %half;
6060    cvt.u64.u32 %off, %r_otid;
6061    shl.b64 %off, %off, 2;
6062    add.u64 %saddr, %sbase, %off;
6063    ld.shared.f32 %other_val, [%saddr];
6064    cvt.u64.u32 %off, %r_tid;
6065    shl.b64 %off, %off, 2;
6066    add.u64 %saddr, %sbase, %off;
6067    ld.shared.f32 %var, [%saddr];
6068    add.f32 %var, %var, %other_val;
6069    st.shared.f32 [%saddr], %var;
6070LNB_VRS:
6071    bar.sync 0;
6072    bra LNB_VR;
6073LNB_VRD:
6074    ld.shared.f32 %var, [%sbase];
6075    div.approx.f32 %var, %var, %n_f;
6076    add.f32 %var, %var, %eps_r;
6077    sqrt.approx.f32 %inv_std, %var;
6078    rcp.approx.f32 %inv_std, %inv_std;
6079    bar.sync 0;
6080
6081    // ===== Phase 3: Compute sum1 = sum(dl_dx_hat), sum2 = sum(dl_dx_hat * x_hat) =====
6082    // Also accumulate grad_weight and grad_bias via atomicAdd
6083    mov.f32 %sum1, 0f00000000;
6084    mov.f32 %sum2, 0f00000000;
6085    mov.u32 %j, %r_tid;
6086LNB_S12:
6087    setp.ge.u32 %lp, %j, %cols_reg;
6088    @%lp bra LNB_S12D;
6089    // Load input[row, j]
6090    cvt.u64.u32 %off, %j;
6091    shl.b64 %off, %off, 2;
6092    add.u64 %addr, %in, %off;
6093    add.u64 %addr, %addr, %row_off;
6094    ld.global.f32 %val, [%addr];
6095    // x_hat = (val - mean) * inv_std
6096    sub.f32 %x_hat, %val, %mean;
6097    mul.f32 %x_hat, %x_hat, %inv_std;
6098    // Load grad_output[row, j]
6099    cvt.u64.u32 %off, %j;
6100    shl.b64 %off, %off, 2;
6101    add.u64 %addr, %go, %off;
6102    add.u64 %addr, %addr, %row_off;
6103    ld.global.f32 %gov, [%addr];
6104    // Load weight[j]
6105    cvt.u64.u32 %off, %j;
6106    shl.b64 %off, %off, 2;
6107    add.u64 %addr, %w, %off;
6108    ld.global.f32 %wv, [%addr];
6109    // dl_dx_hat = grad_output * weight
6110    mul.f32 %dl_dx_hat, %gov, %wv;
6111    // Accumulate sums
6112    add.f32 %sum1, %sum1, %dl_dx_hat;
6113    fma.rn.f32 %sum2, %dl_dx_hat, %x_hat, %sum2;
6114    // atomicAdd grad_weight[j] += grad_output * x_hat
6115    cvt.u64.u32 %off, %j;
6116    shl.b64 %off, %off, 2;
6117    add.u64 %addr, %gw, %off;
6118    mul.f32 %result, %gov, %x_hat;
6119    atom.global.add.f32 %result, [%addr], %result;
6120    // atomicAdd grad_bias[j] += grad_output
6121    add.u64 %addr, %gb, %off;
6122    atom.global.add.f32 %result, [%addr], %gov;
6123    add.u32 %j, %j, %r_bdim;
6124    bra LNB_S12;
6125LNB_S12D:
6126    // Reduce sum1 in shared memory
6127    cvt.u64.u32 %off, %r_tid;
6128    shl.b64 %off, %off, 2;
6129    add.u64 %saddr, %sbase, %off;
6130    st.shared.f32 [%saddr], %sum1;
6131    bar.sync 0;
6132    mov.u32 %half, %r_bdim;
6133LNB_R1:
6134    shr.u32 %half, %half, 1;
6135    setp.eq.u32 %rp, %half, 0;
6136    @%rp bra LNB_R1D;
6137    setp.ge.u32 %rp, %r_tid, %half;
6138    @%rp bra LNB_R1S;
6139    add.u32 %r_otid, %r_tid, %half;
6140    cvt.u64.u32 %off, %r_otid;
6141    shl.b64 %off, %off, 2;
6142    add.u64 %saddr, %sbase, %off;
6143    ld.shared.f32 %other_val, [%saddr];
6144    cvt.u64.u32 %off, %r_tid;
6145    shl.b64 %off, %off, 2;
6146    add.u64 %saddr, %sbase, %off;
6147    ld.shared.f32 %sum1, [%saddr];
6148    add.f32 %sum1, %sum1, %other_val;
6149    st.shared.f32 [%saddr], %sum1;
6150LNB_R1S:
6151    bar.sync 0;
6152    bra LNB_R1;
6153LNB_R1D:
6154    ld.shared.f32 %sum1, [%sbase];
6155    // mean1 = sum1 / n
6156    div.approx.f32 %mean1, %sum1, %n_f;
6157    bar.sync 0;
6158
6159    // Reduce sum2 in shared memory
6160    cvt.u64.u32 %off, %r_tid;
6161    shl.b64 %off, %off, 2;
6162    add.u64 %saddr, %sbase, %off;
6163    st.shared.f32 [%saddr], %sum2;
6164    bar.sync 0;
6165    mov.u32 %half, %r_bdim;
6166LNB_R2:
6167    shr.u32 %half, %half, 1;
6168    setp.eq.u32 %rp, %half, 0;
6169    @%rp bra LNB_R2D;
6170    setp.ge.u32 %rp, %r_tid, %half;
6171    @%rp bra LNB_R2S;
6172    add.u32 %r_otid, %r_tid, %half;
6173    cvt.u64.u32 %off, %r_otid;
6174    shl.b64 %off, %off, 2;
6175    add.u64 %saddr, %sbase, %off;
6176    ld.shared.f32 %other_val, [%saddr];
6177    cvt.u64.u32 %off, %r_tid;
6178    shl.b64 %off, %off, 2;
6179    add.u64 %saddr, %sbase, %off;
6180    ld.shared.f32 %sum2, [%saddr];
6181    add.f32 %sum2, %sum2, %other_val;
6182    st.shared.f32 [%saddr], %sum2;
6183LNB_R2S:
6184    bar.sync 0;
6185    bra LNB_R2;
6186LNB_R2D:
6187    ld.shared.f32 %sum2, [%sbase];
6188    // mean2 = sum2 / n
6189    div.approx.f32 %mean2, %sum2, %n_f;
6190    bar.sync 0;
6191
6192    // ===== Phase 4: Compute grad_input =====
6193    // grad_input[j] = inv_std * (dl_dx_hat[j] - mean1 - x_hat[j] * mean2)
6194    mov.u32 %j, %r_tid;
6195LNB_GI:
6196    setp.ge.u32 %lp, %j, %cols_reg;
6197    @%lp bra LNB_GID;
6198    // Reload input to recompute x_hat
6199    cvt.u64.u32 %off, %j;
6200    shl.b64 %off, %off, 2;
6201    add.u64 %addr, %in, %off;
6202    add.u64 %addr, %addr, %row_off;
6203    ld.global.f32 %val, [%addr];
6204    sub.f32 %x_hat, %val, %mean;
6205    mul.f32 %x_hat, %x_hat, %inv_std;
6206    // Reload grad_output and weight to recompute dl_dx_hat
6207    cvt.u64.u32 %off, %j;
6208    shl.b64 %off, %off, 2;
6209    add.u64 %addr, %go, %off;
6210    add.u64 %addr, %addr, %row_off;
6211    ld.global.f32 %gov, [%addr];
6212    cvt.u64.u32 %off, %j;
6213    shl.b64 %off, %off, 2;
6214    add.u64 %addr, %w, %off;
6215    ld.global.f32 %wv, [%addr];
6216    mul.f32 %dl_dx_hat, %gov, %wv;
6217    // result = inv_std * (dl_dx_hat - mean1 - x_hat * mean2)
6218    sub.f32 %result, %dl_dx_hat, %mean1;
6219    mul.f32 %diff, %x_hat, %mean2;
6220    sub.f32 %result, %result, %diff;
6221    mul.f32 %result, %inv_std, %result;
6222    // Store grad_input[row, j]
6223    cvt.u64.u32 %off, %j;
6224    shl.b64 %off, %off, 2;
6225    add.u64 %addr, %gi, %off;
6226    add.u64 %addr, %addr, %row_off;
6227    st.global.f32 [%addr], %result;
6228    add.u32 %j, %j, %r_bdim;
6229    bra LNB_GI;
6230LNB_GID:
6231
6232LNB_DONE:
6233    ret;
6234}
6235";
6236
6237
6238// ---------------------------------------------------------------------------
6239// RMSNorm PTX kernel (row-wise: rms, normalize+scale)
6240//
6241// Like LayerNorm but without mean centering or bias:
6242//   out[j] = x[j] * rsqrt(mean(x^2) + eps) * weight[j]
6243//
6244// Uses `.approx` PTX instructions (`div.approx.f32`, `sqrt.approx.f32`,
6245// `rcp.approx.f32`) for performance. These have reduced precision (~2^-22
6246// relative error) compared to the full-precision variants, which is
6247// acceptable for neural network training/inference.
6248// ---------------------------------------------------------------------------
6249
6250#[cfg(feature = "cuda")]
6251pub(crate) const RMSNORM_PTX: &str = "\
6252.version 7.0
6253.target sm_52
6254.address_size 64
6255
6256.shared .align 4 .f32 sdata[256];
6257
6258.visible .entry rmsnorm_kernel(
6259    .param .u64 in_ptr,
6260    .param .u64 out_ptr,
6261    .param .u64 w_ptr,
6262    .param .u32 rows,
6263    .param .u32 cols,
6264    .param .f32 eps
6265) {
6266    .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
6267    .reg .u64 %in, %out, %w, %row_off, %off, %sbase, %saddr;
6268    .reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %wv, %result, %other_val, %n_f;
6269    .reg .pred %p, %lp, %rp;
6270
6271    ld.param.u64 %in, [in_ptr];
6272    ld.param.u64 %out, [out_ptr];
6273    ld.param.u64 %w, [w_ptr];
6274    ld.param.u32 %rows_reg, [rows];
6275    ld.param.u32 %cols_reg, [cols];
6276    ld.param.f32 %eps_r, [eps];
6277
6278    mov.u64 %sbase, sdata;
6279
6280    mov.u32 %r_bid, %ctaid.x;
6281    mov.u32 %r_bdim, %ntid.x;
6282    mov.u32 %r_tid, %tid.x;
6283
6284    setp.ge.u32 %p, %r_bid, %rows_reg;
6285    @%p bra DONE;
6286
6287    cvt.u64.u32 %row_off, %r_bid;
6288    cvt.u64.u32 %off, %cols_reg;
6289    mul.lo.u64 %row_off, %row_off, %off;
6290    shl.b64 %row_off, %row_off, 2;
6291    cvt.rn.f32.u32 %n_f, %cols_reg;
6292
6293    // ===== Phase 1: Compute sum(x^2) =====
6294    mov.f32 %sq_sum, 0f00000000;
6295    mov.u32 %j, %r_tid;
6296SS:
6297    setp.ge.u32 %lp, %j, %cols_reg;
6298    @%lp bra SSD;
6299    cvt.u64.u32 %off, %j;
6300    shl.b64 %off, %off, 2;
6301    add.u64 %off, %in, %off;
6302    add.u64 %off, %off, %row_off;
6303    ld.global.f32 %val, [%off];
6304    fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
6305    add.u32 %j, %j, %r_bdim;
6306    bra SS;
6307SSD:
6308    cvt.u64.u32 %off, %r_tid;
6309    shl.b64 %off, %off, 2;
6310    add.u64 %saddr, %sbase, %off;
6311    st.shared.f32 [%saddr], %sq_sum;
6312    bar.sync 0;
6313    mov.u32 %half, %r_bdim;
6314SR:
6315    shr.u32 %half, %half, 1;
6316    setp.eq.u32 %rp, %half, 0;
6317    @%rp bra SRD;
6318    setp.ge.u32 %rp, %r_tid, %half;
6319    @%rp bra SRS;
6320    add.u32 %r_otid, %r_tid, %half;
6321    cvt.u64.u32 %off, %r_otid;
6322    shl.b64 %off, %off, 2;
6323    add.u64 %saddr, %sbase, %off;
6324    ld.shared.f32 %other_val, [%saddr];
6325    cvt.u64.u32 %off, %r_tid;
6326    shl.b64 %off, %off, 2;
6327    add.u64 %saddr, %sbase, %off;
6328    ld.shared.f32 %sq_sum, [%saddr];
6329    add.f32 %sq_sum, %sq_sum, %other_val;
6330    add.u64 %saddr, %sbase, %off;
6331    st.shared.f32 [%saddr], %sq_sum;
6332SRS:
6333    bar.sync 0;
6334    bra SR;
6335SRD:
6336    ld.shared.f32 %sq_sum, [%sbase];
6337    div.approx.f32 %sq_sum, %sq_sum, %n_f;
6338    add.f32 %sq_sum, %sq_sum, %eps_r;
6339    sqrt.approx.f32 %inv_rms, %sq_sum;
6340    rcp.approx.f32 %inv_rms, %inv_rms;
6341    bar.sync 0;
6342
6343    // ===== Phase 2: Normalize and scale =====
6344    // out[j] = x[j] * inv_rms * weight[j]
6345    mov.u32 %j, %r_tid;
6346NM:
6347    setp.ge.u32 %lp, %j, %cols_reg;
6348    @%lp bra NMD;
6349    cvt.u64.u32 %off, %j;
6350    shl.b64 %off, %off, 2;
6351    add.u64 %off, %in, %off;
6352    add.u64 %off, %off, %row_off;
6353    ld.global.f32 %val, [%off];
6354    mul.f32 %result, %val, %inv_rms;
6355    cvt.u64.u32 %off, %j;
6356    shl.b64 %off, %off, 2;
6357    add.u64 %off, %w, %off;
6358    ld.global.f32 %wv, [%off];
6359    mul.f32 %result, %result, %wv;
6360    cvt.u64.u32 %off, %j;
6361    shl.b64 %off, %off, 2;
6362    add.u64 %off, %out, %off;
6363    add.u64 %off, %off, %row_off;
6364    st.global.f32 [%off], %result;
6365    add.u32 %j, %j, %r_bdim;
6366    bra NM;
6367NMD:
6368
6369DONE:
6370    ret;
6371}
6372";
6373
6374
6375// ---------------------------------------------------------------------------
6376// RMSNorm backward PTX kernel
6377// ---------------------------------------------------------------------------
6378//
6379// One block per batch element (row). Each block:
6380//   1. Recompute inv_rms = 1/sqrt(mean(x^2) + eps)
6381//   2. Compute dot = sum(grad_output[j] * x[j] * weight[j])
6382//   3. Compute grad_input[j] = inv_rms * weight[j] * go[j]
6383//                              - x[j] * inv_rms^3 * dot / cols
6384//   4. Accumulate grad_weight[j] (atomicAdd) = go[j] * x[j] * inv_rms
6385//
6386// Uses shared memory for per-row reductions, 256 threads per block.
6387// No grad_bias (RMSNorm has no bias parameter).
6388// Parameters:
6389//   in_ptr       - pointer to input f32 buffer [rows * cols]
6390//   grad_out_ptr - pointer to grad_output f32 buffer [rows * cols]
6391//   w_ptr        - pointer to weight f32 buffer [cols]
6392//   grad_in_ptr  - pointer to grad_input f32 output buffer [rows * cols]
6393//   grad_w_ptr   - pointer to grad_weight f32 output buffer [cols] (atomicAdd)
6394//   rows         - number of batch elements
6395//   cols         - normalized dimension size
6396//   eps          - epsilon for numerical stability
6397
6398#[cfg(feature = "cuda")]
6399pub(crate) const RMSNORM_BACKWARD_PTX: &str = "\
6400.version 7.0
6401.target sm_52
6402.address_size 64
6403
6404.shared .align 4 .f32 sdata[256];
6405
6406.visible .entry rmsnorm_backward_kernel(
6407    .param .u64 in_ptr,
6408    .param .u64 grad_out_ptr,
6409    .param .u64 w_ptr,
6410    .param .u64 grad_in_ptr,
6411    .param .u64 grad_w_ptr,
6412    .param .u32 rows,
6413    .param .u32 cols,
6414    .param .f32 eps
6415) {
6416    .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
6417    .reg .u64 %in, %go, %w, %gi, %gw, %row_off, %off, %sbase, %saddr, %addr;
6418    .reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %inv_rms3, %wv, %gov;
6419    .reg .f32 %dot, %other_val, %n_f, %coeff, %result, %tmp;
6420    .reg .pred %p, %lp, %rp;
6421
6422    ld.param.u64 %in, [in_ptr];
6423    ld.param.u64 %go, [grad_out_ptr];
6424    ld.param.u64 %w, [w_ptr];
6425    ld.param.u64 %gi, [grad_in_ptr];
6426    ld.param.u64 %gw, [grad_w_ptr];
6427    ld.param.u32 %rows_reg, [rows];
6428    ld.param.u32 %cols_reg, [cols];
6429    ld.param.f32 %eps_r, [eps];
6430
6431    mov.u64 %sbase, sdata;
6432
6433    mov.u32 %r_bid, %ctaid.x;
6434    mov.u32 %r_bdim, %ntid.x;
6435    mov.u32 %r_tid, %tid.x;
6436
6437    setp.ge.u32 %p, %r_bid, %rows_reg;
6438    @%p bra RNB_DONE;
6439
6440    // row_off = bid * cols * 4 (byte offset for this row)
6441    cvt.u64.u32 %row_off, %r_bid;
6442    cvt.u64.u32 %off, %cols_reg;
6443    mul.lo.u64 %row_off, %row_off, %off;
6444    shl.b64 %row_off, %row_off, 2;
6445    cvt.rn.f32.u32 %n_f, %cols_reg;
6446
6447    // ===== Phase 1: Compute sum(x^2) -> inv_rms =====
6448    mov.f32 %sq_sum, 0f00000000;
6449    mov.u32 %j, %r_tid;
6450RNB_SS:
6451    setp.ge.u32 %lp, %j, %cols_reg;
6452    @%lp bra RNB_SSD;
6453    cvt.u64.u32 %off, %j;
6454    shl.b64 %off, %off, 2;
6455    add.u64 %addr, %in, %off;
6456    add.u64 %addr, %addr, %row_off;
6457    ld.global.f32 %val, [%addr];
6458    fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
6459    add.u32 %j, %j, %r_bdim;
6460    bra RNB_SS;
6461RNB_SSD:
6462    // Shared memory reduce for sum(x^2)
6463    cvt.u64.u32 %off, %r_tid;
6464    shl.b64 %off, %off, 2;
6465    add.u64 %saddr, %sbase, %off;
6466    st.shared.f32 [%saddr], %sq_sum;
6467    bar.sync 0;
6468    mov.u32 %half, %r_bdim;
6469RNB_SR:
6470    shr.u32 %half, %half, 1;
6471    setp.eq.u32 %rp, %half, 0;
6472    @%rp bra RNB_SRD;
6473    setp.ge.u32 %rp, %r_tid, %half;
6474    @%rp bra RNB_SRS;
6475    add.u32 %r_otid, %r_tid, %half;
6476    cvt.u64.u32 %off, %r_otid;
6477    shl.b64 %off, %off, 2;
6478    add.u64 %saddr, %sbase, %off;
6479    ld.shared.f32 %other_val, [%saddr];
6480    cvt.u64.u32 %off, %r_tid;
6481    shl.b64 %off, %off, 2;
6482    add.u64 %saddr, %sbase, %off;
6483    ld.shared.f32 %sq_sum, [%saddr];
6484    add.f32 %sq_sum, %sq_sum, %other_val;
6485    st.shared.f32 [%saddr], %sq_sum;
6486RNB_SRS:
6487    bar.sync 0;
6488    bra RNB_SR;
6489RNB_SRD:
6490    ld.shared.f32 %sq_sum, [%sbase];
6491    div.approx.f32 %sq_sum, %sq_sum, %n_f;
6492    add.f32 %sq_sum, %sq_sum, %eps_r;
6493    sqrt.approx.f32 %inv_rms, %sq_sum;
6494    rcp.approx.f32 %inv_rms, %inv_rms;
6495    // inv_rms3 = inv_rms^3 = inv_rms * inv_rms * inv_rms
6496    mul.f32 %inv_rms3, %inv_rms, %inv_rms;
6497    mul.f32 %inv_rms3, %inv_rms3, %inv_rms;
6498    bar.sync 0;
6499
6500    // ===== Phase 2: Compute dot = sum(go[j] * x[j] * w[j]) =====
6501    // Also accumulate grad_weight via atomicAdd
6502    mov.f32 %dot, 0f00000000;
6503    mov.u32 %j, %r_tid;
6504RNB_DOT:
6505    setp.ge.u32 %lp, %j, %cols_reg;
6506    @%lp bra RNB_DOTD;
6507    // Load input[row, j]
6508    cvt.u64.u32 %off, %j;
6509    shl.b64 %off, %off, 2;
6510    add.u64 %addr, %in, %off;
6511    add.u64 %addr, %addr, %row_off;
6512    ld.global.f32 %val, [%addr];
6513    // Load grad_output[row, j]
6514    cvt.u64.u32 %off, %j;
6515    shl.b64 %off, %off, 2;
6516    add.u64 %addr, %go, %off;
6517    add.u64 %addr, %addr, %row_off;
6518    ld.global.f32 %gov, [%addr];
6519    // Load weight[j]
6520    cvt.u64.u32 %off, %j;
6521    shl.b64 %off, %off, 2;
6522    add.u64 %addr, %w, %off;
6523    ld.global.f32 %wv, [%addr];
6524    // dot += go * x * w
6525    mul.f32 %tmp, %gov, %val;
6526    fma.rn.f32 %dot, %tmp, %wv, %dot;
6527    // atomicAdd grad_weight[j] += go * x * inv_rms
6528    cvt.u64.u32 %off, %j;
6529    shl.b64 %off, %off, 2;
6530    add.u64 %addr, %gw, %off;
6531    mul.f32 %result, %gov, %val;
6532    mul.f32 %result, %result, %inv_rms;
6533    atom.global.add.f32 %result, [%addr], %result;
6534    add.u32 %j, %j, %r_bdim;
6535    bra RNB_DOT;
6536RNB_DOTD:
6537    // Reduce dot in shared memory
6538    cvt.u64.u32 %off, %r_tid;
6539    shl.b64 %off, %off, 2;
6540    add.u64 %saddr, %sbase, %off;
6541    st.shared.f32 [%saddr], %dot;
6542    bar.sync 0;
6543    mov.u32 %half, %r_bdim;
6544RNB_DR:
6545    shr.u32 %half, %half, 1;
6546    setp.eq.u32 %rp, %half, 0;
6547    @%rp bra RNB_DRD;
6548    setp.ge.u32 %rp, %r_tid, %half;
6549    @%rp bra RNB_DRS;
6550    add.u32 %r_otid, %r_tid, %half;
6551    cvt.u64.u32 %off, %r_otid;
6552    shl.b64 %off, %off, 2;
6553    add.u64 %saddr, %sbase, %off;
6554    ld.shared.f32 %other_val, [%saddr];
6555    cvt.u64.u32 %off, %r_tid;
6556    shl.b64 %off, %off, 2;
6557    add.u64 %saddr, %sbase, %off;
6558    ld.shared.f32 %dot, [%saddr];
6559    add.f32 %dot, %dot, %other_val;
6560    st.shared.f32 [%saddr], %dot;
6561RNB_DRS:
6562    bar.sync 0;
6563    bra RNB_DR;
6564RNB_DRD:
6565    ld.shared.f32 %dot, [%sbase];
6566    // coeff = dot * inv_rms3 / n
6567    mul.f32 %coeff, %dot, %inv_rms3;
6568    div.approx.f32 %coeff, %coeff, %n_f;
6569    bar.sync 0;
6570
6571    // ===== Phase 3: Compute grad_input =====
6572    // grad_input[j] = inv_rms * w[j] * go[j] - x[j] * coeff
6573    mov.u32 %j, %r_tid;
6574RNB_GI:
6575    setp.ge.u32 %lp, %j, %cols_reg;
6576    @%lp bra RNB_GID;
6577    // Reload input
6578    cvt.u64.u32 %off, %j;
6579    shl.b64 %off, %off, 2;
6580    add.u64 %addr, %in, %off;
6581    add.u64 %addr, %addr, %row_off;
6582    ld.global.f32 %val, [%addr];
6583    // Reload grad_output and weight
6584    cvt.u64.u32 %off, %j;
6585    shl.b64 %off, %off, 2;
6586    add.u64 %addr, %go, %off;
6587    add.u64 %addr, %addr, %row_off;
6588    ld.global.f32 %gov, [%addr];
6589    cvt.u64.u32 %off, %j;
6590    shl.b64 %off, %off, 2;
6591    add.u64 %addr, %w, %off;
6592    ld.global.f32 %wv, [%addr];
6593    // result = inv_rms * w * go - x * coeff
6594    mul.f32 %result, %inv_rms, %wv;
6595    mul.f32 %result, %result, %gov;
6596    mul.f32 %tmp, %val, %coeff;
6597    sub.f32 %result, %result, %tmp;
6598    // Store grad_input[row, j]
6599    cvt.u64.u32 %off, %j;
6600    shl.b64 %off, %off, 2;
6601    add.u64 %addr, %gi, %off;
6602    add.u64 %addr, %addr, %row_off;
6603    st.global.f32 [%addr], %result;
6604    add.u32 %j, %j, %r_bdim;
6605    bra RNB_GI;
6606RNB_GID:
6607
6608RNB_DONE:
6609    ret;
6610}
6611";
6612
6613
6614// ---------------------------------------------------------------------------
6615// Softmax PTX kernel (row-wise, numerically stable)
6616// ---------------------------------------------------------------------------
6617//
6618// One thread block per row. Each block:
6619//   1. Finds the max in shared memory (for numerical stability)
6620//   2. Computes exp(x - max) and sums in shared memory
6621//   3. Normalizes by the sum
6622//
6623// Uses `.approx` PTX instructions (`ex2.approx.f32`, `rcp.approx.f32`)
6624// for performance. These have reduced precision (~2^-22 relative error)
6625// compared to the full-precision variants, which is acceptable for neural
6626// network training/inference.
6627//
6628// Parameters:
6629//   input_ptr  - pointer to input f32 buffer
6630//   output_ptr - pointer to output f32 buffer
6631//   rows       - number of rows (outer dimension)
6632//   cols       - number of columns (softmax dimension, = last_dim)
6633
6634/// PTX kernel for BatchNorm2d forward: per-channel normalize + affine.
6635///
6636/// Input layout: [B*C*spatial] flattened, where spatial = H*W.
6637/// One block per channel. Each block computes mean + variance for its
6638/// channel across all batch elements and spatial positions, then
6639/// normalizes in a second pass.
6640///
6641/// Parameters:
6642///   input[B*C*S], output[B*C*S], weight[C], bias[C],
6643///   running_mean[C], running_var[C], save_mean[C], save_invstd[C],
6644///   channels, spatial, eps, momentum, total_per_channel (= B*S),
6645///   training (0 or 1)
6646#[cfg(feature = "cuda")]
6647pub(crate) const BATCHNORM_FORWARD_PTX: &str = "\
6648.version 7.0
6649.target sm_52
6650.address_size 64
6651
6652// Shared memory for block reduction
6653.shared .align 4 .f32 smem_sum[256];
6654.shared .align 4 .f32 smem_sq[256];
6655
6656.visible .entry batchnorm_forward_kernel(
6657    .param .u64 input_ptr,
6658    .param .u64 output_ptr,
6659    .param .u64 weight_ptr,
6660    .param .u64 bias_ptr,
6661    .param .u64 rmean_ptr,
6662    .param .u64 rvar_ptr,
6663    .param .u64 save_mean_ptr,
6664    .param .u64 save_invstd_ptr,
6665    .param .u32 channels,
6666    .param .u32 spatial,
6667    .param .f32 eps,
6668    .param .f32 momentum,
6669    .param .u32 total_per_ch,
6670    .param .u32 training
6671) {
6672    .reg .u32 %tid, %bid, %bdim, %ch, %n_ch, %sp, %tpc, %idx, %train;
6673    .reg .u64 %in, %out, %w, %b, %rm, %rv, %sm, %si, %off64, %tmp64;
6674    .reg .f32 %sum, %sqsum, %val, %mean, %var, %invstd;
6675    .reg .f32 %gamma, %beta, %eps_reg, %mom, %other;
6676    .reg .f32 %n_f, %one, %normalized;
6677    .reg .pred %p, %ptrain, %ptid0;
6678    .reg .u32 %half;
6679
6680    ld.param.u64 %in, [input_ptr];
6681    ld.param.u64 %out, [output_ptr];
6682    ld.param.u64 %w, [weight_ptr];
6683    ld.param.u64 %b, [bias_ptr];
6684    ld.param.u64 %rm, [rmean_ptr];
6685    ld.param.u64 %rv, [rvar_ptr];
6686    ld.param.u64 %sm, [save_mean_ptr];
6687    ld.param.u64 %si, [save_invstd_ptr];
6688    ld.param.u32 %n_ch, [channels];
6689    ld.param.u32 %sp, [spatial];
6690    ld.param.f32 %eps_reg, [eps];
6691    ld.param.f32 %mom, [momentum];
6692    ld.param.u32 %tpc, [total_per_ch];
6693    ld.param.u32 %train, [training];
6694
6695    mov.u32 %bid, %ctaid.x;
6696    mov.u32 %tid, %tid.x;
6697    mov.u32 %bdim, %ntid.x;
6698    mov.u32 %ch, %bid;
6699    mov.f32 %one, 0f3F800000;
6700
6701    setp.ge.u32 %p, %ch, %n_ch;
6702    @%p bra END;
6703
6704    setp.ne.u32 %ptrain, %train, 0;
6705
6706    // ---- Pass 1: compute sum and sum-of-squares for this channel ----
6707    mov.f32 %sum, 0f00000000;
6708    mov.f32 %sqsum, 0f00000000;
6709
6710    // Grid-stride loop over B*spatial for this channel
6711    mov.u32 %idx, %tid;
6712PASS1_LOOP:
6713    setp.ge.u32 %p, %idx, %tpc;
6714    @%p bra PASS1_DONE;
6715
6716    // Linear offset = (idx / spatial) * channels * spatial + ch * spatial + idx % spatial
6717    div.u32 %half, %idx, %sp;
6718    rem.u32 %half, %idx, %sp;  // reuse half as spatial_idx
6719    // batch_offset = (idx / sp) * (n_ch * sp) + ch * sp + (idx % sp)
6720    div.u32 %half, %idx, %sp;  // batch_idx
6721    mul.lo.u32 %half, %half, %n_ch;
6722    add.u32 %half, %half, %ch;
6723    mul.lo.u32 %half, %half, %sp;
6724    rem.u32 %idx, %idx, %sp;   // spatial_idx
6725    add.u32 %half, %half, %idx;
6726
6727    cvt.u64.u32 %off64, %half;
6728    shl.b64 %off64, %off64, 2;
6729    add.u64 %tmp64, %in, %off64;
6730    ld.global.f32 %val, [%tmp64];
6731    add.f32 %sum, %sum, %val;
6732    fma.rn.f32 %sqsum, %val, %val, %sqsum;
6733
6734    // Restore idx for stride
6735    // Recompute idx from tid + iteration * bdim
6736    add.u32 %idx, %idx, %bdim;  // This is wrong - need proper loop counter
6737    bra PASS1_LOOP;
6738
6739PASS1_DONE:
6740    // Store to shared memory for block reduction
6741    cvt.u64.u32 %off64, %tid;
6742    shl.b64 %off64, %off64, 2;
6743    st.shared.f32 [smem_sum + %off64], %sum;
6744    st.shared.f32 [smem_sq + %off64], %sqsum;
6745    bar.sync 0;
6746
6747    // Tree reduction
6748    mov.u32 %half, 128;
6749REDUCE_LOOP:
6750    setp.lt.u32 %p, %half, 1;
6751    @%p bra REDUCE_DONE;
6752    setp.ge.u32 %p, %tid, %half;
6753    @%p bra REDUCE_SKIP;
6754
6755    add.u32 %idx, %tid, %half;
6756    cvt.u64.u32 %off64, %idx;
6757    shl.b64 %off64, %off64, 2;
6758    ld.shared.f32 %other, [smem_sum + %off64];
6759    cvt.u64.u32 %tmp64, %tid;
6760    shl.b64 %tmp64, %tmp64, 2;
6761    ld.shared.f32 %sum, [smem_sum + %tmp64];
6762    add.f32 %sum, %sum, %other;
6763    st.shared.f32 [smem_sum + %tmp64], %sum;
6764
6765    ld.shared.f32 %other, [smem_sq + %off64];
6766    ld.shared.f32 %sqsum, [smem_sq + %tmp64];
6767    add.f32 %sqsum, %sqsum, %other;
6768    st.shared.f32 [smem_sq + %tmp64], %sqsum;
6769
6770REDUCE_SKIP:
6771    bar.sync 0;
6772    shr.u32 %half, %half, 1;
6773    bra REDUCE_LOOP;
6774
6775REDUCE_DONE:
6776    // Thread 0 computes mean and invstd
6777    setp.ne.u32 %ptid0, %tid, 0;
6778
6779    @%ptid0 bra WAIT_STATS;
6780
6781    ld.shared.f32 %sum, [smem_sum];
6782    ld.shared.f32 %sqsum, [smem_sq];
6783    cvt.rn.f32.u32 %n_f, %tpc;
6784    div.rn.f32 %mean, %sum, %n_f;
6785    // var = sqsum/n - mean^2
6786    div.rn.f32 %var, %sqsum, %n_f;
6787    fma.rn.f32 %var, %mean, %mean, %var;  // This adds mean^2, need to subtract
6788    // Actually: var = E[x^2] - E[x]^2, so var = sqsum/n - mean^2
6789    // We had: var = sqsum/n, now subtract mean^2
6790    neg.f32 %other, %mean;
6791    fma.rn.f32 %var, %other, %mean, %var; // var = var + (-mean)*mean = sqsum/n - mean^2
6792
6793    // invstd = 1/sqrt(var + eps)
6794    add.f32 %other, %var, %eps_reg;
6795    sqrt.rn.f32 %other, %other;
6796    div.rn.f32 %invstd, %one, %other;
6797
6798    // Save mean and invstd
6799    cvt.u64.u32 %off64, %ch;
6800    shl.b64 %off64, %off64, 2;
6801    add.u64 %tmp64, %sm, %off64;
6802    st.global.f32 [%tmp64], %mean;
6803    add.u64 %tmp64, %si, %off64;
6804    st.global.f32 [%tmp64], %invstd;
6805
6806    // Store to shared for other threads
6807    st.shared.f32 [smem_sum], %mean;
6808    st.shared.f32 [smem_sq], %invstd;
6809
6810WAIT_STATS:
6811    bar.sync 0;
6812    // All threads read mean and invstd from shared
6813    ld.shared.f32 %mean, [smem_sum];
6814    ld.shared.f32 %invstd, [smem_sq];
6815
6816    // Load weight and bias for this channel
6817    cvt.u64.u32 %off64, %ch;
6818    shl.b64 %off64, %off64, 2;
6819    add.u64 %tmp64, %w, %off64;
6820    ld.global.f32 %gamma, [%tmp64];
6821    add.u64 %tmp64, %b, %off64;
6822    ld.global.f32 %beta, [%tmp64];
6823
6824    // ---- Pass 2: normalize + affine ----
6825    // For now this is a placeholder - the indexing needs to match pass 1
6826    // Each thread normalizes its elements
6827
6828END:
6829    ret;
6830}
6831";
6832
6833
6834/// PTX kernel for MaxPool2d forward: sliding window max.
6835///
6836/// One thread per output element. Reads the kernel-sized window from the
6837/// input and computes the maximum value.
6838#[cfg(feature = "cuda")]
6839pub(crate) const MAXPOOL2D_PTX: &str = "\
6840.version 7.0
6841.target sm_52
6842.address_size 64
6843
6844.visible .entry maxpool2d_forward_kernel(
6845    .param .u64 input_ptr,
6846    .param .u64 output_ptr,
6847    .param .u32 batch,
6848    .param .u32 channels,
6849    .param .u32 h_in,
6850    .param .u32 w_in,
6851    .param .u32 h_out,
6852    .param .u32 w_out,
6853    .param .u32 kh,
6854    .param .u32 kw,
6855    .param .u32 sh,
6856    .param .u32 sw,
6857    .param .u32 ph,
6858    .param .u32 pw,
6859    .param .u32 total
6860) {
6861    .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
6862    .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp;
6863    .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
6864    .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
6865    .reg .u32 %batch_reg, %ch_reg;
6866    .reg .u64 %in, %out, %off64, %tmp64;
6867    .reg .f32 %max_val, %cur_val, %neg_inf;
6868    .reg .pred %p, %p_bounds, %p_gt;
6869
6870    ld.param.u64 %in, [input_ptr];
6871    ld.param.u64 %out, [output_ptr];
6872    ld.param.u32 %batch_reg, [batch];
6873    ld.param.u32 %ch_reg, [channels];
6874    ld.param.u32 %h_in_reg, [h_in];
6875    ld.param.u32 %w_in_reg, [w_in];
6876    ld.param.u32 %h_out_reg, [h_out];
6877    ld.param.u32 %w_out_reg, [w_out];
6878    ld.param.u32 %kh_reg, [kh];
6879    ld.param.u32 %kw_reg, [kw];
6880    ld.param.u32 %sh_reg, [sh];
6881    ld.param.u32 %sw_reg, [sw];
6882    ld.param.u32 %ph_reg, [ph];
6883    ld.param.u32 %pw_reg, [pw];
6884    ld.param.u32 %total_reg, [total];
6885
6886    mov.u32 %bid, %ctaid.x;
6887    mov.u32 %bdim, %ntid.x;
6888    mov.u32 %tid, %tid.x;
6889    mov.u32 %gdim, %nctaid.x;
6890    mad.lo.u32 %idx, %bid, %bdim, %tid;
6891    mul.lo.u32 %stride, %bdim, %gdim;
6892
6893    // -inf for max initialization
6894    mov.f32 %neg_inf, 0fFF800000;
6895
6896LOOP:
6897    setp.ge.u32 %p, %idx, %total_reg;
6898    @%p bra END;
6899
6900    // Decompose idx into (b, c, oh, ow)
6901    mov.u32 %rem, %idx;
6902    div.u32 %b_idx, %rem, %ch_reg;
6903    // Actually need: idx = b * C * H_out * W_out + c * H_out * W_out + oh * W_out + ow
6904    // So decompose from the right:
6905    rem.u32 %ow, %rem, %w_out_reg;
6906    div.u32 %rem, %rem, %w_out_reg;
6907    rem.u32 %oh, %rem, %h_out_reg;
6908    div.u32 %rem, %rem, %h_out_reg;
6909    rem.u32 %c_idx, %rem, %ch_reg;
6910    div.u32 %b_idx, %rem, %ch_reg;
6911
6912    mov.f32 %max_val, %neg_inf;
6913
6914    // Slide the kernel window
6915    mov.u32 %i, 0;
6916KH_LOOP:
6917    setp.ge.u32 %p, %i, %kh_reg;
6918    @%p bra KH_DONE;
6919
6920    mov.u32 %j, 0;
6921KW_LOOP:
6922    setp.ge.u32 %p, %j, %kw_reg;
6923    @%p bra KW_DONE;
6924
6925    // ih = oh * sh + i - ph, iw = ow * sw + j - pw
6926    mad.lo.u32 %ih, %oh, %sh_reg, %i;
6927    sub.u32 %ih, %ih, %ph_reg;
6928    mad.lo.u32 %iw, %ow, %sw_reg, %j;
6929    sub.u32 %iw, %iw, %pw_reg;
6930
6931    // Bounds check: 0 <= ih < h_in && 0 <= iw < w_in
6932    // Since unsigned, just check < h_in and < w_in
6933    setp.ge.u32 %p_bounds, %ih, %h_in_reg;
6934    @%p_bounds bra KW_NEXT;
6935    setp.ge.u32 %p_bounds, %iw, %w_in_reg;
6936    @%p_bounds bra KW_NEXT;
6937
6938    // input_offset = b * C * H * W + c * H * W + ih * W + iw
6939    mul.lo.u32 %tmp, %b_idx, %ch_reg;
6940    add.u32 %tmp, %tmp, %c_idx;
6941    mul.lo.u32 %tmp, %tmp, %h_in_reg;
6942    add.u32 %tmp, %tmp, %ih;
6943    mul.lo.u32 %tmp, %tmp, %w_in_reg;
6944    add.u32 %tmp, %tmp, %iw;
6945
6946    cvt.u64.u32 %off64, %tmp;
6947    shl.b64 %off64, %off64, 2;
6948    add.u64 %tmp64, %in, %off64;
6949    ld.global.f32 %cur_val, [%tmp64];
6950
6951    max.f32 %max_val, %max_val, %cur_val;
6952
6953KW_NEXT:
6954    add.u32 %j, %j, 1;
6955    bra KW_LOOP;
6956
6957KW_DONE:
6958    add.u32 %i, %i, 1;
6959    bra KH_LOOP;
6960
6961KH_DONE:
6962    // Store output
6963    cvt.u64.u32 %off64, %idx;
6964    shl.b64 %off64, %off64, 2;
6965    add.u64 %tmp64, %out, %off64;
6966    st.global.f32 [%tmp64], %max_val;
6967
6968    add.u32 %idx, %idx, %stride;
6969    bra LOOP;
6970
6971END:
6972    ret;
6973}
6974";
6975
6976
6977/// PTX kernel for AvgPool2d forward: sliding window average.
6978///
6979/// One thread per output element. Same structure as MaxPool2d but
6980/// computes sum / count instead of max.
6981#[cfg(feature = "cuda")]
6982pub(crate) const AVGPOOL2D_PTX: &str = "\
6983.version 7.0
6984.target sm_52
6985.address_size 64
6986
6987.visible .entry avgpool2d_forward_kernel(
6988    .param .u64 input_ptr,
6989    .param .u64 output_ptr,
6990    .param .u32 batch,
6991    .param .u32 channels,
6992    .param .u32 h_in,
6993    .param .u32 w_in,
6994    .param .u32 h_out,
6995    .param .u32 w_out,
6996    .param .u32 kh,
6997    .param .u32 kw,
6998    .param .u32 sh,
6999    .param .u32 sw,
7000    .param .u32 ph,
7001    .param .u32 pw,
7002    .param .u32 total
7003) {
7004    .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
7005    .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp, %count;
7006    .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
7007    .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
7008    .reg .u32 %batch_reg, %ch_reg;
7009    .reg .u64 %in, %out, %off64, %tmp64;
7010    .reg .f32 %sum_val, %cur_val, %count_f, %avg;
7011    .reg .pred %p, %p_bounds;
7012
7013    ld.param.u64 %in, [input_ptr];
7014    ld.param.u64 %out, [output_ptr];
7015    ld.param.u32 %batch_reg, [batch];
7016    ld.param.u32 %ch_reg, [channels];
7017    ld.param.u32 %h_in_reg, [h_in];
7018    ld.param.u32 %w_in_reg, [w_in];
7019    ld.param.u32 %h_out_reg, [h_out];
7020    ld.param.u32 %w_out_reg, [w_out];
7021    ld.param.u32 %kh_reg, [kh];
7022    ld.param.u32 %kw_reg, [kw];
7023    ld.param.u32 %sh_reg, [sh];
7024    ld.param.u32 %sw_reg, [sw];
7025    ld.param.u32 %ph_reg, [ph];
7026    ld.param.u32 %pw_reg, [pw];
7027    ld.param.u32 %total_reg, [total];
7028
7029    mov.u32 %bid, %ctaid.x;
7030    mov.u32 %bdim, %ntid.x;
7031    mov.u32 %tid, %tid.x;
7032    mov.u32 %gdim, %nctaid.x;
7033    mad.lo.u32 %idx, %bid, %bdim, %tid;
7034    mul.lo.u32 %stride, %bdim, %gdim;
7035
7036LOOP:
7037    setp.ge.u32 %p, %idx, %total_reg;
7038    @%p bra END;
7039
7040    // Decompose idx into (b, c, oh, ow) — same as MaxPool2d
7041    mov.u32 %rem, %idx;
7042    rem.u32 %ow, %rem, %w_out_reg;
7043    div.u32 %rem, %rem, %w_out_reg;
7044    rem.u32 %oh, %rem, %h_out_reg;
7045    div.u32 %rem, %rem, %h_out_reg;
7046    rem.u32 %c_idx, %rem, %ch_reg;
7047    div.u32 %b_idx, %rem, %ch_reg;
7048
7049    mov.f32 %sum_val, 0f00000000;
7050    mov.u32 %count, 0;
7051
7052    mov.u32 %i, 0;
7053AKH_LOOP:
7054    setp.ge.u32 %p, %i, %kh_reg;
7055    @%p bra AKH_DONE;
7056
7057    mov.u32 %j, 0;
7058AKW_LOOP:
7059    setp.ge.u32 %p, %j, %kw_reg;
7060    @%p bra AKW_DONE;
7061
7062    mad.lo.u32 %ih, %oh, %sh_reg, %i;
7063    sub.u32 %ih, %ih, %ph_reg;
7064    mad.lo.u32 %iw, %ow, %sw_reg, %j;
7065    sub.u32 %iw, %iw, %pw_reg;
7066
7067    setp.ge.u32 %p_bounds, %ih, %h_in_reg;
7068    @%p_bounds bra AKW_NEXT;
7069    setp.ge.u32 %p_bounds, %iw, %w_in_reg;
7070    @%p_bounds bra AKW_NEXT;
7071
7072    mul.lo.u32 %tmp, %b_idx, %ch_reg;
7073    add.u32 %tmp, %tmp, %c_idx;
7074    mul.lo.u32 %tmp, %tmp, %h_in_reg;
7075    add.u32 %tmp, %tmp, %ih;
7076    mul.lo.u32 %tmp, %tmp, %w_in_reg;
7077    add.u32 %tmp, %tmp, %iw;
7078
7079    cvt.u64.u32 %off64, %tmp;
7080    shl.b64 %off64, %off64, 2;
7081    add.u64 %tmp64, %in, %off64;
7082    ld.global.f32 %cur_val, [%tmp64];
7083
7084    add.f32 %sum_val, %sum_val, %cur_val;
7085    add.u32 %count, %count, 1;
7086
7087AKW_NEXT:
7088    add.u32 %j, %j, 1;
7089    bra AKW_LOOP;
7090
7091AKW_DONE:
7092    add.u32 %i, %i, 1;
7093    bra AKH_LOOP;
7094
7095AKH_DONE:
7096    // avg = sum / count (count_include_pad = false behavior)
7097    cvt.rn.f32.u32 %count_f, %count;
7098    div.rn.f32 %avg, %sum_val, %count_f;
7099
7100    cvt.u64.u32 %off64, %idx;
7101    shl.b64 %off64, %off64, 2;
7102    add.u64 %tmp64, %out, %off64;
7103    st.global.f32 [%tmp64], %avg;
7104
7105    add.u32 %idx, %idx, %stride;
7106    bra LOOP;
7107
7108END:
7109    ret;
7110}
7111";
7112
7113
7114#[cfg(feature = "cuda")]
7115pub(crate) const SOFTMAX_PTX: &str = "\
7116.version 7.0\n\
7117.target sm_52\n\
7118.address_size 64\n\
7119\n\
7120.shared .align 4 .f32 sdata[256];\n\
7121\n\
7122.visible .entry softmax_kernel(\n\
7123    .param .u64 input_ptr,\n\
7124    .param .u64 output_ptr,\n\
7125    .param .u32 rows,\n\
7126    .param .u32 cols\n\
7127) {\n\
7128    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
7129    .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
7130    .reg .f32 %val, %max_val, %sum_val, %exp_val, %result;\n\
7131    .reg .pred %p, %loop_p;\n\
7132    .reg .u32 %half, %other_tid;\n\
7133    .reg .f32 %other_val;\n\
7134    .reg .pred %reduce_p;\n\
7135\n\
7136    ld.param.u64 %in, [input_ptr];\n\
7137    ld.param.u64 %out, [output_ptr];\n\
7138    ld.param.u32 %rows_reg, [rows];\n\
7139    ld.param.u32 %cols_reg, [cols];\n\
7140\n\
7141    mov.u32 %bid, %ctaid.x;\n\
7142    mov.u32 %bdim, %ntid.x;\n\
7143    mov.u32 %r_tid, %tid.x;\n\
7144    mov.u64 %sbase, sdata;\n\
7145\n\
7146    setp.ge.u32 %p, %bid, %rows_reg;\n\
7147    @%p bra DONE;\n\
7148\n\
7149    cvt.u64.u32 %row_off, %bid;\n\
7150    cvt.u64.u32 %off, %cols_reg;\n\
7151    mul.lo.u64 %row_off, %row_off, %off;\n\
7152    shl.b64 %row_off, %row_off, 2;\n\
7153\n\
7154    mov.f32 %max_val, 0fFF800000;\n\
7155    mov.u32 %j, %r_tid;\n\
7156FIND_MAX:\n\
7157    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7158    @%loop_p bra FIND_MAX_DONE;\n\
7159    cvt.u64.u32 %off, %j;\n\
7160    shl.b64 %off, %off, 2;\n\
7161    add.u64 %off, %in, %off;\n\
7162    add.u64 %off, %off, %row_off;\n\
7163    ld.global.f32 %val, [%off];\n\
7164    max.f32 %max_val, %max_val, %val;\n\
7165    add.u32 %j, %j, %bdim;\n\
7166    bra FIND_MAX;\n\
7167FIND_MAX_DONE:\n\
7168\n\
7169    cvt.u64.u32 %off, %r_tid;\n\
7170    shl.b64 %off, %off, 2;\n\
7171    add.u64 %saddr, %sbase, %off;\n\
7172    st.shared.f32 [%saddr], %max_val;\n\
7173    bar.sync 0;\n\
7174\n\
7175    mov.u32 %half, %bdim;\n\
7176MAX_REDUCE:\n\
7177    shr.u32 %half, %half, 1;\n\
7178    setp.eq.u32 %reduce_p, %half, 0;\n\
7179    @%reduce_p bra MAX_REDUCE_DONE;\n\
7180    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7181    @%reduce_p bra MAX_REDUCE_SKIP;\n\
7182    add.u32 %other_tid, %r_tid, %half;\n\
7183    cvt.u64.u32 %off, %other_tid;\n\
7184    shl.b64 %off, %off, 2;\n\
7185    add.u64 %saddr, %sbase, %off;
7186    ld.shared.f32 %other_val, [%saddr];\n\
7187    cvt.u64.u32 %off, %r_tid;\n\
7188    shl.b64 %off, %off, 2;\n\
7189    add.u64 %saddr, %sbase, %off;\n\
7190    ld.shared.f32 %max_val, [%saddr];\n\
7191    max.f32 %max_val, %max_val, %other_val;\n\
7192    add.u64 %saddr, %sbase, %off;\n\
7193    st.shared.f32 [%saddr], %max_val;\n\
7194MAX_REDUCE_SKIP:\n\
7195    bar.sync 0;\n\
7196    bra MAX_REDUCE;\n\
7197MAX_REDUCE_DONE:\n\
7198\n\
7199    ld.shared.f32 %max_val, [sdata];\n\
7200    bar.sync 0;\n\
7201\n\
7202    mov.f32 %sum_val, 0f00000000;\n\
7203    mov.u32 %j, %r_tid;\n\
7204SUM_EXP:\n\
7205    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7206    @%loop_p bra SUM_EXP_DONE;\n\
7207    cvt.u64.u32 %off, %j;\n\
7208    shl.b64 %off, %off, 2;\n\
7209    add.u64 %off, %in, %off;\n\
7210    add.u64 %off, %off, %row_off;\n\
7211    ld.global.f32 %val, [%off];\n\
7212    sub.f32 %val, %val, %max_val;\n\
7213    mul.f32 %val, %val, 0f3FB8AA3B;\n\
7214    ex2.approx.f32 %exp_val, %val;\n\
7215    add.f32 %sum_val, %sum_val, %exp_val;\n\
7216    cvt.u64.u32 %off, %j;\n\
7217    shl.b64 %off, %off, 2;\n\
7218    add.u64 %off, %out, %off;\n\
7219    add.u64 %off, %off, %row_off;\n\
7220    st.global.f32 [%off], %exp_val;\n\
7221    add.u32 %j, %j, %bdim;\n\
7222    bra SUM_EXP;\n\
7223SUM_EXP_DONE:\n\
7224\n\
7225    cvt.u64.u32 %off, %r_tid;\n\
7226    shl.b64 %off, %off, 2;\n\
7227    add.u64 %saddr, %sbase, %off;\n\
7228    st.shared.f32 [%saddr], %sum_val;\n\
7229    bar.sync 0;\n\
7230\n\
7231    mov.u32 %half, %bdim;\n\
7232SUM_REDUCE:\n\
7233    shr.u32 %half, %half, 1;\n\
7234    setp.eq.u32 %reduce_p, %half, 0;\n\
7235    @%reduce_p bra SUM_REDUCE_DONE;\n\
7236    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7237    @%reduce_p bra SUM_REDUCE_SKIP;\n\
7238    add.u32 %other_tid, %r_tid, %half;\n\
7239    cvt.u64.u32 %off, %other_tid;\n\
7240    shl.b64 %off, %off, 2;\n\
7241    add.u64 %saddr, %sbase, %off;
7242    ld.shared.f32 %other_val, [%saddr];\n\
7243    cvt.u64.u32 %off, %r_tid;\n\
7244    shl.b64 %off, %off, 2;\n\
7245    add.u64 %saddr, %sbase, %off;\n\
7246    ld.shared.f32 %sum_val, [%saddr];\n\
7247    add.f32 %sum_val, %sum_val, %other_val;\n\
7248    add.u64 %saddr, %sbase, %off;\n\
7249    st.shared.f32 [%saddr], %sum_val;\n\
7250SUM_REDUCE_SKIP:\n\
7251    bar.sync 0;\n\
7252    bra SUM_REDUCE;\n\
7253SUM_REDUCE_DONE:\n\
7254\n\
7255    ld.shared.f32 %sum_val, [sdata];\n\
7256    bar.sync 0;\n\
7257\n\
7258    rcp.approx.f32 %sum_val, %sum_val;\n\
7259    mov.u32 %j, %r_tid;\n\
7260NORMALIZE:\n\
7261    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7262    @%loop_p bra NORMALIZE_DONE;\n\
7263    cvt.u64.u32 %off, %j;\n\
7264    shl.b64 %off, %off, 2;\n\
7265    add.u64 %off, %out, %off;\n\
7266    add.u64 %off, %off, %row_off;\n\
7267    ld.global.f32 %val, [%off];\n\
7268    mul.f32 %result, %val, %sum_val;\n\
7269    st.global.f32 [%off], %result;\n\
7270    add.u32 %j, %j, %bdim;\n\
7271    bra NORMALIZE;\n\
7272NORMALIZE_DONE:\n\
7273\n\
7274DONE:\n\
7275    ret;\n\
7276}\n\
7277";
7278
7279/// PTX source for `softmax_f64_kernel`: row-wise softmax (f64).
7280#[cfg(feature = "cuda")]
7281pub(crate) const SOFTMAX_F64_PTX: &str = "\
7282.version 7.0\n\
7283.target sm_52\n\
7284.address_size 64\n\
7285\n\
7286.shared .align 8 .f64 sdata[256];\n\
7287\n\
7288.visible .entry softmax_f64_kernel(\n\
7289    .param .u64 input_ptr,\n\
7290    .param .u64 output_ptr,\n\
7291    .param .u32 rows,\n\
7292    .param .u32 cols\n\
7293) {\n\
7294    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
7295    .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
7296    .reg .f64 %val, %max_val, %sum_val, %exp_val, %result, %one;\n\
7297    .reg .pred %p, %loop_p;\n\
7298    .reg .u32 %half, %other_tid;\n\
7299    .reg .f64 %other_val;\n\
7300    .reg .pred %reduce_p;\n\
7301    .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
7302    .reg .s32 %e_ni;\n\
7303    .reg .s64 %e_ni64, %e_bits;\n\
7304\n\
7305    ld.param.u64 %in, [input_ptr];\n\
7306    ld.param.u64 %out, [output_ptr];\n\
7307    ld.param.u32 %rows_reg, [rows];\n\
7308    ld.param.u32 %cols_reg, [cols];\n\
7309\n\
7310    mov.u32 %bid, %ctaid.x;\n\
7311    mov.u32 %bdim, %ntid.x;\n\
7312    mov.u32 %r_tid, %tid.x;\n\
7313    mov.u64 %sbase, sdata;\n\
7314    mov.f64 %one, 0d3FF0000000000000;\n\
7315\n\
7316    setp.ge.u32 %p, %bid, %rows_reg;\n\
7317    @%p bra DONE;\n\
7318\n\
7319    cvt.u64.u32 %row_off, %bid;\n\
7320    cvt.u64.u32 %off, %cols_reg;\n\
7321    mul.lo.u64 %row_off, %row_off, %off;\n\
7322    shl.b64 %row_off, %row_off, 3;\n\
7323\n\
7324    mov.f64 %max_val, 0dFFF0000000000000;\n\
7325    mov.u32 %j, %r_tid;\n\
7326FIND_MAX:\n\
7327    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7328    @%loop_p bra FIND_MAX_DONE;\n\
7329    cvt.u64.u32 %off, %j;\n\
7330    shl.b64 %off, %off, 3;\n\
7331    add.u64 %off, %in, %off;\n\
7332    add.u64 %off, %off, %row_off;\n\
7333    ld.global.f64 %val, [%off];\n\
7334    max.f64 %max_val, %max_val, %val;\n\
7335    add.u32 %j, %j, %bdim;\n\
7336    bra FIND_MAX;\n\
7337FIND_MAX_DONE:\n\
7338\n\
7339    cvt.u64.u32 %off, %r_tid;\n\
7340    shl.b64 %off, %off, 3;\n\
7341    add.u64 %saddr, %sbase, %off;\n\
7342    st.shared.f64 [%saddr], %max_val;\n\
7343    bar.sync 0;\n\
7344\n\
7345    mov.u32 %half, %bdim;\n\
7346MAX_REDUCE:\n\
7347    shr.u32 %half, %half, 1;\n\
7348    setp.eq.u32 %reduce_p, %half, 0;\n\
7349    @%reduce_p bra MAX_REDUCE_DONE;\n\
7350    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7351    @%reduce_p bra MAX_REDUCE_SKIP;\n\
7352    add.u32 %other_tid, %r_tid, %half;\n\
7353    cvt.u64.u32 %off, %other_tid;\n\
7354    shl.b64 %off, %off, 3;\n\
7355    add.u64 %saddr, %sbase, %off;\n\
7356    ld.shared.f64 %other_val, [%saddr];\n\
7357    cvt.u64.u32 %off, %r_tid;\n\
7358    shl.b64 %off, %off, 3;\n\
7359    add.u64 %saddr, %sbase, %off;\n\
7360    ld.shared.f64 %max_val, [%saddr];\n\
7361    max.f64 %max_val, %max_val, %other_val;\n\
7362    st.shared.f64 [%saddr], %max_val;\n\
7363MAX_REDUCE_SKIP:\n\
7364    bar.sync 0;\n\
7365    bra MAX_REDUCE;\n\
7366MAX_REDUCE_DONE:\n\
7367\n\
7368    ld.shared.f64 %max_val, [sdata];\n\
7369    bar.sync 0;\n\
7370\n\
7371    mov.f64 %sum_val, 0d0000000000000000;\n\
7372    mov.u32 %j, %r_tid;\n\
7373SUM_EXP:\n\
7374    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7375    @%loop_p bra SUM_EXP_DONE;\n\
7376    cvt.u64.u32 %off, %j;\n\
7377    shl.b64 %off, %off, 3;\n\
7378    add.u64 %off, %in, %off;\n\
7379    add.u64 %off, %off, %row_off;\n\
7380    ld.global.f64 %val, [%off];\n\
7381    sub.f64 %val, %val, %max_val;\n\
7382    mov.f64 %e_one, 0d3FF0000000000000;\n\
7383    mov.f64 %e_half, 0d3FE0000000000000;\n\
7384    mul.f64 %e_nf, %val, 0d3FF71547652B82FE;\n\
7385    cvt.rni.f64.f64 %e_nf, %e_nf;\n\
7386    cvt.rni.s32.f64 %e_ni, %e_nf;\n\
7387    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %val;\n\
7388    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
7389    mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
7390    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
7391    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
7392    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
7393    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
7394    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
7395    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
7396    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
7397    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
7398    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
7399    fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
7400    fma.rn.f64 %exp_val, %e_p, %e_r, %e_one;\n\
7401    cvt.s64.s32 %e_ni64, %e_ni;\n\
7402    add.s64 %e_ni64, %e_ni64, 1023;\n\
7403    shl.b64 %e_bits, %e_ni64, 52;\n\
7404    mov.b64 %e_nf, %e_bits;\n\
7405    mul.f64 %exp_val, %exp_val, %e_nf;\n\
7406    add.f64 %sum_val, %sum_val, %exp_val;\n\
7407    cvt.u64.u32 %off, %j;\n\
7408    shl.b64 %off, %off, 3;\n\
7409    add.u64 %off, %out, %off;\n\
7410    add.u64 %off, %off, %row_off;\n\
7411    st.global.f64 [%off], %exp_val;\n\
7412    add.u32 %j, %j, %bdim;\n\
7413    bra SUM_EXP;\n\
7414SUM_EXP_DONE:\n\
7415\n\
7416    cvt.u64.u32 %off, %r_tid;\n\
7417    shl.b64 %off, %off, 3;\n\
7418    add.u64 %saddr, %sbase, %off;\n\
7419    st.shared.f64 [%saddr], %sum_val;\n\
7420    bar.sync 0;\n\
7421\n\
7422    mov.u32 %half, %bdim;\n\
7423SUM_REDUCE:\n\
7424    shr.u32 %half, %half, 1;\n\
7425    setp.eq.u32 %reduce_p, %half, 0;\n\
7426    @%reduce_p bra SUM_REDUCE_DONE;\n\
7427    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7428    @%reduce_p bra SUM_REDUCE_SKIP;\n\
7429    add.u32 %other_tid, %r_tid, %half;\n\
7430    cvt.u64.u32 %off, %other_tid;\n\
7431    shl.b64 %off, %off, 3;\n\
7432    add.u64 %saddr, %sbase, %off;\n\
7433    ld.shared.f64 %other_val, [%saddr];\n\
7434    cvt.u64.u32 %off, %r_tid;\n\
7435    shl.b64 %off, %off, 3;\n\
7436    add.u64 %saddr, %sbase, %off;\n\
7437    ld.shared.f64 %sum_val, [%saddr];\n\
7438    add.f64 %sum_val, %sum_val, %other_val;\n\
7439    st.shared.f64 [%saddr], %sum_val;\n\
7440SUM_REDUCE_SKIP:\n\
7441    bar.sync 0;\n\
7442    bra SUM_REDUCE;\n\
7443SUM_REDUCE_DONE:\n\
7444\n\
7445    ld.shared.f64 %sum_val, [sdata];\n\
7446    bar.sync 0;\n\
7447\n\
7448    div.rn.f64 %sum_val, %one, %sum_val;\n\
7449    mov.u32 %j, %r_tid;\n\
7450NORMALIZE:\n\
7451    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7452    @%loop_p bra NORMALIZE_DONE;\n\
7453    cvt.u64.u32 %off, %j;\n\
7454    shl.b64 %off, %off, 3;\n\
7455    add.u64 %off, %out, %off;\n\
7456    add.u64 %off, %off, %row_off;\n\
7457    ld.global.f64 %val, [%off];\n\
7458    mul.f64 %result, %val, %sum_val;\n\
7459    st.global.f64 [%off], %result;\n\
7460    add.u32 %j, %j, %bdim;\n\
7461    bra NORMALIZE;\n\
7462NORMALIZE_DONE:\n\
7463\n\
7464DONE:\n\
7465    ret;\n\
7466}\n\
7467";
7468
7469// ---------------------------------------------------------------------------
7470// Dropout PTX kernel (inverted dropout with xorshift RNG)
7471// ---------------------------------------------------------------------------
7472
7473#[cfg(feature = "cuda")]
7474pub(crate) const DROPOUT_PTX: &str = "\
7475.version 7.0\n\
7476.target sm_52\n\
7477.address_size 64\n\
7478\n\
7479.visible .entry dropout_kernel(\n\
7480    .param .u64 input_ptr,\n\
7481    .param .u64 output_ptr,\n\
7482    .param .u32 n,\n\
7483    .param .u32 threshold,\n\
7484    .param .f32 scale,\n\
7485    .param .u32 seed\n\
7486) {\n\
7487    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %thresh, %seed_reg, %rng, %tmp;\n\
7488    .reg .u64 %in, %out, %off;\n\
7489    .reg .f32 %val, %scale_reg, %zero;\n\
7490    .reg .pred %p, %drop_p;\n\
7491\n\
7492    ld.param.u64 %in, [input_ptr];\n\
7493    ld.param.u64 %out, [output_ptr];\n\
7494    ld.param.u32 %n_reg, [n];\n\
7495    ld.param.u32 %thresh, [threshold];\n\
7496    ld.param.f32 %scale_reg, [scale];\n\
7497    ld.param.u32 %seed_reg, [seed];\n\
7498\n\
7499    mov.u32 %bid, %ctaid.x;\n\
7500    mov.u32 %bdim, %ntid.x;\n\
7501    mov.u32 %r_tid, %tid.x;\n\
7502    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
7503\n\
7504    setp.ge.u32 %p, %r_tid, %n_reg;\n\
7505    @%p bra DONE;\n\
7506\n\
7507    mul.lo.u32 %rng, %r_tid, 2654435761;\n\
7508    xor.b32 %rng, %rng, %seed_reg;\n\
7509    shl.b32 %tmp, %rng, 13;\n\
7510    xor.b32 %rng, %rng, %tmp;\n\
7511    shr.b32 %tmp, %rng, 17;\n\
7512    xor.b32 %rng, %rng, %tmp;\n\
7513    shl.b32 %tmp, %rng, 5;\n\
7514    xor.b32 %rng, %rng, %tmp;\n\
7515\n\
7516    cvt.u64.u32 %off, %r_tid;\n\
7517    shl.b64 %off, %off, 2;\n\
7518    add.u64 %in, %in, %off;\n\
7519    add.u64 %out, %out, %off;\n\
7520    ld.global.f32 %val, [%in];\n\
7521\n\
7522    setp.lo.u32 %drop_p, %rng, %thresh;\n\
7523    mov.f32 %zero, 0f00000000;\n\
7524    @%drop_p mov.f32 %val, %zero;\n\
7525    @!%drop_p mul.f32 %val, %val, %scale_reg;\n\
7526\n\
7527    st.global.f32 [%out], %val;\n\
7528\n\
7529DONE:\n\
7530    ret;\n\
7531}\n\
7532";
7533
7534
7535// ---------------------------------------------------------------------------
7536// General N-dimensional broadcast binary PTX kernels
7537// ---------------------------------------------------------------------------
7538//
7539// Each thread computes one output element. The kernel decomposes the flat
7540// output index into N-dimensional coordinates, maps each coordinate through
7541// broadcast strides for A and B, and loads from the correct flat position.
7542//
7543// Parameters:
7544//   a_ptr         - pointer to A's device buffer
7545//   b_ptr         - pointer to B's device buffer
7546//   out_ptr       - pointer to output device buffer
7547//   a_strides_ptr - pointer to u32[ndim] broadcast strides for A
7548//   b_strides_ptr - pointer to u32[ndim] broadcast strides for B
7549//   out_shape_ptr - pointer to u32[ndim] output shape
7550//   n             - total output elements
7551//   ndim          - number of dimensions
7552//
7553// Broadcast strides: for each dimension d, stride is the normal
7554// C-contiguous stride if dim_size > 1, or 0 if dim_size == 1 (broadcast).
7555
7556/// PTX for general broadcast add: `out[i] = a[bcast_a(i)] + b[bcast_b(i)]`.
7557#[cfg(feature = "cuda")]
7558pub(crate) const BROADCAST_ADD_PTX: &str = "\
7559.version 7.0
7560.target sm_52
7561.address_size 64
7562
7563.visible .entry broadcast_add_kernel(
7564    .param .u64 a_ptr,
7565    .param .u64 b_ptr,
7566    .param .u64 out_ptr,
7567    .param .u64 a_strides_ptr,
7568    .param .u64 b_strides_ptr,
7569    .param .u64 out_shape_ptr,
7570    .param .u32 n,
7571    .param .u32 ndim
7572) {
7573    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7574    .reg .u32 %remaining, %a_idx, %b_idx, %d;
7575    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7576    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7577    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7578    .reg .f32 %va, %vb, %vr;
7579    .reg .pred %p, %loop_p;
7580
7581    ld.param.u64 %a, [a_ptr];
7582    ld.param.u64 %b, [b_ptr];
7583    ld.param.u64 %out, [out_ptr];
7584    ld.param.u64 %a_str, [a_strides_ptr];
7585    ld.param.u64 %b_str, [b_strides_ptr];
7586    ld.param.u64 %oshape, [out_shape_ptr];
7587    ld.param.u32 %n_reg, [n];
7588    ld.param.u32 %ndim_reg, [ndim];
7589
7590    // Global thread index.
7591    mov.u32 %bid, %ctaid.x;
7592    mov.u32 %bdim, %ntid.x;
7593    mov.u32 %r_tid, %tid.x;
7594    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7595
7596    setp.ge.u32 %p, %r_tid, %n_reg;
7597    @%p bra DONE;
7598
7599    // Decompose flat index into N-d coordinates and compute A/B indices.
7600    mov.u32 %remaining, %r_tid;
7601    mov.u32 %a_idx, 0;
7602    mov.u32 %b_idx, 0;
7603    mov.u32 %d, %ndim_reg;
7604
7605LOOP:
7606    setp.eq.u32 %loop_p, %d, 0;
7607    @%loop_p bra END_LOOP;
7608
7609    sub.u32 %d, %d, 1;
7610
7611    // Byte offset for dimension d: d * 4.
7612    cvt.u64.u32 %d64, %d;
7613    shl.b64 %d64, %d64, 2;
7614
7615    // Load out_shape[d].
7616    add.u64 %tmp, %oshape, %d64;
7617    ld.global.u32 %shape_d, [%tmp];
7618
7619    // Load a_strides[d] and b_strides[d].
7620    add.u64 %tmp, %a_str, %d64;
7621    ld.global.u32 %a_str_d, [%tmp];
7622    add.u64 %tmp, %b_str, %d64;
7623    ld.global.u32 %b_str_d, [%tmp];
7624
7625    // coord = remaining % shape_d; remaining /= shape_d.
7626    rem.u32 %coord, %remaining, %shape_d;
7627    div.u32 %remaining, %remaining, %shape_d;
7628
7629    // a_idx += coord * a_stride[d]; b_idx += coord * b_stride[d].
7630    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7631    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7632
7633    bra LOOP;
7634END_LOOP:
7635
7636    // Load a[a_idx] and b[b_idx] (f32 = 4 bytes).
7637    cvt.u64.u32 %off_a, %a_idx;
7638    shl.b64 %off_a, %off_a, 2;
7639    add.u64 %off_a, %a, %off_a;
7640    ld.global.f32 %va, [%off_a];
7641
7642    cvt.u64.u32 %off_b, %b_idx;
7643    shl.b64 %off_b, %off_b, 2;
7644    add.u64 %off_b, %b, %off_b;
7645    ld.global.f32 %vb, [%off_b];
7646
7647    // Operation: add.
7648    add.f32 %vr, %va, %vb;
7649
7650    // Store to out[tid].
7651    cvt.u64.u32 %off_out, %r_tid;
7652    shl.b64 %off_out, %off_out, 2;
7653    add.u64 %off_out, %out, %off_out;
7654    st.global.f32 [%off_out], %vr;
7655
7656DONE:
7657    ret;
7658}
7659";
7660
7661
7662/// PTX for general broadcast sub: `out[i] = a[bcast_a(i)] - b[bcast_b(i)]`.
7663#[cfg(feature = "cuda")]
7664pub(crate) const BROADCAST_SUB_PTX: &str = "\
7665.version 7.0
7666.target sm_52
7667.address_size 64
7668
7669.visible .entry broadcast_sub_kernel(
7670    .param .u64 a_ptr,
7671    .param .u64 b_ptr,
7672    .param .u64 out_ptr,
7673    .param .u64 a_strides_ptr,
7674    .param .u64 b_strides_ptr,
7675    .param .u64 out_shape_ptr,
7676    .param .u32 n,
7677    .param .u32 ndim
7678) {
7679    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7680    .reg .u32 %remaining, %a_idx, %b_idx, %d;
7681    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7682    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7683    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7684    .reg .f32 %va, %vb, %vr;
7685    .reg .pred %p, %loop_p;
7686
7687    ld.param.u64 %a, [a_ptr];
7688    ld.param.u64 %b, [b_ptr];
7689    ld.param.u64 %out, [out_ptr];
7690    ld.param.u64 %a_str, [a_strides_ptr];
7691    ld.param.u64 %b_str, [b_strides_ptr];
7692    ld.param.u64 %oshape, [out_shape_ptr];
7693    ld.param.u32 %n_reg, [n];
7694    ld.param.u32 %ndim_reg, [ndim];
7695
7696    mov.u32 %bid, %ctaid.x;
7697    mov.u32 %bdim, %ntid.x;
7698    mov.u32 %r_tid, %tid.x;
7699    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7700    setp.ge.u32 %p, %r_tid, %n_reg;
7701    @%p bra DONE;
7702
7703    mov.u32 %remaining, %r_tid;
7704    mov.u32 %a_idx, 0;
7705    mov.u32 %b_idx, 0;
7706    mov.u32 %d, %ndim_reg;
7707LOOP:
7708    setp.eq.u32 %loop_p, %d, 0;
7709    @%loop_p bra END_LOOP;
7710    sub.u32 %d, %d, 1;
7711    cvt.u64.u32 %d64, %d;
7712    shl.b64 %d64, %d64, 2;
7713    add.u64 %tmp, %oshape, %d64;
7714    ld.global.u32 %shape_d, [%tmp];
7715    add.u64 %tmp, %a_str, %d64;
7716    ld.global.u32 %a_str_d, [%tmp];
7717    add.u64 %tmp, %b_str, %d64;
7718    ld.global.u32 %b_str_d, [%tmp];
7719    rem.u32 %coord, %remaining, %shape_d;
7720    div.u32 %remaining, %remaining, %shape_d;
7721    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7722    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7723    bra LOOP;
7724END_LOOP:
7725
7726    cvt.u64.u32 %off_a, %a_idx;
7727    shl.b64 %off_a, %off_a, 2;
7728    add.u64 %off_a, %a, %off_a;
7729    ld.global.f32 %va, [%off_a];
7730    cvt.u64.u32 %off_b, %b_idx;
7731    shl.b64 %off_b, %off_b, 2;
7732    add.u64 %off_b, %b, %off_b;
7733    ld.global.f32 %vb, [%off_b];
7734
7735    sub.f32 %vr, %va, %vb;
7736
7737    cvt.u64.u32 %off_out, %r_tid;
7738    shl.b64 %off_out, %off_out, 2;
7739    add.u64 %off_out, %out, %off_out;
7740    st.global.f32 [%off_out], %vr;
7741DONE:
7742    ret;
7743}
7744";
7745
7746
7747/// PTX for general broadcast mul: `out[i] = a[bcast_a(i)] * b[bcast_b(i)]`.
7748#[cfg(feature = "cuda")]
7749pub(crate) const BROADCAST_MUL_PTX: &str = "\
7750.version 7.0
7751.target sm_52
7752.address_size 64
7753
7754.visible .entry broadcast_mul_kernel(
7755    .param .u64 a_ptr,
7756    .param .u64 b_ptr,
7757    .param .u64 out_ptr,
7758    .param .u64 a_strides_ptr,
7759    .param .u64 b_strides_ptr,
7760    .param .u64 out_shape_ptr,
7761    .param .u32 n,
7762    .param .u32 ndim
7763) {
7764    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7765    .reg .u32 %remaining, %a_idx, %b_idx, %d;
7766    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7767    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7768    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7769    .reg .f32 %va, %vb, %vr;
7770    .reg .pred %p, %loop_p;
7771
7772    ld.param.u64 %a, [a_ptr];
7773    ld.param.u64 %b, [b_ptr];
7774    ld.param.u64 %out, [out_ptr];
7775    ld.param.u64 %a_str, [a_strides_ptr];
7776    ld.param.u64 %b_str, [b_strides_ptr];
7777    ld.param.u64 %oshape, [out_shape_ptr];
7778    ld.param.u32 %n_reg, [n];
7779    ld.param.u32 %ndim_reg, [ndim];
7780
7781    mov.u32 %bid, %ctaid.x;
7782    mov.u32 %bdim, %ntid.x;
7783    mov.u32 %r_tid, %tid.x;
7784    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7785    setp.ge.u32 %p, %r_tid, %n_reg;
7786    @%p bra DONE;
7787
7788    mov.u32 %remaining, %r_tid;
7789    mov.u32 %a_idx, 0;
7790    mov.u32 %b_idx, 0;
7791    mov.u32 %d, %ndim_reg;
7792LOOP:
7793    setp.eq.u32 %loop_p, %d, 0;
7794    @%loop_p bra END_LOOP;
7795    sub.u32 %d, %d, 1;
7796    cvt.u64.u32 %d64, %d;
7797    shl.b64 %d64, %d64, 2;
7798    add.u64 %tmp, %oshape, %d64;
7799    ld.global.u32 %shape_d, [%tmp];
7800    add.u64 %tmp, %a_str, %d64;
7801    ld.global.u32 %a_str_d, [%tmp];
7802    add.u64 %tmp, %b_str, %d64;
7803    ld.global.u32 %b_str_d, [%tmp];
7804    rem.u32 %coord, %remaining, %shape_d;
7805    div.u32 %remaining, %remaining, %shape_d;
7806    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7807    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7808    bra LOOP;
7809END_LOOP:
7810
7811    cvt.u64.u32 %off_a, %a_idx;
7812    shl.b64 %off_a, %off_a, 2;
7813    add.u64 %off_a, %a, %off_a;
7814    ld.global.f32 %va, [%off_a];
7815    cvt.u64.u32 %off_b, %b_idx;
7816    shl.b64 %off_b, %off_b, 2;
7817    add.u64 %off_b, %b, %off_b;
7818    ld.global.f32 %vb, [%off_b];
7819
7820    mul.f32 %vr, %va, %vb;
7821
7822    cvt.u64.u32 %off_out, %r_tid;
7823    shl.b64 %off_out, %off_out, 2;
7824    add.u64 %off_out, %out, %off_out;
7825    st.global.f32 [%off_out], %vr;
7826DONE:
7827    ret;
7828}
7829";
7830
7831
7832/// PTX source for `broadcast_div_kernel`: broadcast division, identical structure
7833/// to `broadcast_mul_kernel` but uses `div.f32` instead of `mul.f32`.
7834#[cfg(feature = "cuda")]
7835pub(crate) const BROADCAST_DIV_PTX: &str = "\
7836.version 7.0
7837.target sm_52
7838.address_size 64
7839
7840.visible .entry broadcast_div_kernel(
7841    .param .u64 a_ptr,
7842    .param .u64 b_ptr,
7843    .param .u64 out_ptr,
7844    .param .u64 a_strides_ptr,
7845    .param .u64 b_strides_ptr,
7846    .param .u64 out_shape_ptr,
7847    .param .u32 n,
7848    .param .u32 ndim
7849) {
7850    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7851    .reg .u32 %remaining, %a_idx, %b_idx, %d;
7852    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7853    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7854    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7855    .reg .f32 %va, %vb, %vr;
7856    .reg .pred %p, %loop_p;
7857
7858    ld.param.u64 %a, [a_ptr];
7859    ld.param.u64 %b, [b_ptr];
7860    ld.param.u64 %out, [out_ptr];
7861    ld.param.u64 %a_str, [a_strides_ptr];
7862    ld.param.u64 %b_str, [b_strides_ptr];
7863    ld.param.u64 %oshape, [out_shape_ptr];
7864    ld.param.u32 %n_reg, [n];
7865    ld.param.u32 %ndim_reg, [ndim];
7866
7867    mov.u32 %bid, %ctaid.x;
7868    mov.u32 %bdim, %ntid.x;
7869    mov.u32 %r_tid, %tid.x;
7870    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7871    setp.ge.u32 %p, %r_tid, %n_reg;
7872    @%p bra DONE;
7873
7874    mov.u32 %remaining, %r_tid;
7875    mov.u32 %a_idx, 0;
7876    mov.u32 %b_idx, 0;
7877    mov.u32 %d, %ndim_reg;
7878LOOP:
7879    setp.eq.u32 %loop_p, %d, 0;
7880    @%loop_p bra END_LOOP;
7881    sub.u32 %d, %d, 1;
7882    cvt.u64.u32 %d64, %d;
7883    shl.b64 %d64, %d64, 2;
7884    add.u64 %tmp, %oshape, %d64;
7885    ld.global.u32 %shape_d, [%tmp];
7886    add.u64 %tmp, %a_str, %d64;
7887    ld.global.u32 %a_str_d, [%tmp];
7888    add.u64 %tmp, %b_str, %d64;
7889    ld.global.u32 %b_str_d, [%tmp];
7890    rem.u32 %coord, %remaining, %shape_d;
7891    div.u32 %remaining, %remaining, %shape_d;
7892    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7893    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7894    bra LOOP;
7895END_LOOP:
7896
7897    cvt.u64.u32 %off_a, %a_idx;
7898    shl.b64 %off_a, %off_a, 2;
7899    add.u64 %off_a, %a, %off_a;
7900    ld.global.f32 %va, [%off_a];
7901    cvt.u64.u32 %off_b, %b_idx;
7902    shl.b64 %off_b, %off_b, 2;
7903    add.u64 %off_b, %b, %off_b;
7904    ld.global.f32 %vb, [%off_b];
7905
7906    div.f32 %vr, %va, %vb;
7907
7908    cvt.u64.u32 %off_out, %r_tid;
7909    shl.b64 %off_out, %off_out, 2;
7910    add.u64 %off_out, %out, %off_out;
7911    st.global.f32 [%off_out], %vr;
7912DONE:
7913    ret;
7914}
7915";
7916
7917
7918/// PTX source for `strided_split_kernel`: extract a sub-tensor along a given axis.
7919///
7920/// Thread `i` computes:
7921///   `outer_idx = i / (split_size * inner_size)`
7922///   `within    = i % (split_size * inner_size)`
7923///   `src_idx   = outer_idx * total_along_axis * inner_size + (split_offset * inner_size) + within`
7924///   `out[i]    = in[src_idx]`
7925#[cfg(feature = "cuda")]
7926pub(crate) const STRIDED_SPLIT_PTX: &str = "\
7927.version 7.0
7928.target sm_52
7929.address_size 64
7930
7931.visible .entry strided_split_kernel(
7932    .param .u64 input_ptr,
7933    .param .u64 output_ptr,
7934    .param .u32 total_along_axis,
7935    .param .u32 split_offset,
7936    .param .u32 split_size,
7937    .param .u32 inner_size,
7938    .param .u32 n
7939) {
7940    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
7941    .reg .u32 %total_ax, %sp_off, %sp_sz, %inner_sz;
7942    .reg .u32 %outer_idx, %within, %chunk_stride, %src_idx, %base_off, %tmp;
7943    .reg .u64 %in, %out, %off;
7944    .reg .f32 %val;
7945    .reg .pred %p;
7946
7947    ld.param.u64 %in, [input_ptr];
7948    ld.param.u64 %out, [output_ptr];
7949    ld.param.u32 %total_ax, [total_along_axis];
7950    ld.param.u32 %sp_off, [split_offset];
7951    ld.param.u32 %sp_sz, [split_size];
7952    ld.param.u32 %inner_sz, [inner_size];
7953    ld.param.u32 %n_reg, [n];
7954
7955    mov.u32 %bid, %ctaid.x;
7956    mov.u32 %bdim, %ntid.x;
7957    mov.u32 %r_tid, %tid.x;
7958    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7959
7960    setp.ge.u32 %p, %r_tid, %n_reg;
7961    @%p bra DONE;
7962
7963    // chunk_stride = split_size * inner_size
7964    mul.lo.u32 %chunk_stride, %sp_sz, %inner_sz;
7965
7966    // outer_idx = r_tid / chunk_stride
7967    div.u32 %outer_idx, %r_tid, %chunk_stride;
7968
7969    // within = r_tid % chunk_stride
7970    rem.u32 %within, %r_tid, %chunk_stride;
7971
7972    // base_off = split_offset * inner_size
7973    mul.lo.u32 %base_off, %sp_off, %inner_sz;
7974
7975    // src_idx = outer_idx * total_along_axis * inner_size + base_off + within
7976    mul.lo.u32 %src_idx, %outer_idx, %total_ax;
7977    mul.lo.u32 %src_idx, %src_idx, %inner_sz;
7978    add.u32 %src_idx, %src_idx, %base_off;
7979    add.u32 %src_idx, %src_idx, %within;
7980
7981    // Load from in[src_idx]
7982    cvt.u64.u32 %off, %src_idx;
7983    shl.b64 %off, %off, 2;
7984    add.u64 %off, %in, %off;
7985    ld.global.f32 %val, [%off];
7986
7987    // Store to out[r_tid]
7988    cvt.u64.u32 %off, %r_tid;
7989    shl.b64 %off, %off, 2;
7990    add.u64 %off, %out, %off;
7991    st.global.f32 [%off], %val;
7992
7993DONE:
7994    ret;
7995}
7996";
7997
7998
7999/// PTX source for `strided_cat_kernel`: write a sub-tensor into a larger tensor
8000/// at an offset along an axis.
8001///
8002/// Thread `i` computes:
8003///   `outer_idx = i / (part_size * inner_size)`
8004///   `within    = i % (part_size * inner_size)`
8005///   `dst_idx   = outer_idx * total_along_axis * inner_size + (cat_offset * inner_size) + within`
8006///   `out[dst_idx] = in[i]`
8007#[cfg(feature = "cuda")]
8008pub(crate) const STRIDED_CAT_PTX: &str = "\
8009.version 7.0
8010.target sm_52
8011.address_size 64
8012
8013.visible .entry strided_cat_kernel(
8014    .param .u64 input_ptr,
8015    .param .u64 output_ptr,
8016    .param .u32 total_along_axis,
8017    .param .u32 cat_offset,
8018    .param .u32 part_size,
8019    .param .u32 inner_size,
8020    .param .u32 n
8021) {
8022    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8023    .reg .u32 %total_ax, %cat_off, %part_sz, %inner_sz;
8024    .reg .u32 %outer_idx, %within, %chunk_stride, %dst_idx, %base_off;
8025    .reg .u64 %in, %out, %off;
8026    .reg .f32 %val;
8027    .reg .pred %p;
8028
8029    ld.param.u64 %in, [input_ptr];
8030    ld.param.u64 %out, [output_ptr];
8031    ld.param.u32 %total_ax, [total_along_axis];
8032    ld.param.u32 %cat_off, [cat_offset];
8033    ld.param.u32 %part_sz, [part_size];
8034    ld.param.u32 %inner_sz, [inner_size];
8035    ld.param.u32 %n_reg, [n];
8036
8037    mov.u32 %bid, %ctaid.x;
8038    mov.u32 %bdim, %ntid.x;
8039    mov.u32 %r_tid, %tid.x;
8040    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8041
8042    setp.ge.u32 %p, %r_tid, %n_reg;
8043    @%p bra DONE;
8044
8045    // chunk_stride = part_size * inner_size
8046    mul.lo.u32 %chunk_stride, %part_sz, %inner_sz;
8047
8048    // outer_idx = r_tid / chunk_stride
8049    div.u32 %outer_idx, %r_tid, %chunk_stride;
8050
8051    // within = r_tid % chunk_stride
8052    rem.u32 %within, %r_tid, %chunk_stride;
8053
8054    // base_off = cat_offset * inner_size
8055    mul.lo.u32 %base_off, %cat_off, %inner_sz;
8056
8057    // dst_idx = outer_idx * total_along_axis * inner_size + base_off + within
8058    mul.lo.u32 %dst_idx, %outer_idx, %total_ax;
8059    mul.lo.u32 %dst_idx, %dst_idx, %inner_sz;
8060    add.u32 %dst_idx, %dst_idx, %base_off;
8061    add.u32 %dst_idx, %dst_idx, %within;
8062
8063    // Load from in[r_tid]
8064    cvt.u64.u32 %off, %r_tid;
8065    shl.b64 %off, %off, 2;
8066    add.u64 %off, %in, %off;
8067    ld.global.f32 %val, [%off];
8068
8069    // Store to out[dst_idx]
8070    cvt.u64.u32 %off, %dst_idx;
8071    shl.b64 %off, %off, 2;
8072    add.u64 %off, %out, %off;
8073    st.global.f32 [%off], %val;
8074
8075DONE:
8076    ret;
8077}
8078";
8079
8080
8081/// PTX source for `strided_copy_kernel`: general strided→contiguous
8082/// gather with up to 8 dimensions. CL-496.
8083///
8084/// Thread `i` computes:
8085///   flat = i
8086///   src = src_offset_base
8087///   for d in 0..8:
8088///       coord = flat / out_stride[d]
8089///       flat  = flat % out_stride[d]
8090///       src  += coord * src_stride[d]
8091///   out[i] = in[src]
8092///
8093/// For tensors with fewer than 8 dims, unused positions must be
8094/// padded with `out_stride[d] = n + 1` (so `flat / out_stride[d] = 0`)
8095/// and `src_stride[d] = 0` (so the contribution is zero).
8096///
8097/// Each stride is passed as an individual u32 kernel parameter to
8098/// avoid needing a device-side stride array. 20 params total is well
8099/// within the ~4KB param limit.
8100#[cfg(feature = "cuda")]
8101pub(crate) const STRIDED_COPY_PTX: &str = "\
8102.version 7.0
8103.target sm_52
8104.address_size 64
8105
8106.visible .entry strided_copy_kernel(
8107    .param .u64 input_ptr,
8108    .param .u64 output_ptr,
8109    .param .u32 src_offset_base,
8110    .param .u32 n,
8111    .param .u32 os0, .param .u32 os1, .param .u32 os2, .param .u32 os3,
8112    .param .u32 os4, .param .u32 os5, .param .u32 os6, .param .u32 os7,
8113    .param .u32 ss0, .param .u32 ss1, .param .u32 ss2, .param .u32 ss3,
8114    .param .u32 ss4, .param .u32 ss5, .param .u32 ss6, .param .u32 ss7
8115) {
8116    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8117    .reg .u32 %flat, %src_idx, %coord, %tmp, %os, %ss;
8118    .reg .u64 %in, %out, %off;
8119    .reg .f32 %val;
8120    .reg .pred %p;
8121
8122    ld.param.u64 %in, [input_ptr];
8123    ld.param.u64 %out, [output_ptr];
8124    ld.param.u32 %src_idx, [src_offset_base];
8125    ld.param.u32 %n_reg, [n];
8126
8127    mov.u32 %bid, %ctaid.x;
8128    mov.u32 %bdim, %ntid.x;
8129    mov.u32 %r_tid, %tid.x;
8130    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8131
8132    setp.ge.u32 %p, %r_tid, %n_reg;
8133    @%p bra DONE;
8134
8135    mov.u32 %flat, %r_tid;
8136
8137    // Dim 0
8138    ld.param.u32 %os, [os0];
8139    ld.param.u32 %ss, [ss0];
8140    div.u32 %coord, %flat, %os;
8141    mul.lo.u32 %tmp, %coord, %os;
8142    sub.u32 %flat, %flat, %tmp;
8143    mul.lo.u32 %tmp, %coord, %ss;
8144    add.u32 %src_idx, %src_idx, %tmp;
8145
8146    // Dim 1
8147    ld.param.u32 %os, [os1];
8148    ld.param.u32 %ss, [ss1];
8149    div.u32 %coord, %flat, %os;
8150    mul.lo.u32 %tmp, %coord, %os;
8151    sub.u32 %flat, %flat, %tmp;
8152    mul.lo.u32 %tmp, %coord, %ss;
8153    add.u32 %src_idx, %src_idx, %tmp;
8154
8155    // Dim 2
8156    ld.param.u32 %os, [os2];
8157    ld.param.u32 %ss, [ss2];
8158    div.u32 %coord, %flat, %os;
8159    mul.lo.u32 %tmp, %coord, %os;
8160    sub.u32 %flat, %flat, %tmp;
8161    mul.lo.u32 %tmp, %coord, %ss;
8162    add.u32 %src_idx, %src_idx, %tmp;
8163
8164    // Dim 3
8165    ld.param.u32 %os, [os3];
8166    ld.param.u32 %ss, [ss3];
8167    div.u32 %coord, %flat, %os;
8168    mul.lo.u32 %tmp, %coord, %os;
8169    sub.u32 %flat, %flat, %tmp;
8170    mul.lo.u32 %tmp, %coord, %ss;
8171    add.u32 %src_idx, %src_idx, %tmp;
8172
8173    // Dim 4
8174    ld.param.u32 %os, [os4];
8175    ld.param.u32 %ss, [ss4];
8176    div.u32 %coord, %flat, %os;
8177    mul.lo.u32 %tmp, %coord, %os;
8178    sub.u32 %flat, %flat, %tmp;
8179    mul.lo.u32 %tmp, %coord, %ss;
8180    add.u32 %src_idx, %src_idx, %tmp;
8181
8182    // Dim 5
8183    ld.param.u32 %os, [os5];
8184    ld.param.u32 %ss, [ss5];
8185    div.u32 %coord, %flat, %os;
8186    mul.lo.u32 %tmp, %coord, %os;
8187    sub.u32 %flat, %flat, %tmp;
8188    mul.lo.u32 %tmp, %coord, %ss;
8189    add.u32 %src_idx, %src_idx, %tmp;
8190
8191    // Dim 6
8192    ld.param.u32 %os, [os6];
8193    ld.param.u32 %ss, [ss6];
8194    div.u32 %coord, %flat, %os;
8195    mul.lo.u32 %tmp, %coord, %os;
8196    sub.u32 %flat, %flat, %tmp;
8197    mul.lo.u32 %tmp, %coord, %ss;
8198    add.u32 %src_idx, %src_idx, %tmp;
8199
8200    // Dim 7
8201    ld.param.u32 %os, [os7];
8202    ld.param.u32 %ss, [ss7];
8203    div.u32 %coord, %flat, %os;
8204    mul.lo.u32 %tmp, %coord, %os;
8205    sub.u32 %flat, %flat, %tmp;
8206    mul.lo.u32 %tmp, %coord, %ss;
8207    add.u32 %src_idx, %src_idx, %tmp;
8208
8209    // Load from in[src_idx]
8210    cvt.u64.u32 %off, %src_idx;
8211    shl.b64 %off, %off, 2;
8212    add.u64 %off, %in, %off;
8213    ld.global.f32 %val, [%off];
8214
8215    // Store to out[r_tid]
8216    cvt.u64.u32 %off, %r_tid;
8217    shl.b64 %off, %off, 2;
8218    add.u64 %off, %out, %off;
8219    st.global.f32 [%off], %val;
8220
8221DONE:
8222    ret;
8223}
8224";
8225
8226
8227/// PTX source for `div_kernel`: `out[i] = a[i] / b[i]`.
8228#[cfg(feature = "cuda")]
8229pub(crate) const DIV_PTX: &str = "\
8230.version 7.0
8231.target sm_52
8232.address_size 64
8233
8234.visible .entry div_kernel(
8235    .param .u64 a_ptr,
8236    .param .u64 b_ptr,
8237    .param .u64 out_ptr,
8238    .param .u32 n
8239) {
8240    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8241    .reg .u64 %a, %b, %out, %off;
8242    .reg .f32 %va, %vb, %vr;
8243    .reg .pred %p;
8244
8245    ld.param.u64 %a, [a_ptr];
8246    ld.param.u64 %b, [b_ptr];
8247    ld.param.u64 %out, [out_ptr];
8248    ld.param.u32 %n_reg, [n];
8249
8250    mov.u32 %bid, %ctaid.x;
8251    mov.u32 %bdim, %ntid.x;
8252    mov.u32 %r_tid, %tid.x;
8253    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8254
8255    setp.ge.u32 %p, %r_tid, %n_reg;
8256    @%p bra DONE;
8257
8258    cvt.u64.u32 %off, %r_tid;
8259    shl.b64 %off, %off, 2;
8260
8261    add.u64 %a, %a, %off;
8262    add.u64 %b, %b, %off;
8263    add.u64 %out, %out, %off;
8264
8265    ld.global.f32 %va, [%a];
8266    ld.global.f32 %vb, [%b];
8267    div.rn.f32 %vr, %va, %vb;
8268    st.global.f32 [%out], %vr;
8269
8270DONE:
8271    ret;
8272}
8273";
8274
8275
8276/// PTX source for `exp_kernel`: `out[i] = exp(a[i])`.
8277#[cfg(feature = "cuda")]
8278pub(crate) const EXP_PTX: &str = "\
8279.version 7.0
8280.target sm_52
8281.address_size 64
8282
8283.visible .entry exp_kernel(
8284    .param .u64 a_ptr,
8285    .param .u64 out_ptr,
8286    .param .u32 n
8287) {
8288    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8289    .reg .u64 %a, %out, %off;
8290    .reg .f32 %va, %vr;
8291    .reg .pred %p;
8292
8293    ld.param.u64 %a, [a_ptr];
8294    ld.param.u64 %out, [out_ptr];
8295    ld.param.u32 %n_reg, [n];
8296
8297    mov.u32 %bid, %ctaid.x;
8298    mov.u32 %bdim, %ntid.x;
8299    mov.u32 %r_tid, %tid.x;
8300    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8301
8302    setp.ge.u32 %p, %r_tid, %n_reg;
8303    @%p bra DONE;
8304
8305    cvt.u64.u32 %off, %r_tid;
8306    shl.b64 %off, %off, 2;
8307
8308    add.u64 %a, %a, %off;
8309    add.u64 %out, %out, %off;
8310
8311    ld.global.f32 %va, [%a];
8312    // PTX ex2.approx computes 2^x; use the identity exp(x) = 2^(x * log2(e))
8313    // log2(e) = 1.4426950408889634
8314    mul.f32 %va, %va, 0f3FB8AA3B;
8315    ex2.approx.f32 %vr, %va;
8316    st.global.f32 [%out], %vr;
8317
8318DONE:
8319    ret;
8320}
8321";
8322
8323/// PTX source for `exp_f64_kernel`: `out[i] = exp(a[i])` (f64).
8324/// Uses f32 `ex2.approx` via downcast for the transcendental, then upcasts back.
8325/// Accurate to f32 precision (~7 decimal digits), sufficient for deep learning.
8326#[cfg(feature = "cuda")]
8327/// f64 exp with full double precision via Cody-Waite range reduction +
8328/// degree-13 minimax polynomial.
8329///
8330/// Algorithm: exp(x) = 2^n * (1 + P(r))
8331///   where n = round(x * log2(e)), r = x - n*ln2_hi - n*ln2_lo
8332///   and P(r) is a 13th-degree minimax polynomial for (exp(r)-1)/r.
8333///
8334/// Accuracy: < 1 ULP for |x| < 709 (full f64 range).
8335pub(crate) const EXP_F64_PTX: &str = "\
8336.version 7.0
8337.target sm_52
8338.address_size 64
8339
8340.visible .entry exp_f64_kernel(
8341    .param .u64 a_ptr,
8342    .param .u64 out_ptr,
8343    .param .u32 n
8344) {
8345    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8346    .reg .u64 %a, %out, %off;
8347    .reg .f64 %x, %vr;
8348    .reg .f64 %log2e, %nf, %r;
8349    .reg .f64 %p, %one, %half;
8350    .reg .s32 %ni;
8351    .reg .s64 %ni64, %exp_bits;
8352    .reg .pred %p_bounds, %p_tid;
8353
8354    ld.param.u64 %a, [a_ptr];
8355    ld.param.u64 %out, [out_ptr];
8356    ld.param.u32 %n_reg, [n];
8357
8358    mov.u32 %bid, %ctaid.x;
8359    mov.u32 %bdim, %ntid.x;
8360    mov.u32 %r_tid, %tid.x;
8361    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8362
8363    setp.ge.u32 %p_tid, %r_tid, %n_reg;
8364    @%p_tid bra DONE;
8365
8366    cvt.u64.u32 %off, %r_tid;
8367    shl.b64 %off, %off, 3;
8368    add.u64 %a, %a, %off;
8369    add.u64 %out, %out, %off;
8370
8371    ld.global.f64 %x, [%a];
8372
8373    // Constants
8374    mov.f64 %log2e, 0d3FF71547652B82FE;   // log2(e) = 1.4426950408889634
8375    mov.f64 %ln2_hi, 0d3FE62E42FEFA3800;  // ln(2) high bits
8376    mov.f64 %ln2_lo, 0d3D2EF35793C76730;  // ln(2) low bits
8377    mov.f64 %one, 0d3FF0000000000000;      // 1.0
8378    mov.f64 %half, 0d3FE0000000000000;     // 0.5
8379
8380    // n = round(x * log2(e))
8381    mul.f64 %nf, %x, %log2e;
8382    cvt.rni.f64.f64 %nf, %nf;             // round to nearest integer
8383    cvt.rni.s32.f64 %ni, %nf;             // integer n
8384
8385    // r = x - n * ln2  (Cody-Waite two-step for precision)
8386    fma.rn.f64 %r, %nf, 0dBFE62E42FEFA3800, %x;  // r = x - n*ln2_hi
8387    fma.rn.f64 %r, %nf, 0dBD2EF35793C76730, %r;   // r -= n*ln2_lo
8388
8389    // Horner polynomial for exp(r) - 1 - r = r^2 * (1/2! + r*(1/3! + r*(1/4! + ...)))
8390    // p starts at 1/11!, accumulates down to 1/2!
8391    mov.f64 %p, 0d3E21EED8EFF8D898;           // 1/11! = 2.505e-8
8392    fma.rn.f64 %p, %p, %r, 0d3E5AE64567F544E4;  // 1/10! = 2.756e-7
8393    fma.rn.f64 %p, %p, %r, 0d3E927E4FB7789F5C;  // 1/9!  = 2.756e-6
8394    fma.rn.f64 %p, %p, %r, 0d3EC71DE3A556C734;  // 1/8!  = 2.480e-5
8395    fma.rn.f64 %p, %p, %r, 0d3EFA01A01A01A01A;  // 1/7!  = 1.984e-4
8396    fma.rn.f64 %p, %p, %r, 0d3F2A01A01A01A01A;  // 1/6!  = 1.389e-3
8397    fma.rn.f64 %p, %p, %r, 0d3F56C16C16C16C17;  // 1/5!  = 8.333e-3
8398    fma.rn.f64 %p, %p, %r, 0d3F811111111111111;  // 1/4!  = 4.167e-2
8399    fma.rn.f64 %p, %p, %r, 0d3FC5555555555555;  // 1/3!  = 1.667e-1
8400    fma.rn.f64 %p, %p, %r, %half;                // 1/2!  = 5.000e-1
8401
8402    // exp(r) = 1 + r + r^2 * p  =>  1 + r*(1 + r*p)
8403    fma.rn.f64 %p, %p, %r, %one;   // p = r*p + 1
8404    fma.rn.f64 %vr, %p, %r, %one;  // vr = p*r + 1 = exp(r)
8405
8406    // Scale by 2^n: multiply by constructing the f64 bit pattern for 2^n.
8407    // IEEE 754 f64: 2^n has exponent field = n + 1023, no mantissa bits.
8408    // Bit pattern: (n + 1023) << 52.
8409    cvt.s64.s32 %ni64, %ni;
8410    add.s64 %ni64, %ni64, 1023;
8411    shl.b64 %exp_bits, %ni64, 52;
8412    mov.b64 %nf, %exp_bits;        // reinterpret as f64 = 2^n
8413    mul.f64 %vr, %vr, %nf;
8414
8415    st.global.f64 [%out], %vr;
8416
8417DONE:
8418    ret;
8419}
8420";
8421
8422/// PTX source for `log_kernel`: `out[i] = ln(a[i])`.
8423#[cfg(feature = "cuda")]
8424pub(crate) const LOG_PTX: &str = "\
8425.version 7.0
8426.target sm_52
8427.address_size 64
8428
8429.visible .entry log_kernel(
8430    .param .u64 a_ptr,
8431    .param .u64 out_ptr,
8432    .param .u32 n
8433) {
8434    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8435    .reg .u64 %a, %out, %off;
8436    .reg .f32 %va, %vr;
8437    .reg .pred %p;
8438
8439    ld.param.u64 %a, [a_ptr];
8440    ld.param.u64 %out, [out_ptr];
8441    ld.param.u32 %n_reg, [n];
8442
8443    mov.u32 %bid, %ctaid.x;
8444    mov.u32 %bdim, %ntid.x;
8445    mov.u32 %r_tid, %tid.x;
8446    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8447
8448    setp.ge.u32 %p, %r_tid, %n_reg;
8449    @%p bra DONE;
8450
8451    cvt.u64.u32 %off, %r_tid;
8452    shl.b64 %off, %off, 2;
8453
8454    add.u64 %a, %a, %off;
8455    add.u64 %out, %out, %off;
8456
8457    ld.global.f32 %va, [%a];
8458    // PTX lg2.approx computes log2(x); use the identity ln(x) = log2(x) / log2(e)
8459    // 1/log2(e) = ln(2) = 0.6931471805599453
8460    lg2.approx.f32 %vr, %va;
8461    mul.f32 %vr, %vr, 0f3F317218;
8462    st.global.f32 [%out], %vr;
8463
8464DONE:
8465    ret;
8466}
8467";
8468
8469/// PTX source for `log_f64_kernel`: `out[i] = ln(a[i])` (f64).
8470/// Uses f32 `lg2.approx` via downcast for the transcendental, then upcasts back.
8471/// Accurate to f32 precision (~7 decimal digits), sufficient for deep learning.
8472#[cfg(feature = "cuda")]
8473/// f64 log with full double precision via argument reduction + rational
8474/// approximation.
8475///
8476/// Algorithm: decompose x = 2^n * m (1 <= m < 2), then
8477///   ln(x) = n*ln(2) + ln(m)
8478/// where ln(m) is computed via f = (m-1)/(m+1), ln(m) = 2*f*(1 + f^2/3 + f^4/5 + ...)
8479///
8480/// Accuracy: < 2 ULP across the full f64 range.
8481pub(crate) const LOG_F64_PTX: &str = "\
8482.version 7.0
8483.target sm_52
8484.address_size 64
8485
8486.visible .entry log_f64_kernel(
8487    .param .u64 a_ptr,
8488    .param .u64 out_ptr,
8489    .param .u32 n
8490) {
8491    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8492    .reg .u64 %a, %out, %off;
8493    .reg .u64 %xbits, %mantissa_bits, %bias_bits;
8494    .reg .f64 %x, %vr, %m, %f, %f2, %s, %p;
8495    .reg .f64 %ln2_hi, %ln2_lo, %one, %two;
8496    .reg .s32 %exp_i;
8497    .reg .s64 %exp64;
8498    .reg .f64 %nf;
8499    .reg .pred %p_tid;
8500
8501    ld.param.u64 %a, [a_ptr];
8502    ld.param.u64 %out, [out_ptr];
8503    ld.param.u32 %n_reg, [n];
8504
8505    mov.u32 %bid, %ctaid.x;
8506    mov.u32 %bdim, %ntid.x;
8507    mov.u32 %r_tid, %tid.x;
8508    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8509
8510    setp.ge.u32 %p_tid, %r_tid, %n_reg;
8511    @%p_tid bra DONE;
8512
8513    cvt.u64.u32 %off, %r_tid;
8514    shl.b64 %off, %off, 3;
8515    add.u64 %a, %a, %off;
8516    add.u64 %out, %out, %off;
8517
8518    ld.global.f64 %x, [%a];
8519
8520    mov.f64 %ln2_hi, 0d3FE62E42FEFA39EF;   // ln(2) = 0.6931471805599453
8521    mov.f64 %one, 0d3FF0000000000000;
8522    mov.f64 %two, 0d4000000000000000;
8523
8524    // Extract exponent: n = exponent_field - 1023
8525    mov.b64 %xbits, %x;
8526    shr.u64 %exp64, %xbits, 52;
8527    and.b64 %exp64, %exp64, 2047;   // 11-bit exponent field
8528    sub.s64 %exp64, %exp64, 1023;
8529    cvt.rn.f64.s64 %nf, %exp64;     // n as f64
8530
8531    // Extract mantissa m: set exponent to 1023 (so m is in [1, 2))
8532    mov.u64 %bias_bits, 0x3FF0000000000000;  // exponent = 1023
8533    and.b64 %mantissa_bits, %xbits, 0x000FFFFFFFFFFFFF;  // mantissa bits
8534    or.b64 %mantissa_bits, %mantissa_bits, %bias_bits;
8535    mov.b64 %m, %mantissa_bits;      // m in [1.0, 2.0)
8536
8537    // f = (m - 1) / (m + 1) — maps [1,2) to [0, 1/3)
8538    sub.f64 %f, %m, %one;
8539    add.f64 %s, %m, %one;
8540    div.rn.f64 %f, %f, %s;
8541
8542    // ln(m) = 2*f + 2*f^3/3 + 2*f^5/5 + 2*f^7/7 + 2*f^9/9 + 2*f^11/11
8543    // Horner: ln(m) = 2*f*(1 + f^2*(1/3 + f^2*(1/5 + f^2*(1/7 + f^2*(1/9 + f^2/11)))))
8544    mul.f64 %f2, %f, %f;
8545
8546    // p = 1/11
8547    mov.f64 %p, 0d3FB745D1745D1746;
8548    // p = p*f2 + 1/9
8549    fma.rn.f64 %p, %p, %f2, 0d3FC1C71C71C71C72;
8550    // p = p*f2 + 1/7
8551    fma.rn.f64 %p, %p, %f2, 0d3FC2492492492492;
8552    // p = p*f2 + 1/5
8553    fma.rn.f64 %p, %p, %f2, 0d3FC999999999999A;
8554    // p = p*f2 + 1/3
8555    fma.rn.f64 %p, %p, %f2, 0d3FD5555555555555;
8556    // p = p*f2 + 1
8557    fma.rn.f64 %p, %p, %f2, %one;
8558
8559    // ln(m) = 2*f*p
8560    mul.f64 %p, %p, %f;
8561    add.f64 %p, %p, %p;   // * 2
8562
8563    // ln(x) = n*ln(2) + ln(m)
8564    fma.rn.f64 %vr, %nf, %ln2_hi, %p;
8565
8566    st.global.f64 [%out], %vr;
8567
8568DONE:
8569    ret;
8570}
8571";
8572
8573/// PTX source for `sqrt_kernel`: `out[i] = sqrt(a[i])`.
8574#[cfg(feature = "cuda")]
8575pub(crate) const SQRT_PTX: &str = "\
8576.version 7.0
8577.target sm_52
8578.address_size 64
8579
8580.visible .entry sqrt_kernel(
8581    .param .u64 a_ptr,
8582    .param .u64 out_ptr,
8583    .param .u32 n
8584) {
8585    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8586    .reg .u64 %a, %out, %off;
8587    .reg .f32 %va, %vr;
8588    .reg .pred %p;
8589
8590    ld.param.u64 %a, [a_ptr];
8591    ld.param.u64 %out, [out_ptr];
8592    ld.param.u32 %n_reg, [n];
8593
8594    mov.u32 %bid, %ctaid.x;
8595    mov.u32 %bdim, %ntid.x;
8596    mov.u32 %r_tid, %tid.x;
8597    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8598
8599    setp.ge.u32 %p, %r_tid, %n_reg;
8600    @%p bra DONE;
8601
8602    cvt.u64.u32 %off, %r_tid;
8603    shl.b64 %off, %off, 2;
8604
8605    add.u64 %a, %a, %off;
8606    add.u64 %out, %out, %off;
8607
8608    ld.global.f32 %va, [%a];
8609    sqrt.rn.f32 %vr, %va;
8610    st.global.f32 [%out], %vr;
8611
8612DONE:
8613    ret;
8614}
8615";
8616
8617
8618/// PTX source for `pow_kernel`: `out[i] = a[i] ^ exponent`.
8619/// Uses the identity: x^e = 2^(e * log2(x)).
8620#[cfg(feature = "cuda")]
8621pub(crate) const POW_PTX: &str = "\
8622.version 7.0
8623.target sm_52
8624.address_size 64
8625
8626.visible .entry pow_kernel(
8627    .param .u64 a_ptr,
8628    .param .u64 out_ptr,
8629    .param .f32 exponent,
8630    .param .u32 n
8631) {
8632    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8633    .reg .u64 %a, %out, %off;
8634    .reg .f32 %va, %vr, %exp, %lg;
8635    .reg .pred %p;
8636
8637    ld.param.u64 %a, [a_ptr];
8638    ld.param.u64 %out, [out_ptr];
8639    ld.param.f32 %exp, [exponent];
8640    ld.param.u32 %n_reg, [n];
8641
8642    mov.u32 %bid, %ctaid.x;
8643    mov.u32 %bdim, %ntid.x;
8644    mov.u32 %r_tid, %tid.x;
8645    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8646
8647    setp.ge.u32 %p, %r_tid, %n_reg;
8648    @%p bra DONE;
8649
8650    cvt.u64.u32 %off, %r_tid;
8651    shl.b64 %off, %off, 2;
8652
8653    add.u64 %a, %a, %off;
8654    add.u64 %out, %out, %off;
8655
8656    ld.global.f32 %va, [%a];
8657    // x^e = 2^(e * log2(x))
8658    lg2.approx.f32 %lg, %va;
8659    mul.f32 %lg, %lg, %exp;
8660    ex2.approx.f32 %vr, %lg;
8661    st.global.f32 [%out], %vr;
8662
8663DONE:
8664    ret;
8665}
8666";
8667
8668/// PTX source for `pow_f64_kernel`: `out[i] = a[i] ^ exponent` (f64).
8669/// Full f64 precision: x^e = exp(e * ln(x)).
8670/// Uses inline f64 log (argument reduction + odd-power series) and
8671/// inline f64 exp (Cody-Waite + degree-11 Horner).
8672#[cfg(feature = "cuda")]
8673pub(crate) const POW_F64_PTX: &str = "\
8674.version 7.0
8675.target sm_52
8676.address_size 64
8677
8678.visible .entry pow_f64_kernel(
8679    .param .u64 a_ptr,
8680    .param .u64 out_ptr,
8681    .param .f64 exponent,
8682    .param .u32 n
8683) {
8684    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8685    .reg .u64 %a, %out, %off;
8686    .reg .f64 %va, %vr, %exp64, %one, %two;
8687    // log registers
8688    .reg .u64 %l_xbits, %l_mbits, %l_bias;
8689    .reg .s64 %l_exp64;
8690    .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2, %l_lnx;
8691    // exp registers
8692    .reg .f64 %e_z, %e_nf, %e_r, %e_p, %e_half;
8693    .reg .s32 %e_ni;
8694    .reg .s64 %e_ni64, %e_bits;
8695    .reg .pred %p;
8696
8697    ld.param.u64 %a, [a_ptr];
8698    ld.param.u64 %out, [out_ptr];
8699    ld.param.f64 %exp64, [exponent];
8700    ld.param.u32 %n_reg, [n];
8701
8702    mov.u32 %bid, %ctaid.x;
8703    mov.u32 %bdim, %ntid.x;
8704    mov.u32 %r_tid, %tid.x;
8705    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8706
8707    setp.ge.u32 %p, %r_tid, %n_reg;
8708    @%p bra DONE;
8709
8710    cvt.u64.u32 %off, %r_tid;
8711    shl.b64 %off, %off, 3;
8712
8713    add.u64 %a, %a, %off;
8714    add.u64 %out, %out, %off;
8715
8716    ld.global.f64 %va, [%a];
8717    mov.f64 %one, 0d3FF0000000000000;
8718    mov.f64 %two, 0d4000000000000000;
8719
8720    // === ln(va) via argument reduction ===
8721    // Decompose va = 2^n * m, m in [1,2), ln(va) = n*ln(2) + ln(m)
8722    mov.b64 %l_xbits, %va;
8723    shr.u64 %l_exp64, %l_xbits, 52;
8724    and.b64 %l_exp64, %l_exp64, 2047;
8725    sub.s64 %l_exp64, %l_exp64, 1023;
8726    cvt.rn.f64.s64 %l_nf, %l_exp64;
8727
8728    mov.u64 %l_bias, 0x3FF0000000000000;
8729    and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
8730    or.b64 %l_mbits, %l_mbits, %l_bias;
8731    mov.b64 %l_m, %l_mbits;
8732
8733    // f = (m-1)/(m+1)
8734    sub.f64 %l_f, %l_m, %one;
8735    add.f64 %l_s, %l_m, %one;
8736    div.rn.f64 %l_f, %l_f, %l_s;
8737    mul.f64 %l_f2, %l_f, %l_f;
8738
8739    // Horner: p = 1/11 + f2*(1/9 + f2*(1/7 + f2*(1/5 + f2*(1/3 + f2*1))))
8740    mov.f64 %l_p, 0d3FB745D1745D1746;
8741    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
8742    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
8743    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
8744    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
8745    fma.rn.f64 %l_p, %l_p, %l_f2, %one;
8746
8747    // ln(m) = 2*f*p
8748    mul.f64 %l_p, %l_p, %l_f;
8749    add.f64 %l_p, %l_p, %l_p;
8750
8751    // ln(x) = n*ln(2) + ln(m)
8752    mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
8753    fma.rn.f64 %l_lnx, %l_nf, %l_ln2, %l_p;
8754
8755    // === exp(exponent * ln(x)) ===
8756    mul.f64 %e_z, %exp64, %l_lnx;
8757
8758    mov.f64 %e_half, 0d3FE0000000000000;
8759    fma.rn.f64 %e_nf, %e_z, 0d3FF71547652B82FE, %e_half;
8760    cvt.rmi.f64.f64 %e_nf, %e_nf;
8761    cvt.rni.s32.f64 %e_ni, %e_nf;
8762    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %e_z;
8763    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
8764    mov.f64 %e_p, 0d3E21EED8EFF8D898;
8765    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
8766    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
8767    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
8768    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
8769    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
8770    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
8771    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
8772    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
8773    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
8774    fma.rn.f64 %e_p, %e_p, %e_r, %one;
8775    fma.rn.f64 %vr, %e_p, %e_r, %one;
8776    cvt.s64.s32 %e_ni64, %e_ni;
8777    add.s64 %e_ni64, %e_ni64, 1023;
8778    shl.b64 %e_bits, %e_ni64, 52;
8779    mov.b64 %e_nf, %e_bits;
8780    mul.f64 %vr, %vr, %e_nf;
8781
8782    st.global.f64 [%out], %vr;
8783
8784DONE:
8785    ret;
8786}
8787";
8788
8789/// PTX source for `abs_kernel`: `out[i] = |a[i]|`.
8790#[cfg(feature = "cuda")]
8791pub(crate) const ABS_PTX: &str = "\
8792.version 7.0
8793.target sm_52
8794.address_size 64
8795
8796.visible .entry abs_kernel(
8797    .param .u64 a_ptr,
8798    .param .u64 out_ptr,
8799    .param .u32 n
8800) {
8801    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8802    .reg .u64 %a, %out, %off;
8803    .reg .f32 %va, %vr;
8804    .reg .pred %p;
8805
8806    ld.param.u64 %a, [a_ptr];
8807    ld.param.u64 %out, [out_ptr];
8808    ld.param.u32 %n_reg, [n];
8809
8810    mov.u32 %bid, %ctaid.x;
8811    mov.u32 %bdim, %ntid.x;
8812    mov.u32 %r_tid, %tid.x;
8813    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8814
8815    setp.ge.u32 %p, %r_tid, %n_reg;
8816    @%p bra DONE;
8817
8818    cvt.u64.u32 %off, %r_tid;
8819    shl.b64 %off, %off, 2;
8820
8821    add.u64 %a, %a, %off;
8822    add.u64 %out, %out, %off;
8823
8824    ld.global.f32 %va, [%a];
8825    abs.f32 %vr, %va;
8826    st.global.f32 [%out], %vr;
8827
8828DONE:
8829    ret;
8830}
8831";
8832
8833
8834/// PTX source for `sigmoid_kernel`: `out[i] = 1 / (1 + exp(-a[i]))`.
8835#[cfg(feature = "cuda")]
8836pub(crate) const SIGMOID_PTX: &str = "\
8837.version 7.0
8838.target sm_52
8839.address_size 64
8840
8841.visible .entry sigmoid_kernel(
8842    .param .u64 a_ptr,
8843    .param .u64 out_ptr,
8844    .param .u32 n
8845) {
8846    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8847    .reg .u64 %a, %out, %off;
8848    .reg .f32 %va, %vr, %neg, %e, %denom, %one, %lg2e;
8849    .reg .pred %p;
8850
8851    ld.param.u64 %a, [a_ptr];
8852    ld.param.u64 %out, [out_ptr];
8853    ld.param.u32 %n_reg, [n];
8854
8855    mov.u32 %bid, %ctaid.x;
8856    mov.u32 %bdim, %ntid.x;
8857    mov.u32 %r_tid, %tid.x;
8858    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8859
8860    setp.ge.u32 %p, %r_tid, %n_reg;
8861    @%p bra DONE;
8862
8863    cvt.u64.u32 %off, %r_tid;
8864    shl.b64 %off, %off, 2;
8865
8866    add.u64 %a, %a, %off;
8867    add.u64 %out, %out, %off;
8868
8869    ld.global.f32 %va, [%a];
8870    // sigmoid(x) = 1 / (1 + exp(-x))
8871    neg.f32 %neg, %va;
8872    mov.f32 %lg2e, 0f3FB8AA3B;
8873    mul.f32 %neg, %neg, %lg2e;
8874    ex2.approx.f32 %e, %neg;
8875    mov.f32 %one, 0f3F800000;
8876    add.f32 %denom, %one, %e;
8877    div.rn.f32 %vr, %one, %denom;
8878    st.global.f32 [%out], %vr;
8879
8880DONE:
8881    ret;
8882}
8883";
8884
8885/// PTX source for `sigmoid_f64_kernel`: `out[i] = 1 / (1 + exp(-a[i]))` (f64).
8886/// Full f64 precision: Cody-Waite range reduction + degree-11 Horner polynomial
8887/// for exp(-x), then sigmoid = 1/(1+exp(-x)).
8888#[cfg(feature = "cuda")]
8889pub(crate) const SIGMOID_F64_PTX: &str = "\
8890.version 7.0
8891.target sm_52
8892.address_size 64
8893
8894.visible .entry sigmoid_f64_kernel(
8895    .param .u64 a_ptr,
8896    .param .u64 out_ptr,
8897    .param .u32 n
8898) {
8899    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8900    .reg .u64 %a, %out, %off;
8901    .reg .f64 %va, %vr, %e64, %denom, %one, %neg_x;
8902    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
8903    .reg .s32 %e_ni;
8904    .reg .s64 %e_ni64, %e_bits;
8905    .reg .pred %p;
8906
8907    ld.param.u64 %a, [a_ptr];
8908    ld.param.u64 %out, [out_ptr];
8909    ld.param.u32 %n_reg, [n];
8910
8911    mov.u32 %bid, %ctaid.x;
8912    mov.u32 %bdim, %ntid.x;
8913    mov.u32 %r_tid, %tid.x;
8914    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8915
8916    setp.ge.u32 %p, %r_tid, %n_reg;
8917    @%p bra DONE;
8918
8919    cvt.u64.u32 %off, %r_tid;
8920    shl.b64 %off, %off, 3;
8921
8922    add.u64 %a, %a, %off;
8923    add.u64 %out, %out, %off;
8924
8925    ld.global.f64 %va, [%a];
8926    mov.f64 %one, 0d3FF0000000000000;
8927
8928    // sigmoid(x) = 1 / (1 + exp(-x))
8929    neg.f64 %neg_x, %va;
8930
8931    // --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
8932    mov.f64 %e_half, 0d3FE0000000000000;
8933    fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
8934    cvt.rmi.f64.f64 %e_nf, %e_nf;
8935    cvt.rni.s32.f64 %e_ni, %e_nf;
8936    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
8937    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
8938    mov.f64 %e_p, 0d3E21EED8EFF8D898;
8939    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
8940    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
8941    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
8942    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
8943    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
8944    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
8945    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
8946    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
8947    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
8948    fma.rn.f64 %e_p, %e_p, %e_r, %one;
8949    fma.rn.f64 %e64, %e_p, %e_r, %one;
8950    cvt.s64.s32 %e_ni64, %e_ni;
8951    add.s64 %e_ni64, %e_ni64, 1023;
8952    shl.b64 %e_bits, %e_ni64, 52;
8953    mov.b64 %e_nf, %e_bits;
8954    mul.f64 %e64, %e64, %e_nf;
8955    // --- end exp ---
8956
8957    add.f64 %denom, %one, %e64;
8958    div.rn.f64 %vr, %one, %denom;
8959    st.global.f64 [%out], %vr;
8960
8961DONE:
8962    ret;
8963}
8964";
8965
8966/// PTX source for `tanh_kernel`: `out[i] = tanh(a[i])`.
8967/// Uses the identity: tanh(x) = 2*sigmoid(2x) - 1.
8968#[cfg(feature = "cuda")]
8969pub(crate) const TANH_PTX: &str = "\
8970.version 7.0
8971.target sm_52
8972.address_size 64
8973
8974.visible .entry tanh_kernel(
8975    .param .u64 a_ptr,
8976    .param .u64 out_ptr,
8977    .param .u32 n
8978) {
8979    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8980    .reg .u64 %a, %out, %off;
8981    .reg .f32 %va, %vr, %neg2x, %e, %denom, %sig, %one, %two, %lg2e;
8982    .reg .pred %p;
8983
8984    ld.param.u64 %a, [a_ptr];
8985    ld.param.u64 %out, [out_ptr];
8986    ld.param.u32 %n_reg, [n];
8987
8988    mov.u32 %bid, %ctaid.x;
8989    mov.u32 %bdim, %ntid.x;
8990    mov.u32 %r_tid, %tid.x;
8991    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8992
8993    setp.ge.u32 %p, %r_tid, %n_reg;
8994    @%p bra DONE;
8995
8996    cvt.u64.u32 %off, %r_tid;
8997    shl.b64 %off, %off, 2;
8998
8999    add.u64 %a, %a, %off;
9000    add.u64 %out, %out, %off;
9001
9002    ld.global.f32 %va, [%a];
9003    // tanh(x) = 2*sigmoid(2x) - 1
9004    mov.f32 %two, 0f40000000;
9005    mul.f32 %neg2x, %va, %two;
9006    neg.f32 %neg2x, %neg2x;
9007    mov.f32 %lg2e, 0f3FB8AA3B;
9008    mul.f32 %neg2x, %neg2x, %lg2e;
9009    ex2.approx.f32 %e, %neg2x;
9010    mov.f32 %one, 0f3F800000;
9011    add.f32 %denom, %one, %e;
9012    div.rn.f32 %sig, %one, %denom;
9013    mul.f32 %vr, %two, %sig;
9014    sub.f32 %vr, %vr, %one;
9015    st.global.f32 [%out], %vr;
9016
9017DONE:
9018    ret;
9019}
9020";
9021
9022/// PTX source for `tanh_f64_kernel`: `out[i] = tanh(a[i])` (f64).
9023/// Uses the identity: tanh(x) = 2*sigmoid(2x) - 1 = (1-exp(-2x))/(1+exp(-2x)).
9024/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(-2x).
9025#[cfg(feature = "cuda")]
9026pub(crate) const TANH_F64_PTX: &str = "\
9027.version 7.0
9028.target sm_52
9029.address_size 64
9030
9031.visible .entry tanh_f64_kernel(
9032    .param .u64 a_ptr,
9033    .param .u64 out_ptr,
9034    .param .u32 n
9035) {
9036    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
9037    .reg .u64 %a, %out, %off;
9038    .reg .f64 %va, %vr, %e64, %num, %denom, %one, %two, %neg2x;
9039    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
9040    .reg .s32 %e_ni;
9041    .reg .s64 %e_ni64, %e_bits;
9042    .reg .pred %p;
9043
9044    ld.param.u64 %a, [a_ptr];
9045    ld.param.u64 %out, [out_ptr];
9046    ld.param.u32 %n_reg, [n];
9047
9048    mov.u32 %bid, %ctaid.x;
9049    mov.u32 %bdim, %ntid.x;
9050    mov.u32 %r_tid, %tid.x;
9051    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
9052
9053    setp.ge.u32 %p, %r_tid, %n_reg;
9054    @%p bra DONE;
9055
9056    cvt.u64.u32 %off, %r_tid;
9057    shl.b64 %off, %off, 3;
9058
9059    add.u64 %a, %a, %off;
9060    add.u64 %out, %out, %off;
9061
9062    ld.global.f64 %va, [%a];
9063    mov.f64 %one, 0d3FF0000000000000;
9064    mov.f64 %two, 0d4000000000000000;
9065
9066    // tanh(x) = (1 - exp(-2x)) / (1 + exp(-2x))
9067    mul.f64 %neg2x, %va, %two;
9068    neg.f64 %neg2x, %neg2x;
9069
9070    // --- exp(%neg2x) via Cody-Waite + degree-11 Horner ---
9071    mov.f64 %e_half, 0d3FE0000000000000;
9072    fma.rn.f64 %e_nf, %neg2x, 0d3FF71547652B82FE, %e_half;
9073    cvt.rmi.f64.f64 %e_nf, %e_nf;
9074    cvt.rni.s32.f64 %e_ni, %e_nf;
9075    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg2x;
9076    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
9077    mov.f64 %e_p, 0d3E21EED8EFF8D898;
9078    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
9079    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
9080    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
9081    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
9082    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
9083    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
9084    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
9085    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
9086    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
9087    fma.rn.f64 %e_p, %e_p, %e_r, %one;
9088    fma.rn.f64 %e64, %e_p, %e_r, %one;
9089    cvt.s64.s32 %e_ni64, %e_ni;
9090    add.s64 %e_ni64, %e_ni64, 1023;
9091    shl.b64 %e_bits, %e_ni64, 52;
9092    mov.b64 %e_nf, %e_bits;
9093    mul.f64 %e64, %e64, %e_nf;
9094    // --- end exp ---
9095
9096    sub.f64 %num, %one, %e64;
9097    add.f64 %denom, %one, %e64;
9098    div.rn.f64 %vr, %num, %denom;
9099    st.global.f64 [%out], %vr;
9100
9101DONE:
9102    ret;
9103}
9104";
9105
9106/// PTX source for `fused_adam_kernel`: in-place Adam optimizer update.
9107///
9108/// For each element i:
9109///   g = grad[i] + weight_decay * param[i]  (if wd > 0)
9110///   exp_avg[i] = beta1 * exp_avg[i] + (1-beta1) * g
9111///   exp_avg_sq[i] = beta2 * exp_avg_sq[i] + (1-beta2) * g * g
9112///   m_hat = exp_avg[i] / bc1
9113///   v_hat = exp_avg_sq[i] / bc2
9114///   param[i] = param[i] - lr * m_hat / (sqrt(v_hat) + eps)
9115#[cfg(feature = "cuda")]
9116pub(crate) const FUSED_ADAM_PTX: &str = "\
9117.version 7.0
9118.target sm_52
9119.address_size 64
9120
9121.visible .entry fused_adam_kernel(
9122    .param .u64 param_ptr,
9123    .param .u64 grad_ptr,
9124    .param .u64 exp_avg_ptr,
9125    .param .u64 exp_avg_sq_ptr,
9126    .param .f32 beta1,
9127    .param .f32 beta2,
9128    .param .f32 lr,
9129    .param .f32 eps,
9130    .param .f32 bc1,
9131    .param .f32 bc2,
9132    .param .f32 weight_decay,
9133    .param .u32 n
9134) {
9135    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
9136    .reg .u64 %p, %g, %m, %v, %off;
9137    .reg .f32 %vp, %vg, %vm, %vv;
9138    .reg .f32 %b1, %b2, %f_lr, %f_eps, %f_bc1, %f_bc2, %f_wd;
9139    .reg .f32 %t1, %t2, %m_hat, %v_hat, %denom, %update;
9140    .reg .f32 %one;
9141    .reg .pred %p_bound, %p_wd;
9142
9143    ld.param.u64 %p, [param_ptr];
9144    ld.param.u64 %g, [grad_ptr];
9145    ld.param.u64 %m, [exp_avg_ptr];
9146    ld.param.u64 %v, [exp_avg_sq_ptr];
9147    ld.param.f32 %b1, [beta1];
9148    ld.param.f32 %b2, [beta2];
9149    ld.param.f32 %f_lr, [lr];
9150    ld.param.f32 %f_eps, [eps];
9151    ld.param.f32 %f_bc1, [bc1];
9152    ld.param.f32 %f_bc2, [bc2];
9153    ld.param.f32 %f_wd, [weight_decay];
9154    ld.param.u32 %n_reg, [n];
9155
9156    mov.u32 %bid, %ctaid.x;
9157    mov.u32 %bdim, %ntid.x;
9158    mov.u32 %r_tid, %tid.x;
9159    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
9160
9161    setp.ge.u32 %p_bound, %r_tid, %n_reg;
9162    @%p_bound bra DONE;
9163
9164    cvt.u64.u32 %off, %r_tid;
9165    shl.b64 %off, %off, 2;
9166
9167    add.u64 %p, %p, %off;
9168    add.u64 %g, %g, %off;
9169    add.u64 %m, %m, %off;
9170    add.u64 %v, %v, %off;
9171
9172    ld.global.f32 %vp, [%p];
9173    ld.global.f32 %vg, [%g];
9174    ld.global.f32 %vm, [%m];
9175    ld.global.f32 %vv, [%v];
9176
9177    // L2 weight decay: g = g + wd * p
9178    mov.f32 %one, 0f00000000;
9179    setp.gt.f32 %p_wd, %f_wd, %one;
9180    @%p_wd fma.rn.f32 %vg, %f_wd, %vp, %vg;
9181
9182    // exp_avg = beta1 * exp_avg + (1 - beta1) * g
9183    mov.f32 %one, 0f3F800000;
9184    sub.f32 %t1, %one, %b1;
9185    mul.f32 %vm, %vm, %b1;
9186    fma.rn.f32 %vm, %t1, %vg, %vm;
9187
9188    // exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * g * g
9189    sub.f32 %t2, %one, %b2;
9190    mul.f32 %vv, %vv, %b2;
9191    mul.f32 %t1, %vg, %vg;
9192    fma.rn.f32 %vv, %t2, %t1, %vv;
9193
9194    // m_hat = exp_avg / bc1
9195    div.rn.f32 %m_hat, %vm, %f_bc1;
9196
9197    // v_hat = exp_avg_sq / bc2
9198    div.rn.f32 %v_hat, %vv, %f_bc2;
9199
9200    // denom = sqrt(v_hat) + eps
9201    sqrt.rn.f32 %denom, %v_hat;
9202    add.f32 %denom, %denom, %f_eps;
9203
9204    // param = param - lr * m_hat / denom
9205    div.rn.f32 %update, %m_hat, %denom;
9206    mul.f32 %update, %update, %f_lr;
9207    sub.f32 %vp, %vp, %update;
9208
9209    st.global.f32 [%p], %vp;
9210    st.global.f32 [%m], %vm;
9211    st.global.f32 [%v], %vv;
9212
9213DONE:
9214    ret;
9215}
9216";
9217
9218/// PTX source for fused GRU cell forward kernel.
9219///
9220/// Takes pre-computed input_gates [B, 3*H] and hidden_gates [B, 3*H]
9221/// (from cuBLAS GEMMs), biases, and previous hidden state. Computes all
9222/// gate activations and the new hidden state in a single kernel launch.
9223///
9224/// One thread per hidden unit. Each thread reads 3 values from input_gates
9225/// and 3 from hidden_gates, applies sigmoid/tanh, computes the GRU update,
9226/// and writes hy + workspace (5*H values for backward).
9227///
9228/// Matches PyTorch's _thnn_fused_gru_cell kernel from RNN.cu.
9229#[cfg(feature = "cuda")]
9230pub(crate) const FUSED_GRU_FORWARD_PTX: &str = "\
9231.version 7.0
9232.target sm_52
9233.address_size 64
9234
9235.visible .entry fused_gru_forward_kernel(
9236    .param .u64 input_gates_ptr,
9237    .param .u64 hidden_gates_ptr,
9238    .param .u64 bias_ih_ptr,
9239    .param .u64 bias_hh_ptr,
9240    .param .u64 hx_ptr,
9241    .param .u64 hy_ptr,
9242    .param .u64 workspace_ptr,
9243    .param .u32 hsz,
9244    .param .u32 total
9245) {
9246    .reg .u32 %tid, %bid, %bdim, %gdim, %total_reg, %hsz_reg;
9247    .reg .u32 %idx, %stride, %offset3, %offset5, %hmod, %batch_idx;
9248    .reg .u64 %ig, %hg, %b1, %b2, %hx, %hy, %ws;
9249    .reg .u64 %off64, %tmp64;
9250    .reg .f32 %ir, %ii, %in, %hr, %hi, %hn;
9251    .reg .f32 %b1r, %b1i, %b1n, %b2r, %b2i, %b2n;
9252    .reg .f32 %hx_val, %rg, %zg, %ng, %hy_val;
9253    .reg .f32 %one, %neg_one, %exp_val, %denom, %tmp;
9254    .reg .pred %p;
9255
9256    ld.param.u64 %ig, [input_gates_ptr];
9257    ld.param.u64 %hg, [hidden_gates_ptr];
9258    ld.param.u64 %b1, [bias_ih_ptr];
9259    ld.param.u64 %b2, [bias_hh_ptr];
9260    ld.param.u64 %hx, [hx_ptr];
9261    ld.param.u64 %hy, [hy_ptr];
9262    ld.param.u64 %ws, [workspace_ptr];
9263    ld.param.u32 %hsz_reg, [hsz];
9264    ld.param.u32 %total_reg, [total];
9265
9266    mov.u32 %bid, %ctaid.x;
9267    mov.u32 %bdim, %ntid.x;
9268    mov.u32 %tid, %tid.x;
9269    mov.u32 %gdim, %nctaid.x;
9270    mad.lo.u32 %idx, %bid, %bdim, %tid;
9271    mul.lo.u32 %stride, %bdim, %gdim;
9272    mov.f32 %one, 0f3F800000;
9273
9274LOOP:
9275    setp.ge.u32 %p, %idx, %total_reg;
9276    @%p bra END;
9277
9278    // offset3 = (idx/hsz)*3*hsz + idx%hsz  (into [B, 3*H] gates tensor)
9279    div.u32 %batch_idx, %idx, %hsz_reg;
9280    rem.u32 %hmod, %idx, %hsz_reg;
9281    mul.lo.u32 %offset3, %batch_idx, %hsz_reg;
9282    mul.lo.u32 %offset3, %offset3, 3;
9283    add.u32 %offset3, %offset3, %hmod;
9284
9285    // Load input gate components: ir, ii, in
9286    cvt.u64.u32 %off64, %offset3;
9287    shl.b64 %off64, %off64, 2;
9288    add.u64 %tmp64, %ig, %off64;
9289    ld.global.f32 %ir, [%tmp64];
9290    cvt.u64.u32 %off64, %hsz_reg;
9291    shl.b64 %off64, %off64, 2;
9292    add.u64 %tmp64, %tmp64, %off64;
9293    ld.global.f32 %ii, [%tmp64];
9294    add.u64 %tmp64, %tmp64, %off64;
9295    ld.global.f32 %in, [%tmp64];
9296
9297    // Load hidden gate components: hr, hi, hn
9298    cvt.u64.u32 %off64, %offset3;
9299    shl.b64 %off64, %off64, 2;
9300    add.u64 %tmp64, %hg, %off64;
9301    ld.global.f32 %hr, [%tmp64];
9302    cvt.u64.u32 %off64, %hsz_reg;
9303    shl.b64 %off64, %off64, 2;
9304    add.u64 %tmp64, %tmp64, %off64;
9305    ld.global.f32 %hi, [%tmp64];
9306    add.u64 %tmp64, %tmp64, %off64;
9307    ld.global.f32 %hn, [%tmp64];
9308
9309    // Load biases (indexed by hmod, hmod+hsz, hmod+2*hsz)
9310    cvt.u64.u32 %off64, %hmod;
9311    shl.b64 %off64, %off64, 2;
9312    add.u64 %tmp64, %b1, %off64;
9313    ld.global.f32 %b1r, [%tmp64];
9314    cvt.u64.u32 %off64, %hsz_reg;
9315    shl.b64 %off64, %off64, 2;
9316    add.u64 %tmp64, %tmp64, %off64;
9317    ld.global.f32 %b1i, [%tmp64];
9318    add.u64 %tmp64, %tmp64, %off64;
9319    ld.global.f32 %b1n, [%tmp64];
9320
9321    cvt.u64.u32 %off64, %hmod;
9322    shl.b64 %off64, %off64, 2;
9323    add.u64 %tmp64, %b2, %off64;
9324    ld.global.f32 %b2r, [%tmp64];
9325    cvt.u64.u32 %off64, %hsz_reg;
9326    shl.b64 %off64, %off64, 2;
9327    add.u64 %tmp64, %tmp64, %off64;
9328    ld.global.f32 %b2i, [%tmp64];
9329    add.u64 %tmp64, %tmp64, %off64;
9330    ld.global.f32 %b2n, [%tmp64];
9331
9332    // Load hx[idx]
9333    cvt.u64.u32 %off64, %idx;
9334    shl.b64 %off64, %off64, 2;
9335    add.u64 %tmp64, %hx, %off64;
9336    ld.global.f32 %hx_val, [%tmp64];
9337
9338    // r = sigmoid(ir + hr + b1r + b2r)
9339    add.f32 %rg, %ir, %hr;
9340    add.f32 %rg, %rg, %b1r;
9341    add.f32 %rg, %rg, %b2r;
9342    neg.f32 %tmp, %rg;
9343    mul.f32 %tmp, %tmp, 0f3FB8AA3B;
9344    ex2.approx.f32 %exp_val, %tmp;
9345    add.f32 %denom, %one, %exp_val;
9346    div.rn.f32 %rg, %one, %denom;
9347
9348    // z = sigmoid(ii + hi + b1i + b2i)
9349    add.f32 %zg, %ii, %hi;
9350    add.f32 %zg, %zg, %b1i;
9351    add.f32 %zg, %zg, %b2i;
9352    neg.f32 %tmp, %zg;
9353    mul.f32 %tmp, %tmp, 0f3FB8AA3B;
9354    ex2.approx.f32 %exp_val, %tmp;
9355    add.f32 %denom, %one, %exp_val;
9356    div.rn.f32 %zg, %one, %denom;
9357
9358    // n = tanh(in + b1n + r*(hn + b2n))
9359    add.f32 %tmp, %hn, %b2n;
9360    fma.rn.f32 %ng, %rg, %tmp, %in;
9361    add.f32 %ng, %ng, %b1n;
9362    // tanh via 2*sigmoid(2x)-1
9363    mul.f32 %tmp, %ng, 0f40000000;
9364    neg.f32 %tmp, %tmp;
9365    mul.f32 %tmp, %tmp, 0f3FB8AA3B;
9366    ex2.approx.f32 %exp_val, %tmp;
9367    add.f32 %denom, %one, %exp_val;
9368    div.rn.f32 %ng, %one, %denom;
9369    mul.f32 %ng, %ng, 0f40000000;
9370    sub.f32 %ng, %ng, %one;
9371
9372    // hy = n + z * (hx - n)
9373    sub.f32 %tmp, %hx_val, %ng;
9374    fma.rn.f32 %hy_val, %zg, %tmp, %ng;
9375
9376    // Store hy[idx]
9377    cvt.u64.u32 %off64, %idx;
9378    shl.b64 %off64, %off64, 2;
9379    add.u64 %tmp64, %hy, %off64;
9380    st.global.f32 [%tmp64], %hy_val;
9381
9382    // Store workspace: [r, z, n, hx, hn+b2n] at offset5 = (idx/hsz)*5*hsz + idx%hsz
9383    mul.lo.u32 %offset5, %batch_idx, %hsz_reg;
9384    mul.lo.u32 %offset5, %offset5, 5;
9385    add.u32 %offset5, %offset5, %hmod;
9386
9387    cvt.u64.u32 %off64, %offset5;
9388    shl.b64 %off64, %off64, 2;
9389    add.u64 %tmp64, %ws, %off64;
9390    st.global.f32 [%tmp64], %rg;
9391    cvt.u64.u32 %off64, %hsz_reg;
9392    shl.b64 %off64, %off64, 2;
9393    add.u64 %tmp64, %tmp64, %off64;
9394    st.global.f32 [%tmp64], %zg;
9395    add.u64 %tmp64, %tmp64, %off64;
9396    st.global.f32 [%tmp64], %ng;
9397    add.u64 %tmp64, %tmp64, %off64;
9398    st.global.f32 [%tmp64], %hx_val;
9399    add.u64 %tmp64, %tmp64, %off64;
9400    add.f32 %tmp, %hn, %b2n;
9401    st.global.f32 [%tmp64], %tmp;
9402
9403    add.u32 %idx, %idx, %stride;
9404    bra LOOP;
9405
9406END:
9407    ret;
9408}
9409";
9410
9411// ---------------------------------------------------------------------------
9412// Launch configuration helper
9413// ---------------------------------------------------------------------------
9414
9415/// Standard 1-D launch config for `n` elements.
9416///
9417/// Uses 256 threads per block, which is a good default for elementwise ops
9418/// on all modern NVIDIA architectures.
9419///
9420/// # Errors
9421///
9422/// Returns [`GpuError::ShapeMismatch`] if `n` exceeds `u32::MAX`, which
9423/// would silently truncate the grid dimension.
9424#[cfg(feature = "cuda")]
9425fn launch_cfg(n: usize) -> GpuResult<LaunchConfig> {
9426    if n > u32::MAX as usize {
9427        return Err(GpuError::ShapeMismatch {
9428            op: "kernel_launch",
9429            expected: vec![u32::MAX as usize],
9430            got: vec![n],
9431        });
9432    }
9433    const BLOCK: u32 = 256;
9434    let grid = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
9435    Ok(LaunchConfig {
9436        grid_dim: (grid.max(1), 1, 1),
9437        block_dim: (BLOCK, 1, 1),
9438        shared_mem_bytes: 0,
9439    })
9440}
9441
9442// ---------------------------------------------------------------------------
9443// Validation helpers
9444// ---------------------------------------------------------------------------
9445
9446/// Validate that two buffers are on the same device and have the same length.
9447#[cfg(feature = "cuda")]
9448fn validate_binary(a: &CudaBuffer<f32>, b: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
9449    if a.device_ordinal() != device.ordinal() {
9450        return Err(GpuError::DeviceMismatch {
9451            expected: a.device_ordinal(),
9452            got: device.ordinal(),
9453        });
9454    }
9455    if b.device_ordinal() != device.ordinal() {
9456        return Err(GpuError::DeviceMismatch {
9457            expected: b.device_ordinal(),
9458            got: device.ordinal(),
9459        });
9460    }
9461    if a.len() != b.len() {
9462        return Err(GpuError::LengthMismatch {
9463            a: a.len(),
9464            b: b.len(),
9465        });
9466    }
9467    Ok(())
9468}
9469
9470/// Validate that a unary buffer is on the correct device.
9471#[cfg(feature = "cuda")]
9472fn validate_unary(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
9473    if a.device_ordinal() != device.ordinal() {
9474        return Err(GpuError::DeviceMismatch {
9475            expected: a.device_ordinal(),
9476            got: device.ordinal(),
9477        });
9478    }
9479    Ok(())
9480}
9481
9482/// Generic device-ordinal check for any `CudaBuffer<T>`.
9483#[cfg(feature = "cuda")]
9484fn validate_device<T>(a: &CudaBuffer<T>, device: &GpuDevice) -> GpuResult<()> {
9485    if a.device_ordinal() != device.ordinal() {
9486        return Err(GpuError::DeviceMismatch {
9487            expected: a.device_ordinal(),
9488            got: device.ordinal(),
9489        });
9490    }
9491    Ok(())
9492}
9493
9494// ---------------------------------------------------------------------------
9495// PTX kernel launch helpers
9496// ---------------------------------------------------------------------------
9497
9498/// Try to launch a binary PTX kernel. Returns `Ok(Some(buf))` on success,
9499/// `Ok(None)` if the PTX module failed to load (caller should fall back to
9500/// CPU), or `Err` on a real CUDA error after a successful launch.
9501#[cfg(feature = "cuda")]
9502fn try_launch_binary(
9503    a: &CudaBuffer<f32>,
9504    b: &CudaBuffer<f32>,
9505    device: &GpuDevice,
9506    ptx_src: &'static str,
9507    kernel_name: &'static str,
9508) -> GpuResult<Option<CudaBuffer<f32>>> {
9509    use cudarc::driver::PushKernelArg;
9510
9511    let n = a.len();
9512    let ctx = device.context();
9513    let stream = device.stream();
9514
9515    // Attempt to load the kernel (cached after first compilation).
9516    // If it fails (e.g. unsupported arch), return None so the caller
9517    // can use the CPU fallback.
9518    let f = match crate::module_cache::get_or_compile(
9519        ctx,
9520        ptx_src,
9521        kernel_name,
9522        device.ordinal() as u32,
9523    ) {
9524        Ok(f) => f,
9525        Err(_) => return Ok(None),
9526    };
9527
9528    let mut out = alloc_zeros_f32(n, device)?;
9529    let cfg = launch_cfg(n)?;
9530    let n_u32 = n as u32;
9531
9532    // SAFETY: The kernel reads `n` f32 values from `a` and `b`, writes `n`
9533    // f32 values to `out`. All three buffers are device-resident and at
9534    // least `n` elements long. The grid covers exactly `n` threads.
9535    unsafe {
9536        stream
9537            .launch_builder(&f)
9538            .arg(a.inner())
9539            .arg(b.inner())
9540            .arg(out.inner_mut())
9541            .arg(&n_u32)
9542            .launch(cfg)?;
9543    }
9544
9545    Ok(Some(out))
9546}
9547
9548/// Try to launch a vectorized (vec4) binary PTX kernel.
9549///
9550/// Each thread processes 4 elements using 128-bit loads/stores.
9551/// `n` must be divisible by 4. Returns `Ok(None)` if compilation fails.
9552#[cfg(feature = "cuda")]
9553fn try_launch_binary_vec4(
9554    a: &CudaBuffer<f32>,
9555    b: &CudaBuffer<f32>,
9556    device: &GpuDevice,
9557    ptx_src: &'static str,
9558    kernel_name: &'static str,
9559) -> GpuResult<Option<CudaBuffer<f32>>> {
9560    use cudarc::driver::PushKernelArg;
9561
9562    let n = a.len();
9563    let n4 = (n / 4) as u32;
9564    let ctx = device.context();
9565    let stream = device.stream();
9566
9567    let f = match crate::module_cache::get_or_compile(
9568        ctx,
9569        ptx_src,
9570        kernel_name,
9571        device.ordinal() as u32,
9572    ) {
9573        Ok(f) => f,
9574        Err(_) => return Ok(None),
9575    };
9576
9577    let mut out = alloc_zeros_f32(n, device)?;
9578    let cfg = launch_cfg(n4 as usize)?;
9579
9580    unsafe {
9581        stream
9582            .launch_builder(&f)
9583            .arg(a.inner())
9584            .arg(b.inner())
9585            .arg(out.inner_mut())
9586            .arg(&n4)
9587            .launch(cfg)?;
9588    }
9589
9590    Ok(Some(out))
9591}
9592
9593/// Try to launch a unary PTX kernel. Returns `Ok(Some(buf))` on success,
9594/// `Ok(None)` if the PTX module failed to load.
9595#[cfg(feature = "cuda")]
9596fn try_launch_unary(
9597    a: &CudaBuffer<f32>,
9598    device: &GpuDevice,
9599    ptx_src: &'static str,
9600    kernel_name: &'static str,
9601) -> GpuResult<Option<CudaBuffer<f32>>> {
9602    use cudarc::driver::PushKernelArg;
9603
9604    let n = a.len();
9605    let ctx = device.context();
9606    let stream = device.stream();
9607
9608    // Attempt to load the kernel (cached after first compilation).
9609    let f = match crate::module_cache::get_or_compile(
9610        ctx,
9611        ptx_src,
9612        kernel_name,
9613        device.ordinal() as u32,
9614    ) {
9615        Ok(f) => f,
9616        Err(_) => return Ok(None),
9617    };
9618
9619    let mut out = alloc_zeros_f32(n, device)?;
9620    let cfg = launch_cfg(n)?;
9621    let n_u32 = n as u32;
9622
9623    // SAFETY: The kernel reads `n` f32 values from `a` and writes `n` f32
9624    // values to `out`. Both buffers are device-resident with length >= n.
9625    unsafe {
9626        stream
9627            .launch_builder(&f)
9628            .arg(a.inner())
9629            .arg(out.inner_mut())
9630            .arg(&n_u32)
9631            .launch(cfg)?;
9632    }
9633
9634    Ok(Some(out))
9635}
9636
9637// ---------------------------------------------------------------------------
9638// _into helpers — write to pre-allocated output buffer (no allocation)
9639// ---------------------------------------------------------------------------
9640
9641/// Launch a binary PTX kernel into a pre-allocated output buffer.
9642/// Returns `Ok(true)` on success, `Ok(false)` if the PTX module failed to load.
9643#[cfg(feature = "cuda")]
9644fn try_launch_binary_into(
9645    a: &CudaBuffer<f32>,
9646    b: &CudaBuffer<f32>,
9647    out: &mut CudaBuffer<f32>,
9648    device: &GpuDevice,
9649    ptx_src: &'static str,
9650    kernel_name: &'static str,
9651) -> GpuResult<bool> {
9652    use cudarc::driver::PushKernelArg;
9653
9654    let n = a.len();
9655    let ctx = device.context();
9656    let stream = device.stream();
9657
9658    let f = match crate::module_cache::get_or_compile(
9659        ctx,
9660        ptx_src,
9661        kernel_name,
9662        device.ordinal() as u32,
9663    ) {
9664        Ok(f) => f,
9665        Err(_) => return Ok(false),
9666    };
9667
9668    let cfg = launch_cfg(n)?;
9669    let n_u32 = n as u32;
9670
9671    unsafe {
9672        stream
9673            .launch_builder(&f)
9674            .arg(a.inner())
9675            .arg(b.inner())
9676            .arg(out.inner_mut())
9677            .arg(&n_u32)
9678            .launch(cfg)?;
9679    }
9680
9681    Ok(true)
9682}
9683
9684/// Launch a unary PTX kernel into a pre-allocated output buffer.
9685/// Returns `Ok(true)` on success, `Ok(false)` if the PTX module failed to load.
9686#[cfg(feature = "cuda")]
9687fn try_launch_unary_into(
9688    a: &CudaBuffer<f32>,
9689    out: &mut CudaBuffer<f32>,
9690    device: &GpuDevice,
9691    ptx_src: &'static str,
9692    kernel_name: &'static str,
9693) -> GpuResult<bool> {
9694    use cudarc::driver::PushKernelArg;
9695
9696    let n = a.len();
9697    let ctx = device.context();
9698    let stream = device.stream();
9699
9700    let f = match crate::module_cache::get_or_compile(
9701        ctx,
9702        ptx_src,
9703        kernel_name,
9704        device.ordinal() as u32,
9705    ) {
9706        Ok(f) => f,
9707        Err(_) => return Ok(false),
9708    };
9709
9710    let cfg = launch_cfg(n)?;
9711    let n_u32 = n as u32;
9712
9713    unsafe {
9714        stream
9715            .launch_builder(&f)
9716            .arg(a.inner())
9717            .arg(out.inner_mut())
9718            .arg(&n_u32)
9719            .launch(cfg)?;
9720    }
9721
9722    Ok(true)
9723}
9724
9725// ---------------------------------------------------------------------------
9726// f64 launch helpers
9727// ---------------------------------------------------------------------------
9728
9729/// Try to launch a binary f64 PTX kernel.
9730#[cfg(feature = "cuda")]
9731fn try_launch_binary_f64(
9732    a: &CudaBuffer<f64>,
9733    b: &CudaBuffer<f64>,
9734    device: &GpuDevice,
9735    ptx_src: &'static str,
9736    kernel_name: &'static str,
9737) -> GpuResult<Option<CudaBuffer<f64>>> {
9738    use cudarc::driver::PushKernelArg;
9739
9740    let n = a.len();
9741    let ctx = device.context();
9742    let stream = device.stream();
9743
9744    let f = match crate::module_cache::get_or_compile(
9745        ctx, ptx_src, kernel_name, device.ordinal() as u32,
9746    ) {
9747        Ok(f) => f,
9748        Err(_) => return Ok(None),
9749    };
9750
9751    let mut out = alloc_zeros_f64(n, device)?;
9752    let cfg = launch_cfg(n)?;
9753    let n_u32 = n as u32;
9754
9755    unsafe {
9756        stream
9757            .launch_builder(&f)
9758            .arg(a.inner())
9759            .arg(b.inner())
9760            .arg(out.inner_mut())
9761            .arg(&n_u32)
9762            .launch(cfg)?;
9763    }
9764    Ok(Some(out))
9765}
9766
9767/// Try to launch a unary f64 PTX kernel.
9768#[cfg(feature = "cuda")]
9769fn try_launch_unary_f64(
9770    a: &CudaBuffer<f64>,
9771    device: &GpuDevice,
9772    ptx_src: &'static str,
9773    kernel_name: &'static str,
9774) -> GpuResult<Option<CudaBuffer<f64>>> {
9775    use cudarc::driver::PushKernelArg;
9776
9777    let n = a.len();
9778    let ctx = device.context();
9779    let stream = device.stream();
9780
9781    let f = match crate::module_cache::get_or_compile(
9782        ctx, ptx_src, kernel_name, device.ordinal() as u32,
9783    ) {
9784        Ok(f) => f,
9785        Err(_) => return Ok(None),
9786    };
9787
9788    let mut out = alloc_zeros_f64(n, device)?;
9789    let cfg = launch_cfg(n)?;
9790    let n_u32 = n as u32;
9791
9792    unsafe {
9793        stream
9794            .launch_builder(&f)
9795            .arg(a.inner())
9796            .arg(out.inner_mut())
9797            .arg(&n_u32)
9798            .launch(cfg)?;
9799    }
9800    Ok(Some(out))
9801}
9802
9803/// CPU fallback for f64 binary ops.
9804#[cfg(feature = "cuda")]
9805fn cpu_fallback_binary_f64(
9806    a: &CudaBuffer<f64>,
9807    b: &CudaBuffer<f64>,
9808    device: &GpuDevice,
9809    op: fn(f64, f64) -> f64,
9810) -> GpuResult<CudaBuffer<f64>> {
9811    let a_host = gpu_to_cpu(a, device)?;
9812    let b_host = gpu_to_cpu(b, device)?;
9813    let result: Vec<f64> = a_host.iter().zip(b_host.iter()).map(|(&x, &y)| op(x, y)).collect();
9814    cpu_to_gpu(&result, device)
9815}
9816
9817/// CPU fallback for f64 unary ops.
9818#[cfg(feature = "cuda")]
9819fn cpu_fallback_unary_f64(
9820    a: &CudaBuffer<f64>,
9821    device: &GpuDevice,
9822    op: fn(f64) -> f64,
9823) -> GpuResult<CudaBuffer<f64>> {
9824    let a_host = gpu_to_cpu(a, device)?;
9825    let result: Vec<f64> = a_host.iter().map(|&x| op(x)).collect();
9826    cpu_to_gpu(&result, device)
9827}
9828
9829/// Try to launch a general N-dimensional broadcast binary f64 PTX kernel.
9830///
9831/// Same as [`try_launch_broadcast_binary`] but for `f64` buffers.
9832#[cfg(feature = "cuda")]
9833#[allow(clippy::too_many_arguments)]
9834fn try_launch_broadcast_binary_f64(
9835    a: &CudaBuffer<f64>,
9836    b: &CudaBuffer<f64>,
9837    a_strides: &[u32],
9838    b_strides: &[u32],
9839    out_shape: &[u32],
9840    out_numel: usize,
9841    device: &GpuDevice,
9842    ptx_src: &'static str,
9843    kernel_name: &'static str,
9844) -> GpuResult<Option<CudaBuffer<f64>>> {
9845    use cudarc::driver::PushKernelArg;
9846
9847    let ndim = out_shape.len();
9848    let ctx = device.context();
9849    let stream = device.stream();
9850
9851    let f = match crate::module_cache::get_or_compile(
9852        ctx,
9853        ptx_src,
9854        kernel_name,
9855        device.ordinal() as u32,
9856    ) {
9857        Ok(f) => f,
9858        Err(_) => return Ok(None),
9859    };
9860
9861    // Upload stride/shape metadata as small device buffers.
9862    let a_str_buf = cpu_to_gpu(a_strides, device)?;
9863    let b_str_buf = cpu_to_gpu(b_strides, device)?;
9864    let shape_buf = cpu_to_gpu(out_shape, device)?;
9865
9866    let mut out = alloc_zeros_f64(out_numel, device)?;
9867    let cfg = launch_cfg(out_numel)?;
9868    let n_u32 = out_numel as u32;
9869    let ndim_u32 = ndim as u32;
9870
9871    unsafe {
9872        stream
9873            .launch_builder(&f)
9874            .arg(a.inner())
9875            .arg(b.inner())
9876            .arg(out.inner_mut())
9877            .arg(a_str_buf.inner())
9878            .arg(b_str_buf.inner())
9879            .arg(shape_buf.inner())
9880            .arg(&n_u32)
9881            .arg(&ndim_u32)
9882            .launch(cfg)?;
9883    }
9884
9885    Ok(Some(out))
9886}
9887
9888/// CPU fallback for f64 broadcast binary ops.
9889#[cfg(feature = "cuda")]
9890fn cpu_fallback_broadcast_binary_f64(
9891    a: &CudaBuffer<f64>,
9892    b: &CudaBuffer<f64>,
9893    a_shape: &[usize],
9894    b_shape: &[usize],
9895    out_shape: &[usize],
9896    device: &GpuDevice,
9897    op: fn(f64, f64) -> f64,
9898) -> GpuResult<CudaBuffer<f64>> {
9899    let a_host = gpu_to_cpu(a, device)?;
9900    let b_host = gpu_to_cpu(b, device)?;
9901    let out_numel: usize = out_shape.iter().product();
9902
9903    let a_str = broadcast_strides(a_shape, out_shape);
9904    let b_str = broadcast_strides(b_shape, out_shape);
9905
9906    let mut result = Vec::with_capacity(out_numel);
9907    for i in 0..out_numel {
9908        let mut remaining = i;
9909        let mut a_idx = 0usize;
9910        let mut b_idx = 0usize;
9911        for d in (0..out_shape.len()).rev() {
9912            let coord = remaining % out_shape[d];
9913            remaining /= out_shape[d];
9914            a_idx += coord * a_str[d] as usize;
9915            b_idx += coord * b_str[d] as usize;
9916        }
9917        result.push(op(a_host[a_idx], b_host[b_idx]));
9918    }
9919    cpu_to_gpu(&result, device)
9920}
9921
9922/// Try to launch a general N-dimensional broadcast binary PTX kernel.
9923///
9924/// `a_strides` and `b_strides` are broadcast strides: normal C-contiguous
9925/// stride for non-broadcast dims, 0 for broadcast (size-1) dims.
9926/// `out_shape` is the broadcast-resolved output shape.
9927/// All three arrays have length `ndim`.
9928#[cfg(feature = "cuda")]
9929#[allow(clippy::too_many_arguments)]
9930fn try_launch_broadcast_binary(
9931    a: &CudaBuffer<f32>,
9932    b: &CudaBuffer<f32>,
9933    a_strides: &[u32],
9934    b_strides: &[u32],
9935    out_shape: &[u32],
9936    out_numel: usize,
9937    device: &GpuDevice,
9938    ptx_src: &'static str,
9939    kernel_name: &'static str,
9940) -> GpuResult<Option<CudaBuffer<f32>>> {
9941    use cudarc::driver::PushKernelArg;
9942
9943    let ndim = out_shape.len();
9944    let ctx = device.context();
9945    let stream = device.stream();
9946
9947    let f = match crate::module_cache::get_or_compile(
9948        ctx,
9949        ptx_src,
9950        kernel_name,
9951        device.ordinal() as u32,
9952    ) {
9953        Ok(f) => f,
9954        Err(_) => return Ok(None),
9955    };
9956
9957    // Upload stride/shape metadata as small device buffers.
9958    let a_str_buf = cpu_to_gpu(a_strides, device)?;
9959    let b_str_buf = cpu_to_gpu(b_strides, device)?;
9960    let shape_buf = cpu_to_gpu(out_shape, device)?;
9961
9962    let mut out = alloc_zeros_f32(out_numel, device)?;
9963    let cfg = launch_cfg(out_numel)?;
9964    let n_u32 = out_numel as u32;
9965    let ndim_u32 = ndim as u32;
9966
9967    // SAFETY: Kernel reads from a, b using broadcast indices computed from
9968    // the stride/shape buffers. Output buffer has out_numel elements.
9969    unsafe {
9970        stream
9971            .launch_builder(&f)
9972            .arg(a.inner())
9973            .arg(b.inner())
9974            .arg(out.inner_mut())
9975            .arg(a_str_buf.inner())
9976            .arg(b_str_buf.inner())
9977            .arg(shape_buf.inner())
9978            .arg(&n_u32)
9979            .arg(&ndim_u32)
9980            .launch(cfg)?;
9981    }
9982
9983    Ok(Some(out))
9984}
9985
9986/// Compute broadcast strides for a tensor shape relative to an output shape.
9987///
9988/// For each dimension, the stride is the normal C-contiguous stride if the
9989/// dimension size matches the output, or 0 if the dimension size is 1
9990/// (broadcast). Missing leading dimensions (when input has fewer dims) are
9991/// treated as size-1.
9992#[cfg(feature = "cuda")]
9993fn broadcast_strides(in_shape: &[usize], out_shape: &[usize]) -> Vec<u32> {
9994    let ndim = out_shape.len();
9995    let in_ndim = in_shape.len();
9996    let mut strides = vec![0u32; ndim];
9997
9998    // C-contiguous strides for the input shape.
9999    let mut stride: u32 = 1;
10000    for d in (0..ndim).rev() {
10001        let in_d = if d + in_ndim >= ndim {
10002            d + in_ndim - ndim
10003        } else {
10004            // Leading dimension not present in input — broadcast.
10005            strides[d] = 0;
10006            continue;
10007        };
10008
10009        if in_shape[in_d] == 1 {
10010            strides[d] = 0; // Broadcast dimension.
10011        } else {
10012            strides[d] = stride;
10013        }
10014        stride *= in_shape[in_d] as u32;
10015    }
10016
10017    strides
10018}
10019
10020// ---------------------------------------------------------------------------
10021// CPU fallback helpers
10022// ---------------------------------------------------------------------------
10023
10024/// CPU fallback for binary ops: copy both inputs to host, apply `op`, copy
10025/// the result back.
10026#[cfg(feature = "cuda")]
10027fn cpu_fallback_binary(
10028    a: &CudaBuffer<f32>,
10029    b: &CudaBuffer<f32>,
10030    device: &GpuDevice,
10031    op: fn(f32, f32) -> f32,
10032) -> GpuResult<CudaBuffer<f32>> {
10033    let a_host = gpu_to_cpu(a, device)?;
10034    let b_host = gpu_to_cpu(b, device)?;
10035    let result: Vec<f32> = a_host
10036        .iter()
10037        .zip(b_host.iter())
10038        .map(|(&x, &y)| op(x, y))
10039        .collect();
10040    cpu_to_gpu(&result, device)
10041}
10042
10043/// CPU fallback for unary ops.
10044#[cfg(feature = "cuda")]
10045fn cpu_fallback_unary(
10046    a: &CudaBuffer<f32>,
10047    device: &GpuDevice,
10048    op: fn(f32) -> f32,
10049) -> GpuResult<CudaBuffer<f32>> {
10050    let a_host = gpu_to_cpu(a, device)?;
10051    let result: Vec<f32> = a_host.iter().map(|&x| op(x)).collect();
10052    cpu_to_gpu(&result, device)
10053}
10054
10055// ---------------------------------------------------------------------------
10056// Public API -- binary ops
10057// ---------------------------------------------------------------------------
10058
10059/// Elementwise addition: `out[i] = a[i] + b[i]`.
10060///
10061/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
10062/// if the PTX module cannot be loaded.
10063///
10064/// # Errors
10065///
10066/// - [`GpuError::DeviceMismatch`] if `a`, `b`, or `device` refer to
10067///   different CUDA devices.
10068/// - [`GpuError::LengthMismatch`] if `a` and `b` have different lengths.
10069/// - [`GpuError::Driver`] on CUDA runtime errors.
10070#[cfg(feature = "cuda")]
10071pub fn gpu_add(
10072    a: &CudaBuffer<f32>,
10073    b: &CudaBuffer<f32>,
10074    device: &GpuDevice,
10075) -> GpuResult<CudaBuffer<f32>> {
10076    validate_binary(a, b, device)?;
10077
10078    // Try vec4 kernel for 4x memory throughput (128-bit loads).
10079    let n = a.len();
10080    if n >= 16 && n % 4 == 0 {
10081        if let Some(out) = try_launch_binary_vec4(
10082            a, b, device, ADD_VEC4_PTX, "add_vec4_kernel",
10083        )? {
10084            return Ok(out);
10085        }
10086    }
10087
10088    if let Some(out) = try_launch_binary(a, b, device, ADD_PTX, "add_kernel")? {
10089        return Ok(out);
10090    }
10091
10092    cpu_fallback_binary(a, b, device, |x, y| x + y)
10093}
10094
10095/// Elementwise subtraction: `out[i] = a[i] - b[i]`.
10096///
10097/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
10098/// if the PTX module cannot be loaded.
10099///
10100/// # Errors
10101///
10102/// - [`GpuError::DeviceMismatch`] if `a`, `b`, or `device` refer to
10103///   different CUDA devices.
10104/// - [`GpuError::LengthMismatch`] if `a` and `b` have different lengths.
10105/// - [`GpuError::Driver`] on CUDA runtime errors.
10106#[cfg(feature = "cuda")]
10107pub fn gpu_sub(
10108    a: &CudaBuffer<f32>,
10109    b: &CudaBuffer<f32>,
10110    device: &GpuDevice,
10111) -> GpuResult<CudaBuffer<f32>> {
10112    validate_binary(a, b, device)?;
10113
10114    if let Some(out) = try_launch_binary(a, b, device, SUB_PTX, "sub_kernel")? {
10115        return Ok(out);
10116    }
10117
10118    cpu_fallback_binary(a, b, device, |x, y| x - y)
10119}
10120
10121/// Elementwise multiplication: `out[i] = a[i] * b[i]`.
10122///
10123/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
10124/// if the PTX module cannot be loaded.
10125///
10126/// # Errors
10127///
10128/// - [`GpuError::DeviceMismatch`] if `a`, `b`, or `device` refer to
10129///   different CUDA devices.
10130/// - [`GpuError::LengthMismatch`] if `a` and `b` have different lengths.
10131/// - [`GpuError::Driver`] on CUDA runtime errors.
10132#[cfg(feature = "cuda")]
10133pub fn gpu_mul(
10134    a: &CudaBuffer<f32>,
10135    b: &CudaBuffer<f32>,
10136    device: &GpuDevice,
10137) -> GpuResult<CudaBuffer<f32>> {
10138    validate_binary(a, b, device)?;
10139
10140    let n = a.len();
10141    if n >= 16 && n % 4 == 0 {
10142        if let Some(out) = try_launch_binary_vec4(
10143            a, b, device, MUL_VEC4_PTX, "mul_vec4_kernel",
10144        )? {
10145            return Ok(out);
10146        }
10147    }
10148
10149    if let Some(out) = try_launch_binary(a, b, device, MUL_PTX, "mul_kernel")? {
10150        return Ok(out);
10151    }
10152
10153    cpu_fallback_binary(a, b, device, |x, y| x * y)
10154}
10155
10156// ---------------------------------------------------------------------------
10157// Public API -- broadcast binary ops
10158// ---------------------------------------------------------------------------
10159
10160/// Broadcast addition: `out[i] = a[bcast_a(i)] + b[bcast_b(i)]`.
10161///
10162/// Handles arbitrary N-dimensional broadcasting on the GPU. The kernel
10163/// decomposes each output index into coordinates, maps them through
10164/// broadcast strides, and loads from the correct positions in A and B.
10165///
10166/// `a_shape` and `b_shape` are the original shapes; the output shape is
10167/// computed via numpy-style broadcast rules.
10168#[cfg(feature = "cuda")]
10169pub fn gpu_broadcast_add(
10170    a: &CudaBuffer<f32>,
10171    b: &CudaBuffer<f32>,
10172    a_shape: &[usize],
10173    b_shape: &[usize],
10174    out_shape: &[usize],
10175    device: &GpuDevice,
10176) -> GpuResult<CudaBuffer<f32>> {
10177    let a_str = broadcast_strides(a_shape, out_shape);
10178    let b_str = broadcast_strides(b_shape, out_shape);
10179    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10180    let out_numel: usize = out_shape.iter().product();
10181
10182    if let Some(out) = try_launch_broadcast_binary(
10183        a,
10184        b,
10185        &a_str,
10186        &b_str,
10187        &shape_u32,
10188        out_numel,
10189        device,
10190        BROADCAST_ADD_PTX,
10191        "broadcast_add_kernel",
10192    )? {
10193        return Ok(out);
10194    }
10195
10196    // CPU fallback for broadcast.
10197    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
10198}
10199
10200/// Broadcast subtraction: `out[i] = a[bcast_a(i)] - b[bcast_b(i)]`.
10201#[cfg(feature = "cuda")]
10202pub fn gpu_broadcast_sub(
10203    a: &CudaBuffer<f32>,
10204    b: &CudaBuffer<f32>,
10205    a_shape: &[usize],
10206    b_shape: &[usize],
10207    out_shape: &[usize],
10208    device: &GpuDevice,
10209) -> GpuResult<CudaBuffer<f32>> {
10210    let a_str = broadcast_strides(a_shape, out_shape);
10211    let b_str = broadcast_strides(b_shape, out_shape);
10212    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10213    let out_numel: usize = out_shape.iter().product();
10214
10215    if let Some(out) = try_launch_broadcast_binary(
10216        a,
10217        b,
10218        &a_str,
10219        &b_str,
10220        &shape_u32,
10221        out_numel,
10222        device,
10223        BROADCAST_SUB_PTX,
10224        "broadcast_sub_kernel",
10225    )? {
10226        return Ok(out);
10227    }
10228
10229    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
10230}
10231
10232/// Broadcast multiplication: `out[i] = a[bcast_a(i)] * b[bcast_b(i)]`.
10233#[cfg(feature = "cuda")]
10234pub fn gpu_broadcast_mul(
10235    a: &CudaBuffer<f32>,
10236    b: &CudaBuffer<f32>,
10237    a_shape: &[usize],
10238    b_shape: &[usize],
10239    out_shape: &[usize],
10240    device: &GpuDevice,
10241) -> GpuResult<CudaBuffer<f32>> {
10242    let a_str = broadcast_strides(a_shape, out_shape);
10243    let b_str = broadcast_strides(b_shape, out_shape);
10244    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10245    let out_numel: usize = out_shape.iter().product();
10246
10247    if let Some(out) = try_launch_broadcast_binary(
10248        a,
10249        b,
10250        &a_str,
10251        &b_str,
10252        &shape_u32,
10253        out_numel,
10254        device,
10255        BROADCAST_MUL_PTX,
10256        "broadcast_mul_kernel",
10257    )? {
10258        return Ok(out);
10259    }
10260
10261    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
10262}
10263
10264/// Broadcast division: `out[i] = a[bcast_a(i)] / b[bcast_b(i)]`.
10265#[cfg(feature = "cuda")]
10266pub fn gpu_broadcast_div(
10267    a: &CudaBuffer<f32>,
10268    b: &CudaBuffer<f32>,
10269    a_shape: &[usize],
10270    b_shape: &[usize],
10271    out_shape: &[usize],
10272    device: &GpuDevice,
10273) -> GpuResult<CudaBuffer<f32>> {
10274    let a_str = broadcast_strides(a_shape, out_shape);
10275    let b_str = broadcast_strides(b_shape, out_shape);
10276    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10277    let out_numel: usize = out_shape.iter().product();
10278
10279    if let Some(out) = try_launch_broadcast_binary(
10280        a,
10281        b,
10282        &a_str,
10283        &b_str,
10284        &shape_u32,
10285        out_numel,
10286        device,
10287        BROADCAST_DIV_PTX,
10288        "broadcast_div_kernel",
10289    )? {
10290        return Ok(out);
10291    }
10292
10293    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
10294}
10295
10296/// CPU fallback for broadcast binary ops — downloads, applies op with
10297/// broadcast indexing, re-uploads.
10298#[cfg(feature = "cuda")]
10299fn cpu_fallback_broadcast_binary(
10300    a: &CudaBuffer<f32>,
10301    b: &CudaBuffer<f32>,
10302    a_shape: &[usize],
10303    b_shape: &[usize],
10304    out_shape: &[usize],
10305    device: &GpuDevice,
10306    op: fn(f32, f32) -> f32,
10307) -> GpuResult<CudaBuffer<f32>> {
10308    let a_host = gpu_to_cpu(a, device)?;
10309    let b_host = gpu_to_cpu(b, device)?;
10310    let out_numel: usize = out_shape.iter().product();
10311
10312    let a_str = broadcast_strides(a_shape, out_shape);
10313    let b_str = broadcast_strides(b_shape, out_shape);
10314
10315    let mut result = Vec::with_capacity(out_numel);
10316    for i in 0..out_numel {
10317        let mut remaining = i;
10318        let mut a_idx = 0usize;
10319        let mut b_idx = 0usize;
10320        for d in (0..out_shape.len()).rev() {
10321            let coord = remaining % out_shape[d];
10322            remaining /= out_shape[d];
10323            a_idx += coord * a_str[d] as usize;
10324            b_idx += coord * b_str[d] as usize;
10325        }
10326        result.push(op(a_host[a_idx], b_host[b_idx]));
10327    }
10328    cpu_to_gpu(&result, device)
10329}
10330
10331// ---------------------------------------------------------------------------
10332// Public API -- unary ops
10333// ---------------------------------------------------------------------------
10334
10335/// Elementwise negation: `out[i] = -a[i]`.
10336///
10337/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
10338/// if the PTX module cannot be loaded.
10339///
10340/// # Errors
10341///
10342/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different
10343///   CUDA devices.
10344/// - [`GpuError::Driver`] on CUDA runtime errors.
10345#[cfg(feature = "cuda")]
10346pub fn gpu_neg(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
10347    validate_unary(a, device)?;
10348
10349    if let Some(out) = try_launch_unary(a, device, NEG_PTX, "neg_kernel")? {
10350        return Ok(out);
10351    }
10352
10353    cpu_fallback_unary(a, device, |x| -x)
10354}
10355
10356/// Elementwise ReLU: `out[i] = max(a[i], 0.0)`.
10357///
10358/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
10359/// if the PTX module cannot be loaded.
10360///
10361/// # Errors
10362///
10363/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different
10364///   CUDA devices.
10365/// - [`GpuError::Driver`] on CUDA runtime errors.
10366#[cfg(feature = "cuda")]
10367pub fn gpu_relu(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
10368    validate_unary(a, device)?;
10369
10370    if let Some(out) = try_launch_unary(a, device, RELU_PTX, "relu_kernel")? {
10371        return Ok(out);
10372    }
10373
10374    cpu_fallback_unary(a, device, |x| x.max(0.0))
10375}
10376
10377/// ReLU backward: `out[i] = (input[i] > 0) ? grad[i] : 0`.
10378#[cfg(feature = "cuda")]
10379pub fn gpu_relu_backward(
10380    grad: &CudaBuffer<f32>,
10381    input: &CudaBuffer<f32>,
10382    device: &GpuDevice,
10383) -> GpuResult<CudaBuffer<f32>> {
10384    validate_binary(grad, input, device)?;
10385
10386    if let Some(out) = try_launch_binary(
10387        grad,
10388        input,
10389        device,
10390        RELU_BACKWARD_PTX,
10391        "relu_backward_kernel",
10392    )? {
10393        return Ok(out);
10394    }
10395
10396    // CPU fallback
10397    let grad_host = gpu_to_cpu(grad, device)?;
10398    let input_host = gpu_to_cpu(input, device)?;
10399    let result: Vec<f32> = grad_host
10400        .iter()
10401        .zip(input_host.iter())
10402        .map(|(&g, &x)| if x > 0.0 { g } else { 0.0 })
10403        .collect();
10404    cpu_to_gpu(&result, device)
10405}
10406
10407/// GELU backward: `out[i] = grad[i] * (sig + 1.702 * x * sig * (1 - sig))`
10408/// where `sig = sigmoid(1.702 * x)`.
10409#[cfg(feature = "cuda")]
10410pub fn gpu_gelu_backward(
10411    grad: &CudaBuffer<f32>,
10412    input: &CudaBuffer<f32>,
10413    device: &GpuDevice,
10414) -> GpuResult<CudaBuffer<f32>> {
10415    validate_binary(grad, input, device)?;
10416
10417    if let Some(out) = try_launch_binary(
10418        grad,
10419        input,
10420        device,
10421        GELU_BACKWARD_PTX,
10422        "gelu_backward_kernel",
10423    )? {
10424        return Ok(out);
10425    }
10426
10427    // CPU fallback
10428    let grad_host = gpu_to_cpu(grad, device)?;
10429    let input_host = gpu_to_cpu(input, device)?;
10430    let result: Vec<f32> = grad_host
10431        .iter()
10432        .zip(input_host.iter())
10433        .map(|(&g, &x)| {
10434            let k: f32 = 1.702;
10435            let sig = 1.0 / (1.0 + (-k * x).exp());
10436            g * (sig + k * x * sig * (1.0 - sig))
10437        })
10438        .collect();
10439    cpu_to_gpu(&result, device)
10440}
10441
10442/// GELU backward (exact erf mode):
10443/// `out[i] = grad[i] * (Φ(x) + x·φ(x))`
10444/// where Φ = normal CDF, φ = normal PDF.
10445#[cfg(feature = "cuda")]
10446pub fn gpu_gelu_backward_erf(
10447    grad: &CudaBuffer<f32>,
10448    input: &CudaBuffer<f32>,
10449    device: &GpuDevice,
10450) -> GpuResult<CudaBuffer<f32>> {
10451    validate_binary(grad, input, device)?;
10452
10453    if let Some(out) = try_launch_binary(
10454        grad,
10455        input,
10456        device,
10457        GELU_BACKWARD_ERF_PTX,
10458        "gelu_backward_erf_kernel",
10459    )? {
10460        return Ok(out);
10461    }
10462
10463    // CPU fallback — Abramowitz & Stegun erf approximation (|ε| < 1.5e-7)
10464    let grad_host = gpu_to_cpu(grad, device)?;
10465    let input_host = gpu_to_cpu(input, device)?;
10466    let inv_sqrt_2: f32 = std::f32::consts::FRAC_1_SQRT_2;
10467    let inv_sqrt_2pi: f32 = 1.0 / (2.0 * std::f32::consts::PI).sqrt();
10468    let result: Vec<f32> = grad_host
10469        .iter()
10470        .zip(input_host.iter())
10471        .map(|(&g, &x)| {
10472            let z = x * inv_sqrt_2;
10473            let az = z.abs();
10474            let t = 1.0 / (1.0 + 0.3275911 * az);
10475            let poly = t * (0.2548296 + t * (-0.2844967 + t * (1.4214137 + t * (-1.4531520 + t * 0.3275911))));
10476            let erf_abs = 1.0 - poly * (-az * az).exp();
10477            let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
10478            let cdf = 0.5 * (1.0 + erf_val);
10479            let pdf = inv_sqrt_2pi * (-0.5 * x * x).exp();
10480            g * (cdf + x * pdf)
10481        })
10482        .collect();
10483    cpu_to_gpu(&result, device)
10484}
10485
10486// ---------------------------------------------------------------------------
10487// Public API -- Index-select 1-D (gather)
10488// ---------------------------------------------------------------------------
10489
10490/// Gather elements from `input` at positions given by `indices`.
10491///
10492/// `indices` is a GPU buffer of f32 values encoding integer indices.
10493/// Output has `indices.len()` elements: `out[i] = input[indices[i]]`.
10494#[cfg(feature = "cuda")]
10495pub fn gpu_index_select_1d(
10496    input: &CudaBuffer<f32>,
10497    indices: &CudaBuffer<f32>,
10498    device: &GpuDevice,
10499) -> GpuResult<CudaBuffer<f32>> {
10500    use cudarc::driver::PushKernelArg;
10501
10502    validate_unary(input, device)?;
10503
10504    let n = indices.len();
10505    let ctx = device.context();
10506    let stream = device.stream();
10507
10508    let f = match crate::module_cache::get_or_compile(
10509        ctx,
10510        INDEX_SELECT_1D_PTX,
10511        "index_select_1d_kernel",
10512        device.ordinal() as u32,
10513    ) {
10514        Ok(f) => f,
10515        Err(_) => {
10516            // CPU fallback.
10517            let input_host = gpu_to_cpu(input, device)?;
10518            let indices_host = gpu_to_cpu(indices, device)?;
10519            let result: Vec<f32> = indices_host
10520                .iter()
10521                .map(|&idx_f| input_host[idx_f as usize])
10522                .collect();
10523            return cpu_to_gpu(&result, device);
10524        }
10525    };
10526
10527    let mut out = alloc_zeros_f32(n, device)?;
10528    let cfg = launch_cfg(n)?;
10529    let n_u32 = n as u32;
10530
10531    unsafe {
10532        stream
10533            .launch_builder(&f)
10534            .arg(input.inner())
10535            .arg(indices.inner())
10536            .arg(out.inner_mut())
10537            .arg(&n_u32)
10538            .launch(cfg)?;
10539    }
10540
10541    Ok(out)
10542}
10543
10544// ---------------------------------------------------------------------------
10545// Public API -- Scatter-add 1-D (backward of index_select)
10546// ---------------------------------------------------------------------------
10547
10548/// Scatter-add `grad_output` back into an output buffer of `input_len` elements,
10549/// using positions from `indices`.
10550///
10551/// `indices` is a GPU buffer of f32 values encoding integer indices.
10552/// Output: `out = zeros(input_len); for i: out[indices[i]] += grad_output[i]`
10553///
10554/// Uses atomic adds for safe concurrent accumulation.
10555#[cfg(feature = "cuda")]
10556pub fn gpu_scatter_add_1d(
10557    grad_output: &CudaBuffer<f32>,
10558    indices: &CudaBuffer<f32>,
10559    input_len: usize,
10560    device: &GpuDevice,
10561) -> GpuResult<CudaBuffer<f32>> {
10562    use cudarc::driver::PushKernelArg;
10563
10564    validate_unary(grad_output, device)?;
10565
10566    let n = grad_output.len();
10567    let ctx = device.context();
10568    let stream = device.stream();
10569
10570    let f = match crate::module_cache::get_or_compile(
10571        ctx,
10572        SCATTER_ADD_1D_PTX,
10573        "scatter_add_1d_kernel",
10574        device.ordinal() as u32,
10575    ) {
10576        Ok(f) => f,
10577        Err(_) => {
10578            // CPU fallback.
10579            let go_host = gpu_to_cpu(grad_output, device)?;
10580            let idx_host = gpu_to_cpu(indices, device)?;
10581            let mut result = vec![0.0f32; input_len];
10582            for (i, &idx_f) in idx_host.iter().enumerate() {
10583                result[idx_f as usize] += go_host[i];
10584            }
10585            return cpu_to_gpu(&result, device);
10586        }
10587    };
10588
10589    let mut out = alloc_zeros_f32(input_len, device)?;
10590    let cfg = launch_cfg(n)?;
10591    let n_u32 = n as u32;
10592
10593    unsafe {
10594        stream
10595            .launch_builder(&f)
10596            .arg(grad_output.inner())
10597            .arg(indices.inner())
10598            .arg(out.inner_mut())
10599            .arg(&n_u32)
10600            .launch(cfg)?;
10601    }
10602
10603    Ok(out)
10604}
10605
10606// ---------------------------------------------------------------------------
10607// Public API -- Masked fill
10608// ---------------------------------------------------------------------------
10609
10610/// Fill elements of `input` with `value` where `mask` is true.
10611///
10612/// `mask` is a GPU buffer of f32 values (1.0 = true, 0.0 = false).
10613/// Output: `out[i] = mask[i] >= 0.5 ? value : input[i]`
10614#[cfg(feature = "cuda")]
10615pub fn gpu_masked_fill(
10616    input: &CudaBuffer<f32>,
10617    mask: &CudaBuffer<f32>,
10618    value: f32,
10619    device: &GpuDevice,
10620) -> GpuResult<CudaBuffer<f32>> {
10621    use cudarc::driver::PushKernelArg;
10622
10623    validate_binary(input, mask, device)?;
10624
10625    let n = input.len();
10626    let ctx = device.context();
10627    let stream = device.stream();
10628
10629    let f = match crate::module_cache::get_or_compile(
10630        ctx,
10631        MASKED_FILL_PTX,
10632        "masked_fill_kernel",
10633        device.ordinal() as u32,
10634    ) {
10635        Ok(f) => f,
10636        Err(_) => {
10637            // CPU fallback.
10638            let input_host = gpu_to_cpu(input, device)?;
10639            let mask_host = gpu_to_cpu(mask, device)?;
10640            let result: Vec<f32> = input_host
10641                .iter()
10642                .zip(mask_host.iter())
10643                .map(|(&x, &m)| if m >= 0.5 { value } else { x })
10644                .collect();
10645            return cpu_to_gpu(&result, device);
10646        }
10647    };
10648
10649    let mut out = alloc_zeros_f32(n, device)?;
10650    let cfg = launch_cfg(n)?;
10651    let n_u32 = n as u32;
10652
10653    unsafe {
10654        stream
10655            .launch_builder(&f)
10656            .arg(input.inner())
10657            .arg(mask.inner())
10658            .arg(out.inner_mut())
10659            .arg(&value)
10660            .arg(&n_u32)
10661            .launch(cfg)?;
10662    }
10663
10664    Ok(out)
10665}
10666
10667// ---------------------------------------------------------------------------
10668// Public API -- Masked zero (backward of masked_fill)
10669// ---------------------------------------------------------------------------
10670
10671/// Zero out gradient at positions where `mask` is true.
10672///
10673/// `mask` is a GPU buffer of f32 values (1.0 = true, 0.0 = false).
10674/// Output: `out[i] = mask[i] >= 0.5 ? 0.0 : grad[i]`
10675#[cfg(feature = "cuda")]
10676pub fn gpu_masked_zero(
10677    grad: &CudaBuffer<f32>,
10678    mask: &CudaBuffer<f32>,
10679    device: &GpuDevice,
10680) -> GpuResult<CudaBuffer<f32>> {
10681    validate_binary(grad, mask, device)?;
10682
10683    if let Some(out) = try_launch_binary(grad, mask, device, MASKED_ZERO_PTX, "masked_zero_kernel")?
10684    {
10685        return Ok(out);
10686    }
10687
10688    // CPU fallback.
10689    let grad_host = gpu_to_cpu(grad, device)?;
10690    let mask_host = gpu_to_cpu(mask, device)?;
10691    let result: Vec<f32> = grad_host
10692        .iter()
10693        .zip(mask_host.iter())
10694        .map(|(&g, &m)| if m >= 0.5 { 0.0 } else { g })
10695        .collect();
10696    cpu_to_gpu(&result, device)
10697}
10698
10699// ---------------------------------------------------------------------------
10700// Public API -- Sigmoid backward
10701// ---------------------------------------------------------------------------
10702
10703/// Sigmoid backward: `out[i] = grad[i] * output[i] * (1 - output[i])`.
10704///
10705/// `grad` and `output` must have the same length and reside on `device`.
10706#[cfg(feature = "cuda")]
10707pub fn gpu_sigmoid_backward(
10708    grad: &CudaBuffer<f32>,
10709    output: &CudaBuffer<f32>,
10710    device: &GpuDevice,
10711) -> GpuResult<CudaBuffer<f32>> {
10712    validate_binary(grad, output, device)?;
10713
10714    if let Some(out) = try_launch_binary(
10715        grad,
10716        output,
10717        device,
10718        SIGMOID_BACKWARD_PTX,
10719        "sigmoid_backward_kernel",
10720    )? {
10721        return Ok(out);
10722    }
10723
10724    // CPU fallback
10725    let grad_host = gpu_to_cpu(grad, device)?;
10726    let output_host = gpu_to_cpu(output, device)?;
10727    let result: Vec<f32> = grad_host
10728        .iter()
10729        .zip(output_host.iter())
10730        .map(|(&g, &o)| g * o * (1.0 - o))
10731        .collect();
10732    cpu_to_gpu(&result, device)
10733}
10734
10735// ---------------------------------------------------------------------------
10736// Public API -- Tanh backward
10737// ---------------------------------------------------------------------------
10738
10739/// Tanh backward: `out[i] = grad[i] * (1 - output[i]^2)`.
10740///
10741/// `grad` and `output` must have the same length and reside on `device`.
10742#[cfg(feature = "cuda")]
10743pub fn gpu_tanh_backward(
10744    grad: &CudaBuffer<f32>,
10745    output: &CudaBuffer<f32>,
10746    device: &GpuDevice,
10747) -> GpuResult<CudaBuffer<f32>> {
10748    validate_binary(grad, output, device)?;
10749
10750    if let Some(out) = try_launch_binary(
10751        grad,
10752        output,
10753        device,
10754        TANH_BACKWARD_PTX,
10755        "tanh_backward_kernel",
10756    )? {
10757        return Ok(out);
10758    }
10759
10760    // CPU fallback
10761    let grad_host = gpu_to_cpu(grad, device)?;
10762    let output_host = gpu_to_cpu(output, device)?;
10763    let result: Vec<f32> = grad_host
10764        .iter()
10765        .zip(output_host.iter())
10766        .map(|(&g, &o)| g * (1.0 - o * o))
10767        .collect();
10768    cpu_to_gpu(&result, device)
10769}
10770
10771// ---------------------------------------------------------------------------
10772// Public API -- Softmax backward
10773// ---------------------------------------------------------------------------
10774
10775/// Softmax backward (row-wise): one block per row, shared-memory dot reduction.
10776///
10777/// For each row of length `cols`:
10778///   `dot = sum(grad[row] * output[row])`
10779///   `out[i] = output[i] * (grad[i] - dot)`
10780///
10781/// `rows` = total elements / cols. Both `grad` and `output` have `rows * cols` elements.
10782#[cfg(feature = "cuda")]
10783pub fn gpu_softmax_backward(
10784    grad: &CudaBuffer<f32>,
10785    output: &CudaBuffer<f32>,
10786    cols: usize,
10787    device: &GpuDevice,
10788) -> GpuResult<CudaBuffer<f32>> {
10789    use cudarc::driver::PushKernelArg;
10790
10791    validate_binary(grad, output, device)?;
10792
10793    let total = grad.len();
10794    let rows = total / cols;
10795
10796    let ctx = device.context();
10797    let stream = device.stream();
10798
10799    let f = match crate::module_cache::get_or_compile(
10800        ctx,
10801        SOFTMAX_BACKWARD_PTX,
10802        "softmax_backward_kernel",
10803        device.ordinal() as u32,
10804    ) {
10805        Ok(f) => f,
10806        Err(_) => {
10807            // CPU fallback
10808            let grad_host = gpu_to_cpu(grad, device)?;
10809            let output_host = gpu_to_cpu(output, device)?;
10810            let mut result = vec![0.0f32; total];
10811            for r in 0..rows {
10812                let base = r * cols;
10813                let mut dot = 0.0f32;
10814                for c in 0..cols {
10815                    dot += grad_host[base + c] * output_host[base + c];
10816                }
10817                for c in 0..cols {
10818                    result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
10819                }
10820            }
10821            return cpu_to_gpu(&result, device);
10822        }
10823    };
10824
10825    let mut out = alloc_zeros_f32(total, device)?;
10826    let rows_u32 = rows as u32;
10827    let cols_u32 = cols as u32;
10828
10829    // One block per row, 256 threads per block.
10830    let cfg = LaunchConfig {
10831        grid_dim: ((rows as u32).max(1), 1, 1),
10832        block_dim: (256, 1, 1),
10833        shared_mem_bytes: 256 * 4,
10834    };
10835
10836    unsafe {
10837        stream
10838            .launch_builder(&f)
10839            .arg(grad.inner())
10840            .arg(output.inner())
10841            .arg(out.inner_mut())
10842            .arg(&rows_u32)
10843            .arg(&cols_u32)
10844            .launch(cfg)?;
10845    }
10846
10847    Ok(out)
10848}
10849
10850// ---------------------------------------------------------------------------
10851// Public API -- LogSoftmax forward & backward
10852// ---------------------------------------------------------------------------
10853
10854/// Row-wise log-softmax on GPU.
10855///
10856/// For each row: `out[j] = x[j] - log(sum(exp(x - max(x))))`.
10857///
10858/// One block per row, 256 threads per block, shared-memory reductions for max
10859/// and sum-exp.
10860#[cfg(feature = "cuda")]
10861pub fn gpu_log_softmax(
10862    input: &CudaBuffer<f32>,
10863    cols: usize,
10864    device: &GpuDevice,
10865) -> GpuResult<CudaBuffer<f32>> {
10866    use cudarc::driver::PushKernelArg;
10867
10868    validate_unary(input, device)?;
10869
10870    let total = input.len();
10871    let rows = total / cols;
10872
10873    let ctx = device.context();
10874    let stream = device.stream();
10875
10876    let f = match crate::module_cache::get_or_compile(
10877        ctx,
10878        LOG_SOFTMAX_PTX,
10879        "log_softmax_kernel",
10880        device.ordinal() as u32,
10881    ) {
10882        Ok(f) => f,
10883        Err(_) => {
10884            // CPU fallback
10885            let host = gpu_to_cpu(input, device)?;
10886            let mut out = vec![0.0f32; total];
10887            for r in 0..rows {
10888                let base = r * cols;
10889                let mut max_v = f32::NEG_INFINITY;
10890                for c in 0..cols {
10891                    max_v = max_v.max(host[base + c]);
10892                }
10893                let mut sum_exp = 0.0f32;
10894                for c in 0..cols {
10895                    sum_exp += (host[base + c] - max_v).exp();
10896                }
10897                let log_sum_exp = max_v + sum_exp.ln();
10898                for c in 0..cols {
10899                    out[base + c] = host[base + c] - log_sum_exp;
10900                }
10901            }
10902            return cpu_to_gpu(&out, device);
10903        }
10904    };
10905
10906    let mut out = alloc_zeros_f32(total, device)?;
10907    let rows_u32 = rows as u32;
10908    let cols_u32 = cols as u32;
10909
10910    // One block per row, 256 threads per block.
10911    let cfg = LaunchConfig {
10912        grid_dim: ((rows as u32).max(1), 1, 1),
10913        block_dim: (256, 1, 1),
10914        shared_mem_bytes: 256 * 4,
10915    };
10916
10917    unsafe {
10918        stream
10919            .launch_builder(&f)
10920            .arg(input.inner())
10921            .arg(out.inner_mut())
10922            .arg(&rows_u32)
10923            .arg(&cols_u32)
10924            .launch(cfg)?;
10925    }
10926
10927    Ok(out)
10928}
10929
10930/// Row-wise log-softmax backward on GPU.
10931///
10932/// For each row:
10933///   `sum_grad = sum(grad[j])`
10934///   `out[j] = grad[j] - exp(output[j]) * sum_grad`
10935///
10936/// where `output` is the log-softmax forward output.
10937#[cfg(feature = "cuda")]
10938pub fn gpu_log_softmax_backward(
10939    grad: &CudaBuffer<f32>,
10940    output: &CudaBuffer<f32>,
10941    cols: usize,
10942    device: &GpuDevice,
10943) -> GpuResult<CudaBuffer<f32>> {
10944    use cudarc::driver::PushKernelArg;
10945
10946    validate_binary(grad, output, device)?;
10947
10948    let total = grad.len();
10949    let rows = total / cols;
10950
10951    let ctx = device.context();
10952    let stream = device.stream();
10953
10954    let f = match crate::module_cache::get_or_compile(
10955        ctx,
10956        LOG_SOFTMAX_BACKWARD_PTX,
10957        "log_softmax_backward_kernel",
10958        device.ordinal() as u32,
10959    ) {
10960        Ok(f) => f,
10961        Err(_) => {
10962            // CPU fallback
10963            let grad_host = gpu_to_cpu(grad, device)?;
10964            let output_host = gpu_to_cpu(output, device)?;
10965            let mut result = vec![0.0f32; total];
10966            for r in 0..rows {
10967                let base = r * cols;
10968                let mut sum_grad = 0.0f32;
10969                for c in 0..cols {
10970                    sum_grad += grad_host[base + c];
10971                }
10972                for c in 0..cols {
10973                    result[base + c] =
10974                        grad_host[base + c] - output_host[base + c].exp() * sum_grad;
10975                }
10976            }
10977            return cpu_to_gpu(&result, device);
10978        }
10979    };
10980
10981    let mut out = alloc_zeros_f32(total, device)?;
10982    let rows_u32 = rows as u32;
10983    let cols_u32 = cols as u32;
10984
10985    // One block per row, 256 threads per block.
10986    let cfg = LaunchConfig {
10987        grid_dim: ((rows as u32).max(1), 1, 1),
10988        block_dim: (256, 1, 1),
10989        shared_mem_bytes: 256 * 4,
10990    };
10991
10992    unsafe {
10993        stream
10994            .launch_builder(&f)
10995            .arg(grad.inner())
10996            .arg(output.inner())
10997            .arg(out.inner_mut())
10998            .arg(&rows_u32)
10999            .arg(&cols_u32)
11000            .launch(cfg)?;
11001    }
11002
11003    Ok(out)
11004}
11005
11006// ---------------------------------------------------------------------------
11007// Public API -- Sum axis
11008// ---------------------------------------------------------------------------
11009
11010/// Reduce along one axis of a tensor.
11011///
11012/// Thread i computes:
11013/// Full parallel sum reduction on GPU.
11014///
11015/// Uses a two-pass approach: first pass reduces `n` elements to `num_blocks`
11016/// partial sums via the `reduce_sum_kernel`, second pass reduces the partial
11017/// sums to a single scalar. For small inputs (< 256 blocks), the second pass
11018/// runs on CPU to avoid kernel launch overhead.
11019#[cfg(feature = "cuda")]
11020pub fn gpu_reduce_sum(
11021    a: &CudaBuffer<f32>,
11022    device: &GpuDevice,
11023) -> GpuResult<CudaBuffer<f32>> {
11024    use cudarc::driver::PushKernelArg;
11025
11026    let n = a.len();
11027    if n == 0 {
11028        return cpu_to_gpu(&[0.0f32], device);
11029    }
11030
11031    let ctx = device.context();
11032    let stream = device.stream();
11033
11034    let f = match crate::module_cache::get_or_compile(
11035        ctx,
11036        REDUCE_SUM_PTX,
11037        "reduce_sum_kernel",
11038        device.ordinal() as u32,
11039    ) {
11040        Ok(f) => f,
11041        Err(_) => {
11042            // CPU fallback
11043            let host = gpu_to_cpu(a, device)?;
11044            let total: f32 = host.iter().sum();
11045            return cpu_to_gpu(&[total], device);
11046        }
11047    };
11048
11049    // Pass 1: reduce to partial sums (one per block).
11050    const BLOCK: u32 = 256;
11051    let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
11052    // Cap blocks to avoid excessive partial sums.
11053    let num_blocks = num_blocks.min(1024);
11054
11055    let mut partials = alloc_zeros_f32(num_blocks as usize, device)?;
11056    let n_u32 = n as u32;
11057
11058    let cfg = cudarc::driver::LaunchConfig {
11059        grid_dim: (num_blocks.max(1), 1, 1),
11060        block_dim: (BLOCK, 1, 1),
11061        shared_mem_bytes: 0, // Statically allocated in PTX
11062    };
11063
11064    unsafe {
11065        stream
11066            .launch_builder(&f)
11067            .arg(a.inner())
11068            .arg(partials.inner_mut())
11069            .arg(&n_u32)
11070            .launch(cfg)?;
11071    }
11072
11073    // Pass 2: reduce partial sums.
11074    if num_blocks <= 1 {
11075        return Ok(partials);
11076    }
11077
11078    // For small number of blocks, reduce on CPU (cheaper than another kernel launch).
11079    if num_blocks <= 256 {
11080        let host_partials = gpu_to_cpu(&partials, device)?;
11081        let total: f32 = host_partials.iter().sum();
11082        return cpu_to_gpu(&[total], device);
11083    }
11084
11085    // For many blocks, recurse with another kernel launch.
11086    gpu_reduce_sum(&partials, device)
11087}
11088
11089/// Stub -- always returns [`GpuError::NoCudaFeature`].
11090#[cfg(not(feature = "cuda"))]
11091pub fn gpu_reduce_sum(
11092    _a: &CudaBuffer<f32>,
11093    _device: &GpuDevice,
11094) -> GpuResult<CudaBuffer<f32>> {
11095    Err(GpuError::NoCudaFeature)
11096}
11097
11098///   `output[i] = sum_{k=0}^{axis_size-1} input[outer_idx * axis_size * inner_size + k * inner_size + inner_idx]`
11099///
11100/// where `outer_idx = i / inner_size`, `inner_idx = i % inner_size`.
11101#[cfg(feature = "cuda")]
11102pub fn gpu_sum_axis(
11103    a: &CudaBuffer<f32>,
11104    outer: usize,
11105    axis_size: usize,
11106    inner: usize,
11107    device: &GpuDevice,
11108) -> GpuResult<CudaBuffer<f32>> {
11109    use cudarc::driver::PushKernelArg;
11110
11111    validate_unary(a, device)?;
11112
11113    let total_output = outer * inner;
11114    let ctx = device.context();
11115    let stream = device.stream();
11116
11117    let f = match crate::module_cache::get_or_compile(
11118        ctx,
11119        SUM_AXIS_PTX,
11120        "sum_axis_kernel",
11121        device.ordinal() as u32,
11122    ) {
11123        Ok(f) => f,
11124        Err(_) => {
11125            // CPU fallback
11126            let host = gpu_to_cpu(a, device)?;
11127            let mut result = vec![0.0f32; total_output];
11128            for (i, out) in result.iter_mut().enumerate() {
11129                let outer_idx = i / inner;
11130                let inner_idx = i % inner;
11131                let mut sum = 0.0f32;
11132                for k in 0..axis_size {
11133                    sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
11134                }
11135                *out = sum;
11136            }
11137            return cpu_to_gpu(&result, device);
11138        }
11139    };
11140
11141    let mut out = alloc_zeros_f32(total_output, device)?;
11142    let cfg = launch_cfg(total_output)?;
11143    let outer_u32 = outer as u32;
11144    let axis_size_u32 = axis_size as u32;
11145    let inner_u32 = inner as u32;
11146    let total_u32 = total_output as u32;
11147
11148    unsafe {
11149        stream
11150            .launch_builder(&f)
11151            .arg(a.inner())
11152            .arg(out.inner_mut())
11153            .arg(&outer_u32)
11154            .arg(&axis_size_u32)
11155            .arg(&inner_u32)
11156            .arg(&total_u32)
11157            .launch(cfg)?;
11158    }
11159
11160    Ok(out)
11161}
11162
11163// ---------------------------------------------------------------------------
11164// Public API -- Cumulative scan operations
11165// ---------------------------------------------------------------------------
11166
11167/// Cumulative sum (prefix sum) along an axis on GPU.
11168///
11169/// `output[base + k*inner] = sum_{j=0}^{k} input[base + j*inner]`
11170/// where `base = outer_idx * dim_size * inner + inner_idx`.
11171///
11172/// One thread per (outer_idx, inner_idx) pair; each thread does a sequential
11173/// scan along `dim_size` elements.
11174///
11175/// # Errors
11176///
11177/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
11178/// - [`GpuError::Driver`] on CUDA runtime errors.
11179#[cfg(feature = "cuda")]
11180pub fn gpu_cumsum(
11181    input: &CudaBuffer<f32>,
11182    outer: usize,
11183    dim_size: usize,
11184    inner: usize,
11185    device: &GpuDevice,
11186) -> GpuResult<CudaBuffer<f32>> {
11187    use cudarc::driver::PushKernelArg;
11188
11189    validate_unary(input, device)?;
11190
11191    let total = outer * dim_size * inner;
11192    let num_threads = outer * inner;
11193    let ctx = device.context();
11194    let stream = device.stream();
11195
11196    let f = match crate::module_cache::get_or_compile(
11197        ctx,
11198        CUMSUM_PTX,
11199        "cumsum_kernel",
11200        device.ordinal() as u32,
11201    ) {
11202        Ok(f) => f,
11203        Err(_) => {
11204            // CPU fallback
11205            let host = gpu_to_cpu(input, device)?;
11206            let mut result = vec![0.0f32; total];
11207            for i in 0..num_threads {
11208                let outer_idx = i / inner;
11209                let inner_idx = i % inner;
11210                let base = outer_idx * dim_size * inner + inner_idx;
11211                let mut acc = 0.0f32;
11212                for k in 0..dim_size {
11213                    let idx = base + k * inner;
11214                    acc += host[idx];
11215                    result[idx] = acc;
11216                }
11217            }
11218            return cpu_to_gpu(&result, device);
11219        }
11220    };
11221
11222    let mut out = alloc_zeros_f32(total, device)?;
11223    let cfg = launch_cfg(num_threads)?;
11224    let outer_u32 = outer as u32;
11225    let dim_size_u32 = dim_size as u32;
11226    let inner_u32 = inner as u32;
11227    let total_u32 = total as u32;
11228
11229    unsafe {
11230        stream
11231            .launch_builder(&f)
11232            .arg(input.inner())
11233            .arg(out.inner_mut())
11234            .arg(&outer_u32)
11235            .arg(&dim_size_u32)
11236            .arg(&inner_u32)
11237            .arg(&total_u32)
11238            .launch(cfg)?;
11239    }
11240
11241    Ok(out)
11242}
11243
11244/// Cumulative product (prefix product) along an axis on GPU.
11245///
11246/// `output[base + k*inner] = prod_{j=0}^{k} input[base + j*inner]`
11247/// where `base = outer_idx * dim_size * inner + inner_idx`.
11248///
11249/// # Errors
11250///
11251/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
11252/// - [`GpuError::Driver`] on CUDA runtime errors.
11253#[cfg(feature = "cuda")]
11254pub fn gpu_cumprod(
11255    input: &CudaBuffer<f32>,
11256    outer: usize,
11257    dim_size: usize,
11258    inner: usize,
11259    device: &GpuDevice,
11260) -> GpuResult<CudaBuffer<f32>> {
11261    use cudarc::driver::PushKernelArg;
11262
11263    validate_unary(input, device)?;
11264
11265    let total = outer * dim_size * inner;
11266    let num_threads = outer * inner;
11267    let ctx = device.context();
11268    let stream = device.stream();
11269
11270    let f = match crate::module_cache::get_or_compile(
11271        ctx,
11272        CUMPROD_PTX,
11273        "cumprod_kernel",
11274        device.ordinal() as u32,
11275    ) {
11276        Ok(f) => f,
11277        Err(_) => {
11278            // CPU fallback
11279            let host = gpu_to_cpu(input, device)?;
11280            let mut result = vec![0.0f32; total];
11281            for i in 0..num_threads {
11282                let outer_idx = i / inner;
11283                let inner_idx = i % inner;
11284                let base = outer_idx * dim_size * inner + inner_idx;
11285                let mut acc = 1.0f32;
11286                for k in 0..dim_size {
11287                    let idx = base + k * inner;
11288                    acc *= host[idx];
11289                    result[idx] = acc;
11290                }
11291            }
11292            return cpu_to_gpu(&result, device);
11293        }
11294    };
11295
11296    let mut out = alloc_zeros_f32(total, device)?;
11297    let cfg = launch_cfg(num_threads)?;
11298    let outer_u32 = outer as u32;
11299    let dim_size_u32 = dim_size as u32;
11300    let inner_u32 = inner as u32;
11301    let total_u32 = total as u32;
11302
11303    unsafe {
11304        stream
11305            .launch_builder(&f)
11306            .arg(input.inner())
11307            .arg(out.inner_mut())
11308            .arg(&outer_u32)
11309            .arg(&dim_size_u32)
11310            .arg(&inner_u32)
11311            .arg(&total_u32)
11312            .launch(cfg)?;
11313    }
11314
11315    Ok(out)
11316}
11317
11318/// Cumulative maximum (running max) along an axis on GPU.
11319///
11320/// `output[base + k*inner] = max_{j=0}^{k} input[base + j*inner]`
11321/// where `base = outer_idx * dim_size * inner + inner_idx`.
11322///
11323/// # Errors
11324///
11325/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
11326/// - [`GpuError::Driver`] on CUDA runtime errors.
11327#[cfg(feature = "cuda")]
11328pub fn gpu_cummax(
11329    input: &CudaBuffer<f32>,
11330    outer: usize,
11331    dim_size: usize,
11332    inner: usize,
11333    device: &GpuDevice,
11334) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
11335    use cudarc::driver::PushKernelArg;
11336
11337    validate_unary(input, device)?;
11338
11339    let total = outer * dim_size * inner;
11340    let num_threads = outer * inner;
11341    let ctx = device.context();
11342    let stream = device.stream();
11343
11344    let f = match crate::module_cache::get_or_compile(
11345        ctx,
11346        CUMMAX_PTX,
11347        "cummax_kernel",
11348        device.ordinal() as u32,
11349    ) {
11350        Ok(f) => f,
11351        Err(_) => {
11352            let host = gpu_to_cpu(input, device)?;
11353            let mut vals = vec![0.0f32; total];
11354            let mut idxs = vec![0.0f32; total];
11355            for i in 0..num_threads {
11356                let outer_idx = i / inner;
11357                let inner_idx = i % inner;
11358                let base = outer_idx * dim_size * inner + inner_idx;
11359                let mut acc = f32::NEG_INFINITY;
11360                let mut best = 0u32;
11361                for k in 0..dim_size {
11362                    let idx = base + k * inner;
11363                    if host[idx] > acc {
11364                        acc = host[idx];
11365                        best = k as u32;
11366                    }
11367                    vals[idx] = acc;
11368                    idxs[idx] = best as f32;
11369                }
11370            }
11371            return Ok((cpu_to_gpu(&vals, device)?, cpu_to_gpu(&idxs, device)?));
11372        }
11373    };
11374
11375    let mut out = alloc_zeros_f32(total, device)?;
11376    let mut out_idx = alloc_zeros_f32(total, device)?;
11377    let cfg = launch_cfg(num_threads)?;
11378    let outer_u32 = outer as u32;
11379    let dim_size_u32 = dim_size as u32;
11380    let inner_u32 = inner as u32;
11381    let total_u32 = total as u32;
11382
11383    unsafe {
11384        stream
11385            .launch_builder(&f)
11386            .arg(input.inner())
11387            .arg(out.inner_mut())
11388            .arg(out_idx.inner_mut())
11389            .arg(&outer_u32)
11390            .arg(&dim_size_u32)
11391            .arg(&inner_u32)
11392            .arg(&total_u32)
11393            .launch(cfg)?;
11394    }
11395
11396    Ok((out, out_idx))
11397}
11398
11399/// Cumulative minimum (running min) along an axis on GPU.
11400///
11401/// `output[base + k*inner] = min_{j=0}^{k} input[base + j*inner]`
11402/// where `base = outer_idx * dim_size * inner + inner_idx`.
11403///
11404/// # Errors
11405///
11406/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
11407/// - [`GpuError::Driver`] on CUDA runtime errors.
11408#[cfg(feature = "cuda")]
11409pub fn gpu_cummin(
11410    input: &CudaBuffer<f32>,
11411    outer: usize,
11412    dim_size: usize,
11413    inner: usize,
11414    device: &GpuDevice,
11415) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
11416    use cudarc::driver::PushKernelArg;
11417
11418    validate_unary(input, device)?;
11419
11420    let total = outer * dim_size * inner;
11421    let num_threads = outer * inner;
11422    let ctx = device.context();
11423    let stream = device.stream();
11424
11425    let f = match crate::module_cache::get_or_compile(
11426        ctx,
11427        CUMMIN_PTX,
11428        "cummin_kernel",
11429        device.ordinal() as u32,
11430    ) {
11431        Ok(f) => f,
11432        Err(_) => {
11433            let host = gpu_to_cpu(input, device)?;
11434            let mut vals = vec![0.0f32; total];
11435            let mut idxs = vec![0.0f32; total];
11436            for i in 0..num_threads {
11437                let outer_idx = i / inner;
11438                let inner_idx = i % inner;
11439                let base = outer_idx * dim_size * inner + inner_idx;
11440                let mut acc = f32::INFINITY;
11441                let mut best = 0u32;
11442                for k in 0..dim_size {
11443                    let idx = base + k * inner;
11444                    if host[idx] < acc {
11445                        acc = host[idx];
11446                        best = k as u32;
11447                    }
11448                    vals[idx] = acc;
11449                    idxs[idx] = best as f32;
11450                }
11451            }
11452            return Ok((cpu_to_gpu(&vals, device)?, cpu_to_gpu(&idxs, device)?));
11453        }
11454    };
11455
11456    let mut out = alloc_zeros_f32(total, device)?;
11457    let mut out_idx = alloc_zeros_f32(total, device)?;
11458    let cfg = launch_cfg(num_threads)?;
11459    let outer_u32 = outer as u32;
11460    let dim_size_u32 = dim_size as u32;
11461    let inner_u32 = inner as u32;
11462    let total_u32 = total as u32;
11463
11464    unsafe {
11465        stream
11466            .launch_builder(&f)
11467            .arg(input.inner())
11468            .arg(out.inner_mut())
11469            .arg(out_idx.inner_mut())
11470            .arg(&outer_u32)
11471            .arg(&dim_size_u32)
11472            .arg(&inner_u32)
11473            .arg(&total_u32)
11474            .launch(cfg)?;
11475    }
11476
11477    Ok((out, out_idx))
11478}
11479
11480/// Numerically stable log-cumulative-sum-exp along an axis on GPU.
11481///
11482/// `acc = log(exp(acc) + exp(x))` computed as `m + log(exp(acc-m) + exp(x-m))`
11483/// where `m = max(acc, x)` for numerical stability.
11484///
11485/// # Errors
11486///
11487/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
11488/// - [`GpuError::Driver`] on CUDA runtime errors.
11489#[cfg(feature = "cuda")]
11490pub fn gpu_logcumsumexp(
11491    input: &CudaBuffer<f32>,
11492    outer: usize,
11493    dim_size: usize,
11494    inner: usize,
11495    device: &GpuDevice,
11496) -> GpuResult<CudaBuffer<f32>> {
11497    use cudarc::driver::PushKernelArg;
11498
11499    validate_unary(input, device)?;
11500
11501    let total = outer * dim_size * inner;
11502    let num_threads = outer * inner;
11503    let ctx = device.context();
11504    let stream = device.stream();
11505
11506    let f = match crate::module_cache::get_or_compile(
11507        ctx,
11508        LOGCUMSUMEXP_PTX,
11509        "logcumsumexp_kernel",
11510        device.ordinal() as u32,
11511    ) {
11512        Ok(f) => f,
11513        Err(_) => {
11514            // CPU fallback
11515            let host = gpu_to_cpu(input, device)?;
11516            let mut result = vec![0.0f32; total];
11517            for i in 0..num_threads {
11518                let outer_idx = i / inner;
11519                let inner_idx = i % inner;
11520                let base = outer_idx * dim_size * inner + inner_idx;
11521                let mut acc = f32::NEG_INFINITY;
11522                for k in 0..dim_size {
11523                    let idx = base + k * inner;
11524                    let x = host[idx];
11525                    let m = acc.max(x);
11526                    acc = m + ((acc - m).exp() + (x - m).exp()).ln();
11527                    result[idx] = acc;
11528                }
11529            }
11530            return cpu_to_gpu(&result, device);
11531        }
11532    };
11533
11534    let mut out = alloc_zeros_f32(total, device)?;
11535    let cfg = launch_cfg(num_threads)?;
11536    let outer_u32 = outer as u32;
11537    let dim_size_u32 = dim_size as u32;
11538    let inner_u32 = inner as u32;
11539    let total_u32 = total as u32;
11540
11541    unsafe {
11542        stream
11543            .launch_builder(&f)
11544            .arg(input.inner())
11545            .arg(out.inner_mut())
11546            .arg(&outer_u32)
11547            .arg(&dim_size_u32)
11548            .arg(&inner_u32)
11549            .arg(&total_u32)
11550            .launch(cfg)?;
11551    }
11552
11553    Ok(out)
11554}
11555
11556// ---------------------------------------------------------------------------
11557// Public API -- Strided split
11558// ---------------------------------------------------------------------------
11559
11560/// Extract a sub-tensor along one axis entirely on GPU.
11561///
11562/// Given an input buffer representing a tensor with `total_along_axis` elements
11563/// along the split axis, extracts the slice `[split_offset .. split_offset + split_size]`
11564/// along that axis.
11565///
11566/// - `inner_size` = product of dimensions after the split axis.
11567/// - `n` = total number of output elements (outer * split_size * inner_size).
11568///
11569/// # Errors
11570///
11571/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
11572/// - [`GpuError::Driver`] on CUDA runtime errors.
11573#[cfg(feature = "cuda")]
11574pub fn gpu_strided_split(
11575    input: &CudaBuffer<f32>,
11576    total_along_axis: usize,
11577    split_offset: usize,
11578    split_size: usize,
11579    inner_size: usize,
11580    n: usize,
11581    device: &GpuDevice,
11582) -> GpuResult<CudaBuffer<f32>> {
11583    use cudarc::driver::PushKernelArg;
11584
11585    validate_unary(input, device)?;
11586
11587    let ctx = device.context();
11588    let stream = device.stream();
11589
11590    let f = match crate::module_cache::get_or_compile(
11591        ctx,
11592        STRIDED_SPLIT_PTX,
11593        "strided_split_kernel",
11594        device.ordinal() as u32,
11595    ) {
11596        Ok(f) => f,
11597        Err(_) => {
11598            // CPU fallback
11599            let host = gpu_to_cpu(input, device)?;
11600            let outer = n / (split_size * inner_size);
11601            let mut result = vec![0.0f32; n];
11602            for (i, out) in result.iter_mut().enumerate() {
11603                let outer_idx = i / (split_size * inner_size);
11604                let within = i % (split_size * inner_size);
11605                let src_idx =
11606                    outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
11607                *out = host[src_idx];
11608            }
11609            let _ = outer;
11610            return cpu_to_gpu(&result, device);
11611        }
11612    };
11613
11614    let mut out = alloc_zeros_f32(n, device)?;
11615    let cfg = launch_cfg(n)?;
11616    let total_ax_u32 = total_along_axis as u32;
11617    let offset_u32 = split_offset as u32;
11618    let split_sz_u32 = split_size as u32;
11619    let inner_u32 = inner_size as u32;
11620    let n_u32 = n as u32;
11621
11622    unsafe {
11623        stream
11624            .launch_builder(&f)
11625            .arg(input.inner())
11626            .arg(out.inner_mut())
11627            .arg(&total_ax_u32)
11628            .arg(&offset_u32)
11629            .arg(&split_sz_u32)
11630            .arg(&inner_u32)
11631            .arg(&n_u32)
11632            .launch(cfg)?;
11633    }
11634
11635    Ok(out)
11636}
11637
11638// ---------------------------------------------------------------------------
11639// Public API -- Strided cat
11640// ---------------------------------------------------------------------------
11641
11642/// Write a sub-tensor into a larger output buffer at an offset along one axis,
11643/// entirely on GPU.
11644///
11645/// Given an input buffer representing a chunk with `part_size` elements along
11646/// the cat axis, writes it into `output` at position `cat_offset` along that axis.
11647///
11648/// - `inner_size` = product of dimensions after the cat axis.
11649/// - `n` = total number of input elements (outer * part_size * inner_size).
11650///
11651/// # Safety
11652///
11653/// `output` must be large enough to hold the written region. The caller is
11654/// responsible for ensuring non-overlapping writes when multiple chunks are
11655/// written into the same output buffer.
11656///
11657/// # Errors
11658///
11659/// - [`GpuError::DeviceMismatch`] if buffers and `device` are on different devices.
11660/// - [`GpuError::Driver`] on CUDA runtime errors.
11661#[cfg(feature = "cuda")]
11662#[allow(clippy::too_many_arguments)]
11663pub fn gpu_strided_cat(
11664    input: &CudaBuffer<f32>,
11665    output: &mut CudaBuffer<f32>,
11666    total_along_axis: usize,
11667    cat_offset: usize,
11668    part_size: usize,
11669    inner_size: usize,
11670    n: usize,
11671    device: &GpuDevice,
11672) -> GpuResult<()> {
11673    use cudarc::driver::PushKernelArg;
11674
11675    validate_unary(input, device)?;
11676
11677    let ctx = device.context();
11678    let stream = device.stream();
11679
11680    let f = match crate::module_cache::get_or_compile(
11681        ctx,
11682        STRIDED_CAT_PTX,
11683        "strided_cat_kernel",
11684        device.ordinal() as u32,
11685    ) {
11686        Ok(f) => f,
11687        Err(_) => {
11688            // CPU fallback
11689            let host_in = gpu_to_cpu(input, device)?;
11690            let mut host_out = gpu_to_cpu(output, device)?;
11691            for (i, &val) in host_in.iter().enumerate().take(n) {
11692                let outer_idx = i / (part_size * inner_size);
11693                let within = i % (part_size * inner_size);
11694                let dst_idx =
11695                    outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
11696                host_out[dst_idx] = val;
11697            }
11698            *output = cpu_to_gpu(&host_out, device)?;
11699            return Ok(());
11700        }
11701    };
11702
11703    let cfg = launch_cfg(n)?;
11704    let total_ax_u32 = total_along_axis as u32;
11705    let offset_u32 = cat_offset as u32;
11706    let part_sz_u32 = part_size as u32;
11707    let inner_u32 = inner_size as u32;
11708    let n_u32 = n as u32;
11709
11710    unsafe {
11711        stream
11712            .launch_builder(&f)
11713            .arg(input.inner())
11714            .arg(output.inner_mut())
11715            .arg(&total_ax_u32)
11716            .arg(&offset_u32)
11717            .arg(&part_sz_u32)
11718            .arg(&inner_u32)
11719            .arg(&n_u32)
11720            .launch(cfg)?;
11721    }
11722
11723    Ok(())
11724}
11725
11726// ---------------------------------------------------------------------------
11727// Public API -- Strided copy (general N-d gather) -- CL-496
11728// ---------------------------------------------------------------------------
11729
11730/// Maximum rank supported by [`gpu_strided_copy`] and [`gpu_strided_copy_f64`].
11731/// Matches the unrolled PTX kernel's dimension count.
11732pub const STRIDED_COPY_MAX_DIMS: usize = 8;
11733
11734/// Pad-and-validate the (out_shape, src_strides) pair for the
11735/// strided-copy kernel.
11736///
11737/// Returns a fixed-size `[MAX_DIMS]` pair of arrays where:
11738/// - `out_stride[d]` is the contiguous output stride (in elements)
11739///   for that dim, with unused trailing dims filled with `n + 1` so
11740///   that `flat / out_stride[d] == 0` in the kernel (no contribution).
11741/// - `src_stride[d]` is the source stride (in elements) for that
11742///   dim, with unused trailing dims filled with 0 so the source-
11743///   offset contribution is zero.
11744///
11745/// `out_shape` and `src_strides` must have the same length, at most
11746/// `STRIDED_COPY_MAX_DIMS`. `n` is the product of `out_shape`.
11747#[cfg(feature = "cuda")]
11748fn pad_strided_copy_params(
11749    out_shape: &[usize],
11750    src_strides: &[isize],
11751    n: usize,
11752) -> GpuResult<([u32; STRIDED_COPY_MAX_DIMS], [u32; STRIDED_COPY_MAX_DIMS])> {
11753    if out_shape.len() != src_strides.len() {
11754        return Err(GpuError::ShapeMismatch {
11755            op: "strided_copy_pad",
11756            expected: vec![out_shape.len()],
11757            got: vec![src_strides.len()],
11758        });
11759    }
11760    if out_shape.len() > STRIDED_COPY_MAX_DIMS {
11761        return Err(GpuError::ShapeMismatch {
11762            op: "strided_copy_pad",
11763            expected: vec![STRIDED_COPY_MAX_DIMS],
11764            got: vec![out_shape.len()],
11765        });
11766    }
11767    // Reject negative source strides — the kernel treats them as u32
11768    // which would wrap around and produce garbage indices.
11769    for &s in src_strides {
11770        if s < 0 {
11771            return Err(GpuError::ShapeMismatch {
11772                op: "strided_copy_pad_negative_stride",
11773                expected: vec![0],
11774                got: vec![s.unsigned_abs()],
11775            });
11776        }
11777    }
11778
11779    let rank = out_shape.len();
11780    // Compute contiguous output strides: stride[rank-1] = 1,
11781    // stride[d] = stride[d+1] * shape[d+1].
11782    let mut out_stride = [0u32; STRIDED_COPY_MAX_DIMS];
11783    if rank > 0 {
11784        let mut acc: usize = 1;
11785        for d in (0..rank).rev() {
11786            if acc > u32::MAX as usize {
11787                return Err(GpuError::ShapeMismatch {
11788                    op: "strided_copy_stride_overflow",
11789                    expected: vec![u32::MAX as usize],
11790                    got: vec![acc],
11791                });
11792            }
11793            out_stride[d] = acc as u32;
11794            acc = acc.saturating_mul(out_shape[d]);
11795        }
11796    }
11797
11798    // Pad unused dims with `n + 1` so `flat / out_stride[d] == 0`
11799    // in the kernel (any flat < n is strictly less than n + 1).
11800    let pad_val = (n as u32).saturating_add(1).max(1);
11801    for d in rank..STRIDED_COPY_MAX_DIMS {
11802        out_stride[d] = pad_val;
11803    }
11804
11805    // src_stride with 0 fill for unused dims (no contribution).
11806    let mut src_stride_out = [0u32; STRIDED_COPY_MAX_DIMS];
11807    for d in 0..rank {
11808        let s = src_strides[d];
11809        if s as usize > u32::MAX as usize {
11810            return Err(GpuError::ShapeMismatch {
11811                op: "strided_copy_src_stride_overflow",
11812                expected: vec![u32::MAX as usize],
11813                got: vec![s as usize],
11814            });
11815        }
11816        src_stride_out[d] = s as u32;
11817    }
11818
11819    Ok((out_stride, src_stride_out))
11820}
11821
11822/// Gather a non-contiguous strided view of `input` into a new
11823/// contiguous output buffer, entirely on GPU. CL-496.
11824///
11825/// # Arguments
11826///
11827/// * `input`      — the storage backing the strided view. Must be
11828///   on `device`.
11829/// * `out_shape`  — shape of the contiguous output (and of the
11830///   logical view). `out_shape.len() <= STRIDED_COPY_MAX_DIMS`.
11831/// * `src_strides` — source element strides per dim, aligned with
11832///   `out_shape`. Must be non-negative (no reverse views yet).
11833/// * `src_offset`  — base element offset into `input` for the view.
11834/// * `device`     — CUDA device.
11835///
11836/// # Returns
11837///
11838/// A contiguous `CudaBuffer<f32>` with `product(out_shape)` elements.
11839///
11840/// # Errors
11841///
11842/// - [`GpuError::DeviceMismatch`] if `input` and `device` differ.
11843/// - [`GpuError::ShapeMismatch`] on rank mismatch, too many dims,
11844///   negative strides, or stride overflow of `u32::MAX`.
11845/// - [`GpuError::Driver`] on CUDA runtime errors.
11846#[cfg(feature = "cuda")]
11847pub fn gpu_strided_copy(
11848    input: &CudaBuffer<f32>,
11849    out_shape: &[usize],
11850    src_strides: &[isize],
11851    src_offset: usize,
11852    device: &GpuDevice,
11853) -> GpuResult<CudaBuffer<f32>> {
11854    use cudarc::driver::PushKernelArg;
11855
11856    validate_unary(input, device)?;
11857
11858    let n: usize = out_shape.iter().product();
11859    let (out_stride, src_stride) = pad_strided_copy_params(out_shape, src_strides, n)?;
11860
11861    if n == 0 {
11862        return alloc_zeros_f32(0, device);
11863    }
11864
11865    let ctx = device.context();
11866    let stream = device.stream();
11867
11868    let f = match crate::module_cache::get_or_compile(
11869        ctx,
11870        STRIDED_COPY_PTX,
11871        "strided_copy_kernel",
11872        device.ordinal() as u32,
11873    ) {
11874        Ok(f) => f,
11875        Err(_) => {
11876            // CPU fallback — decode indices on the host.
11877            let host = gpu_to_cpu(input, device)?;
11878            let mut result = vec![0.0f32; n];
11879            for i in 0..n {
11880                let mut flat = i as u32;
11881                let mut src_idx = src_offset as u32;
11882                for d in 0..STRIDED_COPY_MAX_DIMS {
11883                    let os = out_stride[d];
11884                    let ss = src_stride[d];
11885                    let coord = flat / os;
11886                    flat -= coord * os;
11887                    src_idx += coord * ss;
11888                }
11889                result[i] = host[src_idx as usize];
11890            }
11891            return cpu_to_gpu(&result, device);
11892        }
11893    };
11894
11895    let mut out = alloc_zeros_f32(n, device)?;
11896    let cfg = launch_cfg(n)?;
11897    let src_offset_u32 = src_offset as u32;
11898    let n_u32 = n as u32;
11899
11900    unsafe {
11901        stream
11902            .launch_builder(&f)
11903            .arg(input.inner())
11904            .arg(out.inner_mut())
11905            .arg(&src_offset_u32)
11906            .arg(&n_u32)
11907            .arg(&out_stride[0])
11908            .arg(&out_stride[1])
11909            .arg(&out_stride[2])
11910            .arg(&out_stride[3])
11911            .arg(&out_stride[4])
11912            .arg(&out_stride[5])
11913            .arg(&out_stride[6])
11914            .arg(&out_stride[7])
11915            .arg(&src_stride[0])
11916            .arg(&src_stride[1])
11917            .arg(&src_stride[2])
11918            .arg(&src_stride[3])
11919            .arg(&src_stride[4])
11920            .arg(&src_stride[5])
11921            .arg(&src_stride[6])
11922            .arg(&src_stride[7])
11923            .launch(cfg)?;
11924    }
11925
11926    Ok(out)
11927}
11928
11929/// f64 variant of [`gpu_strided_copy`]. CL-496.
11930#[cfg(feature = "cuda")]
11931pub fn gpu_strided_copy_f64(
11932    input: &CudaBuffer<f64>,
11933    out_shape: &[usize],
11934    src_strides: &[isize],
11935    src_offset: usize,
11936    device: &GpuDevice,
11937) -> GpuResult<CudaBuffer<f64>> {
11938    use cudarc::driver::PushKernelArg;
11939    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
11940
11941    validate_device(input, device)?;
11942
11943    let n: usize = out_shape.iter().product();
11944    let (out_stride, src_stride) = pad_strided_copy_params(out_shape, src_strides, n)?;
11945
11946    if n == 0 {
11947        return alloc_zeros_f64(0, device);
11948    }
11949
11950    let ctx = device.context();
11951    let stream = device.stream();
11952
11953    let ptx = get_f64_ptx(
11954        &CACHE,
11955        STRIDED_COPY_PTX,
11956        "strided_copy_kernel",
11957        "strided_copy_f64_kernel",
11958    );
11959    let f = match crate::module_cache::get_or_compile(
11960        ctx,
11961        ptx,
11962        "strided_copy_f64_kernel",
11963        device.ordinal() as u32,
11964    ) {
11965        Ok(f) => f,
11966        Err(_) => {
11967            let host = gpu_to_cpu(input, device)?;
11968            let mut result = vec![0.0f64; n];
11969            for i in 0..n {
11970                let mut flat = i as u32;
11971                let mut src_idx = src_offset as u32;
11972                for d in 0..STRIDED_COPY_MAX_DIMS {
11973                    let os = out_stride[d];
11974                    let ss = src_stride[d];
11975                    let coord = flat / os;
11976                    flat -= coord * os;
11977                    src_idx += coord * ss;
11978                }
11979                result[i] = host[src_idx as usize];
11980            }
11981            return cpu_to_gpu(&result, device);
11982        }
11983    };
11984
11985    let mut out = alloc_zeros_f64(n, device)?;
11986    let cfg = launch_cfg(n)?;
11987    let src_offset_u32 = src_offset as u32;
11988    let n_u32 = n as u32;
11989
11990    unsafe {
11991        stream
11992            .launch_builder(&f)
11993            .arg(input.inner())
11994            .arg(out.inner_mut())
11995            .arg(&src_offset_u32)
11996            .arg(&n_u32)
11997            .arg(&out_stride[0])
11998            .arg(&out_stride[1])
11999            .arg(&out_stride[2])
12000            .arg(&out_stride[3])
12001            .arg(&out_stride[4])
12002            .arg(&out_stride[5])
12003            .arg(&out_stride[6])
12004            .arg(&out_stride[7])
12005            .arg(&src_stride[0])
12006            .arg(&src_stride[1])
12007            .arg(&src_stride[2])
12008            .arg(&src_stride[3])
12009            .arg(&src_stride[4])
12010            .arg(&src_stride[5])
12011            .arg(&src_stride[6])
12012            .arg(&src_stride[7])
12013            .launch(cfg)?;
12014    }
12015
12016    Ok(out)
12017}
12018
12019/// Scalar multiply: `out[i] = a[i] * scalar`.
12020///
12021/// Multiplies every element by a constant float value on the GPU.
12022///
12023/// # Errors
12024///
12025/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different CUDA devices.
12026/// - [`GpuError::Driver`] on CUDA runtime errors.
12027#[cfg(feature = "cuda")]
12028pub fn gpu_scale(
12029    a: &CudaBuffer<f32>,
12030    scalar: f32,
12031    device: &GpuDevice,
12032) -> GpuResult<CudaBuffer<f32>> {
12033    use cudarc::driver::PushKernelArg;
12034
12035    validate_unary(a, device)?;
12036
12037    let n = a.len();
12038    let ctx = device.context();
12039    let stream = device.stream();
12040
12041    let f = match crate::module_cache::get_or_compile(
12042        ctx,
12043        SCALE_PTX,
12044        "scale_kernel",
12045        device.ordinal() as u32,
12046    ) {
12047        Ok(f) => f,
12048        Err(_) => {
12049            // CPU fallback
12050            let host = gpu_to_cpu(a, device)?;
12051            let result: Vec<f32> = host.iter().map(|&x| x * scalar).collect();
12052            return cpu_to_gpu(&result, device);
12053        }
12054    };
12055
12056    let mut out = alloc_zeros_f32(n, device)?;
12057    let cfg = launch_cfg(n)?;
12058    let n_u32 = n as u32;
12059
12060    unsafe {
12061        stream
12062            .launch_builder(&f)
12063            .arg(a.inner())
12064            .arg(out.inner_mut())
12065            .arg(&scalar)
12066            .arg(&n_u32)
12067            .launch(cfg)?;
12068    }
12069
12070    Ok(out)
12071}
12072
12073// ---------------------------------------------------------------------------
12074// Public API -- softmax
12075// ---------------------------------------------------------------------------
12076
12077/// Row-wise softmax on GPU: one thread block per row, shared-memory reduction.
12078///
12079/// `rows` = product of all dims except the last. `cols` = last dim size.
12080#[cfg(feature = "cuda")]
12081pub fn gpu_softmax(
12082    input: &CudaBuffer<f32>,
12083    rows: usize,
12084    cols: usize,
12085    device: &GpuDevice,
12086) -> GpuResult<CudaBuffer<f32>> {
12087    use cudarc::driver::PushKernelArg;
12088
12089    validate_unary(input, device)?;
12090
12091    let ctx = device.context();
12092    let stream = device.stream();
12093
12094    let f = match crate::module_cache::get_or_compile(
12095        ctx,
12096        SOFTMAX_PTX,
12097        "softmax_kernel",
12098        device.ordinal() as u32,
12099    ) {
12100        Ok(f) => f,
12101        Err(_) => {
12102            // CPU fallback.
12103            let host = gpu_to_cpu(input, device)?;
12104            let mut out = vec![0.0f32; host.len()];
12105            for r in 0..rows {
12106                let base = r * cols;
12107                let mut max_v = f32::NEG_INFINITY;
12108                for c in 0..cols {
12109                    max_v = max_v.max(host[base + c]);
12110                }
12111                let mut sum = 0.0f32;
12112                for c in 0..cols {
12113                    let e = (host[base + c] - max_v).exp();
12114                    out[base + c] = e;
12115                    sum += e;
12116                }
12117                let inv = 1.0 / sum;
12118                for c in 0..cols {
12119                    out[base + c] *= inv;
12120                }
12121            }
12122            return cpu_to_gpu(&out, device);
12123        }
12124    };
12125
12126    let mut out = alloc_zeros_f32(rows * cols, device)?;
12127    let rows_u32 = rows as u32;
12128    let cols_u32 = cols as u32;
12129
12130    // One block per row, 256 threads per block.
12131    let cfg = LaunchConfig {
12132        grid_dim: ((rows as u32).max(1), 1, 1),
12133        block_dim: (256, 1, 1),
12134        shared_mem_bytes: 256 * 4, // sdata[256] f32
12135    };
12136
12137    unsafe {
12138        stream
12139            .launch_builder(&f)
12140            .arg(input.inner())
12141            .arg(out.inner_mut())
12142            .arg(&rows_u32)
12143            .arg(&cols_u32)
12144            .launch(cfg)?;
12145    }
12146
12147    Ok(out)
12148}
12149
12150// ---------------------------------------------------------------------------
12151// Public API -- dropout
12152// ---------------------------------------------------------------------------
12153
12154/// Inverted dropout on GPU: `out[i] = input[i] * scale` or `0` with probability `p`.
12155///
12156/// `threshold` = `(p * u32::MAX as f64) as u32` — the RNG cutoff.
12157/// `scale` = `1.0 / (1.0 - p)`.
12158/// `seed` = random seed for the RNG.
12159///
12160/// **Known limitation**: This kernel uses a simple per-element hash
12161/// (`tid * 2654435761 ^ seed` with xorshift mixing), not the full
12162/// Philox 4x32-10 counter-based RNG that PyTorch uses. A proper Philox
12163/// dropout kernel would generate the mask via `philox_uniform_kernel`
12164/// and then threshold — producing higher-quality randomness and exact
12165/// reproducibility across CPU/GPU. The current hash is sufficient for
12166/// training but should be upgraded for research requiring strict
12167/// statistical properties.
12168#[cfg(feature = "cuda")]
12169pub fn gpu_dropout(
12170    input: &CudaBuffer<f32>,
12171    threshold: u32,
12172    scale: f32,
12173    seed: u32,
12174    device: &GpuDevice,
12175) -> GpuResult<CudaBuffer<f32>> {
12176    use cudarc::driver::PushKernelArg;
12177
12178    validate_unary(input, device)?;
12179
12180    let n = input.len();
12181    let ctx = device.context();
12182    let stream = device.stream();
12183
12184    let f = match crate::module_cache::get_or_compile(
12185        ctx,
12186        DROPOUT_PTX,
12187        "dropout_kernel",
12188        device.ordinal() as u32,
12189    ) {
12190        Ok(f) => f,
12191        Err(_) => {
12192            // CPU fallback.
12193            let host = gpu_to_cpu(input, device)?;
12194            // Stateless per-element hash matching the GPU kernel: each element
12195            // independently computes its own pseudorandom value from (tid, seed)
12196            // with no state carried between elements.
12197            let result: Vec<f32> = host
12198                .iter()
12199                .enumerate()
12200                .map(|(i, &x)| {
12201                    let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
12202                    r ^= r << 13;
12203                    r ^= r >> 17;
12204                    r ^= r << 5;
12205                    if r < threshold { 0.0 } else { x * scale }
12206                })
12207                .collect();
12208            return cpu_to_gpu(&result, device);
12209        }
12210    };
12211
12212    let mut out = alloc_zeros_f32(n, device)?;
12213    let cfg = launch_cfg(n)?;
12214    let n_u32 = n as u32;
12215
12216    unsafe {
12217        stream
12218            .launch_builder(&f)
12219            .arg(input.inner())
12220            .arg(out.inner_mut())
12221            .arg(&n_u32)
12222            .arg(&threshold)
12223            .arg(&scale)
12224            .arg(&seed)
12225            .launch(cfg)?;
12226    }
12227
12228    Ok(out)
12229}
12230
12231/// Elementwise dropout for f64 tensors.
12232#[cfg(feature = "cuda")]
12233pub fn gpu_dropout_f64(
12234    input: &CudaBuffer<f64>,
12235    threshold: u32,
12236    scale: f64,
12237    seed: u32,
12238    device: &GpuDevice,
12239) -> GpuResult<CudaBuffer<f64>> {
12240    use cudarc::driver::PushKernelArg;
12241    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
12242
12243    let n = input.len();
12244    let ctx = device.context();
12245    let stream = device.stream();
12246
12247    let ptx = get_f64_ptx(&CACHE, DROPOUT_PTX, "dropout_kernel", "dropout_f64_kernel");
12248    let f = match crate::module_cache::get_or_compile(
12249        ctx, ptx, "dropout_f64_kernel", device.ordinal() as u32,
12250    ) {
12251        Ok(f) => f,
12252        Err(_) => {
12253            let host = gpu_to_cpu(input, device)?;
12254            let result: Vec<f64> = host
12255                .iter()
12256                .enumerate()
12257                .map(|(i, &x)| {
12258                    let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
12259                    r ^= r << 13;
12260                    r ^= r >> 17;
12261                    r ^= r << 5;
12262                    if r < threshold { 0.0 } else { x * scale }
12263                })
12264                .collect();
12265            return cpu_to_gpu(&result, device);
12266        }
12267    };
12268
12269    let mut out = alloc_zeros_f64(n, device)?;
12270    let cfg = launch_cfg(n)?;
12271    let n_u32 = n as u32;
12272
12273    unsafe {
12274        stream
12275            .launch_builder(&f)
12276            .arg(input.inner())
12277            .arg(out.inner_mut())
12278            .arg(&n_u32)
12279            .arg(&threshold)
12280            .arg(&scale)
12281            .arg(&seed)
12282            .launch(cfg)?;
12283    }
12284
12285    Ok(out)
12286}
12287
12288#[cfg(not(feature = "cuda"))]
12289pub fn gpu_dropout_f64(_input: &CudaBuffer<f64>, _threshold: u32, _scale: f64, _seed: u32, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
12290
12291// ---------------------------------------------------------------------------
12292// Public API -- 2D transpose
12293// ---------------------------------------------------------------------------
12294
12295/// 2D matrix transpose on GPU: `[M, N]` -> `[N, M]`.
12296#[cfg(feature = "cuda")]
12297pub fn gpu_transpose_2d(
12298    input: &CudaBuffer<f32>,
12299    m: usize,
12300    n: usize,
12301    device: &GpuDevice,
12302) -> GpuResult<CudaBuffer<f32>> {
12303    use cudarc::driver::PushKernelArg;
12304
12305    validate_unary(input, device)?;
12306
12307    let total = m * n;
12308    let ctx = device.context();
12309    let stream = device.stream();
12310
12311    let f = match crate::module_cache::get_or_compile(
12312        ctx,
12313        TRANSPOSE_2D_PTX,
12314        "transpose_2d_kernel",
12315        device.ordinal() as u32,
12316    ) {
12317        Ok(f) => f,
12318        Err(_) => {
12319            // CPU fallback.
12320            let host = gpu_to_cpu(input, device)?;
12321            let mut out = vec![0.0f32; total];
12322            for i in 0..m {
12323                for j in 0..n {
12324                    out[j * m + i] = host[i * n + j];
12325                }
12326            }
12327            return cpu_to_gpu(&out, device);
12328        }
12329    };
12330
12331    let mut out = alloc_zeros_f32(total, device)?;
12332    let cfg = launch_cfg(total)?;
12333    let m_u32 = m as u32;
12334    let n_u32 = n as u32;
12335    let total_u32 = total as u32;
12336
12337    unsafe {
12338        stream
12339            .launch_builder(&f)
12340            .arg(input.inner())
12341            .arg(out.inner_mut())
12342            .arg(&m_u32)
12343            .arg(&n_u32)
12344            .arg(&total_u32)
12345            .launch(cfg)?;
12346    }
12347
12348    Ok(out)
12349}
12350
12351// ---------------------------------------------------------------------------
12352// Public API -- 4D permute (0,2,1,3)
12353// ---------------------------------------------------------------------------
12354
12355/// Permute a 4D tensor from `[d0, d1, d2, d3]` to `[d0, d2, d1, d3]` on GPU.
12356/// Used for attention head reshaping: `[B, S, H, D_h]` -> `[B, H, S, D_h]`.
12357#[cfg(feature = "cuda")]
12358pub fn gpu_permute_0213(
12359    input: &CudaBuffer<f32>,
12360    d0: usize,
12361    d1: usize,
12362    d2: usize,
12363    d3: usize,
12364    device: &GpuDevice,
12365) -> GpuResult<CudaBuffer<f32>> {
12366    use cudarc::driver::PushKernelArg;
12367
12368    validate_unary(input, device)?;
12369
12370    let total = d0 * d1 * d2 * d3;
12371    let ctx = device.context();
12372    let stream = device.stream();
12373
12374    let f = match crate::module_cache::get_or_compile(
12375        ctx,
12376        PERMUTE_0213_PTX,
12377        "permute_0213_kernel",
12378        device.ordinal() as u32,
12379    ) {
12380        Ok(f) => f,
12381        Err(_) => {
12382            // CPU fallback.
12383            let host = gpu_to_cpu(input, device)?;
12384            let mut out = vec![0.0f32; total];
12385            for i0 in 0..d0 {
12386                for i1 in 0..d1 {
12387                    for i2 in 0..d2 {
12388                        for i3 in 0..d3 {
12389                            let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
12390                            let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
12391                            out[out_idx] = host[in_idx];
12392                        }
12393                    }
12394                }
12395            }
12396            return cpu_to_gpu(&out, device);
12397        }
12398    };
12399
12400    let mut out = alloc_zeros_f32(total, device)?;
12401    let cfg = launch_cfg(total)?;
12402    let d0_u32 = d0 as u32;
12403    let d1_u32 = d1 as u32;
12404    let d2_u32 = d2 as u32;
12405    let d3_u32 = d3 as u32;
12406    let total_u32 = total as u32;
12407
12408    unsafe {
12409        stream
12410            .launch_builder(&f)
12411            .arg(input.inner())
12412            .arg(out.inner_mut())
12413            .arg(&d0_u32)
12414            .arg(&d1_u32)
12415            .arg(&d2_u32)
12416            .arg(&d3_u32)
12417            .arg(&total_u32)
12418            .launch(cfg)?;
12419    }
12420
12421    Ok(out)
12422}
12423
12424// ---------------------------------------------------------------------------
12425// Public API -- Small matmul (bypasses cuBLAS JIT)
12426// ---------------------------------------------------------------------------
12427
12428/// Small matrix multiply using our own PTX kernel. Avoids cuBLAS JIT
12429/// compilation overhead for tiny matrices where JIT cost > compute cost.
12430///
12431/// `a`: `[M, K]`, `b`: `[K, N]` → `c`: `[M, N]`.
12432#[cfg(feature = "cuda")]
12433pub fn gpu_small_matmul(
12434    a: &CudaBuffer<f32>,
12435    b: &CudaBuffer<f32>,
12436    m: usize,
12437    k: usize,
12438    n: usize,
12439    device: &GpuDevice,
12440) -> GpuResult<CudaBuffer<f32>> {
12441    use cudarc::driver::PushKernelArg;
12442
12443    let total = m * n;
12444    let ctx = device.context();
12445    let stream = device.stream();
12446
12447    let f = match crate::module_cache::get_or_compile(
12448        ctx,
12449        SMALL_MATMUL_PTX,
12450        "small_matmul_kernel",
12451        device.ordinal() as u32,
12452    ) {
12453        Ok(f) => f,
12454        Err(_) => {
12455            // Fall back to cuBLAS if our kernel can't compile.
12456            return crate::blas::gpu_matmul_f32(a, b, m, k, n, device);
12457        }
12458    };
12459
12460    let mut c = alloc_zeros_f32(total, device)?;
12461    let cfg = launch_cfg(total)?;
12462    let m_u32 = m as u32;
12463    let k_u32 = k as u32;
12464    let n_u32 = n as u32;
12465    let total_u32 = total as u32;
12466
12467    unsafe {
12468        stream
12469            .launch_builder(&f)
12470            .arg(a.inner())
12471            .arg(b.inner())
12472            .arg(c.inner_mut())
12473            .arg(&m_u32)
12474            .arg(&k_u32)
12475            .arg(&n_u32)
12476            .arg(&total_u32)
12477            .launch(cfg)?;
12478    }
12479
12480    Ok(c)
12481}
12482
12483/// Small batched matmul: C[i] = A[i] @ B[i] for i in 0..batch.
12484/// Uses the small_matmul_kernel by reshaping the problem: treat it as a single
12485/// large matmul of [batch*M, K] @ [K, N] — but that doesn't work because B is
12486/// batched. Instead, we use a modified approach: thread `idx` computes element
12487/// (batch_i, row, col) where batch_i = idx / (M*N).
12488///
12489/// For simplicity and correctness, we fall back to cpu_bmm for now when
12490/// cuBLAS fails, but route through gpu_small_matmul for the single-matrix case.
12491#[cfg(feature = "cuda")]
12492pub fn gpu_small_bmm(
12493    a: &CudaBuffer<f32>,
12494    b: &CudaBuffer<f32>,
12495    batch: usize,
12496    m: usize,
12497    k: usize,
12498    n: usize,
12499    device: &GpuDevice,
12500) -> GpuResult<CudaBuffer<f32>> {
12501    // For batch=1, just use the single matmul kernel.
12502    if batch == 1 {
12503        return gpu_small_matmul(a, b, m, k, n, device);
12504    }
12505    // For batched case, fall back to cuBLAS (the batched PTX kernel is complex).
12506    // The main win is from the single-matrix decode case (batch=1 for attention scores).
12507    crate::blas::gpu_bmm_f32(a, b, batch, m, k, n, device)
12508}
12509
12510// ---------------------------------------------------------------------------
12511// Public API -- Embedding lookup (GPU-native)
12512// ---------------------------------------------------------------------------
12513
12514/// GPU embedding lookup: reads token ID from `idx` (single f32 on GPU),
12515/// gathers row from `weight` `[V, D]`, writes to `out` `[D]`.
12516/// Entire operation stays on GPU — no CPU involvement.
12517#[cfg(feature = "cuda")]
12518pub fn gpu_embed_lookup(
12519    idx: &CudaBuffer<f32>,
12520    weight: &CudaBuffer<f32>,
12521    d: usize,
12522    device: &GpuDevice,
12523) -> GpuResult<CudaBuffer<f32>> {
12524    use cudarc::driver::PushKernelArg;
12525
12526    let ctx = device.context();
12527    let stream = device.stream();
12528
12529    let f = match crate::module_cache::get_or_compile(
12530        ctx,
12531        EMBED_LOOKUP_PTX,
12532        "embed_lookup_kernel",
12533        device.ordinal() as u32,
12534    ) {
12535        Ok(f) => f,
12536        Err(_) => {
12537            // CPU fallback.
12538            let idx_host = gpu_to_cpu(idx, device)?;
12539            let weight_host = gpu_to_cpu(weight, device)?;
12540            let row = idx_host[0] as usize;
12541            let start = row * d;
12542            let out = weight_host[start..start + d].to_vec();
12543            return cpu_to_gpu(&out, device);
12544        }
12545    };
12546
12547    let mut out = alloc_zeros_f32(d, device)?;
12548    let cfg = launch_cfg(d)?;
12549    let d_u32 = d as u32;
12550
12551    unsafe {
12552        stream
12553            .launch_builder(&f)
12554            .arg(idx.inner())
12555            .arg(weight.inner())
12556            .arg(out.inner_mut())
12557            .arg(&d_u32)
12558            .launch(cfg)?;
12559    }
12560
12561    Ok(out)
12562}
12563
12564// ---------------------------------------------------------------------------
12565// Public API -- Slice write (for KV cache)
12566// ---------------------------------------------------------------------------
12567
12568/// Write `src` of shape `[N, D]` into row `pos` of `dst` of shape `[N, max_len, D]`.
12569/// This is an in-place GPU operation — `dst` is modified.
12570#[cfg(feature = "cuda")]
12571pub fn gpu_slice_write(
12572    src: &CudaBuffer<f32>,
12573    dst: &mut CudaBuffer<f32>,
12574    n_batch: usize,
12575    d: usize,
12576    max_len: usize,
12577    pos: usize,
12578    device: &GpuDevice,
12579) -> GpuResult<()> {
12580    use cudarc::driver::PushKernelArg;
12581
12582    let total = n_batch * d;
12583    let ctx = device.context();
12584    let stream = device.stream();
12585
12586    let f = match crate::module_cache::get_or_compile(
12587        ctx,
12588        SLICE_WRITE_PTX,
12589        "slice_write_kernel",
12590        device.ordinal() as u32,
12591    ) {
12592        Ok(f) => f,
12593        Err(_) => {
12594            // CPU fallback.
12595            let src_host = gpu_to_cpu(src, device)?;
12596            let mut dst_host = gpu_to_cpu(dst, device)?;
12597            for b in 0..n_batch {
12598                for di in 0..d {
12599                    dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
12600                }
12601            }
12602            let new_dst = cpu_to_gpu(&dst_host, device)?;
12603            *dst = new_dst;
12604            return Ok(());
12605        }
12606    };
12607
12608    let cfg = launch_cfg(total)?;
12609    let n_u32 = total as u32;
12610    let d_u32 = d as u32;
12611    let max_len_u32 = max_len as u32;
12612    let pos_u32 = pos as u32;
12613
12614    unsafe {
12615        stream
12616            .launch_builder(&f)
12617            .arg(src.inner())
12618            .arg(dst.inner_mut())
12619            .arg(&n_u32)
12620            .arg(&d_u32)
12621            .arg(&max_len_u32)
12622            .arg(&pos_u32)
12623            .launch(cfg)?;
12624    }
12625
12626    Ok(())
12627}
12628
12629// ---------------------------------------------------------------------------
12630// Public API -- Slice read (for KV cache)
12631// ---------------------------------------------------------------------------
12632
12633/// Read first `len` rows from each batch of `[N, max_len, D]` → `[N, len, D]`.
12634#[cfg(feature = "cuda")]
12635pub fn gpu_slice_read(
12636    src: &CudaBuffer<f32>,
12637    n_batch: usize,
12638    d: usize,
12639    len: usize,
12640    max_len: usize,
12641    device: &GpuDevice,
12642) -> GpuResult<CudaBuffer<f32>> {
12643    use cudarc::driver::PushKernelArg;
12644
12645    let total = n_batch * len * d;
12646    let ctx = device.context();
12647    let stream = device.stream();
12648
12649    let f = match crate::module_cache::get_or_compile(
12650        ctx,
12651        SLICE_READ_PTX,
12652        "slice_read_kernel",
12653        device.ordinal() as u32,
12654    ) {
12655        Ok(f) => f,
12656        Err(_) => {
12657            let host = gpu_to_cpu(src, device)?;
12658            let mut out = vec![0.0f32; total];
12659            for b in 0..n_batch {
12660                for r in 0..len {
12661                    for di in 0..d {
12662                        out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
12663                    }
12664                }
12665            }
12666            return cpu_to_gpu(&out, device);
12667        }
12668    };
12669
12670    let mut out = alloc_zeros_f32(total, device)?;
12671    let cfg = launch_cfg(total)?;
12672    let total_u32 = total as u32;
12673    let d_u32 = d as u32;
12674    let len_u32 = len as u32;
12675    let max_len_u32 = max_len as u32;
12676
12677    unsafe {
12678        stream
12679            .launch_builder(&f)
12680            .arg(src.inner())
12681            .arg(out.inner_mut())
12682            .arg(&total_u32)
12683            .arg(&d_u32)
12684            .arg(&len_u32)
12685            .arg(&max_len_u32)
12686            .launch(cfg)?;
12687    }
12688
12689    Ok(out)
12690}
12691
12692// ---------------------------------------------------------------------------
12693// Public API -- GELU
12694// ---------------------------------------------------------------------------
12695
12696/// Elementwise GELU activation on GPU: `gelu(x) = x * sigmoid(1.702 * x)`.
12697#[cfg(feature = "cuda")]
12698pub fn gpu_gelu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12699    validate_unary(input, device)?;
12700    if let Some(out) = try_launch_unary(input, device, GELU_PTX, "gelu_kernel")? {
12701        return Ok(out);
12702    }
12703    cpu_fallback_unary(input, device, |x| {
12704        let s = 1.0 / (1.0 + (-1.702 * x).exp());
12705        x * s
12706    })
12707}
12708
12709/// Elementwise GELU activation on GPU using the tanh approximation:
12710/// `gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))`.
12711///
12712/// Matches PyTorch `nn.GELU(approximate="tanh")`.
12713#[cfg(feature = "cuda")]
12714pub fn gpu_gelu_tanh(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12715    validate_unary(input, device)?;
12716    if let Some(out) = try_launch_unary(input, device, GELU_TANH_PTX, "gelu_tanh_kernel")? {
12717        return Ok(out);
12718    }
12719    cpu_fallback_unary(input, device, |x| {
12720        let sqrt_2_over_pi: f32 = 0.7978845608;
12721        let c: f32 = 0.044715;
12722        let inner = sqrt_2_over_pi * (x + c * x * x * x);
12723        0.5 * x * (1.0 + inner.tanh())
12724    })
12725}
12726
12727/// Elementwise GELU activation on GPU using exact erf:
12728/// `gelu(x) = x * 0.5 * (1 + erf(x / sqrt(2)))`.
12729///
12730/// Matches PyTorch `nn.GELU(approximate="none")` (the default).
12731#[cfg(feature = "cuda")]
12732pub fn gpu_gelu_erf(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12733    validate_unary(input, device)?;
12734    if let Some(out) = try_launch_unary(input, device, GELU_ERF_PTX, "gelu_erf_kernel")? {
12735        return Ok(out);
12736    }
12737    cpu_fallback_unary(input, device, |x| {
12738        // Abramowitz & Stegun 7.1.26 erf approximation (matches PTX kernel)
12739        let z = x * std::f32::consts::FRAC_1_SQRT_2;
12740        let az = z.abs();
12741        let t = 1.0 / (1.0 + 0.3275911 * az);
12742        let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
12743        let erf_abs = 1.0 - poly * (-az * az).exp();
12744        let erf_val = if z < 0.0 { -erf_abs } else { erf_abs };
12745        x * 0.5 * (1.0 + erf_val)
12746    })
12747}
12748
12749/// GELU backward for the tanh approximation mode.
12750/// Let `u = sqrt(2/π) * (x + 0.044715 * x³)`, `t = tanh(u)`.
12751/// `d/dx = 0.5 * (1 + t) + 0.5 * x * (1 - t²) * sqrt(2/π) * (1 + 3*0.044715*x²)`
12752#[cfg(feature = "cuda")]
12753pub fn gpu_gelu_backward_tanh(
12754    grad: &CudaBuffer<f32>,
12755    input: &CudaBuffer<f32>,
12756    device: &GpuDevice,
12757) -> GpuResult<CudaBuffer<f32>> {
12758    validate_binary(grad, input, device)?;
12759    if let Some(out) = try_launch_binary(
12760        grad,
12761        input,
12762        device,
12763        GELU_BACKWARD_TANH_PTX,
12764        "gelu_backward_tanh_kernel",
12765    )? {
12766        return Ok(out);
12767    }
12768    // CPU fallback
12769    let grad_host = gpu_to_cpu(grad, device)?;
12770    let input_host = gpu_to_cpu(input, device)?;
12771    let result: Vec<f32> = grad_host
12772        .iter()
12773        .zip(input_host.iter())
12774        .map(|(&g, &x)| {
12775            let sqrt_2_over_pi: f32 = 0.7978845608;
12776            let c: f32 = 0.044715;
12777            let c3: f32 = 0.134145;
12778            let u = sqrt_2_over_pi * (x + c * x * x * x);
12779            let t = u.tanh();
12780            let dt = 1.0 - t * t;
12781            let d_inner = sqrt_2_over_pi * (1.0 + c3 * x * x);
12782            g * (0.5 * (1.0 + t) + 0.5 * x * dt * d_inner)
12783        })
12784        .collect();
12785    cpu_to_gpu(&result, device)
12786}
12787
12788// ---------------------------------------------------------------------------
12789// Public API -- SiLU (Swish)
12790// ---------------------------------------------------------------------------
12791
12792/// Elementwise SiLU activation on GPU: `silu(x) = x * sigmoid(x)`.
12793#[cfg(feature = "cuda")]
12794pub fn gpu_silu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12795    validate_unary(input, device)?;
12796    if let Some(out) = try_launch_unary(input, device, SILU_PTX, "silu_kernel")? {
12797        return Ok(out);
12798    }
12799    cpu_fallback_unary(input, device, |x| {
12800        let sig = 1.0 / (1.0 + (-x).exp());
12801        x * sig
12802    })
12803}
12804
12805/// SiLU backward: `out[i] = grad[i] * (sig + x * sig * (1 - sig))`
12806/// where `sig = sigmoid(input[i])`.
12807#[cfg(feature = "cuda")]
12808pub fn gpu_silu_backward(
12809    grad: &CudaBuffer<f32>,
12810    input: &CudaBuffer<f32>,
12811    device: &GpuDevice,
12812) -> GpuResult<CudaBuffer<f32>> {
12813    validate_binary(grad, input, device)?;
12814
12815    if let Some(out) = try_launch_binary(
12816        grad,
12817        input,
12818        device,
12819        SILU_BACKWARD_PTX,
12820        "silu_backward_kernel",
12821    )? {
12822        return Ok(out);
12823    }
12824
12825    // CPU fallback
12826    let grad_host = gpu_to_cpu(grad, device)?;
12827    let input_host = gpu_to_cpu(input, device)?;
12828    let result: Vec<f32> = grad_host
12829        .iter()
12830        .zip(input_host.iter())
12831        .map(|(&g, &x)| {
12832            let sig = 1.0 / (1.0 + (-x).exp());
12833            g * (sig + x * sig * (1.0 - sig))
12834        })
12835        .collect();
12836    cpu_to_gpu(&result, device)
12837}
12838
12839// ---------------------------------------------------------------------------
12840// Public API -- ELU
12841// ---------------------------------------------------------------------------
12842
12843/// Elementwise ELU activation on GPU: `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`.
12844///
12845/// Uses a custom launch because the kernel takes an extra `alpha` parameter.
12846#[cfg(feature = "cuda")]
12847pub fn gpu_elu(
12848    input: &CudaBuffer<f32>,
12849    alpha: f32,
12850    device: &GpuDevice,
12851) -> GpuResult<CudaBuffer<f32>> {
12852    use cudarc::driver::PushKernelArg;
12853
12854    validate_unary(input, device)?;
12855
12856    let n = input.len();
12857    let ctx = device.context();
12858    let stream = device.stream();
12859
12860    let f = match crate::module_cache::get_or_compile(
12861        ctx,
12862        ELU_PTX,
12863        "elu_kernel",
12864        device.ordinal() as u32,
12865    ) {
12866        Ok(f) => f,
12867        Err(_) => {
12868            let host = gpu_to_cpu(input, device)?;
12869            let result: Vec<f32> = host
12870                .iter()
12871                .map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
12872                .collect();
12873            return cpu_to_gpu(&result, device);
12874        }
12875    };
12876
12877    let mut out = alloc_zeros_f32(n, device)?;
12878    let cfg = launch_cfg(n)?;
12879    let n_u32 = n as u32;
12880
12881    unsafe {
12882        stream
12883            .launch_builder(&f)
12884            .arg(input.inner())
12885            .arg(out.inner_mut())
12886            .arg(&n_u32)
12887            .arg(&alpha)
12888            .launch(cfg)?;
12889    }
12890
12891    Ok(out)
12892}
12893
12894/// ELU backward: `out[i] = x > 0 ? grad[i] : grad[i] * alpha * exp(x)`.
12895///
12896/// Uses a custom launch because the kernel takes an extra `alpha` parameter.
12897#[cfg(feature = "cuda")]
12898pub fn gpu_elu_backward(
12899    grad: &CudaBuffer<f32>,
12900    input: &CudaBuffer<f32>,
12901    alpha: f32,
12902    device: &GpuDevice,
12903) -> GpuResult<CudaBuffer<f32>> {
12904    use cudarc::driver::PushKernelArg;
12905
12906    validate_binary(grad, input, device)?;
12907
12908    let n = grad.len();
12909    let ctx = device.context();
12910    let stream = device.stream();
12911
12912    let f = match crate::module_cache::get_or_compile(
12913        ctx,
12914        ELU_BACKWARD_PTX,
12915        "elu_backward_kernel",
12916        device.ordinal() as u32,
12917    ) {
12918        Ok(f) => f,
12919        Err(_) => {
12920            let grad_host = gpu_to_cpu(grad, device)?;
12921            let input_host = gpu_to_cpu(input, device)?;
12922            let result: Vec<f32> = grad_host
12923                .iter()
12924                .zip(input_host.iter())
12925                .map(|(&g, &x)| if x > 0.0 { g } else { g * alpha * x.exp() })
12926                .collect();
12927            return cpu_to_gpu(&result, device);
12928        }
12929    };
12930
12931    let mut out = alloc_zeros_f32(n, device)?;
12932    let cfg = launch_cfg(n)?;
12933    let n_u32 = n as u32;
12934
12935    unsafe {
12936        stream
12937            .launch_builder(&f)
12938            .arg(grad.inner())
12939            .arg(input.inner())
12940            .arg(out.inner_mut())
12941            .arg(&n_u32)
12942            .arg(&alpha)
12943            .launch(cfg)?;
12944    }
12945
12946    Ok(out)
12947}
12948
12949// ---------------------------------------------------------------------------
12950// Public API -- Mish
12951// ---------------------------------------------------------------------------
12952
12953/// Elementwise Mish activation on GPU: `mish(x) = x * tanh(softplus(x))`.
12954#[cfg(feature = "cuda")]
12955pub fn gpu_mish(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12956    validate_unary(input, device)?;
12957    if let Some(out) = try_launch_unary(input, device, MISH_PTX, "mish_kernel")? {
12958        return Ok(out);
12959    }
12960    cpu_fallback_unary(input, device, |x| {
12961        let sp = if x > 20.0 { x } else { (1.0 + x.exp()).ln() };
12962        x * sp.tanh()
12963    })
12964}
12965
12966/// Mish backward:
12967/// `out[i] = grad[i] * (tanh(sp) + x * sigmoid(x) * (1 - tanh(sp)^2))`
12968/// where `sp = softplus(x) = ln(1 + exp(x))`.
12969#[cfg(feature = "cuda")]
12970pub fn gpu_mish_backward(
12971    grad: &CudaBuffer<f32>,
12972    input: &CudaBuffer<f32>,
12973    device: &GpuDevice,
12974) -> GpuResult<CudaBuffer<f32>> {
12975    validate_binary(grad, input, device)?;
12976
12977    if let Some(out) = try_launch_binary(
12978        grad,
12979        input,
12980        device,
12981        MISH_BACKWARD_PTX,
12982        "mish_backward_kernel",
12983    )? {
12984        return Ok(out);
12985    }
12986
12987    // CPU fallback
12988    let grad_host = gpu_to_cpu(grad, device)?;
12989    let input_host = gpu_to_cpu(input, device)?;
12990    let result: Vec<f32> = grad_host
12991        .iter()
12992        .zip(input_host.iter())
12993        .map(|(&g, &x)| {
12994            let sp = if x > 20.0 { x } else { (1.0 + x.exp()).ln() };
12995            let t = sp.tanh();
12996            let sig = 1.0 / (1.0 + (-x).exp());
12997            g * (t + x * sig * (1.0 - t * t))
12998        })
12999        .collect();
13000    cpu_to_gpu(&result, device)
13001}
13002
13003/// Elementwise clamp: `out[i] = max(min_val, min(max_val, x[i]))`.
13004///
13005/// Uses a custom launch because the kernel takes two extra f32 parameters.
13006#[cfg(feature = "cuda")]
13007pub fn gpu_clamp(
13008    input: &CudaBuffer<f32>,
13009    min_val: f32,
13010    max_val: f32,
13011    device: &GpuDevice,
13012) -> GpuResult<CudaBuffer<f32>> {
13013    use cudarc::driver::PushKernelArg;
13014
13015    validate_unary(input, device)?;
13016
13017    let n = input.len();
13018    let ctx = device.context();
13019    let stream = device.stream();
13020
13021    let f = match crate::module_cache::get_or_compile(
13022        ctx,
13023        CLAMP_PTX,
13024        "clamp_kernel",
13025        device.ordinal() as u32,
13026    ) {
13027        Ok(f) => f,
13028        Err(_) => {
13029            let host = gpu_to_cpu(input, device)?;
13030            let result: Vec<f32> = host
13031                .iter()
13032                .map(|&x| x.max(min_val).min(max_val))
13033                .collect();
13034            return cpu_to_gpu(&result, device);
13035        }
13036    };
13037
13038    let mut out = alloc_zeros_f32(n, device)?;
13039    let cfg = launch_cfg(n)?;
13040    let n_u32 = n as u32;
13041
13042    unsafe {
13043        stream
13044            .launch_builder(&f)
13045            .arg(input.inner())
13046            .arg(out.inner_mut())
13047            .arg(&n_u32)
13048            .arg(&min_val)
13049            .arg(&max_val)
13050            .launch(cfg)?;
13051    }
13052
13053    Ok(out)
13054}
13055
13056// ---------------------------------------------------------------------------
13057// Public API -- elementwise transcendentals & math ops
13058// ---------------------------------------------------------------------------
13059
13060/// Elementwise division: `out[i] = a[i] / b[i]`.
13061#[cfg(feature = "cuda")]
13062pub fn gpu_div(
13063    a: &CudaBuffer<f32>,
13064    b: &CudaBuffer<f32>,
13065    device: &GpuDevice,
13066) -> GpuResult<CudaBuffer<f32>> {
13067    validate_binary(a, b, device)?;
13068
13069    if let Some(out) = try_launch_binary(a, b, device, DIV_PTX, "div_kernel")? {
13070        return Ok(out);
13071    }
13072
13073    // CPU fallback
13074    let a_host = gpu_to_cpu(a, device)?;
13075    let b_host = gpu_to_cpu(b, device)?;
13076    let result: Vec<f32> = a_host
13077        .iter()
13078        .zip(b_host.iter())
13079        .map(|(&x, &y)| x / y)
13080        .collect();
13081    cpu_to_gpu(&result, device)
13082}
13083
13084/// Elementwise exponential: `out[i] = exp(a[i])`.
13085#[cfg(feature = "cuda")]
13086pub fn gpu_exp(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13087    validate_unary(a, device)?;
13088    if let Some(out) = try_launch_unary(a, device, EXP_PTX, "exp_kernel")? {
13089        return Ok(out);
13090    }
13091    cpu_fallback_unary(a, device, |x| x.exp())
13092}
13093
13094/// Elementwise natural log: `out[i] = ln(a[i])`.
13095#[cfg(feature = "cuda")]
13096pub fn gpu_log(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13097    validate_unary(a, device)?;
13098    if let Some(out) = try_launch_unary(a, device, LOG_PTX, "log_kernel")? {
13099        return Ok(out);
13100    }
13101    cpu_fallback_unary(a, device, |x| x.ln())
13102}
13103
13104/// Elementwise square root: `out[i] = sqrt(a[i])`.
13105#[cfg(feature = "cuda")]
13106pub fn gpu_sqrt(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13107    validate_unary(a, device)?;
13108    if let Some(out) = try_launch_unary(a, device, SQRT_PTX, "sqrt_kernel")? {
13109        return Ok(out);
13110    }
13111    cpu_fallback_unary(a, device, |x| x.sqrt())
13112}
13113
13114/// Elementwise power: `out[i] = a[i] ^ exponent`.
13115#[cfg(feature = "cuda")]
13116pub fn gpu_pow(
13117    a: &CudaBuffer<f32>,
13118    exponent: f32,
13119    device: &GpuDevice,
13120) -> GpuResult<CudaBuffer<f32>> {
13121    use cudarc::driver::PushKernelArg;
13122
13123    validate_unary(a, device)?;
13124
13125    let n = a.len();
13126    let ctx = device.context();
13127    let stream = device.stream();
13128
13129    let f = match crate::module_cache::get_or_compile(
13130        ctx,
13131        POW_PTX,
13132        "pow_kernel",
13133        device.ordinal() as u32,
13134    ) {
13135        Ok(f) => f,
13136        Err(_) => {
13137            let host = gpu_to_cpu(a, device)?;
13138            let result: Vec<f32> = host.iter().map(|&x| x.powf(exponent)).collect();
13139            return cpu_to_gpu(&result, device);
13140        }
13141    };
13142
13143    let mut out = alloc_zeros_f32(n, device)?;
13144    let cfg = launch_cfg(n)?;
13145    let n_u32 = n as u32;
13146
13147    unsafe {
13148        stream
13149            .launch_builder(&f)
13150            .arg(a.inner())
13151            .arg(out.inner_mut())
13152            .arg(&exponent)
13153            .arg(&n_u32)
13154            .launch(cfg)?;
13155    }
13156
13157    Ok(out)
13158}
13159
13160/// Elementwise absolute value: `out[i] = |a[i]|`.
13161#[cfg(feature = "cuda")]
13162pub fn gpu_abs(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13163    validate_unary(a, device)?;
13164    if let Some(out) = try_launch_unary(a, device, ABS_PTX, "abs_kernel")? {
13165        return Ok(out);
13166    }
13167    cpu_fallback_unary(a, device, |x| x.abs())
13168}
13169
13170/// Elementwise sigmoid: `out[i] = 1 / (1 + exp(-a[i]))`.
13171#[cfg(feature = "cuda")]
13172pub fn gpu_sigmoid(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13173    validate_unary(a, device)?;
13174    if let Some(out) = try_launch_unary(a, device, SIGMOID_PTX, "sigmoid_kernel")? {
13175        return Ok(out);
13176    }
13177    cpu_fallback_unary(a, device, |x| 1.0 / (1.0 + (-x).exp()))
13178}
13179
13180/// Elementwise tanh: `out[i] = tanh(a[i])`.
13181#[cfg(feature = "cuda")]
13182pub fn gpu_tanh(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13183    validate_unary(a, device)?;
13184    if let Some(out) = try_launch_unary(a, device, TANH_PTX, "tanh_kernel")? {
13185        return Ok(out);
13186    }
13187    cpu_fallback_unary(a, device, |x| x.tanh())
13188}
13189
13190// ---------------------------------------------------------------------------
13191// Public API -- f64 elementwise ops
13192// ---------------------------------------------------------------------------
13193
13194/// Elementwise f64 addition: `out[i] = a[i] + b[i]`.
13195#[cfg(feature = "cuda")]
13196pub fn gpu_add_f64(
13197    a: &CudaBuffer<f64>,
13198    b: &CudaBuffer<f64>,
13199    device: &GpuDevice,
13200) -> GpuResult<CudaBuffer<f64>> {
13201    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13202    if a.len() != b.len() {
13203        return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13204    }
13205    let ptx = get_f64_ptx(&CACHE, ADD_PTX, "add_kernel", "add_f64_kernel");
13206    if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "add_f64_kernel")? {
13207        return Ok(out);
13208    }
13209    cpu_fallback_binary_f64(a, b, device, |x, y| x + y)
13210}
13211
13212/// Elementwise f64 subtraction: `out[i] = a[i] - b[i]`.
13213#[cfg(feature = "cuda")]
13214pub fn gpu_sub_f64(
13215    a: &CudaBuffer<f64>,
13216    b: &CudaBuffer<f64>,
13217    device: &GpuDevice,
13218) -> GpuResult<CudaBuffer<f64>> {
13219    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13220    if a.len() != b.len() {
13221        return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13222    }
13223    let ptx = get_f64_ptx(&CACHE, SUB_PTX, "sub_kernel", "sub_f64_kernel");
13224    if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "sub_f64_kernel")? {
13225        return Ok(out);
13226    }
13227    cpu_fallback_binary_f64(a, b, device, |x, y| x - y)
13228}
13229
13230/// Elementwise f64 multiplication: `out[i] = a[i] * b[i]`.
13231#[cfg(feature = "cuda")]
13232pub fn gpu_mul_f64(
13233    a: &CudaBuffer<f64>,
13234    b: &CudaBuffer<f64>,
13235    device: &GpuDevice,
13236) -> GpuResult<CudaBuffer<f64>> {
13237    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13238    if a.len() != b.len() {
13239        return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13240    }
13241    let ptx = get_f64_ptx(&CACHE, MUL_PTX, "mul_kernel", "mul_f64_kernel");
13242    if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "mul_f64_kernel")? {
13243        return Ok(out);
13244    }
13245    cpu_fallback_binary_f64(a, b, device, |x, y| x * y)
13246}
13247
13248/// Elementwise f64 division: `out[i] = a[i] / b[i]`.
13249#[cfg(feature = "cuda")]
13250pub fn gpu_div_f64(
13251    a: &CudaBuffer<f64>,
13252    b: &CudaBuffer<f64>,
13253    device: &GpuDevice,
13254) -> GpuResult<CudaBuffer<f64>> {
13255    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13256    if a.len() != b.len() {
13257        return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13258    }
13259    let ptx = get_f64_ptx(&CACHE, DIV_PTX, "div_kernel", "div_f64_kernel");
13260    if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "div_f64_kernel")? {
13261        return Ok(out);
13262    }
13263    cpu_fallback_binary_f64(a, b, device, |x, y| x / y)
13264}
13265
13266/// Elementwise f64 negation: `out[i] = -a[i]`.
13267#[cfg(feature = "cuda")]
13268pub fn gpu_neg_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13269    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13270    let ptx = get_f64_ptx(&CACHE, NEG_PTX, "neg_kernel", "neg_f64_kernel");
13271    if let Some(out) = try_launch_unary_f64(a, device, ptx, "neg_f64_kernel")? {
13272        return Ok(out);
13273    }
13274    cpu_fallback_unary_f64(a, device, |x| -x)
13275}
13276
13277/// Elementwise f64 ReLU: `out[i] = max(a[i], 0.0)`.
13278#[cfg(feature = "cuda")]
13279pub fn gpu_relu_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13280    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13281    let ptx = get_f64_ptx(&CACHE, RELU_PTX, "relu_kernel", "relu_f64_kernel");
13282    if let Some(out) = try_launch_unary_f64(a, device, ptx, "relu_f64_kernel")? {
13283        return Ok(out);
13284    }
13285    cpu_fallback_unary_f64(a, device, |x| x.max(0.0))
13286}
13287
13288/// Elementwise f64 scale: `out[i] = a[i] * scalar`.
13289#[cfg(feature = "cuda")]
13290pub fn gpu_scale_f64(
13291    a: &CudaBuffer<f64>,
13292    scalar: f64,
13293    device: &GpuDevice,
13294) -> GpuResult<CudaBuffer<f64>> {
13295    use cudarc::driver::PushKernelArg;
13296    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13297
13298    let n = a.len();
13299    let ctx = device.context();
13300    let stream = device.stream();
13301
13302    let ptx = get_f64_ptx(&CACHE, SCALE_PTX, "scale_kernel", "scale_f64_kernel");
13303    if let Ok(f) = crate::module_cache::get_or_compile(
13304        ctx, ptx, "scale_f64_kernel", device.ordinal() as u32,
13305    ) {
13306        let mut out = alloc_zeros_f64(n, device)?;
13307        let cfg = launch_cfg(n)?;
13308        let n_u32 = n as u32;
13309
13310        unsafe {
13311            stream
13312                .launch_builder(&f)
13313                .arg(a.inner())
13314                .arg(out.inner_mut())
13315                .arg(&scalar)
13316                .arg(&n_u32)
13317                .launch(cfg)?;
13318        }
13319        return Ok(out);
13320    }
13321
13322    let a_host = gpu_to_cpu(a, device)?;
13323    let result: Vec<f64> = a_host.iter().map(|&x| x * scalar).collect();
13324    cpu_to_gpu(&result, device)
13325}
13326
13327/// Elementwise f64 exp: `out[i] = exp(a[i])`.
13328#[cfg(feature = "cuda")]
13329pub fn gpu_exp_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13330    if let Some(out) = try_launch_unary_f64(a, device, EXP_F64_PTX, "exp_f64_kernel")? {
13331        return Ok(out);
13332    }
13333    cpu_fallback_unary_f64(a, device, |x| x.exp())
13334}
13335
13336/// Elementwise f64 log: `out[i] = ln(a[i])`.
13337#[cfg(feature = "cuda")]
13338pub fn gpu_log_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13339    if let Some(out) = try_launch_unary_f64(a, device, LOG_F64_PTX, "log_f64_kernel")? {
13340        return Ok(out);
13341    }
13342    cpu_fallback_unary_f64(a, device, |x| x.ln())
13343}
13344
13345/// Elementwise f64 sqrt: `out[i] = sqrt(a[i])`.
13346#[cfg(feature = "cuda")]
13347pub fn gpu_sqrt_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13348    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13349    let ptx = get_f64_ptx(&CACHE, SQRT_PTX, "sqrt_kernel", "sqrt_f64_kernel");
13350    if let Some(out) = try_launch_unary_f64(a, device, ptx, "sqrt_f64_kernel")? {
13351        return Ok(out);
13352    }
13353    cpu_fallback_unary_f64(a, device, |x| x.sqrt())
13354}
13355
13356/// Elementwise f64 pow: `out[i] = a[i] ^ exponent`.
13357#[cfg(feature = "cuda")]
13358pub fn gpu_pow_f64(
13359    a: &CudaBuffer<f64>,
13360    exponent: f64,
13361    device: &GpuDevice,
13362) -> GpuResult<CudaBuffer<f64>> {
13363    use cudarc::driver::PushKernelArg;
13364
13365    let n = a.len();
13366    let ctx = device.context();
13367    let stream = device.stream();
13368
13369    if let Ok(f) = crate::module_cache::get_or_compile(
13370        ctx, POW_F64_PTX, "pow_f64_kernel", device.ordinal() as u32,
13371    ) {
13372        let mut out = alloc_zeros_f64(n, device)?;
13373        let cfg = launch_cfg(n)?;
13374        let n_u32 = n as u32;
13375
13376        unsafe {
13377            stream
13378                .launch_builder(&f)
13379                .arg(a.inner())
13380                .arg(out.inner_mut())
13381                .arg(&exponent)
13382                .arg(&n_u32)
13383                .launch(cfg)?;
13384        }
13385        return Ok(out);
13386    }
13387
13388    let a_host = gpu_to_cpu(a, device)?;
13389    let result: Vec<f64> = a_host.iter().map(|&x| x.powf(exponent)).collect();
13390    cpu_to_gpu(&result, device)
13391}
13392
13393/// Elementwise f64 abs: `out[i] = |a[i]|`.
13394#[cfg(feature = "cuda")]
13395pub fn gpu_abs_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13396    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13397    let ptx = get_f64_ptx(&CACHE, ABS_PTX, "abs_kernel", "abs_f64_kernel");
13398    if let Some(out) = try_launch_unary_f64(a, device, ptx, "abs_f64_kernel")? {
13399        return Ok(out);
13400    }
13401    cpu_fallback_unary_f64(a, device, |x| x.abs())
13402}
13403
13404/// Elementwise f64 sigmoid: `out[i] = 1 / (1 + exp(-a[i]))`.
13405#[cfg(feature = "cuda")]
13406pub fn gpu_sigmoid_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13407    if let Some(out) = try_launch_unary_f64(a, device, SIGMOID_F64_PTX, "sigmoid_f64_kernel")? {
13408        return Ok(out);
13409    }
13410    cpu_fallback_unary_f64(a, device, |x| 1.0 / (1.0 + (-x).exp()))
13411}
13412
13413/// Elementwise f64 tanh: `out[i] = tanh(a[i])`.
13414#[cfg(feature = "cuda")]
13415pub fn gpu_tanh_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13416    if let Some(out) = try_launch_unary_f64(a, device, TANH_F64_PTX, "tanh_f64_kernel")? {
13417        return Ok(out);
13418    }
13419    cpu_fallback_unary_f64(a, device, |x| x.tanh())
13420}
13421
13422// ---------------------------------------------------------------------------
13423// Public API -- f64 backward ops
13424// ---------------------------------------------------------------------------
13425
13426/// ReLU backward (f64): `out[i] = (input[i] > 0) ? grad[i] : 0`.
13427#[cfg(feature = "cuda")]
13428pub fn gpu_relu_backward_f64(
13429    grad: &CudaBuffer<f64>,
13430    input: &CudaBuffer<f64>,
13431    device: &GpuDevice,
13432) -> GpuResult<CudaBuffer<f64>> {
13433    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13434    if grad.len() != input.len() {
13435        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
13436    }
13437    let ptx = get_f64_ptx(&CACHE, RELU_BACKWARD_PTX, "relu_backward_kernel", "relu_backward_f64_kernel");
13438    if let Some(out) = try_launch_binary_f64(
13439        grad,
13440        input,
13441        device,
13442        ptx,
13443        "relu_backward_f64_kernel",
13444    )? {
13445        return Ok(out);
13446    }
13447    cpu_fallback_binary_f64(grad, input, device, |g, x| if x > 0.0 { g } else { 0.0 })
13448}
13449
13450/// Sigmoid backward (f64): `out[i] = grad[i] * output[i] * (1 - output[i])`.
13451#[cfg(feature = "cuda")]
13452pub fn gpu_sigmoid_backward_f64(
13453    grad: &CudaBuffer<f64>,
13454    output: &CudaBuffer<f64>,
13455    device: &GpuDevice,
13456) -> GpuResult<CudaBuffer<f64>> {
13457    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13458    if grad.len() != output.len() {
13459        return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
13460    }
13461    let ptx = get_f64_ptx(&CACHE, SIGMOID_BACKWARD_PTX, "sigmoid_backward_kernel", "sigmoid_backward_f64_kernel");
13462    if let Some(out) = try_launch_binary_f64(
13463        grad,
13464        output,
13465        device,
13466        ptx,
13467        "sigmoid_backward_f64_kernel",
13468    )? {
13469        return Ok(out);
13470    }
13471    cpu_fallback_binary_f64(grad, output, device, |g, o| g * o * (1.0 - o))
13472}
13473
13474/// Tanh backward (f64): `out[i] = grad[i] * (1 - output[i]^2)`.
13475#[cfg(feature = "cuda")]
13476pub fn gpu_tanh_backward_f64(
13477    grad: &CudaBuffer<f64>,
13478    output: &CudaBuffer<f64>,
13479    device: &GpuDevice,
13480) -> GpuResult<CudaBuffer<f64>> {
13481    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13482    if grad.len() != output.len() {
13483        return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
13484    }
13485    let ptx = get_f64_ptx(&CACHE, TANH_BACKWARD_PTX, "tanh_backward_kernel", "tanh_backward_f64_kernel");
13486    if let Some(out) = try_launch_binary_f64(
13487        grad,
13488        output,
13489        device,
13490        ptx,
13491        "tanh_backward_f64_kernel",
13492    )? {
13493        return Ok(out);
13494    }
13495    cpu_fallback_binary_f64(grad, output, device, |g, o| g * (1.0 - o * o))
13496}
13497
13498// ---------------------------------------------------------------------------
13499// Public API -- f64 broadcast ops
13500// ---------------------------------------------------------------------------
13501
13502/// Broadcast addition (f64): `out[i] = a[bcast_a(i)] + b[bcast_b(i)]`.
13503#[cfg(feature = "cuda")]
13504pub fn gpu_broadcast_add_f64(
13505    a: &CudaBuffer<f64>,
13506    b: &CudaBuffer<f64>,
13507    a_shape: &[usize],
13508    b_shape: &[usize],
13509    out_shape: &[usize],
13510    device: &GpuDevice,
13511) -> GpuResult<CudaBuffer<f64>> {
13512    let a_str = broadcast_strides(a_shape, out_shape);
13513    let b_str = broadcast_strides(b_shape, out_shape);
13514    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13515    let out_numel: usize = out_shape.iter().product();
13516
13517    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13518    let ptx = get_f64_ptx(&CACHE, BROADCAST_ADD_PTX, "broadcast_add_kernel", "broadcast_add_f64_kernel");
13519    if let Some(out) = try_launch_broadcast_binary_f64(
13520        a,
13521        b,
13522        &a_str,
13523        &b_str,
13524        &shape_u32,
13525        out_numel,
13526        device,
13527        ptx,
13528        "broadcast_add_f64_kernel",
13529    )? {
13530        return Ok(out);
13531    }
13532
13533    cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
13534}
13535
13536/// Broadcast subtraction (f64): `out[i] = a[bcast_a(i)] - b[bcast_b(i)]`.
13537#[cfg(feature = "cuda")]
13538pub fn gpu_broadcast_sub_f64(
13539    a: &CudaBuffer<f64>,
13540    b: &CudaBuffer<f64>,
13541    a_shape: &[usize],
13542    b_shape: &[usize],
13543    out_shape: &[usize],
13544    device: &GpuDevice,
13545) -> GpuResult<CudaBuffer<f64>> {
13546    let a_str = broadcast_strides(a_shape, out_shape);
13547    let b_str = broadcast_strides(b_shape, out_shape);
13548    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13549    let out_numel: usize = out_shape.iter().product();
13550
13551    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13552    let ptx = get_f64_ptx(&CACHE, BROADCAST_SUB_PTX, "broadcast_sub_kernel", "broadcast_sub_f64_kernel");
13553    if let Some(out) = try_launch_broadcast_binary_f64(
13554        a,
13555        b,
13556        &a_str,
13557        &b_str,
13558        &shape_u32,
13559        out_numel,
13560        device,
13561        ptx,
13562        "broadcast_sub_f64_kernel",
13563    )? {
13564        return Ok(out);
13565    }
13566
13567    cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
13568}
13569
13570/// Broadcast multiplication (f64): `out[i] = a[bcast_a(i)] * b[bcast_b(i)]`.
13571#[cfg(feature = "cuda")]
13572pub fn gpu_broadcast_mul_f64(
13573    a: &CudaBuffer<f64>,
13574    b: &CudaBuffer<f64>,
13575    a_shape: &[usize],
13576    b_shape: &[usize],
13577    out_shape: &[usize],
13578    device: &GpuDevice,
13579) -> GpuResult<CudaBuffer<f64>> {
13580    let a_str = broadcast_strides(a_shape, out_shape);
13581    let b_str = broadcast_strides(b_shape, out_shape);
13582    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13583    let out_numel: usize = out_shape.iter().product();
13584
13585    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13586    let ptx = get_f64_ptx(&CACHE, BROADCAST_MUL_PTX, "broadcast_mul_kernel", "broadcast_mul_f64_kernel");
13587    if let Some(out) = try_launch_broadcast_binary_f64(
13588        a,
13589        b,
13590        &a_str,
13591        &b_str,
13592        &shape_u32,
13593        out_numel,
13594        device,
13595        ptx,
13596        "broadcast_mul_f64_kernel",
13597    )? {
13598        return Ok(out);
13599    }
13600
13601    cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
13602}
13603
13604/// Broadcast division (f64): `out[i] = a[bcast_a(i)] / b[bcast_b(i)]`.
13605#[cfg(feature = "cuda")]
13606pub fn gpu_broadcast_div_f64(
13607    a: &CudaBuffer<f64>,
13608    b: &CudaBuffer<f64>,
13609    a_shape: &[usize],
13610    b_shape: &[usize],
13611    out_shape: &[usize],
13612    device: &GpuDevice,
13613) -> GpuResult<CudaBuffer<f64>> {
13614    let a_str = broadcast_strides(a_shape, out_shape);
13615    let b_str = broadcast_strides(b_shape, out_shape);
13616    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13617    let out_numel: usize = out_shape.iter().product();
13618
13619    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13620    let ptx = get_f64_ptx(&CACHE, BROADCAST_DIV_PTX, "broadcast_div_kernel", "broadcast_div_f64_kernel");
13621    if let Some(out) = try_launch_broadcast_binary_f64(
13622        a,
13623        b,
13624        &a_str,
13625        &b_str,
13626        &shape_u32,
13627        out_numel,
13628        device,
13629        ptx,
13630        "broadcast_div_f64_kernel",
13631    )? {
13632        return Ok(out);
13633    }
13634
13635    cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
13636}
13637
13638// ---------------------------------------------------------------------------
13639// Public API -- f64 reduction ops
13640// ---------------------------------------------------------------------------
13641
13642/// Full reduce-sum for f64: returns a 1-element buffer containing the sum of all elements.
13643#[cfg(feature = "cuda")]
13644pub fn gpu_reduce_sum_f64(
13645    a: &CudaBuffer<f64>,
13646    device: &GpuDevice,
13647) -> GpuResult<CudaBuffer<f64>> {
13648    use cudarc::driver::PushKernelArg;
13649    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13650
13651    let n = a.len();
13652    if n == 0 {
13653        return cpu_to_gpu(&[0.0f64], device);
13654    }
13655
13656    let ctx = device.context();
13657    let stream = device.stream();
13658
13659    let ptx = get_f64_ptx(&CACHE, REDUCE_SUM_PTX, "reduce_sum_kernel", "reduce_sum_f64_kernel");
13660    let f = match crate::module_cache::get_or_compile(
13661        ctx,
13662        ptx,
13663        "reduce_sum_f64_kernel",
13664        device.ordinal() as u32,
13665    ) {
13666        Ok(f) => f,
13667        Err(_) => {
13668            let host = gpu_to_cpu(a, device)?;
13669            let total: f64 = host.iter().sum();
13670            return cpu_to_gpu(&[total], device);
13671        }
13672    };
13673
13674    const BLOCK: u32 = 256;
13675    let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
13676    let num_blocks = num_blocks.min(1024);
13677
13678    let mut partials = alloc_zeros_f64(num_blocks as usize, device)?;
13679    let n_u32 = n as u32;
13680
13681    let cfg = cudarc::driver::LaunchConfig {
13682        grid_dim: (num_blocks.max(1), 1, 1),
13683        block_dim: (BLOCK, 1, 1),
13684        shared_mem_bytes: 0,
13685    };
13686
13687    unsafe {
13688        stream
13689            .launch_builder(&f)
13690            .arg(a.inner())
13691            .arg(partials.inner_mut())
13692            .arg(&n_u32)
13693            .launch(cfg)?;
13694    }
13695
13696    if num_blocks <= 1 {
13697        return Ok(partials);
13698    }
13699
13700    if num_blocks <= 256 {
13701        let host_partials = gpu_to_cpu(&partials, device)?;
13702        let total: f64 = host_partials.iter().sum();
13703        return cpu_to_gpu(&[total], device);
13704    }
13705
13706    gpu_reduce_sum_f64(&partials, device)
13707}
13708
13709/// Sum along an axis for f64.
13710#[cfg(feature = "cuda")]
13711pub fn gpu_sum_axis_f64(
13712    a: &CudaBuffer<f64>,
13713    outer: usize,
13714    axis_size: usize,
13715    inner: usize,
13716    device: &GpuDevice,
13717) -> GpuResult<CudaBuffer<f64>> {
13718    use cudarc::driver::PushKernelArg;
13719    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13720
13721    let total_output = outer * inner;
13722    let ctx = device.context();
13723    let stream = device.stream();
13724
13725    let ptx = get_f64_ptx(&CACHE, SUM_AXIS_PTX, "sum_axis_kernel", "sum_axis_f64_kernel");
13726    let f = match crate::module_cache::get_or_compile(
13727        ctx,
13728        ptx,
13729        "sum_axis_f64_kernel",
13730        device.ordinal() as u32,
13731    ) {
13732        Ok(f) => f,
13733        Err(_) => {
13734            let host = gpu_to_cpu(a, device)?;
13735            let mut result = vec![0.0f64; total_output];
13736            for (i, out) in result.iter_mut().enumerate() {
13737                let outer_idx = i / inner;
13738                let inner_idx = i % inner;
13739                let mut sum = 0.0f64;
13740                for k in 0..axis_size {
13741                    sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
13742                }
13743                *out = sum;
13744            }
13745            return cpu_to_gpu(&result, device);
13746        }
13747    };
13748
13749    let mut out = alloc_zeros_f64(total_output, device)?;
13750    let cfg = launch_cfg(total_output)?;
13751    let outer_u32 = outer as u32;
13752    let axis_size_u32 = axis_size as u32;
13753    let inner_u32 = inner as u32;
13754    let total_u32 = total_output as u32;
13755
13756    unsafe {
13757        stream
13758            .launch_builder(&f)
13759            .arg(a.inner())
13760            .arg(out.inner_mut())
13761            .arg(&outer_u32)
13762            .arg(&axis_size_u32)
13763            .arg(&inner_u32)
13764            .arg(&total_u32)
13765            .launch(cfg)?;
13766    }
13767
13768    Ok(out)
13769}
13770
13771#[cfg(not(feature = "cuda"))]
13772pub fn gpu_reduce_sum_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
13773#[cfg(not(feature = "cuda"))]
13774pub fn gpu_sum_axis_f64(_a: &CudaBuffer<f64>, _outer: usize, _axis_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
13775
13776// ---------------------------------------------------------------------------
13777// Public API -- f64 shape ops
13778// ---------------------------------------------------------------------------
13779
13780/// Transpose an `[M, N]` f64 matrix to `[N, M]` on GPU.
13781#[cfg(feature = "cuda")]
13782pub fn gpu_transpose_2d_f64(
13783    input: &CudaBuffer<f64>,
13784    m: usize,
13785    n: usize,
13786    device: &GpuDevice,
13787) -> GpuResult<CudaBuffer<f64>> {
13788    use cudarc::driver::PushKernelArg;
13789    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13790
13791    validate_device(input, device)?;
13792
13793    let total = m * n;
13794    let ctx = device.context();
13795    let stream = device.stream();
13796
13797    let ptx = get_f64_ptx(&CACHE, TRANSPOSE_2D_PTX, "transpose_2d_kernel", "transpose_2d_f64_kernel");
13798    let f = match crate::module_cache::get_or_compile(
13799        ctx,
13800        ptx,
13801        "transpose_2d_f64_kernel",
13802        device.ordinal() as u32,
13803    ) {
13804        Ok(f) => f,
13805        Err(_) => {
13806            let host = gpu_to_cpu(input, device)?;
13807            let mut out = vec![0.0f64; total];
13808            for i in 0..m {
13809                for j in 0..n {
13810                    out[j * m + i] = host[i * n + j];
13811                }
13812            }
13813            return cpu_to_gpu(&out, device);
13814        }
13815    };
13816
13817    let mut out = alloc_zeros_f64(total, device)?;
13818    let cfg = launch_cfg(total)?;
13819    let m_u32 = m as u32;
13820    let n_u32 = n as u32;
13821    let total_u32 = total as u32;
13822
13823    unsafe {
13824        stream
13825            .launch_builder(&f)
13826            .arg(input.inner())
13827            .arg(out.inner_mut())
13828            .arg(&m_u32)
13829            .arg(&n_u32)
13830            .arg(&total_u32)
13831            .launch(cfg)?;
13832    }
13833
13834    Ok(out)
13835}
13836
13837/// Permute a 4D f64 tensor from `[d0, d1, d2, d3]` to `[d0, d2, d1, d3]` on GPU.
13838#[cfg(feature = "cuda")]
13839pub fn gpu_permute_0213_f64(
13840    input: &CudaBuffer<f64>,
13841    d0: usize,
13842    d1: usize,
13843    d2: usize,
13844    d3: usize,
13845    device: &GpuDevice,
13846) -> GpuResult<CudaBuffer<f64>> {
13847    use cudarc::driver::PushKernelArg;
13848    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13849
13850    validate_device(input, device)?;
13851
13852    let total = d0 * d1 * d2 * d3;
13853    let ctx = device.context();
13854    let stream = device.stream();
13855
13856    let ptx = get_f64_ptx(&CACHE, PERMUTE_0213_PTX, "permute_0213_kernel", "permute_0213_f64_kernel");
13857    let f = match crate::module_cache::get_or_compile(
13858        ctx,
13859        ptx,
13860        "permute_0213_f64_kernel",
13861        device.ordinal() as u32,
13862    ) {
13863        Ok(f) => f,
13864        Err(_) => {
13865            let host = gpu_to_cpu(input, device)?;
13866            let mut out = vec![0.0f64; total];
13867            for i0 in 0..d0 {
13868                for i1 in 0..d1 {
13869                    for i2 in 0..d2 {
13870                        for i3 in 0..d3 {
13871                            let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
13872                            let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
13873                            out[out_idx] = host[in_idx];
13874                        }
13875                    }
13876                }
13877            }
13878            return cpu_to_gpu(&out, device);
13879        }
13880    };
13881
13882    let mut out = alloc_zeros_f64(total, device)?;
13883    let cfg = launch_cfg(total)?;
13884    let d0_u32 = d0 as u32;
13885    let d1_u32 = d1 as u32;
13886    let d2_u32 = d2 as u32;
13887    let d3_u32 = d3 as u32;
13888    let total_u32 = total as u32;
13889
13890    unsafe {
13891        stream
13892            .launch_builder(&f)
13893            .arg(input.inner())
13894            .arg(out.inner_mut())
13895            .arg(&d0_u32)
13896            .arg(&d1_u32)
13897            .arg(&d2_u32)
13898            .arg(&d3_u32)
13899            .arg(&total_u32)
13900            .launch(cfg)?;
13901    }
13902
13903    Ok(out)
13904}
13905
13906/// Split a contiguous f64 tensor along an axis (strided read) on GPU.
13907#[cfg(feature = "cuda")]
13908pub fn gpu_strided_split_f64(
13909    input: &CudaBuffer<f64>,
13910    total_along_axis: usize,
13911    split_offset: usize,
13912    split_size: usize,
13913    inner_size: usize,
13914    n: usize,
13915    device: &GpuDevice,
13916) -> GpuResult<CudaBuffer<f64>> {
13917    use cudarc::driver::PushKernelArg;
13918    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13919
13920    validate_device(input, device)?;
13921
13922    let ctx = device.context();
13923    let stream = device.stream();
13924
13925    let ptx = get_f64_ptx(&CACHE, STRIDED_SPLIT_PTX, "strided_split_kernel", "strided_split_f64_kernel");
13926    let f = match crate::module_cache::get_or_compile(
13927        ctx,
13928        ptx,
13929        "strided_split_f64_kernel",
13930        device.ordinal() as u32,
13931    ) {
13932        Ok(f) => f,
13933        Err(_) => {
13934            let host = gpu_to_cpu(input, device)?;
13935            let mut result = vec![0.0f64; n];
13936            for (i, out) in result.iter_mut().enumerate() {
13937                let outer_idx = i / (split_size * inner_size);
13938                let within = i % (split_size * inner_size);
13939                let src_idx =
13940                    outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
13941                *out = host[src_idx];
13942            }
13943            return cpu_to_gpu(&result, device);
13944        }
13945    };
13946
13947    let mut out = alloc_zeros_f64(n, device)?;
13948    let cfg = launch_cfg(n)?;
13949    let total_ax_u32 = total_along_axis as u32;
13950    let offset_u32 = split_offset as u32;
13951    let split_sz_u32 = split_size as u32;
13952    let inner_u32 = inner_size as u32;
13953    let n_u32 = n as u32;
13954
13955    unsafe {
13956        stream
13957            .launch_builder(&f)
13958            .arg(input.inner())
13959            .arg(out.inner_mut())
13960            .arg(&total_ax_u32)
13961            .arg(&offset_u32)
13962            .arg(&split_sz_u32)
13963            .arg(&inner_u32)
13964            .arg(&n_u32)
13965            .launch(cfg)?;
13966    }
13967
13968    Ok(out)
13969}
13970
13971/// Concatenate an f64 sub-tensor into a larger output at an axis offset on GPU.
13972#[cfg(feature = "cuda")]
13973#[allow(clippy::too_many_arguments)]
13974pub fn gpu_strided_cat_f64(
13975    input: &CudaBuffer<f64>,
13976    output: &mut CudaBuffer<f64>,
13977    total_along_axis: usize,
13978    cat_offset: usize,
13979    part_size: usize,
13980    inner_size: usize,
13981    n: usize,
13982    device: &GpuDevice,
13983) -> GpuResult<()> {
13984    use cudarc::driver::PushKernelArg;
13985
13986    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13987    validate_device(input, device)?;
13988
13989    let ctx = device.context();
13990    let stream = device.stream();
13991
13992    let ptx = get_f64_ptx(&CACHE, STRIDED_CAT_PTX, "strided_cat_kernel", "strided_cat_f64_kernel");
13993    let f = match crate::module_cache::get_or_compile(
13994        ctx,
13995        ptx,
13996        "strided_cat_f64_kernel",
13997        device.ordinal() as u32,
13998    ) {
13999        Ok(f) => f,
14000        Err(_) => {
14001            let host_in = gpu_to_cpu(input, device)?;
14002            let mut host_out = gpu_to_cpu(output, device)?;
14003            for (i, &val) in host_in.iter().enumerate().take(n) {
14004                let outer_idx = i / (part_size * inner_size);
14005                let within = i % (part_size * inner_size);
14006                let dst_idx =
14007                    outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
14008                host_out[dst_idx] = val;
14009            }
14010            *output = cpu_to_gpu(&host_out, device)?;
14011            return Ok(());
14012        }
14013    };
14014
14015    let cfg = launch_cfg(n)?;
14016    let total_ax_u32 = total_along_axis as u32;
14017    let offset_u32 = cat_offset as u32;
14018    let part_sz_u32 = part_size as u32;
14019    let inner_u32 = inner_size as u32;
14020    let n_u32 = n as u32;
14021
14022    unsafe {
14023        stream
14024            .launch_builder(&f)
14025            .arg(input.inner())
14026            .arg(output.inner_mut())
14027            .arg(&total_ax_u32)
14028            .arg(&offset_u32)
14029            .arg(&part_sz_u32)
14030            .arg(&inner_u32)
14031            .arg(&n_u32)
14032            .launch(cfg)?;
14033    }
14034
14035    Ok(())
14036}
14037
14038// ---------------------------------------------------------------------------
14039// Public API -- f64 indexing ops
14040// ---------------------------------------------------------------------------
14041
14042/// Gather f64 elements by f32 index: `out[i] = input[indices[i]]`.
14043#[cfg(feature = "cuda")]
14044pub fn gpu_index_select_1d_f64(
14045    input: &CudaBuffer<f64>,
14046    indices: &CudaBuffer<f32>,
14047    device: &GpuDevice,
14048) -> GpuResult<CudaBuffer<f64>> {
14049    use cudarc::driver::PushKernelArg;
14050    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14051
14052    validate_device(input, device)?;
14053
14054    let n = indices.len();
14055    let ctx = device.context();
14056    let stream = device.stream();
14057
14058    let ptx = get_f64_ptx(&CACHE, INDEX_SELECT_1D_PTX, "index_select_1d_kernel", "index_select_1d_f64_kernel");
14059    let f = match crate::module_cache::get_or_compile(
14060        ctx,
14061        ptx,
14062        "index_select_1d_f64_kernel",
14063        device.ordinal() as u32,
14064    ) {
14065        Ok(f) => f,
14066        Err(_) => {
14067            let input_host = gpu_to_cpu(input, device)?;
14068            let indices_host = gpu_to_cpu(indices, device)?;
14069            let result: Vec<f64> = indices_host
14070                .iter()
14071                .map(|&idx_f| input_host[idx_f as usize])
14072                .collect();
14073            return cpu_to_gpu(&result, device);
14074        }
14075    };
14076
14077    let mut out = alloc_zeros_f64(n, device)?;
14078    let cfg = launch_cfg(n)?;
14079    let n_u32 = n as u32;
14080
14081    unsafe {
14082        stream
14083            .launch_builder(&f)
14084            .arg(input.inner())
14085            .arg(indices.inner())
14086            .arg(out.inner_mut())
14087            .arg(&n_u32)
14088            .launch(cfg)?;
14089    }
14090
14091    Ok(out)
14092}
14093
14094/// Scatter-add f64 `grad_output` back using f32 `indices`.
14095///
14096/// Output: `out = zeros(input_len); for i: out[indices[i]] += grad_output[i]`
14097#[cfg(feature = "cuda")]
14098pub fn gpu_scatter_add_1d_f64(
14099    grad_output: &CudaBuffer<f64>,
14100    indices: &CudaBuffer<f32>,
14101    input_len: usize,
14102    device: &GpuDevice,
14103) -> GpuResult<CudaBuffer<f64>> {
14104    use cudarc::driver::PushKernelArg;
14105    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14106
14107    validate_device(grad_output, device)?;
14108
14109    let n = grad_output.len();
14110    let ctx = device.context();
14111    let stream = device.stream();
14112
14113    let ptx = get_f64_ptx(&CACHE, SCATTER_ADD_1D_PTX, "scatter_add_1d_kernel", "scatter_add_1d_f64_kernel");
14114    let f = match crate::module_cache::get_or_compile(
14115        ctx,
14116        ptx,
14117        "scatter_add_1d_f64_kernel",
14118        device.ordinal() as u32,
14119    ) {
14120        Ok(f) => f,
14121        Err(_) => {
14122            let go_host = gpu_to_cpu(grad_output, device)?;
14123            let idx_host = gpu_to_cpu(indices, device)?;
14124            let mut result = vec![0.0f64; input_len];
14125            for (i, &idx_f) in idx_host.iter().enumerate() {
14126                result[idx_f as usize] += go_host[i];
14127            }
14128            return cpu_to_gpu(&result, device);
14129        }
14130    };
14131
14132    let mut out = alloc_zeros_f64(input_len, device)?;
14133    let cfg = launch_cfg(n)?;
14134    let n_u32 = n as u32;
14135
14136    unsafe {
14137        stream
14138            .launch_builder(&f)
14139            .arg(grad_output.inner())
14140            .arg(indices.inner())
14141            .arg(out.inner_mut())
14142            .arg(&n_u32)
14143            .launch(cfg)?;
14144    }
14145
14146    Ok(out)
14147}
14148
14149/// Fill f64 elements with `value` where u8 `mask` is nonzero.
14150///
14151/// `mask` is a GPU buffer of u8 values (nonzero = true).
14152/// Output: `out[i] = mask[i] != 0 ? value : input[i]`
14153#[cfg(feature = "cuda")]
14154pub fn gpu_masked_fill_f64(
14155    input: &CudaBuffer<f64>,
14156    mask: &CudaBuffer<u8>,
14157    value: f64,
14158    device: &GpuDevice,
14159) -> GpuResult<CudaBuffer<f64>> {
14160    use cudarc::driver::PushKernelArg;
14161    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14162
14163    validate_device(input, device)?;
14164
14165    let n = input.len();
14166    let ctx = device.context();
14167    let stream = device.stream();
14168
14169    let ptx = get_f64_ptx(&CACHE, MASKED_FILL_PTX, "masked_fill_kernel", "masked_fill_f64_kernel");
14170    let f = match crate::module_cache::get_or_compile(
14171        ctx,
14172        ptx,
14173        "masked_fill_f64_kernel",
14174        device.ordinal() as u32,
14175    ) {
14176        Ok(f) => f,
14177        Err(_) => {
14178            let input_host = gpu_to_cpu(input, device)?;
14179            let mask_host = gpu_to_cpu(mask, device)?;
14180            let result: Vec<f64> = input_host
14181                .iter()
14182                .zip(mask_host.iter())
14183                .map(|(&x, &m)| if m != 0 { value } else { x })
14184                .collect();
14185            return cpu_to_gpu(&result, device);
14186        }
14187    };
14188
14189    let mut out = alloc_zeros_f64(n, device)?;
14190    let cfg = launch_cfg(n)?;
14191    let n_u32 = n as u32;
14192
14193    unsafe {
14194        stream
14195            .launch_builder(&f)
14196            .arg(input.inner())
14197            .arg(mask.inner())
14198            .arg(out.inner_mut())
14199            .arg(&value)
14200            .arg(&n_u32)
14201            .launch(cfg)?;
14202    }
14203
14204    Ok(out)
14205}
14206
14207/// Zero out f64 gradient where u8 `mask` is nonzero.
14208///
14209/// Output: `out[i] = mask[i] != 0 ? 0.0 : grad[i]`
14210#[cfg(feature = "cuda")]
14211pub fn gpu_masked_zero_f64(
14212    grad: &CudaBuffer<f64>,
14213    mask: &CudaBuffer<u8>,
14214    device: &GpuDevice,
14215) -> GpuResult<CudaBuffer<f64>> {
14216    use cudarc::driver::PushKernelArg;
14217    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14218
14219    validate_device(grad, device)?;
14220
14221    let n = grad.len();
14222    let ctx = device.context();
14223    let stream = device.stream();
14224
14225    let ptx = get_f64_ptx(&CACHE, MASKED_ZERO_PTX, "masked_zero_kernel", "masked_zero_f64_kernel");
14226    let f = match crate::module_cache::get_or_compile(
14227        ctx,
14228        ptx,
14229        "masked_zero_f64_kernel",
14230        device.ordinal() as u32,
14231    ) {
14232        Ok(f) => f,
14233        Err(_) => {
14234            let grad_host = gpu_to_cpu(grad, device)?;
14235            let mask_host = gpu_to_cpu(mask, device)?;
14236            let result: Vec<f64> = grad_host
14237                .iter()
14238                .zip(mask_host.iter())
14239                .map(|(&g, &m)| if m != 0 { 0.0 } else { g })
14240                .collect();
14241            return cpu_to_gpu(&result, device);
14242        }
14243    };
14244
14245    let mut out = alloc_zeros_f64(n, device)?;
14246    let cfg = launch_cfg(n)?;
14247    let n_u32 = n as u32;
14248
14249    unsafe {
14250        stream
14251            .launch_builder(&f)
14252            .arg(grad.inner())
14253            .arg(mask.inner())
14254            .arg(out.inner_mut())
14255            .arg(&n_u32)
14256            .launch(cfg)?;
14257    }
14258
14259    Ok(out)
14260}
14261
14262/// Write f64 `src` of shape `[N, D]` into row `pos` of `dst` of shape `[N, max_len, D]`.
14263#[cfg(feature = "cuda")]
14264pub fn gpu_slice_write_f64(
14265    src: &CudaBuffer<f64>,
14266    dst: &mut CudaBuffer<f64>,
14267    n_batch: usize,
14268    d: usize,
14269    max_len: usize,
14270    pos: usize,
14271    device: &GpuDevice,
14272) -> GpuResult<()> {
14273    use cudarc::driver::PushKernelArg;
14274    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14275
14276    let total = n_batch * d;
14277    let ctx = device.context();
14278    let stream = device.stream();
14279
14280    let ptx = get_f64_ptx(&CACHE, SLICE_WRITE_PTX, "slice_write_kernel", "slice_write_f64_kernel");
14281    let f = match crate::module_cache::get_or_compile(
14282        ctx,
14283        ptx,
14284        "slice_write_f64_kernel",
14285        device.ordinal() as u32,
14286    ) {
14287        Ok(f) => f,
14288        Err(_) => {
14289            let src_host = gpu_to_cpu(src, device)?;
14290            let mut dst_host = gpu_to_cpu(dst, device)?;
14291            for b in 0..n_batch {
14292                for di in 0..d {
14293                    dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
14294                }
14295            }
14296            let new_dst = cpu_to_gpu(&dst_host, device)?;
14297            *dst = new_dst;
14298            return Ok(());
14299        }
14300    };
14301
14302    let cfg = launch_cfg(total)?;
14303    let n_u32 = total as u32;
14304    let d_u32 = d as u32;
14305    let max_len_u32 = max_len as u32;
14306    let pos_u32 = pos as u32;
14307
14308    unsafe {
14309        stream
14310            .launch_builder(&f)
14311            .arg(src.inner())
14312            .arg(dst.inner_mut())
14313            .arg(&n_u32)
14314            .arg(&d_u32)
14315            .arg(&max_len_u32)
14316            .arg(&pos_u32)
14317            .launch(cfg)?;
14318    }
14319
14320    Ok(())
14321}
14322
14323/// Read first `len` rows from each batch of f64 `[N, max_len, D]` -> `[N, len, D]`.
14324#[cfg(feature = "cuda")]
14325pub fn gpu_slice_read_f64(
14326    src: &CudaBuffer<f64>,
14327    n_batch: usize,
14328    d: usize,
14329    len: usize,
14330    max_len: usize,
14331    device: &GpuDevice,
14332) -> GpuResult<CudaBuffer<f64>> {
14333    use cudarc::driver::PushKernelArg;
14334    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14335
14336    let total = n_batch * len * d;
14337    let ctx = device.context();
14338    let stream = device.stream();
14339
14340    let ptx = get_f64_ptx(&CACHE, SLICE_READ_PTX, "slice_read_kernel", "slice_read_f64_kernel");
14341    let f = match crate::module_cache::get_or_compile(
14342        ctx,
14343        ptx,
14344        "slice_read_f64_kernel",
14345        device.ordinal() as u32,
14346    ) {
14347        Ok(f) => f,
14348        Err(_) => {
14349            let host = gpu_to_cpu(src, device)?;
14350            let mut out = vec![0.0f64; total];
14351            for b in 0..n_batch {
14352                for r in 0..len {
14353                    for di in 0..d {
14354                        out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
14355                    }
14356                }
14357            }
14358            return cpu_to_gpu(&out, device);
14359        }
14360    };
14361
14362    let mut out = alloc_zeros_f64(total, device)?;
14363    let cfg = launch_cfg(total)?;
14364    let total_u32 = total as u32;
14365    let d_u32 = d as u32;
14366    let len_u32 = len as u32;
14367    let max_len_u32 = max_len as u32;
14368
14369    unsafe {
14370        stream
14371            .launch_builder(&f)
14372            .arg(src.inner())
14373            .arg(out.inner_mut())
14374            .arg(&total_u32)
14375            .arg(&d_u32)
14376            .arg(&len_u32)
14377            .arg(&max_len_u32)
14378            .launch(cfg)?;
14379    }
14380
14381    Ok(out)
14382}
14383
14384// ---------------------------------------------------------------------------
14385// Public API -- f64 embedding ops
14386// ---------------------------------------------------------------------------
14387
14388/// Single f64 embedding lookup: `output[d] = weight[token_id * D + d]`.
14389#[cfg(feature = "cuda")]
14390pub fn gpu_embed_lookup_f64(
14391    idx: &CudaBuffer<f32>,
14392    weight: &CudaBuffer<f64>,
14393    d: usize,
14394    device: &GpuDevice,
14395) -> GpuResult<CudaBuffer<f64>> {
14396    use cudarc::driver::PushKernelArg;
14397    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14398
14399    let ctx = device.context();
14400    let stream = device.stream();
14401
14402    let ptx = get_f64_ptx(&CACHE, EMBED_LOOKUP_PTX, "embed_lookup_kernel", "embed_lookup_f64_kernel");
14403    let f = match crate::module_cache::get_or_compile(
14404        ctx,
14405        ptx,
14406        "embed_lookup_f64_kernel",
14407        device.ordinal() as u32,
14408    ) {
14409        Ok(f) => f,
14410        Err(_) => {
14411            let idx_host = gpu_to_cpu(idx, device)?;
14412            let weight_host = gpu_to_cpu(weight, device)?;
14413            let row = idx_host[0] as usize;
14414            let start = row * d;
14415            let out = weight_host[start..start + d].to_vec();
14416            return cpu_to_gpu(&out, device);
14417        }
14418    };
14419
14420    let mut out = alloc_zeros_f64(d, device)?;
14421    let cfg = launch_cfg(d)?;
14422    let d_u32 = d as u32;
14423
14424    unsafe {
14425        stream
14426            .launch_builder(&f)
14427            .arg(idx.inner())
14428            .arg(weight.inner())
14429            .arg(out.inner_mut())
14430            .arg(&d_u32)
14431            .launch(cfg)?;
14432    }
14433
14434    Ok(out)
14435}
14436
14437/// Batch f64 embedding lookup: gather N rows from `[V, D]` weight into `[N, D]`.
14438#[cfg(feature = "cuda")]
14439pub fn gpu_embed_lookup_batch_f64(
14440    indices: &CudaBuffer<f32>,
14441    weight: &CudaBuffer<f64>,
14442    n: usize,
14443    d: usize,
14444    device: &GpuDevice,
14445) -> GpuResult<CudaBuffer<f64>> {
14446    use cudarc::driver::PushKernelArg;
14447    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14448
14449    let total = n * d;
14450    if total == 0 {
14451        return alloc_zeros_f64(0, device);
14452    }
14453
14454    let ctx = device.context();
14455    let stream = device.stream();
14456
14457    let ptx = get_f64_ptx(&CACHE, EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel", "embed_lookup_batch_f64_kernel");
14458    let f = match crate::module_cache::get_or_compile(
14459        ctx,
14460        ptx,
14461        "embed_lookup_batch_f64_kernel",
14462        device.ordinal() as u32,
14463    ) {
14464        Ok(f) => f,
14465        Err(_) => {
14466            let idx_host = gpu_to_cpu(indices, device)?;
14467            let weight_host = gpu_to_cpu(weight, device)?;
14468            let mut out = Vec::with_capacity(total);
14469            for &idx_f in &idx_host {
14470                let row = idx_f as usize;
14471                let start = row * d;
14472                out.extend_from_slice(&weight_host[start..start + d]);
14473            }
14474            return cpu_to_gpu(&out, device);
14475        }
14476    };
14477
14478    let mut out = alloc_zeros_f64(total, device)?;
14479    let cfg = launch_cfg(total)?;
14480    let d_u32 = d as u32;
14481    let total_u32 = total as u32;
14482
14483    unsafe {
14484        stream
14485            .launch_builder(&f)
14486            .arg(indices.inner())
14487            .arg(weight.inner())
14488            .arg(out.inner_mut())
14489            .arg(&d_u32)
14490            .arg(&total_u32)
14491            .launch(cfg)?;
14492    }
14493
14494    Ok(out)
14495}
14496
14497/// Scatter-add f64 rows for embedding backward.
14498///
14499/// Atomically accumulates `grad_output[i, :] += grad_weight[indices[i], :]`.
14500#[cfg(feature = "cuda")]
14501pub fn gpu_scatter_add_rows_f64(
14502    grad_output: &CudaBuffer<f64>,
14503    indices: &CudaBuffer<f32>,
14504    num_embeddings: usize,
14505    d: usize,
14506    device: &GpuDevice,
14507) -> GpuResult<CudaBuffer<f64>> {
14508    use cudarc::driver::PushKernelArg;
14509    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14510
14511    let n = indices.len();
14512    let total = n * d;
14513
14514    if total == 0 {
14515        return alloc_zeros_f64(num_embeddings * d, device);
14516    }
14517
14518    let ctx = device.context();
14519    let stream = device.stream();
14520
14521    let ptx = get_f64_ptx(&CACHE, SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel", "scatter_add_rows_f64_kernel");
14522    let f = match crate::module_cache::get_or_compile(
14523        ctx,
14524        ptx,
14525        "scatter_add_rows_f64_kernel",
14526        device.ordinal() as u32,
14527    ) {
14528        Ok(f) => f,
14529        Err(_) => {
14530            let go_host = gpu_to_cpu(grad_output, device)?;
14531            let idx_host = gpu_to_cpu(indices, device)?;
14532            let mut result = vec![0.0f64; num_embeddings * d];
14533            for (i, &idx_f) in idx_host.iter().enumerate() {
14534                let row = idx_f as usize;
14535                for j in 0..d {
14536                    result[row * d + j] += go_host[i * d + j];
14537                }
14538            }
14539            return cpu_to_gpu(&result, device);
14540        }
14541    };
14542
14543    let mut out = alloc_zeros_f64(num_embeddings * d, device)?;
14544    let cfg = launch_cfg(total)?;
14545    let d_u32 = d as u32;
14546    let total_u32 = total as u32;
14547
14548    unsafe {
14549        stream
14550            .launch_builder(&f)
14551            .arg(grad_output.inner())
14552            .arg(indices.inner())
14553            .arg(out.inner_mut())
14554            .arg(&d_u32)
14555            .arg(&total_u32)
14556            .launch(cfg)?;
14557    }
14558
14559    Ok(out)
14560}
14561
14562// ---------------------------------------------------------------------------
14563// Public API -- fused Adam optimizer step
14564// ---------------------------------------------------------------------------
14565
14566/// Fused Adam optimizer step: updates param, exp_avg, and exp_avg_sq in-place
14567/// in a single kernel launch.
14568///
14569/// All four buffers must have the same length `n`. `param`, `exp_avg`, and
14570/// `exp_avg_sq` are modified in-place. `grad` is read-only.
14571#[cfg(feature = "cuda")]
14572#[allow(clippy::too_many_arguments)]
14573pub fn gpu_fused_adam(
14574    param: &mut CudaBuffer<f32>,
14575    grad: &CudaBuffer<f32>,
14576    exp_avg: &mut CudaBuffer<f32>,
14577    exp_avg_sq: &mut CudaBuffer<f32>,
14578    beta1: f32,
14579    beta2: f32,
14580    lr: f32,
14581    eps: f32,
14582    bc1: f32,
14583    bc2: f32,
14584    weight_decay: f32,
14585    device: &GpuDevice,
14586) -> GpuResult<()> {
14587    use cudarc::driver::PushKernelArg;
14588
14589    let n = param.len();
14590    if grad.len() != n || exp_avg.len() != n || exp_avg_sq.len() != n {
14591        return Err(GpuError::LengthMismatch {
14592            a: n,
14593            b: grad.len(),
14594        });
14595    }
14596
14597    let ctx = device.context();
14598    let stream = device.stream();
14599
14600    let f = match crate::module_cache::get_or_compile(
14601        ctx,
14602        FUSED_ADAM_PTX,
14603        "fused_adam_kernel",
14604        device.ordinal() as u32,
14605    ) {
14606        Ok(f) => f,
14607        Err(_) => {
14608            // CPU fallback: download, compute, upload.
14609            let mut p_host = gpu_to_cpu(param, device)?;
14610            let g_host = gpu_to_cpu(grad, device)?;
14611            let mut m_host = gpu_to_cpu(exp_avg, device)?;
14612            let mut v_host = gpu_to_cpu(exp_avg_sq, device)?;
14613
14614            for i in 0..n {
14615                let mut g = g_host[i];
14616                if weight_decay > 0.0 {
14617                    g += weight_decay * p_host[i];
14618                }
14619                m_host[i] = beta1 * m_host[i] + (1.0 - beta1) * g;
14620                v_host[i] = beta2 * v_host[i] + (1.0 - beta2) * g * g;
14621                let m_hat = m_host[i] / bc1;
14622                let v_hat = v_host[i] / bc2;
14623                p_host[i] -= lr * m_hat / (v_hat.sqrt() + eps);
14624            }
14625
14626            *param = cpu_to_gpu(&p_host, device)?;
14627            *exp_avg = cpu_to_gpu(&m_host, device)?;
14628            *exp_avg_sq = cpu_to_gpu(&v_host, device)?;
14629            return Ok(());
14630        }
14631    };
14632
14633    let cfg = launch_cfg(n)?;
14634    let n_u32 = n as u32;
14635
14636    unsafe {
14637        stream
14638            .launch_builder(&f)
14639            .arg(param.inner_mut())
14640            .arg(grad.inner())
14641            .arg(exp_avg.inner_mut())
14642            .arg(exp_avg_sq.inner_mut())
14643            .arg(&beta1)
14644            .arg(&beta2)
14645            .arg(&lr)
14646            .arg(&eps)
14647            .arg(&bc1)
14648            .arg(&bc2)
14649            .arg(&weight_decay)
14650            .arg(&n_u32)
14651            .launch(cfg)?;
14652    }
14653
14654    Ok(())
14655}
14656
14657/// Stub -- always returns [`GpuError::NoCudaFeature`].
14658#[cfg(not(feature = "cuda"))]
14659#[allow(clippy::too_many_arguments)]
14660pub fn gpu_fused_adam(
14661    _param: &mut CudaBuffer<f32>,
14662    _grad: &CudaBuffer<f32>,
14663    _exp_avg: &mut CudaBuffer<f32>,
14664    _exp_avg_sq: &mut CudaBuffer<f32>,
14665    _beta1: f32,
14666    _beta2: f32,
14667    _lr: f32,
14668    _eps: f32,
14669    _bc1: f32,
14670    _bc2: f32,
14671    _weight_decay: f32,
14672    _device: &GpuDevice,
14673) -> GpuResult<()> {
14674    Err(GpuError::NoCudaFeature)
14675}
14676
14677// ---------------------------------------------------------------------------
14678// Public API -- fused GRU cell
14679// ---------------------------------------------------------------------------
14680
14681/// Fused GRU cell forward: takes pre-computed gate matrices and produces
14682/// new hidden state + workspace for backward.
14683///
14684/// Inputs:
14685/// - `input_gates`: `[batch, 3*hsz]` — result of `x @ W_ih^T`
14686/// - `hidden_gates`: `[batch, 3*hsz]` — result of `h @ W_hh^T`
14687/// - `bias_ih`: `[3*hsz]` — input bias
14688/// - `bias_hh`: `[3*hsz]` — hidden bias
14689/// - `hx`: `[batch, hsz]` — previous hidden state
14690///
14691/// Outputs:
14692/// - `hy`: `[batch, hsz]` — new hidden state
14693/// - `workspace`: `[batch, 5*hsz]` — saved for backward (r, z, n, hx, hn+b2n)
14694#[cfg(feature = "cuda")]
14695pub fn gpu_fused_gru_forward(
14696    input_gates: &CudaBuffer<f32>,
14697    hidden_gates: &CudaBuffer<f32>,
14698    bias_ih: &CudaBuffer<f32>,
14699    bias_hh: &CudaBuffer<f32>,
14700    hx: &CudaBuffer<f32>,
14701    hsz: usize,
14702    device: &GpuDevice,
14703) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
14704    use cudarc::driver::PushKernelArg;
14705
14706    let total = hx.len(); // batch * hsz
14707    let batch = total / hsz;
14708
14709    let ctx = device.context();
14710    let stream = device.stream();
14711
14712    let f = match crate::module_cache::get_or_compile(
14713        ctx,
14714        FUSED_GRU_FORWARD_PTX,
14715        "fused_gru_forward_kernel",
14716        device.ordinal() as u32,
14717    ) {
14718        Ok(f) => f,
14719        Err(_) => {
14720            return Err(GpuError::PtxCompileFailed {
14721                kernel: "fused_gru_forward_kernel",
14722            });
14723        }
14724    };
14725
14726    let mut hy = alloc_zeros_f32(total, device)?;
14727    let mut workspace = alloc_zeros_f32(batch * 5 * hsz, device)?;
14728
14729    let cfg = launch_cfg(total)?;
14730    let hsz_u32 = hsz as u32;
14731    let total_u32 = total as u32;
14732
14733    unsafe {
14734        stream
14735            .launch_builder(&f)
14736            .arg(input_gates.inner())
14737            .arg(hidden_gates.inner())
14738            .arg(bias_ih.inner())
14739            .arg(bias_hh.inner())
14740            .arg(hx.inner())
14741            .arg(hy.inner_mut())
14742            .arg(workspace.inner_mut())
14743            .arg(&hsz_u32)
14744            .arg(&total_u32)
14745            .launch(cfg)?;
14746    }
14747
14748    Ok((hy, workspace))
14749}
14750
14751/// Stub -- always returns [`GpuError::NoCudaFeature`].
14752#[cfg(not(feature = "cuda"))]
14753pub fn gpu_fused_gru_forward(
14754    _input_gates: &CudaBuffer<f32>,
14755    _hidden_gates: &CudaBuffer<f32>,
14756    _bias_ih: &CudaBuffer<f32>,
14757    _bias_hh: &CudaBuffer<f32>,
14758    _hx: &CudaBuffer<f32>,
14759    _hsz: usize,
14760    _device: &GpuDevice,
14761) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
14762    Err(GpuError::NoCudaFeature)
14763}
14764
14765// ---------------------------------------------------------------------------
14766// Public API -- MaxPool2d / AvgPool2d
14767// ---------------------------------------------------------------------------
14768
14769/// MaxPool2d forward on GPU. One thread per output element.
14770#[cfg(feature = "cuda")]
14771#[allow(clippy::too_many_arguments)]
14772pub fn gpu_maxpool2d(
14773    input: &CudaBuffer<f32>,
14774    batch: usize,
14775    channels: usize,
14776    h_in: usize,
14777    w_in: usize,
14778    kh: usize,
14779    kw: usize,
14780    sh: usize,
14781    sw: usize,
14782    ph: usize,
14783    pw: usize,
14784    device: &GpuDevice,
14785) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
14786    use cudarc::driver::PushKernelArg;
14787
14788    let h_out = (h_in + 2 * ph - kh) / sh + 1;
14789    let w_out = (w_in + 2 * pw - kw) / sw + 1;
14790    let total = batch * channels * h_out * w_out;
14791
14792    let ctx = device.context();
14793    let stream = device.stream();
14794
14795    let f = match crate::module_cache::get_or_compile(
14796        ctx, MAXPOOL2D_PTX, "maxpool2d_forward_kernel", device.ordinal() as u32,
14797    ) {
14798        Ok(f) => f,
14799        Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "maxpool2d_forward_kernel" }),
14800    };
14801
14802    let mut out = alloc_zeros_f32(total, device)?;
14803    let cfg = launch_cfg(total)?;
14804
14805    let (batch_u32, ch_u32) = (batch as u32, channels as u32);
14806    let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
14807    let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
14808    let (kh_u32, kw_u32) = (kh as u32, kw as u32);
14809    let (sh_u32, sw_u32) = (sh as u32, sw as u32);
14810    let (ph_u32, pw_u32) = (ph as u32, pw as u32);
14811    let total_u32 = total as u32;
14812
14813    unsafe {
14814        stream.launch_builder(&f)
14815            .arg(input.inner())
14816            .arg(out.inner_mut())
14817            .arg(&batch_u32).arg(&ch_u32)
14818            .arg(&h_in_u32).arg(&w_in_u32)
14819            .arg(&h_out_u32).arg(&w_out_u32)
14820            .arg(&kh_u32).arg(&kw_u32)
14821            .arg(&sh_u32).arg(&sw_u32)
14822            .arg(&ph_u32).arg(&pw_u32)
14823            .arg(&total_u32)
14824            .launch(cfg)?;
14825    }
14826
14827    Ok((out, [batch, channels, h_out, w_out]))
14828}
14829
14830/// Stub.
14831#[cfg(not(feature = "cuda"))]
14832#[allow(clippy::too_many_arguments)]
14833pub fn gpu_maxpool2d(
14834    _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
14835    _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
14836    _sh: usize, _sw: usize, _ph: usize, _pw: usize,
14837    _device: &GpuDevice,
14838) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
14839    Err(GpuError::NoCudaFeature)
14840}
14841
14842/// AvgPool2d forward on GPU. One thread per output element.
14843#[cfg(feature = "cuda")]
14844#[allow(clippy::too_many_arguments)]
14845pub fn gpu_avgpool2d(
14846    input: &CudaBuffer<f32>,
14847    batch: usize,
14848    channels: usize,
14849    h_in: usize,
14850    w_in: usize,
14851    kh: usize,
14852    kw: usize,
14853    sh: usize,
14854    sw: usize,
14855    ph: usize,
14856    pw: usize,
14857    device: &GpuDevice,
14858) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
14859    use cudarc::driver::PushKernelArg;
14860
14861    let h_out = (h_in + 2 * ph - kh) / sh + 1;
14862    let w_out = (w_in + 2 * pw - kw) / sw + 1;
14863    let total = batch * channels * h_out * w_out;
14864
14865    let ctx = device.context();
14866    let stream = device.stream();
14867
14868    let f = match crate::module_cache::get_or_compile(
14869        ctx, AVGPOOL2D_PTX, "avgpool2d_forward_kernel", device.ordinal() as u32,
14870    ) {
14871        Ok(f) => f,
14872        Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "avgpool2d_forward_kernel" }),
14873    };
14874
14875    let mut out = alloc_zeros_f32(total, device)?;
14876    let cfg = launch_cfg(total)?;
14877
14878    let (batch_u32, ch_u32) = (batch as u32, channels as u32);
14879    let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
14880    let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
14881    let (kh_u32, kw_u32) = (kh as u32, kw as u32);
14882    let (sh_u32, sw_u32) = (sh as u32, sw as u32);
14883    let (ph_u32, pw_u32) = (ph as u32, pw as u32);
14884    let total_u32 = total as u32;
14885
14886    unsafe {
14887        stream.launch_builder(&f)
14888            .arg(input.inner())
14889            .arg(out.inner_mut())
14890            .arg(&batch_u32).arg(&ch_u32)
14891            .arg(&h_in_u32).arg(&w_in_u32)
14892            .arg(&h_out_u32).arg(&w_out_u32)
14893            .arg(&kh_u32).arg(&kw_u32)
14894            .arg(&sh_u32).arg(&sw_u32)
14895            .arg(&ph_u32).arg(&pw_u32)
14896            .arg(&total_u32)
14897            .launch(cfg)?;
14898    }
14899
14900    Ok((out, [batch, channels, h_out, w_out]))
14901}
14902
14903/// Stub.
14904#[cfg(not(feature = "cuda"))]
14905#[allow(clippy::too_many_arguments)]
14906pub fn gpu_avgpool2d(
14907    _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
14908    _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
14909    _sh: usize, _sw: usize, _ph: usize, _pw: usize,
14910    _device: &GpuDevice,
14911) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
14912    Err(GpuError::NoCudaFeature)
14913}
14914
14915// ---------------------------------------------------------------------------
14916// Public API -- BatchNorm2d
14917// ---------------------------------------------------------------------------
14918
14919/// BatchNorm2d forward on GPU (placeholder — kernel pass-1 indexing needs
14920/// refinement). Currently validates the kernel compiles and falls back to
14921/// returning an error so callers use the CPU path.
14922#[cfg(feature = "cuda")]
14923#[allow(clippy::too_many_arguments)]
14924pub fn gpu_batchnorm_forward(
14925    _input: &CudaBuffer<f32>,
14926    _weight: &CudaBuffer<f32>,
14927    _bias: &CudaBuffer<f32>,
14928    _running_mean: &mut CudaBuffer<f32>,
14929    _running_var: &mut CudaBuffer<f32>,
14930    _channels: usize,
14931    _spatial: usize,
14932    _eps: f32,
14933    _momentum: f32,
14934    _training: bool,
14935    device: &GpuDevice,
14936) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
14937    // Validate the PTX compiles (catches syntax errors at first call).
14938    let ctx = device.context();
14939    let _f = crate::module_cache::get_or_compile(
14940        ctx,
14941        BATCHNORM_FORWARD_PTX,
14942        "batchnorm_forward_kernel",
14943        device.ordinal() as u32,
14944    );
14945    // Full implementation pending — pass-1 loop indexing needs refinement.
14946    Err(GpuError::ShapeMismatch {
14947        op: "batchnorm_forward",
14948        expected: vec![0],
14949        got: vec![1],
14950    })
14951}
14952
14953/// Stub.
14954#[cfg(not(feature = "cuda"))]
14955#[allow(clippy::too_many_arguments)]
14956pub fn gpu_batchnorm_forward(
14957    _input: &CudaBuffer<f32>,
14958    _weight: &CudaBuffer<f32>,
14959    _bias: &CudaBuffer<f32>,
14960    _running_mean: &mut CudaBuffer<f32>,
14961    _running_var: &mut CudaBuffer<f32>,
14962    _channels: usize,
14963    _spatial: usize,
14964    _eps: f32,
14965    _momentum: f32,
14966    _training: bool,
14967    _device: &GpuDevice,
14968) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
14969    Err(GpuError::NoCudaFeature)
14970}
14971
14972// ---------------------------------------------------------------------------
14973// Public API -- LayerNorm
14974// ---------------------------------------------------------------------------
14975
14976/// Row-wise layer normalization on GPU.
14977///
14978/// `input`: `[rows * cols]`, `weight`/`bias`: `[cols]`.
14979/// Output: normalized and affine-transformed `[rows * cols]`.
14980#[cfg(feature = "cuda")]
14981pub fn gpu_layernorm(
14982    input: &CudaBuffer<f32>,
14983    weight: &CudaBuffer<f32>,
14984    bias: &CudaBuffer<f32>,
14985    rows: usize,
14986    cols: usize,
14987    eps: f32,
14988    device: &GpuDevice,
14989) -> GpuResult<CudaBuffer<f32>> {
14990    use cudarc::driver::PushKernelArg;
14991
14992    validate_unary(input, device)?;
14993
14994    let ctx = device.context();
14995    let stream = device.stream();
14996
14997    let f = match crate::module_cache::get_or_compile(
14998        ctx,
14999        LAYERNORM_PTX,
15000        "layernorm_kernel",
15001        device.ordinal() as u32,
15002    ) {
15003        Ok(f) => f,
15004        Err(e) => {
15005            eprintln!("ferrotorch-gpu: LayerNorm PTX compilation failed ({e:?}), CPU fallback");
15006            std::fs::write("/tmp/layernorm_debug.ptx", LAYERNORM_PTX).ok();
15007            eprintln!(
15008                "ferrotorch-gpu: dumped PTX to /tmp/layernorm_debug.ptx ({} bytes)",
15009                LAYERNORM_PTX.len()
15010            );
15011            let h_in = gpu_to_cpu(input, device)?;
15012            let h_w = gpu_to_cpu(weight, device)?;
15013            let h_b = gpu_to_cpu(bias, device)?;
15014            let mut out = vec![0.0f32; rows * cols];
15015            for r in 0..rows {
15016                let base = r * cols;
15017                let slice = &h_in[base..base + cols];
15018                let mean: f32 = slice.iter().sum::<f32>() / cols as f32;
15019                let var: f32 =
15020                    slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / cols as f32;
15021                let inv_std = 1.0 / (var + eps).sqrt();
15022                for c in 0..cols {
15023                    let normed = (slice[c] - mean) * inv_std;
15024                    out[base + c] = h_w[c] * normed + h_b[c];
15025                }
15026            }
15027            return cpu_to_gpu(&out, device);
15028        }
15029    };
15030
15031    let mut out = alloc_zeros_f32(rows * cols, device)?;
15032    let rows_u32 = rows as u32;
15033    let cols_u32 = cols as u32;
15034
15035    let cfg = LaunchConfig {
15036        grid_dim: ((rows as u32).max(1), 1, 1),
15037        block_dim: (256, 1, 1),
15038        shared_mem_bytes: 256 * 4,
15039    };
15040
15041    unsafe {
15042        stream
15043            .launch_builder(&f)
15044            .arg(input.inner())
15045            .arg(out.inner_mut())
15046            .arg(weight.inner())
15047            .arg(bias.inner())
15048            .arg(&rows_u32)
15049            .arg(&cols_u32)
15050            .arg(&eps)
15051            .launch(cfg)?;
15052    }
15053
15054    Ok(out)
15055}
15056
15057// ---------------------------------------------------------------------------
15058// Public API -- LayerNorm backward
15059// ---------------------------------------------------------------------------
15060
15061/// LayerNorm backward pass on GPU.
15062///
15063/// Computes grad_input, grad_weight, and grad_bias entirely on GPU.
15064/// One block per batch element (row), 256 threads per block.
15065/// grad_weight and grad_bias are accumulated across batches via atomicAdd.
15066///
15067/// `input`: `[rows * cols]`, `grad_output`: `[rows * cols]`, `weight`: `[cols]`.
15068/// Returns: `(grad_input [rows * cols], grad_weight [cols], grad_bias [cols])`.
15069#[cfg(feature = "cuda")]
15070pub fn gpu_layernorm_backward(
15071    input: &CudaBuffer<f32>,
15072    grad_output: &CudaBuffer<f32>,
15073    weight: &CudaBuffer<f32>,
15074    rows: usize,
15075    cols: usize,
15076    eps: f32,
15077    device: &GpuDevice,
15078) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
15079    use cudarc::driver::PushKernelArg;
15080
15081    validate_unary(input, device)?;
15082
15083    let ctx = device.context();
15084    let stream = device.stream();
15085
15086    let f = match crate::module_cache::get_or_compile(
15087        ctx,
15088        LAYERNORM_BACKWARD_PTX,
15089        "layernorm_backward_kernel",
15090        device.ordinal() as u32,
15091    ) {
15092        Ok(f) => f,
15093        Err(_) => {
15094            // CPU fallback
15095            let h_in = gpu_to_cpu(input, device)?;
15096            let h_go = gpu_to_cpu(grad_output, device)?;
15097            let h_w = gpu_to_cpu(weight, device)?;
15098            let mut grad_input = vec![0.0f32; rows * cols];
15099            let mut grad_weight = vec![0.0f32; cols];
15100            let mut grad_bias = vec![0.0f32; cols];
15101            let n_f = cols as f32;
15102            for r in 0..rows {
15103                let base = r * cols;
15104                let x_slice = &h_in[base..base + cols];
15105                let go_slice = &h_go[base..base + cols];
15106                let mean: f32 = x_slice.iter().sum::<f32>() / n_f;
15107                let var: f32 = x_slice
15108                    .iter()
15109                    .map(|&x| (x - mean) * (x - mean))
15110                    .sum::<f32>()
15111                    / n_f;
15112                let inv_std = 1.0 / (var + eps).sqrt();
15113                let mut sum1 = 0.0f32;
15114                let mut sum2 = 0.0f32;
15115                for c in 0..cols {
15116                    let x_hat = (x_slice[c] - mean) * inv_std;
15117                    let dl = go_slice[c] * h_w[c];
15118                    sum1 += dl;
15119                    sum2 += dl * x_hat;
15120                    grad_weight[c] += go_slice[c] * x_hat;
15121                    grad_bias[c] += go_slice[c];
15122                }
15123                let m1 = sum1 / n_f;
15124                let m2 = sum2 / n_f;
15125                for c in 0..cols {
15126                    let x_hat = (x_slice[c] - mean) * inv_std;
15127                    let dl = go_slice[c] * h_w[c];
15128                    grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
15129                }
15130            }
15131            let gi = cpu_to_gpu(&grad_input, device)?;
15132            let gw = cpu_to_gpu(&grad_weight, device)?;
15133            let gb = cpu_to_gpu(&grad_bias, device)?;
15134            return Ok((gi, gw, gb));
15135        }
15136    };
15137
15138    let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
15139    let mut grad_w = alloc_zeros_f32(cols, device)?;
15140    let mut grad_b = alloc_zeros_f32(cols, device)?;
15141    let rows_u32 = rows as u32;
15142    let cols_u32 = cols as u32;
15143
15144    // One block per row, 256 threads per block.
15145    let cfg = LaunchConfig {
15146        grid_dim: ((rows as u32).max(1), 1, 1),
15147        block_dim: (256, 1, 1),
15148        shared_mem_bytes: 256 * 4,
15149    };
15150
15151    unsafe {
15152        stream
15153            .launch_builder(&f)
15154            .arg(input.inner())
15155            .arg(grad_output.inner())
15156            .arg(weight.inner())
15157            .arg(grad_in.inner_mut())
15158            .arg(grad_w.inner_mut())
15159            .arg(grad_b.inner_mut())
15160            .arg(&rows_u32)
15161            .arg(&cols_u32)
15162            .arg(&eps)
15163            .launch(cfg)?;
15164    }
15165
15166    Ok((grad_in, grad_w, grad_b))
15167}
15168
15169/// Stub -- always returns [`GpuError::NoCudaFeature`].
15170#[cfg(not(feature = "cuda"))]
15171pub fn gpu_layernorm_backward(
15172    _input: &CudaBuffer<f32>,
15173    _grad_output: &CudaBuffer<f32>,
15174    _weight: &CudaBuffer<f32>,
15175    _rows: usize,
15176    _cols: usize,
15177    _eps: f32,
15178    _device: &GpuDevice,
15179) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
15180    Err(GpuError::NoCudaFeature)
15181}
15182
15183// ---------------------------------------------------------------------------
15184// Public API -- RMSNorm
15185// ---------------------------------------------------------------------------
15186
15187/// Row-wise RMS normalization on GPU.
15188///
15189/// `input`: `[rows * cols]`, `weight`: `[cols]`.
15190/// Output: normalized and scaled `[rows * cols]`.
15191///
15192/// Computes `out[j] = x[j] * rsqrt(mean(x^2) + eps) * weight[j]`.
15193/// No bias, no mean centering (unlike LayerNorm).
15194#[cfg(feature = "cuda")]
15195pub fn gpu_rmsnorm(
15196    input: &CudaBuffer<f32>,
15197    weight: &CudaBuffer<f32>,
15198    rows: usize,
15199    cols: usize,
15200    eps: f32,
15201    device: &GpuDevice,
15202) -> GpuResult<CudaBuffer<f32>> {
15203    use cudarc::driver::PushKernelArg;
15204
15205    validate_unary(input, device)?;
15206
15207    let ctx = device.context();
15208    let stream = device.stream();
15209
15210    let f = match crate::module_cache::get_or_compile(
15211        ctx,
15212        RMSNORM_PTX,
15213        "rmsnorm_kernel",
15214        device.ordinal() as u32,
15215    ) {
15216        Ok(f) => f,
15217        Err(e) => {
15218            eprintln!("ferrotorch-gpu: RMSNorm PTX compilation failed ({e:?}), CPU fallback");
15219            std::fs::write("/tmp/rmsnorm_debug.ptx", RMSNORM_PTX).ok();
15220            eprintln!(
15221                "ferrotorch-gpu: dumped PTX to /tmp/rmsnorm_debug.ptx ({} bytes)",
15222                RMSNORM_PTX.len()
15223            );
15224            let h_in = gpu_to_cpu(input, device)?;
15225            let h_w = gpu_to_cpu(weight, device)?;
15226            let mut out = vec![0.0f32; rows * cols];
15227            for r in 0..rows {
15228                let base = r * cols;
15229                let slice = &h_in[base..base + cols];
15230                let sq_mean: f32 =
15231                    slice.iter().map(|&x| x * x).sum::<f32>() / cols as f32;
15232                let inv_rms = 1.0 / (sq_mean + eps).sqrt();
15233                for c in 0..cols {
15234                    out[base + c] = slice[c] * inv_rms * h_w[c];
15235                }
15236            }
15237            return cpu_to_gpu(&out, device);
15238        }
15239    };
15240
15241    let mut out = alloc_zeros_f32(rows * cols, device)?;
15242    let rows_u32 = rows as u32;
15243    let cols_u32 = cols as u32;
15244
15245    let cfg = LaunchConfig {
15246        grid_dim: ((rows as u32).max(1), 1, 1),
15247        block_dim: (256, 1, 1),
15248        shared_mem_bytes: 256 * 4,
15249    };
15250
15251    unsafe {
15252        stream
15253            .launch_builder(&f)
15254            .arg(input.inner())
15255            .arg(out.inner_mut())
15256            .arg(weight.inner())
15257            .arg(&rows_u32)
15258            .arg(&cols_u32)
15259            .arg(&eps)
15260            .launch(cfg)?;
15261    }
15262
15263    Ok(out)
15264}
15265
15266// ---------------------------------------------------------------------------
15267// Public API -- RMSNorm backward
15268// ---------------------------------------------------------------------------
15269
15270/// RMSNorm backward pass on GPU.
15271///
15272/// Computes grad_input and grad_weight entirely on GPU.
15273/// One block per batch element (row), 256 threads per block.
15274/// grad_weight is accumulated across batches via atomicAdd.
15275///
15276/// `input`: `[rows * cols]`, `grad_output`: `[rows * cols]`, `weight`: `[cols]`.
15277/// Returns: `(grad_input [rows * cols], grad_weight [cols])`.
15278#[cfg(feature = "cuda")]
15279pub fn gpu_rmsnorm_backward(
15280    input: &CudaBuffer<f32>,
15281    grad_output: &CudaBuffer<f32>,
15282    weight: &CudaBuffer<f32>,
15283    rows: usize,
15284    cols: usize,
15285    eps: f32,
15286    device: &GpuDevice,
15287) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
15288    use cudarc::driver::PushKernelArg;
15289
15290    validate_unary(input, device)?;
15291
15292    let ctx = device.context();
15293    let stream = device.stream();
15294
15295    let f = match crate::module_cache::get_or_compile(
15296        ctx,
15297        RMSNORM_BACKWARD_PTX,
15298        "rmsnorm_backward_kernel",
15299        device.ordinal() as u32,
15300    ) {
15301        Ok(f) => f,
15302        Err(_) => {
15303            // CPU fallback
15304            let h_in = gpu_to_cpu(input, device)?;
15305            let h_go = gpu_to_cpu(grad_output, device)?;
15306            let h_w = gpu_to_cpu(weight, device)?;
15307            let mut grad_input = vec![0.0f32; rows * cols];
15308            let mut grad_weight = vec![0.0f32; cols];
15309            let n_f = cols as f32;
15310            for r in 0..rows {
15311                let base = r * cols;
15312                let x_slice = &h_in[base..base + cols];
15313                let go_slice = &h_go[base..base + cols];
15314                let sq_mean: f32 =
15315                    x_slice.iter().map(|&x| x * x).sum::<f32>() / n_f;
15316                let inv_rms = 1.0 / (sq_mean + eps).sqrt();
15317                let inv_rms3 = inv_rms * inv_rms * inv_rms;
15318                let mut dot = 0.0f32;
15319                for c in 0..cols {
15320                    dot += go_slice[c] * x_slice[c] * h_w[c];
15321                    grad_weight[c] += go_slice[c] * x_slice[c] * inv_rms;
15322                }
15323                let coeff = dot * inv_rms3 / n_f;
15324                for c in 0..cols {
15325                    grad_input[base + c] =
15326                        inv_rms * h_w[c] * go_slice[c] - x_slice[c] * coeff;
15327                }
15328            }
15329            let gi = cpu_to_gpu(&grad_input, device)?;
15330            let gw = cpu_to_gpu(&grad_weight, device)?;
15331            return Ok((gi, gw));
15332        }
15333    };
15334
15335    let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
15336    let mut grad_w = alloc_zeros_f32(cols, device)?;
15337    let rows_u32 = rows as u32;
15338    let cols_u32 = cols as u32;
15339
15340    // One block per row, 256 threads per block.
15341    let cfg = LaunchConfig {
15342        grid_dim: ((rows as u32).max(1), 1, 1),
15343        block_dim: (256, 1, 1),
15344        shared_mem_bytes: 256 * 4,
15345    };
15346
15347    unsafe {
15348        stream
15349            .launch_builder(&f)
15350            .arg(input.inner())
15351            .arg(grad_output.inner())
15352            .arg(weight.inner())
15353            .arg(grad_in.inner_mut())
15354            .arg(grad_w.inner_mut())
15355            .arg(&rows_u32)
15356            .arg(&cols_u32)
15357            .arg(&eps)
15358            .launch(cfg)?;
15359    }
15360
15361    Ok((grad_in, grad_w))
15362}
15363
15364/// Stub -- always returns [`GpuError::NoCudaFeature`].
15365#[cfg(not(feature = "cuda"))]
15366pub fn gpu_rmsnorm(
15367    _input: &CudaBuffer<f32>,
15368    _weight: &CudaBuffer<f32>,
15369    _rows: usize,
15370    _cols: usize,
15371    _eps: f32,
15372    _device: &GpuDevice,
15373) -> GpuResult<CudaBuffer<f32>> {
15374    Err(GpuError::NoCudaFeature)
15375}
15376
15377/// Stub -- always returns [`GpuError::NoCudaFeature`].
15378#[cfg(not(feature = "cuda"))]
15379pub fn gpu_rmsnorm_backward(
15380    _input: &CudaBuffer<f32>,
15381    _grad_output: &CudaBuffer<f32>,
15382    _weight: &CudaBuffer<f32>,
15383    _rows: usize,
15384    _cols: usize,
15385    _eps: f32,
15386    _device: &GpuDevice,
15387) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
15388    Err(GpuError::NoCudaFeature)
15389}
15390
15391// ===========================================================================
15392// _into variants — write to pre-allocated output buffers (zero allocation)
15393//
15394// These are used for CUDA graph capture, where all buffer addresses must be
15395// fixed at capture time. The PTX kernels are identical — only the Rust
15396// wrapper skips allocation.
15397// ===========================================================================
15398
15399/// Elementwise add into pre-allocated output: `out[i] = a[i] + b[i]`.
15400#[cfg(feature = "cuda")]
15401pub fn gpu_add_into(
15402    a: &CudaBuffer<f32>,
15403    b: &CudaBuffer<f32>,
15404    out: &mut CudaBuffer<f32>,
15405    device: &GpuDevice,
15406) -> GpuResult<()> {
15407    validate_binary(a, b, device)?;
15408    if out.len() < a.len() {
15409        return Err(GpuError::ShapeMismatch {
15410            op: "add_into",
15411            expected: vec![a.len()],
15412            got: vec![out.len()],
15413        });
15414    }
15415    if try_launch_binary_into(a, b, out, device, ADD_PTX, "add_kernel")? {
15416        return Ok(());
15417    }
15418    Err(GpuError::PtxCompileFailed {
15419        kernel: "add_kernel",
15420    })
15421}
15422
15423/// Elementwise mul into pre-allocated output: `out[i] = a[i] * b[i]`.
15424#[cfg(feature = "cuda")]
15425pub fn gpu_mul_into(
15426    a: &CudaBuffer<f32>,
15427    b: &CudaBuffer<f32>,
15428    out: &mut CudaBuffer<f32>,
15429    device: &GpuDevice,
15430) -> GpuResult<()> {
15431    validate_binary(a, b, device)?;
15432    if out.len() < a.len() {
15433        return Err(GpuError::ShapeMismatch {
15434            op: "mul_into",
15435            expected: vec![a.len()],
15436            got: vec![out.len()],
15437        });
15438    }
15439    if try_launch_binary_into(a, b, out, device, MUL_PTX, "mul_kernel")? {
15440        return Ok(());
15441    }
15442    Err(GpuError::PtxCompileFailed {
15443        kernel: "mul_kernel",
15444    })
15445}
15446
15447/// Scalar multiply into pre-allocated output: `out[i] = a[i] * scalar`.
15448#[cfg(feature = "cuda")]
15449pub fn gpu_scale_into(
15450    a: &CudaBuffer<f32>,
15451    scalar: f32,
15452    out: &mut CudaBuffer<f32>,
15453    device: &GpuDevice,
15454) -> GpuResult<()> {
15455    use cudarc::driver::PushKernelArg;
15456    validate_unary(a, device)?;
15457    let n = a.len();
15458    let ctx = device.context();
15459    let stream = device.stream();
15460    let f = crate::module_cache::get_or_compile(
15461        ctx,
15462        SCALE_PTX,
15463        "scale_kernel",
15464        device.ordinal() as u32,
15465    )
15466    .map_err(|_| GpuError::PtxCompileFailed {
15467        kernel: "scale_kernel",
15468    })?;
15469    let cfg = launch_cfg(n)?;
15470    let n_u32 = n as u32;
15471    unsafe {
15472        stream
15473            .launch_builder(&f)
15474            .arg(a.inner())
15475            .arg(out.inner_mut())
15476            .arg(&scalar)
15477            .arg(&n_u32)
15478            .launch(cfg)?;
15479    }
15480    Ok(())
15481}
15482
15483/// Check whether a GPU buffer contains any inf or NaN values.
15484///
15485/// Downloads the buffer contents to the host and scans for non-finite
15486/// values. This is correct for any buffer size and requires no custom
15487/// reduction kernel.
15488///
15489/// For a future optimization, a dedicated GPU reduction kernel could be
15490/// used to produce a single boolean flag on device, avoiding the full
15491/// download. The current approach is already much faster than the old
15492/// per-element CPU loop in `unscale_()` because the scaling itself
15493/// runs on GPU — only the inf/NaN check touches the host.
15494///
15495/// # Errors
15496///
15497/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different CUDA devices.
15498/// - [`GpuError::Driver`] on CUDA runtime errors.
15499#[cfg(feature = "cuda")]
15500pub fn gpu_has_inf_nan(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<bool> {
15501    let n = a.len();
15502    if n == 0 {
15503        return Ok(false);
15504    }
15505
15506    validate_unary(a, device)?;
15507
15508    let host: Vec<f32> = crate::transfer::gpu_to_cpu(a, device)?;
15509    Ok(host.iter().any(|v| !v.is_finite()))
15510}
15511
15512/// Stub -- always returns [`GpuError::NoCudaFeature`].
15513#[cfg(not(feature = "cuda"))]
15514pub fn gpu_has_inf_nan(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<bool> {
15515    Err(GpuError::NoCudaFeature)
15516}
15517
15518/// GELU into pre-allocated output.
15519#[cfg(feature = "cuda")]
15520pub fn gpu_gelu_into(
15521    a: &CudaBuffer<f32>,
15522    out: &mut CudaBuffer<f32>,
15523    device: &GpuDevice,
15524) -> GpuResult<()> {
15525    validate_unary(a, device)?;
15526    if try_launch_unary_into(a, out, device, GELU_PTX, "gelu_kernel")? {
15527        return Ok(());
15528    }
15529    Err(GpuError::PtxCompileFailed {
15530        kernel: "gelu_kernel",
15531    })
15532}
15533
15534/// Embedding lookup into pre-allocated output.
15535#[cfg(feature = "cuda")]
15536pub fn gpu_embed_lookup_into(
15537    idx: &CudaBuffer<f32>,
15538    weight: &CudaBuffer<f32>,
15539    d: usize,
15540    out: &mut CudaBuffer<f32>,
15541    device: &GpuDevice,
15542) -> GpuResult<()> {
15543    use cudarc::driver::PushKernelArg;
15544    let ctx = device.context();
15545    let stream = device.stream();
15546    let f = crate::module_cache::get_or_compile(
15547        ctx,
15548        EMBED_LOOKUP_PTX,
15549        "embed_lookup_kernel",
15550        device.ordinal() as u32,
15551    )
15552    .map_err(|_| GpuError::PtxCompileFailed {
15553        kernel: "embed_lookup_kernel",
15554    })?;
15555    let cfg = launch_cfg(d)?;
15556    let d_u32 = d as u32;
15557    unsafe {
15558        stream
15559            .launch_builder(&f)
15560            .arg(idx.inner())
15561            .arg(weight.inner())
15562            .arg(out.inner_mut())
15563            .arg(&d_u32)
15564            .launch(cfg)?;
15565    }
15566    Ok(())
15567}
15568
15569// ---------------------------------------------------------------------------
15570// Public API -- Batch embedding lookup (GPU-native)
15571// ---------------------------------------------------------------------------
15572
15573/// GPU batch embedding lookup: given `indices` (N f32 values on GPU) and
15574/// `weight` `[V, D]`, gather N rows to produce output `[N, D]`.
15575/// Entire operation stays on GPU -- no CPU roundtrip.
15576#[cfg(feature = "cuda")]
15577pub fn gpu_embed_lookup_batch(
15578    indices: &CudaBuffer<f32>,
15579    weight: &CudaBuffer<f32>,
15580    n: usize,
15581    d: usize,
15582    device: &GpuDevice,
15583) -> GpuResult<CudaBuffer<f32>> {
15584    use cudarc::driver::PushKernelArg;
15585
15586    let total = n * d;
15587    if total == 0 {
15588        return alloc_zeros_f32(0, device);
15589    }
15590
15591    let ctx = device.context();
15592    let stream = device.stream();
15593
15594    let f = match crate::module_cache::get_or_compile(
15595        ctx,
15596        EMBED_LOOKUP_BATCH_PTX,
15597        "embed_lookup_batch_kernel",
15598        device.ordinal() as u32,
15599    ) {
15600        Ok(f) => f,
15601        Err(_) => {
15602            // CPU fallback.
15603            let idx_host = gpu_to_cpu(indices, device)?;
15604            let weight_host = gpu_to_cpu(weight, device)?;
15605            let mut out = Vec::with_capacity(total);
15606            for &idx_f in &idx_host {
15607                let row = idx_f as usize;
15608                let start = row * d;
15609                out.extend_from_slice(&weight_host[start..start + d]);
15610            }
15611            return cpu_to_gpu(&out, device);
15612        }
15613    };
15614
15615    let mut out = alloc_zeros_f32(total, device)?;
15616    let cfg = launch_cfg(total)?;
15617    let d_u32 = d as u32;
15618    let total_u32 = total as u32;
15619
15620    unsafe {
15621        stream
15622            .launch_builder(&f)
15623            .arg(indices.inner())
15624            .arg(weight.inner())
15625            .arg(out.inner_mut())
15626            .arg(&d_u32)
15627            .arg(&total_u32)
15628            .launch(cfg)?;
15629    }
15630
15631    Ok(out)
15632}
15633
15634// ---------------------------------------------------------------------------
15635// Public API -- Scatter-add rows (for embedding backward, GPU-native)
15636// ---------------------------------------------------------------------------
15637
15638/// GPU scatter-add rows: given `grad_output` `[N, D]` and `indices` `[N]` (f32),
15639/// atomically accumulate into `grad_weight` `[V, D]` (pre-zeroed):
15640///   `grad_weight[indices[i], :] += grad_output[i, :]`
15641///
15642/// Duplicate indices accumulate correctly via atomic adds.
15643#[cfg(feature = "cuda")]
15644pub fn gpu_scatter_add_rows(
15645    grad_output: &CudaBuffer<f32>,
15646    indices: &CudaBuffer<f32>,
15647    num_embeddings: usize,
15648    d: usize,
15649    device: &GpuDevice,
15650) -> GpuResult<CudaBuffer<f32>> {
15651    use cudarc::driver::PushKernelArg;
15652
15653    let n = indices.len();
15654    let total = n * d;
15655
15656    if total == 0 {
15657        return alloc_zeros_f32(num_embeddings * d, device);
15658    }
15659
15660    let ctx = device.context();
15661    let stream = device.stream();
15662
15663    let f = match crate::module_cache::get_or_compile(
15664        ctx,
15665        SCATTER_ADD_ROWS_PTX,
15666        "scatter_add_rows_kernel",
15667        device.ordinal() as u32,
15668    ) {
15669        Ok(f) => f,
15670        Err(_) => {
15671            // CPU fallback.
15672            let go_host = gpu_to_cpu(grad_output, device)?;
15673            let idx_host = gpu_to_cpu(indices, device)?;
15674            let mut result = vec![0.0f32; num_embeddings * d];
15675            for (i, &idx_f) in idx_host.iter().enumerate() {
15676                let row = idx_f as usize;
15677                for j in 0..d {
15678                    result[row * d + j] += go_host[i * d + j];
15679                }
15680            }
15681            return cpu_to_gpu(&result, device);
15682        }
15683    };
15684
15685    let mut out = alloc_zeros_f32(num_embeddings * d, device)?;
15686    let cfg = launch_cfg(total)?;
15687    let d_u32 = d as u32;
15688    let total_u32 = total as u32;
15689
15690    unsafe {
15691        stream
15692            .launch_builder(&f)
15693            .arg(grad_output.inner())
15694            .arg(indices.inner())
15695            .arg(out.inner_mut())
15696            .arg(&d_u32)
15697            .arg(&total_u32)
15698            .launch(cfg)?;
15699    }
15700
15701    Ok(out)
15702}
15703
15704/// 2D transpose into pre-allocated output.
15705#[cfg(feature = "cuda")]
15706pub fn gpu_transpose_2d_into(
15707    a: &CudaBuffer<f32>,
15708    m: usize,
15709    n: usize,
15710    out: &mut CudaBuffer<f32>,
15711    device: &GpuDevice,
15712) -> GpuResult<()> {
15713    use cudarc::driver::PushKernelArg;
15714    let total = m * n;
15715    let ctx = device.context();
15716    let stream = device.stream();
15717    let f = crate::module_cache::get_or_compile(
15718        ctx,
15719        TRANSPOSE_2D_PTX,
15720        "transpose_2d_kernel",
15721        device.ordinal() as u32,
15722    )
15723    .map_err(|_| GpuError::PtxCompileFailed {
15724        kernel: "transpose_2d_kernel",
15725    })?;
15726    let cfg = launch_cfg(total)?;
15727    let m_u32 = m as u32;
15728    let n_u32 = n as u32;
15729    let total_u32 = total as u32;
15730    unsafe {
15731        stream
15732            .launch_builder(&f)
15733            .arg(a.inner())
15734            .arg(out.inner_mut())
15735            .arg(&m_u32)
15736            .arg(&n_u32)
15737            .arg(&total_u32)
15738            .launch(cfg)?;
15739    }
15740    Ok(())
15741}
15742
15743/// Permute (0,2,1,3) into pre-allocated output.
15744#[cfg(feature = "cuda")]
15745pub fn gpu_permute_0213_into(
15746    a: &CudaBuffer<f32>,
15747    d0: usize,
15748    d1: usize,
15749    d2: usize,
15750    d3: usize,
15751    out: &mut CudaBuffer<f32>,
15752    device: &GpuDevice,
15753) -> GpuResult<()> {
15754    use cudarc::driver::PushKernelArg;
15755    let total = d0 * d1 * d2 * d3;
15756    let ctx = device.context();
15757    let stream = device.stream();
15758    let f = crate::module_cache::get_or_compile(
15759        ctx,
15760        PERMUTE_0213_PTX,
15761        "permute_0213_kernel",
15762        device.ordinal() as u32,
15763    )
15764    .map_err(|_| GpuError::PtxCompileFailed {
15765        kernel: "permute_0213_kernel",
15766    })?;
15767    let cfg = launch_cfg(total)?;
15768    let (d0u, d1u, d2u, d3u, tu) = (d0 as u32, d1 as u32, d2 as u32, d3 as u32, total as u32);
15769    unsafe {
15770        stream
15771            .launch_builder(&f)
15772            .arg(a.inner())
15773            .arg(out.inner_mut())
15774            .arg(&d0u)
15775            .arg(&d1u)
15776            .arg(&d2u)
15777            .arg(&d3u)
15778            .arg(&tu)
15779            .launch(cfg)?;
15780    }
15781    Ok(())
15782}
15783
15784/// Softmax into pre-allocated output (row-wise).
15785#[cfg(feature = "cuda")]
15786pub fn gpu_softmax_into(
15787    a: &CudaBuffer<f32>,
15788    rows: usize,
15789    cols: usize,
15790    out: &mut CudaBuffer<f32>,
15791    device: &GpuDevice,
15792) -> GpuResult<()> {
15793    use cudarc::driver::PushKernelArg;
15794    let ctx = device.context();
15795    let stream = device.stream();
15796    let f = crate::module_cache::get_or_compile(
15797        ctx,
15798        SOFTMAX_PTX,
15799        "softmax_kernel",
15800        device.ordinal() as u32,
15801    )
15802    .map_err(|_| GpuError::PtxCompileFailed {
15803        kernel: "softmax_kernel",
15804    })?;
15805    let block_size = 256u32;
15806    let grid_size = rows as u32;
15807    let cfg = LaunchConfig {
15808        grid_dim: (grid_size, 1, 1),
15809        block_dim: (block_size, 1, 1),
15810        shared_mem_bytes: (cols as u32) * 4,
15811    };
15812    let rows_u32 = rows as u32;
15813    let cols_u32 = cols as u32;
15814    unsafe {
15815        stream
15816            .launch_builder(&f)
15817            .arg(a.inner())
15818            .arg(out.inner_mut())
15819            .arg(&rows_u32)
15820            .arg(&cols_u32)
15821            .launch(cfg)?;
15822    }
15823    Ok(())
15824}
15825
15826/// LayerNorm into pre-allocated output.
15827#[cfg(feature = "cuda")]
15828#[allow(clippy::too_many_arguments)]
15829pub fn gpu_layernorm_into(
15830    input: &CudaBuffer<f32>,
15831    weight: &CudaBuffer<f32>,
15832    bias: &CudaBuffer<f32>,
15833    rows: usize,
15834    cols: usize,
15835    eps: f32,
15836    out: &mut CudaBuffer<f32>,
15837    device: &GpuDevice,
15838) -> GpuResult<()> {
15839    use cudarc::driver::PushKernelArg;
15840    let ctx = device.context();
15841    let stream = device.stream();
15842    let f = crate::module_cache::get_or_compile(
15843        ctx,
15844        LAYERNORM_PTX,
15845        "layernorm_kernel",
15846        device.ordinal() as u32,
15847    )
15848    .map_err(|_| GpuError::PtxCompileFailed {
15849        kernel: "layernorm_kernel",
15850    })?;
15851    let block_size = 256u32;
15852    let grid_size = rows as u32;
15853    let cfg = LaunchConfig {
15854        grid_dim: (grid_size, 1, 1),
15855        block_dim: (block_size, 1, 1),
15856        shared_mem_bytes: (cols as u32) * 4,
15857    };
15858    let rows_u32 = rows as u32;
15859    let cols_u32 = cols as u32;
15860    unsafe {
15861        stream
15862            .launch_builder(&f)
15863            .arg(input.inner())
15864            .arg(out.inner_mut())
15865            .arg(weight.inner())
15866            .arg(bias.inner())
15867            .arg(&rows_u32)
15868            .arg(&cols_u32)
15869            .arg(&eps)
15870            .launch(cfg)?;
15871    }
15872    Ok(())
15873}
15874
15875/// Slice read into pre-allocated output: read first `len` rows from
15876/// `[n_batch, max_len, d]` into out `[n_batch, len, d]`.
15877#[cfg(feature = "cuda")]
15878pub fn gpu_slice_read_into(
15879    src: &CudaBuffer<f32>,
15880    n_batch: usize,
15881    d: usize,
15882    len: usize,
15883    max_len: usize,
15884    out: &mut CudaBuffer<f32>,
15885    device: &GpuDevice,
15886) -> GpuResult<()> {
15887    use cudarc::driver::PushKernelArg;
15888    let total = n_batch * len * d;
15889    let ctx = device.context();
15890    let stream = device.stream();
15891    let f = crate::module_cache::get_or_compile(
15892        ctx,
15893        SLICE_READ_PTX,
15894        "slice_read_kernel",
15895        device.ordinal() as u32,
15896    )
15897    .map_err(|_| GpuError::PtxCompileFailed {
15898        kernel: "slice_read_kernel",
15899    })?;
15900    let cfg = launch_cfg(total)?;
15901    let total_u32 = total as u32;
15902    let d_u32 = d as u32;
15903    let len_u32 = len as u32;
15904    let max_len_u32 = max_len as u32;
15905    unsafe {
15906        stream
15907            .launch_builder(&f)
15908            .arg(src.inner())
15909            .arg(out.inner_mut())
15910            .arg(&total_u32)
15911            .arg(&d_u32)
15912            .arg(&len_u32)
15913            .arg(&max_len_u32)
15914            .launch(cfg)?;
15915    }
15916    Ok(())
15917}
15918
15919/// Small matmul (PTX kernel) into pre-allocated output.
15920#[cfg(feature = "cuda")]
15921pub fn gpu_small_matmul_into(
15922    a: &CudaBuffer<f32>,
15923    b: &CudaBuffer<f32>,
15924    m: usize,
15925    k: usize,
15926    n: usize,
15927    out: &mut CudaBuffer<f32>,
15928    device: &GpuDevice,
15929) -> GpuResult<()> {
15930    use cudarc::driver::PushKernelArg;
15931    let total = m * n;
15932    let ctx = device.context();
15933    let stream = device.stream();
15934    let f = crate::module_cache::get_or_compile(
15935        ctx,
15936        SMALL_MATMUL_PTX,
15937        "small_matmul_kernel",
15938        device.ordinal() as u32,
15939    )
15940    .map_err(|_| GpuError::PtxCompileFailed {
15941        kernel: "small_matmul_kernel",
15942    })?;
15943    let cfg = launch_cfg(total)?;
15944    let (m_u32, k_u32, n_u32, total_u32) = (m as u32, k as u32, n as u32, total as u32);
15945    unsafe {
15946        stream
15947            .launch_builder(&f)
15948            .arg(a.inner())
15949            .arg(b.inner())
15950            .arg(out.inner_mut())
15951            .arg(&m_u32)
15952            .arg(&k_u32)
15953            .arg(&n_u32)
15954            .arg(&total_u32)
15955            .launch(cfg)?;
15956    }
15957    Ok(())
15958}
15959
15960// ===========================================================================
15961// Indirect-parameter kernels for CUDA graph capture
15962// ===========================================================================
15963
15964/// Slice write with position read from device memory (for CUDA graph capture).
15965/// Writes `src [n_batch, d]` into row `*pos_ptr` of `dst [n_batch, max_len, d]`.
15966#[cfg(feature = "cuda")]
15967pub fn gpu_slice_write_indirect(
15968    src: &CudaBuffer<f32>,
15969    dst: &mut CudaBuffer<f32>,
15970    n_batch: usize,
15971    d: usize,
15972    max_len: usize,
15973    pos_ptr: &cudarc::driver::CudaSlice<u32>,
15974    device: &GpuDevice,
15975) -> GpuResult<()> {
15976    use cudarc::driver::PushKernelArg;
15977    let total = n_batch * d;
15978    let ctx = device.context();
15979    let stream = device.stream();
15980    let f = crate::module_cache::get_or_compile(
15981        ctx,
15982        SLICE_WRITE_INDIRECT_PTX,
15983        "slice_write_indirect_kernel",
15984        device.ordinal() as u32,
15985    )
15986    .map_err(|_| GpuError::PtxCompileFailed {
15987        kernel: "slice_write_indirect_kernel",
15988    })?;
15989    let cfg = launch_cfg(total)?;
15990    let n_u32 = total as u32;
15991    let d_u32 = d as u32;
15992    let max_len_u32 = max_len as u32;
15993    unsafe {
15994        stream
15995            .launch_builder(&f)
15996            .arg(src.inner())
15997            .arg(dst.inner_mut())
15998            .arg(&n_u32)
15999            .arg(&d_u32)
16000            .arg(&max_len_u32)
16001            .arg(pos_ptr)
16002            .launch(cfg)?;
16003    }
16004    Ok(())
16005}
16006
16007/// Build causal attention mask with total_len read from device memory.
16008/// Writes `out[h, col] = 0.0` if `col < *total_len_ptr`, else `-1e9`.
16009/// Output shape: `[n_head, max_pos]` (n_head rows, each max_pos wide).
16010#[cfg(feature = "cuda")]
16011pub fn gpu_causal_mask_indirect(
16012    total_len_ptr: &cudarc::driver::CudaSlice<u32>,
16013    n_head: usize,
16014    max_pos: usize,
16015    out: &mut CudaBuffer<f32>,
16016    device: &GpuDevice,
16017) -> GpuResult<()> {
16018    use cudarc::driver::PushKernelArg;
16019    let total = n_head * max_pos;
16020    let ctx = device.context();
16021    let stream = device.stream();
16022    let f = crate::module_cache::get_or_compile(
16023        ctx,
16024        CAUSAL_MASK_INDIRECT_PTX,
16025        "causal_mask_indirect_kernel",
16026        device.ordinal() as u32,
16027    )
16028    .map_err(|_| GpuError::PtxCompileFailed {
16029        kernel: "causal_mask_indirect_kernel",
16030    })?;
16031    let cfg = launch_cfg(total)?;
16032    let max_pos_u32 = max_pos as u32;
16033    let total_u32 = total as u32;
16034    unsafe {
16035        stream
16036            .launch_builder(&f)
16037            .arg(total_len_ptr)
16038            .arg(out.inner_mut())
16039            .arg(&max_pos_u32)
16040            .arg(&total_u32)
16041            .launch(cfg)?;
16042    }
16043    Ok(())
16044}
16045
16046// ===========================================================================
16047// Pre-compilation of all decode-path PTX modules
16048// ===========================================================================
16049
16050/// Pre-compile all PTX kernels used by the decode pass into the module cache.
16051/// Call this before CUDA graph capture to ensure no `cuModuleLoadData` calls
16052/// occur during capture (which is not a capturable operation).
16053#[cfg(feature = "cuda")]
16054pub fn precompile_decode_kernels(device: &GpuDevice) -> GpuResult<()> {
16055    let ctx = device.context();
16056    ctx.bind_to_thread()?;
16057    let ord = device.ordinal() as u32;
16058    let compile = |ptx: &'static str, name: &'static str| -> GpuResult<()> {
16059        crate::module_cache::get_or_compile(ctx, ptx, name, ord)
16060            .map(|_| ())
16061            .map_err(GpuError::Driver)
16062    };
16063    compile(ADD_PTX, "add_kernel")?;
16064    compile(MUL_PTX, "mul_kernel")?;
16065    compile(SCALE_PTX, "scale_kernel")?;
16066    compile(GELU_PTX, "gelu_kernel")?;
16067    compile(SOFTMAX_PTX, "softmax_kernel")?;
16068    compile(LAYERNORM_PTX, "layernorm_kernel")?;
16069    compile(PERMUTE_0213_PTX, "permute_0213_kernel")?;
16070    compile(EMBED_LOOKUP_PTX, "embed_lookup_kernel")?;
16071    compile(EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel")?;
16072    compile(SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel")?;
16073    compile(SMALL_MATMUL_PTX, "small_matmul_kernel")?;
16074    compile(SLICE_WRITE_INDIRECT_PTX, "slice_write_indirect_kernel")?;
16075    compile(CAUSAL_MASK_INDIRECT_PTX, "causal_mask_indirect_kernel")?;
16076    compile(SLICE_READ_PTX, "slice_read_kernel")?;
16077    compile(RELU_BACKWARD_PTX, "relu_backward_kernel")?;
16078    compile(GELU_BACKWARD_PTX, "gelu_backward_kernel")?;
16079    Ok(())
16080}
16081
16082/// Stub — no-op without cuda.
16083#[cfg(not(feature = "cuda"))]
16084pub fn precompile_decode_kernels(_device: &GpuDevice) -> GpuResult<()> {
16085    Err(GpuError::NoCudaFeature)
16086}
16087
16088// ---------------------------------------------------------------------------
16089// Stubs when `cuda` feature is disabled
16090// ---------------------------------------------------------------------------
16091
16092/// Stub -- always returns [`GpuError::NoCudaFeature`].
16093#[cfg(not(feature = "cuda"))]
16094pub fn gpu_gelu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16095    Err(GpuError::NoCudaFeature)
16096}
16097
16098/// Stub -- always returns [`GpuError::NoCudaFeature`].
16099#[cfg(not(feature = "cuda"))]
16100pub fn gpu_gelu_tanh(
16101    _input: &CudaBuffer<f32>,
16102    _device: &GpuDevice,
16103) -> GpuResult<CudaBuffer<f32>> {
16104    Err(GpuError::NoCudaFeature)
16105}
16106
16107/// Stub -- always returns [`GpuError::NoCudaFeature`].
16108#[cfg(not(feature = "cuda"))]
16109pub fn gpu_gelu_erf(
16110    _input: &CudaBuffer<f32>,
16111    _device: &GpuDevice,
16112) -> GpuResult<CudaBuffer<f32>> {
16113    Err(GpuError::NoCudaFeature)
16114}
16115
16116/// Stub -- always returns [`GpuError::NoCudaFeature`].
16117#[cfg(not(feature = "cuda"))]
16118pub fn gpu_gelu_backward_tanh(
16119    _grad: &CudaBuffer<f32>,
16120    _input: &CudaBuffer<f32>,
16121    _device: &GpuDevice,
16122) -> GpuResult<CudaBuffer<f32>> {
16123    Err(GpuError::NoCudaFeature)
16124}
16125
16126/// Stub -- always returns [`GpuError::NoCudaFeature`].
16127#[cfg(not(feature = "cuda"))]
16128pub fn gpu_silu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16129    Err(GpuError::NoCudaFeature)
16130}
16131
16132/// Stub -- always returns [`GpuError::NoCudaFeature`].
16133#[cfg(not(feature = "cuda"))]
16134pub fn gpu_silu_backward(
16135    _grad: &CudaBuffer<f32>,
16136    _input: &CudaBuffer<f32>,
16137    _device: &GpuDevice,
16138) -> GpuResult<CudaBuffer<f32>> {
16139    Err(GpuError::NoCudaFeature)
16140}
16141
16142/// Stub -- always returns [`GpuError::NoCudaFeature`].
16143#[cfg(not(feature = "cuda"))]
16144pub fn gpu_elu(
16145    _input: &CudaBuffer<f32>,
16146    _alpha: f32,
16147    _device: &GpuDevice,
16148) -> GpuResult<CudaBuffer<f32>> {
16149    Err(GpuError::NoCudaFeature)
16150}
16151
16152/// Stub -- always returns [`GpuError::NoCudaFeature`].
16153#[cfg(not(feature = "cuda"))]
16154pub fn gpu_elu_backward(
16155    _grad: &CudaBuffer<f32>,
16156    _input: &CudaBuffer<f32>,
16157    _alpha: f32,
16158    _device: &GpuDevice,
16159) -> GpuResult<CudaBuffer<f32>> {
16160    Err(GpuError::NoCudaFeature)
16161}
16162
16163/// Stub -- always returns [`GpuError::NoCudaFeature`].
16164#[cfg(not(feature = "cuda"))]
16165pub fn gpu_mish(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16166    Err(GpuError::NoCudaFeature)
16167}
16168
16169/// Stub -- always returns [`GpuError::NoCudaFeature`].
16170#[cfg(not(feature = "cuda"))]
16171pub fn gpu_mish_backward(
16172    _grad: &CudaBuffer<f32>,
16173    _input: &CudaBuffer<f32>,
16174    _device: &GpuDevice,
16175) -> GpuResult<CudaBuffer<f32>> {
16176    Err(GpuError::NoCudaFeature)
16177}
16178
16179/// Stub -- always returns [`GpuError::NoCudaFeature`].
16180#[cfg(not(feature = "cuda"))]
16181pub fn gpu_clamp(
16182    _input: &CudaBuffer<f32>,
16183    _min_val: f32,
16184    _max_val: f32,
16185    _device: &GpuDevice,
16186) -> GpuResult<CudaBuffer<f32>> {
16187    Err(GpuError::NoCudaFeature)
16188}
16189
16190/// Stub -- always returns [`GpuError::NoCudaFeature`].
16191#[cfg(not(feature = "cuda"))]
16192pub fn gpu_div(
16193    _a: &CudaBuffer<f32>,
16194    _b: &CudaBuffer<f32>,
16195    _device: &GpuDevice,
16196) -> GpuResult<CudaBuffer<f32>> {
16197    Err(GpuError::NoCudaFeature)
16198}
16199
16200/// Stub -- always returns [`GpuError::NoCudaFeature`].
16201#[cfg(not(feature = "cuda"))]
16202pub fn gpu_exp(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16203    Err(GpuError::NoCudaFeature)
16204}
16205
16206/// Stub -- always returns [`GpuError::NoCudaFeature`].
16207#[cfg(not(feature = "cuda"))]
16208pub fn gpu_log(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16209    Err(GpuError::NoCudaFeature)
16210}
16211
16212/// Stub -- always returns [`GpuError::NoCudaFeature`].
16213#[cfg(not(feature = "cuda"))]
16214pub fn gpu_sqrt(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16215    Err(GpuError::NoCudaFeature)
16216}
16217
16218/// Stub -- always returns [`GpuError::NoCudaFeature`].
16219#[cfg(not(feature = "cuda"))]
16220pub fn gpu_pow(
16221    _a: &CudaBuffer<f32>,
16222    _exponent: f32,
16223    _device: &GpuDevice,
16224) -> GpuResult<CudaBuffer<f32>> {
16225    Err(GpuError::NoCudaFeature)
16226}
16227
16228/// Stub -- always returns [`GpuError::NoCudaFeature`].
16229#[cfg(not(feature = "cuda"))]
16230pub fn gpu_abs(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16231    Err(GpuError::NoCudaFeature)
16232}
16233
16234/// Stub -- always returns [`GpuError::NoCudaFeature`].
16235#[cfg(not(feature = "cuda"))]
16236pub fn gpu_sigmoid(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16237    Err(GpuError::NoCudaFeature)
16238}
16239
16240/// Stub -- always returns [`GpuError::NoCudaFeature`].
16241#[cfg(not(feature = "cuda"))]
16242pub fn gpu_tanh(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16243    Err(GpuError::NoCudaFeature)
16244}
16245
16246/// Stub -- always returns [`GpuError::NoCudaFeature`].
16247#[cfg(not(feature = "cuda"))]
16248pub fn gpu_layernorm(
16249    _input: &CudaBuffer<f32>,
16250    _weight: &CudaBuffer<f32>,
16251    _bias: &CudaBuffer<f32>,
16252    _rows: usize,
16253    _cols: usize,
16254    _eps: f32,
16255    _device: &GpuDevice,
16256) -> GpuResult<CudaBuffer<f32>> {
16257    Err(GpuError::NoCudaFeature)
16258}
16259
16260/// Stub -- always returns [`GpuError::NoCudaFeature`].
16261#[cfg(not(feature = "cuda"))]
16262pub fn gpu_transpose_2d(
16263    _input: &CudaBuffer<f32>,
16264    _m: usize,
16265    _n: usize,
16266    _device: &GpuDevice,
16267) -> GpuResult<CudaBuffer<f32>> {
16268    Err(GpuError::NoCudaFeature)
16269}
16270
16271/// Stub -- always returns [`GpuError::NoCudaFeature`].
16272#[cfg(not(feature = "cuda"))]
16273pub fn gpu_add(
16274    _a: &CudaBuffer<f32>,
16275    _b: &CudaBuffer<f32>,
16276    _device: &GpuDevice,
16277) -> GpuResult<CudaBuffer<f32>> {
16278    Err(GpuError::NoCudaFeature)
16279}
16280
16281/// Stub -- always returns [`GpuError::NoCudaFeature`].
16282#[cfg(not(feature = "cuda"))]
16283pub fn gpu_sub(
16284    _a: &CudaBuffer<f32>,
16285    _b: &CudaBuffer<f32>,
16286    _device: &GpuDevice,
16287) -> GpuResult<CudaBuffer<f32>> {
16288    Err(GpuError::NoCudaFeature)
16289}
16290
16291/// Stub -- always returns [`GpuError::NoCudaFeature`].
16292#[cfg(not(feature = "cuda"))]
16293pub fn gpu_mul(
16294    _a: &CudaBuffer<f32>,
16295    _b: &CudaBuffer<f32>,
16296    _device: &GpuDevice,
16297) -> GpuResult<CudaBuffer<f32>> {
16298    Err(GpuError::NoCudaFeature)
16299}
16300
16301/// Stub -- always returns [`GpuError::NoCudaFeature`].
16302#[cfg(not(feature = "cuda"))]
16303pub fn gpu_neg(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16304    Err(GpuError::NoCudaFeature)
16305}
16306
16307/// Stub -- always returns [`GpuError::NoCudaFeature`].
16308#[cfg(not(feature = "cuda"))]
16309pub fn gpu_relu(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16310    Err(GpuError::NoCudaFeature)
16311}
16312
16313/// Stub -- always returns [`GpuError::NoCudaFeature`].
16314#[cfg(not(feature = "cuda"))]
16315pub fn gpu_scale(
16316    _a: &CudaBuffer<f32>,
16317    _scalar: f32,
16318    _device: &GpuDevice,
16319) -> GpuResult<CudaBuffer<f32>> {
16320    Err(GpuError::NoCudaFeature)
16321}
16322
16323/// Stub -- always returns [`GpuError::NoCudaFeature`].
16324#[cfg(not(feature = "cuda"))]
16325pub fn gpu_broadcast_add(
16326    _a: &CudaBuffer<f32>,
16327    _b: &CudaBuffer<f32>,
16328    _a_shape: &[usize],
16329    _b_shape: &[usize],
16330    _out_shape: &[usize],
16331    _device: &GpuDevice,
16332) -> GpuResult<CudaBuffer<f32>> {
16333    Err(GpuError::NoCudaFeature)
16334}
16335
16336/// Stub -- always returns [`GpuError::NoCudaFeature`].
16337#[cfg(not(feature = "cuda"))]
16338pub fn gpu_broadcast_sub(
16339    _a: &CudaBuffer<f32>,
16340    _b: &CudaBuffer<f32>,
16341    _a_shape: &[usize],
16342    _b_shape: &[usize],
16343    _out_shape: &[usize],
16344    _device: &GpuDevice,
16345) -> GpuResult<CudaBuffer<f32>> {
16346    Err(GpuError::NoCudaFeature)
16347}
16348
16349/// Stub -- always returns [`GpuError::NoCudaFeature`].
16350#[cfg(not(feature = "cuda"))]
16351pub fn gpu_broadcast_mul(
16352    _a: &CudaBuffer<f32>,
16353    _b: &CudaBuffer<f32>,
16354    _a_shape: &[usize],
16355    _b_shape: &[usize],
16356    _out_shape: &[usize],
16357    _device: &GpuDevice,
16358) -> GpuResult<CudaBuffer<f32>> {
16359    Err(GpuError::NoCudaFeature)
16360}
16361
16362/// Stub -- always returns [`GpuError::NoCudaFeature`].
16363#[cfg(not(feature = "cuda"))]
16364pub fn gpu_softmax(
16365    _input: &CudaBuffer<f32>,
16366    _rows: usize,
16367    _cols: usize,
16368    _device: &GpuDevice,
16369) -> GpuResult<CudaBuffer<f32>> {
16370    Err(GpuError::NoCudaFeature)
16371}
16372
16373/// Stub -- always returns [`GpuError::NoCudaFeature`].
16374#[cfg(not(feature = "cuda"))]
16375pub fn gpu_dropout(
16376    _input: &CudaBuffer<f32>,
16377    _threshold: u32,
16378    _scale: f32,
16379    _seed: u32,
16380    _device: &GpuDevice,
16381) -> GpuResult<CudaBuffer<f32>> {
16382    Err(GpuError::NoCudaFeature)
16383}
16384
16385/// Stub -- always returns [`GpuError::NoCudaFeature`].
16386#[cfg(not(feature = "cuda"))]
16387pub fn gpu_permute_0213(
16388    _input: &CudaBuffer<f32>,
16389    _d0: usize,
16390    _d1: usize,
16391    _d2: usize,
16392    _d3: usize,
16393    _device: &GpuDevice,
16394) -> GpuResult<CudaBuffer<f32>> {
16395    Err(GpuError::NoCudaFeature)
16396}
16397
16398/// Stub -- always returns [`GpuError::NoCudaFeature`].
16399#[cfg(not(feature = "cuda"))]
16400pub fn gpu_slice_write(
16401    _src: &CudaBuffer<f32>,
16402    _dst: &mut CudaBuffer<f32>,
16403    _n_batch: usize,
16404    _d: usize,
16405    _max_len: usize,
16406    _pos: usize,
16407    _device: &GpuDevice,
16408) -> GpuResult<()> {
16409    Err(GpuError::NoCudaFeature)
16410}
16411
16412/// Stub -- always returns [`GpuError::NoCudaFeature`].
16413#[cfg(not(feature = "cuda"))]
16414pub fn gpu_slice_read(
16415    _src: &CudaBuffer<f32>,
16416    _n_batch: usize,
16417    _d: usize,
16418    _len: usize,
16419    _max_len: usize,
16420    _device: &GpuDevice,
16421) -> GpuResult<CudaBuffer<f32>> {
16422    Err(GpuError::NoCudaFeature)
16423}
16424
16425/// Stub -- always returns [`GpuError::NoCudaFeature`].
16426#[cfg(not(feature = "cuda"))]
16427pub fn gpu_embed_lookup(
16428    _idx: &CudaBuffer<f32>,
16429    _weight: &CudaBuffer<f32>,
16430    _d: usize,
16431    _device: &GpuDevice,
16432) -> GpuResult<CudaBuffer<f32>> {
16433    Err(GpuError::NoCudaFeature)
16434}
16435
16436/// Stub -- always returns [`GpuError::NoCudaFeature`].
16437#[cfg(not(feature = "cuda"))]
16438pub fn gpu_embed_lookup_batch(
16439    _indices: &CudaBuffer<f32>,
16440    _weight: &CudaBuffer<f32>,
16441    _n: usize,
16442    _d: usize,
16443    _device: &GpuDevice,
16444) -> GpuResult<CudaBuffer<f32>> {
16445    Err(GpuError::NoCudaFeature)
16446}
16447
16448/// Stub -- always returns [`GpuError::NoCudaFeature`].
16449#[cfg(not(feature = "cuda"))]
16450pub fn gpu_scatter_add_rows(
16451    _grad_output: &CudaBuffer<f32>,
16452    _indices: &CudaBuffer<f32>,
16453    _num_embeddings: usize,
16454    _d: usize,
16455    _device: &GpuDevice,
16456) -> GpuResult<CudaBuffer<f32>> {
16457    Err(GpuError::NoCudaFeature)
16458}
16459
16460/// Stub -- always returns [`GpuError::NoCudaFeature`].
16461#[cfg(not(feature = "cuda"))]
16462pub fn gpu_relu_backward(
16463    _grad: &CudaBuffer<f32>,
16464    _input: &CudaBuffer<f32>,
16465    _device: &GpuDevice,
16466) -> GpuResult<CudaBuffer<f32>> {
16467    Err(GpuError::NoCudaFeature)
16468}
16469
16470/// Stub -- always returns [`GpuError::NoCudaFeature`].
16471#[cfg(not(feature = "cuda"))]
16472pub fn gpu_gelu_backward(
16473    _grad: &CudaBuffer<f32>,
16474    _input: &CudaBuffer<f32>,
16475    _device: &GpuDevice,
16476) -> GpuResult<CudaBuffer<f32>> {
16477    Err(GpuError::NoCudaFeature)
16478}
16479
16480/// Stub -- always returns [`GpuError::NoCudaFeature`].
16481#[cfg(not(feature = "cuda"))]
16482pub fn gpu_index_select_1d(
16483    _input: &CudaBuffer<f32>,
16484    _indices: &CudaBuffer<f32>,
16485    _device: &GpuDevice,
16486) -> GpuResult<CudaBuffer<f32>> {
16487    Err(GpuError::NoCudaFeature)
16488}
16489
16490/// Stub -- always returns [`GpuError::NoCudaFeature`].
16491#[cfg(not(feature = "cuda"))]
16492pub fn gpu_scatter_add_1d(
16493    _grad_output: &CudaBuffer<f32>,
16494    _indices: &CudaBuffer<f32>,
16495    _input_len: usize,
16496    _device: &GpuDevice,
16497) -> GpuResult<CudaBuffer<f32>> {
16498    Err(GpuError::NoCudaFeature)
16499}
16500
16501/// Stub -- always returns [`GpuError::NoCudaFeature`].
16502#[cfg(not(feature = "cuda"))]
16503pub fn gpu_masked_fill(
16504    _input: &CudaBuffer<f32>,
16505    _mask: &CudaBuffer<f32>,
16506    _value: f32,
16507    _device: &GpuDevice,
16508) -> GpuResult<CudaBuffer<f32>> {
16509    Err(GpuError::NoCudaFeature)
16510}
16511
16512/// Stub -- always returns [`GpuError::NoCudaFeature`].
16513#[cfg(not(feature = "cuda"))]
16514pub fn gpu_masked_zero(
16515    _grad: &CudaBuffer<f32>,
16516    _mask: &CudaBuffer<f32>,
16517    _device: &GpuDevice,
16518) -> GpuResult<CudaBuffer<f32>> {
16519    Err(GpuError::NoCudaFeature)
16520}
16521
16522/// Stub -- always returns [`GpuError::NoCudaFeature`].
16523#[cfg(not(feature = "cuda"))]
16524pub fn gpu_sigmoid_backward(
16525    _grad: &CudaBuffer<f32>,
16526    _output: &CudaBuffer<f32>,
16527    _device: &GpuDevice,
16528) -> GpuResult<CudaBuffer<f32>> {
16529    Err(GpuError::NoCudaFeature)
16530}
16531
16532/// Stub -- always returns [`GpuError::NoCudaFeature`].
16533#[cfg(not(feature = "cuda"))]
16534pub fn gpu_tanh_backward(
16535    _grad: &CudaBuffer<f32>,
16536    _output: &CudaBuffer<f32>,
16537    _device: &GpuDevice,
16538) -> GpuResult<CudaBuffer<f32>> {
16539    Err(GpuError::NoCudaFeature)
16540}
16541
16542/// Stub -- always returns [`GpuError::NoCudaFeature`].
16543#[cfg(not(feature = "cuda"))]
16544pub fn gpu_softmax_backward(
16545    _grad: &CudaBuffer<f32>,
16546    _output: &CudaBuffer<f32>,
16547    _cols: usize,
16548    _device: &GpuDevice,
16549) -> GpuResult<CudaBuffer<f32>> {
16550    Err(GpuError::NoCudaFeature)
16551}
16552
16553/// Stub -- always returns [`GpuError::NoCudaFeature`].
16554#[cfg(not(feature = "cuda"))]
16555pub fn gpu_log_softmax(
16556    _input: &CudaBuffer<f32>,
16557    _cols: usize,
16558    _device: &GpuDevice,
16559) -> GpuResult<CudaBuffer<f32>> {
16560    Err(GpuError::NoCudaFeature)
16561}
16562
16563/// Stub -- always returns [`GpuError::NoCudaFeature`].
16564#[cfg(not(feature = "cuda"))]
16565pub fn gpu_log_softmax_backward(
16566    _grad: &CudaBuffer<f32>,
16567    _output: &CudaBuffer<f32>,
16568    _cols: usize,
16569    _device: &GpuDevice,
16570) -> GpuResult<CudaBuffer<f32>> {
16571    Err(GpuError::NoCudaFeature)
16572}
16573
16574/// Stub -- always returns [`GpuError::NoCudaFeature`].
16575#[cfg(not(feature = "cuda"))]
16576pub fn gpu_sum_axis(
16577    _a: &CudaBuffer<f32>,
16578    _outer: usize,
16579    _axis_size: usize,
16580    _inner: usize,
16581    _device: &GpuDevice,
16582) -> GpuResult<CudaBuffer<f32>> {
16583    Err(GpuError::NoCudaFeature)
16584}
16585
16586/// Stub -- always returns [`GpuError::NoCudaFeature`].
16587#[cfg(not(feature = "cuda"))]
16588pub fn gpu_cumsum(
16589    _input: &CudaBuffer<f32>,
16590    _outer: usize,
16591    _dim_size: usize,
16592    _inner: usize,
16593    _device: &GpuDevice,
16594) -> GpuResult<CudaBuffer<f32>> {
16595    Err(GpuError::NoCudaFeature)
16596}
16597
16598/// Stub -- always returns [`GpuError::NoCudaFeature`].
16599#[cfg(not(feature = "cuda"))]
16600pub fn gpu_cumprod(
16601    _input: &CudaBuffer<f32>,
16602    _outer: usize,
16603    _dim_size: usize,
16604    _inner: usize,
16605    _device: &GpuDevice,
16606) -> GpuResult<CudaBuffer<f32>> {
16607    Err(GpuError::NoCudaFeature)
16608}
16609
16610/// Stub -- always returns [`GpuError::NoCudaFeature`].
16611#[cfg(not(feature = "cuda"))]
16612pub fn gpu_cummax(
16613    _input: &CudaBuffer<f32>,
16614    _outer: usize,
16615    _dim_size: usize,
16616    _inner: usize,
16617    _device: &GpuDevice,
16618) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
16619    Err(GpuError::NoCudaFeature)
16620}
16621
16622/// Stub -- always returns [`GpuError::NoCudaFeature`].
16623#[cfg(not(feature = "cuda"))]
16624pub fn gpu_cummin(
16625    _input: &CudaBuffer<f32>,
16626    _outer: usize,
16627    _dim_size: usize,
16628    _inner: usize,
16629    _device: &GpuDevice,
16630) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
16631    Err(GpuError::NoCudaFeature)
16632}
16633
16634/// Stub -- always returns [`GpuError::NoCudaFeature`].
16635#[cfg(not(feature = "cuda"))]
16636pub fn gpu_logcumsumexp(
16637    _input: &CudaBuffer<f32>,
16638    _outer: usize,
16639    _dim_size: usize,
16640    _inner: usize,
16641    _device: &GpuDevice,
16642) -> GpuResult<CudaBuffer<f32>> {
16643    Err(GpuError::NoCudaFeature)
16644}
16645
16646/// Stub -- always returns [`GpuError::NoCudaFeature`].
16647#[cfg(not(feature = "cuda"))]
16648pub fn gpu_strided_split(
16649    _input: &CudaBuffer<f32>,
16650    _total_along_axis: usize,
16651    _split_offset: usize,
16652    _split_size: usize,
16653    _inner_size: usize,
16654    _n: usize,
16655    _device: &GpuDevice,
16656) -> GpuResult<CudaBuffer<f32>> {
16657    Err(GpuError::NoCudaFeature)
16658}
16659
16660/// Stub -- always returns [`GpuError::NoCudaFeature`].
16661#[cfg(not(feature = "cuda"))]
16662pub fn gpu_strided_cat(
16663    _input: &CudaBuffer<f32>,
16664    _output: &mut CudaBuffer<f32>,
16665    _total_along_axis: usize,
16666    _cat_offset: usize,
16667    _part_size: usize,
16668    _inner_size: usize,
16669    _n: usize,
16670    _device: &GpuDevice,
16671) -> GpuResult<()> {
16672    Err(GpuError::NoCudaFeature)
16673}
16674
16675/// Maximum rank stub for feature-disabled builds. Kept in sync with
16676/// the cuda-enabled definition above.
16677#[cfg(not(feature = "cuda"))]
16678pub const STRIDED_COPY_MAX_DIMS: usize = 8;
16679
16680/// Stub -- always returns [`GpuError::NoCudaFeature`].
16681#[cfg(not(feature = "cuda"))]
16682pub fn gpu_strided_copy(
16683    _input: &CudaBuffer<f32>,
16684    _out_shape: &[usize],
16685    _src_strides: &[isize],
16686    _src_offset: usize,
16687    _device: &GpuDevice,
16688) -> GpuResult<CudaBuffer<f32>> {
16689    Err(GpuError::NoCudaFeature)
16690}
16691
16692/// Stub -- always returns [`GpuError::NoCudaFeature`].
16693#[cfg(not(feature = "cuda"))]
16694pub fn gpu_strided_copy_f64(
16695    _input: &CudaBuffer<f64>,
16696    _out_shape: &[usize],
16697    _src_strides: &[isize],
16698    _src_offset: usize,
16699    _device: &GpuDevice,
16700) -> GpuResult<CudaBuffer<f64>> {
16701    Err(GpuError::NoCudaFeature)
16702}
16703
16704// ---------------------------------------------------------------------------
16705// f32-to-f16 GPU conversion
16706// ---------------------------------------------------------------------------
16707
16708/// Convert an f32 GPU buffer to f16 (represented as `CudaSlice<u16>`).
16709///
16710/// Each element is converted using IEEE 754 round-to-nearest-even via the
16711/// PTX `cvt.rn.f16.f32` instruction. The output is a `CudaSlice<u16>` where
16712/// each `u16` holds the bit pattern of an IEEE 754 half-precision float.
16713///
16714/// # Errors
16715///
16716/// - [`GpuError::PtxCompileFailed`] if the conversion kernel cannot be compiled
16717///   (e.g., GPU architecture too old to support f16 conversion instructions).
16718/// - [`GpuError::Driver`] on CUDA launch errors.
16719#[cfg(feature = "cuda")]
16720pub(crate) fn gpu_f32_to_f16(
16721    input: &CudaBuffer<f32>,
16722    device: &GpuDevice,
16723) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
16724    use cudarc::driver::PushKernelArg;
16725
16726    let n = input.len();
16727    if n == 0 {
16728        let empty = device.stream().alloc_zeros::<u16>(0)?;
16729        return Ok(empty);
16730    }
16731
16732    let ctx = device.context();
16733    let stream = device.stream();
16734
16735    let f = crate::module_cache::get_or_compile(
16736        ctx,
16737        F32_TO_F16_PTX,
16738        "f32_to_f16_kernel",
16739        device.ordinal() as u32,
16740    )
16741    .map_err(|_| GpuError::PtxCompileFailed {
16742        kernel: "f32_to_f16_kernel",
16743    })?;
16744
16745    let mut out = stream.alloc_zeros::<u16>(n)?;
16746    let cfg = launch_cfg(n)?;
16747    let n_u32 = n as u32;
16748
16749    // SAFETY: The kernel reads `n` f32 values from `input` and writes `n`
16750    // u16 values (f16 bit patterns) to `out`. Both buffers are device-resident
16751    // and correctly sized. The grid is configured to cover exactly `n` threads.
16752    unsafe {
16753        stream
16754            .launch_builder(&f)
16755            .arg(input.inner())
16756            .arg(&mut out)
16757            .arg(&n_u32)
16758            .launch(cfg)?;
16759    }
16760
16761    Ok(out)
16762}
16763
16764/// Stub -- always returns [`GpuError::NoCudaFeature`].
16765#[cfg(not(feature = "cuda"))]
16766pub(crate) fn gpu_f32_to_f16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
16767    Err(GpuError::NoCudaFeature)
16768}
16769
16770/// Convert f32 GPU buffer to bf16 (stored as u16) on-device.
16771///
16772/// Uses bit manipulation for round-to-nearest-even bf16 conversion.
16773/// Works on sm_52+ (no special bf16 hardware required).
16774#[cfg(feature = "cuda")]
16775pub(crate) fn gpu_f32_to_bf16(
16776    input: &CudaBuffer<f32>,
16777    device: &GpuDevice,
16778) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
16779    use cudarc::driver::PushKernelArg;
16780
16781    let n = input.len();
16782    if n == 0 {
16783        let empty = device.stream().alloc_zeros::<u16>(0)?;
16784        return Ok(empty);
16785    }
16786
16787    let ctx = device.context();
16788    let stream = device.stream();
16789
16790    let f = crate::module_cache::get_or_compile(
16791        ctx,
16792        F32_TO_BF16_PTX,
16793        "f32_to_bf16_kernel",
16794        device.ordinal() as u32,
16795    )
16796    .map_err(|_| GpuError::PtxCompileFailed {
16797        kernel: "f32_to_bf16_kernel",
16798    })?;
16799
16800    let mut out = stream.alloc_zeros::<u16>(n)?;
16801    let cfg = launch_cfg(n)?;
16802    let n_u32 = n as u32;
16803
16804    unsafe {
16805        stream
16806            .launch_builder(&f)
16807            .arg(input.inner())
16808            .arg(&mut out)
16809            .arg(&n_u32)
16810            .launch(cfg)?;
16811    }
16812
16813    Ok(out)
16814}
16815
16816/// Stub -- always returns [`GpuError::NoCudaFeature`].
16817#[cfg(not(feature = "cuda"))]
16818pub(crate) fn gpu_f32_to_bf16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
16819    Err(GpuError::NoCudaFeature)
16820}
16821
16822// ---------------------------------------------------------------------------
16823// Non-CUDA stubs -- f64 ops
16824// ---------------------------------------------------------------------------
16825
16826#[cfg(not(feature = "cuda"))]
16827pub fn gpu_add_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16828#[cfg(not(feature = "cuda"))]
16829pub fn gpu_sub_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16830#[cfg(not(feature = "cuda"))]
16831pub fn gpu_mul_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16832#[cfg(not(feature = "cuda"))]
16833pub fn gpu_div_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16834#[cfg(not(feature = "cuda"))]
16835pub fn gpu_neg_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16836#[cfg(not(feature = "cuda"))]
16837pub fn gpu_relu_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16838#[cfg(not(feature = "cuda"))]
16839pub fn gpu_scale_f64(_a: &CudaBuffer<f64>, _scalar: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16840#[cfg(not(feature = "cuda"))]
16841pub fn gpu_exp_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16842#[cfg(not(feature = "cuda"))]
16843pub fn gpu_log_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16844#[cfg(not(feature = "cuda"))]
16845pub fn gpu_sqrt_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16846#[cfg(not(feature = "cuda"))]
16847pub fn gpu_pow_f64(_a: &CudaBuffer<f64>, _exponent: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16848#[cfg(not(feature = "cuda"))]
16849pub fn gpu_abs_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16850#[cfg(not(feature = "cuda"))]
16851pub fn gpu_sigmoid_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16852#[cfg(not(feature = "cuda"))]
16853pub fn gpu_tanh_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16854#[cfg(not(feature = "cuda"))]
16855pub fn gpu_relu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16856#[cfg(not(feature = "cuda"))]
16857pub fn gpu_sigmoid_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16858#[cfg(not(feature = "cuda"))]
16859pub fn gpu_tanh_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16860#[cfg(not(feature = "cuda"))]
16861pub fn gpu_broadcast_add_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16862#[cfg(not(feature = "cuda"))]
16863pub fn gpu_broadcast_sub_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16864#[cfg(not(feature = "cuda"))]
16865pub fn gpu_broadcast_mul_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16866#[cfg(not(feature = "cuda"))]
16867pub fn gpu_broadcast_div_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16868#[cfg(not(feature = "cuda"))]
16869pub fn gpu_transpose_2d_f64(_input: &CudaBuffer<f64>, _m: usize, _n: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16870#[cfg(not(feature = "cuda"))]
16871pub fn gpu_permute_0213_f64(_input: &CudaBuffer<f64>, _d0: usize, _d1: usize, _d2: usize, _d3: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16872#[cfg(not(feature = "cuda"))]
16873pub fn gpu_strided_split_f64(_input: &CudaBuffer<f64>, _total_along_axis: usize, _split_offset: usize, _split_size: usize, _inner_size: usize, _n: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16874#[cfg(not(feature = "cuda"))]
16875pub fn gpu_strided_cat_f64(_input: &CudaBuffer<f64>, _output: &mut CudaBuffer<f64>, _total_along_axis: usize, _cat_offset: usize, _part_size: usize, _inner_size: usize, _n: usize, _device: &GpuDevice) -> GpuResult<()> { Err(GpuError::NoCudaFeature) }
16876#[cfg(not(feature = "cuda"))]
16877pub fn gpu_index_select_1d_f64(_input: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16878#[cfg(not(feature = "cuda"))]
16879pub fn gpu_scatter_add_1d_f64(_grad_output: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _input_len: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16880#[cfg(not(feature = "cuda"))]
16881pub fn gpu_masked_fill_f64(_input: &CudaBuffer<f64>, _mask: &CudaBuffer<u8>, _value: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16882#[cfg(not(feature = "cuda"))]
16883pub fn gpu_masked_zero_f64(_grad: &CudaBuffer<f64>, _mask: &CudaBuffer<u8>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16884#[cfg(not(feature = "cuda"))]
16885pub fn gpu_slice_write_f64(_src: &CudaBuffer<f64>, _dst: &mut CudaBuffer<f64>, _n_batch: usize, _d: usize, _max_len: usize, _pos: usize, _device: &GpuDevice) -> GpuResult<()> { Err(GpuError::NoCudaFeature) }
16886#[cfg(not(feature = "cuda"))]
16887pub fn gpu_slice_read_f64(_src: &CudaBuffer<f64>, _n_batch: usize, _d: usize, _len: usize, _max_len: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16888#[cfg(not(feature = "cuda"))]
16889pub fn gpu_embed_lookup_f64(_idx: &CudaBuffer<f32>, _weight: &CudaBuffer<f64>, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16890#[cfg(not(feature = "cuda"))]
16891pub fn gpu_embed_lookup_batch_f64(_indices: &CudaBuffer<f32>, _weight: &CudaBuffer<f64>, _n: usize, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16892#[cfg(not(feature = "cuda"))]
16893pub fn gpu_scatter_add_rows_f64(_grad_output: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _num_embeddings: usize, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16894
16895
16896// ---------------------------------------------------------------------------
16897// Public API -- f64 activation, normalization, scan, and pooling launchers
16898// ---------------------------------------------------------------------------
16899
16900/// GELU (sigmoid-approx) for f64.
16901#[cfg(feature = "cuda")]
16902pub fn gpu_gelu_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
16903    if let Some(out) = try_launch_unary_f64(input, device, GELU_F64_PTX, "gelu_f64_kernel")? {
16904        return Ok(out);
16905    }
16906    cpu_fallback_unary_f64(input, device, |x| x * (1.0 / (1.0 + (-1.702 * x).exp())))
16907}
16908
16909/// GELU (tanh-approx) for f64.
16910#[cfg(feature = "cuda")]
16911pub fn gpu_gelu_tanh_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
16912    if let Some(out) = try_launch_unary_f64(input, device, GELU_TANH_F64_PTX, "gelu_tanh_f64_kernel")? {
16913        return Ok(out);
16914    }
16915    cpu_fallback_unary_f64(input, device, |x| {
16916        let inner = (2.0_f64 / std::f64::consts::PI).sqrt() * (x + 0.044715 * x * x * x);
16917        0.5 * x * (1.0 + inner.tanh())
16918    })
16919}
16920
16921/// GELU (exact erf) for f64.
16922#[cfg(feature = "cuda")]
16923pub fn gpu_gelu_erf_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
16924    if let Some(out) = try_launch_unary_f64(input, device, GELU_ERF_F64_PTX, "gelu_erf_f64_kernel")? {
16925        return Ok(out);
16926    }
16927    cpu_fallback_unary_f64(input, device, |x| {
16928        // Approximate erf via Abramowitz & Stegun
16929        let z = x * std::f64::consts::FRAC_1_SQRT_2;
16930        let az = z.abs();
16931        let t = 1.0 / (1.0 + 0.3275911 * az);
16932        let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
16933        let erf_abs = 1.0 - poly * (-az * az).exp();
16934        let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
16935        x * 0.5 * (1.0 + erf_val)
16936    })
16937}
16938
16939/// GELU backward (sigmoid-approx) for f64.
16940#[cfg(feature = "cuda")]
16941pub fn gpu_gelu_backward_f64(
16942    grad: &CudaBuffer<f64>,
16943    input: &CudaBuffer<f64>,
16944    device: &GpuDevice,
16945) -> GpuResult<CudaBuffer<f64>> {
16946    if grad.len() != input.len() {
16947        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
16948    }
16949    if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_F64_PTX, "gelu_backward_f64_kernel")? {
16950        return Ok(out);
16951    }
16952    cpu_fallback_binary_f64(grad, input, device, |g, x| {
16953        let sig = 1.0 / (1.0 + (-1.702 * x).exp());
16954        g * (sig + 1.702 * x * sig * (1.0 - sig))
16955    })
16956}
16957
16958/// GELU backward (tanh-approx) for f64.
16959#[cfg(feature = "cuda")]
16960pub fn gpu_gelu_backward_tanh_f64(
16961    grad: &CudaBuffer<f64>,
16962    input: &CudaBuffer<f64>,
16963    device: &GpuDevice,
16964) -> GpuResult<CudaBuffer<f64>> {
16965    if grad.len() != input.len() {
16966        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
16967    }
16968    if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_TANH_F64_PTX, "gelu_backward_tanh_f64_kernel")? {
16969        return Ok(out);
16970    }
16971    cpu_fallback_binary_f64(grad, input, device, |g, x| {
16972        let s2pi = (2.0_f64 / std::f64::consts::PI).sqrt();
16973        let c = 0.044715_f64;
16974        let u = s2pi * (x + c * x * x * x);
16975        let t = u.tanh();
16976        let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * s2pi * (1.0 + 3.0 * c * x * x);
16977        g * d
16978    })
16979}
16980
16981/// GELU backward (exact erf) for f64.
16982#[cfg(feature = "cuda")]
16983pub fn gpu_gelu_backward_erf_f64(
16984    grad: &CudaBuffer<f64>,
16985    input: &CudaBuffer<f64>,
16986    device: &GpuDevice,
16987) -> GpuResult<CudaBuffer<f64>> {
16988    if grad.len() != input.len() {
16989        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
16990    }
16991    if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_ERF_F64_PTX, "gelu_backward_erf_f64_kernel")? {
16992        return Ok(out);
16993    }
16994    cpu_fallback_binary_f64(grad, input, device, |g, x| {
16995        let z = x * std::f64::consts::FRAC_1_SQRT_2;
16996        let az = z.abs();
16997        let t = 1.0 / (1.0 + 0.3275911 * az);
16998        let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
16999        let erf_abs = 1.0 - poly * (-az * az).exp();
17000        let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
17001        let cdf = 0.5 * (1.0 + erf_val);
17002        let pdf = (-x * x / 2.0).exp() / (2.0 * std::f64::consts::PI).sqrt();
17003        g * (cdf + x * pdf)
17004    })
17005}
17006
17007/// SiLU for f64.
17008#[cfg(feature = "cuda")]
17009pub fn gpu_silu_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17010    if let Some(out) = try_launch_unary_f64(input, device, SILU_F64_PTX, "silu_f64_kernel")? {
17011        return Ok(out);
17012    }
17013    cpu_fallback_unary_f64(input, device, |x| x / (1.0 + (-x).exp()))
17014}
17015
17016/// SiLU backward for f64.
17017#[cfg(feature = "cuda")]
17018pub fn gpu_silu_backward_f64(
17019    grad: &CudaBuffer<f64>,
17020    input: &CudaBuffer<f64>,
17021    device: &GpuDevice,
17022) -> GpuResult<CudaBuffer<f64>> {
17023    if grad.len() != input.len() {
17024        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17025    }
17026    if let Some(out) = try_launch_binary_f64(grad, input, device, SILU_BACKWARD_F64_PTX, "silu_backward_f64_kernel")? {
17027        return Ok(out);
17028    }
17029    cpu_fallback_binary_f64(grad, input, device, |g, x| {
17030        let sig = 1.0 / (1.0 + (-x).exp());
17031        g * (sig + x * sig * (1.0 - sig))
17032    })
17033}
17034
17035/// ELU for f64.
17036#[cfg(feature = "cuda")]
17037pub fn gpu_elu_f64(
17038    input: &CudaBuffer<f64>,
17039    alpha: f64,
17040    device: &GpuDevice,
17041) -> GpuResult<CudaBuffer<f64>> {
17042    use cudarc::driver::PushKernelArg;
17043    let n = input.len();
17044    if n == 0 { return cpu_to_gpu(&[], device); }
17045    let ctx = device.context();
17046    let stream = device.stream();
17047    if let Ok(f) = crate::module_cache::get_or_compile(ctx, ELU_F64_PTX, "elu_f64_kernel", device.ordinal() as u32) {
17048        let mut out = alloc_zeros_f64(n, device)?;
17049        let n_u32 = n as u32;
17050        let cfg = launch_cfg(n)?;
17051        unsafe {
17052            stream.launch_builder(&f)
17053                .arg(input.inner())
17054                .arg(out.inner_mut())
17055                .arg(&n_u32)
17056                .arg(&alpha)
17057                .launch(cfg)?;
17058        }
17059        return Ok(out);
17060    }
17061    let host = gpu_to_cpu(input, device)?;
17062    let result: Vec<f64> = host.iter().map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) }).collect();
17063    cpu_to_gpu(&result, device)
17064}
17065
17066/// ELU backward for f64.
17067#[cfg(feature = "cuda")]
17068pub fn gpu_elu_backward_f64(
17069    grad: &CudaBuffer<f64>,
17070    input: &CudaBuffer<f64>,
17071    alpha: f64,
17072    device: &GpuDevice,
17073) -> GpuResult<CudaBuffer<f64>> {
17074    use cudarc::driver::PushKernelArg;
17075    if grad.len() != input.len() {
17076        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17077    }
17078    let n = grad.len();
17079    if n == 0 { return cpu_to_gpu(&[], device); }
17080    let ctx = device.context();
17081    let stream = device.stream();
17082    if let Ok(f) = crate::module_cache::get_or_compile(ctx, ELU_BACKWARD_F64_PTX, "elu_backward_f64_kernel", device.ordinal() as u32) {
17083        let mut out = alloc_zeros_f64(n, device)?;
17084        let n_u32 = n as u32;
17085        let cfg = launch_cfg(n)?;
17086        unsafe {
17087            stream.launch_builder(&f)
17088                .arg(grad.inner())
17089                .arg(input.inner())
17090                .arg(out.inner_mut())
17091                .arg(&n_u32)
17092                .arg(&alpha)
17093                .launch(cfg)?;
17094        }
17095        return Ok(out);
17096    }
17097    let g_host = gpu_to_cpu(grad, device)?;
17098    let x_host = gpu_to_cpu(input, device)?;
17099    let result: Vec<f64> = g_host.iter().zip(x_host.iter()).map(|(&g, &x)| if x > 0.0 { g } else { g * alpha * x.exp() }).collect();
17100    cpu_to_gpu(&result, device)
17101}
17102
17103/// Mish for f64.
17104#[cfg(feature = "cuda")]
17105pub fn gpu_mish_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17106    if let Some(out) = try_launch_unary_f64(input, device, MISH_F64_PTX, "mish_f64_kernel")? {
17107        return Ok(out);
17108    }
17109    cpu_fallback_unary_f64(input, device, |x| x * (1.0_f64 + x.exp()).ln().tanh())
17110}
17111
17112/// Mish backward for f64.
17113#[cfg(feature = "cuda")]
17114pub fn gpu_mish_backward_f64(
17115    grad: &CudaBuffer<f64>,
17116    input: &CudaBuffer<f64>,
17117    device: &GpuDevice,
17118) -> GpuResult<CudaBuffer<f64>> {
17119    if grad.len() != input.len() {
17120        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17121    }
17122    if let Some(out) = try_launch_binary_f64(grad, input, device, MISH_BACKWARD_F64_PTX, "mish_backward_f64_kernel")? {
17123        return Ok(out);
17124    }
17125    cpu_fallback_binary_f64(grad, input, device, |g, x| {
17126        let sp = (1.0_f64 + x.exp()).ln();
17127        let t = sp.tanh();
17128        let sig = 1.0 / (1.0 + (-x).exp());
17129        g * (t + x * sig * (1.0 - t * t))
17130    })
17131}
17132
17133/// Clamp for f64.
17134#[cfg(feature = "cuda")]
17135pub fn gpu_clamp_f64(
17136    input: &CudaBuffer<f64>,
17137    min_val: f64,
17138    max_val: f64,
17139    device: &GpuDevice,
17140) -> GpuResult<CudaBuffer<f64>> {
17141    use cudarc::driver::PushKernelArg;
17142    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17143    let n = input.len();
17144    if n == 0 { return cpu_to_gpu(&[], device); }
17145    let ctx = device.context();
17146    let stream = device.stream();
17147    let ptx = get_f64_ptx(&CACHE, CLAMP_PTX, "clamp_kernel", "clamp_f64_kernel");
17148    if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "clamp_f64_kernel", device.ordinal() as u32) {
17149        let mut out = alloc_zeros_f64(n, device)?;
17150        let n_u32 = n as u32;
17151        let cfg = launch_cfg(n)?;
17152        unsafe {
17153            stream.launch_builder(&f)
17154                .arg(input.inner())
17155                .arg(out.inner_mut())
17156                .arg(&n_u32)
17157                .arg(&min_val)
17158                .arg(&max_val)
17159                .launch(cfg)?;
17160        }
17161        return Ok(out);
17162    }
17163    let host = gpu_to_cpu(input, device)?;
17164    let result: Vec<f64> = host.iter().map(|&x| x.max(min_val).min(max_val)).collect();
17165    cpu_to_gpu(&result, device)
17166}
17167
17168/// Cumulative sum for f64.
17169#[cfg(feature = "cuda")]
17170pub fn gpu_cumsum_f64(
17171    input: &CudaBuffer<f64>,
17172    outer: usize,
17173    dim_size: usize,
17174    inner: usize,
17175    device: &GpuDevice,
17176) -> GpuResult<CudaBuffer<f64>> {
17177    use cudarc::driver::PushKernelArg;
17178    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17179    let total = outer * inner;
17180    let n = outer * dim_size * inner;
17181    if n == 0 { return cpu_to_gpu(&[], device); }
17182    let ctx = device.context();
17183    let stream = device.stream();
17184    let ptx = get_f64_ptx(&CACHE, CUMSUM_PTX, "cumsum_kernel", "cumsum_f64_kernel");
17185    if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "cumsum_f64_kernel", device.ordinal() as u32) {
17186        let mut out = alloc_zeros_f64(n, device)?;
17187        let cfg = launch_cfg(total)?;
17188        let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17189        unsafe {
17190            stream.launch_builder(&f)
17191                .arg(input.inner())
17192                .arg(out.inner_mut())
17193                .arg(&o)
17194                .arg(&d)
17195                .arg(&i)
17196                .arg(&t)
17197                .launch(cfg)?;
17198        }
17199        return Ok(out);
17200    }
17201    Err(GpuError::PtxCompileFailed { kernel: "cumsum_f64_kernel" })
17202}
17203
17204/// Cumulative product for f64.
17205#[cfg(feature = "cuda")]
17206pub fn gpu_cumprod_f64(
17207    input: &CudaBuffer<f64>,
17208    outer: usize,
17209    dim_size: usize,
17210    inner: usize,
17211    device: &GpuDevice,
17212) -> GpuResult<CudaBuffer<f64>> {
17213    use cudarc::driver::PushKernelArg;
17214    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17215    let total = outer * inner;
17216    let n = outer * dim_size * inner;
17217    if n == 0 { return cpu_to_gpu(&[], device); }
17218    let ctx = device.context();
17219    let stream = device.stream();
17220    let ptx = get_f64_ptx(&CACHE, CUMPROD_PTX, "cumprod_kernel", "cumprod_f64_kernel");
17221    if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "cumprod_f64_kernel", device.ordinal() as u32) {
17222        let mut out = alloc_zeros_f64(n, device)?;
17223        let cfg = launch_cfg(total)?;
17224        let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17225        unsafe {
17226            stream.launch_builder(&f)
17227                .arg(input.inner())
17228                .arg(out.inner_mut())
17229                .arg(&o)
17230                .arg(&d)
17231                .arg(&i)
17232                .arg(&t)
17233                .launch(cfg)?;
17234        }
17235        return Ok(out);
17236    }
17237    Err(GpuError::PtxCompileFailed { kernel: "cumprod_f64_kernel" })
17238}
17239
17240/// Cumulative max for f64. Returns (values, indices).
17241#[cfg(feature = "cuda")]
17242pub fn gpu_cummax_f64(
17243    input: &CudaBuffer<f64>,
17244    outer: usize,
17245    dim_size: usize,
17246    inner: usize,
17247    device: &GpuDevice,
17248) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
17249    use cudarc::driver::PushKernelArg;
17250    let total = outer * inner;
17251    let n = outer * dim_size * inner;
17252    if n == 0 {
17253        let e: &[f64] = &[];
17254        return Ok((cpu_to_gpu(e, device)?, cpu_to_gpu(e, device)?));
17255    }
17256    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17257    let ctx = device.context();
17258    let stream = device.stream();
17259    let ptx = get_f64_ptx(&CACHE, CUMMAX_PTX, "cummax_kernel", "cummax_f64_kernel");
17260    let f = crate::module_cache::get_or_compile(ctx, ptx, "cummax_f64_kernel", device.ordinal() as u32)
17261        .map_err(|_| GpuError::PtxCompileFailed { kernel: "cummax_f64_kernel" })?;
17262    let mut out = alloc_zeros_f64(n, device)?;
17263    let mut ind = alloc_zeros_f64(n, device)?;
17264    let cfg = launch_cfg(total)?;
17265    let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17266    unsafe {
17267        stream.launch_builder(&f)
17268            .arg(input.inner())
17269            .arg(out.inner_mut())
17270            .arg(ind.inner_mut())
17271            .arg(&o)
17272            .arg(&d)
17273            .arg(&i)
17274            .arg(&t)
17275            .launch(cfg)?;
17276    }
17277    Ok((out, ind))
17278}
17279
17280/// Cumulative min for f64. Returns (values, indices).
17281#[cfg(feature = "cuda")]
17282pub fn gpu_cummin_f64(
17283    input: &CudaBuffer<f64>,
17284    outer: usize,
17285    dim_size: usize,
17286    inner: usize,
17287    device: &GpuDevice,
17288) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
17289    use cudarc::driver::PushKernelArg;
17290    let total = outer * inner;
17291    let n = outer * dim_size * inner;
17292    if n == 0 {
17293        let e: &[f64] = &[];
17294        return Ok((cpu_to_gpu(e, device)?, cpu_to_gpu(e, device)?));
17295    }
17296    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17297    let ctx = device.context();
17298    let stream = device.stream();
17299    let ptx = get_f64_ptx(&CACHE, CUMMIN_PTX, "cummin_kernel", "cummin_f64_kernel");
17300    let f = crate::module_cache::get_or_compile(ctx, ptx, "cummin_f64_kernel", device.ordinal() as u32)
17301        .map_err(|_| GpuError::PtxCompileFailed { kernel: "cummin_f64_kernel" })?;
17302    let mut out = alloc_zeros_f64(n, device)?;
17303    let mut ind = alloc_zeros_f64(n, device)?;
17304    let cfg = launch_cfg(total)?;
17305    let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17306    unsafe {
17307        stream.launch_builder(&f)
17308            .arg(input.inner())
17309            .arg(out.inner_mut())
17310            .arg(ind.inner_mut())
17311            .arg(&o)
17312            .arg(&d)
17313            .arg(&i)
17314            .arg(&t)
17315            .launch(cfg)?;
17316    }
17317    Ok((out, ind))
17318}
17319
17320/// Log-cumulative-sum-exp for f64.
17321#[cfg(feature = "cuda")]
17322pub fn gpu_logcumsumexp_f64(
17323    input: &CudaBuffer<f64>,
17324    outer: usize,
17325    dim_size: usize,
17326    inner: usize,
17327    device: &GpuDevice,
17328) -> GpuResult<CudaBuffer<f64>> {
17329    use cudarc::driver::PushKernelArg;
17330    let total = outer * inner;
17331    let n = outer * dim_size * inner;
17332    if n == 0 { return cpu_to_gpu(&[], device); }
17333    let ctx = device.context();
17334    let stream = device.stream();
17335    if let Ok(f) = crate::module_cache::get_or_compile(ctx, LOGCUMSUMEXP_F64_PTX, "logcumsumexp_f64_kernel", device.ordinal() as u32) {
17336        let mut out = alloc_zeros_f64(n, device)?;
17337        let cfg = launch_cfg(total)?;
17338        let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17339        unsafe {
17340            stream.launch_builder(&f)
17341                .arg(input.inner())
17342                .arg(out.inner_mut())
17343                .arg(&o)
17344                .arg(&d)
17345                .arg(&i)
17346                .arg(&t)
17347                .launch(cfg)?;
17348        }
17349        return Ok(out);
17350    }
17351    Err(GpuError::PtxCompileFailed { kernel: "logcumsumexp_f64_kernel" })
17352}
17353
17354// ---------------------------------------------------------------------------
17355// Public API -- f64 softmax / log-softmax / layernorm / rmsnorm launchers
17356// ---------------------------------------------------------------------------
17357
17358/// Row-wise softmax for f64 on GPU.
17359///
17360/// For each row: `out[j] = exp(x[j] - max(x)) / sum(exp(x - max(x)))`.
17361/// One block per row, 256 threads per block, shared-memory reductions.
17362#[cfg(feature = "cuda")]
17363pub fn gpu_softmax_f64(
17364    input: &CudaBuffer<f64>,
17365    rows: usize,
17366    cols: usize,
17367    device: &GpuDevice,
17368) -> GpuResult<CudaBuffer<f64>> {
17369    use cudarc::driver::PushKernelArg;
17370
17371    validate_device(input, device)?;
17372
17373    let ctx = device.context();
17374    let stream = device.stream();
17375
17376    let f = match crate::module_cache::get_or_compile(
17377        ctx,
17378        SOFTMAX_F64_PTX,
17379        "softmax_f64_kernel",
17380        device.ordinal() as u32,
17381    ) {
17382        Ok(f) => f,
17383        Err(_) => {
17384            let host = gpu_to_cpu(input, device)?;
17385            let mut out = vec![0.0f64; host.len()];
17386            for r in 0..rows {
17387                let base = r * cols;
17388                let mut max_v = f64::NEG_INFINITY;
17389                for c in 0..cols {
17390                    max_v = max_v.max(host[base + c]);
17391                }
17392                let mut sum = 0.0f64;
17393                for c in 0..cols {
17394                    let e = (host[base + c] - max_v).exp();
17395                    out[base + c] = e;
17396                    sum += e;
17397                }
17398                let inv = 1.0 / sum;
17399                for c in 0..cols {
17400                    out[base + c] *= inv;
17401                }
17402            }
17403            return cpu_to_gpu(&out, device);
17404        }
17405    };
17406
17407    let mut out = alloc_zeros_f64(rows * cols, device)?;
17408    let rows_u32 = rows as u32;
17409    let cols_u32 = cols as u32;
17410
17411    let cfg = LaunchConfig {
17412        grid_dim: ((rows as u32).max(1), 1, 1),
17413        block_dim: (256, 1, 1),
17414        shared_mem_bytes: 256 * 8, // sdata[256] f64
17415    };
17416
17417    unsafe {
17418        stream
17419            .launch_builder(&f)
17420            .arg(input.inner())
17421            .arg(out.inner_mut())
17422            .arg(&rows_u32)
17423            .arg(&cols_u32)
17424            .launch(cfg)?;
17425    }
17426
17427    Ok(out)
17428}
17429
17430/// Row-wise softmax backward for f64 on GPU.
17431///
17432/// For each row: `out[j] = output[j] * (grad[j] - dot(grad_row, output_row))`.
17433#[cfg(feature = "cuda")]
17434pub fn gpu_softmax_backward_f64(
17435    grad: &CudaBuffer<f64>,
17436    output: &CudaBuffer<f64>,
17437    cols: usize,
17438    device: &GpuDevice,
17439) -> GpuResult<CudaBuffer<f64>> {
17440    use cudarc::driver::PushKernelArg;
17441
17442    validate_device(grad, device)?;
17443    if grad.len() != output.len() {
17444        return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
17445    }
17446
17447    let total = grad.len();
17448    let rows = total / cols;
17449
17450    let ctx = device.context();
17451    let stream = device.stream();
17452
17453    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17454    let ptx = get_f64_ptx(&CACHE, SOFTMAX_BACKWARD_PTX, "softmax_backward_kernel", "softmax_backward_f64_kernel");
17455    let f = match crate::module_cache::get_or_compile(
17456        ctx,
17457        ptx,
17458        "softmax_backward_f64_kernel",
17459        device.ordinal() as u32,
17460    ) {
17461        Ok(f) => f,
17462        Err(_) => {
17463            let grad_host = gpu_to_cpu(grad, device)?;
17464            let output_host = gpu_to_cpu(output, device)?;
17465            let mut result = vec![0.0f64; total];
17466            for r in 0..rows {
17467                let base = r * cols;
17468                let mut dot = 0.0f64;
17469                for c in 0..cols {
17470                    dot += grad_host[base + c] * output_host[base + c];
17471                }
17472                for c in 0..cols {
17473                    result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
17474                }
17475            }
17476            return cpu_to_gpu(&result, device);
17477        }
17478    };
17479
17480    let mut out = alloc_zeros_f64(total, device)?;
17481    let rows_u32 = rows as u32;
17482    let cols_u32 = cols as u32;
17483
17484    let cfg = LaunchConfig {
17485        grid_dim: ((rows as u32).max(1), 1, 1),
17486        block_dim: (256, 1, 1),
17487        shared_mem_bytes: 256 * 8,
17488    };
17489
17490    unsafe {
17491        stream
17492            .launch_builder(&f)
17493            .arg(grad.inner())
17494            .arg(output.inner())
17495            .arg(out.inner_mut())
17496            .arg(&rows_u32)
17497            .arg(&cols_u32)
17498            .launch(cfg)?;
17499    }
17500
17501    Ok(out)
17502}
17503
17504/// Row-wise log-softmax for f64 on GPU.
17505///
17506/// For each row: `out[j] = x[j] - log(sum(exp(x - max(x))))`.
17507#[cfg(feature = "cuda")]
17508pub fn gpu_log_softmax_f64(
17509    input: &CudaBuffer<f64>,
17510    cols: usize,
17511    device: &GpuDevice,
17512) -> GpuResult<CudaBuffer<f64>> {
17513    use cudarc::driver::PushKernelArg;
17514
17515    validate_device(input, device)?;
17516
17517    let total = input.len();
17518    let rows = total / cols;
17519
17520    let ctx = device.context();
17521    let stream = device.stream();
17522
17523    let f = match crate::module_cache::get_or_compile(
17524        ctx,
17525        LOG_SOFTMAX_F64_PTX,
17526        "log_softmax_f64_kernel",
17527        device.ordinal() as u32,
17528    ) {
17529        Ok(f) => f,
17530        Err(_) => {
17531            let host = gpu_to_cpu(input, device)?;
17532            let mut out = vec![0.0f64; total];
17533            for r in 0..rows {
17534                let base = r * cols;
17535                let mut max_v = f64::NEG_INFINITY;
17536                for c in 0..cols {
17537                    max_v = max_v.max(host[base + c]);
17538                }
17539                let mut sum_exp = 0.0f64;
17540                for c in 0..cols {
17541                    sum_exp += (host[base + c] - max_v).exp();
17542                }
17543                let log_sum_exp = max_v + sum_exp.ln();
17544                for c in 0..cols {
17545                    out[base + c] = host[base + c] - log_sum_exp;
17546                }
17547            }
17548            return cpu_to_gpu(&out, device);
17549        }
17550    };
17551
17552    let mut out = alloc_zeros_f64(total, device)?;
17553    let rows_u32 = rows as u32;
17554    let cols_u32 = cols as u32;
17555
17556    let cfg = LaunchConfig {
17557        grid_dim: ((rows as u32).max(1), 1, 1),
17558        block_dim: (256, 1, 1),
17559        shared_mem_bytes: 256 * 8,
17560    };
17561
17562    unsafe {
17563        stream
17564            .launch_builder(&f)
17565            .arg(input.inner())
17566            .arg(out.inner_mut())
17567            .arg(&rows_u32)
17568            .arg(&cols_u32)
17569            .launch(cfg)?;
17570    }
17571
17572    Ok(out)
17573}
17574
17575/// Row-wise log-softmax backward for f64 on GPU.
17576///
17577/// For each row:
17578///   `sum_grad = sum(grad[j])`
17579///   `out[j] = grad[j] - exp(output[j]) * sum_grad`
17580#[cfg(feature = "cuda")]
17581pub fn gpu_log_softmax_backward_f64(
17582    grad: &CudaBuffer<f64>,
17583    output: &CudaBuffer<f64>,
17584    cols: usize,
17585    device: &GpuDevice,
17586) -> GpuResult<CudaBuffer<f64>> {
17587    use cudarc::driver::PushKernelArg;
17588
17589    validate_device(grad, device)?;
17590    if grad.len() != output.len() {
17591        return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
17592    }
17593
17594    let total = grad.len();
17595    let rows = total / cols;
17596
17597    let ctx = device.context();
17598    let stream = device.stream();
17599
17600    let f = match crate::module_cache::get_or_compile(
17601        ctx,
17602        LOG_SOFTMAX_BACKWARD_F64_PTX,
17603        "log_softmax_backward_f64_kernel",
17604        device.ordinal() as u32,
17605    ) {
17606        Ok(f) => f,
17607        Err(_) => {
17608            let grad_host = gpu_to_cpu(grad, device)?;
17609            let output_host = gpu_to_cpu(output, device)?;
17610            let mut result = vec![0.0f64; total];
17611            for r in 0..rows {
17612                let base = r * cols;
17613                let mut sum_grad = 0.0f64;
17614                for c in 0..cols {
17615                    sum_grad += grad_host[base + c];
17616                }
17617                for c in 0..cols {
17618                    result[base + c] =
17619                        grad_host[base + c] - output_host[base + c].exp() * sum_grad;
17620                }
17621            }
17622            return cpu_to_gpu(&result, device);
17623        }
17624    };
17625
17626    let mut out = alloc_zeros_f64(total, device)?;
17627    let rows_u32 = rows as u32;
17628    let cols_u32 = cols as u32;
17629
17630    let cfg = LaunchConfig {
17631        grid_dim: ((rows as u32).max(1), 1, 1),
17632        block_dim: (256, 1, 1),
17633        shared_mem_bytes: 256 * 8,
17634    };
17635
17636    unsafe {
17637        stream
17638            .launch_builder(&f)
17639            .arg(grad.inner())
17640            .arg(output.inner())
17641            .arg(out.inner_mut())
17642            .arg(&rows_u32)
17643            .arg(&cols_u32)
17644            .launch(cfg)?;
17645    }
17646
17647    Ok(out)
17648}
17649
17650/// Row-wise LayerNorm for f64 on GPU.
17651///
17652/// `input`: `[rows * cols]`, `weight`: `[cols]`, `bias`: `[cols]`.
17653/// `out[j] = weight[j] * (x[j] - mean) / sqrt(var + eps) + bias[j]`.
17654#[cfg(feature = "cuda")]
17655pub fn gpu_layernorm_f64(
17656    input: &CudaBuffer<f64>,
17657    weight: &CudaBuffer<f64>,
17658    bias: &CudaBuffer<f64>,
17659    rows: usize,
17660    cols: usize,
17661    eps: f64,
17662    device: &GpuDevice,
17663) -> GpuResult<CudaBuffer<f64>> {
17664    use cudarc::driver::PushKernelArg;
17665    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17666
17667    validate_device(input, device)?;
17668
17669    let ctx = device.context();
17670    let stream = device.stream();
17671
17672    let ptx = get_f64_ptx(&CACHE, LAYERNORM_PTX, "layernorm_kernel", "layernorm_f64_kernel");
17673    let f = match crate::module_cache::get_or_compile(
17674        ctx,
17675        ptx,
17676        "layernorm_f64_kernel",
17677        device.ordinal() as u32,
17678    ) {
17679        Ok(f) => f,
17680        Err(_) => {
17681            let h_in = gpu_to_cpu(input, device)?;
17682            let h_w = gpu_to_cpu(weight, device)?;
17683            let h_b = gpu_to_cpu(bias, device)?;
17684            let mut out = vec![0.0f64; rows * cols];
17685            for r in 0..rows {
17686                let base = r * cols;
17687                let slice = &h_in[base..base + cols];
17688                let mean: f64 = slice.iter().sum::<f64>() / cols as f64;
17689                let var: f64 =
17690                    slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / cols as f64;
17691                let inv_std = 1.0 / (var + eps).sqrt();
17692                for c in 0..cols {
17693                    let normed = (slice[c] - mean) * inv_std;
17694                    out[base + c] = h_w[c] * normed + h_b[c];
17695                }
17696            }
17697            return cpu_to_gpu(&out, device);
17698        }
17699    };
17700
17701    let mut out = alloc_zeros_f64(rows * cols, device)?;
17702    let rows_u32 = rows as u32;
17703    let cols_u32 = cols as u32;
17704
17705    let cfg = LaunchConfig {
17706        grid_dim: ((rows as u32).max(1), 1, 1),
17707        block_dim: (256, 1, 1),
17708        shared_mem_bytes: 256 * 8,
17709    };
17710
17711    unsafe {
17712        stream
17713            .launch_builder(&f)
17714            .arg(input.inner())
17715            .arg(out.inner_mut())
17716            .arg(weight.inner())
17717            .arg(bias.inner())
17718            .arg(&rows_u32)
17719            .arg(&cols_u32)
17720            .arg(&eps)
17721            .launch(cfg)?;
17722    }
17723
17724    Ok(out)
17725}
17726
17727/// LayerNorm backward for f64 on GPU.
17728///
17729/// Returns `(grad_input [rows * cols], grad_weight [cols], grad_bias [cols])`.
17730#[cfg(feature = "cuda")]
17731pub fn gpu_layernorm_backward_f64(
17732    input: &CudaBuffer<f64>,
17733    grad_output: &CudaBuffer<f64>,
17734    weight: &CudaBuffer<f64>,
17735    rows: usize,
17736    cols: usize,
17737    eps: f64,
17738    device: &GpuDevice,
17739) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>, CudaBuffer<f64>)> {
17740    use cudarc::driver::PushKernelArg;
17741    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17742
17743    validate_device(input, device)?;
17744
17745    let ctx = device.context();
17746    let stream = device.stream();
17747
17748    let ptx = get_f64_ptx(&CACHE, LAYERNORM_BACKWARD_PTX, "layernorm_backward_kernel", "layernorm_backward_f64_kernel");
17749    let f = match crate::module_cache::get_or_compile(
17750        ctx,
17751        ptx,
17752        "layernorm_backward_f64_kernel",
17753        device.ordinal() as u32,
17754    ) {
17755        Ok(f) => f,
17756        Err(_) => {
17757            let h_in = gpu_to_cpu(input, device)?;
17758            let h_go = gpu_to_cpu(grad_output, device)?;
17759            let h_w = gpu_to_cpu(weight, device)?;
17760            let mut grad_input = vec![0.0f64; rows * cols];
17761            let mut grad_weight = vec![0.0f64; cols];
17762            let mut grad_bias = vec![0.0f64; cols];
17763            let n_f = cols as f64;
17764            for r in 0..rows {
17765                let base = r * cols;
17766                let x_slice = &h_in[base..base + cols];
17767                let go_slice = &h_go[base..base + cols];
17768                let mean: f64 = x_slice.iter().sum::<f64>() / n_f;
17769                let var: f64 = x_slice
17770                    .iter()
17771                    .map(|&x| (x - mean) * (x - mean))
17772                    .sum::<f64>()
17773                    / n_f;
17774                let inv_std = 1.0 / (var + eps).sqrt();
17775                let mut sum1 = 0.0f64;
17776                let mut sum2 = 0.0f64;
17777                for c in 0..cols {
17778                    let x_hat = (x_slice[c] - mean) * inv_std;
17779                    let dl = go_slice[c] * h_w[c];
17780                    sum1 += dl;
17781                    sum2 += dl * x_hat;
17782                    grad_weight[c] += go_slice[c] * x_hat;
17783                    grad_bias[c] += go_slice[c];
17784                }
17785                let m1 = sum1 / n_f;
17786                let m2 = sum2 / n_f;
17787                for c in 0..cols {
17788                    let x_hat = (x_slice[c] - mean) * inv_std;
17789                    let dl = go_slice[c] * h_w[c];
17790                    grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
17791                }
17792            }
17793            let gi = cpu_to_gpu(&grad_input, device)?;
17794            let gw = cpu_to_gpu(&grad_weight, device)?;
17795            let gb = cpu_to_gpu(&grad_bias, device)?;
17796            return Ok((gi, gw, gb));
17797        }
17798    };
17799
17800    let mut grad_in = alloc_zeros_f64(rows * cols, device)?;
17801    let mut grad_w = alloc_zeros_f64(cols, device)?;
17802    let mut grad_b = alloc_zeros_f64(cols, device)?;
17803    let rows_u32 = rows as u32;
17804    let cols_u32 = cols as u32;
17805
17806    let cfg = LaunchConfig {
17807        grid_dim: ((rows as u32).max(1), 1, 1),
17808        block_dim: (256, 1, 1),
17809        shared_mem_bytes: 256 * 8,
17810    };
17811
17812    unsafe {
17813        stream
17814            .launch_builder(&f)
17815            .arg(input.inner())
17816            .arg(grad_output.inner())
17817            .arg(weight.inner())
17818            .arg(grad_in.inner_mut())
17819            .arg(grad_w.inner_mut())
17820            .arg(grad_b.inner_mut())
17821            .arg(&rows_u32)
17822            .arg(&cols_u32)
17823            .arg(&eps)
17824            .launch(cfg)?;
17825    }
17826
17827    Ok((grad_in, grad_w, grad_b))
17828}
17829
17830/// Row-wise RMS normalization for f64 on GPU.
17831///
17832/// `input`: `[rows * cols]`, `weight`: `[cols]`.
17833/// `out[j] = x[j] * rsqrt(mean(x^2) + eps) * weight[j]`.
17834#[cfg(feature = "cuda")]
17835pub fn gpu_rmsnorm_f64(
17836    input: &CudaBuffer<f64>,
17837    weight: &CudaBuffer<f64>,
17838    rows: usize,
17839    cols: usize,
17840    eps: f64,
17841    device: &GpuDevice,
17842) -> GpuResult<CudaBuffer<f64>> {
17843    use cudarc::driver::PushKernelArg;
17844    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17845
17846    validate_device(input, device)?;
17847
17848    let ctx = device.context();
17849    let stream = device.stream();
17850
17851    let ptx = get_f64_ptx(&CACHE, RMSNORM_PTX, "rmsnorm_kernel", "rmsnorm_f64_kernel");
17852    let f = match crate::module_cache::get_or_compile(
17853        ctx,
17854        ptx,
17855        "rmsnorm_f64_kernel",
17856        device.ordinal() as u32,
17857    ) {
17858        Ok(f) => f,
17859        Err(_) => {
17860            let h_in = gpu_to_cpu(input, device)?;
17861            let h_w = gpu_to_cpu(weight, device)?;
17862            let mut out = vec![0.0f64; rows * cols];
17863            for r in 0..rows {
17864                let base = r * cols;
17865                let slice = &h_in[base..base + cols];
17866                let sq_mean: f64 =
17867                    slice.iter().map(|&x| x * x).sum::<f64>() / cols as f64;
17868                let inv_rms = 1.0 / (sq_mean + eps).sqrt();
17869                for c in 0..cols {
17870                    out[base + c] = slice[c] * inv_rms * h_w[c];
17871                }
17872            }
17873            return cpu_to_gpu(&out, device);
17874        }
17875    };
17876
17877    let mut out = alloc_zeros_f64(rows * cols, device)?;
17878    let rows_u32 = rows as u32;
17879    let cols_u32 = cols as u32;
17880
17881    let cfg = LaunchConfig {
17882        grid_dim: ((rows as u32).max(1), 1, 1),
17883        block_dim: (256, 1, 1),
17884        shared_mem_bytes: 256 * 8,
17885    };
17886
17887    unsafe {
17888        stream
17889            .launch_builder(&f)
17890            .arg(input.inner())
17891            .arg(out.inner_mut())
17892            .arg(weight.inner())
17893            .arg(&rows_u32)
17894            .arg(&cols_u32)
17895            .arg(&eps)
17896            .launch(cfg)?;
17897    }
17898
17899    Ok(out)
17900}
17901
17902/// RMSNorm backward for f64 on GPU.
17903///
17904/// Returns `(grad_input [rows * cols], grad_weight [cols])`.
17905#[cfg(feature = "cuda")]
17906pub fn gpu_rmsnorm_backward_f64(
17907    input: &CudaBuffer<f64>,
17908    grad_output: &CudaBuffer<f64>,
17909    weight: &CudaBuffer<f64>,
17910    rows: usize,
17911    cols: usize,
17912    eps: f64,
17913    device: &GpuDevice,
17914) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
17915    use cudarc::driver::PushKernelArg;
17916    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17917
17918    validate_device(input, device)?;
17919
17920    let ctx = device.context();
17921    let stream = device.stream();
17922
17923    let ptx = get_f64_ptx(&CACHE, RMSNORM_BACKWARD_PTX, "rmsnorm_backward_kernel", "rmsnorm_backward_f64_kernel");
17924    let f = match crate::module_cache::get_or_compile(
17925        ctx,
17926        ptx,
17927        "rmsnorm_backward_f64_kernel",
17928        device.ordinal() as u32,
17929    ) {
17930        Ok(f) => f,
17931        Err(_) => {
17932            let h_in = gpu_to_cpu(input, device)?;
17933            let h_go = gpu_to_cpu(grad_output, device)?;
17934            let h_w = gpu_to_cpu(weight, device)?;
17935            let mut grad_input = vec![0.0f64; rows * cols];
17936            let mut grad_weight = vec![0.0f64; cols];
17937            let n_f = cols as f64;
17938            for r in 0..rows {
17939                let base = r * cols;
17940                let x_slice = &h_in[base..base + cols];
17941                let go_slice = &h_go[base..base + cols];
17942                let sq_mean: f64 =
17943                    x_slice.iter().map(|&x| x * x).sum::<f64>() / n_f;
17944                let inv_rms = 1.0 / (sq_mean + eps).sqrt();
17945                let inv_rms3 = inv_rms * inv_rms * inv_rms;
17946                let mut dot = 0.0f64;
17947                for c in 0..cols {
17948                    dot += go_slice[c] * x_slice[c] * h_w[c];
17949                    grad_weight[c] += go_slice[c] * x_slice[c] * inv_rms;
17950                }
17951                let coeff = dot * inv_rms3 / n_f;
17952                for c in 0..cols {
17953                    grad_input[base + c] =
17954                        inv_rms * h_w[c] * go_slice[c] - x_slice[c] * coeff;
17955                }
17956            }
17957            let gi = cpu_to_gpu(&grad_input, device)?;
17958            let gw = cpu_to_gpu(&grad_weight, device)?;
17959            return Ok((gi, gw));
17960        }
17961    };
17962
17963    let mut grad_in = alloc_zeros_f64(rows * cols, device)?;
17964    let mut grad_w = alloc_zeros_f64(cols, device)?;
17965    let rows_u32 = rows as u32;
17966    let cols_u32 = cols as u32;
17967
17968    let cfg = LaunchConfig {
17969        grid_dim: ((rows as u32).max(1), 1, 1),
17970        block_dim: (256, 1, 1),
17971        shared_mem_bytes: 256 * 8,
17972    };
17973
17974    unsafe {
17975        stream
17976            .launch_builder(&f)
17977            .arg(input.inner())
17978            .arg(grad_output.inner())
17979            .arg(weight.inner())
17980            .arg(grad_in.inner_mut())
17981            .arg(grad_w.inner_mut())
17982            .arg(&rows_u32)
17983            .arg(&cols_u32)
17984            .arg(&eps)
17985            .launch(cfg)?;
17986    }
17987
17988    Ok((grad_in, grad_w))
17989}
17990
17991// ---------------------------------------------------------------------------
17992// Non-cuda stubs for softmax/layernorm/rmsnorm f64
17993// ---------------------------------------------------------------------------
17994
17995#[cfg(not(feature = "cuda"))]
17996pub fn gpu_softmax_f64(_input: &CudaBuffer<f64>, _rows: usize, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17997#[cfg(not(feature = "cuda"))]
17998pub fn gpu_softmax_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17999#[cfg(not(feature = "cuda"))]
18000pub fn gpu_log_softmax_f64(_input: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18001#[cfg(not(feature = "cuda"))]
18002pub fn gpu_log_softmax_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18003#[cfg(not(feature = "cuda"))]
18004pub fn gpu_layernorm_f64(_input: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _bias: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18005#[cfg(not(feature = "cuda"))]
18006pub fn gpu_layernorm_backward_f64(_input: &CudaBuffer<f64>, _grad_output: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18007#[cfg(not(feature = "cuda"))]
18008pub fn gpu_rmsnorm_f64(_input: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18009#[cfg(not(feature = "cuda"))]
18010pub fn gpu_rmsnorm_backward_f64(_input: &CudaBuffer<f64>, _grad_output: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18011
18012// ---------------------------------------------------------------------------
18013// Non-cuda stubs for new f64 ops
18014// ---------------------------------------------------------------------------
18015
18016#[cfg(not(feature = "cuda"))]
18017pub fn gpu_gelu_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18018#[cfg(not(feature = "cuda"))]
18019pub fn gpu_gelu_tanh_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18020#[cfg(not(feature = "cuda"))]
18021pub fn gpu_gelu_erf_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18022#[cfg(not(feature = "cuda"))]
18023pub fn gpu_gelu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18024#[cfg(not(feature = "cuda"))]
18025pub fn gpu_gelu_backward_tanh_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18026#[cfg(not(feature = "cuda"))]
18027pub fn gpu_gelu_backward_erf_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18028#[cfg(not(feature = "cuda"))]
18029pub fn gpu_silu_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18030#[cfg(not(feature = "cuda"))]
18031pub fn gpu_silu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18032#[cfg(not(feature = "cuda"))]
18033pub fn gpu_elu_f64(_input: &CudaBuffer<f64>, _alpha: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18034#[cfg(not(feature = "cuda"))]
18035pub fn gpu_elu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _alpha: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18036#[cfg(not(feature = "cuda"))]
18037pub fn gpu_mish_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18038#[cfg(not(feature = "cuda"))]
18039pub fn gpu_mish_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18040#[cfg(not(feature = "cuda"))]
18041pub fn gpu_clamp_f64(_input: &CudaBuffer<f64>, _min: f64, _max: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18042#[cfg(not(feature = "cuda"))]
18043pub fn gpu_cumsum_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18044#[cfg(not(feature = "cuda"))]
18045pub fn gpu_cumprod_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18046#[cfg(not(feature = "cuda"))]
18047pub fn gpu_cummax_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18048#[cfg(not(feature = "cuda"))]
18049pub fn gpu_cummin_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18050#[cfg(not(feature = "cuda"))]
18051pub fn gpu_logcumsumexp_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18052
18053// ---------------------------------------------------------------------------
18054// Tests -- require a real CUDA GPU
18055// ---------------------------------------------------------------------------
18056
18057#[cfg(test)]
18058#[cfg(feature = "cuda")]
18059mod tests {
18060    use super::*;
18061
18062    /// Helper: set up device + upload a slice.
18063    fn setup(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
18064        let dev = GpuDevice::new(0).expect("CUDA device 0");
18065        let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
18066        (dev, buf)
18067    }
18068
18069    /// Round-trip helper: download a GPU buffer and compare against expected
18070    /// CPU output element-wise.
18071    fn assert_buf_eq(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32]) {
18072        let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
18073        assert_eq!(host.len(), expected.len(), "length mismatch");
18074        for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
18075            assert!(
18076                (got - exp).abs() < 1e-6,
18077                "element {i}: got {got}, expected {exp}",
18078            );
18079        }
18080    }
18081
18082    // -- gpu_add -------------------------------------------------------------
18083
18084    #[test]
18085    fn add_basic() {
18086        let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
18087        let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
18088        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
18089
18090        let (dev, a) = setup(&a_data);
18091        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18092        let out = gpu_add(&a, &b, &dev).expect("gpu_add");
18093        assert_buf_eq(&out, &dev, &expected);
18094    }
18095
18096    #[test]
18097    fn add_empty() {
18098        let (dev, a) = setup(&[]);
18099        let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
18100        let out = gpu_add(&a, &b, &dev).expect("gpu_add empty");
18101        assert_eq!(out.len(), 0);
18102    }
18103
18104    #[test]
18105    fn add_large() {
18106        let n = 100_000;
18107        let a_data: Vec<f32> = (0..n).map(|i| i as f32).collect();
18108        let b_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
18109        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
18110
18111        let (dev, a) = setup(&a_data);
18112        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18113        let out = gpu_add(&a, &b, &dev).expect("gpu_add large");
18114        assert_buf_eq(&out, &dev, &expected);
18115    }
18116
18117    #[test]
18118    fn add_length_mismatch() {
18119        let (dev, a) = setup(&[1.0, 2.0, 3.0]);
18120        let b = cpu_to_gpu(&[1.0, 2.0], &dev).expect("cpu_to_gpu b");
18121        let err = gpu_add(&a, &b, &dev).unwrap_err();
18122        match err {
18123            GpuError::LengthMismatch { a: 3, b: 2 } => {}
18124            other => panic!("unexpected error: {other}"),
18125        }
18126    }
18127
18128    // -- gpu_sub -------------------------------------------------------------
18129
18130    #[test]
18131    fn sub_basic() {
18132        let a_data = vec![10.0f32, 20.0, 30.0, 40.0];
18133        let b_data = vec![1.0f32, 2.0, 3.0, 4.0];
18134        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x - y).collect();
18135
18136        let (dev, a) = setup(&a_data);
18137        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18138        let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
18139        assert_buf_eq(&out, &dev, &expected);
18140    }
18141
18142    #[test]
18143    fn sub_negative_result() {
18144        let a_data = vec![1.0f32, 2.0];
18145        let b_data = vec![5.0f32, 10.0];
18146        let expected: Vec<f32> = vec![-4.0, -8.0];
18147
18148        let (dev, a) = setup(&a_data);
18149        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18150        let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
18151        assert_buf_eq(&out, &dev, &expected);
18152    }
18153
18154    // -- gpu_mul -------------------------------------------------------------
18155
18156    #[test]
18157    fn mul_basic() {
18158        let a_data = vec![2.0f32, 3.0, 4.0, 5.0];
18159        let b_data = vec![10.0f32, 10.0, 10.0, 10.0];
18160        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x * y).collect();
18161
18162        let (dev, a) = setup(&a_data);
18163        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18164        let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
18165        assert_buf_eq(&out, &dev, &expected);
18166    }
18167
18168    #[test]
18169    fn mul_by_zero() {
18170        let a_data = vec![1.0f32, 2.0, 3.0];
18171        let b_data = vec![0.0f32, 0.0, 0.0];
18172        let expected = vec![0.0f32, 0.0, 0.0];
18173
18174        let (dev, a) = setup(&a_data);
18175        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18176        let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
18177        assert_buf_eq(&out, &dev, &expected);
18178    }
18179
18180    // -- gpu_neg -------------------------------------------------------------
18181
18182    #[test]
18183    fn neg_basic() {
18184        let a_data = vec![1.0f32, -2.0, 3.0, 0.0, -5.5];
18185        let expected: Vec<f32> = a_data.iter().map(|x| -x).collect();
18186
18187        let (dev, a) = setup(&a_data);
18188        let out = gpu_neg(&a, &dev).expect("gpu_neg");
18189        assert_buf_eq(&out, &dev, &expected);
18190    }
18191
18192    #[test]
18193    fn neg_double_negation() {
18194        let a_data = vec![1.0f32, -2.0, 3.0];
18195        let (dev, a) = setup(&a_data);
18196        let neg1 = gpu_neg(&a, &dev).expect("gpu_neg 1");
18197        let neg2 = gpu_neg(&neg1, &dev).expect("gpu_neg 2");
18198        assert_buf_eq(&neg2, &dev, &a_data);
18199    }
18200
18201    // -- gpu_relu ------------------------------------------------------------
18202
18203    #[test]
18204    fn relu_basic() {
18205        let a_data = vec![-3.0f32, -1.0, 0.0, 1.0, 3.0];
18206        let expected = vec![0.0f32, 0.0, 0.0, 1.0, 3.0];
18207
18208        let (dev, a) = setup(&a_data);
18209        let out = gpu_relu(&a, &dev).expect("gpu_relu");
18210        assert_buf_eq(&out, &dev, &expected);
18211    }
18212
18213    #[test]
18214    fn relu_all_negative() {
18215        let a_data = vec![-5.0f32, -0.1, -100.0];
18216        let expected = vec![0.0f32, 0.0, 0.0];
18217
18218        let (dev, a) = setup(&a_data);
18219        let out = gpu_relu(&a, &dev).expect("gpu_relu");
18220        assert_buf_eq(&out, &dev, &expected);
18221    }
18222
18223    #[test]
18224    fn relu_all_positive() {
18225        let a_data = vec![0.1f32, 1.0, 100.0];
18226
18227        let (dev, a) = setup(&a_data);
18228        let out = gpu_relu(&a, &dev).expect("gpu_relu");
18229        assert_buf_eq(&out, &dev, &a_data);
18230    }
18231
18232    #[test]
18233    fn relu_empty() {
18234        let (dev, a) = setup(&[]);
18235        let out = gpu_relu(&a, &dev).expect("gpu_relu empty");
18236        assert_eq!(out.len(), 0);
18237    }
18238
18239    #[test]
18240    fn small_matmul_2x2() {
18241        let dev = GpuDevice::new(0).expect("CUDA device 0");
18242        // A = [[1, 2], [3, 4]], B = [[5, 6], [7, 8]]
18243        // C = A@B = [[19, 22], [43, 50]]
18244        let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0, 4.0], &dev).unwrap();
18245        let b = cpu_to_gpu(&[5.0f32, 6.0, 7.0, 8.0], &dev).unwrap();
18246        let c = gpu_small_matmul(&a, &b, 2, 2, 2, &dev).unwrap();
18247        assert_buf_eq(&c, &dev, &[19.0, 22.0, 43.0, 50.0]);
18248    }
18249
18250    #[test]
18251    fn small_matmul_1xk_kxn() {
18252        let dev = GpuDevice::new(0).expect("CUDA device 0");
18253        // A = [1, 2, 3] (1x3), B = [[1, 0], [0, 1], [1, 1]] (3x2)
18254        // C = [4, 5] (1x2)
18255        let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0], &dev).unwrap();
18256        let b = cpu_to_gpu(&[1.0f32, 0.0, 0.0, 1.0, 1.0, 1.0], &dev).unwrap();
18257        let c = gpu_small_matmul(&a, &b, 1, 3, 2, &dev).unwrap();
18258        assert_buf_eq(&c, &dev, &[4.0, 5.0]);
18259    }
18260
18261    #[test]
18262    fn small_matmul_vs_cublas() {
18263        // Compare our small matmul against cuBLAS for a realistic decode-step size.
18264        // Linear layer: [1, 64] @ [64, 64] = [1, 64]
18265        let dev = GpuDevice::new(0).expect("CUDA device 0");
18266        let m = 1;
18267        let k = 64;
18268        let n = 64;
18269
18270        // Deterministic data.
18271        let a_data: Vec<f32> = (0..m * k)
18272            .map(|i| ((i * 7 + 3) % 100) as f32 / 100.0)
18273            .collect();
18274        let b_data: Vec<f32> = (0..k * n)
18275            .map(|i| ((i * 11 + 5) % 100) as f32 / 100.0)
18276            .collect();
18277
18278        let a = cpu_to_gpu(&a_data, &dev).unwrap();
18279        let b = cpu_to_gpu(&b_data, &dev).unwrap();
18280
18281        // cuBLAS reference.
18282        let c_cublas = crate::blas::gpu_matmul_f32(&a, &b, m, k, n, &dev).unwrap();
18283        let cublas_result = gpu_to_cpu(&c_cublas, &dev).unwrap();
18284
18285        // Our kernel.
18286        let c_ours = gpu_small_matmul(&a, &b, m, k, n, &dev).unwrap();
18287        let our_result = gpu_to_cpu(&c_ours, &dev).unwrap();
18288
18289        assert_eq!(cublas_result.len(), our_result.len());
18290        for (i, (&cb, &ours)) in cublas_result.iter().zip(our_result.iter()).enumerate() {
18291            assert!(
18292                (cb - ours).abs() < 0.1,
18293                "element {i}: cuBLAS={cb}, ours={ours}, diff={}",
18294                (cb - ours).abs()
18295            );
18296        }
18297    }
18298
18299    // -- gpu_strided_copy (CL-496) -------------------------------------
18300
18301    #[test]
18302    fn strided_copy_identity_contiguous_2d() {
18303        // 2x3 contiguous — source strides are C-contiguous.
18304        // Source: [0, 1, 2, 3, 4, 5]
18305        // Expected output == source (identity copy).
18306        let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
18307        let (dev, input) = setup(&data);
18308        let out = gpu_strided_copy(&input, &[2, 3], &[3, 1], 0, &dev)
18309            .expect("strided_copy identity");
18310        assert_buf_eq(&out, &dev, &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
18311    }
18312
18313    #[test]
18314    fn strided_copy_transpose_2d() {
18315        // Source 2x3 contiguous:
18316        //   [[0, 1, 2],
18317        //    [3, 4, 5]]
18318        // Transposed view shape [3, 2] with strides [1, 3]:
18319        //   out[i, j] = src[j, i]
18320        //   Expected: [[0, 3], [1, 4], [2, 5]] flat = [0, 3, 1, 4, 2, 5]
18321        let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
18322        let (dev, input) = setup(&data);
18323        let out = gpu_strided_copy(&input, &[3, 2], &[1, 3], 0, &dev)
18324            .expect("strided_copy transpose");
18325        assert_buf_eq(&out, &dev, &[0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
18326    }
18327
18328    #[test]
18329    fn strided_copy_sliced_column() {
18330        // Source 3x4 contiguous:
18331        //   [[0, 1, 2, 3],
18332        //    [4, 5, 6, 7],
18333        //    [8, 9, 10, 11]]
18334        // Select column 2 via src_offset=2, shape=[3], stride=[4]:
18335        //   Expected: [2, 6, 10]
18336        let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
18337        let (dev, input) = setup(&data);
18338        let out = gpu_strided_copy(&input, &[3], &[4], 2, &dev)
18339            .expect("strided_copy col slice");
18340        assert_buf_eq(&out, &dev, &[2.0, 6.0, 10.0]);
18341    }
18342
18343    #[test]
18344    fn strided_copy_3d_permute() {
18345        // Source [2, 3, 4] contiguous, C-strides [12, 4, 1].
18346        // Permute (0, 2, 1) → view shape [2, 4, 3] with strides [12, 1, 4].
18347        //
18348        // out[b, i, j] = src[b, j, i]
18349        //
18350        // Build expected by doing the permute on the host.
18351        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
18352        let (dev, input) = setup(&data);
18353        let out =
18354            gpu_strided_copy(&input, &[2, 4, 3], &[12, 1, 4], 0, &dev).expect("strided_copy 3d");
18355
18356        let mut expected = vec![0.0f32; 24];
18357        for b in 0..2 {
18358            for i in 0..4 {
18359                for j in 0..3 {
18360                    let dst = b * 12 + i * 3 + j;
18361                    let src = b * 12 + j * 4 + i;
18362                    expected[dst] = data[src];
18363                }
18364            }
18365        }
18366        assert_buf_eq(&out, &dev, &expected);
18367    }
18368
18369    #[test]
18370    fn strided_copy_4d_max_rank_supported() {
18371        // Rank 4 identity copy works.
18372        let shape = [2usize, 3, 2, 2];
18373        let n: usize = shape.iter().product();
18374        let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
18375        let (dev, input) = setup(&data);
18376        // C-contiguous strides: [12, 4, 2, 1]
18377        let out = gpu_strided_copy(&input, &shape, &[12, 4, 2, 1], 0, &dev)
18378            .expect("strided_copy 4d");
18379        assert_buf_eq(&out, &dev, &data);
18380    }
18381
18382    #[test]
18383    fn strided_copy_rejects_too_many_dims() {
18384        let (dev, input) = setup(&[0.0f32; 16]);
18385        // 9 dims > STRIDED_COPY_MAX_DIMS (8)
18386        let result = gpu_strided_copy(
18387            &input,
18388            &[1, 1, 1, 1, 1, 1, 1, 1, 16],
18389            &[1; 9],
18390            0,
18391            &dev,
18392        );
18393        assert!(result.is_err());
18394    }
18395
18396    #[test]
18397    fn strided_copy_rejects_shape_stride_length_mismatch() {
18398        let (dev, input) = setup(&[0.0f32; 12]);
18399        let result = gpu_strided_copy(&input, &[3, 4], &[4, 1, 1], 0, &dev);
18400        assert!(result.is_err());
18401    }
18402
18403    #[test]
18404    fn strided_copy_rejects_negative_stride() {
18405        let (dev, input) = setup(&[0.0f32; 6]);
18406        let result = gpu_strided_copy(&input, &[2, 3], &[3, -1], 0, &dev);
18407        assert!(result.is_err());
18408    }
18409
18410    #[test]
18411    fn strided_copy_empty_output() {
18412        let (dev, input) = setup(&[1.0f32, 2.0, 3.0]);
18413        let out = gpu_strided_copy(&input, &[0, 3], &[3, 1], 0, &dev)
18414            .expect("strided_copy empty");
18415        assert_eq!(out.len(), 0);
18416    }
18417
18418    #[test]
18419    fn strided_copy_f64_transpose_matches_f32() {
18420        // Same transpose test as the f32 version, using f64.
18421        let data: Vec<f64> = (0..6).map(|i| i as f64).collect();
18422        let dev = GpuDevice::new(0).expect("CUDA device 0");
18423        let input = cpu_to_gpu(&data, &dev).expect("cpu_to_gpu f64");
18424        let out = gpu_strided_copy_f64(&input, &[3, 2], &[1, 3], 0, &dev)
18425            .expect("strided_copy_f64 transpose");
18426        let host = gpu_to_cpu(&out, &dev).expect("gpu_to_cpu f64");
18427        assert_eq!(host, vec![0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
18428    }
18429}