Skip to main content

ferrotorch_gpu/
kernels.rs

1//! Custom PTX CUDA kernels for elementwise GPU operations.
2//!
3//! Each operation has two code paths:
4//!
5//! 1. **PTX kernel** -- a hand-written PTX string that is loaded into the CUDA
6//!    driver at runtime via [`cudarc`]. This is the fast path and runs entirely
7//!    on the GPU.
8//! 2. **CPU fallback** -- copies data to the host, performs the operation with
9//!    standard Rust iterators, and copies the result back. Correct but slow;
10//!    used when the PTX module cannot be loaded (e.g. architecture mismatch).
11//!
12//! # Supported operations
13//!
14//! | Function | Formula |
15//! |----------|---------|
16//! | [`gpu_add`] | `out[i] = a[i] + b[i]` |
17//! | [`gpu_sub`] | `out[i] = a[i] - b[i]` |
18//! | [`gpu_mul`] | `out[i] = a[i] * b[i]` |
19//! | [`gpu_neg`] | `out[i] = -a[i]` |
20//! | [`gpu_relu`] | `out[i] = max(a[i], 0.0)` |
21
22#[cfg(feature = "cuda")]
23use cudarc::driver::LaunchConfig;
24
25use crate::buffer::CudaBuffer;
26use crate::device::GpuDevice;
27use crate::error::{GpuError, GpuResult};
28#[cfg(feature = "cuda")]
29use crate::transfer::{alloc_zeros_f32, alloc_zeros_f64, cpu_to_gpu, gpu_to_cpu};
30
31// ---------------------------------------------------------------------------
32// f32 → f64 PTX auto-conversion
33// ---------------------------------------------------------------------------
34
35/// Convert an f32 PTX kernel string to its f64 equivalent by applying
36/// mechanical substitutions. Works for "simple" kernels where the only
37/// difference between f32 and f64 is register types, load/store widths,
38/// byte offsets, and float literals.
39///
40/// Does NOT work for kernels that use `ex2.approx.f32` or `lg2.approx.f32`
41/// (transcendentals) — those need hand-written f64 implementations.
42#[cfg(feature = "cuda")]
43pub(crate) fn ptx_f32_to_f64(f32_ptx: &str, f32_kernel_name: &str, f64_kernel_name: &str) -> String {
44    f32_ptx
45        // Kernel entry point name
46        .replace(f32_kernel_name, f64_kernel_name)
47        // Register declarations
48        .replace(".reg .f32", ".reg .f64")
49        // Memory operations (must come before arithmetic to avoid double-replace)
50        .replace("ld.global.f32", "ld.global.f64")
51        .replace("st.global.f32", "st.global.f64")
52        .replace("ld.shared.f32", "ld.shared.f64")
53        .replace("st.shared.f32", "st.shared.f64")
54        .replace("ld.param.f32", "ld.param.f64")
55        .replace(".param .f32", ".param .f64")
56        // Shared memory declarations
57        .replace(".shared .align 4 .f32", ".shared .align 8 .f64")
58        // Arithmetic
59        .replace("add.f32", "add.f64")
60        .replace("sub.f32", "sub.f64")
61        .replace("mul.f32", "mul.f64")
62        .replace("div.rn.f32", "div.rn.f64")
63        .replace("div.f32", "div.f64")
64        .replace("neg.f32", "neg.f64")
65        .replace("abs.f32", "abs.f64")
66        .replace("max.f32", "max.f64")
67        .replace("min.f32", "min.f64")
68        .replace("sqrt.rn.f32", "sqrt.rn.f64")
69        .replace("sqrt.f32", "sqrt.f64")
70        .replace("fma.rn.f32", "fma.rn.f64")
71        .replace("mov.f32", "mov.f64")
72        // Comparisons
73        .replace("setp.gt.f32", "setp.gt.f64")
74        .replace("setp.ge.f32", "setp.ge.f64")
75        .replace("setp.lt.f32", "setp.lt.f64")
76        .replace("setp.le.f32", "setp.le.f64")
77        .replace("setp.eq.f32", "setp.eq.f64")
78        .replace("setp.ne.f32", "setp.ne.f64")
79        // Conversions
80        .replace("cvt.rn.f32.u32", "cvt.rn.f64.u32")
81        .replace("cvt.rn.f32.s32", "cvt.rn.f64.s32")
82        // Bit reinterpretation (for NaN/inf checks)
83        .replace("mov.b32", "mov.b64")
84        // Byte offset: 4 bytes per f32 → 8 bytes per f64
85        .replace("shl.b64 %off, %off, 2", "shl.b64 %off, %off, 3")
86        // Atomics
87        .replace("atom.global.add.f32", "atom.global.add.f64")
88        // Common float hex literals
89        .replace("0f00000000", "0d0000000000000000")     // 0.0
90        .replace("0f3F800000", "0d3FF0000000000000")     // 1.0
91        .replace("0fBF800000", "0dBFF0000000000000")     // -1.0
92        .replace("0f40000000", "0d4000000000000000")     // 2.0
93        .replace("0f3F000000", "0d3FE0000000000000")     // 0.5
94        .replace("0fFF800000", "0dFFF0000000000000")     // -inf
95        .replace("0f7F800000", "0d7FF0000000000000")     // +inf
96        .replace("0f3FB8AA3B", "0d3FF71547652B82FE")     // log2(e)
97        .replace("0f3F317218", "0d3FE62E42FEFA39EF")     // ln(2)
98}
99
100/// Helper to get or create a cached f64 PTX string from an f32 source.
101///
102/// Uses a global cache so the string transformation only happens once per
103/// kernel. The returned `&str` is valid for the lifetime of the program.
104#[cfg(feature = "cuda")]
105pub(crate) fn get_f64_ptx<'a>(
106    cache: &'a std::sync::OnceLock<String>,
107    f32_ptx: &str,
108    f32_name: &str,
109    f64_name: &str,
110) -> &'a str {
111    cache.get_or_init(|| ptx_f32_to_f64(f32_ptx, f32_name, f64_name))
112}
113
114// ---------------------------------------------------------------------------
115// PTX kernel source strings
116// ---------------------------------------------------------------------------
117
118/// PTX source for `add_kernel`: `out[i] = a[i] + b[i]`.
119#[cfg(feature = "cuda")]
120pub(crate) const ADD_PTX: &str = "\
121.version 7.0
122.target sm_52
123.address_size 64
124
125.visible .entry add_kernel(
126    .param .u64 a_ptr,
127    .param .u64 b_ptr,
128    .param .u64 out_ptr,
129    .param .u32 n
130) {
131    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
132    .reg .u64 %a, %b, %out, %off;
133    .reg .f32 %va, %vb, %vr;
134    .reg .pred %p;
135
136    ld.param.u64 %a, [a_ptr];
137    ld.param.u64 %b, [b_ptr];
138    ld.param.u64 %out, [out_ptr];
139    ld.param.u32 %n_reg, [n];
140
141    mov.u32 %bid, %ctaid.x;
142    mov.u32 %bdim, %ntid.x;
143    mov.u32 %r_tid, %tid.x;
144    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
145
146    setp.ge.u32 %p, %r_tid, %n_reg;
147    @%p bra DONE;
148
149    cvt.u64.u32 %off, %r_tid;
150    shl.b64 %off, %off, 2;
151
152    add.u64 %a, %a, %off;
153    add.u64 %b, %b, %off;
154    add.u64 %out, %out, %off;
155
156    ld.global.f32 %va, [%a];
157    ld.global.f32 %vb, [%b];
158    add.f32 %vr, %va, %vb;
159    st.global.f32 [%out], %vr;
160
161DONE:
162    ret;
163}
164";
165
166
167/// PTX source for `add_vec4_kernel`: vectorized add, 4 elements per thread.
168///
169/// Uses `ld.global.v4.f32` (128-bit load) for 4x memory throughput vs scalar.
170/// Thread i processes elements [i*4 .. i*4+3].
171#[cfg(feature = "cuda")]
172pub(crate) const ADD_VEC4_PTX: &str = "\
173.version 7.0
174.target sm_52
175.address_size 64
176
177.visible .entry add_vec4_kernel(
178    .param .u64 a_ptr,
179    .param .u64 b_ptr,
180    .param .u64 out_ptr,
181    .param .u32 n4
182) {
183    .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
184    .reg .u64 %a, %b, %out, %off;
185    .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
186    .reg .pred %p;
187
188    ld.param.u64 %a, [a_ptr];
189    ld.param.u64 %b, [b_ptr];
190    ld.param.u64 %out, [out_ptr];
191    ld.param.u32 %n4_reg, [n4];
192
193    mov.u32 %bid, %ctaid.x;
194    mov.u32 %bdim, %ntid.x;
195    mov.u32 %r_tid, %tid.x;
196    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
197
198    setp.ge.u32 %p, %r_tid, %n4_reg;
199    @%p bra DONE;
200
201    // Byte offset = tid * 16 (4 floats × 4 bytes)
202    cvt.u64.u32 %off, %r_tid;
203    shl.b64 %off, %off, 4;
204
205    add.u64 %a, %a, %off;
206    add.u64 %b, %b, %off;
207    add.u64 %out, %out, %off;
208
209    ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
210    ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
211
212    add.f32 %r0, %a0, %b0;
213    add.f32 %r1, %a1, %b1;
214    add.f32 %r2, %a2, %b2;
215    add.f32 %r3, %a3, %b3;
216
217    st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
218
219DONE:
220    ret;
221}
222";
223
224/// PTX source for `mul_vec4_kernel`: vectorized multiply, 4 elements per thread.
225#[cfg(feature = "cuda")]
226pub(crate) const MUL_VEC4_PTX: &str = "\
227.version 7.0
228.target sm_52
229.address_size 64
230
231.visible .entry mul_vec4_kernel(
232    .param .u64 a_ptr,
233    .param .u64 b_ptr,
234    .param .u64 out_ptr,
235    .param .u32 n4
236) {
237    .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
238    .reg .u64 %a, %b, %out, %off;
239    .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
240    .reg .pred %p;
241
242    ld.param.u64 %a, [a_ptr];
243    ld.param.u64 %b, [b_ptr];
244    ld.param.u64 %out, [out_ptr];
245    ld.param.u32 %n4_reg, [n4];
246
247    mov.u32 %bid, %ctaid.x;
248    mov.u32 %bdim, %ntid.x;
249    mov.u32 %r_tid, %tid.x;
250    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
251
252    setp.ge.u32 %p, %r_tid, %n4_reg;
253    @%p bra DONE;
254
255    cvt.u64.u32 %off, %r_tid;
256    shl.b64 %off, %off, 4;
257
258    add.u64 %a, %a, %off;
259    add.u64 %b, %b, %off;
260    add.u64 %out, %out, %off;
261
262    ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
263    ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
264
265    mul.f32 %r0, %a0, %b0;
266    mul.f32 %r1, %a1, %b1;
267    mul.f32 %r2, %a2, %b2;
268    mul.f32 %r3, %a3, %b3;
269
270    st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
271
272DONE:
273    ret;
274}
275";
276
277/// PTX source for `sub_kernel`: `out[i] = a[i] - b[i]`.
278#[cfg(feature = "cuda")]
279pub(crate) const SUB_PTX: &str = "\
280.version 7.0
281.target sm_52
282.address_size 64
283
284.visible .entry sub_kernel(
285    .param .u64 a_ptr,
286    .param .u64 b_ptr,
287    .param .u64 out_ptr,
288    .param .u32 n
289) {
290    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
291    .reg .u64 %a, %b, %out, %off;
292    .reg .f32 %va, %vb, %vr;
293    .reg .pred %p;
294
295    ld.param.u64 %a, [a_ptr];
296    ld.param.u64 %b, [b_ptr];
297    ld.param.u64 %out, [out_ptr];
298    ld.param.u32 %n_reg, [n];
299
300    mov.u32 %bid, %ctaid.x;
301    mov.u32 %bdim, %ntid.x;
302    mov.u32 %r_tid, %tid.x;
303    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
304
305    setp.ge.u32 %p, %r_tid, %n_reg;
306    @%p bra DONE;
307
308    cvt.u64.u32 %off, %r_tid;
309    shl.b64 %off, %off, 2;
310
311    add.u64 %a, %a, %off;
312    add.u64 %b, %b, %off;
313    add.u64 %out, %out, %off;
314
315    ld.global.f32 %va, [%a];
316    ld.global.f32 %vb, [%b];
317    sub.f32 %vr, %va, %vb;
318    st.global.f32 [%out], %vr;
319
320DONE:
321    ret;
322}
323";
324
325
326/// PTX source for `mul_kernel`: `out[i] = a[i] * b[i]`.
327#[cfg(feature = "cuda")]
328pub(crate) const MUL_PTX: &str = "\
329.version 7.0
330.target sm_52
331.address_size 64
332
333.visible .entry mul_kernel(
334    .param .u64 a_ptr,
335    .param .u64 b_ptr,
336    .param .u64 out_ptr,
337    .param .u32 n
338) {
339    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
340    .reg .u64 %a, %b, %out, %off;
341    .reg .f32 %va, %vb, %vr;
342    .reg .pred %p;
343
344    ld.param.u64 %a, [a_ptr];
345    ld.param.u64 %b, [b_ptr];
346    ld.param.u64 %out, [out_ptr];
347    ld.param.u32 %n_reg, [n];
348
349    mov.u32 %bid, %ctaid.x;
350    mov.u32 %bdim, %ntid.x;
351    mov.u32 %r_tid, %tid.x;
352    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
353
354    setp.ge.u32 %p, %r_tid, %n_reg;
355    @%p bra DONE;
356
357    cvt.u64.u32 %off, %r_tid;
358    shl.b64 %off, %off, 2;
359
360    add.u64 %a, %a, %off;
361    add.u64 %b, %b, %off;
362    add.u64 %out, %out, %off;
363
364    ld.global.f32 %va, [%a];
365    ld.global.f32 %vb, [%b];
366    mul.f32 %vr, %va, %vb;
367    st.global.f32 [%out], %vr;
368
369DONE:
370    ret;
371}
372";
373
374
375/// PTX source for `neg_kernel`: `out[i] = -a[i]`.
376#[cfg(feature = "cuda")]
377pub(crate) const NEG_PTX: &str = "\
378.version 7.0
379.target sm_52
380.address_size 64
381
382.visible .entry neg_kernel(
383    .param .u64 a_ptr,
384    .param .u64 out_ptr,
385    .param .u32 n
386) {
387    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
388    .reg .u64 %a, %out, %off;
389    .reg .f32 %va, %vr;
390    .reg .pred %p;
391
392    ld.param.u64 %a, [a_ptr];
393    ld.param.u64 %out, [out_ptr];
394    ld.param.u32 %n_reg, [n];
395
396    mov.u32 %bid, %ctaid.x;
397    mov.u32 %bdim, %ntid.x;
398    mov.u32 %r_tid, %tid.x;
399    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
400
401    setp.ge.u32 %p, %r_tid, %n_reg;
402    @%p bra DONE;
403
404    cvt.u64.u32 %off, %r_tid;
405    shl.b64 %off, %off, 2;
406
407    add.u64 %a, %a, %off;
408    add.u64 %out, %out, %off;
409
410    ld.global.f32 %va, [%a];
411    neg.f32 %vr, %va;
412    st.global.f32 [%out], %vr;
413
414DONE:
415    ret;
416}
417";
418
419
420/// PTX source for `relu_kernel`: `out[i] = max(a[i], 0.0)`.
421#[cfg(feature = "cuda")]
422pub(crate) const RELU_PTX: &str = "\
423.version 7.0
424.target sm_52
425.address_size 64
426
427.visible .entry relu_kernel(
428    .param .u64 a_ptr,
429    .param .u64 out_ptr,
430    .param .u32 n
431) {
432    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
433    .reg .u64 %a, %out, %off;
434    .reg .f32 %va, %vr, %zero;
435    .reg .pred %p;
436
437    ld.param.u64 %a, [a_ptr];
438    ld.param.u64 %out, [out_ptr];
439    ld.param.u32 %n_reg, [n];
440
441    mov.u32 %bid, %ctaid.x;
442    mov.u32 %bdim, %ntid.x;
443    mov.u32 %r_tid, %tid.x;
444    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
445
446    setp.ge.u32 %p, %r_tid, %n_reg;
447    @%p bra DONE;
448
449    cvt.u64.u32 %off, %r_tid;
450    shl.b64 %off, %off, 2;
451
452    add.u64 %a, %a, %off;
453    add.u64 %out, %out, %off;
454
455    ld.global.f32 %va, [%a];
456    mov.f32 %zero, 0f00000000;
457    max.f32 %vr, %va, %zero;
458    st.global.f32 [%out], %vr;
459
460DONE:
461    ret;
462}
463";
464
465
466/// PTX source for `scale_kernel`: `out[i] = a[i] * scalar`.
467#[cfg(feature = "cuda")]
468pub(crate) const SCALE_PTX: &str = "\
469.version 7.0
470.target sm_52
471.address_size 64
472
473.visible .entry scale_kernel(
474    .param .u64 a_ptr,
475    .param .u64 out_ptr,
476    .param .f32 scalar,
477    .param .u32 n
478) {
479    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
480    .reg .u64 %a, %out, %off;
481    .reg .f32 %va, %vr, %s;
482    .reg .pred %p;
483
484    ld.param.u64 %a, [a_ptr];
485    ld.param.u64 %out, [out_ptr];
486    ld.param.f32 %s, [scalar];
487    ld.param.u32 %n_reg, [n];
488
489    mov.u32 %bid, %ctaid.x;
490    mov.u32 %bdim, %ntid.x;
491    mov.u32 %r_tid, %tid.x;
492    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
493
494    setp.ge.u32 %p, %r_tid, %n_reg;
495    @%p bra DONE;
496
497    cvt.u64.u32 %off, %r_tid;
498    shl.b64 %off, %off, 2;
499
500    add.u64 %a, %a, %off;
501    add.u64 %out, %out, %off;
502
503    ld.global.f32 %va, [%a];
504    mul.f32 %vr, %va, %s;
505    st.global.f32 [%out], %vr;
506
507DONE:
508    ret;
509}
510";
511
512
513/// PTX for 2D matrix transpose: `out[j * M + i] = in[i * N + j]`.
514/// Thread `tid` maps to output index; computes the corresponding input index.
515#[cfg(feature = "cuda")]
516pub(crate) const TRANSPOSE_2D_PTX: &str = "\
517.version 7.0\n\
518.target sm_52\n\
519.address_size 64\n\
520\n\
521.visible .entry transpose_2d_kernel(\n\
522    .param .u64 in_ptr,\n\
523    .param .u64 out_ptr,\n\
524    .param .u32 M,\n\
525    .param .u32 N,\n\
526    .param .u32 total\n\
527) {\n\
528    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %N_reg;\n\
529    .reg .u32 %out_row, %out_col, %in_idx;\n\
530    .reg .u64 %in, %out, %off_in, %off_out;\n\
531    .reg .f32 %val;\n\
532    .reg .pred %p;\n\
533\n\
534    ld.param.u64 %in, [in_ptr];\n\
535    ld.param.u64 %out, [out_ptr];\n\
536    ld.param.u32 %M_reg, [M];\n\
537    ld.param.u32 %N_reg, [N];\n\
538    ld.param.u32 %total_reg, [total];\n\
539\n\
540    mov.u32 %bid, %ctaid.x;\n\
541    mov.u32 %bdim, %ntid.x;\n\
542    mov.u32 %r_tid, %tid.x;\n\
543    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
544\n\
545    setp.ge.u32 %p, %r_tid, %total_reg;\n\
546    @%p bra DONE;\n\
547\n\
548    // Output shape is [N, M]. tid = out_row * M + out_col.\n\
549    div.u32 %out_row, %r_tid, %M_reg;\n\
550    rem.u32 %out_col, %r_tid, %M_reg;\n\
551    // Input index: out_col * N + out_row (transposed).\n\
552    mad.lo.u32 %in_idx, %out_col, %N_reg, %out_row;\n\
553\n\
554    cvt.u64.u32 %off_in, %in_idx;\n\
555    shl.b64 %off_in, %off_in, 2;\n\
556    add.u64 %off_in, %in, %off_in;\n\
557    ld.global.f32 %val, [%off_in];\n\
558\n\
559    cvt.u64.u32 %off_out, %r_tid;\n\
560    shl.b64 %off_out, %off_out, 2;\n\
561    add.u64 %off_out, %out, %off_out;\n\
562    st.global.f32 [%off_out], %val;\n\
563\n\
564DONE:\n\
565    ret;\n\
566}\n\
567";
568
569
570// ---------------------------------------------------------------------------
571// 4D permute (0,2,1,3) PTX kernel — swap dims 1 and 2
572// ---------------------------------------------------------------------------
573// Input:  [d0, d1, d2, d3]
574// Output: [d0, d2, d1, d3]
575// Thread i computes output[i] by mapping to the transposed input index.
576
577#[cfg(feature = "cuda")]
578pub(crate) const PERMUTE_0213_PTX: &str = "\
579.version 7.0\n\
580.target sm_52\n\
581.address_size 64\n\
582\n\
583.visible .entry permute_0213_kernel(\n\
584    .param .u64 in_ptr,\n\
585    .param .u64 out_ptr,\n\
586    .param .u32 d0,\n\
587    .param .u32 d1,\n\
588    .param .u32 d2,\n\
589    .param .u32 d3,\n\
590    .param .u32 total\n\
591) {\n\
592    .reg .u32 %r_tid, %bid, %bdim, %total_reg;\n\
593    .reg .u32 %d0r, %d1r, %d2r, %d3r;\n\
594    .reg .u32 %i0, %i1, %i2, %i3, %rem, %in_idx;\n\
595    .reg .u32 %s_out2, %s_out1, %s_in1;\n\
596    .reg .u64 %in, %out, %off_in, %off_out;\n\
597    .reg .f32 %val;\n\
598    .reg .pred %p;\n\
599\n\
600    ld.param.u64 %in, [in_ptr];\n\
601    ld.param.u64 %out, [out_ptr];\n\
602    ld.param.u32 %d0r, [d0];\n\
603    ld.param.u32 %d1r, [d1];\n\
604    ld.param.u32 %d2r, [d2];\n\
605    ld.param.u32 %d3r, [d3];\n\
606    ld.param.u32 %total_reg, [total];\n\
607\n\
608    mov.u32 %bid, %ctaid.x;\n\
609    mov.u32 %bdim, %ntid.x;\n\
610    mov.u32 %r_tid, %tid.x;\n\
611    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
612\n\
613    setp.ge.u32 %p, %r_tid, %total_reg;\n\
614    @%p bra DONE;\n\
615\n\
616    // Output shape: [d0, d2, d1, d3]\n\
617    // Decompose tid into (i0, i2, i1, i3) in output layout.\n\
618    mul.lo.u32 %s_out2, %d1r, %d3r;\n\
619    mul.lo.u32 %s_out1, %s_out2, %d2r;\n\
620\n\
621    div.u32 %i0, %r_tid, %s_out1;\n\
622    rem.u32 %rem, %r_tid, %s_out1;\n\
623    div.u32 %i2, %rem, %s_out2;\n\
624    rem.u32 %rem, %rem, %s_out2;\n\
625    div.u32 %i1, %rem, %d3r;\n\
626    rem.u32 %i3, %rem, %d3r;\n\
627\n\
628    // Input index: i0 * (d1*d2*d3) + i1 * (d2*d3) + i2 * d3 + i3\n\
629    mul.lo.u32 %s_in1, %d2r, %d3r;\n\
630    mul.lo.u32 %in_idx, %i0, %d1r;\n\
631    add.u32 %in_idx, %in_idx, %i1;\n\
632    mul.lo.u32 %in_idx, %in_idx, %s_in1;\n\
633    mad.lo.u32 %in_idx, %i2, %d3r, %in_idx;\n\
634    add.u32 %in_idx, %in_idx, %i3;\n\
635\n\
636    cvt.u64.u32 %off_in, %in_idx;\n\
637    shl.b64 %off_in, %off_in, 2;\n\
638    add.u64 %off_in, %in, %off_in;\n\
639    ld.global.f32 %val, [%off_in];\n\
640\n\
641    cvt.u64.u32 %off_out, %r_tid;\n\
642    shl.b64 %off_out, %off_out, 2;\n\
643    add.u64 %off_out, %out, %off_out;\n\
644    st.global.f32 [%off_out], %val;\n\
645\n\
646DONE:\n\
647    ret;\n\
648}\n\
649";
650
651
652// ---------------------------------------------------------------------------
653// f32-to-f16 conversion PTX kernel: out_f16[i] = float2half(in_f32[i])
654// ---------------------------------------------------------------------------
655// Used by gpu_matmul_f16 to cast f32 inputs to f16 on-GPU before calling
656// cublasGemmEx. The output is stored as u16 (IEEE 754 half-precision bits).
657
658#[cfg(feature = "cuda")]
659pub(crate) const F32_TO_F16_PTX: &str = "\
660.version 7.0
661.target sm_52
662.address_size 64
663
664.visible .entry f32_to_f16_kernel(
665    .param .u64 in_ptr,
666    .param .u64 out_ptr,
667    .param .u32 n
668) {
669    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
670    .reg .u64 %in, %out, %off_in, %off_out;
671    .reg .f32 %vf;
672    .reg .b16 %vh;
673    .reg .pred %p;
674
675    ld.param.u64 %in, [in_ptr];
676    ld.param.u64 %out, [out_ptr];
677    ld.param.u32 %n_reg, [n];
678
679    mov.u32 %bid, %ctaid.x;
680    mov.u32 %bdim, %ntid.x;
681    mov.u32 %r_tid, %tid.x;
682    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
683
684    setp.ge.u32 %p, %r_tid, %n_reg;
685    @%p bra DONE;
686
687    // Compute input offset: i * 4 (f32 = 4 bytes)
688    cvt.u64.u32 %off_in, %r_tid;
689    shl.b64 %off_in, %off_in, 2;
690    add.u64 %in, %in, %off_in;
691
692    // Compute output offset: i * 2 (f16 = 2 bytes)
693    cvt.u64.u32 %off_out, %r_tid;
694    shl.b64 %off_out, %off_out, 1;
695    add.u64 %out, %out, %off_out;
696
697    // Load f32, convert to f16 (round-to-nearest-even), store as u16
698    ld.global.f32 %vf, [%in];
699    cvt.rn.f16.f32 %vh, %vf;
700    st.global.b16 [%out], %vh;
701
702DONE:
703    ret;
704}
705";
706
707/// PTX source for `f32_to_bf16_kernel`: convert f32 → bf16 (stored as u16).
708///
709/// BF16 is the top 16 bits of f32 with round-to-nearest-even. We do this
710/// with integer bit ops: add rounding bias 0x7FFF + bit 16 of the value,
711/// then shift right 16. This works on sm_52+ (no special bf16 instructions
712/// needed).
713#[cfg(feature = "cuda")]
714pub(crate) const F32_TO_BF16_PTX: &str = "\
715.version 7.0
716.target sm_52
717.address_size 64
718
719.visible .entry f32_to_bf16_kernel(
720    .param .u64 in_ptr,
721    .param .u64 out_ptr,
722    .param .u32 n
723) {
724    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
725    .reg .u64 %in, %out, %off_in, %off_out;
726    .reg .f32 %vf;
727    .reg .u32 %bits, %round, %lsb, %result;
728    .reg .pred %p;
729
730    ld.param.u64 %in, [in_ptr];
731    ld.param.u64 %out, [out_ptr];
732    ld.param.u32 %n_reg, [n];
733
734    mov.u32 %bid, %ctaid.x;
735    mov.u32 %bdim, %ntid.x;
736    mov.u32 %r_tid, %tid.x;
737    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
738
739    setp.ge.u32 %p, %r_tid, %n_reg;
740    @%p bra DONE;
741
742    cvt.u64.u32 %off_in, %r_tid;
743    shl.b64 %off_in, %off_in, 2;
744    add.u64 %in, %in, %off_in;
745
746    cvt.u64.u32 %off_out, %r_tid;
747    shl.b64 %off_out, %off_out, 1;
748    add.u64 %out, %out, %off_out;
749
750    // Load f32 as raw bits
751    ld.global.u32 %bits, [%in];
752
753    // Round-to-nearest-even: add (0x7FFF + bit[16]) then shift right 16
754    shr.u32 %lsb, %bits, 16;
755    and.b32 %lsb, %lsb, 1;
756    add.u32 %round, %bits, 0x7FFF;
757    add.u32 %round, %round, %lsb;
758    shr.u32 %result, %round, 16;
759
760    // Store as u16
761    st.global.u16 [%out], %result;
762
763DONE:
764    ret;
765}
766";
767
768// ---------------------------------------------------------------------------
769// Small matmul PTX kernel: C = A @ B, one thread per output element
770// ---------------------------------------------------------------------------
771// For small matrices where cuBLAS JIT compilation overhead > compute time.
772// Compiles once via module_cache, never JIT-recompiles for different sizes.
773
774#[cfg(feature = "cuda")]
775pub(crate) const SMALL_MATMUL_PTX: &str = "\
776.version 7.0
777.target sm_52
778.address_size 64
779
780.visible .entry small_matmul_kernel(
781    .param .u64 a_ptr,
782    .param .u64 b_ptr,
783    .param .u64 c_ptr,
784    .param .u32 M,
785    .param .u32 K,
786    .param .u32 N,
787    .param .u32 total
788) {
789    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %K_reg, %N_reg;
790    .reg .u32 %row, %col, %p, %idx;
791    .reg .u64 %a, %b, %c, %a_off, %b_off, %c_off;
792    .reg .f32 %sum, %va, %vb;
793    .reg .pred %bounds_p, %loop_p;
794
795    ld.param.u64 %a, [a_ptr];
796    ld.param.u64 %b, [b_ptr];
797    ld.param.u64 %c, [c_ptr];
798    ld.param.u32 %M_reg, [M];
799    ld.param.u32 %K_reg, [K];
800    ld.param.u32 %N_reg, [N];
801    ld.param.u32 %total_reg, [total];
802
803    mov.u32 %bid, %ctaid.x;
804    mov.u32 %bdim, %ntid.x;
805    mov.u32 %r_tid, %tid.x;
806    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
807
808    setp.ge.u32 %bounds_p, %r_tid, %total_reg;
809    @%bounds_p bra DONE;
810
811    div.u32 %row, %r_tid, %N_reg;
812    rem.u32 %col, %r_tid, %N_reg;
813
814    mov.f32 %sum, 0f00000000;
815    mov.u32 %p, 0;
816DOT:
817    setp.ge.u32 %loop_p, %p, %K_reg;
818    @%loop_p bra DOT_DONE;
819
820    mad.lo.u32 %idx, %row, %K_reg, %p;
821    cvt.u64.u32 %a_off, %idx;
822    shl.b64 %a_off, %a_off, 2;
823    add.u64 %a_off, %a, %a_off;
824    ld.global.f32 %va, [%a_off];
825
826    mad.lo.u32 %idx, %p, %N_reg, %col;
827    cvt.u64.u32 %b_off, %idx;
828    shl.b64 %b_off, %b_off, 2;
829    add.u64 %b_off, %b, %b_off;
830    ld.global.f32 %vb, [%b_off];
831
832    fma.rn.f32 %sum, %va, %vb, %sum;
833    add.u32 %p, %p, 1;
834    bra DOT;
835DOT_DONE:
836
837    cvt.u64.u32 %c_off, %r_tid;
838    shl.b64 %c_off, %c_off, 2;
839    add.u64 %c_off, %c, %c_off;
840    st.global.f32 [%c_off], %sum;
841
842DONE:
843    ret;
844}
845";
846
847// ---------------------------------------------------------------------------
848// Slice-write PTX kernel: copy [N, D] into row `pos` of [N, max_len, D]
849// ---------------------------------------------------------------------------
850
851#[cfg(feature = "cuda")]
852pub(crate) const SLICE_WRITE_PTX: &str = "\
853.version 7.0
854.target sm_52
855.address_size 64
856
857.visible .entry slice_write_kernel(
858    .param .u64 src_ptr,
859    .param .u64 dst_ptr,
860    .param .u32 n,
861    .param .u32 D,
862    .param .u32 max_len,
863    .param .u32 pos
864) {
865    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
866    .reg .u32 %batch_idx, %d_idx, %dst_row;
867    .reg .u64 %src, %dst, %src_off, %dst_off;
868    .reg .f32 %val;
869    .reg .pred %p;
870
871    ld.param.u64 %src, [src_ptr];
872    ld.param.u64 %dst, [dst_ptr];
873    ld.param.u32 %n_reg, [n];
874    ld.param.u32 %D_reg, [D];
875    ld.param.u32 %max_len_reg, [max_len];
876    ld.param.u32 %pos_reg, [pos];
877
878    mov.u32 %bid, %ctaid.x;
879    mov.u32 %bdim, %ntid.x;
880    mov.u32 %r_tid, %tid.x;
881    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
882
883    setp.ge.u32 %p, %r_tid, %n_reg;
884    @%p bra DONE;
885
886    cvt.u64.u32 %src_off, %r_tid;
887    shl.b64 %src_off, %src_off, 2;
888    add.u64 %src, %src, %src_off;
889    ld.global.f32 %val, [%src];
890
891    div.u32 %batch_idx, %r_tid, %D_reg;
892    rem.u32 %d_idx, %r_tid, %D_reg;
893    mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
894    add.u32 %dst_row, %dst_row, %pos_reg;
895    mul.lo.u32 %dst_row, %dst_row, %D_reg;
896    add.u32 %dst_row, %dst_row, %d_idx;
897    cvt.u64.u32 %dst_off, %dst_row;
898    shl.b64 %dst_off, %dst_off, 2;
899    add.u64 %dst, %dst, %dst_off;
900    st.global.f32 [%dst], %val;
901
902DONE:
903    ret;
904}
905";
906
907
908/// PTX for `slice_write_indirect_kernel`: same as `slice_write_kernel` but
909/// reads `pos` from a device pointer. This enables CUDA graph capture — the
910/// graph records the pointer address (fixed), and we update the u32 value
911/// at that address before each graph replay.
912#[cfg(feature = "cuda")]
913pub(crate) const SLICE_WRITE_INDIRECT_PTX: &str = "\
914.version 7.0
915.target sm_52
916.address_size 64
917
918.visible .entry slice_write_indirect_kernel(
919    .param .u64 src_ptr,
920    .param .u64 dst_ptr,
921    .param .u32 n,
922    .param .u32 D,
923    .param .u32 max_len,
924    .param .u64 pos_ptr
925) {
926    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
927    .reg .u32 %batch_idx, %d_idx, %dst_row;
928    .reg .u64 %src, %dst, %src_off, %dst_off, %pos_p;
929    .reg .f32 %val;
930    .reg .pred %p;
931
932    ld.param.u64 %src, [src_ptr];
933    ld.param.u64 %dst, [dst_ptr];
934    ld.param.u32 %n_reg, [n];
935    ld.param.u32 %D_reg, [D];
936    ld.param.u32 %max_len_reg, [max_len];
937    ld.param.u64 %pos_p, [pos_ptr];
938    ld.global.u32 %pos_reg, [%pos_p];
939
940    mov.u32 %bid, %ctaid.x;
941    mov.u32 %bdim, %ntid.x;
942    mov.u32 %r_tid, %tid.x;
943    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
944
945    setp.ge.u32 %p, %r_tid, %n_reg;
946    @%p bra DONE;
947
948    cvt.u64.u32 %src_off, %r_tid;
949    shl.b64 %src_off, %src_off, 2;
950    add.u64 %src, %src, %src_off;
951    ld.global.f32 %val, [%src];
952
953    div.u32 %batch_idx, %r_tid, %D_reg;
954    rem.u32 %d_idx, %r_tid, %D_reg;
955    mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
956    add.u32 %dst_row, %dst_row, %pos_reg;
957    mul.lo.u32 %dst_row, %dst_row, %D_reg;
958    add.u32 %dst_row, %dst_row, %d_idx;
959    cvt.u64.u32 %dst_off, %dst_row;
960    shl.b64 %dst_off, %dst_off, 2;
961    add.u64 %dst, %dst, %dst_off;
962    st.global.f32 [%dst], %val;
963
964DONE:
965    ret;
966}
967";
968
969/// PTX for `causal_mask_indirect_kernel`: builds an attention mask where
970/// `out[h, col] = 0.0` for `col < total_len` and `-1e9` for `col >= total_len`.
971/// `total_len` is read from a device pointer (for CUDA graph capture).
972/// Output shape: `[n_head, max_pos]` — one mask row per head (all identical).
973/// Thread `tid` maps to flat index; column = `tid % max_pos`.
974#[cfg(feature = "cuda")]
975pub(crate) const CAUSAL_MASK_INDIRECT_PTX: &str = "\
976.version 7.0
977.target sm_52
978.address_size 64
979
980.visible .entry causal_mask_indirect_kernel(
981    .param .u64 total_len_ptr,
982    .param .u64 out_ptr,
983    .param .u32 max_pos,
984    .param .u32 total
985) {
986    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %tlen, %max_pos_reg, %col;
987    .reg .u64 %out, %off, %tl_p;
988    .reg .f32 %val;
989    .reg .pred %bounds_p, %mask_p;
990
991    ld.param.u64 %tl_p, [total_len_ptr];
992    ld.param.u64 %out, [out_ptr];
993    ld.param.u32 %max_pos_reg, [max_pos];
994    ld.param.u32 %total_reg, [total];
995
996    ld.global.u32 %tlen, [%tl_p];
997
998    mov.u32 %bid, %ctaid.x;
999    mov.u32 %bdim, %ntid.x;
1000    mov.u32 %r_tid, %tid.x;
1001    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1002
1003    setp.ge.u32 %bounds_p, %r_tid, %total_reg;
1004    @%bounds_p bra DONE;
1005
1006    rem.u32 %col, %r_tid, %max_pos_reg;
1007    setp.lt.u32 %mask_p, %col, %tlen;
1008    @%mask_p bra WRITE_ZERO;
1009
1010    // 0fCE6E6B28 = -1.0e9 in IEEE 754 f32, used as a large negative mask value
1011    // to effectively zero out masked positions after softmax.
1012    mov.f32 %val, 0fCE6E6B28;
1013    bra WRITE;
1014
1015WRITE_ZERO:
1016    mov.f32 %val, 0f00000000;
1017
1018WRITE:
1019    cvt.u64.u32 %off, %r_tid;
1020    shl.b64 %off, %off, 2;
1021    add.u64 %out, %out, %off;
1022    st.global.f32 [%out], %val;
1023
1024DONE:
1025    ret;
1026}
1027";
1028
1029// ---------------------------------------------------------------------------
1030// Embedding lookup PTX kernel: output[d] = weight[token_id * D + d]
1031// ---------------------------------------------------------------------------
1032
1033#[cfg(feature = "cuda")]
1034pub(crate) const EMBED_LOOKUP_PTX: &str = "\
1035.version 7.0
1036.target sm_52
1037.address_size 64
1038
1039.visible .entry embed_lookup_kernel(
1040    .param .u64 idx_ptr,
1041    .param .u64 weight_ptr,
1042    .param .u64 out_ptr,
1043    .param .u32 D
1044) {
1045    .reg .u32 %r_tid, %bid, %bdim, %D_reg, %row, %src_idx;
1046    .reg .u64 %idx_addr, %w, %out, %off;
1047    .reg .f32 %idx_f, %val;
1048    .reg .pred %p;
1049
1050    ld.param.u64 %idx_addr, [idx_ptr];
1051    ld.param.u64 %w, [weight_ptr];
1052    ld.param.u64 %out, [out_ptr];
1053    ld.param.u32 %D_reg, [D];
1054
1055    mov.u32 %bid, %ctaid.x;
1056    mov.u32 %bdim, %ntid.x;
1057    mov.u32 %r_tid, %tid.x;
1058    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1059
1060    setp.ge.u32 %p, %r_tid, %D_reg;
1061    @%p bra DONE;
1062
1063    ld.global.f32 %idx_f, [%idx_addr];
1064    cvt.rzi.u32.f32 %row, %idx_f;
1065
1066    mad.lo.u32 %src_idx, %row, %D_reg, %r_tid;
1067    cvt.u64.u32 %off, %src_idx;
1068    shl.b64 %off, %off, 2;
1069    add.u64 %off, %w, %off;
1070    ld.global.f32 %val, [%off];
1071
1072    cvt.u64.u32 %off, %r_tid;
1073    shl.b64 %off, %off, 2;
1074    add.u64 %off, %out, %off;
1075    st.global.f32 [%off], %val;
1076
1077DONE:
1078    ret;
1079}
1080";
1081
1082
1083// ---------------------------------------------------------------------------
1084// Batch embedding lookup PTX kernel
1085// ---------------------------------------------------------------------------
1086// Given N f32 indices and a weight matrix [V, D], gather N rows into [N, D].
1087// Thread `tid` computes one element: row = tid / D, col = tid % D.
1088// out[tid] = weight[indices[row] * D + col]
1089
1090#[cfg(feature = "cuda")]
1091pub(crate) const EMBED_LOOKUP_BATCH_PTX: &str = "\
1092.version 7.0
1093.target sm_52
1094.address_size 64
1095
1096.visible .entry embed_lookup_batch_kernel(
1097    .param .u64 idx_ptr,
1098    .param .u64 weight_ptr,
1099    .param .u64 out_ptr,
1100    .param .u32 D,
1101    .param .u32 total
1102) {
1103    .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1104    .reg .u32 %row, %col, %src_idx;
1105    .reg .u64 %idx_addr, %w, %out, %off;
1106    .reg .f32 %idx_f, %val;
1107    .reg .pred %p;
1108
1109    ld.param.u64 %idx_addr, [idx_ptr];
1110    ld.param.u64 %w, [weight_ptr];
1111    ld.param.u64 %out, [out_ptr];
1112    ld.param.u32 %D_reg, [D];
1113    ld.param.u32 %total_reg, [total];
1114
1115    mov.u32 %bid, %ctaid.x;
1116    mov.u32 %bdim, %ntid.x;
1117    mov.u32 %tid, %tid.x;
1118    mad.lo.u32 %tid, %bid, %bdim, %tid;
1119
1120    setp.ge.u32 %p, %tid, %total_reg;
1121    @%p bra DONE;
1122
1123    // row = tid / D, col = tid % D
1124    div.u32 %row, %tid, %D_reg;
1125    rem.u32 %col, %tid, %D_reg;
1126
1127    // Read indices[row] (f32 -> u32)
1128    cvt.u64.u32 %off, %row;
1129    shl.b64 %off, %off, 2;
1130    add.u64 %off, %idx_addr, %off;
1131    ld.global.f32 %idx_f, [%off];
1132    cvt.rzi.u32.f32 %src_idx, %idx_f;
1133
1134    // src_idx = indices[row] * D + col
1135    mad.lo.u32 %src_idx, %src_idx, %D_reg, %col;
1136    cvt.u64.u32 %off, %src_idx;
1137    shl.b64 %off, %off, 2;
1138    add.u64 %off, %w, %off;
1139    ld.global.f32 %val, [%off];
1140
1141    // Write to out[tid]
1142    cvt.u64.u32 %off, %tid;
1143    shl.b64 %off, %off, 2;
1144    add.u64 %off, %out, %off;
1145    st.global.f32 [%off], %val;
1146
1147DONE:
1148    ret;
1149}
1150";
1151
1152
1153// ---------------------------------------------------------------------------
1154// Scatter-add rows PTX kernel (for embedding backward)
1155// ---------------------------------------------------------------------------
1156// Given grad_output [N, D] and indices [N] (f32), atomically accumulate:
1157//   grad_weight[indices[row], col] += grad_output[row * D + col]
1158// Thread `tid` handles one element: row = tid / D, col = tid % D.
1159
1160#[cfg(feature = "cuda")]
1161pub(crate) const SCATTER_ADD_ROWS_PTX: &str = "\
1162.version 7.0
1163.target sm_52
1164.address_size 64
1165
1166.visible .entry scatter_add_rows_kernel(
1167    .param .u64 grad_output_ptr,
1168    .param .u64 indices_ptr,
1169    .param .u64 grad_weight_ptr,
1170    .param .u32 D,
1171    .param .u32 total
1172) {
1173    .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1174    .reg .u32 %row, %col, %dst_idx;
1175    .reg .u64 %go, %idx_addr, %gw, %off;
1176    .reg .f32 %idx_f, %grad_val, %dummy;
1177    .reg .pred %p;
1178
1179    ld.param.u64 %go, [grad_output_ptr];
1180    ld.param.u64 %idx_addr, [indices_ptr];
1181    ld.param.u64 %gw, [grad_weight_ptr];
1182    ld.param.u32 %D_reg, [D];
1183    ld.param.u32 %total_reg, [total];
1184
1185    mov.u32 %bid, %ctaid.x;
1186    mov.u32 %bdim, %ntid.x;
1187    mov.u32 %tid, %tid.x;
1188    mad.lo.u32 %tid, %bid, %bdim, %tid;
1189
1190    setp.ge.u32 %p, %tid, %total_reg;
1191    @%p bra DONE;
1192
1193    // row = tid / D, col = tid % D
1194    div.u32 %row, %tid, %D_reg;
1195    rem.u32 %col, %tid, %D_reg;
1196
1197    // Read grad_output[tid]
1198    cvt.u64.u32 %off, %tid;
1199    shl.b64 %off, %off, 2;
1200    add.u64 %off, %go, %off;
1201    ld.global.f32 %grad_val, [%off];
1202
1203    // Read indices[row] (f32 -> u32)
1204    cvt.u64.u32 %off, %row;
1205    shl.b64 %off, %off, 2;
1206    add.u64 %off, %idx_addr, %off;
1207    ld.global.f32 %idx_f, [%off];
1208    cvt.rzi.u32.f32 %dst_idx, %idx_f;
1209
1210    // dst_idx = indices[row] * D + col
1211    mad.lo.u32 %dst_idx, %dst_idx, %D_reg, %col;
1212    cvt.u64.u32 %off, %dst_idx;
1213    shl.b64 %off, %off, 2;
1214    add.u64 %off, %gw, %off;
1215    atom.global.add.f32 %dummy, [%off], %grad_val;
1216
1217DONE:
1218    ret;
1219}
1220";
1221
1222
1223// ---------------------------------------------------------------------------
1224// Slice-read PTX kernel: read first `len` rows from [N, max_len, D]
1225// ---------------------------------------------------------------------------
1226// Thread i writes: dst[i] = src[batch_idx * max_len * D + (i % (len*D))]
1227// where batch_idx = i / (len * D)
1228
1229#[cfg(feature = "cuda")]
1230pub(crate) const SLICE_READ_PTX: &str = "\
1231.version 7.0
1232.target sm_52
1233.address_size 64
1234
1235.visible .entry slice_read_kernel(
1236    .param .u64 src_ptr,
1237    .param .u64 dst_ptr,
1238    .param .u32 total,
1239    .param .u32 D,
1240    .param .u32 len,
1241    .param .u32 max_len
1242) {
1243    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %D_reg, %len_reg, %max_len_reg;
1244    .reg .u32 %batch_idx, %within, %row, %col, %src_idx;
1245    .reg .u32 %len_d;
1246    .reg .u64 %src, %dst, %src_off, %dst_off;
1247    .reg .f32 %val;
1248    .reg .pred %p;
1249
1250    ld.param.u64 %src, [src_ptr];
1251    ld.param.u64 %dst, [dst_ptr];
1252    ld.param.u32 %total_reg, [total];
1253    ld.param.u32 %D_reg, [D];
1254    ld.param.u32 %len_reg, [len];
1255    ld.param.u32 %max_len_reg, [max_len];
1256
1257    mov.u32 %bid, %ctaid.x;
1258    mov.u32 %bdim, %ntid.x;
1259    mov.u32 %r_tid, %tid.x;
1260    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1261
1262    setp.ge.u32 %p, %r_tid, %total_reg;
1263    @%p bra DONE;
1264
1265    // dst index = r_tid
1266    // batch_idx = r_tid / (len * D)
1267    // within = r_tid % (len * D)
1268    // row = within / D
1269    // col = within % D
1270    // src_idx = batch_idx * max_len * D + row * D + col
1271    mul.lo.u32 %len_d, %len_reg, %D_reg;
1272    div.u32 %batch_idx, %r_tid, %len_d;
1273    rem.u32 %within, %r_tid, %len_d;
1274    div.u32 %row, %within, %D_reg;
1275    rem.u32 %col, %within, %D_reg;
1276
1277    mul.lo.u32 %src_idx, %batch_idx, %max_len_reg;
1278    add.u32 %src_idx, %src_idx, %row;
1279    mul.lo.u32 %src_idx, %src_idx, %D_reg;
1280    add.u32 %src_idx, %src_idx, %col;
1281
1282    cvt.u64.u32 %src_off, %src_idx;
1283    shl.b64 %src_off, %src_off, 2;
1284    add.u64 %src_off, %src, %src_off;
1285    ld.global.f32 %val, [%src_off];
1286
1287    cvt.u64.u32 %dst_off, %r_tid;
1288    shl.b64 %dst_off, %dst_off, 2;
1289    add.u64 %dst_off, %dst, %dst_off;
1290    st.global.f32 [%dst_off], %val;
1291
1292DONE:
1293    ret;
1294}
1295";
1296
1297
1298// ---------------------------------------------------------------------------
1299// GELU PTX kernel: gelu(x) = x * sigmoid(1.702 * x)
1300//
1301// Uses `.approx` PTX instructions (`ex2.approx.f32`, `rcp.approx.f32`)
1302// for performance. These have reduced precision (~2^-22 relative error)
1303// compared to the full-precision variants, which is acceptable for neural
1304// network training/inference where f32 precision is already limited.
1305// ---------------------------------------------------------------------------
1306
1307#[cfg(feature = "cuda")]
1308pub(crate) const GELU_PTX: &str = "\
1309.version 7.0
1310.target sm_52
1311.address_size 64
1312
1313.visible .entry gelu_kernel(
1314    .param .u64 in_ptr,
1315    .param .u64 out_ptr,
1316    .param .u32 n
1317) {
1318    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1319    .reg .u64 %in, %out, %off;
1320    .reg .f32 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
1321    .reg .pred %p;
1322
1323    ld.param.u64 %in, [in_ptr];
1324    ld.param.u64 %out, [out_ptr];
1325    ld.param.u32 %n_reg, [n];
1326
1327    mov.u32 %bid, %ctaid.x;
1328    mov.u32 %bdim, %ntid.x;
1329    mov.u32 %r_tid, %tid.x;
1330    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1331
1332    setp.ge.u32 %p, %r_tid, %n_reg;
1333    @%p bra DONE;
1334
1335    cvt.u64.u32 %off, %r_tid;
1336    shl.b64 %off, %off, 2;
1337    add.u64 %in, %in, %off;
1338    add.u64 %out, %out, %off;
1339
1340    ld.global.f32 %x, [%in];
1341
1342    mov.f32 %k, 0f3FDA2720;
1343    mul.f32 %neg_kx, %k, %x;
1344    neg.f32 %neg_kx, %neg_kx;
1345    mul.f32 %neg_kx, %neg_kx, 0f3FB8AA3B;
1346    ex2.approx.f32 %exp_neg, %neg_kx;
1347    mov.f32 %one, 0f3F800000;
1348    add.f32 %denom, %one, %exp_neg;
1349    rcp.approx.f32 %sig, %denom;
1350    mul.f32 %result, %x, %sig;
1351    st.global.f32 [%out], %result;
1352
1353DONE:
1354    ret;
1355}
1356";
1357
1358/// PTX source for `gelu_f64_kernel`: `out[i] = x * sigmoid(1.702 * x)` (f64).
1359/// Uses f32-downcast for transcendentals.
1360#[cfg(feature = "cuda")]
1361pub(crate) const GELU_F64_PTX: &str = "\
1362.version 7.0
1363.target sm_52
1364.address_size 64
1365
1366.visible .entry gelu_f64_kernel(
1367    .param .u64 in_ptr,
1368    .param .u64 out_ptr,
1369    .param .u32 n
1370) {
1371    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1372    .reg .u64 %in, %out, %off;
1373    .reg .f64 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
1374    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1375    .reg .s32 %e_ni;
1376    .reg .s64 %e_ni64, %e_bits;
1377    .reg .pred %p;
1378
1379    ld.param.u64 %in, [in_ptr];
1380    ld.param.u64 %out, [out_ptr];
1381    ld.param.u32 %n_reg, [n];
1382
1383    mov.u32 %bid, %ctaid.x;
1384    mov.u32 %bdim, %ntid.x;
1385    mov.u32 %r_tid, %tid.x;
1386    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1387
1388    setp.ge.u32 %p, %r_tid, %n_reg;
1389    @%p bra DONE;
1390
1391    cvt.u64.u32 %off, %r_tid;
1392    shl.b64 %off, %off, 3;
1393    add.u64 %in, %in, %off;
1394    add.u64 %out, %out, %off;
1395
1396    ld.global.f64 %x, [%in];
1397    mov.f64 %one, 0d3FF0000000000000;
1398
1399    // k = 1.702
1400    mov.f64 %k, 0d3FFB44E400000000;
1401    mul.f64 %neg_kx, %k, %x;
1402    neg.f64 %neg_kx, %neg_kx;
1403
1404    // --- exp(%neg_kx) via Cody-Waite + degree-11 Horner ---
1405    mov.f64 %e_half, 0d3FE0000000000000;
1406    fma.rn.f64 %e_nf, %neg_kx, 0d3FF71547652B82FE, %e_half;
1407    cvt.rmi.f64.f64 %e_nf, %e_nf;
1408    cvt.rni.s32.f64 %e_ni, %e_nf;
1409    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_kx;
1410    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1411    mov.f64 %e_p, 0d3E21EED8EFF8D898;
1412    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1413    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
1414    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
1415    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
1416    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
1417    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
1418    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
1419    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
1420    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
1421    fma.rn.f64 %e_p, %e_p, %e_r, %one;
1422    fma.rn.f64 %exp_neg, %e_p, %e_r, %one;
1423    cvt.s64.s32 %e_ni64, %e_ni;
1424    add.s64 %e_ni64, %e_ni64, 1023;
1425    shl.b64 %e_bits, %e_ni64, 52;
1426    mov.b64 %e_nf, %e_bits;
1427    mul.f64 %exp_neg, %exp_neg, %e_nf;
1428    // --- end exp ---
1429
1430    add.f64 %denom, %one, %exp_neg;
1431    div.rn.f64 %sig, %one, %denom;
1432    mul.f64 %result, %x, %sig;
1433    st.global.f64 [%out], %result;
1434
1435DONE:
1436    ret;
1437}
1438";
1439
1440/// PTX source for `gelu_tanh_kernel`: tanh approximation of GELU.
1441/// `out[i] = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))`
1442///
1443/// Uses `ex2.approx.f32` for exp and Horner-form tanh approximation via
1444/// `tanh(y) = (e^(2y) - 1) / (e^(2y) + 1)`.
1445#[cfg(feature = "cuda")]
1446pub(crate) const GELU_TANH_PTX: &str = "\
1447.version 7.0
1448.target sm_52
1449.address_size 64
1450
1451.visible .entry gelu_tanh_kernel(
1452    .param .u64 in_ptr,
1453    .param .u64 out_ptr,
1454    .param .u32 n
1455) {
1456    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1457    .reg .u64 %in, %out, %off;
1458    .reg .f32 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
1459    .reg .f32 %e2y_m1, %e2y_p1, %th, %one, %half, %log2e, %result;
1460    .reg .pred %p;
1461
1462    ld.param.u64 %in, [in_ptr];
1463    ld.param.u64 %out, [out_ptr];
1464    ld.param.u32 %n_reg, [n];
1465
1466    mov.u32 %bid, %ctaid.x;
1467    mov.u32 %bdim, %ntid.x;
1468    mov.u32 %r_tid, %tid.x;
1469    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1470
1471    setp.ge.u32 %p, %r_tid, %n_reg;
1472    @%p bra DONE;
1473
1474    cvt.u64.u32 %off, %r_tid;
1475    shl.b64 %off, %off, 2;
1476    add.u64 %in, %in, %off;
1477    add.u64 %out, %out, %off;
1478
1479    ld.global.f32 %x, [%in];
1480
1481    // inner = sqrt(2/π) * (x + 0.044715 * x³)
1482    // sqrt(2/π) = 0.7978845608 = 0x3F4C422A
1483    // 0.044715 = 0x3D372713
1484    mul.f32 %x3, %x, %x;
1485    mul.f32 %x3, %x3, %x;
1486    mov.f32 %c, 0f3D372713;
1487    mul.f32 %x3, %c, %x3;
1488    add.f32 %inner, %x, %x3;
1489    mov.f32 %sqrt2pi, 0f3F4C422A;
1490    mul.f32 %y, %sqrt2pi, %inner;
1491
1492    // tanh(y) = (exp(2y) - 1) / (exp(2y) + 1)
1493    // exp(2y) = 2^(2y * log2(e))
1494    mov.f32 %log2e, 0f3FB8AA3B;
1495    add.f32 %two_y, %y, %y;
1496    mul.f32 %two_y, %two_y, %log2e;
1497    ex2.approx.f32 %e2y, %two_y;
1498    mov.f32 %one, 0f3F800000;
1499    sub.f32 %e2y_m1, %e2y, %one;
1500    add.f32 %e2y_p1, %e2y, %one;
1501    rcp.approx.f32 %e2y_p1, %e2y_p1;
1502    mul.f32 %th, %e2y_m1, %e2y_p1;
1503
1504    // out = 0.5 * x * (1 + tanh)
1505    add.f32 %th, %one, %th;
1506    mov.f32 %half, 0f3F000000;
1507    mul.f32 %result, %half, %x;
1508    mul.f32 %result, %result, %th;
1509    st.global.f32 [%out], %result;
1510
1511DONE:
1512    ret;
1513}
1514";
1515
1516/// PTX source for `gelu_tanh_f64_kernel`: tanh-approx GELU (f64).
1517/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(2y) in tanh.
1518#[cfg(feature = "cuda")]
1519pub(crate) const GELU_TANH_F64_PTX: &str = "\
1520.version 7.0
1521.target sm_52
1522.address_size 64
1523
1524.visible .entry gelu_tanh_f64_kernel(
1525    .param .u64 in_ptr,
1526    .param .u64 out_ptr,
1527    .param .u32 n
1528) {
1529    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1530    .reg .u64 %in, %out, %off;
1531    .reg .f64 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
1532    .reg .f64 %e2y_m1, %e2y_p1, %th, %one, %half, %result;
1533    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1534    .reg .s32 %e_ni;
1535    .reg .s64 %e_ni64, %e_bits;
1536    .reg .pred %p;
1537
1538    ld.param.u64 %in, [in_ptr];
1539    ld.param.u64 %out, [out_ptr];
1540    ld.param.u32 %n_reg, [n];
1541
1542    mov.u32 %bid, %ctaid.x;
1543    mov.u32 %bdim, %ntid.x;
1544    mov.u32 %r_tid, %tid.x;
1545    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1546
1547    setp.ge.u32 %p, %r_tid, %n_reg;
1548    @%p bra DONE;
1549
1550    cvt.u64.u32 %off, %r_tid;
1551    shl.b64 %off, %off, 3;
1552    add.u64 %in, %in, %off;
1553    add.u64 %out, %out, %off;
1554
1555    ld.global.f64 %x, [%in];
1556    mov.f64 %one, 0d3FF0000000000000;
1557
1558    // inner = sqrt(2/pi) * (x + 0.044715 * x^3)
1559    mul.f64 %x3, %x, %x;
1560    mul.f64 %x3, %x3, %x;
1561    mov.f64 %c, 0d3FA6E4E260000000;
1562    mul.f64 %x3, %c, %x3;
1563    add.f64 %inner, %x, %x3;
1564    mov.f64 %sqrt2pi, 0d3FE9884540000000;
1565    mul.f64 %y, %sqrt2pi, %inner;
1566
1567    // tanh(y) = (exp(2y)-1)/(exp(2y)+1), exp(2y) in full f64
1568    add.f64 %two_y, %y, %y;
1569
1570    // --- exp(%two_y) via Cody-Waite + degree-11 Horner ---
1571    mov.f64 %e_half, 0d3FE0000000000000;
1572    fma.rn.f64 %e_nf, %two_y, 0d3FF71547652B82FE, %e_half;
1573    cvt.rmi.f64.f64 %e_nf, %e_nf;
1574    cvt.rni.s32.f64 %e_ni, %e_nf;
1575    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_y;
1576    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1577    mov.f64 %e_p, 0d3E21EED8EFF8D898;
1578    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1579    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
1580    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
1581    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
1582    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
1583    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
1584    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
1585    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
1586    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
1587    fma.rn.f64 %e_p, %e_p, %e_r, %one;
1588    fma.rn.f64 %e2y, %e_p, %e_r, %one;
1589    cvt.s64.s32 %e_ni64, %e_ni;
1590    add.s64 %e_ni64, %e_ni64, 1023;
1591    shl.b64 %e_bits, %e_ni64, 52;
1592    mov.b64 %e_nf, %e_bits;
1593    mul.f64 %e2y, %e2y, %e_nf;
1594    // --- end exp ---
1595
1596    sub.f64 %e2y_m1, %e2y, %one;
1597    add.f64 %e2y_p1, %e2y, %one;
1598    div.rn.f64 %th, %e2y_m1, %e2y_p1;
1599
1600    // out = 0.5 * x * (1 + tanh)
1601    add.f64 %th, %one, %th;
1602    mov.f64 %half, 0d3FE0000000000000;
1603    mul.f64 %result, %half, %x;
1604    mul.f64 %result, %result, %th;
1605    st.global.f64 [%out], %result;
1606
1607DONE:
1608    ret;
1609}
1610";
1611
1612/// PTX source for `gelu_erf_kernel`: exact GELU using erf.
1613/// `out[i] = x * 0.5 * (1 + erf(x / sqrt(2)))`
1614///
1615/// Uses Abramowitz & Stegun formula 7.1.26 for erf (|ε| < 1.5×10⁻⁷).
1616#[cfg(feature = "cuda")]
1617pub(crate) const GELU_ERF_PTX: &str = "\
1618.version 7.0
1619.target sm_52
1620.address_size 64
1621
1622.visible .entry gelu_erf_kernel(
1623    .param .u64 in_ptr,
1624    .param .u64 out_ptr,
1625    .param .u32 n
1626) {
1627    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1628    .reg .u64 %in, %out, %off;
1629    .reg .f32 %x, %z, %ax, %one, %half, %log2e;
1630    .reg .f32 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
1631    .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %result;
1632    .reg .pred %pred_ge, %pred_neg;
1633
1634    ld.param.u64 %in, [in_ptr];
1635    ld.param.u64 %out, [out_ptr];
1636    ld.param.u32 %n_reg, [n];
1637
1638    mov.u32 %bid, %ctaid.x;
1639    mov.u32 %bdim, %ntid.x;
1640    mov.u32 %r_tid, %tid.x;
1641    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1642
1643    setp.ge.u32 %pred_ge, %r_tid, %n_reg;
1644    @%pred_ge bra DONE;
1645
1646    cvt.u64.u32 %off, %r_tid;
1647    shl.b64 %off, %off, 2;
1648    add.u64 %in, %in, %off;
1649    add.u64 %out, %out, %off;
1650
1651    ld.global.f32 %x, [%in];
1652    mov.f32 %one, 0f3F800000;
1653    mov.f32 %half, 0f3F000000;
1654    mov.f32 %log2e, 0f3FB8AA3B;
1655
1656    // z = x / sqrt(2) = x * 0.70710678
1657    mov.f32 %z, 0f3F3504F3;
1658    mul.f32 %z, %x, %z;
1659
1660    // |z| for erf(|z|)
1661    abs.f32 %ax, %z;
1662
1663    // t = 1 / (1 + 0.3275911 * |z|)
1664    mov.f32 %p, 0f3EA7BA05;
1665    mul.f32 %t, %p, %ax;
1666    add.f32 %t, %one, %t;
1667    rcp.approx.f32 %t, %t;
1668
1669    // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
1670    mov.f32 %a5, 0f3E0AAAAB;
1671    mov.f32 %a4, 0fBEB3A903;
1672    mov.f32 %a3, 0f3FB506DD;
1673    mov.f32 %a2, 0fBF03C1E1;
1674    mov.f32 %a1, 0f3EA0D6BB;
1675
1676    mul.f32 %pt, %t, %a5;
1677    add.f32 %pt, %pt, %a4;
1678    mul.f32 %pt, %pt, %t;
1679    add.f32 %pt, %pt, %a3;
1680    mul.f32 %pt, %pt, %t;
1681    add.f32 %pt, %pt, %a2;
1682    mul.f32 %pt, %pt, %t;
1683    add.f32 %pt, %pt, %a1;
1684    mul.f32 %pt, %pt, %t;
1685
1686    // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
1687    mul.f32 %z2, %ax, %ax;
1688    neg.f32 %neg_z2, %z2;
1689    mul.f32 %neg_z2, %neg_z2, %log2e;
1690    ex2.approx.f32 %exp_neg_z2, %neg_z2;
1691
1692    // erf(|z|) = 1 - poly * exp(-z^2)
1693    mul.f32 %erf_val, %pt, %exp_neg_z2;
1694    sub.f32 %erf_val, %one, %erf_val;
1695
1696    // erf(-z) = -erf(z), so sign-correct
1697    setp.lt.f32 %pred_neg, %z, 0f00000000;
1698    @%pred_neg neg.f32 %erf_val, %erf_val;
1699
1700    // out = x * 0.5 * (1 + erf(x/sqrt(2)))
1701    add.f32 %erf_val, %one, %erf_val;
1702    mul.f32 %result, %half, %x;
1703    mul.f32 %result, %result, %erf_val;
1704    st.global.f32 [%out], %result;
1705
1706DONE:
1707    ret;
1708}
1709";
1710
1711/// PTX source for `gelu_erf_f64_kernel`: exact erf GELU (f64).
1712/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(-z^2).
1713#[cfg(feature = "cuda")]
1714pub(crate) const GELU_ERF_F64_PTX: &str = "\
1715.version 7.0
1716.target sm_52
1717.address_size 64
1718
1719.visible .entry gelu_erf_f64_kernel(
1720    .param .u64 in_ptr,
1721    .param .u64 out_ptr,
1722    .param .u32 n
1723) {
1724    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1725    .reg .u64 %in, %out, %off;
1726    .reg .f64 %x, %z, %ax, %one, %half;
1727    .reg .f64 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
1728    .reg .f64 %p, %a1, %a2, %a3, %a4, %a5, %result;
1729    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1730    .reg .s32 %e_ni;
1731    .reg .s64 %e_ni64, %e_bits;
1732    .reg .pred %pred_ge, %pred_neg;
1733
1734    ld.param.u64 %in, [in_ptr];
1735    ld.param.u64 %out, [out_ptr];
1736    ld.param.u32 %n_reg, [n];
1737
1738    mov.u32 %bid, %ctaid.x;
1739    mov.u32 %bdim, %ntid.x;
1740    mov.u32 %r_tid, %tid.x;
1741    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1742
1743    setp.ge.u32 %pred_ge, %r_tid, %n_reg;
1744    @%pred_ge bra DONE;
1745
1746    cvt.u64.u32 %off, %r_tid;
1747    shl.b64 %off, %off, 3;
1748    add.u64 %in, %in, %off;
1749    add.u64 %out, %out, %off;
1750
1751    ld.global.f64 %x, [%in];
1752    mov.f64 %one, 0d3FF0000000000000;
1753    mov.f64 %half, 0d3FE0000000000000;
1754
1755    // z = x / sqrt(2) = x * 0.70710678
1756    mov.f64 %z, 0d3FE6A09E60000000;
1757    mul.f64 %z, %x, %z;
1758
1759    abs.f64 %ax, %z;
1760
1761    // t = 1 / (1 + 0.3275911 * |z|)
1762    mov.f64 %p, 0d3FD4F740A0000000;
1763    mul.f64 %t, %p, %ax;
1764    add.f64 %t, %one, %t;
1765    div.rn.f64 %t, %one, %t;
1766
1767    // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
1768    mov.f64 %a5, 0d3FC1555560000000;
1769    mov.f64 %a4, 0dBFD6752060000000;
1770    mov.f64 %a3, 0d3FF6A0DBA0000000;
1771    mov.f64 %a2, 0dBFE0783C20000000;
1772    mov.f64 %a1, 0d3FD41AD760000000;
1773
1774    mul.f64 %pt, %t, %a5;
1775    add.f64 %pt, %pt, %a4;
1776    mul.f64 %pt, %pt, %t;
1777    add.f64 %pt, %pt, %a3;
1778    mul.f64 %pt, %pt, %t;
1779    add.f64 %pt, %pt, %a2;
1780    mul.f64 %pt, %pt, %t;
1781    add.f64 %pt, %pt, %a1;
1782    mul.f64 %pt, %pt, %t;
1783
1784    // exp(-z^2) in full f64
1785    mul.f64 %z2, %ax, %ax;
1786    neg.f64 %neg_z2, %z2;
1787
1788    // --- exp(%neg_z2) via Cody-Waite + degree-11 Horner ---
1789    mov.f64 %e_half, 0d3FE0000000000000;
1790    fma.rn.f64 %e_nf, %neg_z2, 0d3FF71547652B82FE, %e_half;
1791    cvt.rmi.f64.f64 %e_nf, %e_nf;
1792    cvt.rni.s32.f64 %e_ni, %e_nf;
1793    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_z2;
1794    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1795    mov.f64 %e_p, 0d3E21EED8EFF8D898;
1796    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1797    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
1798    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
1799    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
1800    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
1801    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
1802    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
1803    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
1804    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
1805    fma.rn.f64 %e_p, %e_p, %e_r, %one;
1806    fma.rn.f64 %exp_neg_z2, %e_p, %e_r, %one;
1807    cvt.s64.s32 %e_ni64, %e_ni;
1808    add.s64 %e_ni64, %e_ni64, 1023;
1809    shl.b64 %e_bits, %e_ni64, 52;
1810    mov.b64 %e_nf, %e_bits;
1811    mul.f64 %exp_neg_z2, %exp_neg_z2, %e_nf;
1812    // --- end exp ---
1813
1814    mul.f64 %erf_val, %pt, %exp_neg_z2;
1815    sub.f64 %erf_val, %one, %erf_val;
1816
1817    setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
1818    @%pred_neg neg.f64 %erf_val, %erf_val;
1819
1820    add.f64 %erf_val, %one, %erf_val;
1821    mul.f64 %result, %half, %x;
1822    mul.f64 %result, %result, %erf_val;
1823    st.global.f64 [%out], %result;
1824
1825DONE:
1826    ret;
1827}
1828";
1829
1830/// PTX source for `gelu_backward_tanh_kernel`:
1831/// Backward for tanh approximation of GELU.
1832/// Let `u = sqrt(2/π) * (x + 0.044715 * x³)`, `t = tanh(u)`.
1833/// `d/dx = 0.5 * (1 + t) + 0.5 * x * (1 - t²) * sqrt(2/π) * (1 + 3*0.044715*x²)`
1834/// `out[i] = grad[i] * d/dx`
1835#[cfg(feature = "cuda")]
1836pub(crate) const GELU_BACKWARD_TANH_PTX: &str = "\
1837.version 7.0
1838.target sm_52
1839.address_size 64
1840
1841.visible .entry gelu_backward_tanh_kernel(
1842    .param .u64 grad_ptr,
1843    .param .u64 input_ptr,
1844    .param .u64 out_ptr,
1845    .param .u32 n
1846) {
1847    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1848    .reg .u64 %grad, %input, %out, %off;
1849    .reg .f32 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
1850    .reg .f32 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half, %log2e;
1851    .reg .f32 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
1852    .reg .pred %p;
1853
1854    ld.param.u64 %grad, [grad_ptr];
1855    ld.param.u64 %input, [input_ptr];
1856    ld.param.u64 %out, [out_ptr];
1857    ld.param.u32 %n_reg, [n];
1858
1859    mov.u32 %bid, %ctaid.x;
1860    mov.u32 %bdim, %ntid.x;
1861    mov.u32 %r_tid, %tid.x;
1862    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1863
1864    setp.ge.u32 %p, %r_tid, %n_reg;
1865    @%p bra DONE;
1866
1867    cvt.u64.u32 %off, %r_tid;
1868    shl.b64 %off, %off, 2;
1869    add.u64 %grad, %grad, %off;
1870    add.u64 %input, %input, %off;
1871    add.u64 %out, %out, %off;
1872
1873    ld.global.f32 %vg, [%grad];
1874    ld.global.f32 %x, [%input];
1875
1876    mov.f32 %one, 0f3F800000;
1877    mov.f32 %half, 0f3F000000;
1878    mov.f32 %log2e, 0f3FB8AA3B;
1879    mov.f32 %sqrt2pi, 0f3F4C422A;
1880    mov.f32 %c, 0f3D372713;
1881    // 3 * 0.044715 = 0.134145 = 0x3E096B8C
1882    mov.f32 %c3, 0f3E096B8C;
1883
1884    // u = sqrt(2/π) * (x + 0.044715 * x³)
1885    mul.f32 %x2, %x, %x;
1886    mul.f32 %x3, %x2, %x;
1887    mul.f32 %x3, %c, %x3;
1888    add.f32 %inner, %x, %x3;
1889    mul.f32 %y, %sqrt2pi, %inner;
1890
1891    // tanh(y) via exp
1892    add.f32 %two_y, %y, %y;
1893    mul.f32 %two_y, %two_y, %log2e;
1894    ex2.approx.f32 %e2y, %two_y;
1895    sub.f32 %e2y_m1, %e2y, %one;
1896    add.f32 %e2y_p1, %e2y, %one;
1897    rcp.approx.f32 %e2y_p1, %e2y_p1;
1898    mul.f32 %th, %e2y_m1, %e2y_p1;
1899
1900    // d/dx = 0.5*(1+tanh) + 0.5*x*(1-tanh²)*sqrt(2/π)*(1+3*0.044715*x²)
1901    // term1 = 0.5 * (1 + th)
1902    add.f32 %term1, %one, %th;
1903    mul.f32 %term1, %half, %term1;
1904
1905    // (1 - th²)
1906    mul.f32 %th2, %th, %th;
1907    sub.f32 %one_m_th2, %one, %th2;
1908
1909    // d_inner = sqrt(2/π) * (1 + 3*0.044715*x²)
1910    mul.f32 %d_inner, %c3, %x2;
1911    add.f32 %d_inner, %one, %d_inner;
1912    mul.f32 %d_inner, %sqrt2pi, %d_inner;
1913
1914    // term2 = 0.5 * x * (1-th²) * d_inner
1915    mul.f32 %term2, %half, %x;
1916    mul.f32 %term2, %term2, %one_m_th2;
1917    mul.f32 %term2, %term2, %d_inner;
1918
1919    add.f32 %d_gelu, %term1, %term2;
1920    mul.f32 %result, %vg, %d_gelu;
1921    st.global.f32 [%out], %result;
1922
1923DONE:
1924    ret;
1925}
1926";
1927
1928/// PTX source for `gelu_backward_tanh_f64_kernel`: tanh-approx backward (f64).
1929/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(2y) in tanh.
1930#[cfg(feature = "cuda")]
1931pub(crate) const GELU_BACKWARD_TANH_F64_PTX: &str = "\
1932.version 7.0
1933.target sm_52
1934.address_size 64
1935
1936.visible .entry gelu_backward_tanh_f64_kernel(
1937    .param .u64 grad_ptr,
1938    .param .u64 input_ptr,
1939    .param .u64 out_ptr,
1940    .param .u32 n
1941) {
1942    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1943    .reg .u64 %grad, %input, %out, %off;
1944    .reg .f64 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
1945    .reg .f64 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half;
1946    .reg .f64 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
1947    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1948    .reg .s32 %e_ni;
1949    .reg .s64 %e_ni64, %e_bits;
1950    .reg .pred %p;
1951
1952    ld.param.u64 %grad, [grad_ptr];
1953    ld.param.u64 %input, [input_ptr];
1954    ld.param.u64 %out, [out_ptr];
1955    ld.param.u32 %n_reg, [n];
1956
1957    mov.u32 %bid, %ctaid.x;
1958    mov.u32 %bdim, %ntid.x;
1959    mov.u32 %r_tid, %tid.x;
1960    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1961
1962    setp.ge.u32 %p, %r_tid, %n_reg;
1963    @%p bra DONE;
1964
1965    cvt.u64.u32 %off, %r_tid;
1966    shl.b64 %off, %off, 3;
1967    add.u64 %grad, %grad, %off;
1968    add.u64 %input, %input, %off;
1969    add.u64 %out, %out, %off;
1970
1971    ld.global.f64 %vg, [%grad];
1972    ld.global.f64 %x, [%input];
1973
1974    mov.f64 %one, 0d3FF0000000000000;
1975    mov.f64 %half, 0d3FE0000000000000;
1976    mov.f64 %sqrt2pi, 0d3FE9884540000000;
1977    mov.f64 %c, 0d3FA6E4E260000000;
1978    // 3 * 0.044715 = 0.134145
1979    mov.f64 %c3, 0d3FC12D7180000000;
1980
1981    mul.f64 %x2, %x, %x;
1982    mul.f64 %x3, %x2, %x;
1983    mul.f64 %x3, %c, %x3;
1984    add.f64 %inner, %x, %x3;
1985    mul.f64 %y, %sqrt2pi, %inner;
1986
1987    // tanh(y) = (exp(2y)-1)/(exp(2y)+1) in full f64
1988    add.f64 %two_y, %y, %y;
1989
1990    // --- exp(%two_y) via Cody-Waite + degree-11 Horner ---
1991    mov.f64 %e_half, 0d3FE0000000000000;
1992    fma.rn.f64 %e_nf, %two_y, 0d3FF71547652B82FE, %e_half;
1993    cvt.rmi.f64.f64 %e_nf, %e_nf;
1994    cvt.rni.s32.f64 %e_ni, %e_nf;
1995    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_y;
1996    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1997    mov.f64 %e_p, 0d3E21EED8EFF8D898;
1998    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1999    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2000    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2001    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2002    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2003    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2004    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2005    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2006    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2007    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2008    fma.rn.f64 %e2y, %e_p, %e_r, %one;
2009    cvt.s64.s32 %e_ni64, %e_ni;
2010    add.s64 %e_ni64, %e_ni64, 1023;
2011    shl.b64 %e_bits, %e_ni64, 52;
2012    mov.b64 %e_nf, %e_bits;
2013    mul.f64 %e2y, %e2y, %e_nf;
2014    // --- end exp ---
2015
2016    sub.f64 %e2y_m1, %e2y, %one;
2017    add.f64 %e2y_p1, %e2y, %one;
2018    div.rn.f64 %th, %e2y_m1, %e2y_p1;
2019
2020    add.f64 %term1, %one, %th;
2021    mul.f64 %term1, %half, %term1;
2022
2023    mul.f64 %th2, %th, %th;
2024    sub.f64 %one_m_th2, %one, %th2;
2025
2026    mul.f64 %d_inner, %c3, %x2;
2027    add.f64 %d_inner, %one, %d_inner;
2028    mul.f64 %d_inner, %sqrt2pi, %d_inner;
2029
2030    mul.f64 %term2, %half, %x;
2031    mul.f64 %term2, %term2, %one_m_th2;
2032    mul.f64 %term2, %term2, %d_inner;
2033
2034    add.f64 %d_gelu, %term1, %term2;
2035    mul.f64 %result, %vg, %d_gelu;
2036    st.global.f64 [%out], %result;
2037
2038DONE:
2039    ret;
2040}
2041";
2042
2043// ---------------------------------------------------------------------------
2044// SiLU / ELU / Mish activation kernels (forward + backward)
2045// ---------------------------------------------------------------------------
2046
2047/// PTX source for `silu_kernel`: `out[i] = x * sigmoid(x)`.
2048/// SiLU (Sigmoid Linear Unit), also known as Swish-1.
2049#[cfg(feature = "cuda")]
2050pub(crate) const SILU_PTX: &str = "\
2051.version 7.0
2052.target sm_52
2053.address_size 64
2054
2055.visible .entry silu_kernel(
2056    .param .u64 a_ptr,
2057    .param .u64 out_ptr,
2058    .param .u32 n
2059) {
2060    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2061    .reg .u64 %a, %out, %off;
2062    .reg .f32 %x, %neg, %e, %denom, %sig, %vr, %one, %lg2e;
2063    .reg .pred %p;
2064
2065    ld.param.u64 %a, [a_ptr];
2066    ld.param.u64 %out, [out_ptr];
2067    ld.param.u32 %n_reg, [n];
2068
2069    mov.u32 %bid, %ctaid.x;
2070    mov.u32 %bdim, %ntid.x;
2071    mov.u32 %r_tid, %tid.x;
2072    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2073
2074    setp.ge.u32 %p, %r_tid, %n_reg;
2075    @%p bra DONE;
2076
2077    cvt.u64.u32 %off, %r_tid;
2078    shl.b64 %off, %off, 2;
2079
2080    add.u64 %a, %a, %off;
2081    add.u64 %out, %out, %off;
2082
2083    ld.global.f32 %x, [%a];
2084    // sigmoid(x) = 1 / (1 + exp(-x))
2085    // exp(-x) = 2^(-x * log2(e))
2086    mov.f32 %one, 0f3F800000;
2087    mov.f32 %lg2e, 0f3FB8AA3B;
2088    neg.f32 %neg, %x;
2089    mul.f32 %neg, %neg, %lg2e;
2090    ex2.approx.f32 %e, %neg;
2091    add.f32 %denom, %one, %e;
2092    rcp.approx.f32 %sig, %denom;
2093    // silu(x) = x * sigmoid(x)
2094    mul.f32 %vr, %x, %sig;
2095    st.global.f32 [%out], %vr;
2096
2097DONE:
2098    ret;
2099}
2100";
2101
2102/// PTX source for `silu_f64_kernel`: `out[i] = x * sigmoid(x)` (f64).
2103/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(-x).
2104#[cfg(feature = "cuda")]
2105pub(crate) const SILU_F64_PTX: &str = "\
2106.version 7.0
2107.target sm_52
2108.address_size 64
2109
2110.visible .entry silu_f64_kernel(
2111    .param .u64 a_ptr,
2112    .param .u64 out_ptr,
2113    .param .u32 n
2114) {
2115    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2116    .reg .u64 %a, %out, %off;
2117    .reg .f64 %x, %neg_x, %e, %denom, %sig, %vr, %one;
2118    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2119    .reg .s32 %e_ni;
2120    .reg .s64 %e_ni64, %e_bits;
2121    .reg .pred %p;
2122
2123    ld.param.u64 %a, [a_ptr];
2124    ld.param.u64 %out, [out_ptr];
2125    ld.param.u32 %n_reg, [n];
2126
2127    mov.u32 %bid, %ctaid.x;
2128    mov.u32 %bdim, %ntid.x;
2129    mov.u32 %r_tid, %tid.x;
2130    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2131
2132    setp.ge.u32 %p, %r_tid, %n_reg;
2133    @%p bra DONE;
2134
2135    cvt.u64.u32 %off, %r_tid;
2136    shl.b64 %off, %off, 3;
2137    add.u64 %a, %a, %off;
2138    add.u64 %out, %out, %off;
2139
2140    ld.global.f64 %x, [%a];
2141    mov.f64 %one, 0d3FF0000000000000;
2142    neg.f64 %neg_x, %x;
2143
2144    // --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
2145    mov.f64 %e_half, 0d3FE0000000000000;
2146    fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
2147    cvt.rmi.f64.f64 %e_nf, %e_nf;
2148    cvt.rni.s32.f64 %e_ni, %e_nf;
2149    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
2150    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2151    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2152    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2153    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2154    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2155    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2156    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2157    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2158    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2159    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2160    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2161    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2162    fma.rn.f64 %e, %e_p, %e_r, %one;
2163    cvt.s64.s32 %e_ni64, %e_ni;
2164    add.s64 %e_ni64, %e_ni64, 1023;
2165    shl.b64 %e_bits, %e_ni64, 52;
2166    mov.b64 %e_nf, %e_bits;
2167    mul.f64 %e, %e, %e_nf;
2168    // --- end exp ---
2169
2170    add.f64 %denom, %one, %e;
2171    div.rn.f64 %sig, %one, %denom;
2172    mul.f64 %vr, %x, %sig;
2173    st.global.f64 [%out], %vr;
2174
2175DONE:
2176    ret;
2177}
2178";
2179
2180/// PTX source for `silu_backward_kernel`:
2181/// `out[i] = grad[i] * (sig + x * sig * (1 - sig))` where `sig = sigmoid(input[i])`.
2182#[cfg(feature = "cuda")]
2183pub(crate) const SILU_BACKWARD_PTX: &str = "\
2184.version 7.0
2185.target sm_52
2186.address_size 64
2187
2188.visible .entry silu_backward_kernel(
2189    .param .u64 grad_ptr,
2190    .param .u64 input_ptr,
2191    .param .u64 out_ptr,
2192    .param .u32 n
2193) {
2194    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2195    .reg .u64 %grad, %input, %out, %off;
2196    .reg .f32 %vg, %x, %neg, %e, %denom, %sig, %one, %lg2e;
2197    .reg .f32 %one_m_sig, %x_sig_omsig, %deriv, %result;
2198    .reg .pred %p;
2199
2200    ld.param.u64 %grad, [grad_ptr];
2201    ld.param.u64 %input, [input_ptr];
2202    ld.param.u64 %out, [out_ptr];
2203    ld.param.u32 %n_reg, [n];
2204
2205    mov.u32 %bid, %ctaid.x;
2206    mov.u32 %bdim, %ntid.x;
2207    mov.u32 %r_tid, %tid.x;
2208    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2209
2210    setp.ge.u32 %p, %r_tid, %n_reg;
2211    @%p bra DONE;
2212
2213    cvt.u64.u32 %off, %r_tid;
2214    shl.b64 %off, %off, 2;
2215    add.u64 %grad, %grad, %off;
2216    add.u64 %input, %input, %off;
2217    add.u64 %out, %out, %off;
2218
2219    ld.global.f32 %vg, [%grad];
2220    ld.global.f32 %x, [%input];
2221
2222    // sig = sigmoid(x) = 1 / (1 + exp(-x))
2223    mov.f32 %one, 0f3F800000;
2224    mov.f32 %lg2e, 0f3FB8AA3B;
2225    neg.f32 %neg, %x;
2226    mul.f32 %neg, %neg, %lg2e;
2227    ex2.approx.f32 %e, %neg;
2228    add.f32 %denom, %one, %e;
2229    rcp.approx.f32 %sig, %denom;
2230
2231    // deriv = sig + x * sig * (1 - sig)
2232    sub.f32 %one_m_sig, %one, %sig;
2233    mul.f32 %x_sig_omsig, %x, %sig;
2234    mul.f32 %x_sig_omsig, %x_sig_omsig, %one_m_sig;
2235    add.f32 %deriv, %sig, %x_sig_omsig;
2236    mul.f32 %result, %vg, %deriv;
2237    st.global.f32 [%out], %result;
2238
2239DONE:
2240    ret;
2241}
2242";
2243
2244/// PTX source for `silu_backward_f64_kernel` (f64).
2245/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(-x).
2246#[cfg(feature = "cuda")]
2247pub(crate) const SILU_BACKWARD_F64_PTX: &str = "\
2248.version 7.0
2249.target sm_52
2250.address_size 64
2251
2252.visible .entry silu_backward_f64_kernel(
2253    .param .u64 grad_ptr,
2254    .param .u64 input_ptr,
2255    .param .u64 out_ptr,
2256    .param .u32 n
2257) {
2258    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2259    .reg .u64 %grad, %input, %out, %off;
2260    .reg .f64 %vg, %x, %neg_x, %e, %denom, %sig, %one;
2261    .reg .f64 %one_m_sig, %x_sig_omsig, %deriv, %result;
2262    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2263    .reg .s32 %e_ni;
2264    .reg .s64 %e_ni64, %e_bits;
2265    .reg .pred %p;
2266
2267    ld.param.u64 %grad, [grad_ptr];
2268    ld.param.u64 %input, [input_ptr];
2269    ld.param.u64 %out, [out_ptr];
2270    ld.param.u32 %n_reg, [n];
2271
2272    mov.u32 %bid, %ctaid.x;
2273    mov.u32 %bdim, %ntid.x;
2274    mov.u32 %r_tid, %tid.x;
2275    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2276
2277    setp.ge.u32 %p, %r_tid, %n_reg;
2278    @%p bra DONE;
2279
2280    cvt.u64.u32 %off, %r_tid;
2281    shl.b64 %off, %off, 3;
2282    add.u64 %grad, %grad, %off;
2283    add.u64 %input, %input, %off;
2284    add.u64 %out, %out, %off;
2285
2286    ld.global.f64 %vg, [%grad];
2287    ld.global.f64 %x, [%input];
2288
2289    mov.f64 %one, 0d3FF0000000000000;
2290    neg.f64 %neg_x, %x;
2291
2292    // --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
2293    mov.f64 %e_half, 0d3FE0000000000000;
2294    fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
2295    cvt.rmi.f64.f64 %e_nf, %e_nf;
2296    cvt.rni.s32.f64 %e_ni, %e_nf;
2297    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
2298    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2299    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2300    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2301    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2302    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2303    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2304    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2305    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2306    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2307    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2308    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2309    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2310    fma.rn.f64 %e, %e_p, %e_r, %one;
2311    cvt.s64.s32 %e_ni64, %e_ni;
2312    add.s64 %e_ni64, %e_ni64, 1023;
2313    shl.b64 %e_bits, %e_ni64, 52;
2314    mov.b64 %e_nf, %e_bits;
2315    mul.f64 %e, %e, %e_nf;
2316    // --- end exp ---
2317
2318    add.f64 %denom, %one, %e;
2319    div.rn.f64 %sig, %one, %denom;
2320
2321    sub.f64 %one_m_sig, %one, %sig;
2322    mul.f64 %x_sig_omsig, %x, %sig;
2323    mul.f64 %x_sig_omsig, %x_sig_omsig, %one_m_sig;
2324    add.f64 %deriv, %sig, %x_sig_omsig;
2325    mul.f64 %result, %vg, %deriv;
2326    st.global.f64 [%out], %result;
2327
2328DONE:
2329    ret;
2330}
2331";
2332
2333/// PTX source for `elu_kernel`: `out[i] = x > 0 ? x : alpha * (exp(x) - 1)`.
2334/// Takes `alpha` as an extra `.param .f32` parameter.
2335#[cfg(feature = "cuda")]
2336pub(crate) const ELU_PTX: &str = "\
2337.version 7.0
2338.target sm_52
2339.address_size 64
2340
2341.visible .entry elu_kernel(
2342    .param .u64 a_ptr,
2343    .param .u64 out_ptr,
2344    .param .u32 n,
2345    .param .f32 alpha
2346) {
2347    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2348    .reg .u64 %a, %out, %off;
2349    .reg .f32 %x, %alpha_r, %lg2e, %one, %ex, %em1, %neg_branch, %vr;
2350    .reg .pred %p, %pos;
2351
2352    ld.param.u64 %a, [a_ptr];
2353    ld.param.u64 %out, [out_ptr];
2354    ld.param.u32 %n_reg, [n];
2355    ld.param.f32 %alpha_r, [alpha];
2356
2357    mov.u32 %bid, %ctaid.x;
2358    mov.u32 %bdim, %ntid.x;
2359    mov.u32 %r_tid, %tid.x;
2360    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2361
2362    setp.ge.u32 %p, %r_tid, %n_reg;
2363    @%p bra DONE;
2364
2365    cvt.u64.u32 %off, %r_tid;
2366    shl.b64 %off, %off, 2;
2367
2368    add.u64 %a, %a, %off;
2369    add.u64 %out, %out, %off;
2370
2371    ld.global.f32 %x, [%a];
2372    mov.f32 %one, 0f3F800000;
2373    mov.f32 %lg2e, 0f3FB8AA3B;
2374
2375    // exp(x) = 2^(x * log2(e))
2376    mul.f32 %ex, %x, %lg2e;
2377    ex2.approx.f32 %ex, %ex;
2378    sub.f32 %em1, %ex, %one;
2379    mul.f32 %neg_branch, %alpha_r, %em1;
2380
2381    // x > 0 ? x : alpha*(exp(x)-1)
2382    mov.f32 %vr, 0f00000000;
2383    setp.gt.f32 %pos, %x, %vr;
2384    selp.f32 %vr, %x, %neg_branch, %pos;
2385    st.global.f32 [%out], %vr;
2386
2387DONE:
2388    ret;
2389}
2390";
2391
2392/// PTX source for `elu_f64_kernel`: `out[i] = x > 0 ? x : alpha * (exp(x) - 1)` (f64).
2393/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(x).
2394#[cfg(feature = "cuda")]
2395pub(crate) const ELU_F64_PTX: &str = "\
2396.version 7.0
2397.target sm_52
2398.address_size 64
2399
2400.visible .entry elu_f64_kernel(
2401    .param .u64 a_ptr,
2402    .param .u64 out_ptr,
2403    .param .u32 n,
2404    .param .f64 alpha
2405) {
2406    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2407    .reg .u64 %a, %out, %off;
2408    .reg .f64 %x, %alpha_r, %one, %ex, %em1, %neg_branch, %vr;
2409    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2410    .reg .s32 %e_ni;
2411    .reg .s64 %e_ni64, %e_bits;
2412    .reg .pred %p, %pos;
2413
2414    ld.param.u64 %a, [a_ptr];
2415    ld.param.u64 %out, [out_ptr];
2416    ld.param.u32 %n_reg, [n];
2417    ld.param.f64 %alpha_r, [alpha];
2418
2419    mov.u32 %bid, %ctaid.x;
2420    mov.u32 %bdim, %ntid.x;
2421    mov.u32 %r_tid, %tid.x;
2422    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2423
2424    setp.ge.u32 %p, %r_tid, %n_reg;
2425    @%p bra DONE;
2426
2427    cvt.u64.u32 %off, %r_tid;
2428    shl.b64 %off, %off, 3;
2429    add.u64 %a, %a, %off;
2430    add.u64 %out, %out, %off;
2431
2432    ld.global.f64 %x, [%a];
2433    mov.f64 %one, 0d3FF0000000000000;
2434
2435    // --- exp(%x) via Cody-Waite + degree-11 Horner ---
2436    mov.f64 %e_half, 0d3FE0000000000000;
2437    fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
2438    cvt.rmi.f64.f64 %e_nf, %e_nf;
2439    cvt.rni.s32.f64 %e_ni, %e_nf;
2440    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
2441    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2442    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2443    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2444    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2445    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2446    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2447    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2448    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2449    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2450    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2451    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2452    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2453    fma.rn.f64 %ex, %e_p, %e_r, %one;
2454    cvt.s64.s32 %e_ni64, %e_ni;
2455    add.s64 %e_ni64, %e_ni64, 1023;
2456    shl.b64 %e_bits, %e_ni64, 52;
2457    mov.b64 %e_nf, %e_bits;
2458    mul.f64 %ex, %ex, %e_nf;
2459    // --- end exp ---
2460
2461    sub.f64 %em1, %ex, %one;
2462    mul.f64 %neg_branch, %alpha_r, %em1;
2463
2464    mov.f64 %vr, 0d0000000000000000;
2465    setp.gt.f64 %pos, %x, %vr;
2466    selp.f64 %vr, %x, %neg_branch, %pos;
2467    st.global.f64 [%out], %vr;
2468
2469DONE:
2470    ret;
2471}
2472";
2473
2474/// PTX source for `elu_backward_kernel`:
2475/// `out[i] = x > 0 ? grad[i] : grad[i] * alpha * exp(x)`.
2476/// Takes `alpha` as an extra `.param .f32` parameter.
2477#[cfg(feature = "cuda")]
2478pub(crate) const ELU_BACKWARD_PTX: &str = "\
2479.version 7.0
2480.target sm_52
2481.address_size 64
2482
2483.visible .entry elu_backward_kernel(
2484    .param .u64 grad_ptr,
2485    .param .u64 input_ptr,
2486    .param .u64 out_ptr,
2487    .param .u32 n,
2488    .param .f32 alpha
2489) {
2490    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2491    .reg .u64 %grad, %input, %out, %off;
2492    .reg .f32 %vg, %x, %alpha_r, %lg2e, %ex, %neg_branch, %vr, %zero;
2493    .reg .pred %p, %pos;
2494
2495    ld.param.u64 %grad, [grad_ptr];
2496    ld.param.u64 %input, [input_ptr];
2497    ld.param.u64 %out, [out_ptr];
2498    ld.param.u32 %n_reg, [n];
2499    ld.param.f32 %alpha_r, [alpha];
2500
2501    mov.u32 %bid, %ctaid.x;
2502    mov.u32 %bdim, %ntid.x;
2503    mov.u32 %r_tid, %tid.x;
2504    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2505
2506    setp.ge.u32 %p, %r_tid, %n_reg;
2507    @%p bra DONE;
2508
2509    cvt.u64.u32 %off, %r_tid;
2510    shl.b64 %off, %off, 2;
2511    add.u64 %grad, %grad, %off;
2512    add.u64 %input, %input, %off;
2513    add.u64 %out, %out, %off;
2514
2515    ld.global.f32 %vg, [%grad];
2516    ld.global.f32 %x, [%input];
2517
2518    mov.f32 %lg2e, 0f3FB8AA3B;
2519    mov.f32 %zero, 0f00000000;
2520
2521    // exp(x) = 2^(x * log2(e))
2522    mul.f32 %ex, %x, %lg2e;
2523    ex2.approx.f32 %ex, %ex;
2524    // negative branch: grad * alpha * exp(x)
2525    mul.f32 %neg_branch, %vg, %alpha_r;
2526    mul.f32 %neg_branch, %neg_branch, %ex;
2527
2528    // x > 0 ? grad : grad * alpha * exp(x)
2529    setp.gt.f32 %pos, %x, %zero;
2530    selp.f32 %vr, %vg, %neg_branch, %pos;
2531    st.global.f32 [%out], %vr;
2532
2533DONE:
2534    ret;
2535}
2536";
2537
2538/// PTX source for `elu_backward_f64_kernel` (f64).
2539/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(x).
2540#[cfg(feature = "cuda")]
2541pub(crate) const ELU_BACKWARD_F64_PTX: &str = "\
2542.version 7.0
2543.target sm_52
2544.address_size 64
2545
2546.visible .entry elu_backward_f64_kernel(
2547    .param .u64 grad_ptr,
2548    .param .u64 input_ptr,
2549    .param .u64 out_ptr,
2550    .param .u32 n,
2551    .param .f64 alpha
2552) {
2553    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2554    .reg .u64 %grad, %input, %out, %off;
2555    .reg .f64 %vg, %x, %alpha_r, %ex, %neg_branch, %vr, %zero, %one;
2556    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2557    .reg .s32 %e_ni;
2558    .reg .s64 %e_ni64, %e_bits;
2559    .reg .pred %p, %pos;
2560
2561    ld.param.u64 %grad, [grad_ptr];
2562    ld.param.u64 %input, [input_ptr];
2563    ld.param.u64 %out, [out_ptr];
2564    ld.param.u32 %n_reg, [n];
2565    ld.param.f64 %alpha_r, [alpha];
2566
2567    mov.u32 %bid, %ctaid.x;
2568    mov.u32 %bdim, %ntid.x;
2569    mov.u32 %r_tid, %tid.x;
2570    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2571
2572    setp.ge.u32 %p, %r_tid, %n_reg;
2573    @%p bra DONE;
2574
2575    cvt.u64.u32 %off, %r_tid;
2576    shl.b64 %off, %off, 3;
2577    add.u64 %grad, %grad, %off;
2578    add.u64 %input, %input, %off;
2579    add.u64 %out, %out, %off;
2580
2581    ld.global.f64 %vg, [%grad];
2582    ld.global.f64 %x, [%input];
2583
2584    mov.f64 %zero, 0d0000000000000000;
2585    mov.f64 %one, 0d3FF0000000000000;
2586
2587    // --- exp(%x) via Cody-Waite + degree-11 Horner ---
2588    mov.f64 %e_half, 0d3FE0000000000000;
2589    fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
2590    cvt.rmi.f64.f64 %e_nf, %e_nf;
2591    cvt.rni.s32.f64 %e_ni, %e_nf;
2592    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
2593    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2594    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2595    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2596    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2597    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2598    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2599    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2600    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2601    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2602    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2603    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2604    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2605    fma.rn.f64 %ex, %e_p, %e_r, %one;
2606    cvt.s64.s32 %e_ni64, %e_ni;
2607    add.s64 %e_ni64, %e_ni64, 1023;
2608    shl.b64 %e_bits, %e_ni64, 52;
2609    mov.b64 %e_nf, %e_bits;
2610    mul.f64 %ex, %ex, %e_nf;
2611    // --- end exp ---
2612
2613    mul.f64 %neg_branch, %vg, %alpha_r;
2614    mul.f64 %neg_branch, %neg_branch, %ex;
2615
2616    setp.gt.f64 %pos, %x, %zero;
2617    selp.f64 %vr, %vg, %neg_branch, %pos;
2618    st.global.f64 [%out], %vr;
2619
2620DONE:
2621    ret;
2622}
2623";
2624
2625/// PTX source for `mish_kernel`: `out[i] = x * tanh(softplus(x))`.
2626/// softplus(x) = ln(1 + exp(x)). For stability: when x > 20, softplus ~ x.
2627/// tanh(y) = (exp(2y) - 1) / (exp(2y) + 1).
2628#[cfg(feature = "cuda")]
2629pub(crate) const MISH_PTX: &str = "\
2630.version 7.0
2631.target sm_52
2632.address_size 64
2633
2634.visible .entry mish_kernel(
2635    .param .u64 a_ptr,
2636    .param .u64 out_ptr,
2637    .param .u32 n
2638) {
2639    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2640    .reg .u64 %a, %out, %off;
2641    .reg .f32 %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
2642    .reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %th, %vr;
2643    .reg .f32 %threshold;
2644    .reg .pred %p, %large;
2645
2646    ld.param.u64 %a, [a_ptr];
2647    ld.param.u64 %out, [out_ptr];
2648    ld.param.u32 %n_reg, [n];
2649
2650    mov.u32 %bid, %ctaid.x;
2651    mov.u32 %bdim, %ntid.x;
2652    mov.u32 %r_tid, %tid.x;
2653    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2654
2655    setp.ge.u32 %p, %r_tid, %n_reg;
2656    @%p bra DONE;
2657
2658    cvt.u64.u32 %off, %r_tid;
2659    shl.b64 %off, %off, 2;
2660
2661    add.u64 %a, %a, %off;
2662    add.u64 %out, %out, %off;
2663
2664    ld.global.f32 %x, [%a];
2665    mov.f32 %one, 0f3F800000;
2666    mov.f32 %lg2e, 0f3FB8AA3B;
2667    // threshold = 20.0 = 0x41A00000
2668    mov.f32 %threshold, 0f41A00000;
2669
2670    // softplus(x) = ln(1 + exp(x))
2671    // For large x (> 20), softplus ~ x to avoid overflow
2672    setp.gt.f32 %large, %x, %threshold;
2673    @%large bra LARGE_X;
2674
2675    // exp(x) = 2^(x * log2(e))
2676    mul.f32 %ex, %x, %lg2e;
2677    ex2.approx.f32 %ex, %ex;
2678    add.f32 %ep1, %ex, %one;
2679    // ln(1+exp(x)) = log2(1+exp(x)) / log2(e)
2680    lg2.approx.f32 %lg_ep1, %ep1;
2681    // 1/log2(e) = ln(2) = 0.6931472 = 0x3F317218
2682    mul.f32 %sp, %lg_ep1, 0f3F317218;
2683
2684    // tanh(sp) = (exp(2*sp) - 1) / (exp(2*sp) + 1)
2685    add.f32 %two_sp, %sp, %sp;
2686    mul.f32 %two_sp, %two_sp, %lg2e;
2687    ex2.approx.f32 %e2sp, %two_sp;
2688    sub.f32 %e2sp_m1, %e2sp, %one;
2689    add.f32 %e2sp_p1, %e2sp, %one;
2690    rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2691    mul.f32 %th, %e2sp_m1, %e2sp_p1;
2692
2693    mul.f32 %vr, %x, %th;
2694    st.global.f32 [%out], %vr;
2695    bra DONE;
2696
2697LARGE_X:
2698    // softplus ~ x, mish ~ x * tanh(x)
2699    // tanh(x) = (exp(2x)-1)/(exp(2x)+1)
2700    add.f32 %two_sp, %x, %x;
2701    mul.f32 %two_sp, %two_sp, %lg2e;
2702    ex2.approx.f32 %e2sp, %two_sp;
2703    sub.f32 %e2sp_m1, %e2sp, %one;
2704    add.f32 %e2sp_p1, %e2sp, %one;
2705    rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2706    mul.f32 %th, %e2sp_m1, %e2sp_p1;
2707    mul.f32 %vr, %x, %th;
2708    st.global.f32 [%out], %vr;
2709
2710DONE:
2711    ret;
2712}
2713";
2714
2715/// PTX source for `mish_f64_kernel`: `out[i] = x * tanh(softplus(x))` (f64).
2716/// Full f64 precision: exp via Cody-Waite + Horner, log via argument reduction.
2717#[cfg(feature = "cuda")]
2718pub(crate) const MISH_F64_PTX: &str = "\
2719.version 7.0
2720.target sm_52
2721.address_size 64
2722
2723.visible .entry mish_f64_kernel(
2724    .param .u64 a_ptr,
2725    .param .u64 out_ptr,
2726    .param .u32 n
2727) {
2728    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2729    .reg .u64 %a, %out, %off;
2730    .reg .f64 %x, %one, %two, %ex, %ep1, %sp;
2731    .reg .f64 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %th, %vr;
2732    .reg .f64 %threshold;
2733    // exp subroutine regs
2734    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2735    .reg .s32 %e_ni;
2736    .reg .s64 %e_ni64, %e_bits;
2737    // log subroutine regs
2738    .reg .u64 %l_xbits, %l_mbits, %l_bias;
2739    .reg .s64 %l_exp64;
2740    .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
2741    .reg .pred %p, %large;
2742
2743    ld.param.u64 %a, [a_ptr];
2744    ld.param.u64 %out, [out_ptr];
2745    ld.param.u32 %n_reg, [n];
2746
2747    mov.u32 %bid, %ctaid.x;
2748    mov.u32 %bdim, %ntid.x;
2749    mov.u32 %r_tid, %tid.x;
2750    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2751
2752    setp.ge.u32 %p, %r_tid, %n_reg;
2753    @%p bra DONE;
2754
2755    cvt.u64.u32 %off, %r_tid;
2756    shl.b64 %off, %off, 3;
2757    add.u64 %a, %a, %off;
2758    add.u64 %out, %out, %off;
2759
2760    ld.global.f64 %x, [%a];
2761    mov.f64 %one, 0d3FF0000000000000;
2762    mov.f64 %two, 0d4000000000000000;
2763    mov.f64 %threshold, 0d4034000000000000;
2764
2765    setp.gt.f64 %large, %x, %threshold;
2766    @%large bra LARGE_X;
2767
2768    // === softplus: sp = ln(1 + exp(x)) ===
2769    // exp(x)
2770    mov.f64 %e_half, 0d3FE0000000000000;
2771    fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
2772    cvt.rmi.f64.f64 %e_nf, %e_nf;
2773    cvt.rni.s32.f64 %e_ni, %e_nf;
2774    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
2775    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2776    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2777    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2778    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2779    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2780    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2781    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2782    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2783    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2784    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2785    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2786    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2787    fma.rn.f64 %ex, %e_p, %e_r, %one;
2788    cvt.s64.s32 %e_ni64, %e_ni;
2789    add.s64 %e_ni64, %e_ni64, 1023;
2790    shl.b64 %e_bits, %e_ni64, 52;
2791    mov.b64 %e_nf, %e_bits;
2792    mul.f64 %ex, %ex, %e_nf;
2793
2794    // ep1 = 1 + exp(x)
2795    add.f64 %ep1, %ex, %one;
2796
2797    // ln(ep1) via argument reduction
2798    mov.b64 %l_xbits, %ep1;
2799    shr.u64 %l_exp64, %l_xbits, 52;
2800    and.b64 %l_exp64, %l_exp64, 2047;
2801    sub.s64 %l_exp64, %l_exp64, 1023;
2802    cvt.rn.f64.s64 %l_nf, %l_exp64;
2803    mov.u64 %l_bias, 0x3FF0000000000000;
2804    and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
2805    or.b64 %l_mbits, %l_mbits, %l_bias;
2806    mov.b64 %l_m, %l_mbits;
2807    sub.f64 %l_f, %l_m, %one;
2808    add.f64 %l_s, %l_m, %one;
2809    div.rn.f64 %l_f, %l_f, %l_s;
2810    mul.f64 %l_f2, %l_f, %l_f;
2811    mov.f64 %l_p, 0d3FB745D1745D1746;
2812    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
2813    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
2814    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
2815    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
2816    fma.rn.f64 %l_p, %l_p, %l_f2, %one;
2817    mul.f64 %l_p, %l_p, %l_f;
2818    add.f64 %l_p, %l_p, %l_p;
2819    mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
2820    fma.rn.f64 %sp, %l_nf, %l_ln2, %l_p;
2821
2822    // === tanh(sp) = (exp(2*sp)-1)/(exp(2*sp)+1) ===
2823    add.f64 %two_sp, %sp, %sp;
2824    fma.rn.f64 %e_nf, %two_sp, 0d3FF71547652B82FE, %e_half;
2825    cvt.rmi.f64.f64 %e_nf, %e_nf;
2826    cvt.rni.s32.f64 %e_ni, %e_nf;
2827    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
2828    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2829    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2830    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2831    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2832    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2833    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2834    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2835    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2836    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2837    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2838    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2839    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2840    fma.rn.f64 %e2sp, %e_p, %e_r, %one;
2841    cvt.s64.s32 %e_ni64, %e_ni;
2842    add.s64 %e_ni64, %e_ni64, 1023;
2843    shl.b64 %e_bits, %e_ni64, 52;
2844    mov.b64 %e_nf, %e_bits;
2845    mul.f64 %e2sp, %e2sp, %e_nf;
2846
2847    sub.f64 %e2sp_m1, %e2sp, %one;
2848    add.f64 %e2sp_p1, %e2sp, %one;
2849    div.rn.f64 %th, %e2sp_m1, %e2sp_p1;
2850
2851    mul.f64 %vr, %x, %th;
2852    st.global.f64 [%out], %vr;
2853    bra DONE;
2854
2855LARGE_X:
2856    // softplus ~ x, tanh(x) = (exp(2x)-1)/(exp(2x)+1) in f64
2857    add.f64 %two_sp, %x, %x;
2858    mov.f64 %e_half, 0d3FE0000000000000;
2859    fma.rn.f64 %e_nf, %two_sp, 0d3FF71547652B82FE, %e_half;
2860    cvt.rmi.f64.f64 %e_nf, %e_nf;
2861    cvt.rni.s32.f64 %e_ni, %e_nf;
2862    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
2863    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2864    mov.f64 %e_p, 0d3E21EED8EFF8D898;
2865    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2866    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2867    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2868    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2869    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2870    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2871    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2872    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2873    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2874    fma.rn.f64 %e_p, %e_p, %e_r, %one;
2875    fma.rn.f64 %e2sp, %e_p, %e_r, %one;
2876    cvt.s64.s32 %e_ni64, %e_ni;
2877    add.s64 %e_ni64, %e_ni64, 1023;
2878    shl.b64 %e_bits, %e_ni64, 52;
2879    mov.b64 %e_nf, %e_bits;
2880    mul.f64 %e2sp, %e2sp, %e_nf;
2881
2882    sub.f64 %e2sp_m1, %e2sp, %one;
2883    add.f64 %e2sp_p1, %e2sp, %one;
2884    div.rn.f64 %th, %e2sp_m1, %e2sp_p1;
2885    mul.f64 %vr, %x, %th;
2886    st.global.f64 [%out], %vr;
2887
2888DONE:
2889    ret;
2890}
2891";
2892
2893/// PTX source for `mish_backward_kernel`:
2894/// ```text
2895/// sp = ln(1 + exp(x))        // softplus
2896/// t  = tanh(sp)
2897/// sig = sigmoid(x) = 1/(1+exp(-x))
2898/// out[i] = grad[i] * (t + x * sig * (1 - t*t))
2899/// ```
2900/// For stability: when x > 20, sp ~ x, t ~ tanh(x), sig ~ 1.
2901#[cfg(feature = "cuda")]
2902pub(crate) const MISH_BACKWARD_PTX: &str = "\
2903.version 7.0
2904.target sm_52
2905.address_size 64
2906
2907.visible .entry mish_backward_kernel(
2908    .param .u64 grad_ptr,
2909    .param .u64 input_ptr,
2910    .param .u64 out_ptr,
2911    .param .u32 n
2912) {
2913    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2914    .reg .u64 %grad, %input, %out, %off;
2915    .reg .f32 %vg, %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
2916    .reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %t, %t2, %one_m_t2;
2917    .reg .f32 %neg, %en, %denom, %sig, %x_sig_omt2, %deriv, %result;
2918    .reg .f32 %threshold;
2919    .reg .pred %p, %large;
2920
2921    ld.param.u64 %grad, [grad_ptr];
2922    ld.param.u64 %input, [input_ptr];
2923    ld.param.u64 %out, [out_ptr];
2924    ld.param.u32 %n_reg, [n];
2925
2926    mov.u32 %bid, %ctaid.x;
2927    mov.u32 %bdim, %ntid.x;
2928    mov.u32 %r_tid, %tid.x;
2929    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2930
2931    setp.ge.u32 %p, %r_tid, %n_reg;
2932    @%p bra DONE;
2933
2934    cvt.u64.u32 %off, %r_tid;
2935    shl.b64 %off, %off, 2;
2936    add.u64 %grad, %grad, %off;
2937    add.u64 %input, %input, %off;
2938    add.u64 %out, %out, %off;
2939
2940    ld.global.f32 %vg, [%grad];
2941    ld.global.f32 %x, [%input];
2942
2943    mov.f32 %one, 0f3F800000;
2944    mov.f32 %lg2e, 0f3FB8AA3B;
2945    // threshold = 20.0
2946    mov.f32 %threshold, 0f41A00000;
2947
2948    setp.gt.f32 %large, %x, %threshold;
2949    @%large bra LARGE_X;
2950
2951    // --- Normal path ---
2952    // softplus: sp = ln(1 + exp(x))
2953    mul.f32 %ex, %x, %lg2e;
2954    ex2.approx.f32 %ex, %ex;
2955    add.f32 %ep1, %ex, %one;
2956    lg2.approx.f32 %lg_ep1, %ep1;
2957    // ln(2) = 0x3F317218
2958    mul.f32 %sp, %lg_ep1, 0f3F317218;
2959
2960    // t = tanh(sp) = (exp(2*sp)-1)/(exp(2*sp)+1)
2961    add.f32 %two_sp, %sp, %sp;
2962    mul.f32 %two_sp, %two_sp, %lg2e;
2963    ex2.approx.f32 %e2sp, %two_sp;
2964    sub.f32 %e2sp_m1, %e2sp, %one;
2965    add.f32 %e2sp_p1, %e2sp, %one;
2966    rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2967    mul.f32 %t, %e2sp_m1, %e2sp_p1;
2968
2969    // sig = sigmoid(x) = 1/(1+exp(-x))
2970    neg.f32 %neg, %x;
2971    mul.f32 %neg, %neg, %lg2e;
2972    ex2.approx.f32 %en, %neg;
2973    add.f32 %denom, %one, %en;
2974    rcp.approx.f32 %sig, %denom;
2975
2976    // deriv = t + x * sig * (1 - t*t)
2977    mul.f32 %t2, %t, %t;
2978    sub.f32 %one_m_t2, %one, %t2;
2979    mul.f32 %x_sig_omt2, %x, %sig;
2980    mul.f32 %x_sig_omt2, %x_sig_omt2, %one_m_t2;
2981    add.f32 %deriv, %t, %x_sig_omt2;
2982    mul.f32 %result, %vg, %deriv;
2983    st.global.f32 [%out], %result;
2984    bra DONE;
2985
2986LARGE_X:
2987    // sp ~ x, t ~ tanh(x), sig ~ 1
2988    // tanh(x) = (exp(2x)-1)/(exp(2x)+1)
2989    add.f32 %two_sp, %x, %x;
2990    mul.f32 %two_sp, %two_sp, %lg2e;
2991    ex2.approx.f32 %e2sp, %two_sp;
2992    sub.f32 %e2sp_m1, %e2sp, %one;
2993    add.f32 %e2sp_p1, %e2sp, %one;
2994    rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2995    mul.f32 %t, %e2sp_m1, %e2sp_p1;
2996
2997    // sig ~ 1, deriv ~ t + x*(1-t*t)
2998    mul.f32 %t2, %t, %t;
2999    sub.f32 %one_m_t2, %one, %t2;
3000    mul.f32 %x_sig_omt2, %x, %one_m_t2;
3001    add.f32 %deriv, %t, %x_sig_omt2;
3002    mul.f32 %result, %vg, %deriv;
3003    st.global.f32 [%out], %result;
3004
3005DONE:
3006    ret;
3007}
3008";
3009
3010/// PTX source for `mish_backward_f64_kernel` (f64).
3011/// Full f64 precision: exp via Cody-Waite + Horner, log via argument reduction.
3012#[cfg(feature = "cuda")]
3013pub(crate) const MISH_BACKWARD_F64_PTX: &str = "\
3014.version 7.0
3015.target sm_52
3016.address_size 64
3017
3018.visible .entry mish_backward_f64_kernel(
3019    .param .u64 grad_ptr,
3020    .param .u64 input_ptr,
3021    .param .u64 out_ptr,
3022    .param .u32 n
3023) {
3024    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3025    .reg .u64 %grad, %input, %out, %off;
3026    .reg .f64 %vg, %x, %one, %ex, %ep1, %sp;
3027    .reg .f64 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %t, %t2, %one_m_t2;
3028    .reg .f64 %neg_x, %en, %denom, %sig, %x_sig_omt2, %deriv, %result;
3029    .reg .f64 %threshold;
3030    // exp subroutine regs
3031    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
3032    .reg .s32 %e_ni;
3033    .reg .s64 %e_ni64, %e_bits;
3034    // log subroutine regs
3035    .reg .u64 %l_xbits, %l_mbits, %l_bias;
3036    .reg .s64 %l_exp64;
3037    .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
3038    .reg .pred %p, %large;
3039
3040    ld.param.u64 %grad, [grad_ptr];
3041    ld.param.u64 %input, [input_ptr];
3042    ld.param.u64 %out, [out_ptr];
3043    ld.param.u32 %n_reg, [n];
3044
3045    mov.u32 %bid, %ctaid.x;
3046    mov.u32 %bdim, %ntid.x;
3047    mov.u32 %r_tid, %tid.x;
3048    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3049
3050    setp.ge.u32 %p, %r_tid, %n_reg;
3051    @%p bra DONE;
3052
3053    cvt.u64.u32 %off, %r_tid;
3054    shl.b64 %off, %off, 3;
3055    add.u64 %grad, %grad, %off;
3056    add.u64 %input, %input, %off;
3057    add.u64 %out, %out, %off;
3058
3059    ld.global.f64 %vg, [%grad];
3060    ld.global.f64 %x, [%input];
3061
3062    mov.f64 %one, 0d3FF0000000000000;
3063    mov.f64 %threshold, 0d4034000000000000;
3064
3065    setp.gt.f64 %large, %x, %threshold;
3066    @%large bra LARGE_X;
3067
3068    // === softplus: sp = ln(1 + exp(x)) ===
3069    // exp(x)
3070    mov.f64 %e_half, 0d3FE0000000000000;
3071    mul.f64 %e_nf, %x, 0d3FF71547652B82FE;
3072    cvt.rni.f64.f64 %e_nf, %e_nf;
3073    cvt.rni.s32.f64 %e_ni, %e_nf;
3074    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
3075    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3076    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3077    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3078    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3079    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3080    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3081    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3082    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3083    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3084    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3085    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3086    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3087    fma.rn.f64 %ex, %e_p, %e_r, %one;
3088    cvt.s64.s32 %e_ni64, %e_ni;
3089    add.s64 %e_ni64, %e_ni64, 1023;
3090    shl.b64 %e_bits, %e_ni64, 52;
3091    mov.b64 %e_nf, %e_bits;
3092    mul.f64 %ex, %ex, %e_nf;
3093
3094    add.f64 %ep1, %ex, %one;
3095
3096    // ln(ep1) via argument reduction
3097    mov.b64 %l_xbits, %ep1;
3098    shr.u64 %l_exp64, %l_xbits, 52;
3099    and.b64 %l_exp64, %l_exp64, 2047;
3100    sub.s64 %l_exp64, %l_exp64, 1023;
3101    cvt.rn.f64.s64 %l_nf, %l_exp64;
3102    mov.u64 %l_bias, 0x3FF0000000000000;
3103    and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
3104    or.b64 %l_mbits, %l_mbits, %l_bias;
3105    mov.b64 %l_m, %l_mbits;
3106    sub.f64 %l_f, %l_m, %one;
3107    add.f64 %l_s, %l_m, %one;
3108    div.rn.f64 %l_f, %l_f, %l_s;
3109    mul.f64 %l_f2, %l_f, %l_f;
3110    mov.f64 %l_p, 0d3FB745D1745D1746;
3111    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
3112    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
3113    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
3114    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
3115    fma.rn.f64 %l_p, %l_p, %l_f2, %one;
3116    mul.f64 %l_p, %l_p, %l_f;
3117    add.f64 %l_p, %l_p, %l_p;
3118    mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
3119    fma.rn.f64 %sp, %l_nf, %l_ln2, %l_p;
3120
3121    // === tanh(sp) ===
3122    add.f64 %two_sp, %sp, %sp;
3123    mul.f64 %e_nf, %two_sp, 0d3FF71547652B82FE;
3124    cvt.rni.f64.f64 %e_nf, %e_nf;
3125    cvt.rni.s32.f64 %e_ni, %e_nf;
3126    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
3127    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3128    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3129    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3130    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3131    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3132    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3133    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3134    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3135    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3136    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3137    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3138    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3139    fma.rn.f64 %e2sp, %e_p, %e_r, %one;
3140    cvt.s64.s32 %e_ni64, %e_ni;
3141    add.s64 %e_ni64, %e_ni64, 1023;
3142    shl.b64 %e_bits, %e_ni64, 52;
3143    mov.b64 %e_nf, %e_bits;
3144    mul.f64 %e2sp, %e2sp, %e_nf;
3145
3146    sub.f64 %e2sp_m1, %e2sp, %one;
3147    add.f64 %e2sp_p1, %e2sp, %one;
3148    div.rn.f64 %t, %e2sp_m1, %e2sp_p1;
3149
3150    // === sigmoid(x) = 1/(1+exp(-x)) ===
3151    neg.f64 %neg_x, %x;
3152    mul.f64 %e_nf, %neg_x, 0d3FF71547652B82FE;
3153    cvt.rni.f64.f64 %e_nf, %e_nf;
3154    cvt.rni.s32.f64 %e_ni, %e_nf;
3155    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
3156    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3157    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3158    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3159    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3160    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3161    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3162    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3163    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3164    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3165    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3166    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3167    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3168    fma.rn.f64 %en, %e_p, %e_r, %one;
3169    cvt.s64.s32 %e_ni64, %e_ni;
3170    add.s64 %e_ni64, %e_ni64, 1023;
3171    shl.b64 %e_bits, %e_ni64, 52;
3172    mov.b64 %e_nf, %e_bits;
3173    mul.f64 %en, %en, %e_nf;
3174
3175    add.f64 %denom, %one, %en;
3176    div.rn.f64 %sig, %one, %denom;
3177
3178    // deriv = t + x * sig * (1 - t*t)
3179    mul.f64 %t2, %t, %t;
3180    sub.f64 %one_m_t2, %one, %t2;
3181    mul.f64 %x_sig_omt2, %x, %sig;
3182    mul.f64 %x_sig_omt2, %x_sig_omt2, %one_m_t2;
3183    add.f64 %deriv, %t, %x_sig_omt2;
3184    mul.f64 %result, %vg, %deriv;
3185    st.global.f64 [%out], %result;
3186    bra DONE;
3187
3188LARGE_X:
3189    // sp ~ x, tanh(x) in f64, sig ~ 1
3190    add.f64 %two_sp, %x, %x;
3191    mov.f64 %e_half, 0d3FE0000000000000;
3192    mul.f64 %e_nf, %two_sp, 0d3FF71547652B82FE;
3193    cvt.rni.f64.f64 %e_nf, %e_nf;
3194    cvt.rni.s32.f64 %e_ni, %e_nf;
3195    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
3196    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3197    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3198    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3199    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3200    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3201    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3202    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3203    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3204    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3205    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3206    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3207    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3208    fma.rn.f64 %e2sp, %e_p, %e_r, %one;
3209    cvt.s64.s32 %e_ni64, %e_ni;
3210    add.s64 %e_ni64, %e_ni64, 1023;
3211    shl.b64 %e_bits, %e_ni64, 52;
3212    mov.b64 %e_nf, %e_bits;
3213    mul.f64 %e2sp, %e2sp, %e_nf;
3214
3215    sub.f64 %e2sp_m1, %e2sp, %one;
3216    add.f64 %e2sp_p1, %e2sp, %one;
3217    div.rn.f64 %t, %e2sp_m1, %e2sp_p1;
3218
3219    // sig ~ 1, deriv ~ t + x*(1-t*t)
3220    mul.f64 %t2, %t, %t;
3221    sub.f64 %one_m_t2, %one, %t2;
3222    mul.f64 %x_sig_omt2, %x, %one_m_t2;
3223    add.f64 %deriv, %t, %x_sig_omt2;
3224    mul.f64 %result, %vg, %deriv;
3225    st.global.f64 [%out], %result;
3226
3227DONE:
3228    ret;
3229}
3230";
3231
3232/// PTX source for `clamp_kernel`: `out[i] = max(min_val, min(max_val, x[i]))`.
3233/// Takes two extra f32 params: min_val, max_val.
3234#[cfg(feature = "cuda")]
3235pub(crate) const CLAMP_PTX: &str = "\
3236.version 7.0
3237.target sm_52
3238.address_size 64
3239
3240.visible .entry clamp_kernel(
3241    .param .u64 in_ptr,
3242    .param .u64 out_ptr,
3243    .param .u32 n,
3244    .param .f32 min_val,
3245    .param .f32 max_val
3246) {
3247    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3248    .reg .u64 %in, %out, %off;
3249    .reg .f32 %x, %mn, %mx, %result;
3250    .reg .pred %p;
3251
3252    ld.param.u64 %in, [in_ptr];
3253    ld.param.u64 %out, [out_ptr];
3254    ld.param.u32 %n_reg, [n];
3255    ld.param.f32 %mn, [min_val];
3256    ld.param.f32 %mx, [max_val];
3257
3258    mov.u32 %bid, %ctaid.x;
3259    mov.u32 %bdim, %ntid.x;
3260    mov.u32 %r_tid, %tid.x;
3261    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3262
3263    setp.ge.u32 %p, %r_tid, %n_reg;
3264    @%p bra DONE;
3265
3266    cvt.u64.u32 %off, %r_tid;
3267    shl.b64 %off, %off, 2;
3268    add.u64 %in, %in, %off;
3269    add.u64 %out, %out, %off;
3270
3271    ld.global.f32 %x, [%in];
3272    max.f32 %result, %x, %mn;
3273    min.f32 %result, %result, %mx;
3274    st.global.f32 [%out], %result;
3275
3276DONE:
3277    ret;
3278}
3279";
3280
3281
3282// ---------------------------------------------------------------------------
3283// Backward activation kernels
3284// ---------------------------------------------------------------------------
3285
3286/// PTX source for `fill_kernel`: `out[i] = scalar for all i < n`.
3287///
3288/// Used by sum/mean backward to produce a GPU-resident tensor filled
3289/// with a constant, without the CPU → GPU round-trip the legacy path
3290/// incurred (`vec![go; numel].to(device)`).
3291#[cfg(feature = "cuda")]
3292pub(crate) const FILL_F32_PTX: &str = "\
3293.version 7.0
3294.target sm_52
3295.address_size 64
3296
3297.visible .entry fill_f32_kernel(
3298    .param .u64 out_ptr,
3299    .param .f32 scalar,
3300    .param .u32 n
3301) {
3302    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3303    .reg .u64 %out, %off;
3304    .reg .f32 %v;
3305    .reg .pred %p;
3306
3307    ld.param.u64 %out, [out_ptr];
3308    ld.param.f32 %v, [scalar];
3309    ld.param.u32 %n_reg, [n];
3310
3311    mov.u32 %bid, %ctaid.x;
3312    mov.u32 %bdim, %ntid.x;
3313    mov.u32 %r_tid, %tid.x;
3314    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3315
3316    setp.ge.u32 %p, %r_tid, %n_reg;
3317    @%p bra DONE;
3318
3319    cvt.u64.u32 %off, %r_tid;
3320    shl.b64 %off, %off, 2;
3321    add.u64 %out, %out, %off;
3322    st.global.f32 [%out], %v;
3323
3324DONE:
3325    ret;
3326}
3327";
3328
3329
3330/// PTX source for `abs_backward_kernel`:
3331/// `out[i] = input[i] > 0 ? grad[i] : (input[i] < 0 ? -grad[i] : 0)`.
3332///
3333/// Implements the derivative of `|x|`: `sign(x)` with the convention
3334/// that `sign(0) = 0`. Takes `grad` (upstream) and `input` (forward
3335/// activation input) as its two tensor parameters.
3336#[cfg(feature = "cuda")]
3337pub(crate) const ABS_BACKWARD_PTX: &str = "\
3338.version 7.0
3339.target sm_52
3340.address_size 64
3341
3342.visible .entry abs_backward_kernel(
3343    .param .u64 grad_ptr,
3344    .param .u64 input_ptr,
3345    .param .u64 out_ptr,
3346    .param .u32 n
3347) {
3348    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3349    .reg .u64 %grad, %input, %out, %off;
3350    .reg .f32 %vg, %vi, %zero, %neg_vg, %tmp, %vr;
3351    .reg .pred %p, %pos, %neg;
3352
3353    ld.param.u64 %grad, [grad_ptr];
3354    ld.param.u64 %input, [input_ptr];
3355    ld.param.u64 %out, [out_ptr];
3356    ld.param.u32 %n_reg, [n];
3357
3358    mov.u32 %bid, %ctaid.x;
3359    mov.u32 %bdim, %ntid.x;
3360    mov.u32 %r_tid, %tid.x;
3361    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3362
3363    setp.ge.u32 %p, %r_tid, %n_reg;
3364    @%p bra DONE;
3365
3366    cvt.u64.u32 %off, %r_tid;
3367    shl.b64 %off, %off, 2;
3368
3369    add.u64 %grad, %grad, %off;
3370    add.u64 %input, %input, %off;
3371    add.u64 %out, %out, %off;
3372
3373    ld.global.f32 %vg, [%grad];
3374    ld.global.f32 %vi, [%input];
3375    mov.f32 %zero, 0f00000000;
3376
3377    neg.f32 %neg_vg, %vg;
3378
3379    // tmp = (vi < 0) ? -vg : 0
3380    setp.lt.f32 %neg, %vi, %zero;
3381    selp.f32 %tmp, %neg_vg, %zero, %neg;
3382    // vr = (vi > 0) ? vg : tmp
3383    setp.gt.f32 %pos, %vi, %zero;
3384    selp.f32 %vr, %vg, %tmp, %pos;
3385
3386    st.global.f32 [%out], %vr;
3387
3388DONE:
3389    ret;
3390}
3391";
3392
3393
3394/// PTX source for `relu_backward_kernel`: `out[i] = (input[i] > 0) ? grad[i] : 0`.
3395/// Takes two inputs: grad (upstream gradient) and input (forward activation input).
3396#[cfg(feature = "cuda")]
3397pub(crate) const RELU_BACKWARD_PTX: &str = "\
3398.version 7.0
3399.target sm_52
3400.address_size 64
3401
3402.visible .entry relu_backward_kernel(
3403    .param .u64 grad_ptr,
3404    .param .u64 input_ptr,
3405    .param .u64 out_ptr,
3406    .param .u32 n
3407) {
3408    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3409    .reg .u64 %grad, %input, %out, %off;
3410    .reg .f32 %vg, %vi, %zero, %vr;
3411    .reg .pred %p, %pos;
3412
3413    ld.param.u64 %grad, [grad_ptr];
3414    ld.param.u64 %input, [input_ptr];
3415    ld.param.u64 %out, [out_ptr];
3416    ld.param.u32 %n_reg, [n];
3417
3418    mov.u32 %bid, %ctaid.x;
3419    mov.u32 %bdim, %ntid.x;
3420    mov.u32 %r_tid, %tid.x;
3421    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3422
3423    setp.ge.u32 %p, %r_tid, %n_reg;
3424    @%p bra DONE;
3425
3426    cvt.u64.u32 %off, %r_tid;
3427    shl.b64 %off, %off, 2;
3428
3429    add.u64 %grad, %grad, %off;
3430    add.u64 %input, %input, %off;
3431    add.u64 %out, %out, %off;
3432
3433    ld.global.f32 %vg, [%grad];
3434    ld.global.f32 %vi, [%input];
3435    mov.f32 %zero, 0f00000000;
3436    setp.gt.f32 %pos, %vi, %zero;
3437    selp.f32 %vr, %vg, %zero, %pos;
3438    st.global.f32 [%out], %vr;
3439
3440DONE:
3441    ret;
3442}
3443";
3444
3445
3446/// PTX source for `gelu_backward_kernel`:
3447/// `out[i] = grad[i] * (sig + 1.702 * x * sig * (1 - sig))`
3448/// where `sig = sigmoid(1.702 * x)`.
3449/// This is the exact derivative of `gelu(x) = x * sigmoid(1.702 * x)`.
3450///
3451/// Uses `.approx` PTX instructions (`ex2.approx.f32`, `rcp.approx.f32`)
3452/// for performance. These have reduced precision (~2^-22 relative error)
3453/// compared to the full-precision variants, which is acceptable for neural
3454/// network training/inference where f32 precision is already limited.
3455#[cfg(feature = "cuda")]
3456pub(crate) const GELU_BACKWARD_PTX: &str = "\
3457.version 7.0
3458.target sm_52
3459.address_size 64
3460
3461.visible .entry gelu_backward_kernel(
3462    .param .u64 grad_ptr,
3463    .param .u64 input_ptr,
3464    .param .u64 out_ptr,
3465    .param .u32 n
3466) {
3467    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3468    .reg .u64 %grad, %input, %out, %off;
3469    .reg .f32 %vg, %x, %k, %kx, %neg_kx, %log2e, %exp_neg, %one, %denom, %sig;
3470    .reg .f32 %one_minus_sig, %kx_sig_oms, %dsig, %result;
3471    .reg .pred %p;
3472
3473    ld.param.u64 %grad, [grad_ptr];
3474    ld.param.u64 %input, [input_ptr];
3475    ld.param.u64 %out, [out_ptr];
3476    ld.param.u32 %n_reg, [n];
3477
3478    mov.u32 %bid, %ctaid.x;
3479    mov.u32 %bdim, %ntid.x;
3480    mov.u32 %r_tid, %tid.x;
3481    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3482
3483    setp.ge.u32 %p, %r_tid, %n_reg;
3484    @%p bra DONE;
3485
3486    cvt.u64.u32 %off, %r_tid;
3487    shl.b64 %off, %off, 2;
3488
3489    add.u64 %grad, %grad, %off;
3490    add.u64 %input, %input, %off;
3491    add.u64 %out, %out, %off;
3492
3493    ld.global.f32 %vg, [%grad];
3494    ld.global.f32 %x, [%input];
3495
3496    // sig = sigmoid(1.702 * x)
3497    mov.f32 %k, 0f3FDA2720;
3498    mul.f32 %kx, %k, %x;
3499    neg.f32 %neg_kx, %kx;
3500    mov.f32 %log2e, 0f3FB8AA3B;
3501    mul.f32 %neg_kx, %neg_kx, %log2e;
3502    ex2.approx.f32 %exp_neg, %neg_kx;
3503    mov.f32 %one, 0f3F800000;
3504    add.f32 %denom, %one, %exp_neg;
3505    rcp.approx.f32 %sig, %denom;
3506
3507    // d/dx gelu(x) = sig + k * x * sig * (1 - sig)
3508    sub.f32 %one_minus_sig, %one, %sig;
3509    mul.f32 %kx_sig_oms, %kx, %sig;
3510    mul.f32 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
3511    add.f32 %dsig, %sig, %kx_sig_oms;
3512
3513    // out = grad * d_gelu
3514    mul.f32 %result, %vg, %dsig;
3515    st.global.f32 [%out], %result;
3516
3517DONE:
3518    ret;
3519}
3520";
3521
3522/// PTX source for `gelu_backward_f64_kernel`: sigmoid-approx backward (f64).
3523/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(-k*x).
3524#[cfg(feature = "cuda")]
3525pub(crate) const GELU_BACKWARD_F64_PTX: &str = "\
3526.version 7.0
3527.target sm_52
3528.address_size 64
3529
3530.visible .entry gelu_backward_f64_kernel(
3531    .param .u64 grad_ptr,
3532    .param .u64 input_ptr,
3533    .param .u64 out_ptr,
3534    .param .u32 n
3535) {
3536    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3537    .reg .u64 %grad, %input, %out, %off;
3538    .reg .f64 %vg, %x, %k, %kx, %neg_kx, %exp_neg, %one, %denom, %sig;
3539    .reg .f64 %one_minus_sig, %kx_sig_oms, %dsig, %result;
3540    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
3541    .reg .s32 %e_ni;
3542    .reg .s64 %e_ni64, %e_bits;
3543    .reg .pred %p;
3544
3545    ld.param.u64 %grad, [grad_ptr];
3546    ld.param.u64 %input, [input_ptr];
3547    ld.param.u64 %out, [out_ptr];
3548    ld.param.u32 %n_reg, [n];
3549
3550    mov.u32 %bid, %ctaid.x;
3551    mov.u32 %bdim, %ntid.x;
3552    mov.u32 %r_tid, %tid.x;
3553    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3554
3555    setp.ge.u32 %p, %r_tid, %n_reg;
3556    @%p bra DONE;
3557
3558    cvt.u64.u32 %off, %r_tid;
3559    shl.b64 %off, %off, 3;
3560    add.u64 %grad, %grad, %off;
3561    add.u64 %input, %input, %off;
3562    add.u64 %out, %out, %off;
3563
3564    ld.global.f64 %vg, [%grad];
3565    ld.global.f64 %x, [%input];
3566
3567    mov.f64 %one, 0d3FF0000000000000;
3568    mov.f64 %k, 0d3FFB44E400000000;
3569    mul.f64 %kx, %k, %x;
3570    neg.f64 %neg_kx, %kx;
3571
3572    // --- exp(%neg_kx) via Cody-Waite + degree-11 Horner ---
3573    mov.f64 %e_half, 0d3FE0000000000000;
3574    fma.rn.f64 %e_nf, %neg_kx, 0d3FF71547652B82FE, %e_half;
3575    cvt.rmi.f64.f64 %e_nf, %e_nf;
3576    cvt.rni.s32.f64 %e_ni, %e_nf;
3577    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_kx;
3578    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3579    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3580    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3581    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3582    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3583    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3584    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3585    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3586    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3587    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
3588    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3589    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3590    fma.rn.f64 %exp_neg, %e_p, %e_r, %one;
3591    cvt.s64.s32 %e_ni64, %e_ni;
3592    add.s64 %e_ni64, %e_ni64, 1023;
3593    shl.b64 %e_bits, %e_ni64, 52;
3594    mov.b64 %e_nf, %e_bits;
3595    mul.f64 %exp_neg, %exp_neg, %e_nf;
3596    // --- end exp ---
3597
3598    add.f64 %denom, %one, %exp_neg;
3599    div.rn.f64 %sig, %one, %denom;
3600
3601    sub.f64 %one_minus_sig, %one, %sig;
3602    mul.f64 %kx_sig_oms, %kx, %sig;
3603    mul.f64 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
3604    add.f64 %dsig, %sig, %kx_sig_oms;
3605
3606    mul.f64 %result, %vg, %dsig;
3607    st.global.f64 [%out], %result;
3608
3609DONE:
3610    ret;
3611}
3612";
3613
3614/// PTX source for `gelu_backward_erf_kernel`:
3615/// Exact GELU backward using erf: `d/dx gelu(x) = Φ(x) + x·φ(x)`
3616/// where `Φ(x) = 0.5·(1 + erf(x/√2))` and `φ(x) = exp(-x²/2) / √(2π)`.
3617///
3618/// Uses Abramowitz & Stegun formula 7.1.26 for erf (|ε| < 1.5×10⁻⁷):
3619///   `erf(x) = 1 - (a₁t + a₂t² + a₃t³ + a₄t⁴ + a₅t⁵) · exp(-x²)`
3620///   where `t = 1/(1 + 0.3275911·|x|)`
3621#[cfg(feature = "cuda")]
3622pub(crate) const GELU_BACKWARD_ERF_PTX: &str = "\
3623.version 7.0
3624.target sm_52
3625.address_size 64
3626
3627.visible .entry gelu_backward_erf_kernel(
3628    .param .u64 grad_ptr,
3629    .param .u64 input_ptr,
3630    .param .u64 out_ptr,
3631    .param .u32 n
3632) {
3633    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3634    .reg .u64 %grad, %input, %out, %off;
3635    .reg .f32 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
3636    .reg .f32 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
3637    .reg .f32 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
3638    .reg .f32 %d_gelu, %result;
3639    .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %log2e;
3640    .reg .pred %pred_ge, %pred_neg;
3641
3642    ld.param.u64 %grad, [grad_ptr];
3643    ld.param.u64 %input, [input_ptr];
3644    ld.param.u64 %out, [out_ptr];
3645    ld.param.u32 %n_reg, [n];
3646
3647    mov.u32 %bid, %ctaid.x;
3648    mov.u32 %bdim, %ntid.x;
3649    mov.u32 %r_tid, %tid.x;
3650    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3651
3652    setp.ge.u32 %pred_ge, %r_tid, %n_reg;
3653    @%pred_ge bra DONE;
3654
3655    cvt.u64.u32 %off, %r_tid;
3656    shl.b64 %off, %off, 2;
3657
3658    add.u64 %grad, %grad, %off;
3659    add.u64 %input, %input, %off;
3660    add.u64 %out, %out, %off;
3661
3662    ld.global.f32 %vg, [%grad];
3663    ld.global.f32 %x, [%input];
3664
3665    mov.f32 %one, 0f3F800000;
3666    mov.f32 %half, 0f3F000000;
3667
3668    // z = x / sqrt(2) = x * 0.70710678
3669    mov.f32 %z, 0f3F3504F3;
3670    mul.f32 %z, %x, %z;
3671
3672    // |z| for erf(|z|)
3673    abs.f32 %ax, %z;
3674
3675    // t = 1 / (1 + 0.3275911 * |z|)
3676    mov.f32 %p, 0f3EA7BA05;
3677    mul.f32 %t, %p, %ax;
3678    add.f32 %t, %one, %t;
3679    rcp.approx.f32 %t, %t;
3680
3681    // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
3682    mov.f32 %a5, 0f3E0AAAAB;
3683    mov.f32 %a4, 0fBEB3A903;
3684    mov.f32 %a3, 0f3FB506DD;
3685    mov.f32 %a2, 0fBF03C1E1;
3686    mov.f32 %a1, 0f3EA0D6BB;
3687
3688    mul.f32 %pt, %t, %a5;
3689    add.f32 %pt, %pt, %a4;
3690    mul.f32 %pt, %pt, %t;
3691    add.f32 %pt, %pt, %a3;
3692    mul.f32 %pt, %pt, %t;
3693    add.f32 %pt, %pt, %a2;
3694    mul.f32 %pt, %pt, %t;
3695    add.f32 %pt, %pt, %a1;
3696    mul.f32 %pt, %pt, %t;
3697
3698    // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
3699    mul.f32 %z2, %ax, %ax;
3700    neg.f32 %neg_z2, %z2;
3701    mov.f32 %log2e, 0f3FB8AA3B;
3702    mul.f32 %neg_z2, %neg_z2, %log2e;
3703    ex2.approx.f32 %exp_neg_z2, %neg_z2;
3704
3705    // erf(|z|) = 1 - poly * exp(-z^2)
3706    mul.f32 %erf_val, %pt, %exp_neg_z2;
3707    sub.f32 %erf_val, %one, %erf_val;
3708
3709    // erf(-z) = -erf(z), so sign-correct
3710    setp.lt.f32 %pred_neg, %z, 0f00000000;
3711    @%pred_neg neg.f32 %erf_val, %erf_val;
3712
3713    // Φ(x) = 0.5 * (1 + erf(x/sqrt(2)))
3714    add.f32 %cdf, %one, %erf_val;
3715    mul.f32 %cdf, %half, %cdf;
3716
3717    // φ(x) = exp(-x²/2) / sqrt(2π)
3718    // exp(-x²/2):
3719    mul.f32 %neg_x2h, %x, %x;
3720    mul.f32 %neg_x2h, %neg_x2h, %half;
3721    neg.f32 %neg_x2h, %neg_x2h;
3722    mul.f32 %neg_x2h, %neg_x2h, %log2e;
3723    ex2.approx.f32 %exp_neg_x2h, %neg_x2h;
3724
3725    // 1/sqrt(2π) = 0.39894228
3726    mov.f32 %inv_sqrt_2pi, 0f3ECC4220;
3727    mul.f32 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
3728
3729    // d/dx gelu(x) = Φ(x) + x * φ(x)
3730    mul.f32 %x_pdf, %x, %pdf;
3731    add.f32 %d_gelu, %cdf, %x_pdf;
3732
3733    // out = grad * d_gelu
3734    mul.f32 %result, %vg, %d_gelu;
3735    st.global.f32 [%out], %result;
3736
3737DONE:
3738    ret;
3739}
3740";
3741
3742/// PTX source for `gelu_backward_erf_f64_kernel`: exact erf backward (f64).
3743/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(-z^2) and exp(-x^2/2).
3744#[cfg(feature = "cuda")]
3745pub(crate) const GELU_BACKWARD_ERF_F64_PTX: &str = "\
3746.version 7.0
3747.target sm_52
3748.address_size 64
3749
3750.visible .entry gelu_backward_erf_f64_kernel(
3751    .param .u64 grad_ptr,
3752    .param .u64 input_ptr,
3753    .param .u64 out_ptr,
3754    .param .u32 n
3755) {
3756    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3757    .reg .u64 %grad, %input, %out, %off;
3758    .reg .f64 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
3759    .reg .f64 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
3760    .reg .f64 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
3761    .reg .f64 %d_gelu, %result;
3762    .reg .f64 %p_coef, %a1, %a2, %a3, %a4, %a5;
3763    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
3764    .reg .s32 %e_ni;
3765    .reg .s64 %e_ni64, %e_bits;
3766    .reg .pred %pred_ge, %pred_neg;
3767
3768    ld.param.u64 %grad, [grad_ptr];
3769    ld.param.u64 %input, [input_ptr];
3770    ld.param.u64 %out, [out_ptr];
3771    ld.param.u32 %n_reg, [n];
3772
3773    mov.u32 %bid, %ctaid.x;
3774    mov.u32 %bdim, %ntid.x;
3775    mov.u32 %r_tid, %tid.x;
3776    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3777
3778    setp.ge.u32 %pred_ge, %r_tid, %n_reg;
3779    @%pred_ge bra DONE;
3780
3781    cvt.u64.u32 %off, %r_tid;
3782    shl.b64 %off, %off, 3;
3783    add.u64 %grad, %grad, %off;
3784    add.u64 %input, %input, %off;
3785    add.u64 %out, %out, %off;
3786
3787    ld.global.f64 %vg, [%grad];
3788    ld.global.f64 %x, [%input];
3789
3790    mov.f64 %one, 0d3FF0000000000000;
3791    mov.f64 %half, 0d3FE0000000000000;
3792
3793    mov.f64 %z, 0d3FE6A09E60000000;
3794    mul.f64 %z, %x, %z;
3795    abs.f64 %ax, %z;
3796
3797    mov.f64 %p_coef, 0d3FD4F740A0000000;
3798    mul.f64 %t, %p_coef, %ax;
3799    add.f64 %t, %one, %t;
3800    div.rn.f64 %t, %one, %t;
3801
3802    mov.f64 %a5, 0d3FC1555560000000;
3803    mov.f64 %a4, 0dBFD6752060000000;
3804    mov.f64 %a3, 0d3FF6A0DBA0000000;
3805    mov.f64 %a2, 0dBFE0783C20000000;
3806    mov.f64 %a1, 0d3FD41AD760000000;
3807
3808    mul.f64 %pt, %t, %a5;
3809    add.f64 %pt, %pt, %a4;
3810    mul.f64 %pt, %pt, %t;
3811    add.f64 %pt, %pt, %a3;
3812    mul.f64 %pt, %pt, %t;
3813    add.f64 %pt, %pt, %a2;
3814    mul.f64 %pt, %pt, %t;
3815    add.f64 %pt, %pt, %a1;
3816    mul.f64 %pt, %pt, %t;
3817
3818    // exp(-z^2) in full f64
3819    mul.f64 %z2, %ax, %ax;
3820    neg.f64 %neg_z2, %z2;
3821
3822    // --- exp(%neg_z2) ---
3823    mov.f64 %e_half, 0d3FE0000000000000;
3824    fma.rn.f64 %e_nf, %neg_z2, 0d3FF71547652B82FE, %e_half;
3825    cvt.rmi.f64.f64 %e_nf, %e_nf;
3826    cvt.rni.s32.f64 %e_ni, %e_nf;
3827    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_z2;
3828    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3829    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3830    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3831    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3832    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3833    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3834    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3835    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3836    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3837    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
3838    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3839    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3840    fma.rn.f64 %exp_neg_z2, %e_p, %e_r, %one;
3841    cvt.s64.s32 %e_ni64, %e_ni;
3842    add.s64 %e_ni64, %e_ni64, 1023;
3843    shl.b64 %e_bits, %e_ni64, 52;
3844    mov.b64 %e_nf, %e_bits;
3845    mul.f64 %exp_neg_z2, %exp_neg_z2, %e_nf;
3846    // --- end exp ---
3847
3848    mul.f64 %erf_val, %pt, %exp_neg_z2;
3849    sub.f64 %erf_val, %one, %erf_val;
3850
3851    setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
3852    @%pred_neg neg.f64 %erf_val, %erf_val;
3853
3854    add.f64 %cdf, %one, %erf_val;
3855    mul.f64 %cdf, %half, %cdf;
3856
3857    // phi(x) = exp(-x^2/2) / sqrt(2*pi)
3858    mul.f64 %neg_x2h, %x, %x;
3859    mul.f64 %neg_x2h, %neg_x2h, %half;
3860    neg.f64 %neg_x2h, %neg_x2h;
3861
3862    // --- exp(%neg_x2h) ---
3863    fma.rn.f64 %e_nf, %neg_x2h, 0d3FF71547652B82FE, %e_half;
3864    cvt.rmi.f64.f64 %e_nf, %e_nf;
3865    cvt.rni.s32.f64 %e_ni, %e_nf;
3866    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x2h;
3867    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3868    mov.f64 %e_p, 0d3E21EED8EFF8D898;
3869    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3870    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3871    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3872    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3873    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3874    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3875    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3876    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
3877    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3878    fma.rn.f64 %e_p, %e_p, %e_r, %one;
3879    fma.rn.f64 %exp_neg_x2h, %e_p, %e_r, %one;
3880    cvt.s64.s32 %e_ni64, %e_ni;
3881    add.s64 %e_ni64, %e_ni64, 1023;
3882    shl.b64 %e_bits, %e_ni64, 52;
3883    mov.b64 %e_nf, %e_bits;
3884    mul.f64 %exp_neg_x2h, %exp_neg_x2h, %e_nf;
3885    // --- end exp ---
3886
3887    // 1/sqrt(2*pi) = 0.39894228
3888    mov.f64 %inv_sqrt_2pi, 0d3FD9884440000000;
3889    mul.f64 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
3890
3891    mul.f64 %x_pdf, %x, %pdf;
3892    add.f64 %d_gelu, %cdf, %x_pdf;
3893
3894    mul.f64 %result, %vg, %d_gelu;
3895    st.global.f64 [%out], %result;
3896
3897DONE:
3898    ret;
3899}
3900";
3901
3902// ---------------------------------------------------------------------------
3903// Index-select (1-D gather) PTX kernel
3904// ---------------------------------------------------------------------------
3905// Thread i: output[i] = input[indices[i]]
3906// Indices are stored as f32 on the GPU (cast to u32 via truncation).
3907
3908#[cfg(feature = "cuda")]
3909pub(crate) const INDEX_SELECT_1D_PTX: &str = "\
3910.version 7.0
3911.target sm_52
3912.address_size 64
3913
3914.visible .entry index_select_1d_kernel(
3915    .param .u64 input_ptr,
3916    .param .u64 indices_ptr,
3917    .param .u64 out_ptr,
3918    .param .u32 n_indices
3919) {
3920    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
3921    .reg .u64 %input, %indices, %out, %off, %addr;
3922    .reg .f32 %idx_f, %val;
3923    .reg .pred %p;
3924
3925    ld.param.u64 %input, [input_ptr];
3926    ld.param.u64 %indices, [indices_ptr];
3927    ld.param.u64 %out, [out_ptr];
3928    ld.param.u32 %n_reg, [n_indices];
3929
3930    mov.u32 %bid, %ctaid.x;
3931    mov.u32 %bdim, %ntid.x;
3932    mov.u32 %r_tid, %tid.x;
3933    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3934
3935    setp.ge.u32 %p, %r_tid, %n_reg;
3936    @%p bra DONE;
3937
3938    // Byte offset for thread
3939    cvt.u64.u32 %off, %r_tid;
3940    shl.b64 %off, %off, 2;
3941
3942    // Read indices[tid] (f32 -> u32)
3943    add.u64 %addr, %indices, %off;
3944    ld.global.f32 %idx_f, [%addr];
3945    cvt.rzi.u32.f32 %idx, %idx_f;
3946
3947    // Read input[idx]
3948    cvt.u64.u32 %addr, %idx;
3949    shl.b64 %addr, %addr, 2;
3950    add.u64 %addr, %input, %addr;
3951    ld.global.f32 %val, [%addr];
3952
3953    // Write output[tid]
3954    add.u64 %addr, %out, %off;
3955    st.global.f32 [%addr], %val;
3956
3957DONE:
3958    ret;
3959}
3960";
3961
3962
3963// ---------------------------------------------------------------------------
3964// Scatter-add (1-D) PTX kernel — backward of index_select
3965// ---------------------------------------------------------------------------
3966// Thread i: atomicAdd(grad_input[indices[i]], grad_output[i])
3967// The output buffer (grad_input) must be pre-zeroed.
3968// Uses atom.global.add.f32 for safe concurrent accumulation when
3969// duplicate indices map multiple threads to the same output slot.
3970
3971#[cfg(feature = "cuda")]
3972pub(crate) const SCATTER_ADD_1D_PTX: &str = "\
3973.version 7.0
3974.target sm_52
3975.address_size 64
3976
3977.visible .entry scatter_add_1d_kernel(
3978    .param .u64 grad_output_ptr,
3979    .param .u64 indices_ptr,
3980    .param .u64 grad_input_ptr,
3981    .param .u32 n_indices
3982) {
3983    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
3984    .reg .u64 %go, %indices, %gi, %off, %addr;
3985    .reg .f32 %idx_f, %grad_val, %dummy;
3986    .reg .pred %p;
3987
3988    ld.param.u64 %go, [grad_output_ptr];
3989    ld.param.u64 %indices, [indices_ptr];
3990    ld.param.u64 %gi, [grad_input_ptr];
3991    ld.param.u32 %n_reg, [n_indices];
3992
3993    mov.u32 %bid, %ctaid.x;
3994    mov.u32 %bdim, %ntid.x;
3995    mov.u32 %r_tid, %tid.x;
3996    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3997
3998    setp.ge.u32 %p, %r_tid, %n_reg;
3999    @%p bra DONE;
4000
4001    // Byte offset for thread
4002    cvt.u64.u32 %off, %r_tid;
4003    shl.b64 %off, %off, 2;
4004
4005    // Read grad_output[tid]
4006    add.u64 %addr, %go, %off;
4007    ld.global.f32 %grad_val, [%addr];
4008
4009    // Read indices[tid] (f32 -> u32)
4010    add.u64 %addr, %indices, %off;
4011    ld.global.f32 %idx_f, [%addr];
4012    cvt.rzi.u32.f32 %idx, %idx_f;
4013
4014    // Atomic add: grad_input[idx] += grad_val
4015    cvt.u64.u32 %addr, %idx;
4016    shl.b64 %addr, %addr, 2;
4017    add.u64 %addr, %gi, %addr;
4018    atom.global.add.f32 %dummy, [%addr], %grad_val;
4019
4020DONE:
4021    ret;
4022}
4023";
4024
4025
4026// ---------------------------------------------------------------------------
4027// Masked-fill PTX kernel
4028// ---------------------------------------------------------------------------
4029// Thread i: output[i] = mask[i] >= 0.5 ? fill_value : input[i]
4030// Mask is stored as f32 (1.0 = true, 0.0 = false).
4031
4032#[cfg(feature = "cuda")]
4033pub(crate) const MASKED_FILL_PTX: &str = "\
4034.version 7.0
4035.target sm_52
4036.address_size 64
4037
4038.visible .entry masked_fill_kernel(
4039    .param .u64 input_ptr,
4040    .param .u64 mask_ptr,
4041    .param .u64 out_ptr,
4042    .param .f32 fill_value,
4043    .param .u32 n
4044) {
4045    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4046    .reg .u64 %input, %mask, %out, %off;
4047    .reg .f32 %in_val, %mask_val, %fill, %result, %half;
4048    .reg .pred %p, %pmask;
4049
4050    ld.param.u64 %input, [input_ptr];
4051    ld.param.u64 %mask, [mask_ptr];
4052    ld.param.u64 %out, [out_ptr];
4053    ld.param.f32 %fill, [fill_value];
4054    ld.param.u32 %n_reg, [n];
4055
4056    mov.u32 %bid, %ctaid.x;
4057    mov.u32 %bdim, %ntid.x;
4058    mov.u32 %r_tid, %tid.x;
4059    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4060
4061    setp.ge.u32 %p, %r_tid, %n_reg;
4062    @%p bra DONE;
4063
4064    cvt.u64.u32 %off, %r_tid;
4065    shl.b64 %off, %off, 2;
4066
4067    add.u64 %input, %input, %off;
4068    add.u64 %mask, %mask, %off;
4069    add.u64 %out, %out, %off;
4070
4071    ld.global.f32 %in_val, [%input];
4072    ld.global.f32 %mask_val, [%mask];
4073    mov.f32 %half, 0f3F000000;
4074    setp.ge.f32 %pmask, %mask_val, %half;
4075    selp.f32 %result, %fill, %in_val, %pmask;
4076    st.global.f32 [%out], %result;
4077
4078DONE:
4079    ret;
4080}
4081";
4082
4083
4084// ---------------------------------------------------------------------------
4085// Masked-zero PTX kernel — backward of masked_fill
4086// ---------------------------------------------------------------------------
4087// Thread i: output[i] = mask[i] >= 0.5 ? 0.0 : grad_output[i]
4088// Zeroes gradient at positions where the forward mask was true.
4089
4090#[cfg(feature = "cuda")]
4091pub(crate) const MASKED_ZERO_PTX: &str = "\
4092.version 7.0
4093.target sm_52
4094.address_size 64
4095
4096.visible .entry masked_zero_kernel(
4097    .param .u64 grad_ptr,
4098    .param .u64 mask_ptr,
4099    .param .u64 out_ptr,
4100    .param .u32 n
4101) {
4102    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4103    .reg .u64 %grad, %mask, %out, %off;
4104    .reg .f32 %vg, %mask_val, %zero, %result, %half;
4105    .reg .pred %p, %pmask;
4106
4107    ld.param.u64 %grad, [grad_ptr];
4108    ld.param.u64 %mask, [mask_ptr];
4109    ld.param.u64 %out, [out_ptr];
4110    ld.param.u32 %n_reg, [n];
4111
4112    mov.u32 %bid, %ctaid.x;
4113    mov.u32 %bdim, %ntid.x;
4114    mov.u32 %r_tid, %tid.x;
4115    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4116
4117    setp.ge.u32 %p, %r_tid, %n_reg;
4118    @%p bra DONE;
4119
4120    cvt.u64.u32 %off, %r_tid;
4121    shl.b64 %off, %off, 2;
4122
4123    add.u64 %grad, %grad, %off;
4124    add.u64 %mask, %mask, %off;
4125    add.u64 %out, %out, %off;
4126
4127    ld.global.f32 %vg, [%grad];
4128    ld.global.f32 %mask_val, [%mask];
4129    mov.f32 %zero, 0f00000000;
4130    mov.f32 %half, 0f3F000000;
4131    setp.ge.f32 %pmask, %mask_val, %half;
4132    selp.f32 %result, %zero, %vg, %pmask;
4133    st.global.f32 [%out], %result;
4134
4135DONE:
4136    ret;
4137}
4138";
4139
4140
4141// ---------------------------------------------------------------------------
4142// Sigmoid backward PTX kernel: out[i] = grad[i] * output[i] * (1 - output[i])
4143// ---------------------------------------------------------------------------
4144
4145#[cfg(feature = "cuda")]
4146pub(crate) const SIGMOID_BACKWARD_PTX: &str = "\
4147.version 7.0
4148.target sm_52
4149.address_size 64
4150
4151.visible .entry sigmoid_backward_kernel(
4152    .param .u64 grad_ptr,
4153    .param .u64 output_ptr,
4154    .param .u64 out_ptr,
4155    .param .u32 n
4156) {
4157    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4158    .reg .u64 %grad, %output, %out, %off;
4159    .reg .f32 %vg, %vo, %one, %one_minus_o, %result;
4160    .reg .pred %p;
4161
4162    ld.param.u64 %grad, [grad_ptr];
4163    ld.param.u64 %output, [output_ptr];
4164    ld.param.u64 %out, [out_ptr];
4165    ld.param.u32 %n_reg, [n];
4166
4167    mov.u32 %bid, %ctaid.x;
4168    mov.u32 %bdim, %ntid.x;
4169    mov.u32 %r_tid, %tid.x;
4170    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4171
4172    setp.ge.u32 %p, %r_tid, %n_reg;
4173    @%p bra DONE;
4174
4175    cvt.u64.u32 %off, %r_tid;
4176    shl.b64 %off, %off, 2;
4177
4178    add.u64 %grad, %grad, %off;
4179    add.u64 %output, %output, %off;
4180    add.u64 %out, %out, %off;
4181
4182    ld.global.f32 %vg, [%grad];
4183    ld.global.f32 %vo, [%output];
4184    mov.f32 %one, 0f3F800000;
4185    sub.f32 %one_minus_o, %one, %vo;
4186    mul.f32 %result, %vo, %one_minus_o;
4187    mul.f32 %result, %vg, %result;
4188    st.global.f32 [%out], %result;
4189
4190DONE:
4191    ret;
4192}
4193";
4194
4195
4196// ---------------------------------------------------------------------------
4197// Tanh backward PTX kernel: out[i] = grad[i] * (1 - output[i]^2)
4198// ---------------------------------------------------------------------------
4199
4200#[cfg(feature = "cuda")]
4201pub(crate) const TANH_BACKWARD_PTX: &str = "\
4202.version 7.0
4203.target sm_52
4204.address_size 64
4205
4206.visible .entry tanh_backward_kernel(
4207    .param .u64 grad_ptr,
4208    .param .u64 output_ptr,
4209    .param .u64 out_ptr,
4210    .param .u32 n
4211) {
4212    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4213    .reg .u64 %grad, %output, %out, %off;
4214    .reg .f32 %vg, %vo, %one, %o_sq, %one_minus_sq, %result;
4215    .reg .pred %p;
4216
4217    ld.param.u64 %grad, [grad_ptr];
4218    ld.param.u64 %output, [output_ptr];
4219    ld.param.u64 %out, [out_ptr];
4220    ld.param.u32 %n_reg, [n];
4221
4222    mov.u32 %bid, %ctaid.x;
4223    mov.u32 %bdim, %ntid.x;
4224    mov.u32 %r_tid, %tid.x;
4225    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4226
4227    setp.ge.u32 %p, %r_tid, %n_reg;
4228    @%p bra DONE;
4229
4230    cvt.u64.u32 %off, %r_tid;
4231    shl.b64 %off, %off, 2;
4232
4233    add.u64 %grad, %grad, %off;
4234    add.u64 %output, %output, %off;
4235    add.u64 %out, %out, %off;
4236
4237    ld.global.f32 %vg, [%grad];
4238    ld.global.f32 %vo, [%output];
4239    mov.f32 %one, 0f3F800000;
4240    mul.f32 %o_sq, %vo, %vo;
4241    sub.f32 %one_minus_sq, %one, %o_sq;
4242    mul.f32 %result, %vg, %one_minus_sq;
4243    st.global.f32 [%out], %result;
4244
4245DONE:
4246    ret;
4247}
4248";
4249
4250
4251// ---------------------------------------------------------------------------
4252// Softmax backward PTX kernel (row-wise, shared-memory dot product)
4253// ---------------------------------------------------------------------------
4254// For each row of length `cols`:
4255//   dot = sum(grad[row] * output[row])
4256//   out[i] = output[i] * (grad[i] - dot)
4257// One block per row, 256 threads per block.
4258
4259#[cfg(feature = "cuda")]
4260pub(crate) const SOFTMAX_BACKWARD_PTX: &str = "\
4261.version 7.0\n\
4262.target sm_52\n\
4263.address_size 64\n\
4264\n\
4265.shared .align 4 .f32 sdata[256];\n\
4266\n\
4267.visible .entry softmax_backward_kernel(\n\
4268    .param .u64 grad_ptr,\n\
4269    .param .u64 output_ptr,\n\
4270    .param .u64 out_ptr,\n\
4271    .param .u32 rows,\n\
4272    .param .u32 cols\n\
4273) {\n\
4274    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
4275    .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
4276    .reg .f32 %vg, %vo, %dot, %other_val, %diff, %result;\n\
4277    .reg .pred %p, %loop_p, %reduce_p;\n\
4278\n\
4279    ld.param.u64 %grad, [grad_ptr];\n\
4280    ld.param.u64 %output, [output_ptr];\n\
4281    ld.param.u64 %out, [out_ptr];\n\
4282    ld.param.u32 %rows_reg, [rows];\n\
4283    ld.param.u32 %cols_reg, [cols];\n\
4284\n\
4285    mov.u32 %bid, %ctaid.x;\n\
4286    mov.u32 %bdim, %ntid.x;\n\
4287    mov.u32 %r_tid, %tid.x;\n\
4288    mov.u64 %sbase, sdata;\n\
4289\n\
4290    setp.ge.u32 %p, %bid, %rows_reg;\n\
4291    @%p bra DONE;\n\
4292\n\
4293    // row_off = bid * cols * 4 (byte offset)\n\
4294    cvt.u64.u32 %row_off, %bid;\n\
4295    cvt.u64.u32 %off, %cols_reg;\n\
4296    mul.lo.u64 %row_off, %row_off, %off;\n\
4297    shl.b64 %row_off, %row_off, 2;\n\
4298\n\
4299    // Phase 1: compute partial dot = sum(grad[j] * output[j]) for this thread's elements\n\
4300    mov.f32 %dot, 0f00000000;\n\
4301    mov.u32 %j, %r_tid;\n\
4302DOT_LOOP:\n\
4303    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4304    @%loop_p bra DOT_LOOP_DONE;\n\
4305    cvt.u64.u32 %off, %j;\n\
4306    shl.b64 %off, %off, 2;\n\
4307    add.u64 %saddr, %grad, %off;\n\
4308    add.u64 %saddr, %saddr, %row_off;\n\
4309    ld.global.f32 %vg, [%saddr];\n\
4310    add.u64 %saddr, %output, %off;\n\
4311    add.u64 %saddr, %saddr, %row_off;\n\
4312    ld.global.f32 %vo, [%saddr];\n\
4313    fma.rn.f32 %dot, %vg, %vo, %dot;\n\
4314    add.u32 %j, %j, %bdim;\n\
4315    bra DOT_LOOP;\n\
4316DOT_LOOP_DONE:\n\
4317\n\
4318    // Store partial dot into shared memory and reduce\n\
4319    cvt.u64.u32 %off, %r_tid;\n\
4320    shl.b64 %off, %off, 2;\n\
4321    add.u64 %saddr, %sbase, %off;\n\
4322    st.shared.f32 [%saddr], %dot;\n\
4323    bar.sync 0;\n\
4324\n\
4325    mov.u32 %half, %bdim;\n\
4326DOT_REDUCE:\n\
4327    shr.u32 %half, %half, 1;\n\
4328    setp.eq.u32 %reduce_p, %half, 0;\n\
4329    @%reduce_p bra DOT_REDUCE_DONE;\n\
4330    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4331    @%reduce_p bra DOT_REDUCE_SKIP;\n\
4332    add.u32 %other_tid, %r_tid, %half;\n\
4333    cvt.u64.u32 %off, %other_tid;\n\
4334    shl.b64 %off, %off, 2;\n\
4335    add.u64 %saddr, %sbase, %off;\n\
4336    ld.shared.f32 %other_val, [%saddr];\n\
4337    cvt.u64.u32 %off, %r_tid;\n\
4338    shl.b64 %off, %off, 2;\n\
4339    add.u64 %saddr, %sbase, %off;\n\
4340    ld.shared.f32 %dot, [%saddr];\n\
4341    add.f32 %dot, %dot, %other_val;\n\
4342    st.shared.f32 [%saddr], %dot;\n\
4343DOT_REDUCE_SKIP:\n\
4344    bar.sync 0;\n\
4345    bra DOT_REDUCE;\n\
4346DOT_REDUCE_DONE:\n\
4347\n\
4348    // Broadcast dot to all threads\n\
4349    ld.shared.f32 %dot, [sdata];\n\
4350    bar.sync 0;\n\
4351\n\
4352    // Phase 2: out[j] = output[j] * (grad[j] - dot)\n\
4353    mov.u32 %j, %r_tid;\n\
4354WRITE_LOOP:\n\
4355    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4356    @%loop_p bra WRITE_LOOP_DONE;\n\
4357    cvt.u64.u32 %off, %j;\n\
4358    shl.b64 %off, %off, 2;\n\
4359    add.u64 %saddr, %grad, %off;\n\
4360    add.u64 %saddr, %saddr, %row_off;\n\
4361    ld.global.f32 %vg, [%saddr];\n\
4362    add.u64 %saddr, %output, %off;\n\
4363    add.u64 %saddr, %saddr, %row_off;\n\
4364    ld.global.f32 %vo, [%saddr];\n\
4365    sub.f32 %diff, %vg, %dot;\n\
4366    mul.f32 %result, %vo, %diff;\n\
4367    add.u64 %saddr, %out, %off;\n\
4368    add.u64 %saddr, %saddr, %row_off;\n\
4369    st.global.f32 [%saddr], %result;\n\
4370    add.u32 %j, %j, %bdim;\n\
4371    bra WRITE_LOOP;\n\
4372WRITE_LOOP_DONE:\n\
4373\n\
4374DONE:\n\
4375    ret;\n\
4376}\n\
4377";
4378
4379
4380// ---------------------------------------------------------------------------
4381// LogSoftmax forward PTX kernel (row-wise, shared-memory max + log-sum-exp)
4382// ---------------------------------------------------------------------------
4383// For each row of length `cols`:
4384//   m = max(x[j])
4385//   log_sum_exp = m + log(sum(exp(x[j] - m)))
4386//   out[j] = x[j] - log_sum_exp
4387// One block per row, 256 threads per block.
4388
4389#[cfg(feature = "cuda")]
4390pub(crate) const LOG_SOFTMAX_PTX: &str = "\
4391.version 7.0\n\
4392.target sm_52\n\
4393.address_size 64\n\
4394\n\
4395.shared .align 4 .f32 sdata[256];\n\
4396\n\
4397.visible .entry log_softmax_kernel(\n\
4398    .param .u64 input_ptr,\n\
4399    .param .u64 output_ptr,\n\
4400    .param .u32 rows,\n\
4401    .param .u32 cols\n\
4402) {\n\
4403    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
4404    .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
4405    .reg .f32 %val, %max_val, %sum_val, %exp_val, %log_sum_exp, %result;\n\
4406    .reg .pred %p, %loop_p;\n\
4407    .reg .u32 %half, %other_tid;\n\
4408    .reg .f32 %other_val;\n\
4409    .reg .pred %reduce_p;\n\
4410\n\
4411    ld.param.u64 %in, [input_ptr];\n\
4412    ld.param.u64 %out, [output_ptr];\n\
4413    ld.param.u32 %rows_reg, [rows];\n\
4414    ld.param.u32 %cols_reg, [cols];\n\
4415\n\
4416    mov.u32 %bid, %ctaid.x;\n\
4417    mov.u32 %bdim, %ntid.x;\n\
4418    mov.u32 %r_tid, %tid.x;\n\
4419    mov.u64 %sbase, sdata;\n\
4420\n\
4421    setp.ge.u32 %p, %bid, %rows_reg;\n\
4422    @%p bra DONE;\n\
4423\n\
4424    // row_off = bid * cols * 4 (byte offset)\n\
4425    cvt.u64.u32 %row_off, %bid;\n\
4426    cvt.u64.u32 %off, %cols_reg;\n\
4427    mul.lo.u64 %row_off, %row_off, %off;\n\
4428    shl.b64 %row_off, %row_off, 2;\n\
4429\n\
4430    // Phase 1: find max across row (grid-stride over columns)\n\
4431    mov.f32 %max_val, 0fFF800000;\n\
4432    mov.u32 %j, %r_tid;\n\
4433FIND_MAX:\n\
4434    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4435    @%loop_p bra FIND_MAX_DONE;\n\
4436    cvt.u64.u32 %off, %j;\n\
4437    shl.b64 %off, %off, 2;\n\
4438    add.u64 %off, %in, %off;\n\
4439    add.u64 %off, %off, %row_off;\n\
4440    ld.global.f32 %val, [%off];\n\
4441    max.f32 %max_val, %max_val, %val;\n\
4442    add.u32 %j, %j, %bdim;\n\
4443    bra FIND_MAX;\n\
4444FIND_MAX_DONE:\n\
4445\n\
4446    // Shared-memory tree reduction for max\n\
4447    cvt.u64.u32 %off, %r_tid;\n\
4448    shl.b64 %off, %off, 2;\n\
4449    add.u64 %saddr, %sbase, %off;\n\
4450    st.shared.f32 [%saddr], %max_val;\n\
4451    bar.sync 0;\n\
4452\n\
4453    mov.u32 %half, %bdim;\n\
4454MAX_REDUCE:\n\
4455    shr.u32 %half, %half, 1;\n\
4456    setp.eq.u32 %reduce_p, %half, 0;\n\
4457    @%reduce_p bra MAX_REDUCE_DONE;\n\
4458    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4459    @%reduce_p bra MAX_REDUCE_SKIP;\n\
4460    add.u32 %other_tid, %r_tid, %half;\n\
4461    cvt.u64.u32 %off, %other_tid;\n\
4462    shl.b64 %off, %off, 2;\n\
4463    add.u64 %saddr, %sbase, %off;\n\
4464    ld.shared.f32 %other_val, [%saddr];\n\
4465    cvt.u64.u32 %off, %r_tid;\n\
4466    shl.b64 %off, %off, 2;\n\
4467    add.u64 %saddr, %sbase, %off;\n\
4468    ld.shared.f32 %max_val, [%saddr];\n\
4469    max.f32 %max_val, %max_val, %other_val;\n\
4470    add.u64 %saddr, %sbase, %off;\n\
4471    st.shared.f32 [%saddr], %max_val;\n\
4472MAX_REDUCE_SKIP:\n\
4473    bar.sync 0;\n\
4474    bra MAX_REDUCE;\n\
4475MAX_REDUCE_DONE:\n\
4476\n\
4477    // Broadcast max to all threads\n\
4478    ld.shared.f32 %max_val, [sdata];\n\
4479    bar.sync 0;\n\
4480\n\
4481    // Phase 2: compute partial sum of exp(x[j] - max)\n\
4482    mov.f32 %sum_val, 0f00000000;\n\
4483    mov.u32 %j, %r_tid;\n\
4484SUM_EXP:\n\
4485    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4486    @%loop_p bra SUM_EXP_DONE;\n\
4487    cvt.u64.u32 %off, %j;\n\
4488    shl.b64 %off, %off, 2;\n\
4489    add.u64 %off, %in, %off;\n\
4490    add.u64 %off, %off, %row_off;\n\
4491    ld.global.f32 %val, [%off];\n\
4492    sub.f32 %val, %val, %max_val;\n\
4493    // exp(x) = exp2(x * log2(e)), log2(e) = 0x3FB8AA3B\n\
4494    mul.f32 %val, %val, 0f3FB8AA3B;\n\
4495    ex2.approx.f32 %exp_val, %val;\n\
4496    add.f32 %sum_val, %sum_val, %exp_val;\n\
4497    add.u32 %j, %j, %bdim;\n\
4498    bra SUM_EXP;\n\
4499SUM_EXP_DONE:\n\
4500\n\
4501    // Shared-memory tree reduction for sum\n\
4502    cvt.u64.u32 %off, %r_tid;\n\
4503    shl.b64 %off, %off, 2;\n\
4504    add.u64 %saddr, %sbase, %off;\n\
4505    st.shared.f32 [%saddr], %sum_val;\n\
4506    bar.sync 0;\n\
4507\n\
4508    mov.u32 %half, %bdim;\n\
4509SUM_REDUCE:\n\
4510    shr.u32 %half, %half, 1;\n\
4511    setp.eq.u32 %reduce_p, %half, 0;\n\
4512    @%reduce_p bra SUM_REDUCE_DONE;\n\
4513    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4514    @%reduce_p bra SUM_REDUCE_SKIP;\n\
4515    add.u32 %other_tid, %r_tid, %half;\n\
4516    cvt.u64.u32 %off, %other_tid;\n\
4517    shl.b64 %off, %off, 2;\n\
4518    add.u64 %saddr, %sbase, %off;\n\
4519    ld.shared.f32 %other_val, [%saddr];\n\
4520    cvt.u64.u32 %off, %r_tid;\n\
4521    shl.b64 %off, %off, 2;\n\
4522    add.u64 %saddr, %sbase, %off;\n\
4523    ld.shared.f32 %sum_val, [%saddr];\n\
4524    add.f32 %sum_val, %sum_val, %other_val;\n\
4525    add.u64 %saddr, %sbase, %off;\n\
4526    st.shared.f32 [%saddr], %sum_val;\n\
4527SUM_REDUCE_SKIP:\n\
4528    bar.sync 0;\n\
4529    bra SUM_REDUCE;\n\
4530SUM_REDUCE_DONE:\n\
4531\n\
4532    // Broadcast sum to all threads, compute log_sum_exp = max + log(sum)\n\
4533    ld.shared.f32 %sum_val, [sdata];\n\
4534    bar.sync 0;\n\
4535    // log(x) = log2(x) / log2(e) = log2(x) * ln(2)\n\
4536    // ln(2) = 0x3F317218\n\
4537    lg2.approx.f32 %log_sum_exp, %sum_val;\n\
4538    mul.f32 %log_sum_exp, %log_sum_exp, 0f3F317218;\n\
4539    add.f32 %log_sum_exp, %max_val, %log_sum_exp;\n\
4540\n\
4541    // Phase 3: out[j] = x[j] - log_sum_exp\n\
4542    mov.u32 %j, %r_tid;\n\
4543WRITE_OUTPUT:\n\
4544    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4545    @%loop_p bra WRITE_OUTPUT_DONE;\n\
4546    cvt.u64.u32 %off, %j;\n\
4547    shl.b64 %off, %off, 2;\n\
4548    add.u64 %saddr, %in, %off;\n\
4549    add.u64 %saddr, %saddr, %row_off;\n\
4550    ld.global.f32 %val, [%saddr];\n\
4551    sub.f32 %result, %val, %log_sum_exp;\n\
4552    cvt.u64.u32 %off, %j;\n\
4553    shl.b64 %off, %off, 2;\n\
4554    add.u64 %saddr, %out, %off;\n\
4555    add.u64 %saddr, %saddr, %row_off;\n\
4556    st.global.f32 [%saddr], %result;\n\
4557    add.u32 %j, %j, %bdim;\n\
4558    bra WRITE_OUTPUT;\n\
4559WRITE_OUTPUT_DONE:\n\
4560\n\
4561DONE:\n\
4562    ret;\n\
4563}\n\
4564";
4565
4566/// PTX source for `log_softmax_f64_kernel`: row-wise log-softmax (f64).
4567#[cfg(feature = "cuda")]
4568pub(crate) const LOG_SOFTMAX_F64_PTX: &str = "\
4569.version 7.0\n\
4570.target sm_52\n\
4571.address_size 64\n\
4572\n\
4573.shared .align 8 .f64 sdata[256];\n\
4574\n\
4575.visible .entry log_softmax_f64_kernel(\n\
4576    .param .u64 input_ptr,\n\
4577    .param .u64 output_ptr,\n\
4578    .param .u32 rows,\n\
4579    .param .u32 cols\n\
4580) {\n\
4581    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
4582    .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
4583    .reg .f64 %val, %max_val, %sum_val, %exp_val, %log_sum_exp, %result;\n\
4584    .reg .pred %p, %loop_p;\n\
4585    .reg .u32 %half, %other_tid;\n\
4586    .reg .f64 %other_val;\n\
4587    .reg .pred %reduce_p;\n\
4588    .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
4589    .reg .s32 %e_ni;\n\
4590    .reg .s64 %e_ni64, %e_bits;\n\
4591    .reg .u64 %l_xbits, %l_mbits, %l_bias;\n\
4592    .reg .s64 %l_exp64;\n\
4593    .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;\n\
4594\n\
4595    ld.param.u64 %in, [input_ptr];\n\
4596    ld.param.u64 %out, [output_ptr];\n\
4597    ld.param.u32 %rows_reg, [rows];\n\
4598    ld.param.u32 %cols_reg, [cols];\n\
4599\n\
4600    mov.u32 %bid, %ctaid.x;\n\
4601    mov.u32 %bdim, %ntid.x;\n\
4602    mov.u32 %r_tid, %tid.x;\n\
4603    mov.u64 %sbase, sdata;\n\
4604\n\
4605    setp.ge.u32 %p, %bid, %rows_reg;\n\
4606    @%p bra DONE;\n\
4607\n\
4608    cvt.u64.u32 %row_off, %bid;\n\
4609    cvt.u64.u32 %off, %cols_reg;\n\
4610    mul.lo.u64 %row_off, %row_off, %off;\n\
4611    shl.b64 %row_off, %row_off, 3;\n\
4612\n\
4613    mov.f64 %max_val, 0dFFF0000000000000;\n\
4614    mov.u32 %j, %r_tid;\n\
4615FIND_MAX:\n\
4616    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4617    @%loop_p bra FIND_MAX_DONE;\n\
4618    cvt.u64.u32 %off, %j;\n\
4619    shl.b64 %off, %off, 3;\n\
4620    add.u64 %off, %in, %off;\n\
4621    add.u64 %off, %off, %row_off;\n\
4622    ld.global.f64 %val, [%off];\n\
4623    max.f64 %max_val, %max_val, %val;\n\
4624    add.u32 %j, %j, %bdim;\n\
4625    bra FIND_MAX;\n\
4626FIND_MAX_DONE:\n\
4627\n\
4628    cvt.u64.u32 %off, %r_tid;\n\
4629    shl.b64 %off, %off, 3;\n\
4630    add.u64 %saddr, %sbase, %off;\n\
4631    st.shared.f64 [%saddr], %max_val;\n\
4632    bar.sync 0;\n\
4633\n\
4634    mov.u32 %half, %bdim;\n\
4635MAX_REDUCE:\n\
4636    shr.u32 %half, %half, 1;\n\
4637    setp.eq.u32 %reduce_p, %half, 0;\n\
4638    @%reduce_p bra MAX_REDUCE_DONE;\n\
4639    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4640    @%reduce_p bra MAX_REDUCE_SKIP;\n\
4641    add.u32 %other_tid, %r_tid, %half;\n\
4642    cvt.u64.u32 %off, %other_tid;\n\
4643    shl.b64 %off, %off, 3;\n\
4644    add.u64 %saddr, %sbase, %off;\n\
4645    ld.shared.f64 %other_val, [%saddr];\n\
4646    cvt.u64.u32 %off, %r_tid;\n\
4647    shl.b64 %off, %off, 3;\n\
4648    add.u64 %saddr, %sbase, %off;\n\
4649    ld.shared.f64 %max_val, [%saddr];\n\
4650    max.f64 %max_val, %max_val, %other_val;\n\
4651    st.shared.f64 [%saddr], %max_val;\n\
4652MAX_REDUCE_SKIP:\n\
4653    bar.sync 0;\n\
4654    bra MAX_REDUCE;\n\
4655MAX_REDUCE_DONE:\n\
4656\n\
4657    ld.shared.f64 %max_val, [sdata];\n\
4658    bar.sync 0;\n\
4659\n\
4660    mov.f64 %sum_val, 0d0000000000000000;\n\
4661    mov.u32 %j, %r_tid;\n\
4662SUM_EXP:\n\
4663    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4664    @%loop_p bra SUM_EXP_DONE;\n\
4665    cvt.u64.u32 %off, %j;\n\
4666    shl.b64 %off, %off, 3;\n\
4667    add.u64 %off, %in, %off;\n\
4668    add.u64 %off, %off, %row_off;\n\
4669    ld.global.f64 %val, [%off];\n\
4670    sub.f64 %val, %val, %max_val;\n\
4671    mov.f64 %e_one, 0d3FF0000000000000;\n\
4672    mov.f64 %e_half, 0d3FE0000000000000;\n\
4673    mul.f64 %e_nf, %val, 0d3FF71547652B82FE;\n\
4674    cvt.rni.f64.f64 %e_nf, %e_nf;\n\
4675    cvt.rni.s32.f64 %e_ni, %e_nf;\n\
4676    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %val;\n\
4677    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
4678    mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
4679    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
4680    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
4681    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
4682    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
4683    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
4684    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
4685    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
4686    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
4687    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
4688    fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
4689    fma.rn.f64 %exp_val, %e_p, %e_r, %e_one;\n\
4690    cvt.s64.s32 %e_ni64, %e_ni;\n\
4691    add.s64 %e_ni64, %e_ni64, 1023;\n\
4692    shl.b64 %e_bits, %e_ni64, 52;\n\
4693    mov.b64 %e_nf, %e_bits;\n\
4694    mul.f64 %exp_val, %exp_val, %e_nf;\n\
4695    add.f64 %sum_val, %sum_val, %exp_val;\n\
4696    add.u32 %j, %j, %bdim;\n\
4697    bra SUM_EXP;\n\
4698SUM_EXP_DONE:\n\
4699\n\
4700    cvt.u64.u32 %off, %r_tid;\n\
4701    shl.b64 %off, %off, 3;\n\
4702    add.u64 %saddr, %sbase, %off;\n\
4703    st.shared.f64 [%saddr], %sum_val;\n\
4704    bar.sync 0;\n\
4705\n\
4706    mov.u32 %half, %bdim;\n\
4707SUM_REDUCE:\n\
4708    shr.u32 %half, %half, 1;\n\
4709    setp.eq.u32 %reduce_p, %half, 0;\n\
4710    @%reduce_p bra SUM_REDUCE_DONE;\n\
4711    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4712    @%reduce_p bra SUM_REDUCE_SKIP;\n\
4713    add.u32 %other_tid, %r_tid, %half;\n\
4714    cvt.u64.u32 %off, %other_tid;\n\
4715    shl.b64 %off, %off, 3;\n\
4716    add.u64 %saddr, %sbase, %off;\n\
4717    ld.shared.f64 %other_val, [%saddr];\n\
4718    cvt.u64.u32 %off, %r_tid;\n\
4719    shl.b64 %off, %off, 3;\n\
4720    add.u64 %saddr, %sbase, %off;\n\
4721    ld.shared.f64 %sum_val, [%saddr];\n\
4722    add.f64 %sum_val, %sum_val, %other_val;\n\
4723    st.shared.f64 [%saddr], %sum_val;\n\
4724SUM_REDUCE_SKIP:\n\
4725    bar.sync 0;\n\
4726    bra SUM_REDUCE;\n\
4727SUM_REDUCE_DONE:\n\
4728\n\
4729    ld.shared.f64 %sum_val, [sdata];\n\
4730    bar.sync 0;\n\
4731    mov.f64 %e_one, 0d3FF0000000000000;\n\
4732    mov.b64 %l_xbits, %sum_val;\n\
4733    shr.u64 %l_exp64, %l_xbits, 52;\n\
4734    and.b64 %l_exp64, %l_exp64, 2047;\n\
4735    sub.s64 %l_exp64, %l_exp64, 1023;\n\
4736    cvt.rn.f64.s64 %l_nf, %l_exp64;\n\
4737    mov.u64 %l_bias, 0x3FF0000000000000;\n\
4738    and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;\n\
4739    or.b64 %l_mbits, %l_mbits, %l_bias;\n\
4740    mov.b64 %l_m, %l_mbits;\n\
4741    sub.f64 %l_f, %l_m, %e_one;\n\
4742    add.f64 %l_s, %l_m, %e_one;\n\
4743    div.rn.f64 %l_f, %l_f, %l_s;\n\
4744    mul.f64 %l_f2, %l_f, %l_f;\n\
4745    mov.f64 %l_p, 0d3FB745D1745D1746;\n\
4746    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;\n\
4747    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;\n\
4748    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;\n\
4749    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;\n\
4750    fma.rn.f64 %l_p, %l_p, %l_f2, %e_one;\n\
4751    mul.f64 %l_p, %l_p, %l_f;\n\
4752    add.f64 %l_p, %l_p, %l_p;\n\
4753    mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;\n\
4754    fma.rn.f64 %log_sum_exp, %l_nf, %l_ln2, %l_p;\n\
4755    add.f64 %log_sum_exp, %max_val, %log_sum_exp;\n\
4756\n\
4757    mov.u32 %j, %r_tid;\n\
4758WRITE_OUTPUT:\n\
4759    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4760    @%loop_p bra WRITE_OUTPUT_DONE;\n\
4761    cvt.u64.u32 %off, %j;\n\
4762    shl.b64 %off, %off, 3;\n\
4763    add.u64 %saddr, %in, %off;\n\
4764    add.u64 %saddr, %saddr, %row_off;\n\
4765    ld.global.f64 %val, [%saddr];\n\
4766    sub.f64 %result, %val, %log_sum_exp;\n\
4767    cvt.u64.u32 %off, %j;\n\
4768    shl.b64 %off, %off, 3;\n\
4769    add.u64 %saddr, %out, %off;\n\
4770    add.u64 %saddr, %saddr, %row_off;\n\
4771    st.global.f64 [%saddr], %result;\n\
4772    add.u32 %j, %j, %bdim;\n\
4773    bra WRITE_OUTPUT;\n\
4774WRITE_OUTPUT_DONE:\n\
4775\n\
4776DONE:\n\
4777    ret;\n\
4778}\n\
4779";
4780
4781// ---------------------------------------------------------------------------
4782// LogSoftmax backward PTX kernel (row-wise, shared-memory sum reduction)
4783// ---------------------------------------------------------------------------
4784// For each row of length `cols`:
4785//   sum_grad = sum(grad[j])
4786//   out[j] = grad[j] - exp(output[j]) * sum_grad
4787// where output[j] is the log-softmax output, so exp(output[j]) = softmax[j].
4788// One block per row, 256 threads per block.
4789
4790#[cfg(feature = "cuda")]
4791pub(crate) const LOG_SOFTMAX_BACKWARD_PTX: &str = "\
4792.version 7.0\n\
4793.target sm_52\n\
4794.address_size 64\n\
4795\n\
4796.shared .align 4 .f32 sdata[256];\n\
4797\n\
4798.visible .entry log_softmax_backward_kernel(\n\
4799    .param .u64 grad_ptr,\n\
4800    .param .u64 output_ptr,\n\
4801    .param .u64 out_ptr,\n\
4802    .param .u32 rows,\n\
4803    .param .u32 cols\n\
4804) {\n\
4805    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
4806    .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
4807    .reg .f32 %vg, %vo, %sum_grad, %other_val, %softmax_j, %result;\n\
4808    .reg .pred %p, %loop_p, %reduce_p;\n\
4809\n\
4810    ld.param.u64 %grad, [grad_ptr];\n\
4811    ld.param.u64 %output, [output_ptr];\n\
4812    ld.param.u64 %out, [out_ptr];\n\
4813    ld.param.u32 %rows_reg, [rows];\n\
4814    ld.param.u32 %cols_reg, [cols];\n\
4815\n\
4816    mov.u32 %bid, %ctaid.x;\n\
4817    mov.u32 %bdim, %ntid.x;\n\
4818    mov.u32 %r_tid, %tid.x;\n\
4819    mov.u64 %sbase, sdata;\n\
4820\n\
4821    setp.ge.u32 %p, %bid, %rows_reg;\n\
4822    @%p bra DONE;\n\
4823\n\
4824    // row_off = bid * cols * 4 (byte offset)\n\
4825    cvt.u64.u32 %row_off, %bid;\n\
4826    cvt.u64.u32 %off, %cols_reg;\n\
4827    mul.lo.u64 %row_off, %row_off, %off;\n\
4828    shl.b64 %row_off, %row_off, 2;\n\
4829\n\
4830    // Phase 1: compute partial sum_grad = sum(grad[j]) for this thread's elements\n\
4831    mov.f32 %sum_grad, 0f00000000;\n\
4832    mov.u32 %j, %r_tid;\n\
4833SUM_LOOP:\n\
4834    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4835    @%loop_p bra SUM_LOOP_DONE;\n\
4836    cvt.u64.u32 %off, %j;\n\
4837    shl.b64 %off, %off, 2;\n\
4838    add.u64 %saddr, %grad, %off;\n\
4839    add.u64 %saddr, %saddr, %row_off;\n\
4840    ld.global.f32 %vg, [%saddr];\n\
4841    add.f32 %sum_grad, %sum_grad, %vg;\n\
4842    add.u32 %j, %j, %bdim;\n\
4843    bra SUM_LOOP;\n\
4844SUM_LOOP_DONE:\n\
4845\n\
4846    // Store partial sum into shared memory and reduce\n\
4847    cvt.u64.u32 %off, %r_tid;\n\
4848    shl.b64 %off, %off, 2;\n\
4849    add.u64 %saddr, %sbase, %off;\n\
4850    st.shared.f32 [%saddr], %sum_grad;\n\
4851    bar.sync 0;\n\
4852\n\
4853    mov.u32 %half, %bdim;\n\
4854SUM_REDUCE:\n\
4855    shr.u32 %half, %half, 1;\n\
4856    setp.eq.u32 %reduce_p, %half, 0;\n\
4857    @%reduce_p bra SUM_REDUCE_DONE;\n\
4858    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4859    @%reduce_p bra SUM_REDUCE_SKIP;\n\
4860    add.u32 %other_tid, %r_tid, %half;\n\
4861    cvt.u64.u32 %off, %other_tid;\n\
4862    shl.b64 %off, %off, 2;\n\
4863    add.u64 %saddr, %sbase, %off;\n\
4864    ld.shared.f32 %other_val, [%saddr];\n\
4865    cvt.u64.u32 %off, %r_tid;\n\
4866    shl.b64 %off, %off, 2;\n\
4867    add.u64 %saddr, %sbase, %off;\n\
4868    ld.shared.f32 %sum_grad, [%saddr];\n\
4869    add.f32 %sum_grad, %sum_grad, %other_val;\n\
4870    st.shared.f32 [%saddr], %sum_grad;\n\
4871SUM_REDUCE_SKIP:\n\
4872    bar.sync 0;\n\
4873    bra SUM_REDUCE;\n\
4874SUM_REDUCE_DONE:\n\
4875\n\
4876    // Broadcast sum_grad to all threads\n\
4877    ld.shared.f32 %sum_grad, [sdata];\n\
4878    bar.sync 0;\n\
4879\n\
4880    // Phase 2: out[j] = grad[j] - exp(output[j]) * sum_grad\n\
4881    mov.u32 %j, %r_tid;\n\
4882WRITE_LOOP:\n\
4883    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4884    @%loop_p bra WRITE_LOOP_DONE;\n\
4885    cvt.u64.u32 %off, %j;\n\
4886    shl.b64 %off, %off, 2;\n\
4887    add.u64 %saddr, %grad, %off;\n\
4888    add.u64 %saddr, %saddr, %row_off;\n\
4889    ld.global.f32 %vg, [%saddr];\n\
4890    add.u64 %saddr, %output, %off;\n\
4891    add.u64 %saddr, %saddr, %row_off;\n\
4892    ld.global.f32 %vo, [%saddr];\n\
4893    // exp(log_softmax_output) = softmax probability\n\
4894    mul.f32 %vo, %vo, 0f3FB8AA3B;\n\
4895    ex2.approx.f32 %softmax_j, %vo;\n\
4896    // out[j] = grad[j] - softmax[j] * sum_grad\n\
4897    mul.f32 %result, %softmax_j, %sum_grad;\n\
4898    sub.f32 %result, %vg, %result;\n\
4899    add.u64 %saddr, %out, %off;\n\
4900    add.u64 %saddr, %saddr, %row_off;\n\
4901    st.global.f32 [%saddr], %result;\n\
4902    add.u32 %j, %j, %bdim;\n\
4903    bra WRITE_LOOP;\n\
4904WRITE_LOOP_DONE:\n\
4905\n\
4906DONE:\n\
4907    ret;\n\
4908}\n\
4909";
4910
4911/// PTX source for `log_softmax_backward_f64_kernel` (f64).
4912#[cfg(feature = "cuda")]
4913pub(crate) const LOG_SOFTMAX_BACKWARD_F64_PTX: &str = "\
4914.version 7.0\n\
4915.target sm_52\n\
4916.address_size 64\n\
4917\n\
4918.shared .align 8 .f64 sdata[256];\n\
4919\n\
4920.visible .entry log_softmax_backward_f64_kernel(\n\
4921    .param .u64 grad_ptr,\n\
4922    .param .u64 output_ptr,\n\
4923    .param .u64 out_ptr,\n\
4924    .param .u32 rows,\n\
4925    .param .u32 cols\n\
4926) {\n\
4927    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
4928    .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
4929    .reg .f64 %vg, %vo, %sum_grad, %other_val, %softmax_j, %result;\n\
4930    .reg .pred %p, %loop_p, %reduce_p;\n\
4931    .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
4932    .reg .s32 %e_ni;\n\
4933    .reg .s64 %e_ni64, %e_bits;\n\
4934\n\
4935    ld.param.u64 %grad, [grad_ptr];\n\
4936    ld.param.u64 %output, [output_ptr];\n\
4937    ld.param.u64 %out, [out_ptr];\n\
4938    ld.param.u32 %rows_reg, [rows];\n\
4939    ld.param.u32 %cols_reg, [cols];\n\
4940\n\
4941    mov.u32 %bid, %ctaid.x;\n\
4942    mov.u32 %bdim, %ntid.x;\n\
4943    mov.u32 %r_tid, %tid.x;\n\
4944    mov.u64 %sbase, sdata;\n\
4945\n\
4946    setp.ge.u32 %p, %bid, %rows_reg;\n\
4947    @%p bra DONE;\n\
4948\n\
4949    cvt.u64.u32 %row_off, %bid;\n\
4950    cvt.u64.u32 %off, %cols_reg;\n\
4951    mul.lo.u64 %row_off, %row_off, %off;\n\
4952    shl.b64 %row_off, %row_off, 3;\n\
4953\n\
4954    mov.f64 %sum_grad, 0d0000000000000000;\n\
4955    mov.u32 %j, %r_tid;\n\
4956SUM_LOOP:\n\
4957    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4958    @%loop_p bra SUM_LOOP_DONE;\n\
4959    cvt.u64.u32 %off, %j;\n\
4960    shl.b64 %off, %off, 3;\n\
4961    add.u64 %saddr, %grad, %off;\n\
4962    add.u64 %saddr, %saddr, %row_off;\n\
4963    ld.global.f64 %vg, [%saddr];\n\
4964    add.f64 %sum_grad, %sum_grad, %vg;\n\
4965    add.u32 %j, %j, %bdim;\n\
4966    bra SUM_LOOP;\n\
4967SUM_LOOP_DONE:\n\
4968\n\
4969    cvt.u64.u32 %off, %r_tid;\n\
4970    shl.b64 %off, %off, 3;\n\
4971    add.u64 %saddr, %sbase, %off;\n\
4972    st.shared.f64 [%saddr], %sum_grad;\n\
4973    bar.sync 0;\n\
4974\n\
4975    mov.u32 %half, %bdim;\n\
4976SUM_REDUCE:\n\
4977    shr.u32 %half, %half, 1;\n\
4978    setp.eq.u32 %reduce_p, %half, 0;\n\
4979    @%reduce_p bra SUM_REDUCE_DONE;\n\
4980    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4981    @%reduce_p bra SUM_REDUCE_SKIP;\n\
4982    add.u32 %other_tid, %r_tid, %half;\n\
4983    cvt.u64.u32 %off, %other_tid;\n\
4984    shl.b64 %off, %off, 3;\n\
4985    add.u64 %saddr, %sbase, %off;\n\
4986    ld.shared.f64 %other_val, [%saddr];\n\
4987    cvt.u64.u32 %off, %r_tid;\n\
4988    shl.b64 %off, %off, 3;\n\
4989    add.u64 %saddr, %sbase, %off;\n\
4990    ld.shared.f64 %sum_grad, [%saddr];\n\
4991    add.f64 %sum_grad, %sum_grad, %other_val;\n\
4992    st.shared.f64 [%saddr], %sum_grad;\n\
4993SUM_REDUCE_SKIP:\n\
4994    bar.sync 0;\n\
4995    bra SUM_REDUCE;\n\
4996SUM_REDUCE_DONE:\n\
4997\n\
4998    ld.shared.f64 %sum_grad, [sdata];\n\
4999    bar.sync 0;\n\
5000\n\
5001    mov.u32 %j, %r_tid;\n\
5002WRITE_LOOP:\n\
5003    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
5004    @%loop_p bra WRITE_LOOP_DONE;\n\
5005    cvt.u64.u32 %off, %j;\n\
5006    shl.b64 %off, %off, 3;\n\
5007    add.u64 %saddr, %grad, %off;\n\
5008    add.u64 %saddr, %saddr, %row_off;\n\
5009    ld.global.f64 %vg, [%saddr];\n\
5010    add.u64 %saddr, %output, %off;\n\
5011    add.u64 %saddr, %saddr, %row_off;\n\
5012    ld.global.f64 %vo, [%saddr];\n\
5013    // exp(log_softmax_output) — inline f64 exp\n\
5014    mov.f64 %e_one, 0d3FF0000000000000;\n\
5015    mov.f64 %e_half, 0d3FE0000000000000;\n\
5016    mul.f64 %e_nf, %vo, 0d3FF71547652B82FE;\n\
5017    cvt.rni.f64.f64 %e_nf, %e_nf;\n\
5018    cvt.rni.s32.f64 %e_ni, %e_nf;\n\
5019    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %vo;\n\
5020    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
5021    mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
5022    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
5023    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
5024    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
5025    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
5026    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
5027    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
5028    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
5029    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
5030    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
5031    fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
5032    fma.rn.f64 %softmax_j, %e_p, %e_r, %e_one;\n\
5033    cvt.s64.s32 %e_ni64, %e_ni;\n\
5034    add.s64 %e_ni64, %e_ni64, 1023;\n\
5035    shl.b64 %e_bits, %e_ni64, 52;\n\
5036    mov.b64 %e_nf, %e_bits;\n\
5037    mul.f64 %softmax_j, %softmax_j, %e_nf;\n\
5038    mul.f64 %result, %softmax_j, %sum_grad;\n\
5039    sub.f64 %result, %vg, %result;\n\
5040    add.u64 %saddr, %out, %off;\n\
5041    add.u64 %saddr, %saddr, %row_off;\n\
5042    st.global.f64 [%saddr], %result;\n\
5043    add.u32 %j, %j, %bdim;\n\
5044    bra WRITE_LOOP;\n\
5045WRITE_LOOP_DONE:\n\
5046\n\
5047DONE:\n\
5048    ret;\n\
5049}\n\
5050";
5051
5052// ---------------------------------------------------------------------------
5053// Sum-axis PTX kernel: reduce along one axis of a tensor
5054// ---------------------------------------------------------------------------
5055// Parameters: input_ptr, output_ptr, outer_size, axis_size, inner_size, total_output
5056/// PTX source for `reduce_sum_kernel`: parallel block-level sum reduction.
5057///
5058/// Each block reduces a contiguous chunk of the input array using shared
5059/// memory. Threads first accumulate a sequential sum (grid-stride loop),
5060/// store to shared memory, then do a tree reduction within the block.
5061/// Each block writes one partial sum to `output[blockIdx.x]`.
5062///
5063/// For a full reduction, launch once to get partial sums, then launch
5064/// again on the partial sums (or reduce on CPU if few blocks).
5065#[cfg(feature = "cuda")]
5066pub(crate) const REDUCE_SUM_PTX: &str = "\
5067.version 7.0
5068.target sm_52
5069.address_size 64
5070
5071// Shared memory for intra-block reduction (256 floats = 1024 bytes).
5072.shared .align 4 .f32 sdata[256];
5073
5074.visible .entry reduce_sum_kernel(
5075    .param .u64 in_ptr,
5076    .param .u64 out_ptr,
5077    .param .u32 n
5078) {
5079    .reg .u32 %tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
5080    .reg .u64 %in, %out, %off;
5081    .reg .f32 %sum, %other;
5082    .reg .pred %p, %ptid;
5083
5084    ld.param.u64 %in, [in_ptr];
5085    ld.param.u64 %out, [out_ptr];
5086    ld.param.u32 %n_reg, [n];
5087
5088    mov.u32 %tid, %tid.x;
5089    mov.u32 %bid, %ctaid.x;
5090    mov.u32 %bdim, %ntid.x;
5091    mov.u32 %gdim, %nctaid.x;
5092
5093    // Grid-stride accumulation: each thread sums multiple elements.
5094    // idx = bid * bdim + tid; stride = bdim * gdim
5095    mad.lo.u32 %idx, %bid, %bdim, %tid;
5096    mul.lo.u32 %stride, %bdim, %gdim;
5097    mov.f32 %sum, 0f00000000;
5098
5099GRID_LOOP:
5100    setp.ge.u32 %p, %idx, %n_reg;
5101    @%p bra GRID_DONE;
5102
5103    cvt.u64.u32 %off, %idx;
5104    shl.b64 %off, %off, 2;
5105    add.u64 %off, %in, %off;
5106    ld.global.f32 %other, [%off];
5107    add.f32 %sum, %sum, %other;
5108    add.u32 %idx, %idx, %stride;
5109    bra GRID_LOOP;
5110
5111GRID_DONE:
5112    // Write thread's partial sum to shared memory.
5113    cvt.u64.u32 %off, %tid;
5114    shl.b64 %off, %off, 2;
5115    st.shared.f32 [sdata + %off], %sum;
5116    bar.sync 0;
5117
5118    // Tree reduction in shared memory.
5119    mov.u32 %half, 128;
5120TREE_LOOP:
5121    setp.lt.u32 %p, %half, 1;
5122    @%p bra TREE_DONE;
5123
5124    setp.ge.u32 %ptid, %tid, %half;
5125    @%ptid bra TREE_SKIP;
5126
5127    // Load partner's value from sdata[tid + half].
5128    add.u32 %idx, %tid, %half;
5129    cvt.u64.u32 %off, %idx;
5130    shl.b64 %off, %off, 2;
5131    ld.shared.f32 %other, [sdata + %off];
5132    // Load own value.
5133    cvt.u64.u32 %off, %tid;
5134    shl.b64 %off, %off, 2;
5135    ld.shared.f32 %sum, [sdata + %off];
5136    add.f32 %sum, %sum, %other;
5137    st.shared.f32 [sdata + %off], %sum;
5138
5139TREE_SKIP:
5140    bar.sync 0;
5141    shr.u32 %half, %half, 1;
5142    bra TREE_LOOP;
5143
5144TREE_DONE:
5145    // Thread 0 writes block result.
5146    setp.ne.u32 %ptid, %tid, 0;
5147    @%ptid bra END;
5148
5149    ld.shared.f32 %sum, [sdata];
5150    cvt.u64.u32 %off, %bid;
5151    shl.b64 %off, %off, 2;
5152    add.u64 %out, %out, %off;
5153    st.global.f32 [%out], %sum;
5154
5155END:
5156    ret;
5157}
5158";
5159
5160
5161// Thread i: output[i] = sum_{k=0}^{axis_size-1} input[outer_idx * axis_size * inner_size + k * inner_size + inner_idx]
5162// where outer_idx = i / inner_size, inner_idx = i % inner_size.
5163
5164
5165#[cfg(feature = "cuda")]
5166pub(crate) const SUM_AXIS_PTX: &str = "\
5167.version 7.0
5168.target sm_52
5169.address_size 64
5170
5171.visible .entry sum_axis_kernel(
5172    .param .u64 input_ptr,
5173    .param .u64 output_ptr,
5174    .param .u32 outer_size,
5175    .param .u32 axis_size,
5176    .param .u32 inner_size,
5177    .param .u32 total_output
5178) {
5179    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %axis_sz, %inner_sz;
5180    .reg .u32 %outer_idx, %inner_idx, %k, %tmp;
5181    .reg .u64 %in, %out, %off, %addr;
5182    .reg .f32 %val, %sum;
5183    .reg .pred %p, %lp;
5184
5185    ld.param.u64 %in, [input_ptr];
5186    ld.param.u64 %out, [output_ptr];
5187    ld.param.u32 %outer_sz, [outer_size];
5188    ld.param.u32 %axis_sz, [axis_size];
5189    ld.param.u32 %inner_sz, [inner_size];
5190    ld.param.u32 %n_reg, [total_output];
5191
5192    mov.u32 %bid, %ctaid.x;
5193    mov.u32 %bdim, %ntid.x;
5194    mov.u32 %r_tid, %tid.x;
5195    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5196
5197    setp.ge.u32 %p, %r_tid, %n_reg;
5198    @%p bra DONE;
5199
5200    // outer_idx = r_tid / inner_size
5201    div.u32 %outer_idx, %r_tid, %inner_sz;
5202    // inner_idx = r_tid % inner_size
5203    rem.u32 %inner_idx, %r_tid, %inner_sz;
5204
5205    // base = outer_idx * axis_size * inner_size + inner_idx
5206    mul.lo.u32 %tmp, %outer_idx, %axis_sz;
5207    mul.lo.u32 %tmp, %tmp, %inner_sz;
5208    add.u32 %tmp, %tmp, %inner_idx;
5209
5210    mov.f32 %sum, 0f00000000;
5211    mov.u32 %k, 0;
5212SUM_LOOP:
5213    setp.ge.u32 %lp, %k, %axis_sz;
5214    @%lp bra SUM_LOOP_DONE;
5215
5216    // addr = in + (tmp + k * inner_size) * 4
5217    mul.lo.u32 %inner_idx, %k, %inner_sz;
5218    add.u32 %inner_idx, %tmp, %inner_idx;
5219    cvt.u64.u32 %off, %inner_idx;
5220    shl.b64 %off, %off, 2;
5221    add.u64 %addr, %in, %off;
5222    ld.global.f32 %val, [%addr];
5223    add.f32 %sum, %sum, %val;
5224
5225    add.u32 %k, %k, 1;
5226    bra SUM_LOOP;
5227SUM_LOOP_DONE:
5228
5229    // output[r_tid] = sum
5230    cvt.u64.u32 %off, %r_tid;
5231    shl.b64 %off, %off, 2;
5232    add.u64 %addr, %out, %off;
5233    st.global.f32 [%addr], %sum;
5234
5235DONE:
5236    ret;
5237}
5238";
5239
5240// ---------------------------------------------------------------------------
5241// Cumulative scan PTX kernels
5242//
5243// One thread per (outer_idx, inner_idx) pair. Each thread does a sequential
5244// scan along dim_size elements. Parallelism comes from outer*inner threads.
5245// ---------------------------------------------------------------------------
5246
5247/// PTX source for `cumsum_kernel`: prefix sum along an axis.
5248///
5249/// Thread i processes the scan for outer_idx = i / inner, inner_idx = i % inner.
5250/// `output[base + k*inner] = sum_{j=0}^{k} input[base + j*inner]`
5251#[cfg(feature = "cuda")]
5252pub(crate) const CUMSUM_PTX: &str = "\
5253.version 7.0
5254.target sm_52
5255.address_size 64
5256
5257.visible .entry cumsum_kernel(
5258    .param .u64 input_ptr,
5259    .param .u64 output_ptr,
5260    .param .u32 outer_size,
5261    .param .u32 dim_size,
5262    .param .u32 inner_size,
5263    .param .u32 total
5264) {
5265    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5266    .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5267    .reg .u64 %in, %out, %off, %addr;
5268    .reg .f32 %val, %acc;
5269    .reg .pred %p, %lp;
5270
5271    ld.param.u64 %in, [input_ptr];
5272    ld.param.u64 %out, [output_ptr];
5273    ld.param.u32 %outer_sz, [outer_size];
5274    ld.param.u32 %dim_sz, [dim_size];
5275    ld.param.u32 %inner_sz, [inner_size];
5276    ld.param.u32 %n_reg, [total];
5277
5278    mov.u32 %bid, %ctaid.x;
5279    mov.u32 %bdim, %ntid.x;
5280    mov.u32 %r_tid, %tid.x;
5281    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5282
5283    // total threads = outer * inner
5284    mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5285    setp.ge.u32 %p, %r_tid, %tmp;
5286    @%p bra DONE;
5287
5288    div.u32 %outer_idx, %r_tid, %inner_sz;
5289    rem.u32 %inner_idx, %r_tid, %inner_sz;
5290
5291    // base = outer_idx * dim_size * inner_size + inner_idx
5292    mul.lo.u32 %base, %outer_idx, %dim_sz;
5293    mul.lo.u32 %base, %base, %inner_sz;
5294    add.u32 %base, %base, %inner_idx;
5295
5296    mov.f32 %acc, 0f00000000;
5297    mov.u32 %k, 0;
5298SCAN_LOOP:
5299    setp.ge.u32 %lp, %k, %dim_sz;
5300    @%lp bra SCAN_DONE;
5301
5302    // idx = base + k * inner_size
5303    mul.lo.u32 %idx, %k, %inner_sz;
5304    add.u32 %idx, %base, %idx;
5305
5306    cvt.u64.u32 %off, %idx;
5307    shl.b64 %off, %off, 2;
5308    add.u64 %addr, %in, %off;
5309    ld.global.f32 %val, [%addr];
5310
5311    add.f32 %acc, %acc, %val;
5312
5313    add.u64 %addr, %out, %off;
5314    st.global.f32 [%addr], %acc;
5315
5316    add.u32 %k, %k, 1;
5317    bra SCAN_LOOP;
5318SCAN_DONE:
5319
5320DONE:
5321    ret;
5322}
5323";
5324
5325
5326/// PTX source for `cumprod_kernel`: prefix product along an axis.
5327///
5328/// Thread i processes the scan for outer_idx = i / inner, inner_idx = i % inner.
5329/// `output[base + k*inner] = prod_{j=0}^{k} input[base + j*inner]`
5330#[cfg(feature = "cuda")]
5331pub(crate) const CUMPROD_PTX: &str = "\
5332.version 7.0
5333.target sm_52
5334.address_size 64
5335
5336.visible .entry cumprod_kernel(
5337    .param .u64 input_ptr,
5338    .param .u64 output_ptr,
5339    .param .u32 outer_size,
5340    .param .u32 dim_size,
5341    .param .u32 inner_size,
5342    .param .u32 total
5343) {
5344    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5345    .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5346    .reg .u64 %in, %out, %off, %addr;
5347    .reg .f32 %val, %acc;
5348    .reg .pred %p, %lp;
5349
5350    ld.param.u64 %in, [input_ptr];
5351    ld.param.u64 %out, [output_ptr];
5352    ld.param.u32 %outer_sz, [outer_size];
5353    ld.param.u32 %dim_sz, [dim_size];
5354    ld.param.u32 %inner_sz, [inner_size];
5355    ld.param.u32 %n_reg, [total];
5356
5357    mov.u32 %bid, %ctaid.x;
5358    mov.u32 %bdim, %ntid.x;
5359    mov.u32 %r_tid, %tid.x;
5360    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5361
5362    mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5363    setp.ge.u32 %p, %r_tid, %tmp;
5364    @%p bra DONE;
5365
5366    div.u32 %outer_idx, %r_tid, %inner_sz;
5367    rem.u32 %inner_idx, %r_tid, %inner_sz;
5368
5369    mul.lo.u32 %base, %outer_idx, %dim_sz;
5370    mul.lo.u32 %base, %base, %inner_sz;
5371    add.u32 %base, %base, %inner_idx;
5372
5373    // acc = 1.0
5374    mov.f32 %acc, 0f3F800000;
5375    mov.u32 %k, 0;
5376SCAN_LOOP:
5377    setp.ge.u32 %lp, %k, %dim_sz;
5378    @%lp bra SCAN_DONE;
5379
5380    mul.lo.u32 %idx, %k, %inner_sz;
5381    add.u32 %idx, %base, %idx;
5382
5383    cvt.u64.u32 %off, %idx;
5384    shl.b64 %off, %off, 2;
5385    add.u64 %addr, %in, %off;
5386    ld.global.f32 %val, [%addr];
5387
5388    mul.f32 %acc, %acc, %val;
5389
5390    add.u64 %addr, %out, %off;
5391    st.global.f32 [%addr], %acc;
5392
5393    add.u32 %k, %k, 1;
5394    bra SCAN_LOOP;
5395SCAN_DONE:
5396
5397DONE:
5398    ret;
5399}
5400";
5401
5402
5403/// PTX source for `cummax_kernel`: running maximum along an axis.
5404///
5405/// Thread i processes the scan for outer_idx = i / inner, inner_idx = i % inner.
5406/// Outputs both values and argmax indices (as f32 for uniform buffer handling).
5407/// `values[idx] = max_{j=0}^{k} input[base + j*inner]`
5408/// `indices[idx] = argmax_{j=0}^{k} input[base + j*inner]`
5409#[cfg(feature = "cuda")]
5410pub(crate) const CUMMAX_PTX: &str = "\
5411.version 7.0
5412.target sm_52
5413.address_size 64
5414
5415.visible .entry cummax_kernel(
5416    .param .u64 input_ptr,
5417    .param .u64 output_ptr,
5418    .param .u64 indices_ptr,
5419    .param .u32 outer_size,
5420    .param .u32 dim_size,
5421    .param .u32 inner_size,
5422    .param .u32 total
5423) {
5424    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5425    .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
5426    .reg .u64 %in, %out, %ind, %off, %addr;
5427    .reg .f32 %val, %acc, %best_k_f;
5428    .reg .pred %p, %lp, %is_new_max;
5429
5430    ld.param.u64 %in, [input_ptr];
5431    ld.param.u64 %out, [output_ptr];
5432    ld.param.u64 %ind, [indices_ptr];
5433    ld.param.u32 %outer_sz, [outer_size];
5434    ld.param.u32 %dim_sz, [dim_size];
5435    ld.param.u32 %inner_sz, [inner_size];
5436    ld.param.u32 %n_reg, [total];
5437
5438    mov.u32 %bid, %ctaid.x;
5439    mov.u32 %bdim, %ntid.x;
5440    mov.u32 %r_tid, %tid.x;
5441    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5442
5443    mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5444    setp.ge.u32 %p, %r_tid, %tmp;
5445    @%p bra DONE;
5446
5447    div.u32 %outer_idx, %r_tid, %inner_sz;
5448    rem.u32 %inner_idx, %r_tid, %inner_sz;
5449
5450    mul.lo.u32 %base, %outer_idx, %dim_sz;
5451    mul.lo.u32 %base, %base, %inner_sz;
5452    add.u32 %base, %base, %inner_idx;
5453
5454    mov.b32 %acc, 0xFF800000;
5455    mov.u32 %best_k, 0;
5456    mov.u32 %k, 0;
5457SCAN_LOOP:
5458    setp.ge.u32 %lp, %k, %dim_sz;
5459    @%lp bra SCAN_DONE;
5460
5461    mul.lo.u32 %idx, %k, %inner_sz;
5462    add.u32 %idx, %base, %idx;
5463
5464    cvt.u64.u32 %off, %idx;
5465    shl.b64 %off, %off, 2;
5466    add.u64 %addr, %in, %off;
5467    ld.global.f32 %val, [%addr];
5468
5469    setp.gt.f32 %is_new_max, %val, %acc;
5470    @%is_new_max mov.u32 %best_k, %k;
5471    max.f32 %acc, %acc, %val;
5472
5473    add.u64 %addr, %out, %off;
5474    st.global.f32 [%addr], %acc;
5475
5476    cvt.rn.f32.u32 %best_k_f, %best_k;
5477    add.u64 %addr, %ind, %off;
5478    st.global.f32 [%addr], %best_k_f;
5479
5480    add.u32 %k, %k, 1;
5481    bra SCAN_LOOP;
5482SCAN_DONE:
5483
5484DONE:
5485    ret;
5486}
5487";
5488
5489
5490/// PTX source for `cummin_kernel`: running minimum along an axis.
5491///
5492/// Thread i processes the scan for outer_idx = i / inner, inner_idx = i % inner.
5493/// Outputs both values and argmin indices (as f32 for uniform buffer handling).
5494#[cfg(feature = "cuda")]
5495pub(crate) const CUMMIN_PTX: &str = "\
5496.version 7.0
5497.target sm_52
5498.address_size 64
5499
5500.visible .entry cummin_kernel(
5501    .param .u64 input_ptr,
5502    .param .u64 output_ptr,
5503    .param .u64 indices_ptr,
5504    .param .u32 outer_size,
5505    .param .u32 dim_size,
5506    .param .u32 inner_size,
5507    .param .u32 total
5508) {
5509    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5510    .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
5511    .reg .u64 %in, %out, %ind, %off, %addr;
5512    .reg .f32 %val, %acc, %best_k_f;
5513    .reg .pred %p, %lp, %is_new_min;
5514
5515    ld.param.u64 %in, [input_ptr];
5516    ld.param.u64 %out, [output_ptr];
5517    ld.param.u64 %ind, [indices_ptr];
5518    ld.param.u32 %outer_sz, [outer_size];
5519    ld.param.u32 %dim_sz, [dim_size];
5520    ld.param.u32 %inner_sz, [inner_size];
5521    ld.param.u32 %n_reg, [total];
5522
5523    mov.u32 %bid, %ctaid.x;
5524    mov.u32 %bdim, %ntid.x;
5525    mov.u32 %r_tid, %tid.x;
5526    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5527
5528    mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5529    setp.ge.u32 %p, %r_tid, %tmp;
5530    @%p bra DONE;
5531
5532    div.u32 %outer_idx, %r_tid, %inner_sz;
5533    rem.u32 %inner_idx, %r_tid, %inner_sz;
5534
5535    mul.lo.u32 %base, %outer_idx, %dim_sz;
5536    mul.lo.u32 %base, %base, %inner_sz;
5537    add.u32 %base, %base, %inner_idx;
5538
5539    mov.b32 %acc, 0x7F800000;
5540    mov.u32 %best_k, 0;
5541    mov.u32 %k, 0;
5542SCAN_LOOP:
5543    setp.ge.u32 %lp, %k, %dim_sz;
5544    @%lp bra SCAN_DONE;
5545
5546    mul.lo.u32 %idx, %k, %inner_sz;
5547    add.u32 %idx, %base, %idx;
5548
5549    cvt.u64.u32 %off, %idx;
5550    shl.b64 %off, %off, 2;
5551    add.u64 %addr, %in, %off;
5552    ld.global.f32 %val, [%addr];
5553
5554    setp.lt.f32 %is_new_min, %val, %acc;
5555    @%is_new_min mov.u32 %best_k, %k;
5556    min.f32 %acc, %acc, %val;
5557
5558    add.u64 %addr, %out, %off;
5559    st.global.f32 [%addr], %acc;
5560
5561    cvt.rn.f32.u32 %best_k_f, %best_k;
5562    add.u64 %addr, %ind, %off;
5563    st.global.f32 [%addr], %best_k_f;
5564
5565    add.u32 %k, %k, 1;
5566    bra SCAN_LOOP;
5567SCAN_DONE:
5568
5569DONE:
5570    ret;
5571}
5572";
5573
5574
5575/// PTX source for `logcumsumexp_kernel`: numerically stable log-cumulative-sum-exp.
5576///
5577/// Thread i processes the scan for outer_idx = i / inner, inner_idx = i % inner.
5578/// `acc = log(exp(acc) + exp(x))` computed as `m + log(exp(acc-m) + exp(x-m))`
5579/// where `m = max(acc, x)` for numerical stability.
5580///
5581/// Uses `ex2.approx.f32` for exp and `lg2.approx.f32` for log with
5582/// log2(e) and ln(2) conversion constants.
5583#[cfg(feature = "cuda")]
5584pub(crate) const LOGCUMSUMEXP_PTX: &str = "\
5585.version 7.0
5586.target sm_52
5587.address_size 64
5588
5589.visible .entry logcumsumexp_kernel(
5590    .param .u64 input_ptr,
5591    .param .u64 output_ptr,
5592    .param .u32 outer_size,
5593    .param .u32 dim_size,
5594    .param .u32 inner_size,
5595    .param .u32 total
5596) {
5597    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5598    .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5599    .reg .u64 %in, %out, %off, %addr;
5600    .reg .f32 %val, %acc, %m, %ea, %ev, %s, %ls, %log2e, %ln2;
5601    .reg .pred %p, %lp;
5602
5603    ld.param.u64 %in, [input_ptr];
5604    ld.param.u64 %out, [output_ptr];
5605    ld.param.u32 %outer_sz, [outer_size];
5606    ld.param.u32 %dim_sz, [dim_size];
5607    ld.param.u32 %inner_sz, [inner_size];
5608    ld.param.u32 %n_reg, [total];
5609
5610    // log2(e) = 1.4426950408...  -> 0x3FB8AA3B
5611    mov.b32 %log2e, 0x3FB8AA3B;
5612    // ln(2) = 0.6931471805... -> 0x3F317218
5613    mov.b32 %ln2, 0x3F317218;
5614
5615    mov.u32 %bid, %ctaid.x;
5616    mov.u32 %bdim, %ntid.x;
5617    mov.u32 %r_tid, %tid.x;
5618    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5619
5620    mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5621    setp.ge.u32 %p, %r_tid, %tmp;
5622    @%p bra DONE;
5623
5624    div.u32 %outer_idx, %r_tid, %inner_sz;
5625    rem.u32 %inner_idx, %r_tid, %inner_sz;
5626
5627    mul.lo.u32 %base, %outer_idx, %dim_sz;
5628    mul.lo.u32 %base, %base, %inner_sz;
5629    add.u32 %base, %base, %inner_idx;
5630
5631    // acc = -inf
5632    mov.b32 %acc, 0xFF800000;
5633    mov.u32 %k, 0;
5634SCAN_LOOP:
5635    setp.ge.u32 %lp, %k, %dim_sz;
5636    @%lp bra SCAN_DONE;
5637
5638    mul.lo.u32 %idx, %k, %inner_sz;
5639    add.u32 %idx, %base, %idx;
5640
5641    cvt.u64.u32 %off, %idx;
5642    shl.b64 %off, %off, 2;
5643    add.u64 %addr, %in, %off;
5644    ld.global.f32 %val, [%addr];
5645
5646    // Numerically stable: m = max(acc, x)
5647    max.f32 %m, %acc, %val;
5648    // exp(acc - m): (acc - m) * log2(e) -> ex2
5649    sub.f32 %ea, %acc, %m;
5650    mul.f32 %ea, %ea, %log2e;
5651    ex2.approx.f32 %ea, %ea;
5652    // exp(x - m): (x - m) * log2(e) -> ex2
5653    sub.f32 %ev, %val, %m;
5654    mul.f32 %ev, %ev, %log2e;
5655    ex2.approx.f32 %ev, %ev;
5656    // sum
5657    add.f32 %s, %ea, %ev;
5658    // log(sum) = lg2(sum) * ln(2)
5659    lg2.approx.f32 %ls, %s;
5660    mul.f32 %ls, %ls, %ln2;
5661    // acc = m + log(sum)
5662    add.f32 %acc, %m, %ls;
5663
5664    add.u64 %addr, %out, %off;
5665    st.global.f32 [%addr], %acc;
5666
5667    add.u32 %k, %k, 1;
5668    bra SCAN_LOOP;
5669SCAN_DONE:
5670
5671DONE:
5672    ret;
5673}
5674";
5675
5676/// PTX source for `logcumsumexp_f64_kernel`: numerically stable log-cumulative-sum-exp (f64).
5677#[cfg(feature = "cuda")]
5678pub(crate) const LOGCUMSUMEXP_F64_PTX: &str = "\
5679.version 7.0
5680.target sm_52
5681.address_size 64
5682
5683.visible .entry logcumsumexp_f64_kernel(
5684    .param .u64 input_ptr,
5685    .param .u64 output_ptr,
5686    .param .u32 outer_size,
5687    .param .u32 dim_size,
5688    .param .u32 inner_size,
5689    .param .u32 total
5690) {
5691    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5692    .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5693    .reg .u64 %in, %out, %off, %addr;
5694    .reg .f64 %val, %acc, %m, %ea, %ev, %s, %ls;
5695    .reg .pred %p, %lp;
5696    .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;
5697    .reg .s32 %e_ni;
5698    .reg .s64 %e_ni64, %e_bits;
5699    .reg .u64 %l_xbits, %l_mbits, %l_bias;
5700    .reg .s64 %l_exp64;
5701    .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
5702
5703    ld.param.u64 %in, [input_ptr];
5704    ld.param.u64 %out, [output_ptr];
5705    ld.param.u32 %outer_sz, [outer_size];
5706    ld.param.u32 %dim_sz, [dim_size];
5707    ld.param.u32 %inner_sz, [inner_size];
5708    ld.param.u32 %n_reg, [total];
5709
5710    mov.u32 %bid, %ctaid.x;
5711    mov.u32 %bdim, %ntid.x;
5712    mov.u32 %r_tid, %tid.x;
5713    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5714
5715    mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5716    setp.ge.u32 %p, %r_tid, %tmp;
5717    @%p bra DONE;
5718
5719    div.u32 %outer_idx, %r_tid, %inner_sz;
5720    rem.u32 %inner_idx, %r_tid, %inner_sz;
5721
5722    mul.lo.u32 %base, %outer_idx, %dim_sz;
5723    mul.lo.u32 %base, %base, %inner_sz;
5724    add.u32 %base, %base, %inner_idx;
5725
5726    // acc = -inf
5727    mov.b64 %acc, 0xFFF0000000000000;
5728    mov.u32 %k, 0;
5729SCAN_LOOP:
5730    setp.ge.u32 %lp, %k, %dim_sz;
5731    @%lp bra SCAN_DONE;
5732
5733    mul.lo.u32 %idx, %k, %inner_sz;
5734    add.u32 %idx, %base, %idx;
5735
5736    cvt.u64.u32 %off, %idx;
5737    shl.b64 %off, %off, 3;
5738    add.u64 %addr, %in, %off;
5739    ld.global.f64 %val, [%addr];
5740
5741    max.f64 %m, %acc, %val;
5742    mov.f64 %e_one, 0d3FF0000000000000;
5743    mov.f64 %e_half, 0d3FE0000000000000;
5744    // --- inline exp(acc - m) -> %ea ---
5745    sub.f64 %ea, %acc, %m;
5746    mul.f64 %e_nf, %ea, 0d3FF71547652B82FE;
5747    cvt.rni.f64.f64 %e_nf, %e_nf;
5748    cvt.rni.s32.f64 %e_ni, %e_nf;
5749    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %ea;
5750    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
5751    mov.f64 %e_p, 0d3E21EED8EFF8D898;
5752    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
5753    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
5754    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
5755    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
5756    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
5757    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
5758    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
5759    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
5760    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
5761    fma.rn.f64 %e_p, %e_p, %e_r, %e_one;
5762    fma.rn.f64 %ea, %e_p, %e_r, %e_one;
5763    cvt.s64.s32 %e_ni64, %e_ni;
5764    add.s64 %e_ni64, %e_ni64, 1023;
5765    shl.b64 %e_bits, %e_ni64, 52;
5766    mov.b64 %e_nf, %e_bits;
5767    mul.f64 %ea, %ea, %e_nf;
5768    // --- inline exp(val - m) -> %ev ---
5769    sub.f64 %ev, %val, %m;
5770    mul.f64 %e_nf, %ev, 0d3FF71547652B82FE;
5771    cvt.rni.f64.f64 %e_nf, %e_nf;
5772    cvt.rni.s32.f64 %e_ni, %e_nf;
5773    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %ev;
5774    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
5775    mov.f64 %e_p, 0d3E21EED8EFF8D898;
5776    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
5777    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
5778    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
5779    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
5780    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
5781    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
5782    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
5783    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
5784    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
5785    fma.rn.f64 %e_p, %e_p, %e_r, %e_one;
5786    fma.rn.f64 %ev, %e_p, %e_r, %e_one;
5787    cvt.s64.s32 %e_ni64, %e_ni;
5788    add.s64 %e_ni64, %e_ni64, 1023;
5789    shl.b64 %e_bits, %e_ni64, 52;
5790    mov.b64 %e_nf, %e_bits;
5791    mul.f64 %ev, %ev, %e_nf;
5792    add.f64 %s, %ea, %ev;
5793    // --- inline ln(%s) -> %ls ---
5794    mov.b64 %l_xbits, %s;
5795    shr.u64 %l_exp64, %l_xbits, 52;
5796    and.b64 %l_exp64, %l_exp64, 2047;
5797    sub.s64 %l_exp64, %l_exp64, 1023;
5798    cvt.rn.f64.s64 %l_nf, %l_exp64;
5799    mov.u64 %l_bias, 0x3FF0000000000000;
5800    and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
5801    or.b64 %l_mbits, %l_mbits, %l_bias;
5802    mov.b64 %l_m, %l_mbits;
5803    sub.f64 %l_f, %l_m, %e_one;
5804    add.f64 %l_s, %l_m, %e_one;
5805    div.rn.f64 %l_f, %l_f, %l_s;
5806    mul.f64 %l_f2, %l_f, %l_f;
5807    mov.f64 %l_p, 0d3FB745D1745D1746;
5808    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
5809    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
5810    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
5811    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
5812    fma.rn.f64 %l_p, %l_p, %l_f2, %e_one;
5813    mul.f64 %l_p, %l_p, %l_f;
5814    add.f64 %l_p, %l_p, %l_p;
5815    mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
5816    fma.rn.f64 %ls, %l_nf, %l_ln2, %l_p;
5817    add.f64 %acc, %m, %ls;
5818
5819    add.u64 %addr, %out, %off;
5820    st.global.f64 [%addr], %acc;
5821
5822    add.u32 %k, %k, 1;
5823    bra SCAN_LOOP;
5824SCAN_DONE:
5825
5826DONE:
5827    ret;
5828}
5829";
5830
5831// ---------------------------------------------------------------------------
5832// LayerNorm PTX kernel (row-wise: mean, var, normalize+affine)
5833//
5834// Uses `.approx` PTX instructions (`div.approx.f32`, `sqrt.approx.f32`,
5835// `rcp.approx.f32`) for performance. These have reduced precision (~2^-22
5836// relative error) compared to the full-precision variants, which is
5837// acceptable for neural network training/inference.
5838// ---------------------------------------------------------------------------
5839
5840#[cfg(feature = "cuda")]
5841pub(crate) const LAYERNORM_PTX: &str = "\
5842.version 7.0
5843.target sm_52
5844.address_size 64
5845
5846.shared .align 4 .f32 sdata[256];
5847
5848.visible .entry layernorm_kernel(
5849    .param .u64 in_ptr,
5850    .param .u64 out_ptr,
5851    .param .u64 w_ptr,
5852    .param .u64 b_ptr,
5853    .param .u32 rows,
5854    .param .u32 cols,
5855    .param .f32 eps
5856) {
5857    .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
5858    .reg .u64 %in, %out, %w, %b, %row_off, %off, %sbase, %saddr;
5859    .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %normed, %wv, %bv, %result, %other_val, %n_f;
5860    .reg .pred %p, %lp, %rp;
5861
5862    ld.param.u64 %in, [in_ptr];
5863    ld.param.u64 %out, [out_ptr];
5864    ld.param.u64 %w, [w_ptr];
5865    ld.param.u64 %b, [b_ptr];
5866    ld.param.u32 %rows_reg, [rows];
5867    ld.param.u32 %cols_reg, [cols];
5868    ld.param.f32 %eps_r, [eps];
5869
5870    mov.u64 %sbase, sdata;
5871
5872    mov.u32 %r_bid, %ctaid.x;
5873    mov.u32 %r_bdim, %ntid.x;
5874    mov.u32 %r_tid, %tid.x;
5875
5876    setp.ge.u32 %p, %r_bid, %rows_reg;
5877    @%p bra DONE;
5878
5879    cvt.u64.u32 %row_off, %r_bid;
5880    cvt.u64.u32 %off, %cols_reg;
5881    mul.lo.u64 %row_off, %row_off, %off;
5882    shl.b64 %row_off, %row_off, 2;
5883    cvt.rn.f32.u32 %n_f, %cols_reg;
5884
5885    mov.f32 %mean, 0f00000000;
5886    mov.u32 %j, %r_tid;
5887SM:
5888    setp.ge.u32 %lp, %j, %cols_reg;
5889    @%lp bra SMD;
5890    cvt.u64.u32 %off, %j;
5891    shl.b64 %off, %off, 2;
5892    add.u64 %off, %in, %off;
5893    add.u64 %off, %off, %row_off;
5894    ld.global.f32 %val, [%off];
5895    add.f32 %mean, %mean, %val;
5896    add.u32 %j, %j, %r_bdim;
5897    bra SM;
5898SMD:
5899    cvt.u64.u32 %off, %r_tid;
5900    shl.b64 %off, %off, 2;
5901    add.u64 %saddr, %sbase, %off;
5902    st.shared.f32 [%saddr], %mean;
5903    bar.sync 0;
5904    mov.u32 %half, %r_bdim;
5905MR:
5906    shr.u32 %half, %half, 1;
5907    setp.eq.u32 %rp, %half, 0;
5908    @%rp bra MRD;
5909    setp.ge.u32 %rp, %r_tid, %half;
5910    @%rp bra MRS;
5911    add.u32 %r_otid, %r_tid, %half;
5912    cvt.u64.u32 %off, %r_otid;
5913    shl.b64 %off, %off, 2;
5914    add.u64 %saddr, %sbase, %off;
5915    ld.shared.f32 %other_val, [%saddr];
5916    cvt.u64.u32 %off, %r_tid;
5917    shl.b64 %off, %off, 2;
5918    add.u64 %saddr, %sbase, %off;
5919    ld.shared.f32 %mean, [%saddr];
5920    add.f32 %mean, %mean, %other_val;
5921    add.u64 %saddr, %sbase, %off;
5922    st.shared.f32 [%saddr], %mean;
5923MRS:
5924    bar.sync 0;
5925    bra MR;
5926MRD:
5927    ld.shared.f32 %mean, [%sbase];
5928    div.approx.f32 %mean, %mean, %n_f;
5929    bar.sync 0;
5930
5931    mov.f32 %var, 0f00000000;
5932    mov.u32 %j, %r_tid;
5933SV:
5934    setp.ge.u32 %lp, %j, %cols_reg;
5935    @%lp bra SVD;
5936    cvt.u64.u32 %off, %j;
5937    shl.b64 %off, %off, 2;
5938    add.u64 %off, %in, %off;
5939    add.u64 %off, %off, %row_off;
5940    ld.global.f32 %val, [%off];
5941    sub.f32 %diff, %val, %mean;
5942    fma.rn.f32 %var, %diff, %diff, %var;
5943    add.u32 %j, %j, %r_bdim;
5944    bra SV;
5945SVD:
5946    cvt.u64.u32 %off, %r_tid;
5947    shl.b64 %off, %off, 2;
5948    add.u64 %saddr, %sbase, %off;
5949    st.shared.f32 [%saddr], %var;
5950    bar.sync 0;
5951    mov.u32 %half, %r_bdim;
5952VR:
5953    shr.u32 %half, %half, 1;
5954    setp.eq.u32 %rp, %half, 0;
5955    @%rp bra VRD;
5956    setp.ge.u32 %rp, %r_tid, %half;
5957    @%rp bra VRS;
5958    add.u32 %r_otid, %r_tid, %half;
5959    cvt.u64.u32 %off, %r_otid;
5960    shl.b64 %off, %off, 2;
5961    add.u64 %saddr, %sbase, %off;
5962    ld.shared.f32 %other_val, [%saddr];
5963    cvt.u64.u32 %off, %r_tid;
5964    shl.b64 %off, %off, 2;
5965    add.u64 %saddr, %sbase, %off;
5966    ld.shared.f32 %var, [%saddr];
5967    add.f32 %var, %var, %other_val;
5968    add.u64 %saddr, %sbase, %off;
5969    st.shared.f32 [%saddr], %var;
5970VRS:
5971    bar.sync 0;
5972    bra VR;
5973VRD:
5974    ld.shared.f32 %var, [%sbase];
5975    div.approx.f32 %var, %var, %n_f;
5976    add.f32 %var, %var, %eps_r;
5977    sqrt.approx.f32 %inv_std, %var;
5978    rcp.approx.f32 %inv_std, %inv_std;
5979    bar.sync 0;
5980
5981    mov.u32 %j, %r_tid;
5982NM:
5983    setp.ge.u32 %lp, %j, %cols_reg;
5984    @%lp bra NMD;
5985    cvt.u64.u32 %off, %j;
5986    shl.b64 %off, %off, 2;
5987    add.u64 %off, %in, %off;
5988    add.u64 %off, %off, %row_off;
5989    ld.global.f32 %val, [%off];
5990    sub.f32 %normed, %val, %mean;
5991    mul.f32 %normed, %normed, %inv_std;
5992    cvt.u64.u32 %off, %j;
5993    shl.b64 %off, %off, 2;
5994    add.u64 %off, %w, %off;
5995    ld.global.f32 %wv, [%off];
5996    cvt.u64.u32 %off, %j;
5997    shl.b64 %off, %off, 2;
5998    add.u64 %off, %b, %off;
5999    ld.global.f32 %bv, [%off];
6000    fma.rn.f32 %result, %wv, %normed, %bv;
6001    cvt.u64.u32 %off, %j;
6002    shl.b64 %off, %off, 2;
6003    add.u64 %off, %out, %off;
6004    add.u64 %off, %off, %row_off;
6005    st.global.f32 [%off], %result;
6006    add.u32 %j, %j, %r_bdim;
6007    bra NM;
6008NMD:
6009
6010DONE:
6011    ret;
6012}
6013";
6014
6015
6016// ---------------------------------------------------------------------------
6017// LayerNorm backward PTX kernel
6018// ---------------------------------------------------------------------------
6019//
6020// One block per batch element (row). Each block:
6021//   1. Recompute mean and variance from input
6022//   2. Compute x_hat = (x - mean) * rsqrt(var + eps)
6023//   3. Compute dl_dx_hat = grad_output * weight
6024//   4. Reduce dl_dx_hat and dl_dx_hat * x_hat across the normalized dimension
6025//   5. Compute grad_input = rsqrt(var+eps) * (dl_dx_hat - mean(dl_dx_hat) - x_hat * mean(dl_dx_hat * x_hat))
6026//   6. Accumulate grad_weight (atomicAdd) and grad_bias (atomicAdd) across batch elements
6027//
6028// Uses shared memory for per-row reductions, 256 threads per block.
6029// Parameters:
6030//   in_ptr      - pointer to input f32 buffer [rows * cols]
6031//   grad_out_ptr - pointer to grad_output f32 buffer [rows * cols]
6032//   w_ptr       - pointer to weight f32 buffer [cols]
6033//   grad_in_ptr - pointer to grad_input f32 output buffer [rows * cols]
6034//   grad_w_ptr  - pointer to grad_weight f32 output buffer [cols] (atomicAdd)
6035//   grad_b_ptr  - pointer to grad_bias f32 output buffer [cols] (atomicAdd)
6036//   rows        - number of batch elements
6037//   cols        - normalized dimension size
6038//   eps         - epsilon for numerical stability
6039
6040#[cfg(feature = "cuda")]
6041pub(crate) const LAYERNORM_BACKWARD_PTX: &str = "\
6042.version 7.0
6043.target sm_52
6044.address_size 64
6045
6046.shared .align 4 .f32 sdata[256];
6047
6048.visible .entry layernorm_backward_kernel(
6049    .param .u64 in_ptr,
6050    .param .u64 grad_out_ptr,
6051    .param .u64 w_ptr,
6052    .param .u64 grad_in_ptr,
6053    .param .u64 grad_w_ptr,
6054    .param .u64 grad_b_ptr,
6055    .param .u32 rows,
6056    .param .u32 cols,
6057    .param .f32 eps
6058) {
6059    .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
6060    .reg .u64 %in, %go, %w, %gi, %gw, %gb, %row_off, %off, %sbase, %saddr, %addr;
6061    .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %x_hat, %wv, %gov;
6062    .reg .f32 %dl_dx_hat, %sum1, %sum2, %other_val, %n_f, %mean1, %mean2, %result;
6063    .reg .pred %p, %lp, %rp;
6064
6065    ld.param.u64 %in, [in_ptr];
6066    ld.param.u64 %go, [grad_out_ptr];
6067    ld.param.u64 %w, [w_ptr];
6068    ld.param.u64 %gi, [grad_in_ptr];
6069    ld.param.u64 %gw, [grad_w_ptr];
6070    ld.param.u64 %gb, [grad_b_ptr];
6071    ld.param.u32 %rows_reg, [rows];
6072    ld.param.u32 %cols_reg, [cols];
6073    ld.param.f32 %eps_r, [eps];
6074
6075    mov.u64 %sbase, sdata;
6076
6077    mov.u32 %r_bid, %ctaid.x;
6078    mov.u32 %r_bdim, %ntid.x;
6079    mov.u32 %r_tid, %tid.x;
6080
6081    setp.ge.u32 %p, %r_bid, %rows_reg;
6082    @%p bra LNB_DONE;
6083
6084    // row_off = bid * cols * 4 (byte offset for this row)
6085    cvt.u64.u32 %row_off, %r_bid;
6086    cvt.u64.u32 %off, %cols_reg;
6087    mul.lo.u64 %row_off, %row_off, %off;
6088    shl.b64 %row_off, %row_off, 2;
6089    cvt.rn.f32.u32 %n_f, %cols_reg;
6090
6091    // ===== Phase 1: Compute mean =====
6092    mov.f32 %mean, 0f00000000;
6093    mov.u32 %j, %r_tid;
6094LNB_SM:
6095    setp.ge.u32 %lp, %j, %cols_reg;
6096    @%lp bra LNB_SMD;
6097    cvt.u64.u32 %off, %j;
6098    shl.b64 %off, %off, 2;
6099    add.u64 %addr, %in, %off;
6100    add.u64 %addr, %addr, %row_off;
6101    ld.global.f32 %val, [%addr];
6102    add.f32 %mean, %mean, %val;
6103    add.u32 %j, %j, %r_bdim;
6104    bra LNB_SM;
6105LNB_SMD:
6106    // Shared memory reduce for mean
6107    cvt.u64.u32 %off, %r_tid;
6108    shl.b64 %off, %off, 2;
6109    add.u64 %saddr, %sbase, %off;
6110    st.shared.f32 [%saddr], %mean;
6111    bar.sync 0;
6112    mov.u32 %half, %r_bdim;
6113LNB_MR:
6114    shr.u32 %half, %half, 1;
6115    setp.eq.u32 %rp, %half, 0;
6116    @%rp bra LNB_MRD;
6117    setp.ge.u32 %rp, %r_tid, %half;
6118    @%rp bra LNB_MRS;
6119    add.u32 %r_otid, %r_tid, %half;
6120    cvt.u64.u32 %off, %r_otid;
6121    shl.b64 %off, %off, 2;
6122    add.u64 %saddr, %sbase, %off;
6123    ld.shared.f32 %other_val, [%saddr];
6124    cvt.u64.u32 %off, %r_tid;
6125    shl.b64 %off, %off, 2;
6126    add.u64 %saddr, %sbase, %off;
6127    ld.shared.f32 %mean, [%saddr];
6128    add.f32 %mean, %mean, %other_val;
6129    st.shared.f32 [%saddr], %mean;
6130LNB_MRS:
6131    bar.sync 0;
6132    bra LNB_MR;
6133LNB_MRD:
6134    ld.shared.f32 %mean, [%sbase];
6135    div.approx.f32 %mean, %mean, %n_f;
6136    bar.sync 0;
6137
6138    // ===== Phase 2: Compute variance =====
6139    mov.f32 %var, 0f00000000;
6140    mov.u32 %j, %r_tid;
6141LNB_SV:
6142    setp.ge.u32 %lp, %j, %cols_reg;
6143    @%lp bra LNB_SVD;
6144    cvt.u64.u32 %off, %j;
6145    shl.b64 %off, %off, 2;
6146    add.u64 %addr, %in, %off;
6147    add.u64 %addr, %addr, %row_off;
6148    ld.global.f32 %val, [%addr];
6149    sub.f32 %diff, %val, %mean;
6150    fma.rn.f32 %var, %diff, %diff, %var;
6151    add.u32 %j, %j, %r_bdim;
6152    bra LNB_SV;
6153LNB_SVD:
6154    // Shared memory reduce for variance
6155    cvt.u64.u32 %off, %r_tid;
6156    shl.b64 %off, %off, 2;
6157    add.u64 %saddr, %sbase, %off;
6158    st.shared.f32 [%saddr], %var;
6159    bar.sync 0;
6160    mov.u32 %half, %r_bdim;
6161LNB_VR:
6162    shr.u32 %half, %half, 1;
6163    setp.eq.u32 %rp, %half, 0;
6164    @%rp bra LNB_VRD;
6165    setp.ge.u32 %rp, %r_tid, %half;
6166    @%rp bra LNB_VRS;
6167    add.u32 %r_otid, %r_tid, %half;
6168    cvt.u64.u32 %off, %r_otid;
6169    shl.b64 %off, %off, 2;
6170    add.u64 %saddr, %sbase, %off;
6171    ld.shared.f32 %other_val, [%saddr];
6172    cvt.u64.u32 %off, %r_tid;
6173    shl.b64 %off, %off, 2;
6174    add.u64 %saddr, %sbase, %off;
6175    ld.shared.f32 %var, [%saddr];
6176    add.f32 %var, %var, %other_val;
6177    st.shared.f32 [%saddr], %var;
6178LNB_VRS:
6179    bar.sync 0;
6180    bra LNB_VR;
6181LNB_VRD:
6182    ld.shared.f32 %var, [%sbase];
6183    div.approx.f32 %var, %var, %n_f;
6184    add.f32 %var, %var, %eps_r;
6185    sqrt.approx.f32 %inv_std, %var;
6186    rcp.approx.f32 %inv_std, %inv_std;
6187    bar.sync 0;
6188
6189    // ===== Phase 3: Compute sum1 = sum(dl_dx_hat), sum2 = sum(dl_dx_hat * x_hat) =====
6190    // Also accumulate grad_weight and grad_bias via atomicAdd
6191    mov.f32 %sum1, 0f00000000;
6192    mov.f32 %sum2, 0f00000000;
6193    mov.u32 %j, %r_tid;
6194LNB_S12:
6195    setp.ge.u32 %lp, %j, %cols_reg;
6196    @%lp bra LNB_S12D;
6197    // Load input[row, j]
6198    cvt.u64.u32 %off, %j;
6199    shl.b64 %off, %off, 2;
6200    add.u64 %addr, %in, %off;
6201    add.u64 %addr, %addr, %row_off;
6202    ld.global.f32 %val, [%addr];
6203    // x_hat = (val - mean) * inv_std
6204    sub.f32 %x_hat, %val, %mean;
6205    mul.f32 %x_hat, %x_hat, %inv_std;
6206    // Load grad_output[row, j]
6207    cvt.u64.u32 %off, %j;
6208    shl.b64 %off, %off, 2;
6209    add.u64 %addr, %go, %off;
6210    add.u64 %addr, %addr, %row_off;
6211    ld.global.f32 %gov, [%addr];
6212    // Load weight[j]
6213    cvt.u64.u32 %off, %j;
6214    shl.b64 %off, %off, 2;
6215    add.u64 %addr, %w, %off;
6216    ld.global.f32 %wv, [%addr];
6217    // dl_dx_hat = grad_output * weight
6218    mul.f32 %dl_dx_hat, %gov, %wv;
6219    // Accumulate sums
6220    add.f32 %sum1, %sum1, %dl_dx_hat;
6221    fma.rn.f32 %sum2, %dl_dx_hat, %x_hat, %sum2;
6222    // atomicAdd grad_weight[j] += grad_output * x_hat
6223    cvt.u64.u32 %off, %j;
6224    shl.b64 %off, %off, 2;
6225    add.u64 %addr, %gw, %off;
6226    mul.f32 %result, %gov, %x_hat;
6227    atom.global.add.f32 %result, [%addr], %result;
6228    // atomicAdd grad_bias[j] += grad_output
6229    add.u64 %addr, %gb, %off;
6230    atom.global.add.f32 %result, [%addr], %gov;
6231    add.u32 %j, %j, %r_bdim;
6232    bra LNB_S12;
6233LNB_S12D:
6234    // Reduce sum1 in shared memory
6235    cvt.u64.u32 %off, %r_tid;
6236    shl.b64 %off, %off, 2;
6237    add.u64 %saddr, %sbase, %off;
6238    st.shared.f32 [%saddr], %sum1;
6239    bar.sync 0;
6240    mov.u32 %half, %r_bdim;
6241LNB_R1:
6242    shr.u32 %half, %half, 1;
6243    setp.eq.u32 %rp, %half, 0;
6244    @%rp bra LNB_R1D;
6245    setp.ge.u32 %rp, %r_tid, %half;
6246    @%rp bra LNB_R1S;
6247    add.u32 %r_otid, %r_tid, %half;
6248    cvt.u64.u32 %off, %r_otid;
6249    shl.b64 %off, %off, 2;
6250    add.u64 %saddr, %sbase, %off;
6251    ld.shared.f32 %other_val, [%saddr];
6252    cvt.u64.u32 %off, %r_tid;
6253    shl.b64 %off, %off, 2;
6254    add.u64 %saddr, %sbase, %off;
6255    ld.shared.f32 %sum1, [%saddr];
6256    add.f32 %sum1, %sum1, %other_val;
6257    st.shared.f32 [%saddr], %sum1;
6258LNB_R1S:
6259    bar.sync 0;
6260    bra LNB_R1;
6261LNB_R1D:
6262    ld.shared.f32 %sum1, [%sbase];
6263    // mean1 = sum1 / n
6264    div.approx.f32 %mean1, %sum1, %n_f;
6265    bar.sync 0;
6266
6267    // Reduce sum2 in shared memory
6268    cvt.u64.u32 %off, %r_tid;
6269    shl.b64 %off, %off, 2;
6270    add.u64 %saddr, %sbase, %off;
6271    st.shared.f32 [%saddr], %sum2;
6272    bar.sync 0;
6273    mov.u32 %half, %r_bdim;
6274LNB_R2:
6275    shr.u32 %half, %half, 1;
6276    setp.eq.u32 %rp, %half, 0;
6277    @%rp bra LNB_R2D;
6278    setp.ge.u32 %rp, %r_tid, %half;
6279    @%rp bra LNB_R2S;
6280    add.u32 %r_otid, %r_tid, %half;
6281    cvt.u64.u32 %off, %r_otid;
6282    shl.b64 %off, %off, 2;
6283    add.u64 %saddr, %sbase, %off;
6284    ld.shared.f32 %other_val, [%saddr];
6285    cvt.u64.u32 %off, %r_tid;
6286    shl.b64 %off, %off, 2;
6287    add.u64 %saddr, %sbase, %off;
6288    ld.shared.f32 %sum2, [%saddr];
6289    add.f32 %sum2, %sum2, %other_val;
6290    st.shared.f32 [%saddr], %sum2;
6291LNB_R2S:
6292    bar.sync 0;
6293    bra LNB_R2;
6294LNB_R2D:
6295    ld.shared.f32 %sum2, [%sbase];
6296    // mean2 = sum2 / n
6297    div.approx.f32 %mean2, %sum2, %n_f;
6298    bar.sync 0;
6299
6300    // ===== Phase 4: Compute grad_input =====
6301    // grad_input[j] = inv_std * (dl_dx_hat[j] - mean1 - x_hat[j] * mean2)
6302    mov.u32 %j, %r_tid;
6303LNB_GI:
6304    setp.ge.u32 %lp, %j, %cols_reg;
6305    @%lp bra LNB_GID;
6306    // Reload input to recompute x_hat
6307    cvt.u64.u32 %off, %j;
6308    shl.b64 %off, %off, 2;
6309    add.u64 %addr, %in, %off;
6310    add.u64 %addr, %addr, %row_off;
6311    ld.global.f32 %val, [%addr];
6312    sub.f32 %x_hat, %val, %mean;
6313    mul.f32 %x_hat, %x_hat, %inv_std;
6314    // Reload grad_output and weight to recompute dl_dx_hat
6315    cvt.u64.u32 %off, %j;
6316    shl.b64 %off, %off, 2;
6317    add.u64 %addr, %go, %off;
6318    add.u64 %addr, %addr, %row_off;
6319    ld.global.f32 %gov, [%addr];
6320    cvt.u64.u32 %off, %j;
6321    shl.b64 %off, %off, 2;
6322    add.u64 %addr, %w, %off;
6323    ld.global.f32 %wv, [%addr];
6324    mul.f32 %dl_dx_hat, %gov, %wv;
6325    // result = inv_std * (dl_dx_hat - mean1 - x_hat * mean2)
6326    sub.f32 %result, %dl_dx_hat, %mean1;
6327    mul.f32 %diff, %x_hat, %mean2;
6328    sub.f32 %result, %result, %diff;
6329    mul.f32 %result, %inv_std, %result;
6330    // Store grad_input[row, j]
6331    cvt.u64.u32 %off, %j;
6332    shl.b64 %off, %off, 2;
6333    add.u64 %addr, %gi, %off;
6334    add.u64 %addr, %addr, %row_off;
6335    st.global.f32 [%addr], %result;
6336    add.u32 %j, %j, %r_bdim;
6337    bra LNB_GI;
6338LNB_GID:
6339
6340LNB_DONE:
6341    ret;
6342}
6343";
6344
6345
6346// ---------------------------------------------------------------------------
6347// RMSNorm PTX kernel (row-wise: rms, normalize+scale)
6348//
6349// Like LayerNorm but without mean centering or bias:
6350//   out[j] = x[j] * rsqrt(mean(x^2) + eps) * weight[j]
6351//
6352// Uses `.approx` PTX instructions (`div.approx.f32`, `sqrt.approx.f32`,
6353// `rcp.approx.f32`) for performance. These have reduced precision (~2^-22
6354// relative error) compared to the full-precision variants, which is
6355// acceptable for neural network training/inference.
6356// ---------------------------------------------------------------------------
6357
6358#[cfg(feature = "cuda")]
6359pub(crate) const RMSNORM_PTX: &str = "\
6360.version 7.0
6361.target sm_52
6362.address_size 64
6363
6364.shared .align 4 .f32 sdata[256];
6365
6366.visible .entry rmsnorm_kernel(
6367    .param .u64 in_ptr,
6368    .param .u64 out_ptr,
6369    .param .u64 w_ptr,
6370    .param .u32 rows,
6371    .param .u32 cols,
6372    .param .f32 eps
6373) {
6374    .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
6375    .reg .u64 %in, %out, %w, %row_off, %off, %sbase, %saddr;
6376    .reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %wv, %result, %other_val, %n_f;
6377    .reg .pred %p, %lp, %rp;
6378
6379    ld.param.u64 %in, [in_ptr];
6380    ld.param.u64 %out, [out_ptr];
6381    ld.param.u64 %w, [w_ptr];
6382    ld.param.u32 %rows_reg, [rows];
6383    ld.param.u32 %cols_reg, [cols];
6384    ld.param.f32 %eps_r, [eps];
6385
6386    mov.u64 %sbase, sdata;
6387
6388    mov.u32 %r_bid, %ctaid.x;
6389    mov.u32 %r_bdim, %ntid.x;
6390    mov.u32 %r_tid, %tid.x;
6391
6392    setp.ge.u32 %p, %r_bid, %rows_reg;
6393    @%p bra DONE;
6394
6395    cvt.u64.u32 %row_off, %r_bid;
6396    cvt.u64.u32 %off, %cols_reg;
6397    mul.lo.u64 %row_off, %row_off, %off;
6398    shl.b64 %row_off, %row_off, 2;
6399    cvt.rn.f32.u32 %n_f, %cols_reg;
6400
6401    // ===== Phase 1: Compute sum(x^2) =====
6402    mov.f32 %sq_sum, 0f00000000;
6403    mov.u32 %j, %r_tid;
6404SS:
6405    setp.ge.u32 %lp, %j, %cols_reg;
6406    @%lp bra SSD;
6407    cvt.u64.u32 %off, %j;
6408    shl.b64 %off, %off, 2;
6409    add.u64 %off, %in, %off;
6410    add.u64 %off, %off, %row_off;
6411    ld.global.f32 %val, [%off];
6412    fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
6413    add.u32 %j, %j, %r_bdim;
6414    bra SS;
6415SSD:
6416    cvt.u64.u32 %off, %r_tid;
6417    shl.b64 %off, %off, 2;
6418    add.u64 %saddr, %sbase, %off;
6419    st.shared.f32 [%saddr], %sq_sum;
6420    bar.sync 0;
6421    mov.u32 %half, %r_bdim;
6422SR:
6423    shr.u32 %half, %half, 1;
6424    setp.eq.u32 %rp, %half, 0;
6425    @%rp bra SRD;
6426    setp.ge.u32 %rp, %r_tid, %half;
6427    @%rp bra SRS;
6428    add.u32 %r_otid, %r_tid, %half;
6429    cvt.u64.u32 %off, %r_otid;
6430    shl.b64 %off, %off, 2;
6431    add.u64 %saddr, %sbase, %off;
6432    ld.shared.f32 %other_val, [%saddr];
6433    cvt.u64.u32 %off, %r_tid;
6434    shl.b64 %off, %off, 2;
6435    add.u64 %saddr, %sbase, %off;
6436    ld.shared.f32 %sq_sum, [%saddr];
6437    add.f32 %sq_sum, %sq_sum, %other_val;
6438    add.u64 %saddr, %sbase, %off;
6439    st.shared.f32 [%saddr], %sq_sum;
6440SRS:
6441    bar.sync 0;
6442    bra SR;
6443SRD:
6444    ld.shared.f32 %sq_sum, [%sbase];
6445    div.approx.f32 %sq_sum, %sq_sum, %n_f;
6446    add.f32 %sq_sum, %sq_sum, %eps_r;
6447    sqrt.approx.f32 %inv_rms, %sq_sum;
6448    rcp.approx.f32 %inv_rms, %inv_rms;
6449    bar.sync 0;
6450
6451    // ===== Phase 2: Normalize and scale =====
6452    // out[j] = x[j] * inv_rms * weight[j]
6453    mov.u32 %j, %r_tid;
6454NM:
6455    setp.ge.u32 %lp, %j, %cols_reg;
6456    @%lp bra NMD;
6457    cvt.u64.u32 %off, %j;
6458    shl.b64 %off, %off, 2;
6459    add.u64 %off, %in, %off;
6460    add.u64 %off, %off, %row_off;
6461    ld.global.f32 %val, [%off];
6462    mul.f32 %result, %val, %inv_rms;
6463    cvt.u64.u32 %off, %j;
6464    shl.b64 %off, %off, 2;
6465    add.u64 %off, %w, %off;
6466    ld.global.f32 %wv, [%off];
6467    mul.f32 %result, %result, %wv;
6468    cvt.u64.u32 %off, %j;
6469    shl.b64 %off, %off, 2;
6470    add.u64 %off, %out, %off;
6471    add.u64 %off, %off, %row_off;
6472    st.global.f32 [%off], %result;
6473    add.u32 %j, %j, %r_bdim;
6474    bra NM;
6475NMD:
6476
6477DONE:
6478    ret;
6479}
6480";
6481
6482
6483// ---------------------------------------------------------------------------
6484// RMSNorm backward PTX kernel
6485// ---------------------------------------------------------------------------
6486//
6487// One block per batch element (row). Each block:
6488//   1. Recompute inv_rms = 1/sqrt(mean(x^2) + eps)
6489//   2. Compute dot = sum(grad_output[j] * x[j] * weight[j])
6490//   3. Compute grad_input[j] = inv_rms * weight[j] * go[j]
6491//                              - x[j] * inv_rms^3 * dot / cols
6492//   4. Accumulate grad_weight[j] (atomicAdd) = go[j] * x[j] * inv_rms
6493//
6494// Uses shared memory for per-row reductions, 256 threads per block.
6495// No grad_bias (RMSNorm has no bias parameter).
6496// Parameters:
6497//   in_ptr       - pointer to input f32 buffer [rows * cols]
6498//   grad_out_ptr - pointer to grad_output f32 buffer [rows * cols]
6499//   w_ptr        - pointer to weight f32 buffer [cols]
6500//   grad_in_ptr  - pointer to grad_input f32 output buffer [rows * cols]
6501//   grad_w_ptr   - pointer to grad_weight f32 output buffer [cols] (atomicAdd)
6502//   rows         - number of batch elements
6503//   cols         - normalized dimension size
6504//   eps          - epsilon for numerical stability
6505
6506#[cfg(feature = "cuda")]
6507pub(crate) const RMSNORM_BACKWARD_PTX: &str = "\
6508.version 7.0
6509.target sm_52
6510.address_size 64
6511
6512.shared .align 4 .f32 sdata[256];
6513
6514.visible .entry rmsnorm_backward_kernel(
6515    .param .u64 in_ptr,
6516    .param .u64 grad_out_ptr,
6517    .param .u64 w_ptr,
6518    .param .u64 grad_in_ptr,
6519    .param .u64 grad_w_ptr,
6520    .param .u32 rows,
6521    .param .u32 cols,
6522    .param .f32 eps
6523) {
6524    .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
6525    .reg .u64 %in, %go, %w, %gi, %gw, %row_off, %off, %sbase, %saddr, %addr;
6526    .reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %inv_rms3, %wv, %gov;
6527    .reg .f32 %dot, %other_val, %n_f, %coeff, %result, %tmp;
6528    .reg .pred %p, %lp, %rp;
6529
6530    ld.param.u64 %in, [in_ptr];
6531    ld.param.u64 %go, [grad_out_ptr];
6532    ld.param.u64 %w, [w_ptr];
6533    ld.param.u64 %gi, [grad_in_ptr];
6534    ld.param.u64 %gw, [grad_w_ptr];
6535    ld.param.u32 %rows_reg, [rows];
6536    ld.param.u32 %cols_reg, [cols];
6537    ld.param.f32 %eps_r, [eps];
6538
6539    mov.u64 %sbase, sdata;
6540
6541    mov.u32 %r_bid, %ctaid.x;
6542    mov.u32 %r_bdim, %ntid.x;
6543    mov.u32 %r_tid, %tid.x;
6544
6545    setp.ge.u32 %p, %r_bid, %rows_reg;
6546    @%p bra RNB_DONE;
6547
6548    // row_off = bid * cols * 4 (byte offset for this row)
6549    cvt.u64.u32 %row_off, %r_bid;
6550    cvt.u64.u32 %off, %cols_reg;
6551    mul.lo.u64 %row_off, %row_off, %off;
6552    shl.b64 %row_off, %row_off, 2;
6553    cvt.rn.f32.u32 %n_f, %cols_reg;
6554
6555    // ===== Phase 1: Compute sum(x^2) -> inv_rms =====
6556    mov.f32 %sq_sum, 0f00000000;
6557    mov.u32 %j, %r_tid;
6558RNB_SS:
6559    setp.ge.u32 %lp, %j, %cols_reg;
6560    @%lp bra RNB_SSD;
6561    cvt.u64.u32 %off, %j;
6562    shl.b64 %off, %off, 2;
6563    add.u64 %addr, %in, %off;
6564    add.u64 %addr, %addr, %row_off;
6565    ld.global.f32 %val, [%addr];
6566    fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
6567    add.u32 %j, %j, %r_bdim;
6568    bra RNB_SS;
6569RNB_SSD:
6570    // Shared memory reduce for sum(x^2)
6571    cvt.u64.u32 %off, %r_tid;
6572    shl.b64 %off, %off, 2;
6573    add.u64 %saddr, %sbase, %off;
6574    st.shared.f32 [%saddr], %sq_sum;
6575    bar.sync 0;
6576    mov.u32 %half, %r_bdim;
6577RNB_SR:
6578    shr.u32 %half, %half, 1;
6579    setp.eq.u32 %rp, %half, 0;
6580    @%rp bra RNB_SRD;
6581    setp.ge.u32 %rp, %r_tid, %half;
6582    @%rp bra RNB_SRS;
6583    add.u32 %r_otid, %r_tid, %half;
6584    cvt.u64.u32 %off, %r_otid;
6585    shl.b64 %off, %off, 2;
6586    add.u64 %saddr, %sbase, %off;
6587    ld.shared.f32 %other_val, [%saddr];
6588    cvt.u64.u32 %off, %r_tid;
6589    shl.b64 %off, %off, 2;
6590    add.u64 %saddr, %sbase, %off;
6591    ld.shared.f32 %sq_sum, [%saddr];
6592    add.f32 %sq_sum, %sq_sum, %other_val;
6593    st.shared.f32 [%saddr], %sq_sum;
6594RNB_SRS:
6595    bar.sync 0;
6596    bra RNB_SR;
6597RNB_SRD:
6598    ld.shared.f32 %sq_sum, [%sbase];
6599    div.approx.f32 %sq_sum, %sq_sum, %n_f;
6600    add.f32 %sq_sum, %sq_sum, %eps_r;
6601    sqrt.approx.f32 %inv_rms, %sq_sum;
6602    rcp.approx.f32 %inv_rms, %inv_rms;
6603    // inv_rms3 = inv_rms^3 = inv_rms * inv_rms * inv_rms
6604    mul.f32 %inv_rms3, %inv_rms, %inv_rms;
6605    mul.f32 %inv_rms3, %inv_rms3, %inv_rms;
6606    bar.sync 0;
6607
6608    // ===== Phase 2: Compute dot = sum(go[j] * x[j] * w[j]) =====
6609    // Also accumulate grad_weight via atomicAdd
6610    mov.f32 %dot, 0f00000000;
6611    mov.u32 %j, %r_tid;
6612RNB_DOT:
6613    setp.ge.u32 %lp, %j, %cols_reg;
6614    @%lp bra RNB_DOTD;
6615    // Load input[row, j]
6616    cvt.u64.u32 %off, %j;
6617    shl.b64 %off, %off, 2;
6618    add.u64 %addr, %in, %off;
6619    add.u64 %addr, %addr, %row_off;
6620    ld.global.f32 %val, [%addr];
6621    // Load grad_output[row, j]
6622    cvt.u64.u32 %off, %j;
6623    shl.b64 %off, %off, 2;
6624    add.u64 %addr, %go, %off;
6625    add.u64 %addr, %addr, %row_off;
6626    ld.global.f32 %gov, [%addr];
6627    // Load weight[j]
6628    cvt.u64.u32 %off, %j;
6629    shl.b64 %off, %off, 2;
6630    add.u64 %addr, %w, %off;
6631    ld.global.f32 %wv, [%addr];
6632    // dot += go * x * w
6633    mul.f32 %tmp, %gov, %val;
6634    fma.rn.f32 %dot, %tmp, %wv, %dot;
6635    // atomicAdd grad_weight[j] += go * x * inv_rms
6636    cvt.u64.u32 %off, %j;
6637    shl.b64 %off, %off, 2;
6638    add.u64 %addr, %gw, %off;
6639    mul.f32 %result, %gov, %val;
6640    mul.f32 %result, %result, %inv_rms;
6641    atom.global.add.f32 %result, [%addr], %result;
6642    add.u32 %j, %j, %r_bdim;
6643    bra RNB_DOT;
6644RNB_DOTD:
6645    // Reduce dot in shared memory
6646    cvt.u64.u32 %off, %r_tid;
6647    shl.b64 %off, %off, 2;
6648    add.u64 %saddr, %sbase, %off;
6649    st.shared.f32 [%saddr], %dot;
6650    bar.sync 0;
6651    mov.u32 %half, %r_bdim;
6652RNB_DR:
6653    shr.u32 %half, %half, 1;
6654    setp.eq.u32 %rp, %half, 0;
6655    @%rp bra RNB_DRD;
6656    setp.ge.u32 %rp, %r_tid, %half;
6657    @%rp bra RNB_DRS;
6658    add.u32 %r_otid, %r_tid, %half;
6659    cvt.u64.u32 %off, %r_otid;
6660    shl.b64 %off, %off, 2;
6661    add.u64 %saddr, %sbase, %off;
6662    ld.shared.f32 %other_val, [%saddr];
6663    cvt.u64.u32 %off, %r_tid;
6664    shl.b64 %off, %off, 2;
6665    add.u64 %saddr, %sbase, %off;
6666    ld.shared.f32 %dot, [%saddr];
6667    add.f32 %dot, %dot, %other_val;
6668    st.shared.f32 [%saddr], %dot;
6669RNB_DRS:
6670    bar.sync 0;
6671    bra RNB_DR;
6672RNB_DRD:
6673    ld.shared.f32 %dot, [%sbase];
6674    // coeff = dot * inv_rms3 / n
6675    mul.f32 %coeff, %dot, %inv_rms3;
6676    div.approx.f32 %coeff, %coeff, %n_f;
6677    bar.sync 0;
6678
6679    // ===== Phase 3: Compute grad_input =====
6680    // grad_input[j] = inv_rms * w[j] * go[j] - x[j] * coeff
6681    mov.u32 %j, %r_tid;
6682RNB_GI:
6683    setp.ge.u32 %lp, %j, %cols_reg;
6684    @%lp bra RNB_GID;
6685    // Reload input
6686    cvt.u64.u32 %off, %j;
6687    shl.b64 %off, %off, 2;
6688    add.u64 %addr, %in, %off;
6689    add.u64 %addr, %addr, %row_off;
6690    ld.global.f32 %val, [%addr];
6691    // Reload grad_output and weight
6692    cvt.u64.u32 %off, %j;
6693    shl.b64 %off, %off, 2;
6694    add.u64 %addr, %go, %off;
6695    add.u64 %addr, %addr, %row_off;
6696    ld.global.f32 %gov, [%addr];
6697    cvt.u64.u32 %off, %j;
6698    shl.b64 %off, %off, 2;
6699    add.u64 %addr, %w, %off;
6700    ld.global.f32 %wv, [%addr];
6701    // result = inv_rms * w * go - x * coeff
6702    mul.f32 %result, %inv_rms, %wv;
6703    mul.f32 %result, %result, %gov;
6704    mul.f32 %tmp, %val, %coeff;
6705    sub.f32 %result, %result, %tmp;
6706    // Store grad_input[row, j]
6707    cvt.u64.u32 %off, %j;
6708    shl.b64 %off, %off, 2;
6709    add.u64 %addr, %gi, %off;
6710    add.u64 %addr, %addr, %row_off;
6711    st.global.f32 [%addr], %result;
6712    add.u32 %j, %j, %r_bdim;
6713    bra RNB_GI;
6714RNB_GID:
6715
6716RNB_DONE:
6717    ret;
6718}
6719";
6720
6721
6722// ---------------------------------------------------------------------------
6723// Softmax PTX kernel (row-wise, numerically stable)
6724// ---------------------------------------------------------------------------
6725//
6726// One thread block per row. Each block:
6727//   1. Finds the max in shared memory (for numerical stability)
6728//   2. Computes exp(x - max) and sums in shared memory
6729//   3. Normalizes by the sum
6730//
6731// Uses `.approx` PTX instructions (`ex2.approx.f32`, `rcp.approx.f32`)
6732// for performance. These have reduced precision (~2^-22 relative error)
6733// compared to the full-precision variants, which is acceptable for neural
6734// network training/inference.
6735//
6736// Parameters:
6737//   input_ptr  - pointer to input f32 buffer
6738//   output_ptr - pointer to output f32 buffer
6739//   rows       - number of rows (outer dimension)
6740//   cols       - number of columns (softmax dimension, = last_dim)
6741
6742/// PTX kernel for BatchNorm2d forward: per-channel normalize + affine.
6743///
6744/// Input layout: [B*C*spatial] flattened, where spatial = H*W.
6745/// One block per channel. Each block computes mean + variance for its
6746/// channel across all batch elements and spatial positions, then
6747/// normalizes in a second pass.
6748///
6749/// Parameters:
6750///   input[B*C*S], output[B*C*S], weight[C], bias[C],
6751///   running_mean[C], running_var[C], save_mean[C], save_invstd[C],
6752///   channels, spatial, eps, momentum, total_per_channel (= B*S),
6753///   training (0 or 1)
6754#[cfg(feature = "cuda")]
6755pub(crate) const BATCHNORM_FORWARD_PTX: &str = "\
6756.version 7.0
6757.target sm_52
6758.address_size 64
6759
6760// Shared memory for block reduction
6761.shared .align 4 .f32 smem_sum[256];
6762.shared .align 4 .f32 smem_sq[256];
6763
6764.visible .entry batchnorm_forward_kernel(
6765    .param .u64 input_ptr,
6766    .param .u64 output_ptr,
6767    .param .u64 weight_ptr,
6768    .param .u64 bias_ptr,
6769    .param .u64 rmean_ptr,
6770    .param .u64 rvar_ptr,
6771    .param .u64 save_mean_ptr,
6772    .param .u64 save_invstd_ptr,
6773    .param .u32 channels,
6774    .param .u32 spatial,
6775    .param .f32 eps,
6776    .param .f32 momentum,
6777    .param .u32 total_per_ch,
6778    .param .u32 training
6779) {
6780    .reg .u32 %tid, %bid, %bdim, %ch, %n_ch, %sp, %tpc, %idx, %train;
6781    .reg .u64 %in, %out, %w, %b, %rm, %rv, %sm, %si, %off64, %tmp64;
6782    .reg .f32 %sum, %sqsum, %val, %mean, %var, %invstd;
6783    .reg .f32 %gamma, %beta, %eps_reg, %mom, %other;
6784    .reg .f32 %n_f, %one, %normalized;
6785    .reg .pred %p, %ptrain, %ptid0;
6786    .reg .u32 %half;
6787
6788    ld.param.u64 %in, [input_ptr];
6789    ld.param.u64 %out, [output_ptr];
6790    ld.param.u64 %w, [weight_ptr];
6791    ld.param.u64 %b, [bias_ptr];
6792    ld.param.u64 %rm, [rmean_ptr];
6793    ld.param.u64 %rv, [rvar_ptr];
6794    ld.param.u64 %sm, [save_mean_ptr];
6795    ld.param.u64 %si, [save_invstd_ptr];
6796    ld.param.u32 %n_ch, [channels];
6797    ld.param.u32 %sp, [spatial];
6798    ld.param.f32 %eps_reg, [eps];
6799    ld.param.f32 %mom, [momentum];
6800    ld.param.u32 %tpc, [total_per_ch];
6801    ld.param.u32 %train, [training];
6802
6803    mov.u32 %bid, %ctaid.x;
6804    mov.u32 %tid, %tid.x;
6805    mov.u32 %bdim, %ntid.x;
6806    mov.u32 %ch, %bid;
6807    mov.f32 %one, 0f3F800000;
6808
6809    setp.ge.u32 %p, %ch, %n_ch;
6810    @%p bra END;
6811
6812    setp.ne.u32 %ptrain, %train, 0;
6813
6814    // ---- Pass 1: compute sum and sum-of-squares for this channel ----
6815    mov.f32 %sum, 0f00000000;
6816    mov.f32 %sqsum, 0f00000000;
6817
6818    // Grid-stride loop over B*spatial for this channel
6819    mov.u32 %idx, %tid;
6820PASS1_LOOP:
6821    setp.ge.u32 %p, %idx, %tpc;
6822    @%p bra PASS1_DONE;
6823
6824    // Linear offset = (idx / spatial) * channels * spatial + ch * spatial + idx % spatial
6825    div.u32 %half, %idx, %sp;
6826    rem.u32 %half, %idx, %sp;  // reuse half as spatial_idx
6827    // batch_offset = (idx / sp) * (n_ch * sp) + ch * sp + (idx % sp)
6828    div.u32 %half, %idx, %sp;  // batch_idx
6829    mul.lo.u32 %half, %half, %n_ch;
6830    add.u32 %half, %half, %ch;
6831    mul.lo.u32 %half, %half, %sp;
6832    rem.u32 %idx, %idx, %sp;   // spatial_idx
6833    add.u32 %half, %half, %idx;
6834
6835    cvt.u64.u32 %off64, %half;
6836    shl.b64 %off64, %off64, 2;
6837    add.u64 %tmp64, %in, %off64;
6838    ld.global.f32 %val, [%tmp64];
6839    add.f32 %sum, %sum, %val;
6840    fma.rn.f32 %sqsum, %val, %val, %sqsum;
6841
6842    // Restore idx for stride
6843    // Recompute idx from tid + iteration * bdim
6844    add.u32 %idx, %idx, %bdim;  // This is wrong - need proper loop counter
6845    bra PASS1_LOOP;
6846
6847PASS1_DONE:
6848    // Store to shared memory for block reduction
6849    cvt.u64.u32 %off64, %tid;
6850    shl.b64 %off64, %off64, 2;
6851    st.shared.f32 [smem_sum + %off64], %sum;
6852    st.shared.f32 [smem_sq + %off64], %sqsum;
6853    bar.sync 0;
6854
6855    // Tree reduction
6856    mov.u32 %half, 128;
6857REDUCE_LOOP:
6858    setp.lt.u32 %p, %half, 1;
6859    @%p bra REDUCE_DONE;
6860    setp.ge.u32 %p, %tid, %half;
6861    @%p bra REDUCE_SKIP;
6862
6863    add.u32 %idx, %tid, %half;
6864    cvt.u64.u32 %off64, %idx;
6865    shl.b64 %off64, %off64, 2;
6866    ld.shared.f32 %other, [smem_sum + %off64];
6867    cvt.u64.u32 %tmp64, %tid;
6868    shl.b64 %tmp64, %tmp64, 2;
6869    ld.shared.f32 %sum, [smem_sum + %tmp64];
6870    add.f32 %sum, %sum, %other;
6871    st.shared.f32 [smem_sum + %tmp64], %sum;
6872
6873    ld.shared.f32 %other, [smem_sq + %off64];
6874    ld.shared.f32 %sqsum, [smem_sq + %tmp64];
6875    add.f32 %sqsum, %sqsum, %other;
6876    st.shared.f32 [smem_sq + %tmp64], %sqsum;
6877
6878REDUCE_SKIP:
6879    bar.sync 0;
6880    shr.u32 %half, %half, 1;
6881    bra REDUCE_LOOP;
6882
6883REDUCE_DONE:
6884    // Thread 0 computes mean and invstd
6885    setp.ne.u32 %ptid0, %tid, 0;
6886
6887    @%ptid0 bra WAIT_STATS;
6888
6889    ld.shared.f32 %sum, [smem_sum];
6890    ld.shared.f32 %sqsum, [smem_sq];
6891    cvt.rn.f32.u32 %n_f, %tpc;
6892    div.rn.f32 %mean, %sum, %n_f;
6893    // var = sqsum/n - mean^2
6894    div.rn.f32 %var, %sqsum, %n_f;
6895    fma.rn.f32 %var, %mean, %mean, %var;  // This adds mean^2, need to subtract
6896    // Actually: var = E[x^2] - E[x]^2, so var = sqsum/n - mean^2
6897    // We had: var = sqsum/n, now subtract mean^2
6898    neg.f32 %other, %mean;
6899    fma.rn.f32 %var, %other, %mean, %var; // var = var + (-mean)*mean = sqsum/n - mean^2
6900
6901    // invstd = 1/sqrt(var + eps)
6902    add.f32 %other, %var, %eps_reg;
6903    sqrt.rn.f32 %other, %other;
6904    div.rn.f32 %invstd, %one, %other;
6905
6906    // Save mean and invstd
6907    cvt.u64.u32 %off64, %ch;
6908    shl.b64 %off64, %off64, 2;
6909    add.u64 %tmp64, %sm, %off64;
6910    st.global.f32 [%tmp64], %mean;
6911    add.u64 %tmp64, %si, %off64;
6912    st.global.f32 [%tmp64], %invstd;
6913
6914    // Store to shared for other threads
6915    st.shared.f32 [smem_sum], %mean;
6916    st.shared.f32 [smem_sq], %invstd;
6917
6918WAIT_STATS:
6919    bar.sync 0;
6920    // All threads read mean and invstd from shared
6921    ld.shared.f32 %mean, [smem_sum];
6922    ld.shared.f32 %invstd, [smem_sq];
6923
6924    // Load weight and bias for this channel
6925    cvt.u64.u32 %off64, %ch;
6926    shl.b64 %off64, %off64, 2;
6927    add.u64 %tmp64, %w, %off64;
6928    ld.global.f32 %gamma, [%tmp64];
6929    add.u64 %tmp64, %b, %off64;
6930    ld.global.f32 %beta, [%tmp64];
6931
6932    // ---- Pass 2: normalize + affine ----
6933    // For now this is a placeholder - the indexing needs to match pass 1
6934    // Each thread normalizes its elements
6935
6936END:
6937    ret;
6938}
6939";
6940
6941
6942/// PTX kernel for MaxPool2d forward: sliding window max.
6943///
6944/// One thread per output element. Reads the kernel-sized window from the
6945/// input and computes the maximum value.
6946#[cfg(feature = "cuda")]
6947pub(crate) const MAXPOOL2D_PTX: &str = "\
6948.version 7.0
6949.target sm_52
6950.address_size 64
6951
6952.visible .entry maxpool2d_forward_kernel(
6953    .param .u64 input_ptr,
6954    .param .u64 output_ptr,
6955    .param .u32 batch,
6956    .param .u32 channels,
6957    .param .u32 h_in,
6958    .param .u32 w_in,
6959    .param .u32 h_out,
6960    .param .u32 w_out,
6961    .param .u32 kh,
6962    .param .u32 kw,
6963    .param .u32 sh,
6964    .param .u32 sw,
6965    .param .u32 ph,
6966    .param .u32 pw,
6967    .param .u32 total
6968) {
6969    .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
6970    .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp;
6971    .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
6972    .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
6973    .reg .u32 %batch_reg, %ch_reg;
6974    .reg .u64 %in, %out, %off64, %tmp64;
6975    .reg .f32 %max_val, %cur_val, %neg_inf;
6976    .reg .pred %p, %p_bounds, %p_gt;
6977
6978    ld.param.u64 %in, [input_ptr];
6979    ld.param.u64 %out, [output_ptr];
6980    ld.param.u32 %batch_reg, [batch];
6981    ld.param.u32 %ch_reg, [channels];
6982    ld.param.u32 %h_in_reg, [h_in];
6983    ld.param.u32 %w_in_reg, [w_in];
6984    ld.param.u32 %h_out_reg, [h_out];
6985    ld.param.u32 %w_out_reg, [w_out];
6986    ld.param.u32 %kh_reg, [kh];
6987    ld.param.u32 %kw_reg, [kw];
6988    ld.param.u32 %sh_reg, [sh];
6989    ld.param.u32 %sw_reg, [sw];
6990    ld.param.u32 %ph_reg, [ph];
6991    ld.param.u32 %pw_reg, [pw];
6992    ld.param.u32 %total_reg, [total];
6993
6994    mov.u32 %bid, %ctaid.x;
6995    mov.u32 %bdim, %ntid.x;
6996    mov.u32 %tid, %tid.x;
6997    mov.u32 %gdim, %nctaid.x;
6998    mad.lo.u32 %idx, %bid, %bdim, %tid;
6999    mul.lo.u32 %stride, %bdim, %gdim;
7000
7001    // -inf for max initialization
7002    mov.f32 %neg_inf, 0fFF800000;
7003
7004LOOP:
7005    setp.ge.u32 %p, %idx, %total_reg;
7006    @%p bra END;
7007
7008    // Decompose idx into (b, c, oh, ow)
7009    mov.u32 %rem, %idx;
7010    div.u32 %b_idx, %rem, %ch_reg;
7011    // Actually need: idx = b * C * H_out * W_out + c * H_out * W_out + oh * W_out + ow
7012    // So decompose from the right:
7013    rem.u32 %ow, %rem, %w_out_reg;
7014    div.u32 %rem, %rem, %w_out_reg;
7015    rem.u32 %oh, %rem, %h_out_reg;
7016    div.u32 %rem, %rem, %h_out_reg;
7017    rem.u32 %c_idx, %rem, %ch_reg;
7018    div.u32 %b_idx, %rem, %ch_reg;
7019
7020    mov.f32 %max_val, %neg_inf;
7021
7022    // Slide the kernel window
7023    mov.u32 %i, 0;
7024KH_LOOP:
7025    setp.ge.u32 %p, %i, %kh_reg;
7026    @%p bra KH_DONE;
7027
7028    mov.u32 %j, 0;
7029KW_LOOP:
7030    setp.ge.u32 %p, %j, %kw_reg;
7031    @%p bra KW_DONE;
7032
7033    // ih = oh * sh + i - ph, iw = ow * sw + j - pw
7034    mad.lo.u32 %ih, %oh, %sh_reg, %i;
7035    sub.u32 %ih, %ih, %ph_reg;
7036    mad.lo.u32 %iw, %ow, %sw_reg, %j;
7037    sub.u32 %iw, %iw, %pw_reg;
7038
7039    // Bounds check: 0 <= ih < h_in && 0 <= iw < w_in
7040    // Since unsigned, just check < h_in and < w_in
7041    setp.ge.u32 %p_bounds, %ih, %h_in_reg;
7042    @%p_bounds bra KW_NEXT;
7043    setp.ge.u32 %p_bounds, %iw, %w_in_reg;
7044    @%p_bounds bra KW_NEXT;
7045
7046    // input_offset = b * C * H * W + c * H * W + ih * W + iw
7047    mul.lo.u32 %tmp, %b_idx, %ch_reg;
7048    add.u32 %tmp, %tmp, %c_idx;
7049    mul.lo.u32 %tmp, %tmp, %h_in_reg;
7050    add.u32 %tmp, %tmp, %ih;
7051    mul.lo.u32 %tmp, %tmp, %w_in_reg;
7052    add.u32 %tmp, %tmp, %iw;
7053
7054    cvt.u64.u32 %off64, %tmp;
7055    shl.b64 %off64, %off64, 2;
7056    add.u64 %tmp64, %in, %off64;
7057    ld.global.f32 %cur_val, [%tmp64];
7058
7059    max.f32 %max_val, %max_val, %cur_val;
7060
7061KW_NEXT:
7062    add.u32 %j, %j, 1;
7063    bra KW_LOOP;
7064
7065KW_DONE:
7066    add.u32 %i, %i, 1;
7067    bra KH_LOOP;
7068
7069KH_DONE:
7070    // Store output
7071    cvt.u64.u32 %off64, %idx;
7072    shl.b64 %off64, %off64, 2;
7073    add.u64 %tmp64, %out, %off64;
7074    st.global.f32 [%tmp64], %max_val;
7075
7076    add.u32 %idx, %idx, %stride;
7077    bra LOOP;
7078
7079END:
7080    ret;
7081}
7082";
7083
7084
7085/// PTX kernel for AvgPool2d forward: sliding window average.
7086///
7087/// One thread per output element. Same structure as MaxPool2d but
7088/// computes sum / count instead of max.
7089#[cfg(feature = "cuda")]
7090pub(crate) const AVGPOOL2D_PTX: &str = "\
7091.version 7.0
7092.target sm_52
7093.address_size 64
7094
7095.visible .entry avgpool2d_forward_kernel(
7096    .param .u64 input_ptr,
7097    .param .u64 output_ptr,
7098    .param .u32 batch,
7099    .param .u32 channels,
7100    .param .u32 h_in,
7101    .param .u32 w_in,
7102    .param .u32 h_out,
7103    .param .u32 w_out,
7104    .param .u32 kh,
7105    .param .u32 kw,
7106    .param .u32 sh,
7107    .param .u32 sw,
7108    .param .u32 ph,
7109    .param .u32 pw,
7110    .param .u32 total
7111) {
7112    .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
7113    .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp, %count;
7114    .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
7115    .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
7116    .reg .u32 %batch_reg, %ch_reg;
7117    .reg .u64 %in, %out, %off64, %tmp64;
7118    .reg .f32 %sum_val, %cur_val, %count_f, %avg;
7119    .reg .pred %p, %p_bounds;
7120
7121    ld.param.u64 %in, [input_ptr];
7122    ld.param.u64 %out, [output_ptr];
7123    ld.param.u32 %batch_reg, [batch];
7124    ld.param.u32 %ch_reg, [channels];
7125    ld.param.u32 %h_in_reg, [h_in];
7126    ld.param.u32 %w_in_reg, [w_in];
7127    ld.param.u32 %h_out_reg, [h_out];
7128    ld.param.u32 %w_out_reg, [w_out];
7129    ld.param.u32 %kh_reg, [kh];
7130    ld.param.u32 %kw_reg, [kw];
7131    ld.param.u32 %sh_reg, [sh];
7132    ld.param.u32 %sw_reg, [sw];
7133    ld.param.u32 %ph_reg, [ph];
7134    ld.param.u32 %pw_reg, [pw];
7135    ld.param.u32 %total_reg, [total];
7136
7137    mov.u32 %bid, %ctaid.x;
7138    mov.u32 %bdim, %ntid.x;
7139    mov.u32 %tid, %tid.x;
7140    mov.u32 %gdim, %nctaid.x;
7141    mad.lo.u32 %idx, %bid, %bdim, %tid;
7142    mul.lo.u32 %stride, %bdim, %gdim;
7143
7144LOOP:
7145    setp.ge.u32 %p, %idx, %total_reg;
7146    @%p bra END;
7147
7148    // Decompose idx into (b, c, oh, ow) — same as MaxPool2d
7149    mov.u32 %rem, %idx;
7150    rem.u32 %ow, %rem, %w_out_reg;
7151    div.u32 %rem, %rem, %w_out_reg;
7152    rem.u32 %oh, %rem, %h_out_reg;
7153    div.u32 %rem, %rem, %h_out_reg;
7154    rem.u32 %c_idx, %rem, %ch_reg;
7155    div.u32 %b_idx, %rem, %ch_reg;
7156
7157    mov.f32 %sum_val, 0f00000000;
7158    mov.u32 %count, 0;
7159
7160    mov.u32 %i, 0;
7161AKH_LOOP:
7162    setp.ge.u32 %p, %i, %kh_reg;
7163    @%p bra AKH_DONE;
7164
7165    mov.u32 %j, 0;
7166AKW_LOOP:
7167    setp.ge.u32 %p, %j, %kw_reg;
7168    @%p bra AKW_DONE;
7169
7170    mad.lo.u32 %ih, %oh, %sh_reg, %i;
7171    sub.u32 %ih, %ih, %ph_reg;
7172    mad.lo.u32 %iw, %ow, %sw_reg, %j;
7173    sub.u32 %iw, %iw, %pw_reg;
7174
7175    setp.ge.u32 %p_bounds, %ih, %h_in_reg;
7176    @%p_bounds bra AKW_NEXT;
7177    setp.ge.u32 %p_bounds, %iw, %w_in_reg;
7178    @%p_bounds bra AKW_NEXT;
7179
7180    mul.lo.u32 %tmp, %b_idx, %ch_reg;
7181    add.u32 %tmp, %tmp, %c_idx;
7182    mul.lo.u32 %tmp, %tmp, %h_in_reg;
7183    add.u32 %tmp, %tmp, %ih;
7184    mul.lo.u32 %tmp, %tmp, %w_in_reg;
7185    add.u32 %tmp, %tmp, %iw;
7186
7187    cvt.u64.u32 %off64, %tmp;
7188    shl.b64 %off64, %off64, 2;
7189    add.u64 %tmp64, %in, %off64;
7190    ld.global.f32 %cur_val, [%tmp64];
7191
7192    add.f32 %sum_val, %sum_val, %cur_val;
7193    add.u32 %count, %count, 1;
7194
7195AKW_NEXT:
7196    add.u32 %j, %j, 1;
7197    bra AKW_LOOP;
7198
7199AKW_DONE:
7200    add.u32 %i, %i, 1;
7201    bra AKH_LOOP;
7202
7203AKH_DONE:
7204    // avg = sum / count (count_include_pad = false behavior)
7205    cvt.rn.f32.u32 %count_f, %count;
7206    div.rn.f32 %avg, %sum_val, %count_f;
7207
7208    cvt.u64.u32 %off64, %idx;
7209    shl.b64 %off64, %off64, 2;
7210    add.u64 %tmp64, %out, %off64;
7211    st.global.f32 [%tmp64], %avg;
7212
7213    add.u32 %idx, %idx, %stride;
7214    bra LOOP;
7215
7216END:
7217    ret;
7218}
7219";
7220
7221
7222#[cfg(feature = "cuda")]
7223pub(crate) const SOFTMAX_PTX: &str = "\
7224.version 7.0\n\
7225.target sm_52\n\
7226.address_size 64\n\
7227\n\
7228.shared .align 4 .f32 sdata[256];\n\
7229\n\
7230.visible .entry softmax_kernel(\n\
7231    .param .u64 input_ptr,\n\
7232    .param .u64 output_ptr,\n\
7233    .param .u32 rows,\n\
7234    .param .u32 cols\n\
7235) {\n\
7236    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
7237    .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
7238    .reg .f32 %val, %max_val, %sum_val, %exp_val, %result;\n\
7239    .reg .pred %p, %loop_p;\n\
7240    .reg .u32 %half, %other_tid;\n\
7241    .reg .f32 %other_val;\n\
7242    .reg .pred %reduce_p;\n\
7243\n\
7244    ld.param.u64 %in, [input_ptr];\n\
7245    ld.param.u64 %out, [output_ptr];\n\
7246    ld.param.u32 %rows_reg, [rows];\n\
7247    ld.param.u32 %cols_reg, [cols];\n\
7248\n\
7249    mov.u32 %bid, %ctaid.x;\n\
7250    mov.u32 %bdim, %ntid.x;\n\
7251    mov.u32 %r_tid, %tid.x;\n\
7252    mov.u64 %sbase, sdata;\n\
7253\n\
7254    setp.ge.u32 %p, %bid, %rows_reg;\n\
7255    @%p bra DONE;\n\
7256\n\
7257    cvt.u64.u32 %row_off, %bid;\n\
7258    cvt.u64.u32 %off, %cols_reg;\n\
7259    mul.lo.u64 %row_off, %row_off, %off;\n\
7260    shl.b64 %row_off, %row_off, 2;\n\
7261\n\
7262    mov.f32 %max_val, 0fFF800000;\n\
7263    mov.u32 %j, %r_tid;\n\
7264FIND_MAX:\n\
7265    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7266    @%loop_p bra FIND_MAX_DONE;\n\
7267    cvt.u64.u32 %off, %j;\n\
7268    shl.b64 %off, %off, 2;\n\
7269    add.u64 %off, %in, %off;\n\
7270    add.u64 %off, %off, %row_off;\n\
7271    ld.global.f32 %val, [%off];\n\
7272    max.f32 %max_val, %max_val, %val;\n\
7273    add.u32 %j, %j, %bdim;\n\
7274    bra FIND_MAX;\n\
7275FIND_MAX_DONE:\n\
7276\n\
7277    cvt.u64.u32 %off, %r_tid;\n\
7278    shl.b64 %off, %off, 2;\n\
7279    add.u64 %saddr, %sbase, %off;\n\
7280    st.shared.f32 [%saddr], %max_val;\n\
7281    bar.sync 0;\n\
7282\n\
7283    mov.u32 %half, %bdim;\n\
7284MAX_REDUCE:\n\
7285    shr.u32 %half, %half, 1;\n\
7286    setp.eq.u32 %reduce_p, %half, 0;\n\
7287    @%reduce_p bra MAX_REDUCE_DONE;\n\
7288    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7289    @%reduce_p bra MAX_REDUCE_SKIP;\n\
7290    add.u32 %other_tid, %r_tid, %half;\n\
7291    cvt.u64.u32 %off, %other_tid;\n\
7292    shl.b64 %off, %off, 2;\n\
7293    add.u64 %saddr, %sbase, %off;
7294    ld.shared.f32 %other_val, [%saddr];\n\
7295    cvt.u64.u32 %off, %r_tid;\n\
7296    shl.b64 %off, %off, 2;\n\
7297    add.u64 %saddr, %sbase, %off;\n\
7298    ld.shared.f32 %max_val, [%saddr];\n\
7299    max.f32 %max_val, %max_val, %other_val;\n\
7300    add.u64 %saddr, %sbase, %off;\n\
7301    st.shared.f32 [%saddr], %max_val;\n\
7302MAX_REDUCE_SKIP:\n\
7303    bar.sync 0;\n\
7304    bra MAX_REDUCE;\n\
7305MAX_REDUCE_DONE:\n\
7306\n\
7307    ld.shared.f32 %max_val, [sdata];\n\
7308    bar.sync 0;\n\
7309\n\
7310    mov.f32 %sum_val, 0f00000000;\n\
7311    mov.u32 %j, %r_tid;\n\
7312SUM_EXP:\n\
7313    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7314    @%loop_p bra SUM_EXP_DONE;\n\
7315    cvt.u64.u32 %off, %j;\n\
7316    shl.b64 %off, %off, 2;\n\
7317    add.u64 %off, %in, %off;\n\
7318    add.u64 %off, %off, %row_off;\n\
7319    ld.global.f32 %val, [%off];\n\
7320    sub.f32 %val, %val, %max_val;\n\
7321    mul.f32 %val, %val, 0f3FB8AA3B;\n\
7322    ex2.approx.f32 %exp_val, %val;\n\
7323    add.f32 %sum_val, %sum_val, %exp_val;\n\
7324    cvt.u64.u32 %off, %j;\n\
7325    shl.b64 %off, %off, 2;\n\
7326    add.u64 %off, %out, %off;\n\
7327    add.u64 %off, %off, %row_off;\n\
7328    st.global.f32 [%off], %exp_val;\n\
7329    add.u32 %j, %j, %bdim;\n\
7330    bra SUM_EXP;\n\
7331SUM_EXP_DONE:\n\
7332\n\
7333    cvt.u64.u32 %off, %r_tid;\n\
7334    shl.b64 %off, %off, 2;\n\
7335    add.u64 %saddr, %sbase, %off;\n\
7336    st.shared.f32 [%saddr], %sum_val;\n\
7337    bar.sync 0;\n\
7338\n\
7339    mov.u32 %half, %bdim;\n\
7340SUM_REDUCE:\n\
7341    shr.u32 %half, %half, 1;\n\
7342    setp.eq.u32 %reduce_p, %half, 0;\n\
7343    @%reduce_p bra SUM_REDUCE_DONE;\n\
7344    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7345    @%reduce_p bra SUM_REDUCE_SKIP;\n\
7346    add.u32 %other_tid, %r_tid, %half;\n\
7347    cvt.u64.u32 %off, %other_tid;\n\
7348    shl.b64 %off, %off, 2;\n\
7349    add.u64 %saddr, %sbase, %off;
7350    ld.shared.f32 %other_val, [%saddr];\n\
7351    cvt.u64.u32 %off, %r_tid;\n\
7352    shl.b64 %off, %off, 2;\n\
7353    add.u64 %saddr, %sbase, %off;\n\
7354    ld.shared.f32 %sum_val, [%saddr];\n\
7355    add.f32 %sum_val, %sum_val, %other_val;\n\
7356    add.u64 %saddr, %sbase, %off;\n\
7357    st.shared.f32 [%saddr], %sum_val;\n\
7358SUM_REDUCE_SKIP:\n\
7359    bar.sync 0;\n\
7360    bra SUM_REDUCE;\n\
7361SUM_REDUCE_DONE:\n\
7362\n\
7363    ld.shared.f32 %sum_val, [sdata];\n\
7364    bar.sync 0;\n\
7365\n\
7366    rcp.approx.f32 %sum_val, %sum_val;\n\
7367    mov.u32 %j, %r_tid;\n\
7368NORMALIZE:\n\
7369    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7370    @%loop_p bra NORMALIZE_DONE;\n\
7371    cvt.u64.u32 %off, %j;\n\
7372    shl.b64 %off, %off, 2;\n\
7373    add.u64 %off, %out, %off;\n\
7374    add.u64 %off, %off, %row_off;\n\
7375    ld.global.f32 %val, [%off];\n\
7376    mul.f32 %result, %val, %sum_val;\n\
7377    st.global.f32 [%off], %result;\n\
7378    add.u32 %j, %j, %bdim;\n\
7379    bra NORMALIZE;\n\
7380NORMALIZE_DONE:\n\
7381\n\
7382DONE:\n\
7383    ret;\n\
7384}\n\
7385";
7386
7387/// PTX source for `softmax_f64_kernel`: row-wise softmax (f64).
7388#[cfg(feature = "cuda")]
7389pub(crate) const SOFTMAX_F64_PTX: &str = "\
7390.version 7.0\n\
7391.target sm_52\n\
7392.address_size 64\n\
7393\n\
7394.shared .align 8 .f64 sdata[256];\n\
7395\n\
7396.visible .entry softmax_f64_kernel(\n\
7397    .param .u64 input_ptr,\n\
7398    .param .u64 output_ptr,\n\
7399    .param .u32 rows,\n\
7400    .param .u32 cols\n\
7401) {\n\
7402    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
7403    .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
7404    .reg .f64 %val, %max_val, %sum_val, %exp_val, %result, %one;\n\
7405    .reg .pred %p, %loop_p;\n\
7406    .reg .u32 %half, %other_tid;\n\
7407    .reg .f64 %other_val;\n\
7408    .reg .pred %reduce_p;\n\
7409    .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
7410    .reg .s32 %e_ni;\n\
7411    .reg .s64 %e_ni64, %e_bits;\n\
7412\n\
7413    ld.param.u64 %in, [input_ptr];\n\
7414    ld.param.u64 %out, [output_ptr];\n\
7415    ld.param.u32 %rows_reg, [rows];\n\
7416    ld.param.u32 %cols_reg, [cols];\n\
7417\n\
7418    mov.u32 %bid, %ctaid.x;\n\
7419    mov.u32 %bdim, %ntid.x;\n\
7420    mov.u32 %r_tid, %tid.x;\n\
7421    mov.u64 %sbase, sdata;\n\
7422    mov.f64 %one, 0d3FF0000000000000;\n\
7423\n\
7424    setp.ge.u32 %p, %bid, %rows_reg;\n\
7425    @%p bra DONE;\n\
7426\n\
7427    cvt.u64.u32 %row_off, %bid;\n\
7428    cvt.u64.u32 %off, %cols_reg;\n\
7429    mul.lo.u64 %row_off, %row_off, %off;\n\
7430    shl.b64 %row_off, %row_off, 3;\n\
7431\n\
7432    mov.f64 %max_val, 0dFFF0000000000000;\n\
7433    mov.u32 %j, %r_tid;\n\
7434FIND_MAX:\n\
7435    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7436    @%loop_p bra FIND_MAX_DONE;\n\
7437    cvt.u64.u32 %off, %j;\n\
7438    shl.b64 %off, %off, 3;\n\
7439    add.u64 %off, %in, %off;\n\
7440    add.u64 %off, %off, %row_off;\n\
7441    ld.global.f64 %val, [%off];\n\
7442    max.f64 %max_val, %max_val, %val;\n\
7443    add.u32 %j, %j, %bdim;\n\
7444    bra FIND_MAX;\n\
7445FIND_MAX_DONE:\n\
7446\n\
7447    cvt.u64.u32 %off, %r_tid;\n\
7448    shl.b64 %off, %off, 3;\n\
7449    add.u64 %saddr, %sbase, %off;\n\
7450    st.shared.f64 [%saddr], %max_val;\n\
7451    bar.sync 0;\n\
7452\n\
7453    mov.u32 %half, %bdim;\n\
7454MAX_REDUCE:\n\
7455    shr.u32 %half, %half, 1;\n\
7456    setp.eq.u32 %reduce_p, %half, 0;\n\
7457    @%reduce_p bra MAX_REDUCE_DONE;\n\
7458    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7459    @%reduce_p bra MAX_REDUCE_SKIP;\n\
7460    add.u32 %other_tid, %r_tid, %half;\n\
7461    cvt.u64.u32 %off, %other_tid;\n\
7462    shl.b64 %off, %off, 3;\n\
7463    add.u64 %saddr, %sbase, %off;\n\
7464    ld.shared.f64 %other_val, [%saddr];\n\
7465    cvt.u64.u32 %off, %r_tid;\n\
7466    shl.b64 %off, %off, 3;\n\
7467    add.u64 %saddr, %sbase, %off;\n\
7468    ld.shared.f64 %max_val, [%saddr];\n\
7469    max.f64 %max_val, %max_val, %other_val;\n\
7470    st.shared.f64 [%saddr], %max_val;\n\
7471MAX_REDUCE_SKIP:\n\
7472    bar.sync 0;\n\
7473    bra MAX_REDUCE;\n\
7474MAX_REDUCE_DONE:\n\
7475\n\
7476    ld.shared.f64 %max_val, [sdata];\n\
7477    bar.sync 0;\n\
7478\n\
7479    mov.f64 %sum_val, 0d0000000000000000;\n\
7480    mov.u32 %j, %r_tid;\n\
7481SUM_EXP:\n\
7482    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7483    @%loop_p bra SUM_EXP_DONE;\n\
7484    cvt.u64.u32 %off, %j;\n\
7485    shl.b64 %off, %off, 3;\n\
7486    add.u64 %off, %in, %off;\n\
7487    add.u64 %off, %off, %row_off;\n\
7488    ld.global.f64 %val, [%off];\n\
7489    sub.f64 %val, %val, %max_val;\n\
7490    mov.f64 %e_one, 0d3FF0000000000000;\n\
7491    mov.f64 %e_half, 0d3FE0000000000000;\n\
7492    mul.f64 %e_nf, %val, 0d3FF71547652B82FE;\n\
7493    cvt.rni.f64.f64 %e_nf, %e_nf;\n\
7494    cvt.rni.s32.f64 %e_ni, %e_nf;\n\
7495    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %val;\n\
7496    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
7497    mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
7498    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
7499    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
7500    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
7501    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
7502    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
7503    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
7504    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
7505    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
7506    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
7507    fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
7508    fma.rn.f64 %exp_val, %e_p, %e_r, %e_one;\n\
7509    cvt.s64.s32 %e_ni64, %e_ni;\n\
7510    add.s64 %e_ni64, %e_ni64, 1023;\n\
7511    shl.b64 %e_bits, %e_ni64, 52;\n\
7512    mov.b64 %e_nf, %e_bits;\n\
7513    mul.f64 %exp_val, %exp_val, %e_nf;\n\
7514    add.f64 %sum_val, %sum_val, %exp_val;\n\
7515    cvt.u64.u32 %off, %j;\n\
7516    shl.b64 %off, %off, 3;\n\
7517    add.u64 %off, %out, %off;\n\
7518    add.u64 %off, %off, %row_off;\n\
7519    st.global.f64 [%off], %exp_val;\n\
7520    add.u32 %j, %j, %bdim;\n\
7521    bra SUM_EXP;\n\
7522SUM_EXP_DONE:\n\
7523\n\
7524    cvt.u64.u32 %off, %r_tid;\n\
7525    shl.b64 %off, %off, 3;\n\
7526    add.u64 %saddr, %sbase, %off;\n\
7527    st.shared.f64 [%saddr], %sum_val;\n\
7528    bar.sync 0;\n\
7529\n\
7530    mov.u32 %half, %bdim;\n\
7531SUM_REDUCE:\n\
7532    shr.u32 %half, %half, 1;\n\
7533    setp.eq.u32 %reduce_p, %half, 0;\n\
7534    @%reduce_p bra SUM_REDUCE_DONE;\n\
7535    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7536    @%reduce_p bra SUM_REDUCE_SKIP;\n\
7537    add.u32 %other_tid, %r_tid, %half;\n\
7538    cvt.u64.u32 %off, %other_tid;\n\
7539    shl.b64 %off, %off, 3;\n\
7540    add.u64 %saddr, %sbase, %off;\n\
7541    ld.shared.f64 %other_val, [%saddr];\n\
7542    cvt.u64.u32 %off, %r_tid;\n\
7543    shl.b64 %off, %off, 3;\n\
7544    add.u64 %saddr, %sbase, %off;\n\
7545    ld.shared.f64 %sum_val, [%saddr];\n\
7546    add.f64 %sum_val, %sum_val, %other_val;\n\
7547    st.shared.f64 [%saddr], %sum_val;\n\
7548SUM_REDUCE_SKIP:\n\
7549    bar.sync 0;\n\
7550    bra SUM_REDUCE;\n\
7551SUM_REDUCE_DONE:\n\
7552\n\
7553    ld.shared.f64 %sum_val, [sdata];\n\
7554    bar.sync 0;\n\
7555\n\
7556    div.rn.f64 %sum_val, %one, %sum_val;\n\
7557    mov.u32 %j, %r_tid;\n\
7558NORMALIZE:\n\
7559    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7560    @%loop_p bra NORMALIZE_DONE;\n\
7561    cvt.u64.u32 %off, %j;\n\
7562    shl.b64 %off, %off, 3;\n\
7563    add.u64 %off, %out, %off;\n\
7564    add.u64 %off, %off, %row_off;\n\
7565    ld.global.f64 %val, [%off];\n\
7566    mul.f64 %result, %val, %sum_val;\n\
7567    st.global.f64 [%off], %result;\n\
7568    add.u32 %j, %j, %bdim;\n\
7569    bra NORMALIZE;\n\
7570NORMALIZE_DONE:\n\
7571\n\
7572DONE:\n\
7573    ret;\n\
7574}\n\
7575";
7576
7577// ---------------------------------------------------------------------------
7578// Dropout PTX kernel (inverted dropout with xorshift RNG)
7579// ---------------------------------------------------------------------------
7580
7581#[cfg(feature = "cuda")]
7582pub(crate) const DROPOUT_PTX: &str = "\
7583.version 7.0\n\
7584.target sm_52\n\
7585.address_size 64\n\
7586\n\
7587.visible .entry dropout_kernel(\n\
7588    .param .u64 input_ptr,\n\
7589    .param .u64 output_ptr,\n\
7590    .param .u32 n,\n\
7591    .param .u32 threshold,\n\
7592    .param .f32 scale,\n\
7593    .param .u32 seed\n\
7594) {\n\
7595    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %thresh, %seed_reg, %rng, %tmp;\n\
7596    .reg .u64 %in, %out, %off;\n\
7597    .reg .f32 %val, %scale_reg, %zero;\n\
7598    .reg .pred %p, %drop_p;\n\
7599\n\
7600    ld.param.u64 %in, [input_ptr];\n\
7601    ld.param.u64 %out, [output_ptr];\n\
7602    ld.param.u32 %n_reg, [n];\n\
7603    ld.param.u32 %thresh, [threshold];\n\
7604    ld.param.f32 %scale_reg, [scale];\n\
7605    ld.param.u32 %seed_reg, [seed];\n\
7606\n\
7607    mov.u32 %bid, %ctaid.x;\n\
7608    mov.u32 %bdim, %ntid.x;\n\
7609    mov.u32 %r_tid, %tid.x;\n\
7610    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
7611\n\
7612    setp.ge.u32 %p, %r_tid, %n_reg;\n\
7613    @%p bra DONE;\n\
7614\n\
7615    mul.lo.u32 %rng, %r_tid, 2654435761;\n\
7616    xor.b32 %rng, %rng, %seed_reg;\n\
7617    shl.b32 %tmp, %rng, 13;\n\
7618    xor.b32 %rng, %rng, %tmp;\n\
7619    shr.b32 %tmp, %rng, 17;\n\
7620    xor.b32 %rng, %rng, %tmp;\n\
7621    shl.b32 %tmp, %rng, 5;\n\
7622    xor.b32 %rng, %rng, %tmp;\n\
7623\n\
7624    cvt.u64.u32 %off, %r_tid;\n\
7625    shl.b64 %off, %off, 2;\n\
7626    add.u64 %in, %in, %off;\n\
7627    add.u64 %out, %out, %off;\n\
7628    ld.global.f32 %val, [%in];\n\
7629\n\
7630    setp.lo.u32 %drop_p, %rng, %thresh;\n\
7631    mov.f32 %zero, 0f00000000;\n\
7632    @%drop_p mov.f32 %val, %zero;\n\
7633    @!%drop_p mul.f32 %val, %val, %scale_reg;\n\
7634\n\
7635    st.global.f32 [%out], %val;\n\
7636\n\
7637DONE:\n\
7638    ret;\n\
7639}\n\
7640";
7641
7642
7643// ---------------------------------------------------------------------------
7644// General N-dimensional broadcast binary PTX kernels
7645// ---------------------------------------------------------------------------
7646//
7647// Each thread computes one output element. The kernel decomposes the flat
7648// output index into N-dimensional coordinates, maps each coordinate through
7649// broadcast strides for A and B, and loads from the correct flat position.
7650//
7651// Parameters:
7652//   a_ptr         - pointer to A's device buffer
7653//   b_ptr         - pointer to B's device buffer
7654//   out_ptr       - pointer to output device buffer
7655//   a_strides_ptr - pointer to u32[ndim] broadcast strides for A
7656//   b_strides_ptr - pointer to u32[ndim] broadcast strides for B
7657//   out_shape_ptr - pointer to u32[ndim] output shape
7658//   n             - total output elements
7659//   ndim          - number of dimensions
7660//
7661// Broadcast strides: for each dimension d, stride is the normal
7662// C-contiguous stride if dim_size > 1, or 0 if dim_size == 1 (broadcast).
7663
7664/// PTX for general broadcast add: `out[i] = a[bcast_a(i)] + b[bcast_b(i)]`.
7665#[cfg(feature = "cuda")]
7666pub(crate) const BROADCAST_ADD_PTX: &str = "\
7667.version 7.0
7668.target sm_52
7669.address_size 64
7670
7671.visible .entry broadcast_add_kernel(
7672    .param .u64 a_ptr,
7673    .param .u64 b_ptr,
7674    .param .u64 out_ptr,
7675    .param .u64 a_strides_ptr,
7676    .param .u64 b_strides_ptr,
7677    .param .u64 out_shape_ptr,
7678    .param .u32 n,
7679    .param .u32 ndim
7680) {
7681    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7682    .reg .u32 %remaining, %a_idx, %b_idx, %d;
7683    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7684    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7685    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7686    .reg .f32 %va, %vb, %vr;
7687    .reg .pred %p, %loop_p;
7688
7689    ld.param.u64 %a, [a_ptr];
7690    ld.param.u64 %b, [b_ptr];
7691    ld.param.u64 %out, [out_ptr];
7692    ld.param.u64 %a_str, [a_strides_ptr];
7693    ld.param.u64 %b_str, [b_strides_ptr];
7694    ld.param.u64 %oshape, [out_shape_ptr];
7695    ld.param.u32 %n_reg, [n];
7696    ld.param.u32 %ndim_reg, [ndim];
7697
7698    // Global thread index.
7699    mov.u32 %bid, %ctaid.x;
7700    mov.u32 %bdim, %ntid.x;
7701    mov.u32 %r_tid, %tid.x;
7702    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7703
7704    setp.ge.u32 %p, %r_tid, %n_reg;
7705    @%p bra DONE;
7706
7707    // Decompose flat index into N-d coordinates and compute A/B indices.
7708    mov.u32 %remaining, %r_tid;
7709    mov.u32 %a_idx, 0;
7710    mov.u32 %b_idx, 0;
7711    mov.u32 %d, %ndim_reg;
7712
7713LOOP:
7714    setp.eq.u32 %loop_p, %d, 0;
7715    @%loop_p bra END_LOOP;
7716
7717    sub.u32 %d, %d, 1;
7718
7719    // Byte offset for dimension d: d * 4.
7720    cvt.u64.u32 %d64, %d;
7721    shl.b64 %d64, %d64, 2;
7722
7723    // Load out_shape[d].
7724    add.u64 %tmp, %oshape, %d64;
7725    ld.global.u32 %shape_d, [%tmp];
7726
7727    // Load a_strides[d] and b_strides[d].
7728    add.u64 %tmp, %a_str, %d64;
7729    ld.global.u32 %a_str_d, [%tmp];
7730    add.u64 %tmp, %b_str, %d64;
7731    ld.global.u32 %b_str_d, [%tmp];
7732
7733    // coord = remaining % shape_d; remaining /= shape_d.
7734    rem.u32 %coord, %remaining, %shape_d;
7735    div.u32 %remaining, %remaining, %shape_d;
7736
7737    // a_idx += coord * a_stride[d]; b_idx += coord * b_stride[d].
7738    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7739    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7740
7741    bra LOOP;
7742END_LOOP:
7743
7744    // Load a[a_idx] and b[b_idx] (f32 = 4 bytes).
7745    cvt.u64.u32 %off_a, %a_idx;
7746    shl.b64 %off_a, %off_a, 2;
7747    add.u64 %off_a, %a, %off_a;
7748    ld.global.f32 %va, [%off_a];
7749
7750    cvt.u64.u32 %off_b, %b_idx;
7751    shl.b64 %off_b, %off_b, 2;
7752    add.u64 %off_b, %b, %off_b;
7753    ld.global.f32 %vb, [%off_b];
7754
7755    // Operation: add.
7756    add.f32 %vr, %va, %vb;
7757
7758    // Store to out[tid].
7759    cvt.u64.u32 %off_out, %r_tid;
7760    shl.b64 %off_out, %off_out, 2;
7761    add.u64 %off_out, %out, %off_out;
7762    st.global.f32 [%off_out], %vr;
7763
7764DONE:
7765    ret;
7766}
7767";
7768
7769
7770/// PTX for general broadcast sub: `out[i] = a[bcast_a(i)] - b[bcast_b(i)]`.
7771#[cfg(feature = "cuda")]
7772pub(crate) const BROADCAST_SUB_PTX: &str = "\
7773.version 7.0
7774.target sm_52
7775.address_size 64
7776
7777.visible .entry broadcast_sub_kernel(
7778    .param .u64 a_ptr,
7779    .param .u64 b_ptr,
7780    .param .u64 out_ptr,
7781    .param .u64 a_strides_ptr,
7782    .param .u64 b_strides_ptr,
7783    .param .u64 out_shape_ptr,
7784    .param .u32 n,
7785    .param .u32 ndim
7786) {
7787    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7788    .reg .u32 %remaining, %a_idx, %b_idx, %d;
7789    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7790    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7791    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7792    .reg .f32 %va, %vb, %vr;
7793    .reg .pred %p, %loop_p;
7794
7795    ld.param.u64 %a, [a_ptr];
7796    ld.param.u64 %b, [b_ptr];
7797    ld.param.u64 %out, [out_ptr];
7798    ld.param.u64 %a_str, [a_strides_ptr];
7799    ld.param.u64 %b_str, [b_strides_ptr];
7800    ld.param.u64 %oshape, [out_shape_ptr];
7801    ld.param.u32 %n_reg, [n];
7802    ld.param.u32 %ndim_reg, [ndim];
7803
7804    mov.u32 %bid, %ctaid.x;
7805    mov.u32 %bdim, %ntid.x;
7806    mov.u32 %r_tid, %tid.x;
7807    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7808    setp.ge.u32 %p, %r_tid, %n_reg;
7809    @%p bra DONE;
7810
7811    mov.u32 %remaining, %r_tid;
7812    mov.u32 %a_idx, 0;
7813    mov.u32 %b_idx, 0;
7814    mov.u32 %d, %ndim_reg;
7815LOOP:
7816    setp.eq.u32 %loop_p, %d, 0;
7817    @%loop_p bra END_LOOP;
7818    sub.u32 %d, %d, 1;
7819    cvt.u64.u32 %d64, %d;
7820    shl.b64 %d64, %d64, 2;
7821    add.u64 %tmp, %oshape, %d64;
7822    ld.global.u32 %shape_d, [%tmp];
7823    add.u64 %tmp, %a_str, %d64;
7824    ld.global.u32 %a_str_d, [%tmp];
7825    add.u64 %tmp, %b_str, %d64;
7826    ld.global.u32 %b_str_d, [%tmp];
7827    rem.u32 %coord, %remaining, %shape_d;
7828    div.u32 %remaining, %remaining, %shape_d;
7829    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7830    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7831    bra LOOP;
7832END_LOOP:
7833
7834    cvt.u64.u32 %off_a, %a_idx;
7835    shl.b64 %off_a, %off_a, 2;
7836    add.u64 %off_a, %a, %off_a;
7837    ld.global.f32 %va, [%off_a];
7838    cvt.u64.u32 %off_b, %b_idx;
7839    shl.b64 %off_b, %off_b, 2;
7840    add.u64 %off_b, %b, %off_b;
7841    ld.global.f32 %vb, [%off_b];
7842
7843    sub.f32 %vr, %va, %vb;
7844
7845    cvt.u64.u32 %off_out, %r_tid;
7846    shl.b64 %off_out, %off_out, 2;
7847    add.u64 %off_out, %out, %off_out;
7848    st.global.f32 [%off_out], %vr;
7849DONE:
7850    ret;
7851}
7852";
7853
7854
7855/// PTX for general broadcast mul: `out[i] = a[bcast_a(i)] * b[bcast_b(i)]`.
7856#[cfg(feature = "cuda")]
7857pub(crate) const BROADCAST_MUL_PTX: &str = "\
7858.version 7.0
7859.target sm_52
7860.address_size 64
7861
7862.visible .entry broadcast_mul_kernel(
7863    .param .u64 a_ptr,
7864    .param .u64 b_ptr,
7865    .param .u64 out_ptr,
7866    .param .u64 a_strides_ptr,
7867    .param .u64 b_strides_ptr,
7868    .param .u64 out_shape_ptr,
7869    .param .u32 n,
7870    .param .u32 ndim
7871) {
7872    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7873    .reg .u32 %remaining, %a_idx, %b_idx, %d;
7874    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7875    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7876    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7877    .reg .f32 %va, %vb, %vr;
7878    .reg .pred %p, %loop_p;
7879
7880    ld.param.u64 %a, [a_ptr];
7881    ld.param.u64 %b, [b_ptr];
7882    ld.param.u64 %out, [out_ptr];
7883    ld.param.u64 %a_str, [a_strides_ptr];
7884    ld.param.u64 %b_str, [b_strides_ptr];
7885    ld.param.u64 %oshape, [out_shape_ptr];
7886    ld.param.u32 %n_reg, [n];
7887    ld.param.u32 %ndim_reg, [ndim];
7888
7889    mov.u32 %bid, %ctaid.x;
7890    mov.u32 %bdim, %ntid.x;
7891    mov.u32 %r_tid, %tid.x;
7892    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7893    setp.ge.u32 %p, %r_tid, %n_reg;
7894    @%p bra DONE;
7895
7896    mov.u32 %remaining, %r_tid;
7897    mov.u32 %a_idx, 0;
7898    mov.u32 %b_idx, 0;
7899    mov.u32 %d, %ndim_reg;
7900LOOP:
7901    setp.eq.u32 %loop_p, %d, 0;
7902    @%loop_p bra END_LOOP;
7903    sub.u32 %d, %d, 1;
7904    cvt.u64.u32 %d64, %d;
7905    shl.b64 %d64, %d64, 2;
7906    add.u64 %tmp, %oshape, %d64;
7907    ld.global.u32 %shape_d, [%tmp];
7908    add.u64 %tmp, %a_str, %d64;
7909    ld.global.u32 %a_str_d, [%tmp];
7910    add.u64 %tmp, %b_str, %d64;
7911    ld.global.u32 %b_str_d, [%tmp];
7912    rem.u32 %coord, %remaining, %shape_d;
7913    div.u32 %remaining, %remaining, %shape_d;
7914    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7915    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7916    bra LOOP;
7917END_LOOP:
7918
7919    cvt.u64.u32 %off_a, %a_idx;
7920    shl.b64 %off_a, %off_a, 2;
7921    add.u64 %off_a, %a, %off_a;
7922    ld.global.f32 %va, [%off_a];
7923    cvt.u64.u32 %off_b, %b_idx;
7924    shl.b64 %off_b, %off_b, 2;
7925    add.u64 %off_b, %b, %off_b;
7926    ld.global.f32 %vb, [%off_b];
7927
7928    mul.f32 %vr, %va, %vb;
7929
7930    cvt.u64.u32 %off_out, %r_tid;
7931    shl.b64 %off_out, %off_out, 2;
7932    add.u64 %off_out, %out, %off_out;
7933    st.global.f32 [%off_out], %vr;
7934DONE:
7935    ret;
7936}
7937";
7938
7939
7940/// PTX source for `broadcast_div_kernel`: broadcast division, identical structure
7941/// to `broadcast_mul_kernel` but uses `div.f32` instead of `mul.f32`.
7942#[cfg(feature = "cuda")]
7943pub(crate) const BROADCAST_DIV_PTX: &str = "\
7944.version 7.0
7945.target sm_52
7946.address_size 64
7947
7948.visible .entry broadcast_div_kernel(
7949    .param .u64 a_ptr,
7950    .param .u64 b_ptr,
7951    .param .u64 out_ptr,
7952    .param .u64 a_strides_ptr,
7953    .param .u64 b_strides_ptr,
7954    .param .u64 out_shape_ptr,
7955    .param .u32 n,
7956    .param .u32 ndim
7957) {
7958    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7959    .reg .u32 %remaining, %a_idx, %b_idx, %d;
7960    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7961    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7962    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7963    .reg .f32 %va, %vb, %vr;
7964    .reg .pred %p, %loop_p;
7965
7966    ld.param.u64 %a, [a_ptr];
7967    ld.param.u64 %b, [b_ptr];
7968    ld.param.u64 %out, [out_ptr];
7969    ld.param.u64 %a_str, [a_strides_ptr];
7970    ld.param.u64 %b_str, [b_strides_ptr];
7971    ld.param.u64 %oshape, [out_shape_ptr];
7972    ld.param.u32 %n_reg, [n];
7973    ld.param.u32 %ndim_reg, [ndim];
7974
7975    mov.u32 %bid, %ctaid.x;
7976    mov.u32 %bdim, %ntid.x;
7977    mov.u32 %r_tid, %tid.x;
7978    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7979    setp.ge.u32 %p, %r_tid, %n_reg;
7980    @%p bra DONE;
7981
7982    mov.u32 %remaining, %r_tid;
7983    mov.u32 %a_idx, 0;
7984    mov.u32 %b_idx, 0;
7985    mov.u32 %d, %ndim_reg;
7986LOOP:
7987    setp.eq.u32 %loop_p, %d, 0;
7988    @%loop_p bra END_LOOP;
7989    sub.u32 %d, %d, 1;
7990    cvt.u64.u32 %d64, %d;
7991    shl.b64 %d64, %d64, 2;
7992    add.u64 %tmp, %oshape, %d64;
7993    ld.global.u32 %shape_d, [%tmp];
7994    add.u64 %tmp, %a_str, %d64;
7995    ld.global.u32 %a_str_d, [%tmp];
7996    add.u64 %tmp, %b_str, %d64;
7997    ld.global.u32 %b_str_d, [%tmp];
7998    rem.u32 %coord, %remaining, %shape_d;
7999    div.u32 %remaining, %remaining, %shape_d;
8000    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
8001    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
8002    bra LOOP;
8003END_LOOP:
8004
8005    cvt.u64.u32 %off_a, %a_idx;
8006    shl.b64 %off_a, %off_a, 2;
8007    add.u64 %off_a, %a, %off_a;
8008    ld.global.f32 %va, [%off_a];
8009    cvt.u64.u32 %off_b, %b_idx;
8010    shl.b64 %off_b, %off_b, 2;
8011    add.u64 %off_b, %b, %off_b;
8012    ld.global.f32 %vb, [%off_b];
8013
8014    div.f32 %vr, %va, %vb;
8015
8016    cvt.u64.u32 %off_out, %r_tid;
8017    shl.b64 %off_out, %off_out, 2;
8018    add.u64 %off_out, %out, %off_out;
8019    st.global.f32 [%off_out], %vr;
8020DONE:
8021    ret;
8022}
8023";
8024
8025
8026/// PTX source for `strided_split_kernel`: extract a sub-tensor along a given axis.
8027///
8028/// Thread `i` computes:
8029///   `outer_idx = i / (split_size * inner_size)`
8030///   `within    = i % (split_size * inner_size)`
8031///   `src_idx   = outer_idx * total_along_axis * inner_size + (split_offset * inner_size) + within`
8032///   `out[i]    = in[src_idx]`
8033#[cfg(feature = "cuda")]
8034pub(crate) const STRIDED_SPLIT_PTX: &str = "\
8035.version 7.0
8036.target sm_52
8037.address_size 64
8038
8039.visible .entry strided_split_kernel(
8040    .param .u64 input_ptr,
8041    .param .u64 output_ptr,
8042    .param .u32 total_along_axis,
8043    .param .u32 split_offset,
8044    .param .u32 split_size,
8045    .param .u32 inner_size,
8046    .param .u32 n
8047) {
8048    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8049    .reg .u32 %total_ax, %sp_off, %sp_sz, %inner_sz;
8050    .reg .u32 %outer_idx, %within, %chunk_stride, %src_idx, %base_off, %tmp;
8051    .reg .u64 %in, %out, %off;
8052    .reg .f32 %val;
8053    .reg .pred %p;
8054
8055    ld.param.u64 %in, [input_ptr];
8056    ld.param.u64 %out, [output_ptr];
8057    ld.param.u32 %total_ax, [total_along_axis];
8058    ld.param.u32 %sp_off, [split_offset];
8059    ld.param.u32 %sp_sz, [split_size];
8060    ld.param.u32 %inner_sz, [inner_size];
8061    ld.param.u32 %n_reg, [n];
8062
8063    mov.u32 %bid, %ctaid.x;
8064    mov.u32 %bdim, %ntid.x;
8065    mov.u32 %r_tid, %tid.x;
8066    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8067
8068    setp.ge.u32 %p, %r_tid, %n_reg;
8069    @%p bra DONE;
8070
8071    // chunk_stride = split_size * inner_size
8072    mul.lo.u32 %chunk_stride, %sp_sz, %inner_sz;
8073
8074    // outer_idx = r_tid / chunk_stride
8075    div.u32 %outer_idx, %r_tid, %chunk_stride;
8076
8077    // within = r_tid % chunk_stride
8078    rem.u32 %within, %r_tid, %chunk_stride;
8079
8080    // base_off = split_offset * inner_size
8081    mul.lo.u32 %base_off, %sp_off, %inner_sz;
8082
8083    // src_idx = outer_idx * total_along_axis * inner_size + base_off + within
8084    mul.lo.u32 %src_idx, %outer_idx, %total_ax;
8085    mul.lo.u32 %src_idx, %src_idx, %inner_sz;
8086    add.u32 %src_idx, %src_idx, %base_off;
8087    add.u32 %src_idx, %src_idx, %within;
8088
8089    // Load from in[src_idx]
8090    cvt.u64.u32 %off, %src_idx;
8091    shl.b64 %off, %off, 2;
8092    add.u64 %off, %in, %off;
8093    ld.global.f32 %val, [%off];
8094
8095    // Store to out[r_tid]
8096    cvt.u64.u32 %off, %r_tid;
8097    shl.b64 %off, %off, 2;
8098    add.u64 %off, %out, %off;
8099    st.global.f32 [%off], %val;
8100
8101DONE:
8102    ret;
8103}
8104";
8105
8106
8107/// PTX source for `strided_cat_kernel`: write a sub-tensor into a larger tensor
8108/// at an offset along an axis.
8109///
8110/// Thread `i` computes:
8111///   `outer_idx = i / (part_size * inner_size)`
8112///   `within    = i % (part_size * inner_size)`
8113///   `dst_idx   = outer_idx * total_along_axis * inner_size + (cat_offset * inner_size) + within`
8114///   `out[dst_idx] = in[i]`
8115#[cfg(feature = "cuda")]
8116pub(crate) const STRIDED_CAT_PTX: &str = "\
8117.version 7.0
8118.target sm_52
8119.address_size 64
8120
8121.visible .entry strided_cat_kernel(
8122    .param .u64 input_ptr,
8123    .param .u64 output_ptr,
8124    .param .u32 total_along_axis,
8125    .param .u32 cat_offset,
8126    .param .u32 part_size,
8127    .param .u32 inner_size,
8128    .param .u32 n
8129) {
8130    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8131    .reg .u32 %total_ax, %cat_off, %part_sz, %inner_sz;
8132    .reg .u32 %outer_idx, %within, %chunk_stride, %dst_idx, %base_off;
8133    .reg .u64 %in, %out, %off;
8134    .reg .f32 %val;
8135    .reg .pred %p;
8136
8137    ld.param.u64 %in, [input_ptr];
8138    ld.param.u64 %out, [output_ptr];
8139    ld.param.u32 %total_ax, [total_along_axis];
8140    ld.param.u32 %cat_off, [cat_offset];
8141    ld.param.u32 %part_sz, [part_size];
8142    ld.param.u32 %inner_sz, [inner_size];
8143    ld.param.u32 %n_reg, [n];
8144
8145    mov.u32 %bid, %ctaid.x;
8146    mov.u32 %bdim, %ntid.x;
8147    mov.u32 %r_tid, %tid.x;
8148    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8149
8150    setp.ge.u32 %p, %r_tid, %n_reg;
8151    @%p bra DONE;
8152
8153    // chunk_stride = part_size * inner_size
8154    mul.lo.u32 %chunk_stride, %part_sz, %inner_sz;
8155
8156    // outer_idx = r_tid / chunk_stride
8157    div.u32 %outer_idx, %r_tid, %chunk_stride;
8158
8159    // within = r_tid % chunk_stride
8160    rem.u32 %within, %r_tid, %chunk_stride;
8161
8162    // base_off = cat_offset * inner_size
8163    mul.lo.u32 %base_off, %cat_off, %inner_sz;
8164
8165    // dst_idx = outer_idx * total_along_axis * inner_size + base_off + within
8166    mul.lo.u32 %dst_idx, %outer_idx, %total_ax;
8167    mul.lo.u32 %dst_idx, %dst_idx, %inner_sz;
8168    add.u32 %dst_idx, %dst_idx, %base_off;
8169    add.u32 %dst_idx, %dst_idx, %within;
8170
8171    // Load from in[r_tid]
8172    cvt.u64.u32 %off, %r_tid;
8173    shl.b64 %off, %off, 2;
8174    add.u64 %off, %in, %off;
8175    ld.global.f32 %val, [%off];
8176
8177    // Store to out[dst_idx]
8178    cvt.u64.u32 %off, %dst_idx;
8179    shl.b64 %off, %off, 2;
8180    add.u64 %off, %out, %off;
8181    st.global.f32 [%off], %val;
8182
8183DONE:
8184    ret;
8185}
8186";
8187
8188
8189/// PTX source for `strided_copy_kernel`: general strided→contiguous
8190/// gather with up to 8 dimensions. CL-496.
8191///
8192/// Thread `i` computes:
8193///   flat = i
8194///   src = src_offset_base
8195///   for d in 0..8:
8196///       coord = flat / out_stride[d]
8197///       flat  = flat % out_stride[d]
8198///       src  += coord * src_stride[d]
8199///   out[i] = in[src]
8200///
8201/// For tensors with fewer than 8 dims, unused positions must be
8202/// padded with `out_stride[d] = n + 1` (so `flat / out_stride[d] = 0`)
8203/// and `src_stride[d] = 0` (so the contribution is zero).
8204///
8205/// Each stride is passed as an individual u32 kernel parameter to
8206/// avoid needing a device-side stride array. 20 params total is well
8207/// within the ~4KB param limit.
8208#[cfg(feature = "cuda")]
8209pub(crate) const STRIDED_COPY_PTX: &str = "\
8210.version 7.0
8211.target sm_52
8212.address_size 64
8213
8214.visible .entry strided_copy_kernel(
8215    .param .u64 input_ptr,
8216    .param .u64 output_ptr,
8217    .param .u32 src_offset_base,
8218    .param .u32 n,
8219    .param .u32 os0, .param .u32 os1, .param .u32 os2, .param .u32 os3,
8220    .param .u32 os4, .param .u32 os5, .param .u32 os6, .param .u32 os7,
8221    .param .u32 ss0, .param .u32 ss1, .param .u32 ss2, .param .u32 ss3,
8222    .param .u32 ss4, .param .u32 ss5, .param .u32 ss6, .param .u32 ss7
8223) {
8224    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8225    .reg .u32 %flat, %src_idx, %coord, %tmp, %os, %ss;
8226    .reg .u64 %in, %out, %off;
8227    .reg .f32 %val;
8228    .reg .pred %p;
8229
8230    ld.param.u64 %in, [input_ptr];
8231    ld.param.u64 %out, [output_ptr];
8232    ld.param.u32 %src_idx, [src_offset_base];
8233    ld.param.u32 %n_reg, [n];
8234
8235    mov.u32 %bid, %ctaid.x;
8236    mov.u32 %bdim, %ntid.x;
8237    mov.u32 %r_tid, %tid.x;
8238    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8239
8240    setp.ge.u32 %p, %r_tid, %n_reg;
8241    @%p bra DONE;
8242
8243    mov.u32 %flat, %r_tid;
8244
8245    // Dim 0
8246    ld.param.u32 %os, [os0];
8247    ld.param.u32 %ss, [ss0];
8248    div.u32 %coord, %flat, %os;
8249    mul.lo.u32 %tmp, %coord, %os;
8250    sub.u32 %flat, %flat, %tmp;
8251    mul.lo.u32 %tmp, %coord, %ss;
8252    add.u32 %src_idx, %src_idx, %tmp;
8253
8254    // Dim 1
8255    ld.param.u32 %os, [os1];
8256    ld.param.u32 %ss, [ss1];
8257    div.u32 %coord, %flat, %os;
8258    mul.lo.u32 %tmp, %coord, %os;
8259    sub.u32 %flat, %flat, %tmp;
8260    mul.lo.u32 %tmp, %coord, %ss;
8261    add.u32 %src_idx, %src_idx, %tmp;
8262
8263    // Dim 2
8264    ld.param.u32 %os, [os2];
8265    ld.param.u32 %ss, [ss2];
8266    div.u32 %coord, %flat, %os;
8267    mul.lo.u32 %tmp, %coord, %os;
8268    sub.u32 %flat, %flat, %tmp;
8269    mul.lo.u32 %tmp, %coord, %ss;
8270    add.u32 %src_idx, %src_idx, %tmp;
8271
8272    // Dim 3
8273    ld.param.u32 %os, [os3];
8274    ld.param.u32 %ss, [ss3];
8275    div.u32 %coord, %flat, %os;
8276    mul.lo.u32 %tmp, %coord, %os;
8277    sub.u32 %flat, %flat, %tmp;
8278    mul.lo.u32 %tmp, %coord, %ss;
8279    add.u32 %src_idx, %src_idx, %tmp;
8280
8281    // Dim 4
8282    ld.param.u32 %os, [os4];
8283    ld.param.u32 %ss, [ss4];
8284    div.u32 %coord, %flat, %os;
8285    mul.lo.u32 %tmp, %coord, %os;
8286    sub.u32 %flat, %flat, %tmp;
8287    mul.lo.u32 %tmp, %coord, %ss;
8288    add.u32 %src_idx, %src_idx, %tmp;
8289
8290    // Dim 5
8291    ld.param.u32 %os, [os5];
8292    ld.param.u32 %ss, [ss5];
8293    div.u32 %coord, %flat, %os;
8294    mul.lo.u32 %tmp, %coord, %os;
8295    sub.u32 %flat, %flat, %tmp;
8296    mul.lo.u32 %tmp, %coord, %ss;
8297    add.u32 %src_idx, %src_idx, %tmp;
8298
8299    // Dim 6
8300    ld.param.u32 %os, [os6];
8301    ld.param.u32 %ss, [ss6];
8302    div.u32 %coord, %flat, %os;
8303    mul.lo.u32 %tmp, %coord, %os;
8304    sub.u32 %flat, %flat, %tmp;
8305    mul.lo.u32 %tmp, %coord, %ss;
8306    add.u32 %src_idx, %src_idx, %tmp;
8307
8308    // Dim 7
8309    ld.param.u32 %os, [os7];
8310    ld.param.u32 %ss, [ss7];
8311    div.u32 %coord, %flat, %os;
8312    mul.lo.u32 %tmp, %coord, %os;
8313    sub.u32 %flat, %flat, %tmp;
8314    mul.lo.u32 %tmp, %coord, %ss;
8315    add.u32 %src_idx, %src_idx, %tmp;
8316
8317    // Load from in[src_idx]
8318    cvt.u64.u32 %off, %src_idx;
8319    shl.b64 %off, %off, 2;
8320    add.u64 %off, %in, %off;
8321    ld.global.f32 %val, [%off];
8322
8323    // Store to out[r_tid]
8324    cvt.u64.u32 %off, %r_tid;
8325    shl.b64 %off, %off, 2;
8326    add.u64 %off, %out, %off;
8327    st.global.f32 [%off], %val;
8328
8329DONE:
8330    ret;
8331}
8332";
8333
8334
8335/// PTX source for `div_kernel`: `out[i] = a[i] / b[i]`.
8336#[cfg(feature = "cuda")]
8337pub(crate) const DIV_PTX: &str = "\
8338.version 7.0
8339.target sm_52
8340.address_size 64
8341
8342.visible .entry div_kernel(
8343    .param .u64 a_ptr,
8344    .param .u64 b_ptr,
8345    .param .u64 out_ptr,
8346    .param .u32 n
8347) {
8348    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8349    .reg .u64 %a, %b, %out, %off;
8350    .reg .f32 %va, %vb, %vr;
8351    .reg .pred %p;
8352
8353    ld.param.u64 %a, [a_ptr];
8354    ld.param.u64 %b, [b_ptr];
8355    ld.param.u64 %out, [out_ptr];
8356    ld.param.u32 %n_reg, [n];
8357
8358    mov.u32 %bid, %ctaid.x;
8359    mov.u32 %bdim, %ntid.x;
8360    mov.u32 %r_tid, %tid.x;
8361    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8362
8363    setp.ge.u32 %p, %r_tid, %n_reg;
8364    @%p bra DONE;
8365
8366    cvt.u64.u32 %off, %r_tid;
8367    shl.b64 %off, %off, 2;
8368
8369    add.u64 %a, %a, %off;
8370    add.u64 %b, %b, %off;
8371    add.u64 %out, %out, %off;
8372
8373    ld.global.f32 %va, [%a];
8374    ld.global.f32 %vb, [%b];
8375    div.rn.f32 %vr, %va, %vb;
8376    st.global.f32 [%out], %vr;
8377
8378DONE:
8379    ret;
8380}
8381";
8382
8383
8384/// PTX source for `exp_kernel`: `out[i] = exp(a[i])`.
8385#[cfg(feature = "cuda")]
8386pub(crate) const EXP_PTX: &str = "\
8387.version 7.0
8388.target sm_52
8389.address_size 64
8390
8391.visible .entry exp_kernel(
8392    .param .u64 a_ptr,
8393    .param .u64 out_ptr,
8394    .param .u32 n
8395) {
8396    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8397    .reg .u64 %a, %out, %off;
8398    .reg .f32 %va, %vr;
8399    .reg .pred %p;
8400
8401    ld.param.u64 %a, [a_ptr];
8402    ld.param.u64 %out, [out_ptr];
8403    ld.param.u32 %n_reg, [n];
8404
8405    mov.u32 %bid, %ctaid.x;
8406    mov.u32 %bdim, %ntid.x;
8407    mov.u32 %r_tid, %tid.x;
8408    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8409
8410    setp.ge.u32 %p, %r_tid, %n_reg;
8411    @%p bra DONE;
8412
8413    cvt.u64.u32 %off, %r_tid;
8414    shl.b64 %off, %off, 2;
8415
8416    add.u64 %a, %a, %off;
8417    add.u64 %out, %out, %off;
8418
8419    ld.global.f32 %va, [%a];
8420    // PTX ex2.approx computes 2^x; use the identity exp(x) = 2^(x * log2(e))
8421    // log2(e) = 1.4426950408889634
8422    mul.f32 %va, %va, 0f3FB8AA3B;
8423    ex2.approx.f32 %vr, %va;
8424    st.global.f32 [%out], %vr;
8425
8426DONE:
8427    ret;
8428}
8429";
8430
8431/// PTX source for `exp_f64_kernel`: `out[i] = exp(a[i])` (f64).
8432/// Uses f32 `ex2.approx` via downcast for the transcendental, then upcasts back.
8433/// Accurate to f32 precision (~7 decimal digits), sufficient for deep learning.
8434#[cfg(feature = "cuda")]
8435/// f64 exp with full double precision via Cody-Waite range reduction +
8436/// degree-13 minimax polynomial.
8437///
8438/// Algorithm: exp(x) = 2^n * (1 + P(r))
8439///   where n = round(x * log2(e)), r = x - n*ln2_hi - n*ln2_lo
8440///   and P(r) is a 13th-degree minimax polynomial for (exp(r)-1)/r.
8441///
8442/// Accuracy: < 1 ULP for |x| < 709 (full f64 range).
8443pub(crate) const EXP_F64_PTX: &str = "\
8444.version 7.0
8445.target sm_52
8446.address_size 64
8447
8448.visible .entry exp_f64_kernel(
8449    .param .u64 a_ptr,
8450    .param .u64 out_ptr,
8451    .param .u32 n
8452) {
8453    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8454    .reg .u64 %a, %out, %off;
8455    .reg .f64 %x, %vr;
8456    .reg .f64 %log2e, %nf, %r;
8457    .reg .f64 %p, %one, %half;
8458    .reg .s32 %ni;
8459    .reg .s64 %ni64, %exp_bits;
8460    .reg .pred %p_bounds, %p_tid;
8461
8462    ld.param.u64 %a, [a_ptr];
8463    ld.param.u64 %out, [out_ptr];
8464    ld.param.u32 %n_reg, [n];
8465
8466    mov.u32 %bid, %ctaid.x;
8467    mov.u32 %bdim, %ntid.x;
8468    mov.u32 %r_tid, %tid.x;
8469    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8470
8471    setp.ge.u32 %p_tid, %r_tid, %n_reg;
8472    @%p_tid bra DONE;
8473
8474    cvt.u64.u32 %off, %r_tid;
8475    shl.b64 %off, %off, 3;
8476    add.u64 %a, %a, %off;
8477    add.u64 %out, %out, %off;
8478
8479    ld.global.f64 %x, [%a];
8480
8481    // Constants
8482    mov.f64 %log2e, 0d3FF71547652B82FE;   // log2(e) = 1.4426950408889634
8483    mov.f64 %ln2_hi, 0d3FE62E42FEFA3800;  // ln(2) high bits
8484    mov.f64 %ln2_lo, 0d3D2EF35793C76730;  // ln(2) low bits
8485    mov.f64 %one, 0d3FF0000000000000;      // 1.0
8486    mov.f64 %half, 0d3FE0000000000000;     // 0.5
8487
8488    // n = round(x * log2(e))
8489    mul.f64 %nf, %x, %log2e;
8490    cvt.rni.f64.f64 %nf, %nf;             // round to nearest integer
8491    cvt.rni.s32.f64 %ni, %nf;             // integer n
8492
8493    // r = x - n * ln2  (Cody-Waite two-step for precision)
8494    fma.rn.f64 %r, %nf, 0dBFE62E42FEFA3800, %x;  // r = x - n*ln2_hi
8495    fma.rn.f64 %r, %nf, 0dBD2EF35793C76730, %r;   // r -= n*ln2_lo
8496
8497    // Horner polynomial for exp(r) - 1 - r = r^2 * (1/2! + r*(1/3! + r*(1/4! + ...)))
8498    // p starts at 1/11!, accumulates down to 1/2!
8499    mov.f64 %p, 0d3E21EED8EFF8D898;           // 1/11! = 2.505e-8
8500    fma.rn.f64 %p, %p, %r, 0d3E5AE64567F544E4;  // 1/10! = 2.756e-7
8501    fma.rn.f64 %p, %p, %r, 0d3E927E4FB7789F5C;  // 1/9!  = 2.756e-6
8502    fma.rn.f64 %p, %p, %r, 0d3EC71DE3A556C734;  // 1/8!  = 2.480e-5
8503    fma.rn.f64 %p, %p, %r, 0d3EFA01A01A01A01A;  // 1/7!  = 1.984e-4
8504    fma.rn.f64 %p, %p, %r, 0d3F2A01A01A01A01A;  // 1/6!  = 1.389e-3
8505    fma.rn.f64 %p, %p, %r, 0d3F56C16C16C16C17;  // 1/5!  = 8.333e-3
8506    fma.rn.f64 %p, %p, %r, 0d3F811111111111111;  // 1/4!  = 4.167e-2
8507    fma.rn.f64 %p, %p, %r, 0d3FC5555555555555;  // 1/3!  = 1.667e-1
8508    fma.rn.f64 %p, %p, %r, %half;                // 1/2!  = 5.000e-1
8509
8510    // exp(r) = 1 + r + r^2 * p  =>  1 + r*(1 + r*p)
8511    fma.rn.f64 %p, %p, %r, %one;   // p = r*p + 1
8512    fma.rn.f64 %vr, %p, %r, %one;  // vr = p*r + 1 = exp(r)
8513
8514    // Scale by 2^n: multiply by constructing the f64 bit pattern for 2^n.
8515    // IEEE 754 f64: 2^n has exponent field = n + 1023, no mantissa bits.
8516    // Bit pattern: (n + 1023) << 52.
8517    cvt.s64.s32 %ni64, %ni;
8518    add.s64 %ni64, %ni64, 1023;
8519    shl.b64 %exp_bits, %ni64, 52;
8520    mov.b64 %nf, %exp_bits;        // reinterpret as f64 = 2^n
8521    mul.f64 %vr, %vr, %nf;
8522
8523    st.global.f64 [%out], %vr;
8524
8525DONE:
8526    ret;
8527}
8528";
8529
8530/// PTX source for `log_kernel`: `out[i] = ln(a[i])`.
8531#[cfg(feature = "cuda")]
8532pub(crate) const LOG_PTX: &str = "\
8533.version 7.0
8534.target sm_52
8535.address_size 64
8536
8537.visible .entry log_kernel(
8538    .param .u64 a_ptr,
8539    .param .u64 out_ptr,
8540    .param .u32 n
8541) {
8542    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8543    .reg .u64 %a, %out, %off;
8544    .reg .f32 %va, %vr;
8545    .reg .pred %p;
8546
8547    ld.param.u64 %a, [a_ptr];
8548    ld.param.u64 %out, [out_ptr];
8549    ld.param.u32 %n_reg, [n];
8550
8551    mov.u32 %bid, %ctaid.x;
8552    mov.u32 %bdim, %ntid.x;
8553    mov.u32 %r_tid, %tid.x;
8554    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8555
8556    setp.ge.u32 %p, %r_tid, %n_reg;
8557    @%p bra DONE;
8558
8559    cvt.u64.u32 %off, %r_tid;
8560    shl.b64 %off, %off, 2;
8561
8562    add.u64 %a, %a, %off;
8563    add.u64 %out, %out, %off;
8564
8565    ld.global.f32 %va, [%a];
8566    // PTX lg2.approx computes log2(x); use the identity ln(x) = log2(x) / log2(e)
8567    // 1/log2(e) = ln(2) = 0.6931471805599453
8568    lg2.approx.f32 %vr, %va;
8569    mul.f32 %vr, %vr, 0f3F317218;
8570    st.global.f32 [%out], %vr;
8571
8572DONE:
8573    ret;
8574}
8575";
8576
8577/// PTX source for `log_f64_kernel`: `out[i] = ln(a[i])` (f64).
8578/// Uses f32 `lg2.approx` via downcast for the transcendental, then upcasts back.
8579/// Accurate to f32 precision (~7 decimal digits), sufficient for deep learning.
8580#[cfg(feature = "cuda")]
8581/// f64 log with full double precision via argument reduction + rational
8582/// approximation.
8583///
8584/// Algorithm: decompose x = 2^n * m (1 <= m < 2), then
8585///   ln(x) = n*ln(2) + ln(m)
8586/// where ln(m) is computed via f = (m-1)/(m+1), ln(m) = 2*f*(1 + f^2/3 + f^4/5 + ...)
8587///
8588/// Accuracy: < 2 ULP across the full f64 range.
8589pub(crate) const LOG_F64_PTX: &str = "\
8590.version 7.0
8591.target sm_52
8592.address_size 64
8593
8594.visible .entry log_f64_kernel(
8595    .param .u64 a_ptr,
8596    .param .u64 out_ptr,
8597    .param .u32 n
8598) {
8599    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8600    .reg .u64 %a, %out, %off;
8601    .reg .u64 %xbits, %mantissa_bits, %bias_bits;
8602    .reg .f64 %x, %vr, %m, %f, %f2, %s, %p;
8603    .reg .f64 %ln2_hi, %ln2_lo, %one, %two;
8604    .reg .s32 %exp_i;
8605    .reg .s64 %exp64;
8606    .reg .f64 %nf;
8607    .reg .pred %p_tid;
8608
8609    ld.param.u64 %a, [a_ptr];
8610    ld.param.u64 %out, [out_ptr];
8611    ld.param.u32 %n_reg, [n];
8612
8613    mov.u32 %bid, %ctaid.x;
8614    mov.u32 %bdim, %ntid.x;
8615    mov.u32 %r_tid, %tid.x;
8616    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8617
8618    setp.ge.u32 %p_tid, %r_tid, %n_reg;
8619    @%p_tid bra DONE;
8620
8621    cvt.u64.u32 %off, %r_tid;
8622    shl.b64 %off, %off, 3;
8623    add.u64 %a, %a, %off;
8624    add.u64 %out, %out, %off;
8625
8626    ld.global.f64 %x, [%a];
8627
8628    mov.f64 %ln2_hi, 0d3FE62E42FEFA39EF;   // ln(2) = 0.6931471805599453
8629    mov.f64 %one, 0d3FF0000000000000;
8630    mov.f64 %two, 0d4000000000000000;
8631
8632    // Extract exponent: n = exponent_field - 1023
8633    mov.b64 %xbits, %x;
8634    shr.u64 %exp64, %xbits, 52;
8635    and.b64 %exp64, %exp64, 2047;   // 11-bit exponent field
8636    sub.s64 %exp64, %exp64, 1023;
8637    cvt.rn.f64.s64 %nf, %exp64;     // n as f64
8638
8639    // Extract mantissa m: set exponent to 1023 (so m is in [1, 2))
8640    mov.u64 %bias_bits, 0x3FF0000000000000;  // exponent = 1023
8641    and.b64 %mantissa_bits, %xbits, 0x000FFFFFFFFFFFFF;  // mantissa bits
8642    or.b64 %mantissa_bits, %mantissa_bits, %bias_bits;
8643    mov.b64 %m, %mantissa_bits;      // m in [1.0, 2.0)
8644
8645    // f = (m - 1) / (m + 1) — maps [1,2) to [0, 1/3)
8646    sub.f64 %f, %m, %one;
8647    add.f64 %s, %m, %one;
8648    div.rn.f64 %f, %f, %s;
8649
8650    // ln(m) = 2*f + 2*f^3/3 + 2*f^5/5 + 2*f^7/7 + 2*f^9/9 + 2*f^11/11
8651    // Horner: ln(m) = 2*f*(1 + f^2*(1/3 + f^2*(1/5 + f^2*(1/7 + f^2*(1/9 + f^2/11)))))
8652    mul.f64 %f2, %f, %f;
8653
8654    // p = 1/11
8655    mov.f64 %p, 0d3FB745D1745D1746;
8656    // p = p*f2 + 1/9
8657    fma.rn.f64 %p, %p, %f2, 0d3FC1C71C71C71C72;
8658    // p = p*f2 + 1/7
8659    fma.rn.f64 %p, %p, %f2, 0d3FC2492492492492;
8660    // p = p*f2 + 1/5
8661    fma.rn.f64 %p, %p, %f2, 0d3FC999999999999A;
8662    // p = p*f2 + 1/3
8663    fma.rn.f64 %p, %p, %f2, 0d3FD5555555555555;
8664    // p = p*f2 + 1
8665    fma.rn.f64 %p, %p, %f2, %one;
8666
8667    // ln(m) = 2*f*p
8668    mul.f64 %p, %p, %f;
8669    add.f64 %p, %p, %p;   // * 2
8670
8671    // ln(x) = n*ln(2) + ln(m)
8672    fma.rn.f64 %vr, %nf, %ln2_hi, %p;
8673
8674    st.global.f64 [%out], %vr;
8675
8676DONE:
8677    ret;
8678}
8679";
8680
8681/// PTX source for `sqrt_kernel`: `out[i] = sqrt(a[i])`.
8682#[cfg(feature = "cuda")]
8683pub(crate) const SQRT_PTX: &str = "\
8684.version 7.0
8685.target sm_52
8686.address_size 64
8687
8688.visible .entry sqrt_kernel(
8689    .param .u64 a_ptr,
8690    .param .u64 out_ptr,
8691    .param .u32 n
8692) {
8693    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8694    .reg .u64 %a, %out, %off;
8695    .reg .f32 %va, %vr;
8696    .reg .pred %p;
8697
8698    ld.param.u64 %a, [a_ptr];
8699    ld.param.u64 %out, [out_ptr];
8700    ld.param.u32 %n_reg, [n];
8701
8702    mov.u32 %bid, %ctaid.x;
8703    mov.u32 %bdim, %ntid.x;
8704    mov.u32 %r_tid, %tid.x;
8705    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8706
8707    setp.ge.u32 %p, %r_tid, %n_reg;
8708    @%p bra DONE;
8709
8710    cvt.u64.u32 %off, %r_tid;
8711    shl.b64 %off, %off, 2;
8712
8713    add.u64 %a, %a, %off;
8714    add.u64 %out, %out, %off;
8715
8716    ld.global.f32 %va, [%a];
8717    sqrt.rn.f32 %vr, %va;
8718    st.global.f32 [%out], %vr;
8719
8720DONE:
8721    ret;
8722}
8723";
8724
8725
8726/// PTX source for `pow_kernel`: `out[i] = a[i] ^ exponent`.
8727/// Uses the identity: x^e = 2^(e * log2(x)).
8728#[cfg(feature = "cuda")]
8729pub(crate) const POW_PTX: &str = "\
8730.version 7.0
8731.target sm_52
8732.address_size 64
8733
8734.visible .entry pow_kernel(
8735    .param .u64 a_ptr,
8736    .param .u64 out_ptr,
8737    .param .f32 exponent,
8738    .param .u32 n
8739) {
8740    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8741    .reg .u64 %a, %out, %off;
8742    .reg .f32 %va, %vr, %exp, %lg;
8743    .reg .pred %p;
8744
8745    ld.param.u64 %a, [a_ptr];
8746    ld.param.u64 %out, [out_ptr];
8747    ld.param.f32 %exp, [exponent];
8748    ld.param.u32 %n_reg, [n];
8749
8750    mov.u32 %bid, %ctaid.x;
8751    mov.u32 %bdim, %ntid.x;
8752    mov.u32 %r_tid, %tid.x;
8753    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8754
8755    setp.ge.u32 %p, %r_tid, %n_reg;
8756    @%p bra DONE;
8757
8758    cvt.u64.u32 %off, %r_tid;
8759    shl.b64 %off, %off, 2;
8760
8761    add.u64 %a, %a, %off;
8762    add.u64 %out, %out, %off;
8763
8764    ld.global.f32 %va, [%a];
8765    // x^e = 2^(e * log2(x))
8766    lg2.approx.f32 %lg, %va;
8767    mul.f32 %lg, %lg, %exp;
8768    ex2.approx.f32 %vr, %lg;
8769    st.global.f32 [%out], %vr;
8770
8771DONE:
8772    ret;
8773}
8774";
8775
8776/// PTX source for `pow_f64_kernel`: `out[i] = a[i] ^ exponent` (f64).
8777/// Full f64 precision: x^e = exp(e * ln(x)).
8778/// Uses inline f64 log (argument reduction + odd-power series) and
8779/// inline f64 exp (Cody-Waite + degree-11 Horner).
8780#[cfg(feature = "cuda")]
8781pub(crate) const POW_F64_PTX: &str = "\
8782.version 7.0
8783.target sm_52
8784.address_size 64
8785
8786.visible .entry pow_f64_kernel(
8787    .param .u64 a_ptr,
8788    .param .u64 out_ptr,
8789    .param .f64 exponent,
8790    .param .u32 n
8791) {
8792    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8793    .reg .u64 %a, %out, %off;
8794    .reg .f64 %va, %vr, %exp64, %one, %two;
8795    // log registers
8796    .reg .u64 %l_xbits, %l_mbits, %l_bias;
8797    .reg .s64 %l_exp64;
8798    .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2, %l_lnx;
8799    // exp registers
8800    .reg .f64 %e_z, %e_nf, %e_r, %e_p, %e_half;
8801    .reg .s32 %e_ni;
8802    .reg .s64 %e_ni64, %e_bits;
8803    .reg .pred %p;
8804
8805    ld.param.u64 %a, [a_ptr];
8806    ld.param.u64 %out, [out_ptr];
8807    ld.param.f64 %exp64, [exponent];
8808    ld.param.u32 %n_reg, [n];
8809
8810    mov.u32 %bid, %ctaid.x;
8811    mov.u32 %bdim, %ntid.x;
8812    mov.u32 %r_tid, %tid.x;
8813    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8814
8815    setp.ge.u32 %p, %r_tid, %n_reg;
8816    @%p bra DONE;
8817
8818    cvt.u64.u32 %off, %r_tid;
8819    shl.b64 %off, %off, 3;
8820
8821    add.u64 %a, %a, %off;
8822    add.u64 %out, %out, %off;
8823
8824    ld.global.f64 %va, [%a];
8825    mov.f64 %one, 0d3FF0000000000000;
8826    mov.f64 %two, 0d4000000000000000;
8827
8828    // === ln(va) via argument reduction ===
8829    // Decompose va = 2^n * m, m in [1,2), ln(va) = n*ln(2) + ln(m)
8830    mov.b64 %l_xbits, %va;
8831    shr.u64 %l_exp64, %l_xbits, 52;
8832    and.b64 %l_exp64, %l_exp64, 2047;
8833    sub.s64 %l_exp64, %l_exp64, 1023;
8834    cvt.rn.f64.s64 %l_nf, %l_exp64;
8835
8836    mov.u64 %l_bias, 0x3FF0000000000000;
8837    and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
8838    or.b64 %l_mbits, %l_mbits, %l_bias;
8839    mov.b64 %l_m, %l_mbits;
8840
8841    // f = (m-1)/(m+1)
8842    sub.f64 %l_f, %l_m, %one;
8843    add.f64 %l_s, %l_m, %one;
8844    div.rn.f64 %l_f, %l_f, %l_s;
8845    mul.f64 %l_f2, %l_f, %l_f;
8846
8847    // Horner: p = 1/11 + f2*(1/9 + f2*(1/7 + f2*(1/5 + f2*(1/3 + f2*1))))
8848    mov.f64 %l_p, 0d3FB745D1745D1746;
8849    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
8850    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
8851    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
8852    fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
8853    fma.rn.f64 %l_p, %l_p, %l_f2, %one;
8854
8855    // ln(m) = 2*f*p
8856    mul.f64 %l_p, %l_p, %l_f;
8857    add.f64 %l_p, %l_p, %l_p;
8858
8859    // ln(x) = n*ln(2) + ln(m)
8860    mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
8861    fma.rn.f64 %l_lnx, %l_nf, %l_ln2, %l_p;
8862
8863    // === exp(exponent * ln(x)) ===
8864    mul.f64 %e_z, %exp64, %l_lnx;
8865
8866    mov.f64 %e_half, 0d3FE0000000000000;
8867    fma.rn.f64 %e_nf, %e_z, 0d3FF71547652B82FE, %e_half;
8868    cvt.rmi.f64.f64 %e_nf, %e_nf;
8869    cvt.rni.s32.f64 %e_ni, %e_nf;
8870    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %e_z;
8871    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
8872    mov.f64 %e_p, 0d3E21EED8EFF8D898;
8873    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
8874    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
8875    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
8876    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
8877    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
8878    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
8879    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
8880    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
8881    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
8882    fma.rn.f64 %e_p, %e_p, %e_r, %one;
8883    fma.rn.f64 %vr, %e_p, %e_r, %one;
8884    cvt.s64.s32 %e_ni64, %e_ni;
8885    add.s64 %e_ni64, %e_ni64, 1023;
8886    shl.b64 %e_bits, %e_ni64, 52;
8887    mov.b64 %e_nf, %e_bits;
8888    mul.f64 %vr, %vr, %e_nf;
8889
8890    st.global.f64 [%out], %vr;
8891
8892DONE:
8893    ret;
8894}
8895";
8896
8897/// PTX source for `abs_kernel`: `out[i] = |a[i]|`.
8898#[cfg(feature = "cuda")]
8899pub(crate) const ABS_PTX: &str = "\
8900.version 7.0
8901.target sm_52
8902.address_size 64
8903
8904.visible .entry abs_kernel(
8905    .param .u64 a_ptr,
8906    .param .u64 out_ptr,
8907    .param .u32 n
8908) {
8909    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8910    .reg .u64 %a, %out, %off;
8911    .reg .f32 %va, %vr;
8912    .reg .pred %p;
8913
8914    ld.param.u64 %a, [a_ptr];
8915    ld.param.u64 %out, [out_ptr];
8916    ld.param.u32 %n_reg, [n];
8917
8918    mov.u32 %bid, %ctaid.x;
8919    mov.u32 %bdim, %ntid.x;
8920    mov.u32 %r_tid, %tid.x;
8921    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8922
8923    setp.ge.u32 %p, %r_tid, %n_reg;
8924    @%p bra DONE;
8925
8926    cvt.u64.u32 %off, %r_tid;
8927    shl.b64 %off, %off, 2;
8928
8929    add.u64 %a, %a, %off;
8930    add.u64 %out, %out, %off;
8931
8932    ld.global.f32 %va, [%a];
8933    abs.f32 %vr, %va;
8934    st.global.f32 [%out], %vr;
8935
8936DONE:
8937    ret;
8938}
8939";
8940
8941
8942/// PTX source for `sigmoid_kernel`: `out[i] = 1 / (1 + exp(-a[i]))`.
8943#[cfg(feature = "cuda")]
8944pub(crate) const SIGMOID_PTX: &str = "\
8945.version 7.0
8946.target sm_52
8947.address_size 64
8948
8949.visible .entry sigmoid_kernel(
8950    .param .u64 a_ptr,
8951    .param .u64 out_ptr,
8952    .param .u32 n
8953) {
8954    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8955    .reg .u64 %a, %out, %off;
8956    .reg .f32 %va, %vr, %neg, %e, %denom, %one, %lg2e;
8957    .reg .pred %p;
8958
8959    ld.param.u64 %a, [a_ptr];
8960    ld.param.u64 %out, [out_ptr];
8961    ld.param.u32 %n_reg, [n];
8962
8963    mov.u32 %bid, %ctaid.x;
8964    mov.u32 %bdim, %ntid.x;
8965    mov.u32 %r_tid, %tid.x;
8966    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8967
8968    setp.ge.u32 %p, %r_tid, %n_reg;
8969    @%p bra DONE;
8970
8971    cvt.u64.u32 %off, %r_tid;
8972    shl.b64 %off, %off, 2;
8973
8974    add.u64 %a, %a, %off;
8975    add.u64 %out, %out, %off;
8976
8977    ld.global.f32 %va, [%a];
8978    // sigmoid(x) = 1 / (1 + exp(-x))
8979    neg.f32 %neg, %va;
8980    mov.f32 %lg2e, 0f3FB8AA3B;
8981    mul.f32 %neg, %neg, %lg2e;
8982    ex2.approx.f32 %e, %neg;
8983    mov.f32 %one, 0f3F800000;
8984    add.f32 %denom, %one, %e;
8985    div.rn.f32 %vr, %one, %denom;
8986    st.global.f32 [%out], %vr;
8987
8988DONE:
8989    ret;
8990}
8991";
8992
8993/// PTX source for `sigmoid_f64_kernel`: `out[i] = 1 / (1 + exp(-a[i]))` (f64).
8994/// Full f64 precision: Cody-Waite range reduction + degree-11 Horner polynomial
8995/// for exp(-x), then sigmoid = 1/(1+exp(-x)).
8996#[cfg(feature = "cuda")]
8997pub(crate) const SIGMOID_F64_PTX: &str = "\
8998.version 7.0
8999.target sm_52
9000.address_size 64
9001
9002.visible .entry sigmoid_f64_kernel(
9003    .param .u64 a_ptr,
9004    .param .u64 out_ptr,
9005    .param .u32 n
9006) {
9007    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
9008    .reg .u64 %a, %out, %off;
9009    .reg .f64 %va, %vr, %e64, %denom, %one, %neg_x;
9010    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
9011    .reg .s32 %e_ni;
9012    .reg .s64 %e_ni64, %e_bits;
9013    .reg .pred %p;
9014
9015    ld.param.u64 %a, [a_ptr];
9016    ld.param.u64 %out, [out_ptr];
9017    ld.param.u32 %n_reg, [n];
9018
9019    mov.u32 %bid, %ctaid.x;
9020    mov.u32 %bdim, %ntid.x;
9021    mov.u32 %r_tid, %tid.x;
9022    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
9023
9024    setp.ge.u32 %p, %r_tid, %n_reg;
9025    @%p bra DONE;
9026
9027    cvt.u64.u32 %off, %r_tid;
9028    shl.b64 %off, %off, 3;
9029
9030    add.u64 %a, %a, %off;
9031    add.u64 %out, %out, %off;
9032
9033    ld.global.f64 %va, [%a];
9034    mov.f64 %one, 0d3FF0000000000000;
9035
9036    // sigmoid(x) = 1 / (1 + exp(-x))
9037    neg.f64 %neg_x, %va;
9038
9039    // --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
9040    mov.f64 %e_half, 0d3FE0000000000000;
9041    fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
9042    cvt.rmi.f64.f64 %e_nf, %e_nf;
9043    cvt.rni.s32.f64 %e_ni, %e_nf;
9044    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
9045    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
9046    mov.f64 %e_p, 0d3E21EED8EFF8D898;
9047    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
9048    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
9049    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
9050    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
9051    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
9052    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
9053    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
9054    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
9055    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
9056    fma.rn.f64 %e_p, %e_p, %e_r, %one;
9057    fma.rn.f64 %e64, %e_p, %e_r, %one;
9058    cvt.s64.s32 %e_ni64, %e_ni;
9059    add.s64 %e_ni64, %e_ni64, 1023;
9060    shl.b64 %e_bits, %e_ni64, 52;
9061    mov.b64 %e_nf, %e_bits;
9062    mul.f64 %e64, %e64, %e_nf;
9063    // --- end exp ---
9064
9065    add.f64 %denom, %one, %e64;
9066    div.rn.f64 %vr, %one, %denom;
9067    st.global.f64 [%out], %vr;
9068
9069DONE:
9070    ret;
9071}
9072";
9073
9074/// PTX source for `tanh_kernel`: `out[i] = tanh(a[i])`.
9075/// Uses the identity: tanh(x) = 2*sigmoid(2x) - 1.
9076#[cfg(feature = "cuda")]
9077pub(crate) const TANH_PTX: &str = "\
9078.version 7.0
9079.target sm_52
9080.address_size 64
9081
9082.visible .entry tanh_kernel(
9083    .param .u64 a_ptr,
9084    .param .u64 out_ptr,
9085    .param .u32 n
9086) {
9087    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
9088    .reg .u64 %a, %out, %off;
9089    .reg .f32 %va, %vr, %neg2x, %e, %denom, %sig, %one, %two, %lg2e;
9090    .reg .pred %p;
9091
9092    ld.param.u64 %a, [a_ptr];
9093    ld.param.u64 %out, [out_ptr];
9094    ld.param.u32 %n_reg, [n];
9095
9096    mov.u32 %bid, %ctaid.x;
9097    mov.u32 %bdim, %ntid.x;
9098    mov.u32 %r_tid, %tid.x;
9099    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
9100
9101    setp.ge.u32 %p, %r_tid, %n_reg;
9102    @%p bra DONE;
9103
9104    cvt.u64.u32 %off, %r_tid;
9105    shl.b64 %off, %off, 2;
9106
9107    add.u64 %a, %a, %off;
9108    add.u64 %out, %out, %off;
9109
9110    ld.global.f32 %va, [%a];
9111    // tanh(x) = 2*sigmoid(2x) - 1
9112    mov.f32 %two, 0f40000000;
9113    mul.f32 %neg2x, %va, %two;
9114    neg.f32 %neg2x, %neg2x;
9115    mov.f32 %lg2e, 0f3FB8AA3B;
9116    mul.f32 %neg2x, %neg2x, %lg2e;
9117    ex2.approx.f32 %e, %neg2x;
9118    mov.f32 %one, 0f3F800000;
9119    add.f32 %denom, %one, %e;
9120    div.rn.f32 %sig, %one, %denom;
9121    mul.f32 %vr, %two, %sig;
9122    sub.f32 %vr, %vr, %one;
9123    st.global.f32 [%out], %vr;
9124
9125DONE:
9126    ret;
9127}
9128";
9129
9130/// PTX source for `tanh_f64_kernel`: `out[i] = tanh(a[i])` (f64).
9131/// Uses the identity: tanh(x) = 2*sigmoid(2x) - 1 = (1-exp(-2x))/(1+exp(-2x)).
9132/// Full f64 precision via Cody-Waite + degree-11 Horner for exp(-2x).
9133#[cfg(feature = "cuda")]
9134pub(crate) const TANH_F64_PTX: &str = "\
9135.version 7.0
9136.target sm_52
9137.address_size 64
9138
9139.visible .entry tanh_f64_kernel(
9140    .param .u64 a_ptr,
9141    .param .u64 out_ptr,
9142    .param .u32 n
9143) {
9144    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
9145    .reg .u64 %a, %out, %off;
9146    .reg .f64 %va, %vr, %e64, %num, %denom, %one, %two, %neg2x;
9147    .reg .f64 %e_nf, %e_r, %e_p, %e_half;
9148    .reg .s32 %e_ni;
9149    .reg .s64 %e_ni64, %e_bits;
9150    .reg .pred %p;
9151
9152    ld.param.u64 %a, [a_ptr];
9153    ld.param.u64 %out, [out_ptr];
9154    ld.param.u32 %n_reg, [n];
9155
9156    mov.u32 %bid, %ctaid.x;
9157    mov.u32 %bdim, %ntid.x;
9158    mov.u32 %r_tid, %tid.x;
9159    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
9160
9161    setp.ge.u32 %p, %r_tid, %n_reg;
9162    @%p bra DONE;
9163
9164    cvt.u64.u32 %off, %r_tid;
9165    shl.b64 %off, %off, 3;
9166
9167    add.u64 %a, %a, %off;
9168    add.u64 %out, %out, %off;
9169
9170    ld.global.f64 %va, [%a];
9171    mov.f64 %one, 0d3FF0000000000000;
9172    mov.f64 %two, 0d4000000000000000;
9173
9174    // tanh(x) = (1 - exp(-2x)) / (1 + exp(-2x))
9175    mul.f64 %neg2x, %va, %two;
9176    neg.f64 %neg2x, %neg2x;
9177
9178    // --- exp(%neg2x) via Cody-Waite + degree-11 Horner ---
9179    mov.f64 %e_half, 0d3FE0000000000000;
9180    fma.rn.f64 %e_nf, %neg2x, 0d3FF71547652B82FE, %e_half;
9181    cvt.rmi.f64.f64 %e_nf, %e_nf;
9182    cvt.rni.s32.f64 %e_ni, %e_nf;
9183    fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg2x;
9184    fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
9185    mov.f64 %e_p, 0d3E21EED8EFF8D898;
9186    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
9187    fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
9188    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
9189    fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
9190    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
9191    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
9192    fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
9193    fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
9194    fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
9195    fma.rn.f64 %e_p, %e_p, %e_r, %one;
9196    fma.rn.f64 %e64, %e_p, %e_r, %one;
9197    cvt.s64.s32 %e_ni64, %e_ni;
9198    add.s64 %e_ni64, %e_ni64, 1023;
9199    shl.b64 %e_bits, %e_ni64, 52;
9200    mov.b64 %e_nf, %e_bits;
9201    mul.f64 %e64, %e64, %e_nf;
9202    // --- end exp ---
9203
9204    sub.f64 %num, %one, %e64;
9205    add.f64 %denom, %one, %e64;
9206    div.rn.f64 %vr, %num, %denom;
9207    st.global.f64 [%out], %vr;
9208
9209DONE:
9210    ret;
9211}
9212";
9213
9214/// PTX source for `fused_adam_kernel`: in-place Adam optimizer update.
9215///
9216/// For each element i:
9217///   g = grad[i] + weight_decay * param[i]  (if wd > 0)
9218///   exp_avg[i] = beta1 * exp_avg[i] + (1-beta1) * g
9219///   exp_avg_sq[i] = beta2 * exp_avg_sq[i] + (1-beta2) * g * g
9220///   m_hat = exp_avg[i] / bc1
9221///   v_hat = exp_avg_sq[i] / bc2
9222///   param[i] = param[i] - lr * m_hat / (sqrt(v_hat) + eps)
9223#[cfg(feature = "cuda")]
9224pub(crate) const FUSED_ADAM_PTX: &str = "\
9225.version 7.0
9226.target sm_52
9227.address_size 64
9228
9229.visible .entry fused_adam_kernel(
9230    .param .u64 param_ptr,
9231    .param .u64 grad_ptr,
9232    .param .u64 exp_avg_ptr,
9233    .param .u64 exp_avg_sq_ptr,
9234    .param .f32 beta1,
9235    .param .f32 beta2,
9236    .param .f32 lr,
9237    .param .f32 eps,
9238    .param .f32 bc1,
9239    .param .f32 bc2,
9240    .param .f32 weight_decay,
9241    .param .u32 n
9242) {
9243    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
9244    .reg .u64 %p, %g, %m, %v, %off;
9245    .reg .f32 %vp, %vg, %vm, %vv;
9246    .reg .f32 %b1, %b2, %f_lr, %f_eps, %f_bc1, %f_bc2, %f_wd;
9247    .reg .f32 %t1, %t2, %m_hat, %v_hat, %denom, %update;
9248    .reg .f32 %one;
9249    .reg .pred %p_bound, %p_wd;
9250
9251    ld.param.u64 %p, [param_ptr];
9252    ld.param.u64 %g, [grad_ptr];
9253    ld.param.u64 %m, [exp_avg_ptr];
9254    ld.param.u64 %v, [exp_avg_sq_ptr];
9255    ld.param.f32 %b1, [beta1];
9256    ld.param.f32 %b2, [beta2];
9257    ld.param.f32 %f_lr, [lr];
9258    ld.param.f32 %f_eps, [eps];
9259    ld.param.f32 %f_bc1, [bc1];
9260    ld.param.f32 %f_bc2, [bc2];
9261    ld.param.f32 %f_wd, [weight_decay];
9262    ld.param.u32 %n_reg, [n];
9263
9264    mov.u32 %bid, %ctaid.x;
9265    mov.u32 %bdim, %ntid.x;
9266    mov.u32 %r_tid, %tid.x;
9267    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
9268
9269    setp.ge.u32 %p_bound, %r_tid, %n_reg;
9270    @%p_bound bra DONE;
9271
9272    cvt.u64.u32 %off, %r_tid;
9273    shl.b64 %off, %off, 2;
9274
9275    add.u64 %p, %p, %off;
9276    add.u64 %g, %g, %off;
9277    add.u64 %m, %m, %off;
9278    add.u64 %v, %v, %off;
9279
9280    ld.global.f32 %vp, [%p];
9281    ld.global.f32 %vg, [%g];
9282    ld.global.f32 %vm, [%m];
9283    ld.global.f32 %vv, [%v];
9284
9285    // L2 weight decay: g = g + wd * p
9286    mov.f32 %one, 0f00000000;
9287    setp.gt.f32 %p_wd, %f_wd, %one;
9288    @%p_wd fma.rn.f32 %vg, %f_wd, %vp, %vg;
9289
9290    // exp_avg = beta1 * exp_avg + (1 - beta1) * g
9291    mov.f32 %one, 0f3F800000;
9292    sub.f32 %t1, %one, %b1;
9293    mul.f32 %vm, %vm, %b1;
9294    fma.rn.f32 %vm, %t1, %vg, %vm;
9295
9296    // exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * g * g
9297    sub.f32 %t2, %one, %b2;
9298    mul.f32 %vv, %vv, %b2;
9299    mul.f32 %t1, %vg, %vg;
9300    fma.rn.f32 %vv, %t2, %t1, %vv;
9301
9302    // m_hat = exp_avg / bc1
9303    div.rn.f32 %m_hat, %vm, %f_bc1;
9304
9305    // v_hat = exp_avg_sq / bc2
9306    div.rn.f32 %v_hat, %vv, %f_bc2;
9307
9308    // denom = sqrt(v_hat) + eps
9309    sqrt.rn.f32 %denom, %v_hat;
9310    add.f32 %denom, %denom, %f_eps;
9311
9312    // param = param - lr * m_hat / denom
9313    div.rn.f32 %update, %m_hat, %denom;
9314    mul.f32 %update, %update, %f_lr;
9315    sub.f32 %vp, %vp, %update;
9316
9317    st.global.f32 [%p], %vp;
9318    st.global.f32 [%m], %vm;
9319    st.global.f32 [%v], %vv;
9320
9321DONE:
9322    ret;
9323}
9324";
9325
9326/// PTX source for fused GRU cell forward kernel.
9327///
9328/// Takes pre-computed input_gates [B, 3*H] and hidden_gates [B, 3*H]
9329/// (from cuBLAS GEMMs), biases, and previous hidden state. Computes all
9330/// gate activations and the new hidden state in a single kernel launch.
9331///
9332/// One thread per hidden unit. Each thread reads 3 values from input_gates
9333/// and 3 from hidden_gates, applies sigmoid/tanh, computes the GRU update,
9334/// and writes hy + workspace (5*H values for backward).
9335///
9336/// Matches PyTorch's _thnn_fused_gru_cell kernel from RNN.cu.
9337#[cfg(feature = "cuda")]
9338pub(crate) const FUSED_GRU_FORWARD_PTX: &str = "\
9339.version 7.0
9340.target sm_52
9341.address_size 64
9342
9343.visible .entry fused_gru_forward_kernel(
9344    .param .u64 input_gates_ptr,
9345    .param .u64 hidden_gates_ptr,
9346    .param .u64 bias_ih_ptr,
9347    .param .u64 bias_hh_ptr,
9348    .param .u64 hx_ptr,
9349    .param .u64 hy_ptr,
9350    .param .u64 workspace_ptr,
9351    .param .u32 hsz,
9352    .param .u32 total
9353) {
9354    .reg .u32 %tid, %bid, %bdim, %gdim, %total_reg, %hsz_reg;
9355    .reg .u32 %idx, %stride, %offset3, %offset5, %hmod, %batch_idx;
9356    .reg .u64 %ig, %hg, %b1, %b2, %hx, %hy, %ws;
9357    .reg .u64 %off64, %tmp64;
9358    .reg .f32 %ir, %ii, %in, %hr, %hi, %hn;
9359    .reg .f32 %b1r, %b1i, %b1n, %b2r, %b2i, %b2n;
9360    .reg .f32 %hx_val, %rg, %zg, %ng, %hy_val;
9361    .reg .f32 %one, %neg_one, %exp_val, %denom, %tmp;
9362    .reg .pred %p;
9363
9364    ld.param.u64 %ig, [input_gates_ptr];
9365    ld.param.u64 %hg, [hidden_gates_ptr];
9366    ld.param.u64 %b1, [bias_ih_ptr];
9367    ld.param.u64 %b2, [bias_hh_ptr];
9368    ld.param.u64 %hx, [hx_ptr];
9369    ld.param.u64 %hy, [hy_ptr];
9370    ld.param.u64 %ws, [workspace_ptr];
9371    ld.param.u32 %hsz_reg, [hsz];
9372    ld.param.u32 %total_reg, [total];
9373
9374    mov.u32 %bid, %ctaid.x;
9375    mov.u32 %bdim, %ntid.x;
9376    mov.u32 %tid, %tid.x;
9377    mov.u32 %gdim, %nctaid.x;
9378    mad.lo.u32 %idx, %bid, %bdim, %tid;
9379    mul.lo.u32 %stride, %bdim, %gdim;
9380    mov.f32 %one, 0f3F800000;
9381
9382LOOP:
9383    setp.ge.u32 %p, %idx, %total_reg;
9384    @%p bra END;
9385
9386    // offset3 = (idx/hsz)*3*hsz + idx%hsz  (into [B, 3*H] gates tensor)
9387    div.u32 %batch_idx, %idx, %hsz_reg;
9388    rem.u32 %hmod, %idx, %hsz_reg;
9389    mul.lo.u32 %offset3, %batch_idx, %hsz_reg;
9390    mul.lo.u32 %offset3, %offset3, 3;
9391    add.u32 %offset3, %offset3, %hmod;
9392
9393    // Load input gate components: ir, ii, in
9394    cvt.u64.u32 %off64, %offset3;
9395    shl.b64 %off64, %off64, 2;
9396    add.u64 %tmp64, %ig, %off64;
9397    ld.global.f32 %ir, [%tmp64];
9398    cvt.u64.u32 %off64, %hsz_reg;
9399    shl.b64 %off64, %off64, 2;
9400    add.u64 %tmp64, %tmp64, %off64;
9401    ld.global.f32 %ii, [%tmp64];
9402    add.u64 %tmp64, %tmp64, %off64;
9403    ld.global.f32 %in, [%tmp64];
9404
9405    // Load hidden gate components: hr, hi, hn
9406    cvt.u64.u32 %off64, %offset3;
9407    shl.b64 %off64, %off64, 2;
9408    add.u64 %tmp64, %hg, %off64;
9409    ld.global.f32 %hr, [%tmp64];
9410    cvt.u64.u32 %off64, %hsz_reg;
9411    shl.b64 %off64, %off64, 2;
9412    add.u64 %tmp64, %tmp64, %off64;
9413    ld.global.f32 %hi, [%tmp64];
9414    add.u64 %tmp64, %tmp64, %off64;
9415    ld.global.f32 %hn, [%tmp64];
9416
9417    // Load biases (indexed by hmod, hmod+hsz, hmod+2*hsz)
9418    cvt.u64.u32 %off64, %hmod;
9419    shl.b64 %off64, %off64, 2;
9420    add.u64 %tmp64, %b1, %off64;
9421    ld.global.f32 %b1r, [%tmp64];
9422    cvt.u64.u32 %off64, %hsz_reg;
9423    shl.b64 %off64, %off64, 2;
9424    add.u64 %tmp64, %tmp64, %off64;
9425    ld.global.f32 %b1i, [%tmp64];
9426    add.u64 %tmp64, %tmp64, %off64;
9427    ld.global.f32 %b1n, [%tmp64];
9428
9429    cvt.u64.u32 %off64, %hmod;
9430    shl.b64 %off64, %off64, 2;
9431    add.u64 %tmp64, %b2, %off64;
9432    ld.global.f32 %b2r, [%tmp64];
9433    cvt.u64.u32 %off64, %hsz_reg;
9434    shl.b64 %off64, %off64, 2;
9435    add.u64 %tmp64, %tmp64, %off64;
9436    ld.global.f32 %b2i, [%tmp64];
9437    add.u64 %tmp64, %tmp64, %off64;
9438    ld.global.f32 %b2n, [%tmp64];
9439
9440    // Load hx[idx]
9441    cvt.u64.u32 %off64, %idx;
9442    shl.b64 %off64, %off64, 2;
9443    add.u64 %tmp64, %hx, %off64;
9444    ld.global.f32 %hx_val, [%tmp64];
9445
9446    // r = sigmoid(ir + hr + b1r + b2r)
9447    add.f32 %rg, %ir, %hr;
9448    add.f32 %rg, %rg, %b1r;
9449    add.f32 %rg, %rg, %b2r;
9450    neg.f32 %tmp, %rg;
9451    mul.f32 %tmp, %tmp, 0f3FB8AA3B;
9452    ex2.approx.f32 %exp_val, %tmp;
9453    add.f32 %denom, %one, %exp_val;
9454    div.rn.f32 %rg, %one, %denom;
9455
9456    // z = sigmoid(ii + hi + b1i + b2i)
9457    add.f32 %zg, %ii, %hi;
9458    add.f32 %zg, %zg, %b1i;
9459    add.f32 %zg, %zg, %b2i;
9460    neg.f32 %tmp, %zg;
9461    mul.f32 %tmp, %tmp, 0f3FB8AA3B;
9462    ex2.approx.f32 %exp_val, %tmp;
9463    add.f32 %denom, %one, %exp_val;
9464    div.rn.f32 %zg, %one, %denom;
9465
9466    // n = tanh(in + b1n + r*(hn + b2n))
9467    add.f32 %tmp, %hn, %b2n;
9468    fma.rn.f32 %ng, %rg, %tmp, %in;
9469    add.f32 %ng, %ng, %b1n;
9470    // tanh via 2*sigmoid(2x)-1
9471    mul.f32 %tmp, %ng, 0f40000000;
9472    neg.f32 %tmp, %tmp;
9473    mul.f32 %tmp, %tmp, 0f3FB8AA3B;
9474    ex2.approx.f32 %exp_val, %tmp;
9475    add.f32 %denom, %one, %exp_val;
9476    div.rn.f32 %ng, %one, %denom;
9477    mul.f32 %ng, %ng, 0f40000000;
9478    sub.f32 %ng, %ng, %one;
9479
9480    // hy = n + z * (hx - n)
9481    sub.f32 %tmp, %hx_val, %ng;
9482    fma.rn.f32 %hy_val, %zg, %tmp, %ng;
9483
9484    // Store hy[idx]
9485    cvt.u64.u32 %off64, %idx;
9486    shl.b64 %off64, %off64, 2;
9487    add.u64 %tmp64, %hy, %off64;
9488    st.global.f32 [%tmp64], %hy_val;
9489
9490    // Store workspace: [r, z, n, hx, hn+b2n] at offset5 = (idx/hsz)*5*hsz + idx%hsz
9491    mul.lo.u32 %offset5, %batch_idx, %hsz_reg;
9492    mul.lo.u32 %offset5, %offset5, 5;
9493    add.u32 %offset5, %offset5, %hmod;
9494
9495    cvt.u64.u32 %off64, %offset5;
9496    shl.b64 %off64, %off64, 2;
9497    add.u64 %tmp64, %ws, %off64;
9498    st.global.f32 [%tmp64], %rg;
9499    cvt.u64.u32 %off64, %hsz_reg;
9500    shl.b64 %off64, %off64, 2;
9501    add.u64 %tmp64, %tmp64, %off64;
9502    st.global.f32 [%tmp64], %zg;
9503    add.u64 %tmp64, %tmp64, %off64;
9504    st.global.f32 [%tmp64], %ng;
9505    add.u64 %tmp64, %tmp64, %off64;
9506    st.global.f32 [%tmp64], %hx_val;
9507    add.u64 %tmp64, %tmp64, %off64;
9508    add.f32 %tmp, %hn, %b2n;
9509    st.global.f32 [%tmp64], %tmp;
9510
9511    add.u32 %idx, %idx, %stride;
9512    bra LOOP;
9513
9514END:
9515    ret;
9516}
9517";
9518
9519// ---------------------------------------------------------------------------
9520// Launch configuration helper
9521// ---------------------------------------------------------------------------
9522
9523/// Standard 1-D launch config for `n` elements.
9524///
9525/// Uses 256 threads per block, which is a good default for elementwise ops
9526/// on all modern NVIDIA architectures.
9527///
9528/// # Errors
9529///
9530/// Returns [`GpuError::ShapeMismatch`] if `n` exceeds `u32::MAX`, which
9531/// would silently truncate the grid dimension.
9532#[cfg(feature = "cuda")]
9533fn launch_cfg(n: usize) -> GpuResult<LaunchConfig> {
9534    if n > u32::MAX as usize {
9535        return Err(GpuError::ShapeMismatch {
9536            op: "kernel_launch",
9537            expected: vec![u32::MAX as usize],
9538            got: vec![n],
9539        });
9540    }
9541    const BLOCK: u32 = 256;
9542    let grid = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
9543    Ok(LaunchConfig {
9544        grid_dim: (grid.max(1), 1, 1),
9545        block_dim: (BLOCK, 1, 1),
9546        shared_mem_bytes: 0,
9547    })
9548}
9549
9550// ---------------------------------------------------------------------------
9551// Validation helpers
9552// ---------------------------------------------------------------------------
9553
9554/// Validate that two buffers are on the same device and have the same length.
9555#[cfg(feature = "cuda")]
9556fn validate_binary(a: &CudaBuffer<f32>, b: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
9557    if a.device_ordinal() != device.ordinal() {
9558        return Err(GpuError::DeviceMismatch {
9559            expected: a.device_ordinal(),
9560            got: device.ordinal(),
9561        });
9562    }
9563    if b.device_ordinal() != device.ordinal() {
9564        return Err(GpuError::DeviceMismatch {
9565            expected: b.device_ordinal(),
9566            got: device.ordinal(),
9567        });
9568    }
9569    if a.len() != b.len() {
9570        return Err(GpuError::LengthMismatch {
9571            a: a.len(),
9572            b: b.len(),
9573        });
9574    }
9575    Ok(())
9576}
9577
9578/// Validate that a unary buffer is on the correct device.
9579#[cfg(feature = "cuda")]
9580fn validate_unary(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
9581    if a.device_ordinal() != device.ordinal() {
9582        return Err(GpuError::DeviceMismatch {
9583            expected: a.device_ordinal(),
9584            got: device.ordinal(),
9585        });
9586    }
9587    Ok(())
9588}
9589
9590/// Generic device-ordinal check for any `CudaBuffer<T>`.
9591#[cfg(feature = "cuda")]
9592fn validate_device<T>(a: &CudaBuffer<T>, device: &GpuDevice) -> GpuResult<()> {
9593    if a.device_ordinal() != device.ordinal() {
9594        return Err(GpuError::DeviceMismatch {
9595            expected: a.device_ordinal(),
9596            got: device.ordinal(),
9597        });
9598    }
9599    Ok(())
9600}
9601
9602// ---------------------------------------------------------------------------
9603// PTX kernel launch helpers
9604// ---------------------------------------------------------------------------
9605
9606/// Try to launch a binary PTX kernel. Returns `Ok(Some(buf))` on success,
9607/// `Ok(None)` if the PTX module failed to load (caller should fall back to
9608/// CPU), or `Err` on a real CUDA error after a successful launch.
9609#[cfg(feature = "cuda")]
9610fn try_launch_binary(
9611    a: &CudaBuffer<f32>,
9612    b: &CudaBuffer<f32>,
9613    device: &GpuDevice,
9614    ptx_src: &'static str,
9615    kernel_name: &'static str,
9616) -> GpuResult<Option<CudaBuffer<f32>>> {
9617    use cudarc::driver::PushKernelArg;
9618
9619    let n = a.len();
9620    let ctx = device.context();
9621    let stream = device.stream();
9622
9623    // Attempt to load the kernel (cached after first compilation).
9624    // If it fails (e.g. unsupported arch), return None so the caller
9625    // can use the CPU fallback.
9626    let f = match crate::module_cache::get_or_compile(
9627        ctx,
9628        ptx_src,
9629        kernel_name,
9630        device.ordinal() as u32,
9631    ) {
9632        Ok(f) => f,
9633        Err(_) => return Ok(None),
9634    };
9635
9636    let mut out = alloc_zeros_f32(n, device)?;
9637    let cfg = launch_cfg(n)?;
9638    let n_u32 = n as u32;
9639
9640    // SAFETY: The kernel reads `n` f32 values from `a` and `b`, writes `n`
9641    // f32 values to `out`. All three buffers are device-resident and at
9642    // least `n` elements long. The grid covers exactly `n` threads.
9643    unsafe {
9644        stream
9645            .launch_builder(&f)
9646            .arg(a.inner())
9647            .arg(b.inner())
9648            .arg(out.inner_mut())
9649            .arg(&n_u32)
9650            .launch(cfg)?;
9651    }
9652
9653    Ok(Some(out))
9654}
9655
9656/// Try to launch a vectorized (vec4) binary PTX kernel.
9657///
9658/// Each thread processes 4 elements using 128-bit loads/stores.
9659/// `n` must be divisible by 4. Returns `Ok(None)` if compilation fails.
9660#[cfg(feature = "cuda")]
9661fn try_launch_binary_vec4(
9662    a: &CudaBuffer<f32>,
9663    b: &CudaBuffer<f32>,
9664    device: &GpuDevice,
9665    ptx_src: &'static str,
9666    kernel_name: &'static str,
9667) -> GpuResult<Option<CudaBuffer<f32>>> {
9668    use cudarc::driver::PushKernelArg;
9669
9670    let n = a.len();
9671    let n4 = (n / 4) as u32;
9672    let ctx = device.context();
9673    let stream = device.stream();
9674
9675    let f = match crate::module_cache::get_or_compile(
9676        ctx,
9677        ptx_src,
9678        kernel_name,
9679        device.ordinal() as u32,
9680    ) {
9681        Ok(f) => f,
9682        Err(_) => return Ok(None),
9683    };
9684
9685    let mut out = alloc_zeros_f32(n, device)?;
9686    let cfg = launch_cfg(n4 as usize)?;
9687
9688    unsafe {
9689        stream
9690            .launch_builder(&f)
9691            .arg(a.inner())
9692            .arg(b.inner())
9693            .arg(out.inner_mut())
9694            .arg(&n4)
9695            .launch(cfg)?;
9696    }
9697
9698    Ok(Some(out))
9699}
9700
9701/// Try to launch a unary PTX kernel. Returns `Ok(Some(buf))` on success,
9702/// `Ok(None)` if the PTX module failed to load.
9703#[cfg(feature = "cuda")]
9704fn try_launch_unary(
9705    a: &CudaBuffer<f32>,
9706    device: &GpuDevice,
9707    ptx_src: &'static str,
9708    kernel_name: &'static str,
9709) -> GpuResult<Option<CudaBuffer<f32>>> {
9710    use cudarc::driver::PushKernelArg;
9711
9712    let n = a.len();
9713    let ctx = device.context();
9714    let stream = device.stream();
9715
9716    // Attempt to load the kernel (cached after first compilation).
9717    let f = match crate::module_cache::get_or_compile(
9718        ctx,
9719        ptx_src,
9720        kernel_name,
9721        device.ordinal() as u32,
9722    ) {
9723        Ok(f) => f,
9724        Err(_) => return Ok(None),
9725    };
9726
9727    let mut out = alloc_zeros_f32(n, device)?;
9728    let cfg = launch_cfg(n)?;
9729    let n_u32 = n as u32;
9730
9731    // SAFETY: The kernel reads `n` f32 values from `a` and writes `n` f32
9732    // values to `out`. Both buffers are device-resident with length >= n.
9733    unsafe {
9734        stream
9735            .launch_builder(&f)
9736            .arg(a.inner())
9737            .arg(out.inner_mut())
9738            .arg(&n_u32)
9739            .launch(cfg)?;
9740    }
9741
9742    Ok(Some(out))
9743}
9744
9745// ---------------------------------------------------------------------------
9746// _into helpers — write to pre-allocated output buffer (no allocation)
9747// ---------------------------------------------------------------------------
9748
9749/// Launch a binary PTX kernel into a pre-allocated output buffer.
9750/// Returns `Ok(true)` on success, `Ok(false)` if the PTX module failed to load.
9751#[cfg(feature = "cuda")]
9752fn try_launch_binary_into(
9753    a: &CudaBuffer<f32>,
9754    b: &CudaBuffer<f32>,
9755    out: &mut CudaBuffer<f32>,
9756    device: &GpuDevice,
9757    ptx_src: &'static str,
9758    kernel_name: &'static str,
9759) -> GpuResult<bool> {
9760    use cudarc::driver::PushKernelArg;
9761
9762    let n = a.len();
9763    let ctx = device.context();
9764    let stream = device.stream();
9765
9766    let f = match crate::module_cache::get_or_compile(
9767        ctx,
9768        ptx_src,
9769        kernel_name,
9770        device.ordinal() as u32,
9771    ) {
9772        Ok(f) => f,
9773        Err(_) => return Ok(false),
9774    };
9775
9776    let cfg = launch_cfg(n)?;
9777    let n_u32 = n as u32;
9778
9779    unsafe {
9780        stream
9781            .launch_builder(&f)
9782            .arg(a.inner())
9783            .arg(b.inner())
9784            .arg(out.inner_mut())
9785            .arg(&n_u32)
9786            .launch(cfg)?;
9787    }
9788
9789    Ok(true)
9790}
9791
9792/// Launch a unary PTX kernel into a pre-allocated output buffer.
9793/// Returns `Ok(true)` on success, `Ok(false)` if the PTX module failed to load.
9794#[cfg(feature = "cuda")]
9795fn try_launch_unary_into(
9796    a: &CudaBuffer<f32>,
9797    out: &mut CudaBuffer<f32>,
9798    device: &GpuDevice,
9799    ptx_src: &'static str,
9800    kernel_name: &'static str,
9801) -> GpuResult<bool> {
9802    use cudarc::driver::PushKernelArg;
9803
9804    let n = a.len();
9805    let ctx = device.context();
9806    let stream = device.stream();
9807
9808    let f = match crate::module_cache::get_or_compile(
9809        ctx,
9810        ptx_src,
9811        kernel_name,
9812        device.ordinal() as u32,
9813    ) {
9814        Ok(f) => f,
9815        Err(_) => return Ok(false),
9816    };
9817
9818    let cfg = launch_cfg(n)?;
9819    let n_u32 = n as u32;
9820
9821    unsafe {
9822        stream
9823            .launch_builder(&f)
9824            .arg(a.inner())
9825            .arg(out.inner_mut())
9826            .arg(&n_u32)
9827            .launch(cfg)?;
9828    }
9829
9830    Ok(true)
9831}
9832
9833// ---------------------------------------------------------------------------
9834// f64 launch helpers
9835// ---------------------------------------------------------------------------
9836
9837/// Try to launch a binary f64 PTX kernel.
9838#[cfg(feature = "cuda")]
9839fn try_launch_binary_f64(
9840    a: &CudaBuffer<f64>,
9841    b: &CudaBuffer<f64>,
9842    device: &GpuDevice,
9843    ptx_src: &'static str,
9844    kernel_name: &'static str,
9845) -> GpuResult<Option<CudaBuffer<f64>>> {
9846    use cudarc::driver::PushKernelArg;
9847
9848    let n = a.len();
9849    let ctx = device.context();
9850    let stream = device.stream();
9851
9852    let f = match crate::module_cache::get_or_compile(
9853        ctx, ptx_src, kernel_name, device.ordinal() as u32,
9854    ) {
9855        Ok(f) => f,
9856        Err(_) => return Ok(None),
9857    };
9858
9859    let mut out = alloc_zeros_f64(n, device)?;
9860    let cfg = launch_cfg(n)?;
9861    let n_u32 = n as u32;
9862
9863    unsafe {
9864        stream
9865            .launch_builder(&f)
9866            .arg(a.inner())
9867            .arg(b.inner())
9868            .arg(out.inner_mut())
9869            .arg(&n_u32)
9870            .launch(cfg)?;
9871    }
9872    Ok(Some(out))
9873}
9874
9875/// Try to launch a unary f64 PTX kernel.
9876#[cfg(feature = "cuda")]
9877fn try_launch_unary_f64(
9878    a: &CudaBuffer<f64>,
9879    device: &GpuDevice,
9880    ptx_src: &'static str,
9881    kernel_name: &'static str,
9882) -> GpuResult<Option<CudaBuffer<f64>>> {
9883    use cudarc::driver::PushKernelArg;
9884
9885    let n = a.len();
9886    let ctx = device.context();
9887    let stream = device.stream();
9888
9889    let f = match crate::module_cache::get_or_compile(
9890        ctx, ptx_src, kernel_name, device.ordinal() as u32,
9891    ) {
9892        Ok(f) => f,
9893        Err(_) => return Ok(None),
9894    };
9895
9896    let mut out = alloc_zeros_f64(n, device)?;
9897    let cfg = launch_cfg(n)?;
9898    let n_u32 = n as u32;
9899
9900    unsafe {
9901        stream
9902            .launch_builder(&f)
9903            .arg(a.inner())
9904            .arg(out.inner_mut())
9905            .arg(&n_u32)
9906            .launch(cfg)?;
9907    }
9908    Ok(Some(out))
9909}
9910
9911/// CPU fallback for f64 binary ops.
9912#[cfg(feature = "cuda")]
9913fn cpu_fallback_binary_f64(
9914    a: &CudaBuffer<f64>,
9915    b: &CudaBuffer<f64>,
9916    device: &GpuDevice,
9917    op: fn(f64, f64) -> f64,
9918) -> GpuResult<CudaBuffer<f64>> {
9919    let a_host = gpu_to_cpu(a, device)?;
9920    let b_host = gpu_to_cpu(b, device)?;
9921    let result: Vec<f64> = a_host.iter().zip(b_host.iter()).map(|(&x, &y)| op(x, y)).collect();
9922    cpu_to_gpu(&result, device)
9923}
9924
9925/// CPU fallback for f64 unary ops.
9926#[cfg(feature = "cuda")]
9927fn cpu_fallback_unary_f64(
9928    a: &CudaBuffer<f64>,
9929    device: &GpuDevice,
9930    op: fn(f64) -> f64,
9931) -> GpuResult<CudaBuffer<f64>> {
9932    let a_host = gpu_to_cpu(a, device)?;
9933    let result: Vec<f64> = a_host.iter().map(|&x| op(x)).collect();
9934    cpu_to_gpu(&result, device)
9935}
9936
9937/// Try to launch a general N-dimensional broadcast binary f64 PTX kernel.
9938///
9939/// Same as [`try_launch_broadcast_binary`] but for `f64` buffers.
9940#[cfg(feature = "cuda")]
9941#[allow(clippy::too_many_arguments)]
9942fn try_launch_broadcast_binary_f64(
9943    a: &CudaBuffer<f64>,
9944    b: &CudaBuffer<f64>,
9945    a_strides: &[u32],
9946    b_strides: &[u32],
9947    out_shape: &[u32],
9948    out_numel: usize,
9949    device: &GpuDevice,
9950    ptx_src: &'static str,
9951    kernel_name: &'static str,
9952) -> GpuResult<Option<CudaBuffer<f64>>> {
9953    use cudarc::driver::PushKernelArg;
9954
9955    let ndim = out_shape.len();
9956    let ctx = device.context();
9957    let stream = device.stream();
9958
9959    let f = match crate::module_cache::get_or_compile(
9960        ctx,
9961        ptx_src,
9962        kernel_name,
9963        device.ordinal() as u32,
9964    ) {
9965        Ok(f) => f,
9966        Err(_) => return Ok(None),
9967    };
9968
9969    // Upload stride/shape metadata as small device buffers.
9970    let a_str_buf = cpu_to_gpu(a_strides, device)?;
9971    let b_str_buf = cpu_to_gpu(b_strides, device)?;
9972    let shape_buf = cpu_to_gpu(out_shape, device)?;
9973
9974    let mut out = alloc_zeros_f64(out_numel, device)?;
9975    let cfg = launch_cfg(out_numel)?;
9976    let n_u32 = out_numel as u32;
9977    let ndim_u32 = ndim as u32;
9978
9979    unsafe {
9980        stream
9981            .launch_builder(&f)
9982            .arg(a.inner())
9983            .arg(b.inner())
9984            .arg(out.inner_mut())
9985            .arg(a_str_buf.inner())
9986            .arg(b_str_buf.inner())
9987            .arg(shape_buf.inner())
9988            .arg(&n_u32)
9989            .arg(&ndim_u32)
9990            .launch(cfg)?;
9991    }
9992
9993    Ok(Some(out))
9994}
9995
9996/// CPU fallback for f64 broadcast binary ops.
9997#[cfg(feature = "cuda")]
9998fn cpu_fallback_broadcast_binary_f64(
9999    a: &CudaBuffer<f64>,
10000    b: &CudaBuffer<f64>,
10001    a_shape: &[usize],
10002    b_shape: &[usize],
10003    out_shape: &[usize],
10004    device: &GpuDevice,
10005    op: fn(f64, f64) -> f64,
10006) -> GpuResult<CudaBuffer<f64>> {
10007    let a_host = gpu_to_cpu(a, device)?;
10008    let b_host = gpu_to_cpu(b, device)?;
10009    let out_numel: usize = out_shape.iter().product();
10010
10011    let a_str = broadcast_strides(a_shape, out_shape);
10012    let b_str = broadcast_strides(b_shape, out_shape);
10013
10014    let mut result = Vec::with_capacity(out_numel);
10015    for i in 0..out_numel {
10016        let mut remaining = i;
10017        let mut a_idx = 0usize;
10018        let mut b_idx = 0usize;
10019        for d in (0..out_shape.len()).rev() {
10020            let coord = remaining % out_shape[d];
10021            remaining /= out_shape[d];
10022            a_idx += coord * a_str[d] as usize;
10023            b_idx += coord * b_str[d] as usize;
10024        }
10025        result.push(op(a_host[a_idx], b_host[b_idx]));
10026    }
10027    cpu_to_gpu(&result, device)
10028}
10029
10030/// Try to launch a general N-dimensional broadcast binary PTX kernel.
10031///
10032/// `a_strides` and `b_strides` are broadcast strides: normal C-contiguous
10033/// stride for non-broadcast dims, 0 for broadcast (size-1) dims.
10034/// `out_shape` is the broadcast-resolved output shape.
10035/// All three arrays have length `ndim`.
10036#[cfg(feature = "cuda")]
10037#[allow(clippy::too_many_arguments)]
10038fn try_launch_broadcast_binary(
10039    a: &CudaBuffer<f32>,
10040    b: &CudaBuffer<f32>,
10041    a_strides: &[u32],
10042    b_strides: &[u32],
10043    out_shape: &[u32],
10044    out_numel: usize,
10045    device: &GpuDevice,
10046    ptx_src: &'static str,
10047    kernel_name: &'static str,
10048) -> GpuResult<Option<CudaBuffer<f32>>> {
10049    use cudarc::driver::PushKernelArg;
10050
10051    let ndim = out_shape.len();
10052    let ctx = device.context();
10053    let stream = device.stream();
10054
10055    let f = match crate::module_cache::get_or_compile(
10056        ctx,
10057        ptx_src,
10058        kernel_name,
10059        device.ordinal() as u32,
10060    ) {
10061        Ok(f) => f,
10062        Err(_) => return Ok(None),
10063    };
10064
10065    // Upload stride/shape metadata as small device buffers.
10066    let a_str_buf = cpu_to_gpu(a_strides, device)?;
10067    let b_str_buf = cpu_to_gpu(b_strides, device)?;
10068    let shape_buf = cpu_to_gpu(out_shape, device)?;
10069
10070    let mut out = alloc_zeros_f32(out_numel, device)?;
10071    let cfg = launch_cfg(out_numel)?;
10072    let n_u32 = out_numel as u32;
10073    let ndim_u32 = ndim as u32;
10074
10075    // SAFETY: Kernel reads from a, b using broadcast indices computed from
10076    // the stride/shape buffers. Output buffer has out_numel elements.
10077    unsafe {
10078        stream
10079            .launch_builder(&f)
10080            .arg(a.inner())
10081            .arg(b.inner())
10082            .arg(out.inner_mut())
10083            .arg(a_str_buf.inner())
10084            .arg(b_str_buf.inner())
10085            .arg(shape_buf.inner())
10086            .arg(&n_u32)
10087            .arg(&ndim_u32)
10088            .launch(cfg)?;
10089    }
10090
10091    Ok(Some(out))
10092}
10093
10094/// Compute broadcast strides for a tensor shape relative to an output shape.
10095///
10096/// For each dimension, the stride is the normal C-contiguous stride if the
10097/// dimension size matches the output, or 0 if the dimension size is 1
10098/// (broadcast). Missing leading dimensions (when input has fewer dims) are
10099/// treated as size-1.
10100#[cfg(feature = "cuda")]
10101fn broadcast_strides(in_shape: &[usize], out_shape: &[usize]) -> Vec<u32> {
10102    let ndim = out_shape.len();
10103    let in_ndim = in_shape.len();
10104    let mut strides = vec![0u32; ndim];
10105
10106    // C-contiguous strides for the input shape.
10107    let mut stride: u32 = 1;
10108    for d in (0..ndim).rev() {
10109        let in_d = if d + in_ndim >= ndim {
10110            d + in_ndim - ndim
10111        } else {
10112            // Leading dimension not present in input — broadcast.
10113            strides[d] = 0;
10114            continue;
10115        };
10116
10117        if in_shape[in_d] == 1 {
10118            strides[d] = 0; // Broadcast dimension.
10119        } else {
10120            strides[d] = stride;
10121        }
10122        stride *= in_shape[in_d] as u32;
10123    }
10124
10125    strides
10126}
10127
10128// ---------------------------------------------------------------------------
10129// CPU fallback helpers
10130// ---------------------------------------------------------------------------
10131
10132/// CPU fallback for binary ops: copy both inputs to host, apply `op`, copy
10133/// the result back.
10134#[cfg(feature = "cuda")]
10135fn cpu_fallback_binary(
10136    a: &CudaBuffer<f32>,
10137    b: &CudaBuffer<f32>,
10138    device: &GpuDevice,
10139    op: fn(f32, f32) -> f32,
10140) -> GpuResult<CudaBuffer<f32>> {
10141    let a_host = gpu_to_cpu(a, device)?;
10142    let b_host = gpu_to_cpu(b, device)?;
10143    let result: Vec<f32> = a_host
10144        .iter()
10145        .zip(b_host.iter())
10146        .map(|(&x, &y)| op(x, y))
10147        .collect();
10148    cpu_to_gpu(&result, device)
10149}
10150
10151/// CPU fallback for unary ops.
10152#[cfg(feature = "cuda")]
10153fn cpu_fallback_unary(
10154    a: &CudaBuffer<f32>,
10155    device: &GpuDevice,
10156    op: fn(f32) -> f32,
10157) -> GpuResult<CudaBuffer<f32>> {
10158    let a_host = gpu_to_cpu(a, device)?;
10159    let result: Vec<f32> = a_host.iter().map(|&x| op(x)).collect();
10160    cpu_to_gpu(&result, device)
10161}
10162
10163// ---------------------------------------------------------------------------
10164// Public API -- binary ops
10165// ---------------------------------------------------------------------------
10166
10167/// Elementwise addition: `out[i] = a[i] + b[i]`.
10168///
10169/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
10170/// if the PTX module cannot be loaded.
10171///
10172/// # Errors
10173///
10174/// - [`GpuError::DeviceMismatch`] if `a`, `b`, or `device` refer to
10175///   different CUDA devices.
10176/// - [`GpuError::LengthMismatch`] if `a` and `b` have different lengths.
10177/// - [`GpuError::Driver`] on CUDA runtime errors.
10178#[cfg(feature = "cuda")]
10179pub fn gpu_add(
10180    a: &CudaBuffer<f32>,
10181    b: &CudaBuffer<f32>,
10182    device: &GpuDevice,
10183) -> GpuResult<CudaBuffer<f32>> {
10184    validate_binary(a, b, device)?;
10185
10186    // Try vec4 kernel for 4x memory throughput (128-bit loads).
10187    let n = a.len();
10188    if n >= 16 && n % 4 == 0 {
10189        if let Some(out) = try_launch_binary_vec4(
10190            a, b, device, ADD_VEC4_PTX, "add_vec4_kernel",
10191        )? {
10192            return Ok(out);
10193        }
10194    }
10195
10196    if let Some(out) = try_launch_binary(a, b, device, ADD_PTX, "add_kernel")? {
10197        return Ok(out);
10198    }
10199
10200    cpu_fallback_binary(a, b, device, |x, y| x + y)
10201}
10202
10203/// Elementwise subtraction: `out[i] = a[i] - b[i]`.
10204///
10205/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
10206/// if the PTX module cannot be loaded.
10207///
10208/// # Errors
10209///
10210/// - [`GpuError::DeviceMismatch`] if `a`, `b`, or `device` refer to
10211///   different CUDA devices.
10212/// - [`GpuError::LengthMismatch`] if `a` and `b` have different lengths.
10213/// - [`GpuError::Driver`] on CUDA runtime errors.
10214#[cfg(feature = "cuda")]
10215pub fn gpu_sub(
10216    a: &CudaBuffer<f32>,
10217    b: &CudaBuffer<f32>,
10218    device: &GpuDevice,
10219) -> GpuResult<CudaBuffer<f32>> {
10220    validate_binary(a, b, device)?;
10221
10222    if let Some(out) = try_launch_binary(a, b, device, SUB_PTX, "sub_kernel")? {
10223        return Ok(out);
10224    }
10225
10226    cpu_fallback_binary(a, b, device, |x, y| x - y)
10227}
10228
10229/// Elementwise multiplication: `out[i] = a[i] * b[i]`.
10230///
10231/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
10232/// if the PTX module cannot be loaded.
10233///
10234/// # Errors
10235///
10236/// - [`GpuError::DeviceMismatch`] if `a`, `b`, or `device` refer to
10237///   different CUDA devices.
10238/// - [`GpuError::LengthMismatch`] if `a` and `b` have different lengths.
10239/// - [`GpuError::Driver`] on CUDA runtime errors.
10240#[cfg(feature = "cuda")]
10241pub fn gpu_mul(
10242    a: &CudaBuffer<f32>,
10243    b: &CudaBuffer<f32>,
10244    device: &GpuDevice,
10245) -> GpuResult<CudaBuffer<f32>> {
10246    validate_binary(a, b, device)?;
10247
10248    let n = a.len();
10249    if n >= 16 && n % 4 == 0 {
10250        if let Some(out) = try_launch_binary_vec4(
10251            a, b, device, MUL_VEC4_PTX, "mul_vec4_kernel",
10252        )? {
10253            return Ok(out);
10254        }
10255    }
10256
10257    if let Some(out) = try_launch_binary(a, b, device, MUL_PTX, "mul_kernel")? {
10258        return Ok(out);
10259    }
10260
10261    cpu_fallback_binary(a, b, device, |x, y| x * y)
10262}
10263
10264// ---------------------------------------------------------------------------
10265// Public API -- broadcast binary ops
10266// ---------------------------------------------------------------------------
10267
10268/// Broadcast addition: `out[i] = a[bcast_a(i)] + b[bcast_b(i)]`.
10269///
10270/// Handles arbitrary N-dimensional broadcasting on the GPU. The kernel
10271/// decomposes each output index into coordinates, maps them through
10272/// broadcast strides, and loads from the correct positions in A and B.
10273///
10274/// `a_shape` and `b_shape` are the original shapes; the output shape is
10275/// computed via numpy-style broadcast rules.
10276#[cfg(feature = "cuda")]
10277pub fn gpu_broadcast_add(
10278    a: &CudaBuffer<f32>,
10279    b: &CudaBuffer<f32>,
10280    a_shape: &[usize],
10281    b_shape: &[usize],
10282    out_shape: &[usize],
10283    device: &GpuDevice,
10284) -> GpuResult<CudaBuffer<f32>> {
10285    let a_str = broadcast_strides(a_shape, out_shape);
10286    let b_str = broadcast_strides(b_shape, out_shape);
10287    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10288    let out_numel: usize = out_shape.iter().product();
10289
10290    if let Some(out) = try_launch_broadcast_binary(
10291        a,
10292        b,
10293        &a_str,
10294        &b_str,
10295        &shape_u32,
10296        out_numel,
10297        device,
10298        BROADCAST_ADD_PTX,
10299        "broadcast_add_kernel",
10300    )? {
10301        return Ok(out);
10302    }
10303
10304    // CPU fallback for broadcast.
10305    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
10306}
10307
10308/// Broadcast subtraction: `out[i] = a[bcast_a(i)] - b[bcast_b(i)]`.
10309#[cfg(feature = "cuda")]
10310pub fn gpu_broadcast_sub(
10311    a: &CudaBuffer<f32>,
10312    b: &CudaBuffer<f32>,
10313    a_shape: &[usize],
10314    b_shape: &[usize],
10315    out_shape: &[usize],
10316    device: &GpuDevice,
10317) -> GpuResult<CudaBuffer<f32>> {
10318    let a_str = broadcast_strides(a_shape, out_shape);
10319    let b_str = broadcast_strides(b_shape, out_shape);
10320    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10321    let out_numel: usize = out_shape.iter().product();
10322
10323    if let Some(out) = try_launch_broadcast_binary(
10324        a,
10325        b,
10326        &a_str,
10327        &b_str,
10328        &shape_u32,
10329        out_numel,
10330        device,
10331        BROADCAST_SUB_PTX,
10332        "broadcast_sub_kernel",
10333    )? {
10334        return Ok(out);
10335    }
10336
10337    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
10338}
10339
10340/// Broadcast multiplication: `out[i] = a[bcast_a(i)] * b[bcast_b(i)]`.
10341#[cfg(feature = "cuda")]
10342pub fn gpu_broadcast_mul(
10343    a: &CudaBuffer<f32>,
10344    b: &CudaBuffer<f32>,
10345    a_shape: &[usize],
10346    b_shape: &[usize],
10347    out_shape: &[usize],
10348    device: &GpuDevice,
10349) -> GpuResult<CudaBuffer<f32>> {
10350    let a_str = broadcast_strides(a_shape, out_shape);
10351    let b_str = broadcast_strides(b_shape, out_shape);
10352    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10353    let out_numel: usize = out_shape.iter().product();
10354
10355    if let Some(out) = try_launch_broadcast_binary(
10356        a,
10357        b,
10358        &a_str,
10359        &b_str,
10360        &shape_u32,
10361        out_numel,
10362        device,
10363        BROADCAST_MUL_PTX,
10364        "broadcast_mul_kernel",
10365    )? {
10366        return Ok(out);
10367    }
10368
10369    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
10370}
10371
10372/// Broadcast division: `out[i] = a[bcast_a(i)] / b[bcast_b(i)]`.
10373#[cfg(feature = "cuda")]
10374pub fn gpu_broadcast_div(
10375    a: &CudaBuffer<f32>,
10376    b: &CudaBuffer<f32>,
10377    a_shape: &[usize],
10378    b_shape: &[usize],
10379    out_shape: &[usize],
10380    device: &GpuDevice,
10381) -> GpuResult<CudaBuffer<f32>> {
10382    let a_str = broadcast_strides(a_shape, out_shape);
10383    let b_str = broadcast_strides(b_shape, out_shape);
10384    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10385    let out_numel: usize = out_shape.iter().product();
10386
10387    if let Some(out) = try_launch_broadcast_binary(
10388        a,
10389        b,
10390        &a_str,
10391        &b_str,
10392        &shape_u32,
10393        out_numel,
10394        device,
10395        BROADCAST_DIV_PTX,
10396        "broadcast_div_kernel",
10397    )? {
10398        return Ok(out);
10399    }
10400
10401    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
10402}
10403
10404/// CPU fallback for broadcast binary ops — downloads, applies op with
10405/// broadcast indexing, re-uploads.
10406#[cfg(feature = "cuda")]
10407fn cpu_fallback_broadcast_binary(
10408    a: &CudaBuffer<f32>,
10409    b: &CudaBuffer<f32>,
10410    a_shape: &[usize],
10411    b_shape: &[usize],
10412    out_shape: &[usize],
10413    device: &GpuDevice,
10414    op: fn(f32, f32) -> f32,
10415) -> GpuResult<CudaBuffer<f32>> {
10416    let a_host = gpu_to_cpu(a, device)?;
10417    let b_host = gpu_to_cpu(b, device)?;
10418    let out_numel: usize = out_shape.iter().product();
10419
10420    let a_str = broadcast_strides(a_shape, out_shape);
10421    let b_str = broadcast_strides(b_shape, out_shape);
10422
10423    let mut result = Vec::with_capacity(out_numel);
10424    for i in 0..out_numel {
10425        let mut remaining = i;
10426        let mut a_idx = 0usize;
10427        let mut b_idx = 0usize;
10428        for d in (0..out_shape.len()).rev() {
10429            let coord = remaining % out_shape[d];
10430            remaining /= out_shape[d];
10431            a_idx += coord * a_str[d] as usize;
10432            b_idx += coord * b_str[d] as usize;
10433        }
10434        result.push(op(a_host[a_idx], b_host[b_idx]));
10435    }
10436    cpu_to_gpu(&result, device)
10437}
10438
10439// ---------------------------------------------------------------------------
10440// Public API -- unary ops
10441// ---------------------------------------------------------------------------
10442
10443/// Elementwise negation: `out[i] = -a[i]`.
10444///
10445/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
10446/// if the PTX module cannot be loaded.
10447///
10448/// # Errors
10449///
10450/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different
10451///   CUDA devices.
10452/// - [`GpuError::Driver`] on CUDA runtime errors.
10453#[cfg(feature = "cuda")]
10454pub fn gpu_neg(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
10455    validate_unary(a, device)?;
10456
10457    if let Some(out) = try_launch_unary(a, device, NEG_PTX, "neg_kernel")? {
10458        return Ok(out);
10459    }
10460
10461    cpu_fallback_unary(a, device, |x| -x)
10462}
10463
10464/// Elementwise ReLU: `out[i] = max(a[i], 0.0)`.
10465///
10466/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
10467/// if the PTX module cannot be loaded.
10468///
10469/// # Errors
10470///
10471/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different
10472///   CUDA devices.
10473/// - [`GpuError::Driver`] on CUDA runtime errors.
10474#[cfg(feature = "cuda")]
10475pub fn gpu_relu(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
10476    validate_unary(a, device)?;
10477
10478    if let Some(out) = try_launch_unary(a, device, RELU_PTX, "relu_kernel")? {
10479        return Ok(out);
10480    }
10481
10482    cpu_fallback_unary(a, device, |x| x.max(0.0))
10483}
10484
10485/// ReLU backward: `out[i] = (input[i] > 0) ? grad[i] : 0`.
10486#[cfg(feature = "cuda")]
10487pub fn gpu_relu_backward(
10488    grad: &CudaBuffer<f32>,
10489    input: &CudaBuffer<f32>,
10490    device: &GpuDevice,
10491) -> GpuResult<CudaBuffer<f32>> {
10492    validate_binary(grad, input, device)?;
10493
10494    if let Some(out) = try_launch_binary(
10495        grad,
10496        input,
10497        device,
10498        RELU_BACKWARD_PTX,
10499        "relu_backward_kernel",
10500    )? {
10501        return Ok(out);
10502    }
10503
10504    // CPU fallback
10505    let grad_host = gpu_to_cpu(grad, device)?;
10506    let input_host = gpu_to_cpu(input, device)?;
10507    let result: Vec<f32> = grad_host
10508        .iter()
10509        .zip(input_host.iter())
10510        .map(|(&g, &x)| if x > 0.0 { g } else { 0.0 })
10511        .collect();
10512    cpu_to_gpu(&result, device)
10513}
10514
10515/// Elementwise backward for `|x|`: `out[i] = grad[i] * sign(input[i])`
10516/// with the convention `sign(0) = 0`. Drives `AbsBackward` on GPU.
10517#[cfg(feature = "cuda")]
10518pub fn gpu_abs_backward(
10519    grad: &CudaBuffer<f32>,
10520    input: &CudaBuffer<f32>,
10521    device: &GpuDevice,
10522) -> GpuResult<CudaBuffer<f32>> {
10523    validate_binary(grad, input, device)?;
10524
10525    if let Some(out) = try_launch_binary(
10526        grad,
10527        input,
10528        device,
10529        ABS_BACKWARD_PTX,
10530        "abs_backward_kernel",
10531    )? {
10532        return Ok(out);
10533    }
10534
10535    // CPU fallback
10536    let grad_host = gpu_to_cpu(grad, device)?;
10537    let input_host = gpu_to_cpu(input, device)?;
10538    let result: Vec<f32> = grad_host
10539        .iter()
10540        .zip(input_host.iter())
10541        .map(|(&g, &x)| {
10542            if x > 0.0 {
10543                g
10544            } else if x < 0.0 {
10545                -g
10546            } else {
10547                0.0
10548            }
10549        })
10550        .collect();
10551    cpu_to_gpu(&result, device)
10552}
10553
10554/// GELU backward: `out[i] = grad[i] * (sig + 1.702 * x * sig * (1 - sig))`
10555/// where `sig = sigmoid(1.702 * x)`.
10556#[cfg(feature = "cuda")]
10557pub fn gpu_gelu_backward(
10558    grad: &CudaBuffer<f32>,
10559    input: &CudaBuffer<f32>,
10560    device: &GpuDevice,
10561) -> GpuResult<CudaBuffer<f32>> {
10562    validate_binary(grad, input, device)?;
10563
10564    if let Some(out) = try_launch_binary(
10565        grad,
10566        input,
10567        device,
10568        GELU_BACKWARD_PTX,
10569        "gelu_backward_kernel",
10570    )? {
10571        return Ok(out);
10572    }
10573
10574    // CPU fallback
10575    let grad_host = gpu_to_cpu(grad, device)?;
10576    let input_host = gpu_to_cpu(input, device)?;
10577    let result: Vec<f32> = grad_host
10578        .iter()
10579        .zip(input_host.iter())
10580        .map(|(&g, &x)| {
10581            let k: f32 = 1.702;
10582            let sig = 1.0 / (1.0 + (-k * x).exp());
10583            g * (sig + k * x * sig * (1.0 - sig))
10584        })
10585        .collect();
10586    cpu_to_gpu(&result, device)
10587}
10588
10589/// GELU backward (exact erf mode):
10590/// `out[i] = grad[i] * (Φ(x) + x·φ(x))`
10591/// where Φ = normal CDF, φ = normal PDF.
10592#[cfg(feature = "cuda")]
10593pub fn gpu_gelu_backward_erf(
10594    grad: &CudaBuffer<f32>,
10595    input: &CudaBuffer<f32>,
10596    device: &GpuDevice,
10597) -> GpuResult<CudaBuffer<f32>> {
10598    validate_binary(grad, input, device)?;
10599
10600    if let Some(out) = try_launch_binary(
10601        grad,
10602        input,
10603        device,
10604        GELU_BACKWARD_ERF_PTX,
10605        "gelu_backward_erf_kernel",
10606    )? {
10607        return Ok(out);
10608    }
10609
10610    // CPU fallback — Abramowitz & Stegun erf approximation (|ε| < 1.5e-7)
10611    let grad_host = gpu_to_cpu(grad, device)?;
10612    let input_host = gpu_to_cpu(input, device)?;
10613    let inv_sqrt_2: f32 = std::f32::consts::FRAC_1_SQRT_2;
10614    let inv_sqrt_2pi: f32 = 1.0 / (2.0 * std::f32::consts::PI).sqrt();
10615    let result: Vec<f32> = grad_host
10616        .iter()
10617        .zip(input_host.iter())
10618        .map(|(&g, &x)| {
10619            let z = x * inv_sqrt_2;
10620            let az = z.abs();
10621            let t = 1.0 / (1.0 + 0.3275911 * az);
10622            let poly = t * (0.2548296 + t * (-0.2844967 + t * (1.4214137 + t * (-1.453_152 + t * 0.3275911))));
10623            let erf_abs = 1.0 - poly * (-az * az).exp();
10624            let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
10625            let cdf = 0.5 * (1.0 + erf_val);
10626            let pdf = inv_sqrt_2pi * (-0.5 * x * x).exp();
10627            g * (cdf + x * pdf)
10628        })
10629        .collect();
10630    cpu_to_gpu(&result, device)
10631}
10632
10633// ---------------------------------------------------------------------------
10634// Public API -- Index-select 1-D (gather)
10635// ---------------------------------------------------------------------------
10636
10637/// Gather elements from `input` at positions given by `indices`.
10638///
10639/// `indices` is a GPU buffer of f32 values encoding integer indices.
10640/// Output has `indices.len()` elements: `out[i] = input[indices[i]]`.
10641#[cfg(feature = "cuda")]
10642pub fn gpu_index_select_1d(
10643    input: &CudaBuffer<f32>,
10644    indices: &CudaBuffer<f32>,
10645    device: &GpuDevice,
10646) -> GpuResult<CudaBuffer<f32>> {
10647    use cudarc::driver::PushKernelArg;
10648
10649    validate_unary(input, device)?;
10650
10651    let n = indices.len();
10652    let ctx = device.context();
10653    let stream = device.stream();
10654
10655    let f = match crate::module_cache::get_or_compile(
10656        ctx,
10657        INDEX_SELECT_1D_PTX,
10658        "index_select_1d_kernel",
10659        device.ordinal() as u32,
10660    ) {
10661        Ok(f) => f,
10662        Err(_) => {
10663            // CPU fallback.
10664            let input_host = gpu_to_cpu(input, device)?;
10665            let indices_host = gpu_to_cpu(indices, device)?;
10666            let result: Vec<f32> = indices_host
10667                .iter()
10668                .map(|&idx_f| input_host[idx_f as usize])
10669                .collect();
10670            return cpu_to_gpu(&result, device);
10671        }
10672    };
10673
10674    let mut out = alloc_zeros_f32(n, device)?;
10675    let cfg = launch_cfg(n)?;
10676    let n_u32 = n as u32;
10677
10678    unsafe {
10679        stream
10680            .launch_builder(&f)
10681            .arg(input.inner())
10682            .arg(indices.inner())
10683            .arg(out.inner_mut())
10684            .arg(&n_u32)
10685            .launch(cfg)?;
10686    }
10687
10688    Ok(out)
10689}
10690
10691// ---------------------------------------------------------------------------
10692// Public API -- Scatter-add 1-D (backward of index_select)
10693// ---------------------------------------------------------------------------
10694
10695/// Scatter-add `grad_output` back into an output buffer of `input_len` elements,
10696/// using positions from `indices`.
10697///
10698/// `indices` is a GPU buffer of f32 values encoding integer indices.
10699/// Output: `out = zeros(input_len); for i: out[indices[i]] += grad_output[i]`
10700///
10701/// Uses atomic adds for safe concurrent accumulation.
10702#[cfg(feature = "cuda")]
10703pub fn gpu_scatter_add_1d(
10704    grad_output: &CudaBuffer<f32>,
10705    indices: &CudaBuffer<f32>,
10706    input_len: usize,
10707    device: &GpuDevice,
10708) -> GpuResult<CudaBuffer<f32>> {
10709    use cudarc::driver::PushKernelArg;
10710
10711    validate_unary(grad_output, device)?;
10712
10713    let n = grad_output.len();
10714    let ctx = device.context();
10715    let stream = device.stream();
10716
10717    let f = match crate::module_cache::get_or_compile(
10718        ctx,
10719        SCATTER_ADD_1D_PTX,
10720        "scatter_add_1d_kernel",
10721        device.ordinal() as u32,
10722    ) {
10723        Ok(f) => f,
10724        Err(_) => {
10725            // CPU fallback.
10726            let go_host = gpu_to_cpu(grad_output, device)?;
10727            let idx_host = gpu_to_cpu(indices, device)?;
10728            let mut result = vec![0.0f32; input_len];
10729            for (i, &idx_f) in idx_host.iter().enumerate() {
10730                result[idx_f as usize] += go_host[i];
10731            }
10732            return cpu_to_gpu(&result, device);
10733        }
10734    };
10735
10736    let mut out = alloc_zeros_f32(input_len, device)?;
10737    let cfg = launch_cfg(n)?;
10738    let n_u32 = n as u32;
10739
10740    unsafe {
10741        stream
10742            .launch_builder(&f)
10743            .arg(grad_output.inner())
10744            .arg(indices.inner())
10745            .arg(out.inner_mut())
10746            .arg(&n_u32)
10747            .launch(cfg)?;
10748    }
10749
10750    Ok(out)
10751}
10752
10753// ---------------------------------------------------------------------------
10754// Public API -- Masked fill
10755// ---------------------------------------------------------------------------
10756
10757/// Fill elements of `input` with `value` where `mask` is true.
10758///
10759/// `mask` is a GPU buffer of f32 values (1.0 = true, 0.0 = false).
10760/// Output: `out[i] = mask[i] >= 0.5 ? value : input[i]`
10761#[cfg(feature = "cuda")]
10762pub fn gpu_masked_fill(
10763    input: &CudaBuffer<f32>,
10764    mask: &CudaBuffer<f32>,
10765    value: f32,
10766    device: &GpuDevice,
10767) -> GpuResult<CudaBuffer<f32>> {
10768    use cudarc::driver::PushKernelArg;
10769
10770    validate_binary(input, mask, device)?;
10771
10772    let n = input.len();
10773    let ctx = device.context();
10774    let stream = device.stream();
10775
10776    let f = match crate::module_cache::get_or_compile(
10777        ctx,
10778        MASKED_FILL_PTX,
10779        "masked_fill_kernel",
10780        device.ordinal() as u32,
10781    ) {
10782        Ok(f) => f,
10783        Err(_) => {
10784            // CPU fallback.
10785            let input_host = gpu_to_cpu(input, device)?;
10786            let mask_host = gpu_to_cpu(mask, device)?;
10787            let result: Vec<f32> = input_host
10788                .iter()
10789                .zip(mask_host.iter())
10790                .map(|(&x, &m)| if m >= 0.5 { value } else { x })
10791                .collect();
10792            return cpu_to_gpu(&result, device);
10793        }
10794    };
10795
10796    let mut out = alloc_zeros_f32(n, device)?;
10797    let cfg = launch_cfg(n)?;
10798    let n_u32 = n as u32;
10799
10800    unsafe {
10801        stream
10802            .launch_builder(&f)
10803            .arg(input.inner())
10804            .arg(mask.inner())
10805            .arg(out.inner_mut())
10806            .arg(&value)
10807            .arg(&n_u32)
10808            .launch(cfg)?;
10809    }
10810
10811    Ok(out)
10812}
10813
10814// ---------------------------------------------------------------------------
10815// Public API -- Masked zero (backward of masked_fill)
10816// ---------------------------------------------------------------------------
10817
10818/// Zero out gradient at positions where `mask` is true.
10819///
10820/// `mask` is a GPU buffer of f32 values (1.0 = true, 0.0 = false).
10821/// Output: `out[i] = mask[i] >= 0.5 ? 0.0 : grad[i]`
10822#[cfg(feature = "cuda")]
10823pub fn gpu_masked_zero(
10824    grad: &CudaBuffer<f32>,
10825    mask: &CudaBuffer<f32>,
10826    device: &GpuDevice,
10827) -> GpuResult<CudaBuffer<f32>> {
10828    validate_binary(grad, mask, device)?;
10829
10830    if let Some(out) = try_launch_binary(grad, mask, device, MASKED_ZERO_PTX, "masked_zero_kernel")?
10831    {
10832        return Ok(out);
10833    }
10834
10835    // CPU fallback.
10836    let grad_host = gpu_to_cpu(grad, device)?;
10837    let mask_host = gpu_to_cpu(mask, device)?;
10838    let result: Vec<f32> = grad_host
10839        .iter()
10840        .zip(mask_host.iter())
10841        .map(|(&g, &m)| if m >= 0.5 { 0.0 } else { g })
10842        .collect();
10843    cpu_to_gpu(&result, device)
10844}
10845
10846// ---------------------------------------------------------------------------
10847// Public API -- Sigmoid backward
10848// ---------------------------------------------------------------------------
10849
10850/// Sigmoid backward: `out[i] = grad[i] * output[i] * (1 - output[i])`.
10851///
10852/// `grad` and `output` must have the same length and reside on `device`.
10853#[cfg(feature = "cuda")]
10854pub fn gpu_sigmoid_backward(
10855    grad: &CudaBuffer<f32>,
10856    output: &CudaBuffer<f32>,
10857    device: &GpuDevice,
10858) -> GpuResult<CudaBuffer<f32>> {
10859    validate_binary(grad, output, device)?;
10860
10861    if let Some(out) = try_launch_binary(
10862        grad,
10863        output,
10864        device,
10865        SIGMOID_BACKWARD_PTX,
10866        "sigmoid_backward_kernel",
10867    )? {
10868        return Ok(out);
10869    }
10870
10871    // CPU fallback
10872    let grad_host = gpu_to_cpu(grad, device)?;
10873    let output_host = gpu_to_cpu(output, device)?;
10874    let result: Vec<f32> = grad_host
10875        .iter()
10876        .zip(output_host.iter())
10877        .map(|(&g, &o)| g * o * (1.0 - o))
10878        .collect();
10879    cpu_to_gpu(&result, device)
10880}
10881
10882// ---------------------------------------------------------------------------
10883// Public API -- Tanh backward
10884// ---------------------------------------------------------------------------
10885
10886/// Tanh backward: `out[i] = grad[i] * (1 - output[i]^2)`.
10887///
10888/// `grad` and `output` must have the same length and reside on `device`.
10889#[cfg(feature = "cuda")]
10890pub fn gpu_tanh_backward(
10891    grad: &CudaBuffer<f32>,
10892    output: &CudaBuffer<f32>,
10893    device: &GpuDevice,
10894) -> GpuResult<CudaBuffer<f32>> {
10895    validate_binary(grad, output, device)?;
10896
10897    if let Some(out) = try_launch_binary(
10898        grad,
10899        output,
10900        device,
10901        TANH_BACKWARD_PTX,
10902        "tanh_backward_kernel",
10903    )? {
10904        return Ok(out);
10905    }
10906
10907    // CPU fallback
10908    let grad_host = gpu_to_cpu(grad, device)?;
10909    let output_host = gpu_to_cpu(output, device)?;
10910    let result: Vec<f32> = grad_host
10911        .iter()
10912        .zip(output_host.iter())
10913        .map(|(&g, &o)| g * (1.0 - o * o))
10914        .collect();
10915    cpu_to_gpu(&result, device)
10916}
10917
10918// ---------------------------------------------------------------------------
10919// Public API -- Softmax backward
10920// ---------------------------------------------------------------------------
10921
10922/// Softmax backward (row-wise): one block per row, shared-memory dot reduction.
10923///
10924/// For each row of length `cols`:
10925///   `dot = sum(grad[row] * output[row])`
10926///   `out[i] = output[i] * (grad[i] - dot)`
10927///
10928/// `rows` = total elements / cols. Both `grad` and `output` have `rows * cols` elements.
10929#[cfg(feature = "cuda")]
10930pub fn gpu_softmax_backward(
10931    grad: &CudaBuffer<f32>,
10932    output: &CudaBuffer<f32>,
10933    cols: usize,
10934    device: &GpuDevice,
10935) -> GpuResult<CudaBuffer<f32>> {
10936    use cudarc::driver::PushKernelArg;
10937
10938    validate_binary(grad, output, device)?;
10939
10940    let total = grad.len();
10941    let rows = total / cols;
10942
10943    let ctx = device.context();
10944    let stream = device.stream();
10945
10946    let f = match crate::module_cache::get_or_compile(
10947        ctx,
10948        SOFTMAX_BACKWARD_PTX,
10949        "softmax_backward_kernel",
10950        device.ordinal() as u32,
10951    ) {
10952        Ok(f) => f,
10953        Err(_) => {
10954            // CPU fallback
10955            let grad_host = gpu_to_cpu(grad, device)?;
10956            let output_host = gpu_to_cpu(output, device)?;
10957            let mut result = vec![0.0f32; total];
10958            for r in 0..rows {
10959                let base = r * cols;
10960                let mut dot = 0.0f32;
10961                for c in 0..cols {
10962                    dot += grad_host[base + c] * output_host[base + c];
10963                }
10964                for c in 0..cols {
10965                    result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
10966                }
10967            }
10968            return cpu_to_gpu(&result, device);
10969        }
10970    };
10971
10972    let mut out = alloc_zeros_f32(total, device)?;
10973    let rows_u32 = rows as u32;
10974    let cols_u32 = cols as u32;
10975
10976    // One block per row, 256 threads per block.
10977    let cfg = LaunchConfig {
10978        grid_dim: ((rows as u32).max(1), 1, 1),
10979        block_dim: (256, 1, 1),
10980        shared_mem_bytes: 256 * 4,
10981    };
10982
10983    unsafe {
10984        stream
10985            .launch_builder(&f)
10986            .arg(grad.inner())
10987            .arg(output.inner())
10988            .arg(out.inner_mut())
10989            .arg(&rows_u32)
10990            .arg(&cols_u32)
10991            .launch(cfg)?;
10992    }
10993
10994    Ok(out)
10995}
10996
10997// ---------------------------------------------------------------------------
10998// Public API -- LogSoftmax forward & backward
10999// ---------------------------------------------------------------------------
11000
11001/// Row-wise log-softmax on GPU.
11002///
11003/// For each row: `out[j] = x[j] - log(sum(exp(x - max(x))))`.
11004///
11005/// One block per row, 256 threads per block, shared-memory reductions for max
11006/// and sum-exp.
11007#[cfg(feature = "cuda")]
11008pub fn gpu_log_softmax(
11009    input: &CudaBuffer<f32>,
11010    cols: usize,
11011    device: &GpuDevice,
11012) -> GpuResult<CudaBuffer<f32>> {
11013    use cudarc::driver::PushKernelArg;
11014
11015    validate_unary(input, device)?;
11016
11017    let total = input.len();
11018    let rows = total / cols;
11019
11020    let ctx = device.context();
11021    let stream = device.stream();
11022
11023    let f = match crate::module_cache::get_or_compile(
11024        ctx,
11025        LOG_SOFTMAX_PTX,
11026        "log_softmax_kernel",
11027        device.ordinal() as u32,
11028    ) {
11029        Ok(f) => f,
11030        Err(_) => {
11031            // CPU fallback
11032            let host = gpu_to_cpu(input, device)?;
11033            let mut out = vec![0.0f32; total];
11034            for r in 0..rows {
11035                let base = r * cols;
11036                let mut max_v = f32::NEG_INFINITY;
11037                for c in 0..cols {
11038                    max_v = max_v.max(host[base + c]);
11039                }
11040                let mut sum_exp = 0.0f32;
11041                for c in 0..cols {
11042                    sum_exp += (host[base + c] - max_v).exp();
11043                }
11044                let log_sum_exp = max_v + sum_exp.ln();
11045                for c in 0..cols {
11046                    out[base + c] = host[base + c] - log_sum_exp;
11047                }
11048            }
11049            return cpu_to_gpu(&out, device);
11050        }
11051    };
11052
11053    let mut out = alloc_zeros_f32(total, device)?;
11054    let rows_u32 = rows as u32;
11055    let cols_u32 = cols as u32;
11056
11057    // One block per row, 256 threads per block.
11058    let cfg = LaunchConfig {
11059        grid_dim: ((rows as u32).max(1), 1, 1),
11060        block_dim: (256, 1, 1),
11061        shared_mem_bytes: 256 * 4,
11062    };
11063
11064    unsafe {
11065        stream
11066            .launch_builder(&f)
11067            .arg(input.inner())
11068            .arg(out.inner_mut())
11069            .arg(&rows_u32)
11070            .arg(&cols_u32)
11071            .launch(cfg)?;
11072    }
11073
11074    Ok(out)
11075}
11076
11077/// Row-wise log-softmax backward on GPU.
11078///
11079/// For each row:
11080///   `sum_grad = sum(grad[j])`
11081///   `out[j] = grad[j] - exp(output[j]) * sum_grad`
11082///
11083/// where `output` is the log-softmax forward output.
11084#[cfg(feature = "cuda")]
11085pub fn gpu_log_softmax_backward(
11086    grad: &CudaBuffer<f32>,
11087    output: &CudaBuffer<f32>,
11088    cols: usize,
11089    device: &GpuDevice,
11090) -> GpuResult<CudaBuffer<f32>> {
11091    use cudarc::driver::PushKernelArg;
11092
11093    validate_binary(grad, output, device)?;
11094
11095    let total = grad.len();
11096    let rows = total / cols;
11097
11098    let ctx = device.context();
11099    let stream = device.stream();
11100
11101    let f = match crate::module_cache::get_or_compile(
11102        ctx,
11103        LOG_SOFTMAX_BACKWARD_PTX,
11104        "log_softmax_backward_kernel",
11105        device.ordinal() as u32,
11106    ) {
11107        Ok(f) => f,
11108        Err(_) => {
11109            // CPU fallback
11110            let grad_host = gpu_to_cpu(grad, device)?;
11111            let output_host = gpu_to_cpu(output, device)?;
11112            let mut result = vec![0.0f32; total];
11113            for r in 0..rows {
11114                let base = r * cols;
11115                let mut sum_grad = 0.0f32;
11116                for c in 0..cols {
11117                    sum_grad += grad_host[base + c];
11118                }
11119                for c in 0..cols {
11120                    result[base + c] =
11121                        grad_host[base + c] - output_host[base + c].exp() * sum_grad;
11122                }
11123            }
11124            return cpu_to_gpu(&result, device);
11125        }
11126    };
11127
11128    let mut out = alloc_zeros_f32(total, device)?;
11129    let rows_u32 = rows as u32;
11130    let cols_u32 = cols as u32;
11131
11132    // One block per row, 256 threads per block.
11133    let cfg = LaunchConfig {
11134        grid_dim: ((rows as u32).max(1), 1, 1),
11135        block_dim: (256, 1, 1),
11136        shared_mem_bytes: 256 * 4,
11137    };
11138
11139    unsafe {
11140        stream
11141            .launch_builder(&f)
11142            .arg(grad.inner())
11143            .arg(output.inner())
11144            .arg(out.inner_mut())
11145            .arg(&rows_u32)
11146            .arg(&cols_u32)
11147            .launch(cfg)?;
11148    }
11149
11150    Ok(out)
11151}
11152
11153// ---------------------------------------------------------------------------
11154// Public API -- Sum axis
11155// ---------------------------------------------------------------------------
11156
11157/// Reduce along one axis of a tensor.
11158///
11159/// Thread i computes:
11160/// Full parallel sum reduction on GPU.
11161///
11162/// Uses a two-pass approach: first pass reduces `n` elements to `num_blocks`
11163/// partial sums via the `reduce_sum_kernel`, second pass reduces the partial
11164/// sums to a single scalar. For small inputs (< 256 blocks), the second pass
11165/// runs on CPU to avoid kernel launch overhead.
11166#[cfg(feature = "cuda")]
11167pub fn gpu_reduce_sum(
11168    a: &CudaBuffer<f32>,
11169    device: &GpuDevice,
11170) -> GpuResult<CudaBuffer<f32>> {
11171    use cudarc::driver::PushKernelArg;
11172
11173    let n = a.len();
11174    if n == 0 {
11175        return cpu_to_gpu(&[0.0f32], device);
11176    }
11177
11178    let ctx = device.context();
11179    let stream = device.stream();
11180
11181    let f = match crate::module_cache::get_or_compile(
11182        ctx,
11183        REDUCE_SUM_PTX,
11184        "reduce_sum_kernel",
11185        device.ordinal() as u32,
11186    ) {
11187        Ok(f) => f,
11188        Err(_) => {
11189            // CPU fallback
11190            let host = gpu_to_cpu(a, device)?;
11191            let total: f32 = host.iter().sum();
11192            return cpu_to_gpu(&[total], device);
11193        }
11194    };
11195
11196    // Pass 1: reduce to partial sums (one per block).
11197    const BLOCK: u32 = 256;
11198    let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
11199    // Cap blocks to avoid excessive partial sums.
11200    let num_blocks = num_blocks.min(1024);
11201
11202    let mut partials = alloc_zeros_f32(num_blocks as usize, device)?;
11203    let n_u32 = n as u32;
11204
11205    let cfg = cudarc::driver::LaunchConfig {
11206        grid_dim: (num_blocks.max(1), 1, 1),
11207        block_dim: (BLOCK, 1, 1),
11208        shared_mem_bytes: 0, // Statically allocated in PTX
11209    };
11210
11211    unsafe {
11212        stream
11213            .launch_builder(&f)
11214            .arg(a.inner())
11215            .arg(partials.inner_mut())
11216            .arg(&n_u32)
11217            .launch(cfg)?;
11218    }
11219
11220    // Pass 2: reduce partial sums.
11221    if num_blocks <= 1 {
11222        return Ok(partials);
11223    }
11224
11225    // For small number of blocks, reduce on CPU (cheaper than another kernel launch).
11226    if num_blocks <= 256 {
11227        let host_partials = gpu_to_cpu(&partials, device)?;
11228        let total: f32 = host_partials.iter().sum();
11229        return cpu_to_gpu(&[total], device);
11230    }
11231
11232    // For many blocks, recurse with another kernel launch.
11233    gpu_reduce_sum(&partials, device)
11234}
11235
11236/// Stub -- always returns [`GpuError::NoCudaFeature`].
11237#[cfg(not(feature = "cuda"))]
11238pub fn gpu_reduce_sum(
11239    _a: &CudaBuffer<f32>,
11240    _device: &GpuDevice,
11241) -> GpuResult<CudaBuffer<f32>> {
11242    Err(GpuError::NoCudaFeature)
11243}
11244
11245///   `output[i] = sum_{k=0}^{axis_size-1} input[outer_idx * axis_size * inner_size + k * inner_size + inner_idx]`
11246///
11247/// where `outer_idx = i / inner_size`, `inner_idx = i % inner_size`.
11248#[cfg(feature = "cuda")]
11249pub fn gpu_sum_axis(
11250    a: &CudaBuffer<f32>,
11251    outer: usize,
11252    axis_size: usize,
11253    inner: usize,
11254    device: &GpuDevice,
11255) -> GpuResult<CudaBuffer<f32>> {
11256    use cudarc::driver::PushKernelArg;
11257
11258    validate_unary(a, device)?;
11259
11260    let total_output = outer * inner;
11261    let ctx = device.context();
11262    let stream = device.stream();
11263
11264    let f = match crate::module_cache::get_or_compile(
11265        ctx,
11266        SUM_AXIS_PTX,
11267        "sum_axis_kernel",
11268        device.ordinal() as u32,
11269    ) {
11270        Ok(f) => f,
11271        Err(_) => {
11272            // CPU fallback
11273            let host = gpu_to_cpu(a, device)?;
11274            let mut result = vec![0.0f32; total_output];
11275            for (i, out) in result.iter_mut().enumerate() {
11276                let outer_idx = i / inner;
11277                let inner_idx = i % inner;
11278                let mut sum = 0.0f32;
11279                for k in 0..axis_size {
11280                    sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
11281                }
11282                *out = sum;
11283            }
11284            return cpu_to_gpu(&result, device);
11285        }
11286    };
11287
11288    let mut out = alloc_zeros_f32(total_output, device)?;
11289    let cfg = launch_cfg(total_output)?;
11290    let outer_u32 = outer as u32;
11291    let axis_size_u32 = axis_size as u32;
11292    let inner_u32 = inner as u32;
11293    let total_u32 = total_output as u32;
11294
11295    unsafe {
11296        stream
11297            .launch_builder(&f)
11298            .arg(a.inner())
11299            .arg(out.inner_mut())
11300            .arg(&outer_u32)
11301            .arg(&axis_size_u32)
11302            .arg(&inner_u32)
11303            .arg(&total_u32)
11304            .launch(cfg)?;
11305    }
11306
11307    Ok(out)
11308}
11309
11310// ---------------------------------------------------------------------------
11311// Public API -- Cumulative scan operations
11312// ---------------------------------------------------------------------------
11313
11314/// Cumulative sum (prefix sum) along an axis on GPU.
11315///
11316/// `output[base + k*inner] = sum_{j=0}^{k} input[base + j*inner]`
11317/// where `base = outer_idx * dim_size * inner + inner_idx`.
11318///
11319/// One thread per (outer_idx, inner_idx) pair; each thread does a sequential
11320/// scan along `dim_size` elements.
11321///
11322/// # Errors
11323///
11324/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
11325/// - [`GpuError::Driver`] on CUDA runtime errors.
11326#[cfg(feature = "cuda")]
11327pub fn gpu_cumsum(
11328    input: &CudaBuffer<f32>,
11329    outer: usize,
11330    dim_size: usize,
11331    inner: usize,
11332    device: &GpuDevice,
11333) -> GpuResult<CudaBuffer<f32>> {
11334    use cudarc::driver::PushKernelArg;
11335
11336    validate_unary(input, device)?;
11337
11338    let total = outer * dim_size * inner;
11339    let num_threads = outer * inner;
11340    let ctx = device.context();
11341    let stream = device.stream();
11342
11343    let f = match crate::module_cache::get_or_compile(
11344        ctx,
11345        CUMSUM_PTX,
11346        "cumsum_kernel",
11347        device.ordinal() as u32,
11348    ) {
11349        Ok(f) => f,
11350        Err(_) => {
11351            // CPU fallback
11352            let host = gpu_to_cpu(input, device)?;
11353            let mut result = vec![0.0f32; total];
11354            for i in 0..num_threads {
11355                let outer_idx = i / inner;
11356                let inner_idx = i % inner;
11357                let base = outer_idx * dim_size * inner + inner_idx;
11358                let mut acc = 0.0f32;
11359                for k in 0..dim_size {
11360                    let idx = base + k * inner;
11361                    acc += host[idx];
11362                    result[idx] = acc;
11363                }
11364            }
11365            return cpu_to_gpu(&result, device);
11366        }
11367    };
11368
11369    let mut out = alloc_zeros_f32(total, device)?;
11370    let cfg = launch_cfg(num_threads)?;
11371    let outer_u32 = outer as u32;
11372    let dim_size_u32 = dim_size as u32;
11373    let inner_u32 = inner as u32;
11374    let total_u32 = total as u32;
11375
11376    unsafe {
11377        stream
11378            .launch_builder(&f)
11379            .arg(input.inner())
11380            .arg(out.inner_mut())
11381            .arg(&outer_u32)
11382            .arg(&dim_size_u32)
11383            .arg(&inner_u32)
11384            .arg(&total_u32)
11385            .launch(cfg)?;
11386    }
11387
11388    Ok(out)
11389}
11390
11391/// Cumulative product (prefix product) along an axis on GPU.
11392///
11393/// `output[base + k*inner] = prod_{j=0}^{k} input[base + j*inner]`
11394/// where `base = outer_idx * dim_size * inner + inner_idx`.
11395///
11396/// # Errors
11397///
11398/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
11399/// - [`GpuError::Driver`] on CUDA runtime errors.
11400#[cfg(feature = "cuda")]
11401pub fn gpu_cumprod(
11402    input: &CudaBuffer<f32>,
11403    outer: usize,
11404    dim_size: usize,
11405    inner: usize,
11406    device: &GpuDevice,
11407) -> GpuResult<CudaBuffer<f32>> {
11408    use cudarc::driver::PushKernelArg;
11409
11410    validate_unary(input, device)?;
11411
11412    let total = outer * dim_size * inner;
11413    let num_threads = outer * inner;
11414    let ctx = device.context();
11415    let stream = device.stream();
11416
11417    let f = match crate::module_cache::get_or_compile(
11418        ctx,
11419        CUMPROD_PTX,
11420        "cumprod_kernel",
11421        device.ordinal() as u32,
11422    ) {
11423        Ok(f) => f,
11424        Err(_) => {
11425            // CPU fallback
11426            let host = gpu_to_cpu(input, device)?;
11427            let mut result = vec![0.0f32; total];
11428            for i in 0..num_threads {
11429                let outer_idx = i / inner;
11430                let inner_idx = i % inner;
11431                let base = outer_idx * dim_size * inner + inner_idx;
11432                let mut acc = 1.0f32;
11433                for k in 0..dim_size {
11434                    let idx = base + k * inner;
11435                    acc *= host[idx];
11436                    result[idx] = acc;
11437                }
11438            }
11439            return cpu_to_gpu(&result, device);
11440        }
11441    };
11442
11443    let mut out = alloc_zeros_f32(total, device)?;
11444    let cfg = launch_cfg(num_threads)?;
11445    let outer_u32 = outer as u32;
11446    let dim_size_u32 = dim_size as u32;
11447    let inner_u32 = inner as u32;
11448    let total_u32 = total as u32;
11449
11450    unsafe {
11451        stream
11452            .launch_builder(&f)
11453            .arg(input.inner())
11454            .arg(out.inner_mut())
11455            .arg(&outer_u32)
11456            .arg(&dim_size_u32)
11457            .arg(&inner_u32)
11458            .arg(&total_u32)
11459            .launch(cfg)?;
11460    }
11461
11462    Ok(out)
11463}
11464
11465/// Cumulative maximum (running max) along an axis on GPU.
11466///
11467/// `output[base + k*inner] = max_{j=0}^{k} input[base + j*inner]`
11468/// where `base = outer_idx * dim_size * inner + inner_idx`.
11469///
11470/// # Errors
11471///
11472/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
11473/// - [`GpuError::Driver`] on CUDA runtime errors.
11474#[cfg(feature = "cuda")]
11475pub fn gpu_cummax(
11476    input: &CudaBuffer<f32>,
11477    outer: usize,
11478    dim_size: usize,
11479    inner: usize,
11480    device: &GpuDevice,
11481) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
11482    use cudarc::driver::PushKernelArg;
11483
11484    validate_unary(input, device)?;
11485
11486    let total = outer * dim_size * inner;
11487    let num_threads = outer * inner;
11488    let ctx = device.context();
11489    let stream = device.stream();
11490
11491    let f = match crate::module_cache::get_or_compile(
11492        ctx,
11493        CUMMAX_PTX,
11494        "cummax_kernel",
11495        device.ordinal() as u32,
11496    ) {
11497        Ok(f) => f,
11498        Err(_) => {
11499            let host = gpu_to_cpu(input, device)?;
11500            let mut vals = vec![0.0f32; total];
11501            let mut idxs = vec![0.0f32; total];
11502            for i in 0..num_threads {
11503                let outer_idx = i / inner;
11504                let inner_idx = i % inner;
11505                let base = outer_idx * dim_size * inner + inner_idx;
11506                let mut acc = f32::NEG_INFINITY;
11507                let mut best = 0u32;
11508                for k in 0..dim_size {
11509                    let idx = base + k * inner;
11510                    if host[idx] > acc {
11511                        acc = host[idx];
11512                        best = k as u32;
11513                    }
11514                    vals[idx] = acc;
11515                    idxs[idx] = best as f32;
11516                }
11517            }
11518            return Ok((cpu_to_gpu(&vals, device)?, cpu_to_gpu(&idxs, device)?));
11519        }
11520    };
11521
11522    let mut out = alloc_zeros_f32(total, device)?;
11523    let mut out_idx = alloc_zeros_f32(total, device)?;
11524    let cfg = launch_cfg(num_threads)?;
11525    let outer_u32 = outer as u32;
11526    let dim_size_u32 = dim_size as u32;
11527    let inner_u32 = inner as u32;
11528    let total_u32 = total as u32;
11529
11530    unsafe {
11531        stream
11532            .launch_builder(&f)
11533            .arg(input.inner())
11534            .arg(out.inner_mut())
11535            .arg(out_idx.inner_mut())
11536            .arg(&outer_u32)
11537            .arg(&dim_size_u32)
11538            .arg(&inner_u32)
11539            .arg(&total_u32)
11540            .launch(cfg)?;
11541    }
11542
11543    Ok((out, out_idx))
11544}
11545
11546/// Cumulative minimum (running min) along an axis on GPU.
11547///
11548/// `output[base + k*inner] = min_{j=0}^{k} input[base + j*inner]`
11549/// where `base = outer_idx * dim_size * inner + inner_idx`.
11550///
11551/// # Errors
11552///
11553/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
11554/// - [`GpuError::Driver`] on CUDA runtime errors.
11555#[cfg(feature = "cuda")]
11556pub fn gpu_cummin(
11557    input: &CudaBuffer<f32>,
11558    outer: usize,
11559    dim_size: usize,
11560    inner: usize,
11561    device: &GpuDevice,
11562) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
11563    use cudarc::driver::PushKernelArg;
11564
11565    validate_unary(input, device)?;
11566
11567    let total = outer * dim_size * inner;
11568    let num_threads = outer * inner;
11569    let ctx = device.context();
11570    let stream = device.stream();
11571
11572    let f = match crate::module_cache::get_or_compile(
11573        ctx,
11574        CUMMIN_PTX,
11575        "cummin_kernel",
11576        device.ordinal() as u32,
11577    ) {
11578        Ok(f) => f,
11579        Err(_) => {
11580            let host = gpu_to_cpu(input, device)?;
11581            let mut vals = vec![0.0f32; total];
11582            let mut idxs = vec![0.0f32; total];
11583            for i in 0..num_threads {
11584                let outer_idx = i / inner;
11585                let inner_idx = i % inner;
11586                let base = outer_idx * dim_size * inner + inner_idx;
11587                let mut acc = f32::INFINITY;
11588                let mut best = 0u32;
11589                for k in 0..dim_size {
11590                    let idx = base + k * inner;
11591                    if host[idx] < acc {
11592                        acc = host[idx];
11593                        best = k as u32;
11594                    }
11595                    vals[idx] = acc;
11596                    idxs[idx] = best as f32;
11597                }
11598            }
11599            return Ok((cpu_to_gpu(&vals, device)?, cpu_to_gpu(&idxs, device)?));
11600        }
11601    };
11602
11603    let mut out = alloc_zeros_f32(total, device)?;
11604    let mut out_idx = alloc_zeros_f32(total, device)?;
11605    let cfg = launch_cfg(num_threads)?;
11606    let outer_u32 = outer as u32;
11607    let dim_size_u32 = dim_size as u32;
11608    let inner_u32 = inner as u32;
11609    let total_u32 = total as u32;
11610
11611    unsafe {
11612        stream
11613            .launch_builder(&f)
11614            .arg(input.inner())
11615            .arg(out.inner_mut())
11616            .arg(out_idx.inner_mut())
11617            .arg(&outer_u32)
11618            .arg(&dim_size_u32)
11619            .arg(&inner_u32)
11620            .arg(&total_u32)
11621            .launch(cfg)?;
11622    }
11623
11624    Ok((out, out_idx))
11625}
11626
11627/// Numerically stable log-cumulative-sum-exp along an axis on GPU.
11628///
11629/// `acc = log(exp(acc) + exp(x))` computed as `m + log(exp(acc-m) + exp(x-m))`
11630/// where `m = max(acc, x)` for numerical stability.
11631///
11632/// # Errors
11633///
11634/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
11635/// - [`GpuError::Driver`] on CUDA runtime errors.
11636#[cfg(feature = "cuda")]
11637pub fn gpu_logcumsumexp(
11638    input: &CudaBuffer<f32>,
11639    outer: usize,
11640    dim_size: usize,
11641    inner: usize,
11642    device: &GpuDevice,
11643) -> GpuResult<CudaBuffer<f32>> {
11644    use cudarc::driver::PushKernelArg;
11645
11646    validate_unary(input, device)?;
11647
11648    let total = outer * dim_size * inner;
11649    let num_threads = outer * inner;
11650    let ctx = device.context();
11651    let stream = device.stream();
11652
11653    let f = match crate::module_cache::get_or_compile(
11654        ctx,
11655        LOGCUMSUMEXP_PTX,
11656        "logcumsumexp_kernel",
11657        device.ordinal() as u32,
11658    ) {
11659        Ok(f) => f,
11660        Err(_) => {
11661            // CPU fallback
11662            let host = gpu_to_cpu(input, device)?;
11663            let mut result = vec![0.0f32; total];
11664            for i in 0..num_threads {
11665                let outer_idx = i / inner;
11666                let inner_idx = i % inner;
11667                let base = outer_idx * dim_size * inner + inner_idx;
11668                let mut acc = f32::NEG_INFINITY;
11669                for k in 0..dim_size {
11670                    let idx = base + k * inner;
11671                    let x = host[idx];
11672                    let m = acc.max(x);
11673                    acc = m + ((acc - m).exp() + (x - m).exp()).ln();
11674                    result[idx] = acc;
11675                }
11676            }
11677            return cpu_to_gpu(&result, device);
11678        }
11679    };
11680
11681    let mut out = alloc_zeros_f32(total, device)?;
11682    let cfg = launch_cfg(num_threads)?;
11683    let outer_u32 = outer as u32;
11684    let dim_size_u32 = dim_size as u32;
11685    let inner_u32 = inner as u32;
11686    let total_u32 = total as u32;
11687
11688    unsafe {
11689        stream
11690            .launch_builder(&f)
11691            .arg(input.inner())
11692            .arg(out.inner_mut())
11693            .arg(&outer_u32)
11694            .arg(&dim_size_u32)
11695            .arg(&inner_u32)
11696            .arg(&total_u32)
11697            .launch(cfg)?;
11698    }
11699
11700    Ok(out)
11701}
11702
11703// ---------------------------------------------------------------------------
11704// Public API -- Strided split
11705// ---------------------------------------------------------------------------
11706
11707/// Extract a sub-tensor along one axis entirely on GPU.
11708///
11709/// Given an input buffer representing a tensor with `total_along_axis` elements
11710/// along the split axis, extracts the slice `[split_offset .. split_offset + split_size]`
11711/// along that axis.
11712///
11713/// - `inner_size` = product of dimensions after the split axis.
11714/// - `n` = total number of output elements (outer * split_size * inner_size).
11715///
11716/// # Errors
11717///
11718/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
11719/// - [`GpuError::Driver`] on CUDA runtime errors.
11720#[cfg(feature = "cuda")]
11721pub fn gpu_strided_split(
11722    input: &CudaBuffer<f32>,
11723    total_along_axis: usize,
11724    split_offset: usize,
11725    split_size: usize,
11726    inner_size: usize,
11727    n: usize,
11728    device: &GpuDevice,
11729) -> GpuResult<CudaBuffer<f32>> {
11730    use cudarc::driver::PushKernelArg;
11731
11732    validate_unary(input, device)?;
11733
11734    let ctx = device.context();
11735    let stream = device.stream();
11736
11737    let f = match crate::module_cache::get_or_compile(
11738        ctx,
11739        STRIDED_SPLIT_PTX,
11740        "strided_split_kernel",
11741        device.ordinal() as u32,
11742    ) {
11743        Ok(f) => f,
11744        Err(_) => {
11745            // CPU fallback
11746            let host = gpu_to_cpu(input, device)?;
11747            let outer = n / (split_size * inner_size);
11748            let mut result = vec![0.0f32; n];
11749            for (i, out) in result.iter_mut().enumerate() {
11750                let outer_idx = i / (split_size * inner_size);
11751                let within = i % (split_size * inner_size);
11752                let src_idx =
11753                    outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
11754                *out = host[src_idx];
11755            }
11756            let _ = outer;
11757            return cpu_to_gpu(&result, device);
11758        }
11759    };
11760
11761    let mut out = alloc_zeros_f32(n, device)?;
11762    let cfg = launch_cfg(n)?;
11763    let total_ax_u32 = total_along_axis as u32;
11764    let offset_u32 = split_offset as u32;
11765    let split_sz_u32 = split_size as u32;
11766    let inner_u32 = inner_size as u32;
11767    let n_u32 = n as u32;
11768
11769    unsafe {
11770        stream
11771            .launch_builder(&f)
11772            .arg(input.inner())
11773            .arg(out.inner_mut())
11774            .arg(&total_ax_u32)
11775            .arg(&offset_u32)
11776            .arg(&split_sz_u32)
11777            .arg(&inner_u32)
11778            .arg(&n_u32)
11779            .launch(cfg)?;
11780    }
11781
11782    Ok(out)
11783}
11784
11785// ---------------------------------------------------------------------------
11786// Public API -- Strided cat
11787// ---------------------------------------------------------------------------
11788
11789/// Write a sub-tensor into a larger output buffer at an offset along one axis,
11790/// entirely on GPU.
11791///
11792/// Given an input buffer representing a chunk with `part_size` elements along
11793/// the cat axis, writes it into `output` at position `cat_offset` along that axis.
11794///
11795/// - `inner_size` = product of dimensions after the cat axis.
11796/// - `n` = total number of input elements (outer * part_size * inner_size).
11797///
11798/// # Safety
11799///
11800/// `output` must be large enough to hold the written region. The caller is
11801/// responsible for ensuring non-overlapping writes when multiple chunks are
11802/// written into the same output buffer.
11803///
11804/// # Errors
11805///
11806/// - [`GpuError::DeviceMismatch`] if buffers and `device` are on different devices.
11807/// - [`GpuError::Driver`] on CUDA runtime errors.
11808#[cfg(feature = "cuda")]
11809#[allow(clippy::too_many_arguments)]
11810pub fn gpu_strided_cat(
11811    input: &CudaBuffer<f32>,
11812    output: &mut CudaBuffer<f32>,
11813    total_along_axis: usize,
11814    cat_offset: usize,
11815    part_size: usize,
11816    inner_size: usize,
11817    n: usize,
11818    device: &GpuDevice,
11819) -> GpuResult<()> {
11820    use cudarc::driver::PushKernelArg;
11821
11822    validate_unary(input, device)?;
11823
11824    let ctx = device.context();
11825    let stream = device.stream();
11826
11827    let f = match crate::module_cache::get_or_compile(
11828        ctx,
11829        STRIDED_CAT_PTX,
11830        "strided_cat_kernel",
11831        device.ordinal() as u32,
11832    ) {
11833        Ok(f) => f,
11834        Err(_) => {
11835            // CPU fallback
11836            let host_in = gpu_to_cpu(input, device)?;
11837            let mut host_out = gpu_to_cpu(output, device)?;
11838            for (i, &val) in host_in.iter().enumerate().take(n) {
11839                let outer_idx = i / (part_size * inner_size);
11840                let within = i % (part_size * inner_size);
11841                let dst_idx =
11842                    outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
11843                host_out[dst_idx] = val;
11844            }
11845            *output = cpu_to_gpu(&host_out, device)?;
11846            return Ok(());
11847        }
11848    };
11849
11850    let cfg = launch_cfg(n)?;
11851    let total_ax_u32 = total_along_axis as u32;
11852    let offset_u32 = cat_offset as u32;
11853    let part_sz_u32 = part_size as u32;
11854    let inner_u32 = inner_size as u32;
11855    let n_u32 = n as u32;
11856
11857    unsafe {
11858        stream
11859            .launch_builder(&f)
11860            .arg(input.inner())
11861            .arg(output.inner_mut())
11862            .arg(&total_ax_u32)
11863            .arg(&offset_u32)
11864            .arg(&part_sz_u32)
11865            .arg(&inner_u32)
11866            .arg(&n_u32)
11867            .launch(cfg)?;
11868    }
11869
11870    Ok(())
11871}
11872
11873// ---------------------------------------------------------------------------
11874// Public API -- Strided copy (general N-d gather) -- CL-496
11875// ---------------------------------------------------------------------------
11876
11877/// Maximum rank supported by [`gpu_strided_copy`] and [`gpu_strided_copy_f64`].
11878/// Matches the unrolled PTX kernel's dimension count.
11879pub const STRIDED_COPY_MAX_DIMS: usize = 8;
11880
11881/// Pad-and-validate the (out_shape, src_strides) pair for the
11882/// strided-copy kernel.
11883///
11884/// Returns a fixed-size `[MAX_DIMS]` pair of arrays where:
11885/// - `out_stride[d]` is the contiguous output stride (in elements)
11886///   for that dim, with unused trailing dims filled with `n + 1` so
11887///   that `flat / out_stride[d] == 0` in the kernel (no contribution).
11888/// - `src_stride[d]` is the source stride (in elements) for that
11889///   dim, with unused trailing dims filled with 0 so the source-
11890///   offset contribution is zero.
11891///
11892/// `out_shape` and `src_strides` must have the same length, at most
11893/// `STRIDED_COPY_MAX_DIMS`. `n` is the product of `out_shape`.
11894#[cfg(feature = "cuda")]
11895fn pad_strided_copy_params(
11896    out_shape: &[usize],
11897    src_strides: &[isize],
11898    n: usize,
11899) -> GpuResult<([u32; STRIDED_COPY_MAX_DIMS], [u32; STRIDED_COPY_MAX_DIMS])> {
11900    if out_shape.len() != src_strides.len() {
11901        return Err(GpuError::ShapeMismatch {
11902            op: "strided_copy_pad",
11903            expected: vec![out_shape.len()],
11904            got: vec![src_strides.len()],
11905        });
11906    }
11907    if out_shape.len() > STRIDED_COPY_MAX_DIMS {
11908        return Err(GpuError::ShapeMismatch {
11909            op: "strided_copy_pad",
11910            expected: vec![STRIDED_COPY_MAX_DIMS],
11911            got: vec![out_shape.len()],
11912        });
11913    }
11914    // Reject negative source strides — the kernel treats them as u32
11915    // which would wrap around and produce garbage indices.
11916    for &s in src_strides {
11917        if s < 0 {
11918            return Err(GpuError::ShapeMismatch {
11919                op: "strided_copy_pad_negative_stride",
11920                expected: vec![0],
11921                got: vec![s.unsigned_abs()],
11922            });
11923        }
11924    }
11925
11926    let rank = out_shape.len();
11927    // Compute contiguous output strides: stride[rank-1] = 1,
11928    // stride[d] = stride[d+1] * shape[d+1].
11929    let mut out_stride = [0u32; STRIDED_COPY_MAX_DIMS];
11930    if rank > 0 {
11931        let mut acc: usize = 1;
11932        for d in (0..rank).rev() {
11933            if acc > u32::MAX as usize {
11934                return Err(GpuError::ShapeMismatch {
11935                    op: "strided_copy_stride_overflow",
11936                    expected: vec![u32::MAX as usize],
11937                    got: vec![acc],
11938                });
11939            }
11940            out_stride[d] = acc as u32;
11941            acc = acc.saturating_mul(out_shape[d]);
11942        }
11943    }
11944
11945    // Pad unused dims with `n + 1` so `flat / out_stride[d] == 0`
11946    // in the kernel (any flat < n is strictly less than n + 1).
11947    let pad_val = (n as u32).saturating_add(1).max(1);
11948    out_stride[rank..STRIDED_COPY_MAX_DIMS].fill(pad_val);
11949
11950    // src_stride with 0 fill for unused dims (no contribution).
11951    let mut src_stride_out = [0u32; STRIDED_COPY_MAX_DIMS];
11952    for d in 0..rank {
11953        let s = src_strides[d];
11954        if s as usize > u32::MAX as usize {
11955            return Err(GpuError::ShapeMismatch {
11956                op: "strided_copy_src_stride_overflow",
11957                expected: vec![u32::MAX as usize],
11958                got: vec![s as usize],
11959            });
11960        }
11961        src_stride_out[d] = s as u32;
11962    }
11963
11964    Ok((out_stride, src_stride_out))
11965}
11966
11967/// Gather a non-contiguous strided view of `input` into a new
11968/// contiguous output buffer, entirely on GPU. CL-496.
11969///
11970/// # Arguments
11971///
11972/// * `input`      — the storage backing the strided view. Must be
11973///   on `device`.
11974/// * `out_shape`  — shape of the contiguous output (and of the
11975///   logical view). `out_shape.len() <= STRIDED_COPY_MAX_DIMS`.
11976/// * `src_strides` — source element strides per dim, aligned with
11977///   `out_shape`. Must be non-negative (no reverse views yet).
11978/// * `src_offset`  — base element offset into `input` for the view.
11979/// * `device`     — CUDA device.
11980///
11981/// # Returns
11982///
11983/// A contiguous `CudaBuffer<f32>` with `product(out_shape)` elements.
11984///
11985/// # Errors
11986///
11987/// - [`GpuError::DeviceMismatch`] if `input` and `device` differ.
11988/// - [`GpuError::ShapeMismatch`] on rank mismatch, too many dims,
11989///   negative strides, or stride overflow of `u32::MAX`.
11990/// - [`GpuError::Driver`] on CUDA runtime errors.
11991#[cfg(feature = "cuda")]
11992pub fn gpu_strided_copy(
11993    input: &CudaBuffer<f32>,
11994    out_shape: &[usize],
11995    src_strides: &[isize],
11996    src_offset: usize,
11997    device: &GpuDevice,
11998) -> GpuResult<CudaBuffer<f32>> {
11999    use cudarc::driver::PushKernelArg;
12000
12001    validate_unary(input, device)?;
12002
12003    let n: usize = out_shape.iter().product();
12004    let (out_stride, src_stride) = pad_strided_copy_params(out_shape, src_strides, n)?;
12005
12006    if n == 0 {
12007        return alloc_zeros_f32(0, device);
12008    }
12009
12010    let ctx = device.context();
12011    let stream = device.stream();
12012
12013    let f = match crate::module_cache::get_or_compile(
12014        ctx,
12015        STRIDED_COPY_PTX,
12016        "strided_copy_kernel",
12017        device.ordinal() as u32,
12018    ) {
12019        Ok(f) => f,
12020        Err(_) => {
12021            // CPU fallback — decode indices on the host.
12022            let host = gpu_to_cpu(input, device)?;
12023            let mut result = vec![0.0f32; n];
12024            for (i, slot) in result.iter_mut().enumerate() {
12025                let mut flat = i as u32;
12026                let mut src_idx = src_offset as u32;
12027                for d in 0..STRIDED_COPY_MAX_DIMS {
12028                    let os = out_stride[d];
12029                    let ss = src_stride[d];
12030                    let coord = flat / os;
12031                    flat -= coord * os;
12032                    src_idx += coord * ss;
12033                }
12034                *slot = host[src_idx as usize];
12035            }
12036            return cpu_to_gpu(&result, device);
12037        }
12038    };
12039
12040    let mut out = alloc_zeros_f32(n, device)?;
12041    let cfg = launch_cfg(n)?;
12042    let src_offset_u32 = src_offset as u32;
12043    let n_u32 = n as u32;
12044
12045    unsafe {
12046        stream
12047            .launch_builder(&f)
12048            .arg(input.inner())
12049            .arg(out.inner_mut())
12050            .arg(&src_offset_u32)
12051            .arg(&n_u32)
12052            .arg(&out_stride[0])
12053            .arg(&out_stride[1])
12054            .arg(&out_stride[2])
12055            .arg(&out_stride[3])
12056            .arg(&out_stride[4])
12057            .arg(&out_stride[5])
12058            .arg(&out_stride[6])
12059            .arg(&out_stride[7])
12060            .arg(&src_stride[0])
12061            .arg(&src_stride[1])
12062            .arg(&src_stride[2])
12063            .arg(&src_stride[3])
12064            .arg(&src_stride[4])
12065            .arg(&src_stride[5])
12066            .arg(&src_stride[6])
12067            .arg(&src_stride[7])
12068            .launch(cfg)?;
12069    }
12070
12071    Ok(out)
12072}
12073
12074/// f64 variant of [`gpu_strided_copy`]. CL-496.
12075#[cfg(feature = "cuda")]
12076pub fn gpu_strided_copy_f64(
12077    input: &CudaBuffer<f64>,
12078    out_shape: &[usize],
12079    src_strides: &[isize],
12080    src_offset: usize,
12081    device: &GpuDevice,
12082) -> GpuResult<CudaBuffer<f64>> {
12083    use cudarc::driver::PushKernelArg;
12084    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
12085
12086    validate_device(input, device)?;
12087
12088    let n: usize = out_shape.iter().product();
12089    let (out_stride, src_stride) = pad_strided_copy_params(out_shape, src_strides, n)?;
12090
12091    if n == 0 {
12092        return alloc_zeros_f64(0, device);
12093    }
12094
12095    let ctx = device.context();
12096    let stream = device.stream();
12097
12098    let ptx = get_f64_ptx(
12099        &CACHE,
12100        STRIDED_COPY_PTX,
12101        "strided_copy_kernel",
12102        "strided_copy_f64_kernel",
12103    );
12104    let f = match crate::module_cache::get_or_compile(
12105        ctx,
12106        ptx,
12107        "strided_copy_f64_kernel",
12108        device.ordinal() as u32,
12109    ) {
12110        Ok(f) => f,
12111        Err(_) => {
12112            let host = gpu_to_cpu(input, device)?;
12113            let mut result = vec![0.0f64; n];
12114            for (i, slot) in result.iter_mut().enumerate() {
12115                let mut flat = i as u32;
12116                let mut src_idx = src_offset as u32;
12117                for d in 0..STRIDED_COPY_MAX_DIMS {
12118                    let os = out_stride[d];
12119                    let ss = src_stride[d];
12120                    let coord = flat / os;
12121                    flat -= coord * os;
12122                    src_idx += coord * ss;
12123                }
12124                *slot = host[src_idx as usize];
12125            }
12126            return cpu_to_gpu(&result, device);
12127        }
12128    };
12129
12130    let mut out = alloc_zeros_f64(n, device)?;
12131    let cfg = launch_cfg(n)?;
12132    let src_offset_u32 = src_offset as u32;
12133    let n_u32 = n as u32;
12134
12135    unsafe {
12136        stream
12137            .launch_builder(&f)
12138            .arg(input.inner())
12139            .arg(out.inner_mut())
12140            .arg(&src_offset_u32)
12141            .arg(&n_u32)
12142            .arg(&out_stride[0])
12143            .arg(&out_stride[1])
12144            .arg(&out_stride[2])
12145            .arg(&out_stride[3])
12146            .arg(&out_stride[4])
12147            .arg(&out_stride[5])
12148            .arg(&out_stride[6])
12149            .arg(&out_stride[7])
12150            .arg(&src_stride[0])
12151            .arg(&src_stride[1])
12152            .arg(&src_stride[2])
12153            .arg(&src_stride[3])
12154            .arg(&src_stride[4])
12155            .arg(&src_stride[5])
12156            .arg(&src_stride[6])
12157            .arg(&src_stride[7])
12158            .launch(cfg)?;
12159    }
12160
12161    Ok(out)
12162}
12163
12164/// Scalar multiply: `out[i] = a[i] * scalar`.
12165///
12166/// Multiplies every element by a constant float value on the GPU.
12167///
12168/// # Errors
12169///
12170/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different CUDA devices.
12171/// - [`GpuError::Driver`] on CUDA runtime errors.
12172#[cfg(feature = "cuda")]
12173pub fn gpu_scale(
12174    a: &CudaBuffer<f32>,
12175    scalar: f32,
12176    device: &GpuDevice,
12177) -> GpuResult<CudaBuffer<f32>> {
12178    use cudarc::driver::PushKernelArg;
12179
12180    validate_unary(a, device)?;
12181
12182    let n = a.len();
12183    let ctx = device.context();
12184    let stream = device.stream();
12185
12186    let f = match crate::module_cache::get_or_compile(
12187        ctx,
12188        SCALE_PTX,
12189        "scale_kernel",
12190        device.ordinal() as u32,
12191    ) {
12192        Ok(f) => f,
12193        Err(_) => {
12194            // CPU fallback
12195            let host = gpu_to_cpu(a, device)?;
12196            let result: Vec<f32> = host.iter().map(|&x| x * scalar).collect();
12197            return cpu_to_gpu(&result, device);
12198        }
12199    };
12200
12201    let mut out = alloc_zeros_f32(n, device)?;
12202    let cfg = launch_cfg(n)?;
12203    let n_u32 = n as u32;
12204
12205    unsafe {
12206        stream
12207            .launch_builder(&f)
12208            .arg(a.inner())
12209            .arg(out.inner_mut())
12210            .arg(&scalar)
12211            .arg(&n_u32)
12212            .launch(cfg)?;
12213    }
12214
12215    Ok(out)
12216}
12217
12218// ---------------------------------------------------------------------------
12219// Public API -- softmax
12220// ---------------------------------------------------------------------------
12221
12222/// Row-wise softmax on GPU: one thread block per row, shared-memory reduction.
12223///
12224/// `rows` = product of all dims except the last. `cols` = last dim size.
12225#[cfg(feature = "cuda")]
12226pub fn gpu_softmax(
12227    input: &CudaBuffer<f32>,
12228    rows: usize,
12229    cols: usize,
12230    device: &GpuDevice,
12231) -> GpuResult<CudaBuffer<f32>> {
12232    use cudarc::driver::PushKernelArg;
12233
12234    validate_unary(input, device)?;
12235
12236    let ctx = device.context();
12237    let stream = device.stream();
12238
12239    let f = match crate::module_cache::get_or_compile(
12240        ctx,
12241        SOFTMAX_PTX,
12242        "softmax_kernel",
12243        device.ordinal() as u32,
12244    ) {
12245        Ok(f) => f,
12246        Err(_) => {
12247            // CPU fallback.
12248            let host = gpu_to_cpu(input, device)?;
12249            let mut out = vec![0.0f32; host.len()];
12250            for r in 0..rows {
12251                let base = r * cols;
12252                let mut max_v = f32::NEG_INFINITY;
12253                for c in 0..cols {
12254                    max_v = max_v.max(host[base + c]);
12255                }
12256                let mut sum = 0.0f32;
12257                for c in 0..cols {
12258                    let e = (host[base + c] - max_v).exp();
12259                    out[base + c] = e;
12260                    sum += e;
12261                }
12262                let inv = 1.0 / sum;
12263                for c in 0..cols {
12264                    out[base + c] *= inv;
12265                }
12266            }
12267            return cpu_to_gpu(&out, device);
12268        }
12269    };
12270
12271    let mut out = alloc_zeros_f32(rows * cols, device)?;
12272    let rows_u32 = rows as u32;
12273    let cols_u32 = cols as u32;
12274
12275    // One block per row, 256 threads per block.
12276    let cfg = LaunchConfig {
12277        grid_dim: ((rows as u32).max(1), 1, 1),
12278        block_dim: (256, 1, 1),
12279        shared_mem_bytes: 256 * 4, // sdata[256] f32
12280    };
12281
12282    unsafe {
12283        stream
12284            .launch_builder(&f)
12285            .arg(input.inner())
12286            .arg(out.inner_mut())
12287            .arg(&rows_u32)
12288            .arg(&cols_u32)
12289            .launch(cfg)?;
12290    }
12291
12292    Ok(out)
12293}
12294
12295// ---------------------------------------------------------------------------
12296// Public API -- dropout
12297// ---------------------------------------------------------------------------
12298
12299/// Inverted dropout on GPU: `out[i] = input[i] * scale` or `0` with probability `p`.
12300///
12301/// `threshold` = `(p * u32::MAX as f64) as u32` — the RNG cutoff.
12302/// `scale` = `1.0 / (1.0 - p)`.
12303/// `seed` = random seed for the RNG.
12304///
12305/// **Known limitation**: This kernel uses a simple per-element hash
12306/// (`tid * 2654435761 ^ seed` with xorshift mixing), not the full
12307/// Philox 4x32-10 counter-based RNG that PyTorch uses. A proper Philox
12308/// dropout kernel would generate the mask via `philox_uniform_kernel`
12309/// and then threshold — producing higher-quality randomness and exact
12310/// reproducibility across CPU/GPU. The current hash is sufficient for
12311/// training but should be upgraded for research requiring strict
12312/// statistical properties.
12313#[cfg(feature = "cuda")]
12314pub fn gpu_dropout(
12315    input: &CudaBuffer<f32>,
12316    threshold: u32,
12317    scale: f32,
12318    seed: u32,
12319    device: &GpuDevice,
12320) -> GpuResult<CudaBuffer<f32>> {
12321    use cudarc::driver::PushKernelArg;
12322
12323    validate_unary(input, device)?;
12324
12325    let n = input.len();
12326    let ctx = device.context();
12327    let stream = device.stream();
12328
12329    let f = match crate::module_cache::get_or_compile(
12330        ctx,
12331        DROPOUT_PTX,
12332        "dropout_kernel",
12333        device.ordinal() as u32,
12334    ) {
12335        Ok(f) => f,
12336        Err(_) => {
12337            // CPU fallback.
12338            let host = gpu_to_cpu(input, device)?;
12339            // Stateless per-element hash matching the GPU kernel: each element
12340            // independently computes its own pseudorandom value from (tid, seed)
12341            // with no state carried between elements.
12342            let result: Vec<f32> = host
12343                .iter()
12344                .enumerate()
12345                .map(|(i, &x)| {
12346                    let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
12347                    r ^= r << 13;
12348                    r ^= r >> 17;
12349                    r ^= r << 5;
12350                    if r < threshold { 0.0 } else { x * scale }
12351                })
12352                .collect();
12353            return cpu_to_gpu(&result, device);
12354        }
12355    };
12356
12357    let mut out = alloc_zeros_f32(n, device)?;
12358    let cfg = launch_cfg(n)?;
12359    let n_u32 = n as u32;
12360
12361    unsafe {
12362        stream
12363            .launch_builder(&f)
12364            .arg(input.inner())
12365            .arg(out.inner_mut())
12366            .arg(&n_u32)
12367            .arg(&threshold)
12368            .arg(&scale)
12369            .arg(&seed)
12370            .launch(cfg)?;
12371    }
12372
12373    Ok(out)
12374}
12375
12376/// Elementwise dropout for f64 tensors.
12377#[cfg(feature = "cuda")]
12378pub fn gpu_dropout_f64(
12379    input: &CudaBuffer<f64>,
12380    threshold: u32,
12381    scale: f64,
12382    seed: u32,
12383    device: &GpuDevice,
12384) -> GpuResult<CudaBuffer<f64>> {
12385    use cudarc::driver::PushKernelArg;
12386    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
12387
12388    let n = input.len();
12389    let ctx = device.context();
12390    let stream = device.stream();
12391
12392    let ptx = get_f64_ptx(&CACHE, DROPOUT_PTX, "dropout_kernel", "dropout_f64_kernel");
12393    let f = match crate::module_cache::get_or_compile(
12394        ctx, ptx, "dropout_f64_kernel", device.ordinal() as u32,
12395    ) {
12396        Ok(f) => f,
12397        Err(_) => {
12398            let host = gpu_to_cpu(input, device)?;
12399            let result: Vec<f64> = host
12400                .iter()
12401                .enumerate()
12402                .map(|(i, &x)| {
12403                    let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
12404                    r ^= r << 13;
12405                    r ^= r >> 17;
12406                    r ^= r << 5;
12407                    if r < threshold { 0.0 } else { x * scale }
12408                })
12409                .collect();
12410            return cpu_to_gpu(&result, device);
12411        }
12412    };
12413
12414    let mut out = alloc_zeros_f64(n, device)?;
12415    let cfg = launch_cfg(n)?;
12416    let n_u32 = n as u32;
12417
12418    unsafe {
12419        stream
12420            .launch_builder(&f)
12421            .arg(input.inner())
12422            .arg(out.inner_mut())
12423            .arg(&n_u32)
12424            .arg(&threshold)
12425            .arg(&scale)
12426            .arg(&seed)
12427            .launch(cfg)?;
12428    }
12429
12430    Ok(out)
12431}
12432
12433#[cfg(not(feature = "cuda"))]
12434pub fn gpu_dropout_f64(_input: &CudaBuffer<f64>, _threshold: u32, _scale: f64, _seed: u32, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
12435
12436// ---------------------------------------------------------------------------
12437// Public API -- 2D transpose
12438// ---------------------------------------------------------------------------
12439
12440/// 2D matrix transpose on GPU: `[M, N]` -> `[N, M]`.
12441#[cfg(feature = "cuda")]
12442pub fn gpu_transpose_2d(
12443    input: &CudaBuffer<f32>,
12444    m: usize,
12445    n: usize,
12446    device: &GpuDevice,
12447) -> GpuResult<CudaBuffer<f32>> {
12448    use cudarc::driver::PushKernelArg;
12449
12450    validate_unary(input, device)?;
12451
12452    let total = m * n;
12453    let ctx = device.context();
12454    let stream = device.stream();
12455
12456    let f = match crate::module_cache::get_or_compile(
12457        ctx,
12458        TRANSPOSE_2D_PTX,
12459        "transpose_2d_kernel",
12460        device.ordinal() as u32,
12461    ) {
12462        Ok(f) => f,
12463        Err(_) => {
12464            // CPU fallback.
12465            let host = gpu_to_cpu(input, device)?;
12466            let mut out = vec![0.0f32; total];
12467            for i in 0..m {
12468                for j in 0..n {
12469                    out[j * m + i] = host[i * n + j];
12470                }
12471            }
12472            return cpu_to_gpu(&out, device);
12473        }
12474    };
12475
12476    let mut out = alloc_zeros_f32(total, device)?;
12477    let cfg = launch_cfg(total)?;
12478    let m_u32 = m as u32;
12479    let n_u32 = n as u32;
12480    let total_u32 = total as u32;
12481
12482    unsafe {
12483        stream
12484            .launch_builder(&f)
12485            .arg(input.inner())
12486            .arg(out.inner_mut())
12487            .arg(&m_u32)
12488            .arg(&n_u32)
12489            .arg(&total_u32)
12490            .launch(cfg)?;
12491    }
12492
12493    Ok(out)
12494}
12495
12496// ---------------------------------------------------------------------------
12497// Public API -- 4D permute (0,2,1,3)
12498// ---------------------------------------------------------------------------
12499
12500/// Permute a 4D tensor from `[d0, d1, d2, d3]` to `[d0, d2, d1, d3]` on GPU.
12501/// Used for attention head reshaping: `[B, S, H, D_h]` -> `[B, H, S, D_h]`.
12502#[cfg(feature = "cuda")]
12503pub fn gpu_permute_0213(
12504    input: &CudaBuffer<f32>,
12505    d0: usize,
12506    d1: usize,
12507    d2: usize,
12508    d3: usize,
12509    device: &GpuDevice,
12510) -> GpuResult<CudaBuffer<f32>> {
12511    use cudarc::driver::PushKernelArg;
12512
12513    validate_unary(input, device)?;
12514
12515    let total = d0 * d1 * d2 * d3;
12516    let ctx = device.context();
12517    let stream = device.stream();
12518
12519    let f = match crate::module_cache::get_or_compile(
12520        ctx,
12521        PERMUTE_0213_PTX,
12522        "permute_0213_kernel",
12523        device.ordinal() as u32,
12524    ) {
12525        Ok(f) => f,
12526        Err(_) => {
12527            // CPU fallback.
12528            let host = gpu_to_cpu(input, device)?;
12529            let mut out = vec![0.0f32; total];
12530            for i0 in 0..d0 {
12531                for i1 in 0..d1 {
12532                    for i2 in 0..d2 {
12533                        for i3 in 0..d3 {
12534                            let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
12535                            let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
12536                            out[out_idx] = host[in_idx];
12537                        }
12538                    }
12539                }
12540            }
12541            return cpu_to_gpu(&out, device);
12542        }
12543    };
12544
12545    let mut out = alloc_zeros_f32(total, device)?;
12546    let cfg = launch_cfg(total)?;
12547    let d0_u32 = d0 as u32;
12548    let d1_u32 = d1 as u32;
12549    let d2_u32 = d2 as u32;
12550    let d3_u32 = d3 as u32;
12551    let total_u32 = total as u32;
12552
12553    unsafe {
12554        stream
12555            .launch_builder(&f)
12556            .arg(input.inner())
12557            .arg(out.inner_mut())
12558            .arg(&d0_u32)
12559            .arg(&d1_u32)
12560            .arg(&d2_u32)
12561            .arg(&d3_u32)
12562            .arg(&total_u32)
12563            .launch(cfg)?;
12564    }
12565
12566    Ok(out)
12567}
12568
12569// ---------------------------------------------------------------------------
12570// Public API -- Small matmul (bypasses cuBLAS JIT)
12571// ---------------------------------------------------------------------------
12572
12573/// Small matrix multiply using our own PTX kernel. Avoids cuBLAS JIT
12574/// compilation overhead for tiny matrices where JIT cost > compute cost.
12575///
12576/// `a`: `[M, K]`, `b`: `[K, N]` → `c`: `[M, N]`.
12577#[cfg(feature = "cuda")]
12578pub fn gpu_small_matmul(
12579    a: &CudaBuffer<f32>,
12580    b: &CudaBuffer<f32>,
12581    m: usize,
12582    k: usize,
12583    n: usize,
12584    device: &GpuDevice,
12585) -> GpuResult<CudaBuffer<f32>> {
12586    use cudarc::driver::PushKernelArg;
12587
12588    let total = m * n;
12589    let ctx = device.context();
12590    let stream = device.stream();
12591
12592    let f = match crate::module_cache::get_or_compile(
12593        ctx,
12594        SMALL_MATMUL_PTX,
12595        "small_matmul_kernel",
12596        device.ordinal() as u32,
12597    ) {
12598        Ok(f) => f,
12599        Err(_) => {
12600            // Fall back to cuBLAS if our kernel can't compile.
12601            return crate::blas::gpu_matmul_f32(a, b, m, k, n, device);
12602        }
12603    };
12604
12605    let mut c = alloc_zeros_f32(total, device)?;
12606    let cfg = launch_cfg(total)?;
12607    let m_u32 = m as u32;
12608    let k_u32 = k as u32;
12609    let n_u32 = n as u32;
12610    let total_u32 = total as u32;
12611
12612    unsafe {
12613        stream
12614            .launch_builder(&f)
12615            .arg(a.inner())
12616            .arg(b.inner())
12617            .arg(c.inner_mut())
12618            .arg(&m_u32)
12619            .arg(&k_u32)
12620            .arg(&n_u32)
12621            .arg(&total_u32)
12622            .launch(cfg)?;
12623    }
12624
12625    Ok(c)
12626}
12627
12628/// Small batched matmul: C[i] = A[i] @ B[i] for i in 0..batch.
12629/// Uses the small_matmul_kernel by reshaping the problem: treat it as a single
12630/// large matmul of [batch*M, K] @ [K, N] — but that doesn't work because B is
12631/// batched. Instead, we use a modified approach: thread `idx` computes element
12632/// (batch_i, row, col) where batch_i = idx / (M*N).
12633///
12634/// For simplicity and correctness, we fall back to cpu_bmm for now when
12635/// cuBLAS fails, but route through gpu_small_matmul for the single-matrix case.
12636#[cfg(feature = "cuda")]
12637pub fn gpu_small_bmm(
12638    a: &CudaBuffer<f32>,
12639    b: &CudaBuffer<f32>,
12640    batch: usize,
12641    m: usize,
12642    k: usize,
12643    n: usize,
12644    device: &GpuDevice,
12645) -> GpuResult<CudaBuffer<f32>> {
12646    // For batch=1, just use the single matmul kernel.
12647    if batch == 1 {
12648        return gpu_small_matmul(a, b, m, k, n, device);
12649    }
12650    // For batched case, fall back to cuBLAS (the batched PTX kernel is complex).
12651    // The main win is from the single-matrix decode case (batch=1 for attention scores).
12652    crate::blas::gpu_bmm_f32(a, b, batch, m, k, n, device)
12653}
12654
12655// ---------------------------------------------------------------------------
12656// Public API -- Embedding lookup (GPU-native)
12657// ---------------------------------------------------------------------------
12658
12659/// GPU embedding lookup: reads token ID from `idx` (single f32 on GPU),
12660/// gathers row from `weight` `[V, D]`, writes to `out` `[D]`.
12661/// Entire operation stays on GPU — no CPU involvement.
12662#[cfg(feature = "cuda")]
12663pub fn gpu_embed_lookup(
12664    idx: &CudaBuffer<f32>,
12665    weight: &CudaBuffer<f32>,
12666    d: usize,
12667    device: &GpuDevice,
12668) -> GpuResult<CudaBuffer<f32>> {
12669    use cudarc::driver::PushKernelArg;
12670
12671    let ctx = device.context();
12672    let stream = device.stream();
12673
12674    let f = match crate::module_cache::get_or_compile(
12675        ctx,
12676        EMBED_LOOKUP_PTX,
12677        "embed_lookup_kernel",
12678        device.ordinal() as u32,
12679    ) {
12680        Ok(f) => f,
12681        Err(_) => {
12682            // CPU fallback.
12683            let idx_host = gpu_to_cpu(idx, device)?;
12684            let weight_host = gpu_to_cpu(weight, device)?;
12685            let row = idx_host[0] as usize;
12686            let start = row * d;
12687            let out = weight_host[start..start + d].to_vec();
12688            return cpu_to_gpu(&out, device);
12689        }
12690    };
12691
12692    let mut out = alloc_zeros_f32(d, device)?;
12693    let cfg = launch_cfg(d)?;
12694    let d_u32 = d as u32;
12695
12696    unsafe {
12697        stream
12698            .launch_builder(&f)
12699            .arg(idx.inner())
12700            .arg(weight.inner())
12701            .arg(out.inner_mut())
12702            .arg(&d_u32)
12703            .launch(cfg)?;
12704    }
12705
12706    Ok(out)
12707}
12708
12709// ---------------------------------------------------------------------------
12710// Public API -- Slice write (for KV cache)
12711// ---------------------------------------------------------------------------
12712
12713/// Write `src` of shape `[N, D]` into row `pos` of `dst` of shape `[N, max_len, D]`.
12714/// This is an in-place GPU operation — `dst` is modified.
12715#[cfg(feature = "cuda")]
12716pub fn gpu_slice_write(
12717    src: &CudaBuffer<f32>,
12718    dst: &mut CudaBuffer<f32>,
12719    n_batch: usize,
12720    d: usize,
12721    max_len: usize,
12722    pos: usize,
12723    device: &GpuDevice,
12724) -> GpuResult<()> {
12725    use cudarc::driver::PushKernelArg;
12726
12727    let total = n_batch * d;
12728    let ctx = device.context();
12729    let stream = device.stream();
12730
12731    let f = match crate::module_cache::get_or_compile(
12732        ctx,
12733        SLICE_WRITE_PTX,
12734        "slice_write_kernel",
12735        device.ordinal() as u32,
12736    ) {
12737        Ok(f) => f,
12738        Err(_) => {
12739            // CPU fallback.
12740            let src_host = gpu_to_cpu(src, device)?;
12741            let mut dst_host = gpu_to_cpu(dst, device)?;
12742            for b in 0..n_batch {
12743                for di in 0..d {
12744                    dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
12745                }
12746            }
12747            let new_dst = cpu_to_gpu(&dst_host, device)?;
12748            *dst = new_dst;
12749            return Ok(());
12750        }
12751    };
12752
12753    let cfg = launch_cfg(total)?;
12754    let n_u32 = total as u32;
12755    let d_u32 = d as u32;
12756    let max_len_u32 = max_len as u32;
12757    let pos_u32 = pos as u32;
12758
12759    unsafe {
12760        stream
12761            .launch_builder(&f)
12762            .arg(src.inner())
12763            .arg(dst.inner_mut())
12764            .arg(&n_u32)
12765            .arg(&d_u32)
12766            .arg(&max_len_u32)
12767            .arg(&pos_u32)
12768            .launch(cfg)?;
12769    }
12770
12771    Ok(())
12772}
12773
12774// ---------------------------------------------------------------------------
12775// Public API -- Slice read (for KV cache)
12776// ---------------------------------------------------------------------------
12777
12778/// Read first `len` rows from each batch of `[N, max_len, D]` → `[N, len, D]`.
12779#[cfg(feature = "cuda")]
12780pub fn gpu_slice_read(
12781    src: &CudaBuffer<f32>,
12782    n_batch: usize,
12783    d: usize,
12784    len: usize,
12785    max_len: usize,
12786    device: &GpuDevice,
12787) -> GpuResult<CudaBuffer<f32>> {
12788    use cudarc::driver::PushKernelArg;
12789
12790    let total = n_batch * len * d;
12791    let ctx = device.context();
12792    let stream = device.stream();
12793
12794    let f = match crate::module_cache::get_or_compile(
12795        ctx,
12796        SLICE_READ_PTX,
12797        "slice_read_kernel",
12798        device.ordinal() as u32,
12799    ) {
12800        Ok(f) => f,
12801        Err(_) => {
12802            let host = gpu_to_cpu(src, device)?;
12803            let mut out = vec![0.0f32; total];
12804            for b in 0..n_batch {
12805                for r in 0..len {
12806                    for di in 0..d {
12807                        out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
12808                    }
12809                }
12810            }
12811            return cpu_to_gpu(&out, device);
12812        }
12813    };
12814
12815    let mut out = alloc_zeros_f32(total, device)?;
12816    let cfg = launch_cfg(total)?;
12817    let total_u32 = total as u32;
12818    let d_u32 = d as u32;
12819    let len_u32 = len as u32;
12820    let max_len_u32 = max_len as u32;
12821
12822    unsafe {
12823        stream
12824            .launch_builder(&f)
12825            .arg(src.inner())
12826            .arg(out.inner_mut())
12827            .arg(&total_u32)
12828            .arg(&d_u32)
12829            .arg(&len_u32)
12830            .arg(&max_len_u32)
12831            .launch(cfg)?;
12832    }
12833
12834    Ok(out)
12835}
12836
12837// ---------------------------------------------------------------------------
12838// Public API -- GELU
12839// ---------------------------------------------------------------------------
12840
12841/// Elementwise GELU activation on GPU: `gelu(x) = x * sigmoid(1.702 * x)`.
12842#[cfg(feature = "cuda")]
12843pub fn gpu_gelu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12844    validate_unary(input, device)?;
12845    if let Some(out) = try_launch_unary(input, device, GELU_PTX, "gelu_kernel")? {
12846        return Ok(out);
12847    }
12848    cpu_fallback_unary(input, device, |x| {
12849        let s = 1.0 / (1.0 + (-1.702 * x).exp());
12850        x * s
12851    })
12852}
12853
12854/// Elementwise GELU activation on GPU using the tanh approximation:
12855/// `gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))`.
12856///
12857/// Matches PyTorch `nn.GELU(approximate="tanh")`.
12858#[cfg(feature = "cuda")]
12859pub fn gpu_gelu_tanh(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12860    validate_unary(input, device)?;
12861    if let Some(out) = try_launch_unary(input, device, GELU_TANH_PTX, "gelu_tanh_kernel")? {
12862        return Ok(out);
12863    }
12864    cpu_fallback_unary(input, device, |x| {
12865        let sqrt_2_over_pi: f32 = 0.797_884_6;
12866        let c: f32 = 0.044715;
12867        let inner = sqrt_2_over_pi * (x + c * x * x * x);
12868        0.5 * x * (1.0 + inner.tanh())
12869    })
12870}
12871
12872/// Elementwise GELU activation on GPU using exact erf:
12873/// `gelu(x) = x * 0.5 * (1 + erf(x / sqrt(2)))`.
12874///
12875/// Matches PyTorch `nn.GELU(approximate="none")` (the default).
12876#[cfg(feature = "cuda")]
12877pub fn gpu_gelu_erf(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12878    validate_unary(input, device)?;
12879    if let Some(out) = try_launch_unary(input, device, GELU_ERF_PTX, "gelu_erf_kernel")? {
12880        return Ok(out);
12881    }
12882    cpu_fallback_unary(input, device, |x| {
12883        // Abramowitz & Stegun 7.1.26 erf approximation (matches PTX kernel)
12884        let z = x * std::f32::consts::FRAC_1_SQRT_2;
12885        let az = z.abs();
12886        let t = 1.0 / (1.0 + 0.3275911 * az);
12887        let poly = t * (0.254_829_6 + t * (-0.284_496_72 + t * (1.421_413_8 + t * (-1.453_152_1 + t * 1.061_405_4))));
12888        let erf_abs = 1.0 - poly * (-az * az).exp();
12889        let erf_val = if z < 0.0 { -erf_abs } else { erf_abs };
12890        x * 0.5 * (1.0 + erf_val)
12891    })
12892}
12893
12894/// GELU backward for the tanh approximation mode.
12895/// Let `u = sqrt(2/π) * (x + 0.044715 * x³)`, `t = tanh(u)`.
12896/// `d/dx = 0.5 * (1 + t) + 0.5 * x * (1 - t²) * sqrt(2/π) * (1 + 3*0.044715*x²)`
12897#[cfg(feature = "cuda")]
12898pub fn gpu_gelu_backward_tanh(
12899    grad: &CudaBuffer<f32>,
12900    input: &CudaBuffer<f32>,
12901    device: &GpuDevice,
12902) -> GpuResult<CudaBuffer<f32>> {
12903    validate_binary(grad, input, device)?;
12904    if let Some(out) = try_launch_binary(
12905        grad,
12906        input,
12907        device,
12908        GELU_BACKWARD_TANH_PTX,
12909        "gelu_backward_tanh_kernel",
12910    )? {
12911        return Ok(out);
12912    }
12913    // CPU fallback
12914    let grad_host = gpu_to_cpu(grad, device)?;
12915    let input_host = gpu_to_cpu(input, device)?;
12916    let result: Vec<f32> = grad_host
12917        .iter()
12918        .zip(input_host.iter())
12919        .map(|(&g, &x)| {
12920            let sqrt_2_over_pi: f32 = 0.797_884_6;
12921            let c: f32 = 0.044715;
12922            let c3: f32 = 0.134145;
12923            let u = sqrt_2_over_pi * (x + c * x * x * x);
12924            let t = u.tanh();
12925            let dt = 1.0 - t * t;
12926            let d_inner = sqrt_2_over_pi * (1.0 + c3 * x * x);
12927            g * (0.5 * (1.0 + t) + 0.5 * x * dt * d_inner)
12928        })
12929        .collect();
12930    cpu_to_gpu(&result, device)
12931}
12932
12933// ---------------------------------------------------------------------------
12934// Public API -- SiLU (Swish)
12935// ---------------------------------------------------------------------------
12936
12937/// Elementwise SiLU activation on GPU: `silu(x) = x * sigmoid(x)`.
12938#[cfg(feature = "cuda")]
12939pub fn gpu_silu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12940    validate_unary(input, device)?;
12941    if let Some(out) = try_launch_unary(input, device, SILU_PTX, "silu_kernel")? {
12942        return Ok(out);
12943    }
12944    cpu_fallback_unary(input, device, |x| {
12945        let sig = 1.0 / (1.0 + (-x).exp());
12946        x * sig
12947    })
12948}
12949
12950/// SiLU backward: `out[i] = grad[i] * (sig + x * sig * (1 - sig))`
12951/// where `sig = sigmoid(input[i])`.
12952#[cfg(feature = "cuda")]
12953pub fn gpu_silu_backward(
12954    grad: &CudaBuffer<f32>,
12955    input: &CudaBuffer<f32>,
12956    device: &GpuDevice,
12957) -> GpuResult<CudaBuffer<f32>> {
12958    validate_binary(grad, input, device)?;
12959
12960    if let Some(out) = try_launch_binary(
12961        grad,
12962        input,
12963        device,
12964        SILU_BACKWARD_PTX,
12965        "silu_backward_kernel",
12966    )? {
12967        return Ok(out);
12968    }
12969
12970    // CPU fallback
12971    let grad_host = gpu_to_cpu(grad, device)?;
12972    let input_host = gpu_to_cpu(input, device)?;
12973    let result: Vec<f32> = grad_host
12974        .iter()
12975        .zip(input_host.iter())
12976        .map(|(&g, &x)| {
12977            let sig = 1.0 / (1.0 + (-x).exp());
12978            g * (sig + x * sig * (1.0 - sig))
12979        })
12980        .collect();
12981    cpu_to_gpu(&result, device)
12982}
12983
12984// ---------------------------------------------------------------------------
12985// Public API -- ELU
12986// ---------------------------------------------------------------------------
12987
12988/// Elementwise ELU activation on GPU: `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`.
12989///
12990/// Uses a custom launch because the kernel takes an extra `alpha` parameter.
12991#[cfg(feature = "cuda")]
12992pub fn gpu_elu(
12993    input: &CudaBuffer<f32>,
12994    alpha: f32,
12995    device: &GpuDevice,
12996) -> GpuResult<CudaBuffer<f32>> {
12997    use cudarc::driver::PushKernelArg;
12998
12999    validate_unary(input, device)?;
13000
13001    let n = input.len();
13002    let ctx = device.context();
13003    let stream = device.stream();
13004
13005    let f = match crate::module_cache::get_or_compile(
13006        ctx,
13007        ELU_PTX,
13008        "elu_kernel",
13009        device.ordinal() as u32,
13010    ) {
13011        Ok(f) => f,
13012        Err(_) => {
13013            let host = gpu_to_cpu(input, device)?;
13014            let result: Vec<f32> = host
13015                .iter()
13016                .map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
13017                .collect();
13018            return cpu_to_gpu(&result, device);
13019        }
13020    };
13021
13022    let mut out = alloc_zeros_f32(n, device)?;
13023    let cfg = launch_cfg(n)?;
13024    let n_u32 = n as u32;
13025
13026    unsafe {
13027        stream
13028            .launch_builder(&f)
13029            .arg(input.inner())
13030            .arg(out.inner_mut())
13031            .arg(&n_u32)
13032            .arg(&alpha)
13033            .launch(cfg)?;
13034    }
13035
13036    Ok(out)
13037}
13038
13039/// ELU backward: `out[i] = x > 0 ? grad[i] : grad[i] * alpha * exp(x)`.
13040///
13041/// Uses a custom launch because the kernel takes an extra `alpha` parameter.
13042#[cfg(feature = "cuda")]
13043pub fn gpu_elu_backward(
13044    grad: &CudaBuffer<f32>,
13045    input: &CudaBuffer<f32>,
13046    alpha: f32,
13047    device: &GpuDevice,
13048) -> GpuResult<CudaBuffer<f32>> {
13049    use cudarc::driver::PushKernelArg;
13050
13051    validate_binary(grad, input, device)?;
13052
13053    let n = grad.len();
13054    let ctx = device.context();
13055    let stream = device.stream();
13056
13057    let f = match crate::module_cache::get_or_compile(
13058        ctx,
13059        ELU_BACKWARD_PTX,
13060        "elu_backward_kernel",
13061        device.ordinal() as u32,
13062    ) {
13063        Ok(f) => f,
13064        Err(_) => {
13065            let grad_host = gpu_to_cpu(grad, device)?;
13066            let input_host = gpu_to_cpu(input, device)?;
13067            let result: Vec<f32> = grad_host
13068                .iter()
13069                .zip(input_host.iter())
13070                .map(|(&g, &x)| if x > 0.0 { g } else { g * alpha * x.exp() })
13071                .collect();
13072            return cpu_to_gpu(&result, device);
13073        }
13074    };
13075
13076    let mut out = alloc_zeros_f32(n, device)?;
13077    let cfg = launch_cfg(n)?;
13078    let n_u32 = n as u32;
13079
13080    unsafe {
13081        stream
13082            .launch_builder(&f)
13083            .arg(grad.inner())
13084            .arg(input.inner())
13085            .arg(out.inner_mut())
13086            .arg(&n_u32)
13087            .arg(&alpha)
13088            .launch(cfg)?;
13089    }
13090
13091    Ok(out)
13092}
13093
13094// ---------------------------------------------------------------------------
13095// Public API -- Mish
13096// ---------------------------------------------------------------------------
13097
13098/// Elementwise Mish activation on GPU: `mish(x) = x * tanh(softplus(x))`.
13099#[cfg(feature = "cuda")]
13100pub fn gpu_mish(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13101    validate_unary(input, device)?;
13102    if let Some(out) = try_launch_unary(input, device, MISH_PTX, "mish_kernel")? {
13103        return Ok(out);
13104    }
13105    cpu_fallback_unary(input, device, |x| {
13106        let sp = if x > 20.0 { x } else { (1.0 + x.exp()).ln() };
13107        x * sp.tanh()
13108    })
13109}
13110
13111/// Mish backward:
13112/// `out[i] = grad[i] * (tanh(sp) + x * sigmoid(x) * (1 - tanh(sp)^2))`
13113/// where `sp = softplus(x) = ln(1 + exp(x))`.
13114#[cfg(feature = "cuda")]
13115pub fn gpu_mish_backward(
13116    grad: &CudaBuffer<f32>,
13117    input: &CudaBuffer<f32>,
13118    device: &GpuDevice,
13119) -> GpuResult<CudaBuffer<f32>> {
13120    validate_binary(grad, input, device)?;
13121
13122    if let Some(out) = try_launch_binary(
13123        grad,
13124        input,
13125        device,
13126        MISH_BACKWARD_PTX,
13127        "mish_backward_kernel",
13128    )? {
13129        return Ok(out);
13130    }
13131
13132    // CPU fallback
13133    let grad_host = gpu_to_cpu(grad, device)?;
13134    let input_host = gpu_to_cpu(input, device)?;
13135    let result: Vec<f32> = grad_host
13136        .iter()
13137        .zip(input_host.iter())
13138        .map(|(&g, &x)| {
13139            let sp = if x > 20.0 { x } else { (1.0 + x.exp()).ln() };
13140            let t = sp.tanh();
13141            let sig = 1.0 / (1.0 + (-x).exp());
13142            g * (t + x * sig * (1.0 - t * t))
13143        })
13144        .collect();
13145    cpu_to_gpu(&result, device)
13146}
13147
13148/// Elementwise clamp: `out[i] = max(min_val, min(max_val, x[i]))`.
13149///
13150/// Uses a custom launch because the kernel takes two extra f32 parameters.
13151#[cfg(feature = "cuda")]
13152pub fn gpu_clamp(
13153    input: &CudaBuffer<f32>,
13154    min_val: f32,
13155    max_val: f32,
13156    device: &GpuDevice,
13157) -> GpuResult<CudaBuffer<f32>> {
13158    use cudarc::driver::PushKernelArg;
13159
13160    validate_unary(input, device)?;
13161
13162    let n = input.len();
13163    let ctx = device.context();
13164    let stream = device.stream();
13165
13166    let f = match crate::module_cache::get_or_compile(
13167        ctx,
13168        CLAMP_PTX,
13169        "clamp_kernel",
13170        device.ordinal() as u32,
13171    ) {
13172        Ok(f) => f,
13173        Err(_) => {
13174            let host = gpu_to_cpu(input, device)?;
13175            let result: Vec<f32> = host
13176                .iter()
13177                .map(|&x| x.max(min_val).min(max_val))
13178                .collect();
13179            return cpu_to_gpu(&result, device);
13180        }
13181    };
13182
13183    let mut out = alloc_zeros_f32(n, device)?;
13184    let cfg = launch_cfg(n)?;
13185    let n_u32 = n as u32;
13186
13187    unsafe {
13188        stream
13189            .launch_builder(&f)
13190            .arg(input.inner())
13191            .arg(out.inner_mut())
13192            .arg(&n_u32)
13193            .arg(&min_val)
13194            .arg(&max_val)
13195            .launch(cfg)?;
13196    }
13197
13198    Ok(out)
13199}
13200
13201// ---------------------------------------------------------------------------
13202// Public API -- elementwise transcendentals & math ops
13203// ---------------------------------------------------------------------------
13204
13205/// Elementwise division: `out[i] = a[i] / b[i]`.
13206#[cfg(feature = "cuda")]
13207pub fn gpu_div(
13208    a: &CudaBuffer<f32>,
13209    b: &CudaBuffer<f32>,
13210    device: &GpuDevice,
13211) -> GpuResult<CudaBuffer<f32>> {
13212    validate_binary(a, b, device)?;
13213
13214    if let Some(out) = try_launch_binary(a, b, device, DIV_PTX, "div_kernel")? {
13215        return Ok(out);
13216    }
13217
13218    // CPU fallback
13219    let a_host = gpu_to_cpu(a, device)?;
13220    let b_host = gpu_to_cpu(b, device)?;
13221    let result: Vec<f32> = a_host
13222        .iter()
13223        .zip(b_host.iter())
13224        .map(|(&x, &y)| x / y)
13225        .collect();
13226    cpu_to_gpu(&result, device)
13227}
13228
13229/// Elementwise exponential: `out[i] = exp(a[i])`.
13230#[cfg(feature = "cuda")]
13231pub fn gpu_exp(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13232    validate_unary(a, device)?;
13233    if let Some(out) = try_launch_unary(a, device, EXP_PTX, "exp_kernel")? {
13234        return Ok(out);
13235    }
13236    cpu_fallback_unary(a, device, |x| x.exp())
13237}
13238
13239/// Elementwise natural log: `out[i] = ln(a[i])`.
13240#[cfg(feature = "cuda")]
13241pub fn gpu_log(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13242    validate_unary(a, device)?;
13243    if let Some(out) = try_launch_unary(a, device, LOG_PTX, "log_kernel")? {
13244        return Ok(out);
13245    }
13246    cpu_fallback_unary(a, device, |x| x.ln())
13247}
13248
13249/// Elementwise square root: `out[i] = sqrt(a[i])`.
13250#[cfg(feature = "cuda")]
13251pub fn gpu_sqrt(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13252    validate_unary(a, device)?;
13253    if let Some(out) = try_launch_unary(a, device, SQRT_PTX, "sqrt_kernel")? {
13254        return Ok(out);
13255    }
13256    cpu_fallback_unary(a, device, |x| x.sqrt())
13257}
13258
13259/// Elementwise power: `out[i] = a[i] ^ exponent`.
13260#[cfg(feature = "cuda")]
13261pub fn gpu_pow(
13262    a: &CudaBuffer<f32>,
13263    exponent: f32,
13264    device: &GpuDevice,
13265) -> GpuResult<CudaBuffer<f32>> {
13266    use cudarc::driver::PushKernelArg;
13267
13268    validate_unary(a, device)?;
13269
13270    let n = a.len();
13271    let ctx = device.context();
13272    let stream = device.stream();
13273
13274    let f = match crate::module_cache::get_or_compile(
13275        ctx,
13276        POW_PTX,
13277        "pow_kernel",
13278        device.ordinal() as u32,
13279    ) {
13280        Ok(f) => f,
13281        Err(_) => {
13282            let host = gpu_to_cpu(a, device)?;
13283            let result: Vec<f32> = host.iter().map(|&x| x.powf(exponent)).collect();
13284            return cpu_to_gpu(&result, device);
13285        }
13286    };
13287
13288    let mut out = alloc_zeros_f32(n, device)?;
13289    let cfg = launch_cfg(n)?;
13290    let n_u32 = n as u32;
13291
13292    unsafe {
13293        stream
13294            .launch_builder(&f)
13295            .arg(a.inner())
13296            .arg(out.inner_mut())
13297            .arg(&exponent)
13298            .arg(&n_u32)
13299            .launch(cfg)?;
13300    }
13301
13302    Ok(out)
13303}
13304
13305/// Elementwise absolute value: `out[i] = |a[i]|`.
13306#[cfg(feature = "cuda")]
13307pub fn gpu_abs(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13308    validate_unary(a, device)?;
13309    if let Some(out) = try_launch_unary(a, device, ABS_PTX, "abs_kernel")? {
13310        return Ok(out);
13311    }
13312    cpu_fallback_unary(a, device, |x| x.abs())
13313}
13314
13315/// Elementwise sigmoid: `out[i] = 1 / (1 + exp(-a[i]))`.
13316#[cfg(feature = "cuda")]
13317pub fn gpu_sigmoid(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13318    validate_unary(a, device)?;
13319    if let Some(out) = try_launch_unary(a, device, SIGMOID_PTX, "sigmoid_kernel")? {
13320        return Ok(out);
13321    }
13322    cpu_fallback_unary(a, device, |x| 1.0 / (1.0 + (-x).exp()))
13323}
13324
13325/// Elementwise tanh: `out[i] = tanh(a[i])`.
13326#[cfg(feature = "cuda")]
13327pub fn gpu_tanh(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13328    validate_unary(a, device)?;
13329    if let Some(out) = try_launch_unary(a, device, TANH_PTX, "tanh_kernel")? {
13330        return Ok(out);
13331    }
13332    cpu_fallback_unary(a, device, |x| x.tanh())
13333}
13334
13335// ---------------------------------------------------------------------------
13336// Public API -- f64 elementwise ops
13337// ---------------------------------------------------------------------------
13338
13339/// Elementwise f64 addition: `out[i] = a[i] + b[i]`.
13340#[cfg(feature = "cuda")]
13341pub fn gpu_add_f64(
13342    a: &CudaBuffer<f64>,
13343    b: &CudaBuffer<f64>,
13344    device: &GpuDevice,
13345) -> GpuResult<CudaBuffer<f64>> {
13346    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13347    if a.len() != b.len() {
13348        return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13349    }
13350    let ptx = get_f64_ptx(&CACHE, ADD_PTX, "add_kernel", "add_f64_kernel");
13351    if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "add_f64_kernel")? {
13352        return Ok(out);
13353    }
13354    cpu_fallback_binary_f64(a, b, device, |x, y| x + y)
13355}
13356
13357/// Elementwise f64 subtraction: `out[i] = a[i] - b[i]`.
13358#[cfg(feature = "cuda")]
13359pub fn gpu_sub_f64(
13360    a: &CudaBuffer<f64>,
13361    b: &CudaBuffer<f64>,
13362    device: &GpuDevice,
13363) -> GpuResult<CudaBuffer<f64>> {
13364    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13365    if a.len() != b.len() {
13366        return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13367    }
13368    let ptx = get_f64_ptx(&CACHE, SUB_PTX, "sub_kernel", "sub_f64_kernel");
13369    if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "sub_f64_kernel")? {
13370        return Ok(out);
13371    }
13372    cpu_fallback_binary_f64(a, b, device, |x, y| x - y)
13373}
13374
13375/// Elementwise f64 multiplication: `out[i] = a[i] * b[i]`.
13376#[cfg(feature = "cuda")]
13377pub fn gpu_mul_f64(
13378    a: &CudaBuffer<f64>,
13379    b: &CudaBuffer<f64>,
13380    device: &GpuDevice,
13381) -> GpuResult<CudaBuffer<f64>> {
13382    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13383    if a.len() != b.len() {
13384        return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13385    }
13386    let ptx = get_f64_ptx(&CACHE, MUL_PTX, "mul_kernel", "mul_f64_kernel");
13387    if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "mul_f64_kernel")? {
13388        return Ok(out);
13389    }
13390    cpu_fallback_binary_f64(a, b, device, |x, y| x * y)
13391}
13392
13393/// Elementwise f64 division: `out[i] = a[i] / b[i]`.
13394#[cfg(feature = "cuda")]
13395pub fn gpu_div_f64(
13396    a: &CudaBuffer<f64>,
13397    b: &CudaBuffer<f64>,
13398    device: &GpuDevice,
13399) -> GpuResult<CudaBuffer<f64>> {
13400    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13401    if a.len() != b.len() {
13402        return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13403    }
13404    let ptx = get_f64_ptx(&CACHE, DIV_PTX, "div_kernel", "div_f64_kernel");
13405    if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "div_f64_kernel")? {
13406        return Ok(out);
13407    }
13408    cpu_fallback_binary_f64(a, b, device, |x, y| x / y)
13409}
13410
13411/// Elementwise f64 negation: `out[i] = -a[i]`.
13412#[cfg(feature = "cuda")]
13413pub fn gpu_neg_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13414    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13415    let ptx = get_f64_ptx(&CACHE, NEG_PTX, "neg_kernel", "neg_f64_kernel");
13416    if let Some(out) = try_launch_unary_f64(a, device, ptx, "neg_f64_kernel")? {
13417        return Ok(out);
13418    }
13419    cpu_fallback_unary_f64(a, device, |x| -x)
13420}
13421
13422/// Elementwise f64 ReLU: `out[i] = max(a[i], 0.0)`.
13423#[cfg(feature = "cuda")]
13424pub fn gpu_relu_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13425    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13426    let ptx = get_f64_ptx(&CACHE, RELU_PTX, "relu_kernel", "relu_f64_kernel");
13427    if let Some(out) = try_launch_unary_f64(a, device, ptx, "relu_f64_kernel")? {
13428        return Ok(out);
13429    }
13430    cpu_fallback_unary_f64(a, device, |x| x.max(0.0))
13431}
13432
13433/// Elementwise f64 scale: `out[i] = a[i] * scalar`.
13434#[cfg(feature = "cuda")]
13435pub fn gpu_scale_f64(
13436    a: &CudaBuffer<f64>,
13437    scalar: f64,
13438    device: &GpuDevice,
13439) -> GpuResult<CudaBuffer<f64>> {
13440    use cudarc::driver::PushKernelArg;
13441    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13442
13443    let n = a.len();
13444    let ctx = device.context();
13445    let stream = device.stream();
13446
13447    let ptx = get_f64_ptx(&CACHE, SCALE_PTX, "scale_kernel", "scale_f64_kernel");
13448    if let Ok(f) = crate::module_cache::get_or_compile(
13449        ctx, ptx, "scale_f64_kernel", device.ordinal() as u32,
13450    ) {
13451        let mut out = alloc_zeros_f64(n, device)?;
13452        let cfg = launch_cfg(n)?;
13453        let n_u32 = n as u32;
13454
13455        unsafe {
13456            stream
13457                .launch_builder(&f)
13458                .arg(a.inner())
13459                .arg(out.inner_mut())
13460                .arg(&scalar)
13461                .arg(&n_u32)
13462                .launch(cfg)?;
13463        }
13464        return Ok(out);
13465    }
13466
13467    let a_host = gpu_to_cpu(a, device)?;
13468    let result: Vec<f64> = a_host.iter().map(|&x| x * scalar).collect();
13469    cpu_to_gpu(&result, device)
13470}
13471
13472/// Elementwise f64 exp: `out[i] = exp(a[i])`.
13473#[cfg(feature = "cuda")]
13474pub fn gpu_exp_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13475    if let Some(out) = try_launch_unary_f64(a, device, EXP_F64_PTX, "exp_f64_kernel")? {
13476        return Ok(out);
13477    }
13478    cpu_fallback_unary_f64(a, device, |x| x.exp())
13479}
13480
13481/// Elementwise f64 log: `out[i] = ln(a[i])`.
13482#[cfg(feature = "cuda")]
13483pub fn gpu_log_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13484    if let Some(out) = try_launch_unary_f64(a, device, LOG_F64_PTX, "log_f64_kernel")? {
13485        return Ok(out);
13486    }
13487    cpu_fallback_unary_f64(a, device, |x| x.ln())
13488}
13489
13490/// Elementwise f64 sqrt: `out[i] = sqrt(a[i])`.
13491#[cfg(feature = "cuda")]
13492pub fn gpu_sqrt_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13493    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13494    let ptx = get_f64_ptx(&CACHE, SQRT_PTX, "sqrt_kernel", "sqrt_f64_kernel");
13495    if let Some(out) = try_launch_unary_f64(a, device, ptx, "sqrt_f64_kernel")? {
13496        return Ok(out);
13497    }
13498    cpu_fallback_unary_f64(a, device, |x| x.sqrt())
13499}
13500
13501/// Elementwise f64 pow: `out[i] = a[i] ^ exponent`.
13502#[cfg(feature = "cuda")]
13503pub fn gpu_pow_f64(
13504    a: &CudaBuffer<f64>,
13505    exponent: f64,
13506    device: &GpuDevice,
13507) -> GpuResult<CudaBuffer<f64>> {
13508    use cudarc::driver::PushKernelArg;
13509
13510    let n = a.len();
13511    let ctx = device.context();
13512    let stream = device.stream();
13513
13514    if let Ok(f) = crate::module_cache::get_or_compile(
13515        ctx, POW_F64_PTX, "pow_f64_kernel", device.ordinal() as u32,
13516    ) {
13517        let mut out = alloc_zeros_f64(n, device)?;
13518        let cfg = launch_cfg(n)?;
13519        let n_u32 = n as u32;
13520
13521        unsafe {
13522            stream
13523                .launch_builder(&f)
13524                .arg(a.inner())
13525                .arg(out.inner_mut())
13526                .arg(&exponent)
13527                .arg(&n_u32)
13528                .launch(cfg)?;
13529        }
13530        return Ok(out);
13531    }
13532
13533    let a_host = gpu_to_cpu(a, device)?;
13534    let result: Vec<f64> = a_host.iter().map(|&x| x.powf(exponent)).collect();
13535    cpu_to_gpu(&result, device)
13536}
13537
13538/// Elementwise f64 abs: `out[i] = |a[i]|`.
13539#[cfg(feature = "cuda")]
13540pub fn gpu_abs_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13541    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13542    let ptx = get_f64_ptx(&CACHE, ABS_PTX, "abs_kernel", "abs_f64_kernel");
13543    if let Some(out) = try_launch_unary_f64(a, device, ptx, "abs_f64_kernel")? {
13544        return Ok(out);
13545    }
13546    cpu_fallback_unary_f64(a, device, |x| x.abs())
13547}
13548
13549/// Elementwise f64 sigmoid: `out[i] = 1 / (1 + exp(-a[i]))`.
13550#[cfg(feature = "cuda")]
13551pub fn gpu_sigmoid_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13552    if let Some(out) = try_launch_unary_f64(a, device, SIGMOID_F64_PTX, "sigmoid_f64_kernel")? {
13553        return Ok(out);
13554    }
13555    cpu_fallback_unary_f64(a, device, |x| 1.0 / (1.0 + (-x).exp()))
13556}
13557
13558/// Elementwise f64 tanh: `out[i] = tanh(a[i])`.
13559#[cfg(feature = "cuda")]
13560pub fn gpu_tanh_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13561    if let Some(out) = try_launch_unary_f64(a, device, TANH_F64_PTX, "tanh_f64_kernel")? {
13562        return Ok(out);
13563    }
13564    cpu_fallback_unary_f64(a, device, |x| x.tanh())
13565}
13566
13567// ---------------------------------------------------------------------------
13568// Public API -- f64 backward ops
13569// ---------------------------------------------------------------------------
13570
13571/// ReLU backward (f64): `out[i] = (input[i] > 0) ? grad[i] : 0`.
13572#[cfg(feature = "cuda")]
13573pub fn gpu_relu_backward_f64(
13574    grad: &CudaBuffer<f64>,
13575    input: &CudaBuffer<f64>,
13576    device: &GpuDevice,
13577) -> GpuResult<CudaBuffer<f64>> {
13578    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13579    if grad.len() != input.len() {
13580        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
13581    }
13582    let ptx = get_f64_ptx(&CACHE, RELU_BACKWARD_PTX, "relu_backward_kernel", "relu_backward_f64_kernel");
13583    if let Some(out) = try_launch_binary_f64(
13584        grad,
13585        input,
13586        device,
13587        ptx,
13588        "relu_backward_f64_kernel",
13589    )? {
13590        return Ok(out);
13591    }
13592    cpu_fallback_binary_f64(grad, input, device, |g, x| if x > 0.0 { g } else { 0.0 })
13593}
13594
13595/// Sigmoid backward (f64): `out[i] = grad[i] * output[i] * (1 - output[i])`.
13596#[cfg(feature = "cuda")]
13597pub fn gpu_sigmoid_backward_f64(
13598    grad: &CudaBuffer<f64>,
13599    output: &CudaBuffer<f64>,
13600    device: &GpuDevice,
13601) -> GpuResult<CudaBuffer<f64>> {
13602    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13603    if grad.len() != output.len() {
13604        return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
13605    }
13606    let ptx = get_f64_ptx(&CACHE, SIGMOID_BACKWARD_PTX, "sigmoid_backward_kernel", "sigmoid_backward_f64_kernel");
13607    if let Some(out) = try_launch_binary_f64(
13608        grad,
13609        output,
13610        device,
13611        ptx,
13612        "sigmoid_backward_f64_kernel",
13613    )? {
13614        return Ok(out);
13615    }
13616    cpu_fallback_binary_f64(grad, output, device, |g, o| g * o * (1.0 - o))
13617}
13618
13619/// Tanh backward (f64): `out[i] = grad[i] * (1 - output[i]^2)`.
13620#[cfg(feature = "cuda")]
13621pub fn gpu_tanh_backward_f64(
13622    grad: &CudaBuffer<f64>,
13623    output: &CudaBuffer<f64>,
13624    device: &GpuDevice,
13625) -> GpuResult<CudaBuffer<f64>> {
13626    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13627    if grad.len() != output.len() {
13628        return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
13629    }
13630    let ptx = get_f64_ptx(&CACHE, TANH_BACKWARD_PTX, "tanh_backward_kernel", "tanh_backward_f64_kernel");
13631    if let Some(out) = try_launch_binary_f64(
13632        grad,
13633        output,
13634        device,
13635        ptx,
13636        "tanh_backward_f64_kernel",
13637    )? {
13638        return Ok(out);
13639    }
13640    cpu_fallback_binary_f64(grad, output, device, |g, o| g * (1.0 - o * o))
13641}
13642
13643// ---------------------------------------------------------------------------
13644// Public API -- f64 broadcast ops
13645// ---------------------------------------------------------------------------
13646
13647/// Broadcast addition (f64): `out[i] = a[bcast_a(i)] + b[bcast_b(i)]`.
13648#[cfg(feature = "cuda")]
13649pub fn gpu_broadcast_add_f64(
13650    a: &CudaBuffer<f64>,
13651    b: &CudaBuffer<f64>,
13652    a_shape: &[usize],
13653    b_shape: &[usize],
13654    out_shape: &[usize],
13655    device: &GpuDevice,
13656) -> GpuResult<CudaBuffer<f64>> {
13657    let a_str = broadcast_strides(a_shape, out_shape);
13658    let b_str = broadcast_strides(b_shape, out_shape);
13659    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13660    let out_numel: usize = out_shape.iter().product();
13661
13662    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13663    let ptx = get_f64_ptx(&CACHE, BROADCAST_ADD_PTX, "broadcast_add_kernel", "broadcast_add_f64_kernel");
13664    if let Some(out) = try_launch_broadcast_binary_f64(
13665        a,
13666        b,
13667        &a_str,
13668        &b_str,
13669        &shape_u32,
13670        out_numel,
13671        device,
13672        ptx,
13673        "broadcast_add_f64_kernel",
13674    )? {
13675        return Ok(out);
13676    }
13677
13678    cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
13679}
13680
13681/// Broadcast subtraction (f64): `out[i] = a[bcast_a(i)] - b[bcast_b(i)]`.
13682#[cfg(feature = "cuda")]
13683pub fn gpu_broadcast_sub_f64(
13684    a: &CudaBuffer<f64>,
13685    b: &CudaBuffer<f64>,
13686    a_shape: &[usize],
13687    b_shape: &[usize],
13688    out_shape: &[usize],
13689    device: &GpuDevice,
13690) -> GpuResult<CudaBuffer<f64>> {
13691    let a_str = broadcast_strides(a_shape, out_shape);
13692    let b_str = broadcast_strides(b_shape, out_shape);
13693    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13694    let out_numel: usize = out_shape.iter().product();
13695
13696    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13697    let ptx = get_f64_ptx(&CACHE, BROADCAST_SUB_PTX, "broadcast_sub_kernel", "broadcast_sub_f64_kernel");
13698    if let Some(out) = try_launch_broadcast_binary_f64(
13699        a,
13700        b,
13701        &a_str,
13702        &b_str,
13703        &shape_u32,
13704        out_numel,
13705        device,
13706        ptx,
13707        "broadcast_sub_f64_kernel",
13708    )? {
13709        return Ok(out);
13710    }
13711
13712    cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
13713}
13714
13715/// Broadcast multiplication (f64): `out[i] = a[bcast_a(i)] * b[bcast_b(i)]`.
13716#[cfg(feature = "cuda")]
13717pub fn gpu_broadcast_mul_f64(
13718    a: &CudaBuffer<f64>,
13719    b: &CudaBuffer<f64>,
13720    a_shape: &[usize],
13721    b_shape: &[usize],
13722    out_shape: &[usize],
13723    device: &GpuDevice,
13724) -> GpuResult<CudaBuffer<f64>> {
13725    let a_str = broadcast_strides(a_shape, out_shape);
13726    let b_str = broadcast_strides(b_shape, out_shape);
13727    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13728    let out_numel: usize = out_shape.iter().product();
13729
13730    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13731    let ptx = get_f64_ptx(&CACHE, BROADCAST_MUL_PTX, "broadcast_mul_kernel", "broadcast_mul_f64_kernel");
13732    if let Some(out) = try_launch_broadcast_binary_f64(
13733        a,
13734        b,
13735        &a_str,
13736        &b_str,
13737        &shape_u32,
13738        out_numel,
13739        device,
13740        ptx,
13741        "broadcast_mul_f64_kernel",
13742    )? {
13743        return Ok(out);
13744    }
13745
13746    cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
13747}
13748
13749/// Broadcast division (f64): `out[i] = a[bcast_a(i)] / b[bcast_b(i)]`.
13750#[cfg(feature = "cuda")]
13751pub fn gpu_broadcast_div_f64(
13752    a: &CudaBuffer<f64>,
13753    b: &CudaBuffer<f64>,
13754    a_shape: &[usize],
13755    b_shape: &[usize],
13756    out_shape: &[usize],
13757    device: &GpuDevice,
13758) -> GpuResult<CudaBuffer<f64>> {
13759    let a_str = broadcast_strides(a_shape, out_shape);
13760    let b_str = broadcast_strides(b_shape, out_shape);
13761    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13762    let out_numel: usize = out_shape.iter().product();
13763
13764    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13765    let ptx = get_f64_ptx(&CACHE, BROADCAST_DIV_PTX, "broadcast_div_kernel", "broadcast_div_f64_kernel");
13766    if let Some(out) = try_launch_broadcast_binary_f64(
13767        a,
13768        b,
13769        &a_str,
13770        &b_str,
13771        &shape_u32,
13772        out_numel,
13773        device,
13774        ptx,
13775        "broadcast_div_f64_kernel",
13776    )? {
13777        return Ok(out);
13778    }
13779
13780    cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
13781}
13782
13783// ---------------------------------------------------------------------------
13784// Public API -- f64 reduction ops
13785// ---------------------------------------------------------------------------
13786
13787/// Full reduce-sum for f64: returns a 1-element buffer containing the sum of all elements.
13788#[cfg(feature = "cuda")]
13789pub fn gpu_reduce_sum_f64(
13790    a: &CudaBuffer<f64>,
13791    device: &GpuDevice,
13792) -> GpuResult<CudaBuffer<f64>> {
13793    use cudarc::driver::PushKernelArg;
13794    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13795
13796    let n = a.len();
13797    if n == 0 {
13798        return cpu_to_gpu(&[0.0f64], device);
13799    }
13800
13801    let ctx = device.context();
13802    let stream = device.stream();
13803
13804    let ptx = get_f64_ptx(&CACHE, REDUCE_SUM_PTX, "reduce_sum_kernel", "reduce_sum_f64_kernel");
13805    let f = match crate::module_cache::get_or_compile(
13806        ctx,
13807        ptx,
13808        "reduce_sum_f64_kernel",
13809        device.ordinal() as u32,
13810    ) {
13811        Ok(f) => f,
13812        Err(_) => {
13813            let host = gpu_to_cpu(a, device)?;
13814            let total: f64 = host.iter().sum();
13815            return cpu_to_gpu(&[total], device);
13816        }
13817    };
13818
13819    const BLOCK: u32 = 256;
13820    let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
13821    let num_blocks = num_blocks.min(1024);
13822
13823    let mut partials = alloc_zeros_f64(num_blocks as usize, device)?;
13824    let n_u32 = n as u32;
13825
13826    let cfg = cudarc::driver::LaunchConfig {
13827        grid_dim: (num_blocks.max(1), 1, 1),
13828        block_dim: (BLOCK, 1, 1),
13829        shared_mem_bytes: 0,
13830    };
13831
13832    unsafe {
13833        stream
13834            .launch_builder(&f)
13835            .arg(a.inner())
13836            .arg(partials.inner_mut())
13837            .arg(&n_u32)
13838            .launch(cfg)?;
13839    }
13840
13841    if num_blocks <= 1 {
13842        return Ok(partials);
13843    }
13844
13845    if num_blocks <= 256 {
13846        let host_partials = gpu_to_cpu(&partials, device)?;
13847        let total: f64 = host_partials.iter().sum();
13848        return cpu_to_gpu(&[total], device);
13849    }
13850
13851    gpu_reduce_sum_f64(&partials, device)
13852}
13853
13854/// Sum along an axis for f64.
13855#[cfg(feature = "cuda")]
13856pub fn gpu_sum_axis_f64(
13857    a: &CudaBuffer<f64>,
13858    outer: usize,
13859    axis_size: usize,
13860    inner: usize,
13861    device: &GpuDevice,
13862) -> GpuResult<CudaBuffer<f64>> {
13863    use cudarc::driver::PushKernelArg;
13864    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13865
13866    let total_output = outer * inner;
13867    let ctx = device.context();
13868    let stream = device.stream();
13869
13870    let ptx = get_f64_ptx(&CACHE, SUM_AXIS_PTX, "sum_axis_kernel", "sum_axis_f64_kernel");
13871    let f = match crate::module_cache::get_or_compile(
13872        ctx,
13873        ptx,
13874        "sum_axis_f64_kernel",
13875        device.ordinal() as u32,
13876    ) {
13877        Ok(f) => f,
13878        Err(_) => {
13879            let host = gpu_to_cpu(a, device)?;
13880            let mut result = vec![0.0f64; total_output];
13881            for (i, out) in result.iter_mut().enumerate() {
13882                let outer_idx = i / inner;
13883                let inner_idx = i % inner;
13884                let mut sum = 0.0f64;
13885                for k in 0..axis_size {
13886                    sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
13887                }
13888                *out = sum;
13889            }
13890            return cpu_to_gpu(&result, device);
13891        }
13892    };
13893
13894    let mut out = alloc_zeros_f64(total_output, device)?;
13895    let cfg = launch_cfg(total_output)?;
13896    let outer_u32 = outer as u32;
13897    let axis_size_u32 = axis_size as u32;
13898    let inner_u32 = inner as u32;
13899    let total_u32 = total_output as u32;
13900
13901    unsafe {
13902        stream
13903            .launch_builder(&f)
13904            .arg(a.inner())
13905            .arg(out.inner_mut())
13906            .arg(&outer_u32)
13907            .arg(&axis_size_u32)
13908            .arg(&inner_u32)
13909            .arg(&total_u32)
13910            .launch(cfg)?;
13911    }
13912
13913    Ok(out)
13914}
13915
13916#[cfg(not(feature = "cuda"))]
13917pub fn gpu_reduce_sum_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
13918#[cfg(not(feature = "cuda"))]
13919pub fn gpu_sum_axis_f64(_a: &CudaBuffer<f64>, _outer: usize, _axis_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
13920
13921// ---------------------------------------------------------------------------
13922// Public API -- f64 shape ops
13923// ---------------------------------------------------------------------------
13924
13925/// Transpose an `[M, N]` f64 matrix to `[N, M]` on GPU.
13926#[cfg(feature = "cuda")]
13927pub fn gpu_transpose_2d_f64(
13928    input: &CudaBuffer<f64>,
13929    m: usize,
13930    n: usize,
13931    device: &GpuDevice,
13932) -> GpuResult<CudaBuffer<f64>> {
13933    use cudarc::driver::PushKernelArg;
13934    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13935
13936    validate_device(input, device)?;
13937
13938    let total = m * n;
13939    let ctx = device.context();
13940    let stream = device.stream();
13941
13942    let ptx = get_f64_ptx(&CACHE, TRANSPOSE_2D_PTX, "transpose_2d_kernel", "transpose_2d_f64_kernel");
13943    let f = match crate::module_cache::get_or_compile(
13944        ctx,
13945        ptx,
13946        "transpose_2d_f64_kernel",
13947        device.ordinal() as u32,
13948    ) {
13949        Ok(f) => f,
13950        Err(_) => {
13951            let host = gpu_to_cpu(input, device)?;
13952            let mut out = vec![0.0f64; total];
13953            for i in 0..m {
13954                for j in 0..n {
13955                    out[j * m + i] = host[i * n + j];
13956                }
13957            }
13958            return cpu_to_gpu(&out, device);
13959        }
13960    };
13961
13962    let mut out = alloc_zeros_f64(total, device)?;
13963    let cfg = launch_cfg(total)?;
13964    let m_u32 = m as u32;
13965    let n_u32 = n as u32;
13966    let total_u32 = total as u32;
13967
13968    unsafe {
13969        stream
13970            .launch_builder(&f)
13971            .arg(input.inner())
13972            .arg(out.inner_mut())
13973            .arg(&m_u32)
13974            .arg(&n_u32)
13975            .arg(&total_u32)
13976            .launch(cfg)?;
13977    }
13978
13979    Ok(out)
13980}
13981
13982/// Permute a 4D f64 tensor from `[d0, d1, d2, d3]` to `[d0, d2, d1, d3]` on GPU.
13983#[cfg(feature = "cuda")]
13984pub fn gpu_permute_0213_f64(
13985    input: &CudaBuffer<f64>,
13986    d0: usize,
13987    d1: usize,
13988    d2: usize,
13989    d3: usize,
13990    device: &GpuDevice,
13991) -> GpuResult<CudaBuffer<f64>> {
13992    use cudarc::driver::PushKernelArg;
13993    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13994
13995    validate_device(input, device)?;
13996
13997    let total = d0 * d1 * d2 * d3;
13998    let ctx = device.context();
13999    let stream = device.stream();
14000
14001    let ptx = get_f64_ptx(&CACHE, PERMUTE_0213_PTX, "permute_0213_kernel", "permute_0213_f64_kernel");
14002    let f = match crate::module_cache::get_or_compile(
14003        ctx,
14004        ptx,
14005        "permute_0213_f64_kernel",
14006        device.ordinal() as u32,
14007    ) {
14008        Ok(f) => f,
14009        Err(_) => {
14010            let host = gpu_to_cpu(input, device)?;
14011            let mut out = vec![0.0f64; total];
14012            for i0 in 0..d0 {
14013                for i1 in 0..d1 {
14014                    for i2 in 0..d2 {
14015                        for i3 in 0..d3 {
14016                            let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
14017                            let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
14018                            out[out_idx] = host[in_idx];
14019                        }
14020                    }
14021                }
14022            }
14023            return cpu_to_gpu(&out, device);
14024        }
14025    };
14026
14027    let mut out = alloc_zeros_f64(total, device)?;
14028    let cfg = launch_cfg(total)?;
14029    let d0_u32 = d0 as u32;
14030    let d1_u32 = d1 as u32;
14031    let d2_u32 = d2 as u32;
14032    let d3_u32 = d3 as u32;
14033    let total_u32 = total as u32;
14034
14035    unsafe {
14036        stream
14037            .launch_builder(&f)
14038            .arg(input.inner())
14039            .arg(out.inner_mut())
14040            .arg(&d0_u32)
14041            .arg(&d1_u32)
14042            .arg(&d2_u32)
14043            .arg(&d3_u32)
14044            .arg(&total_u32)
14045            .launch(cfg)?;
14046    }
14047
14048    Ok(out)
14049}
14050
14051/// Split a contiguous f64 tensor along an axis (strided read) on GPU.
14052#[cfg(feature = "cuda")]
14053pub fn gpu_strided_split_f64(
14054    input: &CudaBuffer<f64>,
14055    total_along_axis: usize,
14056    split_offset: usize,
14057    split_size: usize,
14058    inner_size: usize,
14059    n: usize,
14060    device: &GpuDevice,
14061) -> GpuResult<CudaBuffer<f64>> {
14062    use cudarc::driver::PushKernelArg;
14063    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14064
14065    validate_device(input, device)?;
14066
14067    let ctx = device.context();
14068    let stream = device.stream();
14069
14070    let ptx = get_f64_ptx(&CACHE, STRIDED_SPLIT_PTX, "strided_split_kernel", "strided_split_f64_kernel");
14071    let f = match crate::module_cache::get_or_compile(
14072        ctx,
14073        ptx,
14074        "strided_split_f64_kernel",
14075        device.ordinal() as u32,
14076    ) {
14077        Ok(f) => f,
14078        Err(_) => {
14079            let host = gpu_to_cpu(input, device)?;
14080            let mut result = vec![0.0f64; n];
14081            for (i, out) in result.iter_mut().enumerate() {
14082                let outer_idx = i / (split_size * inner_size);
14083                let within = i % (split_size * inner_size);
14084                let src_idx =
14085                    outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
14086                *out = host[src_idx];
14087            }
14088            return cpu_to_gpu(&result, device);
14089        }
14090    };
14091
14092    let mut out = alloc_zeros_f64(n, device)?;
14093    let cfg = launch_cfg(n)?;
14094    let total_ax_u32 = total_along_axis as u32;
14095    let offset_u32 = split_offset as u32;
14096    let split_sz_u32 = split_size as u32;
14097    let inner_u32 = inner_size as u32;
14098    let n_u32 = n as u32;
14099
14100    unsafe {
14101        stream
14102            .launch_builder(&f)
14103            .arg(input.inner())
14104            .arg(out.inner_mut())
14105            .arg(&total_ax_u32)
14106            .arg(&offset_u32)
14107            .arg(&split_sz_u32)
14108            .arg(&inner_u32)
14109            .arg(&n_u32)
14110            .launch(cfg)?;
14111    }
14112
14113    Ok(out)
14114}
14115
14116/// Concatenate an f64 sub-tensor into a larger output at an axis offset on GPU.
14117#[cfg(feature = "cuda")]
14118#[allow(clippy::too_many_arguments)]
14119pub fn gpu_strided_cat_f64(
14120    input: &CudaBuffer<f64>,
14121    output: &mut CudaBuffer<f64>,
14122    total_along_axis: usize,
14123    cat_offset: usize,
14124    part_size: usize,
14125    inner_size: usize,
14126    n: usize,
14127    device: &GpuDevice,
14128) -> GpuResult<()> {
14129    use cudarc::driver::PushKernelArg;
14130
14131    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14132    validate_device(input, device)?;
14133
14134    let ctx = device.context();
14135    let stream = device.stream();
14136
14137    let ptx = get_f64_ptx(&CACHE, STRIDED_CAT_PTX, "strided_cat_kernel", "strided_cat_f64_kernel");
14138    let f = match crate::module_cache::get_or_compile(
14139        ctx,
14140        ptx,
14141        "strided_cat_f64_kernel",
14142        device.ordinal() as u32,
14143    ) {
14144        Ok(f) => f,
14145        Err(_) => {
14146            let host_in = gpu_to_cpu(input, device)?;
14147            let mut host_out = gpu_to_cpu(output, device)?;
14148            for (i, &val) in host_in.iter().enumerate().take(n) {
14149                let outer_idx = i / (part_size * inner_size);
14150                let within = i % (part_size * inner_size);
14151                let dst_idx =
14152                    outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
14153                host_out[dst_idx] = val;
14154            }
14155            *output = cpu_to_gpu(&host_out, device)?;
14156            return Ok(());
14157        }
14158    };
14159
14160    let cfg = launch_cfg(n)?;
14161    let total_ax_u32 = total_along_axis as u32;
14162    let offset_u32 = cat_offset as u32;
14163    let part_sz_u32 = part_size as u32;
14164    let inner_u32 = inner_size as u32;
14165    let n_u32 = n as u32;
14166
14167    unsafe {
14168        stream
14169            .launch_builder(&f)
14170            .arg(input.inner())
14171            .arg(output.inner_mut())
14172            .arg(&total_ax_u32)
14173            .arg(&offset_u32)
14174            .arg(&part_sz_u32)
14175            .arg(&inner_u32)
14176            .arg(&n_u32)
14177            .launch(cfg)?;
14178    }
14179
14180    Ok(())
14181}
14182
14183// ---------------------------------------------------------------------------
14184// Public API -- f64 indexing ops
14185// ---------------------------------------------------------------------------
14186
14187/// Gather f64 elements by f32 index: `out[i] = input[indices[i]]`.
14188#[cfg(feature = "cuda")]
14189pub fn gpu_index_select_1d_f64(
14190    input: &CudaBuffer<f64>,
14191    indices: &CudaBuffer<f32>,
14192    device: &GpuDevice,
14193) -> GpuResult<CudaBuffer<f64>> {
14194    use cudarc::driver::PushKernelArg;
14195    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14196
14197    validate_device(input, device)?;
14198
14199    let n = indices.len();
14200    let ctx = device.context();
14201    let stream = device.stream();
14202
14203    let ptx = get_f64_ptx(&CACHE, INDEX_SELECT_1D_PTX, "index_select_1d_kernel", "index_select_1d_f64_kernel");
14204    let f = match crate::module_cache::get_or_compile(
14205        ctx,
14206        ptx,
14207        "index_select_1d_f64_kernel",
14208        device.ordinal() as u32,
14209    ) {
14210        Ok(f) => f,
14211        Err(_) => {
14212            let input_host = gpu_to_cpu(input, device)?;
14213            let indices_host = gpu_to_cpu(indices, device)?;
14214            let result: Vec<f64> = indices_host
14215                .iter()
14216                .map(|&idx_f| input_host[idx_f as usize])
14217                .collect();
14218            return cpu_to_gpu(&result, device);
14219        }
14220    };
14221
14222    let mut out = alloc_zeros_f64(n, device)?;
14223    let cfg = launch_cfg(n)?;
14224    let n_u32 = n as u32;
14225
14226    unsafe {
14227        stream
14228            .launch_builder(&f)
14229            .arg(input.inner())
14230            .arg(indices.inner())
14231            .arg(out.inner_mut())
14232            .arg(&n_u32)
14233            .launch(cfg)?;
14234    }
14235
14236    Ok(out)
14237}
14238
14239/// Scatter-add f64 `grad_output` back using f32 `indices`.
14240///
14241/// Output: `out = zeros(input_len); for i: out[indices[i]] += grad_output[i]`
14242#[cfg(feature = "cuda")]
14243pub fn gpu_scatter_add_1d_f64(
14244    grad_output: &CudaBuffer<f64>,
14245    indices: &CudaBuffer<f32>,
14246    input_len: usize,
14247    device: &GpuDevice,
14248) -> GpuResult<CudaBuffer<f64>> {
14249    use cudarc::driver::PushKernelArg;
14250    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14251
14252    validate_device(grad_output, device)?;
14253
14254    let n = grad_output.len();
14255    let ctx = device.context();
14256    let stream = device.stream();
14257
14258    let ptx = get_f64_ptx(&CACHE, SCATTER_ADD_1D_PTX, "scatter_add_1d_kernel", "scatter_add_1d_f64_kernel");
14259    let f = match crate::module_cache::get_or_compile(
14260        ctx,
14261        ptx,
14262        "scatter_add_1d_f64_kernel",
14263        device.ordinal() as u32,
14264    ) {
14265        Ok(f) => f,
14266        Err(_) => {
14267            let go_host = gpu_to_cpu(grad_output, device)?;
14268            let idx_host = gpu_to_cpu(indices, device)?;
14269            let mut result = vec![0.0f64; input_len];
14270            for (i, &idx_f) in idx_host.iter().enumerate() {
14271                result[idx_f as usize] += go_host[i];
14272            }
14273            return cpu_to_gpu(&result, device);
14274        }
14275    };
14276
14277    let mut out = alloc_zeros_f64(input_len, device)?;
14278    let cfg = launch_cfg(n)?;
14279    let n_u32 = n as u32;
14280
14281    unsafe {
14282        stream
14283            .launch_builder(&f)
14284            .arg(grad_output.inner())
14285            .arg(indices.inner())
14286            .arg(out.inner_mut())
14287            .arg(&n_u32)
14288            .launch(cfg)?;
14289    }
14290
14291    Ok(out)
14292}
14293
14294/// Fill f64 elements with `value` where u8 `mask` is nonzero.
14295///
14296/// `mask` is a GPU buffer of u8 values (nonzero = true).
14297/// Output: `out[i] = mask[i] != 0 ? value : input[i]`
14298#[cfg(feature = "cuda")]
14299pub fn gpu_masked_fill_f64(
14300    input: &CudaBuffer<f64>,
14301    mask: &CudaBuffer<u8>,
14302    value: f64,
14303    device: &GpuDevice,
14304) -> GpuResult<CudaBuffer<f64>> {
14305    use cudarc::driver::PushKernelArg;
14306    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14307
14308    validate_device(input, device)?;
14309
14310    let n = input.len();
14311    let ctx = device.context();
14312    let stream = device.stream();
14313
14314    let ptx = get_f64_ptx(&CACHE, MASKED_FILL_PTX, "masked_fill_kernel", "masked_fill_f64_kernel");
14315    let f = match crate::module_cache::get_or_compile(
14316        ctx,
14317        ptx,
14318        "masked_fill_f64_kernel",
14319        device.ordinal() as u32,
14320    ) {
14321        Ok(f) => f,
14322        Err(_) => {
14323            let input_host = gpu_to_cpu(input, device)?;
14324            let mask_host = gpu_to_cpu(mask, device)?;
14325            let result: Vec<f64> = input_host
14326                .iter()
14327                .zip(mask_host.iter())
14328                .map(|(&x, &m)| if m != 0 { value } else { x })
14329                .collect();
14330            return cpu_to_gpu(&result, device);
14331        }
14332    };
14333
14334    let mut out = alloc_zeros_f64(n, device)?;
14335    let cfg = launch_cfg(n)?;
14336    let n_u32 = n as u32;
14337
14338    unsafe {
14339        stream
14340            .launch_builder(&f)
14341            .arg(input.inner())
14342            .arg(mask.inner())
14343            .arg(out.inner_mut())
14344            .arg(&value)
14345            .arg(&n_u32)
14346            .launch(cfg)?;
14347    }
14348
14349    Ok(out)
14350}
14351
14352/// Zero out f64 gradient where u8 `mask` is nonzero.
14353///
14354/// Output: `out[i] = mask[i] != 0 ? 0.0 : grad[i]`
14355#[cfg(feature = "cuda")]
14356pub fn gpu_masked_zero_f64(
14357    grad: &CudaBuffer<f64>,
14358    mask: &CudaBuffer<u8>,
14359    device: &GpuDevice,
14360) -> GpuResult<CudaBuffer<f64>> {
14361    use cudarc::driver::PushKernelArg;
14362    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14363
14364    validate_device(grad, device)?;
14365
14366    let n = grad.len();
14367    let ctx = device.context();
14368    let stream = device.stream();
14369
14370    let ptx = get_f64_ptx(&CACHE, MASKED_ZERO_PTX, "masked_zero_kernel", "masked_zero_f64_kernel");
14371    let f = match crate::module_cache::get_or_compile(
14372        ctx,
14373        ptx,
14374        "masked_zero_f64_kernel",
14375        device.ordinal() as u32,
14376    ) {
14377        Ok(f) => f,
14378        Err(_) => {
14379            let grad_host = gpu_to_cpu(grad, device)?;
14380            let mask_host = gpu_to_cpu(mask, device)?;
14381            let result: Vec<f64> = grad_host
14382                .iter()
14383                .zip(mask_host.iter())
14384                .map(|(&g, &m)| if m != 0 { 0.0 } else { g })
14385                .collect();
14386            return cpu_to_gpu(&result, device);
14387        }
14388    };
14389
14390    let mut out = alloc_zeros_f64(n, device)?;
14391    let cfg = launch_cfg(n)?;
14392    let n_u32 = n as u32;
14393
14394    unsafe {
14395        stream
14396            .launch_builder(&f)
14397            .arg(grad.inner())
14398            .arg(mask.inner())
14399            .arg(out.inner_mut())
14400            .arg(&n_u32)
14401            .launch(cfg)?;
14402    }
14403
14404    Ok(out)
14405}
14406
14407/// Write f64 `src` of shape `[N, D]` into row `pos` of `dst` of shape `[N, max_len, D]`.
14408#[cfg(feature = "cuda")]
14409pub fn gpu_slice_write_f64(
14410    src: &CudaBuffer<f64>,
14411    dst: &mut CudaBuffer<f64>,
14412    n_batch: usize,
14413    d: usize,
14414    max_len: usize,
14415    pos: usize,
14416    device: &GpuDevice,
14417) -> GpuResult<()> {
14418    use cudarc::driver::PushKernelArg;
14419    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14420
14421    let total = n_batch * d;
14422    let ctx = device.context();
14423    let stream = device.stream();
14424
14425    let ptx = get_f64_ptx(&CACHE, SLICE_WRITE_PTX, "slice_write_kernel", "slice_write_f64_kernel");
14426    let f = match crate::module_cache::get_or_compile(
14427        ctx,
14428        ptx,
14429        "slice_write_f64_kernel",
14430        device.ordinal() as u32,
14431    ) {
14432        Ok(f) => f,
14433        Err(_) => {
14434            let src_host = gpu_to_cpu(src, device)?;
14435            let mut dst_host = gpu_to_cpu(dst, device)?;
14436            for b in 0..n_batch {
14437                for di in 0..d {
14438                    dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
14439                }
14440            }
14441            let new_dst = cpu_to_gpu(&dst_host, device)?;
14442            *dst = new_dst;
14443            return Ok(());
14444        }
14445    };
14446
14447    let cfg = launch_cfg(total)?;
14448    let n_u32 = total as u32;
14449    let d_u32 = d as u32;
14450    let max_len_u32 = max_len as u32;
14451    let pos_u32 = pos as u32;
14452
14453    unsafe {
14454        stream
14455            .launch_builder(&f)
14456            .arg(src.inner())
14457            .arg(dst.inner_mut())
14458            .arg(&n_u32)
14459            .arg(&d_u32)
14460            .arg(&max_len_u32)
14461            .arg(&pos_u32)
14462            .launch(cfg)?;
14463    }
14464
14465    Ok(())
14466}
14467
14468/// Read first `len` rows from each batch of f64 `[N, max_len, D]` -> `[N, len, D]`.
14469#[cfg(feature = "cuda")]
14470pub fn gpu_slice_read_f64(
14471    src: &CudaBuffer<f64>,
14472    n_batch: usize,
14473    d: usize,
14474    len: usize,
14475    max_len: usize,
14476    device: &GpuDevice,
14477) -> GpuResult<CudaBuffer<f64>> {
14478    use cudarc::driver::PushKernelArg;
14479    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14480
14481    let total = n_batch * len * d;
14482    let ctx = device.context();
14483    let stream = device.stream();
14484
14485    let ptx = get_f64_ptx(&CACHE, SLICE_READ_PTX, "slice_read_kernel", "slice_read_f64_kernel");
14486    let f = match crate::module_cache::get_or_compile(
14487        ctx,
14488        ptx,
14489        "slice_read_f64_kernel",
14490        device.ordinal() as u32,
14491    ) {
14492        Ok(f) => f,
14493        Err(_) => {
14494            let host = gpu_to_cpu(src, device)?;
14495            let mut out = vec![0.0f64; total];
14496            for b in 0..n_batch {
14497                for r in 0..len {
14498                    for di in 0..d {
14499                        out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
14500                    }
14501                }
14502            }
14503            return cpu_to_gpu(&out, device);
14504        }
14505    };
14506
14507    let mut out = alloc_zeros_f64(total, device)?;
14508    let cfg = launch_cfg(total)?;
14509    let total_u32 = total as u32;
14510    let d_u32 = d as u32;
14511    let len_u32 = len as u32;
14512    let max_len_u32 = max_len as u32;
14513
14514    unsafe {
14515        stream
14516            .launch_builder(&f)
14517            .arg(src.inner())
14518            .arg(out.inner_mut())
14519            .arg(&total_u32)
14520            .arg(&d_u32)
14521            .arg(&len_u32)
14522            .arg(&max_len_u32)
14523            .launch(cfg)?;
14524    }
14525
14526    Ok(out)
14527}
14528
14529// ---------------------------------------------------------------------------
14530// Public API -- f64 embedding ops
14531// ---------------------------------------------------------------------------
14532
14533/// Single f64 embedding lookup: `output[d] = weight[token_id * D + d]`.
14534#[cfg(feature = "cuda")]
14535pub fn gpu_embed_lookup_f64(
14536    idx: &CudaBuffer<f32>,
14537    weight: &CudaBuffer<f64>,
14538    d: usize,
14539    device: &GpuDevice,
14540) -> GpuResult<CudaBuffer<f64>> {
14541    use cudarc::driver::PushKernelArg;
14542    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14543
14544    let ctx = device.context();
14545    let stream = device.stream();
14546
14547    let ptx = get_f64_ptx(&CACHE, EMBED_LOOKUP_PTX, "embed_lookup_kernel", "embed_lookup_f64_kernel");
14548    let f = match crate::module_cache::get_or_compile(
14549        ctx,
14550        ptx,
14551        "embed_lookup_f64_kernel",
14552        device.ordinal() as u32,
14553    ) {
14554        Ok(f) => f,
14555        Err(_) => {
14556            let idx_host = gpu_to_cpu(idx, device)?;
14557            let weight_host = gpu_to_cpu(weight, device)?;
14558            let row = idx_host[0] as usize;
14559            let start = row * d;
14560            let out = weight_host[start..start + d].to_vec();
14561            return cpu_to_gpu(&out, device);
14562        }
14563    };
14564
14565    let mut out = alloc_zeros_f64(d, device)?;
14566    let cfg = launch_cfg(d)?;
14567    let d_u32 = d as u32;
14568
14569    unsafe {
14570        stream
14571            .launch_builder(&f)
14572            .arg(idx.inner())
14573            .arg(weight.inner())
14574            .arg(out.inner_mut())
14575            .arg(&d_u32)
14576            .launch(cfg)?;
14577    }
14578
14579    Ok(out)
14580}
14581
14582/// Batch f64 embedding lookup: gather N rows from `[V, D]` weight into `[N, D]`.
14583#[cfg(feature = "cuda")]
14584pub fn gpu_embed_lookup_batch_f64(
14585    indices: &CudaBuffer<f32>,
14586    weight: &CudaBuffer<f64>,
14587    n: usize,
14588    d: usize,
14589    device: &GpuDevice,
14590) -> GpuResult<CudaBuffer<f64>> {
14591    use cudarc::driver::PushKernelArg;
14592    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14593
14594    let total = n * d;
14595    if total == 0 {
14596        return alloc_zeros_f64(0, device);
14597    }
14598
14599    let ctx = device.context();
14600    let stream = device.stream();
14601
14602    let ptx = get_f64_ptx(&CACHE, EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel", "embed_lookup_batch_f64_kernel");
14603    let f = match crate::module_cache::get_or_compile(
14604        ctx,
14605        ptx,
14606        "embed_lookup_batch_f64_kernel",
14607        device.ordinal() as u32,
14608    ) {
14609        Ok(f) => f,
14610        Err(_) => {
14611            let idx_host = gpu_to_cpu(indices, device)?;
14612            let weight_host = gpu_to_cpu(weight, device)?;
14613            let mut out = Vec::with_capacity(total);
14614            for &idx_f in &idx_host {
14615                let row = idx_f as usize;
14616                let start = row * d;
14617                out.extend_from_slice(&weight_host[start..start + d]);
14618            }
14619            return cpu_to_gpu(&out, device);
14620        }
14621    };
14622
14623    let mut out = alloc_zeros_f64(total, device)?;
14624    let cfg = launch_cfg(total)?;
14625    let d_u32 = d as u32;
14626    let total_u32 = total as u32;
14627
14628    unsafe {
14629        stream
14630            .launch_builder(&f)
14631            .arg(indices.inner())
14632            .arg(weight.inner())
14633            .arg(out.inner_mut())
14634            .arg(&d_u32)
14635            .arg(&total_u32)
14636            .launch(cfg)?;
14637    }
14638
14639    Ok(out)
14640}
14641
14642/// Scatter-add f64 rows for embedding backward.
14643///
14644/// Atomically accumulates `grad_output[i, :] += grad_weight[indices[i], :]`.
14645#[cfg(feature = "cuda")]
14646pub fn gpu_scatter_add_rows_f64(
14647    grad_output: &CudaBuffer<f64>,
14648    indices: &CudaBuffer<f32>,
14649    num_embeddings: usize,
14650    d: usize,
14651    device: &GpuDevice,
14652) -> GpuResult<CudaBuffer<f64>> {
14653    use cudarc::driver::PushKernelArg;
14654    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14655
14656    let n = indices.len();
14657    let total = n * d;
14658
14659    if total == 0 {
14660        return alloc_zeros_f64(num_embeddings * d, device);
14661    }
14662
14663    let ctx = device.context();
14664    let stream = device.stream();
14665
14666    let ptx = get_f64_ptx(&CACHE, SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel", "scatter_add_rows_f64_kernel");
14667    let f = match crate::module_cache::get_or_compile(
14668        ctx,
14669        ptx,
14670        "scatter_add_rows_f64_kernel",
14671        device.ordinal() as u32,
14672    ) {
14673        Ok(f) => f,
14674        Err(_) => {
14675            let go_host = gpu_to_cpu(grad_output, device)?;
14676            let idx_host = gpu_to_cpu(indices, device)?;
14677            let mut result = vec![0.0f64; num_embeddings * d];
14678            for (i, &idx_f) in idx_host.iter().enumerate() {
14679                let row = idx_f as usize;
14680                for j in 0..d {
14681                    result[row * d + j] += go_host[i * d + j];
14682                }
14683            }
14684            return cpu_to_gpu(&result, device);
14685        }
14686    };
14687
14688    let mut out = alloc_zeros_f64(num_embeddings * d, device)?;
14689    let cfg = launch_cfg(total)?;
14690    let d_u32 = d as u32;
14691    let total_u32 = total as u32;
14692
14693    unsafe {
14694        stream
14695            .launch_builder(&f)
14696            .arg(grad_output.inner())
14697            .arg(indices.inner())
14698            .arg(out.inner_mut())
14699            .arg(&d_u32)
14700            .arg(&total_u32)
14701            .launch(cfg)?;
14702    }
14703
14704    Ok(out)
14705}
14706
14707// ---------------------------------------------------------------------------
14708// Public API -- fused Adam optimizer step
14709// ---------------------------------------------------------------------------
14710
14711/// Fused Adam optimizer step: updates param, exp_avg, and exp_avg_sq in-place
14712/// in a single kernel launch.
14713///
14714/// All four buffers must have the same length `n`. `param`, `exp_avg`, and
14715/// `exp_avg_sq` are modified in-place. `grad` is read-only.
14716#[cfg(feature = "cuda")]
14717#[allow(clippy::too_many_arguments)]
14718pub fn gpu_fused_adam(
14719    param: &mut CudaBuffer<f32>,
14720    grad: &CudaBuffer<f32>,
14721    exp_avg: &mut CudaBuffer<f32>,
14722    exp_avg_sq: &mut CudaBuffer<f32>,
14723    beta1: f32,
14724    beta2: f32,
14725    lr: f32,
14726    eps: f32,
14727    bc1: f32,
14728    bc2: f32,
14729    weight_decay: f32,
14730    device: &GpuDevice,
14731) -> GpuResult<()> {
14732    use cudarc::driver::PushKernelArg;
14733
14734    let n = param.len();
14735    if grad.len() != n || exp_avg.len() != n || exp_avg_sq.len() != n {
14736        return Err(GpuError::LengthMismatch {
14737            a: n,
14738            b: grad.len(),
14739        });
14740    }
14741
14742    let ctx = device.context();
14743    let stream = device.stream();
14744
14745    let f = match crate::module_cache::get_or_compile(
14746        ctx,
14747        FUSED_ADAM_PTX,
14748        "fused_adam_kernel",
14749        device.ordinal() as u32,
14750    ) {
14751        Ok(f) => f,
14752        Err(_) => {
14753            // CPU fallback: download, compute, upload.
14754            let mut p_host = gpu_to_cpu(param, device)?;
14755            let g_host = gpu_to_cpu(grad, device)?;
14756            let mut m_host = gpu_to_cpu(exp_avg, device)?;
14757            let mut v_host = gpu_to_cpu(exp_avg_sq, device)?;
14758
14759            for i in 0..n {
14760                let mut g = g_host[i];
14761                if weight_decay > 0.0 {
14762                    g += weight_decay * p_host[i];
14763                }
14764                m_host[i] = beta1 * m_host[i] + (1.0 - beta1) * g;
14765                v_host[i] = beta2 * v_host[i] + (1.0 - beta2) * g * g;
14766                let m_hat = m_host[i] / bc1;
14767                let v_hat = v_host[i] / bc2;
14768                p_host[i] -= lr * m_hat / (v_hat.sqrt() + eps);
14769            }
14770
14771            *param = cpu_to_gpu(&p_host, device)?;
14772            *exp_avg = cpu_to_gpu(&m_host, device)?;
14773            *exp_avg_sq = cpu_to_gpu(&v_host, device)?;
14774            return Ok(());
14775        }
14776    };
14777
14778    let cfg = launch_cfg(n)?;
14779    let n_u32 = n as u32;
14780
14781    unsafe {
14782        stream
14783            .launch_builder(&f)
14784            .arg(param.inner_mut())
14785            .arg(grad.inner())
14786            .arg(exp_avg.inner_mut())
14787            .arg(exp_avg_sq.inner_mut())
14788            .arg(&beta1)
14789            .arg(&beta2)
14790            .arg(&lr)
14791            .arg(&eps)
14792            .arg(&bc1)
14793            .arg(&bc2)
14794            .arg(&weight_decay)
14795            .arg(&n_u32)
14796            .launch(cfg)?;
14797    }
14798
14799    Ok(())
14800}
14801
14802/// Stub -- always returns [`GpuError::NoCudaFeature`].
14803#[cfg(not(feature = "cuda"))]
14804#[allow(clippy::too_many_arguments)]
14805pub fn gpu_fused_adam(
14806    _param: &mut CudaBuffer<f32>,
14807    _grad: &CudaBuffer<f32>,
14808    _exp_avg: &mut CudaBuffer<f32>,
14809    _exp_avg_sq: &mut CudaBuffer<f32>,
14810    _beta1: f32,
14811    _beta2: f32,
14812    _lr: f32,
14813    _eps: f32,
14814    _bc1: f32,
14815    _bc2: f32,
14816    _weight_decay: f32,
14817    _device: &GpuDevice,
14818) -> GpuResult<()> {
14819    Err(GpuError::NoCudaFeature)
14820}
14821
14822// ---------------------------------------------------------------------------
14823// Public API -- fused GRU cell
14824// ---------------------------------------------------------------------------
14825
14826/// Fused GRU cell forward: takes pre-computed gate matrices and produces
14827/// new hidden state + workspace for backward.
14828///
14829/// Inputs:
14830/// - `input_gates`: `[batch, 3*hsz]` — result of `x @ W_ih^T`
14831/// - `hidden_gates`: `[batch, 3*hsz]` — result of `h @ W_hh^T`
14832/// - `bias_ih`: `[3*hsz]` — input bias
14833/// - `bias_hh`: `[3*hsz]` — hidden bias
14834/// - `hx`: `[batch, hsz]` — previous hidden state
14835///
14836/// Outputs:
14837/// - `hy`: `[batch, hsz]` — new hidden state
14838/// - `workspace`: `[batch, 5*hsz]` — saved for backward (r, z, n, hx, hn+b2n)
14839#[cfg(feature = "cuda")]
14840pub fn gpu_fused_gru_forward(
14841    input_gates: &CudaBuffer<f32>,
14842    hidden_gates: &CudaBuffer<f32>,
14843    bias_ih: &CudaBuffer<f32>,
14844    bias_hh: &CudaBuffer<f32>,
14845    hx: &CudaBuffer<f32>,
14846    hsz: usize,
14847    device: &GpuDevice,
14848) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
14849    use cudarc::driver::PushKernelArg;
14850
14851    let total = hx.len(); // batch * hsz
14852    let batch = total / hsz;
14853
14854    let ctx = device.context();
14855    let stream = device.stream();
14856
14857    let f = match crate::module_cache::get_or_compile(
14858        ctx,
14859        FUSED_GRU_FORWARD_PTX,
14860        "fused_gru_forward_kernel",
14861        device.ordinal() as u32,
14862    ) {
14863        Ok(f) => f,
14864        Err(_) => {
14865            return Err(GpuError::PtxCompileFailed {
14866                kernel: "fused_gru_forward_kernel",
14867            });
14868        }
14869    };
14870
14871    let mut hy = alloc_zeros_f32(total, device)?;
14872    let mut workspace = alloc_zeros_f32(batch * 5 * hsz, device)?;
14873
14874    let cfg = launch_cfg(total)?;
14875    let hsz_u32 = hsz as u32;
14876    let total_u32 = total as u32;
14877
14878    unsafe {
14879        stream
14880            .launch_builder(&f)
14881            .arg(input_gates.inner())
14882            .arg(hidden_gates.inner())
14883            .arg(bias_ih.inner())
14884            .arg(bias_hh.inner())
14885            .arg(hx.inner())
14886            .arg(hy.inner_mut())
14887            .arg(workspace.inner_mut())
14888            .arg(&hsz_u32)
14889            .arg(&total_u32)
14890            .launch(cfg)?;
14891    }
14892
14893    Ok((hy, workspace))
14894}
14895
14896/// Stub -- always returns [`GpuError::NoCudaFeature`].
14897#[cfg(not(feature = "cuda"))]
14898pub fn gpu_fused_gru_forward(
14899    _input_gates: &CudaBuffer<f32>,
14900    _hidden_gates: &CudaBuffer<f32>,
14901    _bias_ih: &CudaBuffer<f32>,
14902    _bias_hh: &CudaBuffer<f32>,
14903    _hx: &CudaBuffer<f32>,
14904    _hsz: usize,
14905    _device: &GpuDevice,
14906) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
14907    Err(GpuError::NoCudaFeature)
14908}
14909
14910// ---------------------------------------------------------------------------
14911// Public API -- MaxPool2d / AvgPool2d
14912// ---------------------------------------------------------------------------
14913
14914/// MaxPool2d forward on GPU. One thread per output element.
14915#[cfg(feature = "cuda")]
14916#[allow(clippy::too_many_arguments)]
14917pub fn gpu_maxpool2d(
14918    input: &CudaBuffer<f32>,
14919    batch: usize,
14920    channels: usize,
14921    h_in: usize,
14922    w_in: usize,
14923    kh: usize,
14924    kw: usize,
14925    sh: usize,
14926    sw: usize,
14927    ph: usize,
14928    pw: usize,
14929    device: &GpuDevice,
14930) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
14931    use cudarc::driver::PushKernelArg;
14932
14933    let h_out = (h_in + 2 * ph - kh) / sh + 1;
14934    let w_out = (w_in + 2 * pw - kw) / sw + 1;
14935    let total = batch * channels * h_out * w_out;
14936
14937    let ctx = device.context();
14938    let stream = device.stream();
14939
14940    let f = match crate::module_cache::get_or_compile(
14941        ctx, MAXPOOL2D_PTX, "maxpool2d_forward_kernel", device.ordinal() as u32,
14942    ) {
14943        Ok(f) => f,
14944        Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "maxpool2d_forward_kernel" }),
14945    };
14946
14947    let mut out = alloc_zeros_f32(total, device)?;
14948    let cfg = launch_cfg(total)?;
14949
14950    let (batch_u32, ch_u32) = (batch as u32, channels as u32);
14951    let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
14952    let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
14953    let (kh_u32, kw_u32) = (kh as u32, kw as u32);
14954    let (sh_u32, sw_u32) = (sh as u32, sw as u32);
14955    let (ph_u32, pw_u32) = (ph as u32, pw as u32);
14956    let total_u32 = total as u32;
14957
14958    unsafe {
14959        stream.launch_builder(&f)
14960            .arg(input.inner())
14961            .arg(out.inner_mut())
14962            .arg(&batch_u32).arg(&ch_u32)
14963            .arg(&h_in_u32).arg(&w_in_u32)
14964            .arg(&h_out_u32).arg(&w_out_u32)
14965            .arg(&kh_u32).arg(&kw_u32)
14966            .arg(&sh_u32).arg(&sw_u32)
14967            .arg(&ph_u32).arg(&pw_u32)
14968            .arg(&total_u32)
14969            .launch(cfg)?;
14970    }
14971
14972    Ok((out, [batch, channels, h_out, w_out]))
14973}
14974
14975/// Stub.
14976#[cfg(not(feature = "cuda"))]
14977#[allow(clippy::too_many_arguments)]
14978pub fn gpu_maxpool2d(
14979    _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
14980    _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
14981    _sh: usize, _sw: usize, _ph: usize, _pw: usize,
14982    _device: &GpuDevice,
14983) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
14984    Err(GpuError::NoCudaFeature)
14985}
14986
14987/// AvgPool2d forward on GPU. One thread per output element.
14988#[cfg(feature = "cuda")]
14989#[allow(clippy::too_many_arguments)]
14990pub fn gpu_avgpool2d(
14991    input: &CudaBuffer<f32>,
14992    batch: usize,
14993    channels: usize,
14994    h_in: usize,
14995    w_in: usize,
14996    kh: usize,
14997    kw: usize,
14998    sh: usize,
14999    sw: usize,
15000    ph: usize,
15001    pw: usize,
15002    device: &GpuDevice,
15003) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
15004    use cudarc::driver::PushKernelArg;
15005
15006    let h_out = (h_in + 2 * ph - kh) / sh + 1;
15007    let w_out = (w_in + 2 * pw - kw) / sw + 1;
15008    let total = batch * channels * h_out * w_out;
15009
15010    let ctx = device.context();
15011    let stream = device.stream();
15012
15013    let f = match crate::module_cache::get_or_compile(
15014        ctx, AVGPOOL2D_PTX, "avgpool2d_forward_kernel", device.ordinal() as u32,
15015    ) {
15016        Ok(f) => f,
15017        Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "avgpool2d_forward_kernel" }),
15018    };
15019
15020    let mut out = alloc_zeros_f32(total, device)?;
15021    let cfg = launch_cfg(total)?;
15022
15023    let (batch_u32, ch_u32) = (batch as u32, channels as u32);
15024    let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
15025    let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
15026    let (kh_u32, kw_u32) = (kh as u32, kw as u32);
15027    let (sh_u32, sw_u32) = (sh as u32, sw as u32);
15028    let (ph_u32, pw_u32) = (ph as u32, pw as u32);
15029    let total_u32 = total as u32;
15030
15031    unsafe {
15032        stream.launch_builder(&f)
15033            .arg(input.inner())
15034            .arg(out.inner_mut())
15035            .arg(&batch_u32).arg(&ch_u32)
15036            .arg(&h_in_u32).arg(&w_in_u32)
15037            .arg(&h_out_u32).arg(&w_out_u32)
15038            .arg(&kh_u32).arg(&kw_u32)
15039            .arg(&sh_u32).arg(&sw_u32)
15040            .arg(&ph_u32).arg(&pw_u32)
15041            .arg(&total_u32)
15042            .launch(cfg)?;
15043    }
15044
15045    Ok((out, [batch, channels, h_out, w_out]))
15046}
15047
15048/// Stub.
15049#[cfg(not(feature = "cuda"))]
15050#[allow(clippy::too_many_arguments)]
15051pub fn gpu_avgpool2d(
15052    _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
15053    _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
15054    _sh: usize, _sw: usize, _ph: usize, _pw: usize,
15055    _device: &GpuDevice,
15056) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
15057    Err(GpuError::NoCudaFeature)
15058}
15059
15060// ---------------------------------------------------------------------------
15061// Public API -- BatchNorm2d
15062// ---------------------------------------------------------------------------
15063
15064/// BatchNorm2d forward on GPU (placeholder — kernel pass-1 indexing needs
15065/// refinement). Currently validates the kernel compiles and falls back to
15066/// returning an error so callers use the CPU path.
15067#[cfg(feature = "cuda")]
15068#[allow(clippy::too_many_arguments)]
15069pub fn gpu_batchnorm_forward(
15070    _input: &CudaBuffer<f32>,
15071    _weight: &CudaBuffer<f32>,
15072    _bias: &CudaBuffer<f32>,
15073    _running_mean: &mut CudaBuffer<f32>,
15074    _running_var: &mut CudaBuffer<f32>,
15075    _channels: usize,
15076    _spatial: usize,
15077    _eps: f32,
15078    _momentum: f32,
15079    _training: bool,
15080    device: &GpuDevice,
15081) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
15082    // Validate the PTX compiles (catches syntax errors at first call).
15083    let ctx = device.context();
15084    let _f = crate::module_cache::get_or_compile(
15085        ctx,
15086        BATCHNORM_FORWARD_PTX,
15087        "batchnorm_forward_kernel",
15088        device.ordinal() as u32,
15089    );
15090    // Full implementation pending — pass-1 loop indexing needs refinement.
15091    Err(GpuError::ShapeMismatch {
15092        op: "batchnorm_forward",
15093        expected: vec![0],
15094        got: vec![1],
15095    })
15096}
15097
15098/// Stub.
15099#[cfg(not(feature = "cuda"))]
15100#[allow(clippy::too_many_arguments)]
15101pub fn gpu_batchnorm_forward(
15102    _input: &CudaBuffer<f32>,
15103    _weight: &CudaBuffer<f32>,
15104    _bias: &CudaBuffer<f32>,
15105    _running_mean: &mut CudaBuffer<f32>,
15106    _running_var: &mut CudaBuffer<f32>,
15107    _channels: usize,
15108    _spatial: usize,
15109    _eps: f32,
15110    _momentum: f32,
15111    _training: bool,
15112    _device: &GpuDevice,
15113) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
15114    Err(GpuError::NoCudaFeature)
15115}
15116
15117// ---------------------------------------------------------------------------
15118// Public API -- LayerNorm
15119// ---------------------------------------------------------------------------
15120
15121/// Row-wise layer normalization on GPU.
15122///
15123/// `input`: `[rows * cols]`, `weight`/`bias`: `[cols]`.
15124/// Output: normalized and affine-transformed `[rows * cols]`.
15125#[cfg(feature = "cuda")]
15126pub fn gpu_layernorm(
15127    input: &CudaBuffer<f32>,
15128    weight: &CudaBuffer<f32>,
15129    bias: &CudaBuffer<f32>,
15130    rows: usize,
15131    cols: usize,
15132    eps: f32,
15133    device: &GpuDevice,
15134) -> GpuResult<CudaBuffer<f32>> {
15135    use cudarc::driver::PushKernelArg;
15136
15137    validate_unary(input, device)?;
15138
15139    let ctx = device.context();
15140    let stream = device.stream();
15141
15142    let f = match crate::module_cache::get_or_compile(
15143        ctx,
15144        LAYERNORM_PTX,
15145        "layernorm_kernel",
15146        device.ordinal() as u32,
15147    ) {
15148        Ok(f) => f,
15149        Err(e) => {
15150            eprintln!("ferrotorch-gpu: LayerNorm PTX compilation failed ({e:?}), CPU fallback");
15151            std::fs::write("/tmp/layernorm_debug.ptx", LAYERNORM_PTX).ok();
15152            eprintln!(
15153                "ferrotorch-gpu: dumped PTX to /tmp/layernorm_debug.ptx ({} bytes)",
15154                LAYERNORM_PTX.len()
15155            );
15156            let h_in = gpu_to_cpu(input, device)?;
15157            let h_w = gpu_to_cpu(weight, device)?;
15158            let h_b = gpu_to_cpu(bias, device)?;
15159            let mut out = vec![0.0f32; rows * cols];
15160            for r in 0..rows {
15161                let base = r * cols;
15162                let slice = &h_in[base..base + cols];
15163                let mean: f32 = slice.iter().sum::<f32>() / cols as f32;
15164                let var: f32 =
15165                    slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / cols as f32;
15166                let inv_std = 1.0 / (var + eps).sqrt();
15167                for c in 0..cols {
15168                    let normed = (slice[c] - mean) * inv_std;
15169                    out[base + c] = h_w[c] * normed + h_b[c];
15170                }
15171            }
15172            return cpu_to_gpu(&out, device);
15173        }
15174    };
15175
15176    let mut out = alloc_zeros_f32(rows * cols, device)?;
15177    let rows_u32 = rows as u32;
15178    let cols_u32 = cols as u32;
15179
15180    let cfg = LaunchConfig {
15181        grid_dim: ((rows as u32).max(1), 1, 1),
15182        block_dim: (256, 1, 1),
15183        shared_mem_bytes: 256 * 4,
15184    };
15185
15186    unsafe {
15187        stream
15188            .launch_builder(&f)
15189            .arg(input.inner())
15190            .arg(out.inner_mut())
15191            .arg(weight.inner())
15192            .arg(bias.inner())
15193            .arg(&rows_u32)
15194            .arg(&cols_u32)
15195            .arg(&eps)
15196            .launch(cfg)?;
15197    }
15198
15199    Ok(out)
15200}
15201
15202// ---------------------------------------------------------------------------
15203// Public API -- LayerNorm backward
15204// ---------------------------------------------------------------------------
15205
15206/// LayerNorm backward pass on GPU.
15207///
15208/// Computes grad_input, grad_weight, and grad_bias entirely on GPU.
15209/// One block per batch element (row), 256 threads per block.
15210/// grad_weight and grad_bias are accumulated across batches via atomicAdd.
15211///
15212/// `input`: `[rows * cols]`, `grad_output`: `[rows * cols]`, `weight`: `[cols]`.
15213/// Returns: `(grad_input [rows * cols], grad_weight [cols], grad_bias [cols])`.
15214#[cfg(feature = "cuda")]
15215pub fn gpu_layernorm_backward(
15216    input: &CudaBuffer<f32>,
15217    grad_output: &CudaBuffer<f32>,
15218    weight: &CudaBuffer<f32>,
15219    rows: usize,
15220    cols: usize,
15221    eps: f32,
15222    device: &GpuDevice,
15223) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
15224    use cudarc::driver::PushKernelArg;
15225
15226    validate_unary(input, device)?;
15227
15228    let ctx = device.context();
15229    let stream = device.stream();
15230
15231    let f = match crate::module_cache::get_or_compile(
15232        ctx,
15233        LAYERNORM_BACKWARD_PTX,
15234        "layernorm_backward_kernel",
15235        device.ordinal() as u32,
15236    ) {
15237        Ok(f) => f,
15238        Err(_) => {
15239            // CPU fallback
15240            let h_in = gpu_to_cpu(input, device)?;
15241            let h_go = gpu_to_cpu(grad_output, device)?;
15242            let h_w = gpu_to_cpu(weight, device)?;
15243            let mut grad_input = vec![0.0f32; rows * cols];
15244            let mut grad_weight = vec![0.0f32; cols];
15245            let mut grad_bias = vec![0.0f32; cols];
15246            let n_f = cols as f32;
15247            for r in 0..rows {
15248                let base = r * cols;
15249                let x_slice = &h_in[base..base + cols];
15250                let go_slice = &h_go[base..base + cols];
15251                let mean: f32 = x_slice.iter().sum::<f32>() / n_f;
15252                let var: f32 = x_slice
15253                    .iter()
15254                    .map(|&x| (x - mean) * (x - mean))
15255                    .sum::<f32>()
15256                    / n_f;
15257                let inv_std = 1.0 / (var + eps).sqrt();
15258                let mut sum1 = 0.0f32;
15259                let mut sum2 = 0.0f32;
15260                for c in 0..cols {
15261                    let x_hat = (x_slice[c] - mean) * inv_std;
15262                    let dl = go_slice[c] * h_w[c];
15263                    sum1 += dl;
15264                    sum2 += dl * x_hat;
15265                    grad_weight[c] += go_slice[c] * x_hat;
15266                    grad_bias[c] += go_slice[c];
15267                }
15268                let m1 = sum1 / n_f;
15269                let m2 = sum2 / n_f;
15270                for c in 0..cols {
15271                    let x_hat = (x_slice[c] - mean) * inv_std;
15272                    let dl = go_slice[c] * h_w[c];
15273                    grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
15274                }
15275            }
15276            let gi = cpu_to_gpu(&grad_input, device)?;
15277            let gw = cpu_to_gpu(&grad_weight, device)?;
15278            let gb = cpu_to_gpu(&grad_bias, device)?;
15279            return Ok((gi, gw, gb));
15280        }
15281    };
15282
15283    let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
15284    let mut grad_w = alloc_zeros_f32(cols, device)?;
15285    let mut grad_b = alloc_zeros_f32(cols, device)?;
15286    let rows_u32 = rows as u32;
15287    let cols_u32 = cols as u32;
15288
15289    // One block per row, 256 threads per block.
15290    let cfg = LaunchConfig {
15291        grid_dim: ((rows as u32).max(1), 1, 1),
15292        block_dim: (256, 1, 1),
15293        shared_mem_bytes: 256 * 4,
15294    };
15295
15296    unsafe {
15297        stream
15298            .launch_builder(&f)
15299            .arg(input.inner())
15300            .arg(grad_output.inner())
15301            .arg(weight.inner())
15302            .arg(grad_in.inner_mut())
15303            .arg(grad_w.inner_mut())
15304            .arg(grad_b.inner_mut())
15305            .arg(&rows_u32)
15306            .arg(&cols_u32)
15307            .arg(&eps)
15308            .launch(cfg)?;
15309    }
15310
15311    Ok((grad_in, grad_w, grad_b))
15312}
15313
15314/// Stub -- always returns [`GpuError::NoCudaFeature`].
15315#[cfg(not(feature = "cuda"))]
15316pub fn gpu_layernorm_backward(
15317    _input: &CudaBuffer<f32>,
15318    _grad_output: &CudaBuffer<f32>,
15319    _weight: &CudaBuffer<f32>,
15320    _rows: usize,
15321    _cols: usize,
15322    _eps: f32,
15323    _device: &GpuDevice,
15324) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
15325    Err(GpuError::NoCudaFeature)
15326}
15327
15328// ---------------------------------------------------------------------------
15329// Public API -- RMSNorm
15330// ---------------------------------------------------------------------------
15331
15332/// Row-wise RMS normalization on GPU.
15333///
15334/// `input`: `[rows * cols]`, `weight`: `[cols]`.
15335/// Output: normalized and scaled `[rows * cols]`.
15336///
15337/// Computes `out[j] = x[j] * rsqrt(mean(x^2) + eps) * weight[j]`.
15338/// No bias, no mean centering (unlike LayerNorm).
15339#[cfg(feature = "cuda")]
15340pub fn gpu_rmsnorm(
15341    input: &CudaBuffer<f32>,
15342    weight: &CudaBuffer<f32>,
15343    rows: usize,
15344    cols: usize,
15345    eps: f32,
15346    device: &GpuDevice,
15347) -> GpuResult<CudaBuffer<f32>> {
15348    use cudarc::driver::PushKernelArg;
15349
15350    validate_unary(input, device)?;
15351
15352    let ctx = device.context();
15353    let stream = device.stream();
15354
15355    let f = match crate::module_cache::get_or_compile(
15356        ctx,
15357        RMSNORM_PTX,
15358        "rmsnorm_kernel",
15359        device.ordinal() as u32,
15360    ) {
15361        Ok(f) => f,
15362        Err(e) => {
15363            eprintln!("ferrotorch-gpu: RMSNorm PTX compilation failed ({e:?}), CPU fallback");
15364            std::fs::write("/tmp/rmsnorm_debug.ptx", RMSNORM_PTX).ok();
15365            eprintln!(
15366                "ferrotorch-gpu: dumped PTX to /tmp/rmsnorm_debug.ptx ({} bytes)",
15367                RMSNORM_PTX.len()
15368            );
15369            let h_in = gpu_to_cpu(input, device)?;
15370            let h_w = gpu_to_cpu(weight, device)?;
15371            let mut out = vec![0.0f32; rows * cols];
15372            for r in 0..rows {
15373                let base = r * cols;
15374                let slice = &h_in[base..base + cols];
15375                let sq_mean: f32 =
15376                    slice.iter().map(|&x| x * x).sum::<f32>() / cols as f32;
15377                let inv_rms = 1.0 / (sq_mean + eps).sqrt();
15378                for c in 0..cols {
15379                    out[base + c] = slice[c] * inv_rms * h_w[c];
15380                }
15381            }
15382            return cpu_to_gpu(&out, device);
15383        }
15384    };
15385
15386    let mut out = alloc_zeros_f32(rows * cols, device)?;
15387    let rows_u32 = rows as u32;
15388    let cols_u32 = cols as u32;
15389
15390    let cfg = LaunchConfig {
15391        grid_dim: ((rows as u32).max(1), 1, 1),
15392        block_dim: (256, 1, 1),
15393        shared_mem_bytes: 256 * 4,
15394    };
15395
15396    unsafe {
15397        stream
15398            .launch_builder(&f)
15399            .arg(input.inner())
15400            .arg(out.inner_mut())
15401            .arg(weight.inner())
15402            .arg(&rows_u32)
15403            .arg(&cols_u32)
15404            .arg(&eps)
15405            .launch(cfg)?;
15406    }
15407
15408    Ok(out)
15409}
15410
15411// ---------------------------------------------------------------------------
15412// Public API -- RMSNorm backward
15413// ---------------------------------------------------------------------------
15414
15415/// RMSNorm backward pass on GPU.
15416///
15417/// Computes grad_input and grad_weight entirely on GPU.
15418/// One block per batch element (row), 256 threads per block.
15419/// grad_weight is accumulated across batches via atomicAdd.
15420///
15421/// `input`: `[rows * cols]`, `grad_output`: `[rows * cols]`, `weight`: `[cols]`.
15422/// Returns: `(grad_input [rows * cols], grad_weight [cols])`.
15423#[cfg(feature = "cuda")]
15424pub fn gpu_rmsnorm_backward(
15425    input: &CudaBuffer<f32>,
15426    grad_output: &CudaBuffer<f32>,
15427    weight: &CudaBuffer<f32>,
15428    rows: usize,
15429    cols: usize,
15430    eps: f32,
15431    device: &GpuDevice,
15432) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
15433    use cudarc::driver::PushKernelArg;
15434
15435    validate_unary(input, device)?;
15436
15437    let ctx = device.context();
15438    let stream = device.stream();
15439
15440    let f = match crate::module_cache::get_or_compile(
15441        ctx,
15442        RMSNORM_BACKWARD_PTX,
15443        "rmsnorm_backward_kernel",
15444        device.ordinal() as u32,
15445    ) {
15446        Ok(f) => f,
15447        Err(_) => {
15448            // CPU fallback
15449            let h_in = gpu_to_cpu(input, device)?;
15450            let h_go = gpu_to_cpu(grad_output, device)?;
15451            let h_w = gpu_to_cpu(weight, device)?;
15452            let mut grad_input = vec![0.0f32; rows * cols];
15453            let mut grad_weight = vec![0.0f32; cols];
15454            let n_f = cols as f32;
15455            for r in 0..rows {
15456                let base = r * cols;
15457                let x_slice = &h_in[base..base + cols];
15458                let go_slice = &h_go[base..base + cols];
15459                let sq_mean: f32 =
15460                    x_slice.iter().map(|&x| x * x).sum::<f32>() / n_f;
15461                let inv_rms = 1.0 / (sq_mean + eps).sqrt();
15462                let inv_rms3 = inv_rms * inv_rms * inv_rms;
15463                let mut dot = 0.0f32;
15464                for c in 0..cols {
15465                    dot += go_slice[c] * x_slice[c] * h_w[c];
15466                    grad_weight[c] += go_slice[c] * x_slice[c] * inv_rms;
15467                }
15468                let coeff = dot * inv_rms3 / n_f;
15469                for c in 0..cols {
15470                    grad_input[base + c] =
15471                        inv_rms * h_w[c] * go_slice[c] - x_slice[c] * coeff;
15472                }
15473            }
15474            let gi = cpu_to_gpu(&grad_input, device)?;
15475            let gw = cpu_to_gpu(&grad_weight, device)?;
15476            return Ok((gi, gw));
15477        }
15478    };
15479
15480    let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
15481    let mut grad_w = alloc_zeros_f32(cols, device)?;
15482    let rows_u32 = rows as u32;
15483    let cols_u32 = cols as u32;
15484
15485    // One block per row, 256 threads per block.
15486    let cfg = LaunchConfig {
15487        grid_dim: ((rows as u32).max(1), 1, 1),
15488        block_dim: (256, 1, 1),
15489        shared_mem_bytes: 256 * 4,
15490    };
15491
15492    unsafe {
15493        stream
15494            .launch_builder(&f)
15495            .arg(input.inner())
15496            .arg(grad_output.inner())
15497            .arg(weight.inner())
15498            .arg(grad_in.inner_mut())
15499            .arg(grad_w.inner_mut())
15500            .arg(&rows_u32)
15501            .arg(&cols_u32)
15502            .arg(&eps)
15503            .launch(cfg)?;
15504    }
15505
15506    Ok((grad_in, grad_w))
15507}
15508
15509/// Stub -- always returns [`GpuError::NoCudaFeature`].
15510#[cfg(not(feature = "cuda"))]
15511pub fn gpu_rmsnorm(
15512    _input: &CudaBuffer<f32>,
15513    _weight: &CudaBuffer<f32>,
15514    _rows: usize,
15515    _cols: usize,
15516    _eps: f32,
15517    _device: &GpuDevice,
15518) -> GpuResult<CudaBuffer<f32>> {
15519    Err(GpuError::NoCudaFeature)
15520}
15521
15522/// Stub -- always returns [`GpuError::NoCudaFeature`].
15523#[cfg(not(feature = "cuda"))]
15524pub fn gpu_rmsnorm_backward(
15525    _input: &CudaBuffer<f32>,
15526    _grad_output: &CudaBuffer<f32>,
15527    _weight: &CudaBuffer<f32>,
15528    _rows: usize,
15529    _cols: usize,
15530    _eps: f32,
15531    _device: &GpuDevice,
15532) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
15533    Err(GpuError::NoCudaFeature)
15534}
15535
15536// ===========================================================================
15537// _into variants — write to pre-allocated output buffers (zero allocation)
15538//
15539// These are used for CUDA graph capture, where all buffer addresses must be
15540// fixed at capture time. The PTX kernels are identical — only the Rust
15541// wrapper skips allocation.
15542// ===========================================================================
15543
15544/// Elementwise add into pre-allocated output: `out[i] = a[i] + b[i]`.
15545#[cfg(feature = "cuda")]
15546pub fn gpu_add_into(
15547    a: &CudaBuffer<f32>,
15548    b: &CudaBuffer<f32>,
15549    out: &mut CudaBuffer<f32>,
15550    device: &GpuDevice,
15551) -> GpuResult<()> {
15552    validate_binary(a, b, device)?;
15553    if out.len() < a.len() {
15554        return Err(GpuError::ShapeMismatch {
15555            op: "add_into",
15556            expected: vec![a.len()],
15557            got: vec![out.len()],
15558        });
15559    }
15560    if try_launch_binary_into(a, b, out, device, ADD_PTX, "add_kernel")? {
15561        return Ok(());
15562    }
15563    Err(GpuError::PtxCompileFailed {
15564        kernel: "add_kernel",
15565    })
15566}
15567
15568/// Elementwise mul into pre-allocated output: `out[i] = a[i] * b[i]`.
15569#[cfg(feature = "cuda")]
15570pub fn gpu_mul_into(
15571    a: &CudaBuffer<f32>,
15572    b: &CudaBuffer<f32>,
15573    out: &mut CudaBuffer<f32>,
15574    device: &GpuDevice,
15575) -> GpuResult<()> {
15576    validate_binary(a, b, device)?;
15577    if out.len() < a.len() {
15578        return Err(GpuError::ShapeMismatch {
15579            op: "mul_into",
15580            expected: vec![a.len()],
15581            got: vec![out.len()],
15582        });
15583    }
15584    if try_launch_binary_into(a, b, out, device, MUL_PTX, "mul_kernel")? {
15585        return Ok(());
15586    }
15587    Err(GpuError::PtxCompileFailed {
15588        kernel: "mul_kernel",
15589    })
15590}
15591
15592/// Scalar multiply into pre-allocated output: `out[i] = a[i] * scalar`.
15593#[cfg(feature = "cuda")]
15594pub fn gpu_scale_into(
15595    a: &CudaBuffer<f32>,
15596    scalar: f32,
15597    out: &mut CudaBuffer<f32>,
15598    device: &GpuDevice,
15599) -> GpuResult<()> {
15600    use cudarc::driver::PushKernelArg;
15601    validate_unary(a, device)?;
15602    let n = a.len();
15603    let ctx = device.context();
15604    let stream = device.stream();
15605    let f = crate::module_cache::get_or_compile(
15606        ctx,
15607        SCALE_PTX,
15608        "scale_kernel",
15609        device.ordinal() as u32,
15610    )
15611    .map_err(|_| GpuError::PtxCompileFailed {
15612        kernel: "scale_kernel",
15613    })?;
15614    let cfg = launch_cfg(n)?;
15615    let n_u32 = n as u32;
15616    unsafe {
15617        stream
15618            .launch_builder(&f)
15619            .arg(a.inner())
15620            .arg(out.inner_mut())
15621            .arg(&scalar)
15622            .arg(&n_u32)
15623            .launch(cfg)?;
15624    }
15625    Ok(())
15626}
15627
15628/// Allocate an `n`-element f32 buffer on `device` filled with `scalar`.
15629///
15630/// Entirely on-device: no CPU→GPU upload beyond the single f32 scalar
15631/// passed as a kernel argument. Used by sum/mean backward to produce
15632/// the constant gradient tensor without the legacy `vec![go;
15633/// numel].to(device)` round-trip.
15634#[cfg(feature = "cuda")]
15635pub fn gpu_fill_f32(
15636    n: usize,
15637    scalar: f32,
15638    device: &GpuDevice,
15639) -> GpuResult<CudaBuffer<f32>> {
15640    use cudarc::driver::PushKernelArg;
15641
15642    let ctx = device.context();
15643    let stream = device.stream();
15644    let f = crate::module_cache::get_or_compile(
15645        ctx,
15646        FILL_F32_PTX,
15647        "fill_f32_kernel",
15648        device.ordinal() as u32,
15649    )
15650    .map_err(|_| GpuError::PtxCompileFailed {
15651        kernel: "fill_f32_kernel",
15652    })?;
15653
15654    let mut out = alloc_zeros_f32(n, device)?;
15655    if n == 0 {
15656        return Ok(out);
15657    }
15658    let cfg = launch_cfg(n)?;
15659    let n_u32 = n as u32;
15660    unsafe {
15661        stream
15662            .launch_builder(&f)
15663            .arg(out.inner_mut())
15664            .arg(&scalar)
15665            .arg(&n_u32)
15666            .launch(cfg)?;
15667    }
15668    Ok(out)
15669}
15670
15671/// Check whether a GPU buffer contains any inf or NaN values.
15672///
15673/// Downloads the buffer contents to the host and scans for non-finite
15674/// values. This is correct for any buffer size and requires no custom
15675/// reduction kernel.
15676///
15677/// For a future optimization, a dedicated GPU reduction kernel could be
15678/// used to produce a single boolean flag on device, avoiding the full
15679/// download. The current approach is already much faster than the old
15680/// per-element CPU loop in `unscale_()` because the scaling itself
15681/// runs on GPU — only the inf/NaN check touches the host.
15682///
15683/// # Errors
15684///
15685/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different CUDA devices.
15686/// - [`GpuError::Driver`] on CUDA runtime errors.
15687#[cfg(feature = "cuda")]
15688pub fn gpu_has_inf_nan(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<bool> {
15689    let n = a.len();
15690    if n == 0 {
15691        return Ok(false);
15692    }
15693
15694    validate_unary(a, device)?;
15695
15696    let host: Vec<f32> = crate::transfer::gpu_to_cpu(a, device)?;
15697    Ok(host.iter().any(|v| !v.is_finite()))
15698}
15699
15700/// Stub -- always returns [`GpuError::NoCudaFeature`].
15701#[cfg(not(feature = "cuda"))]
15702pub fn gpu_has_inf_nan(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<bool> {
15703    Err(GpuError::NoCudaFeature)
15704}
15705
15706/// GELU into pre-allocated output.
15707#[cfg(feature = "cuda")]
15708pub fn gpu_gelu_into(
15709    a: &CudaBuffer<f32>,
15710    out: &mut CudaBuffer<f32>,
15711    device: &GpuDevice,
15712) -> GpuResult<()> {
15713    validate_unary(a, device)?;
15714    if try_launch_unary_into(a, out, device, GELU_PTX, "gelu_kernel")? {
15715        return Ok(());
15716    }
15717    Err(GpuError::PtxCompileFailed {
15718        kernel: "gelu_kernel",
15719    })
15720}
15721
15722/// Embedding lookup into pre-allocated output.
15723#[cfg(feature = "cuda")]
15724pub fn gpu_embed_lookup_into(
15725    idx: &CudaBuffer<f32>,
15726    weight: &CudaBuffer<f32>,
15727    d: usize,
15728    out: &mut CudaBuffer<f32>,
15729    device: &GpuDevice,
15730) -> GpuResult<()> {
15731    use cudarc::driver::PushKernelArg;
15732    let ctx = device.context();
15733    let stream = device.stream();
15734    let f = crate::module_cache::get_or_compile(
15735        ctx,
15736        EMBED_LOOKUP_PTX,
15737        "embed_lookup_kernel",
15738        device.ordinal() as u32,
15739    )
15740    .map_err(|_| GpuError::PtxCompileFailed {
15741        kernel: "embed_lookup_kernel",
15742    })?;
15743    let cfg = launch_cfg(d)?;
15744    let d_u32 = d as u32;
15745    unsafe {
15746        stream
15747            .launch_builder(&f)
15748            .arg(idx.inner())
15749            .arg(weight.inner())
15750            .arg(out.inner_mut())
15751            .arg(&d_u32)
15752            .launch(cfg)?;
15753    }
15754    Ok(())
15755}
15756
15757// ---------------------------------------------------------------------------
15758// Public API -- Batch embedding lookup (GPU-native)
15759// ---------------------------------------------------------------------------
15760
15761/// GPU batch embedding lookup: given `indices` (N f32 values on GPU) and
15762/// `weight` `[V, D]`, gather N rows to produce output `[N, D]`.
15763/// Entire operation stays on GPU -- no CPU roundtrip.
15764#[cfg(feature = "cuda")]
15765pub fn gpu_embed_lookup_batch(
15766    indices: &CudaBuffer<f32>,
15767    weight: &CudaBuffer<f32>,
15768    n: usize,
15769    d: usize,
15770    device: &GpuDevice,
15771) -> GpuResult<CudaBuffer<f32>> {
15772    use cudarc::driver::PushKernelArg;
15773
15774    let total = n * d;
15775    if total == 0 {
15776        return alloc_zeros_f32(0, device);
15777    }
15778
15779    let ctx = device.context();
15780    let stream = device.stream();
15781
15782    let f = match crate::module_cache::get_or_compile(
15783        ctx,
15784        EMBED_LOOKUP_BATCH_PTX,
15785        "embed_lookup_batch_kernel",
15786        device.ordinal() as u32,
15787    ) {
15788        Ok(f) => f,
15789        Err(_) => {
15790            // CPU fallback.
15791            let idx_host = gpu_to_cpu(indices, device)?;
15792            let weight_host = gpu_to_cpu(weight, device)?;
15793            let mut out = Vec::with_capacity(total);
15794            for &idx_f in &idx_host {
15795                let row = idx_f as usize;
15796                let start = row * d;
15797                out.extend_from_slice(&weight_host[start..start + d]);
15798            }
15799            return cpu_to_gpu(&out, device);
15800        }
15801    };
15802
15803    let mut out = alloc_zeros_f32(total, device)?;
15804    let cfg = launch_cfg(total)?;
15805    let d_u32 = d as u32;
15806    let total_u32 = total as u32;
15807
15808    unsafe {
15809        stream
15810            .launch_builder(&f)
15811            .arg(indices.inner())
15812            .arg(weight.inner())
15813            .arg(out.inner_mut())
15814            .arg(&d_u32)
15815            .arg(&total_u32)
15816            .launch(cfg)?;
15817    }
15818
15819    Ok(out)
15820}
15821
15822// ---------------------------------------------------------------------------
15823// Public API -- Scatter-add rows (for embedding backward, GPU-native)
15824// ---------------------------------------------------------------------------
15825
15826/// GPU scatter-add rows: given `grad_output` `[N, D]` and `indices` `[N]` (f32),
15827/// atomically accumulate into `grad_weight` `[V, D]` (pre-zeroed):
15828///   `grad_weight[indices[i], :] += grad_output[i, :]`
15829///
15830/// Duplicate indices accumulate correctly via atomic adds.
15831#[cfg(feature = "cuda")]
15832pub fn gpu_scatter_add_rows(
15833    grad_output: &CudaBuffer<f32>,
15834    indices: &CudaBuffer<f32>,
15835    num_embeddings: usize,
15836    d: usize,
15837    device: &GpuDevice,
15838) -> GpuResult<CudaBuffer<f32>> {
15839    use cudarc::driver::PushKernelArg;
15840
15841    let n = indices.len();
15842    let total = n * d;
15843
15844    if total == 0 {
15845        return alloc_zeros_f32(num_embeddings * d, device);
15846    }
15847
15848    let ctx = device.context();
15849    let stream = device.stream();
15850
15851    let f = match crate::module_cache::get_or_compile(
15852        ctx,
15853        SCATTER_ADD_ROWS_PTX,
15854        "scatter_add_rows_kernel",
15855        device.ordinal() as u32,
15856    ) {
15857        Ok(f) => f,
15858        Err(_) => {
15859            // CPU fallback.
15860            let go_host = gpu_to_cpu(grad_output, device)?;
15861            let idx_host = gpu_to_cpu(indices, device)?;
15862            let mut result = vec![0.0f32; num_embeddings * d];
15863            for (i, &idx_f) in idx_host.iter().enumerate() {
15864                let row = idx_f as usize;
15865                for j in 0..d {
15866                    result[row * d + j] += go_host[i * d + j];
15867                }
15868            }
15869            return cpu_to_gpu(&result, device);
15870        }
15871    };
15872
15873    let mut out = alloc_zeros_f32(num_embeddings * d, device)?;
15874    let cfg = launch_cfg(total)?;
15875    let d_u32 = d as u32;
15876    let total_u32 = total as u32;
15877
15878    unsafe {
15879        stream
15880            .launch_builder(&f)
15881            .arg(grad_output.inner())
15882            .arg(indices.inner())
15883            .arg(out.inner_mut())
15884            .arg(&d_u32)
15885            .arg(&total_u32)
15886            .launch(cfg)?;
15887    }
15888
15889    Ok(out)
15890}
15891
15892/// 2D transpose into pre-allocated output.
15893#[cfg(feature = "cuda")]
15894pub fn gpu_transpose_2d_into(
15895    a: &CudaBuffer<f32>,
15896    m: usize,
15897    n: usize,
15898    out: &mut CudaBuffer<f32>,
15899    device: &GpuDevice,
15900) -> GpuResult<()> {
15901    use cudarc::driver::PushKernelArg;
15902    let total = m * n;
15903    let ctx = device.context();
15904    let stream = device.stream();
15905    let f = crate::module_cache::get_or_compile(
15906        ctx,
15907        TRANSPOSE_2D_PTX,
15908        "transpose_2d_kernel",
15909        device.ordinal() as u32,
15910    )
15911    .map_err(|_| GpuError::PtxCompileFailed {
15912        kernel: "transpose_2d_kernel",
15913    })?;
15914    let cfg = launch_cfg(total)?;
15915    let m_u32 = m as u32;
15916    let n_u32 = n as u32;
15917    let total_u32 = total as u32;
15918    unsafe {
15919        stream
15920            .launch_builder(&f)
15921            .arg(a.inner())
15922            .arg(out.inner_mut())
15923            .arg(&m_u32)
15924            .arg(&n_u32)
15925            .arg(&total_u32)
15926            .launch(cfg)?;
15927    }
15928    Ok(())
15929}
15930
15931/// Permute (0,2,1,3) into pre-allocated output.
15932#[cfg(feature = "cuda")]
15933pub fn gpu_permute_0213_into(
15934    a: &CudaBuffer<f32>,
15935    d0: usize,
15936    d1: usize,
15937    d2: usize,
15938    d3: usize,
15939    out: &mut CudaBuffer<f32>,
15940    device: &GpuDevice,
15941) -> GpuResult<()> {
15942    use cudarc::driver::PushKernelArg;
15943    let total = d0 * d1 * d2 * d3;
15944    let ctx = device.context();
15945    let stream = device.stream();
15946    let f = crate::module_cache::get_or_compile(
15947        ctx,
15948        PERMUTE_0213_PTX,
15949        "permute_0213_kernel",
15950        device.ordinal() as u32,
15951    )
15952    .map_err(|_| GpuError::PtxCompileFailed {
15953        kernel: "permute_0213_kernel",
15954    })?;
15955    let cfg = launch_cfg(total)?;
15956    let (d0u, d1u, d2u, d3u, tu) = (d0 as u32, d1 as u32, d2 as u32, d3 as u32, total as u32);
15957    unsafe {
15958        stream
15959            .launch_builder(&f)
15960            .arg(a.inner())
15961            .arg(out.inner_mut())
15962            .arg(&d0u)
15963            .arg(&d1u)
15964            .arg(&d2u)
15965            .arg(&d3u)
15966            .arg(&tu)
15967            .launch(cfg)?;
15968    }
15969    Ok(())
15970}
15971
15972/// Softmax into pre-allocated output (row-wise).
15973#[cfg(feature = "cuda")]
15974pub fn gpu_softmax_into(
15975    a: &CudaBuffer<f32>,
15976    rows: usize,
15977    cols: usize,
15978    out: &mut CudaBuffer<f32>,
15979    device: &GpuDevice,
15980) -> GpuResult<()> {
15981    use cudarc::driver::PushKernelArg;
15982    let ctx = device.context();
15983    let stream = device.stream();
15984    let f = crate::module_cache::get_or_compile(
15985        ctx,
15986        SOFTMAX_PTX,
15987        "softmax_kernel",
15988        device.ordinal() as u32,
15989    )
15990    .map_err(|_| GpuError::PtxCompileFailed {
15991        kernel: "softmax_kernel",
15992    })?;
15993    let block_size = 256u32;
15994    let grid_size = rows as u32;
15995    let cfg = LaunchConfig {
15996        grid_dim: (grid_size, 1, 1),
15997        block_dim: (block_size, 1, 1),
15998        shared_mem_bytes: (cols as u32) * 4,
15999    };
16000    let rows_u32 = rows as u32;
16001    let cols_u32 = cols as u32;
16002    unsafe {
16003        stream
16004            .launch_builder(&f)
16005            .arg(a.inner())
16006            .arg(out.inner_mut())
16007            .arg(&rows_u32)
16008            .arg(&cols_u32)
16009            .launch(cfg)?;
16010    }
16011    Ok(())
16012}
16013
16014/// LayerNorm into pre-allocated output.
16015#[cfg(feature = "cuda")]
16016#[allow(clippy::too_many_arguments)]
16017pub fn gpu_layernorm_into(
16018    input: &CudaBuffer<f32>,
16019    weight: &CudaBuffer<f32>,
16020    bias: &CudaBuffer<f32>,
16021    rows: usize,
16022    cols: usize,
16023    eps: f32,
16024    out: &mut CudaBuffer<f32>,
16025    device: &GpuDevice,
16026) -> GpuResult<()> {
16027    use cudarc::driver::PushKernelArg;
16028    let ctx = device.context();
16029    let stream = device.stream();
16030    let f = crate::module_cache::get_or_compile(
16031        ctx,
16032        LAYERNORM_PTX,
16033        "layernorm_kernel",
16034        device.ordinal() as u32,
16035    )
16036    .map_err(|_| GpuError::PtxCompileFailed {
16037        kernel: "layernorm_kernel",
16038    })?;
16039    let block_size = 256u32;
16040    let grid_size = rows as u32;
16041    let cfg = LaunchConfig {
16042        grid_dim: (grid_size, 1, 1),
16043        block_dim: (block_size, 1, 1),
16044        shared_mem_bytes: (cols as u32) * 4,
16045    };
16046    let rows_u32 = rows as u32;
16047    let cols_u32 = cols as u32;
16048    unsafe {
16049        stream
16050            .launch_builder(&f)
16051            .arg(input.inner())
16052            .arg(out.inner_mut())
16053            .arg(weight.inner())
16054            .arg(bias.inner())
16055            .arg(&rows_u32)
16056            .arg(&cols_u32)
16057            .arg(&eps)
16058            .launch(cfg)?;
16059    }
16060    Ok(())
16061}
16062
16063/// Slice read into pre-allocated output: read first `len` rows from
16064/// `[n_batch, max_len, d]` into out `[n_batch, len, d]`.
16065#[cfg(feature = "cuda")]
16066pub fn gpu_slice_read_into(
16067    src: &CudaBuffer<f32>,
16068    n_batch: usize,
16069    d: usize,
16070    len: usize,
16071    max_len: usize,
16072    out: &mut CudaBuffer<f32>,
16073    device: &GpuDevice,
16074) -> GpuResult<()> {
16075    use cudarc::driver::PushKernelArg;
16076    let total = n_batch * len * d;
16077    let ctx = device.context();
16078    let stream = device.stream();
16079    let f = crate::module_cache::get_or_compile(
16080        ctx,
16081        SLICE_READ_PTX,
16082        "slice_read_kernel",
16083        device.ordinal() as u32,
16084    )
16085    .map_err(|_| GpuError::PtxCompileFailed {
16086        kernel: "slice_read_kernel",
16087    })?;
16088    let cfg = launch_cfg(total)?;
16089    let total_u32 = total as u32;
16090    let d_u32 = d as u32;
16091    let len_u32 = len as u32;
16092    let max_len_u32 = max_len as u32;
16093    unsafe {
16094        stream
16095            .launch_builder(&f)
16096            .arg(src.inner())
16097            .arg(out.inner_mut())
16098            .arg(&total_u32)
16099            .arg(&d_u32)
16100            .arg(&len_u32)
16101            .arg(&max_len_u32)
16102            .launch(cfg)?;
16103    }
16104    Ok(())
16105}
16106
16107/// Small matmul (PTX kernel) into pre-allocated output.
16108#[cfg(feature = "cuda")]
16109pub fn gpu_small_matmul_into(
16110    a: &CudaBuffer<f32>,
16111    b: &CudaBuffer<f32>,
16112    m: usize,
16113    k: usize,
16114    n: usize,
16115    out: &mut CudaBuffer<f32>,
16116    device: &GpuDevice,
16117) -> GpuResult<()> {
16118    use cudarc::driver::PushKernelArg;
16119    let total = m * n;
16120    let ctx = device.context();
16121    let stream = device.stream();
16122    let f = crate::module_cache::get_or_compile(
16123        ctx,
16124        SMALL_MATMUL_PTX,
16125        "small_matmul_kernel",
16126        device.ordinal() as u32,
16127    )
16128    .map_err(|_| GpuError::PtxCompileFailed {
16129        kernel: "small_matmul_kernel",
16130    })?;
16131    let cfg = launch_cfg(total)?;
16132    let (m_u32, k_u32, n_u32, total_u32) = (m as u32, k as u32, n as u32, total as u32);
16133    unsafe {
16134        stream
16135            .launch_builder(&f)
16136            .arg(a.inner())
16137            .arg(b.inner())
16138            .arg(out.inner_mut())
16139            .arg(&m_u32)
16140            .arg(&k_u32)
16141            .arg(&n_u32)
16142            .arg(&total_u32)
16143            .launch(cfg)?;
16144    }
16145    Ok(())
16146}
16147
16148// ===========================================================================
16149// Indirect-parameter kernels for CUDA graph capture
16150// ===========================================================================
16151
16152/// Slice write with position read from device memory (for CUDA graph capture).
16153/// Writes `src [n_batch, d]` into row `*pos_ptr` of `dst [n_batch, max_len, d]`.
16154#[cfg(feature = "cuda")]
16155pub fn gpu_slice_write_indirect(
16156    src: &CudaBuffer<f32>,
16157    dst: &mut CudaBuffer<f32>,
16158    n_batch: usize,
16159    d: usize,
16160    max_len: usize,
16161    pos_ptr: &cudarc::driver::CudaSlice<u32>,
16162    device: &GpuDevice,
16163) -> GpuResult<()> {
16164    use cudarc::driver::PushKernelArg;
16165    let total = n_batch * d;
16166    let ctx = device.context();
16167    let stream = device.stream();
16168    let f = crate::module_cache::get_or_compile(
16169        ctx,
16170        SLICE_WRITE_INDIRECT_PTX,
16171        "slice_write_indirect_kernel",
16172        device.ordinal() as u32,
16173    )
16174    .map_err(|_| GpuError::PtxCompileFailed {
16175        kernel: "slice_write_indirect_kernel",
16176    })?;
16177    let cfg = launch_cfg(total)?;
16178    let n_u32 = total as u32;
16179    let d_u32 = d as u32;
16180    let max_len_u32 = max_len as u32;
16181    unsafe {
16182        stream
16183            .launch_builder(&f)
16184            .arg(src.inner())
16185            .arg(dst.inner_mut())
16186            .arg(&n_u32)
16187            .arg(&d_u32)
16188            .arg(&max_len_u32)
16189            .arg(pos_ptr)
16190            .launch(cfg)?;
16191    }
16192    Ok(())
16193}
16194
16195/// Build causal attention mask with total_len read from device memory.
16196/// Writes `out[h, col] = 0.0` if `col < *total_len_ptr`, else `-1e9`.
16197/// Output shape: `[n_head, max_pos]` (n_head rows, each max_pos wide).
16198#[cfg(feature = "cuda")]
16199pub fn gpu_causal_mask_indirect(
16200    total_len_ptr: &cudarc::driver::CudaSlice<u32>,
16201    n_head: usize,
16202    max_pos: usize,
16203    out: &mut CudaBuffer<f32>,
16204    device: &GpuDevice,
16205) -> GpuResult<()> {
16206    use cudarc::driver::PushKernelArg;
16207    let total = n_head * max_pos;
16208    let ctx = device.context();
16209    let stream = device.stream();
16210    let f = crate::module_cache::get_or_compile(
16211        ctx,
16212        CAUSAL_MASK_INDIRECT_PTX,
16213        "causal_mask_indirect_kernel",
16214        device.ordinal() as u32,
16215    )
16216    .map_err(|_| GpuError::PtxCompileFailed {
16217        kernel: "causal_mask_indirect_kernel",
16218    })?;
16219    let cfg = launch_cfg(total)?;
16220    let max_pos_u32 = max_pos as u32;
16221    let total_u32 = total as u32;
16222    unsafe {
16223        stream
16224            .launch_builder(&f)
16225            .arg(total_len_ptr)
16226            .arg(out.inner_mut())
16227            .arg(&max_pos_u32)
16228            .arg(&total_u32)
16229            .launch(cfg)?;
16230    }
16231    Ok(())
16232}
16233
16234// ===========================================================================
16235// Pre-compilation of all decode-path PTX modules
16236// ===========================================================================
16237
16238/// Pre-compile all PTX kernels used by the decode pass into the module cache.
16239/// Call this before CUDA graph capture to ensure no `cuModuleLoadData` calls
16240/// occur during capture (which is not a capturable operation).
16241#[cfg(feature = "cuda")]
16242pub fn precompile_decode_kernels(device: &GpuDevice) -> GpuResult<()> {
16243    let ctx = device.context();
16244    ctx.bind_to_thread()?;
16245    let ord = device.ordinal() as u32;
16246    let compile = |ptx: &'static str, name: &'static str| -> GpuResult<()> {
16247        crate::module_cache::get_or_compile(ctx, ptx, name, ord)
16248            .map(|_| ())
16249            .map_err(GpuError::Driver)
16250    };
16251    compile(ADD_PTX, "add_kernel")?;
16252    compile(MUL_PTX, "mul_kernel")?;
16253    compile(SCALE_PTX, "scale_kernel")?;
16254    compile(GELU_PTX, "gelu_kernel")?;
16255    compile(SOFTMAX_PTX, "softmax_kernel")?;
16256    compile(LAYERNORM_PTX, "layernorm_kernel")?;
16257    compile(PERMUTE_0213_PTX, "permute_0213_kernel")?;
16258    compile(EMBED_LOOKUP_PTX, "embed_lookup_kernel")?;
16259    compile(EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel")?;
16260    compile(SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel")?;
16261    compile(SMALL_MATMUL_PTX, "small_matmul_kernel")?;
16262    compile(SLICE_WRITE_INDIRECT_PTX, "slice_write_indirect_kernel")?;
16263    compile(CAUSAL_MASK_INDIRECT_PTX, "causal_mask_indirect_kernel")?;
16264    compile(SLICE_READ_PTX, "slice_read_kernel")?;
16265    compile(RELU_BACKWARD_PTX, "relu_backward_kernel")?;
16266    compile(GELU_BACKWARD_PTX, "gelu_backward_kernel")?;
16267    Ok(())
16268}
16269
16270/// Stub — no-op without cuda.
16271#[cfg(not(feature = "cuda"))]
16272pub fn precompile_decode_kernels(_device: &GpuDevice) -> GpuResult<()> {
16273    Err(GpuError::NoCudaFeature)
16274}
16275
16276// ---------------------------------------------------------------------------
16277// Stubs when `cuda` feature is disabled
16278// ---------------------------------------------------------------------------
16279
16280/// Stub -- always returns [`GpuError::NoCudaFeature`].
16281#[cfg(not(feature = "cuda"))]
16282pub fn gpu_gelu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16283    Err(GpuError::NoCudaFeature)
16284}
16285
16286/// Stub -- always returns [`GpuError::NoCudaFeature`].
16287#[cfg(not(feature = "cuda"))]
16288pub fn gpu_gelu_tanh(
16289    _input: &CudaBuffer<f32>,
16290    _device: &GpuDevice,
16291) -> GpuResult<CudaBuffer<f32>> {
16292    Err(GpuError::NoCudaFeature)
16293}
16294
16295/// Stub -- always returns [`GpuError::NoCudaFeature`].
16296#[cfg(not(feature = "cuda"))]
16297pub fn gpu_gelu_erf(
16298    _input: &CudaBuffer<f32>,
16299    _device: &GpuDevice,
16300) -> GpuResult<CudaBuffer<f32>> {
16301    Err(GpuError::NoCudaFeature)
16302}
16303
16304/// Stub -- always returns [`GpuError::NoCudaFeature`].
16305#[cfg(not(feature = "cuda"))]
16306pub fn gpu_gelu_backward_tanh(
16307    _grad: &CudaBuffer<f32>,
16308    _input: &CudaBuffer<f32>,
16309    _device: &GpuDevice,
16310) -> GpuResult<CudaBuffer<f32>> {
16311    Err(GpuError::NoCudaFeature)
16312}
16313
16314/// Stub -- always returns [`GpuError::NoCudaFeature`].
16315#[cfg(not(feature = "cuda"))]
16316pub fn gpu_silu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16317    Err(GpuError::NoCudaFeature)
16318}
16319
16320/// Stub -- always returns [`GpuError::NoCudaFeature`].
16321#[cfg(not(feature = "cuda"))]
16322pub fn gpu_silu_backward(
16323    _grad: &CudaBuffer<f32>,
16324    _input: &CudaBuffer<f32>,
16325    _device: &GpuDevice,
16326) -> GpuResult<CudaBuffer<f32>> {
16327    Err(GpuError::NoCudaFeature)
16328}
16329
16330/// Stub -- always returns [`GpuError::NoCudaFeature`].
16331#[cfg(not(feature = "cuda"))]
16332pub fn gpu_elu(
16333    _input: &CudaBuffer<f32>,
16334    _alpha: f32,
16335    _device: &GpuDevice,
16336) -> GpuResult<CudaBuffer<f32>> {
16337    Err(GpuError::NoCudaFeature)
16338}
16339
16340/// Stub -- always returns [`GpuError::NoCudaFeature`].
16341#[cfg(not(feature = "cuda"))]
16342pub fn gpu_elu_backward(
16343    _grad: &CudaBuffer<f32>,
16344    _input: &CudaBuffer<f32>,
16345    _alpha: f32,
16346    _device: &GpuDevice,
16347) -> GpuResult<CudaBuffer<f32>> {
16348    Err(GpuError::NoCudaFeature)
16349}
16350
16351/// Stub -- always returns [`GpuError::NoCudaFeature`].
16352#[cfg(not(feature = "cuda"))]
16353pub fn gpu_mish(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16354    Err(GpuError::NoCudaFeature)
16355}
16356
16357/// Stub -- always returns [`GpuError::NoCudaFeature`].
16358#[cfg(not(feature = "cuda"))]
16359pub fn gpu_mish_backward(
16360    _grad: &CudaBuffer<f32>,
16361    _input: &CudaBuffer<f32>,
16362    _device: &GpuDevice,
16363) -> GpuResult<CudaBuffer<f32>> {
16364    Err(GpuError::NoCudaFeature)
16365}
16366
16367/// Stub -- always returns [`GpuError::NoCudaFeature`].
16368#[cfg(not(feature = "cuda"))]
16369pub fn gpu_clamp(
16370    _input: &CudaBuffer<f32>,
16371    _min_val: f32,
16372    _max_val: f32,
16373    _device: &GpuDevice,
16374) -> GpuResult<CudaBuffer<f32>> {
16375    Err(GpuError::NoCudaFeature)
16376}
16377
16378/// Stub -- always returns [`GpuError::NoCudaFeature`].
16379#[cfg(not(feature = "cuda"))]
16380pub fn gpu_div(
16381    _a: &CudaBuffer<f32>,
16382    _b: &CudaBuffer<f32>,
16383    _device: &GpuDevice,
16384) -> GpuResult<CudaBuffer<f32>> {
16385    Err(GpuError::NoCudaFeature)
16386}
16387
16388/// Stub -- always returns [`GpuError::NoCudaFeature`].
16389#[cfg(not(feature = "cuda"))]
16390pub fn gpu_exp(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16391    Err(GpuError::NoCudaFeature)
16392}
16393
16394/// Stub -- always returns [`GpuError::NoCudaFeature`].
16395#[cfg(not(feature = "cuda"))]
16396pub fn gpu_log(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16397    Err(GpuError::NoCudaFeature)
16398}
16399
16400/// Stub -- always returns [`GpuError::NoCudaFeature`].
16401#[cfg(not(feature = "cuda"))]
16402pub fn gpu_sqrt(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16403    Err(GpuError::NoCudaFeature)
16404}
16405
16406/// Stub -- always returns [`GpuError::NoCudaFeature`].
16407#[cfg(not(feature = "cuda"))]
16408pub fn gpu_pow(
16409    _a: &CudaBuffer<f32>,
16410    _exponent: f32,
16411    _device: &GpuDevice,
16412) -> GpuResult<CudaBuffer<f32>> {
16413    Err(GpuError::NoCudaFeature)
16414}
16415
16416/// Stub -- always returns [`GpuError::NoCudaFeature`].
16417#[cfg(not(feature = "cuda"))]
16418pub fn gpu_abs(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16419    Err(GpuError::NoCudaFeature)
16420}
16421
16422/// Stub -- always returns [`GpuError::NoCudaFeature`].
16423#[cfg(not(feature = "cuda"))]
16424pub fn gpu_sigmoid(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16425    Err(GpuError::NoCudaFeature)
16426}
16427
16428/// Stub -- always returns [`GpuError::NoCudaFeature`].
16429#[cfg(not(feature = "cuda"))]
16430pub fn gpu_tanh(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16431    Err(GpuError::NoCudaFeature)
16432}
16433
16434/// Stub -- always returns [`GpuError::NoCudaFeature`].
16435#[cfg(not(feature = "cuda"))]
16436pub fn gpu_layernorm(
16437    _input: &CudaBuffer<f32>,
16438    _weight: &CudaBuffer<f32>,
16439    _bias: &CudaBuffer<f32>,
16440    _rows: usize,
16441    _cols: usize,
16442    _eps: f32,
16443    _device: &GpuDevice,
16444) -> GpuResult<CudaBuffer<f32>> {
16445    Err(GpuError::NoCudaFeature)
16446}
16447
16448/// Stub -- always returns [`GpuError::NoCudaFeature`].
16449#[cfg(not(feature = "cuda"))]
16450pub fn gpu_transpose_2d(
16451    _input: &CudaBuffer<f32>,
16452    _m: usize,
16453    _n: usize,
16454    _device: &GpuDevice,
16455) -> GpuResult<CudaBuffer<f32>> {
16456    Err(GpuError::NoCudaFeature)
16457}
16458
16459/// Stub -- always returns [`GpuError::NoCudaFeature`].
16460#[cfg(not(feature = "cuda"))]
16461pub fn gpu_add(
16462    _a: &CudaBuffer<f32>,
16463    _b: &CudaBuffer<f32>,
16464    _device: &GpuDevice,
16465) -> GpuResult<CudaBuffer<f32>> {
16466    Err(GpuError::NoCudaFeature)
16467}
16468
16469/// Stub -- always returns [`GpuError::NoCudaFeature`].
16470#[cfg(not(feature = "cuda"))]
16471pub fn gpu_sub(
16472    _a: &CudaBuffer<f32>,
16473    _b: &CudaBuffer<f32>,
16474    _device: &GpuDevice,
16475) -> GpuResult<CudaBuffer<f32>> {
16476    Err(GpuError::NoCudaFeature)
16477}
16478
16479/// Stub -- always returns [`GpuError::NoCudaFeature`].
16480#[cfg(not(feature = "cuda"))]
16481pub fn gpu_mul(
16482    _a: &CudaBuffer<f32>,
16483    _b: &CudaBuffer<f32>,
16484    _device: &GpuDevice,
16485) -> GpuResult<CudaBuffer<f32>> {
16486    Err(GpuError::NoCudaFeature)
16487}
16488
16489/// Stub -- always returns [`GpuError::NoCudaFeature`].
16490#[cfg(not(feature = "cuda"))]
16491pub fn gpu_neg(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16492    Err(GpuError::NoCudaFeature)
16493}
16494
16495/// Stub -- always returns [`GpuError::NoCudaFeature`].
16496#[cfg(not(feature = "cuda"))]
16497pub fn gpu_relu(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16498    Err(GpuError::NoCudaFeature)
16499}
16500
16501/// Stub -- always returns [`GpuError::NoCudaFeature`].
16502#[cfg(not(feature = "cuda"))]
16503pub fn gpu_scale(
16504    _a: &CudaBuffer<f32>,
16505    _scalar: f32,
16506    _device: &GpuDevice,
16507) -> GpuResult<CudaBuffer<f32>> {
16508    Err(GpuError::NoCudaFeature)
16509}
16510
16511/// Stub -- always returns [`GpuError::NoCudaFeature`].
16512#[cfg(not(feature = "cuda"))]
16513pub fn gpu_broadcast_add(
16514    _a: &CudaBuffer<f32>,
16515    _b: &CudaBuffer<f32>,
16516    _a_shape: &[usize],
16517    _b_shape: &[usize],
16518    _out_shape: &[usize],
16519    _device: &GpuDevice,
16520) -> GpuResult<CudaBuffer<f32>> {
16521    Err(GpuError::NoCudaFeature)
16522}
16523
16524/// Stub -- always returns [`GpuError::NoCudaFeature`].
16525#[cfg(not(feature = "cuda"))]
16526pub fn gpu_broadcast_sub(
16527    _a: &CudaBuffer<f32>,
16528    _b: &CudaBuffer<f32>,
16529    _a_shape: &[usize],
16530    _b_shape: &[usize],
16531    _out_shape: &[usize],
16532    _device: &GpuDevice,
16533) -> GpuResult<CudaBuffer<f32>> {
16534    Err(GpuError::NoCudaFeature)
16535}
16536
16537/// Stub -- always returns [`GpuError::NoCudaFeature`].
16538#[cfg(not(feature = "cuda"))]
16539pub fn gpu_broadcast_mul(
16540    _a: &CudaBuffer<f32>,
16541    _b: &CudaBuffer<f32>,
16542    _a_shape: &[usize],
16543    _b_shape: &[usize],
16544    _out_shape: &[usize],
16545    _device: &GpuDevice,
16546) -> GpuResult<CudaBuffer<f32>> {
16547    Err(GpuError::NoCudaFeature)
16548}
16549
16550/// Stub -- always returns [`GpuError::NoCudaFeature`].
16551#[cfg(not(feature = "cuda"))]
16552pub fn gpu_softmax(
16553    _input: &CudaBuffer<f32>,
16554    _rows: usize,
16555    _cols: usize,
16556    _device: &GpuDevice,
16557) -> GpuResult<CudaBuffer<f32>> {
16558    Err(GpuError::NoCudaFeature)
16559}
16560
16561/// Stub -- always returns [`GpuError::NoCudaFeature`].
16562#[cfg(not(feature = "cuda"))]
16563pub fn gpu_dropout(
16564    _input: &CudaBuffer<f32>,
16565    _threshold: u32,
16566    _scale: f32,
16567    _seed: u32,
16568    _device: &GpuDevice,
16569) -> GpuResult<CudaBuffer<f32>> {
16570    Err(GpuError::NoCudaFeature)
16571}
16572
16573/// Stub -- always returns [`GpuError::NoCudaFeature`].
16574#[cfg(not(feature = "cuda"))]
16575pub fn gpu_permute_0213(
16576    _input: &CudaBuffer<f32>,
16577    _d0: usize,
16578    _d1: usize,
16579    _d2: usize,
16580    _d3: usize,
16581    _device: &GpuDevice,
16582) -> GpuResult<CudaBuffer<f32>> {
16583    Err(GpuError::NoCudaFeature)
16584}
16585
16586/// Stub -- always returns [`GpuError::NoCudaFeature`].
16587#[cfg(not(feature = "cuda"))]
16588pub fn gpu_slice_write(
16589    _src: &CudaBuffer<f32>,
16590    _dst: &mut CudaBuffer<f32>,
16591    _n_batch: usize,
16592    _d: usize,
16593    _max_len: usize,
16594    _pos: usize,
16595    _device: &GpuDevice,
16596) -> GpuResult<()> {
16597    Err(GpuError::NoCudaFeature)
16598}
16599
16600/// Stub -- always returns [`GpuError::NoCudaFeature`].
16601#[cfg(not(feature = "cuda"))]
16602pub fn gpu_slice_read(
16603    _src: &CudaBuffer<f32>,
16604    _n_batch: usize,
16605    _d: usize,
16606    _len: usize,
16607    _max_len: usize,
16608    _device: &GpuDevice,
16609) -> GpuResult<CudaBuffer<f32>> {
16610    Err(GpuError::NoCudaFeature)
16611}
16612
16613/// Stub -- always returns [`GpuError::NoCudaFeature`].
16614#[cfg(not(feature = "cuda"))]
16615pub fn gpu_embed_lookup(
16616    _idx: &CudaBuffer<f32>,
16617    _weight: &CudaBuffer<f32>,
16618    _d: usize,
16619    _device: &GpuDevice,
16620) -> GpuResult<CudaBuffer<f32>> {
16621    Err(GpuError::NoCudaFeature)
16622}
16623
16624/// Stub -- always returns [`GpuError::NoCudaFeature`].
16625#[cfg(not(feature = "cuda"))]
16626pub fn gpu_embed_lookup_batch(
16627    _indices: &CudaBuffer<f32>,
16628    _weight: &CudaBuffer<f32>,
16629    _n: usize,
16630    _d: usize,
16631    _device: &GpuDevice,
16632) -> GpuResult<CudaBuffer<f32>> {
16633    Err(GpuError::NoCudaFeature)
16634}
16635
16636/// Stub -- always returns [`GpuError::NoCudaFeature`].
16637#[cfg(not(feature = "cuda"))]
16638pub fn gpu_scatter_add_rows(
16639    _grad_output: &CudaBuffer<f32>,
16640    _indices: &CudaBuffer<f32>,
16641    _num_embeddings: usize,
16642    _d: usize,
16643    _device: &GpuDevice,
16644) -> GpuResult<CudaBuffer<f32>> {
16645    Err(GpuError::NoCudaFeature)
16646}
16647
16648/// Stub -- always returns [`GpuError::NoCudaFeature`].
16649#[cfg(not(feature = "cuda"))]
16650pub fn gpu_relu_backward(
16651    _grad: &CudaBuffer<f32>,
16652    _input: &CudaBuffer<f32>,
16653    _device: &GpuDevice,
16654) -> GpuResult<CudaBuffer<f32>> {
16655    Err(GpuError::NoCudaFeature)
16656}
16657
16658/// Stub -- always returns [`GpuError::NoCudaFeature`].
16659#[cfg(not(feature = "cuda"))]
16660pub fn gpu_abs_backward(
16661    _grad: &CudaBuffer<f32>,
16662    _input: &CudaBuffer<f32>,
16663    _device: &GpuDevice,
16664) -> GpuResult<CudaBuffer<f32>> {
16665    Err(GpuError::NoCudaFeature)
16666}
16667
16668/// Stub -- always returns [`GpuError::NoCudaFeature`].
16669#[cfg(not(feature = "cuda"))]
16670pub fn gpu_fill_f32(
16671    _n: usize,
16672    _scalar: f32,
16673    _device: &GpuDevice,
16674) -> GpuResult<CudaBuffer<f32>> {
16675    Err(GpuError::NoCudaFeature)
16676}
16677
16678/// Stub -- always returns [`GpuError::NoCudaFeature`].
16679#[cfg(not(feature = "cuda"))]
16680pub fn gpu_gelu_backward(
16681    _grad: &CudaBuffer<f32>,
16682    _input: &CudaBuffer<f32>,
16683    _device: &GpuDevice,
16684) -> GpuResult<CudaBuffer<f32>> {
16685    Err(GpuError::NoCudaFeature)
16686}
16687
16688/// Stub -- always returns [`GpuError::NoCudaFeature`].
16689#[cfg(not(feature = "cuda"))]
16690pub fn gpu_index_select_1d(
16691    _input: &CudaBuffer<f32>,
16692    _indices: &CudaBuffer<f32>,
16693    _device: &GpuDevice,
16694) -> GpuResult<CudaBuffer<f32>> {
16695    Err(GpuError::NoCudaFeature)
16696}
16697
16698/// Stub -- always returns [`GpuError::NoCudaFeature`].
16699#[cfg(not(feature = "cuda"))]
16700pub fn gpu_scatter_add_1d(
16701    _grad_output: &CudaBuffer<f32>,
16702    _indices: &CudaBuffer<f32>,
16703    _input_len: usize,
16704    _device: &GpuDevice,
16705) -> GpuResult<CudaBuffer<f32>> {
16706    Err(GpuError::NoCudaFeature)
16707}
16708
16709/// Stub -- always returns [`GpuError::NoCudaFeature`].
16710#[cfg(not(feature = "cuda"))]
16711pub fn gpu_masked_fill(
16712    _input: &CudaBuffer<f32>,
16713    _mask: &CudaBuffer<f32>,
16714    _value: f32,
16715    _device: &GpuDevice,
16716) -> GpuResult<CudaBuffer<f32>> {
16717    Err(GpuError::NoCudaFeature)
16718}
16719
16720/// Stub -- always returns [`GpuError::NoCudaFeature`].
16721#[cfg(not(feature = "cuda"))]
16722pub fn gpu_masked_zero(
16723    _grad: &CudaBuffer<f32>,
16724    _mask: &CudaBuffer<f32>,
16725    _device: &GpuDevice,
16726) -> GpuResult<CudaBuffer<f32>> {
16727    Err(GpuError::NoCudaFeature)
16728}
16729
16730/// Stub -- always returns [`GpuError::NoCudaFeature`].
16731#[cfg(not(feature = "cuda"))]
16732pub fn gpu_sigmoid_backward(
16733    _grad: &CudaBuffer<f32>,
16734    _output: &CudaBuffer<f32>,
16735    _device: &GpuDevice,
16736) -> GpuResult<CudaBuffer<f32>> {
16737    Err(GpuError::NoCudaFeature)
16738}
16739
16740/// Stub -- always returns [`GpuError::NoCudaFeature`].
16741#[cfg(not(feature = "cuda"))]
16742pub fn gpu_tanh_backward(
16743    _grad: &CudaBuffer<f32>,
16744    _output: &CudaBuffer<f32>,
16745    _device: &GpuDevice,
16746) -> GpuResult<CudaBuffer<f32>> {
16747    Err(GpuError::NoCudaFeature)
16748}
16749
16750/// Stub -- always returns [`GpuError::NoCudaFeature`].
16751#[cfg(not(feature = "cuda"))]
16752pub fn gpu_softmax_backward(
16753    _grad: &CudaBuffer<f32>,
16754    _output: &CudaBuffer<f32>,
16755    _cols: usize,
16756    _device: &GpuDevice,
16757) -> GpuResult<CudaBuffer<f32>> {
16758    Err(GpuError::NoCudaFeature)
16759}
16760
16761/// Stub -- always returns [`GpuError::NoCudaFeature`].
16762#[cfg(not(feature = "cuda"))]
16763pub fn gpu_log_softmax(
16764    _input: &CudaBuffer<f32>,
16765    _cols: usize,
16766    _device: &GpuDevice,
16767) -> GpuResult<CudaBuffer<f32>> {
16768    Err(GpuError::NoCudaFeature)
16769}
16770
16771/// Stub -- always returns [`GpuError::NoCudaFeature`].
16772#[cfg(not(feature = "cuda"))]
16773pub fn gpu_log_softmax_backward(
16774    _grad: &CudaBuffer<f32>,
16775    _output: &CudaBuffer<f32>,
16776    _cols: usize,
16777    _device: &GpuDevice,
16778) -> GpuResult<CudaBuffer<f32>> {
16779    Err(GpuError::NoCudaFeature)
16780}
16781
16782/// Stub -- always returns [`GpuError::NoCudaFeature`].
16783#[cfg(not(feature = "cuda"))]
16784pub fn gpu_sum_axis(
16785    _a: &CudaBuffer<f32>,
16786    _outer: usize,
16787    _axis_size: usize,
16788    _inner: usize,
16789    _device: &GpuDevice,
16790) -> GpuResult<CudaBuffer<f32>> {
16791    Err(GpuError::NoCudaFeature)
16792}
16793
16794/// Stub -- always returns [`GpuError::NoCudaFeature`].
16795#[cfg(not(feature = "cuda"))]
16796pub fn gpu_cumsum(
16797    _input: &CudaBuffer<f32>,
16798    _outer: usize,
16799    _dim_size: usize,
16800    _inner: usize,
16801    _device: &GpuDevice,
16802) -> GpuResult<CudaBuffer<f32>> {
16803    Err(GpuError::NoCudaFeature)
16804}
16805
16806/// Stub -- always returns [`GpuError::NoCudaFeature`].
16807#[cfg(not(feature = "cuda"))]
16808pub fn gpu_cumprod(
16809    _input: &CudaBuffer<f32>,
16810    _outer: usize,
16811    _dim_size: usize,
16812    _inner: usize,
16813    _device: &GpuDevice,
16814) -> GpuResult<CudaBuffer<f32>> {
16815    Err(GpuError::NoCudaFeature)
16816}
16817
16818/// Stub -- always returns [`GpuError::NoCudaFeature`].
16819#[cfg(not(feature = "cuda"))]
16820pub fn gpu_cummax(
16821    _input: &CudaBuffer<f32>,
16822    _outer: usize,
16823    _dim_size: usize,
16824    _inner: usize,
16825    _device: &GpuDevice,
16826) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
16827    Err(GpuError::NoCudaFeature)
16828}
16829
16830/// Stub -- always returns [`GpuError::NoCudaFeature`].
16831#[cfg(not(feature = "cuda"))]
16832pub fn gpu_cummin(
16833    _input: &CudaBuffer<f32>,
16834    _outer: usize,
16835    _dim_size: usize,
16836    _inner: usize,
16837    _device: &GpuDevice,
16838) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
16839    Err(GpuError::NoCudaFeature)
16840}
16841
16842/// Stub -- always returns [`GpuError::NoCudaFeature`].
16843#[cfg(not(feature = "cuda"))]
16844pub fn gpu_logcumsumexp(
16845    _input: &CudaBuffer<f32>,
16846    _outer: usize,
16847    _dim_size: usize,
16848    _inner: usize,
16849    _device: &GpuDevice,
16850) -> GpuResult<CudaBuffer<f32>> {
16851    Err(GpuError::NoCudaFeature)
16852}
16853
16854/// Stub -- always returns [`GpuError::NoCudaFeature`].
16855#[cfg(not(feature = "cuda"))]
16856pub fn gpu_strided_split(
16857    _input: &CudaBuffer<f32>,
16858    _total_along_axis: usize,
16859    _split_offset: usize,
16860    _split_size: usize,
16861    _inner_size: usize,
16862    _n: usize,
16863    _device: &GpuDevice,
16864) -> GpuResult<CudaBuffer<f32>> {
16865    Err(GpuError::NoCudaFeature)
16866}
16867
16868/// Stub -- always returns [`GpuError::NoCudaFeature`].
16869#[cfg(not(feature = "cuda"))]
16870pub fn gpu_strided_cat(
16871    _input: &CudaBuffer<f32>,
16872    _output: &mut CudaBuffer<f32>,
16873    _total_along_axis: usize,
16874    _cat_offset: usize,
16875    _part_size: usize,
16876    _inner_size: usize,
16877    _n: usize,
16878    _device: &GpuDevice,
16879) -> GpuResult<()> {
16880    Err(GpuError::NoCudaFeature)
16881}
16882
16883/// Maximum rank stub for feature-disabled builds. Kept in sync with
16884/// the cuda-enabled definition above.
16885#[cfg(not(feature = "cuda"))]
16886pub const STRIDED_COPY_MAX_DIMS: usize = 8;
16887
16888/// Stub -- always returns [`GpuError::NoCudaFeature`].
16889#[cfg(not(feature = "cuda"))]
16890pub fn gpu_strided_copy(
16891    _input: &CudaBuffer<f32>,
16892    _out_shape: &[usize],
16893    _src_strides: &[isize],
16894    _src_offset: usize,
16895    _device: &GpuDevice,
16896) -> GpuResult<CudaBuffer<f32>> {
16897    Err(GpuError::NoCudaFeature)
16898}
16899
16900/// Stub -- always returns [`GpuError::NoCudaFeature`].
16901#[cfg(not(feature = "cuda"))]
16902pub fn gpu_strided_copy_f64(
16903    _input: &CudaBuffer<f64>,
16904    _out_shape: &[usize],
16905    _src_strides: &[isize],
16906    _src_offset: usize,
16907    _device: &GpuDevice,
16908) -> GpuResult<CudaBuffer<f64>> {
16909    Err(GpuError::NoCudaFeature)
16910}
16911
16912// ---------------------------------------------------------------------------
16913// f32-to-f16 GPU conversion
16914// ---------------------------------------------------------------------------
16915
16916/// Convert an f32 GPU buffer to f16 (represented as `CudaSlice<u16>`).
16917///
16918/// Each element is converted using IEEE 754 round-to-nearest-even via the
16919/// PTX `cvt.rn.f16.f32` instruction. The output is a `CudaSlice<u16>` where
16920/// each `u16` holds the bit pattern of an IEEE 754 half-precision float.
16921///
16922/// # Errors
16923///
16924/// - [`GpuError::PtxCompileFailed`] if the conversion kernel cannot be compiled
16925///   (e.g., GPU architecture too old to support f16 conversion instructions).
16926/// - [`GpuError::Driver`] on CUDA launch errors.
16927#[cfg(feature = "cuda")]
16928pub(crate) fn gpu_f32_to_f16(
16929    input: &CudaBuffer<f32>,
16930    device: &GpuDevice,
16931) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
16932    use cudarc::driver::PushKernelArg;
16933
16934    let n = input.len();
16935    if n == 0 {
16936        let empty = device.stream().alloc_zeros::<u16>(0)?;
16937        return Ok(empty);
16938    }
16939
16940    let ctx = device.context();
16941    let stream = device.stream();
16942
16943    let f = crate::module_cache::get_or_compile(
16944        ctx,
16945        F32_TO_F16_PTX,
16946        "f32_to_f16_kernel",
16947        device.ordinal() as u32,
16948    )
16949    .map_err(|_| GpuError::PtxCompileFailed {
16950        kernel: "f32_to_f16_kernel",
16951    })?;
16952
16953    let mut out = stream.alloc_zeros::<u16>(n)?;
16954    let cfg = launch_cfg(n)?;
16955    let n_u32 = n as u32;
16956
16957    // SAFETY: The kernel reads `n` f32 values from `input` and writes `n`
16958    // u16 values (f16 bit patterns) to `out`. Both buffers are device-resident
16959    // and correctly sized. The grid is configured to cover exactly `n` threads.
16960    unsafe {
16961        stream
16962            .launch_builder(&f)
16963            .arg(input.inner())
16964            .arg(&mut out)
16965            .arg(&n_u32)
16966            .launch(cfg)?;
16967    }
16968
16969    Ok(out)
16970}
16971
16972/// Stub -- always returns [`GpuError::NoCudaFeature`].
16973#[cfg(not(feature = "cuda"))]
16974pub(crate) fn gpu_f32_to_f16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
16975    Err(GpuError::NoCudaFeature)
16976}
16977
16978/// Convert f32 GPU buffer to bf16 (stored as u16) on-device.
16979///
16980/// Uses bit manipulation for round-to-nearest-even bf16 conversion.
16981/// Works on sm_52+ (no special bf16 hardware required).
16982#[cfg(feature = "cuda")]
16983pub(crate) fn gpu_f32_to_bf16(
16984    input: &CudaBuffer<f32>,
16985    device: &GpuDevice,
16986) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
16987    use cudarc::driver::PushKernelArg;
16988
16989    let n = input.len();
16990    if n == 0 {
16991        let empty = device.stream().alloc_zeros::<u16>(0)?;
16992        return Ok(empty);
16993    }
16994
16995    let ctx = device.context();
16996    let stream = device.stream();
16997
16998    let f = crate::module_cache::get_or_compile(
16999        ctx,
17000        F32_TO_BF16_PTX,
17001        "f32_to_bf16_kernel",
17002        device.ordinal() as u32,
17003    )
17004    .map_err(|_| GpuError::PtxCompileFailed {
17005        kernel: "f32_to_bf16_kernel",
17006    })?;
17007
17008    let mut out = stream.alloc_zeros::<u16>(n)?;
17009    let cfg = launch_cfg(n)?;
17010    let n_u32 = n as u32;
17011
17012    unsafe {
17013        stream
17014            .launch_builder(&f)
17015            .arg(input.inner())
17016            .arg(&mut out)
17017            .arg(&n_u32)
17018            .launch(cfg)?;
17019    }
17020
17021    Ok(out)
17022}
17023
17024/// Stub -- always returns [`GpuError::NoCudaFeature`].
17025#[cfg(not(feature = "cuda"))]
17026pub(crate) fn gpu_f32_to_bf16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
17027    Err(GpuError::NoCudaFeature)
17028}
17029
17030// ---------------------------------------------------------------------------
17031// Non-CUDA stubs -- f64 ops
17032// ---------------------------------------------------------------------------
17033
17034#[cfg(not(feature = "cuda"))]
17035pub fn gpu_add_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17036#[cfg(not(feature = "cuda"))]
17037pub fn gpu_sub_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17038#[cfg(not(feature = "cuda"))]
17039pub fn gpu_mul_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17040#[cfg(not(feature = "cuda"))]
17041pub fn gpu_div_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17042#[cfg(not(feature = "cuda"))]
17043pub fn gpu_neg_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17044#[cfg(not(feature = "cuda"))]
17045pub fn gpu_relu_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17046#[cfg(not(feature = "cuda"))]
17047pub fn gpu_scale_f64(_a: &CudaBuffer<f64>, _scalar: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17048#[cfg(not(feature = "cuda"))]
17049pub fn gpu_exp_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17050#[cfg(not(feature = "cuda"))]
17051pub fn gpu_log_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17052#[cfg(not(feature = "cuda"))]
17053pub fn gpu_sqrt_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17054#[cfg(not(feature = "cuda"))]
17055pub fn gpu_pow_f64(_a: &CudaBuffer<f64>, _exponent: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17056#[cfg(not(feature = "cuda"))]
17057pub fn gpu_abs_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17058#[cfg(not(feature = "cuda"))]
17059pub fn gpu_sigmoid_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17060#[cfg(not(feature = "cuda"))]
17061pub fn gpu_tanh_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17062#[cfg(not(feature = "cuda"))]
17063pub fn gpu_relu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17064#[cfg(not(feature = "cuda"))]
17065pub fn gpu_sigmoid_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17066#[cfg(not(feature = "cuda"))]
17067pub fn gpu_tanh_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17068#[cfg(not(feature = "cuda"))]
17069pub fn gpu_broadcast_add_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17070#[cfg(not(feature = "cuda"))]
17071pub fn gpu_broadcast_sub_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17072#[cfg(not(feature = "cuda"))]
17073pub fn gpu_broadcast_mul_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17074#[cfg(not(feature = "cuda"))]
17075pub fn gpu_broadcast_div_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17076#[cfg(not(feature = "cuda"))]
17077pub fn gpu_transpose_2d_f64(_input: &CudaBuffer<f64>, _m: usize, _n: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17078#[cfg(not(feature = "cuda"))]
17079pub fn gpu_permute_0213_f64(_input: &CudaBuffer<f64>, _d0: usize, _d1: usize, _d2: usize, _d3: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17080#[cfg(not(feature = "cuda"))]
17081pub fn gpu_strided_split_f64(_input: &CudaBuffer<f64>, _total_along_axis: usize, _split_offset: usize, _split_size: usize, _inner_size: usize, _n: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17082#[cfg(not(feature = "cuda"))]
17083pub fn gpu_strided_cat_f64(_input: &CudaBuffer<f64>, _output: &mut CudaBuffer<f64>, _total_along_axis: usize, _cat_offset: usize, _part_size: usize, _inner_size: usize, _n: usize, _device: &GpuDevice) -> GpuResult<()> { Err(GpuError::NoCudaFeature) }
17084#[cfg(not(feature = "cuda"))]
17085pub fn gpu_index_select_1d_f64(_input: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17086#[cfg(not(feature = "cuda"))]
17087pub fn gpu_scatter_add_1d_f64(_grad_output: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _input_len: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17088#[cfg(not(feature = "cuda"))]
17089pub fn gpu_masked_fill_f64(_input: &CudaBuffer<f64>, _mask: &CudaBuffer<u8>, _value: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17090#[cfg(not(feature = "cuda"))]
17091pub fn gpu_masked_zero_f64(_grad: &CudaBuffer<f64>, _mask: &CudaBuffer<u8>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17092#[cfg(not(feature = "cuda"))]
17093pub fn gpu_slice_write_f64(_src: &CudaBuffer<f64>, _dst: &mut CudaBuffer<f64>, _n_batch: usize, _d: usize, _max_len: usize, _pos: usize, _device: &GpuDevice) -> GpuResult<()> { Err(GpuError::NoCudaFeature) }
17094#[cfg(not(feature = "cuda"))]
17095pub fn gpu_slice_read_f64(_src: &CudaBuffer<f64>, _n_batch: usize, _d: usize, _len: usize, _max_len: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17096#[cfg(not(feature = "cuda"))]
17097pub fn gpu_embed_lookup_f64(_idx: &CudaBuffer<f32>, _weight: &CudaBuffer<f64>, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17098#[cfg(not(feature = "cuda"))]
17099pub fn gpu_embed_lookup_batch_f64(_indices: &CudaBuffer<f32>, _weight: &CudaBuffer<f64>, _n: usize, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17100#[cfg(not(feature = "cuda"))]
17101pub fn gpu_scatter_add_rows_f64(_grad_output: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _num_embeddings: usize, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17102
17103
17104// ---------------------------------------------------------------------------
17105// Public API -- f64 activation, normalization, scan, and pooling launchers
17106// ---------------------------------------------------------------------------
17107
17108/// GELU (sigmoid-approx) for f64.
17109#[cfg(feature = "cuda")]
17110pub fn gpu_gelu_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17111    if let Some(out) = try_launch_unary_f64(input, device, GELU_F64_PTX, "gelu_f64_kernel")? {
17112        return Ok(out);
17113    }
17114    cpu_fallback_unary_f64(input, device, |x| x * (1.0 / (1.0 + (-1.702 * x).exp())))
17115}
17116
17117/// GELU (tanh-approx) for f64.
17118#[cfg(feature = "cuda")]
17119pub fn gpu_gelu_tanh_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17120    if let Some(out) = try_launch_unary_f64(input, device, GELU_TANH_F64_PTX, "gelu_tanh_f64_kernel")? {
17121        return Ok(out);
17122    }
17123    cpu_fallback_unary_f64(input, device, |x| {
17124        let inner = (2.0_f64 / std::f64::consts::PI).sqrt() * (x + 0.044715 * x * x * x);
17125        0.5 * x * (1.0 + inner.tanh())
17126    })
17127}
17128
17129/// GELU (exact erf) for f64.
17130#[cfg(feature = "cuda")]
17131pub fn gpu_gelu_erf_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17132    if let Some(out) = try_launch_unary_f64(input, device, GELU_ERF_F64_PTX, "gelu_erf_f64_kernel")? {
17133        return Ok(out);
17134    }
17135    cpu_fallback_unary_f64(input, device, |x| {
17136        // Approximate erf via Abramowitz & Stegun
17137        let z = x * std::f64::consts::FRAC_1_SQRT_2;
17138        let az = z.abs();
17139        let t = 1.0 / (1.0 + 0.3275911 * az);
17140        let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
17141        let erf_abs = 1.0 - poly * (-az * az).exp();
17142        let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
17143        x * 0.5 * (1.0 + erf_val)
17144    })
17145}
17146
17147/// GELU backward (sigmoid-approx) for f64.
17148#[cfg(feature = "cuda")]
17149pub fn gpu_gelu_backward_f64(
17150    grad: &CudaBuffer<f64>,
17151    input: &CudaBuffer<f64>,
17152    device: &GpuDevice,
17153) -> GpuResult<CudaBuffer<f64>> {
17154    if grad.len() != input.len() {
17155        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17156    }
17157    if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_F64_PTX, "gelu_backward_f64_kernel")? {
17158        return Ok(out);
17159    }
17160    cpu_fallback_binary_f64(grad, input, device, |g, x| {
17161        let sig = 1.0 / (1.0 + (-1.702 * x).exp());
17162        g * (sig + 1.702 * x * sig * (1.0 - sig))
17163    })
17164}
17165
17166/// GELU backward (tanh-approx) for f64.
17167#[cfg(feature = "cuda")]
17168pub fn gpu_gelu_backward_tanh_f64(
17169    grad: &CudaBuffer<f64>,
17170    input: &CudaBuffer<f64>,
17171    device: &GpuDevice,
17172) -> GpuResult<CudaBuffer<f64>> {
17173    if grad.len() != input.len() {
17174        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17175    }
17176    if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_TANH_F64_PTX, "gelu_backward_tanh_f64_kernel")? {
17177        return Ok(out);
17178    }
17179    cpu_fallback_binary_f64(grad, input, device, |g, x| {
17180        let s2pi = (2.0_f64 / std::f64::consts::PI).sqrt();
17181        let c = 0.044715_f64;
17182        let u = s2pi * (x + c * x * x * x);
17183        let t = u.tanh();
17184        let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * s2pi * (1.0 + 3.0 * c * x * x);
17185        g * d
17186    })
17187}
17188
17189/// GELU backward (exact erf) for f64.
17190#[cfg(feature = "cuda")]
17191pub fn gpu_gelu_backward_erf_f64(
17192    grad: &CudaBuffer<f64>,
17193    input: &CudaBuffer<f64>,
17194    device: &GpuDevice,
17195) -> GpuResult<CudaBuffer<f64>> {
17196    if grad.len() != input.len() {
17197        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17198    }
17199    if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_ERF_F64_PTX, "gelu_backward_erf_f64_kernel")? {
17200        return Ok(out);
17201    }
17202    cpu_fallback_binary_f64(grad, input, device, |g, x| {
17203        let z = x * std::f64::consts::FRAC_1_SQRT_2;
17204        let az = z.abs();
17205        let t = 1.0 / (1.0 + 0.3275911 * az);
17206        let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
17207        let erf_abs = 1.0 - poly * (-az * az).exp();
17208        let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
17209        let cdf = 0.5 * (1.0 + erf_val);
17210        let pdf = (-x * x / 2.0).exp() / (2.0 * std::f64::consts::PI).sqrt();
17211        g * (cdf + x * pdf)
17212    })
17213}
17214
17215/// SiLU for f64.
17216#[cfg(feature = "cuda")]
17217pub fn gpu_silu_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17218    if let Some(out) = try_launch_unary_f64(input, device, SILU_F64_PTX, "silu_f64_kernel")? {
17219        return Ok(out);
17220    }
17221    cpu_fallback_unary_f64(input, device, |x| x / (1.0 + (-x).exp()))
17222}
17223
17224/// SiLU backward for f64.
17225#[cfg(feature = "cuda")]
17226pub fn gpu_silu_backward_f64(
17227    grad: &CudaBuffer<f64>,
17228    input: &CudaBuffer<f64>,
17229    device: &GpuDevice,
17230) -> GpuResult<CudaBuffer<f64>> {
17231    if grad.len() != input.len() {
17232        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17233    }
17234    if let Some(out) = try_launch_binary_f64(grad, input, device, SILU_BACKWARD_F64_PTX, "silu_backward_f64_kernel")? {
17235        return Ok(out);
17236    }
17237    cpu_fallback_binary_f64(grad, input, device, |g, x| {
17238        let sig = 1.0 / (1.0 + (-x).exp());
17239        g * (sig + x * sig * (1.0 - sig))
17240    })
17241}
17242
17243/// ELU for f64.
17244#[cfg(feature = "cuda")]
17245pub fn gpu_elu_f64(
17246    input: &CudaBuffer<f64>,
17247    alpha: f64,
17248    device: &GpuDevice,
17249) -> GpuResult<CudaBuffer<f64>> {
17250    use cudarc::driver::PushKernelArg;
17251    let n = input.len();
17252    if n == 0 { return cpu_to_gpu(&[], device); }
17253    let ctx = device.context();
17254    let stream = device.stream();
17255    if let Ok(f) = crate::module_cache::get_or_compile(ctx, ELU_F64_PTX, "elu_f64_kernel", device.ordinal() as u32) {
17256        let mut out = alloc_zeros_f64(n, device)?;
17257        let n_u32 = n as u32;
17258        let cfg = launch_cfg(n)?;
17259        unsafe {
17260            stream.launch_builder(&f)
17261                .arg(input.inner())
17262                .arg(out.inner_mut())
17263                .arg(&n_u32)
17264                .arg(&alpha)
17265                .launch(cfg)?;
17266        }
17267        return Ok(out);
17268    }
17269    let host = gpu_to_cpu(input, device)?;
17270    let result: Vec<f64> = host.iter().map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) }).collect();
17271    cpu_to_gpu(&result, device)
17272}
17273
17274/// ELU backward for f64.
17275#[cfg(feature = "cuda")]
17276pub fn gpu_elu_backward_f64(
17277    grad: &CudaBuffer<f64>,
17278    input: &CudaBuffer<f64>,
17279    alpha: f64,
17280    device: &GpuDevice,
17281) -> GpuResult<CudaBuffer<f64>> {
17282    use cudarc::driver::PushKernelArg;
17283    if grad.len() != input.len() {
17284        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17285    }
17286    let n = grad.len();
17287    if n == 0 { return cpu_to_gpu(&[], device); }
17288    let ctx = device.context();
17289    let stream = device.stream();
17290    if let Ok(f) = crate::module_cache::get_or_compile(ctx, ELU_BACKWARD_F64_PTX, "elu_backward_f64_kernel", device.ordinal() as u32) {
17291        let mut out = alloc_zeros_f64(n, device)?;
17292        let n_u32 = n as u32;
17293        let cfg = launch_cfg(n)?;
17294        unsafe {
17295            stream.launch_builder(&f)
17296                .arg(grad.inner())
17297                .arg(input.inner())
17298                .arg(out.inner_mut())
17299                .arg(&n_u32)
17300                .arg(&alpha)
17301                .launch(cfg)?;
17302        }
17303        return Ok(out);
17304    }
17305    let g_host = gpu_to_cpu(grad, device)?;
17306    let x_host = gpu_to_cpu(input, device)?;
17307    let result: Vec<f64> = g_host.iter().zip(x_host.iter()).map(|(&g, &x)| if x > 0.0 { g } else { g * alpha * x.exp() }).collect();
17308    cpu_to_gpu(&result, device)
17309}
17310
17311/// Mish for f64.
17312#[cfg(feature = "cuda")]
17313pub fn gpu_mish_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17314    if let Some(out) = try_launch_unary_f64(input, device, MISH_F64_PTX, "mish_f64_kernel")? {
17315        return Ok(out);
17316    }
17317    cpu_fallback_unary_f64(input, device, |x| x * (1.0_f64 + x.exp()).ln().tanh())
17318}
17319
17320/// Mish backward for f64.
17321#[cfg(feature = "cuda")]
17322pub fn gpu_mish_backward_f64(
17323    grad: &CudaBuffer<f64>,
17324    input: &CudaBuffer<f64>,
17325    device: &GpuDevice,
17326) -> GpuResult<CudaBuffer<f64>> {
17327    if grad.len() != input.len() {
17328        return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17329    }
17330    if let Some(out) = try_launch_binary_f64(grad, input, device, MISH_BACKWARD_F64_PTX, "mish_backward_f64_kernel")? {
17331        return Ok(out);
17332    }
17333    cpu_fallback_binary_f64(grad, input, device, |g, x| {
17334        let sp = (1.0_f64 + x.exp()).ln();
17335        let t = sp.tanh();
17336        let sig = 1.0 / (1.0 + (-x).exp());
17337        g * (t + x * sig * (1.0 - t * t))
17338    })
17339}
17340
17341/// Clamp for f64.
17342#[cfg(feature = "cuda")]
17343pub fn gpu_clamp_f64(
17344    input: &CudaBuffer<f64>,
17345    min_val: f64,
17346    max_val: f64,
17347    device: &GpuDevice,
17348) -> GpuResult<CudaBuffer<f64>> {
17349    use cudarc::driver::PushKernelArg;
17350    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17351    let n = input.len();
17352    if n == 0 { return cpu_to_gpu(&[], device); }
17353    let ctx = device.context();
17354    let stream = device.stream();
17355    let ptx = get_f64_ptx(&CACHE, CLAMP_PTX, "clamp_kernel", "clamp_f64_kernel");
17356    if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "clamp_f64_kernel", device.ordinal() as u32) {
17357        let mut out = alloc_zeros_f64(n, device)?;
17358        let n_u32 = n as u32;
17359        let cfg = launch_cfg(n)?;
17360        unsafe {
17361            stream.launch_builder(&f)
17362                .arg(input.inner())
17363                .arg(out.inner_mut())
17364                .arg(&n_u32)
17365                .arg(&min_val)
17366                .arg(&max_val)
17367                .launch(cfg)?;
17368        }
17369        return Ok(out);
17370    }
17371    let host = gpu_to_cpu(input, device)?;
17372    let result: Vec<f64> = host.iter().map(|&x| x.max(min_val).min(max_val)).collect();
17373    cpu_to_gpu(&result, device)
17374}
17375
17376/// Cumulative sum for f64.
17377#[cfg(feature = "cuda")]
17378pub fn gpu_cumsum_f64(
17379    input: &CudaBuffer<f64>,
17380    outer: usize,
17381    dim_size: usize,
17382    inner: usize,
17383    device: &GpuDevice,
17384) -> GpuResult<CudaBuffer<f64>> {
17385    use cudarc::driver::PushKernelArg;
17386    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17387    let total = outer * inner;
17388    let n = outer * dim_size * inner;
17389    if n == 0 { return cpu_to_gpu(&[], device); }
17390    let ctx = device.context();
17391    let stream = device.stream();
17392    let ptx = get_f64_ptx(&CACHE, CUMSUM_PTX, "cumsum_kernel", "cumsum_f64_kernel");
17393    if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "cumsum_f64_kernel", device.ordinal() as u32) {
17394        let mut out = alloc_zeros_f64(n, device)?;
17395        let cfg = launch_cfg(total)?;
17396        let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17397        unsafe {
17398            stream.launch_builder(&f)
17399                .arg(input.inner())
17400                .arg(out.inner_mut())
17401                .arg(&o)
17402                .arg(&d)
17403                .arg(&i)
17404                .arg(&t)
17405                .launch(cfg)?;
17406        }
17407        return Ok(out);
17408    }
17409    Err(GpuError::PtxCompileFailed { kernel: "cumsum_f64_kernel" })
17410}
17411
17412/// Cumulative product for f64.
17413#[cfg(feature = "cuda")]
17414pub fn gpu_cumprod_f64(
17415    input: &CudaBuffer<f64>,
17416    outer: usize,
17417    dim_size: usize,
17418    inner: usize,
17419    device: &GpuDevice,
17420) -> GpuResult<CudaBuffer<f64>> {
17421    use cudarc::driver::PushKernelArg;
17422    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17423    let total = outer * inner;
17424    let n = outer * dim_size * inner;
17425    if n == 0 { return cpu_to_gpu(&[], device); }
17426    let ctx = device.context();
17427    let stream = device.stream();
17428    let ptx = get_f64_ptx(&CACHE, CUMPROD_PTX, "cumprod_kernel", "cumprod_f64_kernel");
17429    if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "cumprod_f64_kernel", device.ordinal() as u32) {
17430        let mut out = alloc_zeros_f64(n, device)?;
17431        let cfg = launch_cfg(total)?;
17432        let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17433        unsafe {
17434            stream.launch_builder(&f)
17435                .arg(input.inner())
17436                .arg(out.inner_mut())
17437                .arg(&o)
17438                .arg(&d)
17439                .arg(&i)
17440                .arg(&t)
17441                .launch(cfg)?;
17442        }
17443        return Ok(out);
17444    }
17445    Err(GpuError::PtxCompileFailed { kernel: "cumprod_f64_kernel" })
17446}
17447
17448/// Cumulative max for f64. Returns (values, indices).
17449#[cfg(feature = "cuda")]
17450pub fn gpu_cummax_f64(
17451    input: &CudaBuffer<f64>,
17452    outer: usize,
17453    dim_size: usize,
17454    inner: usize,
17455    device: &GpuDevice,
17456) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
17457    use cudarc::driver::PushKernelArg;
17458    let total = outer * inner;
17459    let n = outer * dim_size * inner;
17460    if n == 0 {
17461        let e: &[f64] = &[];
17462        return Ok((cpu_to_gpu(e, device)?, cpu_to_gpu(e, device)?));
17463    }
17464    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17465    let ctx = device.context();
17466    let stream = device.stream();
17467    let ptx = get_f64_ptx(&CACHE, CUMMAX_PTX, "cummax_kernel", "cummax_f64_kernel");
17468    let f = crate::module_cache::get_or_compile(ctx, ptx, "cummax_f64_kernel", device.ordinal() as u32)
17469        .map_err(|_| GpuError::PtxCompileFailed { kernel: "cummax_f64_kernel" })?;
17470    let mut out = alloc_zeros_f64(n, device)?;
17471    let mut ind = alloc_zeros_f64(n, device)?;
17472    let cfg = launch_cfg(total)?;
17473    let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17474    unsafe {
17475        stream.launch_builder(&f)
17476            .arg(input.inner())
17477            .arg(out.inner_mut())
17478            .arg(ind.inner_mut())
17479            .arg(&o)
17480            .arg(&d)
17481            .arg(&i)
17482            .arg(&t)
17483            .launch(cfg)?;
17484    }
17485    Ok((out, ind))
17486}
17487
17488/// Cumulative min for f64. Returns (values, indices).
17489#[cfg(feature = "cuda")]
17490pub fn gpu_cummin_f64(
17491    input: &CudaBuffer<f64>,
17492    outer: usize,
17493    dim_size: usize,
17494    inner: usize,
17495    device: &GpuDevice,
17496) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
17497    use cudarc::driver::PushKernelArg;
17498    let total = outer * inner;
17499    let n = outer * dim_size * inner;
17500    if n == 0 {
17501        let e: &[f64] = &[];
17502        return Ok((cpu_to_gpu(e, device)?, cpu_to_gpu(e, device)?));
17503    }
17504    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17505    let ctx = device.context();
17506    let stream = device.stream();
17507    let ptx = get_f64_ptx(&CACHE, CUMMIN_PTX, "cummin_kernel", "cummin_f64_kernel");
17508    let f = crate::module_cache::get_or_compile(ctx, ptx, "cummin_f64_kernel", device.ordinal() as u32)
17509        .map_err(|_| GpuError::PtxCompileFailed { kernel: "cummin_f64_kernel" })?;
17510    let mut out = alloc_zeros_f64(n, device)?;
17511    let mut ind = alloc_zeros_f64(n, device)?;
17512    let cfg = launch_cfg(total)?;
17513    let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17514    unsafe {
17515        stream.launch_builder(&f)
17516            .arg(input.inner())
17517            .arg(out.inner_mut())
17518            .arg(ind.inner_mut())
17519            .arg(&o)
17520            .arg(&d)
17521            .arg(&i)
17522            .arg(&t)
17523            .launch(cfg)?;
17524    }
17525    Ok((out, ind))
17526}
17527
17528/// Log-cumulative-sum-exp for f64.
17529#[cfg(feature = "cuda")]
17530pub fn gpu_logcumsumexp_f64(
17531    input: &CudaBuffer<f64>,
17532    outer: usize,
17533    dim_size: usize,
17534    inner: usize,
17535    device: &GpuDevice,
17536) -> GpuResult<CudaBuffer<f64>> {
17537    use cudarc::driver::PushKernelArg;
17538    let total = outer * inner;
17539    let n = outer * dim_size * inner;
17540    if n == 0 { return cpu_to_gpu(&[], device); }
17541    let ctx = device.context();
17542    let stream = device.stream();
17543    if let Ok(f) = crate::module_cache::get_or_compile(ctx, LOGCUMSUMEXP_F64_PTX, "logcumsumexp_f64_kernel", device.ordinal() as u32) {
17544        let mut out = alloc_zeros_f64(n, device)?;
17545        let cfg = launch_cfg(total)?;
17546        let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17547        unsafe {
17548            stream.launch_builder(&f)
17549                .arg(input.inner())
17550                .arg(out.inner_mut())
17551                .arg(&o)
17552                .arg(&d)
17553                .arg(&i)
17554                .arg(&t)
17555                .launch(cfg)?;
17556        }
17557        return Ok(out);
17558    }
17559    Err(GpuError::PtxCompileFailed { kernel: "logcumsumexp_f64_kernel" })
17560}
17561
17562// ---------------------------------------------------------------------------
17563// Public API -- f64 softmax / log-softmax / layernorm / rmsnorm launchers
17564// ---------------------------------------------------------------------------
17565
17566/// Row-wise softmax for f64 on GPU.
17567///
17568/// For each row: `out[j] = exp(x[j] - max(x)) / sum(exp(x - max(x)))`.
17569/// One block per row, 256 threads per block, shared-memory reductions.
17570#[cfg(feature = "cuda")]
17571pub fn gpu_softmax_f64(
17572    input: &CudaBuffer<f64>,
17573    rows: usize,
17574    cols: usize,
17575    device: &GpuDevice,
17576) -> GpuResult<CudaBuffer<f64>> {
17577    use cudarc::driver::PushKernelArg;
17578
17579    validate_device(input, device)?;
17580
17581    let ctx = device.context();
17582    let stream = device.stream();
17583
17584    let f = match crate::module_cache::get_or_compile(
17585        ctx,
17586        SOFTMAX_F64_PTX,
17587        "softmax_f64_kernel",
17588        device.ordinal() as u32,
17589    ) {
17590        Ok(f) => f,
17591        Err(_) => {
17592            let host = gpu_to_cpu(input, device)?;
17593            let mut out = vec![0.0f64; host.len()];
17594            for r in 0..rows {
17595                let base = r * cols;
17596                let mut max_v = f64::NEG_INFINITY;
17597                for c in 0..cols {
17598                    max_v = max_v.max(host[base + c]);
17599                }
17600                let mut sum = 0.0f64;
17601                for c in 0..cols {
17602                    let e = (host[base + c] - max_v).exp();
17603                    out[base + c] = e;
17604                    sum += e;
17605                }
17606                let inv = 1.0 / sum;
17607                for c in 0..cols {
17608                    out[base + c] *= inv;
17609                }
17610            }
17611            return cpu_to_gpu(&out, device);
17612        }
17613    };
17614
17615    let mut out = alloc_zeros_f64(rows * cols, device)?;
17616    let rows_u32 = rows as u32;
17617    let cols_u32 = cols as u32;
17618
17619    let cfg = LaunchConfig {
17620        grid_dim: ((rows as u32).max(1), 1, 1),
17621        block_dim: (256, 1, 1),
17622        shared_mem_bytes: 256 * 8, // sdata[256] f64
17623    };
17624
17625    unsafe {
17626        stream
17627            .launch_builder(&f)
17628            .arg(input.inner())
17629            .arg(out.inner_mut())
17630            .arg(&rows_u32)
17631            .arg(&cols_u32)
17632            .launch(cfg)?;
17633    }
17634
17635    Ok(out)
17636}
17637
17638/// Row-wise softmax backward for f64 on GPU.
17639///
17640/// For each row: `out[j] = output[j] * (grad[j] - dot(grad_row, output_row))`.
17641#[cfg(feature = "cuda")]
17642pub fn gpu_softmax_backward_f64(
17643    grad: &CudaBuffer<f64>,
17644    output: &CudaBuffer<f64>,
17645    cols: usize,
17646    device: &GpuDevice,
17647) -> GpuResult<CudaBuffer<f64>> {
17648    use cudarc::driver::PushKernelArg;
17649
17650    validate_device(grad, device)?;
17651    if grad.len() != output.len() {
17652        return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
17653    }
17654
17655    let total = grad.len();
17656    let rows = total / cols;
17657
17658    let ctx = device.context();
17659    let stream = device.stream();
17660
17661    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17662    let ptx = get_f64_ptx(&CACHE, SOFTMAX_BACKWARD_PTX, "softmax_backward_kernel", "softmax_backward_f64_kernel");
17663    let f = match crate::module_cache::get_or_compile(
17664        ctx,
17665        ptx,
17666        "softmax_backward_f64_kernel",
17667        device.ordinal() as u32,
17668    ) {
17669        Ok(f) => f,
17670        Err(_) => {
17671            let grad_host = gpu_to_cpu(grad, device)?;
17672            let output_host = gpu_to_cpu(output, device)?;
17673            let mut result = vec![0.0f64; total];
17674            for r in 0..rows {
17675                let base = r * cols;
17676                let mut dot = 0.0f64;
17677                for c in 0..cols {
17678                    dot += grad_host[base + c] * output_host[base + c];
17679                }
17680                for c in 0..cols {
17681                    result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
17682                }
17683            }
17684            return cpu_to_gpu(&result, device);
17685        }
17686    };
17687
17688    let mut out = alloc_zeros_f64(total, device)?;
17689    let rows_u32 = rows as u32;
17690    let cols_u32 = cols as u32;
17691
17692    let cfg = LaunchConfig {
17693        grid_dim: ((rows as u32).max(1), 1, 1),
17694        block_dim: (256, 1, 1),
17695        shared_mem_bytes: 256 * 8,
17696    };
17697
17698    unsafe {
17699        stream
17700            .launch_builder(&f)
17701            .arg(grad.inner())
17702            .arg(output.inner())
17703            .arg(out.inner_mut())
17704            .arg(&rows_u32)
17705            .arg(&cols_u32)
17706            .launch(cfg)?;
17707    }
17708
17709    Ok(out)
17710}
17711
17712/// Row-wise log-softmax for f64 on GPU.
17713///
17714/// For each row: `out[j] = x[j] - log(sum(exp(x - max(x))))`.
17715#[cfg(feature = "cuda")]
17716pub fn gpu_log_softmax_f64(
17717    input: &CudaBuffer<f64>,
17718    cols: usize,
17719    device: &GpuDevice,
17720) -> GpuResult<CudaBuffer<f64>> {
17721    use cudarc::driver::PushKernelArg;
17722
17723    validate_device(input, device)?;
17724
17725    let total = input.len();
17726    let rows = total / cols;
17727
17728    let ctx = device.context();
17729    let stream = device.stream();
17730
17731    let f = match crate::module_cache::get_or_compile(
17732        ctx,
17733        LOG_SOFTMAX_F64_PTX,
17734        "log_softmax_f64_kernel",
17735        device.ordinal() as u32,
17736    ) {
17737        Ok(f) => f,
17738        Err(_) => {
17739            let host = gpu_to_cpu(input, device)?;
17740            let mut out = vec![0.0f64; total];
17741            for r in 0..rows {
17742                let base = r * cols;
17743                let mut max_v = f64::NEG_INFINITY;
17744                for c in 0..cols {
17745                    max_v = max_v.max(host[base + c]);
17746                }
17747                let mut sum_exp = 0.0f64;
17748                for c in 0..cols {
17749                    sum_exp += (host[base + c] - max_v).exp();
17750                }
17751                let log_sum_exp = max_v + sum_exp.ln();
17752                for c in 0..cols {
17753                    out[base + c] = host[base + c] - log_sum_exp;
17754                }
17755            }
17756            return cpu_to_gpu(&out, device);
17757        }
17758    };
17759
17760    let mut out = alloc_zeros_f64(total, device)?;
17761    let rows_u32 = rows as u32;
17762    let cols_u32 = cols as u32;
17763
17764    let cfg = LaunchConfig {
17765        grid_dim: ((rows as u32).max(1), 1, 1),
17766        block_dim: (256, 1, 1),
17767        shared_mem_bytes: 256 * 8,
17768    };
17769
17770    unsafe {
17771        stream
17772            .launch_builder(&f)
17773            .arg(input.inner())
17774            .arg(out.inner_mut())
17775            .arg(&rows_u32)
17776            .arg(&cols_u32)
17777            .launch(cfg)?;
17778    }
17779
17780    Ok(out)
17781}
17782
17783/// Row-wise log-softmax backward for f64 on GPU.
17784///
17785/// For each row:
17786///   `sum_grad = sum(grad[j])`
17787///   `out[j] = grad[j] - exp(output[j]) * sum_grad`
17788#[cfg(feature = "cuda")]
17789pub fn gpu_log_softmax_backward_f64(
17790    grad: &CudaBuffer<f64>,
17791    output: &CudaBuffer<f64>,
17792    cols: usize,
17793    device: &GpuDevice,
17794) -> GpuResult<CudaBuffer<f64>> {
17795    use cudarc::driver::PushKernelArg;
17796
17797    validate_device(grad, device)?;
17798    if grad.len() != output.len() {
17799        return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
17800    }
17801
17802    let total = grad.len();
17803    let rows = total / cols;
17804
17805    let ctx = device.context();
17806    let stream = device.stream();
17807
17808    let f = match crate::module_cache::get_or_compile(
17809        ctx,
17810        LOG_SOFTMAX_BACKWARD_F64_PTX,
17811        "log_softmax_backward_f64_kernel",
17812        device.ordinal() as u32,
17813    ) {
17814        Ok(f) => f,
17815        Err(_) => {
17816            let grad_host = gpu_to_cpu(grad, device)?;
17817            let output_host = gpu_to_cpu(output, device)?;
17818            let mut result = vec![0.0f64; total];
17819            for r in 0..rows {
17820                let base = r * cols;
17821                let mut sum_grad = 0.0f64;
17822                for c in 0..cols {
17823                    sum_grad += grad_host[base + c];
17824                }
17825                for c in 0..cols {
17826                    result[base + c] =
17827                        grad_host[base + c] - output_host[base + c].exp() * sum_grad;
17828                }
17829            }
17830            return cpu_to_gpu(&result, device);
17831        }
17832    };
17833
17834    let mut out = alloc_zeros_f64(total, device)?;
17835    let rows_u32 = rows as u32;
17836    let cols_u32 = cols as u32;
17837
17838    let cfg = LaunchConfig {
17839        grid_dim: ((rows as u32).max(1), 1, 1),
17840        block_dim: (256, 1, 1),
17841        shared_mem_bytes: 256 * 8,
17842    };
17843
17844    unsafe {
17845        stream
17846            .launch_builder(&f)
17847            .arg(grad.inner())
17848            .arg(output.inner())
17849            .arg(out.inner_mut())
17850            .arg(&rows_u32)
17851            .arg(&cols_u32)
17852            .launch(cfg)?;
17853    }
17854
17855    Ok(out)
17856}
17857
17858/// Row-wise LayerNorm for f64 on GPU.
17859///
17860/// `input`: `[rows * cols]`, `weight`: `[cols]`, `bias`: `[cols]`.
17861/// `out[j] = weight[j] * (x[j] - mean) / sqrt(var + eps) + bias[j]`.
17862#[cfg(feature = "cuda")]
17863pub fn gpu_layernorm_f64(
17864    input: &CudaBuffer<f64>,
17865    weight: &CudaBuffer<f64>,
17866    bias: &CudaBuffer<f64>,
17867    rows: usize,
17868    cols: usize,
17869    eps: f64,
17870    device: &GpuDevice,
17871) -> GpuResult<CudaBuffer<f64>> {
17872    use cudarc::driver::PushKernelArg;
17873    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17874
17875    validate_device(input, device)?;
17876
17877    let ctx = device.context();
17878    let stream = device.stream();
17879
17880    let ptx = get_f64_ptx(&CACHE, LAYERNORM_PTX, "layernorm_kernel", "layernorm_f64_kernel");
17881    let f = match crate::module_cache::get_or_compile(
17882        ctx,
17883        ptx,
17884        "layernorm_f64_kernel",
17885        device.ordinal() as u32,
17886    ) {
17887        Ok(f) => f,
17888        Err(_) => {
17889            let h_in = gpu_to_cpu(input, device)?;
17890            let h_w = gpu_to_cpu(weight, device)?;
17891            let h_b = gpu_to_cpu(bias, device)?;
17892            let mut out = vec![0.0f64; rows * cols];
17893            for r in 0..rows {
17894                let base = r * cols;
17895                let slice = &h_in[base..base + cols];
17896                let mean: f64 = slice.iter().sum::<f64>() / cols as f64;
17897                let var: f64 =
17898                    slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / cols as f64;
17899                let inv_std = 1.0 / (var + eps).sqrt();
17900                for c in 0..cols {
17901                    let normed = (slice[c] - mean) * inv_std;
17902                    out[base + c] = h_w[c] * normed + h_b[c];
17903                }
17904            }
17905            return cpu_to_gpu(&out, device);
17906        }
17907    };
17908
17909    let mut out = alloc_zeros_f64(rows * cols, device)?;
17910    let rows_u32 = rows as u32;
17911    let cols_u32 = cols as u32;
17912
17913    let cfg = LaunchConfig {
17914        grid_dim: ((rows as u32).max(1), 1, 1),
17915        block_dim: (256, 1, 1),
17916        shared_mem_bytes: 256 * 8,
17917    };
17918
17919    unsafe {
17920        stream
17921            .launch_builder(&f)
17922            .arg(input.inner())
17923            .arg(out.inner_mut())
17924            .arg(weight.inner())
17925            .arg(bias.inner())
17926            .arg(&rows_u32)
17927            .arg(&cols_u32)
17928            .arg(&eps)
17929            .launch(cfg)?;
17930    }
17931
17932    Ok(out)
17933}
17934
17935/// LayerNorm backward for f64 on GPU.
17936///
17937/// Returns `(grad_input [rows * cols], grad_weight [cols], grad_bias [cols])`.
17938#[cfg(feature = "cuda")]
17939pub fn gpu_layernorm_backward_f64(
17940    input: &CudaBuffer<f64>,
17941    grad_output: &CudaBuffer<f64>,
17942    weight: &CudaBuffer<f64>,
17943    rows: usize,
17944    cols: usize,
17945    eps: f64,
17946    device: &GpuDevice,
17947) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>, CudaBuffer<f64>)> {
17948    use cudarc::driver::PushKernelArg;
17949    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17950
17951    validate_device(input, device)?;
17952
17953    let ctx = device.context();
17954    let stream = device.stream();
17955
17956    let ptx = get_f64_ptx(&CACHE, LAYERNORM_BACKWARD_PTX, "layernorm_backward_kernel", "layernorm_backward_f64_kernel");
17957    let f = match crate::module_cache::get_or_compile(
17958        ctx,
17959        ptx,
17960        "layernorm_backward_f64_kernel",
17961        device.ordinal() as u32,
17962    ) {
17963        Ok(f) => f,
17964        Err(_) => {
17965            let h_in = gpu_to_cpu(input, device)?;
17966            let h_go = gpu_to_cpu(grad_output, device)?;
17967            let h_w = gpu_to_cpu(weight, device)?;
17968            let mut grad_input = vec![0.0f64; rows * cols];
17969            let mut grad_weight = vec![0.0f64; cols];
17970            let mut grad_bias = vec![0.0f64; cols];
17971            let n_f = cols as f64;
17972            for r in 0..rows {
17973                let base = r * cols;
17974                let x_slice = &h_in[base..base + cols];
17975                let go_slice = &h_go[base..base + cols];
17976                let mean: f64 = x_slice.iter().sum::<f64>() / n_f;
17977                let var: f64 = x_slice
17978                    .iter()
17979                    .map(|&x| (x - mean) * (x - mean))
17980                    .sum::<f64>()
17981                    / n_f;
17982                let inv_std = 1.0 / (var + eps).sqrt();
17983                let mut sum1 = 0.0f64;
17984                let mut sum2 = 0.0f64;
17985                for c in 0..cols {
17986                    let x_hat = (x_slice[c] - mean) * inv_std;
17987                    let dl = go_slice[c] * h_w[c];
17988                    sum1 += dl;
17989                    sum2 += dl * x_hat;
17990                    grad_weight[c] += go_slice[c] * x_hat;
17991                    grad_bias[c] += go_slice[c];
17992                }
17993                let m1 = sum1 / n_f;
17994                let m2 = sum2 / n_f;
17995                for c in 0..cols {
17996                    let x_hat = (x_slice[c] - mean) * inv_std;
17997                    let dl = go_slice[c] * h_w[c];
17998                    grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
17999                }
18000            }
18001            let gi = cpu_to_gpu(&grad_input, device)?;
18002            let gw = cpu_to_gpu(&grad_weight, device)?;
18003            let gb = cpu_to_gpu(&grad_bias, device)?;
18004            return Ok((gi, gw, gb));
18005        }
18006    };
18007
18008    let mut grad_in = alloc_zeros_f64(rows * cols, device)?;
18009    let mut grad_w = alloc_zeros_f64(cols, device)?;
18010    let mut grad_b = alloc_zeros_f64(cols, device)?;
18011    let rows_u32 = rows as u32;
18012    let cols_u32 = cols as u32;
18013
18014    let cfg = LaunchConfig {
18015        grid_dim: ((rows as u32).max(1), 1, 1),
18016        block_dim: (256, 1, 1),
18017        shared_mem_bytes: 256 * 8,
18018    };
18019
18020    unsafe {
18021        stream
18022            .launch_builder(&f)
18023            .arg(input.inner())
18024            .arg(grad_output.inner())
18025            .arg(weight.inner())
18026            .arg(grad_in.inner_mut())
18027            .arg(grad_w.inner_mut())
18028            .arg(grad_b.inner_mut())
18029            .arg(&rows_u32)
18030            .arg(&cols_u32)
18031            .arg(&eps)
18032            .launch(cfg)?;
18033    }
18034
18035    Ok((grad_in, grad_w, grad_b))
18036}
18037
18038/// Row-wise RMS normalization for f64 on GPU.
18039///
18040/// `input`: `[rows * cols]`, `weight`: `[cols]`.
18041/// `out[j] = x[j] * rsqrt(mean(x^2) + eps) * weight[j]`.
18042#[cfg(feature = "cuda")]
18043pub fn gpu_rmsnorm_f64(
18044    input: &CudaBuffer<f64>,
18045    weight: &CudaBuffer<f64>,
18046    rows: usize,
18047    cols: usize,
18048    eps: f64,
18049    device: &GpuDevice,
18050) -> GpuResult<CudaBuffer<f64>> {
18051    use cudarc::driver::PushKernelArg;
18052    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
18053
18054    validate_device(input, device)?;
18055
18056    let ctx = device.context();
18057    let stream = device.stream();
18058
18059    let ptx = get_f64_ptx(&CACHE, RMSNORM_PTX, "rmsnorm_kernel", "rmsnorm_f64_kernel");
18060    let f = match crate::module_cache::get_or_compile(
18061        ctx,
18062        ptx,
18063        "rmsnorm_f64_kernel",
18064        device.ordinal() as u32,
18065    ) {
18066        Ok(f) => f,
18067        Err(_) => {
18068            let h_in = gpu_to_cpu(input, device)?;
18069            let h_w = gpu_to_cpu(weight, device)?;
18070            let mut out = vec![0.0f64; rows * cols];
18071            for r in 0..rows {
18072                let base = r * cols;
18073                let slice = &h_in[base..base + cols];
18074                let sq_mean: f64 =
18075                    slice.iter().map(|&x| x * x).sum::<f64>() / cols as f64;
18076                let inv_rms = 1.0 / (sq_mean + eps).sqrt();
18077                for c in 0..cols {
18078                    out[base + c] = slice[c] * inv_rms * h_w[c];
18079                }
18080            }
18081            return cpu_to_gpu(&out, device);
18082        }
18083    };
18084
18085    let mut out = alloc_zeros_f64(rows * cols, device)?;
18086    let rows_u32 = rows as u32;
18087    let cols_u32 = cols as u32;
18088
18089    let cfg = LaunchConfig {
18090        grid_dim: ((rows as u32).max(1), 1, 1),
18091        block_dim: (256, 1, 1),
18092        shared_mem_bytes: 256 * 8,
18093    };
18094
18095    unsafe {
18096        stream
18097            .launch_builder(&f)
18098            .arg(input.inner())
18099            .arg(out.inner_mut())
18100            .arg(weight.inner())
18101            .arg(&rows_u32)
18102            .arg(&cols_u32)
18103            .arg(&eps)
18104            .launch(cfg)?;
18105    }
18106
18107    Ok(out)
18108}
18109
18110/// RMSNorm backward for f64 on GPU.
18111///
18112/// Returns `(grad_input [rows * cols], grad_weight [cols])`.
18113#[cfg(feature = "cuda")]
18114pub fn gpu_rmsnorm_backward_f64(
18115    input: &CudaBuffer<f64>,
18116    grad_output: &CudaBuffer<f64>,
18117    weight: &CudaBuffer<f64>,
18118    rows: usize,
18119    cols: usize,
18120    eps: f64,
18121    device: &GpuDevice,
18122) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
18123    use cudarc::driver::PushKernelArg;
18124    static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
18125
18126    validate_device(input, device)?;
18127
18128    let ctx = device.context();
18129    let stream = device.stream();
18130
18131    let ptx = get_f64_ptx(&CACHE, RMSNORM_BACKWARD_PTX, "rmsnorm_backward_kernel", "rmsnorm_backward_f64_kernel");
18132    let f = match crate::module_cache::get_or_compile(
18133        ctx,
18134        ptx,
18135        "rmsnorm_backward_f64_kernel",
18136        device.ordinal() as u32,
18137    ) {
18138        Ok(f) => f,
18139        Err(_) => {
18140            let h_in = gpu_to_cpu(input, device)?;
18141            let h_go = gpu_to_cpu(grad_output, device)?;
18142            let h_w = gpu_to_cpu(weight, device)?;
18143            let mut grad_input = vec![0.0f64; rows * cols];
18144            let mut grad_weight = vec![0.0f64; cols];
18145            let n_f = cols as f64;
18146            for r in 0..rows {
18147                let base = r * cols;
18148                let x_slice = &h_in[base..base + cols];
18149                let go_slice = &h_go[base..base + cols];
18150                let sq_mean: f64 =
18151                    x_slice.iter().map(|&x| x * x).sum::<f64>() / n_f;
18152                let inv_rms = 1.0 / (sq_mean + eps).sqrt();
18153                let inv_rms3 = inv_rms * inv_rms * inv_rms;
18154                let mut dot = 0.0f64;
18155                for c in 0..cols {
18156                    dot += go_slice[c] * x_slice[c] * h_w[c];
18157                    grad_weight[c] += go_slice[c] * x_slice[c] * inv_rms;
18158                }
18159                let coeff = dot * inv_rms3 / n_f;
18160                for c in 0..cols {
18161                    grad_input[base + c] =
18162                        inv_rms * h_w[c] * go_slice[c] - x_slice[c] * coeff;
18163                }
18164            }
18165            let gi = cpu_to_gpu(&grad_input, device)?;
18166            let gw = cpu_to_gpu(&grad_weight, device)?;
18167            return Ok((gi, gw));
18168        }
18169    };
18170
18171    let mut grad_in = alloc_zeros_f64(rows * cols, device)?;
18172    let mut grad_w = alloc_zeros_f64(cols, device)?;
18173    let rows_u32 = rows as u32;
18174    let cols_u32 = cols as u32;
18175
18176    let cfg = LaunchConfig {
18177        grid_dim: ((rows as u32).max(1), 1, 1),
18178        block_dim: (256, 1, 1),
18179        shared_mem_bytes: 256 * 8,
18180    };
18181
18182    unsafe {
18183        stream
18184            .launch_builder(&f)
18185            .arg(input.inner())
18186            .arg(grad_output.inner())
18187            .arg(weight.inner())
18188            .arg(grad_in.inner_mut())
18189            .arg(grad_w.inner_mut())
18190            .arg(&rows_u32)
18191            .arg(&cols_u32)
18192            .arg(&eps)
18193            .launch(cfg)?;
18194    }
18195
18196    Ok((grad_in, grad_w))
18197}
18198
18199// ---------------------------------------------------------------------------
18200// Non-cuda stubs for softmax/layernorm/rmsnorm f64
18201// ---------------------------------------------------------------------------
18202
18203#[cfg(not(feature = "cuda"))]
18204pub fn gpu_softmax_f64(_input: &CudaBuffer<f64>, _rows: usize, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18205#[cfg(not(feature = "cuda"))]
18206pub fn gpu_softmax_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18207#[cfg(not(feature = "cuda"))]
18208pub fn gpu_log_softmax_f64(_input: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18209#[cfg(not(feature = "cuda"))]
18210pub fn gpu_log_softmax_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18211#[cfg(not(feature = "cuda"))]
18212pub fn gpu_layernorm_f64(_input: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _bias: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18213#[cfg(not(feature = "cuda"))]
18214pub fn gpu_layernorm_backward_f64(_input: &CudaBuffer<f64>, _grad_output: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18215#[cfg(not(feature = "cuda"))]
18216pub fn gpu_rmsnorm_f64(_input: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18217#[cfg(not(feature = "cuda"))]
18218pub fn gpu_rmsnorm_backward_f64(_input: &CudaBuffer<f64>, _grad_output: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18219
18220// ---------------------------------------------------------------------------
18221// Non-cuda stubs for new f64 ops
18222// ---------------------------------------------------------------------------
18223
18224#[cfg(not(feature = "cuda"))]
18225pub fn gpu_gelu_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18226#[cfg(not(feature = "cuda"))]
18227pub fn gpu_gelu_tanh_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18228#[cfg(not(feature = "cuda"))]
18229pub fn gpu_gelu_erf_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18230#[cfg(not(feature = "cuda"))]
18231pub fn gpu_gelu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18232#[cfg(not(feature = "cuda"))]
18233pub fn gpu_gelu_backward_tanh_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18234#[cfg(not(feature = "cuda"))]
18235pub fn gpu_gelu_backward_erf_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18236#[cfg(not(feature = "cuda"))]
18237pub fn gpu_silu_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18238#[cfg(not(feature = "cuda"))]
18239pub fn gpu_silu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18240#[cfg(not(feature = "cuda"))]
18241pub fn gpu_elu_f64(_input: &CudaBuffer<f64>, _alpha: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18242#[cfg(not(feature = "cuda"))]
18243pub fn gpu_elu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _alpha: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18244#[cfg(not(feature = "cuda"))]
18245pub fn gpu_mish_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18246#[cfg(not(feature = "cuda"))]
18247pub fn gpu_mish_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18248#[cfg(not(feature = "cuda"))]
18249pub fn gpu_clamp_f64(_input: &CudaBuffer<f64>, _min: f64, _max: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18250#[cfg(not(feature = "cuda"))]
18251pub fn gpu_cumsum_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18252#[cfg(not(feature = "cuda"))]
18253pub fn gpu_cumprod_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18254#[cfg(not(feature = "cuda"))]
18255pub fn gpu_cummax_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18256#[cfg(not(feature = "cuda"))]
18257pub fn gpu_cummin_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18258#[cfg(not(feature = "cuda"))]
18259pub fn gpu_logcumsumexp_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18260
18261// ---------------------------------------------------------------------------
18262// Tests -- require a real CUDA GPU
18263// ---------------------------------------------------------------------------
18264
18265#[cfg(test)]
18266#[cfg(feature = "cuda")]
18267mod tests {
18268    use super::*;
18269
18270    /// Helper: set up device + upload a slice.
18271    fn setup(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
18272        let dev = GpuDevice::new(0).expect("CUDA device 0");
18273        let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
18274        (dev, buf)
18275    }
18276
18277    /// Round-trip helper: download a GPU buffer and compare against expected
18278    /// CPU output element-wise.
18279    fn assert_buf_eq(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32]) {
18280        let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
18281        assert_eq!(host.len(), expected.len(), "length mismatch");
18282        for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
18283            assert!(
18284                (got - exp).abs() < 1e-6,
18285                "element {i}: got {got}, expected {exp}",
18286            );
18287        }
18288    }
18289
18290    // -- gpu_add -------------------------------------------------------------
18291
18292    #[test]
18293    fn add_basic() {
18294        let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
18295        let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
18296        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
18297
18298        let (dev, a) = setup(&a_data);
18299        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18300        let out = gpu_add(&a, &b, &dev).expect("gpu_add");
18301        assert_buf_eq(&out, &dev, &expected);
18302    }
18303
18304    #[test]
18305    fn add_empty() {
18306        let (dev, a) = setup(&[]);
18307        let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
18308        let out = gpu_add(&a, &b, &dev).expect("gpu_add empty");
18309        assert_eq!(out.len(), 0);
18310    }
18311
18312    #[test]
18313    fn add_large() {
18314        let n = 100_000;
18315        let a_data: Vec<f32> = (0..n).map(|i| i as f32).collect();
18316        let b_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
18317        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
18318
18319        let (dev, a) = setup(&a_data);
18320        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18321        let out = gpu_add(&a, &b, &dev).expect("gpu_add large");
18322        assert_buf_eq(&out, &dev, &expected);
18323    }
18324
18325    #[test]
18326    fn add_length_mismatch() {
18327        let (dev, a) = setup(&[1.0, 2.0, 3.0]);
18328        let b = cpu_to_gpu(&[1.0, 2.0], &dev).expect("cpu_to_gpu b");
18329        let err = gpu_add(&a, &b, &dev).unwrap_err();
18330        match err {
18331            GpuError::LengthMismatch { a: 3, b: 2 } => {}
18332            other => panic!("unexpected error: {other}"),
18333        }
18334    }
18335
18336    // -- gpu_sub -------------------------------------------------------------
18337
18338    #[test]
18339    fn sub_basic() {
18340        let a_data = vec![10.0f32, 20.0, 30.0, 40.0];
18341        let b_data = vec![1.0f32, 2.0, 3.0, 4.0];
18342        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x - y).collect();
18343
18344        let (dev, a) = setup(&a_data);
18345        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18346        let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
18347        assert_buf_eq(&out, &dev, &expected);
18348    }
18349
18350    #[test]
18351    fn sub_negative_result() {
18352        let a_data = vec![1.0f32, 2.0];
18353        let b_data = vec![5.0f32, 10.0];
18354        let expected: Vec<f32> = vec![-4.0, -8.0];
18355
18356        let (dev, a) = setup(&a_data);
18357        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18358        let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
18359        assert_buf_eq(&out, &dev, &expected);
18360    }
18361
18362    // -- gpu_mul -------------------------------------------------------------
18363
18364    #[test]
18365    fn mul_basic() {
18366        let a_data = vec![2.0f32, 3.0, 4.0, 5.0];
18367        let b_data = vec![10.0f32, 10.0, 10.0, 10.0];
18368        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x * y).collect();
18369
18370        let (dev, a) = setup(&a_data);
18371        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18372        let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
18373        assert_buf_eq(&out, &dev, &expected);
18374    }
18375
18376    #[test]
18377    fn mul_by_zero() {
18378        let a_data = vec![1.0f32, 2.0, 3.0];
18379        let b_data = vec![0.0f32, 0.0, 0.0];
18380        let expected = vec![0.0f32, 0.0, 0.0];
18381
18382        let (dev, a) = setup(&a_data);
18383        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18384        let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
18385        assert_buf_eq(&out, &dev, &expected);
18386    }
18387
18388    // -- gpu_neg -------------------------------------------------------------
18389
18390    #[test]
18391    fn neg_basic() {
18392        let a_data = vec![1.0f32, -2.0, 3.0, 0.0, -5.5];
18393        let expected: Vec<f32> = a_data.iter().map(|x| -x).collect();
18394
18395        let (dev, a) = setup(&a_data);
18396        let out = gpu_neg(&a, &dev).expect("gpu_neg");
18397        assert_buf_eq(&out, &dev, &expected);
18398    }
18399
18400    #[test]
18401    fn neg_double_negation() {
18402        let a_data = vec![1.0f32, -2.0, 3.0];
18403        let (dev, a) = setup(&a_data);
18404        let neg1 = gpu_neg(&a, &dev).expect("gpu_neg 1");
18405        let neg2 = gpu_neg(&neg1, &dev).expect("gpu_neg 2");
18406        assert_buf_eq(&neg2, &dev, &a_data);
18407    }
18408
18409    // -- gpu_relu ------------------------------------------------------------
18410
18411    #[test]
18412    fn relu_basic() {
18413        let a_data = vec![-3.0f32, -1.0, 0.0, 1.0, 3.0];
18414        let expected = vec![0.0f32, 0.0, 0.0, 1.0, 3.0];
18415
18416        let (dev, a) = setup(&a_data);
18417        let out = gpu_relu(&a, &dev).expect("gpu_relu");
18418        assert_buf_eq(&out, &dev, &expected);
18419    }
18420
18421    #[test]
18422    fn relu_all_negative() {
18423        let a_data = vec![-5.0f32, -0.1, -100.0];
18424        let expected = vec![0.0f32, 0.0, 0.0];
18425
18426        let (dev, a) = setup(&a_data);
18427        let out = gpu_relu(&a, &dev).expect("gpu_relu");
18428        assert_buf_eq(&out, &dev, &expected);
18429    }
18430
18431    #[test]
18432    fn relu_all_positive() {
18433        let a_data = vec![0.1f32, 1.0, 100.0];
18434
18435        let (dev, a) = setup(&a_data);
18436        let out = gpu_relu(&a, &dev).expect("gpu_relu");
18437        assert_buf_eq(&out, &dev, &a_data);
18438    }
18439
18440    #[test]
18441    fn relu_empty() {
18442        let (dev, a) = setup(&[]);
18443        let out = gpu_relu(&a, &dev).expect("gpu_relu empty");
18444        assert_eq!(out.len(), 0);
18445    }
18446
18447    #[test]
18448    fn small_matmul_2x2() {
18449        let dev = GpuDevice::new(0).expect("CUDA device 0");
18450        // A = [[1, 2], [3, 4]], B = [[5, 6], [7, 8]]
18451        // C = A@B = [[19, 22], [43, 50]]
18452        let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0, 4.0], &dev).unwrap();
18453        let b = cpu_to_gpu(&[5.0f32, 6.0, 7.0, 8.0], &dev).unwrap();
18454        let c = gpu_small_matmul(&a, &b, 2, 2, 2, &dev).unwrap();
18455        assert_buf_eq(&c, &dev, &[19.0, 22.0, 43.0, 50.0]);
18456    }
18457
18458    #[test]
18459    fn small_matmul_1xk_kxn() {
18460        let dev = GpuDevice::new(0).expect("CUDA device 0");
18461        // A = [1, 2, 3] (1x3), B = [[1, 0], [0, 1], [1, 1]] (3x2)
18462        // C = [4, 5] (1x2)
18463        let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0], &dev).unwrap();
18464        let b = cpu_to_gpu(&[1.0f32, 0.0, 0.0, 1.0, 1.0, 1.0], &dev).unwrap();
18465        let c = gpu_small_matmul(&a, &b, 1, 3, 2, &dev).unwrap();
18466        assert_buf_eq(&c, &dev, &[4.0, 5.0]);
18467    }
18468
18469    #[test]
18470    fn small_matmul_vs_cublas() {
18471        // Compare our small matmul against cuBLAS for a realistic decode-step size.
18472        // Linear layer: [1, 64] @ [64, 64] = [1, 64]
18473        let dev = GpuDevice::new(0).expect("CUDA device 0");
18474        let m = 1;
18475        let k = 64;
18476        let n = 64;
18477
18478        // Deterministic data.
18479        let a_data: Vec<f32> = (0..m * k)
18480            .map(|i| ((i * 7 + 3) % 100) as f32 / 100.0)
18481            .collect();
18482        let b_data: Vec<f32> = (0..k * n)
18483            .map(|i| ((i * 11 + 5) % 100) as f32 / 100.0)
18484            .collect();
18485
18486        let a = cpu_to_gpu(&a_data, &dev).unwrap();
18487        let b = cpu_to_gpu(&b_data, &dev).unwrap();
18488
18489        // cuBLAS reference.
18490        let c_cublas = crate::blas::gpu_matmul_f32(&a, &b, m, k, n, &dev).unwrap();
18491        let cublas_result = gpu_to_cpu(&c_cublas, &dev).unwrap();
18492
18493        // Our kernel.
18494        let c_ours = gpu_small_matmul(&a, &b, m, k, n, &dev).unwrap();
18495        let our_result = gpu_to_cpu(&c_ours, &dev).unwrap();
18496
18497        assert_eq!(cublas_result.len(), our_result.len());
18498        for (i, (&cb, &ours)) in cublas_result.iter().zip(our_result.iter()).enumerate() {
18499            assert!(
18500                (cb - ours).abs() < 0.1,
18501                "element {i}: cuBLAS={cb}, ours={ours}, diff={}",
18502                (cb - ours).abs()
18503            );
18504        }
18505    }
18506
18507    // -- gpu_strided_copy (CL-496) -------------------------------------
18508
18509    #[test]
18510    fn strided_copy_identity_contiguous_2d() {
18511        // 2x3 contiguous — source strides are C-contiguous.
18512        // Source: [0, 1, 2, 3, 4, 5]
18513        // Expected output == source (identity copy).
18514        let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
18515        let (dev, input) = setup(&data);
18516        let out = gpu_strided_copy(&input, &[2, 3], &[3, 1], 0, &dev)
18517            .expect("strided_copy identity");
18518        assert_buf_eq(&out, &dev, &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
18519    }
18520
18521    #[test]
18522    fn strided_copy_transpose_2d() {
18523        // Source 2x3 contiguous:
18524        //   [[0, 1, 2],
18525        //    [3, 4, 5]]
18526        // Transposed view shape [3, 2] with strides [1, 3]:
18527        //   out[i, j] = src[j, i]
18528        //   Expected: [[0, 3], [1, 4], [2, 5]] flat = [0, 3, 1, 4, 2, 5]
18529        let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
18530        let (dev, input) = setup(&data);
18531        let out = gpu_strided_copy(&input, &[3, 2], &[1, 3], 0, &dev)
18532            .expect("strided_copy transpose");
18533        assert_buf_eq(&out, &dev, &[0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
18534    }
18535
18536    #[test]
18537    fn strided_copy_sliced_column() {
18538        // Source 3x4 contiguous:
18539        //   [[0, 1, 2, 3],
18540        //    [4, 5, 6, 7],
18541        //    [8, 9, 10, 11]]
18542        // Select column 2 via src_offset=2, shape=[3], stride=[4]:
18543        //   Expected: [2, 6, 10]
18544        let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
18545        let (dev, input) = setup(&data);
18546        let out = gpu_strided_copy(&input, &[3], &[4], 2, &dev)
18547            .expect("strided_copy col slice");
18548        assert_buf_eq(&out, &dev, &[2.0, 6.0, 10.0]);
18549    }
18550
18551    #[test]
18552    fn strided_copy_3d_permute() {
18553        // Source [2, 3, 4] contiguous, C-strides [12, 4, 1].
18554        // Permute (0, 2, 1) → view shape [2, 4, 3] with strides [12, 1, 4].
18555        //
18556        // out[b, i, j] = src[b, j, i]
18557        //
18558        // Build expected by doing the permute on the host.
18559        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
18560        let (dev, input) = setup(&data);
18561        let out =
18562            gpu_strided_copy(&input, &[2, 4, 3], &[12, 1, 4], 0, &dev).expect("strided_copy 3d");
18563
18564        let mut expected = vec![0.0f32; 24];
18565        for b in 0..2 {
18566            for i in 0..4 {
18567                for j in 0..3 {
18568                    let dst = b * 12 + i * 3 + j;
18569                    let src = b * 12 + j * 4 + i;
18570                    expected[dst] = data[src];
18571                }
18572            }
18573        }
18574        assert_buf_eq(&out, &dev, &expected);
18575    }
18576
18577    #[test]
18578    fn strided_copy_4d_max_rank_supported() {
18579        // Rank 4 identity copy works.
18580        let shape = [2usize, 3, 2, 2];
18581        let n: usize = shape.iter().product();
18582        let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
18583        let (dev, input) = setup(&data);
18584        // C-contiguous strides: [12, 4, 2, 1]
18585        let out = gpu_strided_copy(&input, &shape, &[12, 4, 2, 1], 0, &dev)
18586            .expect("strided_copy 4d");
18587        assert_buf_eq(&out, &dev, &data);
18588    }
18589
18590    #[test]
18591    fn strided_copy_rejects_too_many_dims() {
18592        let (dev, input) = setup(&[0.0f32; 16]);
18593        // 9 dims > STRIDED_COPY_MAX_DIMS (8)
18594        let result = gpu_strided_copy(
18595            &input,
18596            &[1, 1, 1, 1, 1, 1, 1, 1, 16],
18597            &[1; 9],
18598            0,
18599            &dev,
18600        );
18601        assert!(result.is_err());
18602    }
18603
18604    #[test]
18605    fn strided_copy_rejects_shape_stride_length_mismatch() {
18606        let (dev, input) = setup(&[0.0f32; 12]);
18607        let result = gpu_strided_copy(&input, &[3, 4], &[4, 1, 1], 0, &dev);
18608        assert!(result.is_err());
18609    }
18610
18611    #[test]
18612    fn strided_copy_rejects_negative_stride() {
18613        let (dev, input) = setup(&[0.0f32; 6]);
18614        let result = gpu_strided_copy(&input, &[2, 3], &[3, -1], 0, &dev);
18615        assert!(result.is_err());
18616    }
18617
18618    #[test]
18619    fn strided_copy_empty_output() {
18620        let (dev, input) = setup(&[1.0f32, 2.0, 3.0]);
18621        let out = gpu_strided_copy(&input, &[0, 3], &[3, 1], 0, &dev)
18622            .expect("strided_copy empty");
18623        assert_eq!(out.len(), 0);
18624    }
18625
18626    #[test]
18627    fn strided_copy_f64_transpose_matches_f32() {
18628        // Same transpose test as the f32 version, using f64.
18629        let data: Vec<f64> = (0..6).map(|i| i as f64).collect();
18630        let dev = GpuDevice::new(0).expect("CUDA device 0");
18631        let input = cpu_to_gpu(&data, &dev).expect("cpu_to_gpu f64");
18632        let out = gpu_strided_copy_f64(&input, &[3, 2], &[1, 3], 0, &dev)
18633            .expect("strided_copy_f64 transpose");
18634        let host = gpu_to_cpu(&out, &dev).expect("gpu_to_cpu f64");
18635        assert_eq!(host, vec![0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
18636    }
18637}