Skip to main content

ferrotorch_gpu/
kernels.rs

1//! Custom PTX CUDA kernels for elementwise GPU operations.
2//!
3//! Each operation has two code paths:
4//!
5//! 1. **PTX kernel** -- a hand-written PTX string that is loaded into the CUDA
6//!    driver at runtime via [`cudarc`]. This is the fast path and runs entirely
7//!    on the GPU.
8//! 2. **CPU fallback** -- copies data to the host, performs the operation with
9//!    standard Rust iterators, and copies the result back. Correct but slow;
10//!    used when the PTX module cannot be loaded (e.g. architecture mismatch).
11//!
12//! # Supported operations
13//!
14//! | Function | Formula |
15//! |----------|---------|
16//! | [`gpu_add`] | `out[i] = a[i] + b[i]` |
17//! | [`gpu_sub`] | `out[i] = a[i] - b[i]` |
18//! | [`gpu_mul`] | `out[i] = a[i] * b[i]` |
19//! | [`gpu_neg`] | `out[i] = -a[i]` |
20//! | [`gpu_relu`] | `out[i] = max(a[i], 0.0)` |
21
22#[cfg(feature = "cuda")]
23use cudarc::driver::LaunchConfig;
24
25use crate::buffer::CudaBuffer;
26use crate::device::GpuDevice;
27use crate::error::{GpuError, GpuResult};
28#[cfg(feature = "cuda")]
29use crate::transfer::{alloc_zeros_f32, cpu_to_gpu, gpu_to_cpu};
30
31// ---------------------------------------------------------------------------
32// PTX kernel source strings
33// ---------------------------------------------------------------------------
34
35/// PTX source for `add_kernel`: `out[i] = a[i] + b[i]`.
36#[cfg(feature = "cuda")]
37pub(crate) const ADD_PTX: &str = "\
38.version 7.0
39.target sm_52
40.address_size 64
41
42.visible .entry add_kernel(
43    .param .u64 a_ptr,
44    .param .u64 b_ptr,
45    .param .u64 out_ptr,
46    .param .u32 n
47) {
48    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
49    .reg .u64 %a, %b, %out, %off;
50    .reg .f32 %va, %vb, %vr;
51    .reg .pred %p;
52
53    ld.param.u64 %a, [a_ptr];
54    ld.param.u64 %b, [b_ptr];
55    ld.param.u64 %out, [out_ptr];
56    ld.param.u32 %n_reg, [n];
57
58    mov.u32 %bid, %ctaid.x;
59    mov.u32 %bdim, %ntid.x;
60    mov.u32 %r_tid, %tid.x;
61    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
62
63    setp.ge.u32 %p, %r_tid, %n_reg;
64    @%p bra DONE;
65
66    cvt.u64.u32 %off, %r_tid;
67    shl.b64 %off, %off, 2;
68
69    add.u64 %a, %a, %off;
70    add.u64 %b, %b, %off;
71    add.u64 %out, %out, %off;
72
73    ld.global.f32 %va, [%a];
74    ld.global.f32 %vb, [%b];
75    add.f32 %vr, %va, %vb;
76    st.global.f32 [%out], %vr;
77
78DONE:
79    ret;
80}
81";
82
83/// PTX source for `add_vec4_kernel`: vectorized add, 4 elements per thread.
84///
85/// Uses `ld.global.v4.f32` (128-bit load) for 4x memory throughput vs scalar.
86/// Thread i processes elements [i*4 .. i*4+3].
87#[cfg(feature = "cuda")]
88pub(crate) const ADD_VEC4_PTX: &str = "\
89.version 7.0
90.target sm_52
91.address_size 64
92
93.visible .entry add_vec4_kernel(
94    .param .u64 a_ptr,
95    .param .u64 b_ptr,
96    .param .u64 out_ptr,
97    .param .u32 n4
98) {
99    .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
100    .reg .u64 %a, %b, %out, %off;
101    .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
102    .reg .pred %p;
103
104    ld.param.u64 %a, [a_ptr];
105    ld.param.u64 %b, [b_ptr];
106    ld.param.u64 %out, [out_ptr];
107    ld.param.u32 %n4_reg, [n4];
108
109    mov.u32 %bid, %ctaid.x;
110    mov.u32 %bdim, %ntid.x;
111    mov.u32 %r_tid, %tid.x;
112    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
113
114    setp.ge.u32 %p, %r_tid, %n4_reg;
115    @%p bra DONE;
116
117    // Byte offset = tid * 16 (4 floats × 4 bytes)
118    cvt.u64.u32 %off, %r_tid;
119    shl.b64 %off, %off, 4;
120
121    add.u64 %a, %a, %off;
122    add.u64 %b, %b, %off;
123    add.u64 %out, %out, %off;
124
125    ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
126    ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
127
128    add.f32 %r0, %a0, %b0;
129    add.f32 %r1, %a1, %b1;
130    add.f32 %r2, %a2, %b2;
131    add.f32 %r3, %a3, %b3;
132
133    st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
134
135DONE:
136    ret;
137}
138";
139
140/// PTX source for `mul_vec4_kernel`: vectorized multiply, 4 elements per thread.
141#[cfg(feature = "cuda")]
142pub(crate) const MUL_VEC4_PTX: &str = "\
143.version 7.0
144.target sm_52
145.address_size 64
146
147.visible .entry mul_vec4_kernel(
148    .param .u64 a_ptr,
149    .param .u64 b_ptr,
150    .param .u64 out_ptr,
151    .param .u32 n4
152) {
153    .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
154    .reg .u64 %a, %b, %out, %off;
155    .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
156    .reg .pred %p;
157
158    ld.param.u64 %a, [a_ptr];
159    ld.param.u64 %b, [b_ptr];
160    ld.param.u64 %out, [out_ptr];
161    ld.param.u32 %n4_reg, [n4];
162
163    mov.u32 %bid, %ctaid.x;
164    mov.u32 %bdim, %ntid.x;
165    mov.u32 %r_tid, %tid.x;
166    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
167
168    setp.ge.u32 %p, %r_tid, %n4_reg;
169    @%p bra DONE;
170
171    cvt.u64.u32 %off, %r_tid;
172    shl.b64 %off, %off, 4;
173
174    add.u64 %a, %a, %off;
175    add.u64 %b, %b, %off;
176    add.u64 %out, %out, %off;
177
178    ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
179    ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
180
181    mul.f32 %r0, %a0, %b0;
182    mul.f32 %r1, %a1, %b1;
183    mul.f32 %r2, %a2, %b2;
184    mul.f32 %r3, %a3, %b3;
185
186    st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
187
188DONE:
189    ret;
190}
191";
192
193/// PTX source for `sub_kernel`: `out[i] = a[i] - b[i]`.
194#[cfg(feature = "cuda")]
195pub(crate) const SUB_PTX: &str = "\
196.version 7.0
197.target sm_52
198.address_size 64
199
200.visible .entry sub_kernel(
201    .param .u64 a_ptr,
202    .param .u64 b_ptr,
203    .param .u64 out_ptr,
204    .param .u32 n
205) {
206    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
207    .reg .u64 %a, %b, %out, %off;
208    .reg .f32 %va, %vb, %vr;
209    .reg .pred %p;
210
211    ld.param.u64 %a, [a_ptr];
212    ld.param.u64 %b, [b_ptr];
213    ld.param.u64 %out, [out_ptr];
214    ld.param.u32 %n_reg, [n];
215
216    mov.u32 %bid, %ctaid.x;
217    mov.u32 %bdim, %ntid.x;
218    mov.u32 %r_tid, %tid.x;
219    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
220
221    setp.ge.u32 %p, %r_tid, %n_reg;
222    @%p bra DONE;
223
224    cvt.u64.u32 %off, %r_tid;
225    shl.b64 %off, %off, 2;
226
227    add.u64 %a, %a, %off;
228    add.u64 %b, %b, %off;
229    add.u64 %out, %out, %off;
230
231    ld.global.f32 %va, [%a];
232    ld.global.f32 %vb, [%b];
233    sub.f32 %vr, %va, %vb;
234    st.global.f32 [%out], %vr;
235
236DONE:
237    ret;
238}
239";
240
241/// PTX source for `mul_kernel`: `out[i] = a[i] * b[i]`.
242#[cfg(feature = "cuda")]
243pub(crate) const MUL_PTX: &str = "\
244.version 7.0
245.target sm_52
246.address_size 64
247
248.visible .entry mul_kernel(
249    .param .u64 a_ptr,
250    .param .u64 b_ptr,
251    .param .u64 out_ptr,
252    .param .u32 n
253) {
254    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
255    .reg .u64 %a, %b, %out, %off;
256    .reg .f32 %va, %vb, %vr;
257    .reg .pred %p;
258
259    ld.param.u64 %a, [a_ptr];
260    ld.param.u64 %b, [b_ptr];
261    ld.param.u64 %out, [out_ptr];
262    ld.param.u32 %n_reg, [n];
263
264    mov.u32 %bid, %ctaid.x;
265    mov.u32 %bdim, %ntid.x;
266    mov.u32 %r_tid, %tid.x;
267    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
268
269    setp.ge.u32 %p, %r_tid, %n_reg;
270    @%p bra DONE;
271
272    cvt.u64.u32 %off, %r_tid;
273    shl.b64 %off, %off, 2;
274
275    add.u64 %a, %a, %off;
276    add.u64 %b, %b, %off;
277    add.u64 %out, %out, %off;
278
279    ld.global.f32 %va, [%a];
280    ld.global.f32 %vb, [%b];
281    mul.f32 %vr, %va, %vb;
282    st.global.f32 [%out], %vr;
283
284DONE:
285    ret;
286}
287";
288
289/// PTX source for `neg_kernel`: `out[i] = -a[i]`.
290#[cfg(feature = "cuda")]
291pub(crate) const NEG_PTX: &str = "\
292.version 7.0
293.target sm_52
294.address_size 64
295
296.visible .entry neg_kernel(
297    .param .u64 a_ptr,
298    .param .u64 out_ptr,
299    .param .u32 n
300) {
301    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
302    .reg .u64 %a, %out, %off;
303    .reg .f32 %va, %vr;
304    .reg .pred %p;
305
306    ld.param.u64 %a, [a_ptr];
307    ld.param.u64 %out, [out_ptr];
308    ld.param.u32 %n_reg, [n];
309
310    mov.u32 %bid, %ctaid.x;
311    mov.u32 %bdim, %ntid.x;
312    mov.u32 %r_tid, %tid.x;
313    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
314
315    setp.ge.u32 %p, %r_tid, %n_reg;
316    @%p bra DONE;
317
318    cvt.u64.u32 %off, %r_tid;
319    shl.b64 %off, %off, 2;
320
321    add.u64 %a, %a, %off;
322    add.u64 %out, %out, %off;
323
324    ld.global.f32 %va, [%a];
325    neg.f32 %vr, %va;
326    st.global.f32 [%out], %vr;
327
328DONE:
329    ret;
330}
331";
332
333/// PTX source for `relu_kernel`: `out[i] = max(a[i], 0.0)`.
334#[cfg(feature = "cuda")]
335pub(crate) const RELU_PTX: &str = "\
336.version 7.0
337.target sm_52
338.address_size 64
339
340.visible .entry relu_kernel(
341    .param .u64 a_ptr,
342    .param .u64 out_ptr,
343    .param .u32 n
344) {
345    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
346    .reg .u64 %a, %out, %off;
347    .reg .f32 %va, %vr, %zero;
348    .reg .pred %p;
349
350    ld.param.u64 %a, [a_ptr];
351    ld.param.u64 %out, [out_ptr];
352    ld.param.u32 %n_reg, [n];
353
354    mov.u32 %bid, %ctaid.x;
355    mov.u32 %bdim, %ntid.x;
356    mov.u32 %r_tid, %tid.x;
357    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
358
359    setp.ge.u32 %p, %r_tid, %n_reg;
360    @%p bra DONE;
361
362    cvt.u64.u32 %off, %r_tid;
363    shl.b64 %off, %off, 2;
364
365    add.u64 %a, %a, %off;
366    add.u64 %out, %out, %off;
367
368    ld.global.f32 %va, [%a];
369    mov.f32 %zero, 0f00000000;
370    max.f32 %vr, %va, %zero;
371    st.global.f32 [%out], %vr;
372
373DONE:
374    ret;
375}
376";
377
378/// PTX source for `scale_kernel`: `out[i] = a[i] * scalar`.
379#[cfg(feature = "cuda")]
380pub(crate) const SCALE_PTX: &str = "\
381.version 7.0
382.target sm_52
383.address_size 64
384
385.visible .entry scale_kernel(
386    .param .u64 a_ptr,
387    .param .u64 out_ptr,
388    .param .f32 scalar,
389    .param .u32 n
390) {
391    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
392    .reg .u64 %a, %out, %off;
393    .reg .f32 %va, %vr, %s;
394    .reg .pred %p;
395
396    ld.param.u64 %a, [a_ptr];
397    ld.param.u64 %out, [out_ptr];
398    ld.param.f32 %s, [scalar];
399    ld.param.u32 %n_reg, [n];
400
401    mov.u32 %bid, %ctaid.x;
402    mov.u32 %bdim, %ntid.x;
403    mov.u32 %r_tid, %tid.x;
404    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
405
406    setp.ge.u32 %p, %r_tid, %n_reg;
407    @%p bra DONE;
408
409    cvt.u64.u32 %off, %r_tid;
410    shl.b64 %off, %off, 2;
411
412    add.u64 %a, %a, %off;
413    add.u64 %out, %out, %off;
414
415    ld.global.f32 %va, [%a];
416    mul.f32 %vr, %va, %s;
417    st.global.f32 [%out], %vr;
418
419DONE:
420    ret;
421}
422";
423
424/// PTX for 2D matrix transpose: `out[j * M + i] = in[i * N + j]`.
425/// Thread `tid` maps to output index; computes the corresponding input index.
426#[cfg(feature = "cuda")]
427pub(crate) const TRANSPOSE_2D_PTX: &str = "\
428.version 7.0\n\
429.target sm_52\n\
430.address_size 64\n\
431\n\
432.visible .entry transpose_2d_kernel(\n\
433    .param .u64 in_ptr,\n\
434    .param .u64 out_ptr,\n\
435    .param .u32 M,\n\
436    .param .u32 N,\n\
437    .param .u32 total\n\
438) {\n\
439    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %N_reg;\n\
440    .reg .u32 %out_row, %out_col, %in_idx;\n\
441    .reg .u64 %in, %out, %off_in, %off_out;\n\
442    .reg .f32 %val;\n\
443    .reg .pred %p;\n\
444\n\
445    ld.param.u64 %in, [in_ptr];\n\
446    ld.param.u64 %out, [out_ptr];\n\
447    ld.param.u32 %M_reg, [M];\n\
448    ld.param.u32 %N_reg, [N];\n\
449    ld.param.u32 %total_reg, [total];\n\
450\n\
451    mov.u32 %bid, %ctaid.x;\n\
452    mov.u32 %bdim, %ntid.x;\n\
453    mov.u32 %r_tid, %tid.x;\n\
454    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
455\n\
456    setp.ge.u32 %p, %r_tid, %total_reg;\n\
457    @%p bra DONE;\n\
458\n\
459    // Output shape is [N, M]. tid = out_row * M + out_col.\n\
460    div.u32 %out_row, %r_tid, %M_reg;\n\
461    rem.u32 %out_col, %r_tid, %M_reg;\n\
462    // Input index: out_col * N + out_row (transposed).\n\
463    mad.lo.u32 %in_idx, %out_col, %N_reg, %out_row;\n\
464\n\
465    cvt.u64.u32 %off_in, %in_idx;\n\
466    shl.b64 %off_in, %off_in, 2;\n\
467    add.u64 %off_in, %in, %off_in;\n\
468    ld.global.f32 %val, [%off_in];\n\
469\n\
470    cvt.u64.u32 %off_out, %r_tid;\n\
471    shl.b64 %off_out, %off_out, 2;\n\
472    add.u64 %off_out, %out, %off_out;\n\
473    st.global.f32 [%off_out], %val;\n\
474\n\
475DONE:\n\
476    ret;\n\
477}\n\
478";
479
480// ---------------------------------------------------------------------------
481// 4D permute (0,2,1,3) PTX kernel — swap dims 1 and 2
482// ---------------------------------------------------------------------------
483// Input:  [d0, d1, d2, d3]
484// Output: [d0, d2, d1, d3]
485// Thread i computes output[i] by mapping to the transposed input index.
486
487#[cfg(feature = "cuda")]
488pub(crate) const PERMUTE_0213_PTX: &str = "\
489.version 7.0\n\
490.target sm_52\n\
491.address_size 64\n\
492\n\
493.visible .entry permute_0213_kernel(\n\
494    .param .u64 in_ptr,\n\
495    .param .u64 out_ptr,\n\
496    .param .u32 d0,\n\
497    .param .u32 d1,\n\
498    .param .u32 d2,\n\
499    .param .u32 d3,\n\
500    .param .u32 total\n\
501) {\n\
502    .reg .u32 %r_tid, %bid, %bdim, %total_reg;\n\
503    .reg .u32 %d0r, %d1r, %d2r, %d3r;\n\
504    .reg .u32 %i0, %i1, %i2, %i3, %rem, %in_idx;\n\
505    .reg .u32 %s_out2, %s_out1, %s_in1;\n\
506    .reg .u64 %in, %out, %off_in, %off_out;\n\
507    .reg .f32 %val;\n\
508    .reg .pred %p;\n\
509\n\
510    ld.param.u64 %in, [in_ptr];\n\
511    ld.param.u64 %out, [out_ptr];\n\
512    ld.param.u32 %d0r, [d0];\n\
513    ld.param.u32 %d1r, [d1];\n\
514    ld.param.u32 %d2r, [d2];\n\
515    ld.param.u32 %d3r, [d3];\n\
516    ld.param.u32 %total_reg, [total];\n\
517\n\
518    mov.u32 %bid, %ctaid.x;\n\
519    mov.u32 %bdim, %ntid.x;\n\
520    mov.u32 %r_tid, %tid.x;\n\
521    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
522\n\
523    setp.ge.u32 %p, %r_tid, %total_reg;\n\
524    @%p bra DONE;\n\
525\n\
526    // Output shape: [d0, d2, d1, d3]\n\
527    // Decompose tid into (i0, i2, i1, i3) in output layout.\n\
528    mul.lo.u32 %s_out2, %d1r, %d3r;\n\
529    mul.lo.u32 %s_out1, %s_out2, %d2r;\n\
530\n\
531    div.u32 %i0, %r_tid, %s_out1;\n\
532    rem.u32 %rem, %r_tid, %s_out1;\n\
533    div.u32 %i2, %rem, %s_out2;\n\
534    rem.u32 %rem, %rem, %s_out2;\n\
535    div.u32 %i1, %rem, %d3r;\n\
536    rem.u32 %i3, %rem, %d3r;\n\
537\n\
538    // Input index: i0 * (d1*d2*d3) + i1 * (d2*d3) + i2 * d3 + i3\n\
539    mul.lo.u32 %s_in1, %d2r, %d3r;\n\
540    mul.lo.u32 %in_idx, %i0, %d1r;\n\
541    add.u32 %in_idx, %in_idx, %i1;\n\
542    mul.lo.u32 %in_idx, %in_idx, %s_in1;\n\
543    mad.lo.u32 %in_idx, %i2, %d3r, %in_idx;\n\
544    add.u32 %in_idx, %in_idx, %i3;\n\
545\n\
546    cvt.u64.u32 %off_in, %in_idx;\n\
547    shl.b64 %off_in, %off_in, 2;\n\
548    add.u64 %off_in, %in, %off_in;\n\
549    ld.global.f32 %val, [%off_in];\n\
550\n\
551    cvt.u64.u32 %off_out, %r_tid;\n\
552    shl.b64 %off_out, %off_out, 2;\n\
553    add.u64 %off_out, %out, %off_out;\n\
554    st.global.f32 [%off_out], %val;\n\
555\n\
556DONE:\n\
557    ret;\n\
558}\n\
559";
560
561// ---------------------------------------------------------------------------
562// f32-to-f16 conversion PTX kernel: out_f16[i] = float2half(in_f32[i])
563// ---------------------------------------------------------------------------
564// Used by gpu_matmul_f16 to cast f32 inputs to f16 on-GPU before calling
565// cublasGemmEx. The output is stored as u16 (IEEE 754 half-precision bits).
566
567#[cfg(feature = "cuda")]
568pub(crate) const F32_TO_F16_PTX: &str = "\
569.version 7.0
570.target sm_52
571.address_size 64
572
573.visible .entry f32_to_f16_kernel(
574    .param .u64 in_ptr,
575    .param .u64 out_ptr,
576    .param .u32 n
577) {
578    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
579    .reg .u64 %in, %out, %off_in, %off_out;
580    .reg .f32 %vf;
581    .reg .b16 %vh;
582    .reg .pred %p;
583
584    ld.param.u64 %in, [in_ptr];
585    ld.param.u64 %out, [out_ptr];
586    ld.param.u32 %n_reg, [n];
587
588    mov.u32 %bid, %ctaid.x;
589    mov.u32 %bdim, %ntid.x;
590    mov.u32 %r_tid, %tid.x;
591    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
592
593    setp.ge.u32 %p, %r_tid, %n_reg;
594    @%p bra DONE;
595
596    // Compute input offset: i * 4 (f32 = 4 bytes)
597    cvt.u64.u32 %off_in, %r_tid;
598    shl.b64 %off_in, %off_in, 2;
599    add.u64 %in, %in, %off_in;
600
601    // Compute output offset: i * 2 (f16 = 2 bytes)
602    cvt.u64.u32 %off_out, %r_tid;
603    shl.b64 %off_out, %off_out, 1;
604    add.u64 %out, %out, %off_out;
605
606    // Load f32, convert to f16 (round-to-nearest-even), store as u16
607    ld.global.f32 %vf, [%in];
608    cvt.rn.f16.f32 %vh, %vf;
609    st.global.b16 [%out], %vh;
610
611DONE:
612    ret;
613}
614";
615
616/// PTX source for `f32_to_bf16_kernel`: convert f32 → bf16 (stored as u16).
617///
618/// BF16 is the top 16 bits of f32 with round-to-nearest-even. We do this
619/// with integer bit ops: add rounding bias 0x7FFF + bit 16 of the value,
620/// then shift right 16. This works on sm_52+ (no special bf16 instructions
621/// needed).
622#[cfg(feature = "cuda")]
623pub(crate) const F32_TO_BF16_PTX: &str = "\
624.version 7.0
625.target sm_52
626.address_size 64
627
628.visible .entry f32_to_bf16_kernel(
629    .param .u64 in_ptr,
630    .param .u64 out_ptr,
631    .param .u32 n
632) {
633    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
634    .reg .u64 %in, %out, %off_in, %off_out;
635    .reg .f32 %vf;
636    .reg .u32 %bits, %round, %lsb, %result;
637    .reg .pred %p;
638
639    ld.param.u64 %in, [in_ptr];
640    ld.param.u64 %out, [out_ptr];
641    ld.param.u32 %n_reg, [n];
642
643    mov.u32 %bid, %ctaid.x;
644    mov.u32 %bdim, %ntid.x;
645    mov.u32 %r_tid, %tid.x;
646    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
647
648    setp.ge.u32 %p, %r_tid, %n_reg;
649    @%p bra DONE;
650
651    cvt.u64.u32 %off_in, %r_tid;
652    shl.b64 %off_in, %off_in, 2;
653    add.u64 %in, %in, %off_in;
654
655    cvt.u64.u32 %off_out, %r_tid;
656    shl.b64 %off_out, %off_out, 1;
657    add.u64 %out, %out, %off_out;
658
659    // Load f32 as raw bits
660    ld.global.u32 %bits, [%in];
661
662    // Round-to-nearest-even: add (0x7FFF + bit[16]) then shift right 16
663    shr.u32 %lsb, %bits, 16;
664    and.b32 %lsb, %lsb, 1;
665    add.u32 %round, %bits, 0x7FFF;
666    add.u32 %round, %round, %lsb;
667    shr.u32 %result, %round, 16;
668
669    // Store as u16
670    st.global.u16 [%out], %result;
671
672DONE:
673    ret;
674}
675";
676
677// ---------------------------------------------------------------------------
678// Small matmul PTX kernel: C = A @ B, one thread per output element
679// ---------------------------------------------------------------------------
680// For small matrices where cuBLAS JIT compilation overhead > compute time.
681// Compiles once via module_cache, never JIT-recompiles for different sizes.
682
683#[cfg(feature = "cuda")]
684pub(crate) const SMALL_MATMUL_PTX: &str = "\
685.version 7.0
686.target sm_52
687.address_size 64
688
689.visible .entry small_matmul_kernel(
690    .param .u64 a_ptr,
691    .param .u64 b_ptr,
692    .param .u64 c_ptr,
693    .param .u32 M,
694    .param .u32 K,
695    .param .u32 N,
696    .param .u32 total
697) {
698    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %K_reg, %N_reg;
699    .reg .u32 %row, %col, %p, %idx;
700    .reg .u64 %a, %b, %c, %a_off, %b_off, %c_off;
701    .reg .f32 %sum, %va, %vb;
702    .reg .pred %bounds_p, %loop_p;
703
704    ld.param.u64 %a, [a_ptr];
705    ld.param.u64 %b, [b_ptr];
706    ld.param.u64 %c, [c_ptr];
707    ld.param.u32 %M_reg, [M];
708    ld.param.u32 %K_reg, [K];
709    ld.param.u32 %N_reg, [N];
710    ld.param.u32 %total_reg, [total];
711
712    mov.u32 %bid, %ctaid.x;
713    mov.u32 %bdim, %ntid.x;
714    mov.u32 %r_tid, %tid.x;
715    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
716
717    setp.ge.u32 %bounds_p, %r_tid, %total_reg;
718    @%bounds_p bra DONE;
719
720    div.u32 %row, %r_tid, %N_reg;
721    rem.u32 %col, %r_tid, %N_reg;
722
723    mov.f32 %sum, 0f00000000;
724    mov.u32 %p, 0;
725DOT:
726    setp.ge.u32 %loop_p, %p, %K_reg;
727    @%loop_p bra DOT_DONE;
728
729    mad.lo.u32 %idx, %row, %K_reg, %p;
730    cvt.u64.u32 %a_off, %idx;
731    shl.b64 %a_off, %a_off, 2;
732    add.u64 %a_off, %a, %a_off;
733    ld.global.f32 %va, [%a_off];
734
735    mad.lo.u32 %idx, %p, %N_reg, %col;
736    cvt.u64.u32 %b_off, %idx;
737    shl.b64 %b_off, %b_off, 2;
738    add.u64 %b_off, %b, %b_off;
739    ld.global.f32 %vb, [%b_off];
740
741    fma.rn.f32 %sum, %va, %vb, %sum;
742    add.u32 %p, %p, 1;
743    bra DOT;
744DOT_DONE:
745
746    cvt.u64.u32 %c_off, %r_tid;
747    shl.b64 %c_off, %c_off, 2;
748    add.u64 %c_off, %c, %c_off;
749    st.global.f32 [%c_off], %sum;
750
751DONE:
752    ret;
753}
754";
755
756// ---------------------------------------------------------------------------
757// Slice-write PTX kernel: copy [N, D] into row `pos` of [N, max_len, D]
758// ---------------------------------------------------------------------------
759
760#[cfg(feature = "cuda")]
761pub(crate) const SLICE_WRITE_PTX: &str = "\
762.version 7.0
763.target sm_52
764.address_size 64
765
766.visible .entry slice_write_kernel(
767    .param .u64 src_ptr,
768    .param .u64 dst_ptr,
769    .param .u32 n,
770    .param .u32 D,
771    .param .u32 max_len,
772    .param .u32 pos
773) {
774    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
775    .reg .u32 %batch_idx, %d_idx, %dst_row;
776    .reg .u64 %src, %dst, %src_off, %dst_off;
777    .reg .f32 %val;
778    .reg .pred %p;
779
780    ld.param.u64 %src, [src_ptr];
781    ld.param.u64 %dst, [dst_ptr];
782    ld.param.u32 %n_reg, [n];
783    ld.param.u32 %D_reg, [D];
784    ld.param.u32 %max_len_reg, [max_len];
785    ld.param.u32 %pos_reg, [pos];
786
787    mov.u32 %bid, %ctaid.x;
788    mov.u32 %bdim, %ntid.x;
789    mov.u32 %r_tid, %tid.x;
790    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
791
792    setp.ge.u32 %p, %r_tid, %n_reg;
793    @%p bra DONE;
794
795    cvt.u64.u32 %src_off, %r_tid;
796    shl.b64 %src_off, %src_off, 2;
797    add.u64 %src, %src, %src_off;
798    ld.global.f32 %val, [%src];
799
800    div.u32 %batch_idx, %r_tid, %D_reg;
801    rem.u32 %d_idx, %r_tid, %D_reg;
802    mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
803    add.u32 %dst_row, %dst_row, %pos_reg;
804    mul.lo.u32 %dst_row, %dst_row, %D_reg;
805    add.u32 %dst_row, %dst_row, %d_idx;
806    cvt.u64.u32 %dst_off, %dst_row;
807    shl.b64 %dst_off, %dst_off, 2;
808    add.u64 %dst, %dst, %dst_off;
809    st.global.f32 [%dst], %val;
810
811DONE:
812    ret;
813}
814";
815
816/// PTX for `slice_write_indirect_kernel`: same as `slice_write_kernel` but
817/// reads `pos` from a device pointer. This enables CUDA graph capture — the
818/// graph records the pointer address (fixed), and we update the u32 value
819/// at that address before each graph replay.
820#[cfg(feature = "cuda")]
821pub(crate) const SLICE_WRITE_INDIRECT_PTX: &str = "\
822.version 7.0
823.target sm_52
824.address_size 64
825
826.visible .entry slice_write_indirect_kernel(
827    .param .u64 src_ptr,
828    .param .u64 dst_ptr,
829    .param .u32 n,
830    .param .u32 D,
831    .param .u32 max_len,
832    .param .u64 pos_ptr
833) {
834    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
835    .reg .u32 %batch_idx, %d_idx, %dst_row;
836    .reg .u64 %src, %dst, %src_off, %dst_off, %pos_p;
837    .reg .f32 %val;
838    .reg .pred %p;
839
840    ld.param.u64 %src, [src_ptr];
841    ld.param.u64 %dst, [dst_ptr];
842    ld.param.u32 %n_reg, [n];
843    ld.param.u32 %D_reg, [D];
844    ld.param.u32 %max_len_reg, [max_len];
845    ld.param.u64 %pos_p, [pos_ptr];
846    ld.global.u32 %pos_reg, [%pos_p];
847
848    mov.u32 %bid, %ctaid.x;
849    mov.u32 %bdim, %ntid.x;
850    mov.u32 %r_tid, %tid.x;
851    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
852
853    setp.ge.u32 %p, %r_tid, %n_reg;
854    @%p bra DONE;
855
856    cvt.u64.u32 %src_off, %r_tid;
857    shl.b64 %src_off, %src_off, 2;
858    add.u64 %src, %src, %src_off;
859    ld.global.f32 %val, [%src];
860
861    div.u32 %batch_idx, %r_tid, %D_reg;
862    rem.u32 %d_idx, %r_tid, %D_reg;
863    mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
864    add.u32 %dst_row, %dst_row, %pos_reg;
865    mul.lo.u32 %dst_row, %dst_row, %D_reg;
866    add.u32 %dst_row, %dst_row, %d_idx;
867    cvt.u64.u32 %dst_off, %dst_row;
868    shl.b64 %dst_off, %dst_off, 2;
869    add.u64 %dst, %dst, %dst_off;
870    st.global.f32 [%dst], %val;
871
872DONE:
873    ret;
874}
875";
876
877/// PTX for `causal_mask_indirect_kernel`: builds an attention mask where
878/// `out[h, col] = 0.0` for `col < total_len` and `-1e9` for `col >= total_len`.
879/// `total_len` is read from a device pointer (for CUDA graph capture).
880/// Output shape: `[n_head, max_pos]` — one mask row per head (all identical).
881/// Thread `tid` maps to flat index; column = `tid % max_pos`.
882#[cfg(feature = "cuda")]
883pub(crate) const CAUSAL_MASK_INDIRECT_PTX: &str = "\
884.version 7.0
885.target sm_52
886.address_size 64
887
888.visible .entry causal_mask_indirect_kernel(
889    .param .u64 total_len_ptr,
890    .param .u64 out_ptr,
891    .param .u32 max_pos,
892    .param .u32 total
893) {
894    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %tlen, %max_pos_reg, %col;
895    .reg .u64 %out, %off, %tl_p;
896    .reg .f32 %val;
897    .reg .pred %bounds_p, %mask_p;
898
899    ld.param.u64 %tl_p, [total_len_ptr];
900    ld.param.u64 %out, [out_ptr];
901    ld.param.u32 %max_pos_reg, [max_pos];
902    ld.param.u32 %total_reg, [total];
903
904    ld.global.u32 %tlen, [%tl_p];
905
906    mov.u32 %bid, %ctaid.x;
907    mov.u32 %bdim, %ntid.x;
908    mov.u32 %r_tid, %tid.x;
909    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
910
911    setp.ge.u32 %bounds_p, %r_tid, %total_reg;
912    @%bounds_p bra DONE;
913
914    rem.u32 %col, %r_tid, %max_pos_reg;
915    setp.lt.u32 %mask_p, %col, %tlen;
916    @%mask_p bra WRITE_ZERO;
917
918    // 0fCE6E6B28 = -1.0e9 in IEEE 754 f32, used as a large negative mask value
919    // to effectively zero out masked positions after softmax.
920    mov.f32 %val, 0fCE6E6B28;
921    bra WRITE;
922
923WRITE_ZERO:
924    mov.f32 %val, 0f00000000;
925
926WRITE:
927    cvt.u64.u32 %off, %r_tid;
928    shl.b64 %off, %off, 2;
929    add.u64 %out, %out, %off;
930    st.global.f32 [%out], %val;
931
932DONE:
933    ret;
934}
935";
936
937// ---------------------------------------------------------------------------
938// Embedding lookup PTX kernel: output[d] = weight[token_id * D + d]
939// ---------------------------------------------------------------------------
940
941#[cfg(feature = "cuda")]
942pub(crate) const EMBED_LOOKUP_PTX: &str = "\
943.version 7.0
944.target sm_52
945.address_size 64
946
947.visible .entry embed_lookup_kernel(
948    .param .u64 idx_ptr,
949    .param .u64 weight_ptr,
950    .param .u64 out_ptr,
951    .param .u32 D
952) {
953    .reg .u32 %r_tid, %bid, %bdim, %D_reg, %row, %src_idx;
954    .reg .u64 %idx_addr, %w, %out, %off;
955    .reg .f32 %idx_f, %val;
956    .reg .pred %p;
957
958    ld.param.u64 %idx_addr, [idx_ptr];
959    ld.param.u64 %w, [weight_ptr];
960    ld.param.u64 %out, [out_ptr];
961    ld.param.u32 %D_reg, [D];
962
963    mov.u32 %bid, %ctaid.x;
964    mov.u32 %bdim, %ntid.x;
965    mov.u32 %r_tid, %tid.x;
966    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
967
968    setp.ge.u32 %p, %r_tid, %D_reg;
969    @%p bra DONE;
970
971    ld.global.f32 %idx_f, [%idx_addr];
972    cvt.rzi.u32.f32 %row, %idx_f;
973
974    mad.lo.u32 %src_idx, %row, %D_reg, %r_tid;
975    cvt.u64.u32 %off, %src_idx;
976    shl.b64 %off, %off, 2;
977    add.u64 %off, %w, %off;
978    ld.global.f32 %val, [%off];
979
980    cvt.u64.u32 %off, %r_tid;
981    shl.b64 %off, %off, 2;
982    add.u64 %off, %out, %off;
983    st.global.f32 [%off], %val;
984
985DONE:
986    ret;
987}
988";
989
990// ---------------------------------------------------------------------------
991// Batch embedding lookup PTX kernel
992// ---------------------------------------------------------------------------
993// Given N f32 indices and a weight matrix [V, D], gather N rows into [N, D].
994// Thread `tid` computes one element: row = tid / D, col = tid % D.
995// out[tid] = weight[indices[row] * D + col]
996
997#[cfg(feature = "cuda")]
998pub(crate) const EMBED_LOOKUP_BATCH_PTX: &str = "\
999.version 7.0
1000.target sm_52
1001.address_size 64
1002
1003.visible .entry embed_lookup_batch_kernel(
1004    .param .u64 idx_ptr,
1005    .param .u64 weight_ptr,
1006    .param .u64 out_ptr,
1007    .param .u32 D,
1008    .param .u32 total
1009) {
1010    .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1011    .reg .u32 %row, %col, %src_idx;
1012    .reg .u64 %idx_addr, %w, %out, %off;
1013    .reg .f32 %idx_f, %val;
1014    .reg .pred %p;
1015
1016    ld.param.u64 %idx_addr, [idx_ptr];
1017    ld.param.u64 %w, [weight_ptr];
1018    ld.param.u64 %out, [out_ptr];
1019    ld.param.u32 %D_reg, [D];
1020    ld.param.u32 %total_reg, [total];
1021
1022    mov.u32 %bid, %ctaid.x;
1023    mov.u32 %bdim, %ntid.x;
1024    mov.u32 %tid, %tid.x;
1025    mad.lo.u32 %tid, %bid, %bdim, %tid;
1026
1027    setp.ge.u32 %p, %tid, %total_reg;
1028    @%p bra DONE;
1029
1030    // row = tid / D, col = tid % D
1031    div.u32 %row, %tid, %D_reg;
1032    rem.u32 %col, %tid, %D_reg;
1033
1034    // Read indices[row] (f32 -> u32)
1035    cvt.u64.u32 %off, %row;
1036    shl.b64 %off, %off, 2;
1037    add.u64 %off, %idx_addr, %off;
1038    ld.global.f32 %idx_f, [%off];
1039    cvt.rzi.u32.f32 %src_idx, %idx_f;
1040
1041    // src_idx = indices[row] * D + col
1042    mad.lo.u32 %src_idx, %src_idx, %D_reg, %col;
1043    cvt.u64.u32 %off, %src_idx;
1044    shl.b64 %off, %off, 2;
1045    add.u64 %off, %w, %off;
1046    ld.global.f32 %val, [%off];
1047
1048    // Write to out[tid]
1049    cvt.u64.u32 %off, %tid;
1050    shl.b64 %off, %off, 2;
1051    add.u64 %off, %out, %off;
1052    st.global.f32 [%off], %val;
1053
1054DONE:
1055    ret;
1056}
1057";
1058
1059// ---------------------------------------------------------------------------
1060// Scatter-add rows PTX kernel (for embedding backward)
1061// ---------------------------------------------------------------------------
1062// Given grad_output [N, D] and indices [N] (f32), atomically accumulate:
1063//   grad_weight[indices[row], col] += grad_output[row * D + col]
1064// Thread `tid` handles one element: row = tid / D, col = tid % D.
1065
1066#[cfg(feature = "cuda")]
1067pub(crate) const SCATTER_ADD_ROWS_PTX: &str = "\
1068.version 7.0
1069.target sm_52
1070.address_size 64
1071
1072.visible .entry scatter_add_rows_kernel(
1073    .param .u64 grad_output_ptr,
1074    .param .u64 indices_ptr,
1075    .param .u64 grad_weight_ptr,
1076    .param .u32 D,
1077    .param .u32 total
1078) {
1079    .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1080    .reg .u32 %row, %col, %dst_idx;
1081    .reg .u64 %go, %idx_addr, %gw, %off;
1082    .reg .f32 %idx_f, %grad_val, %dummy;
1083    .reg .pred %p;
1084
1085    ld.param.u64 %go, [grad_output_ptr];
1086    ld.param.u64 %idx_addr, [indices_ptr];
1087    ld.param.u64 %gw, [grad_weight_ptr];
1088    ld.param.u32 %D_reg, [D];
1089    ld.param.u32 %total_reg, [total];
1090
1091    mov.u32 %bid, %ctaid.x;
1092    mov.u32 %bdim, %ntid.x;
1093    mov.u32 %tid, %tid.x;
1094    mad.lo.u32 %tid, %bid, %bdim, %tid;
1095
1096    setp.ge.u32 %p, %tid, %total_reg;
1097    @%p bra DONE;
1098
1099    // row = tid / D, col = tid % D
1100    div.u32 %row, %tid, %D_reg;
1101    rem.u32 %col, %tid, %D_reg;
1102
1103    // Read grad_output[tid]
1104    cvt.u64.u32 %off, %tid;
1105    shl.b64 %off, %off, 2;
1106    add.u64 %off, %go, %off;
1107    ld.global.f32 %grad_val, [%off];
1108
1109    // Read indices[row] (f32 -> u32)
1110    cvt.u64.u32 %off, %row;
1111    shl.b64 %off, %off, 2;
1112    add.u64 %off, %idx_addr, %off;
1113    ld.global.f32 %idx_f, [%off];
1114    cvt.rzi.u32.f32 %dst_idx, %idx_f;
1115
1116    // dst_idx = indices[row] * D + col
1117    mad.lo.u32 %dst_idx, %dst_idx, %D_reg, %col;
1118    cvt.u64.u32 %off, %dst_idx;
1119    shl.b64 %off, %off, 2;
1120    add.u64 %off, %gw, %off;
1121    atom.global.add.f32 %dummy, [%off], %grad_val;
1122
1123DONE:
1124    ret;
1125}
1126";
1127
1128// ---------------------------------------------------------------------------
1129// Slice-read PTX kernel: read first `len` rows from [N, max_len, D]
1130// ---------------------------------------------------------------------------
1131// Thread i writes: dst[i] = src[batch_idx * max_len * D + (i % (len*D))]
1132// where batch_idx = i / (len * D)
1133
1134#[cfg(feature = "cuda")]
1135pub(crate) const SLICE_READ_PTX: &str = "\
1136.version 7.0
1137.target sm_52
1138.address_size 64
1139
1140.visible .entry slice_read_kernel(
1141    .param .u64 src_ptr,
1142    .param .u64 dst_ptr,
1143    .param .u32 total,
1144    .param .u32 D,
1145    .param .u32 len,
1146    .param .u32 max_len
1147) {
1148    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %D_reg, %len_reg, %max_len_reg;
1149    .reg .u32 %batch_idx, %within, %row, %col, %src_idx;
1150    .reg .u32 %len_d;
1151    .reg .u64 %src, %dst, %src_off, %dst_off;
1152    .reg .f32 %val;
1153    .reg .pred %p;
1154
1155    ld.param.u64 %src, [src_ptr];
1156    ld.param.u64 %dst, [dst_ptr];
1157    ld.param.u32 %total_reg, [total];
1158    ld.param.u32 %D_reg, [D];
1159    ld.param.u32 %len_reg, [len];
1160    ld.param.u32 %max_len_reg, [max_len];
1161
1162    mov.u32 %bid, %ctaid.x;
1163    mov.u32 %bdim, %ntid.x;
1164    mov.u32 %r_tid, %tid.x;
1165    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1166
1167    setp.ge.u32 %p, %r_tid, %total_reg;
1168    @%p bra DONE;
1169
1170    // dst index = r_tid
1171    // batch_idx = r_tid / (len * D)
1172    // within = r_tid % (len * D)
1173    // row = within / D
1174    // col = within % D
1175    // src_idx = batch_idx * max_len * D + row * D + col
1176    mul.lo.u32 %len_d, %len_reg, %D_reg;
1177    div.u32 %batch_idx, %r_tid, %len_d;
1178    rem.u32 %within, %r_tid, %len_d;
1179    div.u32 %row, %within, %D_reg;
1180    rem.u32 %col, %within, %D_reg;
1181
1182    mul.lo.u32 %src_idx, %batch_idx, %max_len_reg;
1183    add.u32 %src_idx, %src_idx, %row;
1184    mul.lo.u32 %src_idx, %src_idx, %D_reg;
1185    add.u32 %src_idx, %src_idx, %col;
1186
1187    cvt.u64.u32 %src_off, %src_idx;
1188    shl.b64 %src_off, %src_off, 2;
1189    add.u64 %src_off, %src, %src_off;
1190    ld.global.f32 %val, [%src_off];
1191
1192    cvt.u64.u32 %dst_off, %r_tid;
1193    shl.b64 %dst_off, %dst_off, 2;
1194    add.u64 %dst_off, %dst, %dst_off;
1195    st.global.f32 [%dst_off], %val;
1196
1197DONE:
1198    ret;
1199}
1200";
1201
1202// ---------------------------------------------------------------------------
1203// GELU PTX kernel: gelu(x) = x * sigmoid(1.702 * x)
1204//
1205// Uses `.approx` PTX instructions (`ex2.approx.f32`, `rcp.approx.f32`)
1206// for performance. These have reduced precision (~2^-22 relative error)
1207// compared to the full-precision variants, which is acceptable for neural
1208// network training/inference where f32 precision is already limited.
1209// ---------------------------------------------------------------------------
1210
1211#[cfg(feature = "cuda")]
1212pub(crate) const GELU_PTX: &str = "\
1213.version 7.0
1214.target sm_52
1215.address_size 64
1216
1217.visible .entry gelu_kernel(
1218    .param .u64 in_ptr,
1219    .param .u64 out_ptr,
1220    .param .u32 n
1221) {
1222    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1223    .reg .u64 %in, %out, %off;
1224    .reg .f32 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
1225    .reg .pred %p;
1226
1227    ld.param.u64 %in, [in_ptr];
1228    ld.param.u64 %out, [out_ptr];
1229    ld.param.u32 %n_reg, [n];
1230
1231    mov.u32 %bid, %ctaid.x;
1232    mov.u32 %bdim, %ntid.x;
1233    mov.u32 %r_tid, %tid.x;
1234    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1235
1236    setp.ge.u32 %p, %r_tid, %n_reg;
1237    @%p bra DONE;
1238
1239    cvt.u64.u32 %off, %r_tid;
1240    shl.b64 %off, %off, 2;
1241    add.u64 %in, %in, %off;
1242    add.u64 %out, %out, %off;
1243
1244    ld.global.f32 %x, [%in];
1245
1246    mov.f32 %k, 0f3FDA2720;
1247    mul.f32 %neg_kx, %k, %x;
1248    neg.f32 %neg_kx, %neg_kx;
1249    mul.f32 %neg_kx, %neg_kx, 0f3FB8AA3B;
1250    ex2.approx.f32 %exp_neg, %neg_kx;
1251    mov.f32 %one, 0f3F800000;
1252    add.f32 %denom, %one, %exp_neg;
1253    rcp.approx.f32 %sig, %denom;
1254    mul.f32 %result, %x, %sig;
1255    st.global.f32 [%out], %result;
1256
1257DONE:
1258    ret;
1259}
1260";
1261
1262/// PTX source for `gelu_tanh_kernel`: tanh approximation of GELU.
1263/// `out[i] = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))`
1264///
1265/// Uses `ex2.approx.f32` for exp and Horner-form tanh approximation via
1266/// `tanh(y) = (e^(2y) - 1) / (e^(2y) + 1)`.
1267#[cfg(feature = "cuda")]
1268pub(crate) const GELU_TANH_PTX: &str = "\
1269.version 7.0
1270.target sm_52
1271.address_size 64
1272
1273.visible .entry gelu_tanh_kernel(
1274    .param .u64 in_ptr,
1275    .param .u64 out_ptr,
1276    .param .u32 n
1277) {
1278    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1279    .reg .u64 %in, %out, %off;
1280    .reg .f32 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
1281    .reg .f32 %e2y_m1, %e2y_p1, %th, %one, %half, %log2e, %result;
1282    .reg .pred %p;
1283
1284    ld.param.u64 %in, [in_ptr];
1285    ld.param.u64 %out, [out_ptr];
1286    ld.param.u32 %n_reg, [n];
1287
1288    mov.u32 %bid, %ctaid.x;
1289    mov.u32 %bdim, %ntid.x;
1290    mov.u32 %r_tid, %tid.x;
1291    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1292
1293    setp.ge.u32 %p, %r_tid, %n_reg;
1294    @%p bra DONE;
1295
1296    cvt.u64.u32 %off, %r_tid;
1297    shl.b64 %off, %off, 2;
1298    add.u64 %in, %in, %off;
1299    add.u64 %out, %out, %off;
1300
1301    ld.global.f32 %x, [%in];
1302
1303    // inner = sqrt(2/π) * (x + 0.044715 * x³)
1304    // sqrt(2/π) = 0.7978845608 = 0x3F4C422A
1305    // 0.044715 = 0x3D372713
1306    mul.f32 %x3, %x, %x;
1307    mul.f32 %x3, %x3, %x;
1308    mov.f32 %c, 0f3D372713;
1309    mul.f32 %x3, %c, %x3;
1310    add.f32 %inner, %x, %x3;
1311    mov.f32 %sqrt2pi, 0f3F4C422A;
1312    mul.f32 %y, %sqrt2pi, %inner;
1313
1314    // tanh(y) = (exp(2y) - 1) / (exp(2y) + 1)
1315    // exp(2y) = 2^(2y * log2(e))
1316    mov.f32 %log2e, 0f3FB8AA3B;
1317    add.f32 %two_y, %y, %y;
1318    mul.f32 %two_y, %two_y, %log2e;
1319    ex2.approx.f32 %e2y, %two_y;
1320    mov.f32 %one, 0f3F800000;
1321    sub.f32 %e2y_m1, %e2y, %one;
1322    add.f32 %e2y_p1, %e2y, %one;
1323    rcp.approx.f32 %e2y_p1, %e2y_p1;
1324    mul.f32 %th, %e2y_m1, %e2y_p1;
1325
1326    // out = 0.5 * x * (1 + tanh)
1327    add.f32 %th, %one, %th;
1328    mov.f32 %half, 0f3F000000;
1329    mul.f32 %result, %half, %x;
1330    mul.f32 %result, %result, %th;
1331    st.global.f32 [%out], %result;
1332
1333DONE:
1334    ret;
1335}
1336";
1337
1338/// PTX source for `gelu_erf_kernel`: exact GELU using erf.
1339/// `out[i] = x * 0.5 * (1 + erf(x / sqrt(2)))`
1340///
1341/// Uses Abramowitz & Stegun formula 7.1.26 for erf (|ε| < 1.5×10⁻⁷).
1342#[cfg(feature = "cuda")]
1343pub(crate) const GELU_ERF_PTX: &str = "\
1344.version 7.0
1345.target sm_52
1346.address_size 64
1347
1348.visible .entry gelu_erf_kernel(
1349    .param .u64 in_ptr,
1350    .param .u64 out_ptr,
1351    .param .u32 n
1352) {
1353    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1354    .reg .u64 %in, %out, %off;
1355    .reg .f32 %x, %z, %ax, %one, %half, %log2e;
1356    .reg .f32 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
1357    .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %result;
1358    .reg .pred %pred_ge, %pred_neg;
1359
1360    ld.param.u64 %in, [in_ptr];
1361    ld.param.u64 %out, [out_ptr];
1362    ld.param.u32 %n_reg, [n];
1363
1364    mov.u32 %bid, %ctaid.x;
1365    mov.u32 %bdim, %ntid.x;
1366    mov.u32 %r_tid, %tid.x;
1367    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1368
1369    setp.ge.u32 %pred_ge, %r_tid, %n_reg;
1370    @%pred_ge bra DONE;
1371
1372    cvt.u64.u32 %off, %r_tid;
1373    shl.b64 %off, %off, 2;
1374    add.u64 %in, %in, %off;
1375    add.u64 %out, %out, %off;
1376
1377    ld.global.f32 %x, [%in];
1378    mov.f32 %one, 0f3F800000;
1379    mov.f32 %half, 0f3F000000;
1380    mov.f32 %log2e, 0f3FB8AA3B;
1381
1382    // z = x / sqrt(2) = x * 0.70710678
1383    mov.f32 %z, 0f3F3504F3;
1384    mul.f32 %z, %x, %z;
1385
1386    // |z| for erf(|z|)
1387    abs.f32 %ax, %z;
1388
1389    // t = 1 / (1 + 0.3275911 * |z|)
1390    mov.f32 %p, 0f3EA7BA05;
1391    mul.f32 %t, %p, %ax;
1392    add.f32 %t, %one, %t;
1393    rcp.approx.f32 %t, %t;
1394
1395    // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
1396    mov.f32 %a5, 0f3E0AAAAB;
1397    mov.f32 %a4, 0fBEB3A903;
1398    mov.f32 %a3, 0f3FB506DD;
1399    mov.f32 %a2, 0fBF03C1E1;
1400    mov.f32 %a1, 0f3EA0D6BB;
1401
1402    mul.f32 %pt, %t, %a5;
1403    add.f32 %pt, %pt, %a4;
1404    mul.f32 %pt, %pt, %t;
1405    add.f32 %pt, %pt, %a3;
1406    mul.f32 %pt, %pt, %t;
1407    add.f32 %pt, %pt, %a2;
1408    mul.f32 %pt, %pt, %t;
1409    add.f32 %pt, %pt, %a1;
1410    mul.f32 %pt, %pt, %t;
1411
1412    // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
1413    mul.f32 %z2, %ax, %ax;
1414    neg.f32 %neg_z2, %z2;
1415    mul.f32 %neg_z2, %neg_z2, %log2e;
1416    ex2.approx.f32 %exp_neg_z2, %neg_z2;
1417
1418    // erf(|z|) = 1 - poly * exp(-z^2)
1419    mul.f32 %erf_val, %pt, %exp_neg_z2;
1420    sub.f32 %erf_val, %one, %erf_val;
1421
1422    // erf(-z) = -erf(z), so sign-correct
1423    setp.lt.f32 %pred_neg, %z, 0f00000000;
1424    @%pred_neg neg.f32 %erf_val, %erf_val;
1425
1426    // out = x * 0.5 * (1 + erf(x/sqrt(2)))
1427    add.f32 %erf_val, %one, %erf_val;
1428    mul.f32 %result, %half, %x;
1429    mul.f32 %result, %result, %erf_val;
1430    st.global.f32 [%out], %result;
1431
1432DONE:
1433    ret;
1434}
1435";
1436
1437/// PTX source for `gelu_backward_tanh_kernel`:
1438/// Backward for tanh approximation of GELU.
1439/// Let `u = sqrt(2/π) * (x + 0.044715 * x³)`, `t = tanh(u)`.
1440/// `d/dx = 0.5 * (1 + t) + 0.5 * x * (1 - t²) * sqrt(2/π) * (1 + 3*0.044715*x²)`
1441/// `out[i] = grad[i] * d/dx`
1442#[cfg(feature = "cuda")]
1443pub(crate) const GELU_BACKWARD_TANH_PTX: &str = "\
1444.version 7.0
1445.target sm_52
1446.address_size 64
1447
1448.visible .entry gelu_backward_tanh_kernel(
1449    .param .u64 grad_ptr,
1450    .param .u64 input_ptr,
1451    .param .u64 out_ptr,
1452    .param .u32 n
1453) {
1454    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1455    .reg .u64 %grad, %input, %out, %off;
1456    .reg .f32 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
1457    .reg .f32 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half, %log2e;
1458    .reg .f32 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
1459    .reg .pred %p;
1460
1461    ld.param.u64 %grad, [grad_ptr];
1462    ld.param.u64 %input, [input_ptr];
1463    ld.param.u64 %out, [out_ptr];
1464    ld.param.u32 %n_reg, [n];
1465
1466    mov.u32 %bid, %ctaid.x;
1467    mov.u32 %bdim, %ntid.x;
1468    mov.u32 %r_tid, %tid.x;
1469    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1470
1471    setp.ge.u32 %p, %r_tid, %n_reg;
1472    @%p bra DONE;
1473
1474    cvt.u64.u32 %off, %r_tid;
1475    shl.b64 %off, %off, 2;
1476    add.u64 %grad, %grad, %off;
1477    add.u64 %input, %input, %off;
1478    add.u64 %out, %out, %off;
1479
1480    ld.global.f32 %vg, [%grad];
1481    ld.global.f32 %x, [%input];
1482
1483    mov.f32 %one, 0f3F800000;
1484    mov.f32 %half, 0f3F000000;
1485    mov.f32 %log2e, 0f3FB8AA3B;
1486    mov.f32 %sqrt2pi, 0f3F4C422A;
1487    mov.f32 %c, 0f3D372713;
1488    // 3 * 0.044715 = 0.134145 = 0x3E096B8C
1489    mov.f32 %c3, 0f3E096B8C;
1490
1491    // u = sqrt(2/π) * (x + 0.044715 * x³)
1492    mul.f32 %x2, %x, %x;
1493    mul.f32 %x3, %x2, %x;
1494    mul.f32 %x3, %c, %x3;
1495    add.f32 %inner, %x, %x3;
1496    mul.f32 %y, %sqrt2pi, %inner;
1497
1498    // tanh(y) via exp
1499    add.f32 %two_y, %y, %y;
1500    mul.f32 %two_y, %two_y, %log2e;
1501    ex2.approx.f32 %e2y, %two_y;
1502    sub.f32 %e2y_m1, %e2y, %one;
1503    add.f32 %e2y_p1, %e2y, %one;
1504    rcp.approx.f32 %e2y_p1, %e2y_p1;
1505    mul.f32 %th, %e2y_m1, %e2y_p1;
1506
1507    // d/dx = 0.5*(1+tanh) + 0.5*x*(1-tanh²)*sqrt(2/π)*(1+3*0.044715*x²)
1508    // term1 = 0.5 * (1 + th)
1509    add.f32 %term1, %one, %th;
1510    mul.f32 %term1, %half, %term1;
1511
1512    // (1 - th²)
1513    mul.f32 %th2, %th, %th;
1514    sub.f32 %one_m_th2, %one, %th2;
1515
1516    // d_inner = sqrt(2/π) * (1 + 3*0.044715*x²)
1517    mul.f32 %d_inner, %c3, %x2;
1518    add.f32 %d_inner, %one, %d_inner;
1519    mul.f32 %d_inner, %sqrt2pi, %d_inner;
1520
1521    // term2 = 0.5 * x * (1-th²) * d_inner
1522    mul.f32 %term2, %half, %x;
1523    mul.f32 %term2, %term2, %one_m_th2;
1524    mul.f32 %term2, %term2, %d_inner;
1525
1526    add.f32 %d_gelu, %term1, %term2;
1527    mul.f32 %result, %vg, %d_gelu;
1528    st.global.f32 [%out], %result;
1529
1530DONE:
1531    ret;
1532}
1533";
1534
1535// ---------------------------------------------------------------------------
1536// Backward activation kernels
1537// ---------------------------------------------------------------------------
1538
1539/// PTX source for `relu_backward_kernel`: `out[i] = (input[i] > 0) ? grad[i] : 0`.
1540/// Takes two inputs: grad (upstream gradient) and input (forward activation input).
1541#[cfg(feature = "cuda")]
1542pub(crate) const RELU_BACKWARD_PTX: &str = "\
1543.version 7.0
1544.target sm_52
1545.address_size 64
1546
1547.visible .entry relu_backward_kernel(
1548    .param .u64 grad_ptr,
1549    .param .u64 input_ptr,
1550    .param .u64 out_ptr,
1551    .param .u32 n
1552) {
1553    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1554    .reg .u64 %grad, %input, %out, %off;
1555    .reg .f32 %vg, %vi, %zero, %vr;
1556    .reg .pred %p, %pos;
1557
1558    ld.param.u64 %grad, [grad_ptr];
1559    ld.param.u64 %input, [input_ptr];
1560    ld.param.u64 %out, [out_ptr];
1561    ld.param.u32 %n_reg, [n];
1562
1563    mov.u32 %bid, %ctaid.x;
1564    mov.u32 %bdim, %ntid.x;
1565    mov.u32 %r_tid, %tid.x;
1566    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1567
1568    setp.ge.u32 %p, %r_tid, %n_reg;
1569    @%p bra DONE;
1570
1571    cvt.u64.u32 %off, %r_tid;
1572    shl.b64 %off, %off, 2;
1573
1574    add.u64 %grad, %grad, %off;
1575    add.u64 %input, %input, %off;
1576    add.u64 %out, %out, %off;
1577
1578    ld.global.f32 %vg, [%grad];
1579    ld.global.f32 %vi, [%input];
1580    mov.f32 %zero, 0f00000000;
1581    setp.gt.f32 %pos, %vi, %zero;
1582    selp.f32 %vr, %vg, %zero, %pos;
1583    st.global.f32 [%out], %vr;
1584
1585DONE:
1586    ret;
1587}
1588";
1589
1590/// PTX source for `gelu_backward_kernel`:
1591/// `out[i] = grad[i] * (sig + 1.702 * x * sig * (1 - sig))`
1592/// where `sig = sigmoid(1.702 * x)`.
1593/// This is the exact derivative of `gelu(x) = x * sigmoid(1.702 * x)`.
1594///
1595/// Uses `.approx` PTX instructions (`ex2.approx.f32`, `rcp.approx.f32`)
1596/// for performance. These have reduced precision (~2^-22 relative error)
1597/// compared to the full-precision variants, which is acceptable for neural
1598/// network training/inference where f32 precision is already limited.
1599#[cfg(feature = "cuda")]
1600pub(crate) const GELU_BACKWARD_PTX: &str = "\
1601.version 7.0
1602.target sm_52
1603.address_size 64
1604
1605.visible .entry gelu_backward_kernel(
1606    .param .u64 grad_ptr,
1607    .param .u64 input_ptr,
1608    .param .u64 out_ptr,
1609    .param .u32 n
1610) {
1611    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1612    .reg .u64 %grad, %input, %out, %off;
1613    .reg .f32 %vg, %x, %k, %kx, %neg_kx, %log2e, %exp_neg, %one, %denom, %sig;
1614    .reg .f32 %one_minus_sig, %kx_sig_oms, %dsig, %result;
1615    .reg .pred %p;
1616
1617    ld.param.u64 %grad, [grad_ptr];
1618    ld.param.u64 %input, [input_ptr];
1619    ld.param.u64 %out, [out_ptr];
1620    ld.param.u32 %n_reg, [n];
1621
1622    mov.u32 %bid, %ctaid.x;
1623    mov.u32 %bdim, %ntid.x;
1624    mov.u32 %r_tid, %tid.x;
1625    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1626
1627    setp.ge.u32 %p, %r_tid, %n_reg;
1628    @%p bra DONE;
1629
1630    cvt.u64.u32 %off, %r_tid;
1631    shl.b64 %off, %off, 2;
1632
1633    add.u64 %grad, %grad, %off;
1634    add.u64 %input, %input, %off;
1635    add.u64 %out, %out, %off;
1636
1637    ld.global.f32 %vg, [%grad];
1638    ld.global.f32 %x, [%input];
1639
1640    // sig = sigmoid(1.702 * x)
1641    mov.f32 %k, 0f3FDA2720;
1642    mul.f32 %kx, %k, %x;
1643    neg.f32 %neg_kx, %kx;
1644    mov.f32 %log2e, 0f3FB8AA3B;
1645    mul.f32 %neg_kx, %neg_kx, %log2e;
1646    ex2.approx.f32 %exp_neg, %neg_kx;
1647    mov.f32 %one, 0f3F800000;
1648    add.f32 %denom, %one, %exp_neg;
1649    rcp.approx.f32 %sig, %denom;
1650
1651    // d/dx gelu(x) = sig + k * x * sig * (1 - sig)
1652    sub.f32 %one_minus_sig, %one, %sig;
1653    mul.f32 %kx_sig_oms, %kx, %sig;
1654    mul.f32 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
1655    add.f32 %dsig, %sig, %kx_sig_oms;
1656
1657    // out = grad * d_gelu
1658    mul.f32 %result, %vg, %dsig;
1659    st.global.f32 [%out], %result;
1660
1661DONE:
1662    ret;
1663}
1664";
1665
1666/// PTX source for `gelu_backward_erf_kernel`:
1667/// Exact GELU backward using erf: `d/dx gelu(x) = Φ(x) + x·φ(x)`
1668/// where `Φ(x) = 0.5·(1 + erf(x/√2))` and `φ(x) = exp(-x²/2) / √(2π)`.
1669///
1670/// Uses Abramowitz & Stegun formula 7.1.26 for erf (|ε| < 1.5×10⁻⁷):
1671///   `erf(x) = 1 - (a₁t + a₂t² + a₃t³ + a₄t⁴ + a₅t⁵) · exp(-x²)`
1672///   where `t = 1/(1 + 0.3275911·|x|)`
1673#[cfg(feature = "cuda")]
1674pub(crate) const GELU_BACKWARD_ERF_PTX: &str = "\
1675.version 7.0
1676.target sm_52
1677.address_size 64
1678
1679.visible .entry gelu_backward_erf_kernel(
1680    .param .u64 grad_ptr,
1681    .param .u64 input_ptr,
1682    .param .u64 out_ptr,
1683    .param .u32 n
1684) {
1685    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1686    .reg .u64 %grad, %input, %out, %off;
1687    .reg .f32 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
1688    .reg .f32 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
1689    .reg .f32 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
1690    .reg .f32 %d_gelu, %result;
1691    .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %log2e;
1692    .reg .pred %pred_ge, %pred_neg;
1693
1694    ld.param.u64 %grad, [grad_ptr];
1695    ld.param.u64 %input, [input_ptr];
1696    ld.param.u64 %out, [out_ptr];
1697    ld.param.u32 %n_reg, [n];
1698
1699    mov.u32 %bid, %ctaid.x;
1700    mov.u32 %bdim, %ntid.x;
1701    mov.u32 %r_tid, %tid.x;
1702    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1703
1704    setp.ge.u32 %pred_ge, %r_tid, %n_reg;
1705    @%pred_ge bra DONE;
1706
1707    cvt.u64.u32 %off, %r_tid;
1708    shl.b64 %off, %off, 2;
1709
1710    add.u64 %grad, %grad, %off;
1711    add.u64 %input, %input, %off;
1712    add.u64 %out, %out, %off;
1713
1714    ld.global.f32 %vg, [%grad];
1715    ld.global.f32 %x, [%input];
1716
1717    mov.f32 %one, 0f3F800000;
1718    mov.f32 %half, 0f3F000000;
1719
1720    // z = x / sqrt(2) = x * 0.70710678
1721    mov.f32 %z, 0f3F3504F3;
1722    mul.f32 %z, %x, %z;
1723
1724    // |z| for erf(|z|)
1725    abs.f32 %ax, %z;
1726
1727    // t = 1 / (1 + 0.3275911 * |z|)
1728    mov.f32 %p, 0f3EA7BA05;
1729    mul.f32 %t, %p, %ax;
1730    add.f32 %t, %one, %t;
1731    rcp.approx.f32 %t, %t;
1732
1733    // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
1734    mov.f32 %a5, 0f3E0AAAAB;
1735    mov.f32 %a4, 0fBEB3A903;
1736    mov.f32 %a3, 0f3FB506DD;
1737    mov.f32 %a2, 0fBF03C1E1;
1738    mov.f32 %a1, 0f3EA0D6BB;
1739
1740    mul.f32 %pt, %t, %a5;
1741    add.f32 %pt, %pt, %a4;
1742    mul.f32 %pt, %pt, %t;
1743    add.f32 %pt, %pt, %a3;
1744    mul.f32 %pt, %pt, %t;
1745    add.f32 %pt, %pt, %a2;
1746    mul.f32 %pt, %pt, %t;
1747    add.f32 %pt, %pt, %a1;
1748    mul.f32 %pt, %pt, %t;
1749
1750    // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
1751    mul.f32 %z2, %ax, %ax;
1752    neg.f32 %neg_z2, %z2;
1753    mov.f32 %log2e, 0f3FB8AA3B;
1754    mul.f32 %neg_z2, %neg_z2, %log2e;
1755    ex2.approx.f32 %exp_neg_z2, %neg_z2;
1756
1757    // erf(|z|) = 1 - poly * exp(-z^2)
1758    mul.f32 %erf_val, %pt, %exp_neg_z2;
1759    sub.f32 %erf_val, %one, %erf_val;
1760
1761    // erf(-z) = -erf(z), so sign-correct
1762    setp.lt.f32 %pred_neg, %z, 0f00000000;
1763    @%pred_neg neg.f32 %erf_val, %erf_val;
1764
1765    // Φ(x) = 0.5 * (1 + erf(x/sqrt(2)))
1766    add.f32 %cdf, %one, %erf_val;
1767    mul.f32 %cdf, %half, %cdf;
1768
1769    // φ(x) = exp(-x²/2) / sqrt(2π)
1770    // exp(-x²/2):
1771    mul.f32 %neg_x2h, %x, %x;
1772    mul.f32 %neg_x2h, %neg_x2h, %half;
1773    neg.f32 %neg_x2h, %neg_x2h;
1774    mul.f32 %neg_x2h, %neg_x2h, %log2e;
1775    ex2.approx.f32 %exp_neg_x2h, %neg_x2h;
1776
1777    // 1/sqrt(2π) = 0.39894228
1778    mov.f32 %inv_sqrt_2pi, 0f3ECC4220;
1779    mul.f32 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
1780
1781    // d/dx gelu(x) = Φ(x) + x * φ(x)
1782    mul.f32 %x_pdf, %x, %pdf;
1783    add.f32 %d_gelu, %cdf, %x_pdf;
1784
1785    // out = grad * d_gelu
1786    mul.f32 %result, %vg, %d_gelu;
1787    st.global.f32 [%out], %result;
1788
1789DONE:
1790    ret;
1791}
1792";
1793
1794// ---------------------------------------------------------------------------
1795// Index-select (1-D gather) PTX kernel
1796// ---------------------------------------------------------------------------
1797// Thread i: output[i] = input[indices[i]]
1798// Indices are stored as f32 on the GPU (cast to u32 via truncation).
1799
1800#[cfg(feature = "cuda")]
1801pub(crate) const INDEX_SELECT_1D_PTX: &str = "\
1802.version 7.0
1803.target sm_52
1804.address_size 64
1805
1806.visible .entry index_select_1d_kernel(
1807    .param .u64 input_ptr,
1808    .param .u64 indices_ptr,
1809    .param .u64 out_ptr,
1810    .param .u32 n_indices
1811) {
1812    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
1813    .reg .u64 %input, %indices, %out, %off, %addr;
1814    .reg .f32 %idx_f, %val;
1815    .reg .pred %p;
1816
1817    ld.param.u64 %input, [input_ptr];
1818    ld.param.u64 %indices, [indices_ptr];
1819    ld.param.u64 %out, [out_ptr];
1820    ld.param.u32 %n_reg, [n_indices];
1821
1822    mov.u32 %bid, %ctaid.x;
1823    mov.u32 %bdim, %ntid.x;
1824    mov.u32 %r_tid, %tid.x;
1825    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1826
1827    setp.ge.u32 %p, %r_tid, %n_reg;
1828    @%p bra DONE;
1829
1830    // Byte offset for thread
1831    cvt.u64.u32 %off, %r_tid;
1832    shl.b64 %off, %off, 2;
1833
1834    // Read indices[tid] (f32 -> u32)
1835    add.u64 %addr, %indices, %off;
1836    ld.global.f32 %idx_f, [%addr];
1837    cvt.rzi.u32.f32 %idx, %idx_f;
1838
1839    // Read input[idx]
1840    cvt.u64.u32 %addr, %idx;
1841    shl.b64 %addr, %addr, 2;
1842    add.u64 %addr, %input, %addr;
1843    ld.global.f32 %val, [%addr];
1844
1845    // Write output[tid]
1846    add.u64 %addr, %out, %off;
1847    st.global.f32 [%addr], %val;
1848
1849DONE:
1850    ret;
1851}
1852";
1853
1854// ---------------------------------------------------------------------------
1855// Scatter-add (1-D) PTX kernel — backward of index_select
1856// ---------------------------------------------------------------------------
1857// Thread i: atomicAdd(grad_input[indices[i]], grad_output[i])
1858// The output buffer (grad_input) must be pre-zeroed.
1859// Uses atom.global.add.f32 for safe concurrent accumulation when
1860// duplicate indices map multiple threads to the same output slot.
1861
1862#[cfg(feature = "cuda")]
1863pub(crate) const SCATTER_ADD_1D_PTX: &str = "\
1864.version 7.0
1865.target sm_52
1866.address_size 64
1867
1868.visible .entry scatter_add_1d_kernel(
1869    .param .u64 grad_output_ptr,
1870    .param .u64 indices_ptr,
1871    .param .u64 grad_input_ptr,
1872    .param .u32 n_indices
1873) {
1874    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
1875    .reg .u64 %go, %indices, %gi, %off, %addr;
1876    .reg .f32 %idx_f, %grad_val, %dummy;
1877    .reg .pred %p;
1878
1879    ld.param.u64 %go, [grad_output_ptr];
1880    ld.param.u64 %indices, [indices_ptr];
1881    ld.param.u64 %gi, [grad_input_ptr];
1882    ld.param.u32 %n_reg, [n_indices];
1883
1884    mov.u32 %bid, %ctaid.x;
1885    mov.u32 %bdim, %ntid.x;
1886    mov.u32 %r_tid, %tid.x;
1887    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1888
1889    setp.ge.u32 %p, %r_tid, %n_reg;
1890    @%p bra DONE;
1891
1892    // Byte offset for thread
1893    cvt.u64.u32 %off, %r_tid;
1894    shl.b64 %off, %off, 2;
1895
1896    // Read grad_output[tid]
1897    add.u64 %addr, %go, %off;
1898    ld.global.f32 %grad_val, [%addr];
1899
1900    // Read indices[tid] (f32 -> u32)
1901    add.u64 %addr, %indices, %off;
1902    ld.global.f32 %idx_f, [%addr];
1903    cvt.rzi.u32.f32 %idx, %idx_f;
1904
1905    // Atomic add: grad_input[idx] += grad_val
1906    cvt.u64.u32 %addr, %idx;
1907    shl.b64 %addr, %addr, 2;
1908    add.u64 %addr, %gi, %addr;
1909    atom.global.add.f32 %dummy, [%addr], %grad_val;
1910
1911DONE:
1912    ret;
1913}
1914";
1915
1916// ---------------------------------------------------------------------------
1917// Masked-fill PTX kernel
1918// ---------------------------------------------------------------------------
1919// Thread i: output[i] = mask[i] >= 0.5 ? fill_value : input[i]
1920// Mask is stored as f32 (1.0 = true, 0.0 = false).
1921
1922#[cfg(feature = "cuda")]
1923pub(crate) const MASKED_FILL_PTX: &str = "\
1924.version 7.0
1925.target sm_52
1926.address_size 64
1927
1928.visible .entry masked_fill_kernel(
1929    .param .u64 input_ptr,
1930    .param .u64 mask_ptr,
1931    .param .u64 out_ptr,
1932    .param .f32 fill_value,
1933    .param .u32 n
1934) {
1935    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1936    .reg .u64 %input, %mask, %out, %off;
1937    .reg .f32 %in_val, %mask_val, %fill, %result, %half;
1938    .reg .pred %p, %pmask;
1939
1940    ld.param.u64 %input, [input_ptr];
1941    ld.param.u64 %mask, [mask_ptr];
1942    ld.param.u64 %out, [out_ptr];
1943    ld.param.f32 %fill, [fill_value];
1944    ld.param.u32 %n_reg, [n];
1945
1946    mov.u32 %bid, %ctaid.x;
1947    mov.u32 %bdim, %ntid.x;
1948    mov.u32 %r_tid, %tid.x;
1949    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1950
1951    setp.ge.u32 %p, %r_tid, %n_reg;
1952    @%p bra DONE;
1953
1954    cvt.u64.u32 %off, %r_tid;
1955    shl.b64 %off, %off, 2;
1956
1957    add.u64 %input, %input, %off;
1958    add.u64 %mask, %mask, %off;
1959    add.u64 %out, %out, %off;
1960
1961    ld.global.f32 %in_val, [%input];
1962    ld.global.f32 %mask_val, [%mask];
1963    mov.f32 %half, 0f3F000000;
1964    setp.ge.f32 %pmask, %mask_val, %half;
1965    selp.f32 %result, %fill, %in_val, %pmask;
1966    st.global.f32 [%out], %result;
1967
1968DONE:
1969    ret;
1970}
1971";
1972
1973// ---------------------------------------------------------------------------
1974// Masked-zero PTX kernel — backward of masked_fill
1975// ---------------------------------------------------------------------------
1976// Thread i: output[i] = mask[i] >= 0.5 ? 0.0 : grad_output[i]
1977// Zeroes gradient at positions where the forward mask was true.
1978
1979#[cfg(feature = "cuda")]
1980pub(crate) const MASKED_ZERO_PTX: &str = "\
1981.version 7.0
1982.target sm_52
1983.address_size 64
1984
1985.visible .entry masked_zero_kernel(
1986    .param .u64 grad_ptr,
1987    .param .u64 mask_ptr,
1988    .param .u64 out_ptr,
1989    .param .u32 n
1990) {
1991    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1992    .reg .u64 %grad, %mask, %out, %off;
1993    .reg .f32 %vg, %mask_val, %zero, %result, %half;
1994    .reg .pred %p, %pmask;
1995
1996    ld.param.u64 %grad, [grad_ptr];
1997    ld.param.u64 %mask, [mask_ptr];
1998    ld.param.u64 %out, [out_ptr];
1999    ld.param.u32 %n_reg, [n];
2000
2001    mov.u32 %bid, %ctaid.x;
2002    mov.u32 %bdim, %ntid.x;
2003    mov.u32 %r_tid, %tid.x;
2004    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2005
2006    setp.ge.u32 %p, %r_tid, %n_reg;
2007    @%p bra DONE;
2008
2009    cvt.u64.u32 %off, %r_tid;
2010    shl.b64 %off, %off, 2;
2011
2012    add.u64 %grad, %grad, %off;
2013    add.u64 %mask, %mask, %off;
2014    add.u64 %out, %out, %off;
2015
2016    ld.global.f32 %vg, [%grad];
2017    ld.global.f32 %mask_val, [%mask];
2018    mov.f32 %zero, 0f00000000;
2019    mov.f32 %half, 0f3F000000;
2020    setp.ge.f32 %pmask, %mask_val, %half;
2021    selp.f32 %result, %zero, %vg, %pmask;
2022    st.global.f32 [%out], %result;
2023
2024DONE:
2025    ret;
2026}
2027";
2028
2029// ---------------------------------------------------------------------------
2030// Sigmoid backward PTX kernel: out[i] = grad[i] * output[i] * (1 - output[i])
2031// ---------------------------------------------------------------------------
2032
2033#[cfg(feature = "cuda")]
2034pub(crate) const SIGMOID_BACKWARD_PTX: &str = "\
2035.version 7.0
2036.target sm_52
2037.address_size 64
2038
2039.visible .entry sigmoid_backward_kernel(
2040    .param .u64 grad_ptr,
2041    .param .u64 output_ptr,
2042    .param .u64 out_ptr,
2043    .param .u32 n
2044) {
2045    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2046    .reg .u64 %grad, %output, %out, %off;
2047    .reg .f32 %vg, %vo, %one, %one_minus_o, %result;
2048    .reg .pred %p;
2049
2050    ld.param.u64 %grad, [grad_ptr];
2051    ld.param.u64 %output, [output_ptr];
2052    ld.param.u64 %out, [out_ptr];
2053    ld.param.u32 %n_reg, [n];
2054
2055    mov.u32 %bid, %ctaid.x;
2056    mov.u32 %bdim, %ntid.x;
2057    mov.u32 %r_tid, %tid.x;
2058    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2059
2060    setp.ge.u32 %p, %r_tid, %n_reg;
2061    @%p bra DONE;
2062
2063    cvt.u64.u32 %off, %r_tid;
2064    shl.b64 %off, %off, 2;
2065
2066    add.u64 %grad, %grad, %off;
2067    add.u64 %output, %output, %off;
2068    add.u64 %out, %out, %off;
2069
2070    ld.global.f32 %vg, [%grad];
2071    ld.global.f32 %vo, [%output];
2072    mov.f32 %one, 0f3F800000;
2073    sub.f32 %one_minus_o, %one, %vo;
2074    mul.f32 %result, %vo, %one_minus_o;
2075    mul.f32 %result, %vg, %result;
2076    st.global.f32 [%out], %result;
2077
2078DONE:
2079    ret;
2080}
2081";
2082
2083// ---------------------------------------------------------------------------
2084// Tanh backward PTX kernel: out[i] = grad[i] * (1 - output[i]^2)
2085// ---------------------------------------------------------------------------
2086
2087#[cfg(feature = "cuda")]
2088pub(crate) const TANH_BACKWARD_PTX: &str = "\
2089.version 7.0
2090.target sm_52
2091.address_size 64
2092
2093.visible .entry tanh_backward_kernel(
2094    .param .u64 grad_ptr,
2095    .param .u64 output_ptr,
2096    .param .u64 out_ptr,
2097    .param .u32 n
2098) {
2099    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2100    .reg .u64 %grad, %output, %out, %off;
2101    .reg .f32 %vg, %vo, %one, %o_sq, %one_minus_sq, %result;
2102    .reg .pred %p;
2103
2104    ld.param.u64 %grad, [grad_ptr];
2105    ld.param.u64 %output, [output_ptr];
2106    ld.param.u64 %out, [out_ptr];
2107    ld.param.u32 %n_reg, [n];
2108
2109    mov.u32 %bid, %ctaid.x;
2110    mov.u32 %bdim, %ntid.x;
2111    mov.u32 %r_tid, %tid.x;
2112    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2113
2114    setp.ge.u32 %p, %r_tid, %n_reg;
2115    @%p bra DONE;
2116
2117    cvt.u64.u32 %off, %r_tid;
2118    shl.b64 %off, %off, 2;
2119
2120    add.u64 %grad, %grad, %off;
2121    add.u64 %output, %output, %off;
2122    add.u64 %out, %out, %off;
2123
2124    ld.global.f32 %vg, [%grad];
2125    ld.global.f32 %vo, [%output];
2126    mov.f32 %one, 0f3F800000;
2127    mul.f32 %o_sq, %vo, %vo;
2128    sub.f32 %one_minus_sq, %one, %o_sq;
2129    mul.f32 %result, %vg, %one_minus_sq;
2130    st.global.f32 [%out], %result;
2131
2132DONE:
2133    ret;
2134}
2135";
2136
2137// ---------------------------------------------------------------------------
2138// Softmax backward PTX kernel (row-wise, shared-memory dot product)
2139// ---------------------------------------------------------------------------
2140// For each row of length `cols`:
2141//   dot = sum(grad[row] * output[row])
2142//   out[i] = output[i] * (grad[i] - dot)
2143// One block per row, 256 threads per block.
2144
2145#[cfg(feature = "cuda")]
2146pub(crate) const SOFTMAX_BACKWARD_PTX: &str = "\
2147.version 7.0\n\
2148.target sm_52\n\
2149.address_size 64\n\
2150\n\
2151.shared .align 4 .f32 sdata[256];\n\
2152\n\
2153.visible .entry softmax_backward_kernel(\n\
2154    .param .u64 grad_ptr,\n\
2155    .param .u64 output_ptr,\n\
2156    .param .u64 out_ptr,\n\
2157    .param .u32 rows,\n\
2158    .param .u32 cols\n\
2159) {\n\
2160    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
2161    .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
2162    .reg .f32 %vg, %vo, %dot, %other_val, %diff, %result;\n\
2163    .reg .pred %p, %loop_p, %reduce_p;\n\
2164\n\
2165    ld.param.u64 %grad, [grad_ptr];\n\
2166    ld.param.u64 %output, [output_ptr];\n\
2167    ld.param.u64 %out, [out_ptr];\n\
2168    ld.param.u32 %rows_reg, [rows];\n\
2169    ld.param.u32 %cols_reg, [cols];\n\
2170\n\
2171    mov.u32 %bid, %ctaid.x;\n\
2172    mov.u32 %bdim, %ntid.x;\n\
2173    mov.u32 %r_tid, %tid.x;\n\
2174    mov.u64 %sbase, sdata;\n\
2175\n\
2176    setp.ge.u32 %p, %bid, %rows_reg;\n\
2177    @%p bra DONE;\n\
2178\n\
2179    // row_off = bid * cols * 4 (byte offset)\n\
2180    cvt.u64.u32 %row_off, %bid;\n\
2181    cvt.u64.u32 %off, %cols_reg;\n\
2182    mul.lo.u64 %row_off, %row_off, %off;\n\
2183    shl.b64 %row_off, %row_off, 2;\n\
2184\n\
2185    // Phase 1: compute partial dot = sum(grad[j] * output[j]) for this thread's elements\n\
2186    mov.f32 %dot, 0f00000000;\n\
2187    mov.u32 %j, %r_tid;\n\
2188DOT_LOOP:\n\
2189    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
2190    @%loop_p bra DOT_LOOP_DONE;\n\
2191    cvt.u64.u32 %off, %j;\n\
2192    shl.b64 %off, %off, 2;\n\
2193    add.u64 %saddr, %grad, %off;\n\
2194    add.u64 %saddr, %saddr, %row_off;\n\
2195    ld.global.f32 %vg, [%saddr];\n\
2196    add.u64 %saddr, %output, %off;\n\
2197    add.u64 %saddr, %saddr, %row_off;\n\
2198    ld.global.f32 %vo, [%saddr];\n\
2199    fma.rn.f32 %dot, %vg, %vo, %dot;\n\
2200    add.u32 %j, %j, %bdim;\n\
2201    bra DOT_LOOP;\n\
2202DOT_LOOP_DONE:\n\
2203\n\
2204    // Store partial dot into shared memory and reduce\n\
2205    cvt.u64.u32 %off, %r_tid;\n\
2206    shl.b64 %off, %off, 2;\n\
2207    add.u64 %saddr, %sbase, %off;\n\
2208    st.shared.f32 [%saddr], %dot;\n\
2209    bar.sync 0;\n\
2210\n\
2211    mov.u32 %half, %bdim;\n\
2212DOT_REDUCE:\n\
2213    shr.u32 %half, %half, 1;\n\
2214    setp.eq.u32 %reduce_p, %half, 0;\n\
2215    @%reduce_p bra DOT_REDUCE_DONE;\n\
2216    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
2217    @%reduce_p bra DOT_REDUCE_SKIP;\n\
2218    add.u32 %other_tid, %r_tid, %half;\n\
2219    cvt.u64.u32 %off, %other_tid;\n\
2220    shl.b64 %off, %off, 2;\n\
2221    add.u64 %saddr, %sbase, %off;\n\
2222    ld.shared.f32 %other_val, [%saddr];\n\
2223    cvt.u64.u32 %off, %r_tid;\n\
2224    shl.b64 %off, %off, 2;\n\
2225    add.u64 %saddr, %sbase, %off;\n\
2226    ld.shared.f32 %dot, [%saddr];\n\
2227    add.f32 %dot, %dot, %other_val;\n\
2228    st.shared.f32 [%saddr], %dot;\n\
2229DOT_REDUCE_SKIP:\n\
2230    bar.sync 0;\n\
2231    bra DOT_REDUCE;\n\
2232DOT_REDUCE_DONE:\n\
2233\n\
2234    // Broadcast dot to all threads\n\
2235    ld.shared.f32 %dot, [sdata];\n\
2236    bar.sync 0;\n\
2237\n\
2238    // Phase 2: out[j] = output[j] * (grad[j] - dot)\n\
2239    mov.u32 %j, %r_tid;\n\
2240WRITE_LOOP:\n\
2241    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
2242    @%loop_p bra WRITE_LOOP_DONE;\n\
2243    cvt.u64.u32 %off, %j;\n\
2244    shl.b64 %off, %off, 2;\n\
2245    add.u64 %saddr, %grad, %off;\n\
2246    add.u64 %saddr, %saddr, %row_off;\n\
2247    ld.global.f32 %vg, [%saddr];\n\
2248    add.u64 %saddr, %output, %off;\n\
2249    add.u64 %saddr, %saddr, %row_off;\n\
2250    ld.global.f32 %vo, [%saddr];\n\
2251    sub.f32 %diff, %vg, %dot;\n\
2252    mul.f32 %result, %vo, %diff;\n\
2253    add.u64 %saddr, %out, %off;\n\
2254    add.u64 %saddr, %saddr, %row_off;\n\
2255    st.global.f32 [%saddr], %result;\n\
2256    add.u32 %j, %j, %bdim;\n\
2257    bra WRITE_LOOP;\n\
2258WRITE_LOOP_DONE:\n\
2259\n\
2260DONE:\n\
2261    ret;\n\
2262}\n\
2263";
2264
2265// ---------------------------------------------------------------------------
2266// Sum-axis PTX kernel: reduce along one axis of a tensor
2267// ---------------------------------------------------------------------------
2268// Parameters: input_ptr, output_ptr, outer_size, axis_size, inner_size, total_output
2269/// PTX source for `reduce_sum_kernel`: parallel block-level sum reduction.
2270///
2271/// Each block reduces a contiguous chunk of the input array using shared
2272/// memory. Threads first accumulate a sequential sum (grid-stride loop),
2273/// store to shared memory, then do a tree reduction within the block.
2274/// Each block writes one partial sum to `output[blockIdx.x]`.
2275///
2276/// For a full reduction, launch once to get partial sums, then launch
2277/// again on the partial sums (or reduce on CPU if few blocks).
2278#[cfg(feature = "cuda")]
2279pub(crate) const REDUCE_SUM_PTX: &str = "\
2280.version 7.0
2281.target sm_52
2282.address_size 64
2283
2284// Shared memory for intra-block reduction (256 floats = 1024 bytes).
2285.shared .align 4 .f32 sdata[256];
2286
2287.visible .entry reduce_sum_kernel(
2288    .param .u64 in_ptr,
2289    .param .u64 out_ptr,
2290    .param .u32 n
2291) {
2292    .reg .u32 %tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
2293    .reg .u64 %in, %out, %off;
2294    .reg .f32 %sum, %other;
2295    .reg .pred %p, %ptid;
2296
2297    ld.param.u64 %in, [in_ptr];
2298    ld.param.u64 %out, [out_ptr];
2299    ld.param.u32 %n_reg, [n];
2300
2301    mov.u32 %tid, %tid.x;
2302    mov.u32 %bid, %ctaid.x;
2303    mov.u32 %bdim, %ntid.x;
2304    mov.u32 %gdim, %nctaid.x;
2305
2306    // Grid-stride accumulation: each thread sums multiple elements.
2307    // idx = bid * bdim + tid; stride = bdim * gdim
2308    mad.lo.u32 %idx, %bid, %bdim, %tid;
2309    mul.lo.u32 %stride, %bdim, %gdim;
2310    mov.f32 %sum, 0f00000000;
2311
2312GRID_LOOP:
2313    setp.ge.u32 %p, %idx, %n_reg;
2314    @%p bra GRID_DONE;
2315
2316    cvt.u64.u32 %off, %idx;
2317    shl.b64 %off, %off, 2;
2318    add.u64 %off, %in, %off;
2319    ld.global.f32 %other, [%off];
2320    add.f32 %sum, %sum, %other;
2321    add.u32 %idx, %idx, %stride;
2322    bra GRID_LOOP;
2323
2324GRID_DONE:
2325    // Write thread's partial sum to shared memory.
2326    cvt.u64.u32 %off, %tid;
2327    shl.b64 %off, %off, 2;
2328    st.shared.f32 [sdata + %off], %sum;
2329    bar.sync 0;
2330
2331    // Tree reduction in shared memory.
2332    mov.u32 %half, 128;
2333TREE_LOOP:
2334    setp.lt.u32 %p, %half, 1;
2335    @%p bra TREE_DONE;
2336
2337    setp.ge.u32 %ptid, %tid, %half;
2338    @%ptid bra TREE_SKIP;
2339
2340    // Load partner's value from sdata[tid + half].
2341    add.u32 %idx, %tid, %half;
2342    cvt.u64.u32 %off, %idx;
2343    shl.b64 %off, %off, 2;
2344    ld.shared.f32 %other, [sdata + %off];
2345    // Load own value.
2346    cvt.u64.u32 %off, %tid;
2347    shl.b64 %off, %off, 2;
2348    ld.shared.f32 %sum, [sdata + %off];
2349    add.f32 %sum, %sum, %other;
2350    st.shared.f32 [sdata + %off], %sum;
2351
2352TREE_SKIP:
2353    bar.sync 0;
2354    shr.u32 %half, %half, 1;
2355    bra TREE_LOOP;
2356
2357TREE_DONE:
2358    // Thread 0 writes block result.
2359    setp.ne.u32 %ptid, %tid, 0;
2360    @%ptid bra END;
2361
2362    ld.shared.f32 %sum, [sdata];
2363    cvt.u64.u32 %off, %bid;
2364    shl.b64 %off, %off, 2;
2365    add.u64 %out, %out, %off;
2366    st.global.f32 [%out], %sum;
2367
2368END:
2369    ret;
2370}
2371";
2372
2373// Thread i: output[i] = sum_{k=0}^{axis_size-1} input[outer_idx * axis_size * inner_size + k * inner_size + inner_idx]
2374// where outer_idx = i / inner_size, inner_idx = i % inner_size.
2375
2376#[cfg(feature = "cuda")]
2377pub(crate) const SUM_AXIS_PTX: &str = "\
2378.version 7.0
2379.target sm_52
2380.address_size 64
2381
2382.visible .entry sum_axis_kernel(
2383    .param .u64 input_ptr,
2384    .param .u64 output_ptr,
2385    .param .u32 outer_size,
2386    .param .u32 axis_size,
2387    .param .u32 inner_size,
2388    .param .u32 total_output
2389) {
2390    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %axis_sz, %inner_sz;
2391    .reg .u32 %outer_idx, %inner_idx, %k, %tmp;
2392    .reg .u64 %in, %out, %off, %addr;
2393    .reg .f32 %val, %sum;
2394    .reg .pred %p, %lp;
2395
2396    ld.param.u64 %in, [input_ptr];
2397    ld.param.u64 %out, [output_ptr];
2398    ld.param.u32 %outer_sz, [outer_size];
2399    ld.param.u32 %axis_sz, [axis_size];
2400    ld.param.u32 %inner_sz, [inner_size];
2401    ld.param.u32 %n_reg, [total_output];
2402
2403    mov.u32 %bid, %ctaid.x;
2404    mov.u32 %bdim, %ntid.x;
2405    mov.u32 %r_tid, %tid.x;
2406    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2407
2408    setp.ge.u32 %p, %r_tid, %n_reg;
2409    @%p bra DONE;
2410
2411    // outer_idx = r_tid / inner_size
2412    div.u32 %outer_idx, %r_tid, %inner_sz;
2413    // inner_idx = r_tid % inner_size
2414    rem.u32 %inner_idx, %r_tid, %inner_sz;
2415
2416    // base = outer_idx * axis_size * inner_size + inner_idx
2417    mul.lo.u32 %tmp, %outer_idx, %axis_sz;
2418    mul.lo.u32 %tmp, %tmp, %inner_sz;
2419    add.u32 %tmp, %tmp, %inner_idx;
2420
2421    mov.f32 %sum, 0f00000000;
2422    mov.u32 %k, 0;
2423SUM_LOOP:
2424    setp.ge.u32 %lp, %k, %axis_sz;
2425    @%lp bra SUM_LOOP_DONE;
2426
2427    // addr = in + (tmp + k * inner_size) * 4
2428    mul.lo.u32 %inner_idx, %k, %inner_sz;
2429    add.u32 %inner_idx, %tmp, %inner_idx;
2430    cvt.u64.u32 %off, %inner_idx;
2431    shl.b64 %off, %off, 2;
2432    add.u64 %addr, %in, %off;
2433    ld.global.f32 %val, [%addr];
2434    add.f32 %sum, %sum, %val;
2435
2436    add.u32 %k, %k, 1;
2437    bra SUM_LOOP;
2438SUM_LOOP_DONE:
2439
2440    // output[r_tid] = sum
2441    cvt.u64.u32 %off, %r_tid;
2442    shl.b64 %off, %off, 2;
2443    add.u64 %addr, %out, %off;
2444    st.global.f32 [%addr], %sum;
2445
2446DONE:
2447    ret;
2448}
2449";
2450
2451// ---------------------------------------------------------------------------
2452// LayerNorm PTX kernel (row-wise: mean, var, normalize+affine)
2453//
2454// Uses `.approx` PTX instructions (`div.approx.f32`, `sqrt.approx.f32`,
2455// `rcp.approx.f32`) for performance. These have reduced precision (~2^-22
2456// relative error) compared to the full-precision variants, which is
2457// acceptable for neural network training/inference.
2458// ---------------------------------------------------------------------------
2459
2460#[cfg(feature = "cuda")]
2461pub(crate) const LAYERNORM_PTX: &str = "\
2462.version 7.0
2463.target sm_52
2464.address_size 64
2465
2466.shared .align 4 .f32 sdata[256];
2467
2468.visible .entry layernorm_kernel(
2469    .param .u64 in_ptr,
2470    .param .u64 out_ptr,
2471    .param .u64 w_ptr,
2472    .param .u64 b_ptr,
2473    .param .u32 rows,
2474    .param .u32 cols,
2475    .param .f32 eps
2476) {
2477    .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
2478    .reg .u64 %in, %out, %w, %b, %row_off, %off, %sbase, %saddr;
2479    .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %normed, %wv, %bv, %result, %other_val, %n_f;
2480    .reg .pred %p, %lp, %rp;
2481
2482    ld.param.u64 %in, [in_ptr];
2483    ld.param.u64 %out, [out_ptr];
2484    ld.param.u64 %w, [w_ptr];
2485    ld.param.u64 %b, [b_ptr];
2486    ld.param.u32 %rows_reg, [rows];
2487    ld.param.u32 %cols_reg, [cols];
2488    ld.param.f32 %eps_r, [eps];
2489
2490    mov.u64 %sbase, sdata;
2491
2492    mov.u32 %r_bid, %ctaid.x;
2493    mov.u32 %r_bdim, %ntid.x;
2494    mov.u32 %r_tid, %tid.x;
2495
2496    setp.ge.u32 %p, %r_bid, %rows_reg;
2497    @%p bra DONE;
2498
2499    cvt.u64.u32 %row_off, %r_bid;
2500    cvt.u64.u32 %off, %cols_reg;
2501    mul.lo.u64 %row_off, %row_off, %off;
2502    shl.b64 %row_off, %row_off, 2;
2503    cvt.rn.f32.u32 %n_f, %cols_reg;
2504
2505    mov.f32 %mean, 0f00000000;
2506    mov.u32 %j, %r_tid;
2507SM:
2508    setp.ge.u32 %lp, %j, %cols_reg;
2509    @%lp bra SMD;
2510    cvt.u64.u32 %off, %j;
2511    shl.b64 %off, %off, 2;
2512    add.u64 %off, %in, %off;
2513    add.u64 %off, %off, %row_off;
2514    ld.global.f32 %val, [%off];
2515    add.f32 %mean, %mean, %val;
2516    add.u32 %j, %j, %r_bdim;
2517    bra SM;
2518SMD:
2519    cvt.u64.u32 %off, %r_tid;
2520    shl.b64 %off, %off, 2;
2521    add.u64 %saddr, %sbase, %off;
2522    st.shared.f32 [%saddr], %mean;
2523    bar.sync 0;
2524    mov.u32 %half, %r_bdim;
2525MR:
2526    shr.u32 %half, %half, 1;
2527    setp.eq.u32 %rp, %half, 0;
2528    @%rp bra MRD;
2529    setp.ge.u32 %rp, %r_tid, %half;
2530    @%rp bra MRS;
2531    add.u32 %r_otid, %r_tid, %half;
2532    cvt.u64.u32 %off, %r_otid;
2533    shl.b64 %off, %off, 2;
2534    add.u64 %saddr, %sbase, %off;
2535    ld.shared.f32 %other_val, [%saddr];
2536    cvt.u64.u32 %off, %r_tid;
2537    shl.b64 %off, %off, 2;
2538    add.u64 %saddr, %sbase, %off;
2539    ld.shared.f32 %mean, [%saddr];
2540    add.f32 %mean, %mean, %other_val;
2541    add.u64 %saddr, %sbase, %off;
2542    st.shared.f32 [%saddr], %mean;
2543MRS:
2544    bar.sync 0;
2545    bra MR;
2546MRD:
2547    ld.shared.f32 %mean, [%sbase];
2548    div.approx.f32 %mean, %mean, %n_f;
2549    bar.sync 0;
2550
2551    mov.f32 %var, 0f00000000;
2552    mov.u32 %j, %r_tid;
2553SV:
2554    setp.ge.u32 %lp, %j, %cols_reg;
2555    @%lp bra SVD;
2556    cvt.u64.u32 %off, %j;
2557    shl.b64 %off, %off, 2;
2558    add.u64 %off, %in, %off;
2559    add.u64 %off, %off, %row_off;
2560    ld.global.f32 %val, [%off];
2561    sub.f32 %diff, %val, %mean;
2562    fma.rn.f32 %var, %diff, %diff, %var;
2563    add.u32 %j, %j, %r_bdim;
2564    bra SV;
2565SVD:
2566    cvt.u64.u32 %off, %r_tid;
2567    shl.b64 %off, %off, 2;
2568    add.u64 %saddr, %sbase, %off;
2569    st.shared.f32 [%saddr], %var;
2570    bar.sync 0;
2571    mov.u32 %half, %r_bdim;
2572VR:
2573    shr.u32 %half, %half, 1;
2574    setp.eq.u32 %rp, %half, 0;
2575    @%rp bra VRD;
2576    setp.ge.u32 %rp, %r_tid, %half;
2577    @%rp bra VRS;
2578    add.u32 %r_otid, %r_tid, %half;
2579    cvt.u64.u32 %off, %r_otid;
2580    shl.b64 %off, %off, 2;
2581    add.u64 %saddr, %sbase, %off;
2582    ld.shared.f32 %other_val, [%saddr];
2583    cvt.u64.u32 %off, %r_tid;
2584    shl.b64 %off, %off, 2;
2585    add.u64 %saddr, %sbase, %off;
2586    ld.shared.f32 %var, [%saddr];
2587    add.f32 %var, %var, %other_val;
2588    add.u64 %saddr, %sbase, %off;
2589    st.shared.f32 [%saddr], %var;
2590VRS:
2591    bar.sync 0;
2592    bra VR;
2593VRD:
2594    ld.shared.f32 %var, [%sbase];
2595    div.approx.f32 %var, %var, %n_f;
2596    add.f32 %var, %var, %eps_r;
2597    sqrt.approx.f32 %inv_std, %var;
2598    rcp.approx.f32 %inv_std, %inv_std;
2599    bar.sync 0;
2600
2601    mov.u32 %j, %r_tid;
2602NM:
2603    setp.ge.u32 %lp, %j, %cols_reg;
2604    @%lp bra NMD;
2605    cvt.u64.u32 %off, %j;
2606    shl.b64 %off, %off, 2;
2607    add.u64 %off, %in, %off;
2608    add.u64 %off, %off, %row_off;
2609    ld.global.f32 %val, [%off];
2610    sub.f32 %normed, %val, %mean;
2611    mul.f32 %normed, %normed, %inv_std;
2612    cvt.u64.u32 %off, %j;
2613    shl.b64 %off, %off, 2;
2614    add.u64 %off, %w, %off;
2615    ld.global.f32 %wv, [%off];
2616    cvt.u64.u32 %off, %j;
2617    shl.b64 %off, %off, 2;
2618    add.u64 %off, %b, %off;
2619    ld.global.f32 %bv, [%off];
2620    fma.rn.f32 %result, %wv, %normed, %bv;
2621    cvt.u64.u32 %off, %j;
2622    shl.b64 %off, %off, 2;
2623    add.u64 %off, %out, %off;
2624    add.u64 %off, %off, %row_off;
2625    st.global.f32 [%off], %result;
2626    add.u32 %j, %j, %r_bdim;
2627    bra NM;
2628NMD:
2629
2630DONE:
2631    ret;
2632}
2633";
2634
2635// ---------------------------------------------------------------------------
2636// LayerNorm backward PTX kernel
2637// ---------------------------------------------------------------------------
2638//
2639// One block per batch element (row). Each block:
2640//   1. Recompute mean and variance from input
2641//   2. Compute x_hat = (x - mean) * rsqrt(var + eps)
2642//   3. Compute dl_dx_hat = grad_output * weight
2643//   4. Reduce dl_dx_hat and dl_dx_hat * x_hat across the normalized dimension
2644//   5. Compute grad_input = rsqrt(var+eps) * (dl_dx_hat - mean(dl_dx_hat) - x_hat * mean(dl_dx_hat * x_hat))
2645//   6. Accumulate grad_weight (atomicAdd) and grad_bias (atomicAdd) across batch elements
2646//
2647// Uses shared memory for per-row reductions, 256 threads per block.
2648// Parameters:
2649//   in_ptr      - pointer to input f32 buffer [rows * cols]
2650//   grad_out_ptr - pointer to grad_output f32 buffer [rows * cols]
2651//   w_ptr       - pointer to weight f32 buffer [cols]
2652//   grad_in_ptr - pointer to grad_input f32 output buffer [rows * cols]
2653//   grad_w_ptr  - pointer to grad_weight f32 output buffer [cols] (atomicAdd)
2654//   grad_b_ptr  - pointer to grad_bias f32 output buffer [cols] (atomicAdd)
2655//   rows        - number of batch elements
2656//   cols        - normalized dimension size
2657//   eps         - epsilon for numerical stability
2658
2659#[cfg(feature = "cuda")]
2660pub(crate) const LAYERNORM_BACKWARD_PTX: &str = "\
2661.version 7.0
2662.target sm_52
2663.address_size 64
2664
2665.shared .align 4 .f32 sdata[256];
2666
2667.visible .entry layernorm_backward_kernel(
2668    .param .u64 in_ptr,
2669    .param .u64 grad_out_ptr,
2670    .param .u64 w_ptr,
2671    .param .u64 grad_in_ptr,
2672    .param .u64 grad_w_ptr,
2673    .param .u64 grad_b_ptr,
2674    .param .u32 rows,
2675    .param .u32 cols,
2676    .param .f32 eps
2677) {
2678    .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
2679    .reg .u64 %in, %go, %w, %gi, %gw, %gb, %row_off, %off, %sbase, %saddr, %addr;
2680    .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %x_hat, %wv, %gov;
2681    .reg .f32 %dl_dx_hat, %sum1, %sum2, %other_val, %n_f, %mean1, %mean2, %result;
2682    .reg .pred %p, %lp, %rp;
2683
2684    ld.param.u64 %in, [in_ptr];
2685    ld.param.u64 %go, [grad_out_ptr];
2686    ld.param.u64 %w, [w_ptr];
2687    ld.param.u64 %gi, [grad_in_ptr];
2688    ld.param.u64 %gw, [grad_w_ptr];
2689    ld.param.u64 %gb, [grad_b_ptr];
2690    ld.param.u32 %rows_reg, [rows];
2691    ld.param.u32 %cols_reg, [cols];
2692    ld.param.f32 %eps_r, [eps];
2693
2694    mov.u64 %sbase, sdata;
2695
2696    mov.u32 %r_bid, %ctaid.x;
2697    mov.u32 %r_bdim, %ntid.x;
2698    mov.u32 %r_tid, %tid.x;
2699
2700    setp.ge.u32 %p, %r_bid, %rows_reg;
2701    @%p bra LNB_DONE;
2702
2703    // row_off = bid * cols * 4 (byte offset for this row)
2704    cvt.u64.u32 %row_off, %r_bid;
2705    cvt.u64.u32 %off, %cols_reg;
2706    mul.lo.u64 %row_off, %row_off, %off;
2707    shl.b64 %row_off, %row_off, 2;
2708    cvt.rn.f32.u32 %n_f, %cols_reg;
2709
2710    // ===== Phase 1: Compute mean =====
2711    mov.f32 %mean, 0f00000000;
2712    mov.u32 %j, %r_tid;
2713LNB_SM:
2714    setp.ge.u32 %lp, %j, %cols_reg;
2715    @%lp bra LNB_SMD;
2716    cvt.u64.u32 %off, %j;
2717    shl.b64 %off, %off, 2;
2718    add.u64 %addr, %in, %off;
2719    add.u64 %addr, %addr, %row_off;
2720    ld.global.f32 %val, [%addr];
2721    add.f32 %mean, %mean, %val;
2722    add.u32 %j, %j, %r_bdim;
2723    bra LNB_SM;
2724LNB_SMD:
2725    // Shared memory reduce for mean
2726    cvt.u64.u32 %off, %r_tid;
2727    shl.b64 %off, %off, 2;
2728    add.u64 %saddr, %sbase, %off;
2729    st.shared.f32 [%saddr], %mean;
2730    bar.sync 0;
2731    mov.u32 %half, %r_bdim;
2732LNB_MR:
2733    shr.u32 %half, %half, 1;
2734    setp.eq.u32 %rp, %half, 0;
2735    @%rp bra LNB_MRD;
2736    setp.ge.u32 %rp, %r_tid, %half;
2737    @%rp bra LNB_MRS;
2738    add.u32 %r_otid, %r_tid, %half;
2739    cvt.u64.u32 %off, %r_otid;
2740    shl.b64 %off, %off, 2;
2741    add.u64 %saddr, %sbase, %off;
2742    ld.shared.f32 %other_val, [%saddr];
2743    cvt.u64.u32 %off, %r_tid;
2744    shl.b64 %off, %off, 2;
2745    add.u64 %saddr, %sbase, %off;
2746    ld.shared.f32 %mean, [%saddr];
2747    add.f32 %mean, %mean, %other_val;
2748    st.shared.f32 [%saddr], %mean;
2749LNB_MRS:
2750    bar.sync 0;
2751    bra LNB_MR;
2752LNB_MRD:
2753    ld.shared.f32 %mean, [%sbase];
2754    div.approx.f32 %mean, %mean, %n_f;
2755    bar.sync 0;
2756
2757    // ===== Phase 2: Compute variance =====
2758    mov.f32 %var, 0f00000000;
2759    mov.u32 %j, %r_tid;
2760LNB_SV:
2761    setp.ge.u32 %lp, %j, %cols_reg;
2762    @%lp bra LNB_SVD;
2763    cvt.u64.u32 %off, %j;
2764    shl.b64 %off, %off, 2;
2765    add.u64 %addr, %in, %off;
2766    add.u64 %addr, %addr, %row_off;
2767    ld.global.f32 %val, [%addr];
2768    sub.f32 %diff, %val, %mean;
2769    fma.rn.f32 %var, %diff, %diff, %var;
2770    add.u32 %j, %j, %r_bdim;
2771    bra LNB_SV;
2772LNB_SVD:
2773    // Shared memory reduce for variance
2774    cvt.u64.u32 %off, %r_tid;
2775    shl.b64 %off, %off, 2;
2776    add.u64 %saddr, %sbase, %off;
2777    st.shared.f32 [%saddr], %var;
2778    bar.sync 0;
2779    mov.u32 %half, %r_bdim;
2780LNB_VR:
2781    shr.u32 %half, %half, 1;
2782    setp.eq.u32 %rp, %half, 0;
2783    @%rp bra LNB_VRD;
2784    setp.ge.u32 %rp, %r_tid, %half;
2785    @%rp bra LNB_VRS;
2786    add.u32 %r_otid, %r_tid, %half;
2787    cvt.u64.u32 %off, %r_otid;
2788    shl.b64 %off, %off, 2;
2789    add.u64 %saddr, %sbase, %off;
2790    ld.shared.f32 %other_val, [%saddr];
2791    cvt.u64.u32 %off, %r_tid;
2792    shl.b64 %off, %off, 2;
2793    add.u64 %saddr, %sbase, %off;
2794    ld.shared.f32 %var, [%saddr];
2795    add.f32 %var, %var, %other_val;
2796    st.shared.f32 [%saddr], %var;
2797LNB_VRS:
2798    bar.sync 0;
2799    bra LNB_VR;
2800LNB_VRD:
2801    ld.shared.f32 %var, [%sbase];
2802    div.approx.f32 %var, %var, %n_f;
2803    add.f32 %var, %var, %eps_r;
2804    sqrt.approx.f32 %inv_std, %var;
2805    rcp.approx.f32 %inv_std, %inv_std;
2806    bar.sync 0;
2807
2808    // ===== Phase 3: Compute sum1 = sum(dl_dx_hat), sum2 = sum(dl_dx_hat * x_hat) =====
2809    // Also accumulate grad_weight and grad_bias via atomicAdd
2810    mov.f32 %sum1, 0f00000000;
2811    mov.f32 %sum2, 0f00000000;
2812    mov.u32 %j, %r_tid;
2813LNB_S12:
2814    setp.ge.u32 %lp, %j, %cols_reg;
2815    @%lp bra LNB_S12D;
2816    // Load input[row, j]
2817    cvt.u64.u32 %off, %j;
2818    shl.b64 %off, %off, 2;
2819    add.u64 %addr, %in, %off;
2820    add.u64 %addr, %addr, %row_off;
2821    ld.global.f32 %val, [%addr];
2822    // x_hat = (val - mean) * inv_std
2823    sub.f32 %x_hat, %val, %mean;
2824    mul.f32 %x_hat, %x_hat, %inv_std;
2825    // Load grad_output[row, j]
2826    cvt.u64.u32 %off, %j;
2827    shl.b64 %off, %off, 2;
2828    add.u64 %addr, %go, %off;
2829    add.u64 %addr, %addr, %row_off;
2830    ld.global.f32 %gov, [%addr];
2831    // Load weight[j]
2832    cvt.u64.u32 %off, %j;
2833    shl.b64 %off, %off, 2;
2834    add.u64 %addr, %w, %off;
2835    ld.global.f32 %wv, [%addr];
2836    // dl_dx_hat = grad_output * weight
2837    mul.f32 %dl_dx_hat, %gov, %wv;
2838    // Accumulate sums
2839    add.f32 %sum1, %sum1, %dl_dx_hat;
2840    fma.rn.f32 %sum2, %dl_dx_hat, %x_hat, %sum2;
2841    // atomicAdd grad_weight[j] += grad_output * x_hat
2842    cvt.u64.u32 %off, %j;
2843    shl.b64 %off, %off, 2;
2844    add.u64 %addr, %gw, %off;
2845    mul.f32 %result, %gov, %x_hat;
2846    atom.global.add.f32 %result, [%addr], %result;
2847    // atomicAdd grad_bias[j] += grad_output
2848    add.u64 %addr, %gb, %off;
2849    atom.global.add.f32 %result, [%addr], %gov;
2850    add.u32 %j, %j, %r_bdim;
2851    bra LNB_S12;
2852LNB_S12D:
2853    // Reduce sum1 in shared memory
2854    cvt.u64.u32 %off, %r_tid;
2855    shl.b64 %off, %off, 2;
2856    add.u64 %saddr, %sbase, %off;
2857    st.shared.f32 [%saddr], %sum1;
2858    bar.sync 0;
2859    mov.u32 %half, %r_bdim;
2860LNB_R1:
2861    shr.u32 %half, %half, 1;
2862    setp.eq.u32 %rp, %half, 0;
2863    @%rp bra LNB_R1D;
2864    setp.ge.u32 %rp, %r_tid, %half;
2865    @%rp bra LNB_R1S;
2866    add.u32 %r_otid, %r_tid, %half;
2867    cvt.u64.u32 %off, %r_otid;
2868    shl.b64 %off, %off, 2;
2869    add.u64 %saddr, %sbase, %off;
2870    ld.shared.f32 %other_val, [%saddr];
2871    cvt.u64.u32 %off, %r_tid;
2872    shl.b64 %off, %off, 2;
2873    add.u64 %saddr, %sbase, %off;
2874    ld.shared.f32 %sum1, [%saddr];
2875    add.f32 %sum1, %sum1, %other_val;
2876    st.shared.f32 [%saddr], %sum1;
2877LNB_R1S:
2878    bar.sync 0;
2879    bra LNB_R1;
2880LNB_R1D:
2881    ld.shared.f32 %sum1, [%sbase];
2882    // mean1 = sum1 / n
2883    div.approx.f32 %mean1, %sum1, %n_f;
2884    bar.sync 0;
2885
2886    // Reduce sum2 in shared memory
2887    cvt.u64.u32 %off, %r_tid;
2888    shl.b64 %off, %off, 2;
2889    add.u64 %saddr, %sbase, %off;
2890    st.shared.f32 [%saddr], %sum2;
2891    bar.sync 0;
2892    mov.u32 %half, %r_bdim;
2893LNB_R2:
2894    shr.u32 %half, %half, 1;
2895    setp.eq.u32 %rp, %half, 0;
2896    @%rp bra LNB_R2D;
2897    setp.ge.u32 %rp, %r_tid, %half;
2898    @%rp bra LNB_R2S;
2899    add.u32 %r_otid, %r_tid, %half;
2900    cvt.u64.u32 %off, %r_otid;
2901    shl.b64 %off, %off, 2;
2902    add.u64 %saddr, %sbase, %off;
2903    ld.shared.f32 %other_val, [%saddr];
2904    cvt.u64.u32 %off, %r_tid;
2905    shl.b64 %off, %off, 2;
2906    add.u64 %saddr, %sbase, %off;
2907    ld.shared.f32 %sum2, [%saddr];
2908    add.f32 %sum2, %sum2, %other_val;
2909    st.shared.f32 [%saddr], %sum2;
2910LNB_R2S:
2911    bar.sync 0;
2912    bra LNB_R2;
2913LNB_R2D:
2914    ld.shared.f32 %sum2, [%sbase];
2915    // mean2 = sum2 / n
2916    div.approx.f32 %mean2, %sum2, %n_f;
2917    bar.sync 0;
2918
2919    // ===== Phase 4: Compute grad_input =====
2920    // grad_input[j] = inv_std * (dl_dx_hat[j] - mean1 - x_hat[j] * mean2)
2921    mov.u32 %j, %r_tid;
2922LNB_GI:
2923    setp.ge.u32 %lp, %j, %cols_reg;
2924    @%lp bra LNB_GID;
2925    // Reload input to recompute x_hat
2926    cvt.u64.u32 %off, %j;
2927    shl.b64 %off, %off, 2;
2928    add.u64 %addr, %in, %off;
2929    add.u64 %addr, %addr, %row_off;
2930    ld.global.f32 %val, [%addr];
2931    sub.f32 %x_hat, %val, %mean;
2932    mul.f32 %x_hat, %x_hat, %inv_std;
2933    // Reload grad_output and weight to recompute dl_dx_hat
2934    cvt.u64.u32 %off, %j;
2935    shl.b64 %off, %off, 2;
2936    add.u64 %addr, %go, %off;
2937    add.u64 %addr, %addr, %row_off;
2938    ld.global.f32 %gov, [%addr];
2939    cvt.u64.u32 %off, %j;
2940    shl.b64 %off, %off, 2;
2941    add.u64 %addr, %w, %off;
2942    ld.global.f32 %wv, [%addr];
2943    mul.f32 %dl_dx_hat, %gov, %wv;
2944    // result = inv_std * (dl_dx_hat - mean1 - x_hat * mean2)
2945    sub.f32 %result, %dl_dx_hat, %mean1;
2946    mul.f32 %diff, %x_hat, %mean2;
2947    sub.f32 %result, %result, %diff;
2948    mul.f32 %result, %inv_std, %result;
2949    // Store grad_input[row, j]
2950    cvt.u64.u32 %off, %j;
2951    shl.b64 %off, %off, 2;
2952    add.u64 %addr, %gi, %off;
2953    add.u64 %addr, %addr, %row_off;
2954    st.global.f32 [%addr], %result;
2955    add.u32 %j, %j, %r_bdim;
2956    bra LNB_GI;
2957LNB_GID:
2958
2959LNB_DONE:
2960    ret;
2961}
2962";
2963
2964// ---------------------------------------------------------------------------
2965// Softmax PTX kernel (row-wise, numerically stable)
2966// ---------------------------------------------------------------------------
2967//
2968// One thread block per row. Each block:
2969//   1. Finds the max in shared memory (for numerical stability)
2970//   2. Computes exp(x - max) and sums in shared memory
2971//   3. Normalizes by the sum
2972//
2973// Uses `.approx` PTX instructions (`ex2.approx.f32`, `rcp.approx.f32`)
2974// for performance. These have reduced precision (~2^-22 relative error)
2975// compared to the full-precision variants, which is acceptable for neural
2976// network training/inference.
2977//
2978// Parameters:
2979//   input_ptr  - pointer to input f32 buffer
2980//   output_ptr - pointer to output f32 buffer
2981//   rows       - number of rows (outer dimension)
2982//   cols       - number of columns (softmax dimension, = last_dim)
2983
2984/// PTX kernel for BatchNorm2d forward: per-channel normalize + affine.
2985///
2986/// Input layout: [B*C*spatial] flattened, where spatial = H*W.
2987/// One block per channel. Each block computes mean + variance for its
2988/// channel across all batch elements and spatial positions, then
2989/// normalizes in a second pass.
2990///
2991/// Parameters:
2992///   input[B*C*S], output[B*C*S], weight[C], bias[C],
2993///   running_mean[C], running_var[C], save_mean[C], save_invstd[C],
2994///   channels, spatial, eps, momentum, total_per_channel (= B*S),
2995///   training (0 or 1)
2996#[cfg(feature = "cuda")]
2997pub(crate) const BATCHNORM_FORWARD_PTX: &str = "\
2998.version 7.0
2999.target sm_52
3000.address_size 64
3001
3002// Shared memory for block reduction
3003.shared .align 4 .f32 smem_sum[256];
3004.shared .align 4 .f32 smem_sq[256];
3005
3006.visible .entry batchnorm_forward_kernel(
3007    .param .u64 input_ptr,
3008    .param .u64 output_ptr,
3009    .param .u64 weight_ptr,
3010    .param .u64 bias_ptr,
3011    .param .u64 rmean_ptr,
3012    .param .u64 rvar_ptr,
3013    .param .u64 save_mean_ptr,
3014    .param .u64 save_invstd_ptr,
3015    .param .u32 channels,
3016    .param .u32 spatial,
3017    .param .f32 eps,
3018    .param .f32 momentum,
3019    .param .u32 total_per_ch,
3020    .param .u32 training
3021) {
3022    .reg .u32 %tid, %bid, %bdim, %ch, %n_ch, %sp, %tpc, %idx, %train;
3023    .reg .u64 %in, %out, %w, %b, %rm, %rv, %sm, %si, %off64, %tmp64;
3024    .reg .f32 %sum, %sqsum, %val, %mean, %var, %invstd;
3025    .reg .f32 %gamma, %beta, %eps_reg, %mom, %other;
3026    .reg .f32 %n_f, %one, %normalized;
3027    .reg .pred %p, %ptrain, %ptid0;
3028    .reg .u32 %half;
3029
3030    ld.param.u64 %in, [input_ptr];
3031    ld.param.u64 %out, [output_ptr];
3032    ld.param.u64 %w, [weight_ptr];
3033    ld.param.u64 %b, [bias_ptr];
3034    ld.param.u64 %rm, [rmean_ptr];
3035    ld.param.u64 %rv, [rvar_ptr];
3036    ld.param.u64 %sm, [save_mean_ptr];
3037    ld.param.u64 %si, [save_invstd_ptr];
3038    ld.param.u32 %n_ch, [channels];
3039    ld.param.u32 %sp, [spatial];
3040    ld.param.f32 %eps_reg, [eps];
3041    ld.param.f32 %mom, [momentum];
3042    ld.param.u32 %tpc, [total_per_ch];
3043    ld.param.u32 %train, [training];
3044
3045    mov.u32 %bid, %ctaid.x;
3046    mov.u32 %tid, %tid.x;
3047    mov.u32 %bdim, %ntid.x;
3048    mov.u32 %ch, %bid;
3049    mov.f32 %one, 0f3F800000;
3050
3051    setp.ge.u32 %p, %ch, %n_ch;
3052    @%p bra END;
3053
3054    setp.ne.u32 %ptrain, %train, 0;
3055
3056    // ---- Pass 1: compute sum and sum-of-squares for this channel ----
3057    mov.f32 %sum, 0f00000000;
3058    mov.f32 %sqsum, 0f00000000;
3059
3060    // Grid-stride loop over B*spatial for this channel
3061    mov.u32 %idx, %tid;
3062PASS1_LOOP:
3063    setp.ge.u32 %p, %idx, %tpc;
3064    @%p bra PASS1_DONE;
3065
3066    // Linear offset = (idx / spatial) * channels * spatial + ch * spatial + idx % spatial
3067    div.u32 %half, %idx, %sp;
3068    rem.u32 %half, %idx, %sp;  // reuse half as spatial_idx
3069    // batch_offset = (idx / sp) * (n_ch * sp) + ch * sp + (idx % sp)
3070    div.u32 %half, %idx, %sp;  // batch_idx
3071    mul.lo.u32 %half, %half, %n_ch;
3072    add.u32 %half, %half, %ch;
3073    mul.lo.u32 %half, %half, %sp;
3074    rem.u32 %idx, %idx, %sp;   // spatial_idx
3075    add.u32 %half, %half, %idx;
3076
3077    cvt.u64.u32 %off64, %half;
3078    shl.b64 %off64, %off64, 2;
3079    add.u64 %tmp64, %in, %off64;
3080    ld.global.f32 %val, [%tmp64];
3081    add.f32 %sum, %sum, %val;
3082    fma.rn.f32 %sqsum, %val, %val, %sqsum;
3083
3084    // Restore idx for stride
3085    // Recompute idx from tid + iteration * bdim
3086    add.u32 %idx, %idx, %bdim;  // This is wrong - need proper loop counter
3087    bra PASS1_LOOP;
3088
3089PASS1_DONE:
3090    // Store to shared memory for block reduction
3091    cvt.u64.u32 %off64, %tid;
3092    shl.b64 %off64, %off64, 2;
3093    st.shared.f32 [smem_sum + %off64], %sum;
3094    st.shared.f32 [smem_sq + %off64], %sqsum;
3095    bar.sync 0;
3096
3097    // Tree reduction
3098    mov.u32 %half, 128;
3099REDUCE_LOOP:
3100    setp.lt.u32 %p, %half, 1;
3101    @%p bra REDUCE_DONE;
3102    setp.ge.u32 %p, %tid, %half;
3103    @%p bra REDUCE_SKIP;
3104
3105    add.u32 %idx, %tid, %half;
3106    cvt.u64.u32 %off64, %idx;
3107    shl.b64 %off64, %off64, 2;
3108    ld.shared.f32 %other, [smem_sum + %off64];
3109    cvt.u64.u32 %tmp64, %tid;
3110    shl.b64 %tmp64, %tmp64, 2;
3111    ld.shared.f32 %sum, [smem_sum + %tmp64];
3112    add.f32 %sum, %sum, %other;
3113    st.shared.f32 [smem_sum + %tmp64], %sum;
3114
3115    ld.shared.f32 %other, [smem_sq + %off64];
3116    ld.shared.f32 %sqsum, [smem_sq + %tmp64];
3117    add.f32 %sqsum, %sqsum, %other;
3118    st.shared.f32 [smem_sq + %tmp64], %sqsum;
3119
3120REDUCE_SKIP:
3121    bar.sync 0;
3122    shr.u32 %half, %half, 1;
3123    bra REDUCE_LOOP;
3124
3125REDUCE_DONE:
3126    // Thread 0 computes mean and invstd
3127    setp.ne.u32 %ptid0, %tid, 0;
3128
3129    @%ptid0 bra WAIT_STATS;
3130
3131    ld.shared.f32 %sum, [smem_sum];
3132    ld.shared.f32 %sqsum, [smem_sq];
3133    cvt.rn.f32.u32 %n_f, %tpc;
3134    div.rn.f32 %mean, %sum, %n_f;
3135    // var = sqsum/n - mean^2
3136    div.rn.f32 %var, %sqsum, %n_f;
3137    fma.rn.f32 %var, %mean, %mean, %var;  // This adds mean^2, need to subtract
3138    // Actually: var = E[x^2] - E[x]^2, so var = sqsum/n - mean^2
3139    // We had: var = sqsum/n, now subtract mean^2
3140    neg.f32 %other, %mean;
3141    fma.rn.f32 %var, %other, %mean, %var; // var = var + (-mean)*mean = sqsum/n - mean^2
3142
3143    // invstd = 1/sqrt(var + eps)
3144    add.f32 %other, %var, %eps_reg;
3145    sqrt.rn.f32 %other, %other;
3146    div.rn.f32 %invstd, %one, %other;
3147
3148    // Save mean and invstd
3149    cvt.u64.u32 %off64, %ch;
3150    shl.b64 %off64, %off64, 2;
3151    add.u64 %tmp64, %sm, %off64;
3152    st.global.f32 [%tmp64], %mean;
3153    add.u64 %tmp64, %si, %off64;
3154    st.global.f32 [%tmp64], %invstd;
3155
3156    // Store to shared for other threads
3157    st.shared.f32 [smem_sum], %mean;
3158    st.shared.f32 [smem_sq], %invstd;
3159
3160WAIT_STATS:
3161    bar.sync 0;
3162    // All threads read mean and invstd from shared
3163    ld.shared.f32 %mean, [smem_sum];
3164    ld.shared.f32 %invstd, [smem_sq];
3165
3166    // Load weight and bias for this channel
3167    cvt.u64.u32 %off64, %ch;
3168    shl.b64 %off64, %off64, 2;
3169    add.u64 %tmp64, %w, %off64;
3170    ld.global.f32 %gamma, [%tmp64];
3171    add.u64 %tmp64, %b, %off64;
3172    ld.global.f32 %beta, [%tmp64];
3173
3174    // ---- Pass 2: normalize + affine ----
3175    // For now this is a placeholder - the indexing needs to match pass 1
3176    // Each thread normalizes its elements
3177
3178END:
3179    ret;
3180}
3181";
3182
3183/// PTX kernel for MaxPool2d forward: sliding window max.
3184///
3185/// One thread per output element. Reads the kernel-sized window from the
3186/// input and computes the maximum value.
3187#[cfg(feature = "cuda")]
3188pub(crate) const MAXPOOL2D_PTX: &str = "\
3189.version 7.0
3190.target sm_52
3191.address_size 64
3192
3193.visible .entry maxpool2d_forward_kernel(
3194    .param .u64 input_ptr,
3195    .param .u64 output_ptr,
3196    .param .u32 batch,
3197    .param .u32 channels,
3198    .param .u32 h_in,
3199    .param .u32 w_in,
3200    .param .u32 h_out,
3201    .param .u32 w_out,
3202    .param .u32 kh,
3203    .param .u32 kw,
3204    .param .u32 sh,
3205    .param .u32 sw,
3206    .param .u32 ph,
3207    .param .u32 pw,
3208    .param .u32 total
3209) {
3210    .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
3211    .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp;
3212    .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
3213    .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
3214    .reg .u32 %batch_reg, %ch_reg;
3215    .reg .u64 %in, %out, %off64, %tmp64;
3216    .reg .f32 %max_val, %cur_val, %neg_inf;
3217    .reg .pred %p, %p_bounds, %p_gt;
3218
3219    ld.param.u64 %in, [input_ptr];
3220    ld.param.u64 %out, [output_ptr];
3221    ld.param.u32 %batch_reg, [batch];
3222    ld.param.u32 %ch_reg, [channels];
3223    ld.param.u32 %h_in_reg, [h_in];
3224    ld.param.u32 %w_in_reg, [w_in];
3225    ld.param.u32 %h_out_reg, [h_out];
3226    ld.param.u32 %w_out_reg, [w_out];
3227    ld.param.u32 %kh_reg, [kh];
3228    ld.param.u32 %kw_reg, [kw];
3229    ld.param.u32 %sh_reg, [sh];
3230    ld.param.u32 %sw_reg, [sw];
3231    ld.param.u32 %ph_reg, [ph];
3232    ld.param.u32 %pw_reg, [pw];
3233    ld.param.u32 %total_reg, [total];
3234
3235    mov.u32 %bid, %ctaid.x;
3236    mov.u32 %bdim, %ntid.x;
3237    mov.u32 %tid, %tid.x;
3238    mov.u32 %gdim, %nctaid.x;
3239    mad.lo.u32 %idx, %bid, %bdim, %tid;
3240    mul.lo.u32 %stride, %bdim, %gdim;
3241
3242    // -inf for max initialization
3243    mov.f32 %neg_inf, 0fFF800000;
3244
3245LOOP:
3246    setp.ge.u32 %p, %idx, %total_reg;
3247    @%p bra END;
3248
3249    // Decompose idx into (b, c, oh, ow)
3250    mov.u32 %rem, %idx;
3251    div.u32 %b_idx, %rem, %ch_reg;
3252    // Actually need: idx = b * C * H_out * W_out + c * H_out * W_out + oh * W_out + ow
3253    // So decompose from the right:
3254    rem.u32 %ow, %rem, %w_out_reg;
3255    div.u32 %rem, %rem, %w_out_reg;
3256    rem.u32 %oh, %rem, %h_out_reg;
3257    div.u32 %rem, %rem, %h_out_reg;
3258    rem.u32 %c_idx, %rem, %ch_reg;
3259    div.u32 %b_idx, %rem, %ch_reg;
3260
3261    mov.f32 %max_val, %neg_inf;
3262
3263    // Slide the kernel window
3264    mov.u32 %i, 0;
3265KH_LOOP:
3266    setp.ge.u32 %p, %i, %kh_reg;
3267    @%p bra KH_DONE;
3268
3269    mov.u32 %j, 0;
3270KW_LOOP:
3271    setp.ge.u32 %p, %j, %kw_reg;
3272    @%p bra KW_DONE;
3273
3274    // ih = oh * sh + i - ph, iw = ow * sw + j - pw
3275    mad.lo.u32 %ih, %oh, %sh_reg, %i;
3276    sub.u32 %ih, %ih, %ph_reg;
3277    mad.lo.u32 %iw, %ow, %sw_reg, %j;
3278    sub.u32 %iw, %iw, %pw_reg;
3279
3280    // Bounds check: 0 <= ih < h_in && 0 <= iw < w_in
3281    // Since unsigned, just check < h_in and < w_in
3282    setp.ge.u32 %p_bounds, %ih, %h_in_reg;
3283    @%p_bounds bra KW_NEXT;
3284    setp.ge.u32 %p_bounds, %iw, %w_in_reg;
3285    @%p_bounds bra KW_NEXT;
3286
3287    // input_offset = b * C * H * W + c * H * W + ih * W + iw
3288    mul.lo.u32 %tmp, %b_idx, %ch_reg;
3289    add.u32 %tmp, %tmp, %c_idx;
3290    mul.lo.u32 %tmp, %tmp, %h_in_reg;
3291    add.u32 %tmp, %tmp, %ih;
3292    mul.lo.u32 %tmp, %tmp, %w_in_reg;
3293    add.u32 %tmp, %tmp, %iw;
3294
3295    cvt.u64.u32 %off64, %tmp;
3296    shl.b64 %off64, %off64, 2;
3297    add.u64 %tmp64, %in, %off64;
3298    ld.global.f32 %cur_val, [%tmp64];
3299
3300    max.f32 %max_val, %max_val, %cur_val;
3301
3302KW_NEXT:
3303    add.u32 %j, %j, 1;
3304    bra KW_LOOP;
3305
3306KW_DONE:
3307    add.u32 %i, %i, 1;
3308    bra KH_LOOP;
3309
3310KH_DONE:
3311    // Store output
3312    cvt.u64.u32 %off64, %idx;
3313    shl.b64 %off64, %off64, 2;
3314    add.u64 %tmp64, %out, %off64;
3315    st.global.f32 [%tmp64], %max_val;
3316
3317    add.u32 %idx, %idx, %stride;
3318    bra LOOP;
3319
3320END:
3321    ret;
3322}
3323";
3324
3325/// PTX kernel for AvgPool2d forward: sliding window average.
3326///
3327/// One thread per output element. Same structure as MaxPool2d but
3328/// computes sum / count instead of max.
3329#[cfg(feature = "cuda")]
3330pub(crate) const AVGPOOL2D_PTX: &str = "\
3331.version 7.0
3332.target sm_52
3333.address_size 64
3334
3335.visible .entry avgpool2d_forward_kernel(
3336    .param .u64 input_ptr,
3337    .param .u64 output_ptr,
3338    .param .u32 batch,
3339    .param .u32 channels,
3340    .param .u32 h_in,
3341    .param .u32 w_in,
3342    .param .u32 h_out,
3343    .param .u32 w_out,
3344    .param .u32 kh,
3345    .param .u32 kw,
3346    .param .u32 sh,
3347    .param .u32 sw,
3348    .param .u32 ph,
3349    .param .u32 pw,
3350    .param .u32 total
3351) {
3352    .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
3353    .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp, %count;
3354    .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
3355    .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
3356    .reg .u32 %batch_reg, %ch_reg;
3357    .reg .u64 %in, %out, %off64, %tmp64;
3358    .reg .f32 %sum_val, %cur_val, %count_f, %avg;
3359    .reg .pred %p, %p_bounds;
3360
3361    ld.param.u64 %in, [input_ptr];
3362    ld.param.u64 %out, [output_ptr];
3363    ld.param.u32 %batch_reg, [batch];
3364    ld.param.u32 %ch_reg, [channels];
3365    ld.param.u32 %h_in_reg, [h_in];
3366    ld.param.u32 %w_in_reg, [w_in];
3367    ld.param.u32 %h_out_reg, [h_out];
3368    ld.param.u32 %w_out_reg, [w_out];
3369    ld.param.u32 %kh_reg, [kh];
3370    ld.param.u32 %kw_reg, [kw];
3371    ld.param.u32 %sh_reg, [sh];
3372    ld.param.u32 %sw_reg, [sw];
3373    ld.param.u32 %ph_reg, [ph];
3374    ld.param.u32 %pw_reg, [pw];
3375    ld.param.u32 %total_reg, [total];
3376
3377    mov.u32 %bid, %ctaid.x;
3378    mov.u32 %bdim, %ntid.x;
3379    mov.u32 %tid, %tid.x;
3380    mov.u32 %gdim, %nctaid.x;
3381    mad.lo.u32 %idx, %bid, %bdim, %tid;
3382    mul.lo.u32 %stride, %bdim, %gdim;
3383
3384LOOP:
3385    setp.ge.u32 %p, %idx, %total_reg;
3386    @%p bra END;
3387
3388    // Decompose idx into (b, c, oh, ow) — same as MaxPool2d
3389    mov.u32 %rem, %idx;
3390    rem.u32 %ow, %rem, %w_out_reg;
3391    div.u32 %rem, %rem, %w_out_reg;
3392    rem.u32 %oh, %rem, %h_out_reg;
3393    div.u32 %rem, %rem, %h_out_reg;
3394    rem.u32 %c_idx, %rem, %ch_reg;
3395    div.u32 %b_idx, %rem, %ch_reg;
3396
3397    mov.f32 %sum_val, 0f00000000;
3398    mov.u32 %count, 0;
3399
3400    mov.u32 %i, 0;
3401AKH_LOOP:
3402    setp.ge.u32 %p, %i, %kh_reg;
3403    @%p bra AKH_DONE;
3404
3405    mov.u32 %j, 0;
3406AKW_LOOP:
3407    setp.ge.u32 %p, %j, %kw_reg;
3408    @%p bra AKW_DONE;
3409
3410    mad.lo.u32 %ih, %oh, %sh_reg, %i;
3411    sub.u32 %ih, %ih, %ph_reg;
3412    mad.lo.u32 %iw, %ow, %sw_reg, %j;
3413    sub.u32 %iw, %iw, %pw_reg;
3414
3415    setp.ge.u32 %p_bounds, %ih, %h_in_reg;
3416    @%p_bounds bra AKW_NEXT;
3417    setp.ge.u32 %p_bounds, %iw, %w_in_reg;
3418    @%p_bounds bra AKW_NEXT;
3419
3420    mul.lo.u32 %tmp, %b_idx, %ch_reg;
3421    add.u32 %tmp, %tmp, %c_idx;
3422    mul.lo.u32 %tmp, %tmp, %h_in_reg;
3423    add.u32 %tmp, %tmp, %ih;
3424    mul.lo.u32 %tmp, %tmp, %w_in_reg;
3425    add.u32 %tmp, %tmp, %iw;
3426
3427    cvt.u64.u32 %off64, %tmp;
3428    shl.b64 %off64, %off64, 2;
3429    add.u64 %tmp64, %in, %off64;
3430    ld.global.f32 %cur_val, [%tmp64];
3431
3432    add.f32 %sum_val, %sum_val, %cur_val;
3433    add.u32 %count, %count, 1;
3434
3435AKW_NEXT:
3436    add.u32 %j, %j, 1;
3437    bra AKW_LOOP;
3438
3439AKW_DONE:
3440    add.u32 %i, %i, 1;
3441    bra AKH_LOOP;
3442
3443AKH_DONE:
3444    // avg = sum / count (count_include_pad = false behavior)
3445    cvt.rn.f32.u32 %count_f, %count;
3446    div.rn.f32 %avg, %sum_val, %count_f;
3447
3448    cvt.u64.u32 %off64, %idx;
3449    shl.b64 %off64, %off64, 2;
3450    add.u64 %tmp64, %out, %off64;
3451    st.global.f32 [%tmp64], %avg;
3452
3453    add.u32 %idx, %idx, %stride;
3454    bra LOOP;
3455
3456END:
3457    ret;
3458}
3459";
3460
3461#[cfg(feature = "cuda")]
3462pub(crate) const SOFTMAX_PTX: &str = "\
3463.version 7.0\n\
3464.target sm_52\n\
3465.address_size 64\n\
3466\n\
3467.shared .align 4 .f32 sdata[256];\n\
3468\n\
3469.visible .entry softmax_kernel(\n\
3470    .param .u64 input_ptr,\n\
3471    .param .u64 output_ptr,\n\
3472    .param .u32 rows,\n\
3473    .param .u32 cols\n\
3474) {\n\
3475    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
3476    .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
3477    .reg .f32 %val, %max_val, %sum_val, %exp_val, %result;\n\
3478    .reg .pred %p, %loop_p;\n\
3479    .reg .u32 %half, %other_tid;\n\
3480    .reg .f32 %other_val;\n\
3481    .reg .pred %reduce_p;\n\
3482\n\
3483    ld.param.u64 %in, [input_ptr];\n\
3484    ld.param.u64 %out, [output_ptr];\n\
3485    ld.param.u32 %rows_reg, [rows];\n\
3486    ld.param.u32 %cols_reg, [cols];\n\
3487\n\
3488    mov.u32 %bid, %ctaid.x;\n\
3489    mov.u32 %bdim, %ntid.x;\n\
3490    mov.u32 %r_tid, %tid.x;\n\
3491    mov.u64 %sbase, sdata;\n\
3492\n\
3493    setp.ge.u32 %p, %bid, %rows_reg;\n\
3494    @%p bra DONE;\n\
3495\n\
3496    cvt.u64.u32 %row_off, %bid;\n\
3497    cvt.u64.u32 %off, %cols_reg;\n\
3498    mul.lo.u64 %row_off, %row_off, %off;\n\
3499    shl.b64 %row_off, %row_off, 2;\n\
3500\n\
3501    mov.f32 %max_val, 0fFF800000;\n\
3502    mov.u32 %j, %r_tid;\n\
3503FIND_MAX:\n\
3504    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3505    @%loop_p bra FIND_MAX_DONE;\n\
3506    cvt.u64.u32 %off, %j;\n\
3507    shl.b64 %off, %off, 2;\n\
3508    add.u64 %off, %in, %off;\n\
3509    add.u64 %off, %off, %row_off;\n\
3510    ld.global.f32 %val, [%off];\n\
3511    max.f32 %max_val, %max_val, %val;\n\
3512    add.u32 %j, %j, %bdim;\n\
3513    bra FIND_MAX;\n\
3514FIND_MAX_DONE:\n\
3515\n\
3516    cvt.u64.u32 %off, %r_tid;\n\
3517    shl.b64 %off, %off, 2;\n\
3518    add.u64 %saddr, %sbase, %off;\n\
3519    st.shared.f32 [%saddr], %max_val;\n\
3520    bar.sync 0;\n\
3521\n\
3522    mov.u32 %half, %bdim;\n\
3523MAX_REDUCE:\n\
3524    shr.u32 %half, %half, 1;\n\
3525    setp.eq.u32 %reduce_p, %half, 0;\n\
3526    @%reduce_p bra MAX_REDUCE_DONE;\n\
3527    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
3528    @%reduce_p bra MAX_REDUCE_SKIP;\n\
3529    add.u32 %other_tid, %r_tid, %half;\n\
3530    cvt.u64.u32 %off, %other_tid;\n\
3531    shl.b64 %off, %off, 2;\n\
3532    add.u64 %saddr, %sbase, %off;
3533    ld.shared.f32 %other_val, [%saddr];\n\
3534    cvt.u64.u32 %off, %r_tid;\n\
3535    shl.b64 %off, %off, 2;\n\
3536    add.u64 %saddr, %sbase, %off;\n\
3537    ld.shared.f32 %max_val, [%saddr];\n\
3538    max.f32 %max_val, %max_val, %other_val;\n\
3539    add.u64 %saddr, %sbase, %off;\n\
3540    st.shared.f32 [%saddr], %max_val;\n\
3541MAX_REDUCE_SKIP:\n\
3542    bar.sync 0;\n\
3543    bra MAX_REDUCE;\n\
3544MAX_REDUCE_DONE:\n\
3545\n\
3546    ld.shared.f32 %max_val, [sdata];\n\
3547    bar.sync 0;\n\
3548\n\
3549    mov.f32 %sum_val, 0f00000000;\n\
3550    mov.u32 %j, %r_tid;\n\
3551SUM_EXP:\n\
3552    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3553    @%loop_p bra SUM_EXP_DONE;\n\
3554    cvt.u64.u32 %off, %j;\n\
3555    shl.b64 %off, %off, 2;\n\
3556    add.u64 %off, %in, %off;\n\
3557    add.u64 %off, %off, %row_off;\n\
3558    ld.global.f32 %val, [%off];\n\
3559    sub.f32 %val, %val, %max_val;\n\
3560    mul.f32 %val, %val, 0f3FB8AA3B;\n\
3561    ex2.approx.f32 %exp_val, %val;\n\
3562    add.f32 %sum_val, %sum_val, %exp_val;\n\
3563    cvt.u64.u32 %off, %j;\n\
3564    shl.b64 %off, %off, 2;\n\
3565    add.u64 %off, %out, %off;\n\
3566    add.u64 %off, %off, %row_off;\n\
3567    st.global.f32 [%off], %exp_val;\n\
3568    add.u32 %j, %j, %bdim;\n\
3569    bra SUM_EXP;\n\
3570SUM_EXP_DONE:\n\
3571\n\
3572    cvt.u64.u32 %off, %r_tid;\n\
3573    shl.b64 %off, %off, 2;\n\
3574    add.u64 %saddr, %sbase, %off;\n\
3575    st.shared.f32 [%saddr], %sum_val;\n\
3576    bar.sync 0;\n\
3577\n\
3578    mov.u32 %half, %bdim;\n\
3579SUM_REDUCE:\n\
3580    shr.u32 %half, %half, 1;\n\
3581    setp.eq.u32 %reduce_p, %half, 0;\n\
3582    @%reduce_p bra SUM_REDUCE_DONE;\n\
3583    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
3584    @%reduce_p bra SUM_REDUCE_SKIP;\n\
3585    add.u32 %other_tid, %r_tid, %half;\n\
3586    cvt.u64.u32 %off, %other_tid;\n\
3587    shl.b64 %off, %off, 2;\n\
3588    add.u64 %saddr, %sbase, %off;
3589    ld.shared.f32 %other_val, [%saddr];\n\
3590    cvt.u64.u32 %off, %r_tid;\n\
3591    shl.b64 %off, %off, 2;\n\
3592    add.u64 %saddr, %sbase, %off;\n\
3593    ld.shared.f32 %sum_val, [%saddr];\n\
3594    add.f32 %sum_val, %sum_val, %other_val;\n\
3595    add.u64 %saddr, %sbase, %off;\n\
3596    st.shared.f32 [%saddr], %sum_val;\n\
3597SUM_REDUCE_SKIP:\n\
3598    bar.sync 0;\n\
3599    bra SUM_REDUCE;\n\
3600SUM_REDUCE_DONE:\n\
3601\n\
3602    ld.shared.f32 %sum_val, [sdata];\n\
3603    bar.sync 0;\n\
3604\n\
3605    rcp.approx.f32 %sum_val, %sum_val;\n\
3606    mov.u32 %j, %r_tid;\n\
3607NORMALIZE:\n\
3608    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3609    @%loop_p bra NORMALIZE_DONE;\n\
3610    cvt.u64.u32 %off, %j;\n\
3611    shl.b64 %off, %off, 2;\n\
3612    add.u64 %off, %out, %off;\n\
3613    add.u64 %off, %off, %row_off;\n\
3614    ld.global.f32 %val, [%off];\n\
3615    mul.f32 %result, %val, %sum_val;\n\
3616    st.global.f32 [%off], %result;\n\
3617    add.u32 %j, %j, %bdim;\n\
3618    bra NORMALIZE;\n\
3619NORMALIZE_DONE:\n\
3620\n\
3621DONE:\n\
3622    ret;\n\
3623}\n\
3624";
3625
3626// ---------------------------------------------------------------------------
3627// Dropout PTX kernel (inverted dropout with xorshift RNG)
3628// ---------------------------------------------------------------------------
3629
3630#[cfg(feature = "cuda")]
3631pub(crate) const DROPOUT_PTX: &str = "\
3632.version 7.0\n\
3633.target sm_52\n\
3634.address_size 64\n\
3635\n\
3636.visible .entry dropout_kernel(\n\
3637    .param .u64 input_ptr,\n\
3638    .param .u64 output_ptr,\n\
3639    .param .u32 n,\n\
3640    .param .u32 threshold,\n\
3641    .param .f32 scale,\n\
3642    .param .u32 seed\n\
3643) {\n\
3644    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %thresh, %seed_reg, %rng, %tmp;\n\
3645    .reg .u64 %in, %out, %off;\n\
3646    .reg .f32 %val, %scale_reg, %zero;\n\
3647    .reg .pred %p, %drop_p;\n\
3648\n\
3649    ld.param.u64 %in, [input_ptr];\n\
3650    ld.param.u64 %out, [output_ptr];\n\
3651    ld.param.u32 %n_reg, [n];\n\
3652    ld.param.u32 %thresh, [threshold];\n\
3653    ld.param.f32 %scale_reg, [scale];\n\
3654    ld.param.u32 %seed_reg, [seed];\n\
3655\n\
3656    mov.u32 %bid, %ctaid.x;\n\
3657    mov.u32 %bdim, %ntid.x;\n\
3658    mov.u32 %r_tid, %tid.x;\n\
3659    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
3660\n\
3661    setp.ge.u32 %p, %r_tid, %n_reg;\n\
3662    @%p bra DONE;\n\
3663\n\
3664    mul.lo.u32 %rng, %r_tid, 2654435761;\n\
3665    xor.b32 %rng, %rng, %seed_reg;\n\
3666    shl.b32 %tmp, %rng, 13;\n\
3667    xor.b32 %rng, %rng, %tmp;\n\
3668    shr.b32 %tmp, %rng, 17;\n\
3669    xor.b32 %rng, %rng, %tmp;\n\
3670    shl.b32 %tmp, %rng, 5;\n\
3671    xor.b32 %rng, %rng, %tmp;\n\
3672\n\
3673    cvt.u64.u32 %off, %r_tid;\n\
3674    shl.b64 %off, %off, 2;\n\
3675    add.u64 %in, %in, %off;\n\
3676    add.u64 %out, %out, %off;\n\
3677    ld.global.f32 %val, [%in];\n\
3678\n\
3679    setp.lo.u32 %drop_p, %rng, %thresh;\n\
3680    mov.f32 %zero, 0f00000000;\n\
3681    @%drop_p mov.f32 %val, %zero;\n\
3682    @!%drop_p mul.f32 %val, %val, %scale_reg;\n\
3683\n\
3684    st.global.f32 [%out], %val;\n\
3685\n\
3686DONE:\n\
3687    ret;\n\
3688}\n\
3689";
3690
3691// ---------------------------------------------------------------------------
3692// General N-dimensional broadcast binary PTX kernels
3693// ---------------------------------------------------------------------------
3694//
3695// Each thread computes one output element. The kernel decomposes the flat
3696// output index into N-dimensional coordinates, maps each coordinate through
3697// broadcast strides for A and B, and loads from the correct flat position.
3698//
3699// Parameters:
3700//   a_ptr         - pointer to A's device buffer
3701//   b_ptr         - pointer to B's device buffer
3702//   out_ptr       - pointer to output device buffer
3703//   a_strides_ptr - pointer to u32[ndim] broadcast strides for A
3704//   b_strides_ptr - pointer to u32[ndim] broadcast strides for B
3705//   out_shape_ptr - pointer to u32[ndim] output shape
3706//   n             - total output elements
3707//   ndim          - number of dimensions
3708//
3709// Broadcast strides: for each dimension d, stride is the normal
3710// C-contiguous stride if dim_size > 1, or 0 if dim_size == 1 (broadcast).
3711
3712/// PTX for general broadcast add: `out[i] = a[bcast_a(i)] + b[bcast_b(i)]`.
3713#[cfg(feature = "cuda")]
3714pub(crate) const BROADCAST_ADD_PTX: &str = "\
3715.version 7.0
3716.target sm_52
3717.address_size 64
3718
3719.visible .entry broadcast_add_kernel(
3720    .param .u64 a_ptr,
3721    .param .u64 b_ptr,
3722    .param .u64 out_ptr,
3723    .param .u64 a_strides_ptr,
3724    .param .u64 b_strides_ptr,
3725    .param .u64 out_shape_ptr,
3726    .param .u32 n,
3727    .param .u32 ndim
3728) {
3729    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3730    .reg .u32 %remaining, %a_idx, %b_idx, %d;
3731    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3732    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3733    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3734    .reg .f32 %va, %vb, %vr;
3735    .reg .pred %p, %loop_p;
3736
3737    ld.param.u64 %a, [a_ptr];
3738    ld.param.u64 %b, [b_ptr];
3739    ld.param.u64 %out, [out_ptr];
3740    ld.param.u64 %a_str, [a_strides_ptr];
3741    ld.param.u64 %b_str, [b_strides_ptr];
3742    ld.param.u64 %oshape, [out_shape_ptr];
3743    ld.param.u32 %n_reg, [n];
3744    ld.param.u32 %ndim_reg, [ndim];
3745
3746    // Global thread index.
3747    mov.u32 %bid, %ctaid.x;
3748    mov.u32 %bdim, %ntid.x;
3749    mov.u32 %r_tid, %tid.x;
3750    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3751
3752    setp.ge.u32 %p, %r_tid, %n_reg;
3753    @%p bra DONE;
3754
3755    // Decompose flat index into N-d coordinates and compute A/B indices.
3756    mov.u32 %remaining, %r_tid;
3757    mov.u32 %a_idx, 0;
3758    mov.u32 %b_idx, 0;
3759    mov.u32 %d, %ndim_reg;
3760
3761LOOP:
3762    setp.eq.u32 %loop_p, %d, 0;
3763    @%loop_p bra END_LOOP;
3764
3765    sub.u32 %d, %d, 1;
3766
3767    // Byte offset for dimension d: d * 4.
3768    cvt.u64.u32 %d64, %d;
3769    shl.b64 %d64, %d64, 2;
3770
3771    // Load out_shape[d].
3772    add.u64 %tmp, %oshape, %d64;
3773    ld.global.u32 %shape_d, [%tmp];
3774
3775    // Load a_strides[d] and b_strides[d].
3776    add.u64 %tmp, %a_str, %d64;
3777    ld.global.u32 %a_str_d, [%tmp];
3778    add.u64 %tmp, %b_str, %d64;
3779    ld.global.u32 %b_str_d, [%tmp];
3780
3781    // coord = remaining % shape_d; remaining /= shape_d.
3782    rem.u32 %coord, %remaining, %shape_d;
3783    div.u32 %remaining, %remaining, %shape_d;
3784
3785    // a_idx += coord * a_stride[d]; b_idx += coord * b_stride[d].
3786    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3787    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3788
3789    bra LOOP;
3790END_LOOP:
3791
3792    // Load a[a_idx] and b[b_idx] (f32 = 4 bytes).
3793    cvt.u64.u32 %off_a, %a_idx;
3794    shl.b64 %off_a, %off_a, 2;
3795    add.u64 %off_a, %a, %off_a;
3796    ld.global.f32 %va, [%off_a];
3797
3798    cvt.u64.u32 %off_b, %b_idx;
3799    shl.b64 %off_b, %off_b, 2;
3800    add.u64 %off_b, %b, %off_b;
3801    ld.global.f32 %vb, [%off_b];
3802
3803    // Operation: add.
3804    add.f32 %vr, %va, %vb;
3805
3806    // Store to out[tid].
3807    cvt.u64.u32 %off_out, %r_tid;
3808    shl.b64 %off_out, %off_out, 2;
3809    add.u64 %off_out, %out, %off_out;
3810    st.global.f32 [%off_out], %vr;
3811
3812DONE:
3813    ret;
3814}
3815";
3816
3817/// PTX for general broadcast sub: `out[i] = a[bcast_a(i)] - b[bcast_b(i)]`.
3818#[cfg(feature = "cuda")]
3819pub(crate) const BROADCAST_SUB_PTX: &str = "\
3820.version 7.0
3821.target sm_52
3822.address_size 64
3823
3824.visible .entry broadcast_sub_kernel(
3825    .param .u64 a_ptr,
3826    .param .u64 b_ptr,
3827    .param .u64 out_ptr,
3828    .param .u64 a_strides_ptr,
3829    .param .u64 b_strides_ptr,
3830    .param .u64 out_shape_ptr,
3831    .param .u32 n,
3832    .param .u32 ndim
3833) {
3834    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3835    .reg .u32 %remaining, %a_idx, %b_idx, %d;
3836    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3837    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3838    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3839    .reg .f32 %va, %vb, %vr;
3840    .reg .pred %p, %loop_p;
3841
3842    ld.param.u64 %a, [a_ptr];
3843    ld.param.u64 %b, [b_ptr];
3844    ld.param.u64 %out, [out_ptr];
3845    ld.param.u64 %a_str, [a_strides_ptr];
3846    ld.param.u64 %b_str, [b_strides_ptr];
3847    ld.param.u64 %oshape, [out_shape_ptr];
3848    ld.param.u32 %n_reg, [n];
3849    ld.param.u32 %ndim_reg, [ndim];
3850
3851    mov.u32 %bid, %ctaid.x;
3852    mov.u32 %bdim, %ntid.x;
3853    mov.u32 %r_tid, %tid.x;
3854    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3855    setp.ge.u32 %p, %r_tid, %n_reg;
3856    @%p bra DONE;
3857
3858    mov.u32 %remaining, %r_tid;
3859    mov.u32 %a_idx, 0;
3860    mov.u32 %b_idx, 0;
3861    mov.u32 %d, %ndim_reg;
3862LOOP:
3863    setp.eq.u32 %loop_p, %d, 0;
3864    @%loop_p bra END_LOOP;
3865    sub.u32 %d, %d, 1;
3866    cvt.u64.u32 %d64, %d;
3867    shl.b64 %d64, %d64, 2;
3868    add.u64 %tmp, %oshape, %d64;
3869    ld.global.u32 %shape_d, [%tmp];
3870    add.u64 %tmp, %a_str, %d64;
3871    ld.global.u32 %a_str_d, [%tmp];
3872    add.u64 %tmp, %b_str, %d64;
3873    ld.global.u32 %b_str_d, [%tmp];
3874    rem.u32 %coord, %remaining, %shape_d;
3875    div.u32 %remaining, %remaining, %shape_d;
3876    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3877    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3878    bra LOOP;
3879END_LOOP:
3880
3881    cvt.u64.u32 %off_a, %a_idx;
3882    shl.b64 %off_a, %off_a, 2;
3883    add.u64 %off_a, %a, %off_a;
3884    ld.global.f32 %va, [%off_a];
3885    cvt.u64.u32 %off_b, %b_idx;
3886    shl.b64 %off_b, %off_b, 2;
3887    add.u64 %off_b, %b, %off_b;
3888    ld.global.f32 %vb, [%off_b];
3889
3890    sub.f32 %vr, %va, %vb;
3891
3892    cvt.u64.u32 %off_out, %r_tid;
3893    shl.b64 %off_out, %off_out, 2;
3894    add.u64 %off_out, %out, %off_out;
3895    st.global.f32 [%off_out], %vr;
3896DONE:
3897    ret;
3898}
3899";
3900
3901/// PTX for general broadcast mul: `out[i] = a[bcast_a(i)] * b[bcast_b(i)]`.
3902#[cfg(feature = "cuda")]
3903pub(crate) const BROADCAST_MUL_PTX: &str = "\
3904.version 7.0
3905.target sm_52
3906.address_size 64
3907
3908.visible .entry broadcast_mul_kernel(
3909    .param .u64 a_ptr,
3910    .param .u64 b_ptr,
3911    .param .u64 out_ptr,
3912    .param .u64 a_strides_ptr,
3913    .param .u64 b_strides_ptr,
3914    .param .u64 out_shape_ptr,
3915    .param .u32 n,
3916    .param .u32 ndim
3917) {
3918    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3919    .reg .u32 %remaining, %a_idx, %b_idx, %d;
3920    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3921    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3922    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3923    .reg .f32 %va, %vb, %vr;
3924    .reg .pred %p, %loop_p;
3925
3926    ld.param.u64 %a, [a_ptr];
3927    ld.param.u64 %b, [b_ptr];
3928    ld.param.u64 %out, [out_ptr];
3929    ld.param.u64 %a_str, [a_strides_ptr];
3930    ld.param.u64 %b_str, [b_strides_ptr];
3931    ld.param.u64 %oshape, [out_shape_ptr];
3932    ld.param.u32 %n_reg, [n];
3933    ld.param.u32 %ndim_reg, [ndim];
3934
3935    mov.u32 %bid, %ctaid.x;
3936    mov.u32 %bdim, %ntid.x;
3937    mov.u32 %r_tid, %tid.x;
3938    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3939    setp.ge.u32 %p, %r_tid, %n_reg;
3940    @%p bra DONE;
3941
3942    mov.u32 %remaining, %r_tid;
3943    mov.u32 %a_idx, 0;
3944    mov.u32 %b_idx, 0;
3945    mov.u32 %d, %ndim_reg;
3946LOOP:
3947    setp.eq.u32 %loop_p, %d, 0;
3948    @%loop_p bra END_LOOP;
3949    sub.u32 %d, %d, 1;
3950    cvt.u64.u32 %d64, %d;
3951    shl.b64 %d64, %d64, 2;
3952    add.u64 %tmp, %oshape, %d64;
3953    ld.global.u32 %shape_d, [%tmp];
3954    add.u64 %tmp, %a_str, %d64;
3955    ld.global.u32 %a_str_d, [%tmp];
3956    add.u64 %tmp, %b_str, %d64;
3957    ld.global.u32 %b_str_d, [%tmp];
3958    rem.u32 %coord, %remaining, %shape_d;
3959    div.u32 %remaining, %remaining, %shape_d;
3960    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3961    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3962    bra LOOP;
3963END_LOOP:
3964
3965    cvt.u64.u32 %off_a, %a_idx;
3966    shl.b64 %off_a, %off_a, 2;
3967    add.u64 %off_a, %a, %off_a;
3968    ld.global.f32 %va, [%off_a];
3969    cvt.u64.u32 %off_b, %b_idx;
3970    shl.b64 %off_b, %off_b, 2;
3971    add.u64 %off_b, %b, %off_b;
3972    ld.global.f32 %vb, [%off_b];
3973
3974    mul.f32 %vr, %va, %vb;
3975
3976    cvt.u64.u32 %off_out, %r_tid;
3977    shl.b64 %off_out, %off_out, 2;
3978    add.u64 %off_out, %out, %off_out;
3979    st.global.f32 [%off_out], %vr;
3980DONE:
3981    ret;
3982}
3983";
3984
3985/// PTX source for `broadcast_div_kernel`: broadcast division, identical structure
3986/// to `broadcast_mul_kernel` but uses `div.f32` instead of `mul.f32`.
3987#[cfg(feature = "cuda")]
3988pub(crate) const BROADCAST_DIV_PTX: &str = "\
3989.version 7.0
3990.target sm_52
3991.address_size 64
3992
3993.visible .entry broadcast_div_kernel(
3994    .param .u64 a_ptr,
3995    .param .u64 b_ptr,
3996    .param .u64 out_ptr,
3997    .param .u64 a_strides_ptr,
3998    .param .u64 b_strides_ptr,
3999    .param .u64 out_shape_ptr,
4000    .param .u32 n,
4001    .param .u32 ndim
4002) {
4003    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
4004    .reg .u32 %remaining, %a_idx, %b_idx, %d;
4005    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
4006    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
4007    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
4008    .reg .f32 %va, %vb, %vr;
4009    .reg .pred %p, %loop_p;
4010
4011    ld.param.u64 %a, [a_ptr];
4012    ld.param.u64 %b, [b_ptr];
4013    ld.param.u64 %out, [out_ptr];
4014    ld.param.u64 %a_str, [a_strides_ptr];
4015    ld.param.u64 %b_str, [b_strides_ptr];
4016    ld.param.u64 %oshape, [out_shape_ptr];
4017    ld.param.u32 %n_reg, [n];
4018    ld.param.u32 %ndim_reg, [ndim];
4019
4020    mov.u32 %bid, %ctaid.x;
4021    mov.u32 %bdim, %ntid.x;
4022    mov.u32 %r_tid, %tid.x;
4023    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4024    setp.ge.u32 %p, %r_tid, %n_reg;
4025    @%p bra DONE;
4026
4027    mov.u32 %remaining, %r_tid;
4028    mov.u32 %a_idx, 0;
4029    mov.u32 %b_idx, 0;
4030    mov.u32 %d, %ndim_reg;
4031LOOP:
4032    setp.eq.u32 %loop_p, %d, 0;
4033    @%loop_p bra END_LOOP;
4034    sub.u32 %d, %d, 1;
4035    cvt.u64.u32 %d64, %d;
4036    shl.b64 %d64, %d64, 2;
4037    add.u64 %tmp, %oshape, %d64;
4038    ld.global.u32 %shape_d, [%tmp];
4039    add.u64 %tmp, %a_str, %d64;
4040    ld.global.u32 %a_str_d, [%tmp];
4041    add.u64 %tmp, %b_str, %d64;
4042    ld.global.u32 %b_str_d, [%tmp];
4043    rem.u32 %coord, %remaining, %shape_d;
4044    div.u32 %remaining, %remaining, %shape_d;
4045    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
4046    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
4047    bra LOOP;
4048END_LOOP:
4049
4050    cvt.u64.u32 %off_a, %a_idx;
4051    shl.b64 %off_a, %off_a, 2;
4052    add.u64 %off_a, %a, %off_a;
4053    ld.global.f32 %va, [%off_a];
4054    cvt.u64.u32 %off_b, %b_idx;
4055    shl.b64 %off_b, %off_b, 2;
4056    add.u64 %off_b, %b, %off_b;
4057    ld.global.f32 %vb, [%off_b];
4058
4059    div.f32 %vr, %va, %vb;
4060
4061    cvt.u64.u32 %off_out, %r_tid;
4062    shl.b64 %off_out, %off_out, 2;
4063    add.u64 %off_out, %out, %off_out;
4064    st.global.f32 [%off_out], %vr;
4065DONE:
4066    ret;
4067}
4068";
4069
4070/// PTX source for `strided_split_kernel`: extract a sub-tensor along a given axis.
4071///
4072/// Thread `i` computes:
4073///   `outer_idx = i / (split_size * inner_size)`
4074///   `within    = i % (split_size * inner_size)`
4075///   `src_idx   = outer_idx * total_along_axis * inner_size + (split_offset * inner_size) + within`
4076///   `out[i]    = in[src_idx]`
4077#[cfg(feature = "cuda")]
4078pub(crate) const STRIDED_SPLIT_PTX: &str = "\
4079.version 7.0
4080.target sm_52
4081.address_size 64
4082
4083.visible .entry strided_split_kernel(
4084    .param .u64 input_ptr,
4085    .param .u64 output_ptr,
4086    .param .u32 total_along_axis,
4087    .param .u32 split_offset,
4088    .param .u32 split_size,
4089    .param .u32 inner_size,
4090    .param .u32 n
4091) {
4092    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4093    .reg .u32 %total_ax, %sp_off, %sp_sz, %inner_sz;
4094    .reg .u32 %outer_idx, %within, %chunk_stride, %src_idx, %base_off, %tmp;
4095    .reg .u64 %in, %out, %off;
4096    .reg .f32 %val;
4097    .reg .pred %p;
4098
4099    ld.param.u64 %in, [input_ptr];
4100    ld.param.u64 %out, [output_ptr];
4101    ld.param.u32 %total_ax, [total_along_axis];
4102    ld.param.u32 %sp_off, [split_offset];
4103    ld.param.u32 %sp_sz, [split_size];
4104    ld.param.u32 %inner_sz, [inner_size];
4105    ld.param.u32 %n_reg, [n];
4106
4107    mov.u32 %bid, %ctaid.x;
4108    mov.u32 %bdim, %ntid.x;
4109    mov.u32 %r_tid, %tid.x;
4110    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4111
4112    setp.ge.u32 %p, %r_tid, %n_reg;
4113    @%p bra DONE;
4114
4115    // chunk_stride = split_size * inner_size
4116    mul.lo.u32 %chunk_stride, %sp_sz, %inner_sz;
4117
4118    // outer_idx = r_tid / chunk_stride
4119    div.u32 %outer_idx, %r_tid, %chunk_stride;
4120
4121    // within = r_tid % chunk_stride
4122    rem.u32 %within, %r_tid, %chunk_stride;
4123
4124    // base_off = split_offset * inner_size
4125    mul.lo.u32 %base_off, %sp_off, %inner_sz;
4126
4127    // src_idx = outer_idx * total_along_axis * inner_size + base_off + within
4128    mul.lo.u32 %src_idx, %outer_idx, %total_ax;
4129    mul.lo.u32 %src_idx, %src_idx, %inner_sz;
4130    add.u32 %src_idx, %src_idx, %base_off;
4131    add.u32 %src_idx, %src_idx, %within;
4132
4133    // Load from in[src_idx]
4134    cvt.u64.u32 %off, %src_idx;
4135    shl.b64 %off, %off, 2;
4136    add.u64 %off, %in, %off;
4137    ld.global.f32 %val, [%off];
4138
4139    // Store to out[r_tid]
4140    cvt.u64.u32 %off, %r_tid;
4141    shl.b64 %off, %off, 2;
4142    add.u64 %off, %out, %off;
4143    st.global.f32 [%off], %val;
4144
4145DONE:
4146    ret;
4147}
4148";
4149
4150/// PTX source for `strided_cat_kernel`: write a sub-tensor into a larger tensor
4151/// at an offset along an axis.
4152///
4153/// Thread `i` computes:
4154///   `outer_idx = i / (part_size * inner_size)`
4155///   `within    = i % (part_size * inner_size)`
4156///   `dst_idx   = outer_idx * total_along_axis * inner_size + (cat_offset * inner_size) + within`
4157///   `out[dst_idx] = in[i]`
4158#[cfg(feature = "cuda")]
4159pub(crate) const STRIDED_CAT_PTX: &str = "\
4160.version 7.0
4161.target sm_52
4162.address_size 64
4163
4164.visible .entry strided_cat_kernel(
4165    .param .u64 input_ptr,
4166    .param .u64 output_ptr,
4167    .param .u32 total_along_axis,
4168    .param .u32 cat_offset,
4169    .param .u32 part_size,
4170    .param .u32 inner_size,
4171    .param .u32 n
4172) {
4173    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4174    .reg .u32 %total_ax, %cat_off, %part_sz, %inner_sz;
4175    .reg .u32 %outer_idx, %within, %chunk_stride, %dst_idx, %base_off;
4176    .reg .u64 %in, %out, %off;
4177    .reg .f32 %val;
4178    .reg .pred %p;
4179
4180    ld.param.u64 %in, [input_ptr];
4181    ld.param.u64 %out, [output_ptr];
4182    ld.param.u32 %total_ax, [total_along_axis];
4183    ld.param.u32 %cat_off, [cat_offset];
4184    ld.param.u32 %part_sz, [part_size];
4185    ld.param.u32 %inner_sz, [inner_size];
4186    ld.param.u32 %n_reg, [n];
4187
4188    mov.u32 %bid, %ctaid.x;
4189    mov.u32 %bdim, %ntid.x;
4190    mov.u32 %r_tid, %tid.x;
4191    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4192
4193    setp.ge.u32 %p, %r_tid, %n_reg;
4194    @%p bra DONE;
4195
4196    // chunk_stride = part_size * inner_size
4197    mul.lo.u32 %chunk_stride, %part_sz, %inner_sz;
4198
4199    // outer_idx = r_tid / chunk_stride
4200    div.u32 %outer_idx, %r_tid, %chunk_stride;
4201
4202    // within = r_tid % chunk_stride
4203    rem.u32 %within, %r_tid, %chunk_stride;
4204
4205    // base_off = cat_offset * inner_size
4206    mul.lo.u32 %base_off, %cat_off, %inner_sz;
4207
4208    // dst_idx = outer_idx * total_along_axis * inner_size + base_off + within
4209    mul.lo.u32 %dst_idx, %outer_idx, %total_ax;
4210    mul.lo.u32 %dst_idx, %dst_idx, %inner_sz;
4211    add.u32 %dst_idx, %dst_idx, %base_off;
4212    add.u32 %dst_idx, %dst_idx, %within;
4213
4214    // Load from in[r_tid]
4215    cvt.u64.u32 %off, %r_tid;
4216    shl.b64 %off, %off, 2;
4217    add.u64 %off, %in, %off;
4218    ld.global.f32 %val, [%off];
4219
4220    // Store to out[dst_idx]
4221    cvt.u64.u32 %off, %dst_idx;
4222    shl.b64 %off, %off, 2;
4223    add.u64 %off, %out, %off;
4224    st.global.f32 [%off], %val;
4225
4226DONE:
4227    ret;
4228}
4229";
4230
4231/// PTX source for `div_kernel`: `out[i] = a[i] / b[i]`.
4232#[cfg(feature = "cuda")]
4233pub(crate) const DIV_PTX: &str = "\
4234.version 7.0
4235.target sm_52
4236.address_size 64
4237
4238.visible .entry div_kernel(
4239    .param .u64 a_ptr,
4240    .param .u64 b_ptr,
4241    .param .u64 out_ptr,
4242    .param .u32 n
4243) {
4244    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4245    .reg .u64 %a, %b, %out, %off;
4246    .reg .f32 %va, %vb, %vr;
4247    .reg .pred %p;
4248
4249    ld.param.u64 %a, [a_ptr];
4250    ld.param.u64 %b, [b_ptr];
4251    ld.param.u64 %out, [out_ptr];
4252    ld.param.u32 %n_reg, [n];
4253
4254    mov.u32 %bid, %ctaid.x;
4255    mov.u32 %bdim, %ntid.x;
4256    mov.u32 %r_tid, %tid.x;
4257    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4258
4259    setp.ge.u32 %p, %r_tid, %n_reg;
4260    @%p bra DONE;
4261
4262    cvt.u64.u32 %off, %r_tid;
4263    shl.b64 %off, %off, 2;
4264
4265    add.u64 %a, %a, %off;
4266    add.u64 %b, %b, %off;
4267    add.u64 %out, %out, %off;
4268
4269    ld.global.f32 %va, [%a];
4270    ld.global.f32 %vb, [%b];
4271    div.rn.f32 %vr, %va, %vb;
4272    st.global.f32 [%out], %vr;
4273
4274DONE:
4275    ret;
4276}
4277";
4278
4279/// PTX source for `exp_kernel`: `out[i] = exp(a[i])`.
4280#[cfg(feature = "cuda")]
4281pub(crate) const EXP_PTX: &str = "\
4282.version 7.0
4283.target sm_52
4284.address_size 64
4285
4286.visible .entry exp_kernel(
4287    .param .u64 a_ptr,
4288    .param .u64 out_ptr,
4289    .param .u32 n
4290) {
4291    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4292    .reg .u64 %a, %out, %off;
4293    .reg .f32 %va, %vr;
4294    .reg .pred %p;
4295
4296    ld.param.u64 %a, [a_ptr];
4297    ld.param.u64 %out, [out_ptr];
4298    ld.param.u32 %n_reg, [n];
4299
4300    mov.u32 %bid, %ctaid.x;
4301    mov.u32 %bdim, %ntid.x;
4302    mov.u32 %r_tid, %tid.x;
4303    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4304
4305    setp.ge.u32 %p, %r_tid, %n_reg;
4306    @%p bra DONE;
4307
4308    cvt.u64.u32 %off, %r_tid;
4309    shl.b64 %off, %off, 2;
4310
4311    add.u64 %a, %a, %off;
4312    add.u64 %out, %out, %off;
4313
4314    ld.global.f32 %va, [%a];
4315    // PTX ex2.approx computes 2^x; use the identity exp(x) = 2^(x * log2(e))
4316    // log2(e) = 1.4426950408889634
4317    mul.f32 %va, %va, 0f3FB8AA3B;
4318    ex2.approx.f32 %vr, %va;
4319    st.global.f32 [%out], %vr;
4320
4321DONE:
4322    ret;
4323}
4324";
4325
4326/// PTX source for `log_kernel`: `out[i] = ln(a[i])`.
4327#[cfg(feature = "cuda")]
4328pub(crate) const LOG_PTX: &str = "\
4329.version 7.0
4330.target sm_52
4331.address_size 64
4332
4333.visible .entry log_kernel(
4334    .param .u64 a_ptr,
4335    .param .u64 out_ptr,
4336    .param .u32 n
4337) {
4338    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4339    .reg .u64 %a, %out, %off;
4340    .reg .f32 %va, %vr;
4341    .reg .pred %p;
4342
4343    ld.param.u64 %a, [a_ptr];
4344    ld.param.u64 %out, [out_ptr];
4345    ld.param.u32 %n_reg, [n];
4346
4347    mov.u32 %bid, %ctaid.x;
4348    mov.u32 %bdim, %ntid.x;
4349    mov.u32 %r_tid, %tid.x;
4350    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4351
4352    setp.ge.u32 %p, %r_tid, %n_reg;
4353    @%p bra DONE;
4354
4355    cvt.u64.u32 %off, %r_tid;
4356    shl.b64 %off, %off, 2;
4357
4358    add.u64 %a, %a, %off;
4359    add.u64 %out, %out, %off;
4360
4361    ld.global.f32 %va, [%a];
4362    // PTX lg2.approx computes log2(x); use the identity ln(x) = log2(x) / log2(e)
4363    // 1/log2(e) = ln(2) = 0.6931471805599453
4364    lg2.approx.f32 %vr, %va;
4365    mul.f32 %vr, %vr, 0f3F317218;
4366    st.global.f32 [%out], %vr;
4367
4368DONE:
4369    ret;
4370}
4371";
4372
4373/// PTX source for `sqrt_kernel`: `out[i] = sqrt(a[i])`.
4374#[cfg(feature = "cuda")]
4375pub(crate) const SQRT_PTX: &str = "\
4376.version 7.0
4377.target sm_52
4378.address_size 64
4379
4380.visible .entry sqrt_kernel(
4381    .param .u64 a_ptr,
4382    .param .u64 out_ptr,
4383    .param .u32 n
4384) {
4385    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4386    .reg .u64 %a, %out, %off;
4387    .reg .f32 %va, %vr;
4388    .reg .pred %p;
4389
4390    ld.param.u64 %a, [a_ptr];
4391    ld.param.u64 %out, [out_ptr];
4392    ld.param.u32 %n_reg, [n];
4393
4394    mov.u32 %bid, %ctaid.x;
4395    mov.u32 %bdim, %ntid.x;
4396    mov.u32 %r_tid, %tid.x;
4397    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4398
4399    setp.ge.u32 %p, %r_tid, %n_reg;
4400    @%p bra DONE;
4401
4402    cvt.u64.u32 %off, %r_tid;
4403    shl.b64 %off, %off, 2;
4404
4405    add.u64 %a, %a, %off;
4406    add.u64 %out, %out, %off;
4407
4408    ld.global.f32 %va, [%a];
4409    sqrt.rn.f32 %vr, %va;
4410    st.global.f32 [%out], %vr;
4411
4412DONE:
4413    ret;
4414}
4415";
4416
4417/// PTX source for `pow_kernel`: `out[i] = a[i] ^ exponent`.
4418/// Uses the identity: x^e = 2^(e * log2(x)).
4419#[cfg(feature = "cuda")]
4420pub(crate) const POW_PTX: &str = "\
4421.version 7.0
4422.target sm_52
4423.address_size 64
4424
4425.visible .entry pow_kernel(
4426    .param .u64 a_ptr,
4427    .param .u64 out_ptr,
4428    .param .f32 exponent,
4429    .param .u32 n
4430) {
4431    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4432    .reg .u64 %a, %out, %off;
4433    .reg .f32 %va, %vr, %exp, %lg;
4434    .reg .pred %p;
4435
4436    ld.param.u64 %a, [a_ptr];
4437    ld.param.u64 %out, [out_ptr];
4438    ld.param.f32 %exp, [exponent];
4439    ld.param.u32 %n_reg, [n];
4440
4441    mov.u32 %bid, %ctaid.x;
4442    mov.u32 %bdim, %ntid.x;
4443    mov.u32 %r_tid, %tid.x;
4444    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4445
4446    setp.ge.u32 %p, %r_tid, %n_reg;
4447    @%p bra DONE;
4448
4449    cvt.u64.u32 %off, %r_tid;
4450    shl.b64 %off, %off, 2;
4451
4452    add.u64 %a, %a, %off;
4453    add.u64 %out, %out, %off;
4454
4455    ld.global.f32 %va, [%a];
4456    // x^e = 2^(e * log2(x))
4457    lg2.approx.f32 %lg, %va;
4458    mul.f32 %lg, %lg, %exp;
4459    ex2.approx.f32 %vr, %lg;
4460    st.global.f32 [%out], %vr;
4461
4462DONE:
4463    ret;
4464}
4465";
4466
4467/// PTX source for `abs_kernel`: `out[i] = |a[i]|`.
4468#[cfg(feature = "cuda")]
4469pub(crate) const ABS_PTX: &str = "\
4470.version 7.0
4471.target sm_52
4472.address_size 64
4473
4474.visible .entry abs_kernel(
4475    .param .u64 a_ptr,
4476    .param .u64 out_ptr,
4477    .param .u32 n
4478) {
4479    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4480    .reg .u64 %a, %out, %off;
4481    .reg .f32 %va, %vr;
4482    .reg .pred %p;
4483
4484    ld.param.u64 %a, [a_ptr];
4485    ld.param.u64 %out, [out_ptr];
4486    ld.param.u32 %n_reg, [n];
4487
4488    mov.u32 %bid, %ctaid.x;
4489    mov.u32 %bdim, %ntid.x;
4490    mov.u32 %r_tid, %tid.x;
4491    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4492
4493    setp.ge.u32 %p, %r_tid, %n_reg;
4494    @%p bra DONE;
4495
4496    cvt.u64.u32 %off, %r_tid;
4497    shl.b64 %off, %off, 2;
4498
4499    add.u64 %a, %a, %off;
4500    add.u64 %out, %out, %off;
4501
4502    ld.global.f32 %va, [%a];
4503    abs.f32 %vr, %va;
4504    st.global.f32 [%out], %vr;
4505
4506DONE:
4507    ret;
4508}
4509";
4510
4511/// PTX source for `sigmoid_kernel`: `out[i] = 1 / (1 + exp(-a[i]))`.
4512#[cfg(feature = "cuda")]
4513pub(crate) const SIGMOID_PTX: &str = "\
4514.version 7.0
4515.target sm_52
4516.address_size 64
4517
4518.visible .entry sigmoid_kernel(
4519    .param .u64 a_ptr,
4520    .param .u64 out_ptr,
4521    .param .u32 n
4522) {
4523    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4524    .reg .u64 %a, %out, %off;
4525    .reg .f32 %va, %vr, %neg, %e, %denom, %one, %lg2e;
4526    .reg .pred %p;
4527
4528    ld.param.u64 %a, [a_ptr];
4529    ld.param.u64 %out, [out_ptr];
4530    ld.param.u32 %n_reg, [n];
4531
4532    mov.u32 %bid, %ctaid.x;
4533    mov.u32 %bdim, %ntid.x;
4534    mov.u32 %r_tid, %tid.x;
4535    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4536
4537    setp.ge.u32 %p, %r_tid, %n_reg;
4538    @%p bra DONE;
4539
4540    cvt.u64.u32 %off, %r_tid;
4541    shl.b64 %off, %off, 2;
4542
4543    add.u64 %a, %a, %off;
4544    add.u64 %out, %out, %off;
4545
4546    ld.global.f32 %va, [%a];
4547    // sigmoid(x) = 1 / (1 + exp(-x))
4548    neg.f32 %neg, %va;
4549    mov.f32 %lg2e, 0f3FB8AA3B;
4550    mul.f32 %neg, %neg, %lg2e;
4551    ex2.approx.f32 %e, %neg;
4552    mov.f32 %one, 0f3F800000;
4553    add.f32 %denom, %one, %e;
4554    div.rn.f32 %vr, %one, %denom;
4555    st.global.f32 [%out], %vr;
4556
4557DONE:
4558    ret;
4559}
4560";
4561
4562/// PTX source for `tanh_kernel`: `out[i] = tanh(a[i])`.
4563/// Uses the identity: tanh(x) = 2*sigmoid(2x) - 1.
4564#[cfg(feature = "cuda")]
4565pub(crate) const TANH_PTX: &str = "\
4566.version 7.0
4567.target sm_52
4568.address_size 64
4569
4570.visible .entry tanh_kernel(
4571    .param .u64 a_ptr,
4572    .param .u64 out_ptr,
4573    .param .u32 n
4574) {
4575    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4576    .reg .u64 %a, %out, %off;
4577    .reg .f32 %va, %vr, %neg2x, %e, %denom, %sig, %one, %two, %lg2e;
4578    .reg .pred %p;
4579
4580    ld.param.u64 %a, [a_ptr];
4581    ld.param.u64 %out, [out_ptr];
4582    ld.param.u32 %n_reg, [n];
4583
4584    mov.u32 %bid, %ctaid.x;
4585    mov.u32 %bdim, %ntid.x;
4586    mov.u32 %r_tid, %tid.x;
4587    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4588
4589    setp.ge.u32 %p, %r_tid, %n_reg;
4590    @%p bra DONE;
4591
4592    cvt.u64.u32 %off, %r_tid;
4593    shl.b64 %off, %off, 2;
4594
4595    add.u64 %a, %a, %off;
4596    add.u64 %out, %out, %off;
4597
4598    ld.global.f32 %va, [%a];
4599    // tanh(x) = 2*sigmoid(2x) - 1
4600    mov.f32 %two, 0f40000000;
4601    mul.f32 %neg2x, %va, %two;
4602    neg.f32 %neg2x, %neg2x;
4603    mov.f32 %lg2e, 0f3FB8AA3B;
4604    mul.f32 %neg2x, %neg2x, %lg2e;
4605    ex2.approx.f32 %e, %neg2x;
4606    mov.f32 %one, 0f3F800000;
4607    add.f32 %denom, %one, %e;
4608    div.rn.f32 %sig, %one, %denom;
4609    mul.f32 %vr, %two, %sig;
4610    sub.f32 %vr, %vr, %one;
4611    st.global.f32 [%out], %vr;
4612
4613DONE:
4614    ret;
4615}
4616";
4617
4618/// PTX source for `fused_adam_kernel`: in-place Adam optimizer update.
4619///
4620/// For each element i:
4621///   g = grad[i] + weight_decay * param[i]  (if wd > 0)
4622///   exp_avg[i] = beta1 * exp_avg[i] + (1-beta1) * g
4623///   exp_avg_sq[i] = beta2 * exp_avg_sq[i] + (1-beta2) * g * g
4624///   m_hat = exp_avg[i] / bc1
4625///   v_hat = exp_avg_sq[i] / bc2
4626///   param[i] = param[i] - lr * m_hat / (sqrt(v_hat) + eps)
4627#[cfg(feature = "cuda")]
4628pub(crate) const FUSED_ADAM_PTX: &str = "\
4629.version 7.0
4630.target sm_52
4631.address_size 64
4632
4633.visible .entry fused_adam_kernel(
4634    .param .u64 param_ptr,
4635    .param .u64 grad_ptr,
4636    .param .u64 exp_avg_ptr,
4637    .param .u64 exp_avg_sq_ptr,
4638    .param .f32 beta1,
4639    .param .f32 beta2,
4640    .param .f32 lr,
4641    .param .f32 eps,
4642    .param .f32 bc1,
4643    .param .f32 bc2,
4644    .param .f32 weight_decay,
4645    .param .u32 n
4646) {
4647    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4648    .reg .u64 %p, %g, %m, %v, %off;
4649    .reg .f32 %vp, %vg, %vm, %vv;
4650    .reg .f32 %b1, %b2, %f_lr, %f_eps, %f_bc1, %f_bc2, %f_wd;
4651    .reg .f32 %t1, %t2, %m_hat, %v_hat, %denom, %update;
4652    .reg .f32 %one;
4653    .reg .pred %p_bound, %p_wd;
4654
4655    ld.param.u64 %p, [param_ptr];
4656    ld.param.u64 %g, [grad_ptr];
4657    ld.param.u64 %m, [exp_avg_ptr];
4658    ld.param.u64 %v, [exp_avg_sq_ptr];
4659    ld.param.f32 %b1, [beta1];
4660    ld.param.f32 %b2, [beta2];
4661    ld.param.f32 %f_lr, [lr];
4662    ld.param.f32 %f_eps, [eps];
4663    ld.param.f32 %f_bc1, [bc1];
4664    ld.param.f32 %f_bc2, [bc2];
4665    ld.param.f32 %f_wd, [weight_decay];
4666    ld.param.u32 %n_reg, [n];
4667
4668    mov.u32 %bid, %ctaid.x;
4669    mov.u32 %bdim, %ntid.x;
4670    mov.u32 %r_tid, %tid.x;
4671    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4672
4673    setp.ge.u32 %p_bound, %r_tid, %n_reg;
4674    @%p_bound bra DONE;
4675
4676    cvt.u64.u32 %off, %r_tid;
4677    shl.b64 %off, %off, 2;
4678
4679    add.u64 %p, %p, %off;
4680    add.u64 %g, %g, %off;
4681    add.u64 %m, %m, %off;
4682    add.u64 %v, %v, %off;
4683
4684    ld.global.f32 %vp, [%p];
4685    ld.global.f32 %vg, [%g];
4686    ld.global.f32 %vm, [%m];
4687    ld.global.f32 %vv, [%v];
4688
4689    // L2 weight decay: g = g + wd * p
4690    mov.f32 %one, 0f00000000;
4691    setp.gt.f32 %p_wd, %f_wd, %one;
4692    @%p_wd fma.rn.f32 %vg, %f_wd, %vp, %vg;
4693
4694    // exp_avg = beta1 * exp_avg + (1 - beta1) * g
4695    mov.f32 %one, 0f3F800000;
4696    sub.f32 %t1, %one, %b1;
4697    mul.f32 %vm, %vm, %b1;
4698    fma.rn.f32 %vm, %t1, %vg, %vm;
4699
4700    // exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * g * g
4701    sub.f32 %t2, %one, %b2;
4702    mul.f32 %vv, %vv, %b2;
4703    mul.f32 %t1, %vg, %vg;
4704    fma.rn.f32 %vv, %t2, %t1, %vv;
4705
4706    // m_hat = exp_avg / bc1
4707    div.rn.f32 %m_hat, %vm, %f_bc1;
4708
4709    // v_hat = exp_avg_sq / bc2
4710    div.rn.f32 %v_hat, %vv, %f_bc2;
4711
4712    // denom = sqrt(v_hat) + eps
4713    sqrt.rn.f32 %denom, %v_hat;
4714    add.f32 %denom, %denom, %f_eps;
4715
4716    // param = param - lr * m_hat / denom
4717    div.rn.f32 %update, %m_hat, %denom;
4718    mul.f32 %update, %update, %f_lr;
4719    sub.f32 %vp, %vp, %update;
4720
4721    st.global.f32 [%p], %vp;
4722    st.global.f32 [%m], %vm;
4723    st.global.f32 [%v], %vv;
4724
4725DONE:
4726    ret;
4727}
4728";
4729
4730/// PTX source for fused GRU cell forward kernel.
4731///
4732/// Takes pre-computed input_gates [B, 3*H] and hidden_gates [B, 3*H]
4733/// (from cuBLAS GEMMs), biases, and previous hidden state. Computes all
4734/// gate activations and the new hidden state in a single kernel launch.
4735///
4736/// One thread per hidden unit. Each thread reads 3 values from input_gates
4737/// and 3 from hidden_gates, applies sigmoid/tanh, computes the GRU update,
4738/// and writes hy + workspace (5*H values for backward).
4739///
4740/// Matches PyTorch's _thnn_fused_gru_cell kernel from RNN.cu.
4741#[cfg(feature = "cuda")]
4742pub(crate) const FUSED_GRU_FORWARD_PTX: &str = "\
4743.version 7.0
4744.target sm_52
4745.address_size 64
4746
4747.visible .entry fused_gru_forward_kernel(
4748    .param .u64 input_gates_ptr,
4749    .param .u64 hidden_gates_ptr,
4750    .param .u64 bias_ih_ptr,
4751    .param .u64 bias_hh_ptr,
4752    .param .u64 hx_ptr,
4753    .param .u64 hy_ptr,
4754    .param .u64 workspace_ptr,
4755    .param .u32 hsz,
4756    .param .u32 total
4757) {
4758    .reg .u32 %tid, %bid, %bdim, %gdim, %total_reg, %hsz_reg;
4759    .reg .u32 %idx, %stride, %offset3, %offset5, %hmod, %batch_idx;
4760    .reg .u64 %ig, %hg, %b1, %b2, %hx, %hy, %ws;
4761    .reg .u64 %off64, %tmp64;
4762    .reg .f32 %ir, %ii, %in, %hr, %hi, %hn;
4763    .reg .f32 %b1r, %b1i, %b1n, %b2r, %b2i, %b2n;
4764    .reg .f32 %hx_val, %rg, %zg, %ng, %hy_val;
4765    .reg .f32 %one, %neg_one, %exp_val, %denom, %tmp;
4766    .reg .pred %p;
4767
4768    ld.param.u64 %ig, [input_gates_ptr];
4769    ld.param.u64 %hg, [hidden_gates_ptr];
4770    ld.param.u64 %b1, [bias_ih_ptr];
4771    ld.param.u64 %b2, [bias_hh_ptr];
4772    ld.param.u64 %hx, [hx_ptr];
4773    ld.param.u64 %hy, [hy_ptr];
4774    ld.param.u64 %ws, [workspace_ptr];
4775    ld.param.u32 %hsz_reg, [hsz];
4776    ld.param.u32 %total_reg, [total];
4777
4778    mov.u32 %bid, %ctaid.x;
4779    mov.u32 %bdim, %ntid.x;
4780    mov.u32 %tid, %tid.x;
4781    mov.u32 %gdim, %nctaid.x;
4782    mad.lo.u32 %idx, %bid, %bdim, %tid;
4783    mul.lo.u32 %stride, %bdim, %gdim;
4784    mov.f32 %one, 0f3F800000;
4785
4786LOOP:
4787    setp.ge.u32 %p, %idx, %total_reg;
4788    @%p bra END;
4789
4790    // offset3 = (idx/hsz)*3*hsz + idx%hsz  (into [B, 3*H] gates tensor)
4791    div.u32 %batch_idx, %idx, %hsz_reg;
4792    rem.u32 %hmod, %idx, %hsz_reg;
4793    mul.lo.u32 %offset3, %batch_idx, %hsz_reg;
4794    mul.lo.u32 %offset3, %offset3, 3;
4795    add.u32 %offset3, %offset3, %hmod;
4796
4797    // Load input gate components: ir, ii, in
4798    cvt.u64.u32 %off64, %offset3;
4799    shl.b64 %off64, %off64, 2;
4800    add.u64 %tmp64, %ig, %off64;
4801    ld.global.f32 %ir, [%tmp64];
4802    cvt.u64.u32 %off64, %hsz_reg;
4803    shl.b64 %off64, %off64, 2;
4804    add.u64 %tmp64, %tmp64, %off64;
4805    ld.global.f32 %ii, [%tmp64];
4806    add.u64 %tmp64, %tmp64, %off64;
4807    ld.global.f32 %in, [%tmp64];
4808
4809    // Load hidden gate components: hr, hi, hn
4810    cvt.u64.u32 %off64, %offset3;
4811    shl.b64 %off64, %off64, 2;
4812    add.u64 %tmp64, %hg, %off64;
4813    ld.global.f32 %hr, [%tmp64];
4814    cvt.u64.u32 %off64, %hsz_reg;
4815    shl.b64 %off64, %off64, 2;
4816    add.u64 %tmp64, %tmp64, %off64;
4817    ld.global.f32 %hi, [%tmp64];
4818    add.u64 %tmp64, %tmp64, %off64;
4819    ld.global.f32 %hn, [%tmp64];
4820
4821    // Load biases (indexed by hmod, hmod+hsz, hmod+2*hsz)
4822    cvt.u64.u32 %off64, %hmod;
4823    shl.b64 %off64, %off64, 2;
4824    add.u64 %tmp64, %b1, %off64;
4825    ld.global.f32 %b1r, [%tmp64];
4826    cvt.u64.u32 %off64, %hsz_reg;
4827    shl.b64 %off64, %off64, 2;
4828    add.u64 %tmp64, %tmp64, %off64;
4829    ld.global.f32 %b1i, [%tmp64];
4830    add.u64 %tmp64, %tmp64, %off64;
4831    ld.global.f32 %b1n, [%tmp64];
4832
4833    cvt.u64.u32 %off64, %hmod;
4834    shl.b64 %off64, %off64, 2;
4835    add.u64 %tmp64, %b2, %off64;
4836    ld.global.f32 %b2r, [%tmp64];
4837    cvt.u64.u32 %off64, %hsz_reg;
4838    shl.b64 %off64, %off64, 2;
4839    add.u64 %tmp64, %tmp64, %off64;
4840    ld.global.f32 %b2i, [%tmp64];
4841    add.u64 %tmp64, %tmp64, %off64;
4842    ld.global.f32 %b2n, [%tmp64];
4843
4844    // Load hx[idx]
4845    cvt.u64.u32 %off64, %idx;
4846    shl.b64 %off64, %off64, 2;
4847    add.u64 %tmp64, %hx, %off64;
4848    ld.global.f32 %hx_val, [%tmp64];
4849
4850    // r = sigmoid(ir + hr + b1r + b2r)
4851    add.f32 %rg, %ir, %hr;
4852    add.f32 %rg, %rg, %b1r;
4853    add.f32 %rg, %rg, %b2r;
4854    neg.f32 %tmp, %rg;
4855    mul.f32 %tmp, %tmp, 0f3FB8AA3B;
4856    ex2.approx.f32 %exp_val, %tmp;
4857    add.f32 %denom, %one, %exp_val;
4858    div.rn.f32 %rg, %one, %denom;
4859
4860    // z = sigmoid(ii + hi + b1i + b2i)
4861    add.f32 %zg, %ii, %hi;
4862    add.f32 %zg, %zg, %b1i;
4863    add.f32 %zg, %zg, %b2i;
4864    neg.f32 %tmp, %zg;
4865    mul.f32 %tmp, %tmp, 0f3FB8AA3B;
4866    ex2.approx.f32 %exp_val, %tmp;
4867    add.f32 %denom, %one, %exp_val;
4868    div.rn.f32 %zg, %one, %denom;
4869
4870    // n = tanh(in + b1n + r*(hn + b2n))
4871    add.f32 %tmp, %hn, %b2n;
4872    fma.rn.f32 %ng, %rg, %tmp, %in;
4873    add.f32 %ng, %ng, %b1n;
4874    // tanh via 2*sigmoid(2x)-1
4875    mul.f32 %tmp, %ng, 0f40000000;
4876    neg.f32 %tmp, %tmp;
4877    mul.f32 %tmp, %tmp, 0f3FB8AA3B;
4878    ex2.approx.f32 %exp_val, %tmp;
4879    add.f32 %denom, %one, %exp_val;
4880    div.rn.f32 %ng, %one, %denom;
4881    mul.f32 %ng, %ng, 0f40000000;
4882    sub.f32 %ng, %ng, %one;
4883
4884    // hy = n + z * (hx - n)
4885    sub.f32 %tmp, %hx_val, %ng;
4886    fma.rn.f32 %hy_val, %zg, %tmp, %ng;
4887
4888    // Store hy[idx]
4889    cvt.u64.u32 %off64, %idx;
4890    shl.b64 %off64, %off64, 2;
4891    add.u64 %tmp64, %hy, %off64;
4892    st.global.f32 [%tmp64], %hy_val;
4893
4894    // Store workspace: [r, z, n, hx, hn+b2n] at offset5 = (idx/hsz)*5*hsz + idx%hsz
4895    mul.lo.u32 %offset5, %batch_idx, %hsz_reg;
4896    mul.lo.u32 %offset5, %offset5, 5;
4897    add.u32 %offset5, %offset5, %hmod;
4898
4899    cvt.u64.u32 %off64, %offset5;
4900    shl.b64 %off64, %off64, 2;
4901    add.u64 %tmp64, %ws, %off64;
4902    st.global.f32 [%tmp64], %rg;
4903    cvt.u64.u32 %off64, %hsz_reg;
4904    shl.b64 %off64, %off64, 2;
4905    add.u64 %tmp64, %tmp64, %off64;
4906    st.global.f32 [%tmp64], %zg;
4907    add.u64 %tmp64, %tmp64, %off64;
4908    st.global.f32 [%tmp64], %ng;
4909    add.u64 %tmp64, %tmp64, %off64;
4910    st.global.f32 [%tmp64], %hx_val;
4911    add.u64 %tmp64, %tmp64, %off64;
4912    add.f32 %tmp, %hn, %b2n;
4913    st.global.f32 [%tmp64], %tmp;
4914
4915    add.u32 %idx, %idx, %stride;
4916    bra LOOP;
4917
4918END:
4919    ret;
4920}
4921";
4922
4923// ---------------------------------------------------------------------------
4924// Launch configuration helper
4925// ---------------------------------------------------------------------------
4926
4927/// Standard 1-D launch config for `n` elements.
4928///
4929/// Uses 256 threads per block, which is a good default for elementwise ops
4930/// on all modern NVIDIA architectures.
4931///
4932/// # Errors
4933///
4934/// Returns [`GpuError::ShapeMismatch`] if `n` exceeds `u32::MAX`, which
4935/// would silently truncate the grid dimension.
4936#[cfg(feature = "cuda")]
4937fn launch_cfg(n: usize) -> GpuResult<LaunchConfig> {
4938    if n > u32::MAX as usize {
4939        return Err(GpuError::ShapeMismatch {
4940            op: "kernel_launch",
4941            expected: vec![u32::MAX as usize],
4942            got: vec![n],
4943        });
4944    }
4945    const BLOCK: u32 = 256;
4946    let grid = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
4947    Ok(LaunchConfig {
4948        grid_dim: (grid.max(1), 1, 1),
4949        block_dim: (BLOCK, 1, 1),
4950        shared_mem_bytes: 0,
4951    })
4952}
4953
4954// ---------------------------------------------------------------------------
4955// Validation helpers
4956// ---------------------------------------------------------------------------
4957
4958/// Validate that two buffers are on the same device and have the same length.
4959#[cfg(feature = "cuda")]
4960fn validate_binary(a: &CudaBuffer<f32>, b: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
4961    if a.device_ordinal() != device.ordinal() {
4962        return Err(GpuError::DeviceMismatch {
4963            expected: a.device_ordinal(),
4964            got: device.ordinal(),
4965        });
4966    }
4967    if b.device_ordinal() != device.ordinal() {
4968        return Err(GpuError::DeviceMismatch {
4969            expected: b.device_ordinal(),
4970            got: device.ordinal(),
4971        });
4972    }
4973    if a.len() != b.len() {
4974        return Err(GpuError::LengthMismatch {
4975            a: a.len(),
4976            b: b.len(),
4977        });
4978    }
4979    Ok(())
4980}
4981
4982/// Validate that a unary buffer is on the correct device.
4983#[cfg(feature = "cuda")]
4984fn validate_unary(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
4985    if a.device_ordinal() != device.ordinal() {
4986        return Err(GpuError::DeviceMismatch {
4987            expected: a.device_ordinal(),
4988            got: device.ordinal(),
4989        });
4990    }
4991    Ok(())
4992}
4993
4994// ---------------------------------------------------------------------------
4995// PTX kernel launch helpers
4996// ---------------------------------------------------------------------------
4997
4998/// Try to launch a binary PTX kernel. Returns `Ok(Some(buf))` on success,
4999/// `Ok(None)` if the PTX module failed to load (caller should fall back to
5000/// CPU), or `Err` on a real CUDA error after a successful launch.
5001#[cfg(feature = "cuda")]
5002fn try_launch_binary(
5003    a: &CudaBuffer<f32>,
5004    b: &CudaBuffer<f32>,
5005    device: &GpuDevice,
5006    ptx_src: &'static str,
5007    kernel_name: &'static str,
5008) -> GpuResult<Option<CudaBuffer<f32>>> {
5009    use cudarc::driver::PushKernelArg;
5010
5011    let n = a.len();
5012    let ctx = device.context();
5013    let stream = device.stream();
5014
5015    // Attempt to load the kernel (cached after first compilation).
5016    // If it fails (e.g. unsupported arch), return None so the caller
5017    // can use the CPU fallback.
5018    let f = match crate::module_cache::get_or_compile(
5019        ctx,
5020        ptx_src,
5021        kernel_name,
5022        device.ordinal() as u32,
5023    ) {
5024        Ok(f) => f,
5025        Err(_) => return Ok(None),
5026    };
5027
5028    let mut out = alloc_zeros_f32(n, device)?;
5029    let cfg = launch_cfg(n)?;
5030    let n_u32 = n as u32;
5031
5032    // SAFETY: The kernel reads `n` f32 values from `a` and `b`, writes `n`
5033    // f32 values to `out`. All three buffers are device-resident and at
5034    // least `n` elements long. The grid covers exactly `n` threads.
5035    unsafe {
5036        stream
5037            .launch_builder(&f)
5038            .arg(a.inner())
5039            .arg(b.inner())
5040            .arg(out.inner_mut())
5041            .arg(&n_u32)
5042            .launch(cfg)?;
5043    }
5044
5045    Ok(Some(out))
5046}
5047
5048/// Try to launch a vectorized (vec4) binary PTX kernel.
5049///
5050/// Each thread processes 4 elements using 128-bit loads/stores.
5051/// `n` must be divisible by 4. Returns `Ok(None)` if compilation fails.
5052#[cfg(feature = "cuda")]
5053fn try_launch_binary_vec4(
5054    a: &CudaBuffer<f32>,
5055    b: &CudaBuffer<f32>,
5056    device: &GpuDevice,
5057    ptx_src: &'static str,
5058    kernel_name: &'static str,
5059) -> GpuResult<Option<CudaBuffer<f32>>> {
5060    use cudarc::driver::PushKernelArg;
5061
5062    let n = a.len();
5063    let n4 = (n / 4) as u32;
5064    let ctx = device.context();
5065    let stream = device.stream();
5066
5067    let f = match crate::module_cache::get_or_compile(
5068        ctx,
5069        ptx_src,
5070        kernel_name,
5071        device.ordinal() as u32,
5072    ) {
5073        Ok(f) => f,
5074        Err(_) => return Ok(None),
5075    };
5076
5077    let mut out = alloc_zeros_f32(n, device)?;
5078    let cfg = launch_cfg(n4 as usize)?;
5079
5080    unsafe {
5081        stream
5082            .launch_builder(&f)
5083            .arg(a.inner())
5084            .arg(b.inner())
5085            .arg(out.inner_mut())
5086            .arg(&n4)
5087            .launch(cfg)?;
5088    }
5089
5090    Ok(Some(out))
5091}
5092
5093/// Try to launch a unary PTX kernel. Returns `Ok(Some(buf))` on success,
5094/// `Ok(None)` if the PTX module failed to load.
5095#[cfg(feature = "cuda")]
5096fn try_launch_unary(
5097    a: &CudaBuffer<f32>,
5098    device: &GpuDevice,
5099    ptx_src: &'static str,
5100    kernel_name: &'static str,
5101) -> GpuResult<Option<CudaBuffer<f32>>> {
5102    use cudarc::driver::PushKernelArg;
5103
5104    let n = a.len();
5105    let ctx = device.context();
5106    let stream = device.stream();
5107
5108    // Attempt to load the kernel (cached after first compilation).
5109    let f = match crate::module_cache::get_or_compile(
5110        ctx,
5111        ptx_src,
5112        kernel_name,
5113        device.ordinal() as u32,
5114    ) {
5115        Ok(f) => f,
5116        Err(_) => return Ok(None),
5117    };
5118
5119    let mut out = alloc_zeros_f32(n, device)?;
5120    let cfg = launch_cfg(n)?;
5121    let n_u32 = n as u32;
5122
5123    // SAFETY: The kernel reads `n` f32 values from `a` and writes `n` f32
5124    // values to `out`. Both buffers are device-resident with length >= n.
5125    unsafe {
5126        stream
5127            .launch_builder(&f)
5128            .arg(a.inner())
5129            .arg(out.inner_mut())
5130            .arg(&n_u32)
5131            .launch(cfg)?;
5132    }
5133
5134    Ok(Some(out))
5135}
5136
5137// ---------------------------------------------------------------------------
5138// _into helpers — write to pre-allocated output buffer (no allocation)
5139// ---------------------------------------------------------------------------
5140
5141/// Launch a binary PTX kernel into a pre-allocated output buffer.
5142/// Returns `Ok(true)` on success, `Ok(false)` if the PTX module failed to load.
5143#[cfg(feature = "cuda")]
5144fn try_launch_binary_into(
5145    a: &CudaBuffer<f32>,
5146    b: &CudaBuffer<f32>,
5147    out: &mut CudaBuffer<f32>,
5148    device: &GpuDevice,
5149    ptx_src: &'static str,
5150    kernel_name: &'static str,
5151) -> GpuResult<bool> {
5152    use cudarc::driver::PushKernelArg;
5153
5154    let n = a.len();
5155    let ctx = device.context();
5156    let stream = device.stream();
5157
5158    let f = match crate::module_cache::get_or_compile(
5159        ctx,
5160        ptx_src,
5161        kernel_name,
5162        device.ordinal() as u32,
5163    ) {
5164        Ok(f) => f,
5165        Err(_) => return Ok(false),
5166    };
5167
5168    let cfg = launch_cfg(n)?;
5169    let n_u32 = n as u32;
5170
5171    unsafe {
5172        stream
5173            .launch_builder(&f)
5174            .arg(a.inner())
5175            .arg(b.inner())
5176            .arg(out.inner_mut())
5177            .arg(&n_u32)
5178            .launch(cfg)?;
5179    }
5180
5181    Ok(true)
5182}
5183
5184/// Launch a unary PTX kernel into a pre-allocated output buffer.
5185/// Returns `Ok(true)` on success, `Ok(false)` if the PTX module failed to load.
5186#[cfg(feature = "cuda")]
5187fn try_launch_unary_into(
5188    a: &CudaBuffer<f32>,
5189    out: &mut CudaBuffer<f32>,
5190    device: &GpuDevice,
5191    ptx_src: &'static str,
5192    kernel_name: &'static str,
5193) -> GpuResult<bool> {
5194    use cudarc::driver::PushKernelArg;
5195
5196    let n = a.len();
5197    let ctx = device.context();
5198    let stream = device.stream();
5199
5200    let f = match crate::module_cache::get_or_compile(
5201        ctx,
5202        ptx_src,
5203        kernel_name,
5204        device.ordinal() as u32,
5205    ) {
5206        Ok(f) => f,
5207        Err(_) => return Ok(false),
5208    };
5209
5210    let cfg = launch_cfg(n)?;
5211    let n_u32 = n as u32;
5212
5213    unsafe {
5214        stream
5215            .launch_builder(&f)
5216            .arg(a.inner())
5217            .arg(out.inner_mut())
5218            .arg(&n_u32)
5219            .launch(cfg)?;
5220    }
5221
5222    Ok(true)
5223}
5224
5225/// Try to launch a general N-dimensional broadcast binary PTX kernel.
5226///
5227/// `a_strides` and `b_strides` are broadcast strides: normal C-contiguous
5228/// stride for non-broadcast dims, 0 for broadcast (size-1) dims.
5229/// `out_shape` is the broadcast-resolved output shape.
5230/// All three arrays have length `ndim`.
5231#[cfg(feature = "cuda")]
5232#[allow(clippy::too_many_arguments)]
5233fn try_launch_broadcast_binary(
5234    a: &CudaBuffer<f32>,
5235    b: &CudaBuffer<f32>,
5236    a_strides: &[u32],
5237    b_strides: &[u32],
5238    out_shape: &[u32],
5239    out_numel: usize,
5240    device: &GpuDevice,
5241    ptx_src: &'static str,
5242    kernel_name: &'static str,
5243) -> GpuResult<Option<CudaBuffer<f32>>> {
5244    use cudarc::driver::PushKernelArg;
5245
5246    let ndim = out_shape.len();
5247    let ctx = device.context();
5248    let stream = device.stream();
5249
5250    let f = match crate::module_cache::get_or_compile(
5251        ctx,
5252        ptx_src,
5253        kernel_name,
5254        device.ordinal() as u32,
5255    ) {
5256        Ok(f) => f,
5257        Err(_) => return Ok(None),
5258    };
5259
5260    // Upload stride/shape metadata as small device buffers.
5261    let a_str_buf = cpu_to_gpu(a_strides, device)?;
5262    let b_str_buf = cpu_to_gpu(b_strides, device)?;
5263    let shape_buf = cpu_to_gpu(out_shape, device)?;
5264
5265    let mut out = alloc_zeros_f32(out_numel, device)?;
5266    let cfg = launch_cfg(out_numel)?;
5267    let n_u32 = out_numel as u32;
5268    let ndim_u32 = ndim as u32;
5269
5270    // SAFETY: Kernel reads from a, b using broadcast indices computed from
5271    // the stride/shape buffers. Output buffer has out_numel elements.
5272    unsafe {
5273        stream
5274            .launch_builder(&f)
5275            .arg(a.inner())
5276            .arg(b.inner())
5277            .arg(out.inner_mut())
5278            .arg(a_str_buf.inner())
5279            .arg(b_str_buf.inner())
5280            .arg(shape_buf.inner())
5281            .arg(&n_u32)
5282            .arg(&ndim_u32)
5283            .launch(cfg)?;
5284    }
5285
5286    Ok(Some(out))
5287}
5288
5289/// Compute broadcast strides for a tensor shape relative to an output shape.
5290///
5291/// For each dimension, the stride is the normal C-contiguous stride if the
5292/// dimension size matches the output, or 0 if the dimension size is 1
5293/// (broadcast). Missing leading dimensions (when input has fewer dims) are
5294/// treated as size-1.
5295#[cfg(feature = "cuda")]
5296fn broadcast_strides(in_shape: &[usize], out_shape: &[usize]) -> Vec<u32> {
5297    let ndim = out_shape.len();
5298    let in_ndim = in_shape.len();
5299    let mut strides = vec![0u32; ndim];
5300
5301    // C-contiguous strides for the input shape.
5302    let mut stride: u32 = 1;
5303    for d in (0..ndim).rev() {
5304        let in_d = if d + in_ndim >= ndim {
5305            d + in_ndim - ndim
5306        } else {
5307            // Leading dimension not present in input — broadcast.
5308            strides[d] = 0;
5309            continue;
5310        };
5311
5312        if in_shape[in_d] == 1 {
5313            strides[d] = 0; // Broadcast dimension.
5314        } else {
5315            strides[d] = stride;
5316        }
5317        stride *= in_shape[in_d] as u32;
5318    }
5319
5320    strides
5321}
5322
5323// ---------------------------------------------------------------------------
5324// CPU fallback helpers
5325// ---------------------------------------------------------------------------
5326
5327/// CPU fallback for binary ops: copy both inputs to host, apply `op`, copy
5328/// the result back.
5329#[cfg(feature = "cuda")]
5330fn cpu_fallback_binary(
5331    a: &CudaBuffer<f32>,
5332    b: &CudaBuffer<f32>,
5333    device: &GpuDevice,
5334    op: fn(f32, f32) -> f32,
5335) -> GpuResult<CudaBuffer<f32>> {
5336    let a_host = gpu_to_cpu(a, device)?;
5337    let b_host = gpu_to_cpu(b, device)?;
5338    let result: Vec<f32> = a_host
5339        .iter()
5340        .zip(b_host.iter())
5341        .map(|(&x, &y)| op(x, y))
5342        .collect();
5343    cpu_to_gpu(&result, device)
5344}
5345
5346/// CPU fallback for unary ops.
5347#[cfg(feature = "cuda")]
5348fn cpu_fallback_unary(
5349    a: &CudaBuffer<f32>,
5350    device: &GpuDevice,
5351    op: fn(f32) -> f32,
5352) -> GpuResult<CudaBuffer<f32>> {
5353    let a_host = gpu_to_cpu(a, device)?;
5354    let result: Vec<f32> = a_host.iter().map(|&x| op(x)).collect();
5355    cpu_to_gpu(&result, device)
5356}
5357
5358// ---------------------------------------------------------------------------
5359// Public API -- binary ops
5360// ---------------------------------------------------------------------------
5361
5362/// Elementwise addition: `out[i] = a[i] + b[i]`.
5363///
5364/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
5365/// if the PTX module cannot be loaded.
5366///
5367/// # Errors
5368///
5369/// - [`GpuError::DeviceMismatch`] if `a`, `b`, or `device` refer to
5370///   different CUDA devices.
5371/// - [`GpuError::LengthMismatch`] if `a` and `b` have different lengths.
5372/// - [`GpuError::Driver`] on CUDA runtime errors.
5373#[cfg(feature = "cuda")]
5374pub fn gpu_add(
5375    a: &CudaBuffer<f32>,
5376    b: &CudaBuffer<f32>,
5377    device: &GpuDevice,
5378) -> GpuResult<CudaBuffer<f32>> {
5379    validate_binary(a, b, device)?;
5380
5381    // Try vec4 kernel for 4x memory throughput (128-bit loads).
5382    let n = a.len();
5383    if n >= 16 && n % 4 == 0 {
5384        if let Some(out) = try_launch_binary_vec4(
5385            a, b, device, ADD_VEC4_PTX, "add_vec4_kernel",
5386        )? {
5387            return Ok(out);
5388        }
5389    }
5390
5391    if let Some(out) = try_launch_binary(a, b, device, ADD_PTX, "add_kernel")? {
5392        return Ok(out);
5393    }
5394
5395    cpu_fallback_binary(a, b, device, |x, y| x + y)
5396}
5397
5398/// Elementwise subtraction: `out[i] = a[i] - b[i]`.
5399///
5400/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
5401/// if the PTX module cannot be loaded.
5402///
5403/// # Errors
5404///
5405/// - [`GpuError::DeviceMismatch`] if `a`, `b`, or `device` refer to
5406///   different CUDA devices.
5407/// - [`GpuError::LengthMismatch`] if `a` and `b` have different lengths.
5408/// - [`GpuError::Driver`] on CUDA runtime errors.
5409#[cfg(feature = "cuda")]
5410pub fn gpu_sub(
5411    a: &CudaBuffer<f32>,
5412    b: &CudaBuffer<f32>,
5413    device: &GpuDevice,
5414) -> GpuResult<CudaBuffer<f32>> {
5415    validate_binary(a, b, device)?;
5416
5417    if let Some(out) = try_launch_binary(a, b, device, SUB_PTX, "sub_kernel")? {
5418        return Ok(out);
5419    }
5420
5421    cpu_fallback_binary(a, b, device, |x, y| x - y)
5422}
5423
5424/// Elementwise multiplication: `out[i] = a[i] * b[i]`.
5425///
5426/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
5427/// if the PTX module cannot be loaded.
5428///
5429/// # Errors
5430///
5431/// - [`GpuError::DeviceMismatch`] if `a`, `b`, or `device` refer to
5432///   different CUDA devices.
5433/// - [`GpuError::LengthMismatch`] if `a` and `b` have different lengths.
5434/// - [`GpuError::Driver`] on CUDA runtime errors.
5435#[cfg(feature = "cuda")]
5436pub fn gpu_mul(
5437    a: &CudaBuffer<f32>,
5438    b: &CudaBuffer<f32>,
5439    device: &GpuDevice,
5440) -> GpuResult<CudaBuffer<f32>> {
5441    validate_binary(a, b, device)?;
5442
5443    let n = a.len();
5444    if n >= 16 && n % 4 == 0 {
5445        if let Some(out) = try_launch_binary_vec4(
5446            a, b, device, MUL_VEC4_PTX, "mul_vec4_kernel",
5447        )? {
5448            return Ok(out);
5449        }
5450    }
5451
5452    if let Some(out) = try_launch_binary(a, b, device, MUL_PTX, "mul_kernel")? {
5453        return Ok(out);
5454    }
5455
5456    cpu_fallback_binary(a, b, device, |x, y| x * y)
5457}
5458
5459// ---------------------------------------------------------------------------
5460// Public API -- broadcast binary ops
5461// ---------------------------------------------------------------------------
5462
5463/// Broadcast addition: `out[i] = a[bcast_a(i)] + b[bcast_b(i)]`.
5464///
5465/// Handles arbitrary N-dimensional broadcasting on the GPU. The kernel
5466/// decomposes each output index into coordinates, maps them through
5467/// broadcast strides, and loads from the correct positions in A and B.
5468///
5469/// `a_shape` and `b_shape` are the original shapes; the output shape is
5470/// computed via numpy-style broadcast rules.
5471#[cfg(feature = "cuda")]
5472pub fn gpu_broadcast_add(
5473    a: &CudaBuffer<f32>,
5474    b: &CudaBuffer<f32>,
5475    a_shape: &[usize],
5476    b_shape: &[usize],
5477    out_shape: &[usize],
5478    device: &GpuDevice,
5479) -> GpuResult<CudaBuffer<f32>> {
5480    let a_str = broadcast_strides(a_shape, out_shape);
5481    let b_str = broadcast_strides(b_shape, out_shape);
5482    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5483    let out_numel: usize = out_shape.iter().product();
5484
5485    if let Some(out) = try_launch_broadcast_binary(
5486        a,
5487        b,
5488        &a_str,
5489        &b_str,
5490        &shape_u32,
5491        out_numel,
5492        device,
5493        BROADCAST_ADD_PTX,
5494        "broadcast_add_kernel",
5495    )? {
5496        return Ok(out);
5497    }
5498
5499    // CPU fallback for broadcast.
5500    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
5501}
5502
5503/// Broadcast subtraction: `out[i] = a[bcast_a(i)] - b[bcast_b(i)]`.
5504#[cfg(feature = "cuda")]
5505pub fn gpu_broadcast_sub(
5506    a: &CudaBuffer<f32>,
5507    b: &CudaBuffer<f32>,
5508    a_shape: &[usize],
5509    b_shape: &[usize],
5510    out_shape: &[usize],
5511    device: &GpuDevice,
5512) -> GpuResult<CudaBuffer<f32>> {
5513    let a_str = broadcast_strides(a_shape, out_shape);
5514    let b_str = broadcast_strides(b_shape, out_shape);
5515    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5516    let out_numel: usize = out_shape.iter().product();
5517
5518    if let Some(out) = try_launch_broadcast_binary(
5519        a,
5520        b,
5521        &a_str,
5522        &b_str,
5523        &shape_u32,
5524        out_numel,
5525        device,
5526        BROADCAST_SUB_PTX,
5527        "broadcast_sub_kernel",
5528    )? {
5529        return Ok(out);
5530    }
5531
5532    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
5533}
5534
5535/// Broadcast multiplication: `out[i] = a[bcast_a(i)] * b[bcast_b(i)]`.
5536#[cfg(feature = "cuda")]
5537pub fn gpu_broadcast_mul(
5538    a: &CudaBuffer<f32>,
5539    b: &CudaBuffer<f32>,
5540    a_shape: &[usize],
5541    b_shape: &[usize],
5542    out_shape: &[usize],
5543    device: &GpuDevice,
5544) -> GpuResult<CudaBuffer<f32>> {
5545    let a_str = broadcast_strides(a_shape, out_shape);
5546    let b_str = broadcast_strides(b_shape, out_shape);
5547    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5548    let out_numel: usize = out_shape.iter().product();
5549
5550    if let Some(out) = try_launch_broadcast_binary(
5551        a,
5552        b,
5553        &a_str,
5554        &b_str,
5555        &shape_u32,
5556        out_numel,
5557        device,
5558        BROADCAST_MUL_PTX,
5559        "broadcast_mul_kernel",
5560    )? {
5561        return Ok(out);
5562    }
5563
5564    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
5565}
5566
5567/// Broadcast division: `out[i] = a[bcast_a(i)] / b[bcast_b(i)]`.
5568#[cfg(feature = "cuda")]
5569pub fn gpu_broadcast_div(
5570    a: &CudaBuffer<f32>,
5571    b: &CudaBuffer<f32>,
5572    a_shape: &[usize],
5573    b_shape: &[usize],
5574    out_shape: &[usize],
5575    device: &GpuDevice,
5576) -> GpuResult<CudaBuffer<f32>> {
5577    let a_str = broadcast_strides(a_shape, out_shape);
5578    let b_str = broadcast_strides(b_shape, out_shape);
5579    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5580    let out_numel: usize = out_shape.iter().product();
5581
5582    if let Some(out) = try_launch_broadcast_binary(
5583        a,
5584        b,
5585        &a_str,
5586        &b_str,
5587        &shape_u32,
5588        out_numel,
5589        device,
5590        BROADCAST_DIV_PTX,
5591        "broadcast_div_kernel",
5592    )? {
5593        return Ok(out);
5594    }
5595
5596    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
5597}
5598
5599/// CPU fallback for broadcast binary ops — downloads, applies op with
5600/// broadcast indexing, re-uploads.
5601#[cfg(feature = "cuda")]
5602fn cpu_fallback_broadcast_binary(
5603    a: &CudaBuffer<f32>,
5604    b: &CudaBuffer<f32>,
5605    a_shape: &[usize],
5606    b_shape: &[usize],
5607    out_shape: &[usize],
5608    device: &GpuDevice,
5609    op: fn(f32, f32) -> f32,
5610) -> GpuResult<CudaBuffer<f32>> {
5611    let a_host = gpu_to_cpu(a, device)?;
5612    let b_host = gpu_to_cpu(b, device)?;
5613    let out_numel: usize = out_shape.iter().product();
5614
5615    let a_str = broadcast_strides(a_shape, out_shape);
5616    let b_str = broadcast_strides(b_shape, out_shape);
5617
5618    let mut result = Vec::with_capacity(out_numel);
5619    for i in 0..out_numel {
5620        let mut remaining = i;
5621        let mut a_idx = 0usize;
5622        let mut b_idx = 0usize;
5623        for d in (0..out_shape.len()).rev() {
5624            let coord = remaining % out_shape[d];
5625            remaining /= out_shape[d];
5626            a_idx += coord * a_str[d] as usize;
5627            b_idx += coord * b_str[d] as usize;
5628        }
5629        result.push(op(a_host[a_idx], b_host[b_idx]));
5630    }
5631    cpu_to_gpu(&result, device)
5632}
5633
5634// ---------------------------------------------------------------------------
5635// Public API -- unary ops
5636// ---------------------------------------------------------------------------
5637
5638/// Elementwise negation: `out[i] = -a[i]`.
5639///
5640/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
5641/// if the PTX module cannot be loaded.
5642///
5643/// # Errors
5644///
5645/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different
5646///   CUDA devices.
5647/// - [`GpuError::Driver`] on CUDA runtime errors.
5648#[cfg(feature = "cuda")]
5649pub fn gpu_neg(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
5650    validate_unary(a, device)?;
5651
5652    if let Some(out) = try_launch_unary(a, device, NEG_PTX, "neg_kernel")? {
5653        return Ok(out);
5654    }
5655
5656    cpu_fallback_unary(a, device, |x| -x)
5657}
5658
5659/// Elementwise ReLU: `out[i] = max(a[i], 0.0)`.
5660///
5661/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
5662/// if the PTX module cannot be loaded.
5663///
5664/// # Errors
5665///
5666/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different
5667///   CUDA devices.
5668/// - [`GpuError::Driver`] on CUDA runtime errors.
5669#[cfg(feature = "cuda")]
5670pub fn gpu_relu(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
5671    validate_unary(a, device)?;
5672
5673    if let Some(out) = try_launch_unary(a, device, RELU_PTX, "relu_kernel")? {
5674        return Ok(out);
5675    }
5676
5677    cpu_fallback_unary(a, device, |x| x.max(0.0))
5678}
5679
5680/// ReLU backward: `out[i] = (input[i] > 0) ? grad[i] : 0`.
5681#[cfg(feature = "cuda")]
5682pub fn gpu_relu_backward(
5683    grad: &CudaBuffer<f32>,
5684    input: &CudaBuffer<f32>,
5685    device: &GpuDevice,
5686) -> GpuResult<CudaBuffer<f32>> {
5687    validate_binary(grad, input, device)?;
5688
5689    if let Some(out) = try_launch_binary(
5690        grad,
5691        input,
5692        device,
5693        RELU_BACKWARD_PTX,
5694        "relu_backward_kernel",
5695    )? {
5696        return Ok(out);
5697    }
5698
5699    // CPU fallback
5700    let grad_host = gpu_to_cpu(grad, device)?;
5701    let input_host = gpu_to_cpu(input, device)?;
5702    let result: Vec<f32> = grad_host
5703        .iter()
5704        .zip(input_host.iter())
5705        .map(|(&g, &x)| if x > 0.0 { g } else { 0.0 })
5706        .collect();
5707    cpu_to_gpu(&result, device)
5708}
5709
5710/// GELU backward: `out[i] = grad[i] * (sig + 1.702 * x * sig * (1 - sig))`
5711/// where `sig = sigmoid(1.702 * x)`.
5712#[cfg(feature = "cuda")]
5713pub fn gpu_gelu_backward(
5714    grad: &CudaBuffer<f32>,
5715    input: &CudaBuffer<f32>,
5716    device: &GpuDevice,
5717) -> GpuResult<CudaBuffer<f32>> {
5718    validate_binary(grad, input, device)?;
5719
5720    if let Some(out) = try_launch_binary(
5721        grad,
5722        input,
5723        device,
5724        GELU_BACKWARD_PTX,
5725        "gelu_backward_kernel",
5726    )? {
5727        return Ok(out);
5728    }
5729
5730    // CPU fallback
5731    let grad_host = gpu_to_cpu(grad, device)?;
5732    let input_host = gpu_to_cpu(input, device)?;
5733    let result: Vec<f32> = grad_host
5734        .iter()
5735        .zip(input_host.iter())
5736        .map(|(&g, &x)| {
5737            let k: f32 = 1.702;
5738            let sig = 1.0 / (1.0 + (-k * x).exp());
5739            g * (sig + k * x * sig * (1.0 - sig))
5740        })
5741        .collect();
5742    cpu_to_gpu(&result, device)
5743}
5744
5745/// GELU backward (exact erf mode):
5746/// `out[i] = grad[i] * (Φ(x) + x·φ(x))`
5747/// where Φ = normal CDF, φ = normal PDF.
5748#[cfg(feature = "cuda")]
5749pub fn gpu_gelu_backward_erf(
5750    grad: &CudaBuffer<f32>,
5751    input: &CudaBuffer<f32>,
5752    device: &GpuDevice,
5753) -> GpuResult<CudaBuffer<f32>> {
5754    validate_binary(grad, input, device)?;
5755
5756    if let Some(out) = try_launch_binary(
5757        grad,
5758        input,
5759        device,
5760        GELU_BACKWARD_ERF_PTX,
5761        "gelu_backward_erf_kernel",
5762    )? {
5763        return Ok(out);
5764    }
5765
5766    // CPU fallback — Abramowitz & Stegun erf approximation (|ε| < 1.5e-7)
5767    let grad_host = gpu_to_cpu(grad, device)?;
5768    let input_host = gpu_to_cpu(input, device)?;
5769    let inv_sqrt_2: f32 = std::f32::consts::FRAC_1_SQRT_2;
5770    let inv_sqrt_2pi: f32 = 1.0 / (2.0 * std::f32::consts::PI).sqrt();
5771    let result: Vec<f32> = grad_host
5772        .iter()
5773        .zip(input_host.iter())
5774        .map(|(&g, &x)| {
5775            let z = x * inv_sqrt_2;
5776            let az = z.abs();
5777            let t = 1.0 / (1.0 + 0.3275911 * az);
5778            let poly = t * (0.2548296 + t * (-0.2844967 + t * (1.4214137 + t * (-1.4531520 + t * 0.3275911))));
5779            let erf_abs = 1.0 - poly * (-az * az).exp();
5780            let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
5781            let cdf = 0.5 * (1.0 + erf_val);
5782            let pdf = inv_sqrt_2pi * (-0.5 * x * x).exp();
5783            g * (cdf + x * pdf)
5784        })
5785        .collect();
5786    cpu_to_gpu(&result, device)
5787}
5788
5789// ---------------------------------------------------------------------------
5790// Public API -- Index-select 1-D (gather)
5791// ---------------------------------------------------------------------------
5792
5793/// Gather elements from `input` at positions given by `indices`.
5794///
5795/// `indices` is a GPU buffer of f32 values encoding integer indices.
5796/// Output has `indices.len()` elements: `out[i] = input[indices[i]]`.
5797#[cfg(feature = "cuda")]
5798pub fn gpu_index_select_1d(
5799    input: &CudaBuffer<f32>,
5800    indices: &CudaBuffer<f32>,
5801    device: &GpuDevice,
5802) -> GpuResult<CudaBuffer<f32>> {
5803    use cudarc::driver::PushKernelArg;
5804
5805    validate_unary(input, device)?;
5806
5807    let n = indices.len();
5808    let ctx = device.context();
5809    let stream = device.stream();
5810
5811    let f = match crate::module_cache::get_or_compile(
5812        ctx,
5813        INDEX_SELECT_1D_PTX,
5814        "index_select_1d_kernel",
5815        device.ordinal() as u32,
5816    ) {
5817        Ok(f) => f,
5818        Err(_) => {
5819            // CPU fallback.
5820            let input_host = gpu_to_cpu(input, device)?;
5821            let indices_host = gpu_to_cpu(indices, device)?;
5822            let result: Vec<f32> = indices_host
5823                .iter()
5824                .map(|&idx_f| input_host[idx_f as usize])
5825                .collect();
5826            return cpu_to_gpu(&result, device);
5827        }
5828    };
5829
5830    let mut out = alloc_zeros_f32(n, device)?;
5831    let cfg = launch_cfg(n)?;
5832    let n_u32 = n as u32;
5833
5834    unsafe {
5835        stream
5836            .launch_builder(&f)
5837            .arg(input.inner())
5838            .arg(indices.inner())
5839            .arg(out.inner_mut())
5840            .arg(&n_u32)
5841            .launch(cfg)?;
5842    }
5843
5844    Ok(out)
5845}
5846
5847// ---------------------------------------------------------------------------
5848// Public API -- Scatter-add 1-D (backward of index_select)
5849// ---------------------------------------------------------------------------
5850
5851/// Scatter-add `grad_output` back into an output buffer of `input_len` elements,
5852/// using positions from `indices`.
5853///
5854/// `indices` is a GPU buffer of f32 values encoding integer indices.
5855/// Output: `out = zeros(input_len); for i: out[indices[i]] += grad_output[i]`
5856///
5857/// Uses atomic adds for safe concurrent accumulation.
5858#[cfg(feature = "cuda")]
5859pub fn gpu_scatter_add_1d(
5860    grad_output: &CudaBuffer<f32>,
5861    indices: &CudaBuffer<f32>,
5862    input_len: usize,
5863    device: &GpuDevice,
5864) -> GpuResult<CudaBuffer<f32>> {
5865    use cudarc::driver::PushKernelArg;
5866
5867    validate_unary(grad_output, device)?;
5868
5869    let n = grad_output.len();
5870    let ctx = device.context();
5871    let stream = device.stream();
5872
5873    let f = match crate::module_cache::get_or_compile(
5874        ctx,
5875        SCATTER_ADD_1D_PTX,
5876        "scatter_add_1d_kernel",
5877        device.ordinal() as u32,
5878    ) {
5879        Ok(f) => f,
5880        Err(_) => {
5881            // CPU fallback.
5882            let go_host = gpu_to_cpu(grad_output, device)?;
5883            let idx_host = gpu_to_cpu(indices, device)?;
5884            let mut result = vec![0.0f32; input_len];
5885            for (i, &idx_f) in idx_host.iter().enumerate() {
5886                result[idx_f as usize] += go_host[i];
5887            }
5888            return cpu_to_gpu(&result, device);
5889        }
5890    };
5891
5892    let mut out = alloc_zeros_f32(input_len, device)?;
5893    let cfg = launch_cfg(n)?;
5894    let n_u32 = n as u32;
5895
5896    unsafe {
5897        stream
5898            .launch_builder(&f)
5899            .arg(grad_output.inner())
5900            .arg(indices.inner())
5901            .arg(out.inner_mut())
5902            .arg(&n_u32)
5903            .launch(cfg)?;
5904    }
5905
5906    Ok(out)
5907}
5908
5909// ---------------------------------------------------------------------------
5910// Public API -- Masked fill
5911// ---------------------------------------------------------------------------
5912
5913/// Fill elements of `input` with `value` where `mask` is true.
5914///
5915/// `mask` is a GPU buffer of f32 values (1.0 = true, 0.0 = false).
5916/// Output: `out[i] = mask[i] >= 0.5 ? value : input[i]`
5917#[cfg(feature = "cuda")]
5918pub fn gpu_masked_fill(
5919    input: &CudaBuffer<f32>,
5920    mask: &CudaBuffer<f32>,
5921    value: f32,
5922    device: &GpuDevice,
5923) -> GpuResult<CudaBuffer<f32>> {
5924    use cudarc::driver::PushKernelArg;
5925
5926    validate_binary(input, mask, device)?;
5927
5928    let n = input.len();
5929    let ctx = device.context();
5930    let stream = device.stream();
5931
5932    let f = match crate::module_cache::get_or_compile(
5933        ctx,
5934        MASKED_FILL_PTX,
5935        "masked_fill_kernel",
5936        device.ordinal() as u32,
5937    ) {
5938        Ok(f) => f,
5939        Err(_) => {
5940            // CPU fallback.
5941            let input_host = gpu_to_cpu(input, device)?;
5942            let mask_host = gpu_to_cpu(mask, device)?;
5943            let result: Vec<f32> = input_host
5944                .iter()
5945                .zip(mask_host.iter())
5946                .map(|(&x, &m)| if m >= 0.5 { value } else { x })
5947                .collect();
5948            return cpu_to_gpu(&result, device);
5949        }
5950    };
5951
5952    let mut out = alloc_zeros_f32(n, device)?;
5953    let cfg = launch_cfg(n)?;
5954    let n_u32 = n as u32;
5955
5956    unsafe {
5957        stream
5958            .launch_builder(&f)
5959            .arg(input.inner())
5960            .arg(mask.inner())
5961            .arg(out.inner_mut())
5962            .arg(&value)
5963            .arg(&n_u32)
5964            .launch(cfg)?;
5965    }
5966
5967    Ok(out)
5968}
5969
5970// ---------------------------------------------------------------------------
5971// Public API -- Masked zero (backward of masked_fill)
5972// ---------------------------------------------------------------------------
5973
5974/// Zero out gradient at positions where `mask` is true.
5975///
5976/// `mask` is a GPU buffer of f32 values (1.0 = true, 0.0 = false).
5977/// Output: `out[i] = mask[i] >= 0.5 ? 0.0 : grad[i]`
5978#[cfg(feature = "cuda")]
5979pub fn gpu_masked_zero(
5980    grad: &CudaBuffer<f32>,
5981    mask: &CudaBuffer<f32>,
5982    device: &GpuDevice,
5983) -> GpuResult<CudaBuffer<f32>> {
5984    validate_binary(grad, mask, device)?;
5985
5986    if let Some(out) = try_launch_binary(grad, mask, device, MASKED_ZERO_PTX, "masked_zero_kernel")?
5987    {
5988        return Ok(out);
5989    }
5990
5991    // CPU fallback.
5992    let grad_host = gpu_to_cpu(grad, device)?;
5993    let mask_host = gpu_to_cpu(mask, device)?;
5994    let result: Vec<f32> = grad_host
5995        .iter()
5996        .zip(mask_host.iter())
5997        .map(|(&g, &m)| if m >= 0.5 { 0.0 } else { g })
5998        .collect();
5999    cpu_to_gpu(&result, device)
6000}
6001
6002// ---------------------------------------------------------------------------
6003// Public API -- Sigmoid backward
6004// ---------------------------------------------------------------------------
6005
6006/// Sigmoid backward: `out[i] = grad[i] * output[i] * (1 - output[i])`.
6007///
6008/// `grad` and `output` must have the same length and reside on `device`.
6009#[cfg(feature = "cuda")]
6010pub fn gpu_sigmoid_backward(
6011    grad: &CudaBuffer<f32>,
6012    output: &CudaBuffer<f32>,
6013    device: &GpuDevice,
6014) -> GpuResult<CudaBuffer<f32>> {
6015    validate_binary(grad, output, device)?;
6016
6017    if let Some(out) = try_launch_binary(
6018        grad,
6019        output,
6020        device,
6021        SIGMOID_BACKWARD_PTX,
6022        "sigmoid_backward_kernel",
6023    )? {
6024        return Ok(out);
6025    }
6026
6027    // CPU fallback
6028    let grad_host = gpu_to_cpu(grad, device)?;
6029    let output_host = gpu_to_cpu(output, device)?;
6030    let result: Vec<f32> = grad_host
6031        .iter()
6032        .zip(output_host.iter())
6033        .map(|(&g, &o)| g * o * (1.0 - o))
6034        .collect();
6035    cpu_to_gpu(&result, device)
6036}
6037
6038// ---------------------------------------------------------------------------
6039// Public API -- Tanh backward
6040// ---------------------------------------------------------------------------
6041
6042/// Tanh backward: `out[i] = grad[i] * (1 - output[i]^2)`.
6043///
6044/// `grad` and `output` must have the same length and reside on `device`.
6045#[cfg(feature = "cuda")]
6046pub fn gpu_tanh_backward(
6047    grad: &CudaBuffer<f32>,
6048    output: &CudaBuffer<f32>,
6049    device: &GpuDevice,
6050) -> GpuResult<CudaBuffer<f32>> {
6051    validate_binary(grad, output, device)?;
6052
6053    if let Some(out) = try_launch_binary(
6054        grad,
6055        output,
6056        device,
6057        TANH_BACKWARD_PTX,
6058        "tanh_backward_kernel",
6059    )? {
6060        return Ok(out);
6061    }
6062
6063    // CPU fallback
6064    let grad_host = gpu_to_cpu(grad, device)?;
6065    let output_host = gpu_to_cpu(output, device)?;
6066    let result: Vec<f32> = grad_host
6067        .iter()
6068        .zip(output_host.iter())
6069        .map(|(&g, &o)| g * (1.0 - o * o))
6070        .collect();
6071    cpu_to_gpu(&result, device)
6072}
6073
6074// ---------------------------------------------------------------------------
6075// Public API -- Softmax backward
6076// ---------------------------------------------------------------------------
6077
6078/// Softmax backward (row-wise): one block per row, shared-memory dot reduction.
6079///
6080/// For each row of length `cols`:
6081///   `dot = sum(grad[row] * output[row])`
6082///   `out[i] = output[i] * (grad[i] - dot)`
6083///
6084/// `rows` = total elements / cols. Both `grad` and `output` have `rows * cols` elements.
6085#[cfg(feature = "cuda")]
6086pub fn gpu_softmax_backward(
6087    grad: &CudaBuffer<f32>,
6088    output: &CudaBuffer<f32>,
6089    cols: usize,
6090    device: &GpuDevice,
6091) -> GpuResult<CudaBuffer<f32>> {
6092    use cudarc::driver::PushKernelArg;
6093
6094    validate_binary(grad, output, device)?;
6095
6096    let total = grad.len();
6097    let rows = total / cols;
6098
6099    let ctx = device.context();
6100    let stream = device.stream();
6101
6102    let f = match crate::module_cache::get_or_compile(
6103        ctx,
6104        SOFTMAX_BACKWARD_PTX,
6105        "softmax_backward_kernel",
6106        device.ordinal() as u32,
6107    ) {
6108        Ok(f) => f,
6109        Err(_) => {
6110            // CPU fallback
6111            let grad_host = gpu_to_cpu(grad, device)?;
6112            let output_host = gpu_to_cpu(output, device)?;
6113            let mut result = vec![0.0f32; total];
6114            for r in 0..rows {
6115                let base = r * cols;
6116                let mut dot = 0.0f32;
6117                for c in 0..cols {
6118                    dot += grad_host[base + c] * output_host[base + c];
6119                }
6120                for c in 0..cols {
6121                    result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
6122                }
6123            }
6124            return cpu_to_gpu(&result, device);
6125        }
6126    };
6127
6128    let mut out = alloc_zeros_f32(total, device)?;
6129    let rows_u32 = rows as u32;
6130    let cols_u32 = cols as u32;
6131
6132    // One block per row, 256 threads per block.
6133    let cfg = LaunchConfig {
6134        grid_dim: ((rows as u32).max(1), 1, 1),
6135        block_dim: (256, 1, 1),
6136        shared_mem_bytes: 256 * 4,
6137    };
6138
6139    unsafe {
6140        stream
6141            .launch_builder(&f)
6142            .arg(grad.inner())
6143            .arg(output.inner())
6144            .arg(out.inner_mut())
6145            .arg(&rows_u32)
6146            .arg(&cols_u32)
6147            .launch(cfg)?;
6148    }
6149
6150    Ok(out)
6151}
6152
6153// ---------------------------------------------------------------------------
6154// Public API -- Sum axis
6155// ---------------------------------------------------------------------------
6156
6157/// Reduce along one axis of a tensor.
6158///
6159/// Thread i computes:
6160/// Full parallel sum reduction on GPU.
6161///
6162/// Uses a two-pass approach: first pass reduces `n` elements to `num_blocks`
6163/// partial sums via the `reduce_sum_kernel`, second pass reduces the partial
6164/// sums to a single scalar. For small inputs (< 256 blocks), the second pass
6165/// runs on CPU to avoid kernel launch overhead.
6166#[cfg(feature = "cuda")]
6167pub fn gpu_reduce_sum(
6168    a: &CudaBuffer<f32>,
6169    device: &GpuDevice,
6170) -> GpuResult<CudaBuffer<f32>> {
6171    use cudarc::driver::PushKernelArg;
6172
6173    let n = a.len();
6174    if n == 0 {
6175        return cpu_to_gpu(&[0.0f32], device);
6176    }
6177
6178    let ctx = device.context();
6179    let stream = device.stream();
6180
6181    let f = match crate::module_cache::get_or_compile(
6182        ctx,
6183        REDUCE_SUM_PTX,
6184        "reduce_sum_kernel",
6185        device.ordinal() as u32,
6186    ) {
6187        Ok(f) => f,
6188        Err(_) => {
6189            // CPU fallback
6190            let host = gpu_to_cpu(a, device)?;
6191            let total: f32 = host.iter().sum();
6192            return cpu_to_gpu(&[total], device);
6193        }
6194    };
6195
6196    // Pass 1: reduce to partial sums (one per block).
6197    const BLOCK: u32 = 256;
6198    let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
6199    // Cap blocks to avoid excessive partial sums.
6200    let num_blocks = num_blocks.min(1024);
6201
6202    let mut partials = alloc_zeros_f32(num_blocks as usize, device)?;
6203    let n_u32 = n as u32;
6204
6205    let cfg = cudarc::driver::LaunchConfig {
6206        grid_dim: (num_blocks.max(1), 1, 1),
6207        block_dim: (BLOCK, 1, 1),
6208        shared_mem_bytes: 0, // Statically allocated in PTX
6209    };
6210
6211    unsafe {
6212        stream
6213            .launch_builder(&f)
6214            .arg(a.inner())
6215            .arg(partials.inner_mut())
6216            .arg(&n_u32)
6217            .launch(cfg)?;
6218    }
6219
6220    // Pass 2: reduce partial sums.
6221    if num_blocks <= 1 {
6222        return Ok(partials);
6223    }
6224
6225    // For small number of blocks, reduce on CPU (cheaper than another kernel launch).
6226    if num_blocks <= 256 {
6227        let host_partials = gpu_to_cpu(&partials, device)?;
6228        let total: f32 = host_partials.iter().sum();
6229        return cpu_to_gpu(&[total], device);
6230    }
6231
6232    // For many blocks, recurse with another kernel launch.
6233    gpu_reduce_sum(&partials, device)
6234}
6235
6236/// Stub -- always returns [`GpuError::NoCudaFeature`].
6237#[cfg(not(feature = "cuda"))]
6238pub fn gpu_reduce_sum(
6239    _a: &CudaBuffer<f32>,
6240    _device: &GpuDevice,
6241) -> GpuResult<CudaBuffer<f32>> {
6242    Err(GpuError::NoCudaFeature)
6243}
6244
6245///   `output[i] = sum_{k=0}^{axis_size-1} input[outer_idx * axis_size * inner_size + k * inner_size + inner_idx]`
6246///
6247/// where `outer_idx = i / inner_size`, `inner_idx = i % inner_size`.
6248#[cfg(feature = "cuda")]
6249pub fn gpu_sum_axis(
6250    a: &CudaBuffer<f32>,
6251    outer: usize,
6252    axis_size: usize,
6253    inner: usize,
6254    device: &GpuDevice,
6255) -> GpuResult<CudaBuffer<f32>> {
6256    use cudarc::driver::PushKernelArg;
6257
6258    validate_unary(a, device)?;
6259
6260    let total_output = outer * inner;
6261    let ctx = device.context();
6262    let stream = device.stream();
6263
6264    let f = match crate::module_cache::get_or_compile(
6265        ctx,
6266        SUM_AXIS_PTX,
6267        "sum_axis_kernel",
6268        device.ordinal() as u32,
6269    ) {
6270        Ok(f) => f,
6271        Err(_) => {
6272            // CPU fallback
6273            let host = gpu_to_cpu(a, device)?;
6274            let mut result = vec![0.0f32; total_output];
6275            for (i, out) in result.iter_mut().enumerate() {
6276                let outer_idx = i / inner;
6277                let inner_idx = i % inner;
6278                let mut sum = 0.0f32;
6279                for k in 0..axis_size {
6280                    sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
6281                }
6282                *out = sum;
6283            }
6284            return cpu_to_gpu(&result, device);
6285        }
6286    };
6287
6288    let mut out = alloc_zeros_f32(total_output, device)?;
6289    let cfg = launch_cfg(total_output)?;
6290    let outer_u32 = outer as u32;
6291    let axis_size_u32 = axis_size as u32;
6292    let inner_u32 = inner as u32;
6293    let total_u32 = total_output as u32;
6294
6295    unsafe {
6296        stream
6297            .launch_builder(&f)
6298            .arg(a.inner())
6299            .arg(out.inner_mut())
6300            .arg(&outer_u32)
6301            .arg(&axis_size_u32)
6302            .arg(&inner_u32)
6303            .arg(&total_u32)
6304            .launch(cfg)?;
6305    }
6306
6307    Ok(out)
6308}
6309
6310// ---------------------------------------------------------------------------
6311// Public API -- Strided split
6312// ---------------------------------------------------------------------------
6313
6314/// Extract a sub-tensor along one axis entirely on GPU.
6315///
6316/// Given an input buffer representing a tensor with `total_along_axis` elements
6317/// along the split axis, extracts the slice `[split_offset .. split_offset + split_size]`
6318/// along that axis.
6319///
6320/// - `inner_size` = product of dimensions after the split axis.
6321/// - `n` = total number of output elements (outer * split_size * inner_size).
6322///
6323/// # Errors
6324///
6325/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
6326/// - [`GpuError::Driver`] on CUDA runtime errors.
6327#[cfg(feature = "cuda")]
6328pub fn gpu_strided_split(
6329    input: &CudaBuffer<f32>,
6330    total_along_axis: usize,
6331    split_offset: usize,
6332    split_size: usize,
6333    inner_size: usize,
6334    n: usize,
6335    device: &GpuDevice,
6336) -> GpuResult<CudaBuffer<f32>> {
6337    use cudarc::driver::PushKernelArg;
6338
6339    validate_unary(input, device)?;
6340
6341    let ctx = device.context();
6342    let stream = device.stream();
6343
6344    let f = match crate::module_cache::get_or_compile(
6345        ctx,
6346        STRIDED_SPLIT_PTX,
6347        "strided_split_kernel",
6348        device.ordinal() as u32,
6349    ) {
6350        Ok(f) => f,
6351        Err(_) => {
6352            // CPU fallback
6353            let host = gpu_to_cpu(input, device)?;
6354            let outer = n / (split_size * inner_size);
6355            let mut result = vec![0.0f32; n];
6356            for (i, out) in result.iter_mut().enumerate() {
6357                let outer_idx = i / (split_size * inner_size);
6358                let within = i % (split_size * inner_size);
6359                let src_idx =
6360                    outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
6361                *out = host[src_idx];
6362            }
6363            let _ = outer;
6364            return cpu_to_gpu(&result, device);
6365        }
6366    };
6367
6368    let mut out = alloc_zeros_f32(n, device)?;
6369    let cfg = launch_cfg(n)?;
6370    let total_ax_u32 = total_along_axis as u32;
6371    let offset_u32 = split_offset as u32;
6372    let split_sz_u32 = split_size as u32;
6373    let inner_u32 = inner_size as u32;
6374    let n_u32 = n as u32;
6375
6376    unsafe {
6377        stream
6378            .launch_builder(&f)
6379            .arg(input.inner())
6380            .arg(out.inner_mut())
6381            .arg(&total_ax_u32)
6382            .arg(&offset_u32)
6383            .arg(&split_sz_u32)
6384            .arg(&inner_u32)
6385            .arg(&n_u32)
6386            .launch(cfg)?;
6387    }
6388
6389    Ok(out)
6390}
6391
6392// ---------------------------------------------------------------------------
6393// Public API -- Strided cat
6394// ---------------------------------------------------------------------------
6395
6396/// Write a sub-tensor into a larger output buffer at an offset along one axis,
6397/// entirely on GPU.
6398///
6399/// Given an input buffer representing a chunk with `part_size` elements along
6400/// the cat axis, writes it into `output` at position `cat_offset` along that axis.
6401///
6402/// - `inner_size` = product of dimensions after the cat axis.
6403/// - `n` = total number of input elements (outer * part_size * inner_size).
6404///
6405/// # Safety
6406///
6407/// `output` must be large enough to hold the written region. The caller is
6408/// responsible for ensuring non-overlapping writes when multiple chunks are
6409/// written into the same output buffer.
6410///
6411/// # Errors
6412///
6413/// - [`GpuError::DeviceMismatch`] if buffers and `device` are on different devices.
6414/// - [`GpuError::Driver`] on CUDA runtime errors.
6415#[cfg(feature = "cuda")]
6416#[allow(clippy::too_many_arguments)]
6417pub fn gpu_strided_cat(
6418    input: &CudaBuffer<f32>,
6419    output: &mut CudaBuffer<f32>,
6420    total_along_axis: usize,
6421    cat_offset: usize,
6422    part_size: usize,
6423    inner_size: usize,
6424    n: usize,
6425    device: &GpuDevice,
6426) -> GpuResult<()> {
6427    use cudarc::driver::PushKernelArg;
6428
6429    validate_unary(input, device)?;
6430
6431    let ctx = device.context();
6432    let stream = device.stream();
6433
6434    let f = match crate::module_cache::get_or_compile(
6435        ctx,
6436        STRIDED_CAT_PTX,
6437        "strided_cat_kernel",
6438        device.ordinal() as u32,
6439    ) {
6440        Ok(f) => f,
6441        Err(_) => {
6442            // CPU fallback
6443            let host_in = gpu_to_cpu(input, device)?;
6444            let mut host_out = gpu_to_cpu(output, device)?;
6445            for (i, &val) in host_in.iter().enumerate().take(n) {
6446                let outer_idx = i / (part_size * inner_size);
6447                let within = i % (part_size * inner_size);
6448                let dst_idx =
6449                    outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
6450                host_out[dst_idx] = val;
6451            }
6452            *output = cpu_to_gpu(&host_out, device)?;
6453            return Ok(());
6454        }
6455    };
6456
6457    let cfg = launch_cfg(n)?;
6458    let total_ax_u32 = total_along_axis as u32;
6459    let offset_u32 = cat_offset as u32;
6460    let part_sz_u32 = part_size as u32;
6461    let inner_u32 = inner_size as u32;
6462    let n_u32 = n as u32;
6463
6464    unsafe {
6465        stream
6466            .launch_builder(&f)
6467            .arg(input.inner())
6468            .arg(output.inner_mut())
6469            .arg(&total_ax_u32)
6470            .arg(&offset_u32)
6471            .arg(&part_sz_u32)
6472            .arg(&inner_u32)
6473            .arg(&n_u32)
6474            .launch(cfg)?;
6475    }
6476
6477    Ok(())
6478}
6479
6480/// Scalar multiply: `out[i] = a[i] * scalar`.
6481///
6482/// Multiplies every element by a constant float value on the GPU.
6483///
6484/// # Errors
6485///
6486/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different CUDA devices.
6487/// - [`GpuError::Driver`] on CUDA runtime errors.
6488#[cfg(feature = "cuda")]
6489pub fn gpu_scale(
6490    a: &CudaBuffer<f32>,
6491    scalar: f32,
6492    device: &GpuDevice,
6493) -> GpuResult<CudaBuffer<f32>> {
6494    use cudarc::driver::PushKernelArg;
6495
6496    validate_unary(a, device)?;
6497
6498    let n = a.len();
6499    let ctx = device.context();
6500    let stream = device.stream();
6501
6502    let f = match crate::module_cache::get_or_compile(
6503        ctx,
6504        SCALE_PTX,
6505        "scale_kernel",
6506        device.ordinal() as u32,
6507    ) {
6508        Ok(f) => f,
6509        Err(_) => {
6510            // CPU fallback
6511            let host = gpu_to_cpu(a, device)?;
6512            let result: Vec<f32> = host.iter().map(|&x| x * scalar).collect();
6513            return cpu_to_gpu(&result, device);
6514        }
6515    };
6516
6517    let mut out = alloc_zeros_f32(n, device)?;
6518    let cfg = launch_cfg(n)?;
6519    let n_u32 = n as u32;
6520
6521    unsafe {
6522        stream
6523            .launch_builder(&f)
6524            .arg(a.inner())
6525            .arg(out.inner_mut())
6526            .arg(&scalar)
6527            .arg(&n_u32)
6528            .launch(cfg)?;
6529    }
6530
6531    Ok(out)
6532}
6533
6534// ---------------------------------------------------------------------------
6535// Public API -- softmax
6536// ---------------------------------------------------------------------------
6537
6538/// Row-wise softmax on GPU: one thread block per row, shared-memory reduction.
6539///
6540/// `rows` = product of all dims except the last. `cols` = last dim size.
6541#[cfg(feature = "cuda")]
6542pub fn gpu_softmax(
6543    input: &CudaBuffer<f32>,
6544    rows: usize,
6545    cols: usize,
6546    device: &GpuDevice,
6547) -> GpuResult<CudaBuffer<f32>> {
6548    use cudarc::driver::PushKernelArg;
6549
6550    validate_unary(input, device)?;
6551
6552    let ctx = device.context();
6553    let stream = device.stream();
6554
6555    let f = match crate::module_cache::get_or_compile(
6556        ctx,
6557        SOFTMAX_PTX,
6558        "softmax_kernel",
6559        device.ordinal() as u32,
6560    ) {
6561        Ok(f) => f,
6562        Err(_) => {
6563            // CPU fallback.
6564            let host = gpu_to_cpu(input, device)?;
6565            let mut out = vec![0.0f32; host.len()];
6566            for r in 0..rows {
6567                let base = r * cols;
6568                let mut max_v = f32::NEG_INFINITY;
6569                for c in 0..cols {
6570                    max_v = max_v.max(host[base + c]);
6571                }
6572                let mut sum = 0.0f32;
6573                for c in 0..cols {
6574                    let e = (host[base + c] - max_v).exp();
6575                    out[base + c] = e;
6576                    sum += e;
6577                }
6578                let inv = 1.0 / sum;
6579                for c in 0..cols {
6580                    out[base + c] *= inv;
6581                }
6582            }
6583            return cpu_to_gpu(&out, device);
6584        }
6585    };
6586
6587    let mut out = alloc_zeros_f32(rows * cols, device)?;
6588    let rows_u32 = rows as u32;
6589    let cols_u32 = cols as u32;
6590
6591    // One block per row, 256 threads per block.
6592    let cfg = LaunchConfig {
6593        grid_dim: ((rows as u32).max(1), 1, 1),
6594        block_dim: (256, 1, 1),
6595        shared_mem_bytes: 256 * 4, // sdata[256] f32
6596    };
6597
6598    unsafe {
6599        stream
6600            .launch_builder(&f)
6601            .arg(input.inner())
6602            .arg(out.inner_mut())
6603            .arg(&rows_u32)
6604            .arg(&cols_u32)
6605            .launch(cfg)?;
6606    }
6607
6608    Ok(out)
6609}
6610
6611// ---------------------------------------------------------------------------
6612// Public API -- dropout
6613// ---------------------------------------------------------------------------
6614
6615/// Inverted dropout on GPU: `out[i] = input[i] * scale` or `0` with probability `p`.
6616///
6617/// `threshold` = `(p * u32::MAX as f64) as u32` — the RNG cutoff.
6618/// `scale` = `1.0 / (1.0 - p)`.
6619/// `seed` = random seed for the RNG.
6620///
6621/// **Known limitation**: This kernel uses a simple per-element hash
6622/// (`tid * 2654435761 ^ seed` with xorshift mixing), not the full
6623/// Philox 4x32-10 counter-based RNG that PyTorch uses. A proper Philox
6624/// dropout kernel would generate the mask via `philox_uniform_kernel`
6625/// and then threshold — producing higher-quality randomness and exact
6626/// reproducibility across CPU/GPU. The current hash is sufficient for
6627/// training but should be upgraded for research requiring strict
6628/// statistical properties.
6629#[cfg(feature = "cuda")]
6630pub fn gpu_dropout(
6631    input: &CudaBuffer<f32>,
6632    threshold: u32,
6633    scale: f32,
6634    seed: u32,
6635    device: &GpuDevice,
6636) -> GpuResult<CudaBuffer<f32>> {
6637    use cudarc::driver::PushKernelArg;
6638
6639    validate_unary(input, device)?;
6640
6641    let n = input.len();
6642    let ctx = device.context();
6643    let stream = device.stream();
6644
6645    let f = match crate::module_cache::get_or_compile(
6646        ctx,
6647        DROPOUT_PTX,
6648        "dropout_kernel",
6649        device.ordinal() as u32,
6650    ) {
6651        Ok(f) => f,
6652        Err(_) => {
6653            // CPU fallback.
6654            let host = gpu_to_cpu(input, device)?;
6655            // Stateless per-element hash matching the GPU kernel: each element
6656            // independently computes its own pseudorandom value from (tid, seed)
6657            // with no state carried between elements.
6658            let result: Vec<f32> = host
6659                .iter()
6660                .enumerate()
6661                .map(|(i, &x)| {
6662                    let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
6663                    r ^= r << 13;
6664                    r ^= r >> 17;
6665                    r ^= r << 5;
6666                    if r < threshold { 0.0 } else { x * scale }
6667                })
6668                .collect();
6669            return cpu_to_gpu(&result, device);
6670        }
6671    };
6672
6673    let mut out = alloc_zeros_f32(n, device)?;
6674    let cfg = launch_cfg(n)?;
6675    let n_u32 = n as u32;
6676
6677    unsafe {
6678        stream
6679            .launch_builder(&f)
6680            .arg(input.inner())
6681            .arg(out.inner_mut())
6682            .arg(&n_u32)
6683            .arg(&threshold)
6684            .arg(&scale)
6685            .arg(&seed)
6686            .launch(cfg)?;
6687    }
6688
6689    Ok(out)
6690}
6691
6692// ---------------------------------------------------------------------------
6693// Public API -- 2D transpose
6694// ---------------------------------------------------------------------------
6695
6696/// 2D matrix transpose on GPU: `[M, N]` -> `[N, M]`.
6697#[cfg(feature = "cuda")]
6698pub fn gpu_transpose_2d(
6699    input: &CudaBuffer<f32>,
6700    m: usize,
6701    n: usize,
6702    device: &GpuDevice,
6703) -> GpuResult<CudaBuffer<f32>> {
6704    use cudarc::driver::PushKernelArg;
6705
6706    validate_unary(input, device)?;
6707
6708    let total = m * n;
6709    let ctx = device.context();
6710    let stream = device.stream();
6711
6712    let f = match crate::module_cache::get_or_compile(
6713        ctx,
6714        TRANSPOSE_2D_PTX,
6715        "transpose_2d_kernel",
6716        device.ordinal() as u32,
6717    ) {
6718        Ok(f) => f,
6719        Err(_) => {
6720            // CPU fallback.
6721            let host = gpu_to_cpu(input, device)?;
6722            let mut out = vec![0.0f32; total];
6723            for i in 0..m {
6724                for j in 0..n {
6725                    out[j * m + i] = host[i * n + j];
6726                }
6727            }
6728            return cpu_to_gpu(&out, device);
6729        }
6730    };
6731
6732    let mut out = alloc_zeros_f32(total, device)?;
6733    let cfg = launch_cfg(total)?;
6734    let m_u32 = m as u32;
6735    let n_u32 = n as u32;
6736    let total_u32 = total as u32;
6737
6738    unsafe {
6739        stream
6740            .launch_builder(&f)
6741            .arg(input.inner())
6742            .arg(out.inner_mut())
6743            .arg(&m_u32)
6744            .arg(&n_u32)
6745            .arg(&total_u32)
6746            .launch(cfg)?;
6747    }
6748
6749    Ok(out)
6750}
6751
6752// ---------------------------------------------------------------------------
6753// Public API -- 4D permute (0,2,1,3)
6754// ---------------------------------------------------------------------------
6755
6756/// Permute a 4D tensor from `[d0, d1, d2, d3]` to `[d0, d2, d1, d3]` on GPU.
6757/// Used for attention head reshaping: `[B, S, H, D_h]` -> `[B, H, S, D_h]`.
6758#[cfg(feature = "cuda")]
6759pub fn gpu_permute_0213(
6760    input: &CudaBuffer<f32>,
6761    d0: usize,
6762    d1: usize,
6763    d2: usize,
6764    d3: usize,
6765    device: &GpuDevice,
6766) -> GpuResult<CudaBuffer<f32>> {
6767    use cudarc::driver::PushKernelArg;
6768
6769    validate_unary(input, device)?;
6770
6771    let total = d0 * d1 * d2 * d3;
6772    let ctx = device.context();
6773    let stream = device.stream();
6774
6775    let f = match crate::module_cache::get_or_compile(
6776        ctx,
6777        PERMUTE_0213_PTX,
6778        "permute_0213_kernel",
6779        device.ordinal() as u32,
6780    ) {
6781        Ok(f) => f,
6782        Err(_) => {
6783            // CPU fallback.
6784            let host = gpu_to_cpu(input, device)?;
6785            let mut out = vec![0.0f32; total];
6786            for i0 in 0..d0 {
6787                for i1 in 0..d1 {
6788                    for i2 in 0..d2 {
6789                        for i3 in 0..d3 {
6790                            let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
6791                            let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
6792                            out[out_idx] = host[in_idx];
6793                        }
6794                    }
6795                }
6796            }
6797            return cpu_to_gpu(&out, device);
6798        }
6799    };
6800
6801    let mut out = alloc_zeros_f32(total, device)?;
6802    let cfg = launch_cfg(total)?;
6803    let d0_u32 = d0 as u32;
6804    let d1_u32 = d1 as u32;
6805    let d2_u32 = d2 as u32;
6806    let d3_u32 = d3 as u32;
6807    let total_u32 = total as u32;
6808
6809    unsafe {
6810        stream
6811            .launch_builder(&f)
6812            .arg(input.inner())
6813            .arg(out.inner_mut())
6814            .arg(&d0_u32)
6815            .arg(&d1_u32)
6816            .arg(&d2_u32)
6817            .arg(&d3_u32)
6818            .arg(&total_u32)
6819            .launch(cfg)?;
6820    }
6821
6822    Ok(out)
6823}
6824
6825// ---------------------------------------------------------------------------
6826// Public API -- Small matmul (bypasses cuBLAS JIT)
6827// ---------------------------------------------------------------------------
6828
6829/// Small matrix multiply using our own PTX kernel. Avoids cuBLAS JIT
6830/// compilation overhead for tiny matrices where JIT cost > compute cost.
6831///
6832/// `a`: `[M, K]`, `b`: `[K, N]` → `c`: `[M, N]`.
6833#[cfg(feature = "cuda")]
6834pub fn gpu_small_matmul(
6835    a: &CudaBuffer<f32>,
6836    b: &CudaBuffer<f32>,
6837    m: usize,
6838    k: usize,
6839    n: usize,
6840    device: &GpuDevice,
6841) -> GpuResult<CudaBuffer<f32>> {
6842    use cudarc::driver::PushKernelArg;
6843
6844    let total = m * n;
6845    let ctx = device.context();
6846    let stream = device.stream();
6847
6848    let f = match crate::module_cache::get_or_compile(
6849        ctx,
6850        SMALL_MATMUL_PTX,
6851        "small_matmul_kernel",
6852        device.ordinal() as u32,
6853    ) {
6854        Ok(f) => f,
6855        Err(_) => {
6856            // Fall back to cuBLAS if our kernel can't compile.
6857            return crate::blas::gpu_matmul_f32(a, b, m, k, n, device);
6858        }
6859    };
6860
6861    let mut c = alloc_zeros_f32(total, device)?;
6862    let cfg = launch_cfg(total)?;
6863    let m_u32 = m as u32;
6864    let k_u32 = k as u32;
6865    let n_u32 = n as u32;
6866    let total_u32 = total as u32;
6867
6868    unsafe {
6869        stream
6870            .launch_builder(&f)
6871            .arg(a.inner())
6872            .arg(b.inner())
6873            .arg(c.inner_mut())
6874            .arg(&m_u32)
6875            .arg(&k_u32)
6876            .arg(&n_u32)
6877            .arg(&total_u32)
6878            .launch(cfg)?;
6879    }
6880
6881    Ok(c)
6882}
6883
6884/// Small batched matmul: C[i] = A[i] @ B[i] for i in 0..batch.
6885/// Uses the small_matmul_kernel by reshaping the problem: treat it as a single
6886/// large matmul of [batch*M, K] @ [K, N] — but that doesn't work because B is
6887/// batched. Instead, we use a modified approach: thread `idx` computes element
6888/// (batch_i, row, col) where batch_i = idx / (M*N).
6889///
6890/// For simplicity and correctness, we fall back to cpu_bmm for now when
6891/// cuBLAS fails, but route through gpu_small_matmul for the single-matrix case.
6892#[cfg(feature = "cuda")]
6893pub fn gpu_small_bmm(
6894    a: &CudaBuffer<f32>,
6895    b: &CudaBuffer<f32>,
6896    batch: usize,
6897    m: usize,
6898    k: usize,
6899    n: usize,
6900    device: &GpuDevice,
6901) -> GpuResult<CudaBuffer<f32>> {
6902    // For batch=1, just use the single matmul kernel.
6903    if batch == 1 {
6904        return gpu_small_matmul(a, b, m, k, n, device);
6905    }
6906    // For batched case, fall back to cuBLAS (the batched PTX kernel is complex).
6907    // The main win is from the single-matrix decode case (batch=1 for attention scores).
6908    crate::blas::gpu_bmm_f32(a, b, batch, m, k, n, device)
6909}
6910
6911// ---------------------------------------------------------------------------
6912// Public API -- Embedding lookup (GPU-native)
6913// ---------------------------------------------------------------------------
6914
6915/// GPU embedding lookup: reads token ID from `idx` (single f32 on GPU),
6916/// gathers row from `weight` `[V, D]`, writes to `out` `[D]`.
6917/// Entire operation stays on GPU — no CPU involvement.
6918#[cfg(feature = "cuda")]
6919pub fn gpu_embed_lookup(
6920    idx: &CudaBuffer<f32>,
6921    weight: &CudaBuffer<f32>,
6922    d: usize,
6923    device: &GpuDevice,
6924) -> GpuResult<CudaBuffer<f32>> {
6925    use cudarc::driver::PushKernelArg;
6926
6927    let ctx = device.context();
6928    let stream = device.stream();
6929
6930    let f = match crate::module_cache::get_or_compile(
6931        ctx,
6932        EMBED_LOOKUP_PTX,
6933        "embed_lookup_kernel",
6934        device.ordinal() as u32,
6935    ) {
6936        Ok(f) => f,
6937        Err(_) => {
6938            // CPU fallback.
6939            let idx_host = gpu_to_cpu(idx, device)?;
6940            let weight_host = gpu_to_cpu(weight, device)?;
6941            let row = idx_host[0] as usize;
6942            let start = row * d;
6943            let out = weight_host[start..start + d].to_vec();
6944            return cpu_to_gpu(&out, device);
6945        }
6946    };
6947
6948    let mut out = alloc_zeros_f32(d, device)?;
6949    let cfg = launch_cfg(d)?;
6950    let d_u32 = d as u32;
6951
6952    unsafe {
6953        stream
6954            .launch_builder(&f)
6955            .arg(idx.inner())
6956            .arg(weight.inner())
6957            .arg(out.inner_mut())
6958            .arg(&d_u32)
6959            .launch(cfg)?;
6960    }
6961
6962    Ok(out)
6963}
6964
6965// ---------------------------------------------------------------------------
6966// Public API -- Slice write (for KV cache)
6967// ---------------------------------------------------------------------------
6968
6969/// Write `src` of shape `[N, D]` into row `pos` of `dst` of shape `[N, max_len, D]`.
6970/// This is an in-place GPU operation — `dst` is modified.
6971#[cfg(feature = "cuda")]
6972pub fn gpu_slice_write(
6973    src: &CudaBuffer<f32>,
6974    dst: &mut CudaBuffer<f32>,
6975    n_batch: usize,
6976    d: usize,
6977    max_len: usize,
6978    pos: usize,
6979    device: &GpuDevice,
6980) -> GpuResult<()> {
6981    use cudarc::driver::PushKernelArg;
6982
6983    let total = n_batch * d;
6984    let ctx = device.context();
6985    let stream = device.stream();
6986
6987    let f = match crate::module_cache::get_or_compile(
6988        ctx,
6989        SLICE_WRITE_PTX,
6990        "slice_write_kernel",
6991        device.ordinal() as u32,
6992    ) {
6993        Ok(f) => f,
6994        Err(_) => {
6995            // CPU fallback.
6996            let src_host = gpu_to_cpu(src, device)?;
6997            let mut dst_host = gpu_to_cpu(dst, device)?;
6998            for b in 0..n_batch {
6999                for di in 0..d {
7000                    dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
7001                }
7002            }
7003            let new_dst = cpu_to_gpu(&dst_host, device)?;
7004            *dst = new_dst;
7005            return Ok(());
7006        }
7007    };
7008
7009    let cfg = launch_cfg(total)?;
7010    let n_u32 = total as u32;
7011    let d_u32 = d as u32;
7012    let max_len_u32 = max_len as u32;
7013    let pos_u32 = pos as u32;
7014
7015    unsafe {
7016        stream
7017            .launch_builder(&f)
7018            .arg(src.inner())
7019            .arg(dst.inner_mut())
7020            .arg(&n_u32)
7021            .arg(&d_u32)
7022            .arg(&max_len_u32)
7023            .arg(&pos_u32)
7024            .launch(cfg)?;
7025    }
7026
7027    Ok(())
7028}
7029
7030// ---------------------------------------------------------------------------
7031// Public API -- Slice read (for KV cache)
7032// ---------------------------------------------------------------------------
7033
7034/// Read first `len` rows from each batch of `[N, max_len, D]` → `[N, len, D]`.
7035#[cfg(feature = "cuda")]
7036pub fn gpu_slice_read(
7037    src: &CudaBuffer<f32>,
7038    n_batch: usize,
7039    d: usize,
7040    len: usize,
7041    max_len: usize,
7042    device: &GpuDevice,
7043) -> GpuResult<CudaBuffer<f32>> {
7044    use cudarc::driver::PushKernelArg;
7045
7046    let total = n_batch * len * d;
7047    let ctx = device.context();
7048    let stream = device.stream();
7049
7050    let f = match crate::module_cache::get_or_compile(
7051        ctx,
7052        SLICE_READ_PTX,
7053        "slice_read_kernel",
7054        device.ordinal() as u32,
7055    ) {
7056        Ok(f) => f,
7057        Err(_) => {
7058            let host = gpu_to_cpu(src, device)?;
7059            let mut out = vec![0.0f32; total];
7060            for b in 0..n_batch {
7061                for r in 0..len {
7062                    for di in 0..d {
7063                        out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
7064                    }
7065                }
7066            }
7067            return cpu_to_gpu(&out, device);
7068        }
7069    };
7070
7071    let mut out = alloc_zeros_f32(total, device)?;
7072    let cfg = launch_cfg(total)?;
7073    let total_u32 = total as u32;
7074    let d_u32 = d as u32;
7075    let len_u32 = len as u32;
7076    let max_len_u32 = max_len as u32;
7077
7078    unsafe {
7079        stream
7080            .launch_builder(&f)
7081            .arg(src.inner())
7082            .arg(out.inner_mut())
7083            .arg(&total_u32)
7084            .arg(&d_u32)
7085            .arg(&len_u32)
7086            .arg(&max_len_u32)
7087            .launch(cfg)?;
7088    }
7089
7090    Ok(out)
7091}
7092
7093// ---------------------------------------------------------------------------
7094// Public API -- GELU
7095// ---------------------------------------------------------------------------
7096
7097/// Elementwise GELU activation on GPU: `gelu(x) = x * sigmoid(1.702 * x)`.
7098#[cfg(feature = "cuda")]
7099pub fn gpu_gelu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7100    validate_unary(input, device)?;
7101    if let Some(out) = try_launch_unary(input, device, GELU_PTX, "gelu_kernel")? {
7102        return Ok(out);
7103    }
7104    cpu_fallback_unary(input, device, |x| {
7105        let s = 1.0 / (1.0 + (-1.702 * x).exp());
7106        x * s
7107    })
7108}
7109
7110/// Elementwise GELU activation on GPU using the tanh approximation:
7111/// `gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))`.
7112///
7113/// Matches PyTorch `nn.GELU(approximate="tanh")`.
7114#[cfg(feature = "cuda")]
7115pub fn gpu_gelu_tanh(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7116    validate_unary(input, device)?;
7117    if let Some(out) = try_launch_unary(input, device, GELU_TANH_PTX, "gelu_tanh_kernel")? {
7118        return Ok(out);
7119    }
7120    cpu_fallback_unary(input, device, |x| {
7121        let sqrt_2_over_pi: f32 = 0.7978845608;
7122        let c: f32 = 0.044715;
7123        let inner = sqrt_2_over_pi * (x + c * x * x * x);
7124        0.5 * x * (1.0 + inner.tanh())
7125    })
7126}
7127
7128/// Elementwise GELU activation on GPU using exact erf:
7129/// `gelu(x) = x * 0.5 * (1 + erf(x / sqrt(2)))`.
7130///
7131/// Matches PyTorch `nn.GELU(approximate="none")` (the default).
7132#[cfg(feature = "cuda")]
7133pub fn gpu_gelu_erf(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7134    validate_unary(input, device)?;
7135    if let Some(out) = try_launch_unary(input, device, GELU_ERF_PTX, "gelu_erf_kernel")? {
7136        return Ok(out);
7137    }
7138    cpu_fallback_unary(input, device, |x| {
7139        // Abramowitz & Stegun 7.1.26 erf approximation (matches PTX kernel)
7140        let z = x * std::f32::consts::FRAC_1_SQRT_2;
7141        let az = z.abs();
7142        let t = 1.0 / (1.0 + 0.3275911 * az);
7143        let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
7144        let erf_abs = 1.0 - poly * (-az * az).exp();
7145        let erf_val = if z < 0.0 { -erf_abs } else { erf_abs };
7146        x * 0.5 * (1.0 + erf_val)
7147    })
7148}
7149
7150/// GELU backward for the tanh approximation mode.
7151/// Let `u = sqrt(2/π) * (x + 0.044715 * x³)`, `t = tanh(u)`.
7152/// `d/dx = 0.5 * (1 + t) + 0.5 * x * (1 - t²) * sqrt(2/π) * (1 + 3*0.044715*x²)`
7153#[cfg(feature = "cuda")]
7154pub fn gpu_gelu_backward_tanh(
7155    grad: &CudaBuffer<f32>,
7156    input: &CudaBuffer<f32>,
7157    device: &GpuDevice,
7158) -> GpuResult<CudaBuffer<f32>> {
7159    validate_binary(grad, input, device)?;
7160    if let Some(out) = try_launch_binary(
7161        grad,
7162        input,
7163        device,
7164        GELU_BACKWARD_TANH_PTX,
7165        "gelu_backward_tanh_kernel",
7166    )? {
7167        return Ok(out);
7168    }
7169    // CPU fallback
7170    let grad_host = gpu_to_cpu(grad, device)?;
7171    let input_host = gpu_to_cpu(input, device)?;
7172    let result: Vec<f32> = grad_host
7173        .iter()
7174        .zip(input_host.iter())
7175        .map(|(&g, &x)| {
7176            let sqrt_2_over_pi: f32 = 0.7978845608;
7177            let c: f32 = 0.044715;
7178            let c3: f32 = 0.134145;
7179            let u = sqrt_2_over_pi * (x + c * x * x * x);
7180            let t = u.tanh();
7181            let dt = 1.0 - t * t;
7182            let d_inner = sqrt_2_over_pi * (1.0 + c3 * x * x);
7183            g * (0.5 * (1.0 + t) + 0.5 * x * dt * d_inner)
7184        })
7185        .collect();
7186    cpu_to_gpu(&result, device)
7187}
7188
7189// ---------------------------------------------------------------------------
7190// Public API -- elementwise transcendentals & math ops
7191// ---------------------------------------------------------------------------
7192
7193/// Elementwise division: `out[i] = a[i] / b[i]`.
7194#[cfg(feature = "cuda")]
7195pub fn gpu_div(
7196    a: &CudaBuffer<f32>,
7197    b: &CudaBuffer<f32>,
7198    device: &GpuDevice,
7199) -> GpuResult<CudaBuffer<f32>> {
7200    validate_binary(a, b, device)?;
7201
7202    if let Some(out) = try_launch_binary(a, b, device, DIV_PTX, "div_kernel")? {
7203        return Ok(out);
7204    }
7205
7206    // CPU fallback
7207    let a_host = gpu_to_cpu(a, device)?;
7208    let b_host = gpu_to_cpu(b, device)?;
7209    let result: Vec<f32> = a_host
7210        .iter()
7211        .zip(b_host.iter())
7212        .map(|(&x, &y)| x / y)
7213        .collect();
7214    cpu_to_gpu(&result, device)
7215}
7216
7217/// Elementwise exponential: `out[i] = exp(a[i])`.
7218#[cfg(feature = "cuda")]
7219pub fn gpu_exp(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7220    validate_unary(a, device)?;
7221    if let Some(out) = try_launch_unary(a, device, EXP_PTX, "exp_kernel")? {
7222        return Ok(out);
7223    }
7224    cpu_fallback_unary(a, device, |x| x.exp())
7225}
7226
7227/// Elementwise natural log: `out[i] = ln(a[i])`.
7228#[cfg(feature = "cuda")]
7229pub fn gpu_log(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7230    validate_unary(a, device)?;
7231    if let Some(out) = try_launch_unary(a, device, LOG_PTX, "log_kernel")? {
7232        return Ok(out);
7233    }
7234    cpu_fallback_unary(a, device, |x| x.ln())
7235}
7236
7237/// Elementwise square root: `out[i] = sqrt(a[i])`.
7238#[cfg(feature = "cuda")]
7239pub fn gpu_sqrt(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7240    validate_unary(a, device)?;
7241    if let Some(out) = try_launch_unary(a, device, SQRT_PTX, "sqrt_kernel")? {
7242        return Ok(out);
7243    }
7244    cpu_fallback_unary(a, device, |x| x.sqrt())
7245}
7246
7247/// Elementwise power: `out[i] = a[i] ^ exponent`.
7248#[cfg(feature = "cuda")]
7249pub fn gpu_pow(
7250    a: &CudaBuffer<f32>,
7251    exponent: f32,
7252    device: &GpuDevice,
7253) -> GpuResult<CudaBuffer<f32>> {
7254    use cudarc::driver::PushKernelArg;
7255
7256    validate_unary(a, device)?;
7257
7258    let n = a.len();
7259    let ctx = device.context();
7260    let stream = device.stream();
7261
7262    let f = match crate::module_cache::get_or_compile(
7263        ctx,
7264        POW_PTX,
7265        "pow_kernel",
7266        device.ordinal() as u32,
7267    ) {
7268        Ok(f) => f,
7269        Err(_) => {
7270            let host = gpu_to_cpu(a, device)?;
7271            let result: Vec<f32> = host.iter().map(|&x| x.powf(exponent)).collect();
7272            return cpu_to_gpu(&result, device);
7273        }
7274    };
7275
7276    let mut out = alloc_zeros_f32(n, device)?;
7277    let cfg = launch_cfg(n)?;
7278    let n_u32 = n as u32;
7279
7280    unsafe {
7281        stream
7282            .launch_builder(&f)
7283            .arg(a.inner())
7284            .arg(out.inner_mut())
7285            .arg(&exponent)
7286            .arg(&n_u32)
7287            .launch(cfg)?;
7288    }
7289
7290    Ok(out)
7291}
7292
7293/// Elementwise absolute value: `out[i] = |a[i]|`.
7294#[cfg(feature = "cuda")]
7295pub fn gpu_abs(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7296    validate_unary(a, device)?;
7297    if let Some(out) = try_launch_unary(a, device, ABS_PTX, "abs_kernel")? {
7298        return Ok(out);
7299    }
7300    cpu_fallback_unary(a, device, |x| x.abs())
7301}
7302
7303/// Elementwise sigmoid: `out[i] = 1 / (1 + exp(-a[i]))`.
7304#[cfg(feature = "cuda")]
7305pub fn gpu_sigmoid(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7306    validate_unary(a, device)?;
7307    if let Some(out) = try_launch_unary(a, device, SIGMOID_PTX, "sigmoid_kernel")? {
7308        return Ok(out);
7309    }
7310    cpu_fallback_unary(a, device, |x| 1.0 / (1.0 + (-x).exp()))
7311}
7312
7313/// Elementwise tanh: `out[i] = tanh(a[i])`.
7314#[cfg(feature = "cuda")]
7315pub fn gpu_tanh(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7316    validate_unary(a, device)?;
7317    if let Some(out) = try_launch_unary(a, device, TANH_PTX, "tanh_kernel")? {
7318        return Ok(out);
7319    }
7320    cpu_fallback_unary(a, device, |x| x.tanh())
7321}
7322
7323// ---------------------------------------------------------------------------
7324// Public API -- fused Adam optimizer step
7325// ---------------------------------------------------------------------------
7326
7327/// Fused Adam optimizer step: updates param, exp_avg, and exp_avg_sq in-place
7328/// in a single kernel launch.
7329///
7330/// All four buffers must have the same length `n`. `param`, `exp_avg`, and
7331/// `exp_avg_sq` are modified in-place. `grad` is read-only.
7332#[cfg(feature = "cuda")]
7333#[allow(clippy::too_many_arguments)]
7334pub fn gpu_fused_adam(
7335    param: &mut CudaBuffer<f32>,
7336    grad: &CudaBuffer<f32>,
7337    exp_avg: &mut CudaBuffer<f32>,
7338    exp_avg_sq: &mut CudaBuffer<f32>,
7339    beta1: f32,
7340    beta2: f32,
7341    lr: f32,
7342    eps: f32,
7343    bc1: f32,
7344    bc2: f32,
7345    weight_decay: f32,
7346    device: &GpuDevice,
7347) -> GpuResult<()> {
7348    use cudarc::driver::PushKernelArg;
7349
7350    let n = param.len();
7351    if grad.len() != n || exp_avg.len() != n || exp_avg_sq.len() != n {
7352        return Err(GpuError::LengthMismatch {
7353            a: n,
7354            b: grad.len(),
7355        });
7356    }
7357
7358    let ctx = device.context();
7359    let stream = device.stream();
7360
7361    let f = match crate::module_cache::get_or_compile(
7362        ctx,
7363        FUSED_ADAM_PTX,
7364        "fused_adam_kernel",
7365        device.ordinal() as u32,
7366    ) {
7367        Ok(f) => f,
7368        Err(_) => {
7369            // CPU fallback: download, compute, upload.
7370            let mut p_host = gpu_to_cpu(param, device)?;
7371            let g_host = gpu_to_cpu(grad, device)?;
7372            let mut m_host = gpu_to_cpu(exp_avg, device)?;
7373            let mut v_host = gpu_to_cpu(exp_avg_sq, device)?;
7374
7375            for i in 0..n {
7376                let mut g = g_host[i];
7377                if weight_decay > 0.0 {
7378                    g += weight_decay * p_host[i];
7379                }
7380                m_host[i] = beta1 * m_host[i] + (1.0 - beta1) * g;
7381                v_host[i] = beta2 * v_host[i] + (1.0 - beta2) * g * g;
7382                let m_hat = m_host[i] / bc1;
7383                let v_hat = v_host[i] / bc2;
7384                p_host[i] -= lr * m_hat / (v_hat.sqrt() + eps);
7385            }
7386
7387            *param = cpu_to_gpu(&p_host, device)?;
7388            *exp_avg = cpu_to_gpu(&m_host, device)?;
7389            *exp_avg_sq = cpu_to_gpu(&v_host, device)?;
7390            return Ok(());
7391        }
7392    };
7393
7394    let cfg = launch_cfg(n)?;
7395    let n_u32 = n as u32;
7396
7397    unsafe {
7398        stream
7399            .launch_builder(&f)
7400            .arg(param.inner_mut())
7401            .arg(grad.inner())
7402            .arg(exp_avg.inner_mut())
7403            .arg(exp_avg_sq.inner_mut())
7404            .arg(&beta1)
7405            .arg(&beta2)
7406            .arg(&lr)
7407            .arg(&eps)
7408            .arg(&bc1)
7409            .arg(&bc2)
7410            .arg(&weight_decay)
7411            .arg(&n_u32)
7412            .launch(cfg)?;
7413    }
7414
7415    Ok(())
7416}
7417
7418/// Stub -- always returns [`GpuError::NoCudaFeature`].
7419#[cfg(not(feature = "cuda"))]
7420#[allow(clippy::too_many_arguments)]
7421pub fn gpu_fused_adam(
7422    _param: &mut CudaBuffer<f32>,
7423    _grad: &CudaBuffer<f32>,
7424    _exp_avg: &mut CudaBuffer<f32>,
7425    _exp_avg_sq: &mut CudaBuffer<f32>,
7426    _beta1: f32,
7427    _beta2: f32,
7428    _lr: f32,
7429    _eps: f32,
7430    _bc1: f32,
7431    _bc2: f32,
7432    _weight_decay: f32,
7433    _device: &GpuDevice,
7434) -> GpuResult<()> {
7435    Err(GpuError::NoCudaFeature)
7436}
7437
7438// ---------------------------------------------------------------------------
7439// Public API -- fused GRU cell
7440// ---------------------------------------------------------------------------
7441
7442/// Fused GRU cell forward: takes pre-computed gate matrices and produces
7443/// new hidden state + workspace for backward.
7444///
7445/// Inputs:
7446/// - `input_gates`: `[batch, 3*hsz]` — result of `x @ W_ih^T`
7447/// - `hidden_gates`: `[batch, 3*hsz]` — result of `h @ W_hh^T`
7448/// - `bias_ih`: `[3*hsz]` — input bias
7449/// - `bias_hh`: `[3*hsz]` — hidden bias
7450/// - `hx`: `[batch, hsz]` — previous hidden state
7451///
7452/// Outputs:
7453/// - `hy`: `[batch, hsz]` — new hidden state
7454/// - `workspace`: `[batch, 5*hsz]` — saved for backward (r, z, n, hx, hn+b2n)
7455#[cfg(feature = "cuda")]
7456pub fn gpu_fused_gru_forward(
7457    input_gates: &CudaBuffer<f32>,
7458    hidden_gates: &CudaBuffer<f32>,
7459    bias_ih: &CudaBuffer<f32>,
7460    bias_hh: &CudaBuffer<f32>,
7461    hx: &CudaBuffer<f32>,
7462    hsz: usize,
7463    device: &GpuDevice,
7464) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
7465    use cudarc::driver::PushKernelArg;
7466
7467    let total = hx.len(); // batch * hsz
7468    let batch = total / hsz;
7469
7470    let ctx = device.context();
7471    let stream = device.stream();
7472
7473    let f = match crate::module_cache::get_or_compile(
7474        ctx,
7475        FUSED_GRU_FORWARD_PTX,
7476        "fused_gru_forward_kernel",
7477        device.ordinal() as u32,
7478    ) {
7479        Ok(f) => f,
7480        Err(_) => {
7481            return Err(GpuError::PtxCompileFailed {
7482                kernel: "fused_gru_forward_kernel",
7483            });
7484        }
7485    };
7486
7487    let mut hy = alloc_zeros_f32(total, device)?;
7488    let mut workspace = alloc_zeros_f32(batch * 5 * hsz, device)?;
7489
7490    let cfg = launch_cfg(total)?;
7491    let hsz_u32 = hsz as u32;
7492    let total_u32 = total as u32;
7493
7494    unsafe {
7495        stream
7496            .launch_builder(&f)
7497            .arg(input_gates.inner())
7498            .arg(hidden_gates.inner())
7499            .arg(bias_ih.inner())
7500            .arg(bias_hh.inner())
7501            .arg(hx.inner())
7502            .arg(hy.inner_mut())
7503            .arg(workspace.inner_mut())
7504            .arg(&hsz_u32)
7505            .arg(&total_u32)
7506            .launch(cfg)?;
7507    }
7508
7509    Ok((hy, workspace))
7510}
7511
7512/// Stub -- always returns [`GpuError::NoCudaFeature`].
7513#[cfg(not(feature = "cuda"))]
7514pub fn gpu_fused_gru_forward(
7515    _input_gates: &CudaBuffer<f32>,
7516    _hidden_gates: &CudaBuffer<f32>,
7517    _bias_ih: &CudaBuffer<f32>,
7518    _bias_hh: &CudaBuffer<f32>,
7519    _hx: &CudaBuffer<f32>,
7520    _hsz: usize,
7521    _device: &GpuDevice,
7522) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
7523    Err(GpuError::NoCudaFeature)
7524}
7525
7526// ---------------------------------------------------------------------------
7527// Public API -- MaxPool2d / AvgPool2d
7528// ---------------------------------------------------------------------------
7529
7530/// MaxPool2d forward on GPU. One thread per output element.
7531#[cfg(feature = "cuda")]
7532#[allow(clippy::too_many_arguments)]
7533pub fn gpu_maxpool2d(
7534    input: &CudaBuffer<f32>,
7535    batch: usize,
7536    channels: usize,
7537    h_in: usize,
7538    w_in: usize,
7539    kh: usize,
7540    kw: usize,
7541    sh: usize,
7542    sw: usize,
7543    ph: usize,
7544    pw: usize,
7545    device: &GpuDevice,
7546) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7547    use cudarc::driver::PushKernelArg;
7548
7549    let h_out = (h_in + 2 * ph - kh) / sh + 1;
7550    let w_out = (w_in + 2 * pw - kw) / sw + 1;
7551    let total = batch * channels * h_out * w_out;
7552
7553    let ctx = device.context();
7554    let stream = device.stream();
7555
7556    let f = match crate::module_cache::get_or_compile(
7557        ctx, MAXPOOL2D_PTX, "maxpool2d_forward_kernel", device.ordinal() as u32,
7558    ) {
7559        Ok(f) => f,
7560        Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "maxpool2d_forward_kernel" }),
7561    };
7562
7563    let mut out = alloc_zeros_f32(total, device)?;
7564    let cfg = launch_cfg(total)?;
7565
7566    let (batch_u32, ch_u32) = (batch as u32, channels as u32);
7567    let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
7568    let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
7569    let (kh_u32, kw_u32) = (kh as u32, kw as u32);
7570    let (sh_u32, sw_u32) = (sh as u32, sw as u32);
7571    let (ph_u32, pw_u32) = (ph as u32, pw as u32);
7572    let total_u32 = total as u32;
7573
7574    unsafe {
7575        stream.launch_builder(&f)
7576            .arg(input.inner())
7577            .arg(out.inner_mut())
7578            .arg(&batch_u32).arg(&ch_u32)
7579            .arg(&h_in_u32).arg(&w_in_u32)
7580            .arg(&h_out_u32).arg(&w_out_u32)
7581            .arg(&kh_u32).arg(&kw_u32)
7582            .arg(&sh_u32).arg(&sw_u32)
7583            .arg(&ph_u32).arg(&pw_u32)
7584            .arg(&total_u32)
7585            .launch(cfg)?;
7586    }
7587
7588    Ok((out, [batch, channels, h_out, w_out]))
7589}
7590
7591/// Stub.
7592#[cfg(not(feature = "cuda"))]
7593#[allow(clippy::too_many_arguments)]
7594pub fn gpu_maxpool2d(
7595    _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
7596    _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
7597    _sh: usize, _sw: usize, _ph: usize, _pw: usize,
7598    _device: &GpuDevice,
7599) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7600    Err(GpuError::NoCudaFeature)
7601}
7602
7603/// AvgPool2d forward on GPU. One thread per output element.
7604#[cfg(feature = "cuda")]
7605#[allow(clippy::too_many_arguments)]
7606pub fn gpu_avgpool2d(
7607    input: &CudaBuffer<f32>,
7608    batch: usize,
7609    channels: usize,
7610    h_in: usize,
7611    w_in: usize,
7612    kh: usize,
7613    kw: usize,
7614    sh: usize,
7615    sw: usize,
7616    ph: usize,
7617    pw: usize,
7618    device: &GpuDevice,
7619) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7620    use cudarc::driver::PushKernelArg;
7621
7622    let h_out = (h_in + 2 * ph - kh) / sh + 1;
7623    let w_out = (w_in + 2 * pw - kw) / sw + 1;
7624    let total = batch * channels * h_out * w_out;
7625
7626    let ctx = device.context();
7627    let stream = device.stream();
7628
7629    let f = match crate::module_cache::get_or_compile(
7630        ctx, AVGPOOL2D_PTX, "avgpool2d_forward_kernel", device.ordinal() as u32,
7631    ) {
7632        Ok(f) => f,
7633        Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "avgpool2d_forward_kernel" }),
7634    };
7635
7636    let mut out = alloc_zeros_f32(total, device)?;
7637    let cfg = launch_cfg(total)?;
7638
7639    let (batch_u32, ch_u32) = (batch as u32, channels as u32);
7640    let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
7641    let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
7642    let (kh_u32, kw_u32) = (kh as u32, kw as u32);
7643    let (sh_u32, sw_u32) = (sh as u32, sw as u32);
7644    let (ph_u32, pw_u32) = (ph as u32, pw as u32);
7645    let total_u32 = total as u32;
7646
7647    unsafe {
7648        stream.launch_builder(&f)
7649            .arg(input.inner())
7650            .arg(out.inner_mut())
7651            .arg(&batch_u32).arg(&ch_u32)
7652            .arg(&h_in_u32).arg(&w_in_u32)
7653            .arg(&h_out_u32).arg(&w_out_u32)
7654            .arg(&kh_u32).arg(&kw_u32)
7655            .arg(&sh_u32).arg(&sw_u32)
7656            .arg(&ph_u32).arg(&pw_u32)
7657            .arg(&total_u32)
7658            .launch(cfg)?;
7659    }
7660
7661    Ok((out, [batch, channels, h_out, w_out]))
7662}
7663
7664/// Stub.
7665#[cfg(not(feature = "cuda"))]
7666#[allow(clippy::too_many_arguments)]
7667pub fn gpu_avgpool2d(
7668    _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
7669    _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
7670    _sh: usize, _sw: usize, _ph: usize, _pw: usize,
7671    _device: &GpuDevice,
7672) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7673    Err(GpuError::NoCudaFeature)
7674}
7675
7676// ---------------------------------------------------------------------------
7677// Public API -- BatchNorm2d
7678// ---------------------------------------------------------------------------
7679
7680/// BatchNorm2d forward on GPU (placeholder — kernel pass-1 indexing needs
7681/// refinement). Currently validates the kernel compiles and falls back to
7682/// returning an error so callers use the CPU path.
7683#[cfg(feature = "cuda")]
7684#[allow(clippy::too_many_arguments)]
7685pub fn gpu_batchnorm_forward(
7686    _input: &CudaBuffer<f32>,
7687    _weight: &CudaBuffer<f32>,
7688    _bias: &CudaBuffer<f32>,
7689    _running_mean: &mut CudaBuffer<f32>,
7690    _running_var: &mut CudaBuffer<f32>,
7691    _channels: usize,
7692    _spatial: usize,
7693    _eps: f32,
7694    _momentum: f32,
7695    _training: bool,
7696    device: &GpuDevice,
7697) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7698    // Validate the PTX compiles (catches syntax errors at first call).
7699    let ctx = device.context();
7700    let _f = crate::module_cache::get_or_compile(
7701        ctx,
7702        BATCHNORM_FORWARD_PTX,
7703        "batchnorm_forward_kernel",
7704        device.ordinal() as u32,
7705    );
7706    // Full implementation pending — pass-1 loop indexing needs refinement.
7707    Err(GpuError::ShapeMismatch {
7708        op: "batchnorm_forward",
7709        expected: vec![0],
7710        got: vec![1],
7711    })
7712}
7713
7714/// Stub.
7715#[cfg(not(feature = "cuda"))]
7716#[allow(clippy::too_many_arguments)]
7717pub fn gpu_batchnorm_forward(
7718    _input: &CudaBuffer<f32>,
7719    _weight: &CudaBuffer<f32>,
7720    _bias: &CudaBuffer<f32>,
7721    _running_mean: &mut CudaBuffer<f32>,
7722    _running_var: &mut CudaBuffer<f32>,
7723    _channels: usize,
7724    _spatial: usize,
7725    _eps: f32,
7726    _momentum: f32,
7727    _training: bool,
7728    _device: &GpuDevice,
7729) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7730    Err(GpuError::NoCudaFeature)
7731}
7732
7733// ---------------------------------------------------------------------------
7734// Public API -- LayerNorm
7735// ---------------------------------------------------------------------------
7736
7737/// Row-wise layer normalization on GPU.
7738///
7739/// `input`: `[rows * cols]`, `weight`/`bias`: `[cols]`.
7740/// Output: normalized and affine-transformed `[rows * cols]`.
7741#[cfg(feature = "cuda")]
7742pub fn gpu_layernorm(
7743    input: &CudaBuffer<f32>,
7744    weight: &CudaBuffer<f32>,
7745    bias: &CudaBuffer<f32>,
7746    rows: usize,
7747    cols: usize,
7748    eps: f32,
7749    device: &GpuDevice,
7750) -> GpuResult<CudaBuffer<f32>> {
7751    use cudarc::driver::PushKernelArg;
7752
7753    validate_unary(input, device)?;
7754
7755    let ctx = device.context();
7756    let stream = device.stream();
7757
7758    let f = match crate::module_cache::get_or_compile(
7759        ctx,
7760        LAYERNORM_PTX,
7761        "layernorm_kernel",
7762        device.ordinal() as u32,
7763    ) {
7764        Ok(f) => f,
7765        Err(e) => {
7766            eprintln!("ferrotorch-gpu: LayerNorm PTX compilation failed ({e:?}), CPU fallback");
7767            std::fs::write("/tmp/layernorm_debug.ptx", LAYERNORM_PTX).ok();
7768            eprintln!(
7769                "ferrotorch-gpu: dumped PTX to /tmp/layernorm_debug.ptx ({} bytes)",
7770                LAYERNORM_PTX.len()
7771            );
7772            let h_in = gpu_to_cpu(input, device)?;
7773            let h_w = gpu_to_cpu(weight, device)?;
7774            let h_b = gpu_to_cpu(bias, device)?;
7775            let mut out = vec![0.0f32; rows * cols];
7776            for r in 0..rows {
7777                let base = r * cols;
7778                let slice = &h_in[base..base + cols];
7779                let mean: f32 = slice.iter().sum::<f32>() / cols as f32;
7780                let var: f32 =
7781                    slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / cols as f32;
7782                let inv_std = 1.0 / (var + eps).sqrt();
7783                for c in 0..cols {
7784                    let normed = (slice[c] - mean) * inv_std;
7785                    out[base + c] = h_w[c] * normed + h_b[c];
7786                }
7787            }
7788            return cpu_to_gpu(&out, device);
7789        }
7790    };
7791
7792    let mut out = alloc_zeros_f32(rows * cols, device)?;
7793    let rows_u32 = rows as u32;
7794    let cols_u32 = cols as u32;
7795
7796    let cfg = LaunchConfig {
7797        grid_dim: ((rows as u32).max(1), 1, 1),
7798        block_dim: (256, 1, 1),
7799        shared_mem_bytes: 256 * 4,
7800    };
7801
7802    unsafe {
7803        stream
7804            .launch_builder(&f)
7805            .arg(input.inner())
7806            .arg(out.inner_mut())
7807            .arg(weight.inner())
7808            .arg(bias.inner())
7809            .arg(&rows_u32)
7810            .arg(&cols_u32)
7811            .arg(&eps)
7812            .launch(cfg)?;
7813    }
7814
7815    Ok(out)
7816}
7817
7818// ---------------------------------------------------------------------------
7819// Public API -- LayerNorm backward
7820// ---------------------------------------------------------------------------
7821
7822/// LayerNorm backward pass on GPU.
7823///
7824/// Computes grad_input, grad_weight, and grad_bias entirely on GPU.
7825/// One block per batch element (row), 256 threads per block.
7826/// grad_weight and grad_bias are accumulated across batches via atomicAdd.
7827///
7828/// `input`: `[rows * cols]`, `grad_output`: `[rows * cols]`, `weight`: `[cols]`.
7829/// Returns: `(grad_input [rows * cols], grad_weight [cols], grad_bias [cols])`.
7830#[cfg(feature = "cuda")]
7831pub fn gpu_layernorm_backward(
7832    input: &CudaBuffer<f32>,
7833    grad_output: &CudaBuffer<f32>,
7834    weight: &CudaBuffer<f32>,
7835    rows: usize,
7836    cols: usize,
7837    eps: f32,
7838    device: &GpuDevice,
7839) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7840    use cudarc::driver::PushKernelArg;
7841
7842    validate_unary(input, device)?;
7843
7844    let ctx = device.context();
7845    let stream = device.stream();
7846
7847    let f = match crate::module_cache::get_or_compile(
7848        ctx,
7849        LAYERNORM_BACKWARD_PTX,
7850        "layernorm_backward_kernel",
7851        device.ordinal() as u32,
7852    ) {
7853        Ok(f) => f,
7854        Err(_) => {
7855            // CPU fallback
7856            let h_in = gpu_to_cpu(input, device)?;
7857            let h_go = gpu_to_cpu(grad_output, device)?;
7858            let h_w = gpu_to_cpu(weight, device)?;
7859            let mut grad_input = vec![0.0f32; rows * cols];
7860            let mut grad_weight = vec![0.0f32; cols];
7861            let mut grad_bias = vec![0.0f32; cols];
7862            let n_f = cols as f32;
7863            for r in 0..rows {
7864                let base = r * cols;
7865                let x_slice = &h_in[base..base + cols];
7866                let go_slice = &h_go[base..base + cols];
7867                let mean: f32 = x_slice.iter().sum::<f32>() / n_f;
7868                let var: f32 = x_slice
7869                    .iter()
7870                    .map(|&x| (x - mean) * (x - mean))
7871                    .sum::<f32>()
7872                    / n_f;
7873                let inv_std = 1.0 / (var + eps).sqrt();
7874                let mut sum1 = 0.0f32;
7875                let mut sum2 = 0.0f32;
7876                for c in 0..cols {
7877                    let x_hat = (x_slice[c] - mean) * inv_std;
7878                    let dl = go_slice[c] * h_w[c];
7879                    sum1 += dl;
7880                    sum2 += dl * x_hat;
7881                    grad_weight[c] += go_slice[c] * x_hat;
7882                    grad_bias[c] += go_slice[c];
7883                }
7884                let m1 = sum1 / n_f;
7885                let m2 = sum2 / n_f;
7886                for c in 0..cols {
7887                    let x_hat = (x_slice[c] - mean) * inv_std;
7888                    let dl = go_slice[c] * h_w[c];
7889                    grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
7890                }
7891            }
7892            let gi = cpu_to_gpu(&grad_input, device)?;
7893            let gw = cpu_to_gpu(&grad_weight, device)?;
7894            let gb = cpu_to_gpu(&grad_bias, device)?;
7895            return Ok((gi, gw, gb));
7896        }
7897    };
7898
7899    let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
7900    let mut grad_w = alloc_zeros_f32(cols, device)?;
7901    let mut grad_b = alloc_zeros_f32(cols, device)?;
7902    let rows_u32 = rows as u32;
7903    let cols_u32 = cols as u32;
7904
7905    // One block per row, 256 threads per block.
7906    let cfg = LaunchConfig {
7907        grid_dim: ((rows as u32).max(1), 1, 1),
7908        block_dim: (256, 1, 1),
7909        shared_mem_bytes: 256 * 4,
7910    };
7911
7912    unsafe {
7913        stream
7914            .launch_builder(&f)
7915            .arg(input.inner())
7916            .arg(grad_output.inner())
7917            .arg(weight.inner())
7918            .arg(grad_in.inner_mut())
7919            .arg(grad_w.inner_mut())
7920            .arg(grad_b.inner_mut())
7921            .arg(&rows_u32)
7922            .arg(&cols_u32)
7923            .arg(&eps)
7924            .launch(cfg)?;
7925    }
7926
7927    Ok((grad_in, grad_w, grad_b))
7928}
7929
7930/// Stub -- always returns [`GpuError::NoCudaFeature`].
7931#[cfg(not(feature = "cuda"))]
7932pub fn gpu_layernorm_backward(
7933    _input: &CudaBuffer<f32>,
7934    _grad_output: &CudaBuffer<f32>,
7935    _weight: &CudaBuffer<f32>,
7936    _rows: usize,
7937    _cols: usize,
7938    _eps: f32,
7939    _device: &GpuDevice,
7940) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7941    Err(GpuError::NoCudaFeature)
7942}
7943
7944// ===========================================================================
7945// _into variants — write to pre-allocated output buffers (zero allocation)
7946//
7947// These are used for CUDA graph capture, where all buffer addresses must be
7948// fixed at capture time. The PTX kernels are identical — only the Rust
7949// wrapper skips allocation.
7950// ===========================================================================
7951
7952/// Elementwise add into pre-allocated output: `out[i] = a[i] + b[i]`.
7953#[cfg(feature = "cuda")]
7954pub fn gpu_add_into(
7955    a: &CudaBuffer<f32>,
7956    b: &CudaBuffer<f32>,
7957    out: &mut CudaBuffer<f32>,
7958    device: &GpuDevice,
7959) -> GpuResult<()> {
7960    validate_binary(a, b, device)?;
7961    if out.len() < a.len() {
7962        return Err(GpuError::ShapeMismatch {
7963            op: "add_into",
7964            expected: vec![a.len()],
7965            got: vec![out.len()],
7966        });
7967    }
7968    if try_launch_binary_into(a, b, out, device, ADD_PTX, "add_kernel")? {
7969        return Ok(());
7970    }
7971    Err(GpuError::PtxCompileFailed {
7972        kernel: "add_kernel",
7973    })
7974}
7975
7976/// Elementwise mul into pre-allocated output: `out[i] = a[i] * b[i]`.
7977#[cfg(feature = "cuda")]
7978pub fn gpu_mul_into(
7979    a: &CudaBuffer<f32>,
7980    b: &CudaBuffer<f32>,
7981    out: &mut CudaBuffer<f32>,
7982    device: &GpuDevice,
7983) -> GpuResult<()> {
7984    validate_binary(a, b, device)?;
7985    if out.len() < a.len() {
7986        return Err(GpuError::ShapeMismatch {
7987            op: "mul_into",
7988            expected: vec![a.len()],
7989            got: vec![out.len()],
7990        });
7991    }
7992    if try_launch_binary_into(a, b, out, device, MUL_PTX, "mul_kernel")? {
7993        return Ok(());
7994    }
7995    Err(GpuError::PtxCompileFailed {
7996        kernel: "mul_kernel",
7997    })
7998}
7999
8000/// Scalar multiply into pre-allocated output: `out[i] = a[i] * scalar`.
8001#[cfg(feature = "cuda")]
8002pub fn gpu_scale_into(
8003    a: &CudaBuffer<f32>,
8004    scalar: f32,
8005    out: &mut CudaBuffer<f32>,
8006    device: &GpuDevice,
8007) -> GpuResult<()> {
8008    use cudarc::driver::PushKernelArg;
8009    validate_unary(a, device)?;
8010    let n = a.len();
8011    let ctx = device.context();
8012    let stream = device.stream();
8013    let f = crate::module_cache::get_or_compile(
8014        ctx,
8015        SCALE_PTX,
8016        "scale_kernel",
8017        device.ordinal() as u32,
8018    )
8019    .map_err(|_| GpuError::PtxCompileFailed {
8020        kernel: "scale_kernel",
8021    })?;
8022    let cfg = launch_cfg(n)?;
8023    let n_u32 = n as u32;
8024    unsafe {
8025        stream
8026            .launch_builder(&f)
8027            .arg(a.inner())
8028            .arg(out.inner_mut())
8029            .arg(&scalar)
8030            .arg(&n_u32)
8031            .launch(cfg)?;
8032    }
8033    Ok(())
8034}
8035
8036/// Check whether a GPU buffer contains any inf or NaN values.
8037///
8038/// Downloads the buffer contents to the host and scans for non-finite
8039/// values. This is correct for any buffer size and requires no custom
8040/// reduction kernel.
8041///
8042/// For a future optimization, a dedicated GPU reduction kernel could be
8043/// used to produce a single boolean flag on device, avoiding the full
8044/// download. The current approach is already much faster than the old
8045/// per-element CPU loop in `unscale_()` because the scaling itself
8046/// runs on GPU — only the inf/NaN check touches the host.
8047///
8048/// # Errors
8049///
8050/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different CUDA devices.
8051/// - [`GpuError::Driver`] on CUDA runtime errors.
8052#[cfg(feature = "cuda")]
8053pub fn gpu_has_inf_nan(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<bool> {
8054    let n = a.len();
8055    if n == 0 {
8056        return Ok(false);
8057    }
8058
8059    validate_unary(a, device)?;
8060
8061    let host: Vec<f32> = crate::transfer::gpu_to_cpu(a, device)?;
8062    Ok(host.iter().any(|v| !v.is_finite()))
8063}
8064
8065/// Stub -- always returns [`GpuError::NoCudaFeature`].
8066#[cfg(not(feature = "cuda"))]
8067pub fn gpu_has_inf_nan(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<bool> {
8068    Err(GpuError::NoCudaFeature)
8069}
8070
8071/// GELU into pre-allocated output.
8072#[cfg(feature = "cuda")]
8073pub fn gpu_gelu_into(
8074    a: &CudaBuffer<f32>,
8075    out: &mut CudaBuffer<f32>,
8076    device: &GpuDevice,
8077) -> GpuResult<()> {
8078    validate_unary(a, device)?;
8079    if try_launch_unary_into(a, out, device, GELU_PTX, "gelu_kernel")? {
8080        return Ok(());
8081    }
8082    Err(GpuError::PtxCompileFailed {
8083        kernel: "gelu_kernel",
8084    })
8085}
8086
8087/// Embedding lookup into pre-allocated output.
8088#[cfg(feature = "cuda")]
8089pub fn gpu_embed_lookup_into(
8090    idx: &CudaBuffer<f32>,
8091    weight: &CudaBuffer<f32>,
8092    d: usize,
8093    out: &mut CudaBuffer<f32>,
8094    device: &GpuDevice,
8095) -> GpuResult<()> {
8096    use cudarc::driver::PushKernelArg;
8097    let ctx = device.context();
8098    let stream = device.stream();
8099    let f = crate::module_cache::get_or_compile(
8100        ctx,
8101        EMBED_LOOKUP_PTX,
8102        "embed_lookup_kernel",
8103        device.ordinal() as u32,
8104    )
8105    .map_err(|_| GpuError::PtxCompileFailed {
8106        kernel: "embed_lookup_kernel",
8107    })?;
8108    let cfg = launch_cfg(d)?;
8109    let d_u32 = d as u32;
8110    unsafe {
8111        stream
8112            .launch_builder(&f)
8113            .arg(idx.inner())
8114            .arg(weight.inner())
8115            .arg(out.inner_mut())
8116            .arg(&d_u32)
8117            .launch(cfg)?;
8118    }
8119    Ok(())
8120}
8121
8122// ---------------------------------------------------------------------------
8123// Public API -- Batch embedding lookup (GPU-native)
8124// ---------------------------------------------------------------------------
8125
8126/// GPU batch embedding lookup: given `indices` (N f32 values on GPU) and
8127/// `weight` `[V, D]`, gather N rows to produce output `[N, D]`.
8128/// Entire operation stays on GPU -- no CPU roundtrip.
8129#[cfg(feature = "cuda")]
8130pub fn gpu_embed_lookup_batch(
8131    indices: &CudaBuffer<f32>,
8132    weight: &CudaBuffer<f32>,
8133    n: usize,
8134    d: usize,
8135    device: &GpuDevice,
8136) -> GpuResult<CudaBuffer<f32>> {
8137    use cudarc::driver::PushKernelArg;
8138
8139    let total = n * d;
8140    if total == 0 {
8141        return alloc_zeros_f32(0, device);
8142    }
8143
8144    let ctx = device.context();
8145    let stream = device.stream();
8146
8147    let f = match crate::module_cache::get_or_compile(
8148        ctx,
8149        EMBED_LOOKUP_BATCH_PTX,
8150        "embed_lookup_batch_kernel",
8151        device.ordinal() as u32,
8152    ) {
8153        Ok(f) => f,
8154        Err(_) => {
8155            // CPU fallback.
8156            let idx_host = gpu_to_cpu(indices, device)?;
8157            let weight_host = gpu_to_cpu(weight, device)?;
8158            let mut out = Vec::with_capacity(total);
8159            for &idx_f in &idx_host {
8160                let row = idx_f as usize;
8161                let start = row * d;
8162                out.extend_from_slice(&weight_host[start..start + d]);
8163            }
8164            return cpu_to_gpu(&out, device);
8165        }
8166    };
8167
8168    let mut out = alloc_zeros_f32(total, device)?;
8169    let cfg = launch_cfg(total)?;
8170    let d_u32 = d as u32;
8171    let total_u32 = total as u32;
8172
8173    unsafe {
8174        stream
8175            .launch_builder(&f)
8176            .arg(indices.inner())
8177            .arg(weight.inner())
8178            .arg(out.inner_mut())
8179            .arg(&d_u32)
8180            .arg(&total_u32)
8181            .launch(cfg)?;
8182    }
8183
8184    Ok(out)
8185}
8186
8187// ---------------------------------------------------------------------------
8188// Public API -- Scatter-add rows (for embedding backward, GPU-native)
8189// ---------------------------------------------------------------------------
8190
8191/// GPU scatter-add rows: given `grad_output` `[N, D]` and `indices` `[N]` (f32),
8192/// atomically accumulate into `grad_weight` `[V, D]` (pre-zeroed):
8193///   `grad_weight[indices[i], :] += grad_output[i, :]`
8194///
8195/// Duplicate indices accumulate correctly via atomic adds.
8196#[cfg(feature = "cuda")]
8197pub fn gpu_scatter_add_rows(
8198    grad_output: &CudaBuffer<f32>,
8199    indices: &CudaBuffer<f32>,
8200    num_embeddings: usize,
8201    d: usize,
8202    device: &GpuDevice,
8203) -> GpuResult<CudaBuffer<f32>> {
8204    use cudarc::driver::PushKernelArg;
8205
8206    let n = indices.len();
8207    let total = n * d;
8208
8209    if total == 0 {
8210        return alloc_zeros_f32(num_embeddings * d, device);
8211    }
8212
8213    let ctx = device.context();
8214    let stream = device.stream();
8215
8216    let f = match crate::module_cache::get_or_compile(
8217        ctx,
8218        SCATTER_ADD_ROWS_PTX,
8219        "scatter_add_rows_kernel",
8220        device.ordinal() as u32,
8221    ) {
8222        Ok(f) => f,
8223        Err(_) => {
8224            // CPU fallback.
8225            let go_host = gpu_to_cpu(grad_output, device)?;
8226            let idx_host = gpu_to_cpu(indices, device)?;
8227            let mut result = vec![0.0f32; num_embeddings * d];
8228            for (i, &idx_f) in idx_host.iter().enumerate() {
8229                let row = idx_f as usize;
8230                for j in 0..d {
8231                    result[row * d + j] += go_host[i * d + j];
8232                }
8233            }
8234            return cpu_to_gpu(&result, device);
8235        }
8236    };
8237
8238    let mut out = alloc_zeros_f32(num_embeddings * d, device)?;
8239    let cfg = launch_cfg(total)?;
8240    let d_u32 = d as u32;
8241    let total_u32 = total as u32;
8242
8243    unsafe {
8244        stream
8245            .launch_builder(&f)
8246            .arg(grad_output.inner())
8247            .arg(indices.inner())
8248            .arg(out.inner_mut())
8249            .arg(&d_u32)
8250            .arg(&total_u32)
8251            .launch(cfg)?;
8252    }
8253
8254    Ok(out)
8255}
8256
8257/// 2D transpose into pre-allocated output.
8258#[cfg(feature = "cuda")]
8259pub fn gpu_transpose_2d_into(
8260    a: &CudaBuffer<f32>,
8261    m: usize,
8262    n: usize,
8263    out: &mut CudaBuffer<f32>,
8264    device: &GpuDevice,
8265) -> GpuResult<()> {
8266    use cudarc::driver::PushKernelArg;
8267    let total = m * n;
8268    let ctx = device.context();
8269    let stream = device.stream();
8270    let f = crate::module_cache::get_or_compile(
8271        ctx,
8272        TRANSPOSE_2D_PTX,
8273        "transpose_2d_kernel",
8274        device.ordinal() as u32,
8275    )
8276    .map_err(|_| GpuError::PtxCompileFailed {
8277        kernel: "transpose_2d_kernel",
8278    })?;
8279    let cfg = launch_cfg(total)?;
8280    let m_u32 = m as u32;
8281    let n_u32 = n as u32;
8282    let total_u32 = total as u32;
8283    unsafe {
8284        stream
8285            .launch_builder(&f)
8286            .arg(a.inner())
8287            .arg(out.inner_mut())
8288            .arg(&m_u32)
8289            .arg(&n_u32)
8290            .arg(&total_u32)
8291            .launch(cfg)?;
8292    }
8293    Ok(())
8294}
8295
8296/// Permute (0,2,1,3) into pre-allocated output.
8297#[cfg(feature = "cuda")]
8298pub fn gpu_permute_0213_into(
8299    a: &CudaBuffer<f32>,
8300    d0: usize,
8301    d1: usize,
8302    d2: usize,
8303    d3: usize,
8304    out: &mut CudaBuffer<f32>,
8305    device: &GpuDevice,
8306) -> GpuResult<()> {
8307    use cudarc::driver::PushKernelArg;
8308    let total = d0 * d1 * d2 * d3;
8309    let ctx = device.context();
8310    let stream = device.stream();
8311    let f = crate::module_cache::get_or_compile(
8312        ctx,
8313        PERMUTE_0213_PTX,
8314        "permute_0213_kernel",
8315        device.ordinal() as u32,
8316    )
8317    .map_err(|_| GpuError::PtxCompileFailed {
8318        kernel: "permute_0213_kernel",
8319    })?;
8320    let cfg = launch_cfg(total)?;
8321    let (d0u, d1u, d2u, d3u, tu) = (d0 as u32, d1 as u32, d2 as u32, d3 as u32, total as u32);
8322    unsafe {
8323        stream
8324            .launch_builder(&f)
8325            .arg(a.inner())
8326            .arg(out.inner_mut())
8327            .arg(&d0u)
8328            .arg(&d1u)
8329            .arg(&d2u)
8330            .arg(&d3u)
8331            .arg(&tu)
8332            .launch(cfg)?;
8333    }
8334    Ok(())
8335}
8336
8337/// Softmax into pre-allocated output (row-wise).
8338#[cfg(feature = "cuda")]
8339pub fn gpu_softmax_into(
8340    a: &CudaBuffer<f32>,
8341    rows: usize,
8342    cols: usize,
8343    out: &mut CudaBuffer<f32>,
8344    device: &GpuDevice,
8345) -> GpuResult<()> {
8346    use cudarc::driver::PushKernelArg;
8347    let ctx = device.context();
8348    let stream = device.stream();
8349    let f = crate::module_cache::get_or_compile(
8350        ctx,
8351        SOFTMAX_PTX,
8352        "softmax_kernel",
8353        device.ordinal() as u32,
8354    )
8355    .map_err(|_| GpuError::PtxCompileFailed {
8356        kernel: "softmax_kernel",
8357    })?;
8358    let block_size = 256u32;
8359    let grid_size = rows as u32;
8360    let cfg = LaunchConfig {
8361        grid_dim: (grid_size, 1, 1),
8362        block_dim: (block_size, 1, 1),
8363        shared_mem_bytes: (cols as u32) * 4,
8364    };
8365    let rows_u32 = rows as u32;
8366    let cols_u32 = cols as u32;
8367    unsafe {
8368        stream
8369            .launch_builder(&f)
8370            .arg(a.inner())
8371            .arg(out.inner_mut())
8372            .arg(&rows_u32)
8373            .arg(&cols_u32)
8374            .launch(cfg)?;
8375    }
8376    Ok(())
8377}
8378
8379/// LayerNorm into pre-allocated output.
8380#[cfg(feature = "cuda")]
8381#[allow(clippy::too_many_arguments)]
8382pub fn gpu_layernorm_into(
8383    input: &CudaBuffer<f32>,
8384    weight: &CudaBuffer<f32>,
8385    bias: &CudaBuffer<f32>,
8386    rows: usize,
8387    cols: usize,
8388    eps: f32,
8389    out: &mut CudaBuffer<f32>,
8390    device: &GpuDevice,
8391) -> GpuResult<()> {
8392    use cudarc::driver::PushKernelArg;
8393    let ctx = device.context();
8394    let stream = device.stream();
8395    let f = crate::module_cache::get_or_compile(
8396        ctx,
8397        LAYERNORM_PTX,
8398        "layernorm_kernel",
8399        device.ordinal() as u32,
8400    )
8401    .map_err(|_| GpuError::PtxCompileFailed {
8402        kernel: "layernorm_kernel",
8403    })?;
8404    let block_size = 256u32;
8405    let grid_size = rows as u32;
8406    let cfg = LaunchConfig {
8407        grid_dim: (grid_size, 1, 1),
8408        block_dim: (block_size, 1, 1),
8409        shared_mem_bytes: (cols as u32) * 4,
8410    };
8411    let rows_u32 = rows as u32;
8412    let cols_u32 = cols as u32;
8413    unsafe {
8414        stream
8415            .launch_builder(&f)
8416            .arg(input.inner())
8417            .arg(out.inner_mut())
8418            .arg(weight.inner())
8419            .arg(bias.inner())
8420            .arg(&rows_u32)
8421            .arg(&cols_u32)
8422            .arg(&eps)
8423            .launch(cfg)?;
8424    }
8425    Ok(())
8426}
8427
8428/// Slice read into pre-allocated output: read first `len` rows from
8429/// `[n_batch, max_len, d]` into out `[n_batch, len, d]`.
8430#[cfg(feature = "cuda")]
8431pub fn gpu_slice_read_into(
8432    src: &CudaBuffer<f32>,
8433    n_batch: usize,
8434    d: usize,
8435    len: usize,
8436    max_len: usize,
8437    out: &mut CudaBuffer<f32>,
8438    device: &GpuDevice,
8439) -> GpuResult<()> {
8440    use cudarc::driver::PushKernelArg;
8441    let total = n_batch * len * d;
8442    let ctx = device.context();
8443    let stream = device.stream();
8444    let f = crate::module_cache::get_or_compile(
8445        ctx,
8446        SLICE_READ_PTX,
8447        "slice_read_kernel",
8448        device.ordinal() as u32,
8449    )
8450    .map_err(|_| GpuError::PtxCompileFailed {
8451        kernel: "slice_read_kernel",
8452    })?;
8453    let cfg = launch_cfg(total)?;
8454    let total_u32 = total as u32;
8455    let d_u32 = d as u32;
8456    let len_u32 = len as u32;
8457    let max_len_u32 = max_len as u32;
8458    unsafe {
8459        stream
8460            .launch_builder(&f)
8461            .arg(src.inner())
8462            .arg(out.inner_mut())
8463            .arg(&total_u32)
8464            .arg(&d_u32)
8465            .arg(&len_u32)
8466            .arg(&max_len_u32)
8467            .launch(cfg)?;
8468    }
8469    Ok(())
8470}
8471
8472/// Small matmul (PTX kernel) into pre-allocated output.
8473#[cfg(feature = "cuda")]
8474pub fn gpu_small_matmul_into(
8475    a: &CudaBuffer<f32>,
8476    b: &CudaBuffer<f32>,
8477    m: usize,
8478    k: usize,
8479    n: usize,
8480    out: &mut CudaBuffer<f32>,
8481    device: &GpuDevice,
8482) -> GpuResult<()> {
8483    use cudarc::driver::PushKernelArg;
8484    let total = m * n;
8485    let ctx = device.context();
8486    let stream = device.stream();
8487    let f = crate::module_cache::get_or_compile(
8488        ctx,
8489        SMALL_MATMUL_PTX,
8490        "small_matmul_kernel",
8491        device.ordinal() as u32,
8492    )
8493    .map_err(|_| GpuError::PtxCompileFailed {
8494        kernel: "small_matmul_kernel",
8495    })?;
8496    let cfg = launch_cfg(total)?;
8497    let (m_u32, k_u32, n_u32, total_u32) = (m as u32, k as u32, n as u32, total as u32);
8498    unsafe {
8499        stream
8500            .launch_builder(&f)
8501            .arg(a.inner())
8502            .arg(b.inner())
8503            .arg(out.inner_mut())
8504            .arg(&m_u32)
8505            .arg(&k_u32)
8506            .arg(&n_u32)
8507            .arg(&total_u32)
8508            .launch(cfg)?;
8509    }
8510    Ok(())
8511}
8512
8513// ===========================================================================
8514// Indirect-parameter kernels for CUDA graph capture
8515// ===========================================================================
8516
8517/// Slice write with position read from device memory (for CUDA graph capture).
8518/// Writes `src [n_batch, d]` into row `*pos_ptr` of `dst [n_batch, max_len, d]`.
8519#[cfg(feature = "cuda")]
8520pub fn gpu_slice_write_indirect(
8521    src: &CudaBuffer<f32>,
8522    dst: &mut CudaBuffer<f32>,
8523    n_batch: usize,
8524    d: usize,
8525    max_len: usize,
8526    pos_ptr: &cudarc::driver::CudaSlice<u32>,
8527    device: &GpuDevice,
8528) -> GpuResult<()> {
8529    use cudarc::driver::PushKernelArg;
8530    let total = n_batch * d;
8531    let ctx = device.context();
8532    let stream = device.stream();
8533    let f = crate::module_cache::get_or_compile(
8534        ctx,
8535        SLICE_WRITE_INDIRECT_PTX,
8536        "slice_write_indirect_kernel",
8537        device.ordinal() as u32,
8538    )
8539    .map_err(|_| GpuError::PtxCompileFailed {
8540        kernel: "slice_write_indirect_kernel",
8541    })?;
8542    let cfg = launch_cfg(total)?;
8543    let n_u32 = total as u32;
8544    let d_u32 = d as u32;
8545    let max_len_u32 = max_len as u32;
8546    unsafe {
8547        stream
8548            .launch_builder(&f)
8549            .arg(src.inner())
8550            .arg(dst.inner_mut())
8551            .arg(&n_u32)
8552            .arg(&d_u32)
8553            .arg(&max_len_u32)
8554            .arg(pos_ptr)
8555            .launch(cfg)?;
8556    }
8557    Ok(())
8558}
8559
8560/// Build causal attention mask with total_len read from device memory.
8561/// Writes `out[h, col] = 0.0` if `col < *total_len_ptr`, else `-1e9`.
8562/// Output shape: `[n_head, max_pos]` (n_head rows, each max_pos wide).
8563#[cfg(feature = "cuda")]
8564pub fn gpu_causal_mask_indirect(
8565    total_len_ptr: &cudarc::driver::CudaSlice<u32>,
8566    n_head: usize,
8567    max_pos: usize,
8568    out: &mut CudaBuffer<f32>,
8569    device: &GpuDevice,
8570) -> GpuResult<()> {
8571    use cudarc::driver::PushKernelArg;
8572    let total = n_head * max_pos;
8573    let ctx = device.context();
8574    let stream = device.stream();
8575    let f = crate::module_cache::get_or_compile(
8576        ctx,
8577        CAUSAL_MASK_INDIRECT_PTX,
8578        "causal_mask_indirect_kernel",
8579        device.ordinal() as u32,
8580    )
8581    .map_err(|_| GpuError::PtxCompileFailed {
8582        kernel: "causal_mask_indirect_kernel",
8583    })?;
8584    let cfg = launch_cfg(total)?;
8585    let max_pos_u32 = max_pos as u32;
8586    let total_u32 = total as u32;
8587    unsafe {
8588        stream
8589            .launch_builder(&f)
8590            .arg(total_len_ptr)
8591            .arg(out.inner_mut())
8592            .arg(&max_pos_u32)
8593            .arg(&total_u32)
8594            .launch(cfg)?;
8595    }
8596    Ok(())
8597}
8598
8599// ===========================================================================
8600// Pre-compilation of all decode-path PTX modules
8601// ===========================================================================
8602
8603/// Pre-compile all PTX kernels used by the decode pass into the module cache.
8604/// Call this before CUDA graph capture to ensure no `cuModuleLoadData` calls
8605/// occur during capture (which is not a capturable operation).
8606#[cfg(feature = "cuda")]
8607pub fn precompile_decode_kernels(device: &GpuDevice) -> GpuResult<()> {
8608    let ctx = device.context();
8609    ctx.bind_to_thread()?;
8610    let ord = device.ordinal() as u32;
8611    let compile = |ptx: &'static str, name: &'static str| -> GpuResult<()> {
8612        crate::module_cache::get_or_compile(ctx, ptx, name, ord)
8613            .map(|_| ())
8614            .map_err(GpuError::Driver)
8615    };
8616    compile(ADD_PTX, "add_kernel")?;
8617    compile(MUL_PTX, "mul_kernel")?;
8618    compile(SCALE_PTX, "scale_kernel")?;
8619    compile(GELU_PTX, "gelu_kernel")?;
8620    compile(SOFTMAX_PTX, "softmax_kernel")?;
8621    compile(LAYERNORM_PTX, "layernorm_kernel")?;
8622    compile(PERMUTE_0213_PTX, "permute_0213_kernel")?;
8623    compile(EMBED_LOOKUP_PTX, "embed_lookup_kernel")?;
8624    compile(EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel")?;
8625    compile(SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel")?;
8626    compile(SMALL_MATMUL_PTX, "small_matmul_kernel")?;
8627    compile(SLICE_WRITE_INDIRECT_PTX, "slice_write_indirect_kernel")?;
8628    compile(CAUSAL_MASK_INDIRECT_PTX, "causal_mask_indirect_kernel")?;
8629    compile(SLICE_READ_PTX, "slice_read_kernel")?;
8630    compile(RELU_BACKWARD_PTX, "relu_backward_kernel")?;
8631    compile(GELU_BACKWARD_PTX, "gelu_backward_kernel")?;
8632    Ok(())
8633}
8634
8635/// Stub — no-op without cuda.
8636#[cfg(not(feature = "cuda"))]
8637pub fn precompile_decode_kernels(_device: &GpuDevice) -> GpuResult<()> {
8638    Err(GpuError::NoCudaFeature)
8639}
8640
8641// ---------------------------------------------------------------------------
8642// Stubs when `cuda` feature is disabled
8643// ---------------------------------------------------------------------------
8644
8645/// Stub -- always returns [`GpuError::NoCudaFeature`].
8646#[cfg(not(feature = "cuda"))]
8647pub fn gpu_gelu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8648    Err(GpuError::NoCudaFeature)
8649}
8650
8651/// Stub -- always returns [`GpuError::NoCudaFeature`].
8652#[cfg(not(feature = "cuda"))]
8653pub fn gpu_gelu_tanh(
8654    _input: &CudaBuffer<f32>,
8655    _device: &GpuDevice,
8656) -> GpuResult<CudaBuffer<f32>> {
8657    Err(GpuError::NoCudaFeature)
8658}
8659
8660/// Stub -- always returns [`GpuError::NoCudaFeature`].
8661#[cfg(not(feature = "cuda"))]
8662pub fn gpu_gelu_erf(
8663    _input: &CudaBuffer<f32>,
8664    _device: &GpuDevice,
8665) -> GpuResult<CudaBuffer<f32>> {
8666    Err(GpuError::NoCudaFeature)
8667}
8668
8669/// Stub -- always returns [`GpuError::NoCudaFeature`].
8670#[cfg(not(feature = "cuda"))]
8671pub fn gpu_gelu_backward_tanh(
8672    _grad: &CudaBuffer<f32>,
8673    _input: &CudaBuffer<f32>,
8674    _device: &GpuDevice,
8675) -> GpuResult<CudaBuffer<f32>> {
8676    Err(GpuError::NoCudaFeature)
8677}
8678
8679/// Stub -- always returns [`GpuError::NoCudaFeature`].
8680#[cfg(not(feature = "cuda"))]
8681pub fn gpu_div(
8682    _a: &CudaBuffer<f32>,
8683    _b: &CudaBuffer<f32>,
8684    _device: &GpuDevice,
8685) -> GpuResult<CudaBuffer<f32>> {
8686    Err(GpuError::NoCudaFeature)
8687}
8688
8689/// Stub -- always returns [`GpuError::NoCudaFeature`].
8690#[cfg(not(feature = "cuda"))]
8691pub fn gpu_exp(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8692    Err(GpuError::NoCudaFeature)
8693}
8694
8695/// Stub -- always returns [`GpuError::NoCudaFeature`].
8696#[cfg(not(feature = "cuda"))]
8697pub fn gpu_log(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8698    Err(GpuError::NoCudaFeature)
8699}
8700
8701/// Stub -- always returns [`GpuError::NoCudaFeature`].
8702#[cfg(not(feature = "cuda"))]
8703pub fn gpu_sqrt(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8704    Err(GpuError::NoCudaFeature)
8705}
8706
8707/// Stub -- always returns [`GpuError::NoCudaFeature`].
8708#[cfg(not(feature = "cuda"))]
8709pub fn gpu_pow(
8710    _a: &CudaBuffer<f32>,
8711    _exponent: f32,
8712    _device: &GpuDevice,
8713) -> GpuResult<CudaBuffer<f32>> {
8714    Err(GpuError::NoCudaFeature)
8715}
8716
8717/// Stub -- always returns [`GpuError::NoCudaFeature`].
8718#[cfg(not(feature = "cuda"))]
8719pub fn gpu_abs(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8720    Err(GpuError::NoCudaFeature)
8721}
8722
8723/// Stub -- always returns [`GpuError::NoCudaFeature`].
8724#[cfg(not(feature = "cuda"))]
8725pub fn gpu_sigmoid(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8726    Err(GpuError::NoCudaFeature)
8727}
8728
8729/// Stub -- always returns [`GpuError::NoCudaFeature`].
8730#[cfg(not(feature = "cuda"))]
8731pub fn gpu_tanh(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8732    Err(GpuError::NoCudaFeature)
8733}
8734
8735/// Stub -- always returns [`GpuError::NoCudaFeature`].
8736#[cfg(not(feature = "cuda"))]
8737pub fn gpu_layernorm(
8738    _input: &CudaBuffer<f32>,
8739    _weight: &CudaBuffer<f32>,
8740    _bias: &CudaBuffer<f32>,
8741    _rows: usize,
8742    _cols: usize,
8743    _eps: f32,
8744    _device: &GpuDevice,
8745) -> GpuResult<CudaBuffer<f32>> {
8746    Err(GpuError::NoCudaFeature)
8747}
8748
8749/// Stub -- always returns [`GpuError::NoCudaFeature`].
8750#[cfg(not(feature = "cuda"))]
8751pub fn gpu_transpose_2d(
8752    _input: &CudaBuffer<f32>,
8753    _m: usize,
8754    _n: usize,
8755    _device: &GpuDevice,
8756) -> GpuResult<CudaBuffer<f32>> {
8757    Err(GpuError::NoCudaFeature)
8758}
8759
8760/// Stub -- always returns [`GpuError::NoCudaFeature`].
8761#[cfg(not(feature = "cuda"))]
8762pub fn gpu_add(
8763    _a: &CudaBuffer<f32>,
8764    _b: &CudaBuffer<f32>,
8765    _device: &GpuDevice,
8766) -> GpuResult<CudaBuffer<f32>> {
8767    Err(GpuError::NoCudaFeature)
8768}
8769
8770/// Stub -- always returns [`GpuError::NoCudaFeature`].
8771#[cfg(not(feature = "cuda"))]
8772pub fn gpu_sub(
8773    _a: &CudaBuffer<f32>,
8774    _b: &CudaBuffer<f32>,
8775    _device: &GpuDevice,
8776) -> GpuResult<CudaBuffer<f32>> {
8777    Err(GpuError::NoCudaFeature)
8778}
8779
8780/// Stub -- always returns [`GpuError::NoCudaFeature`].
8781#[cfg(not(feature = "cuda"))]
8782pub fn gpu_mul(
8783    _a: &CudaBuffer<f32>,
8784    _b: &CudaBuffer<f32>,
8785    _device: &GpuDevice,
8786) -> GpuResult<CudaBuffer<f32>> {
8787    Err(GpuError::NoCudaFeature)
8788}
8789
8790/// Stub -- always returns [`GpuError::NoCudaFeature`].
8791#[cfg(not(feature = "cuda"))]
8792pub fn gpu_neg(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8793    Err(GpuError::NoCudaFeature)
8794}
8795
8796/// Stub -- always returns [`GpuError::NoCudaFeature`].
8797#[cfg(not(feature = "cuda"))]
8798pub fn gpu_relu(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8799    Err(GpuError::NoCudaFeature)
8800}
8801
8802/// Stub -- always returns [`GpuError::NoCudaFeature`].
8803#[cfg(not(feature = "cuda"))]
8804pub fn gpu_scale(
8805    _a: &CudaBuffer<f32>,
8806    _scalar: f32,
8807    _device: &GpuDevice,
8808) -> GpuResult<CudaBuffer<f32>> {
8809    Err(GpuError::NoCudaFeature)
8810}
8811
8812/// Stub -- always returns [`GpuError::NoCudaFeature`].
8813#[cfg(not(feature = "cuda"))]
8814pub fn gpu_broadcast_add(
8815    _a: &CudaBuffer<f32>,
8816    _b: &CudaBuffer<f32>,
8817    _a_shape: &[usize],
8818    _b_shape: &[usize],
8819    _out_shape: &[usize],
8820    _device: &GpuDevice,
8821) -> GpuResult<CudaBuffer<f32>> {
8822    Err(GpuError::NoCudaFeature)
8823}
8824
8825/// Stub -- always returns [`GpuError::NoCudaFeature`].
8826#[cfg(not(feature = "cuda"))]
8827pub fn gpu_broadcast_sub(
8828    _a: &CudaBuffer<f32>,
8829    _b: &CudaBuffer<f32>,
8830    _a_shape: &[usize],
8831    _b_shape: &[usize],
8832    _out_shape: &[usize],
8833    _device: &GpuDevice,
8834) -> GpuResult<CudaBuffer<f32>> {
8835    Err(GpuError::NoCudaFeature)
8836}
8837
8838/// Stub -- always returns [`GpuError::NoCudaFeature`].
8839#[cfg(not(feature = "cuda"))]
8840pub fn gpu_broadcast_mul(
8841    _a: &CudaBuffer<f32>,
8842    _b: &CudaBuffer<f32>,
8843    _a_shape: &[usize],
8844    _b_shape: &[usize],
8845    _out_shape: &[usize],
8846    _device: &GpuDevice,
8847) -> GpuResult<CudaBuffer<f32>> {
8848    Err(GpuError::NoCudaFeature)
8849}
8850
8851/// Stub -- always returns [`GpuError::NoCudaFeature`].
8852#[cfg(not(feature = "cuda"))]
8853pub fn gpu_softmax(
8854    _input: &CudaBuffer<f32>,
8855    _rows: usize,
8856    _cols: usize,
8857    _device: &GpuDevice,
8858) -> GpuResult<CudaBuffer<f32>> {
8859    Err(GpuError::NoCudaFeature)
8860}
8861
8862/// Stub -- always returns [`GpuError::NoCudaFeature`].
8863#[cfg(not(feature = "cuda"))]
8864pub fn gpu_dropout(
8865    _input: &CudaBuffer<f32>,
8866    _threshold: u32,
8867    _scale: f32,
8868    _seed: u32,
8869    _device: &GpuDevice,
8870) -> GpuResult<CudaBuffer<f32>> {
8871    Err(GpuError::NoCudaFeature)
8872}
8873
8874/// Stub -- always returns [`GpuError::NoCudaFeature`].
8875#[cfg(not(feature = "cuda"))]
8876pub fn gpu_permute_0213(
8877    _input: &CudaBuffer<f32>,
8878    _d0: usize,
8879    _d1: usize,
8880    _d2: usize,
8881    _d3: usize,
8882    _device: &GpuDevice,
8883) -> GpuResult<CudaBuffer<f32>> {
8884    Err(GpuError::NoCudaFeature)
8885}
8886
8887/// Stub -- always returns [`GpuError::NoCudaFeature`].
8888#[cfg(not(feature = "cuda"))]
8889pub fn gpu_slice_write(
8890    _src: &CudaBuffer<f32>,
8891    _dst: &mut CudaBuffer<f32>,
8892    _n_batch: usize,
8893    _d: usize,
8894    _max_len: usize,
8895    _pos: usize,
8896    _device: &GpuDevice,
8897) -> GpuResult<()> {
8898    Err(GpuError::NoCudaFeature)
8899}
8900
8901/// Stub -- always returns [`GpuError::NoCudaFeature`].
8902#[cfg(not(feature = "cuda"))]
8903pub fn gpu_slice_read(
8904    _src: &CudaBuffer<f32>,
8905    _n_batch: usize,
8906    _d: usize,
8907    _len: usize,
8908    _max_len: usize,
8909    _device: &GpuDevice,
8910) -> GpuResult<CudaBuffer<f32>> {
8911    Err(GpuError::NoCudaFeature)
8912}
8913
8914/// Stub -- always returns [`GpuError::NoCudaFeature`].
8915#[cfg(not(feature = "cuda"))]
8916pub fn gpu_embed_lookup(
8917    _idx: &CudaBuffer<f32>,
8918    _weight: &CudaBuffer<f32>,
8919    _d: usize,
8920    _device: &GpuDevice,
8921) -> GpuResult<CudaBuffer<f32>> {
8922    Err(GpuError::NoCudaFeature)
8923}
8924
8925/// Stub -- always returns [`GpuError::NoCudaFeature`].
8926#[cfg(not(feature = "cuda"))]
8927pub fn gpu_embed_lookup_batch(
8928    _indices: &CudaBuffer<f32>,
8929    _weight: &CudaBuffer<f32>,
8930    _n: usize,
8931    _d: usize,
8932    _device: &GpuDevice,
8933) -> GpuResult<CudaBuffer<f32>> {
8934    Err(GpuError::NoCudaFeature)
8935}
8936
8937/// Stub -- always returns [`GpuError::NoCudaFeature`].
8938#[cfg(not(feature = "cuda"))]
8939pub fn gpu_scatter_add_rows(
8940    _grad_output: &CudaBuffer<f32>,
8941    _indices: &CudaBuffer<f32>,
8942    _num_embeddings: usize,
8943    _d: usize,
8944    _device: &GpuDevice,
8945) -> GpuResult<CudaBuffer<f32>> {
8946    Err(GpuError::NoCudaFeature)
8947}
8948
8949/// Stub -- always returns [`GpuError::NoCudaFeature`].
8950#[cfg(not(feature = "cuda"))]
8951pub fn gpu_relu_backward(
8952    _grad: &CudaBuffer<f32>,
8953    _input: &CudaBuffer<f32>,
8954    _device: &GpuDevice,
8955) -> GpuResult<CudaBuffer<f32>> {
8956    Err(GpuError::NoCudaFeature)
8957}
8958
8959/// Stub -- always returns [`GpuError::NoCudaFeature`].
8960#[cfg(not(feature = "cuda"))]
8961pub fn gpu_gelu_backward(
8962    _grad: &CudaBuffer<f32>,
8963    _input: &CudaBuffer<f32>,
8964    _device: &GpuDevice,
8965) -> GpuResult<CudaBuffer<f32>> {
8966    Err(GpuError::NoCudaFeature)
8967}
8968
8969/// Stub -- always returns [`GpuError::NoCudaFeature`].
8970#[cfg(not(feature = "cuda"))]
8971pub fn gpu_index_select_1d(
8972    _input: &CudaBuffer<f32>,
8973    _indices: &CudaBuffer<f32>,
8974    _device: &GpuDevice,
8975) -> GpuResult<CudaBuffer<f32>> {
8976    Err(GpuError::NoCudaFeature)
8977}
8978
8979/// Stub -- always returns [`GpuError::NoCudaFeature`].
8980#[cfg(not(feature = "cuda"))]
8981pub fn gpu_scatter_add_1d(
8982    _grad_output: &CudaBuffer<f32>,
8983    _indices: &CudaBuffer<f32>,
8984    _input_len: usize,
8985    _device: &GpuDevice,
8986) -> GpuResult<CudaBuffer<f32>> {
8987    Err(GpuError::NoCudaFeature)
8988}
8989
8990/// Stub -- always returns [`GpuError::NoCudaFeature`].
8991#[cfg(not(feature = "cuda"))]
8992pub fn gpu_masked_fill(
8993    _input: &CudaBuffer<f32>,
8994    _mask: &CudaBuffer<f32>,
8995    _value: f32,
8996    _device: &GpuDevice,
8997) -> GpuResult<CudaBuffer<f32>> {
8998    Err(GpuError::NoCudaFeature)
8999}
9000
9001/// Stub -- always returns [`GpuError::NoCudaFeature`].
9002#[cfg(not(feature = "cuda"))]
9003pub fn gpu_masked_zero(
9004    _grad: &CudaBuffer<f32>,
9005    _mask: &CudaBuffer<f32>,
9006    _device: &GpuDevice,
9007) -> GpuResult<CudaBuffer<f32>> {
9008    Err(GpuError::NoCudaFeature)
9009}
9010
9011/// Stub -- always returns [`GpuError::NoCudaFeature`].
9012#[cfg(not(feature = "cuda"))]
9013pub fn gpu_sigmoid_backward(
9014    _grad: &CudaBuffer<f32>,
9015    _output: &CudaBuffer<f32>,
9016    _device: &GpuDevice,
9017) -> GpuResult<CudaBuffer<f32>> {
9018    Err(GpuError::NoCudaFeature)
9019}
9020
9021/// Stub -- always returns [`GpuError::NoCudaFeature`].
9022#[cfg(not(feature = "cuda"))]
9023pub fn gpu_tanh_backward(
9024    _grad: &CudaBuffer<f32>,
9025    _output: &CudaBuffer<f32>,
9026    _device: &GpuDevice,
9027) -> GpuResult<CudaBuffer<f32>> {
9028    Err(GpuError::NoCudaFeature)
9029}
9030
9031/// Stub -- always returns [`GpuError::NoCudaFeature`].
9032#[cfg(not(feature = "cuda"))]
9033pub fn gpu_softmax_backward(
9034    _grad: &CudaBuffer<f32>,
9035    _output: &CudaBuffer<f32>,
9036    _cols: usize,
9037    _device: &GpuDevice,
9038) -> GpuResult<CudaBuffer<f32>> {
9039    Err(GpuError::NoCudaFeature)
9040}
9041
9042/// Stub -- always returns [`GpuError::NoCudaFeature`].
9043#[cfg(not(feature = "cuda"))]
9044pub fn gpu_sum_axis(
9045    _a: &CudaBuffer<f32>,
9046    _outer: usize,
9047    _axis_size: usize,
9048    _inner: usize,
9049    _device: &GpuDevice,
9050) -> GpuResult<CudaBuffer<f32>> {
9051    Err(GpuError::NoCudaFeature)
9052}
9053
9054/// Stub -- always returns [`GpuError::NoCudaFeature`].
9055#[cfg(not(feature = "cuda"))]
9056pub fn gpu_strided_split(
9057    _input: &CudaBuffer<f32>,
9058    _total_along_axis: usize,
9059    _split_offset: usize,
9060    _split_size: usize,
9061    _inner_size: usize,
9062    _n: usize,
9063    _device: &GpuDevice,
9064) -> GpuResult<CudaBuffer<f32>> {
9065    Err(GpuError::NoCudaFeature)
9066}
9067
9068/// Stub -- always returns [`GpuError::NoCudaFeature`].
9069#[cfg(not(feature = "cuda"))]
9070pub fn gpu_strided_cat(
9071    _input: &CudaBuffer<f32>,
9072    _output: &mut CudaBuffer<f32>,
9073    _total_along_axis: usize,
9074    _cat_offset: usize,
9075    _part_size: usize,
9076    _inner_size: usize,
9077    _n: usize,
9078    _device: &GpuDevice,
9079) -> GpuResult<()> {
9080    Err(GpuError::NoCudaFeature)
9081}
9082
9083// ---------------------------------------------------------------------------
9084// f32-to-f16 GPU conversion
9085// ---------------------------------------------------------------------------
9086
9087/// Convert an f32 GPU buffer to f16 (represented as `CudaSlice<u16>`).
9088///
9089/// Each element is converted using IEEE 754 round-to-nearest-even via the
9090/// PTX `cvt.rn.f16.f32` instruction. The output is a `CudaSlice<u16>` where
9091/// each `u16` holds the bit pattern of an IEEE 754 half-precision float.
9092///
9093/// # Errors
9094///
9095/// - [`GpuError::PtxCompileFailed`] if the conversion kernel cannot be compiled
9096///   (e.g., GPU architecture too old to support f16 conversion instructions).
9097/// - [`GpuError::Driver`] on CUDA launch errors.
9098#[cfg(feature = "cuda")]
9099pub(crate) fn gpu_f32_to_f16(
9100    input: &CudaBuffer<f32>,
9101    device: &GpuDevice,
9102) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
9103    use cudarc::driver::PushKernelArg;
9104
9105    let n = input.len();
9106    if n == 0 {
9107        let empty = device.stream().alloc_zeros::<u16>(0)?;
9108        return Ok(empty);
9109    }
9110
9111    let ctx = device.context();
9112    let stream = device.stream();
9113
9114    let f = crate::module_cache::get_or_compile(
9115        ctx,
9116        F32_TO_F16_PTX,
9117        "f32_to_f16_kernel",
9118        device.ordinal() as u32,
9119    )
9120    .map_err(|_| GpuError::PtxCompileFailed {
9121        kernel: "f32_to_f16_kernel",
9122    })?;
9123
9124    let mut out = stream.alloc_zeros::<u16>(n)?;
9125    let cfg = launch_cfg(n)?;
9126    let n_u32 = n as u32;
9127
9128    // SAFETY: The kernel reads `n` f32 values from `input` and writes `n`
9129    // u16 values (f16 bit patterns) to `out`. Both buffers are device-resident
9130    // and correctly sized. The grid is configured to cover exactly `n` threads.
9131    unsafe {
9132        stream
9133            .launch_builder(&f)
9134            .arg(input.inner())
9135            .arg(&mut out)
9136            .arg(&n_u32)
9137            .launch(cfg)?;
9138    }
9139
9140    Ok(out)
9141}
9142
9143/// Stub -- always returns [`GpuError::NoCudaFeature`].
9144#[cfg(not(feature = "cuda"))]
9145pub(crate) fn gpu_f32_to_f16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
9146    Err(GpuError::NoCudaFeature)
9147}
9148
9149/// Convert f32 GPU buffer to bf16 (stored as u16) on-device.
9150///
9151/// Uses bit manipulation for round-to-nearest-even bf16 conversion.
9152/// Works on sm_52+ (no special bf16 hardware required).
9153#[cfg(feature = "cuda")]
9154pub(crate) fn gpu_f32_to_bf16(
9155    input: &CudaBuffer<f32>,
9156    device: &GpuDevice,
9157) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
9158    use cudarc::driver::PushKernelArg;
9159
9160    let n = input.len();
9161    if n == 0 {
9162        let empty = device.stream().alloc_zeros::<u16>(0)?;
9163        return Ok(empty);
9164    }
9165
9166    let ctx = device.context();
9167    let stream = device.stream();
9168
9169    let f = crate::module_cache::get_or_compile(
9170        ctx,
9171        F32_TO_BF16_PTX,
9172        "f32_to_bf16_kernel",
9173        device.ordinal() as u32,
9174    )
9175    .map_err(|_| GpuError::PtxCompileFailed {
9176        kernel: "f32_to_bf16_kernel",
9177    })?;
9178
9179    let mut out = stream.alloc_zeros::<u16>(n)?;
9180    let cfg = launch_cfg(n)?;
9181    let n_u32 = n as u32;
9182
9183    unsafe {
9184        stream
9185            .launch_builder(&f)
9186            .arg(input.inner())
9187            .arg(&mut out)
9188            .arg(&n_u32)
9189            .launch(cfg)?;
9190    }
9191
9192    Ok(out)
9193}
9194
9195/// Stub -- always returns [`GpuError::NoCudaFeature`].
9196#[cfg(not(feature = "cuda"))]
9197pub(crate) fn gpu_f32_to_bf16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
9198    Err(GpuError::NoCudaFeature)
9199}
9200
9201// ---------------------------------------------------------------------------
9202// Tests -- require a real CUDA GPU
9203// ---------------------------------------------------------------------------
9204
9205#[cfg(test)]
9206#[cfg(feature = "cuda")]
9207mod tests {
9208    use super::*;
9209
9210    /// Helper: set up device + upload a slice.
9211    fn setup(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
9212        let dev = GpuDevice::new(0).expect("CUDA device 0");
9213        let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
9214        (dev, buf)
9215    }
9216
9217    /// Round-trip helper: download a GPU buffer and compare against expected
9218    /// CPU output element-wise.
9219    fn assert_buf_eq(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32]) {
9220        let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
9221        assert_eq!(host.len(), expected.len(), "length mismatch");
9222        for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
9223            assert!(
9224                (got - exp).abs() < 1e-6,
9225                "element {i}: got {got}, expected {exp}",
9226            );
9227        }
9228    }
9229
9230    // -- gpu_add -------------------------------------------------------------
9231
9232    #[test]
9233    fn add_basic() {
9234        let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
9235        let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
9236        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
9237
9238        let (dev, a) = setup(&a_data);
9239        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
9240        let out = gpu_add(&a, &b, &dev).expect("gpu_add");
9241        assert_buf_eq(&out, &dev, &expected);
9242    }
9243
9244    #[test]
9245    fn add_empty() {
9246        let (dev, a) = setup(&[]);
9247        let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
9248        let out = gpu_add(&a, &b, &dev).expect("gpu_add empty");
9249        assert_eq!(out.len(), 0);
9250    }
9251
9252    #[test]
9253    fn add_large() {
9254        let n = 100_000;
9255        let a_data: Vec<f32> = (0..n).map(|i| i as f32).collect();
9256        let b_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
9257        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
9258
9259        let (dev, a) = setup(&a_data);
9260        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
9261        let out = gpu_add(&a, &b, &dev).expect("gpu_add large");
9262        assert_buf_eq(&out, &dev, &expected);
9263    }
9264
9265    #[test]
9266    fn add_length_mismatch() {
9267        let (dev, a) = setup(&[1.0, 2.0, 3.0]);
9268        let b = cpu_to_gpu(&[1.0, 2.0], &dev).expect("cpu_to_gpu b");
9269        let err = gpu_add(&a, &b, &dev).unwrap_err();
9270        match err {
9271            GpuError::LengthMismatch { a: 3, b: 2 } => {}
9272            other => panic!("unexpected error: {other}"),
9273        }
9274    }
9275
9276    // -- gpu_sub -------------------------------------------------------------
9277
9278    #[test]
9279    fn sub_basic() {
9280        let a_data = vec![10.0f32, 20.0, 30.0, 40.0];
9281        let b_data = vec![1.0f32, 2.0, 3.0, 4.0];
9282        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x - y).collect();
9283
9284        let (dev, a) = setup(&a_data);
9285        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
9286        let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
9287        assert_buf_eq(&out, &dev, &expected);
9288    }
9289
9290    #[test]
9291    fn sub_negative_result() {
9292        let a_data = vec![1.0f32, 2.0];
9293        let b_data = vec![5.0f32, 10.0];
9294        let expected: Vec<f32> = vec![-4.0, -8.0];
9295
9296        let (dev, a) = setup(&a_data);
9297        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
9298        let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
9299        assert_buf_eq(&out, &dev, &expected);
9300    }
9301
9302    // -- gpu_mul -------------------------------------------------------------
9303
9304    #[test]
9305    fn mul_basic() {
9306        let a_data = vec![2.0f32, 3.0, 4.0, 5.0];
9307        let b_data = vec![10.0f32, 10.0, 10.0, 10.0];
9308        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x * y).collect();
9309
9310        let (dev, a) = setup(&a_data);
9311        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
9312        let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
9313        assert_buf_eq(&out, &dev, &expected);
9314    }
9315
9316    #[test]
9317    fn mul_by_zero() {
9318        let a_data = vec![1.0f32, 2.0, 3.0];
9319        let b_data = vec![0.0f32, 0.0, 0.0];
9320        let expected = vec![0.0f32, 0.0, 0.0];
9321
9322        let (dev, a) = setup(&a_data);
9323        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
9324        let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
9325        assert_buf_eq(&out, &dev, &expected);
9326    }
9327
9328    // -- gpu_neg -------------------------------------------------------------
9329
9330    #[test]
9331    fn neg_basic() {
9332        let a_data = vec![1.0f32, -2.0, 3.0, 0.0, -5.5];
9333        let expected: Vec<f32> = a_data.iter().map(|x| -x).collect();
9334
9335        let (dev, a) = setup(&a_data);
9336        let out = gpu_neg(&a, &dev).expect("gpu_neg");
9337        assert_buf_eq(&out, &dev, &expected);
9338    }
9339
9340    #[test]
9341    fn neg_double_negation() {
9342        let a_data = vec![1.0f32, -2.0, 3.0];
9343        let (dev, a) = setup(&a_data);
9344        let neg1 = gpu_neg(&a, &dev).expect("gpu_neg 1");
9345        let neg2 = gpu_neg(&neg1, &dev).expect("gpu_neg 2");
9346        assert_buf_eq(&neg2, &dev, &a_data);
9347    }
9348
9349    // -- gpu_relu ------------------------------------------------------------
9350
9351    #[test]
9352    fn relu_basic() {
9353        let a_data = vec![-3.0f32, -1.0, 0.0, 1.0, 3.0];
9354        let expected = vec![0.0f32, 0.0, 0.0, 1.0, 3.0];
9355
9356        let (dev, a) = setup(&a_data);
9357        let out = gpu_relu(&a, &dev).expect("gpu_relu");
9358        assert_buf_eq(&out, &dev, &expected);
9359    }
9360
9361    #[test]
9362    fn relu_all_negative() {
9363        let a_data = vec![-5.0f32, -0.1, -100.0];
9364        let expected = vec![0.0f32, 0.0, 0.0];
9365
9366        let (dev, a) = setup(&a_data);
9367        let out = gpu_relu(&a, &dev).expect("gpu_relu");
9368        assert_buf_eq(&out, &dev, &expected);
9369    }
9370
9371    #[test]
9372    fn relu_all_positive() {
9373        let a_data = vec![0.1f32, 1.0, 100.0];
9374
9375        let (dev, a) = setup(&a_data);
9376        let out = gpu_relu(&a, &dev).expect("gpu_relu");
9377        assert_buf_eq(&out, &dev, &a_data);
9378    }
9379
9380    #[test]
9381    fn relu_empty() {
9382        let (dev, a) = setup(&[]);
9383        let out = gpu_relu(&a, &dev).expect("gpu_relu empty");
9384        assert_eq!(out.len(), 0);
9385    }
9386
9387    #[test]
9388    fn small_matmul_2x2() {
9389        let dev = GpuDevice::new(0).expect("CUDA device 0");
9390        // A = [[1, 2], [3, 4]], B = [[5, 6], [7, 8]]
9391        // C = A@B = [[19, 22], [43, 50]]
9392        let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0, 4.0], &dev).unwrap();
9393        let b = cpu_to_gpu(&[5.0f32, 6.0, 7.0, 8.0], &dev).unwrap();
9394        let c = gpu_small_matmul(&a, &b, 2, 2, 2, &dev).unwrap();
9395        assert_buf_eq(&c, &dev, &[19.0, 22.0, 43.0, 50.0]);
9396    }
9397
9398    #[test]
9399    fn small_matmul_1xk_kxn() {
9400        let dev = GpuDevice::new(0).expect("CUDA device 0");
9401        // A = [1, 2, 3] (1x3), B = [[1, 0], [0, 1], [1, 1]] (3x2)
9402        // C = [4, 5] (1x2)
9403        let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0], &dev).unwrap();
9404        let b = cpu_to_gpu(&[1.0f32, 0.0, 0.0, 1.0, 1.0, 1.0], &dev).unwrap();
9405        let c = gpu_small_matmul(&a, &b, 1, 3, 2, &dev).unwrap();
9406        assert_buf_eq(&c, &dev, &[4.0, 5.0]);
9407    }
9408
9409    #[test]
9410    fn small_matmul_vs_cublas() {
9411        // Compare our small matmul against cuBLAS for a realistic decode-step size.
9412        // Linear layer: [1, 64] @ [64, 64] = [1, 64]
9413        let dev = GpuDevice::new(0).expect("CUDA device 0");
9414        let m = 1;
9415        let k = 64;
9416        let n = 64;
9417
9418        // Deterministic data.
9419        let a_data: Vec<f32> = (0..m * k)
9420            .map(|i| ((i * 7 + 3) % 100) as f32 / 100.0)
9421            .collect();
9422        let b_data: Vec<f32> = (0..k * n)
9423            .map(|i| ((i * 11 + 5) % 100) as f32 / 100.0)
9424            .collect();
9425
9426        let a = cpu_to_gpu(&a_data, &dev).unwrap();
9427        let b = cpu_to_gpu(&b_data, &dev).unwrap();
9428
9429        // cuBLAS reference.
9430        let c_cublas = crate::blas::gpu_matmul_f32(&a, &b, m, k, n, &dev).unwrap();
9431        let cublas_result = gpu_to_cpu(&c_cublas, &dev).unwrap();
9432
9433        // Our kernel.
9434        let c_ours = gpu_small_matmul(&a, &b, m, k, n, &dev).unwrap();
9435        let our_result = gpu_to_cpu(&c_ours, &dev).unwrap();
9436
9437        assert_eq!(cublas_result.len(), our_result.len());
9438        for (i, (&cb, &ours)) in cublas_result.iter().zip(our_result.iter()).enumerate() {
9439            assert!(
9440                (cb - ours).abs() < 0.1,
9441                "element {i}: cuBLAS={cb}, ours={ours}, diff={}",
9442                (cb - ours).abs()
9443            );
9444        }
9445    }
9446}