Skip to main content

ferrotorch_gpu/
kernels.rs

1//! Custom PTX CUDA kernels for elementwise GPU operations.
2//!
3//! Each operation has two code paths:
4//!
5//! 1. **PTX kernel** -- a hand-written PTX string that is loaded into the CUDA
6//!    driver at runtime via [`cudarc`]. This is the fast path and runs entirely
7//!    on the GPU.
8//! 2. **CPU fallback** -- copies data to the host, performs the operation with
9//!    standard Rust iterators, and copies the result back. Correct but slow;
10//!    used when the PTX module cannot be loaded (e.g. architecture mismatch).
11//!
12//! # Supported operations
13//!
14//! | Function | Formula |
15//! |----------|---------|
16//! | [`gpu_add`] | `out[i] = a[i] + b[i]` |
17//! | [`gpu_sub`] | `out[i] = a[i] - b[i]` |
18//! | [`gpu_mul`] | `out[i] = a[i] * b[i]` |
19//! | [`gpu_neg`] | `out[i] = -a[i]` |
20//! | [`gpu_relu`] | `out[i] = max(a[i], 0.0)` |
21
22#[cfg(feature = "cuda")]
23use cudarc::driver::LaunchConfig;
24
25use crate::buffer::CudaBuffer;
26use crate::device::GpuDevice;
27use crate::error::{GpuError, GpuResult};
28#[cfg(feature = "cuda")]
29use crate::transfer::{alloc_zeros_f32, cpu_to_gpu, gpu_to_cpu};
30
31// ---------------------------------------------------------------------------
32// PTX kernel source strings
33// ---------------------------------------------------------------------------
34
35/// PTX source for `add_kernel`: `out[i] = a[i] + b[i]`.
36#[cfg(feature = "cuda")]
37pub(crate) const ADD_PTX: &str = "\
38.version 7.0
39.target sm_52
40.address_size 64
41
42.visible .entry add_kernel(
43    .param .u64 a_ptr,
44    .param .u64 b_ptr,
45    .param .u64 out_ptr,
46    .param .u32 n
47) {
48    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
49    .reg .u64 %a, %b, %out, %off;
50    .reg .f32 %va, %vb, %vr;
51    .reg .pred %p;
52
53    ld.param.u64 %a, [a_ptr];
54    ld.param.u64 %b, [b_ptr];
55    ld.param.u64 %out, [out_ptr];
56    ld.param.u32 %n_reg, [n];
57
58    mov.u32 %bid, %ctaid.x;
59    mov.u32 %bdim, %ntid.x;
60    mov.u32 %r_tid, %tid.x;
61    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
62
63    setp.ge.u32 %p, %r_tid, %n_reg;
64    @%p bra DONE;
65
66    cvt.u64.u32 %off, %r_tid;
67    shl.b64 %off, %off, 2;
68
69    add.u64 %a, %a, %off;
70    add.u64 %b, %b, %off;
71    add.u64 %out, %out, %off;
72
73    ld.global.f32 %va, [%a];
74    ld.global.f32 %vb, [%b];
75    add.f32 %vr, %va, %vb;
76    st.global.f32 [%out], %vr;
77
78DONE:
79    ret;
80}
81";
82
83/// PTX source for `add_vec4_kernel`: vectorized add, 4 elements per thread.
84///
85/// Uses `ld.global.v4.f32` (128-bit load) for 4x memory throughput vs scalar.
86/// Thread i processes elements [i*4 .. i*4+3].
87#[cfg(feature = "cuda")]
88pub(crate) const ADD_VEC4_PTX: &str = "\
89.version 7.0
90.target sm_52
91.address_size 64
92
93.visible .entry add_vec4_kernel(
94    .param .u64 a_ptr,
95    .param .u64 b_ptr,
96    .param .u64 out_ptr,
97    .param .u32 n4
98) {
99    .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
100    .reg .u64 %a, %b, %out, %off;
101    .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
102    .reg .pred %p;
103
104    ld.param.u64 %a, [a_ptr];
105    ld.param.u64 %b, [b_ptr];
106    ld.param.u64 %out, [out_ptr];
107    ld.param.u32 %n4_reg, [n4];
108
109    mov.u32 %bid, %ctaid.x;
110    mov.u32 %bdim, %ntid.x;
111    mov.u32 %r_tid, %tid.x;
112    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
113
114    setp.ge.u32 %p, %r_tid, %n4_reg;
115    @%p bra DONE;
116
117    // Byte offset = tid * 16 (4 floats × 4 bytes)
118    cvt.u64.u32 %off, %r_tid;
119    shl.b64 %off, %off, 4;
120
121    add.u64 %a, %a, %off;
122    add.u64 %b, %b, %off;
123    add.u64 %out, %out, %off;
124
125    ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
126    ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
127
128    add.f32 %r0, %a0, %b0;
129    add.f32 %r1, %a1, %b1;
130    add.f32 %r2, %a2, %b2;
131    add.f32 %r3, %a3, %b3;
132
133    st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
134
135DONE:
136    ret;
137}
138";
139
140/// PTX source for `mul_vec4_kernel`: vectorized multiply, 4 elements per thread.
141#[cfg(feature = "cuda")]
142pub(crate) const MUL_VEC4_PTX: &str = "\
143.version 7.0
144.target sm_52
145.address_size 64
146
147.visible .entry mul_vec4_kernel(
148    .param .u64 a_ptr,
149    .param .u64 b_ptr,
150    .param .u64 out_ptr,
151    .param .u32 n4
152) {
153    .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
154    .reg .u64 %a, %b, %out, %off;
155    .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
156    .reg .pred %p;
157
158    ld.param.u64 %a, [a_ptr];
159    ld.param.u64 %b, [b_ptr];
160    ld.param.u64 %out, [out_ptr];
161    ld.param.u32 %n4_reg, [n4];
162
163    mov.u32 %bid, %ctaid.x;
164    mov.u32 %bdim, %ntid.x;
165    mov.u32 %r_tid, %tid.x;
166    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
167
168    setp.ge.u32 %p, %r_tid, %n4_reg;
169    @%p bra DONE;
170
171    cvt.u64.u32 %off, %r_tid;
172    shl.b64 %off, %off, 4;
173
174    add.u64 %a, %a, %off;
175    add.u64 %b, %b, %off;
176    add.u64 %out, %out, %off;
177
178    ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
179    ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
180
181    mul.f32 %r0, %a0, %b0;
182    mul.f32 %r1, %a1, %b1;
183    mul.f32 %r2, %a2, %b2;
184    mul.f32 %r3, %a3, %b3;
185
186    st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
187
188DONE:
189    ret;
190}
191";
192
193/// PTX source for `sub_kernel`: `out[i] = a[i] - b[i]`.
194#[cfg(feature = "cuda")]
195pub(crate) const SUB_PTX: &str = "\
196.version 7.0
197.target sm_52
198.address_size 64
199
200.visible .entry sub_kernel(
201    .param .u64 a_ptr,
202    .param .u64 b_ptr,
203    .param .u64 out_ptr,
204    .param .u32 n
205) {
206    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
207    .reg .u64 %a, %b, %out, %off;
208    .reg .f32 %va, %vb, %vr;
209    .reg .pred %p;
210
211    ld.param.u64 %a, [a_ptr];
212    ld.param.u64 %b, [b_ptr];
213    ld.param.u64 %out, [out_ptr];
214    ld.param.u32 %n_reg, [n];
215
216    mov.u32 %bid, %ctaid.x;
217    mov.u32 %bdim, %ntid.x;
218    mov.u32 %r_tid, %tid.x;
219    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
220
221    setp.ge.u32 %p, %r_tid, %n_reg;
222    @%p bra DONE;
223
224    cvt.u64.u32 %off, %r_tid;
225    shl.b64 %off, %off, 2;
226
227    add.u64 %a, %a, %off;
228    add.u64 %b, %b, %off;
229    add.u64 %out, %out, %off;
230
231    ld.global.f32 %va, [%a];
232    ld.global.f32 %vb, [%b];
233    sub.f32 %vr, %va, %vb;
234    st.global.f32 [%out], %vr;
235
236DONE:
237    ret;
238}
239";
240
241/// PTX source for `mul_kernel`: `out[i] = a[i] * b[i]`.
242#[cfg(feature = "cuda")]
243pub(crate) const MUL_PTX: &str = "\
244.version 7.0
245.target sm_52
246.address_size 64
247
248.visible .entry mul_kernel(
249    .param .u64 a_ptr,
250    .param .u64 b_ptr,
251    .param .u64 out_ptr,
252    .param .u32 n
253) {
254    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
255    .reg .u64 %a, %b, %out, %off;
256    .reg .f32 %va, %vb, %vr;
257    .reg .pred %p;
258
259    ld.param.u64 %a, [a_ptr];
260    ld.param.u64 %b, [b_ptr];
261    ld.param.u64 %out, [out_ptr];
262    ld.param.u32 %n_reg, [n];
263
264    mov.u32 %bid, %ctaid.x;
265    mov.u32 %bdim, %ntid.x;
266    mov.u32 %r_tid, %tid.x;
267    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
268
269    setp.ge.u32 %p, %r_tid, %n_reg;
270    @%p bra DONE;
271
272    cvt.u64.u32 %off, %r_tid;
273    shl.b64 %off, %off, 2;
274
275    add.u64 %a, %a, %off;
276    add.u64 %b, %b, %off;
277    add.u64 %out, %out, %off;
278
279    ld.global.f32 %va, [%a];
280    ld.global.f32 %vb, [%b];
281    mul.f32 %vr, %va, %vb;
282    st.global.f32 [%out], %vr;
283
284DONE:
285    ret;
286}
287";
288
289/// PTX source for `neg_kernel`: `out[i] = -a[i]`.
290#[cfg(feature = "cuda")]
291pub(crate) const NEG_PTX: &str = "\
292.version 7.0
293.target sm_52
294.address_size 64
295
296.visible .entry neg_kernel(
297    .param .u64 a_ptr,
298    .param .u64 out_ptr,
299    .param .u32 n
300) {
301    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
302    .reg .u64 %a, %out, %off;
303    .reg .f32 %va, %vr;
304    .reg .pred %p;
305
306    ld.param.u64 %a, [a_ptr];
307    ld.param.u64 %out, [out_ptr];
308    ld.param.u32 %n_reg, [n];
309
310    mov.u32 %bid, %ctaid.x;
311    mov.u32 %bdim, %ntid.x;
312    mov.u32 %r_tid, %tid.x;
313    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
314
315    setp.ge.u32 %p, %r_tid, %n_reg;
316    @%p bra DONE;
317
318    cvt.u64.u32 %off, %r_tid;
319    shl.b64 %off, %off, 2;
320
321    add.u64 %a, %a, %off;
322    add.u64 %out, %out, %off;
323
324    ld.global.f32 %va, [%a];
325    neg.f32 %vr, %va;
326    st.global.f32 [%out], %vr;
327
328DONE:
329    ret;
330}
331";
332
333/// PTX source for `relu_kernel`: `out[i] = max(a[i], 0.0)`.
334#[cfg(feature = "cuda")]
335pub(crate) const RELU_PTX: &str = "\
336.version 7.0
337.target sm_52
338.address_size 64
339
340.visible .entry relu_kernel(
341    .param .u64 a_ptr,
342    .param .u64 out_ptr,
343    .param .u32 n
344) {
345    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
346    .reg .u64 %a, %out, %off;
347    .reg .f32 %va, %vr, %zero;
348    .reg .pred %p;
349
350    ld.param.u64 %a, [a_ptr];
351    ld.param.u64 %out, [out_ptr];
352    ld.param.u32 %n_reg, [n];
353
354    mov.u32 %bid, %ctaid.x;
355    mov.u32 %bdim, %ntid.x;
356    mov.u32 %r_tid, %tid.x;
357    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
358
359    setp.ge.u32 %p, %r_tid, %n_reg;
360    @%p bra DONE;
361
362    cvt.u64.u32 %off, %r_tid;
363    shl.b64 %off, %off, 2;
364
365    add.u64 %a, %a, %off;
366    add.u64 %out, %out, %off;
367
368    ld.global.f32 %va, [%a];
369    mov.f32 %zero, 0f00000000;
370    max.f32 %vr, %va, %zero;
371    st.global.f32 [%out], %vr;
372
373DONE:
374    ret;
375}
376";
377
378/// PTX source for `scale_kernel`: `out[i] = a[i] * scalar`.
379#[cfg(feature = "cuda")]
380pub(crate) const SCALE_PTX: &str = "\
381.version 7.0
382.target sm_52
383.address_size 64
384
385.visible .entry scale_kernel(
386    .param .u64 a_ptr,
387    .param .u64 out_ptr,
388    .param .f32 scalar,
389    .param .u32 n
390) {
391    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
392    .reg .u64 %a, %out, %off;
393    .reg .f32 %va, %vr, %s;
394    .reg .pred %p;
395
396    ld.param.u64 %a, [a_ptr];
397    ld.param.u64 %out, [out_ptr];
398    ld.param.f32 %s, [scalar];
399    ld.param.u32 %n_reg, [n];
400
401    mov.u32 %bid, %ctaid.x;
402    mov.u32 %bdim, %ntid.x;
403    mov.u32 %r_tid, %tid.x;
404    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
405
406    setp.ge.u32 %p, %r_tid, %n_reg;
407    @%p bra DONE;
408
409    cvt.u64.u32 %off, %r_tid;
410    shl.b64 %off, %off, 2;
411
412    add.u64 %a, %a, %off;
413    add.u64 %out, %out, %off;
414
415    ld.global.f32 %va, [%a];
416    mul.f32 %vr, %va, %s;
417    st.global.f32 [%out], %vr;
418
419DONE:
420    ret;
421}
422";
423
424/// PTX for 2D matrix transpose: `out[j * M + i] = in[i * N + j]`.
425/// Thread `tid` maps to output index; computes the corresponding input index.
426#[cfg(feature = "cuda")]
427pub(crate) const TRANSPOSE_2D_PTX: &str = "\
428.version 7.0\n\
429.target sm_52\n\
430.address_size 64\n\
431\n\
432.visible .entry transpose_2d_kernel(\n\
433    .param .u64 in_ptr,\n\
434    .param .u64 out_ptr,\n\
435    .param .u32 M,\n\
436    .param .u32 N,\n\
437    .param .u32 total\n\
438) {\n\
439    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %N_reg;\n\
440    .reg .u32 %out_row, %out_col, %in_idx;\n\
441    .reg .u64 %in, %out, %off_in, %off_out;\n\
442    .reg .f32 %val;\n\
443    .reg .pred %p;\n\
444\n\
445    ld.param.u64 %in, [in_ptr];\n\
446    ld.param.u64 %out, [out_ptr];\n\
447    ld.param.u32 %M_reg, [M];\n\
448    ld.param.u32 %N_reg, [N];\n\
449    ld.param.u32 %total_reg, [total];\n\
450\n\
451    mov.u32 %bid, %ctaid.x;\n\
452    mov.u32 %bdim, %ntid.x;\n\
453    mov.u32 %r_tid, %tid.x;\n\
454    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
455\n\
456    setp.ge.u32 %p, %r_tid, %total_reg;\n\
457    @%p bra DONE;\n\
458\n\
459    // Output shape is [N, M]. tid = out_row * M + out_col.\n\
460    div.u32 %out_row, %r_tid, %M_reg;\n\
461    rem.u32 %out_col, %r_tid, %M_reg;\n\
462    // Input index: out_col * N + out_row (transposed).\n\
463    mad.lo.u32 %in_idx, %out_col, %N_reg, %out_row;\n\
464\n\
465    cvt.u64.u32 %off_in, %in_idx;\n\
466    shl.b64 %off_in, %off_in, 2;\n\
467    add.u64 %off_in, %in, %off_in;\n\
468    ld.global.f32 %val, [%off_in];\n\
469\n\
470    cvt.u64.u32 %off_out, %r_tid;\n\
471    shl.b64 %off_out, %off_out, 2;\n\
472    add.u64 %off_out, %out, %off_out;\n\
473    st.global.f32 [%off_out], %val;\n\
474\n\
475DONE:\n\
476    ret;\n\
477}\n\
478";
479
480// ---------------------------------------------------------------------------
481// 4D permute (0,2,1,3) PTX kernel — swap dims 1 and 2
482// ---------------------------------------------------------------------------
483// Input:  [d0, d1, d2, d3]
484// Output: [d0, d2, d1, d3]
485// Thread i computes output[i] by mapping to the transposed input index.
486
487#[cfg(feature = "cuda")]
488pub(crate) const PERMUTE_0213_PTX: &str = "\
489.version 7.0\n\
490.target sm_52\n\
491.address_size 64\n\
492\n\
493.visible .entry permute_0213_kernel(\n\
494    .param .u64 in_ptr,\n\
495    .param .u64 out_ptr,\n\
496    .param .u32 d0,\n\
497    .param .u32 d1,\n\
498    .param .u32 d2,\n\
499    .param .u32 d3,\n\
500    .param .u32 total\n\
501) {\n\
502    .reg .u32 %r_tid, %bid, %bdim, %total_reg;\n\
503    .reg .u32 %d0r, %d1r, %d2r, %d3r;\n\
504    .reg .u32 %i0, %i1, %i2, %i3, %rem, %in_idx;\n\
505    .reg .u32 %s_out2, %s_out1, %s_in1;\n\
506    .reg .u64 %in, %out, %off_in, %off_out;\n\
507    .reg .f32 %val;\n\
508    .reg .pred %p;\n\
509\n\
510    ld.param.u64 %in, [in_ptr];\n\
511    ld.param.u64 %out, [out_ptr];\n\
512    ld.param.u32 %d0r, [d0];\n\
513    ld.param.u32 %d1r, [d1];\n\
514    ld.param.u32 %d2r, [d2];\n\
515    ld.param.u32 %d3r, [d3];\n\
516    ld.param.u32 %total_reg, [total];\n\
517\n\
518    mov.u32 %bid, %ctaid.x;\n\
519    mov.u32 %bdim, %ntid.x;\n\
520    mov.u32 %r_tid, %tid.x;\n\
521    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
522\n\
523    setp.ge.u32 %p, %r_tid, %total_reg;\n\
524    @%p bra DONE;\n\
525\n\
526    // Output shape: [d0, d2, d1, d3]\n\
527    // Decompose tid into (i0, i2, i1, i3) in output layout.\n\
528    mul.lo.u32 %s_out2, %d1r, %d3r;\n\
529    mul.lo.u32 %s_out1, %s_out2, %d2r;\n\
530\n\
531    div.u32 %i0, %r_tid, %s_out1;\n\
532    rem.u32 %rem, %r_tid, %s_out1;\n\
533    div.u32 %i2, %rem, %s_out2;\n\
534    rem.u32 %rem, %rem, %s_out2;\n\
535    div.u32 %i1, %rem, %d3r;\n\
536    rem.u32 %i3, %rem, %d3r;\n\
537\n\
538    // Input index: i0 * (d1*d2*d3) + i1 * (d2*d3) + i2 * d3 + i3\n\
539    mul.lo.u32 %s_in1, %d2r, %d3r;\n\
540    mul.lo.u32 %in_idx, %i0, %d1r;\n\
541    add.u32 %in_idx, %in_idx, %i1;\n\
542    mul.lo.u32 %in_idx, %in_idx, %s_in1;\n\
543    mad.lo.u32 %in_idx, %i2, %d3r, %in_idx;\n\
544    add.u32 %in_idx, %in_idx, %i3;\n\
545\n\
546    cvt.u64.u32 %off_in, %in_idx;\n\
547    shl.b64 %off_in, %off_in, 2;\n\
548    add.u64 %off_in, %in, %off_in;\n\
549    ld.global.f32 %val, [%off_in];\n\
550\n\
551    cvt.u64.u32 %off_out, %r_tid;\n\
552    shl.b64 %off_out, %off_out, 2;\n\
553    add.u64 %off_out, %out, %off_out;\n\
554    st.global.f32 [%off_out], %val;\n\
555\n\
556DONE:\n\
557    ret;\n\
558}\n\
559";
560
561// ---------------------------------------------------------------------------
562// f32-to-f16 conversion PTX kernel: out_f16[i] = float2half(in_f32[i])
563// ---------------------------------------------------------------------------
564// Used by gpu_matmul_f16 to cast f32 inputs to f16 on-GPU before calling
565// cublasGemmEx. The output is stored as u16 (IEEE 754 half-precision bits).
566
567#[cfg(feature = "cuda")]
568pub(crate) const F32_TO_F16_PTX: &str = "\
569.version 7.0
570.target sm_52
571.address_size 64
572
573.visible .entry f32_to_f16_kernel(
574    .param .u64 in_ptr,
575    .param .u64 out_ptr,
576    .param .u32 n
577) {
578    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
579    .reg .u64 %in, %out, %off_in, %off_out;
580    .reg .f32 %vf;
581    .reg .b16 %vh;
582    .reg .pred %p;
583
584    ld.param.u64 %in, [in_ptr];
585    ld.param.u64 %out, [out_ptr];
586    ld.param.u32 %n_reg, [n];
587
588    mov.u32 %bid, %ctaid.x;
589    mov.u32 %bdim, %ntid.x;
590    mov.u32 %r_tid, %tid.x;
591    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
592
593    setp.ge.u32 %p, %r_tid, %n_reg;
594    @%p bra DONE;
595
596    // Compute input offset: i * 4 (f32 = 4 bytes)
597    cvt.u64.u32 %off_in, %r_tid;
598    shl.b64 %off_in, %off_in, 2;
599    add.u64 %in, %in, %off_in;
600
601    // Compute output offset: i * 2 (f16 = 2 bytes)
602    cvt.u64.u32 %off_out, %r_tid;
603    shl.b64 %off_out, %off_out, 1;
604    add.u64 %out, %out, %off_out;
605
606    // Load f32, convert to f16 (round-to-nearest-even), store as u16
607    ld.global.f32 %vf, [%in];
608    cvt.rn.f16.f32 %vh, %vf;
609    st.global.b16 [%out], %vh;
610
611DONE:
612    ret;
613}
614";
615
616/// PTX source for `f32_to_bf16_kernel`: convert f32 → bf16 (stored as u16).
617///
618/// BF16 is the top 16 bits of f32 with round-to-nearest-even. We do this
619/// with integer bit ops: add rounding bias 0x7FFF + bit 16 of the value,
620/// then shift right 16. This works on sm_52+ (no special bf16 instructions
621/// needed).
622#[cfg(feature = "cuda")]
623pub(crate) const F32_TO_BF16_PTX: &str = "\
624.version 7.0
625.target sm_52
626.address_size 64
627
628.visible .entry f32_to_bf16_kernel(
629    .param .u64 in_ptr,
630    .param .u64 out_ptr,
631    .param .u32 n
632) {
633    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
634    .reg .u64 %in, %out, %off_in, %off_out;
635    .reg .f32 %vf;
636    .reg .u32 %bits, %round, %lsb, %result;
637    .reg .pred %p;
638
639    ld.param.u64 %in, [in_ptr];
640    ld.param.u64 %out, [out_ptr];
641    ld.param.u32 %n_reg, [n];
642
643    mov.u32 %bid, %ctaid.x;
644    mov.u32 %bdim, %ntid.x;
645    mov.u32 %r_tid, %tid.x;
646    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
647
648    setp.ge.u32 %p, %r_tid, %n_reg;
649    @%p bra DONE;
650
651    cvt.u64.u32 %off_in, %r_tid;
652    shl.b64 %off_in, %off_in, 2;
653    add.u64 %in, %in, %off_in;
654
655    cvt.u64.u32 %off_out, %r_tid;
656    shl.b64 %off_out, %off_out, 1;
657    add.u64 %out, %out, %off_out;
658
659    // Load f32 as raw bits
660    ld.global.u32 %bits, [%in];
661
662    // Round-to-nearest-even: add (0x7FFF + bit[16]) then shift right 16
663    shr.u32 %lsb, %bits, 16;
664    and.b32 %lsb, %lsb, 1;
665    add.u32 %round, %bits, 0x7FFF;
666    add.u32 %round, %round, %lsb;
667    shr.u32 %result, %round, 16;
668
669    // Store as u16
670    st.global.u16 [%out], %result;
671
672DONE:
673    ret;
674}
675";
676
677// ---------------------------------------------------------------------------
678// Small matmul PTX kernel: C = A @ B, one thread per output element
679// ---------------------------------------------------------------------------
680// For small matrices where cuBLAS JIT compilation overhead > compute time.
681// Compiles once via module_cache, never JIT-recompiles for different sizes.
682
683#[cfg(feature = "cuda")]
684pub(crate) const SMALL_MATMUL_PTX: &str = "\
685.version 7.0
686.target sm_52
687.address_size 64
688
689.visible .entry small_matmul_kernel(
690    .param .u64 a_ptr,
691    .param .u64 b_ptr,
692    .param .u64 c_ptr,
693    .param .u32 M,
694    .param .u32 K,
695    .param .u32 N,
696    .param .u32 total
697) {
698    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %K_reg, %N_reg;
699    .reg .u32 %row, %col, %p, %idx;
700    .reg .u64 %a, %b, %c, %a_off, %b_off, %c_off;
701    .reg .f32 %sum, %va, %vb;
702    .reg .pred %bounds_p, %loop_p;
703
704    ld.param.u64 %a, [a_ptr];
705    ld.param.u64 %b, [b_ptr];
706    ld.param.u64 %c, [c_ptr];
707    ld.param.u32 %M_reg, [M];
708    ld.param.u32 %K_reg, [K];
709    ld.param.u32 %N_reg, [N];
710    ld.param.u32 %total_reg, [total];
711
712    mov.u32 %bid, %ctaid.x;
713    mov.u32 %bdim, %ntid.x;
714    mov.u32 %r_tid, %tid.x;
715    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
716
717    setp.ge.u32 %bounds_p, %r_tid, %total_reg;
718    @%bounds_p bra DONE;
719
720    div.u32 %row, %r_tid, %N_reg;
721    rem.u32 %col, %r_tid, %N_reg;
722
723    mov.f32 %sum, 0f00000000;
724    mov.u32 %p, 0;
725DOT:
726    setp.ge.u32 %loop_p, %p, %K_reg;
727    @%loop_p bra DOT_DONE;
728
729    mad.lo.u32 %idx, %row, %K_reg, %p;
730    cvt.u64.u32 %a_off, %idx;
731    shl.b64 %a_off, %a_off, 2;
732    add.u64 %a_off, %a, %a_off;
733    ld.global.f32 %va, [%a_off];
734
735    mad.lo.u32 %idx, %p, %N_reg, %col;
736    cvt.u64.u32 %b_off, %idx;
737    shl.b64 %b_off, %b_off, 2;
738    add.u64 %b_off, %b, %b_off;
739    ld.global.f32 %vb, [%b_off];
740
741    fma.rn.f32 %sum, %va, %vb, %sum;
742    add.u32 %p, %p, 1;
743    bra DOT;
744DOT_DONE:
745
746    cvt.u64.u32 %c_off, %r_tid;
747    shl.b64 %c_off, %c_off, 2;
748    add.u64 %c_off, %c, %c_off;
749    st.global.f32 [%c_off], %sum;
750
751DONE:
752    ret;
753}
754";
755
756// ---------------------------------------------------------------------------
757// Slice-write PTX kernel: copy [N, D] into row `pos` of [N, max_len, D]
758// ---------------------------------------------------------------------------
759
760#[cfg(feature = "cuda")]
761pub(crate) const SLICE_WRITE_PTX: &str = "\
762.version 7.0
763.target sm_52
764.address_size 64
765
766.visible .entry slice_write_kernel(
767    .param .u64 src_ptr,
768    .param .u64 dst_ptr,
769    .param .u32 n,
770    .param .u32 D,
771    .param .u32 max_len,
772    .param .u32 pos
773) {
774    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
775    .reg .u32 %batch_idx, %d_idx, %dst_row;
776    .reg .u64 %src, %dst, %src_off, %dst_off;
777    .reg .f32 %val;
778    .reg .pred %p;
779
780    ld.param.u64 %src, [src_ptr];
781    ld.param.u64 %dst, [dst_ptr];
782    ld.param.u32 %n_reg, [n];
783    ld.param.u32 %D_reg, [D];
784    ld.param.u32 %max_len_reg, [max_len];
785    ld.param.u32 %pos_reg, [pos];
786
787    mov.u32 %bid, %ctaid.x;
788    mov.u32 %bdim, %ntid.x;
789    mov.u32 %r_tid, %tid.x;
790    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
791
792    setp.ge.u32 %p, %r_tid, %n_reg;
793    @%p bra DONE;
794
795    cvt.u64.u32 %src_off, %r_tid;
796    shl.b64 %src_off, %src_off, 2;
797    add.u64 %src, %src, %src_off;
798    ld.global.f32 %val, [%src];
799
800    div.u32 %batch_idx, %r_tid, %D_reg;
801    rem.u32 %d_idx, %r_tid, %D_reg;
802    mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
803    add.u32 %dst_row, %dst_row, %pos_reg;
804    mul.lo.u32 %dst_row, %dst_row, %D_reg;
805    add.u32 %dst_row, %dst_row, %d_idx;
806    cvt.u64.u32 %dst_off, %dst_row;
807    shl.b64 %dst_off, %dst_off, 2;
808    add.u64 %dst, %dst, %dst_off;
809    st.global.f32 [%dst], %val;
810
811DONE:
812    ret;
813}
814";
815
816/// PTX for `slice_write_indirect_kernel`: same as `slice_write_kernel` but
817/// reads `pos` from a device pointer. This enables CUDA graph capture — the
818/// graph records the pointer address (fixed), and we update the u32 value
819/// at that address before each graph replay.
820#[cfg(feature = "cuda")]
821pub(crate) const SLICE_WRITE_INDIRECT_PTX: &str = "\
822.version 7.0
823.target sm_52
824.address_size 64
825
826.visible .entry slice_write_indirect_kernel(
827    .param .u64 src_ptr,
828    .param .u64 dst_ptr,
829    .param .u32 n,
830    .param .u32 D,
831    .param .u32 max_len,
832    .param .u64 pos_ptr
833) {
834    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
835    .reg .u32 %batch_idx, %d_idx, %dst_row;
836    .reg .u64 %src, %dst, %src_off, %dst_off, %pos_p;
837    .reg .f32 %val;
838    .reg .pred %p;
839
840    ld.param.u64 %src, [src_ptr];
841    ld.param.u64 %dst, [dst_ptr];
842    ld.param.u32 %n_reg, [n];
843    ld.param.u32 %D_reg, [D];
844    ld.param.u32 %max_len_reg, [max_len];
845    ld.param.u64 %pos_p, [pos_ptr];
846    ld.global.u32 %pos_reg, [%pos_p];
847
848    mov.u32 %bid, %ctaid.x;
849    mov.u32 %bdim, %ntid.x;
850    mov.u32 %r_tid, %tid.x;
851    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
852
853    setp.ge.u32 %p, %r_tid, %n_reg;
854    @%p bra DONE;
855
856    cvt.u64.u32 %src_off, %r_tid;
857    shl.b64 %src_off, %src_off, 2;
858    add.u64 %src, %src, %src_off;
859    ld.global.f32 %val, [%src];
860
861    div.u32 %batch_idx, %r_tid, %D_reg;
862    rem.u32 %d_idx, %r_tid, %D_reg;
863    mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
864    add.u32 %dst_row, %dst_row, %pos_reg;
865    mul.lo.u32 %dst_row, %dst_row, %D_reg;
866    add.u32 %dst_row, %dst_row, %d_idx;
867    cvt.u64.u32 %dst_off, %dst_row;
868    shl.b64 %dst_off, %dst_off, 2;
869    add.u64 %dst, %dst, %dst_off;
870    st.global.f32 [%dst], %val;
871
872DONE:
873    ret;
874}
875";
876
877/// PTX for `causal_mask_indirect_kernel`: builds an attention mask where
878/// `out[h, col] = 0.0` for `col < total_len` and `-1e9` for `col >= total_len`.
879/// `total_len` is read from a device pointer (for CUDA graph capture).
880/// Output shape: `[n_head, max_pos]` — one mask row per head (all identical).
881/// Thread `tid` maps to flat index; column = `tid % max_pos`.
882#[cfg(feature = "cuda")]
883pub(crate) const CAUSAL_MASK_INDIRECT_PTX: &str = "\
884.version 7.0
885.target sm_52
886.address_size 64
887
888.visible .entry causal_mask_indirect_kernel(
889    .param .u64 total_len_ptr,
890    .param .u64 out_ptr,
891    .param .u32 max_pos,
892    .param .u32 total
893) {
894    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %tlen, %max_pos_reg, %col;
895    .reg .u64 %out, %off, %tl_p;
896    .reg .f32 %val;
897    .reg .pred %bounds_p, %mask_p;
898
899    ld.param.u64 %tl_p, [total_len_ptr];
900    ld.param.u64 %out, [out_ptr];
901    ld.param.u32 %max_pos_reg, [max_pos];
902    ld.param.u32 %total_reg, [total];
903
904    ld.global.u32 %tlen, [%tl_p];
905
906    mov.u32 %bid, %ctaid.x;
907    mov.u32 %bdim, %ntid.x;
908    mov.u32 %r_tid, %tid.x;
909    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
910
911    setp.ge.u32 %bounds_p, %r_tid, %total_reg;
912    @%bounds_p bra DONE;
913
914    rem.u32 %col, %r_tid, %max_pos_reg;
915    setp.lt.u32 %mask_p, %col, %tlen;
916    @%mask_p bra WRITE_ZERO;
917
918    // 0fCE6E6B28 = -1.0e9 in IEEE 754 f32, used as a large negative mask value
919    // to effectively zero out masked positions after softmax.
920    mov.f32 %val, 0fCE6E6B28;
921    bra WRITE;
922
923WRITE_ZERO:
924    mov.f32 %val, 0f00000000;
925
926WRITE:
927    cvt.u64.u32 %off, %r_tid;
928    shl.b64 %off, %off, 2;
929    add.u64 %out, %out, %off;
930    st.global.f32 [%out], %val;
931
932DONE:
933    ret;
934}
935";
936
937// ---------------------------------------------------------------------------
938// Embedding lookup PTX kernel: output[d] = weight[token_id * D + d]
939// ---------------------------------------------------------------------------
940
941#[cfg(feature = "cuda")]
942pub(crate) const EMBED_LOOKUP_PTX: &str = "\
943.version 7.0
944.target sm_52
945.address_size 64
946
947.visible .entry embed_lookup_kernel(
948    .param .u64 idx_ptr,
949    .param .u64 weight_ptr,
950    .param .u64 out_ptr,
951    .param .u32 D
952) {
953    .reg .u32 %r_tid, %bid, %bdim, %D_reg, %row, %src_idx;
954    .reg .u64 %idx_addr, %w, %out, %off;
955    .reg .f32 %idx_f, %val;
956    .reg .pred %p;
957
958    ld.param.u64 %idx_addr, [idx_ptr];
959    ld.param.u64 %w, [weight_ptr];
960    ld.param.u64 %out, [out_ptr];
961    ld.param.u32 %D_reg, [D];
962
963    mov.u32 %bid, %ctaid.x;
964    mov.u32 %bdim, %ntid.x;
965    mov.u32 %r_tid, %tid.x;
966    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
967
968    setp.ge.u32 %p, %r_tid, %D_reg;
969    @%p bra DONE;
970
971    ld.global.f32 %idx_f, [%idx_addr];
972    cvt.rzi.u32.f32 %row, %idx_f;
973
974    mad.lo.u32 %src_idx, %row, %D_reg, %r_tid;
975    cvt.u64.u32 %off, %src_idx;
976    shl.b64 %off, %off, 2;
977    add.u64 %off, %w, %off;
978    ld.global.f32 %val, [%off];
979
980    cvt.u64.u32 %off, %r_tid;
981    shl.b64 %off, %off, 2;
982    add.u64 %off, %out, %off;
983    st.global.f32 [%off], %val;
984
985DONE:
986    ret;
987}
988";
989
990// ---------------------------------------------------------------------------
991// Batch embedding lookup PTX kernel
992// ---------------------------------------------------------------------------
993// Given N f32 indices and a weight matrix [V, D], gather N rows into [N, D].
994// Thread `tid` computes one element: row = tid / D, col = tid % D.
995// out[tid] = weight[indices[row] * D + col]
996
997#[cfg(feature = "cuda")]
998pub(crate) const EMBED_LOOKUP_BATCH_PTX: &str = "\
999.version 7.0
1000.target sm_52
1001.address_size 64
1002
1003.visible .entry embed_lookup_batch_kernel(
1004    .param .u64 idx_ptr,
1005    .param .u64 weight_ptr,
1006    .param .u64 out_ptr,
1007    .param .u32 D,
1008    .param .u32 total
1009) {
1010    .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1011    .reg .u32 %row, %col, %src_idx;
1012    .reg .u64 %idx_addr, %w, %out, %off;
1013    .reg .f32 %idx_f, %val;
1014    .reg .pred %p;
1015
1016    ld.param.u64 %idx_addr, [idx_ptr];
1017    ld.param.u64 %w, [weight_ptr];
1018    ld.param.u64 %out, [out_ptr];
1019    ld.param.u32 %D_reg, [D];
1020    ld.param.u32 %total_reg, [total];
1021
1022    mov.u32 %bid, %ctaid.x;
1023    mov.u32 %bdim, %ntid.x;
1024    mov.u32 %tid, %tid.x;
1025    mad.lo.u32 %tid, %bid, %bdim, %tid;
1026
1027    setp.ge.u32 %p, %tid, %total_reg;
1028    @%p bra DONE;
1029
1030    // row = tid / D, col = tid % D
1031    div.u32 %row, %tid, %D_reg;
1032    rem.u32 %col, %tid, %D_reg;
1033
1034    // Read indices[row] (f32 -> u32)
1035    cvt.u64.u32 %off, %row;
1036    shl.b64 %off, %off, 2;
1037    add.u64 %off, %idx_addr, %off;
1038    ld.global.f32 %idx_f, [%off];
1039    cvt.rzi.u32.f32 %src_idx, %idx_f;
1040
1041    // src_idx = indices[row] * D + col
1042    mad.lo.u32 %src_idx, %src_idx, %D_reg, %col;
1043    cvt.u64.u32 %off, %src_idx;
1044    shl.b64 %off, %off, 2;
1045    add.u64 %off, %w, %off;
1046    ld.global.f32 %val, [%off];
1047
1048    // Write to out[tid]
1049    cvt.u64.u32 %off, %tid;
1050    shl.b64 %off, %off, 2;
1051    add.u64 %off, %out, %off;
1052    st.global.f32 [%off], %val;
1053
1054DONE:
1055    ret;
1056}
1057";
1058
1059// ---------------------------------------------------------------------------
1060// Scatter-add rows PTX kernel (for embedding backward)
1061// ---------------------------------------------------------------------------
1062// Given grad_output [N, D] and indices [N] (f32), atomically accumulate:
1063//   grad_weight[indices[row], col] += grad_output[row * D + col]
1064// Thread `tid` handles one element: row = tid / D, col = tid % D.
1065
1066#[cfg(feature = "cuda")]
1067pub(crate) const SCATTER_ADD_ROWS_PTX: &str = "\
1068.version 7.0
1069.target sm_52
1070.address_size 64
1071
1072.visible .entry scatter_add_rows_kernel(
1073    .param .u64 grad_output_ptr,
1074    .param .u64 indices_ptr,
1075    .param .u64 grad_weight_ptr,
1076    .param .u32 D,
1077    .param .u32 total
1078) {
1079    .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1080    .reg .u32 %row, %col, %dst_idx;
1081    .reg .u64 %go, %idx_addr, %gw, %off;
1082    .reg .f32 %idx_f, %grad_val, %dummy;
1083    .reg .pred %p;
1084
1085    ld.param.u64 %go, [grad_output_ptr];
1086    ld.param.u64 %idx_addr, [indices_ptr];
1087    ld.param.u64 %gw, [grad_weight_ptr];
1088    ld.param.u32 %D_reg, [D];
1089    ld.param.u32 %total_reg, [total];
1090
1091    mov.u32 %bid, %ctaid.x;
1092    mov.u32 %bdim, %ntid.x;
1093    mov.u32 %tid, %tid.x;
1094    mad.lo.u32 %tid, %bid, %bdim, %tid;
1095
1096    setp.ge.u32 %p, %tid, %total_reg;
1097    @%p bra DONE;
1098
1099    // row = tid / D, col = tid % D
1100    div.u32 %row, %tid, %D_reg;
1101    rem.u32 %col, %tid, %D_reg;
1102
1103    // Read grad_output[tid]
1104    cvt.u64.u32 %off, %tid;
1105    shl.b64 %off, %off, 2;
1106    add.u64 %off, %go, %off;
1107    ld.global.f32 %grad_val, [%off];
1108
1109    // Read indices[row] (f32 -> u32)
1110    cvt.u64.u32 %off, %row;
1111    shl.b64 %off, %off, 2;
1112    add.u64 %off, %idx_addr, %off;
1113    ld.global.f32 %idx_f, [%off];
1114    cvt.rzi.u32.f32 %dst_idx, %idx_f;
1115
1116    // dst_idx = indices[row] * D + col
1117    mad.lo.u32 %dst_idx, %dst_idx, %D_reg, %col;
1118    cvt.u64.u32 %off, %dst_idx;
1119    shl.b64 %off, %off, 2;
1120    add.u64 %off, %gw, %off;
1121    atom.global.add.f32 %dummy, [%off], %grad_val;
1122
1123DONE:
1124    ret;
1125}
1126";
1127
1128// ---------------------------------------------------------------------------
1129// Slice-read PTX kernel: read first `len` rows from [N, max_len, D]
1130// ---------------------------------------------------------------------------
1131// Thread i writes: dst[i] = src[batch_idx * max_len * D + (i % (len*D))]
1132// where batch_idx = i / (len * D)
1133
1134#[cfg(feature = "cuda")]
1135pub(crate) const SLICE_READ_PTX: &str = "\
1136.version 7.0
1137.target sm_52
1138.address_size 64
1139
1140.visible .entry slice_read_kernel(
1141    .param .u64 src_ptr,
1142    .param .u64 dst_ptr,
1143    .param .u32 total,
1144    .param .u32 D,
1145    .param .u32 len,
1146    .param .u32 max_len
1147) {
1148    .reg .u32 %r_tid, %bid, %bdim, %total_reg, %D_reg, %len_reg, %max_len_reg;
1149    .reg .u32 %batch_idx, %within, %row, %col, %src_idx;
1150    .reg .u32 %len_d;
1151    .reg .u64 %src, %dst, %src_off, %dst_off;
1152    .reg .f32 %val;
1153    .reg .pred %p;
1154
1155    ld.param.u64 %src, [src_ptr];
1156    ld.param.u64 %dst, [dst_ptr];
1157    ld.param.u32 %total_reg, [total];
1158    ld.param.u32 %D_reg, [D];
1159    ld.param.u32 %len_reg, [len];
1160    ld.param.u32 %max_len_reg, [max_len];
1161
1162    mov.u32 %bid, %ctaid.x;
1163    mov.u32 %bdim, %ntid.x;
1164    mov.u32 %r_tid, %tid.x;
1165    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1166
1167    setp.ge.u32 %p, %r_tid, %total_reg;
1168    @%p bra DONE;
1169
1170    // dst index = r_tid
1171    // batch_idx = r_tid / (len * D)
1172    // within = r_tid % (len * D)
1173    // row = within / D
1174    // col = within % D
1175    // src_idx = batch_idx * max_len * D + row * D + col
1176    mul.lo.u32 %len_d, %len_reg, %D_reg;
1177    div.u32 %batch_idx, %r_tid, %len_d;
1178    rem.u32 %within, %r_tid, %len_d;
1179    div.u32 %row, %within, %D_reg;
1180    rem.u32 %col, %within, %D_reg;
1181
1182    mul.lo.u32 %src_idx, %batch_idx, %max_len_reg;
1183    add.u32 %src_idx, %src_idx, %row;
1184    mul.lo.u32 %src_idx, %src_idx, %D_reg;
1185    add.u32 %src_idx, %src_idx, %col;
1186
1187    cvt.u64.u32 %src_off, %src_idx;
1188    shl.b64 %src_off, %src_off, 2;
1189    add.u64 %src_off, %src, %src_off;
1190    ld.global.f32 %val, [%src_off];
1191
1192    cvt.u64.u32 %dst_off, %r_tid;
1193    shl.b64 %dst_off, %dst_off, 2;
1194    add.u64 %dst_off, %dst, %dst_off;
1195    st.global.f32 [%dst_off], %val;
1196
1197DONE:
1198    ret;
1199}
1200";
1201
1202// ---------------------------------------------------------------------------
1203// GELU PTX kernel: gelu(x) = x * sigmoid(1.702 * x)
1204//
1205// Uses `.approx` PTX instructions (`ex2.approx.f32`, `rcp.approx.f32`)
1206// for performance. These have reduced precision (~2^-22 relative error)
1207// compared to the full-precision variants, which is acceptable for neural
1208// network training/inference where f32 precision is already limited.
1209// ---------------------------------------------------------------------------
1210
1211#[cfg(feature = "cuda")]
1212pub(crate) const GELU_PTX: &str = "\
1213.version 7.0
1214.target sm_52
1215.address_size 64
1216
1217.visible .entry gelu_kernel(
1218    .param .u64 in_ptr,
1219    .param .u64 out_ptr,
1220    .param .u32 n
1221) {
1222    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1223    .reg .u64 %in, %out, %off;
1224    .reg .f32 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
1225    .reg .pred %p;
1226
1227    ld.param.u64 %in, [in_ptr];
1228    ld.param.u64 %out, [out_ptr];
1229    ld.param.u32 %n_reg, [n];
1230
1231    mov.u32 %bid, %ctaid.x;
1232    mov.u32 %bdim, %ntid.x;
1233    mov.u32 %r_tid, %tid.x;
1234    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1235
1236    setp.ge.u32 %p, %r_tid, %n_reg;
1237    @%p bra DONE;
1238
1239    cvt.u64.u32 %off, %r_tid;
1240    shl.b64 %off, %off, 2;
1241    add.u64 %in, %in, %off;
1242    add.u64 %out, %out, %off;
1243
1244    ld.global.f32 %x, [%in];
1245
1246    mov.f32 %k, 0f3FDA2720;
1247    mul.f32 %neg_kx, %k, %x;
1248    neg.f32 %neg_kx, %neg_kx;
1249    mul.f32 %neg_kx, %neg_kx, 0f3FB8AA3B;
1250    ex2.approx.f32 %exp_neg, %neg_kx;
1251    mov.f32 %one, 0f3F800000;
1252    add.f32 %denom, %one, %exp_neg;
1253    rcp.approx.f32 %sig, %denom;
1254    mul.f32 %result, %x, %sig;
1255    st.global.f32 [%out], %result;
1256
1257DONE:
1258    ret;
1259}
1260";
1261
1262// ---------------------------------------------------------------------------
1263// Backward activation kernels
1264// ---------------------------------------------------------------------------
1265
1266/// PTX source for `relu_backward_kernel`: `out[i] = (input[i] > 0) ? grad[i] : 0`.
1267/// Takes two inputs: grad (upstream gradient) and input (forward activation input).
1268#[cfg(feature = "cuda")]
1269pub(crate) const RELU_BACKWARD_PTX: &str = "\
1270.version 7.0
1271.target sm_52
1272.address_size 64
1273
1274.visible .entry relu_backward_kernel(
1275    .param .u64 grad_ptr,
1276    .param .u64 input_ptr,
1277    .param .u64 out_ptr,
1278    .param .u32 n
1279) {
1280    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1281    .reg .u64 %grad, %input, %out, %off;
1282    .reg .f32 %vg, %vi, %zero, %vr;
1283    .reg .pred %p, %pos;
1284
1285    ld.param.u64 %grad, [grad_ptr];
1286    ld.param.u64 %input, [input_ptr];
1287    ld.param.u64 %out, [out_ptr];
1288    ld.param.u32 %n_reg, [n];
1289
1290    mov.u32 %bid, %ctaid.x;
1291    mov.u32 %bdim, %ntid.x;
1292    mov.u32 %r_tid, %tid.x;
1293    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1294
1295    setp.ge.u32 %p, %r_tid, %n_reg;
1296    @%p bra DONE;
1297
1298    cvt.u64.u32 %off, %r_tid;
1299    shl.b64 %off, %off, 2;
1300
1301    add.u64 %grad, %grad, %off;
1302    add.u64 %input, %input, %off;
1303    add.u64 %out, %out, %off;
1304
1305    ld.global.f32 %vg, [%grad];
1306    ld.global.f32 %vi, [%input];
1307    mov.f32 %zero, 0f00000000;
1308    setp.gt.f32 %pos, %vi, %zero;
1309    selp.f32 %vr, %vg, %zero, %pos;
1310    st.global.f32 [%out], %vr;
1311
1312DONE:
1313    ret;
1314}
1315";
1316
1317/// PTX source for `gelu_backward_kernel`:
1318/// `out[i] = grad[i] * (sig + 1.702 * x * sig * (1 - sig))`
1319/// where `sig = sigmoid(1.702 * x)`.
1320/// This is the exact derivative of `gelu(x) = x * sigmoid(1.702 * x)`.
1321///
1322/// Uses `.approx` PTX instructions (`ex2.approx.f32`, `rcp.approx.f32`)
1323/// for performance. These have reduced precision (~2^-22 relative error)
1324/// compared to the full-precision variants, which is acceptable for neural
1325/// network training/inference where f32 precision is already limited.
1326#[cfg(feature = "cuda")]
1327pub(crate) const GELU_BACKWARD_PTX: &str = "\
1328.version 7.0
1329.target sm_52
1330.address_size 64
1331
1332.visible .entry gelu_backward_kernel(
1333    .param .u64 grad_ptr,
1334    .param .u64 input_ptr,
1335    .param .u64 out_ptr,
1336    .param .u32 n
1337) {
1338    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1339    .reg .u64 %grad, %input, %out, %off;
1340    .reg .f32 %vg, %x, %k, %kx, %neg_kx, %log2e, %exp_neg, %one, %denom, %sig;
1341    .reg .f32 %one_minus_sig, %kx_sig_oms, %dsig, %result;
1342    .reg .pred %p;
1343
1344    ld.param.u64 %grad, [grad_ptr];
1345    ld.param.u64 %input, [input_ptr];
1346    ld.param.u64 %out, [out_ptr];
1347    ld.param.u32 %n_reg, [n];
1348
1349    mov.u32 %bid, %ctaid.x;
1350    mov.u32 %bdim, %ntid.x;
1351    mov.u32 %r_tid, %tid.x;
1352    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1353
1354    setp.ge.u32 %p, %r_tid, %n_reg;
1355    @%p bra DONE;
1356
1357    cvt.u64.u32 %off, %r_tid;
1358    shl.b64 %off, %off, 2;
1359
1360    add.u64 %grad, %grad, %off;
1361    add.u64 %input, %input, %off;
1362    add.u64 %out, %out, %off;
1363
1364    ld.global.f32 %vg, [%grad];
1365    ld.global.f32 %x, [%input];
1366
1367    // sig = sigmoid(1.702 * x)
1368    mov.f32 %k, 0f3FDA2720;
1369    mul.f32 %kx, %k, %x;
1370    neg.f32 %neg_kx, %kx;
1371    mov.f32 %log2e, 0f3FB8AA3B;
1372    mul.f32 %neg_kx, %neg_kx, %log2e;
1373    ex2.approx.f32 %exp_neg, %neg_kx;
1374    mov.f32 %one, 0f3F800000;
1375    add.f32 %denom, %one, %exp_neg;
1376    rcp.approx.f32 %sig, %denom;
1377
1378    // d/dx gelu(x) = sig + k * x * sig * (1 - sig)
1379    sub.f32 %one_minus_sig, %one, %sig;
1380    mul.f32 %kx_sig_oms, %kx, %sig;
1381    mul.f32 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
1382    add.f32 %dsig, %sig, %kx_sig_oms;
1383
1384    // out = grad * d_gelu
1385    mul.f32 %result, %vg, %dsig;
1386    st.global.f32 [%out], %result;
1387
1388DONE:
1389    ret;
1390}
1391";
1392
1393// ---------------------------------------------------------------------------
1394// Index-select (1-D gather) PTX kernel
1395// ---------------------------------------------------------------------------
1396// Thread i: output[i] = input[indices[i]]
1397// Indices are stored as f32 on the GPU (cast to u32 via truncation).
1398
1399#[cfg(feature = "cuda")]
1400pub(crate) const INDEX_SELECT_1D_PTX: &str = "\
1401.version 7.0
1402.target sm_52
1403.address_size 64
1404
1405.visible .entry index_select_1d_kernel(
1406    .param .u64 input_ptr,
1407    .param .u64 indices_ptr,
1408    .param .u64 out_ptr,
1409    .param .u32 n_indices
1410) {
1411    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
1412    .reg .u64 %input, %indices, %out, %off, %addr;
1413    .reg .f32 %idx_f, %val;
1414    .reg .pred %p;
1415
1416    ld.param.u64 %input, [input_ptr];
1417    ld.param.u64 %indices, [indices_ptr];
1418    ld.param.u64 %out, [out_ptr];
1419    ld.param.u32 %n_reg, [n_indices];
1420
1421    mov.u32 %bid, %ctaid.x;
1422    mov.u32 %bdim, %ntid.x;
1423    mov.u32 %r_tid, %tid.x;
1424    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1425
1426    setp.ge.u32 %p, %r_tid, %n_reg;
1427    @%p bra DONE;
1428
1429    // Byte offset for thread
1430    cvt.u64.u32 %off, %r_tid;
1431    shl.b64 %off, %off, 2;
1432
1433    // Read indices[tid] (f32 -> u32)
1434    add.u64 %addr, %indices, %off;
1435    ld.global.f32 %idx_f, [%addr];
1436    cvt.rzi.u32.f32 %idx, %idx_f;
1437
1438    // Read input[idx]
1439    cvt.u64.u32 %addr, %idx;
1440    shl.b64 %addr, %addr, 2;
1441    add.u64 %addr, %input, %addr;
1442    ld.global.f32 %val, [%addr];
1443
1444    // Write output[tid]
1445    add.u64 %addr, %out, %off;
1446    st.global.f32 [%addr], %val;
1447
1448DONE:
1449    ret;
1450}
1451";
1452
1453// ---------------------------------------------------------------------------
1454// Scatter-add (1-D) PTX kernel — backward of index_select
1455// ---------------------------------------------------------------------------
1456// Thread i: atomicAdd(grad_input[indices[i]], grad_output[i])
1457// The output buffer (grad_input) must be pre-zeroed.
1458// Uses atom.global.add.f32 for safe concurrent accumulation when
1459// duplicate indices map multiple threads to the same output slot.
1460
1461#[cfg(feature = "cuda")]
1462pub(crate) const SCATTER_ADD_1D_PTX: &str = "\
1463.version 7.0
1464.target sm_52
1465.address_size 64
1466
1467.visible .entry scatter_add_1d_kernel(
1468    .param .u64 grad_output_ptr,
1469    .param .u64 indices_ptr,
1470    .param .u64 grad_input_ptr,
1471    .param .u32 n_indices
1472) {
1473    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
1474    .reg .u64 %go, %indices, %gi, %off, %addr;
1475    .reg .f32 %idx_f, %grad_val, %dummy;
1476    .reg .pred %p;
1477
1478    ld.param.u64 %go, [grad_output_ptr];
1479    ld.param.u64 %indices, [indices_ptr];
1480    ld.param.u64 %gi, [grad_input_ptr];
1481    ld.param.u32 %n_reg, [n_indices];
1482
1483    mov.u32 %bid, %ctaid.x;
1484    mov.u32 %bdim, %ntid.x;
1485    mov.u32 %r_tid, %tid.x;
1486    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1487
1488    setp.ge.u32 %p, %r_tid, %n_reg;
1489    @%p bra DONE;
1490
1491    // Byte offset for thread
1492    cvt.u64.u32 %off, %r_tid;
1493    shl.b64 %off, %off, 2;
1494
1495    // Read grad_output[tid]
1496    add.u64 %addr, %go, %off;
1497    ld.global.f32 %grad_val, [%addr];
1498
1499    // Read indices[tid] (f32 -> u32)
1500    add.u64 %addr, %indices, %off;
1501    ld.global.f32 %idx_f, [%addr];
1502    cvt.rzi.u32.f32 %idx, %idx_f;
1503
1504    // Atomic add: grad_input[idx] += grad_val
1505    cvt.u64.u32 %addr, %idx;
1506    shl.b64 %addr, %addr, 2;
1507    add.u64 %addr, %gi, %addr;
1508    atom.global.add.f32 %dummy, [%addr], %grad_val;
1509
1510DONE:
1511    ret;
1512}
1513";
1514
1515// ---------------------------------------------------------------------------
1516// Masked-fill PTX kernel
1517// ---------------------------------------------------------------------------
1518// Thread i: output[i] = mask[i] >= 0.5 ? fill_value : input[i]
1519// Mask is stored as f32 (1.0 = true, 0.0 = false).
1520
1521#[cfg(feature = "cuda")]
1522pub(crate) const MASKED_FILL_PTX: &str = "\
1523.version 7.0
1524.target sm_52
1525.address_size 64
1526
1527.visible .entry masked_fill_kernel(
1528    .param .u64 input_ptr,
1529    .param .u64 mask_ptr,
1530    .param .u64 out_ptr,
1531    .param .f32 fill_value,
1532    .param .u32 n
1533) {
1534    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1535    .reg .u64 %input, %mask, %out, %off;
1536    .reg .f32 %in_val, %mask_val, %fill, %result, %half;
1537    .reg .pred %p, %pmask;
1538
1539    ld.param.u64 %input, [input_ptr];
1540    ld.param.u64 %mask, [mask_ptr];
1541    ld.param.u64 %out, [out_ptr];
1542    ld.param.f32 %fill, [fill_value];
1543    ld.param.u32 %n_reg, [n];
1544
1545    mov.u32 %bid, %ctaid.x;
1546    mov.u32 %bdim, %ntid.x;
1547    mov.u32 %r_tid, %tid.x;
1548    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1549
1550    setp.ge.u32 %p, %r_tid, %n_reg;
1551    @%p bra DONE;
1552
1553    cvt.u64.u32 %off, %r_tid;
1554    shl.b64 %off, %off, 2;
1555
1556    add.u64 %input, %input, %off;
1557    add.u64 %mask, %mask, %off;
1558    add.u64 %out, %out, %off;
1559
1560    ld.global.f32 %in_val, [%input];
1561    ld.global.f32 %mask_val, [%mask];
1562    mov.f32 %half, 0f3F000000;
1563    setp.ge.f32 %pmask, %mask_val, %half;
1564    selp.f32 %result, %fill, %in_val, %pmask;
1565    st.global.f32 [%out], %result;
1566
1567DONE:
1568    ret;
1569}
1570";
1571
1572// ---------------------------------------------------------------------------
1573// Masked-zero PTX kernel — backward of masked_fill
1574// ---------------------------------------------------------------------------
1575// Thread i: output[i] = mask[i] >= 0.5 ? 0.0 : grad_output[i]
1576// Zeroes gradient at positions where the forward mask was true.
1577
1578#[cfg(feature = "cuda")]
1579pub(crate) const MASKED_ZERO_PTX: &str = "\
1580.version 7.0
1581.target sm_52
1582.address_size 64
1583
1584.visible .entry masked_zero_kernel(
1585    .param .u64 grad_ptr,
1586    .param .u64 mask_ptr,
1587    .param .u64 out_ptr,
1588    .param .u32 n
1589) {
1590    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1591    .reg .u64 %grad, %mask, %out, %off;
1592    .reg .f32 %vg, %mask_val, %zero, %result, %half;
1593    .reg .pred %p, %pmask;
1594
1595    ld.param.u64 %grad, [grad_ptr];
1596    ld.param.u64 %mask, [mask_ptr];
1597    ld.param.u64 %out, [out_ptr];
1598    ld.param.u32 %n_reg, [n];
1599
1600    mov.u32 %bid, %ctaid.x;
1601    mov.u32 %bdim, %ntid.x;
1602    mov.u32 %r_tid, %tid.x;
1603    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1604
1605    setp.ge.u32 %p, %r_tid, %n_reg;
1606    @%p bra DONE;
1607
1608    cvt.u64.u32 %off, %r_tid;
1609    shl.b64 %off, %off, 2;
1610
1611    add.u64 %grad, %grad, %off;
1612    add.u64 %mask, %mask, %off;
1613    add.u64 %out, %out, %off;
1614
1615    ld.global.f32 %vg, [%grad];
1616    ld.global.f32 %mask_val, [%mask];
1617    mov.f32 %zero, 0f00000000;
1618    mov.f32 %half, 0f3F000000;
1619    setp.ge.f32 %pmask, %mask_val, %half;
1620    selp.f32 %result, %zero, %vg, %pmask;
1621    st.global.f32 [%out], %result;
1622
1623DONE:
1624    ret;
1625}
1626";
1627
1628// ---------------------------------------------------------------------------
1629// Sigmoid backward PTX kernel: out[i] = grad[i] * output[i] * (1 - output[i])
1630// ---------------------------------------------------------------------------
1631
1632#[cfg(feature = "cuda")]
1633pub(crate) const SIGMOID_BACKWARD_PTX: &str = "\
1634.version 7.0
1635.target sm_52
1636.address_size 64
1637
1638.visible .entry sigmoid_backward_kernel(
1639    .param .u64 grad_ptr,
1640    .param .u64 output_ptr,
1641    .param .u64 out_ptr,
1642    .param .u32 n
1643) {
1644    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1645    .reg .u64 %grad, %output, %out, %off;
1646    .reg .f32 %vg, %vo, %one, %one_minus_o, %result;
1647    .reg .pred %p;
1648
1649    ld.param.u64 %grad, [grad_ptr];
1650    ld.param.u64 %output, [output_ptr];
1651    ld.param.u64 %out, [out_ptr];
1652    ld.param.u32 %n_reg, [n];
1653
1654    mov.u32 %bid, %ctaid.x;
1655    mov.u32 %bdim, %ntid.x;
1656    mov.u32 %r_tid, %tid.x;
1657    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1658
1659    setp.ge.u32 %p, %r_tid, %n_reg;
1660    @%p bra DONE;
1661
1662    cvt.u64.u32 %off, %r_tid;
1663    shl.b64 %off, %off, 2;
1664
1665    add.u64 %grad, %grad, %off;
1666    add.u64 %output, %output, %off;
1667    add.u64 %out, %out, %off;
1668
1669    ld.global.f32 %vg, [%grad];
1670    ld.global.f32 %vo, [%output];
1671    mov.f32 %one, 0f3F800000;
1672    sub.f32 %one_minus_o, %one, %vo;
1673    mul.f32 %result, %vo, %one_minus_o;
1674    mul.f32 %result, %vg, %result;
1675    st.global.f32 [%out], %result;
1676
1677DONE:
1678    ret;
1679}
1680";
1681
1682// ---------------------------------------------------------------------------
1683// Tanh backward PTX kernel: out[i] = grad[i] * (1 - output[i]^2)
1684// ---------------------------------------------------------------------------
1685
1686#[cfg(feature = "cuda")]
1687pub(crate) const TANH_BACKWARD_PTX: &str = "\
1688.version 7.0
1689.target sm_52
1690.address_size 64
1691
1692.visible .entry tanh_backward_kernel(
1693    .param .u64 grad_ptr,
1694    .param .u64 output_ptr,
1695    .param .u64 out_ptr,
1696    .param .u32 n
1697) {
1698    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1699    .reg .u64 %grad, %output, %out, %off;
1700    .reg .f32 %vg, %vo, %one, %o_sq, %one_minus_sq, %result;
1701    .reg .pred %p;
1702
1703    ld.param.u64 %grad, [grad_ptr];
1704    ld.param.u64 %output, [output_ptr];
1705    ld.param.u64 %out, [out_ptr];
1706    ld.param.u32 %n_reg, [n];
1707
1708    mov.u32 %bid, %ctaid.x;
1709    mov.u32 %bdim, %ntid.x;
1710    mov.u32 %r_tid, %tid.x;
1711    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1712
1713    setp.ge.u32 %p, %r_tid, %n_reg;
1714    @%p bra DONE;
1715
1716    cvt.u64.u32 %off, %r_tid;
1717    shl.b64 %off, %off, 2;
1718
1719    add.u64 %grad, %grad, %off;
1720    add.u64 %output, %output, %off;
1721    add.u64 %out, %out, %off;
1722
1723    ld.global.f32 %vg, [%grad];
1724    ld.global.f32 %vo, [%output];
1725    mov.f32 %one, 0f3F800000;
1726    mul.f32 %o_sq, %vo, %vo;
1727    sub.f32 %one_minus_sq, %one, %o_sq;
1728    mul.f32 %result, %vg, %one_minus_sq;
1729    st.global.f32 [%out], %result;
1730
1731DONE:
1732    ret;
1733}
1734";
1735
1736// ---------------------------------------------------------------------------
1737// Softmax backward PTX kernel (row-wise, shared-memory dot product)
1738// ---------------------------------------------------------------------------
1739// For each row of length `cols`:
1740//   dot = sum(grad[row] * output[row])
1741//   out[i] = output[i] * (grad[i] - dot)
1742// One block per row, 256 threads per block.
1743
1744#[cfg(feature = "cuda")]
1745pub(crate) const SOFTMAX_BACKWARD_PTX: &str = "\
1746.version 7.0\n\
1747.target sm_52\n\
1748.address_size 64\n\
1749\n\
1750.shared .align 4 .f32 sdata[256];\n\
1751\n\
1752.visible .entry softmax_backward_kernel(\n\
1753    .param .u64 grad_ptr,\n\
1754    .param .u64 output_ptr,\n\
1755    .param .u64 out_ptr,\n\
1756    .param .u32 rows,\n\
1757    .param .u32 cols\n\
1758) {\n\
1759    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
1760    .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
1761    .reg .f32 %vg, %vo, %dot, %other_val, %diff, %result;\n\
1762    .reg .pred %p, %loop_p, %reduce_p;\n\
1763\n\
1764    ld.param.u64 %grad, [grad_ptr];\n\
1765    ld.param.u64 %output, [output_ptr];\n\
1766    ld.param.u64 %out, [out_ptr];\n\
1767    ld.param.u32 %rows_reg, [rows];\n\
1768    ld.param.u32 %cols_reg, [cols];\n\
1769\n\
1770    mov.u32 %bid, %ctaid.x;\n\
1771    mov.u32 %bdim, %ntid.x;\n\
1772    mov.u32 %r_tid, %tid.x;\n\
1773    mov.u64 %sbase, sdata;\n\
1774\n\
1775    setp.ge.u32 %p, %bid, %rows_reg;\n\
1776    @%p bra DONE;\n\
1777\n\
1778    // row_off = bid * cols * 4 (byte offset)\n\
1779    cvt.u64.u32 %row_off, %bid;\n\
1780    cvt.u64.u32 %off, %cols_reg;\n\
1781    mul.lo.u64 %row_off, %row_off, %off;\n\
1782    shl.b64 %row_off, %row_off, 2;\n\
1783\n\
1784    // Phase 1: compute partial dot = sum(grad[j] * output[j]) for this thread's elements\n\
1785    mov.f32 %dot, 0f00000000;\n\
1786    mov.u32 %j, %r_tid;\n\
1787DOT_LOOP:\n\
1788    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
1789    @%loop_p bra DOT_LOOP_DONE;\n\
1790    cvt.u64.u32 %off, %j;\n\
1791    shl.b64 %off, %off, 2;\n\
1792    add.u64 %saddr, %grad, %off;\n\
1793    add.u64 %saddr, %saddr, %row_off;\n\
1794    ld.global.f32 %vg, [%saddr];\n\
1795    add.u64 %saddr, %output, %off;\n\
1796    add.u64 %saddr, %saddr, %row_off;\n\
1797    ld.global.f32 %vo, [%saddr];\n\
1798    fma.rn.f32 %dot, %vg, %vo, %dot;\n\
1799    add.u32 %j, %j, %bdim;\n\
1800    bra DOT_LOOP;\n\
1801DOT_LOOP_DONE:\n\
1802\n\
1803    // Store partial dot into shared memory and reduce\n\
1804    cvt.u64.u32 %off, %r_tid;\n\
1805    shl.b64 %off, %off, 2;\n\
1806    add.u64 %saddr, %sbase, %off;\n\
1807    st.shared.f32 [%saddr], %dot;\n\
1808    bar.sync 0;\n\
1809\n\
1810    mov.u32 %half, %bdim;\n\
1811DOT_REDUCE:\n\
1812    shr.u32 %half, %half, 1;\n\
1813    setp.eq.u32 %reduce_p, %half, 0;\n\
1814    @%reduce_p bra DOT_REDUCE_DONE;\n\
1815    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
1816    @%reduce_p bra DOT_REDUCE_SKIP;\n\
1817    add.u32 %other_tid, %r_tid, %half;\n\
1818    cvt.u64.u32 %off, %other_tid;\n\
1819    shl.b64 %off, %off, 2;\n\
1820    add.u64 %saddr, %sbase, %off;\n\
1821    ld.shared.f32 %other_val, [%saddr];\n\
1822    cvt.u64.u32 %off, %r_tid;\n\
1823    shl.b64 %off, %off, 2;\n\
1824    add.u64 %saddr, %sbase, %off;\n\
1825    ld.shared.f32 %dot, [%saddr];\n\
1826    add.f32 %dot, %dot, %other_val;\n\
1827    st.shared.f32 [%saddr], %dot;\n\
1828DOT_REDUCE_SKIP:\n\
1829    bar.sync 0;\n\
1830    bra DOT_REDUCE;\n\
1831DOT_REDUCE_DONE:\n\
1832\n\
1833    // Broadcast dot to all threads\n\
1834    ld.shared.f32 %dot, [sdata];\n\
1835    bar.sync 0;\n\
1836\n\
1837    // Phase 2: out[j] = output[j] * (grad[j] - dot)\n\
1838    mov.u32 %j, %r_tid;\n\
1839WRITE_LOOP:\n\
1840    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
1841    @%loop_p bra WRITE_LOOP_DONE;\n\
1842    cvt.u64.u32 %off, %j;\n\
1843    shl.b64 %off, %off, 2;\n\
1844    add.u64 %saddr, %grad, %off;\n\
1845    add.u64 %saddr, %saddr, %row_off;\n\
1846    ld.global.f32 %vg, [%saddr];\n\
1847    add.u64 %saddr, %output, %off;\n\
1848    add.u64 %saddr, %saddr, %row_off;\n\
1849    ld.global.f32 %vo, [%saddr];\n\
1850    sub.f32 %diff, %vg, %dot;\n\
1851    mul.f32 %result, %vo, %diff;\n\
1852    add.u64 %saddr, %out, %off;\n\
1853    add.u64 %saddr, %saddr, %row_off;\n\
1854    st.global.f32 [%saddr], %result;\n\
1855    add.u32 %j, %j, %bdim;\n\
1856    bra WRITE_LOOP;\n\
1857WRITE_LOOP_DONE:\n\
1858\n\
1859DONE:\n\
1860    ret;\n\
1861}\n\
1862";
1863
1864// ---------------------------------------------------------------------------
1865// Sum-axis PTX kernel: reduce along one axis of a tensor
1866// ---------------------------------------------------------------------------
1867// Parameters: input_ptr, output_ptr, outer_size, axis_size, inner_size, total_output
1868/// PTX source for `reduce_sum_kernel`: parallel block-level sum reduction.
1869///
1870/// Each block reduces a contiguous chunk of the input array using shared
1871/// memory. Threads first accumulate a sequential sum (grid-stride loop),
1872/// store to shared memory, then do a tree reduction within the block.
1873/// Each block writes one partial sum to `output[blockIdx.x]`.
1874///
1875/// For a full reduction, launch once to get partial sums, then launch
1876/// again on the partial sums (or reduce on CPU if few blocks).
1877#[cfg(feature = "cuda")]
1878pub(crate) const REDUCE_SUM_PTX: &str = "\
1879.version 7.0
1880.target sm_52
1881.address_size 64
1882
1883// Shared memory for intra-block reduction (256 floats = 1024 bytes).
1884.shared .align 4 .f32 sdata[256];
1885
1886.visible .entry reduce_sum_kernel(
1887    .param .u64 in_ptr,
1888    .param .u64 out_ptr,
1889    .param .u32 n
1890) {
1891    .reg .u32 %tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
1892    .reg .u64 %in, %out, %off;
1893    .reg .f32 %sum, %other;
1894    .reg .pred %p, %ptid;
1895
1896    ld.param.u64 %in, [in_ptr];
1897    ld.param.u64 %out, [out_ptr];
1898    ld.param.u32 %n_reg, [n];
1899
1900    mov.u32 %tid, %tid.x;
1901    mov.u32 %bid, %ctaid.x;
1902    mov.u32 %bdim, %ntid.x;
1903    mov.u32 %gdim, %nctaid.x;
1904
1905    // Grid-stride accumulation: each thread sums multiple elements.
1906    // idx = bid * bdim + tid; stride = bdim * gdim
1907    mad.lo.u32 %idx, %bid, %bdim, %tid;
1908    mul.lo.u32 %stride, %bdim, %gdim;
1909    mov.f32 %sum, 0f00000000;
1910
1911GRID_LOOP:
1912    setp.ge.u32 %p, %idx, %n_reg;
1913    @%p bra GRID_DONE;
1914
1915    cvt.u64.u32 %off, %idx;
1916    shl.b64 %off, %off, 2;
1917    add.u64 %off, %in, %off;
1918    ld.global.f32 %other, [%off];
1919    add.f32 %sum, %sum, %other;
1920    add.u32 %idx, %idx, %stride;
1921    bra GRID_LOOP;
1922
1923GRID_DONE:
1924    // Write thread's partial sum to shared memory.
1925    cvt.u64.u32 %off, %tid;
1926    shl.b64 %off, %off, 2;
1927    st.shared.f32 [sdata + %off], %sum;
1928    bar.sync 0;
1929
1930    // Tree reduction in shared memory.
1931    mov.u32 %half, 128;
1932TREE_LOOP:
1933    setp.lt.u32 %p, %half, 1;
1934    @%p bra TREE_DONE;
1935
1936    setp.ge.u32 %ptid, %tid, %half;
1937    @%ptid bra TREE_SKIP;
1938
1939    // Load partner's value from sdata[tid + half].
1940    add.u32 %idx, %tid, %half;
1941    cvt.u64.u32 %off, %idx;
1942    shl.b64 %off, %off, 2;
1943    ld.shared.f32 %other, [sdata + %off];
1944    // Load own value.
1945    cvt.u64.u32 %off, %tid;
1946    shl.b64 %off, %off, 2;
1947    ld.shared.f32 %sum, [sdata + %off];
1948    add.f32 %sum, %sum, %other;
1949    st.shared.f32 [sdata + %off], %sum;
1950
1951TREE_SKIP:
1952    bar.sync 0;
1953    shr.u32 %half, %half, 1;
1954    bra TREE_LOOP;
1955
1956TREE_DONE:
1957    // Thread 0 writes block result.
1958    setp.ne.u32 %ptid, %tid, 0;
1959    @%ptid bra END;
1960
1961    ld.shared.f32 %sum, [sdata];
1962    cvt.u64.u32 %off, %bid;
1963    shl.b64 %off, %off, 2;
1964    add.u64 %out, %out, %off;
1965    st.global.f32 [%out], %sum;
1966
1967END:
1968    ret;
1969}
1970";
1971
1972// Thread i: output[i] = sum_{k=0}^{axis_size-1} input[outer_idx * axis_size * inner_size + k * inner_size + inner_idx]
1973// where outer_idx = i / inner_size, inner_idx = i % inner_size.
1974
1975#[cfg(feature = "cuda")]
1976pub(crate) const SUM_AXIS_PTX: &str = "\
1977.version 7.0
1978.target sm_52
1979.address_size 64
1980
1981.visible .entry sum_axis_kernel(
1982    .param .u64 input_ptr,
1983    .param .u64 output_ptr,
1984    .param .u32 outer_size,
1985    .param .u32 axis_size,
1986    .param .u32 inner_size,
1987    .param .u32 total_output
1988) {
1989    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %axis_sz, %inner_sz;
1990    .reg .u32 %outer_idx, %inner_idx, %k, %tmp;
1991    .reg .u64 %in, %out, %off, %addr;
1992    .reg .f32 %val, %sum;
1993    .reg .pred %p, %lp;
1994
1995    ld.param.u64 %in, [input_ptr];
1996    ld.param.u64 %out, [output_ptr];
1997    ld.param.u32 %outer_sz, [outer_size];
1998    ld.param.u32 %axis_sz, [axis_size];
1999    ld.param.u32 %inner_sz, [inner_size];
2000    ld.param.u32 %n_reg, [total_output];
2001
2002    mov.u32 %bid, %ctaid.x;
2003    mov.u32 %bdim, %ntid.x;
2004    mov.u32 %r_tid, %tid.x;
2005    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2006
2007    setp.ge.u32 %p, %r_tid, %n_reg;
2008    @%p bra DONE;
2009
2010    // outer_idx = r_tid / inner_size
2011    div.u32 %outer_idx, %r_tid, %inner_sz;
2012    // inner_idx = r_tid % inner_size
2013    rem.u32 %inner_idx, %r_tid, %inner_sz;
2014
2015    // base = outer_idx * axis_size * inner_size + inner_idx
2016    mul.lo.u32 %tmp, %outer_idx, %axis_sz;
2017    mul.lo.u32 %tmp, %tmp, %inner_sz;
2018    add.u32 %tmp, %tmp, %inner_idx;
2019
2020    mov.f32 %sum, 0f00000000;
2021    mov.u32 %k, 0;
2022SUM_LOOP:
2023    setp.ge.u32 %lp, %k, %axis_sz;
2024    @%lp bra SUM_LOOP_DONE;
2025
2026    // addr = in + (tmp + k * inner_size) * 4
2027    mul.lo.u32 %inner_idx, %k, %inner_sz;
2028    add.u32 %inner_idx, %tmp, %inner_idx;
2029    cvt.u64.u32 %off, %inner_idx;
2030    shl.b64 %off, %off, 2;
2031    add.u64 %addr, %in, %off;
2032    ld.global.f32 %val, [%addr];
2033    add.f32 %sum, %sum, %val;
2034
2035    add.u32 %k, %k, 1;
2036    bra SUM_LOOP;
2037SUM_LOOP_DONE:
2038
2039    // output[r_tid] = sum
2040    cvt.u64.u32 %off, %r_tid;
2041    shl.b64 %off, %off, 2;
2042    add.u64 %addr, %out, %off;
2043    st.global.f32 [%addr], %sum;
2044
2045DONE:
2046    ret;
2047}
2048";
2049
2050// ---------------------------------------------------------------------------
2051// LayerNorm PTX kernel (row-wise: mean, var, normalize+affine)
2052//
2053// Uses `.approx` PTX instructions (`div.approx.f32`, `sqrt.approx.f32`,
2054// `rcp.approx.f32`) for performance. These have reduced precision (~2^-22
2055// relative error) compared to the full-precision variants, which is
2056// acceptable for neural network training/inference.
2057// ---------------------------------------------------------------------------
2058
2059#[cfg(feature = "cuda")]
2060pub(crate) const LAYERNORM_PTX: &str = "\
2061.version 7.0
2062.target sm_52
2063.address_size 64
2064
2065.shared .align 4 .f32 sdata[256];
2066
2067.visible .entry layernorm_kernel(
2068    .param .u64 in_ptr,
2069    .param .u64 out_ptr,
2070    .param .u64 w_ptr,
2071    .param .u64 b_ptr,
2072    .param .u32 rows,
2073    .param .u32 cols,
2074    .param .f32 eps
2075) {
2076    .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
2077    .reg .u64 %in, %out, %w, %b, %row_off, %off, %sbase, %saddr;
2078    .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %normed, %wv, %bv, %result, %other_val, %n_f;
2079    .reg .pred %p, %lp, %rp;
2080
2081    ld.param.u64 %in, [in_ptr];
2082    ld.param.u64 %out, [out_ptr];
2083    ld.param.u64 %w, [w_ptr];
2084    ld.param.u64 %b, [b_ptr];
2085    ld.param.u32 %rows_reg, [rows];
2086    ld.param.u32 %cols_reg, [cols];
2087    ld.param.f32 %eps_r, [eps];
2088
2089    mov.u64 %sbase, sdata;
2090
2091    mov.u32 %r_bid, %ctaid.x;
2092    mov.u32 %r_bdim, %ntid.x;
2093    mov.u32 %r_tid, %tid.x;
2094
2095    setp.ge.u32 %p, %r_bid, %rows_reg;
2096    @%p bra DONE;
2097
2098    cvt.u64.u32 %row_off, %r_bid;
2099    cvt.u64.u32 %off, %cols_reg;
2100    mul.lo.u64 %row_off, %row_off, %off;
2101    shl.b64 %row_off, %row_off, 2;
2102    cvt.rn.f32.u32 %n_f, %cols_reg;
2103
2104    mov.f32 %mean, 0f00000000;
2105    mov.u32 %j, %r_tid;
2106SM:
2107    setp.ge.u32 %lp, %j, %cols_reg;
2108    @%lp bra SMD;
2109    cvt.u64.u32 %off, %j;
2110    shl.b64 %off, %off, 2;
2111    add.u64 %off, %in, %off;
2112    add.u64 %off, %off, %row_off;
2113    ld.global.f32 %val, [%off];
2114    add.f32 %mean, %mean, %val;
2115    add.u32 %j, %j, %r_bdim;
2116    bra SM;
2117SMD:
2118    cvt.u64.u32 %off, %r_tid;
2119    shl.b64 %off, %off, 2;
2120    add.u64 %saddr, %sbase, %off;
2121    st.shared.f32 [%saddr], %mean;
2122    bar.sync 0;
2123    mov.u32 %half, %r_bdim;
2124MR:
2125    shr.u32 %half, %half, 1;
2126    setp.eq.u32 %rp, %half, 0;
2127    @%rp bra MRD;
2128    setp.ge.u32 %rp, %r_tid, %half;
2129    @%rp bra MRS;
2130    add.u32 %r_otid, %r_tid, %half;
2131    cvt.u64.u32 %off, %r_otid;
2132    shl.b64 %off, %off, 2;
2133    add.u64 %saddr, %sbase, %off;
2134    ld.shared.f32 %other_val, [%saddr];
2135    cvt.u64.u32 %off, %r_tid;
2136    shl.b64 %off, %off, 2;
2137    add.u64 %saddr, %sbase, %off;
2138    ld.shared.f32 %mean, [%saddr];
2139    add.f32 %mean, %mean, %other_val;
2140    add.u64 %saddr, %sbase, %off;
2141    st.shared.f32 [%saddr], %mean;
2142MRS:
2143    bar.sync 0;
2144    bra MR;
2145MRD:
2146    ld.shared.f32 %mean, [%sbase];
2147    div.approx.f32 %mean, %mean, %n_f;
2148    bar.sync 0;
2149
2150    mov.f32 %var, 0f00000000;
2151    mov.u32 %j, %r_tid;
2152SV:
2153    setp.ge.u32 %lp, %j, %cols_reg;
2154    @%lp bra SVD;
2155    cvt.u64.u32 %off, %j;
2156    shl.b64 %off, %off, 2;
2157    add.u64 %off, %in, %off;
2158    add.u64 %off, %off, %row_off;
2159    ld.global.f32 %val, [%off];
2160    sub.f32 %diff, %val, %mean;
2161    fma.rn.f32 %var, %diff, %diff, %var;
2162    add.u32 %j, %j, %r_bdim;
2163    bra SV;
2164SVD:
2165    cvt.u64.u32 %off, %r_tid;
2166    shl.b64 %off, %off, 2;
2167    add.u64 %saddr, %sbase, %off;
2168    st.shared.f32 [%saddr], %var;
2169    bar.sync 0;
2170    mov.u32 %half, %r_bdim;
2171VR:
2172    shr.u32 %half, %half, 1;
2173    setp.eq.u32 %rp, %half, 0;
2174    @%rp bra VRD;
2175    setp.ge.u32 %rp, %r_tid, %half;
2176    @%rp bra VRS;
2177    add.u32 %r_otid, %r_tid, %half;
2178    cvt.u64.u32 %off, %r_otid;
2179    shl.b64 %off, %off, 2;
2180    add.u64 %saddr, %sbase, %off;
2181    ld.shared.f32 %other_val, [%saddr];
2182    cvt.u64.u32 %off, %r_tid;
2183    shl.b64 %off, %off, 2;
2184    add.u64 %saddr, %sbase, %off;
2185    ld.shared.f32 %var, [%saddr];
2186    add.f32 %var, %var, %other_val;
2187    add.u64 %saddr, %sbase, %off;
2188    st.shared.f32 [%saddr], %var;
2189VRS:
2190    bar.sync 0;
2191    bra VR;
2192VRD:
2193    ld.shared.f32 %var, [%sbase];
2194    div.approx.f32 %var, %var, %n_f;
2195    add.f32 %var, %var, %eps_r;
2196    sqrt.approx.f32 %inv_std, %var;
2197    rcp.approx.f32 %inv_std, %inv_std;
2198    bar.sync 0;
2199
2200    mov.u32 %j, %r_tid;
2201NM:
2202    setp.ge.u32 %lp, %j, %cols_reg;
2203    @%lp bra NMD;
2204    cvt.u64.u32 %off, %j;
2205    shl.b64 %off, %off, 2;
2206    add.u64 %off, %in, %off;
2207    add.u64 %off, %off, %row_off;
2208    ld.global.f32 %val, [%off];
2209    sub.f32 %normed, %val, %mean;
2210    mul.f32 %normed, %normed, %inv_std;
2211    cvt.u64.u32 %off, %j;
2212    shl.b64 %off, %off, 2;
2213    add.u64 %off, %w, %off;
2214    ld.global.f32 %wv, [%off];
2215    cvt.u64.u32 %off, %j;
2216    shl.b64 %off, %off, 2;
2217    add.u64 %off, %b, %off;
2218    ld.global.f32 %bv, [%off];
2219    fma.rn.f32 %result, %wv, %normed, %bv;
2220    cvt.u64.u32 %off, %j;
2221    shl.b64 %off, %off, 2;
2222    add.u64 %off, %out, %off;
2223    add.u64 %off, %off, %row_off;
2224    st.global.f32 [%off], %result;
2225    add.u32 %j, %j, %r_bdim;
2226    bra NM;
2227NMD:
2228
2229DONE:
2230    ret;
2231}
2232";
2233
2234// ---------------------------------------------------------------------------
2235// LayerNorm backward PTX kernel
2236// ---------------------------------------------------------------------------
2237//
2238// One block per batch element (row). Each block:
2239//   1. Recompute mean and variance from input
2240//   2. Compute x_hat = (x - mean) * rsqrt(var + eps)
2241//   3. Compute dl_dx_hat = grad_output * weight
2242//   4. Reduce dl_dx_hat and dl_dx_hat * x_hat across the normalized dimension
2243//   5. Compute grad_input = rsqrt(var+eps) * (dl_dx_hat - mean(dl_dx_hat) - x_hat * mean(dl_dx_hat * x_hat))
2244//   6. Accumulate grad_weight (atomicAdd) and grad_bias (atomicAdd) across batch elements
2245//
2246// Uses shared memory for per-row reductions, 256 threads per block.
2247// Parameters:
2248//   in_ptr      - pointer to input f32 buffer [rows * cols]
2249//   grad_out_ptr - pointer to grad_output f32 buffer [rows * cols]
2250//   w_ptr       - pointer to weight f32 buffer [cols]
2251//   grad_in_ptr - pointer to grad_input f32 output buffer [rows * cols]
2252//   grad_w_ptr  - pointer to grad_weight f32 output buffer [cols] (atomicAdd)
2253//   grad_b_ptr  - pointer to grad_bias f32 output buffer [cols] (atomicAdd)
2254//   rows        - number of batch elements
2255//   cols        - normalized dimension size
2256//   eps         - epsilon for numerical stability
2257
2258#[cfg(feature = "cuda")]
2259pub(crate) const LAYERNORM_BACKWARD_PTX: &str = "\
2260.version 7.0
2261.target sm_52
2262.address_size 64
2263
2264.shared .align 4 .f32 sdata[256];
2265
2266.visible .entry layernorm_backward_kernel(
2267    .param .u64 in_ptr,
2268    .param .u64 grad_out_ptr,
2269    .param .u64 w_ptr,
2270    .param .u64 grad_in_ptr,
2271    .param .u64 grad_w_ptr,
2272    .param .u64 grad_b_ptr,
2273    .param .u32 rows,
2274    .param .u32 cols,
2275    .param .f32 eps
2276) {
2277    .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
2278    .reg .u64 %in, %go, %w, %gi, %gw, %gb, %row_off, %off, %sbase, %saddr, %addr;
2279    .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %x_hat, %wv, %gov;
2280    .reg .f32 %dl_dx_hat, %sum1, %sum2, %other_val, %n_f, %mean1, %mean2, %result;
2281    .reg .pred %p, %lp, %rp;
2282
2283    ld.param.u64 %in, [in_ptr];
2284    ld.param.u64 %go, [grad_out_ptr];
2285    ld.param.u64 %w, [w_ptr];
2286    ld.param.u64 %gi, [grad_in_ptr];
2287    ld.param.u64 %gw, [grad_w_ptr];
2288    ld.param.u64 %gb, [grad_b_ptr];
2289    ld.param.u32 %rows_reg, [rows];
2290    ld.param.u32 %cols_reg, [cols];
2291    ld.param.f32 %eps_r, [eps];
2292
2293    mov.u64 %sbase, sdata;
2294
2295    mov.u32 %r_bid, %ctaid.x;
2296    mov.u32 %r_bdim, %ntid.x;
2297    mov.u32 %r_tid, %tid.x;
2298
2299    setp.ge.u32 %p, %r_bid, %rows_reg;
2300    @%p bra LNB_DONE;
2301
2302    // row_off = bid * cols * 4 (byte offset for this row)
2303    cvt.u64.u32 %row_off, %r_bid;
2304    cvt.u64.u32 %off, %cols_reg;
2305    mul.lo.u64 %row_off, %row_off, %off;
2306    shl.b64 %row_off, %row_off, 2;
2307    cvt.rn.f32.u32 %n_f, %cols_reg;
2308
2309    // ===== Phase 1: Compute mean =====
2310    mov.f32 %mean, 0f00000000;
2311    mov.u32 %j, %r_tid;
2312LNB_SM:
2313    setp.ge.u32 %lp, %j, %cols_reg;
2314    @%lp bra LNB_SMD;
2315    cvt.u64.u32 %off, %j;
2316    shl.b64 %off, %off, 2;
2317    add.u64 %addr, %in, %off;
2318    add.u64 %addr, %addr, %row_off;
2319    ld.global.f32 %val, [%addr];
2320    add.f32 %mean, %mean, %val;
2321    add.u32 %j, %j, %r_bdim;
2322    bra LNB_SM;
2323LNB_SMD:
2324    // Shared memory reduce for mean
2325    cvt.u64.u32 %off, %r_tid;
2326    shl.b64 %off, %off, 2;
2327    add.u64 %saddr, %sbase, %off;
2328    st.shared.f32 [%saddr], %mean;
2329    bar.sync 0;
2330    mov.u32 %half, %r_bdim;
2331LNB_MR:
2332    shr.u32 %half, %half, 1;
2333    setp.eq.u32 %rp, %half, 0;
2334    @%rp bra LNB_MRD;
2335    setp.ge.u32 %rp, %r_tid, %half;
2336    @%rp bra LNB_MRS;
2337    add.u32 %r_otid, %r_tid, %half;
2338    cvt.u64.u32 %off, %r_otid;
2339    shl.b64 %off, %off, 2;
2340    add.u64 %saddr, %sbase, %off;
2341    ld.shared.f32 %other_val, [%saddr];
2342    cvt.u64.u32 %off, %r_tid;
2343    shl.b64 %off, %off, 2;
2344    add.u64 %saddr, %sbase, %off;
2345    ld.shared.f32 %mean, [%saddr];
2346    add.f32 %mean, %mean, %other_val;
2347    st.shared.f32 [%saddr], %mean;
2348LNB_MRS:
2349    bar.sync 0;
2350    bra LNB_MR;
2351LNB_MRD:
2352    ld.shared.f32 %mean, [%sbase];
2353    div.approx.f32 %mean, %mean, %n_f;
2354    bar.sync 0;
2355
2356    // ===== Phase 2: Compute variance =====
2357    mov.f32 %var, 0f00000000;
2358    mov.u32 %j, %r_tid;
2359LNB_SV:
2360    setp.ge.u32 %lp, %j, %cols_reg;
2361    @%lp bra LNB_SVD;
2362    cvt.u64.u32 %off, %j;
2363    shl.b64 %off, %off, 2;
2364    add.u64 %addr, %in, %off;
2365    add.u64 %addr, %addr, %row_off;
2366    ld.global.f32 %val, [%addr];
2367    sub.f32 %diff, %val, %mean;
2368    fma.rn.f32 %var, %diff, %diff, %var;
2369    add.u32 %j, %j, %r_bdim;
2370    bra LNB_SV;
2371LNB_SVD:
2372    // Shared memory reduce for variance
2373    cvt.u64.u32 %off, %r_tid;
2374    shl.b64 %off, %off, 2;
2375    add.u64 %saddr, %sbase, %off;
2376    st.shared.f32 [%saddr], %var;
2377    bar.sync 0;
2378    mov.u32 %half, %r_bdim;
2379LNB_VR:
2380    shr.u32 %half, %half, 1;
2381    setp.eq.u32 %rp, %half, 0;
2382    @%rp bra LNB_VRD;
2383    setp.ge.u32 %rp, %r_tid, %half;
2384    @%rp bra LNB_VRS;
2385    add.u32 %r_otid, %r_tid, %half;
2386    cvt.u64.u32 %off, %r_otid;
2387    shl.b64 %off, %off, 2;
2388    add.u64 %saddr, %sbase, %off;
2389    ld.shared.f32 %other_val, [%saddr];
2390    cvt.u64.u32 %off, %r_tid;
2391    shl.b64 %off, %off, 2;
2392    add.u64 %saddr, %sbase, %off;
2393    ld.shared.f32 %var, [%saddr];
2394    add.f32 %var, %var, %other_val;
2395    st.shared.f32 [%saddr], %var;
2396LNB_VRS:
2397    bar.sync 0;
2398    bra LNB_VR;
2399LNB_VRD:
2400    ld.shared.f32 %var, [%sbase];
2401    div.approx.f32 %var, %var, %n_f;
2402    add.f32 %var, %var, %eps_r;
2403    sqrt.approx.f32 %inv_std, %var;
2404    rcp.approx.f32 %inv_std, %inv_std;
2405    bar.sync 0;
2406
2407    // ===== Phase 3: Compute sum1 = sum(dl_dx_hat), sum2 = sum(dl_dx_hat * x_hat) =====
2408    // Also accumulate grad_weight and grad_bias via atomicAdd
2409    mov.f32 %sum1, 0f00000000;
2410    mov.f32 %sum2, 0f00000000;
2411    mov.u32 %j, %r_tid;
2412LNB_S12:
2413    setp.ge.u32 %lp, %j, %cols_reg;
2414    @%lp bra LNB_S12D;
2415    // Load input[row, j]
2416    cvt.u64.u32 %off, %j;
2417    shl.b64 %off, %off, 2;
2418    add.u64 %addr, %in, %off;
2419    add.u64 %addr, %addr, %row_off;
2420    ld.global.f32 %val, [%addr];
2421    // x_hat = (val - mean) * inv_std
2422    sub.f32 %x_hat, %val, %mean;
2423    mul.f32 %x_hat, %x_hat, %inv_std;
2424    // Load grad_output[row, j]
2425    cvt.u64.u32 %off, %j;
2426    shl.b64 %off, %off, 2;
2427    add.u64 %addr, %go, %off;
2428    add.u64 %addr, %addr, %row_off;
2429    ld.global.f32 %gov, [%addr];
2430    // Load weight[j]
2431    cvt.u64.u32 %off, %j;
2432    shl.b64 %off, %off, 2;
2433    add.u64 %addr, %w, %off;
2434    ld.global.f32 %wv, [%addr];
2435    // dl_dx_hat = grad_output * weight
2436    mul.f32 %dl_dx_hat, %gov, %wv;
2437    // Accumulate sums
2438    add.f32 %sum1, %sum1, %dl_dx_hat;
2439    fma.rn.f32 %sum2, %dl_dx_hat, %x_hat, %sum2;
2440    // atomicAdd grad_weight[j] += grad_output * x_hat
2441    cvt.u64.u32 %off, %j;
2442    shl.b64 %off, %off, 2;
2443    add.u64 %addr, %gw, %off;
2444    mul.f32 %result, %gov, %x_hat;
2445    atom.global.add.f32 %result, [%addr], %result;
2446    // atomicAdd grad_bias[j] += grad_output
2447    add.u64 %addr, %gb, %off;
2448    atom.global.add.f32 %result, [%addr], %gov;
2449    add.u32 %j, %j, %r_bdim;
2450    bra LNB_S12;
2451LNB_S12D:
2452    // Reduce sum1 in shared memory
2453    cvt.u64.u32 %off, %r_tid;
2454    shl.b64 %off, %off, 2;
2455    add.u64 %saddr, %sbase, %off;
2456    st.shared.f32 [%saddr], %sum1;
2457    bar.sync 0;
2458    mov.u32 %half, %r_bdim;
2459LNB_R1:
2460    shr.u32 %half, %half, 1;
2461    setp.eq.u32 %rp, %half, 0;
2462    @%rp bra LNB_R1D;
2463    setp.ge.u32 %rp, %r_tid, %half;
2464    @%rp bra LNB_R1S;
2465    add.u32 %r_otid, %r_tid, %half;
2466    cvt.u64.u32 %off, %r_otid;
2467    shl.b64 %off, %off, 2;
2468    add.u64 %saddr, %sbase, %off;
2469    ld.shared.f32 %other_val, [%saddr];
2470    cvt.u64.u32 %off, %r_tid;
2471    shl.b64 %off, %off, 2;
2472    add.u64 %saddr, %sbase, %off;
2473    ld.shared.f32 %sum1, [%saddr];
2474    add.f32 %sum1, %sum1, %other_val;
2475    st.shared.f32 [%saddr], %sum1;
2476LNB_R1S:
2477    bar.sync 0;
2478    bra LNB_R1;
2479LNB_R1D:
2480    ld.shared.f32 %sum1, [%sbase];
2481    // mean1 = sum1 / n
2482    div.approx.f32 %mean1, %sum1, %n_f;
2483    bar.sync 0;
2484
2485    // Reduce sum2 in shared memory
2486    cvt.u64.u32 %off, %r_tid;
2487    shl.b64 %off, %off, 2;
2488    add.u64 %saddr, %sbase, %off;
2489    st.shared.f32 [%saddr], %sum2;
2490    bar.sync 0;
2491    mov.u32 %half, %r_bdim;
2492LNB_R2:
2493    shr.u32 %half, %half, 1;
2494    setp.eq.u32 %rp, %half, 0;
2495    @%rp bra LNB_R2D;
2496    setp.ge.u32 %rp, %r_tid, %half;
2497    @%rp bra LNB_R2S;
2498    add.u32 %r_otid, %r_tid, %half;
2499    cvt.u64.u32 %off, %r_otid;
2500    shl.b64 %off, %off, 2;
2501    add.u64 %saddr, %sbase, %off;
2502    ld.shared.f32 %other_val, [%saddr];
2503    cvt.u64.u32 %off, %r_tid;
2504    shl.b64 %off, %off, 2;
2505    add.u64 %saddr, %sbase, %off;
2506    ld.shared.f32 %sum2, [%saddr];
2507    add.f32 %sum2, %sum2, %other_val;
2508    st.shared.f32 [%saddr], %sum2;
2509LNB_R2S:
2510    bar.sync 0;
2511    bra LNB_R2;
2512LNB_R2D:
2513    ld.shared.f32 %sum2, [%sbase];
2514    // mean2 = sum2 / n
2515    div.approx.f32 %mean2, %sum2, %n_f;
2516    bar.sync 0;
2517
2518    // ===== Phase 4: Compute grad_input =====
2519    // grad_input[j] = inv_std * (dl_dx_hat[j] - mean1 - x_hat[j] * mean2)
2520    mov.u32 %j, %r_tid;
2521LNB_GI:
2522    setp.ge.u32 %lp, %j, %cols_reg;
2523    @%lp bra LNB_GID;
2524    // Reload input to recompute x_hat
2525    cvt.u64.u32 %off, %j;
2526    shl.b64 %off, %off, 2;
2527    add.u64 %addr, %in, %off;
2528    add.u64 %addr, %addr, %row_off;
2529    ld.global.f32 %val, [%addr];
2530    sub.f32 %x_hat, %val, %mean;
2531    mul.f32 %x_hat, %x_hat, %inv_std;
2532    // Reload grad_output and weight to recompute dl_dx_hat
2533    cvt.u64.u32 %off, %j;
2534    shl.b64 %off, %off, 2;
2535    add.u64 %addr, %go, %off;
2536    add.u64 %addr, %addr, %row_off;
2537    ld.global.f32 %gov, [%addr];
2538    cvt.u64.u32 %off, %j;
2539    shl.b64 %off, %off, 2;
2540    add.u64 %addr, %w, %off;
2541    ld.global.f32 %wv, [%addr];
2542    mul.f32 %dl_dx_hat, %gov, %wv;
2543    // result = inv_std * (dl_dx_hat - mean1 - x_hat * mean2)
2544    sub.f32 %result, %dl_dx_hat, %mean1;
2545    mul.f32 %diff, %x_hat, %mean2;
2546    sub.f32 %result, %result, %diff;
2547    mul.f32 %result, %inv_std, %result;
2548    // Store grad_input[row, j]
2549    cvt.u64.u32 %off, %j;
2550    shl.b64 %off, %off, 2;
2551    add.u64 %addr, %gi, %off;
2552    add.u64 %addr, %addr, %row_off;
2553    st.global.f32 [%addr], %result;
2554    add.u32 %j, %j, %r_bdim;
2555    bra LNB_GI;
2556LNB_GID:
2557
2558LNB_DONE:
2559    ret;
2560}
2561";
2562
2563// ---------------------------------------------------------------------------
2564// Softmax PTX kernel (row-wise, numerically stable)
2565// ---------------------------------------------------------------------------
2566//
2567// One thread block per row. Each block:
2568//   1. Finds the max in shared memory (for numerical stability)
2569//   2. Computes exp(x - max) and sums in shared memory
2570//   3. Normalizes by the sum
2571//
2572// Uses `.approx` PTX instructions (`ex2.approx.f32`, `rcp.approx.f32`)
2573// for performance. These have reduced precision (~2^-22 relative error)
2574// compared to the full-precision variants, which is acceptable for neural
2575// network training/inference.
2576//
2577// Parameters:
2578//   input_ptr  - pointer to input f32 buffer
2579//   output_ptr - pointer to output f32 buffer
2580//   rows       - number of rows (outer dimension)
2581//   cols       - number of columns (softmax dimension, = last_dim)
2582
2583/// PTX kernel for BatchNorm2d forward: per-channel normalize + affine.
2584///
2585/// Input layout: [B*C*spatial] flattened, where spatial = H*W.
2586/// One block per channel. Each block computes mean + variance for its
2587/// channel across all batch elements and spatial positions, then
2588/// normalizes in a second pass.
2589///
2590/// Parameters:
2591///   input[B*C*S], output[B*C*S], weight[C], bias[C],
2592///   running_mean[C], running_var[C], save_mean[C], save_invstd[C],
2593///   channels, spatial, eps, momentum, total_per_channel (= B*S),
2594///   training (0 or 1)
2595#[cfg(feature = "cuda")]
2596pub(crate) const BATCHNORM_FORWARD_PTX: &str = "\
2597.version 7.0
2598.target sm_52
2599.address_size 64
2600
2601// Shared memory for block reduction
2602.shared .align 4 .f32 smem_sum[256];
2603.shared .align 4 .f32 smem_sq[256];
2604
2605.visible .entry batchnorm_forward_kernel(
2606    .param .u64 input_ptr,
2607    .param .u64 output_ptr,
2608    .param .u64 weight_ptr,
2609    .param .u64 bias_ptr,
2610    .param .u64 rmean_ptr,
2611    .param .u64 rvar_ptr,
2612    .param .u64 save_mean_ptr,
2613    .param .u64 save_invstd_ptr,
2614    .param .u32 channels,
2615    .param .u32 spatial,
2616    .param .f32 eps,
2617    .param .f32 momentum,
2618    .param .u32 total_per_ch,
2619    .param .u32 training
2620) {
2621    .reg .u32 %tid, %bid, %bdim, %ch, %n_ch, %sp, %tpc, %idx, %train;
2622    .reg .u64 %in, %out, %w, %b, %rm, %rv, %sm, %si, %off64, %tmp64;
2623    .reg .f32 %sum, %sqsum, %val, %mean, %var, %invstd;
2624    .reg .f32 %gamma, %beta, %eps_reg, %mom, %other;
2625    .reg .f32 %n_f, %one, %normalized;
2626    .reg .pred %p, %ptrain, %ptid0;
2627    .reg .u32 %half;
2628
2629    ld.param.u64 %in, [input_ptr];
2630    ld.param.u64 %out, [output_ptr];
2631    ld.param.u64 %w, [weight_ptr];
2632    ld.param.u64 %b, [bias_ptr];
2633    ld.param.u64 %rm, [rmean_ptr];
2634    ld.param.u64 %rv, [rvar_ptr];
2635    ld.param.u64 %sm, [save_mean_ptr];
2636    ld.param.u64 %si, [save_invstd_ptr];
2637    ld.param.u32 %n_ch, [channels];
2638    ld.param.u32 %sp, [spatial];
2639    ld.param.f32 %eps_reg, [eps];
2640    ld.param.f32 %mom, [momentum];
2641    ld.param.u32 %tpc, [total_per_ch];
2642    ld.param.u32 %train, [training];
2643
2644    mov.u32 %bid, %ctaid.x;
2645    mov.u32 %tid, %tid.x;
2646    mov.u32 %bdim, %ntid.x;
2647    mov.u32 %ch, %bid;
2648    mov.f32 %one, 0f3F800000;
2649
2650    setp.ge.u32 %p, %ch, %n_ch;
2651    @%p bra END;
2652
2653    setp.ne.u32 %ptrain, %train, 0;
2654
2655    // ---- Pass 1: compute sum and sum-of-squares for this channel ----
2656    mov.f32 %sum, 0f00000000;
2657    mov.f32 %sqsum, 0f00000000;
2658
2659    // Grid-stride loop over B*spatial for this channel
2660    mov.u32 %idx, %tid;
2661PASS1_LOOP:
2662    setp.ge.u32 %p, %idx, %tpc;
2663    @%p bra PASS1_DONE;
2664
2665    // Linear offset = (idx / spatial) * channels * spatial + ch * spatial + idx % spatial
2666    div.u32 %half, %idx, %sp;
2667    rem.u32 %half, %idx, %sp;  // reuse half as spatial_idx
2668    // batch_offset = (idx / sp) * (n_ch * sp) + ch * sp + (idx % sp)
2669    div.u32 %half, %idx, %sp;  // batch_idx
2670    mul.lo.u32 %half, %half, %n_ch;
2671    add.u32 %half, %half, %ch;
2672    mul.lo.u32 %half, %half, %sp;
2673    rem.u32 %idx, %idx, %sp;   // spatial_idx
2674    add.u32 %half, %half, %idx;
2675
2676    cvt.u64.u32 %off64, %half;
2677    shl.b64 %off64, %off64, 2;
2678    add.u64 %tmp64, %in, %off64;
2679    ld.global.f32 %val, [%tmp64];
2680    add.f32 %sum, %sum, %val;
2681    fma.rn.f32 %sqsum, %val, %val, %sqsum;
2682
2683    // Restore idx for stride
2684    // Recompute idx from tid + iteration * bdim
2685    add.u32 %idx, %idx, %bdim;  // This is wrong - need proper loop counter
2686    bra PASS1_LOOP;
2687
2688PASS1_DONE:
2689    // Store to shared memory for block reduction
2690    cvt.u64.u32 %off64, %tid;
2691    shl.b64 %off64, %off64, 2;
2692    st.shared.f32 [smem_sum + %off64], %sum;
2693    st.shared.f32 [smem_sq + %off64], %sqsum;
2694    bar.sync 0;
2695
2696    // Tree reduction
2697    mov.u32 %half, 128;
2698REDUCE_LOOP:
2699    setp.lt.u32 %p, %half, 1;
2700    @%p bra REDUCE_DONE;
2701    setp.ge.u32 %p, %tid, %half;
2702    @%p bra REDUCE_SKIP;
2703
2704    add.u32 %idx, %tid, %half;
2705    cvt.u64.u32 %off64, %idx;
2706    shl.b64 %off64, %off64, 2;
2707    ld.shared.f32 %other, [smem_sum + %off64];
2708    cvt.u64.u32 %tmp64, %tid;
2709    shl.b64 %tmp64, %tmp64, 2;
2710    ld.shared.f32 %sum, [smem_sum + %tmp64];
2711    add.f32 %sum, %sum, %other;
2712    st.shared.f32 [smem_sum + %tmp64], %sum;
2713
2714    ld.shared.f32 %other, [smem_sq + %off64];
2715    ld.shared.f32 %sqsum, [smem_sq + %tmp64];
2716    add.f32 %sqsum, %sqsum, %other;
2717    st.shared.f32 [smem_sq + %tmp64], %sqsum;
2718
2719REDUCE_SKIP:
2720    bar.sync 0;
2721    shr.u32 %half, %half, 1;
2722    bra REDUCE_LOOP;
2723
2724REDUCE_DONE:
2725    // Thread 0 computes mean and invstd
2726    setp.ne.u32 %ptid0, %tid, 0;
2727
2728    @%ptid0 bra WAIT_STATS;
2729
2730    ld.shared.f32 %sum, [smem_sum];
2731    ld.shared.f32 %sqsum, [smem_sq];
2732    cvt.rn.f32.u32 %n_f, %tpc;
2733    div.rn.f32 %mean, %sum, %n_f;
2734    // var = sqsum/n - mean^2
2735    div.rn.f32 %var, %sqsum, %n_f;
2736    fma.rn.f32 %var, %mean, %mean, %var;  // This adds mean^2, need to subtract
2737    // Actually: var = E[x^2] - E[x]^2, so var = sqsum/n - mean^2
2738    // We had: var = sqsum/n, now subtract mean^2
2739    neg.f32 %other, %mean;
2740    fma.rn.f32 %var, %other, %mean, %var; // var = var + (-mean)*mean = sqsum/n - mean^2
2741
2742    // invstd = 1/sqrt(var + eps)
2743    add.f32 %other, %var, %eps_reg;
2744    sqrt.rn.f32 %other, %other;
2745    div.rn.f32 %invstd, %one, %other;
2746
2747    // Save mean and invstd
2748    cvt.u64.u32 %off64, %ch;
2749    shl.b64 %off64, %off64, 2;
2750    add.u64 %tmp64, %sm, %off64;
2751    st.global.f32 [%tmp64], %mean;
2752    add.u64 %tmp64, %si, %off64;
2753    st.global.f32 [%tmp64], %invstd;
2754
2755    // Store to shared for other threads
2756    st.shared.f32 [smem_sum], %mean;
2757    st.shared.f32 [smem_sq], %invstd;
2758
2759WAIT_STATS:
2760    bar.sync 0;
2761    // All threads read mean and invstd from shared
2762    ld.shared.f32 %mean, [smem_sum];
2763    ld.shared.f32 %invstd, [smem_sq];
2764
2765    // Load weight and bias for this channel
2766    cvt.u64.u32 %off64, %ch;
2767    shl.b64 %off64, %off64, 2;
2768    add.u64 %tmp64, %w, %off64;
2769    ld.global.f32 %gamma, [%tmp64];
2770    add.u64 %tmp64, %b, %off64;
2771    ld.global.f32 %beta, [%tmp64];
2772
2773    // ---- Pass 2: normalize + affine ----
2774    // For now this is a placeholder - the indexing needs to match pass 1
2775    // Each thread normalizes its elements
2776
2777END:
2778    ret;
2779}
2780";
2781
2782/// PTX kernel for MaxPool2d forward: sliding window max.
2783///
2784/// One thread per output element. Reads the kernel-sized window from the
2785/// input and computes the maximum value.
2786#[cfg(feature = "cuda")]
2787pub(crate) const MAXPOOL2D_PTX: &str = "\
2788.version 7.0
2789.target sm_52
2790.address_size 64
2791
2792.visible .entry maxpool2d_forward_kernel(
2793    .param .u64 input_ptr,
2794    .param .u64 output_ptr,
2795    .param .u32 batch,
2796    .param .u32 channels,
2797    .param .u32 h_in,
2798    .param .u32 w_in,
2799    .param .u32 h_out,
2800    .param .u32 w_out,
2801    .param .u32 kh,
2802    .param .u32 kw,
2803    .param .u32 sh,
2804    .param .u32 sw,
2805    .param .u32 ph,
2806    .param .u32 pw,
2807    .param .u32 total
2808) {
2809    .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
2810    .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp;
2811    .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
2812    .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
2813    .reg .u32 %batch_reg, %ch_reg;
2814    .reg .u64 %in, %out, %off64, %tmp64;
2815    .reg .f32 %max_val, %cur_val, %neg_inf;
2816    .reg .pred %p, %p_bounds, %p_gt;
2817
2818    ld.param.u64 %in, [input_ptr];
2819    ld.param.u64 %out, [output_ptr];
2820    ld.param.u32 %batch_reg, [batch];
2821    ld.param.u32 %ch_reg, [channels];
2822    ld.param.u32 %h_in_reg, [h_in];
2823    ld.param.u32 %w_in_reg, [w_in];
2824    ld.param.u32 %h_out_reg, [h_out];
2825    ld.param.u32 %w_out_reg, [w_out];
2826    ld.param.u32 %kh_reg, [kh];
2827    ld.param.u32 %kw_reg, [kw];
2828    ld.param.u32 %sh_reg, [sh];
2829    ld.param.u32 %sw_reg, [sw];
2830    ld.param.u32 %ph_reg, [ph];
2831    ld.param.u32 %pw_reg, [pw];
2832    ld.param.u32 %total_reg, [total];
2833
2834    mov.u32 %bid, %ctaid.x;
2835    mov.u32 %bdim, %ntid.x;
2836    mov.u32 %tid, %tid.x;
2837    mov.u32 %gdim, %nctaid.x;
2838    mad.lo.u32 %idx, %bid, %bdim, %tid;
2839    mul.lo.u32 %stride, %bdim, %gdim;
2840
2841    // -inf for max initialization
2842    mov.f32 %neg_inf, 0fFF800000;
2843
2844LOOP:
2845    setp.ge.u32 %p, %idx, %total_reg;
2846    @%p bra END;
2847
2848    // Decompose idx into (b, c, oh, ow)
2849    mov.u32 %rem, %idx;
2850    div.u32 %b_idx, %rem, %ch_reg;
2851    // Actually need: idx = b * C * H_out * W_out + c * H_out * W_out + oh * W_out + ow
2852    // So decompose from the right:
2853    rem.u32 %ow, %rem, %w_out_reg;
2854    div.u32 %rem, %rem, %w_out_reg;
2855    rem.u32 %oh, %rem, %h_out_reg;
2856    div.u32 %rem, %rem, %h_out_reg;
2857    rem.u32 %c_idx, %rem, %ch_reg;
2858    div.u32 %b_idx, %rem, %ch_reg;
2859
2860    mov.f32 %max_val, %neg_inf;
2861
2862    // Slide the kernel window
2863    mov.u32 %i, 0;
2864KH_LOOP:
2865    setp.ge.u32 %p, %i, %kh_reg;
2866    @%p bra KH_DONE;
2867
2868    mov.u32 %j, 0;
2869KW_LOOP:
2870    setp.ge.u32 %p, %j, %kw_reg;
2871    @%p bra KW_DONE;
2872
2873    // ih = oh * sh + i - ph, iw = ow * sw + j - pw
2874    mad.lo.u32 %ih, %oh, %sh_reg, %i;
2875    sub.u32 %ih, %ih, %ph_reg;
2876    mad.lo.u32 %iw, %ow, %sw_reg, %j;
2877    sub.u32 %iw, %iw, %pw_reg;
2878
2879    // Bounds check: 0 <= ih < h_in && 0 <= iw < w_in
2880    // Since unsigned, just check < h_in and < w_in
2881    setp.ge.u32 %p_bounds, %ih, %h_in_reg;
2882    @%p_bounds bra KW_NEXT;
2883    setp.ge.u32 %p_bounds, %iw, %w_in_reg;
2884    @%p_bounds bra KW_NEXT;
2885
2886    // input_offset = b * C * H * W + c * H * W + ih * W + iw
2887    mul.lo.u32 %tmp, %b_idx, %ch_reg;
2888    add.u32 %tmp, %tmp, %c_idx;
2889    mul.lo.u32 %tmp, %tmp, %h_in_reg;
2890    add.u32 %tmp, %tmp, %ih;
2891    mul.lo.u32 %tmp, %tmp, %w_in_reg;
2892    add.u32 %tmp, %tmp, %iw;
2893
2894    cvt.u64.u32 %off64, %tmp;
2895    shl.b64 %off64, %off64, 2;
2896    add.u64 %tmp64, %in, %off64;
2897    ld.global.f32 %cur_val, [%tmp64];
2898
2899    max.f32 %max_val, %max_val, %cur_val;
2900
2901KW_NEXT:
2902    add.u32 %j, %j, 1;
2903    bra KW_LOOP;
2904
2905KW_DONE:
2906    add.u32 %i, %i, 1;
2907    bra KH_LOOP;
2908
2909KH_DONE:
2910    // Store output
2911    cvt.u64.u32 %off64, %idx;
2912    shl.b64 %off64, %off64, 2;
2913    add.u64 %tmp64, %out, %off64;
2914    st.global.f32 [%tmp64], %max_val;
2915
2916    add.u32 %idx, %idx, %stride;
2917    bra LOOP;
2918
2919END:
2920    ret;
2921}
2922";
2923
2924/// PTX kernel for AvgPool2d forward: sliding window average.
2925///
2926/// One thread per output element. Same structure as MaxPool2d but
2927/// computes sum / count instead of max.
2928#[cfg(feature = "cuda")]
2929pub(crate) const AVGPOOL2D_PTX: &str = "\
2930.version 7.0
2931.target sm_52
2932.address_size 64
2933
2934.visible .entry avgpool2d_forward_kernel(
2935    .param .u64 input_ptr,
2936    .param .u64 output_ptr,
2937    .param .u32 batch,
2938    .param .u32 channels,
2939    .param .u32 h_in,
2940    .param .u32 w_in,
2941    .param .u32 h_out,
2942    .param .u32 w_out,
2943    .param .u32 kh,
2944    .param .u32 kw,
2945    .param .u32 sh,
2946    .param .u32 sw,
2947    .param .u32 ph,
2948    .param .u32 pw,
2949    .param .u32 total
2950) {
2951    .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
2952    .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp, %count;
2953    .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
2954    .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
2955    .reg .u32 %batch_reg, %ch_reg;
2956    .reg .u64 %in, %out, %off64, %tmp64;
2957    .reg .f32 %sum_val, %cur_val, %count_f, %avg;
2958    .reg .pred %p, %p_bounds;
2959
2960    ld.param.u64 %in, [input_ptr];
2961    ld.param.u64 %out, [output_ptr];
2962    ld.param.u32 %batch_reg, [batch];
2963    ld.param.u32 %ch_reg, [channels];
2964    ld.param.u32 %h_in_reg, [h_in];
2965    ld.param.u32 %w_in_reg, [w_in];
2966    ld.param.u32 %h_out_reg, [h_out];
2967    ld.param.u32 %w_out_reg, [w_out];
2968    ld.param.u32 %kh_reg, [kh];
2969    ld.param.u32 %kw_reg, [kw];
2970    ld.param.u32 %sh_reg, [sh];
2971    ld.param.u32 %sw_reg, [sw];
2972    ld.param.u32 %ph_reg, [ph];
2973    ld.param.u32 %pw_reg, [pw];
2974    ld.param.u32 %total_reg, [total];
2975
2976    mov.u32 %bid, %ctaid.x;
2977    mov.u32 %bdim, %ntid.x;
2978    mov.u32 %tid, %tid.x;
2979    mov.u32 %gdim, %nctaid.x;
2980    mad.lo.u32 %idx, %bid, %bdim, %tid;
2981    mul.lo.u32 %stride, %bdim, %gdim;
2982
2983LOOP:
2984    setp.ge.u32 %p, %idx, %total_reg;
2985    @%p bra END;
2986
2987    // Decompose idx into (b, c, oh, ow) — same as MaxPool2d
2988    mov.u32 %rem, %idx;
2989    rem.u32 %ow, %rem, %w_out_reg;
2990    div.u32 %rem, %rem, %w_out_reg;
2991    rem.u32 %oh, %rem, %h_out_reg;
2992    div.u32 %rem, %rem, %h_out_reg;
2993    rem.u32 %c_idx, %rem, %ch_reg;
2994    div.u32 %b_idx, %rem, %ch_reg;
2995
2996    mov.f32 %sum_val, 0f00000000;
2997    mov.u32 %count, 0;
2998
2999    mov.u32 %i, 0;
3000AKH_LOOP:
3001    setp.ge.u32 %p, %i, %kh_reg;
3002    @%p bra AKH_DONE;
3003
3004    mov.u32 %j, 0;
3005AKW_LOOP:
3006    setp.ge.u32 %p, %j, %kw_reg;
3007    @%p bra AKW_DONE;
3008
3009    mad.lo.u32 %ih, %oh, %sh_reg, %i;
3010    sub.u32 %ih, %ih, %ph_reg;
3011    mad.lo.u32 %iw, %ow, %sw_reg, %j;
3012    sub.u32 %iw, %iw, %pw_reg;
3013
3014    setp.ge.u32 %p_bounds, %ih, %h_in_reg;
3015    @%p_bounds bra AKW_NEXT;
3016    setp.ge.u32 %p_bounds, %iw, %w_in_reg;
3017    @%p_bounds bra AKW_NEXT;
3018
3019    mul.lo.u32 %tmp, %b_idx, %ch_reg;
3020    add.u32 %tmp, %tmp, %c_idx;
3021    mul.lo.u32 %tmp, %tmp, %h_in_reg;
3022    add.u32 %tmp, %tmp, %ih;
3023    mul.lo.u32 %tmp, %tmp, %w_in_reg;
3024    add.u32 %tmp, %tmp, %iw;
3025
3026    cvt.u64.u32 %off64, %tmp;
3027    shl.b64 %off64, %off64, 2;
3028    add.u64 %tmp64, %in, %off64;
3029    ld.global.f32 %cur_val, [%tmp64];
3030
3031    add.f32 %sum_val, %sum_val, %cur_val;
3032    add.u32 %count, %count, 1;
3033
3034AKW_NEXT:
3035    add.u32 %j, %j, 1;
3036    bra AKW_LOOP;
3037
3038AKW_DONE:
3039    add.u32 %i, %i, 1;
3040    bra AKH_LOOP;
3041
3042AKH_DONE:
3043    // avg = sum / count (count_include_pad = false behavior)
3044    cvt.rn.f32.u32 %count_f, %count;
3045    div.rn.f32 %avg, %sum_val, %count_f;
3046
3047    cvt.u64.u32 %off64, %idx;
3048    shl.b64 %off64, %off64, 2;
3049    add.u64 %tmp64, %out, %off64;
3050    st.global.f32 [%tmp64], %avg;
3051
3052    add.u32 %idx, %idx, %stride;
3053    bra LOOP;
3054
3055END:
3056    ret;
3057}
3058";
3059
3060#[cfg(feature = "cuda")]
3061pub(crate) const SOFTMAX_PTX: &str = "\
3062.version 7.0\n\
3063.target sm_52\n\
3064.address_size 64\n\
3065\n\
3066.shared .align 4 .f32 sdata[256];\n\
3067\n\
3068.visible .entry softmax_kernel(\n\
3069    .param .u64 input_ptr,\n\
3070    .param .u64 output_ptr,\n\
3071    .param .u32 rows,\n\
3072    .param .u32 cols\n\
3073) {\n\
3074    .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
3075    .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
3076    .reg .f32 %val, %max_val, %sum_val, %exp_val, %result;\n\
3077    .reg .pred %p, %loop_p;\n\
3078    .reg .u32 %half, %other_tid;\n\
3079    .reg .f32 %other_val;\n\
3080    .reg .pred %reduce_p;\n\
3081\n\
3082    ld.param.u64 %in, [input_ptr];\n\
3083    ld.param.u64 %out, [output_ptr];\n\
3084    ld.param.u32 %rows_reg, [rows];\n\
3085    ld.param.u32 %cols_reg, [cols];\n\
3086\n\
3087    mov.u32 %bid, %ctaid.x;\n\
3088    mov.u32 %bdim, %ntid.x;\n\
3089    mov.u32 %r_tid, %tid.x;\n\
3090    mov.u64 %sbase, sdata;\n\
3091\n\
3092    setp.ge.u32 %p, %bid, %rows_reg;\n\
3093    @%p bra DONE;\n\
3094\n\
3095    cvt.u64.u32 %row_off, %bid;\n\
3096    cvt.u64.u32 %off, %cols_reg;\n\
3097    mul.lo.u64 %row_off, %row_off, %off;\n\
3098    shl.b64 %row_off, %row_off, 2;\n\
3099\n\
3100    mov.f32 %max_val, 0fFF800000;\n\
3101    mov.u32 %j, %r_tid;\n\
3102FIND_MAX:\n\
3103    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3104    @%loop_p bra FIND_MAX_DONE;\n\
3105    cvt.u64.u32 %off, %j;\n\
3106    shl.b64 %off, %off, 2;\n\
3107    add.u64 %off, %in, %off;\n\
3108    add.u64 %off, %off, %row_off;\n\
3109    ld.global.f32 %val, [%off];\n\
3110    max.f32 %max_val, %max_val, %val;\n\
3111    add.u32 %j, %j, %bdim;\n\
3112    bra FIND_MAX;\n\
3113FIND_MAX_DONE:\n\
3114\n\
3115    cvt.u64.u32 %off, %r_tid;\n\
3116    shl.b64 %off, %off, 2;\n\
3117    add.u64 %saddr, %sbase, %off;\n\
3118    st.shared.f32 [%saddr], %max_val;\n\
3119    bar.sync 0;\n\
3120\n\
3121    mov.u32 %half, %bdim;\n\
3122MAX_REDUCE:\n\
3123    shr.u32 %half, %half, 1;\n\
3124    setp.eq.u32 %reduce_p, %half, 0;\n\
3125    @%reduce_p bra MAX_REDUCE_DONE;\n\
3126    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
3127    @%reduce_p bra MAX_REDUCE_SKIP;\n\
3128    add.u32 %other_tid, %r_tid, %half;\n\
3129    cvt.u64.u32 %off, %other_tid;\n\
3130    shl.b64 %off, %off, 2;\n\
3131    add.u64 %saddr, %sbase, %off;
3132    ld.shared.f32 %other_val, [%saddr];\n\
3133    cvt.u64.u32 %off, %r_tid;\n\
3134    shl.b64 %off, %off, 2;\n\
3135    add.u64 %saddr, %sbase, %off;\n\
3136    ld.shared.f32 %max_val, [%saddr];\n\
3137    max.f32 %max_val, %max_val, %other_val;\n\
3138    add.u64 %saddr, %sbase, %off;\n\
3139    st.shared.f32 [%saddr], %max_val;\n\
3140MAX_REDUCE_SKIP:\n\
3141    bar.sync 0;\n\
3142    bra MAX_REDUCE;\n\
3143MAX_REDUCE_DONE:\n\
3144\n\
3145    ld.shared.f32 %max_val, [sdata];\n\
3146    bar.sync 0;\n\
3147\n\
3148    mov.f32 %sum_val, 0f00000000;\n\
3149    mov.u32 %j, %r_tid;\n\
3150SUM_EXP:\n\
3151    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3152    @%loop_p bra SUM_EXP_DONE;\n\
3153    cvt.u64.u32 %off, %j;\n\
3154    shl.b64 %off, %off, 2;\n\
3155    add.u64 %off, %in, %off;\n\
3156    add.u64 %off, %off, %row_off;\n\
3157    ld.global.f32 %val, [%off];\n\
3158    sub.f32 %val, %val, %max_val;\n\
3159    mul.f32 %val, %val, 0f3FB8AA3B;\n\
3160    ex2.approx.f32 %exp_val, %val;\n\
3161    add.f32 %sum_val, %sum_val, %exp_val;\n\
3162    cvt.u64.u32 %off, %j;\n\
3163    shl.b64 %off, %off, 2;\n\
3164    add.u64 %off, %out, %off;\n\
3165    add.u64 %off, %off, %row_off;\n\
3166    st.global.f32 [%off], %exp_val;\n\
3167    add.u32 %j, %j, %bdim;\n\
3168    bra SUM_EXP;\n\
3169SUM_EXP_DONE:\n\
3170\n\
3171    cvt.u64.u32 %off, %r_tid;\n\
3172    shl.b64 %off, %off, 2;\n\
3173    add.u64 %saddr, %sbase, %off;\n\
3174    st.shared.f32 [%saddr], %sum_val;\n\
3175    bar.sync 0;\n\
3176\n\
3177    mov.u32 %half, %bdim;\n\
3178SUM_REDUCE:\n\
3179    shr.u32 %half, %half, 1;\n\
3180    setp.eq.u32 %reduce_p, %half, 0;\n\
3181    @%reduce_p bra SUM_REDUCE_DONE;\n\
3182    setp.ge.u32 %reduce_p, %r_tid, %half;\n\
3183    @%reduce_p bra SUM_REDUCE_SKIP;\n\
3184    add.u32 %other_tid, %r_tid, %half;\n\
3185    cvt.u64.u32 %off, %other_tid;\n\
3186    shl.b64 %off, %off, 2;\n\
3187    add.u64 %saddr, %sbase, %off;
3188    ld.shared.f32 %other_val, [%saddr];\n\
3189    cvt.u64.u32 %off, %r_tid;\n\
3190    shl.b64 %off, %off, 2;\n\
3191    add.u64 %saddr, %sbase, %off;\n\
3192    ld.shared.f32 %sum_val, [%saddr];\n\
3193    add.f32 %sum_val, %sum_val, %other_val;\n\
3194    add.u64 %saddr, %sbase, %off;\n\
3195    st.shared.f32 [%saddr], %sum_val;\n\
3196SUM_REDUCE_SKIP:\n\
3197    bar.sync 0;\n\
3198    bra SUM_REDUCE;\n\
3199SUM_REDUCE_DONE:\n\
3200\n\
3201    ld.shared.f32 %sum_val, [sdata];\n\
3202    bar.sync 0;\n\
3203\n\
3204    rcp.approx.f32 %sum_val, %sum_val;\n\
3205    mov.u32 %j, %r_tid;\n\
3206NORMALIZE:\n\
3207    setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3208    @%loop_p bra NORMALIZE_DONE;\n\
3209    cvt.u64.u32 %off, %j;\n\
3210    shl.b64 %off, %off, 2;\n\
3211    add.u64 %off, %out, %off;\n\
3212    add.u64 %off, %off, %row_off;\n\
3213    ld.global.f32 %val, [%off];\n\
3214    mul.f32 %result, %val, %sum_val;\n\
3215    st.global.f32 [%off], %result;\n\
3216    add.u32 %j, %j, %bdim;\n\
3217    bra NORMALIZE;\n\
3218NORMALIZE_DONE:\n\
3219\n\
3220DONE:\n\
3221    ret;\n\
3222}\n\
3223";
3224
3225// ---------------------------------------------------------------------------
3226// Dropout PTX kernel (inverted dropout with xorshift RNG)
3227// ---------------------------------------------------------------------------
3228
3229#[cfg(feature = "cuda")]
3230pub(crate) const DROPOUT_PTX: &str = "\
3231.version 7.0\n\
3232.target sm_52\n\
3233.address_size 64\n\
3234\n\
3235.visible .entry dropout_kernel(\n\
3236    .param .u64 input_ptr,\n\
3237    .param .u64 output_ptr,\n\
3238    .param .u32 n,\n\
3239    .param .u32 threshold,\n\
3240    .param .f32 scale,\n\
3241    .param .u32 seed\n\
3242) {\n\
3243    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %thresh, %seed_reg, %rng, %tmp;\n\
3244    .reg .u64 %in, %out, %off;\n\
3245    .reg .f32 %val, %scale_reg, %zero;\n\
3246    .reg .pred %p, %drop_p;\n\
3247\n\
3248    ld.param.u64 %in, [input_ptr];\n\
3249    ld.param.u64 %out, [output_ptr];\n\
3250    ld.param.u32 %n_reg, [n];\n\
3251    ld.param.u32 %thresh, [threshold];\n\
3252    ld.param.f32 %scale_reg, [scale];\n\
3253    ld.param.u32 %seed_reg, [seed];\n\
3254\n\
3255    mov.u32 %bid, %ctaid.x;\n\
3256    mov.u32 %bdim, %ntid.x;\n\
3257    mov.u32 %r_tid, %tid.x;\n\
3258    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
3259\n\
3260    setp.ge.u32 %p, %r_tid, %n_reg;\n\
3261    @%p bra DONE;\n\
3262\n\
3263    mul.lo.u32 %rng, %r_tid, 2654435761;\n\
3264    xor.b32 %rng, %rng, %seed_reg;\n\
3265    shl.b32 %tmp, %rng, 13;\n\
3266    xor.b32 %rng, %rng, %tmp;\n\
3267    shr.b32 %tmp, %rng, 17;\n\
3268    xor.b32 %rng, %rng, %tmp;\n\
3269    shl.b32 %tmp, %rng, 5;\n\
3270    xor.b32 %rng, %rng, %tmp;\n\
3271\n\
3272    cvt.u64.u32 %off, %r_tid;\n\
3273    shl.b64 %off, %off, 2;\n\
3274    add.u64 %in, %in, %off;\n\
3275    add.u64 %out, %out, %off;\n\
3276    ld.global.f32 %val, [%in];\n\
3277\n\
3278    setp.lo.u32 %drop_p, %rng, %thresh;\n\
3279    mov.f32 %zero, 0f00000000;\n\
3280    @%drop_p mov.f32 %val, %zero;\n\
3281    @!%drop_p mul.f32 %val, %val, %scale_reg;\n\
3282\n\
3283    st.global.f32 [%out], %val;\n\
3284\n\
3285DONE:\n\
3286    ret;\n\
3287}\n\
3288";
3289
3290// ---------------------------------------------------------------------------
3291// General N-dimensional broadcast binary PTX kernels
3292// ---------------------------------------------------------------------------
3293//
3294// Each thread computes one output element. The kernel decomposes the flat
3295// output index into N-dimensional coordinates, maps each coordinate through
3296// broadcast strides for A and B, and loads from the correct flat position.
3297//
3298// Parameters:
3299//   a_ptr         - pointer to A's device buffer
3300//   b_ptr         - pointer to B's device buffer
3301//   out_ptr       - pointer to output device buffer
3302//   a_strides_ptr - pointer to u32[ndim] broadcast strides for A
3303//   b_strides_ptr - pointer to u32[ndim] broadcast strides for B
3304//   out_shape_ptr - pointer to u32[ndim] output shape
3305//   n             - total output elements
3306//   ndim          - number of dimensions
3307//
3308// Broadcast strides: for each dimension d, stride is the normal
3309// C-contiguous stride if dim_size > 1, or 0 if dim_size == 1 (broadcast).
3310
3311/// PTX for general broadcast add: `out[i] = a[bcast_a(i)] + b[bcast_b(i)]`.
3312#[cfg(feature = "cuda")]
3313pub(crate) const BROADCAST_ADD_PTX: &str = "\
3314.version 7.0
3315.target sm_52
3316.address_size 64
3317
3318.visible .entry broadcast_add_kernel(
3319    .param .u64 a_ptr,
3320    .param .u64 b_ptr,
3321    .param .u64 out_ptr,
3322    .param .u64 a_strides_ptr,
3323    .param .u64 b_strides_ptr,
3324    .param .u64 out_shape_ptr,
3325    .param .u32 n,
3326    .param .u32 ndim
3327) {
3328    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3329    .reg .u32 %remaining, %a_idx, %b_idx, %d;
3330    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3331    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3332    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3333    .reg .f32 %va, %vb, %vr;
3334    .reg .pred %p, %loop_p;
3335
3336    ld.param.u64 %a, [a_ptr];
3337    ld.param.u64 %b, [b_ptr];
3338    ld.param.u64 %out, [out_ptr];
3339    ld.param.u64 %a_str, [a_strides_ptr];
3340    ld.param.u64 %b_str, [b_strides_ptr];
3341    ld.param.u64 %oshape, [out_shape_ptr];
3342    ld.param.u32 %n_reg, [n];
3343    ld.param.u32 %ndim_reg, [ndim];
3344
3345    // Global thread index.
3346    mov.u32 %bid, %ctaid.x;
3347    mov.u32 %bdim, %ntid.x;
3348    mov.u32 %r_tid, %tid.x;
3349    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3350
3351    setp.ge.u32 %p, %r_tid, %n_reg;
3352    @%p bra DONE;
3353
3354    // Decompose flat index into N-d coordinates and compute A/B indices.
3355    mov.u32 %remaining, %r_tid;
3356    mov.u32 %a_idx, 0;
3357    mov.u32 %b_idx, 0;
3358    mov.u32 %d, %ndim_reg;
3359
3360LOOP:
3361    setp.eq.u32 %loop_p, %d, 0;
3362    @%loop_p bra END_LOOP;
3363
3364    sub.u32 %d, %d, 1;
3365
3366    // Byte offset for dimension d: d * 4.
3367    cvt.u64.u32 %d64, %d;
3368    shl.b64 %d64, %d64, 2;
3369
3370    // Load out_shape[d].
3371    add.u64 %tmp, %oshape, %d64;
3372    ld.global.u32 %shape_d, [%tmp];
3373
3374    // Load a_strides[d] and b_strides[d].
3375    add.u64 %tmp, %a_str, %d64;
3376    ld.global.u32 %a_str_d, [%tmp];
3377    add.u64 %tmp, %b_str, %d64;
3378    ld.global.u32 %b_str_d, [%tmp];
3379
3380    // coord = remaining % shape_d; remaining /= shape_d.
3381    rem.u32 %coord, %remaining, %shape_d;
3382    div.u32 %remaining, %remaining, %shape_d;
3383
3384    // a_idx += coord * a_stride[d]; b_idx += coord * b_stride[d].
3385    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3386    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3387
3388    bra LOOP;
3389END_LOOP:
3390
3391    // Load a[a_idx] and b[b_idx] (f32 = 4 bytes).
3392    cvt.u64.u32 %off_a, %a_idx;
3393    shl.b64 %off_a, %off_a, 2;
3394    add.u64 %off_a, %a, %off_a;
3395    ld.global.f32 %va, [%off_a];
3396
3397    cvt.u64.u32 %off_b, %b_idx;
3398    shl.b64 %off_b, %off_b, 2;
3399    add.u64 %off_b, %b, %off_b;
3400    ld.global.f32 %vb, [%off_b];
3401
3402    // Operation: add.
3403    add.f32 %vr, %va, %vb;
3404
3405    // Store to out[tid].
3406    cvt.u64.u32 %off_out, %r_tid;
3407    shl.b64 %off_out, %off_out, 2;
3408    add.u64 %off_out, %out, %off_out;
3409    st.global.f32 [%off_out], %vr;
3410
3411DONE:
3412    ret;
3413}
3414";
3415
3416/// PTX for general broadcast sub: `out[i] = a[bcast_a(i)] - b[bcast_b(i)]`.
3417#[cfg(feature = "cuda")]
3418pub(crate) const BROADCAST_SUB_PTX: &str = "\
3419.version 7.0
3420.target sm_52
3421.address_size 64
3422
3423.visible .entry broadcast_sub_kernel(
3424    .param .u64 a_ptr,
3425    .param .u64 b_ptr,
3426    .param .u64 out_ptr,
3427    .param .u64 a_strides_ptr,
3428    .param .u64 b_strides_ptr,
3429    .param .u64 out_shape_ptr,
3430    .param .u32 n,
3431    .param .u32 ndim
3432) {
3433    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3434    .reg .u32 %remaining, %a_idx, %b_idx, %d;
3435    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3436    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3437    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3438    .reg .f32 %va, %vb, %vr;
3439    .reg .pred %p, %loop_p;
3440
3441    ld.param.u64 %a, [a_ptr];
3442    ld.param.u64 %b, [b_ptr];
3443    ld.param.u64 %out, [out_ptr];
3444    ld.param.u64 %a_str, [a_strides_ptr];
3445    ld.param.u64 %b_str, [b_strides_ptr];
3446    ld.param.u64 %oshape, [out_shape_ptr];
3447    ld.param.u32 %n_reg, [n];
3448    ld.param.u32 %ndim_reg, [ndim];
3449
3450    mov.u32 %bid, %ctaid.x;
3451    mov.u32 %bdim, %ntid.x;
3452    mov.u32 %r_tid, %tid.x;
3453    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3454    setp.ge.u32 %p, %r_tid, %n_reg;
3455    @%p bra DONE;
3456
3457    mov.u32 %remaining, %r_tid;
3458    mov.u32 %a_idx, 0;
3459    mov.u32 %b_idx, 0;
3460    mov.u32 %d, %ndim_reg;
3461LOOP:
3462    setp.eq.u32 %loop_p, %d, 0;
3463    @%loop_p bra END_LOOP;
3464    sub.u32 %d, %d, 1;
3465    cvt.u64.u32 %d64, %d;
3466    shl.b64 %d64, %d64, 2;
3467    add.u64 %tmp, %oshape, %d64;
3468    ld.global.u32 %shape_d, [%tmp];
3469    add.u64 %tmp, %a_str, %d64;
3470    ld.global.u32 %a_str_d, [%tmp];
3471    add.u64 %tmp, %b_str, %d64;
3472    ld.global.u32 %b_str_d, [%tmp];
3473    rem.u32 %coord, %remaining, %shape_d;
3474    div.u32 %remaining, %remaining, %shape_d;
3475    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3476    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3477    bra LOOP;
3478END_LOOP:
3479
3480    cvt.u64.u32 %off_a, %a_idx;
3481    shl.b64 %off_a, %off_a, 2;
3482    add.u64 %off_a, %a, %off_a;
3483    ld.global.f32 %va, [%off_a];
3484    cvt.u64.u32 %off_b, %b_idx;
3485    shl.b64 %off_b, %off_b, 2;
3486    add.u64 %off_b, %b, %off_b;
3487    ld.global.f32 %vb, [%off_b];
3488
3489    sub.f32 %vr, %va, %vb;
3490
3491    cvt.u64.u32 %off_out, %r_tid;
3492    shl.b64 %off_out, %off_out, 2;
3493    add.u64 %off_out, %out, %off_out;
3494    st.global.f32 [%off_out], %vr;
3495DONE:
3496    ret;
3497}
3498";
3499
3500/// PTX for general broadcast mul: `out[i] = a[bcast_a(i)] * b[bcast_b(i)]`.
3501#[cfg(feature = "cuda")]
3502pub(crate) const BROADCAST_MUL_PTX: &str = "\
3503.version 7.0
3504.target sm_52
3505.address_size 64
3506
3507.visible .entry broadcast_mul_kernel(
3508    .param .u64 a_ptr,
3509    .param .u64 b_ptr,
3510    .param .u64 out_ptr,
3511    .param .u64 a_strides_ptr,
3512    .param .u64 b_strides_ptr,
3513    .param .u64 out_shape_ptr,
3514    .param .u32 n,
3515    .param .u32 ndim
3516) {
3517    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3518    .reg .u32 %remaining, %a_idx, %b_idx, %d;
3519    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3520    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3521    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3522    .reg .f32 %va, %vb, %vr;
3523    .reg .pred %p, %loop_p;
3524
3525    ld.param.u64 %a, [a_ptr];
3526    ld.param.u64 %b, [b_ptr];
3527    ld.param.u64 %out, [out_ptr];
3528    ld.param.u64 %a_str, [a_strides_ptr];
3529    ld.param.u64 %b_str, [b_strides_ptr];
3530    ld.param.u64 %oshape, [out_shape_ptr];
3531    ld.param.u32 %n_reg, [n];
3532    ld.param.u32 %ndim_reg, [ndim];
3533
3534    mov.u32 %bid, %ctaid.x;
3535    mov.u32 %bdim, %ntid.x;
3536    mov.u32 %r_tid, %tid.x;
3537    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3538    setp.ge.u32 %p, %r_tid, %n_reg;
3539    @%p bra DONE;
3540
3541    mov.u32 %remaining, %r_tid;
3542    mov.u32 %a_idx, 0;
3543    mov.u32 %b_idx, 0;
3544    mov.u32 %d, %ndim_reg;
3545LOOP:
3546    setp.eq.u32 %loop_p, %d, 0;
3547    @%loop_p bra END_LOOP;
3548    sub.u32 %d, %d, 1;
3549    cvt.u64.u32 %d64, %d;
3550    shl.b64 %d64, %d64, 2;
3551    add.u64 %tmp, %oshape, %d64;
3552    ld.global.u32 %shape_d, [%tmp];
3553    add.u64 %tmp, %a_str, %d64;
3554    ld.global.u32 %a_str_d, [%tmp];
3555    add.u64 %tmp, %b_str, %d64;
3556    ld.global.u32 %b_str_d, [%tmp];
3557    rem.u32 %coord, %remaining, %shape_d;
3558    div.u32 %remaining, %remaining, %shape_d;
3559    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3560    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3561    bra LOOP;
3562END_LOOP:
3563
3564    cvt.u64.u32 %off_a, %a_idx;
3565    shl.b64 %off_a, %off_a, 2;
3566    add.u64 %off_a, %a, %off_a;
3567    ld.global.f32 %va, [%off_a];
3568    cvt.u64.u32 %off_b, %b_idx;
3569    shl.b64 %off_b, %off_b, 2;
3570    add.u64 %off_b, %b, %off_b;
3571    ld.global.f32 %vb, [%off_b];
3572
3573    mul.f32 %vr, %va, %vb;
3574
3575    cvt.u64.u32 %off_out, %r_tid;
3576    shl.b64 %off_out, %off_out, 2;
3577    add.u64 %off_out, %out, %off_out;
3578    st.global.f32 [%off_out], %vr;
3579DONE:
3580    ret;
3581}
3582";
3583
3584/// PTX source for `broadcast_div_kernel`: broadcast division, identical structure
3585/// to `broadcast_mul_kernel` but uses `div.f32` instead of `mul.f32`.
3586#[cfg(feature = "cuda")]
3587pub(crate) const BROADCAST_DIV_PTX: &str = "\
3588.version 7.0
3589.target sm_52
3590.address_size 64
3591
3592.visible .entry broadcast_div_kernel(
3593    .param .u64 a_ptr,
3594    .param .u64 b_ptr,
3595    .param .u64 out_ptr,
3596    .param .u64 a_strides_ptr,
3597    .param .u64 b_strides_ptr,
3598    .param .u64 out_shape_ptr,
3599    .param .u32 n,
3600    .param .u32 ndim
3601) {
3602    .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3603    .reg .u32 %remaining, %a_idx, %b_idx, %d;
3604    .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3605    .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3606    .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3607    .reg .f32 %va, %vb, %vr;
3608    .reg .pred %p, %loop_p;
3609
3610    ld.param.u64 %a, [a_ptr];
3611    ld.param.u64 %b, [b_ptr];
3612    ld.param.u64 %out, [out_ptr];
3613    ld.param.u64 %a_str, [a_strides_ptr];
3614    ld.param.u64 %b_str, [b_strides_ptr];
3615    ld.param.u64 %oshape, [out_shape_ptr];
3616    ld.param.u32 %n_reg, [n];
3617    ld.param.u32 %ndim_reg, [ndim];
3618
3619    mov.u32 %bid, %ctaid.x;
3620    mov.u32 %bdim, %ntid.x;
3621    mov.u32 %r_tid, %tid.x;
3622    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3623    setp.ge.u32 %p, %r_tid, %n_reg;
3624    @%p bra DONE;
3625
3626    mov.u32 %remaining, %r_tid;
3627    mov.u32 %a_idx, 0;
3628    mov.u32 %b_idx, 0;
3629    mov.u32 %d, %ndim_reg;
3630LOOP:
3631    setp.eq.u32 %loop_p, %d, 0;
3632    @%loop_p bra END_LOOP;
3633    sub.u32 %d, %d, 1;
3634    cvt.u64.u32 %d64, %d;
3635    shl.b64 %d64, %d64, 2;
3636    add.u64 %tmp, %oshape, %d64;
3637    ld.global.u32 %shape_d, [%tmp];
3638    add.u64 %tmp, %a_str, %d64;
3639    ld.global.u32 %a_str_d, [%tmp];
3640    add.u64 %tmp, %b_str, %d64;
3641    ld.global.u32 %b_str_d, [%tmp];
3642    rem.u32 %coord, %remaining, %shape_d;
3643    div.u32 %remaining, %remaining, %shape_d;
3644    mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3645    mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3646    bra LOOP;
3647END_LOOP:
3648
3649    cvt.u64.u32 %off_a, %a_idx;
3650    shl.b64 %off_a, %off_a, 2;
3651    add.u64 %off_a, %a, %off_a;
3652    ld.global.f32 %va, [%off_a];
3653    cvt.u64.u32 %off_b, %b_idx;
3654    shl.b64 %off_b, %off_b, 2;
3655    add.u64 %off_b, %b, %off_b;
3656    ld.global.f32 %vb, [%off_b];
3657
3658    div.f32 %vr, %va, %vb;
3659
3660    cvt.u64.u32 %off_out, %r_tid;
3661    shl.b64 %off_out, %off_out, 2;
3662    add.u64 %off_out, %out, %off_out;
3663    st.global.f32 [%off_out], %vr;
3664DONE:
3665    ret;
3666}
3667";
3668
3669/// PTX source for `strided_split_kernel`: extract a sub-tensor along a given axis.
3670///
3671/// Thread `i` computes:
3672///   `outer_idx = i / (split_size * inner_size)`
3673///   `within    = i % (split_size * inner_size)`
3674///   `src_idx   = outer_idx * total_along_axis * inner_size + (split_offset * inner_size) + within`
3675///   `out[i]    = in[src_idx]`
3676#[cfg(feature = "cuda")]
3677pub(crate) const STRIDED_SPLIT_PTX: &str = "\
3678.version 7.0
3679.target sm_52
3680.address_size 64
3681
3682.visible .entry strided_split_kernel(
3683    .param .u64 input_ptr,
3684    .param .u64 output_ptr,
3685    .param .u32 total_along_axis,
3686    .param .u32 split_offset,
3687    .param .u32 split_size,
3688    .param .u32 inner_size,
3689    .param .u32 n
3690) {
3691    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3692    .reg .u32 %total_ax, %sp_off, %sp_sz, %inner_sz;
3693    .reg .u32 %outer_idx, %within, %chunk_stride, %src_idx, %base_off, %tmp;
3694    .reg .u64 %in, %out, %off;
3695    .reg .f32 %val;
3696    .reg .pred %p;
3697
3698    ld.param.u64 %in, [input_ptr];
3699    ld.param.u64 %out, [output_ptr];
3700    ld.param.u32 %total_ax, [total_along_axis];
3701    ld.param.u32 %sp_off, [split_offset];
3702    ld.param.u32 %sp_sz, [split_size];
3703    ld.param.u32 %inner_sz, [inner_size];
3704    ld.param.u32 %n_reg, [n];
3705
3706    mov.u32 %bid, %ctaid.x;
3707    mov.u32 %bdim, %ntid.x;
3708    mov.u32 %r_tid, %tid.x;
3709    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3710
3711    setp.ge.u32 %p, %r_tid, %n_reg;
3712    @%p bra DONE;
3713
3714    // chunk_stride = split_size * inner_size
3715    mul.lo.u32 %chunk_stride, %sp_sz, %inner_sz;
3716
3717    // outer_idx = r_tid / chunk_stride
3718    div.u32 %outer_idx, %r_tid, %chunk_stride;
3719
3720    // within = r_tid % chunk_stride
3721    rem.u32 %within, %r_tid, %chunk_stride;
3722
3723    // base_off = split_offset * inner_size
3724    mul.lo.u32 %base_off, %sp_off, %inner_sz;
3725
3726    // src_idx = outer_idx * total_along_axis * inner_size + base_off + within
3727    mul.lo.u32 %src_idx, %outer_idx, %total_ax;
3728    mul.lo.u32 %src_idx, %src_idx, %inner_sz;
3729    add.u32 %src_idx, %src_idx, %base_off;
3730    add.u32 %src_idx, %src_idx, %within;
3731
3732    // Load from in[src_idx]
3733    cvt.u64.u32 %off, %src_idx;
3734    shl.b64 %off, %off, 2;
3735    add.u64 %off, %in, %off;
3736    ld.global.f32 %val, [%off];
3737
3738    // Store to out[r_tid]
3739    cvt.u64.u32 %off, %r_tid;
3740    shl.b64 %off, %off, 2;
3741    add.u64 %off, %out, %off;
3742    st.global.f32 [%off], %val;
3743
3744DONE:
3745    ret;
3746}
3747";
3748
3749/// PTX source for `strided_cat_kernel`: write a sub-tensor into a larger tensor
3750/// at an offset along an axis.
3751///
3752/// Thread `i` computes:
3753///   `outer_idx = i / (part_size * inner_size)`
3754///   `within    = i % (part_size * inner_size)`
3755///   `dst_idx   = outer_idx * total_along_axis * inner_size + (cat_offset * inner_size) + within`
3756///   `out[dst_idx] = in[i]`
3757#[cfg(feature = "cuda")]
3758pub(crate) const STRIDED_CAT_PTX: &str = "\
3759.version 7.0
3760.target sm_52
3761.address_size 64
3762
3763.visible .entry strided_cat_kernel(
3764    .param .u64 input_ptr,
3765    .param .u64 output_ptr,
3766    .param .u32 total_along_axis,
3767    .param .u32 cat_offset,
3768    .param .u32 part_size,
3769    .param .u32 inner_size,
3770    .param .u32 n
3771) {
3772    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3773    .reg .u32 %total_ax, %cat_off, %part_sz, %inner_sz;
3774    .reg .u32 %outer_idx, %within, %chunk_stride, %dst_idx, %base_off;
3775    .reg .u64 %in, %out, %off;
3776    .reg .f32 %val;
3777    .reg .pred %p;
3778
3779    ld.param.u64 %in, [input_ptr];
3780    ld.param.u64 %out, [output_ptr];
3781    ld.param.u32 %total_ax, [total_along_axis];
3782    ld.param.u32 %cat_off, [cat_offset];
3783    ld.param.u32 %part_sz, [part_size];
3784    ld.param.u32 %inner_sz, [inner_size];
3785    ld.param.u32 %n_reg, [n];
3786
3787    mov.u32 %bid, %ctaid.x;
3788    mov.u32 %bdim, %ntid.x;
3789    mov.u32 %r_tid, %tid.x;
3790    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3791
3792    setp.ge.u32 %p, %r_tid, %n_reg;
3793    @%p bra DONE;
3794
3795    // chunk_stride = part_size * inner_size
3796    mul.lo.u32 %chunk_stride, %part_sz, %inner_sz;
3797
3798    // outer_idx = r_tid / chunk_stride
3799    div.u32 %outer_idx, %r_tid, %chunk_stride;
3800
3801    // within = r_tid % chunk_stride
3802    rem.u32 %within, %r_tid, %chunk_stride;
3803
3804    // base_off = cat_offset * inner_size
3805    mul.lo.u32 %base_off, %cat_off, %inner_sz;
3806
3807    // dst_idx = outer_idx * total_along_axis * inner_size + base_off + within
3808    mul.lo.u32 %dst_idx, %outer_idx, %total_ax;
3809    mul.lo.u32 %dst_idx, %dst_idx, %inner_sz;
3810    add.u32 %dst_idx, %dst_idx, %base_off;
3811    add.u32 %dst_idx, %dst_idx, %within;
3812
3813    // Load from in[r_tid]
3814    cvt.u64.u32 %off, %r_tid;
3815    shl.b64 %off, %off, 2;
3816    add.u64 %off, %in, %off;
3817    ld.global.f32 %val, [%off];
3818
3819    // Store to out[dst_idx]
3820    cvt.u64.u32 %off, %dst_idx;
3821    shl.b64 %off, %off, 2;
3822    add.u64 %off, %out, %off;
3823    st.global.f32 [%off], %val;
3824
3825DONE:
3826    ret;
3827}
3828";
3829
3830/// PTX source for `div_kernel`: `out[i] = a[i] / b[i]`.
3831#[cfg(feature = "cuda")]
3832pub(crate) const DIV_PTX: &str = "\
3833.version 7.0
3834.target sm_52
3835.address_size 64
3836
3837.visible .entry div_kernel(
3838    .param .u64 a_ptr,
3839    .param .u64 b_ptr,
3840    .param .u64 out_ptr,
3841    .param .u32 n
3842) {
3843    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3844    .reg .u64 %a, %b, %out, %off;
3845    .reg .f32 %va, %vb, %vr;
3846    .reg .pred %p;
3847
3848    ld.param.u64 %a, [a_ptr];
3849    ld.param.u64 %b, [b_ptr];
3850    ld.param.u64 %out, [out_ptr];
3851    ld.param.u32 %n_reg, [n];
3852
3853    mov.u32 %bid, %ctaid.x;
3854    mov.u32 %bdim, %ntid.x;
3855    mov.u32 %r_tid, %tid.x;
3856    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3857
3858    setp.ge.u32 %p, %r_tid, %n_reg;
3859    @%p bra DONE;
3860
3861    cvt.u64.u32 %off, %r_tid;
3862    shl.b64 %off, %off, 2;
3863
3864    add.u64 %a, %a, %off;
3865    add.u64 %b, %b, %off;
3866    add.u64 %out, %out, %off;
3867
3868    ld.global.f32 %va, [%a];
3869    ld.global.f32 %vb, [%b];
3870    div.rn.f32 %vr, %va, %vb;
3871    st.global.f32 [%out], %vr;
3872
3873DONE:
3874    ret;
3875}
3876";
3877
3878/// PTX source for `exp_kernel`: `out[i] = exp(a[i])`.
3879#[cfg(feature = "cuda")]
3880pub(crate) const EXP_PTX: &str = "\
3881.version 7.0
3882.target sm_52
3883.address_size 64
3884
3885.visible .entry exp_kernel(
3886    .param .u64 a_ptr,
3887    .param .u64 out_ptr,
3888    .param .u32 n
3889) {
3890    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3891    .reg .u64 %a, %out, %off;
3892    .reg .f32 %va, %vr;
3893    .reg .pred %p;
3894
3895    ld.param.u64 %a, [a_ptr];
3896    ld.param.u64 %out, [out_ptr];
3897    ld.param.u32 %n_reg, [n];
3898
3899    mov.u32 %bid, %ctaid.x;
3900    mov.u32 %bdim, %ntid.x;
3901    mov.u32 %r_tid, %tid.x;
3902    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3903
3904    setp.ge.u32 %p, %r_tid, %n_reg;
3905    @%p bra DONE;
3906
3907    cvt.u64.u32 %off, %r_tid;
3908    shl.b64 %off, %off, 2;
3909
3910    add.u64 %a, %a, %off;
3911    add.u64 %out, %out, %off;
3912
3913    ld.global.f32 %va, [%a];
3914    // PTX ex2.approx computes 2^x; use the identity exp(x) = 2^(x * log2(e))
3915    // log2(e) = 1.4426950408889634
3916    mul.f32 %va, %va, 0f3FB8AA3B;
3917    ex2.approx.f32 %vr, %va;
3918    st.global.f32 [%out], %vr;
3919
3920DONE:
3921    ret;
3922}
3923";
3924
3925/// PTX source for `log_kernel`: `out[i] = ln(a[i])`.
3926#[cfg(feature = "cuda")]
3927pub(crate) const LOG_PTX: &str = "\
3928.version 7.0
3929.target sm_52
3930.address_size 64
3931
3932.visible .entry log_kernel(
3933    .param .u64 a_ptr,
3934    .param .u64 out_ptr,
3935    .param .u32 n
3936) {
3937    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3938    .reg .u64 %a, %out, %off;
3939    .reg .f32 %va, %vr;
3940    .reg .pred %p;
3941
3942    ld.param.u64 %a, [a_ptr];
3943    ld.param.u64 %out, [out_ptr];
3944    ld.param.u32 %n_reg, [n];
3945
3946    mov.u32 %bid, %ctaid.x;
3947    mov.u32 %bdim, %ntid.x;
3948    mov.u32 %r_tid, %tid.x;
3949    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3950
3951    setp.ge.u32 %p, %r_tid, %n_reg;
3952    @%p bra DONE;
3953
3954    cvt.u64.u32 %off, %r_tid;
3955    shl.b64 %off, %off, 2;
3956
3957    add.u64 %a, %a, %off;
3958    add.u64 %out, %out, %off;
3959
3960    ld.global.f32 %va, [%a];
3961    // PTX lg2.approx computes log2(x); use the identity ln(x) = log2(x) / log2(e)
3962    // 1/log2(e) = ln(2) = 0.6931471805599453
3963    lg2.approx.f32 %vr, %va;
3964    mul.f32 %vr, %vr, 0f3F317218;
3965    st.global.f32 [%out], %vr;
3966
3967DONE:
3968    ret;
3969}
3970";
3971
3972/// PTX source for `sqrt_kernel`: `out[i] = sqrt(a[i])`.
3973#[cfg(feature = "cuda")]
3974pub(crate) const SQRT_PTX: &str = "\
3975.version 7.0
3976.target sm_52
3977.address_size 64
3978
3979.visible .entry sqrt_kernel(
3980    .param .u64 a_ptr,
3981    .param .u64 out_ptr,
3982    .param .u32 n
3983) {
3984    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3985    .reg .u64 %a, %out, %off;
3986    .reg .f32 %va, %vr;
3987    .reg .pred %p;
3988
3989    ld.param.u64 %a, [a_ptr];
3990    ld.param.u64 %out, [out_ptr];
3991    ld.param.u32 %n_reg, [n];
3992
3993    mov.u32 %bid, %ctaid.x;
3994    mov.u32 %bdim, %ntid.x;
3995    mov.u32 %r_tid, %tid.x;
3996    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3997
3998    setp.ge.u32 %p, %r_tid, %n_reg;
3999    @%p bra DONE;
4000
4001    cvt.u64.u32 %off, %r_tid;
4002    shl.b64 %off, %off, 2;
4003
4004    add.u64 %a, %a, %off;
4005    add.u64 %out, %out, %off;
4006
4007    ld.global.f32 %va, [%a];
4008    sqrt.rn.f32 %vr, %va;
4009    st.global.f32 [%out], %vr;
4010
4011DONE:
4012    ret;
4013}
4014";
4015
4016/// PTX source for `pow_kernel`: `out[i] = a[i] ^ exponent`.
4017/// Uses the identity: x^e = 2^(e * log2(x)).
4018#[cfg(feature = "cuda")]
4019pub(crate) const POW_PTX: &str = "\
4020.version 7.0
4021.target sm_52
4022.address_size 64
4023
4024.visible .entry pow_kernel(
4025    .param .u64 a_ptr,
4026    .param .u64 out_ptr,
4027    .param .f32 exponent,
4028    .param .u32 n
4029) {
4030    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4031    .reg .u64 %a, %out, %off;
4032    .reg .f32 %va, %vr, %exp, %lg;
4033    .reg .pred %p;
4034
4035    ld.param.u64 %a, [a_ptr];
4036    ld.param.u64 %out, [out_ptr];
4037    ld.param.f32 %exp, [exponent];
4038    ld.param.u32 %n_reg, [n];
4039
4040    mov.u32 %bid, %ctaid.x;
4041    mov.u32 %bdim, %ntid.x;
4042    mov.u32 %r_tid, %tid.x;
4043    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4044
4045    setp.ge.u32 %p, %r_tid, %n_reg;
4046    @%p bra DONE;
4047
4048    cvt.u64.u32 %off, %r_tid;
4049    shl.b64 %off, %off, 2;
4050
4051    add.u64 %a, %a, %off;
4052    add.u64 %out, %out, %off;
4053
4054    ld.global.f32 %va, [%a];
4055    // x^e = 2^(e * log2(x))
4056    lg2.approx.f32 %lg, %va;
4057    mul.f32 %lg, %lg, %exp;
4058    ex2.approx.f32 %vr, %lg;
4059    st.global.f32 [%out], %vr;
4060
4061DONE:
4062    ret;
4063}
4064";
4065
4066/// PTX source for `abs_kernel`: `out[i] = |a[i]|`.
4067#[cfg(feature = "cuda")]
4068pub(crate) const ABS_PTX: &str = "\
4069.version 7.0
4070.target sm_52
4071.address_size 64
4072
4073.visible .entry abs_kernel(
4074    .param .u64 a_ptr,
4075    .param .u64 out_ptr,
4076    .param .u32 n
4077) {
4078    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4079    .reg .u64 %a, %out, %off;
4080    .reg .f32 %va, %vr;
4081    .reg .pred %p;
4082
4083    ld.param.u64 %a, [a_ptr];
4084    ld.param.u64 %out, [out_ptr];
4085    ld.param.u32 %n_reg, [n];
4086
4087    mov.u32 %bid, %ctaid.x;
4088    mov.u32 %bdim, %ntid.x;
4089    mov.u32 %r_tid, %tid.x;
4090    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4091
4092    setp.ge.u32 %p, %r_tid, %n_reg;
4093    @%p bra DONE;
4094
4095    cvt.u64.u32 %off, %r_tid;
4096    shl.b64 %off, %off, 2;
4097
4098    add.u64 %a, %a, %off;
4099    add.u64 %out, %out, %off;
4100
4101    ld.global.f32 %va, [%a];
4102    abs.f32 %vr, %va;
4103    st.global.f32 [%out], %vr;
4104
4105DONE:
4106    ret;
4107}
4108";
4109
4110/// PTX source for `sigmoid_kernel`: `out[i] = 1 / (1 + exp(-a[i]))`.
4111#[cfg(feature = "cuda")]
4112pub(crate) const SIGMOID_PTX: &str = "\
4113.version 7.0
4114.target sm_52
4115.address_size 64
4116
4117.visible .entry sigmoid_kernel(
4118    .param .u64 a_ptr,
4119    .param .u64 out_ptr,
4120    .param .u32 n
4121) {
4122    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4123    .reg .u64 %a, %out, %off;
4124    .reg .f32 %va, %vr, %neg, %e, %denom, %one, %lg2e;
4125    .reg .pred %p;
4126
4127    ld.param.u64 %a, [a_ptr];
4128    ld.param.u64 %out, [out_ptr];
4129    ld.param.u32 %n_reg, [n];
4130
4131    mov.u32 %bid, %ctaid.x;
4132    mov.u32 %bdim, %ntid.x;
4133    mov.u32 %r_tid, %tid.x;
4134    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4135
4136    setp.ge.u32 %p, %r_tid, %n_reg;
4137    @%p bra DONE;
4138
4139    cvt.u64.u32 %off, %r_tid;
4140    shl.b64 %off, %off, 2;
4141
4142    add.u64 %a, %a, %off;
4143    add.u64 %out, %out, %off;
4144
4145    ld.global.f32 %va, [%a];
4146    // sigmoid(x) = 1 / (1 + exp(-x))
4147    neg.f32 %neg, %va;
4148    mov.f32 %lg2e, 0f3FB8AA3B;
4149    mul.f32 %neg, %neg, %lg2e;
4150    ex2.approx.f32 %e, %neg;
4151    mov.f32 %one, 0f3F800000;
4152    add.f32 %denom, %one, %e;
4153    div.rn.f32 %vr, %one, %denom;
4154    st.global.f32 [%out], %vr;
4155
4156DONE:
4157    ret;
4158}
4159";
4160
4161/// PTX source for `tanh_kernel`: `out[i] = tanh(a[i])`.
4162/// Uses the identity: tanh(x) = 2*sigmoid(2x) - 1.
4163#[cfg(feature = "cuda")]
4164pub(crate) const TANH_PTX: &str = "\
4165.version 7.0
4166.target sm_52
4167.address_size 64
4168
4169.visible .entry tanh_kernel(
4170    .param .u64 a_ptr,
4171    .param .u64 out_ptr,
4172    .param .u32 n
4173) {
4174    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4175    .reg .u64 %a, %out, %off;
4176    .reg .f32 %va, %vr, %neg2x, %e, %denom, %sig, %one, %two, %lg2e;
4177    .reg .pred %p;
4178
4179    ld.param.u64 %a, [a_ptr];
4180    ld.param.u64 %out, [out_ptr];
4181    ld.param.u32 %n_reg, [n];
4182
4183    mov.u32 %bid, %ctaid.x;
4184    mov.u32 %bdim, %ntid.x;
4185    mov.u32 %r_tid, %tid.x;
4186    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4187
4188    setp.ge.u32 %p, %r_tid, %n_reg;
4189    @%p bra DONE;
4190
4191    cvt.u64.u32 %off, %r_tid;
4192    shl.b64 %off, %off, 2;
4193
4194    add.u64 %a, %a, %off;
4195    add.u64 %out, %out, %off;
4196
4197    ld.global.f32 %va, [%a];
4198    // tanh(x) = 2*sigmoid(2x) - 1
4199    mov.f32 %two, 0f40000000;
4200    mul.f32 %neg2x, %va, %two;
4201    neg.f32 %neg2x, %neg2x;
4202    mov.f32 %lg2e, 0f3FB8AA3B;
4203    mul.f32 %neg2x, %neg2x, %lg2e;
4204    ex2.approx.f32 %e, %neg2x;
4205    mov.f32 %one, 0f3F800000;
4206    add.f32 %denom, %one, %e;
4207    div.rn.f32 %sig, %one, %denom;
4208    mul.f32 %vr, %two, %sig;
4209    sub.f32 %vr, %vr, %one;
4210    st.global.f32 [%out], %vr;
4211
4212DONE:
4213    ret;
4214}
4215";
4216
4217/// PTX source for `fused_adam_kernel`: in-place Adam optimizer update.
4218///
4219/// For each element i:
4220///   g = grad[i] + weight_decay * param[i]  (if wd > 0)
4221///   exp_avg[i] = beta1 * exp_avg[i] + (1-beta1) * g
4222///   exp_avg_sq[i] = beta2 * exp_avg_sq[i] + (1-beta2) * g * g
4223///   m_hat = exp_avg[i] / bc1
4224///   v_hat = exp_avg_sq[i] / bc2
4225///   param[i] = param[i] - lr * m_hat / (sqrt(v_hat) + eps)
4226#[cfg(feature = "cuda")]
4227pub(crate) const FUSED_ADAM_PTX: &str = "\
4228.version 7.0
4229.target sm_52
4230.address_size 64
4231
4232.visible .entry fused_adam_kernel(
4233    .param .u64 param_ptr,
4234    .param .u64 grad_ptr,
4235    .param .u64 exp_avg_ptr,
4236    .param .u64 exp_avg_sq_ptr,
4237    .param .f32 beta1,
4238    .param .f32 beta2,
4239    .param .f32 lr,
4240    .param .f32 eps,
4241    .param .f32 bc1,
4242    .param .f32 bc2,
4243    .param .f32 weight_decay,
4244    .param .u32 n
4245) {
4246    .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4247    .reg .u64 %p, %g, %m, %v, %off;
4248    .reg .f32 %vp, %vg, %vm, %vv;
4249    .reg .f32 %b1, %b2, %f_lr, %f_eps, %f_bc1, %f_bc2, %f_wd;
4250    .reg .f32 %t1, %t2, %m_hat, %v_hat, %denom, %update;
4251    .reg .f32 %one;
4252    .reg .pred %p_bound, %p_wd;
4253
4254    ld.param.u64 %p, [param_ptr];
4255    ld.param.u64 %g, [grad_ptr];
4256    ld.param.u64 %m, [exp_avg_ptr];
4257    ld.param.u64 %v, [exp_avg_sq_ptr];
4258    ld.param.f32 %b1, [beta1];
4259    ld.param.f32 %b2, [beta2];
4260    ld.param.f32 %f_lr, [lr];
4261    ld.param.f32 %f_eps, [eps];
4262    ld.param.f32 %f_bc1, [bc1];
4263    ld.param.f32 %f_bc2, [bc2];
4264    ld.param.f32 %f_wd, [weight_decay];
4265    ld.param.u32 %n_reg, [n];
4266
4267    mov.u32 %bid, %ctaid.x;
4268    mov.u32 %bdim, %ntid.x;
4269    mov.u32 %r_tid, %tid.x;
4270    mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4271
4272    setp.ge.u32 %p_bound, %r_tid, %n_reg;
4273    @%p_bound bra DONE;
4274
4275    cvt.u64.u32 %off, %r_tid;
4276    shl.b64 %off, %off, 2;
4277
4278    add.u64 %p, %p, %off;
4279    add.u64 %g, %g, %off;
4280    add.u64 %m, %m, %off;
4281    add.u64 %v, %v, %off;
4282
4283    ld.global.f32 %vp, [%p];
4284    ld.global.f32 %vg, [%g];
4285    ld.global.f32 %vm, [%m];
4286    ld.global.f32 %vv, [%v];
4287
4288    // L2 weight decay: g = g + wd * p
4289    mov.f32 %one, 0f00000000;
4290    setp.gt.f32 %p_wd, %f_wd, %one;
4291    @%p_wd fma.rn.f32 %vg, %f_wd, %vp, %vg;
4292
4293    // exp_avg = beta1 * exp_avg + (1 - beta1) * g
4294    mov.f32 %one, 0f3F800000;
4295    sub.f32 %t1, %one, %b1;
4296    mul.f32 %vm, %vm, %b1;
4297    fma.rn.f32 %vm, %t1, %vg, %vm;
4298
4299    // exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * g * g
4300    sub.f32 %t2, %one, %b2;
4301    mul.f32 %vv, %vv, %b2;
4302    mul.f32 %t1, %vg, %vg;
4303    fma.rn.f32 %vv, %t2, %t1, %vv;
4304
4305    // m_hat = exp_avg / bc1
4306    div.rn.f32 %m_hat, %vm, %f_bc1;
4307
4308    // v_hat = exp_avg_sq / bc2
4309    div.rn.f32 %v_hat, %vv, %f_bc2;
4310
4311    // denom = sqrt(v_hat) + eps
4312    sqrt.rn.f32 %denom, %v_hat;
4313    add.f32 %denom, %denom, %f_eps;
4314
4315    // param = param - lr * m_hat / denom
4316    div.rn.f32 %update, %m_hat, %denom;
4317    mul.f32 %update, %update, %f_lr;
4318    sub.f32 %vp, %vp, %update;
4319
4320    st.global.f32 [%p], %vp;
4321    st.global.f32 [%m], %vm;
4322    st.global.f32 [%v], %vv;
4323
4324DONE:
4325    ret;
4326}
4327";
4328
4329/// PTX source for fused GRU cell forward kernel.
4330///
4331/// Takes pre-computed input_gates [B, 3*H] and hidden_gates [B, 3*H]
4332/// (from cuBLAS GEMMs), biases, and previous hidden state. Computes all
4333/// gate activations and the new hidden state in a single kernel launch.
4334///
4335/// One thread per hidden unit. Each thread reads 3 values from input_gates
4336/// and 3 from hidden_gates, applies sigmoid/tanh, computes the GRU update,
4337/// and writes hy + workspace (5*H values for backward).
4338///
4339/// Matches PyTorch's _thnn_fused_gru_cell kernel from RNN.cu.
4340#[cfg(feature = "cuda")]
4341pub(crate) const FUSED_GRU_FORWARD_PTX: &str = "\
4342.version 7.0
4343.target sm_52
4344.address_size 64
4345
4346.visible .entry fused_gru_forward_kernel(
4347    .param .u64 input_gates_ptr,
4348    .param .u64 hidden_gates_ptr,
4349    .param .u64 bias_ih_ptr,
4350    .param .u64 bias_hh_ptr,
4351    .param .u64 hx_ptr,
4352    .param .u64 hy_ptr,
4353    .param .u64 workspace_ptr,
4354    .param .u32 hsz,
4355    .param .u32 total
4356) {
4357    .reg .u32 %tid, %bid, %bdim, %gdim, %total_reg, %hsz_reg;
4358    .reg .u32 %idx, %stride, %offset3, %offset5, %hmod, %batch_idx;
4359    .reg .u64 %ig, %hg, %b1, %b2, %hx, %hy, %ws;
4360    .reg .u64 %off64, %tmp64;
4361    .reg .f32 %ir, %ii, %in, %hr, %hi, %hn;
4362    .reg .f32 %b1r, %b1i, %b1n, %b2r, %b2i, %b2n;
4363    .reg .f32 %hx_val, %rg, %zg, %ng, %hy_val;
4364    .reg .f32 %one, %neg_one, %exp_val, %denom, %tmp;
4365    .reg .pred %p;
4366
4367    ld.param.u64 %ig, [input_gates_ptr];
4368    ld.param.u64 %hg, [hidden_gates_ptr];
4369    ld.param.u64 %b1, [bias_ih_ptr];
4370    ld.param.u64 %b2, [bias_hh_ptr];
4371    ld.param.u64 %hx, [hx_ptr];
4372    ld.param.u64 %hy, [hy_ptr];
4373    ld.param.u64 %ws, [workspace_ptr];
4374    ld.param.u32 %hsz_reg, [hsz];
4375    ld.param.u32 %total_reg, [total];
4376
4377    mov.u32 %bid, %ctaid.x;
4378    mov.u32 %bdim, %ntid.x;
4379    mov.u32 %tid, %tid.x;
4380    mov.u32 %gdim, %nctaid.x;
4381    mad.lo.u32 %idx, %bid, %bdim, %tid;
4382    mul.lo.u32 %stride, %bdim, %gdim;
4383    mov.f32 %one, 0f3F800000;
4384
4385LOOP:
4386    setp.ge.u32 %p, %idx, %total_reg;
4387    @%p bra END;
4388
4389    // offset3 = (idx/hsz)*3*hsz + idx%hsz  (into [B, 3*H] gates tensor)
4390    div.u32 %batch_idx, %idx, %hsz_reg;
4391    rem.u32 %hmod, %idx, %hsz_reg;
4392    mul.lo.u32 %offset3, %batch_idx, %hsz_reg;
4393    mul.lo.u32 %offset3, %offset3, 3;
4394    add.u32 %offset3, %offset3, %hmod;
4395
4396    // Load input gate components: ir, ii, in
4397    cvt.u64.u32 %off64, %offset3;
4398    shl.b64 %off64, %off64, 2;
4399    add.u64 %tmp64, %ig, %off64;
4400    ld.global.f32 %ir, [%tmp64];
4401    cvt.u64.u32 %off64, %hsz_reg;
4402    shl.b64 %off64, %off64, 2;
4403    add.u64 %tmp64, %tmp64, %off64;
4404    ld.global.f32 %ii, [%tmp64];
4405    add.u64 %tmp64, %tmp64, %off64;
4406    ld.global.f32 %in, [%tmp64];
4407
4408    // Load hidden gate components: hr, hi, hn
4409    cvt.u64.u32 %off64, %offset3;
4410    shl.b64 %off64, %off64, 2;
4411    add.u64 %tmp64, %hg, %off64;
4412    ld.global.f32 %hr, [%tmp64];
4413    cvt.u64.u32 %off64, %hsz_reg;
4414    shl.b64 %off64, %off64, 2;
4415    add.u64 %tmp64, %tmp64, %off64;
4416    ld.global.f32 %hi, [%tmp64];
4417    add.u64 %tmp64, %tmp64, %off64;
4418    ld.global.f32 %hn, [%tmp64];
4419
4420    // Load biases (indexed by hmod, hmod+hsz, hmod+2*hsz)
4421    cvt.u64.u32 %off64, %hmod;
4422    shl.b64 %off64, %off64, 2;
4423    add.u64 %tmp64, %b1, %off64;
4424    ld.global.f32 %b1r, [%tmp64];
4425    cvt.u64.u32 %off64, %hsz_reg;
4426    shl.b64 %off64, %off64, 2;
4427    add.u64 %tmp64, %tmp64, %off64;
4428    ld.global.f32 %b1i, [%tmp64];
4429    add.u64 %tmp64, %tmp64, %off64;
4430    ld.global.f32 %b1n, [%tmp64];
4431
4432    cvt.u64.u32 %off64, %hmod;
4433    shl.b64 %off64, %off64, 2;
4434    add.u64 %tmp64, %b2, %off64;
4435    ld.global.f32 %b2r, [%tmp64];
4436    cvt.u64.u32 %off64, %hsz_reg;
4437    shl.b64 %off64, %off64, 2;
4438    add.u64 %tmp64, %tmp64, %off64;
4439    ld.global.f32 %b2i, [%tmp64];
4440    add.u64 %tmp64, %tmp64, %off64;
4441    ld.global.f32 %b2n, [%tmp64];
4442
4443    // Load hx[idx]
4444    cvt.u64.u32 %off64, %idx;
4445    shl.b64 %off64, %off64, 2;
4446    add.u64 %tmp64, %hx, %off64;
4447    ld.global.f32 %hx_val, [%tmp64];
4448
4449    // r = sigmoid(ir + hr + b1r + b2r)
4450    add.f32 %rg, %ir, %hr;
4451    add.f32 %rg, %rg, %b1r;
4452    add.f32 %rg, %rg, %b2r;
4453    neg.f32 %tmp, %rg;
4454    mul.f32 %tmp, %tmp, 0f3FB8AA3B;
4455    ex2.approx.f32 %exp_val, %tmp;
4456    add.f32 %denom, %one, %exp_val;
4457    div.rn.f32 %rg, %one, %denom;
4458
4459    // z = sigmoid(ii + hi + b1i + b2i)
4460    add.f32 %zg, %ii, %hi;
4461    add.f32 %zg, %zg, %b1i;
4462    add.f32 %zg, %zg, %b2i;
4463    neg.f32 %tmp, %zg;
4464    mul.f32 %tmp, %tmp, 0f3FB8AA3B;
4465    ex2.approx.f32 %exp_val, %tmp;
4466    add.f32 %denom, %one, %exp_val;
4467    div.rn.f32 %zg, %one, %denom;
4468
4469    // n = tanh(in + b1n + r*(hn + b2n))
4470    add.f32 %tmp, %hn, %b2n;
4471    fma.rn.f32 %ng, %rg, %tmp, %in;
4472    add.f32 %ng, %ng, %b1n;
4473    // tanh via 2*sigmoid(2x)-1
4474    mul.f32 %tmp, %ng, 0f40000000;
4475    neg.f32 %tmp, %tmp;
4476    mul.f32 %tmp, %tmp, 0f3FB8AA3B;
4477    ex2.approx.f32 %exp_val, %tmp;
4478    add.f32 %denom, %one, %exp_val;
4479    div.rn.f32 %ng, %one, %denom;
4480    mul.f32 %ng, %ng, 0f40000000;
4481    sub.f32 %ng, %ng, %one;
4482
4483    // hy = n + z * (hx - n)
4484    sub.f32 %tmp, %hx_val, %ng;
4485    fma.rn.f32 %hy_val, %zg, %tmp, %ng;
4486
4487    // Store hy[idx]
4488    cvt.u64.u32 %off64, %idx;
4489    shl.b64 %off64, %off64, 2;
4490    add.u64 %tmp64, %hy, %off64;
4491    st.global.f32 [%tmp64], %hy_val;
4492
4493    // Store workspace: [r, z, n, hx, hn+b2n] at offset5 = (idx/hsz)*5*hsz + idx%hsz
4494    mul.lo.u32 %offset5, %batch_idx, %hsz_reg;
4495    mul.lo.u32 %offset5, %offset5, 5;
4496    add.u32 %offset5, %offset5, %hmod;
4497
4498    cvt.u64.u32 %off64, %offset5;
4499    shl.b64 %off64, %off64, 2;
4500    add.u64 %tmp64, %ws, %off64;
4501    st.global.f32 [%tmp64], %rg;
4502    cvt.u64.u32 %off64, %hsz_reg;
4503    shl.b64 %off64, %off64, 2;
4504    add.u64 %tmp64, %tmp64, %off64;
4505    st.global.f32 [%tmp64], %zg;
4506    add.u64 %tmp64, %tmp64, %off64;
4507    st.global.f32 [%tmp64], %ng;
4508    add.u64 %tmp64, %tmp64, %off64;
4509    st.global.f32 [%tmp64], %hx_val;
4510    add.u64 %tmp64, %tmp64, %off64;
4511    add.f32 %tmp, %hn, %b2n;
4512    st.global.f32 [%tmp64], %tmp;
4513
4514    add.u32 %idx, %idx, %stride;
4515    bra LOOP;
4516
4517END:
4518    ret;
4519}
4520";
4521
4522// ---------------------------------------------------------------------------
4523// Launch configuration helper
4524// ---------------------------------------------------------------------------
4525
4526/// Standard 1-D launch config for `n` elements.
4527///
4528/// Uses 256 threads per block, which is a good default for elementwise ops
4529/// on all modern NVIDIA architectures.
4530///
4531/// # Errors
4532///
4533/// Returns [`GpuError::ShapeMismatch`] if `n` exceeds `u32::MAX`, which
4534/// would silently truncate the grid dimension.
4535#[cfg(feature = "cuda")]
4536fn launch_cfg(n: usize) -> GpuResult<LaunchConfig> {
4537    if n > u32::MAX as usize {
4538        return Err(GpuError::ShapeMismatch {
4539            op: "kernel_launch",
4540            expected: vec![u32::MAX as usize],
4541            got: vec![n],
4542        });
4543    }
4544    const BLOCK: u32 = 256;
4545    let grid = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
4546    Ok(LaunchConfig {
4547        grid_dim: (grid.max(1), 1, 1),
4548        block_dim: (BLOCK, 1, 1),
4549        shared_mem_bytes: 0,
4550    })
4551}
4552
4553// ---------------------------------------------------------------------------
4554// Validation helpers
4555// ---------------------------------------------------------------------------
4556
4557/// Validate that two buffers are on the same device and have the same length.
4558#[cfg(feature = "cuda")]
4559fn validate_binary(a: &CudaBuffer<f32>, b: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
4560    if a.device_ordinal() != device.ordinal() {
4561        return Err(GpuError::DeviceMismatch {
4562            expected: a.device_ordinal(),
4563            got: device.ordinal(),
4564        });
4565    }
4566    if b.device_ordinal() != device.ordinal() {
4567        return Err(GpuError::DeviceMismatch {
4568            expected: b.device_ordinal(),
4569            got: device.ordinal(),
4570        });
4571    }
4572    if a.len() != b.len() {
4573        return Err(GpuError::LengthMismatch {
4574            a: a.len(),
4575            b: b.len(),
4576        });
4577    }
4578    Ok(())
4579}
4580
4581/// Validate that a unary buffer is on the correct device.
4582#[cfg(feature = "cuda")]
4583fn validate_unary(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
4584    if a.device_ordinal() != device.ordinal() {
4585        return Err(GpuError::DeviceMismatch {
4586            expected: a.device_ordinal(),
4587            got: device.ordinal(),
4588        });
4589    }
4590    Ok(())
4591}
4592
4593// ---------------------------------------------------------------------------
4594// PTX kernel launch helpers
4595// ---------------------------------------------------------------------------
4596
4597/// Try to launch a binary PTX kernel. Returns `Ok(Some(buf))` on success,
4598/// `Ok(None)` if the PTX module failed to load (caller should fall back to
4599/// CPU), or `Err` on a real CUDA error after a successful launch.
4600#[cfg(feature = "cuda")]
4601fn try_launch_binary(
4602    a: &CudaBuffer<f32>,
4603    b: &CudaBuffer<f32>,
4604    device: &GpuDevice,
4605    ptx_src: &'static str,
4606    kernel_name: &'static str,
4607) -> GpuResult<Option<CudaBuffer<f32>>> {
4608    use cudarc::driver::PushKernelArg;
4609
4610    let n = a.len();
4611    let ctx = device.context();
4612    let stream = device.stream();
4613
4614    // Attempt to load the kernel (cached after first compilation).
4615    // If it fails (e.g. unsupported arch), return None so the caller
4616    // can use the CPU fallback.
4617    let f = match crate::module_cache::get_or_compile(
4618        ctx,
4619        ptx_src,
4620        kernel_name,
4621        device.ordinal() as u32,
4622    ) {
4623        Ok(f) => f,
4624        Err(_) => return Ok(None),
4625    };
4626
4627    let mut out = alloc_zeros_f32(n, device)?;
4628    let cfg = launch_cfg(n)?;
4629    let n_u32 = n as u32;
4630
4631    // SAFETY: The kernel reads `n` f32 values from `a` and `b`, writes `n`
4632    // f32 values to `out`. All three buffers are device-resident and at
4633    // least `n` elements long. The grid covers exactly `n` threads.
4634    unsafe {
4635        stream
4636            .launch_builder(&f)
4637            .arg(a.inner())
4638            .arg(b.inner())
4639            .arg(out.inner_mut())
4640            .arg(&n_u32)
4641            .launch(cfg)?;
4642    }
4643
4644    Ok(Some(out))
4645}
4646
4647/// Try to launch a vectorized (vec4) binary PTX kernel.
4648///
4649/// Each thread processes 4 elements using 128-bit loads/stores.
4650/// `n` must be divisible by 4. Returns `Ok(None)` if compilation fails.
4651#[cfg(feature = "cuda")]
4652fn try_launch_binary_vec4(
4653    a: &CudaBuffer<f32>,
4654    b: &CudaBuffer<f32>,
4655    device: &GpuDevice,
4656    ptx_src: &'static str,
4657    kernel_name: &'static str,
4658) -> GpuResult<Option<CudaBuffer<f32>>> {
4659    use cudarc::driver::PushKernelArg;
4660
4661    let n = a.len();
4662    let n4 = (n / 4) as u32;
4663    let ctx = device.context();
4664    let stream = device.stream();
4665
4666    let f = match crate::module_cache::get_or_compile(
4667        ctx,
4668        ptx_src,
4669        kernel_name,
4670        device.ordinal() as u32,
4671    ) {
4672        Ok(f) => f,
4673        Err(_) => return Ok(None),
4674    };
4675
4676    let mut out = alloc_zeros_f32(n, device)?;
4677    let cfg = launch_cfg(n4 as usize)?;
4678
4679    unsafe {
4680        stream
4681            .launch_builder(&f)
4682            .arg(a.inner())
4683            .arg(b.inner())
4684            .arg(out.inner_mut())
4685            .arg(&n4)
4686            .launch(cfg)?;
4687    }
4688
4689    Ok(Some(out))
4690}
4691
4692/// Try to launch a unary PTX kernel. Returns `Ok(Some(buf))` on success,
4693/// `Ok(None)` if the PTX module failed to load.
4694#[cfg(feature = "cuda")]
4695fn try_launch_unary(
4696    a: &CudaBuffer<f32>,
4697    device: &GpuDevice,
4698    ptx_src: &'static str,
4699    kernel_name: &'static str,
4700) -> GpuResult<Option<CudaBuffer<f32>>> {
4701    use cudarc::driver::PushKernelArg;
4702
4703    let n = a.len();
4704    let ctx = device.context();
4705    let stream = device.stream();
4706
4707    // Attempt to load the kernel (cached after first compilation).
4708    let f = match crate::module_cache::get_or_compile(
4709        ctx,
4710        ptx_src,
4711        kernel_name,
4712        device.ordinal() as u32,
4713    ) {
4714        Ok(f) => f,
4715        Err(_) => return Ok(None),
4716    };
4717
4718    let mut out = alloc_zeros_f32(n, device)?;
4719    let cfg = launch_cfg(n)?;
4720    let n_u32 = n as u32;
4721
4722    // SAFETY: The kernel reads `n` f32 values from `a` and writes `n` f32
4723    // values to `out`. Both buffers are device-resident with length >= n.
4724    unsafe {
4725        stream
4726            .launch_builder(&f)
4727            .arg(a.inner())
4728            .arg(out.inner_mut())
4729            .arg(&n_u32)
4730            .launch(cfg)?;
4731    }
4732
4733    Ok(Some(out))
4734}
4735
4736// ---------------------------------------------------------------------------
4737// _into helpers — write to pre-allocated output buffer (no allocation)
4738// ---------------------------------------------------------------------------
4739
4740/// Launch a binary PTX kernel into a pre-allocated output buffer.
4741/// Returns `Ok(true)` on success, `Ok(false)` if the PTX module failed to load.
4742#[cfg(feature = "cuda")]
4743fn try_launch_binary_into(
4744    a: &CudaBuffer<f32>,
4745    b: &CudaBuffer<f32>,
4746    out: &mut CudaBuffer<f32>,
4747    device: &GpuDevice,
4748    ptx_src: &'static str,
4749    kernel_name: &'static str,
4750) -> GpuResult<bool> {
4751    use cudarc::driver::PushKernelArg;
4752
4753    let n = a.len();
4754    let ctx = device.context();
4755    let stream = device.stream();
4756
4757    let f = match crate::module_cache::get_or_compile(
4758        ctx,
4759        ptx_src,
4760        kernel_name,
4761        device.ordinal() as u32,
4762    ) {
4763        Ok(f) => f,
4764        Err(_) => return Ok(false),
4765    };
4766
4767    let cfg = launch_cfg(n)?;
4768    let n_u32 = n as u32;
4769
4770    unsafe {
4771        stream
4772            .launch_builder(&f)
4773            .arg(a.inner())
4774            .arg(b.inner())
4775            .arg(out.inner_mut())
4776            .arg(&n_u32)
4777            .launch(cfg)?;
4778    }
4779
4780    Ok(true)
4781}
4782
4783/// Launch a unary PTX kernel into a pre-allocated output buffer.
4784/// Returns `Ok(true)` on success, `Ok(false)` if the PTX module failed to load.
4785#[cfg(feature = "cuda")]
4786fn try_launch_unary_into(
4787    a: &CudaBuffer<f32>,
4788    out: &mut CudaBuffer<f32>,
4789    device: &GpuDevice,
4790    ptx_src: &'static str,
4791    kernel_name: &'static str,
4792) -> GpuResult<bool> {
4793    use cudarc::driver::PushKernelArg;
4794
4795    let n = a.len();
4796    let ctx = device.context();
4797    let stream = device.stream();
4798
4799    let f = match crate::module_cache::get_or_compile(
4800        ctx,
4801        ptx_src,
4802        kernel_name,
4803        device.ordinal() as u32,
4804    ) {
4805        Ok(f) => f,
4806        Err(_) => return Ok(false),
4807    };
4808
4809    let cfg = launch_cfg(n)?;
4810    let n_u32 = n as u32;
4811
4812    unsafe {
4813        stream
4814            .launch_builder(&f)
4815            .arg(a.inner())
4816            .arg(out.inner_mut())
4817            .arg(&n_u32)
4818            .launch(cfg)?;
4819    }
4820
4821    Ok(true)
4822}
4823
4824/// Try to launch a general N-dimensional broadcast binary PTX kernel.
4825///
4826/// `a_strides` and `b_strides` are broadcast strides: normal C-contiguous
4827/// stride for non-broadcast dims, 0 for broadcast (size-1) dims.
4828/// `out_shape` is the broadcast-resolved output shape.
4829/// All three arrays have length `ndim`.
4830#[cfg(feature = "cuda")]
4831#[allow(clippy::too_many_arguments)]
4832fn try_launch_broadcast_binary(
4833    a: &CudaBuffer<f32>,
4834    b: &CudaBuffer<f32>,
4835    a_strides: &[u32],
4836    b_strides: &[u32],
4837    out_shape: &[u32],
4838    out_numel: usize,
4839    device: &GpuDevice,
4840    ptx_src: &'static str,
4841    kernel_name: &'static str,
4842) -> GpuResult<Option<CudaBuffer<f32>>> {
4843    use cudarc::driver::PushKernelArg;
4844
4845    let ndim = out_shape.len();
4846    let ctx = device.context();
4847    let stream = device.stream();
4848
4849    let f = match crate::module_cache::get_or_compile(
4850        ctx,
4851        ptx_src,
4852        kernel_name,
4853        device.ordinal() as u32,
4854    ) {
4855        Ok(f) => f,
4856        Err(_) => return Ok(None),
4857    };
4858
4859    // Upload stride/shape metadata as small device buffers.
4860    let a_str_buf = cpu_to_gpu(a_strides, device)?;
4861    let b_str_buf = cpu_to_gpu(b_strides, device)?;
4862    let shape_buf = cpu_to_gpu(out_shape, device)?;
4863
4864    let mut out = alloc_zeros_f32(out_numel, device)?;
4865    let cfg = launch_cfg(out_numel)?;
4866    let n_u32 = out_numel as u32;
4867    let ndim_u32 = ndim as u32;
4868
4869    // SAFETY: Kernel reads from a, b using broadcast indices computed from
4870    // the stride/shape buffers. Output buffer has out_numel elements.
4871    unsafe {
4872        stream
4873            .launch_builder(&f)
4874            .arg(a.inner())
4875            .arg(b.inner())
4876            .arg(out.inner_mut())
4877            .arg(a_str_buf.inner())
4878            .arg(b_str_buf.inner())
4879            .arg(shape_buf.inner())
4880            .arg(&n_u32)
4881            .arg(&ndim_u32)
4882            .launch(cfg)?;
4883    }
4884
4885    Ok(Some(out))
4886}
4887
4888/// Compute broadcast strides for a tensor shape relative to an output shape.
4889///
4890/// For each dimension, the stride is the normal C-contiguous stride if the
4891/// dimension size matches the output, or 0 if the dimension size is 1
4892/// (broadcast). Missing leading dimensions (when input has fewer dims) are
4893/// treated as size-1.
4894#[cfg(feature = "cuda")]
4895fn broadcast_strides(in_shape: &[usize], out_shape: &[usize]) -> Vec<u32> {
4896    let ndim = out_shape.len();
4897    let in_ndim = in_shape.len();
4898    let mut strides = vec![0u32; ndim];
4899
4900    // C-contiguous strides for the input shape.
4901    let mut stride: u32 = 1;
4902    for d in (0..ndim).rev() {
4903        let in_d = if d + in_ndim >= ndim {
4904            d + in_ndim - ndim
4905        } else {
4906            // Leading dimension not present in input — broadcast.
4907            strides[d] = 0;
4908            continue;
4909        };
4910
4911        if in_shape[in_d] == 1 {
4912            strides[d] = 0; // Broadcast dimension.
4913        } else {
4914            strides[d] = stride;
4915        }
4916        stride *= in_shape[in_d] as u32;
4917    }
4918
4919    strides
4920}
4921
4922// ---------------------------------------------------------------------------
4923// CPU fallback helpers
4924// ---------------------------------------------------------------------------
4925
4926/// CPU fallback for binary ops: copy both inputs to host, apply `op`, copy
4927/// the result back.
4928#[cfg(feature = "cuda")]
4929fn cpu_fallback_binary(
4930    a: &CudaBuffer<f32>,
4931    b: &CudaBuffer<f32>,
4932    device: &GpuDevice,
4933    op: fn(f32, f32) -> f32,
4934) -> GpuResult<CudaBuffer<f32>> {
4935    let a_host = gpu_to_cpu(a, device)?;
4936    let b_host = gpu_to_cpu(b, device)?;
4937    let result: Vec<f32> = a_host
4938        .iter()
4939        .zip(b_host.iter())
4940        .map(|(&x, &y)| op(x, y))
4941        .collect();
4942    cpu_to_gpu(&result, device)
4943}
4944
4945/// CPU fallback for unary ops.
4946#[cfg(feature = "cuda")]
4947fn cpu_fallback_unary(
4948    a: &CudaBuffer<f32>,
4949    device: &GpuDevice,
4950    op: fn(f32) -> f32,
4951) -> GpuResult<CudaBuffer<f32>> {
4952    let a_host = gpu_to_cpu(a, device)?;
4953    let result: Vec<f32> = a_host.iter().map(|&x| op(x)).collect();
4954    cpu_to_gpu(&result, device)
4955}
4956
4957// ---------------------------------------------------------------------------
4958// Public API -- binary ops
4959// ---------------------------------------------------------------------------
4960
4961/// Elementwise addition: `out[i] = a[i] + b[i]`.
4962///
4963/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
4964/// if the PTX module cannot be loaded.
4965///
4966/// # Errors
4967///
4968/// - [`GpuError::DeviceMismatch`] if `a`, `b`, or `device` refer to
4969///   different CUDA devices.
4970/// - [`GpuError::LengthMismatch`] if `a` and `b` have different lengths.
4971/// - [`GpuError::Driver`] on CUDA runtime errors.
4972#[cfg(feature = "cuda")]
4973pub fn gpu_add(
4974    a: &CudaBuffer<f32>,
4975    b: &CudaBuffer<f32>,
4976    device: &GpuDevice,
4977) -> GpuResult<CudaBuffer<f32>> {
4978    validate_binary(a, b, device)?;
4979
4980    // Try vec4 kernel for 4x memory throughput (128-bit loads).
4981    let n = a.len();
4982    if n >= 16 && n % 4 == 0 {
4983        if let Some(out) = try_launch_binary_vec4(
4984            a, b, device, ADD_VEC4_PTX, "add_vec4_kernel",
4985        )? {
4986            return Ok(out);
4987        }
4988    }
4989
4990    if let Some(out) = try_launch_binary(a, b, device, ADD_PTX, "add_kernel")? {
4991        return Ok(out);
4992    }
4993
4994    cpu_fallback_binary(a, b, device, |x, y| x + y)
4995}
4996
4997/// Elementwise subtraction: `out[i] = a[i] - b[i]`.
4998///
4999/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
5000/// if the PTX module cannot be loaded.
5001///
5002/// # Errors
5003///
5004/// - [`GpuError::DeviceMismatch`] if `a`, `b`, or `device` refer to
5005///   different CUDA devices.
5006/// - [`GpuError::LengthMismatch`] if `a` and `b` have different lengths.
5007/// - [`GpuError::Driver`] on CUDA runtime errors.
5008#[cfg(feature = "cuda")]
5009pub fn gpu_sub(
5010    a: &CudaBuffer<f32>,
5011    b: &CudaBuffer<f32>,
5012    device: &GpuDevice,
5013) -> GpuResult<CudaBuffer<f32>> {
5014    validate_binary(a, b, device)?;
5015
5016    if let Some(out) = try_launch_binary(a, b, device, SUB_PTX, "sub_kernel")? {
5017        return Ok(out);
5018    }
5019
5020    cpu_fallback_binary(a, b, device, |x, y| x - y)
5021}
5022
5023/// Elementwise multiplication: `out[i] = a[i] * b[i]`.
5024///
5025/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
5026/// if the PTX module cannot be loaded.
5027///
5028/// # Errors
5029///
5030/// - [`GpuError::DeviceMismatch`] if `a`, `b`, or `device` refer to
5031///   different CUDA devices.
5032/// - [`GpuError::LengthMismatch`] if `a` and `b` have different lengths.
5033/// - [`GpuError::Driver`] on CUDA runtime errors.
5034#[cfg(feature = "cuda")]
5035pub fn gpu_mul(
5036    a: &CudaBuffer<f32>,
5037    b: &CudaBuffer<f32>,
5038    device: &GpuDevice,
5039) -> GpuResult<CudaBuffer<f32>> {
5040    validate_binary(a, b, device)?;
5041
5042    let n = a.len();
5043    if n >= 16 && n % 4 == 0 {
5044        if let Some(out) = try_launch_binary_vec4(
5045            a, b, device, MUL_VEC4_PTX, "mul_vec4_kernel",
5046        )? {
5047            return Ok(out);
5048        }
5049    }
5050
5051    if let Some(out) = try_launch_binary(a, b, device, MUL_PTX, "mul_kernel")? {
5052        return Ok(out);
5053    }
5054
5055    cpu_fallback_binary(a, b, device, |x, y| x * y)
5056}
5057
5058// ---------------------------------------------------------------------------
5059// Public API -- broadcast binary ops
5060// ---------------------------------------------------------------------------
5061
5062/// Broadcast addition: `out[i] = a[bcast_a(i)] + b[bcast_b(i)]`.
5063///
5064/// Handles arbitrary N-dimensional broadcasting on the GPU. The kernel
5065/// decomposes each output index into coordinates, maps them through
5066/// broadcast strides, and loads from the correct positions in A and B.
5067///
5068/// `a_shape` and `b_shape` are the original shapes; the output shape is
5069/// computed via numpy-style broadcast rules.
5070#[cfg(feature = "cuda")]
5071pub fn gpu_broadcast_add(
5072    a: &CudaBuffer<f32>,
5073    b: &CudaBuffer<f32>,
5074    a_shape: &[usize],
5075    b_shape: &[usize],
5076    out_shape: &[usize],
5077    device: &GpuDevice,
5078) -> GpuResult<CudaBuffer<f32>> {
5079    let a_str = broadcast_strides(a_shape, out_shape);
5080    let b_str = broadcast_strides(b_shape, out_shape);
5081    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5082    let out_numel: usize = out_shape.iter().product();
5083
5084    if let Some(out) = try_launch_broadcast_binary(
5085        a,
5086        b,
5087        &a_str,
5088        &b_str,
5089        &shape_u32,
5090        out_numel,
5091        device,
5092        BROADCAST_ADD_PTX,
5093        "broadcast_add_kernel",
5094    )? {
5095        return Ok(out);
5096    }
5097
5098    // CPU fallback for broadcast.
5099    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
5100}
5101
5102/// Broadcast subtraction: `out[i] = a[bcast_a(i)] - b[bcast_b(i)]`.
5103#[cfg(feature = "cuda")]
5104pub fn gpu_broadcast_sub(
5105    a: &CudaBuffer<f32>,
5106    b: &CudaBuffer<f32>,
5107    a_shape: &[usize],
5108    b_shape: &[usize],
5109    out_shape: &[usize],
5110    device: &GpuDevice,
5111) -> GpuResult<CudaBuffer<f32>> {
5112    let a_str = broadcast_strides(a_shape, out_shape);
5113    let b_str = broadcast_strides(b_shape, out_shape);
5114    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5115    let out_numel: usize = out_shape.iter().product();
5116
5117    if let Some(out) = try_launch_broadcast_binary(
5118        a,
5119        b,
5120        &a_str,
5121        &b_str,
5122        &shape_u32,
5123        out_numel,
5124        device,
5125        BROADCAST_SUB_PTX,
5126        "broadcast_sub_kernel",
5127    )? {
5128        return Ok(out);
5129    }
5130
5131    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
5132}
5133
5134/// Broadcast multiplication: `out[i] = a[bcast_a(i)] * b[bcast_b(i)]`.
5135#[cfg(feature = "cuda")]
5136pub fn gpu_broadcast_mul(
5137    a: &CudaBuffer<f32>,
5138    b: &CudaBuffer<f32>,
5139    a_shape: &[usize],
5140    b_shape: &[usize],
5141    out_shape: &[usize],
5142    device: &GpuDevice,
5143) -> GpuResult<CudaBuffer<f32>> {
5144    let a_str = broadcast_strides(a_shape, out_shape);
5145    let b_str = broadcast_strides(b_shape, out_shape);
5146    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5147    let out_numel: usize = out_shape.iter().product();
5148
5149    if let Some(out) = try_launch_broadcast_binary(
5150        a,
5151        b,
5152        &a_str,
5153        &b_str,
5154        &shape_u32,
5155        out_numel,
5156        device,
5157        BROADCAST_MUL_PTX,
5158        "broadcast_mul_kernel",
5159    )? {
5160        return Ok(out);
5161    }
5162
5163    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
5164}
5165
5166/// Broadcast division: `out[i] = a[bcast_a(i)] / b[bcast_b(i)]`.
5167#[cfg(feature = "cuda")]
5168pub fn gpu_broadcast_div(
5169    a: &CudaBuffer<f32>,
5170    b: &CudaBuffer<f32>,
5171    a_shape: &[usize],
5172    b_shape: &[usize],
5173    out_shape: &[usize],
5174    device: &GpuDevice,
5175) -> GpuResult<CudaBuffer<f32>> {
5176    let a_str = broadcast_strides(a_shape, out_shape);
5177    let b_str = broadcast_strides(b_shape, out_shape);
5178    let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5179    let out_numel: usize = out_shape.iter().product();
5180
5181    if let Some(out) = try_launch_broadcast_binary(
5182        a,
5183        b,
5184        &a_str,
5185        &b_str,
5186        &shape_u32,
5187        out_numel,
5188        device,
5189        BROADCAST_DIV_PTX,
5190        "broadcast_div_kernel",
5191    )? {
5192        return Ok(out);
5193    }
5194
5195    cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
5196}
5197
5198/// CPU fallback for broadcast binary ops — downloads, applies op with
5199/// broadcast indexing, re-uploads.
5200#[cfg(feature = "cuda")]
5201fn cpu_fallback_broadcast_binary(
5202    a: &CudaBuffer<f32>,
5203    b: &CudaBuffer<f32>,
5204    a_shape: &[usize],
5205    b_shape: &[usize],
5206    out_shape: &[usize],
5207    device: &GpuDevice,
5208    op: fn(f32, f32) -> f32,
5209) -> GpuResult<CudaBuffer<f32>> {
5210    let a_host = gpu_to_cpu(a, device)?;
5211    let b_host = gpu_to_cpu(b, device)?;
5212    let out_numel: usize = out_shape.iter().product();
5213
5214    let a_str = broadcast_strides(a_shape, out_shape);
5215    let b_str = broadcast_strides(b_shape, out_shape);
5216
5217    let mut result = Vec::with_capacity(out_numel);
5218    for i in 0..out_numel {
5219        let mut remaining = i;
5220        let mut a_idx = 0usize;
5221        let mut b_idx = 0usize;
5222        for d in (0..out_shape.len()).rev() {
5223            let coord = remaining % out_shape[d];
5224            remaining /= out_shape[d];
5225            a_idx += coord * a_str[d] as usize;
5226            b_idx += coord * b_str[d] as usize;
5227        }
5228        result.push(op(a_host[a_idx], b_host[b_idx]));
5229    }
5230    cpu_to_gpu(&result, device)
5231}
5232
5233// ---------------------------------------------------------------------------
5234// Public API -- unary ops
5235// ---------------------------------------------------------------------------
5236
5237/// Elementwise negation: `out[i] = -a[i]`.
5238///
5239/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
5240/// if the PTX module cannot be loaded.
5241///
5242/// # Errors
5243///
5244/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different
5245///   CUDA devices.
5246/// - [`GpuError::Driver`] on CUDA runtime errors.
5247#[cfg(feature = "cuda")]
5248pub fn gpu_neg(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
5249    validate_unary(a, device)?;
5250
5251    if let Some(out) = try_launch_unary(a, device, NEG_PTX, "neg_kernel")? {
5252        return Ok(out);
5253    }
5254
5255    cpu_fallback_unary(a, device, |x| -x)
5256}
5257
5258/// Elementwise ReLU: `out[i] = max(a[i], 0.0)`.
5259///
5260/// Attempts to run a PTX kernel on the GPU. Falls back to a CPU round-trip
5261/// if the PTX module cannot be loaded.
5262///
5263/// # Errors
5264///
5265/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different
5266///   CUDA devices.
5267/// - [`GpuError::Driver`] on CUDA runtime errors.
5268#[cfg(feature = "cuda")]
5269pub fn gpu_relu(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
5270    validate_unary(a, device)?;
5271
5272    if let Some(out) = try_launch_unary(a, device, RELU_PTX, "relu_kernel")? {
5273        return Ok(out);
5274    }
5275
5276    cpu_fallback_unary(a, device, |x| x.max(0.0))
5277}
5278
5279/// ReLU backward: `out[i] = (input[i] > 0) ? grad[i] : 0`.
5280#[cfg(feature = "cuda")]
5281pub fn gpu_relu_backward(
5282    grad: &CudaBuffer<f32>,
5283    input: &CudaBuffer<f32>,
5284    device: &GpuDevice,
5285) -> GpuResult<CudaBuffer<f32>> {
5286    validate_binary(grad, input, device)?;
5287
5288    if let Some(out) = try_launch_binary(
5289        grad,
5290        input,
5291        device,
5292        RELU_BACKWARD_PTX,
5293        "relu_backward_kernel",
5294    )? {
5295        return Ok(out);
5296    }
5297
5298    // CPU fallback
5299    let grad_host = gpu_to_cpu(grad, device)?;
5300    let input_host = gpu_to_cpu(input, device)?;
5301    let result: Vec<f32> = grad_host
5302        .iter()
5303        .zip(input_host.iter())
5304        .map(|(&g, &x)| if x > 0.0 { g } else { 0.0 })
5305        .collect();
5306    cpu_to_gpu(&result, device)
5307}
5308
5309/// GELU backward: `out[i] = grad[i] * (sig + 1.702 * x * sig * (1 - sig))`
5310/// where `sig = sigmoid(1.702 * x)`.
5311#[cfg(feature = "cuda")]
5312pub fn gpu_gelu_backward(
5313    grad: &CudaBuffer<f32>,
5314    input: &CudaBuffer<f32>,
5315    device: &GpuDevice,
5316) -> GpuResult<CudaBuffer<f32>> {
5317    validate_binary(grad, input, device)?;
5318
5319    if let Some(out) = try_launch_binary(
5320        grad,
5321        input,
5322        device,
5323        GELU_BACKWARD_PTX,
5324        "gelu_backward_kernel",
5325    )? {
5326        return Ok(out);
5327    }
5328
5329    // CPU fallback
5330    let grad_host = gpu_to_cpu(grad, device)?;
5331    let input_host = gpu_to_cpu(input, device)?;
5332    let result: Vec<f32> = grad_host
5333        .iter()
5334        .zip(input_host.iter())
5335        .map(|(&g, &x)| {
5336            let k: f32 = 1.702;
5337            let sig = 1.0 / (1.0 + (-k * x).exp());
5338            g * (sig + k * x * sig * (1.0 - sig))
5339        })
5340        .collect();
5341    cpu_to_gpu(&result, device)
5342}
5343
5344// ---------------------------------------------------------------------------
5345// Public API -- Index-select 1-D (gather)
5346// ---------------------------------------------------------------------------
5347
5348/// Gather elements from `input` at positions given by `indices`.
5349///
5350/// `indices` is a GPU buffer of f32 values encoding integer indices.
5351/// Output has `indices.len()` elements: `out[i] = input[indices[i]]`.
5352#[cfg(feature = "cuda")]
5353pub fn gpu_index_select_1d(
5354    input: &CudaBuffer<f32>,
5355    indices: &CudaBuffer<f32>,
5356    device: &GpuDevice,
5357) -> GpuResult<CudaBuffer<f32>> {
5358    use cudarc::driver::PushKernelArg;
5359
5360    validate_unary(input, device)?;
5361
5362    let n = indices.len();
5363    let ctx = device.context();
5364    let stream = device.stream();
5365
5366    let f = match crate::module_cache::get_or_compile(
5367        ctx,
5368        INDEX_SELECT_1D_PTX,
5369        "index_select_1d_kernel",
5370        device.ordinal() as u32,
5371    ) {
5372        Ok(f) => f,
5373        Err(_) => {
5374            // CPU fallback.
5375            let input_host = gpu_to_cpu(input, device)?;
5376            let indices_host = gpu_to_cpu(indices, device)?;
5377            let result: Vec<f32> = indices_host
5378                .iter()
5379                .map(|&idx_f| input_host[idx_f as usize])
5380                .collect();
5381            return cpu_to_gpu(&result, device);
5382        }
5383    };
5384
5385    let mut out = alloc_zeros_f32(n, device)?;
5386    let cfg = launch_cfg(n)?;
5387    let n_u32 = n as u32;
5388
5389    unsafe {
5390        stream
5391            .launch_builder(&f)
5392            .arg(input.inner())
5393            .arg(indices.inner())
5394            .arg(out.inner_mut())
5395            .arg(&n_u32)
5396            .launch(cfg)?;
5397    }
5398
5399    Ok(out)
5400}
5401
5402// ---------------------------------------------------------------------------
5403// Public API -- Scatter-add 1-D (backward of index_select)
5404// ---------------------------------------------------------------------------
5405
5406/// Scatter-add `grad_output` back into an output buffer of `input_len` elements,
5407/// using positions from `indices`.
5408///
5409/// `indices` is a GPU buffer of f32 values encoding integer indices.
5410/// Output: `out = zeros(input_len); for i: out[indices[i]] += grad_output[i]`
5411///
5412/// Uses atomic adds for safe concurrent accumulation.
5413#[cfg(feature = "cuda")]
5414pub fn gpu_scatter_add_1d(
5415    grad_output: &CudaBuffer<f32>,
5416    indices: &CudaBuffer<f32>,
5417    input_len: usize,
5418    device: &GpuDevice,
5419) -> GpuResult<CudaBuffer<f32>> {
5420    use cudarc::driver::PushKernelArg;
5421
5422    validate_unary(grad_output, device)?;
5423
5424    let n = grad_output.len();
5425    let ctx = device.context();
5426    let stream = device.stream();
5427
5428    let f = match crate::module_cache::get_or_compile(
5429        ctx,
5430        SCATTER_ADD_1D_PTX,
5431        "scatter_add_1d_kernel",
5432        device.ordinal() as u32,
5433    ) {
5434        Ok(f) => f,
5435        Err(_) => {
5436            // CPU fallback.
5437            let go_host = gpu_to_cpu(grad_output, device)?;
5438            let idx_host = gpu_to_cpu(indices, device)?;
5439            let mut result = vec![0.0f32; input_len];
5440            for (i, &idx_f) in idx_host.iter().enumerate() {
5441                result[idx_f as usize] += go_host[i];
5442            }
5443            return cpu_to_gpu(&result, device);
5444        }
5445    };
5446
5447    let mut out = alloc_zeros_f32(input_len, device)?;
5448    let cfg = launch_cfg(n)?;
5449    let n_u32 = n as u32;
5450
5451    unsafe {
5452        stream
5453            .launch_builder(&f)
5454            .arg(grad_output.inner())
5455            .arg(indices.inner())
5456            .arg(out.inner_mut())
5457            .arg(&n_u32)
5458            .launch(cfg)?;
5459    }
5460
5461    Ok(out)
5462}
5463
5464// ---------------------------------------------------------------------------
5465// Public API -- Masked fill
5466// ---------------------------------------------------------------------------
5467
5468/// Fill elements of `input` with `value` where `mask` is true.
5469///
5470/// `mask` is a GPU buffer of f32 values (1.0 = true, 0.0 = false).
5471/// Output: `out[i] = mask[i] >= 0.5 ? value : input[i]`
5472#[cfg(feature = "cuda")]
5473pub fn gpu_masked_fill(
5474    input: &CudaBuffer<f32>,
5475    mask: &CudaBuffer<f32>,
5476    value: f32,
5477    device: &GpuDevice,
5478) -> GpuResult<CudaBuffer<f32>> {
5479    use cudarc::driver::PushKernelArg;
5480
5481    validate_binary(input, mask, device)?;
5482
5483    let n = input.len();
5484    let ctx = device.context();
5485    let stream = device.stream();
5486
5487    let f = match crate::module_cache::get_or_compile(
5488        ctx,
5489        MASKED_FILL_PTX,
5490        "masked_fill_kernel",
5491        device.ordinal() as u32,
5492    ) {
5493        Ok(f) => f,
5494        Err(_) => {
5495            // CPU fallback.
5496            let input_host = gpu_to_cpu(input, device)?;
5497            let mask_host = gpu_to_cpu(mask, device)?;
5498            let result: Vec<f32> = input_host
5499                .iter()
5500                .zip(mask_host.iter())
5501                .map(|(&x, &m)| if m >= 0.5 { value } else { x })
5502                .collect();
5503            return cpu_to_gpu(&result, device);
5504        }
5505    };
5506
5507    let mut out = alloc_zeros_f32(n, device)?;
5508    let cfg = launch_cfg(n)?;
5509    let n_u32 = n as u32;
5510
5511    unsafe {
5512        stream
5513            .launch_builder(&f)
5514            .arg(input.inner())
5515            .arg(mask.inner())
5516            .arg(out.inner_mut())
5517            .arg(&value)
5518            .arg(&n_u32)
5519            .launch(cfg)?;
5520    }
5521
5522    Ok(out)
5523}
5524
5525// ---------------------------------------------------------------------------
5526// Public API -- Masked zero (backward of masked_fill)
5527// ---------------------------------------------------------------------------
5528
5529/// Zero out gradient at positions where `mask` is true.
5530///
5531/// `mask` is a GPU buffer of f32 values (1.0 = true, 0.0 = false).
5532/// Output: `out[i] = mask[i] >= 0.5 ? 0.0 : grad[i]`
5533#[cfg(feature = "cuda")]
5534pub fn gpu_masked_zero(
5535    grad: &CudaBuffer<f32>,
5536    mask: &CudaBuffer<f32>,
5537    device: &GpuDevice,
5538) -> GpuResult<CudaBuffer<f32>> {
5539    validate_binary(grad, mask, device)?;
5540
5541    if let Some(out) = try_launch_binary(grad, mask, device, MASKED_ZERO_PTX, "masked_zero_kernel")?
5542    {
5543        return Ok(out);
5544    }
5545
5546    // CPU fallback.
5547    let grad_host = gpu_to_cpu(grad, device)?;
5548    let mask_host = gpu_to_cpu(mask, device)?;
5549    let result: Vec<f32> = grad_host
5550        .iter()
5551        .zip(mask_host.iter())
5552        .map(|(&g, &m)| if m >= 0.5 { 0.0 } else { g })
5553        .collect();
5554    cpu_to_gpu(&result, device)
5555}
5556
5557// ---------------------------------------------------------------------------
5558// Public API -- Sigmoid backward
5559// ---------------------------------------------------------------------------
5560
5561/// Sigmoid backward: `out[i] = grad[i] * output[i] * (1 - output[i])`.
5562///
5563/// `grad` and `output` must have the same length and reside on `device`.
5564#[cfg(feature = "cuda")]
5565pub fn gpu_sigmoid_backward(
5566    grad: &CudaBuffer<f32>,
5567    output: &CudaBuffer<f32>,
5568    device: &GpuDevice,
5569) -> GpuResult<CudaBuffer<f32>> {
5570    validate_binary(grad, output, device)?;
5571
5572    if let Some(out) = try_launch_binary(
5573        grad,
5574        output,
5575        device,
5576        SIGMOID_BACKWARD_PTX,
5577        "sigmoid_backward_kernel",
5578    )? {
5579        return Ok(out);
5580    }
5581
5582    // CPU fallback
5583    let grad_host = gpu_to_cpu(grad, device)?;
5584    let output_host = gpu_to_cpu(output, device)?;
5585    let result: Vec<f32> = grad_host
5586        .iter()
5587        .zip(output_host.iter())
5588        .map(|(&g, &o)| g * o * (1.0 - o))
5589        .collect();
5590    cpu_to_gpu(&result, device)
5591}
5592
5593// ---------------------------------------------------------------------------
5594// Public API -- Tanh backward
5595// ---------------------------------------------------------------------------
5596
5597/// Tanh backward: `out[i] = grad[i] * (1 - output[i]^2)`.
5598///
5599/// `grad` and `output` must have the same length and reside on `device`.
5600#[cfg(feature = "cuda")]
5601pub fn gpu_tanh_backward(
5602    grad: &CudaBuffer<f32>,
5603    output: &CudaBuffer<f32>,
5604    device: &GpuDevice,
5605) -> GpuResult<CudaBuffer<f32>> {
5606    validate_binary(grad, output, device)?;
5607
5608    if let Some(out) = try_launch_binary(
5609        grad,
5610        output,
5611        device,
5612        TANH_BACKWARD_PTX,
5613        "tanh_backward_kernel",
5614    )? {
5615        return Ok(out);
5616    }
5617
5618    // CPU fallback
5619    let grad_host = gpu_to_cpu(grad, device)?;
5620    let output_host = gpu_to_cpu(output, device)?;
5621    let result: Vec<f32> = grad_host
5622        .iter()
5623        .zip(output_host.iter())
5624        .map(|(&g, &o)| g * (1.0 - o * o))
5625        .collect();
5626    cpu_to_gpu(&result, device)
5627}
5628
5629// ---------------------------------------------------------------------------
5630// Public API -- Softmax backward
5631// ---------------------------------------------------------------------------
5632
5633/// Softmax backward (row-wise): one block per row, shared-memory dot reduction.
5634///
5635/// For each row of length `cols`:
5636///   `dot = sum(grad[row] * output[row])`
5637///   `out[i] = output[i] * (grad[i] - dot)`
5638///
5639/// `rows` = total elements / cols. Both `grad` and `output` have `rows * cols` elements.
5640#[cfg(feature = "cuda")]
5641pub fn gpu_softmax_backward(
5642    grad: &CudaBuffer<f32>,
5643    output: &CudaBuffer<f32>,
5644    cols: usize,
5645    device: &GpuDevice,
5646) -> GpuResult<CudaBuffer<f32>> {
5647    use cudarc::driver::PushKernelArg;
5648
5649    validate_binary(grad, output, device)?;
5650
5651    let total = grad.len();
5652    let rows = total / cols;
5653
5654    let ctx = device.context();
5655    let stream = device.stream();
5656
5657    let f = match crate::module_cache::get_or_compile(
5658        ctx,
5659        SOFTMAX_BACKWARD_PTX,
5660        "softmax_backward_kernel",
5661        device.ordinal() as u32,
5662    ) {
5663        Ok(f) => f,
5664        Err(_) => {
5665            // CPU fallback
5666            let grad_host = gpu_to_cpu(grad, device)?;
5667            let output_host = gpu_to_cpu(output, device)?;
5668            let mut result = vec![0.0f32; total];
5669            for r in 0..rows {
5670                let base = r * cols;
5671                let mut dot = 0.0f32;
5672                for c in 0..cols {
5673                    dot += grad_host[base + c] * output_host[base + c];
5674                }
5675                for c in 0..cols {
5676                    result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
5677                }
5678            }
5679            return cpu_to_gpu(&result, device);
5680        }
5681    };
5682
5683    let mut out = alloc_zeros_f32(total, device)?;
5684    let rows_u32 = rows as u32;
5685    let cols_u32 = cols as u32;
5686
5687    // One block per row, 256 threads per block.
5688    let cfg = LaunchConfig {
5689        grid_dim: ((rows as u32).max(1), 1, 1),
5690        block_dim: (256, 1, 1),
5691        shared_mem_bytes: 256 * 4,
5692    };
5693
5694    unsafe {
5695        stream
5696            .launch_builder(&f)
5697            .arg(grad.inner())
5698            .arg(output.inner())
5699            .arg(out.inner_mut())
5700            .arg(&rows_u32)
5701            .arg(&cols_u32)
5702            .launch(cfg)?;
5703    }
5704
5705    Ok(out)
5706}
5707
5708// ---------------------------------------------------------------------------
5709// Public API -- Sum axis
5710// ---------------------------------------------------------------------------
5711
5712/// Reduce along one axis of a tensor.
5713///
5714/// Thread i computes:
5715/// Full parallel sum reduction on GPU.
5716///
5717/// Uses a two-pass approach: first pass reduces `n` elements to `num_blocks`
5718/// partial sums via the `reduce_sum_kernel`, second pass reduces the partial
5719/// sums to a single scalar. For small inputs (< 256 blocks), the second pass
5720/// runs on CPU to avoid kernel launch overhead.
5721#[cfg(feature = "cuda")]
5722pub fn gpu_reduce_sum(
5723    a: &CudaBuffer<f32>,
5724    device: &GpuDevice,
5725) -> GpuResult<CudaBuffer<f32>> {
5726    use cudarc::driver::PushKernelArg;
5727
5728    let n = a.len();
5729    if n == 0 {
5730        return cpu_to_gpu(&[0.0f32], device);
5731    }
5732
5733    let ctx = device.context();
5734    let stream = device.stream();
5735
5736    let f = match crate::module_cache::get_or_compile(
5737        ctx,
5738        REDUCE_SUM_PTX,
5739        "reduce_sum_kernel",
5740        device.ordinal() as u32,
5741    ) {
5742        Ok(f) => f,
5743        Err(_) => {
5744            // CPU fallback
5745            let host = gpu_to_cpu(a, device)?;
5746            let total: f32 = host.iter().sum();
5747            return cpu_to_gpu(&[total], device);
5748        }
5749    };
5750
5751    // Pass 1: reduce to partial sums (one per block).
5752    const BLOCK: u32 = 256;
5753    let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
5754    // Cap blocks to avoid excessive partial sums.
5755    let num_blocks = num_blocks.min(1024);
5756
5757    let mut partials = alloc_zeros_f32(num_blocks as usize, device)?;
5758    let n_u32 = n as u32;
5759
5760    let cfg = cudarc::driver::LaunchConfig {
5761        grid_dim: (num_blocks.max(1), 1, 1),
5762        block_dim: (BLOCK, 1, 1),
5763        shared_mem_bytes: 0, // Statically allocated in PTX
5764    };
5765
5766    unsafe {
5767        stream
5768            .launch_builder(&f)
5769            .arg(a.inner())
5770            .arg(partials.inner_mut())
5771            .arg(&n_u32)
5772            .launch(cfg)?;
5773    }
5774
5775    // Pass 2: reduce partial sums.
5776    if num_blocks <= 1 {
5777        return Ok(partials);
5778    }
5779
5780    // For small number of blocks, reduce on CPU (cheaper than another kernel launch).
5781    if num_blocks <= 256 {
5782        let host_partials = gpu_to_cpu(&partials, device)?;
5783        let total: f32 = host_partials.iter().sum();
5784        return cpu_to_gpu(&[total], device);
5785    }
5786
5787    // For many blocks, recurse with another kernel launch.
5788    gpu_reduce_sum(&partials, device)
5789}
5790
5791/// Stub -- always returns [`GpuError::NoCudaFeature`].
5792#[cfg(not(feature = "cuda"))]
5793pub fn gpu_reduce_sum(
5794    _a: &CudaBuffer<f32>,
5795    _device: &GpuDevice,
5796) -> GpuResult<CudaBuffer<f32>> {
5797    Err(GpuError::NoCudaFeature)
5798}
5799
5800///   `output[i] = sum_{k=0}^{axis_size-1} input[outer_idx * axis_size * inner_size + k * inner_size + inner_idx]`
5801///
5802/// where `outer_idx = i / inner_size`, `inner_idx = i % inner_size`.
5803#[cfg(feature = "cuda")]
5804pub fn gpu_sum_axis(
5805    a: &CudaBuffer<f32>,
5806    outer: usize,
5807    axis_size: usize,
5808    inner: usize,
5809    device: &GpuDevice,
5810) -> GpuResult<CudaBuffer<f32>> {
5811    use cudarc::driver::PushKernelArg;
5812
5813    validate_unary(a, device)?;
5814
5815    let total_output = outer * inner;
5816    let ctx = device.context();
5817    let stream = device.stream();
5818
5819    let f = match crate::module_cache::get_or_compile(
5820        ctx,
5821        SUM_AXIS_PTX,
5822        "sum_axis_kernel",
5823        device.ordinal() as u32,
5824    ) {
5825        Ok(f) => f,
5826        Err(_) => {
5827            // CPU fallback
5828            let host = gpu_to_cpu(a, device)?;
5829            let mut result = vec![0.0f32; total_output];
5830            for (i, out) in result.iter_mut().enumerate() {
5831                let outer_idx = i / inner;
5832                let inner_idx = i % inner;
5833                let mut sum = 0.0f32;
5834                for k in 0..axis_size {
5835                    sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
5836                }
5837                *out = sum;
5838            }
5839            return cpu_to_gpu(&result, device);
5840        }
5841    };
5842
5843    let mut out = alloc_zeros_f32(total_output, device)?;
5844    let cfg = launch_cfg(total_output)?;
5845    let outer_u32 = outer as u32;
5846    let axis_size_u32 = axis_size as u32;
5847    let inner_u32 = inner as u32;
5848    let total_u32 = total_output as u32;
5849
5850    unsafe {
5851        stream
5852            .launch_builder(&f)
5853            .arg(a.inner())
5854            .arg(out.inner_mut())
5855            .arg(&outer_u32)
5856            .arg(&axis_size_u32)
5857            .arg(&inner_u32)
5858            .arg(&total_u32)
5859            .launch(cfg)?;
5860    }
5861
5862    Ok(out)
5863}
5864
5865// ---------------------------------------------------------------------------
5866// Public API -- Strided split
5867// ---------------------------------------------------------------------------
5868
5869/// Extract a sub-tensor along one axis entirely on GPU.
5870///
5871/// Given an input buffer representing a tensor with `total_along_axis` elements
5872/// along the split axis, extracts the slice `[split_offset .. split_offset + split_size]`
5873/// along that axis.
5874///
5875/// - `inner_size` = product of dimensions after the split axis.
5876/// - `n` = total number of output elements (outer * split_size * inner_size).
5877///
5878/// # Errors
5879///
5880/// - [`GpuError::DeviceMismatch`] if `input` and `device` are on different devices.
5881/// - [`GpuError::Driver`] on CUDA runtime errors.
5882#[cfg(feature = "cuda")]
5883pub fn gpu_strided_split(
5884    input: &CudaBuffer<f32>,
5885    total_along_axis: usize,
5886    split_offset: usize,
5887    split_size: usize,
5888    inner_size: usize,
5889    n: usize,
5890    device: &GpuDevice,
5891) -> GpuResult<CudaBuffer<f32>> {
5892    use cudarc::driver::PushKernelArg;
5893
5894    validate_unary(input, device)?;
5895
5896    let ctx = device.context();
5897    let stream = device.stream();
5898
5899    let f = match crate::module_cache::get_or_compile(
5900        ctx,
5901        STRIDED_SPLIT_PTX,
5902        "strided_split_kernel",
5903        device.ordinal() as u32,
5904    ) {
5905        Ok(f) => f,
5906        Err(_) => {
5907            // CPU fallback
5908            let host = gpu_to_cpu(input, device)?;
5909            let outer = n / (split_size * inner_size);
5910            let mut result = vec![0.0f32; n];
5911            for (i, out) in result.iter_mut().enumerate() {
5912                let outer_idx = i / (split_size * inner_size);
5913                let within = i % (split_size * inner_size);
5914                let src_idx =
5915                    outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
5916                *out = host[src_idx];
5917            }
5918            let _ = outer;
5919            return cpu_to_gpu(&result, device);
5920        }
5921    };
5922
5923    let mut out = alloc_zeros_f32(n, device)?;
5924    let cfg = launch_cfg(n)?;
5925    let total_ax_u32 = total_along_axis as u32;
5926    let offset_u32 = split_offset as u32;
5927    let split_sz_u32 = split_size as u32;
5928    let inner_u32 = inner_size as u32;
5929    let n_u32 = n as u32;
5930
5931    unsafe {
5932        stream
5933            .launch_builder(&f)
5934            .arg(input.inner())
5935            .arg(out.inner_mut())
5936            .arg(&total_ax_u32)
5937            .arg(&offset_u32)
5938            .arg(&split_sz_u32)
5939            .arg(&inner_u32)
5940            .arg(&n_u32)
5941            .launch(cfg)?;
5942    }
5943
5944    Ok(out)
5945}
5946
5947// ---------------------------------------------------------------------------
5948// Public API -- Strided cat
5949// ---------------------------------------------------------------------------
5950
5951/// Write a sub-tensor into a larger output buffer at an offset along one axis,
5952/// entirely on GPU.
5953///
5954/// Given an input buffer representing a chunk with `part_size` elements along
5955/// the cat axis, writes it into `output` at position `cat_offset` along that axis.
5956///
5957/// - `inner_size` = product of dimensions after the cat axis.
5958/// - `n` = total number of input elements (outer * part_size * inner_size).
5959///
5960/// # Safety
5961///
5962/// `output` must be large enough to hold the written region. The caller is
5963/// responsible for ensuring non-overlapping writes when multiple chunks are
5964/// written into the same output buffer.
5965///
5966/// # Errors
5967///
5968/// - [`GpuError::DeviceMismatch`] if buffers and `device` are on different devices.
5969/// - [`GpuError::Driver`] on CUDA runtime errors.
5970#[cfg(feature = "cuda")]
5971#[allow(clippy::too_many_arguments)]
5972pub fn gpu_strided_cat(
5973    input: &CudaBuffer<f32>,
5974    output: &mut CudaBuffer<f32>,
5975    total_along_axis: usize,
5976    cat_offset: usize,
5977    part_size: usize,
5978    inner_size: usize,
5979    n: usize,
5980    device: &GpuDevice,
5981) -> GpuResult<()> {
5982    use cudarc::driver::PushKernelArg;
5983
5984    validate_unary(input, device)?;
5985
5986    let ctx = device.context();
5987    let stream = device.stream();
5988
5989    let f = match crate::module_cache::get_or_compile(
5990        ctx,
5991        STRIDED_CAT_PTX,
5992        "strided_cat_kernel",
5993        device.ordinal() as u32,
5994    ) {
5995        Ok(f) => f,
5996        Err(_) => {
5997            // CPU fallback
5998            let host_in = gpu_to_cpu(input, device)?;
5999            let mut host_out = gpu_to_cpu(output, device)?;
6000            for (i, &val) in host_in.iter().enumerate().take(n) {
6001                let outer_idx = i / (part_size * inner_size);
6002                let within = i % (part_size * inner_size);
6003                let dst_idx =
6004                    outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
6005                host_out[dst_idx] = val;
6006            }
6007            *output = cpu_to_gpu(&host_out, device)?;
6008            return Ok(());
6009        }
6010    };
6011
6012    let cfg = launch_cfg(n)?;
6013    let total_ax_u32 = total_along_axis as u32;
6014    let offset_u32 = cat_offset as u32;
6015    let part_sz_u32 = part_size as u32;
6016    let inner_u32 = inner_size as u32;
6017    let n_u32 = n as u32;
6018
6019    unsafe {
6020        stream
6021            .launch_builder(&f)
6022            .arg(input.inner())
6023            .arg(output.inner_mut())
6024            .arg(&total_ax_u32)
6025            .arg(&offset_u32)
6026            .arg(&part_sz_u32)
6027            .arg(&inner_u32)
6028            .arg(&n_u32)
6029            .launch(cfg)?;
6030    }
6031
6032    Ok(())
6033}
6034
6035/// Scalar multiply: `out[i] = a[i] * scalar`.
6036///
6037/// Multiplies every element by a constant float value on the GPU.
6038///
6039/// # Errors
6040///
6041/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different CUDA devices.
6042/// - [`GpuError::Driver`] on CUDA runtime errors.
6043#[cfg(feature = "cuda")]
6044pub fn gpu_scale(
6045    a: &CudaBuffer<f32>,
6046    scalar: f32,
6047    device: &GpuDevice,
6048) -> GpuResult<CudaBuffer<f32>> {
6049    use cudarc::driver::PushKernelArg;
6050
6051    validate_unary(a, device)?;
6052
6053    let n = a.len();
6054    let ctx = device.context();
6055    let stream = device.stream();
6056
6057    let f = match crate::module_cache::get_or_compile(
6058        ctx,
6059        SCALE_PTX,
6060        "scale_kernel",
6061        device.ordinal() as u32,
6062    ) {
6063        Ok(f) => f,
6064        Err(_) => {
6065            // CPU fallback
6066            let host = gpu_to_cpu(a, device)?;
6067            let result: Vec<f32> = host.iter().map(|&x| x * scalar).collect();
6068            return cpu_to_gpu(&result, device);
6069        }
6070    };
6071
6072    let mut out = alloc_zeros_f32(n, device)?;
6073    let cfg = launch_cfg(n)?;
6074    let n_u32 = n as u32;
6075
6076    unsafe {
6077        stream
6078            .launch_builder(&f)
6079            .arg(a.inner())
6080            .arg(out.inner_mut())
6081            .arg(&scalar)
6082            .arg(&n_u32)
6083            .launch(cfg)?;
6084    }
6085
6086    Ok(out)
6087}
6088
6089// ---------------------------------------------------------------------------
6090// Public API -- softmax
6091// ---------------------------------------------------------------------------
6092
6093/// Row-wise softmax on GPU: one thread block per row, shared-memory reduction.
6094///
6095/// `rows` = product of all dims except the last. `cols` = last dim size.
6096#[cfg(feature = "cuda")]
6097pub fn gpu_softmax(
6098    input: &CudaBuffer<f32>,
6099    rows: usize,
6100    cols: usize,
6101    device: &GpuDevice,
6102) -> GpuResult<CudaBuffer<f32>> {
6103    use cudarc::driver::PushKernelArg;
6104
6105    validate_unary(input, device)?;
6106
6107    let ctx = device.context();
6108    let stream = device.stream();
6109
6110    let f = match crate::module_cache::get_or_compile(
6111        ctx,
6112        SOFTMAX_PTX,
6113        "softmax_kernel",
6114        device.ordinal() as u32,
6115    ) {
6116        Ok(f) => f,
6117        Err(_) => {
6118            // CPU fallback.
6119            let host = gpu_to_cpu(input, device)?;
6120            let mut out = vec![0.0f32; host.len()];
6121            for r in 0..rows {
6122                let base = r * cols;
6123                let mut max_v = f32::NEG_INFINITY;
6124                for c in 0..cols {
6125                    max_v = max_v.max(host[base + c]);
6126                }
6127                let mut sum = 0.0f32;
6128                for c in 0..cols {
6129                    let e = (host[base + c] - max_v).exp();
6130                    out[base + c] = e;
6131                    sum += e;
6132                }
6133                let inv = 1.0 / sum;
6134                for c in 0..cols {
6135                    out[base + c] *= inv;
6136                }
6137            }
6138            return cpu_to_gpu(&out, device);
6139        }
6140    };
6141
6142    let mut out = alloc_zeros_f32(rows * cols, device)?;
6143    let rows_u32 = rows as u32;
6144    let cols_u32 = cols as u32;
6145
6146    // One block per row, 256 threads per block.
6147    let cfg = LaunchConfig {
6148        grid_dim: ((rows as u32).max(1), 1, 1),
6149        block_dim: (256, 1, 1),
6150        shared_mem_bytes: 256 * 4, // sdata[256] f32
6151    };
6152
6153    unsafe {
6154        stream
6155            .launch_builder(&f)
6156            .arg(input.inner())
6157            .arg(out.inner_mut())
6158            .arg(&rows_u32)
6159            .arg(&cols_u32)
6160            .launch(cfg)?;
6161    }
6162
6163    Ok(out)
6164}
6165
6166// ---------------------------------------------------------------------------
6167// Public API -- dropout
6168// ---------------------------------------------------------------------------
6169
6170/// Inverted dropout on GPU: `out[i] = input[i] * scale` or `0` with probability `p`.
6171///
6172/// `threshold` = `(p * u32::MAX as f64) as u32` — the RNG cutoff.
6173/// `scale` = `1.0 / (1.0 - p)`.
6174/// `seed` = random seed for the RNG.
6175///
6176/// **Known limitation**: This kernel uses a simple per-element hash
6177/// (`tid * 2654435761 ^ seed` with xorshift mixing), not the full
6178/// Philox 4x32-10 counter-based RNG that PyTorch uses. A proper Philox
6179/// dropout kernel would generate the mask via `philox_uniform_kernel`
6180/// and then threshold — producing higher-quality randomness and exact
6181/// reproducibility across CPU/GPU. The current hash is sufficient for
6182/// training but should be upgraded for research requiring strict
6183/// statistical properties.
6184#[cfg(feature = "cuda")]
6185pub fn gpu_dropout(
6186    input: &CudaBuffer<f32>,
6187    threshold: u32,
6188    scale: f32,
6189    seed: u32,
6190    device: &GpuDevice,
6191) -> GpuResult<CudaBuffer<f32>> {
6192    use cudarc::driver::PushKernelArg;
6193
6194    validate_unary(input, device)?;
6195
6196    let n = input.len();
6197    let ctx = device.context();
6198    let stream = device.stream();
6199
6200    let f = match crate::module_cache::get_or_compile(
6201        ctx,
6202        DROPOUT_PTX,
6203        "dropout_kernel",
6204        device.ordinal() as u32,
6205    ) {
6206        Ok(f) => f,
6207        Err(_) => {
6208            // CPU fallback.
6209            let host = gpu_to_cpu(input, device)?;
6210            // Stateless per-element hash matching the GPU kernel: each element
6211            // independently computes its own pseudorandom value from (tid, seed)
6212            // with no state carried between elements.
6213            let result: Vec<f32> = host
6214                .iter()
6215                .enumerate()
6216                .map(|(i, &x)| {
6217                    let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
6218                    r ^= r << 13;
6219                    r ^= r >> 17;
6220                    r ^= r << 5;
6221                    if r < threshold { 0.0 } else { x * scale }
6222                })
6223                .collect();
6224            return cpu_to_gpu(&result, device);
6225        }
6226    };
6227
6228    let mut out = alloc_zeros_f32(n, device)?;
6229    let cfg = launch_cfg(n)?;
6230    let n_u32 = n as u32;
6231
6232    unsafe {
6233        stream
6234            .launch_builder(&f)
6235            .arg(input.inner())
6236            .arg(out.inner_mut())
6237            .arg(&n_u32)
6238            .arg(&threshold)
6239            .arg(&scale)
6240            .arg(&seed)
6241            .launch(cfg)?;
6242    }
6243
6244    Ok(out)
6245}
6246
6247// ---------------------------------------------------------------------------
6248// Public API -- 2D transpose
6249// ---------------------------------------------------------------------------
6250
6251/// 2D matrix transpose on GPU: `[M, N]` -> `[N, M]`.
6252#[cfg(feature = "cuda")]
6253pub fn gpu_transpose_2d(
6254    input: &CudaBuffer<f32>,
6255    m: usize,
6256    n: usize,
6257    device: &GpuDevice,
6258) -> GpuResult<CudaBuffer<f32>> {
6259    use cudarc::driver::PushKernelArg;
6260
6261    validate_unary(input, device)?;
6262
6263    let total = m * n;
6264    let ctx = device.context();
6265    let stream = device.stream();
6266
6267    let f = match crate::module_cache::get_or_compile(
6268        ctx,
6269        TRANSPOSE_2D_PTX,
6270        "transpose_2d_kernel",
6271        device.ordinal() as u32,
6272    ) {
6273        Ok(f) => f,
6274        Err(_) => {
6275            // CPU fallback.
6276            let host = gpu_to_cpu(input, device)?;
6277            let mut out = vec![0.0f32; total];
6278            for i in 0..m {
6279                for j in 0..n {
6280                    out[j * m + i] = host[i * n + j];
6281                }
6282            }
6283            return cpu_to_gpu(&out, device);
6284        }
6285    };
6286
6287    let mut out = alloc_zeros_f32(total, device)?;
6288    let cfg = launch_cfg(total)?;
6289    let m_u32 = m as u32;
6290    let n_u32 = n as u32;
6291    let total_u32 = total as u32;
6292
6293    unsafe {
6294        stream
6295            .launch_builder(&f)
6296            .arg(input.inner())
6297            .arg(out.inner_mut())
6298            .arg(&m_u32)
6299            .arg(&n_u32)
6300            .arg(&total_u32)
6301            .launch(cfg)?;
6302    }
6303
6304    Ok(out)
6305}
6306
6307// ---------------------------------------------------------------------------
6308// Public API -- 4D permute (0,2,1,3)
6309// ---------------------------------------------------------------------------
6310
6311/// Permute a 4D tensor from `[d0, d1, d2, d3]` to `[d0, d2, d1, d3]` on GPU.
6312/// Used for attention head reshaping: `[B, S, H, D_h]` -> `[B, H, S, D_h]`.
6313#[cfg(feature = "cuda")]
6314pub fn gpu_permute_0213(
6315    input: &CudaBuffer<f32>,
6316    d0: usize,
6317    d1: usize,
6318    d2: usize,
6319    d3: usize,
6320    device: &GpuDevice,
6321) -> GpuResult<CudaBuffer<f32>> {
6322    use cudarc::driver::PushKernelArg;
6323
6324    validate_unary(input, device)?;
6325
6326    let total = d0 * d1 * d2 * d3;
6327    let ctx = device.context();
6328    let stream = device.stream();
6329
6330    let f = match crate::module_cache::get_or_compile(
6331        ctx,
6332        PERMUTE_0213_PTX,
6333        "permute_0213_kernel",
6334        device.ordinal() as u32,
6335    ) {
6336        Ok(f) => f,
6337        Err(_) => {
6338            // CPU fallback.
6339            let host = gpu_to_cpu(input, device)?;
6340            let mut out = vec![0.0f32; total];
6341            for i0 in 0..d0 {
6342                for i1 in 0..d1 {
6343                    for i2 in 0..d2 {
6344                        for i3 in 0..d3 {
6345                            let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
6346                            let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
6347                            out[out_idx] = host[in_idx];
6348                        }
6349                    }
6350                }
6351            }
6352            return cpu_to_gpu(&out, device);
6353        }
6354    };
6355
6356    let mut out = alloc_zeros_f32(total, device)?;
6357    let cfg = launch_cfg(total)?;
6358    let d0_u32 = d0 as u32;
6359    let d1_u32 = d1 as u32;
6360    let d2_u32 = d2 as u32;
6361    let d3_u32 = d3 as u32;
6362    let total_u32 = total as u32;
6363
6364    unsafe {
6365        stream
6366            .launch_builder(&f)
6367            .arg(input.inner())
6368            .arg(out.inner_mut())
6369            .arg(&d0_u32)
6370            .arg(&d1_u32)
6371            .arg(&d2_u32)
6372            .arg(&d3_u32)
6373            .arg(&total_u32)
6374            .launch(cfg)?;
6375    }
6376
6377    Ok(out)
6378}
6379
6380// ---------------------------------------------------------------------------
6381// Public API -- Small matmul (bypasses cuBLAS JIT)
6382// ---------------------------------------------------------------------------
6383
6384/// Small matrix multiply using our own PTX kernel. Avoids cuBLAS JIT
6385/// compilation overhead for tiny matrices where JIT cost > compute cost.
6386///
6387/// `a`: `[M, K]`, `b`: `[K, N]` → `c`: `[M, N]`.
6388#[cfg(feature = "cuda")]
6389pub fn gpu_small_matmul(
6390    a: &CudaBuffer<f32>,
6391    b: &CudaBuffer<f32>,
6392    m: usize,
6393    k: usize,
6394    n: usize,
6395    device: &GpuDevice,
6396) -> GpuResult<CudaBuffer<f32>> {
6397    use cudarc::driver::PushKernelArg;
6398
6399    let total = m * n;
6400    let ctx = device.context();
6401    let stream = device.stream();
6402
6403    let f = match crate::module_cache::get_or_compile(
6404        ctx,
6405        SMALL_MATMUL_PTX,
6406        "small_matmul_kernel",
6407        device.ordinal() as u32,
6408    ) {
6409        Ok(f) => f,
6410        Err(_) => {
6411            // Fall back to cuBLAS if our kernel can't compile.
6412            return crate::blas::gpu_matmul_f32(a, b, m, k, n, device);
6413        }
6414    };
6415
6416    let mut c = alloc_zeros_f32(total, device)?;
6417    let cfg = launch_cfg(total)?;
6418    let m_u32 = m as u32;
6419    let k_u32 = k as u32;
6420    let n_u32 = n as u32;
6421    let total_u32 = total as u32;
6422
6423    unsafe {
6424        stream
6425            .launch_builder(&f)
6426            .arg(a.inner())
6427            .arg(b.inner())
6428            .arg(c.inner_mut())
6429            .arg(&m_u32)
6430            .arg(&k_u32)
6431            .arg(&n_u32)
6432            .arg(&total_u32)
6433            .launch(cfg)?;
6434    }
6435
6436    Ok(c)
6437}
6438
6439/// Small batched matmul: C[i] = A[i] @ B[i] for i in 0..batch.
6440/// Uses the small_matmul_kernel by reshaping the problem: treat it as a single
6441/// large matmul of [batch*M, K] @ [K, N] — but that doesn't work because B is
6442/// batched. Instead, we use a modified approach: thread `idx` computes element
6443/// (batch_i, row, col) where batch_i = idx / (M*N).
6444///
6445/// For simplicity and correctness, we fall back to cpu_bmm for now when
6446/// cuBLAS fails, but route through gpu_small_matmul for the single-matrix case.
6447#[cfg(feature = "cuda")]
6448pub fn gpu_small_bmm(
6449    a: &CudaBuffer<f32>,
6450    b: &CudaBuffer<f32>,
6451    batch: usize,
6452    m: usize,
6453    k: usize,
6454    n: usize,
6455    device: &GpuDevice,
6456) -> GpuResult<CudaBuffer<f32>> {
6457    // For batch=1, just use the single matmul kernel.
6458    if batch == 1 {
6459        return gpu_small_matmul(a, b, m, k, n, device);
6460    }
6461    // For batched case, fall back to cuBLAS (the batched PTX kernel is complex).
6462    // The main win is from the single-matrix decode case (batch=1 for attention scores).
6463    crate::blas::gpu_bmm_f32(a, b, batch, m, k, n, device)
6464}
6465
6466// ---------------------------------------------------------------------------
6467// Public API -- Embedding lookup (GPU-native)
6468// ---------------------------------------------------------------------------
6469
6470/// GPU embedding lookup: reads token ID from `idx` (single f32 on GPU),
6471/// gathers row from `weight` `[V, D]`, writes to `out` `[D]`.
6472/// Entire operation stays on GPU — no CPU involvement.
6473#[cfg(feature = "cuda")]
6474pub fn gpu_embed_lookup(
6475    idx: &CudaBuffer<f32>,
6476    weight: &CudaBuffer<f32>,
6477    d: usize,
6478    device: &GpuDevice,
6479) -> GpuResult<CudaBuffer<f32>> {
6480    use cudarc::driver::PushKernelArg;
6481
6482    let ctx = device.context();
6483    let stream = device.stream();
6484
6485    let f = match crate::module_cache::get_or_compile(
6486        ctx,
6487        EMBED_LOOKUP_PTX,
6488        "embed_lookup_kernel",
6489        device.ordinal() as u32,
6490    ) {
6491        Ok(f) => f,
6492        Err(_) => {
6493            // CPU fallback.
6494            let idx_host = gpu_to_cpu(idx, device)?;
6495            let weight_host = gpu_to_cpu(weight, device)?;
6496            let row = idx_host[0] as usize;
6497            let start = row * d;
6498            let out = weight_host[start..start + d].to_vec();
6499            return cpu_to_gpu(&out, device);
6500        }
6501    };
6502
6503    let mut out = alloc_zeros_f32(d, device)?;
6504    let cfg = launch_cfg(d)?;
6505    let d_u32 = d as u32;
6506
6507    unsafe {
6508        stream
6509            .launch_builder(&f)
6510            .arg(idx.inner())
6511            .arg(weight.inner())
6512            .arg(out.inner_mut())
6513            .arg(&d_u32)
6514            .launch(cfg)?;
6515    }
6516
6517    Ok(out)
6518}
6519
6520// ---------------------------------------------------------------------------
6521// Public API -- Slice write (for KV cache)
6522// ---------------------------------------------------------------------------
6523
6524/// Write `src` of shape `[N, D]` into row `pos` of `dst` of shape `[N, max_len, D]`.
6525/// This is an in-place GPU operation — `dst` is modified.
6526#[cfg(feature = "cuda")]
6527pub fn gpu_slice_write(
6528    src: &CudaBuffer<f32>,
6529    dst: &mut CudaBuffer<f32>,
6530    n_batch: usize,
6531    d: usize,
6532    max_len: usize,
6533    pos: usize,
6534    device: &GpuDevice,
6535) -> GpuResult<()> {
6536    use cudarc::driver::PushKernelArg;
6537
6538    let total = n_batch * d;
6539    let ctx = device.context();
6540    let stream = device.stream();
6541
6542    let f = match crate::module_cache::get_or_compile(
6543        ctx,
6544        SLICE_WRITE_PTX,
6545        "slice_write_kernel",
6546        device.ordinal() as u32,
6547    ) {
6548        Ok(f) => f,
6549        Err(_) => {
6550            // CPU fallback.
6551            let src_host = gpu_to_cpu(src, device)?;
6552            let mut dst_host = gpu_to_cpu(dst, device)?;
6553            for b in 0..n_batch {
6554                for di in 0..d {
6555                    dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
6556                }
6557            }
6558            let new_dst = cpu_to_gpu(&dst_host, device)?;
6559            *dst = new_dst;
6560            return Ok(());
6561        }
6562    };
6563
6564    let cfg = launch_cfg(total)?;
6565    let n_u32 = total as u32;
6566    let d_u32 = d as u32;
6567    let max_len_u32 = max_len as u32;
6568    let pos_u32 = pos as u32;
6569
6570    unsafe {
6571        stream
6572            .launch_builder(&f)
6573            .arg(src.inner())
6574            .arg(dst.inner_mut())
6575            .arg(&n_u32)
6576            .arg(&d_u32)
6577            .arg(&max_len_u32)
6578            .arg(&pos_u32)
6579            .launch(cfg)?;
6580    }
6581
6582    Ok(())
6583}
6584
6585// ---------------------------------------------------------------------------
6586// Public API -- Slice read (for KV cache)
6587// ---------------------------------------------------------------------------
6588
6589/// Read first `len` rows from each batch of `[N, max_len, D]` → `[N, len, D]`.
6590#[cfg(feature = "cuda")]
6591pub fn gpu_slice_read(
6592    src: &CudaBuffer<f32>,
6593    n_batch: usize,
6594    d: usize,
6595    len: usize,
6596    max_len: usize,
6597    device: &GpuDevice,
6598) -> GpuResult<CudaBuffer<f32>> {
6599    use cudarc::driver::PushKernelArg;
6600
6601    let total = n_batch * len * d;
6602    let ctx = device.context();
6603    let stream = device.stream();
6604
6605    let f = match crate::module_cache::get_or_compile(
6606        ctx,
6607        SLICE_READ_PTX,
6608        "slice_read_kernel",
6609        device.ordinal() as u32,
6610    ) {
6611        Ok(f) => f,
6612        Err(_) => {
6613            let host = gpu_to_cpu(src, device)?;
6614            let mut out = vec![0.0f32; total];
6615            for b in 0..n_batch {
6616                for r in 0..len {
6617                    for di in 0..d {
6618                        out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
6619                    }
6620                }
6621            }
6622            return cpu_to_gpu(&out, device);
6623        }
6624    };
6625
6626    let mut out = alloc_zeros_f32(total, device)?;
6627    let cfg = launch_cfg(total)?;
6628    let total_u32 = total as u32;
6629    let d_u32 = d as u32;
6630    let len_u32 = len as u32;
6631    let max_len_u32 = max_len as u32;
6632
6633    unsafe {
6634        stream
6635            .launch_builder(&f)
6636            .arg(src.inner())
6637            .arg(out.inner_mut())
6638            .arg(&total_u32)
6639            .arg(&d_u32)
6640            .arg(&len_u32)
6641            .arg(&max_len_u32)
6642            .launch(cfg)?;
6643    }
6644
6645    Ok(out)
6646}
6647
6648// ---------------------------------------------------------------------------
6649// Public API -- GELU
6650// ---------------------------------------------------------------------------
6651
6652/// Elementwise GELU activation on GPU: `gelu(x) = x * sigmoid(1.702 * x)`.
6653#[cfg(feature = "cuda")]
6654pub fn gpu_gelu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6655    validate_unary(input, device)?;
6656    if let Some(out) = try_launch_unary(input, device, GELU_PTX, "gelu_kernel")? {
6657        return Ok(out);
6658    }
6659    cpu_fallback_unary(input, device, |x| {
6660        let s = 1.0 / (1.0 + (-1.702 * x).exp());
6661        x * s
6662    })
6663}
6664
6665// ---------------------------------------------------------------------------
6666// Public API -- elementwise transcendentals & math ops
6667// ---------------------------------------------------------------------------
6668
6669/// Elementwise division: `out[i] = a[i] / b[i]`.
6670#[cfg(feature = "cuda")]
6671pub fn gpu_div(
6672    a: &CudaBuffer<f32>,
6673    b: &CudaBuffer<f32>,
6674    device: &GpuDevice,
6675) -> GpuResult<CudaBuffer<f32>> {
6676    validate_binary(a, b, device)?;
6677
6678    if let Some(out) = try_launch_binary(a, b, device, DIV_PTX, "div_kernel")? {
6679        return Ok(out);
6680    }
6681
6682    // CPU fallback
6683    let a_host = gpu_to_cpu(a, device)?;
6684    let b_host = gpu_to_cpu(b, device)?;
6685    let result: Vec<f32> = a_host
6686        .iter()
6687        .zip(b_host.iter())
6688        .map(|(&x, &y)| x / y)
6689        .collect();
6690    cpu_to_gpu(&result, device)
6691}
6692
6693/// Elementwise exponential: `out[i] = exp(a[i])`.
6694#[cfg(feature = "cuda")]
6695pub fn gpu_exp(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6696    validate_unary(a, device)?;
6697    if let Some(out) = try_launch_unary(a, device, EXP_PTX, "exp_kernel")? {
6698        return Ok(out);
6699    }
6700    cpu_fallback_unary(a, device, |x| x.exp())
6701}
6702
6703/// Elementwise natural log: `out[i] = ln(a[i])`.
6704#[cfg(feature = "cuda")]
6705pub fn gpu_log(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6706    validate_unary(a, device)?;
6707    if let Some(out) = try_launch_unary(a, device, LOG_PTX, "log_kernel")? {
6708        return Ok(out);
6709    }
6710    cpu_fallback_unary(a, device, |x| x.ln())
6711}
6712
6713/// Elementwise square root: `out[i] = sqrt(a[i])`.
6714#[cfg(feature = "cuda")]
6715pub fn gpu_sqrt(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6716    validate_unary(a, device)?;
6717    if let Some(out) = try_launch_unary(a, device, SQRT_PTX, "sqrt_kernel")? {
6718        return Ok(out);
6719    }
6720    cpu_fallback_unary(a, device, |x| x.sqrt())
6721}
6722
6723/// Elementwise power: `out[i] = a[i] ^ exponent`.
6724#[cfg(feature = "cuda")]
6725pub fn gpu_pow(
6726    a: &CudaBuffer<f32>,
6727    exponent: f32,
6728    device: &GpuDevice,
6729) -> GpuResult<CudaBuffer<f32>> {
6730    use cudarc::driver::PushKernelArg;
6731
6732    validate_unary(a, device)?;
6733
6734    let n = a.len();
6735    let ctx = device.context();
6736    let stream = device.stream();
6737
6738    let f = match crate::module_cache::get_or_compile(
6739        ctx,
6740        POW_PTX,
6741        "pow_kernel",
6742        device.ordinal() as u32,
6743    ) {
6744        Ok(f) => f,
6745        Err(_) => {
6746            let host = gpu_to_cpu(a, device)?;
6747            let result: Vec<f32> = host.iter().map(|&x| x.powf(exponent)).collect();
6748            return cpu_to_gpu(&result, device);
6749        }
6750    };
6751
6752    let mut out = alloc_zeros_f32(n, device)?;
6753    let cfg = launch_cfg(n)?;
6754    let n_u32 = n as u32;
6755
6756    unsafe {
6757        stream
6758            .launch_builder(&f)
6759            .arg(a.inner())
6760            .arg(out.inner_mut())
6761            .arg(&exponent)
6762            .arg(&n_u32)
6763            .launch(cfg)?;
6764    }
6765
6766    Ok(out)
6767}
6768
6769/// Elementwise absolute value: `out[i] = |a[i]|`.
6770#[cfg(feature = "cuda")]
6771pub fn gpu_abs(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6772    validate_unary(a, device)?;
6773    if let Some(out) = try_launch_unary(a, device, ABS_PTX, "abs_kernel")? {
6774        return Ok(out);
6775    }
6776    cpu_fallback_unary(a, device, |x| x.abs())
6777}
6778
6779/// Elementwise sigmoid: `out[i] = 1 / (1 + exp(-a[i]))`.
6780#[cfg(feature = "cuda")]
6781pub fn gpu_sigmoid(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6782    validate_unary(a, device)?;
6783    if let Some(out) = try_launch_unary(a, device, SIGMOID_PTX, "sigmoid_kernel")? {
6784        return Ok(out);
6785    }
6786    cpu_fallback_unary(a, device, |x| 1.0 / (1.0 + (-x).exp()))
6787}
6788
6789/// Elementwise tanh: `out[i] = tanh(a[i])`.
6790#[cfg(feature = "cuda")]
6791pub fn gpu_tanh(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6792    validate_unary(a, device)?;
6793    if let Some(out) = try_launch_unary(a, device, TANH_PTX, "tanh_kernel")? {
6794        return Ok(out);
6795    }
6796    cpu_fallback_unary(a, device, |x| x.tanh())
6797}
6798
6799// ---------------------------------------------------------------------------
6800// Public API -- fused Adam optimizer step
6801// ---------------------------------------------------------------------------
6802
6803/// Fused Adam optimizer step: updates param, exp_avg, and exp_avg_sq in-place
6804/// in a single kernel launch.
6805///
6806/// All four buffers must have the same length `n`. `param`, `exp_avg`, and
6807/// `exp_avg_sq` are modified in-place. `grad` is read-only.
6808#[cfg(feature = "cuda")]
6809#[allow(clippy::too_many_arguments)]
6810pub fn gpu_fused_adam(
6811    param: &mut CudaBuffer<f32>,
6812    grad: &CudaBuffer<f32>,
6813    exp_avg: &mut CudaBuffer<f32>,
6814    exp_avg_sq: &mut CudaBuffer<f32>,
6815    beta1: f32,
6816    beta2: f32,
6817    lr: f32,
6818    eps: f32,
6819    bc1: f32,
6820    bc2: f32,
6821    weight_decay: f32,
6822    device: &GpuDevice,
6823) -> GpuResult<()> {
6824    use cudarc::driver::PushKernelArg;
6825
6826    let n = param.len();
6827    if grad.len() != n || exp_avg.len() != n || exp_avg_sq.len() != n {
6828        return Err(GpuError::LengthMismatch {
6829            a: n,
6830            b: grad.len(),
6831        });
6832    }
6833
6834    let ctx = device.context();
6835    let stream = device.stream();
6836
6837    let f = match crate::module_cache::get_or_compile(
6838        ctx,
6839        FUSED_ADAM_PTX,
6840        "fused_adam_kernel",
6841        device.ordinal() as u32,
6842    ) {
6843        Ok(f) => f,
6844        Err(_) => {
6845            // CPU fallback: download, compute, upload.
6846            let mut p_host = gpu_to_cpu(param, device)?;
6847            let g_host = gpu_to_cpu(grad, device)?;
6848            let mut m_host = gpu_to_cpu(exp_avg, device)?;
6849            let mut v_host = gpu_to_cpu(exp_avg_sq, device)?;
6850
6851            for i in 0..n {
6852                let mut g = g_host[i];
6853                if weight_decay > 0.0 {
6854                    g += weight_decay * p_host[i];
6855                }
6856                m_host[i] = beta1 * m_host[i] + (1.0 - beta1) * g;
6857                v_host[i] = beta2 * v_host[i] + (1.0 - beta2) * g * g;
6858                let m_hat = m_host[i] / bc1;
6859                let v_hat = v_host[i] / bc2;
6860                p_host[i] -= lr * m_hat / (v_hat.sqrt() + eps);
6861            }
6862
6863            *param = cpu_to_gpu(&p_host, device)?;
6864            *exp_avg = cpu_to_gpu(&m_host, device)?;
6865            *exp_avg_sq = cpu_to_gpu(&v_host, device)?;
6866            return Ok(());
6867        }
6868    };
6869
6870    let cfg = launch_cfg(n)?;
6871    let n_u32 = n as u32;
6872
6873    unsafe {
6874        stream
6875            .launch_builder(&f)
6876            .arg(param.inner_mut())
6877            .arg(grad.inner())
6878            .arg(exp_avg.inner_mut())
6879            .arg(exp_avg_sq.inner_mut())
6880            .arg(&beta1)
6881            .arg(&beta2)
6882            .arg(&lr)
6883            .arg(&eps)
6884            .arg(&bc1)
6885            .arg(&bc2)
6886            .arg(&weight_decay)
6887            .arg(&n_u32)
6888            .launch(cfg)?;
6889    }
6890
6891    Ok(())
6892}
6893
6894/// Stub -- always returns [`GpuError::NoCudaFeature`].
6895#[cfg(not(feature = "cuda"))]
6896#[allow(clippy::too_many_arguments)]
6897pub fn gpu_fused_adam(
6898    _param: &mut CudaBuffer<f32>,
6899    _grad: &CudaBuffer<f32>,
6900    _exp_avg: &mut CudaBuffer<f32>,
6901    _exp_avg_sq: &mut CudaBuffer<f32>,
6902    _beta1: f32,
6903    _beta2: f32,
6904    _lr: f32,
6905    _eps: f32,
6906    _bc1: f32,
6907    _bc2: f32,
6908    _weight_decay: f32,
6909    _device: &GpuDevice,
6910) -> GpuResult<()> {
6911    Err(GpuError::NoCudaFeature)
6912}
6913
6914// ---------------------------------------------------------------------------
6915// Public API -- fused GRU cell
6916// ---------------------------------------------------------------------------
6917
6918/// Fused GRU cell forward: takes pre-computed gate matrices and produces
6919/// new hidden state + workspace for backward.
6920///
6921/// Inputs:
6922/// - `input_gates`: `[batch, 3*hsz]` — result of `x @ W_ih^T`
6923/// - `hidden_gates`: `[batch, 3*hsz]` — result of `h @ W_hh^T`
6924/// - `bias_ih`: `[3*hsz]` — input bias
6925/// - `bias_hh`: `[3*hsz]` — hidden bias
6926/// - `hx`: `[batch, hsz]` — previous hidden state
6927///
6928/// Outputs:
6929/// - `hy`: `[batch, hsz]` — new hidden state
6930/// - `workspace`: `[batch, 5*hsz]` — saved for backward (r, z, n, hx, hn+b2n)
6931#[cfg(feature = "cuda")]
6932pub fn gpu_fused_gru_forward(
6933    input_gates: &CudaBuffer<f32>,
6934    hidden_gates: &CudaBuffer<f32>,
6935    bias_ih: &CudaBuffer<f32>,
6936    bias_hh: &CudaBuffer<f32>,
6937    hx: &CudaBuffer<f32>,
6938    hsz: usize,
6939    device: &GpuDevice,
6940) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
6941    use cudarc::driver::PushKernelArg;
6942
6943    let total = hx.len(); // batch * hsz
6944    let batch = total / hsz;
6945
6946    let ctx = device.context();
6947    let stream = device.stream();
6948
6949    let f = match crate::module_cache::get_or_compile(
6950        ctx,
6951        FUSED_GRU_FORWARD_PTX,
6952        "fused_gru_forward_kernel",
6953        device.ordinal() as u32,
6954    ) {
6955        Ok(f) => f,
6956        Err(_) => {
6957            return Err(GpuError::PtxCompileFailed {
6958                kernel: "fused_gru_forward_kernel",
6959            });
6960        }
6961    };
6962
6963    let mut hy = alloc_zeros_f32(total, device)?;
6964    let mut workspace = alloc_zeros_f32(batch * 5 * hsz, device)?;
6965
6966    let cfg = launch_cfg(total)?;
6967    let hsz_u32 = hsz as u32;
6968    let total_u32 = total as u32;
6969
6970    unsafe {
6971        stream
6972            .launch_builder(&f)
6973            .arg(input_gates.inner())
6974            .arg(hidden_gates.inner())
6975            .arg(bias_ih.inner())
6976            .arg(bias_hh.inner())
6977            .arg(hx.inner())
6978            .arg(hy.inner_mut())
6979            .arg(workspace.inner_mut())
6980            .arg(&hsz_u32)
6981            .arg(&total_u32)
6982            .launch(cfg)?;
6983    }
6984
6985    Ok((hy, workspace))
6986}
6987
6988/// Stub -- always returns [`GpuError::NoCudaFeature`].
6989#[cfg(not(feature = "cuda"))]
6990pub fn gpu_fused_gru_forward(
6991    _input_gates: &CudaBuffer<f32>,
6992    _hidden_gates: &CudaBuffer<f32>,
6993    _bias_ih: &CudaBuffer<f32>,
6994    _bias_hh: &CudaBuffer<f32>,
6995    _hx: &CudaBuffer<f32>,
6996    _hsz: usize,
6997    _device: &GpuDevice,
6998) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
6999    Err(GpuError::NoCudaFeature)
7000}
7001
7002// ---------------------------------------------------------------------------
7003// Public API -- MaxPool2d / AvgPool2d
7004// ---------------------------------------------------------------------------
7005
7006/// MaxPool2d forward on GPU. One thread per output element.
7007#[cfg(feature = "cuda")]
7008#[allow(clippy::too_many_arguments)]
7009pub fn gpu_maxpool2d(
7010    input: &CudaBuffer<f32>,
7011    batch: usize,
7012    channels: usize,
7013    h_in: usize,
7014    w_in: usize,
7015    kh: usize,
7016    kw: usize,
7017    sh: usize,
7018    sw: usize,
7019    ph: usize,
7020    pw: usize,
7021    device: &GpuDevice,
7022) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7023    use cudarc::driver::PushKernelArg;
7024
7025    let h_out = (h_in + 2 * ph - kh) / sh + 1;
7026    let w_out = (w_in + 2 * pw - kw) / sw + 1;
7027    let total = batch * channels * h_out * w_out;
7028
7029    let ctx = device.context();
7030    let stream = device.stream();
7031
7032    let f = match crate::module_cache::get_or_compile(
7033        ctx, MAXPOOL2D_PTX, "maxpool2d_forward_kernel", device.ordinal() as u32,
7034    ) {
7035        Ok(f) => f,
7036        Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "maxpool2d_forward_kernel" }),
7037    };
7038
7039    let mut out = alloc_zeros_f32(total, device)?;
7040    let cfg = launch_cfg(total)?;
7041
7042    let (batch_u32, ch_u32) = (batch as u32, channels as u32);
7043    let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
7044    let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
7045    let (kh_u32, kw_u32) = (kh as u32, kw as u32);
7046    let (sh_u32, sw_u32) = (sh as u32, sw as u32);
7047    let (ph_u32, pw_u32) = (ph as u32, pw as u32);
7048    let total_u32 = total as u32;
7049
7050    unsafe {
7051        stream.launch_builder(&f)
7052            .arg(input.inner())
7053            .arg(out.inner_mut())
7054            .arg(&batch_u32).arg(&ch_u32)
7055            .arg(&h_in_u32).arg(&w_in_u32)
7056            .arg(&h_out_u32).arg(&w_out_u32)
7057            .arg(&kh_u32).arg(&kw_u32)
7058            .arg(&sh_u32).arg(&sw_u32)
7059            .arg(&ph_u32).arg(&pw_u32)
7060            .arg(&total_u32)
7061            .launch(cfg)?;
7062    }
7063
7064    Ok((out, [batch, channels, h_out, w_out]))
7065}
7066
7067/// Stub.
7068#[cfg(not(feature = "cuda"))]
7069#[allow(clippy::too_many_arguments)]
7070pub fn gpu_maxpool2d(
7071    _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
7072    _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
7073    _sh: usize, _sw: usize, _ph: usize, _pw: usize,
7074    _device: &GpuDevice,
7075) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7076    Err(GpuError::NoCudaFeature)
7077}
7078
7079/// AvgPool2d forward on GPU. One thread per output element.
7080#[cfg(feature = "cuda")]
7081#[allow(clippy::too_many_arguments)]
7082pub fn gpu_avgpool2d(
7083    input: &CudaBuffer<f32>,
7084    batch: usize,
7085    channels: usize,
7086    h_in: usize,
7087    w_in: usize,
7088    kh: usize,
7089    kw: usize,
7090    sh: usize,
7091    sw: usize,
7092    ph: usize,
7093    pw: usize,
7094    device: &GpuDevice,
7095) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7096    use cudarc::driver::PushKernelArg;
7097
7098    let h_out = (h_in + 2 * ph - kh) / sh + 1;
7099    let w_out = (w_in + 2 * pw - kw) / sw + 1;
7100    let total = batch * channels * h_out * w_out;
7101
7102    let ctx = device.context();
7103    let stream = device.stream();
7104
7105    let f = match crate::module_cache::get_or_compile(
7106        ctx, AVGPOOL2D_PTX, "avgpool2d_forward_kernel", device.ordinal() as u32,
7107    ) {
7108        Ok(f) => f,
7109        Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "avgpool2d_forward_kernel" }),
7110    };
7111
7112    let mut out = alloc_zeros_f32(total, device)?;
7113    let cfg = launch_cfg(total)?;
7114
7115    let (batch_u32, ch_u32) = (batch as u32, channels as u32);
7116    let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
7117    let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
7118    let (kh_u32, kw_u32) = (kh as u32, kw as u32);
7119    let (sh_u32, sw_u32) = (sh as u32, sw as u32);
7120    let (ph_u32, pw_u32) = (ph as u32, pw as u32);
7121    let total_u32 = total as u32;
7122
7123    unsafe {
7124        stream.launch_builder(&f)
7125            .arg(input.inner())
7126            .arg(out.inner_mut())
7127            .arg(&batch_u32).arg(&ch_u32)
7128            .arg(&h_in_u32).arg(&w_in_u32)
7129            .arg(&h_out_u32).arg(&w_out_u32)
7130            .arg(&kh_u32).arg(&kw_u32)
7131            .arg(&sh_u32).arg(&sw_u32)
7132            .arg(&ph_u32).arg(&pw_u32)
7133            .arg(&total_u32)
7134            .launch(cfg)?;
7135    }
7136
7137    Ok((out, [batch, channels, h_out, w_out]))
7138}
7139
7140/// Stub.
7141#[cfg(not(feature = "cuda"))]
7142#[allow(clippy::too_many_arguments)]
7143pub fn gpu_avgpool2d(
7144    _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
7145    _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
7146    _sh: usize, _sw: usize, _ph: usize, _pw: usize,
7147    _device: &GpuDevice,
7148) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7149    Err(GpuError::NoCudaFeature)
7150}
7151
7152// ---------------------------------------------------------------------------
7153// Public API -- BatchNorm2d
7154// ---------------------------------------------------------------------------
7155
7156/// BatchNorm2d forward on GPU (placeholder — kernel pass-1 indexing needs
7157/// refinement). Currently validates the kernel compiles and falls back to
7158/// returning an error so callers use the CPU path.
7159#[cfg(feature = "cuda")]
7160#[allow(clippy::too_many_arguments)]
7161pub fn gpu_batchnorm_forward(
7162    _input: &CudaBuffer<f32>,
7163    _weight: &CudaBuffer<f32>,
7164    _bias: &CudaBuffer<f32>,
7165    _running_mean: &mut CudaBuffer<f32>,
7166    _running_var: &mut CudaBuffer<f32>,
7167    _channels: usize,
7168    _spatial: usize,
7169    _eps: f32,
7170    _momentum: f32,
7171    _training: bool,
7172    device: &GpuDevice,
7173) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7174    // Validate the PTX compiles (catches syntax errors at first call).
7175    let ctx = device.context();
7176    let _f = crate::module_cache::get_or_compile(
7177        ctx,
7178        BATCHNORM_FORWARD_PTX,
7179        "batchnorm_forward_kernel",
7180        device.ordinal() as u32,
7181    );
7182    // Full implementation pending — pass-1 loop indexing needs refinement.
7183    Err(GpuError::ShapeMismatch {
7184        op: "batchnorm_forward",
7185        expected: vec![0],
7186        got: vec![1],
7187    })
7188}
7189
7190/// Stub.
7191#[cfg(not(feature = "cuda"))]
7192#[allow(clippy::too_many_arguments)]
7193pub fn gpu_batchnorm_forward(
7194    _input: &CudaBuffer<f32>,
7195    _weight: &CudaBuffer<f32>,
7196    _bias: &CudaBuffer<f32>,
7197    _running_mean: &mut CudaBuffer<f32>,
7198    _running_var: &mut CudaBuffer<f32>,
7199    _channels: usize,
7200    _spatial: usize,
7201    _eps: f32,
7202    _momentum: f32,
7203    _training: bool,
7204    _device: &GpuDevice,
7205) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7206    Err(GpuError::NoCudaFeature)
7207}
7208
7209// ---------------------------------------------------------------------------
7210// Public API -- LayerNorm
7211// ---------------------------------------------------------------------------
7212
7213/// Row-wise layer normalization on GPU.
7214///
7215/// `input`: `[rows * cols]`, `weight`/`bias`: `[cols]`.
7216/// Output: normalized and affine-transformed `[rows * cols]`.
7217#[cfg(feature = "cuda")]
7218pub fn gpu_layernorm(
7219    input: &CudaBuffer<f32>,
7220    weight: &CudaBuffer<f32>,
7221    bias: &CudaBuffer<f32>,
7222    rows: usize,
7223    cols: usize,
7224    eps: f32,
7225    device: &GpuDevice,
7226) -> GpuResult<CudaBuffer<f32>> {
7227    use cudarc::driver::PushKernelArg;
7228
7229    validate_unary(input, device)?;
7230
7231    let ctx = device.context();
7232    let stream = device.stream();
7233
7234    let f = match crate::module_cache::get_or_compile(
7235        ctx,
7236        LAYERNORM_PTX,
7237        "layernorm_kernel",
7238        device.ordinal() as u32,
7239    ) {
7240        Ok(f) => f,
7241        Err(e) => {
7242            eprintln!("ferrotorch-gpu: LayerNorm PTX compilation failed ({e:?}), CPU fallback");
7243            std::fs::write("/tmp/layernorm_debug.ptx", LAYERNORM_PTX).ok();
7244            eprintln!(
7245                "ferrotorch-gpu: dumped PTX to /tmp/layernorm_debug.ptx ({} bytes)",
7246                LAYERNORM_PTX.len()
7247            );
7248            let h_in = gpu_to_cpu(input, device)?;
7249            let h_w = gpu_to_cpu(weight, device)?;
7250            let h_b = gpu_to_cpu(bias, device)?;
7251            let mut out = vec![0.0f32; rows * cols];
7252            for r in 0..rows {
7253                let base = r * cols;
7254                let slice = &h_in[base..base + cols];
7255                let mean: f32 = slice.iter().sum::<f32>() / cols as f32;
7256                let var: f32 =
7257                    slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / cols as f32;
7258                let inv_std = 1.0 / (var + eps).sqrt();
7259                for c in 0..cols {
7260                    let normed = (slice[c] - mean) * inv_std;
7261                    out[base + c] = h_w[c] * normed + h_b[c];
7262                }
7263            }
7264            return cpu_to_gpu(&out, device);
7265        }
7266    };
7267
7268    let mut out = alloc_zeros_f32(rows * cols, device)?;
7269    let rows_u32 = rows as u32;
7270    let cols_u32 = cols as u32;
7271
7272    let cfg = LaunchConfig {
7273        grid_dim: ((rows as u32).max(1), 1, 1),
7274        block_dim: (256, 1, 1),
7275        shared_mem_bytes: 256 * 4,
7276    };
7277
7278    unsafe {
7279        stream
7280            .launch_builder(&f)
7281            .arg(input.inner())
7282            .arg(out.inner_mut())
7283            .arg(weight.inner())
7284            .arg(bias.inner())
7285            .arg(&rows_u32)
7286            .arg(&cols_u32)
7287            .arg(&eps)
7288            .launch(cfg)?;
7289    }
7290
7291    Ok(out)
7292}
7293
7294// ---------------------------------------------------------------------------
7295// Public API -- LayerNorm backward
7296// ---------------------------------------------------------------------------
7297
7298/// LayerNorm backward pass on GPU.
7299///
7300/// Computes grad_input, grad_weight, and grad_bias entirely on GPU.
7301/// One block per batch element (row), 256 threads per block.
7302/// grad_weight and grad_bias are accumulated across batches via atomicAdd.
7303///
7304/// `input`: `[rows * cols]`, `grad_output`: `[rows * cols]`, `weight`: `[cols]`.
7305/// Returns: `(grad_input [rows * cols], grad_weight [cols], grad_bias [cols])`.
7306#[cfg(feature = "cuda")]
7307pub fn gpu_layernorm_backward(
7308    input: &CudaBuffer<f32>,
7309    grad_output: &CudaBuffer<f32>,
7310    weight: &CudaBuffer<f32>,
7311    rows: usize,
7312    cols: usize,
7313    eps: f32,
7314    device: &GpuDevice,
7315) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7316    use cudarc::driver::PushKernelArg;
7317
7318    validate_unary(input, device)?;
7319
7320    let ctx = device.context();
7321    let stream = device.stream();
7322
7323    let f = match crate::module_cache::get_or_compile(
7324        ctx,
7325        LAYERNORM_BACKWARD_PTX,
7326        "layernorm_backward_kernel",
7327        device.ordinal() as u32,
7328    ) {
7329        Ok(f) => f,
7330        Err(_) => {
7331            // CPU fallback
7332            let h_in = gpu_to_cpu(input, device)?;
7333            let h_go = gpu_to_cpu(grad_output, device)?;
7334            let h_w = gpu_to_cpu(weight, device)?;
7335            let mut grad_input = vec![0.0f32; rows * cols];
7336            let mut grad_weight = vec![0.0f32; cols];
7337            let mut grad_bias = vec![0.0f32; cols];
7338            let n_f = cols as f32;
7339            for r in 0..rows {
7340                let base = r * cols;
7341                let x_slice = &h_in[base..base + cols];
7342                let go_slice = &h_go[base..base + cols];
7343                let mean: f32 = x_slice.iter().sum::<f32>() / n_f;
7344                let var: f32 = x_slice
7345                    .iter()
7346                    .map(|&x| (x - mean) * (x - mean))
7347                    .sum::<f32>()
7348                    / n_f;
7349                let inv_std = 1.0 / (var + eps).sqrt();
7350                let mut sum1 = 0.0f32;
7351                let mut sum2 = 0.0f32;
7352                for c in 0..cols {
7353                    let x_hat = (x_slice[c] - mean) * inv_std;
7354                    let dl = go_slice[c] * h_w[c];
7355                    sum1 += dl;
7356                    sum2 += dl * x_hat;
7357                    grad_weight[c] += go_slice[c] * x_hat;
7358                    grad_bias[c] += go_slice[c];
7359                }
7360                let m1 = sum1 / n_f;
7361                let m2 = sum2 / n_f;
7362                for c in 0..cols {
7363                    let x_hat = (x_slice[c] - mean) * inv_std;
7364                    let dl = go_slice[c] * h_w[c];
7365                    grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
7366                }
7367            }
7368            let gi = cpu_to_gpu(&grad_input, device)?;
7369            let gw = cpu_to_gpu(&grad_weight, device)?;
7370            let gb = cpu_to_gpu(&grad_bias, device)?;
7371            return Ok((gi, gw, gb));
7372        }
7373    };
7374
7375    let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
7376    let mut grad_w = alloc_zeros_f32(cols, device)?;
7377    let mut grad_b = alloc_zeros_f32(cols, device)?;
7378    let rows_u32 = rows as u32;
7379    let cols_u32 = cols as u32;
7380
7381    // One block per row, 256 threads per block.
7382    let cfg = LaunchConfig {
7383        grid_dim: ((rows as u32).max(1), 1, 1),
7384        block_dim: (256, 1, 1),
7385        shared_mem_bytes: 256 * 4,
7386    };
7387
7388    unsafe {
7389        stream
7390            .launch_builder(&f)
7391            .arg(input.inner())
7392            .arg(grad_output.inner())
7393            .arg(weight.inner())
7394            .arg(grad_in.inner_mut())
7395            .arg(grad_w.inner_mut())
7396            .arg(grad_b.inner_mut())
7397            .arg(&rows_u32)
7398            .arg(&cols_u32)
7399            .arg(&eps)
7400            .launch(cfg)?;
7401    }
7402
7403    Ok((grad_in, grad_w, grad_b))
7404}
7405
7406/// Stub -- always returns [`GpuError::NoCudaFeature`].
7407#[cfg(not(feature = "cuda"))]
7408pub fn gpu_layernorm_backward(
7409    _input: &CudaBuffer<f32>,
7410    _grad_output: &CudaBuffer<f32>,
7411    _weight: &CudaBuffer<f32>,
7412    _rows: usize,
7413    _cols: usize,
7414    _eps: f32,
7415    _device: &GpuDevice,
7416) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7417    Err(GpuError::NoCudaFeature)
7418}
7419
7420// ===========================================================================
7421// _into variants — write to pre-allocated output buffers (zero allocation)
7422//
7423// These are used for CUDA graph capture, where all buffer addresses must be
7424// fixed at capture time. The PTX kernels are identical — only the Rust
7425// wrapper skips allocation.
7426// ===========================================================================
7427
7428/// Elementwise add into pre-allocated output: `out[i] = a[i] + b[i]`.
7429#[cfg(feature = "cuda")]
7430pub fn gpu_add_into(
7431    a: &CudaBuffer<f32>,
7432    b: &CudaBuffer<f32>,
7433    out: &mut CudaBuffer<f32>,
7434    device: &GpuDevice,
7435) -> GpuResult<()> {
7436    validate_binary(a, b, device)?;
7437    if out.len() < a.len() {
7438        return Err(GpuError::ShapeMismatch {
7439            op: "add_into",
7440            expected: vec![a.len()],
7441            got: vec![out.len()],
7442        });
7443    }
7444    if try_launch_binary_into(a, b, out, device, ADD_PTX, "add_kernel")? {
7445        return Ok(());
7446    }
7447    Err(GpuError::PtxCompileFailed {
7448        kernel: "add_kernel",
7449    })
7450}
7451
7452/// Elementwise mul into pre-allocated output: `out[i] = a[i] * b[i]`.
7453#[cfg(feature = "cuda")]
7454pub fn gpu_mul_into(
7455    a: &CudaBuffer<f32>,
7456    b: &CudaBuffer<f32>,
7457    out: &mut CudaBuffer<f32>,
7458    device: &GpuDevice,
7459) -> GpuResult<()> {
7460    validate_binary(a, b, device)?;
7461    if out.len() < a.len() {
7462        return Err(GpuError::ShapeMismatch {
7463            op: "mul_into",
7464            expected: vec![a.len()],
7465            got: vec![out.len()],
7466        });
7467    }
7468    if try_launch_binary_into(a, b, out, device, MUL_PTX, "mul_kernel")? {
7469        return Ok(());
7470    }
7471    Err(GpuError::PtxCompileFailed {
7472        kernel: "mul_kernel",
7473    })
7474}
7475
7476/// Scalar multiply into pre-allocated output: `out[i] = a[i] * scalar`.
7477#[cfg(feature = "cuda")]
7478pub fn gpu_scale_into(
7479    a: &CudaBuffer<f32>,
7480    scalar: f32,
7481    out: &mut CudaBuffer<f32>,
7482    device: &GpuDevice,
7483) -> GpuResult<()> {
7484    use cudarc::driver::PushKernelArg;
7485    validate_unary(a, device)?;
7486    let n = a.len();
7487    let ctx = device.context();
7488    let stream = device.stream();
7489    let f = crate::module_cache::get_or_compile(
7490        ctx,
7491        SCALE_PTX,
7492        "scale_kernel",
7493        device.ordinal() as u32,
7494    )
7495    .map_err(|_| GpuError::PtxCompileFailed {
7496        kernel: "scale_kernel",
7497    })?;
7498    let cfg = launch_cfg(n)?;
7499    let n_u32 = n as u32;
7500    unsafe {
7501        stream
7502            .launch_builder(&f)
7503            .arg(a.inner())
7504            .arg(out.inner_mut())
7505            .arg(&scalar)
7506            .arg(&n_u32)
7507            .launch(cfg)?;
7508    }
7509    Ok(())
7510}
7511
7512/// Check whether a GPU buffer contains any inf or NaN values.
7513///
7514/// Downloads the buffer contents to the host and scans for non-finite
7515/// values. This is correct for any buffer size and requires no custom
7516/// reduction kernel.
7517///
7518/// For a future optimization, a dedicated GPU reduction kernel could be
7519/// used to produce a single boolean flag on device, avoiding the full
7520/// download. The current approach is already much faster than the old
7521/// per-element CPU loop in `unscale_()` because the scaling itself
7522/// runs on GPU — only the inf/NaN check touches the host.
7523///
7524/// # Errors
7525///
7526/// - [`GpuError::DeviceMismatch`] if `a` and `device` refer to different CUDA devices.
7527/// - [`GpuError::Driver`] on CUDA runtime errors.
7528#[cfg(feature = "cuda")]
7529pub fn gpu_has_inf_nan(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<bool> {
7530    let n = a.len();
7531    if n == 0 {
7532        return Ok(false);
7533    }
7534
7535    validate_unary(a, device)?;
7536
7537    let host: Vec<f32> = crate::transfer::gpu_to_cpu(a, device)?;
7538    Ok(host.iter().any(|v| !v.is_finite()))
7539}
7540
7541/// Stub -- always returns [`GpuError::NoCudaFeature`].
7542#[cfg(not(feature = "cuda"))]
7543pub fn gpu_has_inf_nan(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<bool> {
7544    Err(GpuError::NoCudaFeature)
7545}
7546
7547/// GELU into pre-allocated output.
7548#[cfg(feature = "cuda")]
7549pub fn gpu_gelu_into(
7550    a: &CudaBuffer<f32>,
7551    out: &mut CudaBuffer<f32>,
7552    device: &GpuDevice,
7553) -> GpuResult<()> {
7554    validate_unary(a, device)?;
7555    if try_launch_unary_into(a, out, device, GELU_PTX, "gelu_kernel")? {
7556        return Ok(());
7557    }
7558    Err(GpuError::PtxCompileFailed {
7559        kernel: "gelu_kernel",
7560    })
7561}
7562
7563/// Embedding lookup into pre-allocated output.
7564#[cfg(feature = "cuda")]
7565pub fn gpu_embed_lookup_into(
7566    idx: &CudaBuffer<f32>,
7567    weight: &CudaBuffer<f32>,
7568    d: usize,
7569    out: &mut CudaBuffer<f32>,
7570    device: &GpuDevice,
7571) -> GpuResult<()> {
7572    use cudarc::driver::PushKernelArg;
7573    let ctx = device.context();
7574    let stream = device.stream();
7575    let f = crate::module_cache::get_or_compile(
7576        ctx,
7577        EMBED_LOOKUP_PTX,
7578        "embed_lookup_kernel",
7579        device.ordinal() as u32,
7580    )
7581    .map_err(|_| GpuError::PtxCompileFailed {
7582        kernel: "embed_lookup_kernel",
7583    })?;
7584    let cfg = launch_cfg(d)?;
7585    let d_u32 = d as u32;
7586    unsafe {
7587        stream
7588            .launch_builder(&f)
7589            .arg(idx.inner())
7590            .arg(weight.inner())
7591            .arg(out.inner_mut())
7592            .arg(&d_u32)
7593            .launch(cfg)?;
7594    }
7595    Ok(())
7596}
7597
7598// ---------------------------------------------------------------------------
7599// Public API -- Batch embedding lookup (GPU-native)
7600// ---------------------------------------------------------------------------
7601
7602/// GPU batch embedding lookup: given `indices` (N f32 values on GPU) and
7603/// `weight` `[V, D]`, gather N rows to produce output `[N, D]`.
7604/// Entire operation stays on GPU -- no CPU roundtrip.
7605#[cfg(feature = "cuda")]
7606pub fn gpu_embed_lookup_batch(
7607    indices: &CudaBuffer<f32>,
7608    weight: &CudaBuffer<f32>,
7609    n: usize,
7610    d: usize,
7611    device: &GpuDevice,
7612) -> GpuResult<CudaBuffer<f32>> {
7613    use cudarc::driver::PushKernelArg;
7614
7615    let total = n * d;
7616    if total == 0 {
7617        return alloc_zeros_f32(0, device);
7618    }
7619
7620    let ctx = device.context();
7621    let stream = device.stream();
7622
7623    let f = match crate::module_cache::get_or_compile(
7624        ctx,
7625        EMBED_LOOKUP_BATCH_PTX,
7626        "embed_lookup_batch_kernel",
7627        device.ordinal() as u32,
7628    ) {
7629        Ok(f) => f,
7630        Err(_) => {
7631            // CPU fallback.
7632            let idx_host = gpu_to_cpu(indices, device)?;
7633            let weight_host = gpu_to_cpu(weight, device)?;
7634            let mut out = Vec::with_capacity(total);
7635            for &idx_f in &idx_host {
7636                let row = idx_f as usize;
7637                let start = row * d;
7638                out.extend_from_slice(&weight_host[start..start + d]);
7639            }
7640            return cpu_to_gpu(&out, device);
7641        }
7642    };
7643
7644    let mut out = alloc_zeros_f32(total, device)?;
7645    let cfg = launch_cfg(total)?;
7646    let d_u32 = d as u32;
7647    let total_u32 = total as u32;
7648
7649    unsafe {
7650        stream
7651            .launch_builder(&f)
7652            .arg(indices.inner())
7653            .arg(weight.inner())
7654            .arg(out.inner_mut())
7655            .arg(&d_u32)
7656            .arg(&total_u32)
7657            .launch(cfg)?;
7658    }
7659
7660    Ok(out)
7661}
7662
7663// ---------------------------------------------------------------------------
7664// Public API -- Scatter-add rows (for embedding backward, GPU-native)
7665// ---------------------------------------------------------------------------
7666
7667/// GPU scatter-add rows: given `grad_output` `[N, D]` and `indices` `[N]` (f32),
7668/// atomically accumulate into `grad_weight` `[V, D]` (pre-zeroed):
7669///   `grad_weight[indices[i], :] += grad_output[i, :]`
7670///
7671/// Duplicate indices accumulate correctly via atomic adds.
7672#[cfg(feature = "cuda")]
7673pub fn gpu_scatter_add_rows(
7674    grad_output: &CudaBuffer<f32>,
7675    indices: &CudaBuffer<f32>,
7676    num_embeddings: usize,
7677    d: usize,
7678    device: &GpuDevice,
7679) -> GpuResult<CudaBuffer<f32>> {
7680    use cudarc::driver::PushKernelArg;
7681
7682    let n = indices.len();
7683    let total = n * d;
7684
7685    if total == 0 {
7686        return alloc_zeros_f32(num_embeddings * d, device);
7687    }
7688
7689    let ctx = device.context();
7690    let stream = device.stream();
7691
7692    let f = match crate::module_cache::get_or_compile(
7693        ctx,
7694        SCATTER_ADD_ROWS_PTX,
7695        "scatter_add_rows_kernel",
7696        device.ordinal() as u32,
7697    ) {
7698        Ok(f) => f,
7699        Err(_) => {
7700            // CPU fallback.
7701            let go_host = gpu_to_cpu(grad_output, device)?;
7702            let idx_host = gpu_to_cpu(indices, device)?;
7703            let mut result = vec![0.0f32; num_embeddings * d];
7704            for (i, &idx_f) in idx_host.iter().enumerate() {
7705                let row = idx_f as usize;
7706                for j in 0..d {
7707                    result[row * d + j] += go_host[i * d + j];
7708                }
7709            }
7710            return cpu_to_gpu(&result, device);
7711        }
7712    };
7713
7714    let mut out = alloc_zeros_f32(num_embeddings * d, device)?;
7715    let cfg = launch_cfg(total)?;
7716    let d_u32 = d as u32;
7717    let total_u32 = total as u32;
7718
7719    unsafe {
7720        stream
7721            .launch_builder(&f)
7722            .arg(grad_output.inner())
7723            .arg(indices.inner())
7724            .arg(out.inner_mut())
7725            .arg(&d_u32)
7726            .arg(&total_u32)
7727            .launch(cfg)?;
7728    }
7729
7730    Ok(out)
7731}
7732
7733/// 2D transpose into pre-allocated output.
7734#[cfg(feature = "cuda")]
7735pub fn gpu_transpose_2d_into(
7736    a: &CudaBuffer<f32>,
7737    m: usize,
7738    n: usize,
7739    out: &mut CudaBuffer<f32>,
7740    device: &GpuDevice,
7741) -> GpuResult<()> {
7742    use cudarc::driver::PushKernelArg;
7743    let total = m * n;
7744    let ctx = device.context();
7745    let stream = device.stream();
7746    let f = crate::module_cache::get_or_compile(
7747        ctx,
7748        TRANSPOSE_2D_PTX,
7749        "transpose_2d_kernel",
7750        device.ordinal() as u32,
7751    )
7752    .map_err(|_| GpuError::PtxCompileFailed {
7753        kernel: "transpose_2d_kernel",
7754    })?;
7755    let cfg = launch_cfg(total)?;
7756    let m_u32 = m as u32;
7757    let n_u32 = n as u32;
7758    let total_u32 = total as u32;
7759    unsafe {
7760        stream
7761            .launch_builder(&f)
7762            .arg(a.inner())
7763            .arg(out.inner_mut())
7764            .arg(&m_u32)
7765            .arg(&n_u32)
7766            .arg(&total_u32)
7767            .launch(cfg)?;
7768    }
7769    Ok(())
7770}
7771
7772/// Permute (0,2,1,3) into pre-allocated output.
7773#[cfg(feature = "cuda")]
7774pub fn gpu_permute_0213_into(
7775    a: &CudaBuffer<f32>,
7776    d0: usize,
7777    d1: usize,
7778    d2: usize,
7779    d3: usize,
7780    out: &mut CudaBuffer<f32>,
7781    device: &GpuDevice,
7782) -> GpuResult<()> {
7783    use cudarc::driver::PushKernelArg;
7784    let total = d0 * d1 * d2 * d3;
7785    let ctx = device.context();
7786    let stream = device.stream();
7787    let f = crate::module_cache::get_or_compile(
7788        ctx,
7789        PERMUTE_0213_PTX,
7790        "permute_0213_kernel",
7791        device.ordinal() as u32,
7792    )
7793    .map_err(|_| GpuError::PtxCompileFailed {
7794        kernel: "permute_0213_kernel",
7795    })?;
7796    let cfg = launch_cfg(total)?;
7797    let (d0u, d1u, d2u, d3u, tu) = (d0 as u32, d1 as u32, d2 as u32, d3 as u32, total as u32);
7798    unsafe {
7799        stream
7800            .launch_builder(&f)
7801            .arg(a.inner())
7802            .arg(out.inner_mut())
7803            .arg(&d0u)
7804            .arg(&d1u)
7805            .arg(&d2u)
7806            .arg(&d3u)
7807            .arg(&tu)
7808            .launch(cfg)?;
7809    }
7810    Ok(())
7811}
7812
7813/// Softmax into pre-allocated output (row-wise).
7814#[cfg(feature = "cuda")]
7815pub fn gpu_softmax_into(
7816    a: &CudaBuffer<f32>,
7817    rows: usize,
7818    cols: usize,
7819    out: &mut CudaBuffer<f32>,
7820    device: &GpuDevice,
7821) -> GpuResult<()> {
7822    use cudarc::driver::PushKernelArg;
7823    let ctx = device.context();
7824    let stream = device.stream();
7825    let f = crate::module_cache::get_or_compile(
7826        ctx,
7827        SOFTMAX_PTX,
7828        "softmax_kernel",
7829        device.ordinal() as u32,
7830    )
7831    .map_err(|_| GpuError::PtxCompileFailed {
7832        kernel: "softmax_kernel",
7833    })?;
7834    let block_size = 256u32;
7835    let grid_size = rows as u32;
7836    let cfg = LaunchConfig {
7837        grid_dim: (grid_size, 1, 1),
7838        block_dim: (block_size, 1, 1),
7839        shared_mem_bytes: (cols as u32) * 4,
7840    };
7841    let rows_u32 = rows as u32;
7842    let cols_u32 = cols as u32;
7843    unsafe {
7844        stream
7845            .launch_builder(&f)
7846            .arg(a.inner())
7847            .arg(out.inner_mut())
7848            .arg(&rows_u32)
7849            .arg(&cols_u32)
7850            .launch(cfg)?;
7851    }
7852    Ok(())
7853}
7854
7855/// LayerNorm into pre-allocated output.
7856#[cfg(feature = "cuda")]
7857#[allow(clippy::too_many_arguments)]
7858pub fn gpu_layernorm_into(
7859    input: &CudaBuffer<f32>,
7860    weight: &CudaBuffer<f32>,
7861    bias: &CudaBuffer<f32>,
7862    rows: usize,
7863    cols: usize,
7864    eps: f32,
7865    out: &mut CudaBuffer<f32>,
7866    device: &GpuDevice,
7867) -> GpuResult<()> {
7868    use cudarc::driver::PushKernelArg;
7869    let ctx = device.context();
7870    let stream = device.stream();
7871    let f = crate::module_cache::get_or_compile(
7872        ctx,
7873        LAYERNORM_PTX,
7874        "layernorm_kernel",
7875        device.ordinal() as u32,
7876    )
7877    .map_err(|_| GpuError::PtxCompileFailed {
7878        kernel: "layernorm_kernel",
7879    })?;
7880    let block_size = 256u32;
7881    let grid_size = rows as u32;
7882    let cfg = LaunchConfig {
7883        grid_dim: (grid_size, 1, 1),
7884        block_dim: (block_size, 1, 1),
7885        shared_mem_bytes: (cols as u32) * 4,
7886    };
7887    let rows_u32 = rows as u32;
7888    let cols_u32 = cols as u32;
7889    unsafe {
7890        stream
7891            .launch_builder(&f)
7892            .arg(input.inner())
7893            .arg(out.inner_mut())
7894            .arg(weight.inner())
7895            .arg(bias.inner())
7896            .arg(&rows_u32)
7897            .arg(&cols_u32)
7898            .arg(&eps)
7899            .launch(cfg)?;
7900    }
7901    Ok(())
7902}
7903
7904/// Slice read into pre-allocated output: read first `len` rows from
7905/// `[n_batch, max_len, d]` into out `[n_batch, len, d]`.
7906#[cfg(feature = "cuda")]
7907pub fn gpu_slice_read_into(
7908    src: &CudaBuffer<f32>,
7909    n_batch: usize,
7910    d: usize,
7911    len: usize,
7912    max_len: usize,
7913    out: &mut CudaBuffer<f32>,
7914    device: &GpuDevice,
7915) -> GpuResult<()> {
7916    use cudarc::driver::PushKernelArg;
7917    let total = n_batch * len * d;
7918    let ctx = device.context();
7919    let stream = device.stream();
7920    let f = crate::module_cache::get_or_compile(
7921        ctx,
7922        SLICE_READ_PTX,
7923        "slice_read_kernel",
7924        device.ordinal() as u32,
7925    )
7926    .map_err(|_| GpuError::PtxCompileFailed {
7927        kernel: "slice_read_kernel",
7928    })?;
7929    let cfg = launch_cfg(total)?;
7930    let total_u32 = total as u32;
7931    let d_u32 = d as u32;
7932    let len_u32 = len as u32;
7933    let max_len_u32 = max_len as u32;
7934    unsafe {
7935        stream
7936            .launch_builder(&f)
7937            .arg(src.inner())
7938            .arg(out.inner_mut())
7939            .arg(&total_u32)
7940            .arg(&d_u32)
7941            .arg(&len_u32)
7942            .arg(&max_len_u32)
7943            .launch(cfg)?;
7944    }
7945    Ok(())
7946}
7947
7948/// Small matmul (PTX kernel) into pre-allocated output.
7949#[cfg(feature = "cuda")]
7950pub fn gpu_small_matmul_into(
7951    a: &CudaBuffer<f32>,
7952    b: &CudaBuffer<f32>,
7953    m: usize,
7954    k: usize,
7955    n: usize,
7956    out: &mut CudaBuffer<f32>,
7957    device: &GpuDevice,
7958) -> GpuResult<()> {
7959    use cudarc::driver::PushKernelArg;
7960    let total = m * n;
7961    let ctx = device.context();
7962    let stream = device.stream();
7963    let f = crate::module_cache::get_or_compile(
7964        ctx,
7965        SMALL_MATMUL_PTX,
7966        "small_matmul_kernel",
7967        device.ordinal() as u32,
7968    )
7969    .map_err(|_| GpuError::PtxCompileFailed {
7970        kernel: "small_matmul_kernel",
7971    })?;
7972    let cfg = launch_cfg(total)?;
7973    let (m_u32, k_u32, n_u32, total_u32) = (m as u32, k as u32, n as u32, total as u32);
7974    unsafe {
7975        stream
7976            .launch_builder(&f)
7977            .arg(a.inner())
7978            .arg(b.inner())
7979            .arg(out.inner_mut())
7980            .arg(&m_u32)
7981            .arg(&k_u32)
7982            .arg(&n_u32)
7983            .arg(&total_u32)
7984            .launch(cfg)?;
7985    }
7986    Ok(())
7987}
7988
7989// ===========================================================================
7990// Indirect-parameter kernels for CUDA graph capture
7991// ===========================================================================
7992
7993/// Slice write with position read from device memory (for CUDA graph capture).
7994/// Writes `src [n_batch, d]` into row `*pos_ptr` of `dst [n_batch, max_len, d]`.
7995#[cfg(feature = "cuda")]
7996pub fn gpu_slice_write_indirect(
7997    src: &CudaBuffer<f32>,
7998    dst: &mut CudaBuffer<f32>,
7999    n_batch: usize,
8000    d: usize,
8001    max_len: usize,
8002    pos_ptr: &cudarc::driver::CudaSlice<u32>,
8003    device: &GpuDevice,
8004) -> GpuResult<()> {
8005    use cudarc::driver::PushKernelArg;
8006    let total = n_batch * d;
8007    let ctx = device.context();
8008    let stream = device.stream();
8009    let f = crate::module_cache::get_or_compile(
8010        ctx,
8011        SLICE_WRITE_INDIRECT_PTX,
8012        "slice_write_indirect_kernel",
8013        device.ordinal() as u32,
8014    )
8015    .map_err(|_| GpuError::PtxCompileFailed {
8016        kernel: "slice_write_indirect_kernel",
8017    })?;
8018    let cfg = launch_cfg(total)?;
8019    let n_u32 = total as u32;
8020    let d_u32 = d as u32;
8021    let max_len_u32 = max_len as u32;
8022    unsafe {
8023        stream
8024            .launch_builder(&f)
8025            .arg(src.inner())
8026            .arg(dst.inner_mut())
8027            .arg(&n_u32)
8028            .arg(&d_u32)
8029            .arg(&max_len_u32)
8030            .arg(pos_ptr)
8031            .launch(cfg)?;
8032    }
8033    Ok(())
8034}
8035
8036/// Build causal attention mask with total_len read from device memory.
8037/// Writes `out[h, col] = 0.0` if `col < *total_len_ptr`, else `-1e9`.
8038/// Output shape: `[n_head, max_pos]` (n_head rows, each max_pos wide).
8039#[cfg(feature = "cuda")]
8040pub fn gpu_causal_mask_indirect(
8041    total_len_ptr: &cudarc::driver::CudaSlice<u32>,
8042    n_head: usize,
8043    max_pos: usize,
8044    out: &mut CudaBuffer<f32>,
8045    device: &GpuDevice,
8046) -> GpuResult<()> {
8047    use cudarc::driver::PushKernelArg;
8048    let total = n_head * max_pos;
8049    let ctx = device.context();
8050    let stream = device.stream();
8051    let f = crate::module_cache::get_or_compile(
8052        ctx,
8053        CAUSAL_MASK_INDIRECT_PTX,
8054        "causal_mask_indirect_kernel",
8055        device.ordinal() as u32,
8056    )
8057    .map_err(|_| GpuError::PtxCompileFailed {
8058        kernel: "causal_mask_indirect_kernel",
8059    })?;
8060    let cfg = launch_cfg(total)?;
8061    let max_pos_u32 = max_pos as u32;
8062    let total_u32 = total as u32;
8063    unsafe {
8064        stream
8065            .launch_builder(&f)
8066            .arg(total_len_ptr)
8067            .arg(out.inner_mut())
8068            .arg(&max_pos_u32)
8069            .arg(&total_u32)
8070            .launch(cfg)?;
8071    }
8072    Ok(())
8073}
8074
8075// ===========================================================================
8076// Pre-compilation of all decode-path PTX modules
8077// ===========================================================================
8078
8079/// Pre-compile all PTX kernels used by the decode pass into the module cache.
8080/// Call this before CUDA graph capture to ensure no `cuModuleLoadData` calls
8081/// occur during capture (which is not a capturable operation).
8082#[cfg(feature = "cuda")]
8083pub fn precompile_decode_kernels(device: &GpuDevice) -> GpuResult<()> {
8084    let ctx = device.context();
8085    ctx.bind_to_thread()?;
8086    let ord = device.ordinal() as u32;
8087    let compile = |ptx: &'static str, name: &'static str| -> GpuResult<()> {
8088        crate::module_cache::get_or_compile(ctx, ptx, name, ord)
8089            .map(|_| ())
8090            .map_err(GpuError::Driver)
8091    };
8092    compile(ADD_PTX, "add_kernel")?;
8093    compile(MUL_PTX, "mul_kernel")?;
8094    compile(SCALE_PTX, "scale_kernel")?;
8095    compile(GELU_PTX, "gelu_kernel")?;
8096    compile(SOFTMAX_PTX, "softmax_kernel")?;
8097    compile(LAYERNORM_PTX, "layernorm_kernel")?;
8098    compile(PERMUTE_0213_PTX, "permute_0213_kernel")?;
8099    compile(EMBED_LOOKUP_PTX, "embed_lookup_kernel")?;
8100    compile(EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel")?;
8101    compile(SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel")?;
8102    compile(SMALL_MATMUL_PTX, "small_matmul_kernel")?;
8103    compile(SLICE_WRITE_INDIRECT_PTX, "slice_write_indirect_kernel")?;
8104    compile(CAUSAL_MASK_INDIRECT_PTX, "causal_mask_indirect_kernel")?;
8105    compile(SLICE_READ_PTX, "slice_read_kernel")?;
8106    compile(RELU_BACKWARD_PTX, "relu_backward_kernel")?;
8107    compile(GELU_BACKWARD_PTX, "gelu_backward_kernel")?;
8108    Ok(())
8109}
8110
8111/// Stub — no-op without cuda.
8112#[cfg(not(feature = "cuda"))]
8113pub fn precompile_decode_kernels(_device: &GpuDevice) -> GpuResult<()> {
8114    Err(GpuError::NoCudaFeature)
8115}
8116
8117// ---------------------------------------------------------------------------
8118// Stubs when `cuda` feature is disabled
8119// ---------------------------------------------------------------------------
8120
8121/// Stub -- always returns [`GpuError::NoCudaFeature`].
8122#[cfg(not(feature = "cuda"))]
8123pub fn gpu_gelu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8124    Err(GpuError::NoCudaFeature)
8125}
8126
8127/// Stub -- always returns [`GpuError::NoCudaFeature`].
8128#[cfg(not(feature = "cuda"))]
8129pub fn gpu_div(
8130    _a: &CudaBuffer<f32>,
8131    _b: &CudaBuffer<f32>,
8132    _device: &GpuDevice,
8133) -> GpuResult<CudaBuffer<f32>> {
8134    Err(GpuError::NoCudaFeature)
8135}
8136
8137/// Stub -- always returns [`GpuError::NoCudaFeature`].
8138#[cfg(not(feature = "cuda"))]
8139pub fn gpu_exp(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8140    Err(GpuError::NoCudaFeature)
8141}
8142
8143/// Stub -- always returns [`GpuError::NoCudaFeature`].
8144#[cfg(not(feature = "cuda"))]
8145pub fn gpu_log(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8146    Err(GpuError::NoCudaFeature)
8147}
8148
8149/// Stub -- always returns [`GpuError::NoCudaFeature`].
8150#[cfg(not(feature = "cuda"))]
8151pub fn gpu_sqrt(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8152    Err(GpuError::NoCudaFeature)
8153}
8154
8155/// Stub -- always returns [`GpuError::NoCudaFeature`].
8156#[cfg(not(feature = "cuda"))]
8157pub fn gpu_pow(
8158    _a: &CudaBuffer<f32>,
8159    _exponent: f32,
8160    _device: &GpuDevice,
8161) -> GpuResult<CudaBuffer<f32>> {
8162    Err(GpuError::NoCudaFeature)
8163}
8164
8165/// Stub -- always returns [`GpuError::NoCudaFeature`].
8166#[cfg(not(feature = "cuda"))]
8167pub fn gpu_abs(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8168    Err(GpuError::NoCudaFeature)
8169}
8170
8171/// Stub -- always returns [`GpuError::NoCudaFeature`].
8172#[cfg(not(feature = "cuda"))]
8173pub fn gpu_sigmoid(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8174    Err(GpuError::NoCudaFeature)
8175}
8176
8177/// Stub -- always returns [`GpuError::NoCudaFeature`].
8178#[cfg(not(feature = "cuda"))]
8179pub fn gpu_tanh(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8180    Err(GpuError::NoCudaFeature)
8181}
8182
8183/// Stub -- always returns [`GpuError::NoCudaFeature`].
8184#[cfg(not(feature = "cuda"))]
8185pub fn gpu_layernorm(
8186    _input: &CudaBuffer<f32>,
8187    _weight: &CudaBuffer<f32>,
8188    _bias: &CudaBuffer<f32>,
8189    _rows: usize,
8190    _cols: usize,
8191    _eps: f32,
8192    _device: &GpuDevice,
8193) -> GpuResult<CudaBuffer<f32>> {
8194    Err(GpuError::NoCudaFeature)
8195}
8196
8197/// Stub -- always returns [`GpuError::NoCudaFeature`].
8198#[cfg(not(feature = "cuda"))]
8199pub fn gpu_transpose_2d(
8200    _input: &CudaBuffer<f32>,
8201    _m: usize,
8202    _n: usize,
8203    _device: &GpuDevice,
8204) -> GpuResult<CudaBuffer<f32>> {
8205    Err(GpuError::NoCudaFeature)
8206}
8207
8208/// Stub -- always returns [`GpuError::NoCudaFeature`].
8209#[cfg(not(feature = "cuda"))]
8210pub fn gpu_add(
8211    _a: &CudaBuffer<f32>,
8212    _b: &CudaBuffer<f32>,
8213    _device: &GpuDevice,
8214) -> GpuResult<CudaBuffer<f32>> {
8215    Err(GpuError::NoCudaFeature)
8216}
8217
8218/// Stub -- always returns [`GpuError::NoCudaFeature`].
8219#[cfg(not(feature = "cuda"))]
8220pub fn gpu_sub(
8221    _a: &CudaBuffer<f32>,
8222    _b: &CudaBuffer<f32>,
8223    _device: &GpuDevice,
8224) -> GpuResult<CudaBuffer<f32>> {
8225    Err(GpuError::NoCudaFeature)
8226}
8227
8228/// Stub -- always returns [`GpuError::NoCudaFeature`].
8229#[cfg(not(feature = "cuda"))]
8230pub fn gpu_mul(
8231    _a: &CudaBuffer<f32>,
8232    _b: &CudaBuffer<f32>,
8233    _device: &GpuDevice,
8234) -> GpuResult<CudaBuffer<f32>> {
8235    Err(GpuError::NoCudaFeature)
8236}
8237
8238/// Stub -- always returns [`GpuError::NoCudaFeature`].
8239#[cfg(not(feature = "cuda"))]
8240pub fn gpu_neg(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8241    Err(GpuError::NoCudaFeature)
8242}
8243
8244/// Stub -- always returns [`GpuError::NoCudaFeature`].
8245#[cfg(not(feature = "cuda"))]
8246pub fn gpu_relu(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8247    Err(GpuError::NoCudaFeature)
8248}
8249
8250/// Stub -- always returns [`GpuError::NoCudaFeature`].
8251#[cfg(not(feature = "cuda"))]
8252pub fn gpu_scale(
8253    _a: &CudaBuffer<f32>,
8254    _scalar: f32,
8255    _device: &GpuDevice,
8256) -> GpuResult<CudaBuffer<f32>> {
8257    Err(GpuError::NoCudaFeature)
8258}
8259
8260/// Stub -- always returns [`GpuError::NoCudaFeature`].
8261#[cfg(not(feature = "cuda"))]
8262pub fn gpu_broadcast_add(
8263    _a: &CudaBuffer<f32>,
8264    _b: &CudaBuffer<f32>,
8265    _a_shape: &[usize],
8266    _b_shape: &[usize],
8267    _out_shape: &[usize],
8268    _device: &GpuDevice,
8269) -> GpuResult<CudaBuffer<f32>> {
8270    Err(GpuError::NoCudaFeature)
8271}
8272
8273/// Stub -- always returns [`GpuError::NoCudaFeature`].
8274#[cfg(not(feature = "cuda"))]
8275pub fn gpu_broadcast_sub(
8276    _a: &CudaBuffer<f32>,
8277    _b: &CudaBuffer<f32>,
8278    _a_shape: &[usize],
8279    _b_shape: &[usize],
8280    _out_shape: &[usize],
8281    _device: &GpuDevice,
8282) -> GpuResult<CudaBuffer<f32>> {
8283    Err(GpuError::NoCudaFeature)
8284}
8285
8286/// Stub -- always returns [`GpuError::NoCudaFeature`].
8287#[cfg(not(feature = "cuda"))]
8288pub fn gpu_broadcast_mul(
8289    _a: &CudaBuffer<f32>,
8290    _b: &CudaBuffer<f32>,
8291    _a_shape: &[usize],
8292    _b_shape: &[usize],
8293    _out_shape: &[usize],
8294    _device: &GpuDevice,
8295) -> GpuResult<CudaBuffer<f32>> {
8296    Err(GpuError::NoCudaFeature)
8297}
8298
8299/// Stub -- always returns [`GpuError::NoCudaFeature`].
8300#[cfg(not(feature = "cuda"))]
8301pub fn gpu_softmax(
8302    _input: &CudaBuffer<f32>,
8303    _rows: usize,
8304    _cols: usize,
8305    _device: &GpuDevice,
8306) -> GpuResult<CudaBuffer<f32>> {
8307    Err(GpuError::NoCudaFeature)
8308}
8309
8310/// Stub -- always returns [`GpuError::NoCudaFeature`].
8311#[cfg(not(feature = "cuda"))]
8312pub fn gpu_dropout(
8313    _input: &CudaBuffer<f32>,
8314    _threshold: u32,
8315    _scale: f32,
8316    _seed: u32,
8317    _device: &GpuDevice,
8318) -> GpuResult<CudaBuffer<f32>> {
8319    Err(GpuError::NoCudaFeature)
8320}
8321
8322/// Stub -- always returns [`GpuError::NoCudaFeature`].
8323#[cfg(not(feature = "cuda"))]
8324pub fn gpu_permute_0213(
8325    _input: &CudaBuffer<f32>,
8326    _d0: usize,
8327    _d1: usize,
8328    _d2: usize,
8329    _d3: usize,
8330    _device: &GpuDevice,
8331) -> GpuResult<CudaBuffer<f32>> {
8332    Err(GpuError::NoCudaFeature)
8333}
8334
8335/// Stub -- always returns [`GpuError::NoCudaFeature`].
8336#[cfg(not(feature = "cuda"))]
8337pub fn gpu_slice_write(
8338    _src: &CudaBuffer<f32>,
8339    _dst: &mut CudaBuffer<f32>,
8340    _n_batch: usize,
8341    _d: usize,
8342    _max_len: usize,
8343    _pos: usize,
8344    _device: &GpuDevice,
8345) -> GpuResult<()> {
8346    Err(GpuError::NoCudaFeature)
8347}
8348
8349/// Stub -- always returns [`GpuError::NoCudaFeature`].
8350#[cfg(not(feature = "cuda"))]
8351pub fn gpu_slice_read(
8352    _src: &CudaBuffer<f32>,
8353    _n_batch: usize,
8354    _d: usize,
8355    _len: usize,
8356    _max_len: usize,
8357    _device: &GpuDevice,
8358) -> GpuResult<CudaBuffer<f32>> {
8359    Err(GpuError::NoCudaFeature)
8360}
8361
8362/// Stub -- always returns [`GpuError::NoCudaFeature`].
8363#[cfg(not(feature = "cuda"))]
8364pub fn gpu_embed_lookup(
8365    _idx: &CudaBuffer<f32>,
8366    _weight: &CudaBuffer<f32>,
8367    _d: usize,
8368    _device: &GpuDevice,
8369) -> GpuResult<CudaBuffer<f32>> {
8370    Err(GpuError::NoCudaFeature)
8371}
8372
8373/// Stub -- always returns [`GpuError::NoCudaFeature`].
8374#[cfg(not(feature = "cuda"))]
8375pub fn gpu_embed_lookup_batch(
8376    _indices: &CudaBuffer<f32>,
8377    _weight: &CudaBuffer<f32>,
8378    _n: usize,
8379    _d: usize,
8380    _device: &GpuDevice,
8381) -> GpuResult<CudaBuffer<f32>> {
8382    Err(GpuError::NoCudaFeature)
8383}
8384
8385/// Stub -- always returns [`GpuError::NoCudaFeature`].
8386#[cfg(not(feature = "cuda"))]
8387pub fn gpu_scatter_add_rows(
8388    _grad_output: &CudaBuffer<f32>,
8389    _indices: &CudaBuffer<f32>,
8390    _num_embeddings: usize,
8391    _d: usize,
8392    _device: &GpuDevice,
8393) -> GpuResult<CudaBuffer<f32>> {
8394    Err(GpuError::NoCudaFeature)
8395}
8396
8397/// Stub -- always returns [`GpuError::NoCudaFeature`].
8398#[cfg(not(feature = "cuda"))]
8399pub fn gpu_relu_backward(
8400    _grad: &CudaBuffer<f32>,
8401    _input: &CudaBuffer<f32>,
8402    _device: &GpuDevice,
8403) -> GpuResult<CudaBuffer<f32>> {
8404    Err(GpuError::NoCudaFeature)
8405}
8406
8407/// Stub -- always returns [`GpuError::NoCudaFeature`].
8408#[cfg(not(feature = "cuda"))]
8409pub fn gpu_gelu_backward(
8410    _grad: &CudaBuffer<f32>,
8411    _input: &CudaBuffer<f32>,
8412    _device: &GpuDevice,
8413) -> GpuResult<CudaBuffer<f32>> {
8414    Err(GpuError::NoCudaFeature)
8415}
8416
8417/// Stub -- always returns [`GpuError::NoCudaFeature`].
8418#[cfg(not(feature = "cuda"))]
8419pub fn gpu_index_select_1d(
8420    _input: &CudaBuffer<f32>,
8421    _indices: &CudaBuffer<f32>,
8422    _device: &GpuDevice,
8423) -> GpuResult<CudaBuffer<f32>> {
8424    Err(GpuError::NoCudaFeature)
8425}
8426
8427/// Stub -- always returns [`GpuError::NoCudaFeature`].
8428#[cfg(not(feature = "cuda"))]
8429pub fn gpu_scatter_add_1d(
8430    _grad_output: &CudaBuffer<f32>,
8431    _indices: &CudaBuffer<f32>,
8432    _input_len: usize,
8433    _device: &GpuDevice,
8434) -> GpuResult<CudaBuffer<f32>> {
8435    Err(GpuError::NoCudaFeature)
8436}
8437
8438/// Stub -- always returns [`GpuError::NoCudaFeature`].
8439#[cfg(not(feature = "cuda"))]
8440pub fn gpu_masked_fill(
8441    _input: &CudaBuffer<f32>,
8442    _mask: &CudaBuffer<f32>,
8443    _value: f32,
8444    _device: &GpuDevice,
8445) -> GpuResult<CudaBuffer<f32>> {
8446    Err(GpuError::NoCudaFeature)
8447}
8448
8449/// Stub -- always returns [`GpuError::NoCudaFeature`].
8450#[cfg(not(feature = "cuda"))]
8451pub fn gpu_masked_zero(
8452    _grad: &CudaBuffer<f32>,
8453    _mask: &CudaBuffer<f32>,
8454    _device: &GpuDevice,
8455) -> GpuResult<CudaBuffer<f32>> {
8456    Err(GpuError::NoCudaFeature)
8457}
8458
8459/// Stub -- always returns [`GpuError::NoCudaFeature`].
8460#[cfg(not(feature = "cuda"))]
8461pub fn gpu_sigmoid_backward(
8462    _grad: &CudaBuffer<f32>,
8463    _output: &CudaBuffer<f32>,
8464    _device: &GpuDevice,
8465) -> GpuResult<CudaBuffer<f32>> {
8466    Err(GpuError::NoCudaFeature)
8467}
8468
8469/// Stub -- always returns [`GpuError::NoCudaFeature`].
8470#[cfg(not(feature = "cuda"))]
8471pub fn gpu_tanh_backward(
8472    _grad: &CudaBuffer<f32>,
8473    _output: &CudaBuffer<f32>,
8474    _device: &GpuDevice,
8475) -> GpuResult<CudaBuffer<f32>> {
8476    Err(GpuError::NoCudaFeature)
8477}
8478
8479/// Stub -- always returns [`GpuError::NoCudaFeature`].
8480#[cfg(not(feature = "cuda"))]
8481pub fn gpu_softmax_backward(
8482    _grad: &CudaBuffer<f32>,
8483    _output: &CudaBuffer<f32>,
8484    _cols: usize,
8485    _device: &GpuDevice,
8486) -> GpuResult<CudaBuffer<f32>> {
8487    Err(GpuError::NoCudaFeature)
8488}
8489
8490/// Stub -- always returns [`GpuError::NoCudaFeature`].
8491#[cfg(not(feature = "cuda"))]
8492pub fn gpu_sum_axis(
8493    _a: &CudaBuffer<f32>,
8494    _outer: usize,
8495    _axis_size: usize,
8496    _inner: usize,
8497    _device: &GpuDevice,
8498) -> GpuResult<CudaBuffer<f32>> {
8499    Err(GpuError::NoCudaFeature)
8500}
8501
8502/// Stub -- always returns [`GpuError::NoCudaFeature`].
8503#[cfg(not(feature = "cuda"))]
8504pub fn gpu_strided_split(
8505    _input: &CudaBuffer<f32>,
8506    _total_along_axis: usize,
8507    _split_offset: usize,
8508    _split_size: usize,
8509    _inner_size: usize,
8510    _n: usize,
8511    _device: &GpuDevice,
8512) -> GpuResult<CudaBuffer<f32>> {
8513    Err(GpuError::NoCudaFeature)
8514}
8515
8516/// Stub -- always returns [`GpuError::NoCudaFeature`].
8517#[cfg(not(feature = "cuda"))]
8518pub fn gpu_strided_cat(
8519    _input: &CudaBuffer<f32>,
8520    _output: &mut CudaBuffer<f32>,
8521    _total_along_axis: usize,
8522    _cat_offset: usize,
8523    _part_size: usize,
8524    _inner_size: usize,
8525    _n: usize,
8526    _device: &GpuDevice,
8527) -> GpuResult<()> {
8528    Err(GpuError::NoCudaFeature)
8529}
8530
8531// ---------------------------------------------------------------------------
8532// f32-to-f16 GPU conversion
8533// ---------------------------------------------------------------------------
8534
8535/// Convert an f32 GPU buffer to f16 (represented as `CudaSlice<u16>`).
8536///
8537/// Each element is converted using IEEE 754 round-to-nearest-even via the
8538/// PTX `cvt.rn.f16.f32` instruction. The output is a `CudaSlice<u16>` where
8539/// each `u16` holds the bit pattern of an IEEE 754 half-precision float.
8540///
8541/// # Errors
8542///
8543/// - [`GpuError::PtxCompileFailed`] if the conversion kernel cannot be compiled
8544///   (e.g., GPU architecture too old to support f16 conversion instructions).
8545/// - [`GpuError::Driver`] on CUDA launch errors.
8546#[cfg(feature = "cuda")]
8547pub(crate) fn gpu_f32_to_f16(
8548    input: &CudaBuffer<f32>,
8549    device: &GpuDevice,
8550) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
8551    use cudarc::driver::PushKernelArg;
8552
8553    let n = input.len();
8554    if n == 0 {
8555        let empty = device.stream().alloc_zeros::<u16>(0)?;
8556        return Ok(empty);
8557    }
8558
8559    let ctx = device.context();
8560    let stream = device.stream();
8561
8562    let f = crate::module_cache::get_or_compile(
8563        ctx,
8564        F32_TO_F16_PTX,
8565        "f32_to_f16_kernel",
8566        device.ordinal() as u32,
8567    )
8568    .map_err(|_| GpuError::PtxCompileFailed {
8569        kernel: "f32_to_f16_kernel",
8570    })?;
8571
8572    let mut out = stream.alloc_zeros::<u16>(n)?;
8573    let cfg = launch_cfg(n)?;
8574    let n_u32 = n as u32;
8575
8576    // SAFETY: The kernel reads `n` f32 values from `input` and writes `n`
8577    // u16 values (f16 bit patterns) to `out`. Both buffers are device-resident
8578    // and correctly sized. The grid is configured to cover exactly `n` threads.
8579    unsafe {
8580        stream
8581            .launch_builder(&f)
8582            .arg(input.inner())
8583            .arg(&mut out)
8584            .arg(&n_u32)
8585            .launch(cfg)?;
8586    }
8587
8588    Ok(out)
8589}
8590
8591/// Stub -- always returns [`GpuError::NoCudaFeature`].
8592#[cfg(not(feature = "cuda"))]
8593pub(crate) fn gpu_f32_to_f16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
8594    Err(GpuError::NoCudaFeature)
8595}
8596
8597/// Convert f32 GPU buffer to bf16 (stored as u16) on-device.
8598///
8599/// Uses bit manipulation for round-to-nearest-even bf16 conversion.
8600/// Works on sm_52+ (no special bf16 hardware required).
8601#[cfg(feature = "cuda")]
8602pub(crate) fn gpu_f32_to_bf16(
8603    input: &CudaBuffer<f32>,
8604    device: &GpuDevice,
8605) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
8606    use cudarc::driver::PushKernelArg;
8607
8608    let n = input.len();
8609    if n == 0 {
8610        let empty = device.stream().alloc_zeros::<u16>(0)?;
8611        return Ok(empty);
8612    }
8613
8614    let ctx = device.context();
8615    let stream = device.stream();
8616
8617    let f = crate::module_cache::get_or_compile(
8618        ctx,
8619        F32_TO_BF16_PTX,
8620        "f32_to_bf16_kernel",
8621        device.ordinal() as u32,
8622    )
8623    .map_err(|_| GpuError::PtxCompileFailed {
8624        kernel: "f32_to_bf16_kernel",
8625    })?;
8626
8627    let mut out = stream.alloc_zeros::<u16>(n)?;
8628    let cfg = launch_cfg(n)?;
8629    let n_u32 = n as u32;
8630
8631    unsafe {
8632        stream
8633            .launch_builder(&f)
8634            .arg(input.inner())
8635            .arg(&mut out)
8636            .arg(&n_u32)
8637            .launch(cfg)?;
8638    }
8639
8640    Ok(out)
8641}
8642
8643/// Stub -- always returns [`GpuError::NoCudaFeature`].
8644#[cfg(not(feature = "cuda"))]
8645pub(crate) fn gpu_f32_to_bf16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
8646    Err(GpuError::NoCudaFeature)
8647}
8648
8649// ---------------------------------------------------------------------------
8650// Tests -- require a real CUDA GPU
8651// ---------------------------------------------------------------------------
8652
8653#[cfg(test)]
8654#[cfg(feature = "cuda")]
8655mod tests {
8656    use super::*;
8657
8658    /// Helper: set up device + upload a slice.
8659    fn setup(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
8660        let dev = GpuDevice::new(0).expect("CUDA device 0");
8661        let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
8662        (dev, buf)
8663    }
8664
8665    /// Round-trip helper: download a GPU buffer and compare against expected
8666    /// CPU output element-wise.
8667    fn assert_buf_eq(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32]) {
8668        let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
8669        assert_eq!(host.len(), expected.len(), "length mismatch");
8670        for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
8671            assert!(
8672                (got - exp).abs() < 1e-6,
8673                "element {i}: got {got}, expected {exp}",
8674            );
8675        }
8676    }
8677
8678    // -- gpu_add -------------------------------------------------------------
8679
8680    #[test]
8681    fn add_basic() {
8682        let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
8683        let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
8684        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
8685
8686        let (dev, a) = setup(&a_data);
8687        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
8688        let out = gpu_add(&a, &b, &dev).expect("gpu_add");
8689        assert_buf_eq(&out, &dev, &expected);
8690    }
8691
8692    #[test]
8693    fn add_empty() {
8694        let (dev, a) = setup(&[]);
8695        let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
8696        let out = gpu_add(&a, &b, &dev).expect("gpu_add empty");
8697        assert_eq!(out.len(), 0);
8698    }
8699
8700    #[test]
8701    fn add_large() {
8702        let n = 100_000;
8703        let a_data: Vec<f32> = (0..n).map(|i| i as f32).collect();
8704        let b_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
8705        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
8706
8707        let (dev, a) = setup(&a_data);
8708        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
8709        let out = gpu_add(&a, &b, &dev).expect("gpu_add large");
8710        assert_buf_eq(&out, &dev, &expected);
8711    }
8712
8713    #[test]
8714    fn add_length_mismatch() {
8715        let (dev, a) = setup(&[1.0, 2.0, 3.0]);
8716        let b = cpu_to_gpu(&[1.0, 2.0], &dev).expect("cpu_to_gpu b");
8717        let err = gpu_add(&a, &b, &dev).unwrap_err();
8718        match err {
8719            GpuError::LengthMismatch { a: 3, b: 2 } => {}
8720            other => panic!("unexpected error: {other}"),
8721        }
8722    }
8723
8724    // -- gpu_sub -------------------------------------------------------------
8725
8726    #[test]
8727    fn sub_basic() {
8728        let a_data = vec![10.0f32, 20.0, 30.0, 40.0];
8729        let b_data = vec![1.0f32, 2.0, 3.0, 4.0];
8730        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x - y).collect();
8731
8732        let (dev, a) = setup(&a_data);
8733        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
8734        let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
8735        assert_buf_eq(&out, &dev, &expected);
8736    }
8737
8738    #[test]
8739    fn sub_negative_result() {
8740        let a_data = vec![1.0f32, 2.0];
8741        let b_data = vec![5.0f32, 10.0];
8742        let expected: Vec<f32> = vec![-4.0, -8.0];
8743
8744        let (dev, a) = setup(&a_data);
8745        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
8746        let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
8747        assert_buf_eq(&out, &dev, &expected);
8748    }
8749
8750    // -- gpu_mul -------------------------------------------------------------
8751
8752    #[test]
8753    fn mul_basic() {
8754        let a_data = vec![2.0f32, 3.0, 4.0, 5.0];
8755        let b_data = vec![10.0f32, 10.0, 10.0, 10.0];
8756        let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x * y).collect();
8757
8758        let (dev, a) = setup(&a_data);
8759        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
8760        let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
8761        assert_buf_eq(&out, &dev, &expected);
8762    }
8763
8764    #[test]
8765    fn mul_by_zero() {
8766        let a_data = vec![1.0f32, 2.0, 3.0];
8767        let b_data = vec![0.0f32, 0.0, 0.0];
8768        let expected = vec![0.0f32, 0.0, 0.0];
8769
8770        let (dev, a) = setup(&a_data);
8771        let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
8772        let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
8773        assert_buf_eq(&out, &dev, &expected);
8774    }
8775
8776    // -- gpu_neg -------------------------------------------------------------
8777
8778    #[test]
8779    fn neg_basic() {
8780        let a_data = vec![1.0f32, -2.0, 3.0, 0.0, -5.5];
8781        let expected: Vec<f32> = a_data.iter().map(|x| -x).collect();
8782
8783        let (dev, a) = setup(&a_data);
8784        let out = gpu_neg(&a, &dev).expect("gpu_neg");
8785        assert_buf_eq(&out, &dev, &expected);
8786    }
8787
8788    #[test]
8789    fn neg_double_negation() {
8790        let a_data = vec![1.0f32, -2.0, 3.0];
8791        let (dev, a) = setup(&a_data);
8792        let neg1 = gpu_neg(&a, &dev).expect("gpu_neg 1");
8793        let neg2 = gpu_neg(&neg1, &dev).expect("gpu_neg 2");
8794        assert_buf_eq(&neg2, &dev, &a_data);
8795    }
8796
8797    // -- gpu_relu ------------------------------------------------------------
8798
8799    #[test]
8800    fn relu_basic() {
8801        let a_data = vec![-3.0f32, -1.0, 0.0, 1.0, 3.0];
8802        let expected = vec![0.0f32, 0.0, 0.0, 1.0, 3.0];
8803
8804        let (dev, a) = setup(&a_data);
8805        let out = gpu_relu(&a, &dev).expect("gpu_relu");
8806        assert_buf_eq(&out, &dev, &expected);
8807    }
8808
8809    #[test]
8810    fn relu_all_negative() {
8811        let a_data = vec![-5.0f32, -0.1, -100.0];
8812        let expected = vec![0.0f32, 0.0, 0.0];
8813
8814        let (dev, a) = setup(&a_data);
8815        let out = gpu_relu(&a, &dev).expect("gpu_relu");
8816        assert_buf_eq(&out, &dev, &expected);
8817    }
8818
8819    #[test]
8820    fn relu_all_positive() {
8821        let a_data = vec![0.1f32, 1.0, 100.0];
8822
8823        let (dev, a) = setup(&a_data);
8824        let out = gpu_relu(&a, &dev).expect("gpu_relu");
8825        assert_buf_eq(&out, &dev, &a_data);
8826    }
8827
8828    #[test]
8829    fn relu_empty() {
8830        let (dev, a) = setup(&[]);
8831        let out = gpu_relu(&a, &dev).expect("gpu_relu empty");
8832        assert_eq!(out.len(), 0);
8833    }
8834
8835    #[test]
8836    fn small_matmul_2x2() {
8837        let dev = GpuDevice::new(0).expect("CUDA device 0");
8838        // A = [[1, 2], [3, 4]], B = [[5, 6], [7, 8]]
8839        // C = A@B = [[19, 22], [43, 50]]
8840        let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0, 4.0], &dev).unwrap();
8841        let b = cpu_to_gpu(&[5.0f32, 6.0, 7.0, 8.0], &dev).unwrap();
8842        let c = gpu_small_matmul(&a, &b, 2, 2, 2, &dev).unwrap();
8843        assert_buf_eq(&c, &dev, &[19.0, 22.0, 43.0, 50.0]);
8844    }
8845
8846    #[test]
8847    fn small_matmul_1xk_kxn() {
8848        let dev = GpuDevice::new(0).expect("CUDA device 0");
8849        // A = [1, 2, 3] (1x3), B = [[1, 0], [0, 1], [1, 1]] (3x2)
8850        // C = [4, 5] (1x2)
8851        let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0], &dev).unwrap();
8852        let b = cpu_to_gpu(&[1.0f32, 0.0, 0.0, 1.0, 1.0, 1.0], &dev).unwrap();
8853        let c = gpu_small_matmul(&a, &b, 1, 3, 2, &dev).unwrap();
8854        assert_buf_eq(&c, &dev, &[4.0, 5.0]);
8855    }
8856
8857    #[test]
8858    fn small_matmul_vs_cublas() {
8859        // Compare our small matmul against cuBLAS for a realistic decode-step size.
8860        // Linear layer: [1, 64] @ [64, 64] = [1, 64]
8861        let dev = GpuDevice::new(0).expect("CUDA device 0");
8862        let m = 1;
8863        let k = 64;
8864        let n = 64;
8865
8866        // Deterministic data.
8867        let a_data: Vec<f32> = (0..m * k)
8868            .map(|i| ((i * 7 + 3) % 100) as f32 / 100.0)
8869            .collect();
8870        let b_data: Vec<f32> = (0..k * n)
8871            .map(|i| ((i * 11 + 5) % 100) as f32 / 100.0)
8872            .collect();
8873
8874        let a = cpu_to_gpu(&a_data, &dev).unwrap();
8875        let b = cpu_to_gpu(&b_data, &dev).unwrap();
8876
8877        // cuBLAS reference.
8878        let c_cublas = crate::blas::gpu_matmul_f32(&a, &b, m, k, n, &dev).unwrap();
8879        let cublas_result = gpu_to_cpu(&c_cublas, &dev).unwrap();
8880
8881        // Our kernel.
8882        let c_ours = gpu_small_matmul(&a, &b, m, k, n, &dev).unwrap();
8883        let our_result = gpu_to_cpu(&c_ours, &dev).unwrap();
8884
8885        assert_eq!(cublas_result.len(), our_result.len());
8886        for (i, (&cb, &ours)) in cublas_result.iter().zip(our_result.iter()).enumerate() {
8887            assert!(
8888                (cb - ours).abs() < 0.1,
8889                "element {i}: cuBLAS={cb}, ours={ours}, diff={}",
8890                (cb - ours).abs()
8891            );
8892        }
8893    }
8894}