1#[cfg(feature = "cuda")]
23use cudarc::driver::LaunchConfig;
24
25use crate::buffer::CudaBuffer;
26use crate::device::GpuDevice;
27use crate::error::{GpuError, GpuResult};
28#[cfg(feature = "cuda")]
29use crate::transfer::{alloc_zeros_f32, cpu_to_gpu, gpu_to_cpu};
30
31#[cfg(feature = "cuda")]
37pub(crate) const ADD_PTX: &str = "\
38.version 7.0
39.target sm_52
40.address_size 64
41
42.visible .entry add_kernel(
43 .param .u64 a_ptr,
44 .param .u64 b_ptr,
45 .param .u64 out_ptr,
46 .param .u32 n
47) {
48 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
49 .reg .u64 %a, %b, %out, %off;
50 .reg .f32 %va, %vb, %vr;
51 .reg .pred %p;
52
53 ld.param.u64 %a, [a_ptr];
54 ld.param.u64 %b, [b_ptr];
55 ld.param.u64 %out, [out_ptr];
56 ld.param.u32 %n_reg, [n];
57
58 mov.u32 %bid, %ctaid.x;
59 mov.u32 %bdim, %ntid.x;
60 mov.u32 %r_tid, %tid.x;
61 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
62
63 setp.ge.u32 %p, %r_tid, %n_reg;
64 @%p bra DONE;
65
66 cvt.u64.u32 %off, %r_tid;
67 shl.b64 %off, %off, 2;
68
69 add.u64 %a, %a, %off;
70 add.u64 %b, %b, %off;
71 add.u64 %out, %out, %off;
72
73 ld.global.f32 %va, [%a];
74 ld.global.f32 %vb, [%b];
75 add.f32 %vr, %va, %vb;
76 st.global.f32 [%out], %vr;
77
78DONE:
79 ret;
80}
81";
82
83#[cfg(feature = "cuda")]
88pub(crate) const ADD_VEC4_PTX: &str = "\
89.version 7.0
90.target sm_52
91.address_size 64
92
93.visible .entry add_vec4_kernel(
94 .param .u64 a_ptr,
95 .param .u64 b_ptr,
96 .param .u64 out_ptr,
97 .param .u32 n4
98) {
99 .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
100 .reg .u64 %a, %b, %out, %off;
101 .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
102 .reg .pred %p;
103
104 ld.param.u64 %a, [a_ptr];
105 ld.param.u64 %b, [b_ptr];
106 ld.param.u64 %out, [out_ptr];
107 ld.param.u32 %n4_reg, [n4];
108
109 mov.u32 %bid, %ctaid.x;
110 mov.u32 %bdim, %ntid.x;
111 mov.u32 %r_tid, %tid.x;
112 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
113
114 setp.ge.u32 %p, %r_tid, %n4_reg;
115 @%p bra DONE;
116
117 // Byte offset = tid * 16 (4 floats × 4 bytes)
118 cvt.u64.u32 %off, %r_tid;
119 shl.b64 %off, %off, 4;
120
121 add.u64 %a, %a, %off;
122 add.u64 %b, %b, %off;
123 add.u64 %out, %out, %off;
124
125 ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
126 ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
127
128 add.f32 %r0, %a0, %b0;
129 add.f32 %r1, %a1, %b1;
130 add.f32 %r2, %a2, %b2;
131 add.f32 %r3, %a3, %b3;
132
133 st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
134
135DONE:
136 ret;
137}
138";
139
140#[cfg(feature = "cuda")]
142pub(crate) const MUL_VEC4_PTX: &str = "\
143.version 7.0
144.target sm_52
145.address_size 64
146
147.visible .entry mul_vec4_kernel(
148 .param .u64 a_ptr,
149 .param .u64 b_ptr,
150 .param .u64 out_ptr,
151 .param .u32 n4
152) {
153 .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
154 .reg .u64 %a, %b, %out, %off;
155 .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
156 .reg .pred %p;
157
158 ld.param.u64 %a, [a_ptr];
159 ld.param.u64 %b, [b_ptr];
160 ld.param.u64 %out, [out_ptr];
161 ld.param.u32 %n4_reg, [n4];
162
163 mov.u32 %bid, %ctaid.x;
164 mov.u32 %bdim, %ntid.x;
165 mov.u32 %r_tid, %tid.x;
166 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
167
168 setp.ge.u32 %p, %r_tid, %n4_reg;
169 @%p bra DONE;
170
171 cvt.u64.u32 %off, %r_tid;
172 shl.b64 %off, %off, 4;
173
174 add.u64 %a, %a, %off;
175 add.u64 %b, %b, %off;
176 add.u64 %out, %out, %off;
177
178 ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
179 ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
180
181 mul.f32 %r0, %a0, %b0;
182 mul.f32 %r1, %a1, %b1;
183 mul.f32 %r2, %a2, %b2;
184 mul.f32 %r3, %a3, %b3;
185
186 st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
187
188DONE:
189 ret;
190}
191";
192
193#[cfg(feature = "cuda")]
195pub(crate) const SUB_PTX: &str = "\
196.version 7.0
197.target sm_52
198.address_size 64
199
200.visible .entry sub_kernel(
201 .param .u64 a_ptr,
202 .param .u64 b_ptr,
203 .param .u64 out_ptr,
204 .param .u32 n
205) {
206 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
207 .reg .u64 %a, %b, %out, %off;
208 .reg .f32 %va, %vb, %vr;
209 .reg .pred %p;
210
211 ld.param.u64 %a, [a_ptr];
212 ld.param.u64 %b, [b_ptr];
213 ld.param.u64 %out, [out_ptr];
214 ld.param.u32 %n_reg, [n];
215
216 mov.u32 %bid, %ctaid.x;
217 mov.u32 %bdim, %ntid.x;
218 mov.u32 %r_tid, %tid.x;
219 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
220
221 setp.ge.u32 %p, %r_tid, %n_reg;
222 @%p bra DONE;
223
224 cvt.u64.u32 %off, %r_tid;
225 shl.b64 %off, %off, 2;
226
227 add.u64 %a, %a, %off;
228 add.u64 %b, %b, %off;
229 add.u64 %out, %out, %off;
230
231 ld.global.f32 %va, [%a];
232 ld.global.f32 %vb, [%b];
233 sub.f32 %vr, %va, %vb;
234 st.global.f32 [%out], %vr;
235
236DONE:
237 ret;
238}
239";
240
241#[cfg(feature = "cuda")]
243pub(crate) const MUL_PTX: &str = "\
244.version 7.0
245.target sm_52
246.address_size 64
247
248.visible .entry mul_kernel(
249 .param .u64 a_ptr,
250 .param .u64 b_ptr,
251 .param .u64 out_ptr,
252 .param .u32 n
253) {
254 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
255 .reg .u64 %a, %b, %out, %off;
256 .reg .f32 %va, %vb, %vr;
257 .reg .pred %p;
258
259 ld.param.u64 %a, [a_ptr];
260 ld.param.u64 %b, [b_ptr];
261 ld.param.u64 %out, [out_ptr];
262 ld.param.u32 %n_reg, [n];
263
264 mov.u32 %bid, %ctaid.x;
265 mov.u32 %bdim, %ntid.x;
266 mov.u32 %r_tid, %tid.x;
267 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
268
269 setp.ge.u32 %p, %r_tid, %n_reg;
270 @%p bra DONE;
271
272 cvt.u64.u32 %off, %r_tid;
273 shl.b64 %off, %off, 2;
274
275 add.u64 %a, %a, %off;
276 add.u64 %b, %b, %off;
277 add.u64 %out, %out, %off;
278
279 ld.global.f32 %va, [%a];
280 ld.global.f32 %vb, [%b];
281 mul.f32 %vr, %va, %vb;
282 st.global.f32 [%out], %vr;
283
284DONE:
285 ret;
286}
287";
288
289#[cfg(feature = "cuda")]
291pub(crate) const NEG_PTX: &str = "\
292.version 7.0
293.target sm_52
294.address_size 64
295
296.visible .entry neg_kernel(
297 .param .u64 a_ptr,
298 .param .u64 out_ptr,
299 .param .u32 n
300) {
301 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
302 .reg .u64 %a, %out, %off;
303 .reg .f32 %va, %vr;
304 .reg .pred %p;
305
306 ld.param.u64 %a, [a_ptr];
307 ld.param.u64 %out, [out_ptr];
308 ld.param.u32 %n_reg, [n];
309
310 mov.u32 %bid, %ctaid.x;
311 mov.u32 %bdim, %ntid.x;
312 mov.u32 %r_tid, %tid.x;
313 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
314
315 setp.ge.u32 %p, %r_tid, %n_reg;
316 @%p bra DONE;
317
318 cvt.u64.u32 %off, %r_tid;
319 shl.b64 %off, %off, 2;
320
321 add.u64 %a, %a, %off;
322 add.u64 %out, %out, %off;
323
324 ld.global.f32 %va, [%a];
325 neg.f32 %vr, %va;
326 st.global.f32 [%out], %vr;
327
328DONE:
329 ret;
330}
331";
332
333#[cfg(feature = "cuda")]
335pub(crate) const RELU_PTX: &str = "\
336.version 7.0
337.target sm_52
338.address_size 64
339
340.visible .entry relu_kernel(
341 .param .u64 a_ptr,
342 .param .u64 out_ptr,
343 .param .u32 n
344) {
345 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
346 .reg .u64 %a, %out, %off;
347 .reg .f32 %va, %vr, %zero;
348 .reg .pred %p;
349
350 ld.param.u64 %a, [a_ptr];
351 ld.param.u64 %out, [out_ptr];
352 ld.param.u32 %n_reg, [n];
353
354 mov.u32 %bid, %ctaid.x;
355 mov.u32 %bdim, %ntid.x;
356 mov.u32 %r_tid, %tid.x;
357 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
358
359 setp.ge.u32 %p, %r_tid, %n_reg;
360 @%p bra DONE;
361
362 cvt.u64.u32 %off, %r_tid;
363 shl.b64 %off, %off, 2;
364
365 add.u64 %a, %a, %off;
366 add.u64 %out, %out, %off;
367
368 ld.global.f32 %va, [%a];
369 mov.f32 %zero, 0f00000000;
370 max.f32 %vr, %va, %zero;
371 st.global.f32 [%out], %vr;
372
373DONE:
374 ret;
375}
376";
377
378#[cfg(feature = "cuda")]
380pub(crate) const SCALE_PTX: &str = "\
381.version 7.0
382.target sm_52
383.address_size 64
384
385.visible .entry scale_kernel(
386 .param .u64 a_ptr,
387 .param .u64 out_ptr,
388 .param .f32 scalar,
389 .param .u32 n
390) {
391 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
392 .reg .u64 %a, %out, %off;
393 .reg .f32 %va, %vr, %s;
394 .reg .pred %p;
395
396 ld.param.u64 %a, [a_ptr];
397 ld.param.u64 %out, [out_ptr];
398 ld.param.f32 %s, [scalar];
399 ld.param.u32 %n_reg, [n];
400
401 mov.u32 %bid, %ctaid.x;
402 mov.u32 %bdim, %ntid.x;
403 mov.u32 %r_tid, %tid.x;
404 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
405
406 setp.ge.u32 %p, %r_tid, %n_reg;
407 @%p bra DONE;
408
409 cvt.u64.u32 %off, %r_tid;
410 shl.b64 %off, %off, 2;
411
412 add.u64 %a, %a, %off;
413 add.u64 %out, %out, %off;
414
415 ld.global.f32 %va, [%a];
416 mul.f32 %vr, %va, %s;
417 st.global.f32 [%out], %vr;
418
419DONE:
420 ret;
421}
422";
423
424#[cfg(feature = "cuda")]
427pub(crate) const TRANSPOSE_2D_PTX: &str = "\
428.version 7.0\n\
429.target sm_52\n\
430.address_size 64\n\
431\n\
432.visible .entry transpose_2d_kernel(\n\
433 .param .u64 in_ptr,\n\
434 .param .u64 out_ptr,\n\
435 .param .u32 M,\n\
436 .param .u32 N,\n\
437 .param .u32 total\n\
438) {\n\
439 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %N_reg;\n\
440 .reg .u32 %out_row, %out_col, %in_idx;\n\
441 .reg .u64 %in, %out, %off_in, %off_out;\n\
442 .reg .f32 %val;\n\
443 .reg .pred %p;\n\
444\n\
445 ld.param.u64 %in, [in_ptr];\n\
446 ld.param.u64 %out, [out_ptr];\n\
447 ld.param.u32 %M_reg, [M];\n\
448 ld.param.u32 %N_reg, [N];\n\
449 ld.param.u32 %total_reg, [total];\n\
450\n\
451 mov.u32 %bid, %ctaid.x;\n\
452 mov.u32 %bdim, %ntid.x;\n\
453 mov.u32 %r_tid, %tid.x;\n\
454 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
455\n\
456 setp.ge.u32 %p, %r_tid, %total_reg;\n\
457 @%p bra DONE;\n\
458\n\
459 // Output shape is [N, M]. tid = out_row * M + out_col.\n\
460 div.u32 %out_row, %r_tid, %M_reg;\n\
461 rem.u32 %out_col, %r_tid, %M_reg;\n\
462 // Input index: out_col * N + out_row (transposed).\n\
463 mad.lo.u32 %in_idx, %out_col, %N_reg, %out_row;\n\
464\n\
465 cvt.u64.u32 %off_in, %in_idx;\n\
466 shl.b64 %off_in, %off_in, 2;\n\
467 add.u64 %off_in, %in, %off_in;\n\
468 ld.global.f32 %val, [%off_in];\n\
469\n\
470 cvt.u64.u32 %off_out, %r_tid;\n\
471 shl.b64 %off_out, %off_out, 2;\n\
472 add.u64 %off_out, %out, %off_out;\n\
473 st.global.f32 [%off_out], %val;\n\
474\n\
475DONE:\n\
476 ret;\n\
477}\n\
478";
479
480#[cfg(feature = "cuda")]
488pub(crate) const PERMUTE_0213_PTX: &str = "\
489.version 7.0\n\
490.target sm_52\n\
491.address_size 64\n\
492\n\
493.visible .entry permute_0213_kernel(\n\
494 .param .u64 in_ptr,\n\
495 .param .u64 out_ptr,\n\
496 .param .u32 d0,\n\
497 .param .u32 d1,\n\
498 .param .u32 d2,\n\
499 .param .u32 d3,\n\
500 .param .u32 total\n\
501) {\n\
502 .reg .u32 %r_tid, %bid, %bdim, %total_reg;\n\
503 .reg .u32 %d0r, %d1r, %d2r, %d3r;\n\
504 .reg .u32 %i0, %i1, %i2, %i3, %rem, %in_idx;\n\
505 .reg .u32 %s_out2, %s_out1, %s_in1;\n\
506 .reg .u64 %in, %out, %off_in, %off_out;\n\
507 .reg .f32 %val;\n\
508 .reg .pred %p;\n\
509\n\
510 ld.param.u64 %in, [in_ptr];\n\
511 ld.param.u64 %out, [out_ptr];\n\
512 ld.param.u32 %d0r, [d0];\n\
513 ld.param.u32 %d1r, [d1];\n\
514 ld.param.u32 %d2r, [d2];\n\
515 ld.param.u32 %d3r, [d3];\n\
516 ld.param.u32 %total_reg, [total];\n\
517\n\
518 mov.u32 %bid, %ctaid.x;\n\
519 mov.u32 %bdim, %ntid.x;\n\
520 mov.u32 %r_tid, %tid.x;\n\
521 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
522\n\
523 setp.ge.u32 %p, %r_tid, %total_reg;\n\
524 @%p bra DONE;\n\
525\n\
526 // Output shape: [d0, d2, d1, d3]\n\
527 // Decompose tid into (i0, i2, i1, i3) in output layout.\n\
528 mul.lo.u32 %s_out2, %d1r, %d3r;\n\
529 mul.lo.u32 %s_out1, %s_out2, %d2r;\n\
530\n\
531 div.u32 %i0, %r_tid, %s_out1;\n\
532 rem.u32 %rem, %r_tid, %s_out1;\n\
533 div.u32 %i2, %rem, %s_out2;\n\
534 rem.u32 %rem, %rem, %s_out2;\n\
535 div.u32 %i1, %rem, %d3r;\n\
536 rem.u32 %i3, %rem, %d3r;\n\
537\n\
538 // Input index: i0 * (d1*d2*d3) + i1 * (d2*d3) + i2 * d3 + i3\n\
539 mul.lo.u32 %s_in1, %d2r, %d3r;\n\
540 mul.lo.u32 %in_idx, %i0, %d1r;\n\
541 add.u32 %in_idx, %in_idx, %i1;\n\
542 mul.lo.u32 %in_idx, %in_idx, %s_in1;\n\
543 mad.lo.u32 %in_idx, %i2, %d3r, %in_idx;\n\
544 add.u32 %in_idx, %in_idx, %i3;\n\
545\n\
546 cvt.u64.u32 %off_in, %in_idx;\n\
547 shl.b64 %off_in, %off_in, 2;\n\
548 add.u64 %off_in, %in, %off_in;\n\
549 ld.global.f32 %val, [%off_in];\n\
550\n\
551 cvt.u64.u32 %off_out, %r_tid;\n\
552 shl.b64 %off_out, %off_out, 2;\n\
553 add.u64 %off_out, %out, %off_out;\n\
554 st.global.f32 [%off_out], %val;\n\
555\n\
556DONE:\n\
557 ret;\n\
558}\n\
559";
560
561#[cfg(feature = "cuda")]
568pub(crate) const F32_TO_F16_PTX: &str = "\
569.version 7.0
570.target sm_52
571.address_size 64
572
573.visible .entry f32_to_f16_kernel(
574 .param .u64 in_ptr,
575 .param .u64 out_ptr,
576 .param .u32 n
577) {
578 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
579 .reg .u64 %in, %out, %off_in, %off_out;
580 .reg .f32 %vf;
581 .reg .b16 %vh;
582 .reg .pred %p;
583
584 ld.param.u64 %in, [in_ptr];
585 ld.param.u64 %out, [out_ptr];
586 ld.param.u32 %n_reg, [n];
587
588 mov.u32 %bid, %ctaid.x;
589 mov.u32 %bdim, %ntid.x;
590 mov.u32 %r_tid, %tid.x;
591 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
592
593 setp.ge.u32 %p, %r_tid, %n_reg;
594 @%p bra DONE;
595
596 // Compute input offset: i * 4 (f32 = 4 bytes)
597 cvt.u64.u32 %off_in, %r_tid;
598 shl.b64 %off_in, %off_in, 2;
599 add.u64 %in, %in, %off_in;
600
601 // Compute output offset: i * 2 (f16 = 2 bytes)
602 cvt.u64.u32 %off_out, %r_tid;
603 shl.b64 %off_out, %off_out, 1;
604 add.u64 %out, %out, %off_out;
605
606 // Load f32, convert to f16 (round-to-nearest-even), store as u16
607 ld.global.f32 %vf, [%in];
608 cvt.rn.f16.f32 %vh, %vf;
609 st.global.b16 [%out], %vh;
610
611DONE:
612 ret;
613}
614";
615
616#[cfg(feature = "cuda")]
623pub(crate) const F32_TO_BF16_PTX: &str = "\
624.version 7.0
625.target sm_52
626.address_size 64
627
628.visible .entry f32_to_bf16_kernel(
629 .param .u64 in_ptr,
630 .param .u64 out_ptr,
631 .param .u32 n
632) {
633 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
634 .reg .u64 %in, %out, %off_in, %off_out;
635 .reg .f32 %vf;
636 .reg .u32 %bits, %round, %lsb, %result;
637 .reg .pred %p;
638
639 ld.param.u64 %in, [in_ptr];
640 ld.param.u64 %out, [out_ptr];
641 ld.param.u32 %n_reg, [n];
642
643 mov.u32 %bid, %ctaid.x;
644 mov.u32 %bdim, %ntid.x;
645 mov.u32 %r_tid, %tid.x;
646 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
647
648 setp.ge.u32 %p, %r_tid, %n_reg;
649 @%p bra DONE;
650
651 cvt.u64.u32 %off_in, %r_tid;
652 shl.b64 %off_in, %off_in, 2;
653 add.u64 %in, %in, %off_in;
654
655 cvt.u64.u32 %off_out, %r_tid;
656 shl.b64 %off_out, %off_out, 1;
657 add.u64 %out, %out, %off_out;
658
659 // Load f32 as raw bits
660 ld.global.u32 %bits, [%in];
661
662 // Round-to-nearest-even: add (0x7FFF + bit[16]) then shift right 16
663 shr.u32 %lsb, %bits, 16;
664 and.b32 %lsb, %lsb, 1;
665 add.u32 %round, %bits, 0x7FFF;
666 add.u32 %round, %round, %lsb;
667 shr.u32 %result, %round, 16;
668
669 // Store as u16
670 st.global.u16 [%out], %result;
671
672DONE:
673 ret;
674}
675";
676
677#[cfg(feature = "cuda")]
684pub(crate) const SMALL_MATMUL_PTX: &str = "\
685.version 7.0
686.target sm_52
687.address_size 64
688
689.visible .entry small_matmul_kernel(
690 .param .u64 a_ptr,
691 .param .u64 b_ptr,
692 .param .u64 c_ptr,
693 .param .u32 M,
694 .param .u32 K,
695 .param .u32 N,
696 .param .u32 total
697) {
698 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %K_reg, %N_reg;
699 .reg .u32 %row, %col, %p, %idx;
700 .reg .u64 %a, %b, %c, %a_off, %b_off, %c_off;
701 .reg .f32 %sum, %va, %vb;
702 .reg .pred %bounds_p, %loop_p;
703
704 ld.param.u64 %a, [a_ptr];
705 ld.param.u64 %b, [b_ptr];
706 ld.param.u64 %c, [c_ptr];
707 ld.param.u32 %M_reg, [M];
708 ld.param.u32 %K_reg, [K];
709 ld.param.u32 %N_reg, [N];
710 ld.param.u32 %total_reg, [total];
711
712 mov.u32 %bid, %ctaid.x;
713 mov.u32 %bdim, %ntid.x;
714 mov.u32 %r_tid, %tid.x;
715 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
716
717 setp.ge.u32 %bounds_p, %r_tid, %total_reg;
718 @%bounds_p bra DONE;
719
720 div.u32 %row, %r_tid, %N_reg;
721 rem.u32 %col, %r_tid, %N_reg;
722
723 mov.f32 %sum, 0f00000000;
724 mov.u32 %p, 0;
725DOT:
726 setp.ge.u32 %loop_p, %p, %K_reg;
727 @%loop_p bra DOT_DONE;
728
729 mad.lo.u32 %idx, %row, %K_reg, %p;
730 cvt.u64.u32 %a_off, %idx;
731 shl.b64 %a_off, %a_off, 2;
732 add.u64 %a_off, %a, %a_off;
733 ld.global.f32 %va, [%a_off];
734
735 mad.lo.u32 %idx, %p, %N_reg, %col;
736 cvt.u64.u32 %b_off, %idx;
737 shl.b64 %b_off, %b_off, 2;
738 add.u64 %b_off, %b, %b_off;
739 ld.global.f32 %vb, [%b_off];
740
741 fma.rn.f32 %sum, %va, %vb, %sum;
742 add.u32 %p, %p, 1;
743 bra DOT;
744DOT_DONE:
745
746 cvt.u64.u32 %c_off, %r_tid;
747 shl.b64 %c_off, %c_off, 2;
748 add.u64 %c_off, %c, %c_off;
749 st.global.f32 [%c_off], %sum;
750
751DONE:
752 ret;
753}
754";
755
756#[cfg(feature = "cuda")]
761pub(crate) const SLICE_WRITE_PTX: &str = "\
762.version 7.0
763.target sm_52
764.address_size 64
765
766.visible .entry slice_write_kernel(
767 .param .u64 src_ptr,
768 .param .u64 dst_ptr,
769 .param .u32 n,
770 .param .u32 D,
771 .param .u32 max_len,
772 .param .u32 pos
773) {
774 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
775 .reg .u32 %batch_idx, %d_idx, %dst_row;
776 .reg .u64 %src, %dst, %src_off, %dst_off;
777 .reg .f32 %val;
778 .reg .pred %p;
779
780 ld.param.u64 %src, [src_ptr];
781 ld.param.u64 %dst, [dst_ptr];
782 ld.param.u32 %n_reg, [n];
783 ld.param.u32 %D_reg, [D];
784 ld.param.u32 %max_len_reg, [max_len];
785 ld.param.u32 %pos_reg, [pos];
786
787 mov.u32 %bid, %ctaid.x;
788 mov.u32 %bdim, %ntid.x;
789 mov.u32 %r_tid, %tid.x;
790 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
791
792 setp.ge.u32 %p, %r_tid, %n_reg;
793 @%p bra DONE;
794
795 cvt.u64.u32 %src_off, %r_tid;
796 shl.b64 %src_off, %src_off, 2;
797 add.u64 %src, %src, %src_off;
798 ld.global.f32 %val, [%src];
799
800 div.u32 %batch_idx, %r_tid, %D_reg;
801 rem.u32 %d_idx, %r_tid, %D_reg;
802 mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
803 add.u32 %dst_row, %dst_row, %pos_reg;
804 mul.lo.u32 %dst_row, %dst_row, %D_reg;
805 add.u32 %dst_row, %dst_row, %d_idx;
806 cvt.u64.u32 %dst_off, %dst_row;
807 shl.b64 %dst_off, %dst_off, 2;
808 add.u64 %dst, %dst, %dst_off;
809 st.global.f32 [%dst], %val;
810
811DONE:
812 ret;
813}
814";
815
816#[cfg(feature = "cuda")]
821pub(crate) const SLICE_WRITE_INDIRECT_PTX: &str = "\
822.version 7.0
823.target sm_52
824.address_size 64
825
826.visible .entry slice_write_indirect_kernel(
827 .param .u64 src_ptr,
828 .param .u64 dst_ptr,
829 .param .u32 n,
830 .param .u32 D,
831 .param .u32 max_len,
832 .param .u64 pos_ptr
833) {
834 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
835 .reg .u32 %batch_idx, %d_idx, %dst_row;
836 .reg .u64 %src, %dst, %src_off, %dst_off, %pos_p;
837 .reg .f32 %val;
838 .reg .pred %p;
839
840 ld.param.u64 %src, [src_ptr];
841 ld.param.u64 %dst, [dst_ptr];
842 ld.param.u32 %n_reg, [n];
843 ld.param.u32 %D_reg, [D];
844 ld.param.u32 %max_len_reg, [max_len];
845 ld.param.u64 %pos_p, [pos_ptr];
846 ld.global.u32 %pos_reg, [%pos_p];
847
848 mov.u32 %bid, %ctaid.x;
849 mov.u32 %bdim, %ntid.x;
850 mov.u32 %r_tid, %tid.x;
851 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
852
853 setp.ge.u32 %p, %r_tid, %n_reg;
854 @%p bra DONE;
855
856 cvt.u64.u32 %src_off, %r_tid;
857 shl.b64 %src_off, %src_off, 2;
858 add.u64 %src, %src, %src_off;
859 ld.global.f32 %val, [%src];
860
861 div.u32 %batch_idx, %r_tid, %D_reg;
862 rem.u32 %d_idx, %r_tid, %D_reg;
863 mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
864 add.u32 %dst_row, %dst_row, %pos_reg;
865 mul.lo.u32 %dst_row, %dst_row, %D_reg;
866 add.u32 %dst_row, %dst_row, %d_idx;
867 cvt.u64.u32 %dst_off, %dst_row;
868 shl.b64 %dst_off, %dst_off, 2;
869 add.u64 %dst, %dst, %dst_off;
870 st.global.f32 [%dst], %val;
871
872DONE:
873 ret;
874}
875";
876
877#[cfg(feature = "cuda")]
883pub(crate) const CAUSAL_MASK_INDIRECT_PTX: &str = "\
884.version 7.0
885.target sm_52
886.address_size 64
887
888.visible .entry causal_mask_indirect_kernel(
889 .param .u64 total_len_ptr,
890 .param .u64 out_ptr,
891 .param .u32 max_pos,
892 .param .u32 total
893) {
894 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %tlen, %max_pos_reg, %col;
895 .reg .u64 %out, %off, %tl_p;
896 .reg .f32 %val;
897 .reg .pred %bounds_p, %mask_p;
898
899 ld.param.u64 %tl_p, [total_len_ptr];
900 ld.param.u64 %out, [out_ptr];
901 ld.param.u32 %max_pos_reg, [max_pos];
902 ld.param.u32 %total_reg, [total];
903
904 ld.global.u32 %tlen, [%tl_p];
905
906 mov.u32 %bid, %ctaid.x;
907 mov.u32 %bdim, %ntid.x;
908 mov.u32 %r_tid, %tid.x;
909 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
910
911 setp.ge.u32 %bounds_p, %r_tid, %total_reg;
912 @%bounds_p bra DONE;
913
914 rem.u32 %col, %r_tid, %max_pos_reg;
915 setp.lt.u32 %mask_p, %col, %tlen;
916 @%mask_p bra WRITE_ZERO;
917
918 // 0fCE6E6B28 = -1.0e9 in IEEE 754 f32, used as a large negative mask value
919 // to effectively zero out masked positions after softmax.
920 mov.f32 %val, 0fCE6E6B28;
921 bra WRITE;
922
923WRITE_ZERO:
924 mov.f32 %val, 0f00000000;
925
926WRITE:
927 cvt.u64.u32 %off, %r_tid;
928 shl.b64 %off, %off, 2;
929 add.u64 %out, %out, %off;
930 st.global.f32 [%out], %val;
931
932DONE:
933 ret;
934}
935";
936
937#[cfg(feature = "cuda")]
942pub(crate) const EMBED_LOOKUP_PTX: &str = "\
943.version 7.0
944.target sm_52
945.address_size 64
946
947.visible .entry embed_lookup_kernel(
948 .param .u64 idx_ptr,
949 .param .u64 weight_ptr,
950 .param .u64 out_ptr,
951 .param .u32 D
952) {
953 .reg .u32 %r_tid, %bid, %bdim, %D_reg, %row, %src_idx;
954 .reg .u64 %idx_addr, %w, %out, %off;
955 .reg .f32 %idx_f, %val;
956 .reg .pred %p;
957
958 ld.param.u64 %idx_addr, [idx_ptr];
959 ld.param.u64 %w, [weight_ptr];
960 ld.param.u64 %out, [out_ptr];
961 ld.param.u32 %D_reg, [D];
962
963 mov.u32 %bid, %ctaid.x;
964 mov.u32 %bdim, %ntid.x;
965 mov.u32 %r_tid, %tid.x;
966 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
967
968 setp.ge.u32 %p, %r_tid, %D_reg;
969 @%p bra DONE;
970
971 ld.global.f32 %idx_f, [%idx_addr];
972 cvt.rzi.u32.f32 %row, %idx_f;
973
974 mad.lo.u32 %src_idx, %row, %D_reg, %r_tid;
975 cvt.u64.u32 %off, %src_idx;
976 shl.b64 %off, %off, 2;
977 add.u64 %off, %w, %off;
978 ld.global.f32 %val, [%off];
979
980 cvt.u64.u32 %off, %r_tid;
981 shl.b64 %off, %off, 2;
982 add.u64 %off, %out, %off;
983 st.global.f32 [%off], %val;
984
985DONE:
986 ret;
987}
988";
989
990#[cfg(feature = "cuda")]
998pub(crate) const EMBED_LOOKUP_BATCH_PTX: &str = "\
999.version 7.0
1000.target sm_52
1001.address_size 64
1002
1003.visible .entry embed_lookup_batch_kernel(
1004 .param .u64 idx_ptr,
1005 .param .u64 weight_ptr,
1006 .param .u64 out_ptr,
1007 .param .u32 D,
1008 .param .u32 total
1009) {
1010 .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1011 .reg .u32 %row, %col, %src_idx;
1012 .reg .u64 %idx_addr, %w, %out, %off;
1013 .reg .f32 %idx_f, %val;
1014 .reg .pred %p;
1015
1016 ld.param.u64 %idx_addr, [idx_ptr];
1017 ld.param.u64 %w, [weight_ptr];
1018 ld.param.u64 %out, [out_ptr];
1019 ld.param.u32 %D_reg, [D];
1020 ld.param.u32 %total_reg, [total];
1021
1022 mov.u32 %bid, %ctaid.x;
1023 mov.u32 %bdim, %ntid.x;
1024 mov.u32 %tid, %tid.x;
1025 mad.lo.u32 %tid, %bid, %bdim, %tid;
1026
1027 setp.ge.u32 %p, %tid, %total_reg;
1028 @%p bra DONE;
1029
1030 // row = tid / D, col = tid % D
1031 div.u32 %row, %tid, %D_reg;
1032 rem.u32 %col, %tid, %D_reg;
1033
1034 // Read indices[row] (f32 -> u32)
1035 cvt.u64.u32 %off, %row;
1036 shl.b64 %off, %off, 2;
1037 add.u64 %off, %idx_addr, %off;
1038 ld.global.f32 %idx_f, [%off];
1039 cvt.rzi.u32.f32 %src_idx, %idx_f;
1040
1041 // src_idx = indices[row] * D + col
1042 mad.lo.u32 %src_idx, %src_idx, %D_reg, %col;
1043 cvt.u64.u32 %off, %src_idx;
1044 shl.b64 %off, %off, 2;
1045 add.u64 %off, %w, %off;
1046 ld.global.f32 %val, [%off];
1047
1048 // Write to out[tid]
1049 cvt.u64.u32 %off, %tid;
1050 shl.b64 %off, %off, 2;
1051 add.u64 %off, %out, %off;
1052 st.global.f32 [%off], %val;
1053
1054DONE:
1055 ret;
1056}
1057";
1058
1059#[cfg(feature = "cuda")]
1067pub(crate) const SCATTER_ADD_ROWS_PTX: &str = "\
1068.version 7.0
1069.target sm_52
1070.address_size 64
1071
1072.visible .entry scatter_add_rows_kernel(
1073 .param .u64 grad_output_ptr,
1074 .param .u64 indices_ptr,
1075 .param .u64 grad_weight_ptr,
1076 .param .u32 D,
1077 .param .u32 total
1078) {
1079 .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1080 .reg .u32 %row, %col, %dst_idx;
1081 .reg .u64 %go, %idx_addr, %gw, %off;
1082 .reg .f32 %idx_f, %grad_val, %dummy;
1083 .reg .pred %p;
1084
1085 ld.param.u64 %go, [grad_output_ptr];
1086 ld.param.u64 %idx_addr, [indices_ptr];
1087 ld.param.u64 %gw, [grad_weight_ptr];
1088 ld.param.u32 %D_reg, [D];
1089 ld.param.u32 %total_reg, [total];
1090
1091 mov.u32 %bid, %ctaid.x;
1092 mov.u32 %bdim, %ntid.x;
1093 mov.u32 %tid, %tid.x;
1094 mad.lo.u32 %tid, %bid, %bdim, %tid;
1095
1096 setp.ge.u32 %p, %tid, %total_reg;
1097 @%p bra DONE;
1098
1099 // row = tid / D, col = tid % D
1100 div.u32 %row, %tid, %D_reg;
1101 rem.u32 %col, %tid, %D_reg;
1102
1103 // Read grad_output[tid]
1104 cvt.u64.u32 %off, %tid;
1105 shl.b64 %off, %off, 2;
1106 add.u64 %off, %go, %off;
1107 ld.global.f32 %grad_val, [%off];
1108
1109 // Read indices[row] (f32 -> u32)
1110 cvt.u64.u32 %off, %row;
1111 shl.b64 %off, %off, 2;
1112 add.u64 %off, %idx_addr, %off;
1113 ld.global.f32 %idx_f, [%off];
1114 cvt.rzi.u32.f32 %dst_idx, %idx_f;
1115
1116 // dst_idx = indices[row] * D + col
1117 mad.lo.u32 %dst_idx, %dst_idx, %D_reg, %col;
1118 cvt.u64.u32 %off, %dst_idx;
1119 shl.b64 %off, %off, 2;
1120 add.u64 %off, %gw, %off;
1121 atom.global.add.f32 %dummy, [%off], %grad_val;
1122
1123DONE:
1124 ret;
1125}
1126";
1127
1128#[cfg(feature = "cuda")]
1135pub(crate) const SLICE_READ_PTX: &str = "\
1136.version 7.0
1137.target sm_52
1138.address_size 64
1139
1140.visible .entry slice_read_kernel(
1141 .param .u64 src_ptr,
1142 .param .u64 dst_ptr,
1143 .param .u32 total,
1144 .param .u32 D,
1145 .param .u32 len,
1146 .param .u32 max_len
1147) {
1148 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %D_reg, %len_reg, %max_len_reg;
1149 .reg .u32 %batch_idx, %within, %row, %col, %src_idx;
1150 .reg .u32 %len_d;
1151 .reg .u64 %src, %dst, %src_off, %dst_off;
1152 .reg .f32 %val;
1153 .reg .pred %p;
1154
1155 ld.param.u64 %src, [src_ptr];
1156 ld.param.u64 %dst, [dst_ptr];
1157 ld.param.u32 %total_reg, [total];
1158 ld.param.u32 %D_reg, [D];
1159 ld.param.u32 %len_reg, [len];
1160 ld.param.u32 %max_len_reg, [max_len];
1161
1162 mov.u32 %bid, %ctaid.x;
1163 mov.u32 %bdim, %ntid.x;
1164 mov.u32 %r_tid, %tid.x;
1165 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1166
1167 setp.ge.u32 %p, %r_tid, %total_reg;
1168 @%p bra DONE;
1169
1170 // dst index = r_tid
1171 // batch_idx = r_tid / (len * D)
1172 // within = r_tid % (len * D)
1173 // row = within / D
1174 // col = within % D
1175 // src_idx = batch_idx * max_len * D + row * D + col
1176 mul.lo.u32 %len_d, %len_reg, %D_reg;
1177 div.u32 %batch_idx, %r_tid, %len_d;
1178 rem.u32 %within, %r_tid, %len_d;
1179 div.u32 %row, %within, %D_reg;
1180 rem.u32 %col, %within, %D_reg;
1181
1182 mul.lo.u32 %src_idx, %batch_idx, %max_len_reg;
1183 add.u32 %src_idx, %src_idx, %row;
1184 mul.lo.u32 %src_idx, %src_idx, %D_reg;
1185 add.u32 %src_idx, %src_idx, %col;
1186
1187 cvt.u64.u32 %src_off, %src_idx;
1188 shl.b64 %src_off, %src_off, 2;
1189 add.u64 %src_off, %src, %src_off;
1190 ld.global.f32 %val, [%src_off];
1191
1192 cvt.u64.u32 %dst_off, %r_tid;
1193 shl.b64 %dst_off, %dst_off, 2;
1194 add.u64 %dst_off, %dst, %dst_off;
1195 st.global.f32 [%dst_off], %val;
1196
1197DONE:
1198 ret;
1199}
1200";
1201
1202#[cfg(feature = "cuda")]
1212pub(crate) const GELU_PTX: &str = "\
1213.version 7.0
1214.target sm_52
1215.address_size 64
1216
1217.visible .entry gelu_kernel(
1218 .param .u64 in_ptr,
1219 .param .u64 out_ptr,
1220 .param .u32 n
1221) {
1222 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1223 .reg .u64 %in, %out, %off;
1224 .reg .f32 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
1225 .reg .pred %p;
1226
1227 ld.param.u64 %in, [in_ptr];
1228 ld.param.u64 %out, [out_ptr];
1229 ld.param.u32 %n_reg, [n];
1230
1231 mov.u32 %bid, %ctaid.x;
1232 mov.u32 %bdim, %ntid.x;
1233 mov.u32 %r_tid, %tid.x;
1234 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1235
1236 setp.ge.u32 %p, %r_tid, %n_reg;
1237 @%p bra DONE;
1238
1239 cvt.u64.u32 %off, %r_tid;
1240 shl.b64 %off, %off, 2;
1241 add.u64 %in, %in, %off;
1242 add.u64 %out, %out, %off;
1243
1244 ld.global.f32 %x, [%in];
1245
1246 mov.f32 %k, 0f3FDA2720;
1247 mul.f32 %neg_kx, %k, %x;
1248 neg.f32 %neg_kx, %neg_kx;
1249 mul.f32 %neg_kx, %neg_kx, 0f3FB8AA3B;
1250 ex2.approx.f32 %exp_neg, %neg_kx;
1251 mov.f32 %one, 0f3F800000;
1252 add.f32 %denom, %one, %exp_neg;
1253 rcp.approx.f32 %sig, %denom;
1254 mul.f32 %result, %x, %sig;
1255 st.global.f32 [%out], %result;
1256
1257DONE:
1258 ret;
1259}
1260";
1261
1262#[cfg(feature = "cuda")]
1268pub(crate) const GELU_TANH_PTX: &str = "\
1269.version 7.0
1270.target sm_52
1271.address_size 64
1272
1273.visible .entry gelu_tanh_kernel(
1274 .param .u64 in_ptr,
1275 .param .u64 out_ptr,
1276 .param .u32 n
1277) {
1278 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1279 .reg .u64 %in, %out, %off;
1280 .reg .f32 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
1281 .reg .f32 %e2y_m1, %e2y_p1, %th, %one, %half, %log2e, %result;
1282 .reg .pred %p;
1283
1284 ld.param.u64 %in, [in_ptr];
1285 ld.param.u64 %out, [out_ptr];
1286 ld.param.u32 %n_reg, [n];
1287
1288 mov.u32 %bid, %ctaid.x;
1289 mov.u32 %bdim, %ntid.x;
1290 mov.u32 %r_tid, %tid.x;
1291 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1292
1293 setp.ge.u32 %p, %r_tid, %n_reg;
1294 @%p bra DONE;
1295
1296 cvt.u64.u32 %off, %r_tid;
1297 shl.b64 %off, %off, 2;
1298 add.u64 %in, %in, %off;
1299 add.u64 %out, %out, %off;
1300
1301 ld.global.f32 %x, [%in];
1302
1303 // inner = sqrt(2/π) * (x + 0.044715 * x³)
1304 // sqrt(2/π) = 0.7978845608 = 0x3F4C422A
1305 // 0.044715 = 0x3D372713
1306 mul.f32 %x3, %x, %x;
1307 mul.f32 %x3, %x3, %x;
1308 mov.f32 %c, 0f3D372713;
1309 mul.f32 %x3, %c, %x3;
1310 add.f32 %inner, %x, %x3;
1311 mov.f32 %sqrt2pi, 0f3F4C422A;
1312 mul.f32 %y, %sqrt2pi, %inner;
1313
1314 // tanh(y) = (exp(2y) - 1) / (exp(2y) + 1)
1315 // exp(2y) = 2^(2y * log2(e))
1316 mov.f32 %log2e, 0f3FB8AA3B;
1317 add.f32 %two_y, %y, %y;
1318 mul.f32 %two_y, %two_y, %log2e;
1319 ex2.approx.f32 %e2y, %two_y;
1320 mov.f32 %one, 0f3F800000;
1321 sub.f32 %e2y_m1, %e2y, %one;
1322 add.f32 %e2y_p1, %e2y, %one;
1323 rcp.approx.f32 %e2y_p1, %e2y_p1;
1324 mul.f32 %th, %e2y_m1, %e2y_p1;
1325
1326 // out = 0.5 * x * (1 + tanh)
1327 add.f32 %th, %one, %th;
1328 mov.f32 %half, 0f3F000000;
1329 mul.f32 %result, %half, %x;
1330 mul.f32 %result, %result, %th;
1331 st.global.f32 [%out], %result;
1332
1333DONE:
1334 ret;
1335}
1336";
1337
1338#[cfg(feature = "cuda")]
1343pub(crate) const GELU_ERF_PTX: &str = "\
1344.version 7.0
1345.target sm_52
1346.address_size 64
1347
1348.visible .entry gelu_erf_kernel(
1349 .param .u64 in_ptr,
1350 .param .u64 out_ptr,
1351 .param .u32 n
1352) {
1353 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1354 .reg .u64 %in, %out, %off;
1355 .reg .f32 %x, %z, %ax, %one, %half, %log2e;
1356 .reg .f32 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
1357 .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %result;
1358 .reg .pred %pred_ge, %pred_neg;
1359
1360 ld.param.u64 %in, [in_ptr];
1361 ld.param.u64 %out, [out_ptr];
1362 ld.param.u32 %n_reg, [n];
1363
1364 mov.u32 %bid, %ctaid.x;
1365 mov.u32 %bdim, %ntid.x;
1366 mov.u32 %r_tid, %tid.x;
1367 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1368
1369 setp.ge.u32 %pred_ge, %r_tid, %n_reg;
1370 @%pred_ge bra DONE;
1371
1372 cvt.u64.u32 %off, %r_tid;
1373 shl.b64 %off, %off, 2;
1374 add.u64 %in, %in, %off;
1375 add.u64 %out, %out, %off;
1376
1377 ld.global.f32 %x, [%in];
1378 mov.f32 %one, 0f3F800000;
1379 mov.f32 %half, 0f3F000000;
1380 mov.f32 %log2e, 0f3FB8AA3B;
1381
1382 // z = x / sqrt(2) = x * 0.70710678
1383 mov.f32 %z, 0f3F3504F3;
1384 mul.f32 %z, %x, %z;
1385
1386 // |z| for erf(|z|)
1387 abs.f32 %ax, %z;
1388
1389 // t = 1 / (1 + 0.3275911 * |z|)
1390 mov.f32 %p, 0f3EA7BA05;
1391 mul.f32 %t, %p, %ax;
1392 add.f32 %t, %one, %t;
1393 rcp.approx.f32 %t, %t;
1394
1395 // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
1396 mov.f32 %a5, 0f3E0AAAAB;
1397 mov.f32 %a4, 0fBEB3A903;
1398 mov.f32 %a3, 0f3FB506DD;
1399 mov.f32 %a2, 0fBF03C1E1;
1400 mov.f32 %a1, 0f3EA0D6BB;
1401
1402 mul.f32 %pt, %t, %a5;
1403 add.f32 %pt, %pt, %a4;
1404 mul.f32 %pt, %pt, %t;
1405 add.f32 %pt, %pt, %a3;
1406 mul.f32 %pt, %pt, %t;
1407 add.f32 %pt, %pt, %a2;
1408 mul.f32 %pt, %pt, %t;
1409 add.f32 %pt, %pt, %a1;
1410 mul.f32 %pt, %pt, %t;
1411
1412 // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
1413 mul.f32 %z2, %ax, %ax;
1414 neg.f32 %neg_z2, %z2;
1415 mul.f32 %neg_z2, %neg_z2, %log2e;
1416 ex2.approx.f32 %exp_neg_z2, %neg_z2;
1417
1418 // erf(|z|) = 1 - poly * exp(-z^2)
1419 mul.f32 %erf_val, %pt, %exp_neg_z2;
1420 sub.f32 %erf_val, %one, %erf_val;
1421
1422 // erf(-z) = -erf(z), so sign-correct
1423 setp.lt.f32 %pred_neg, %z, 0f00000000;
1424 @%pred_neg neg.f32 %erf_val, %erf_val;
1425
1426 // out = x * 0.5 * (1 + erf(x/sqrt(2)))
1427 add.f32 %erf_val, %one, %erf_val;
1428 mul.f32 %result, %half, %x;
1429 mul.f32 %result, %result, %erf_val;
1430 st.global.f32 [%out], %result;
1431
1432DONE:
1433 ret;
1434}
1435";
1436
1437#[cfg(feature = "cuda")]
1443pub(crate) const GELU_BACKWARD_TANH_PTX: &str = "\
1444.version 7.0
1445.target sm_52
1446.address_size 64
1447
1448.visible .entry gelu_backward_tanh_kernel(
1449 .param .u64 grad_ptr,
1450 .param .u64 input_ptr,
1451 .param .u64 out_ptr,
1452 .param .u32 n
1453) {
1454 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1455 .reg .u64 %grad, %input, %out, %off;
1456 .reg .f32 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
1457 .reg .f32 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half, %log2e;
1458 .reg .f32 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
1459 .reg .pred %p;
1460
1461 ld.param.u64 %grad, [grad_ptr];
1462 ld.param.u64 %input, [input_ptr];
1463 ld.param.u64 %out, [out_ptr];
1464 ld.param.u32 %n_reg, [n];
1465
1466 mov.u32 %bid, %ctaid.x;
1467 mov.u32 %bdim, %ntid.x;
1468 mov.u32 %r_tid, %tid.x;
1469 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1470
1471 setp.ge.u32 %p, %r_tid, %n_reg;
1472 @%p bra DONE;
1473
1474 cvt.u64.u32 %off, %r_tid;
1475 shl.b64 %off, %off, 2;
1476 add.u64 %grad, %grad, %off;
1477 add.u64 %input, %input, %off;
1478 add.u64 %out, %out, %off;
1479
1480 ld.global.f32 %vg, [%grad];
1481 ld.global.f32 %x, [%input];
1482
1483 mov.f32 %one, 0f3F800000;
1484 mov.f32 %half, 0f3F000000;
1485 mov.f32 %log2e, 0f3FB8AA3B;
1486 mov.f32 %sqrt2pi, 0f3F4C422A;
1487 mov.f32 %c, 0f3D372713;
1488 // 3 * 0.044715 = 0.134145 = 0x3E096B8C
1489 mov.f32 %c3, 0f3E096B8C;
1490
1491 // u = sqrt(2/π) * (x + 0.044715 * x³)
1492 mul.f32 %x2, %x, %x;
1493 mul.f32 %x3, %x2, %x;
1494 mul.f32 %x3, %c, %x3;
1495 add.f32 %inner, %x, %x3;
1496 mul.f32 %y, %sqrt2pi, %inner;
1497
1498 // tanh(y) via exp
1499 add.f32 %two_y, %y, %y;
1500 mul.f32 %two_y, %two_y, %log2e;
1501 ex2.approx.f32 %e2y, %two_y;
1502 sub.f32 %e2y_m1, %e2y, %one;
1503 add.f32 %e2y_p1, %e2y, %one;
1504 rcp.approx.f32 %e2y_p1, %e2y_p1;
1505 mul.f32 %th, %e2y_m1, %e2y_p1;
1506
1507 // d/dx = 0.5*(1+tanh) + 0.5*x*(1-tanh²)*sqrt(2/π)*(1+3*0.044715*x²)
1508 // term1 = 0.5 * (1 + th)
1509 add.f32 %term1, %one, %th;
1510 mul.f32 %term1, %half, %term1;
1511
1512 // (1 - th²)
1513 mul.f32 %th2, %th, %th;
1514 sub.f32 %one_m_th2, %one, %th2;
1515
1516 // d_inner = sqrt(2/π) * (1 + 3*0.044715*x²)
1517 mul.f32 %d_inner, %c3, %x2;
1518 add.f32 %d_inner, %one, %d_inner;
1519 mul.f32 %d_inner, %sqrt2pi, %d_inner;
1520
1521 // term2 = 0.5 * x * (1-th²) * d_inner
1522 mul.f32 %term2, %half, %x;
1523 mul.f32 %term2, %term2, %one_m_th2;
1524 mul.f32 %term2, %term2, %d_inner;
1525
1526 add.f32 %d_gelu, %term1, %term2;
1527 mul.f32 %result, %vg, %d_gelu;
1528 st.global.f32 [%out], %result;
1529
1530DONE:
1531 ret;
1532}
1533";
1534
1535#[cfg(feature = "cuda")]
1542pub(crate) const SILU_PTX: &str = "\
1543.version 7.0
1544.target sm_52
1545.address_size 64
1546
1547.visible .entry silu_kernel(
1548 .param .u64 a_ptr,
1549 .param .u64 out_ptr,
1550 .param .u32 n
1551) {
1552 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1553 .reg .u64 %a, %out, %off;
1554 .reg .f32 %x, %neg, %e, %denom, %sig, %vr, %one, %lg2e;
1555 .reg .pred %p;
1556
1557 ld.param.u64 %a, [a_ptr];
1558 ld.param.u64 %out, [out_ptr];
1559 ld.param.u32 %n_reg, [n];
1560
1561 mov.u32 %bid, %ctaid.x;
1562 mov.u32 %bdim, %ntid.x;
1563 mov.u32 %r_tid, %tid.x;
1564 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1565
1566 setp.ge.u32 %p, %r_tid, %n_reg;
1567 @%p bra DONE;
1568
1569 cvt.u64.u32 %off, %r_tid;
1570 shl.b64 %off, %off, 2;
1571
1572 add.u64 %a, %a, %off;
1573 add.u64 %out, %out, %off;
1574
1575 ld.global.f32 %x, [%a];
1576 // sigmoid(x) = 1 / (1 + exp(-x))
1577 // exp(-x) = 2^(-x * log2(e))
1578 mov.f32 %one, 0f3F800000;
1579 mov.f32 %lg2e, 0f3FB8AA3B;
1580 neg.f32 %neg, %x;
1581 mul.f32 %neg, %neg, %lg2e;
1582 ex2.approx.f32 %e, %neg;
1583 add.f32 %denom, %one, %e;
1584 rcp.approx.f32 %sig, %denom;
1585 // silu(x) = x * sigmoid(x)
1586 mul.f32 %vr, %x, %sig;
1587 st.global.f32 [%out], %vr;
1588
1589DONE:
1590 ret;
1591}
1592";
1593
1594#[cfg(feature = "cuda")]
1597pub(crate) const SILU_BACKWARD_PTX: &str = "\
1598.version 7.0
1599.target sm_52
1600.address_size 64
1601
1602.visible .entry silu_backward_kernel(
1603 .param .u64 grad_ptr,
1604 .param .u64 input_ptr,
1605 .param .u64 out_ptr,
1606 .param .u32 n
1607) {
1608 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1609 .reg .u64 %grad, %input, %out, %off;
1610 .reg .f32 %vg, %x, %neg, %e, %denom, %sig, %one, %lg2e;
1611 .reg .f32 %one_m_sig, %x_sig_omsig, %deriv, %result;
1612 .reg .pred %p;
1613
1614 ld.param.u64 %grad, [grad_ptr];
1615 ld.param.u64 %input, [input_ptr];
1616 ld.param.u64 %out, [out_ptr];
1617 ld.param.u32 %n_reg, [n];
1618
1619 mov.u32 %bid, %ctaid.x;
1620 mov.u32 %bdim, %ntid.x;
1621 mov.u32 %r_tid, %tid.x;
1622 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1623
1624 setp.ge.u32 %p, %r_tid, %n_reg;
1625 @%p bra DONE;
1626
1627 cvt.u64.u32 %off, %r_tid;
1628 shl.b64 %off, %off, 2;
1629 add.u64 %grad, %grad, %off;
1630 add.u64 %input, %input, %off;
1631 add.u64 %out, %out, %off;
1632
1633 ld.global.f32 %vg, [%grad];
1634 ld.global.f32 %x, [%input];
1635
1636 // sig = sigmoid(x) = 1 / (1 + exp(-x))
1637 mov.f32 %one, 0f3F800000;
1638 mov.f32 %lg2e, 0f3FB8AA3B;
1639 neg.f32 %neg, %x;
1640 mul.f32 %neg, %neg, %lg2e;
1641 ex2.approx.f32 %e, %neg;
1642 add.f32 %denom, %one, %e;
1643 rcp.approx.f32 %sig, %denom;
1644
1645 // deriv = sig + x * sig * (1 - sig)
1646 sub.f32 %one_m_sig, %one, %sig;
1647 mul.f32 %x_sig_omsig, %x, %sig;
1648 mul.f32 %x_sig_omsig, %x_sig_omsig, %one_m_sig;
1649 add.f32 %deriv, %sig, %x_sig_omsig;
1650 mul.f32 %result, %vg, %deriv;
1651 st.global.f32 [%out], %result;
1652
1653DONE:
1654 ret;
1655}
1656";
1657
1658#[cfg(feature = "cuda")]
1661pub(crate) const ELU_PTX: &str = "\
1662.version 7.0
1663.target sm_52
1664.address_size 64
1665
1666.visible .entry elu_kernel(
1667 .param .u64 a_ptr,
1668 .param .u64 out_ptr,
1669 .param .u32 n,
1670 .param .f32 alpha
1671) {
1672 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1673 .reg .u64 %a, %out, %off;
1674 .reg .f32 %x, %alpha_r, %lg2e, %one, %ex, %em1, %neg_branch, %vr;
1675 .reg .pred %p, %pos;
1676
1677 ld.param.u64 %a, [a_ptr];
1678 ld.param.u64 %out, [out_ptr];
1679 ld.param.u32 %n_reg, [n];
1680 ld.param.f32 %alpha_r, [alpha];
1681
1682 mov.u32 %bid, %ctaid.x;
1683 mov.u32 %bdim, %ntid.x;
1684 mov.u32 %r_tid, %tid.x;
1685 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1686
1687 setp.ge.u32 %p, %r_tid, %n_reg;
1688 @%p bra DONE;
1689
1690 cvt.u64.u32 %off, %r_tid;
1691 shl.b64 %off, %off, 2;
1692
1693 add.u64 %a, %a, %off;
1694 add.u64 %out, %out, %off;
1695
1696 ld.global.f32 %x, [%a];
1697 mov.f32 %one, 0f3F800000;
1698 mov.f32 %lg2e, 0f3FB8AA3B;
1699
1700 // exp(x) = 2^(x * log2(e))
1701 mul.f32 %ex, %x, %lg2e;
1702 ex2.approx.f32 %ex, %ex;
1703 sub.f32 %em1, %ex, %one;
1704 mul.f32 %neg_branch, %alpha_r, %em1;
1705
1706 // x > 0 ? x : alpha*(exp(x)-1)
1707 mov.f32 %vr, 0f00000000;
1708 setp.gt.f32 %pos, %x, %vr;
1709 selp.f32 %vr, %x, %neg_branch, %pos;
1710 st.global.f32 [%out], %vr;
1711
1712DONE:
1713 ret;
1714}
1715";
1716
1717#[cfg(feature = "cuda")]
1721pub(crate) const ELU_BACKWARD_PTX: &str = "\
1722.version 7.0
1723.target sm_52
1724.address_size 64
1725
1726.visible .entry elu_backward_kernel(
1727 .param .u64 grad_ptr,
1728 .param .u64 input_ptr,
1729 .param .u64 out_ptr,
1730 .param .u32 n,
1731 .param .f32 alpha
1732) {
1733 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1734 .reg .u64 %grad, %input, %out, %off;
1735 .reg .f32 %vg, %x, %alpha_r, %lg2e, %ex, %neg_branch, %vr, %zero;
1736 .reg .pred %p, %pos;
1737
1738 ld.param.u64 %grad, [grad_ptr];
1739 ld.param.u64 %input, [input_ptr];
1740 ld.param.u64 %out, [out_ptr];
1741 ld.param.u32 %n_reg, [n];
1742 ld.param.f32 %alpha_r, [alpha];
1743
1744 mov.u32 %bid, %ctaid.x;
1745 mov.u32 %bdim, %ntid.x;
1746 mov.u32 %r_tid, %tid.x;
1747 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1748
1749 setp.ge.u32 %p, %r_tid, %n_reg;
1750 @%p bra DONE;
1751
1752 cvt.u64.u32 %off, %r_tid;
1753 shl.b64 %off, %off, 2;
1754 add.u64 %grad, %grad, %off;
1755 add.u64 %input, %input, %off;
1756 add.u64 %out, %out, %off;
1757
1758 ld.global.f32 %vg, [%grad];
1759 ld.global.f32 %x, [%input];
1760
1761 mov.f32 %lg2e, 0f3FB8AA3B;
1762 mov.f32 %zero, 0f00000000;
1763
1764 // exp(x) = 2^(x * log2(e))
1765 mul.f32 %ex, %x, %lg2e;
1766 ex2.approx.f32 %ex, %ex;
1767 // negative branch: grad * alpha * exp(x)
1768 mul.f32 %neg_branch, %vg, %alpha_r;
1769 mul.f32 %neg_branch, %neg_branch, %ex;
1770
1771 // x > 0 ? grad : grad * alpha * exp(x)
1772 setp.gt.f32 %pos, %x, %zero;
1773 selp.f32 %vr, %vg, %neg_branch, %pos;
1774 st.global.f32 [%out], %vr;
1775
1776DONE:
1777 ret;
1778}
1779";
1780
1781#[cfg(feature = "cuda")]
1785pub(crate) const MISH_PTX: &str = "\
1786.version 7.0
1787.target sm_52
1788.address_size 64
1789
1790.visible .entry mish_kernel(
1791 .param .u64 a_ptr,
1792 .param .u64 out_ptr,
1793 .param .u32 n
1794) {
1795 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1796 .reg .u64 %a, %out, %off;
1797 .reg .f32 %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
1798 .reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %th, %vr;
1799 .reg .f32 %threshold;
1800 .reg .pred %p, %large;
1801
1802 ld.param.u64 %a, [a_ptr];
1803 ld.param.u64 %out, [out_ptr];
1804 ld.param.u32 %n_reg, [n];
1805
1806 mov.u32 %bid, %ctaid.x;
1807 mov.u32 %bdim, %ntid.x;
1808 mov.u32 %r_tid, %tid.x;
1809 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1810
1811 setp.ge.u32 %p, %r_tid, %n_reg;
1812 @%p bra DONE;
1813
1814 cvt.u64.u32 %off, %r_tid;
1815 shl.b64 %off, %off, 2;
1816
1817 add.u64 %a, %a, %off;
1818 add.u64 %out, %out, %off;
1819
1820 ld.global.f32 %x, [%a];
1821 mov.f32 %one, 0f3F800000;
1822 mov.f32 %lg2e, 0f3FB8AA3B;
1823 // threshold = 20.0 = 0x41A00000
1824 mov.f32 %threshold, 0f41A00000;
1825
1826 // softplus(x) = ln(1 + exp(x))
1827 // For large x (> 20), softplus ~ x to avoid overflow
1828 setp.gt.f32 %large, %x, %threshold;
1829 @%large bra LARGE_X;
1830
1831 // exp(x) = 2^(x * log2(e))
1832 mul.f32 %ex, %x, %lg2e;
1833 ex2.approx.f32 %ex, %ex;
1834 add.f32 %ep1, %ex, %one;
1835 // ln(1+exp(x)) = log2(1+exp(x)) / log2(e)
1836 lg2.approx.f32 %lg_ep1, %ep1;
1837 // 1/log2(e) = ln(2) = 0.6931472 = 0x3F317218
1838 mul.f32 %sp, %lg_ep1, 0f3F317218;
1839
1840 // tanh(sp) = (exp(2*sp) - 1) / (exp(2*sp) + 1)
1841 add.f32 %two_sp, %sp, %sp;
1842 mul.f32 %two_sp, %two_sp, %lg2e;
1843 ex2.approx.f32 %e2sp, %two_sp;
1844 sub.f32 %e2sp_m1, %e2sp, %one;
1845 add.f32 %e2sp_p1, %e2sp, %one;
1846 rcp.approx.f32 %e2sp_p1, %e2sp_p1;
1847 mul.f32 %th, %e2sp_m1, %e2sp_p1;
1848
1849 mul.f32 %vr, %x, %th;
1850 st.global.f32 [%out], %vr;
1851 bra DONE;
1852
1853LARGE_X:
1854 // softplus ~ x, mish ~ x * tanh(x)
1855 // tanh(x) = (exp(2x)-1)/(exp(2x)+1)
1856 add.f32 %two_sp, %x, %x;
1857 mul.f32 %two_sp, %two_sp, %lg2e;
1858 ex2.approx.f32 %e2sp, %two_sp;
1859 sub.f32 %e2sp_m1, %e2sp, %one;
1860 add.f32 %e2sp_p1, %e2sp, %one;
1861 rcp.approx.f32 %e2sp_p1, %e2sp_p1;
1862 mul.f32 %th, %e2sp_m1, %e2sp_p1;
1863 mul.f32 %vr, %x, %th;
1864 st.global.f32 [%out], %vr;
1865
1866DONE:
1867 ret;
1868}
1869";
1870
1871#[cfg(feature = "cuda")]
1880pub(crate) const MISH_BACKWARD_PTX: &str = "\
1881.version 7.0
1882.target sm_52
1883.address_size 64
1884
1885.visible .entry mish_backward_kernel(
1886 .param .u64 grad_ptr,
1887 .param .u64 input_ptr,
1888 .param .u64 out_ptr,
1889 .param .u32 n
1890) {
1891 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1892 .reg .u64 %grad, %input, %out, %off;
1893 .reg .f32 %vg, %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
1894 .reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %t, %t2, %one_m_t2;
1895 .reg .f32 %neg, %en, %denom, %sig, %x_sig_omt2, %deriv, %result;
1896 .reg .f32 %threshold;
1897 .reg .pred %p, %large;
1898
1899 ld.param.u64 %grad, [grad_ptr];
1900 ld.param.u64 %input, [input_ptr];
1901 ld.param.u64 %out, [out_ptr];
1902 ld.param.u32 %n_reg, [n];
1903
1904 mov.u32 %bid, %ctaid.x;
1905 mov.u32 %bdim, %ntid.x;
1906 mov.u32 %r_tid, %tid.x;
1907 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1908
1909 setp.ge.u32 %p, %r_tid, %n_reg;
1910 @%p bra DONE;
1911
1912 cvt.u64.u32 %off, %r_tid;
1913 shl.b64 %off, %off, 2;
1914 add.u64 %grad, %grad, %off;
1915 add.u64 %input, %input, %off;
1916 add.u64 %out, %out, %off;
1917
1918 ld.global.f32 %vg, [%grad];
1919 ld.global.f32 %x, [%input];
1920
1921 mov.f32 %one, 0f3F800000;
1922 mov.f32 %lg2e, 0f3FB8AA3B;
1923 // threshold = 20.0
1924 mov.f32 %threshold, 0f41A00000;
1925
1926 setp.gt.f32 %large, %x, %threshold;
1927 @%large bra LARGE_X;
1928
1929 // --- Normal path ---
1930 // softplus: sp = ln(1 + exp(x))
1931 mul.f32 %ex, %x, %lg2e;
1932 ex2.approx.f32 %ex, %ex;
1933 add.f32 %ep1, %ex, %one;
1934 lg2.approx.f32 %lg_ep1, %ep1;
1935 // ln(2) = 0x3F317218
1936 mul.f32 %sp, %lg_ep1, 0f3F317218;
1937
1938 // t = tanh(sp) = (exp(2*sp)-1)/(exp(2*sp)+1)
1939 add.f32 %two_sp, %sp, %sp;
1940 mul.f32 %two_sp, %two_sp, %lg2e;
1941 ex2.approx.f32 %e2sp, %two_sp;
1942 sub.f32 %e2sp_m1, %e2sp, %one;
1943 add.f32 %e2sp_p1, %e2sp, %one;
1944 rcp.approx.f32 %e2sp_p1, %e2sp_p1;
1945 mul.f32 %t, %e2sp_m1, %e2sp_p1;
1946
1947 // sig = sigmoid(x) = 1/(1+exp(-x))
1948 neg.f32 %neg, %x;
1949 mul.f32 %neg, %neg, %lg2e;
1950 ex2.approx.f32 %en, %neg;
1951 add.f32 %denom, %one, %en;
1952 rcp.approx.f32 %sig, %denom;
1953
1954 // deriv = t + x * sig * (1 - t*t)
1955 mul.f32 %t2, %t, %t;
1956 sub.f32 %one_m_t2, %one, %t2;
1957 mul.f32 %x_sig_omt2, %x, %sig;
1958 mul.f32 %x_sig_omt2, %x_sig_omt2, %one_m_t2;
1959 add.f32 %deriv, %t, %x_sig_omt2;
1960 mul.f32 %result, %vg, %deriv;
1961 st.global.f32 [%out], %result;
1962 bra DONE;
1963
1964LARGE_X:
1965 // sp ~ x, t ~ tanh(x), sig ~ 1
1966 // tanh(x) = (exp(2x)-1)/(exp(2x)+1)
1967 add.f32 %two_sp, %x, %x;
1968 mul.f32 %two_sp, %two_sp, %lg2e;
1969 ex2.approx.f32 %e2sp, %two_sp;
1970 sub.f32 %e2sp_m1, %e2sp, %one;
1971 add.f32 %e2sp_p1, %e2sp, %one;
1972 rcp.approx.f32 %e2sp_p1, %e2sp_p1;
1973 mul.f32 %t, %e2sp_m1, %e2sp_p1;
1974
1975 // sig ~ 1, deriv ~ t + x*(1-t*t)
1976 mul.f32 %t2, %t, %t;
1977 sub.f32 %one_m_t2, %one, %t2;
1978 mul.f32 %x_sig_omt2, %x, %one_m_t2;
1979 add.f32 %deriv, %t, %x_sig_omt2;
1980 mul.f32 %result, %vg, %deriv;
1981 st.global.f32 [%out], %result;
1982
1983DONE:
1984 ret;
1985}
1986";
1987
1988#[cfg(feature = "cuda")]
1991pub(crate) const CLAMP_PTX: &str = "\
1992.version 7.0
1993.target sm_52
1994.address_size 64
1995
1996.visible .entry clamp_kernel(
1997 .param .u64 in_ptr,
1998 .param .u64 out_ptr,
1999 .param .u32 n,
2000 .param .f32 min_val,
2001 .param .f32 max_val
2002) {
2003 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2004 .reg .u64 %in, %out, %off;
2005 .reg .f32 %x, %mn, %mx, %result;
2006 .reg .pred %p;
2007
2008 ld.param.u64 %in, [in_ptr];
2009 ld.param.u64 %out, [out_ptr];
2010 ld.param.u32 %n_reg, [n];
2011 ld.param.f32 %mn, [min_val];
2012 ld.param.f32 %mx, [max_val];
2013
2014 mov.u32 %bid, %ctaid.x;
2015 mov.u32 %bdim, %ntid.x;
2016 mov.u32 %r_tid, %tid.x;
2017 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2018
2019 setp.ge.u32 %p, %r_tid, %n_reg;
2020 @%p bra DONE;
2021
2022 cvt.u64.u32 %off, %r_tid;
2023 shl.b64 %off, %off, 2;
2024 add.u64 %in, %in, %off;
2025 add.u64 %out, %out, %off;
2026
2027 ld.global.f32 %x, [%in];
2028 max.f32 %result, %x, %mn;
2029 min.f32 %result, %result, %mx;
2030 st.global.f32 [%out], %result;
2031
2032DONE:
2033 ret;
2034}
2035";
2036
2037#[cfg(feature = "cuda")]
2044pub(crate) const RELU_BACKWARD_PTX: &str = "\
2045.version 7.0
2046.target sm_52
2047.address_size 64
2048
2049.visible .entry relu_backward_kernel(
2050 .param .u64 grad_ptr,
2051 .param .u64 input_ptr,
2052 .param .u64 out_ptr,
2053 .param .u32 n
2054) {
2055 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2056 .reg .u64 %grad, %input, %out, %off;
2057 .reg .f32 %vg, %vi, %zero, %vr;
2058 .reg .pred %p, %pos;
2059
2060 ld.param.u64 %grad, [grad_ptr];
2061 ld.param.u64 %input, [input_ptr];
2062 ld.param.u64 %out, [out_ptr];
2063 ld.param.u32 %n_reg, [n];
2064
2065 mov.u32 %bid, %ctaid.x;
2066 mov.u32 %bdim, %ntid.x;
2067 mov.u32 %r_tid, %tid.x;
2068 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2069
2070 setp.ge.u32 %p, %r_tid, %n_reg;
2071 @%p bra DONE;
2072
2073 cvt.u64.u32 %off, %r_tid;
2074 shl.b64 %off, %off, 2;
2075
2076 add.u64 %grad, %grad, %off;
2077 add.u64 %input, %input, %off;
2078 add.u64 %out, %out, %off;
2079
2080 ld.global.f32 %vg, [%grad];
2081 ld.global.f32 %vi, [%input];
2082 mov.f32 %zero, 0f00000000;
2083 setp.gt.f32 %pos, %vi, %zero;
2084 selp.f32 %vr, %vg, %zero, %pos;
2085 st.global.f32 [%out], %vr;
2086
2087DONE:
2088 ret;
2089}
2090";
2091
2092#[cfg(feature = "cuda")]
2102pub(crate) const GELU_BACKWARD_PTX: &str = "\
2103.version 7.0
2104.target sm_52
2105.address_size 64
2106
2107.visible .entry gelu_backward_kernel(
2108 .param .u64 grad_ptr,
2109 .param .u64 input_ptr,
2110 .param .u64 out_ptr,
2111 .param .u32 n
2112) {
2113 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2114 .reg .u64 %grad, %input, %out, %off;
2115 .reg .f32 %vg, %x, %k, %kx, %neg_kx, %log2e, %exp_neg, %one, %denom, %sig;
2116 .reg .f32 %one_minus_sig, %kx_sig_oms, %dsig, %result;
2117 .reg .pred %p;
2118
2119 ld.param.u64 %grad, [grad_ptr];
2120 ld.param.u64 %input, [input_ptr];
2121 ld.param.u64 %out, [out_ptr];
2122 ld.param.u32 %n_reg, [n];
2123
2124 mov.u32 %bid, %ctaid.x;
2125 mov.u32 %bdim, %ntid.x;
2126 mov.u32 %r_tid, %tid.x;
2127 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2128
2129 setp.ge.u32 %p, %r_tid, %n_reg;
2130 @%p bra DONE;
2131
2132 cvt.u64.u32 %off, %r_tid;
2133 shl.b64 %off, %off, 2;
2134
2135 add.u64 %grad, %grad, %off;
2136 add.u64 %input, %input, %off;
2137 add.u64 %out, %out, %off;
2138
2139 ld.global.f32 %vg, [%grad];
2140 ld.global.f32 %x, [%input];
2141
2142 // sig = sigmoid(1.702 * x)
2143 mov.f32 %k, 0f3FDA2720;
2144 mul.f32 %kx, %k, %x;
2145 neg.f32 %neg_kx, %kx;
2146 mov.f32 %log2e, 0f3FB8AA3B;
2147 mul.f32 %neg_kx, %neg_kx, %log2e;
2148 ex2.approx.f32 %exp_neg, %neg_kx;
2149 mov.f32 %one, 0f3F800000;
2150 add.f32 %denom, %one, %exp_neg;
2151 rcp.approx.f32 %sig, %denom;
2152
2153 // d/dx gelu(x) = sig + k * x * sig * (1 - sig)
2154 sub.f32 %one_minus_sig, %one, %sig;
2155 mul.f32 %kx_sig_oms, %kx, %sig;
2156 mul.f32 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
2157 add.f32 %dsig, %sig, %kx_sig_oms;
2158
2159 // out = grad * d_gelu
2160 mul.f32 %result, %vg, %dsig;
2161 st.global.f32 [%out], %result;
2162
2163DONE:
2164 ret;
2165}
2166";
2167
2168#[cfg(feature = "cuda")]
2176pub(crate) const GELU_BACKWARD_ERF_PTX: &str = "\
2177.version 7.0
2178.target sm_52
2179.address_size 64
2180
2181.visible .entry gelu_backward_erf_kernel(
2182 .param .u64 grad_ptr,
2183 .param .u64 input_ptr,
2184 .param .u64 out_ptr,
2185 .param .u32 n
2186) {
2187 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2188 .reg .u64 %grad, %input, %out, %off;
2189 .reg .f32 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
2190 .reg .f32 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
2191 .reg .f32 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
2192 .reg .f32 %d_gelu, %result;
2193 .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %log2e;
2194 .reg .pred %pred_ge, %pred_neg;
2195
2196 ld.param.u64 %grad, [grad_ptr];
2197 ld.param.u64 %input, [input_ptr];
2198 ld.param.u64 %out, [out_ptr];
2199 ld.param.u32 %n_reg, [n];
2200
2201 mov.u32 %bid, %ctaid.x;
2202 mov.u32 %bdim, %ntid.x;
2203 mov.u32 %r_tid, %tid.x;
2204 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2205
2206 setp.ge.u32 %pred_ge, %r_tid, %n_reg;
2207 @%pred_ge bra DONE;
2208
2209 cvt.u64.u32 %off, %r_tid;
2210 shl.b64 %off, %off, 2;
2211
2212 add.u64 %grad, %grad, %off;
2213 add.u64 %input, %input, %off;
2214 add.u64 %out, %out, %off;
2215
2216 ld.global.f32 %vg, [%grad];
2217 ld.global.f32 %x, [%input];
2218
2219 mov.f32 %one, 0f3F800000;
2220 mov.f32 %half, 0f3F000000;
2221
2222 // z = x / sqrt(2) = x * 0.70710678
2223 mov.f32 %z, 0f3F3504F3;
2224 mul.f32 %z, %x, %z;
2225
2226 // |z| for erf(|z|)
2227 abs.f32 %ax, %z;
2228
2229 // t = 1 / (1 + 0.3275911 * |z|)
2230 mov.f32 %p, 0f3EA7BA05;
2231 mul.f32 %t, %p, %ax;
2232 add.f32 %t, %one, %t;
2233 rcp.approx.f32 %t, %t;
2234
2235 // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
2236 mov.f32 %a5, 0f3E0AAAAB;
2237 mov.f32 %a4, 0fBEB3A903;
2238 mov.f32 %a3, 0f3FB506DD;
2239 mov.f32 %a2, 0fBF03C1E1;
2240 mov.f32 %a1, 0f3EA0D6BB;
2241
2242 mul.f32 %pt, %t, %a5;
2243 add.f32 %pt, %pt, %a4;
2244 mul.f32 %pt, %pt, %t;
2245 add.f32 %pt, %pt, %a3;
2246 mul.f32 %pt, %pt, %t;
2247 add.f32 %pt, %pt, %a2;
2248 mul.f32 %pt, %pt, %t;
2249 add.f32 %pt, %pt, %a1;
2250 mul.f32 %pt, %pt, %t;
2251
2252 // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
2253 mul.f32 %z2, %ax, %ax;
2254 neg.f32 %neg_z2, %z2;
2255 mov.f32 %log2e, 0f3FB8AA3B;
2256 mul.f32 %neg_z2, %neg_z2, %log2e;
2257 ex2.approx.f32 %exp_neg_z2, %neg_z2;
2258
2259 // erf(|z|) = 1 - poly * exp(-z^2)
2260 mul.f32 %erf_val, %pt, %exp_neg_z2;
2261 sub.f32 %erf_val, %one, %erf_val;
2262
2263 // erf(-z) = -erf(z), so sign-correct
2264 setp.lt.f32 %pred_neg, %z, 0f00000000;
2265 @%pred_neg neg.f32 %erf_val, %erf_val;
2266
2267 // Φ(x) = 0.5 * (1 + erf(x/sqrt(2)))
2268 add.f32 %cdf, %one, %erf_val;
2269 mul.f32 %cdf, %half, %cdf;
2270
2271 // φ(x) = exp(-x²/2) / sqrt(2π)
2272 // exp(-x²/2):
2273 mul.f32 %neg_x2h, %x, %x;
2274 mul.f32 %neg_x2h, %neg_x2h, %half;
2275 neg.f32 %neg_x2h, %neg_x2h;
2276 mul.f32 %neg_x2h, %neg_x2h, %log2e;
2277 ex2.approx.f32 %exp_neg_x2h, %neg_x2h;
2278
2279 // 1/sqrt(2π) = 0.39894228
2280 mov.f32 %inv_sqrt_2pi, 0f3ECC4220;
2281 mul.f32 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
2282
2283 // d/dx gelu(x) = Φ(x) + x * φ(x)
2284 mul.f32 %x_pdf, %x, %pdf;
2285 add.f32 %d_gelu, %cdf, %x_pdf;
2286
2287 // out = grad * d_gelu
2288 mul.f32 %result, %vg, %d_gelu;
2289 st.global.f32 [%out], %result;
2290
2291DONE:
2292 ret;
2293}
2294";
2295
2296#[cfg(feature = "cuda")]
2303pub(crate) const INDEX_SELECT_1D_PTX: &str = "\
2304.version 7.0
2305.target sm_52
2306.address_size 64
2307
2308.visible .entry index_select_1d_kernel(
2309 .param .u64 input_ptr,
2310 .param .u64 indices_ptr,
2311 .param .u64 out_ptr,
2312 .param .u32 n_indices
2313) {
2314 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
2315 .reg .u64 %input, %indices, %out, %off, %addr;
2316 .reg .f32 %idx_f, %val;
2317 .reg .pred %p;
2318
2319 ld.param.u64 %input, [input_ptr];
2320 ld.param.u64 %indices, [indices_ptr];
2321 ld.param.u64 %out, [out_ptr];
2322 ld.param.u32 %n_reg, [n_indices];
2323
2324 mov.u32 %bid, %ctaid.x;
2325 mov.u32 %bdim, %ntid.x;
2326 mov.u32 %r_tid, %tid.x;
2327 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2328
2329 setp.ge.u32 %p, %r_tid, %n_reg;
2330 @%p bra DONE;
2331
2332 // Byte offset for thread
2333 cvt.u64.u32 %off, %r_tid;
2334 shl.b64 %off, %off, 2;
2335
2336 // Read indices[tid] (f32 -> u32)
2337 add.u64 %addr, %indices, %off;
2338 ld.global.f32 %idx_f, [%addr];
2339 cvt.rzi.u32.f32 %idx, %idx_f;
2340
2341 // Read input[idx]
2342 cvt.u64.u32 %addr, %idx;
2343 shl.b64 %addr, %addr, 2;
2344 add.u64 %addr, %input, %addr;
2345 ld.global.f32 %val, [%addr];
2346
2347 // Write output[tid]
2348 add.u64 %addr, %out, %off;
2349 st.global.f32 [%addr], %val;
2350
2351DONE:
2352 ret;
2353}
2354";
2355
2356#[cfg(feature = "cuda")]
2365pub(crate) const SCATTER_ADD_1D_PTX: &str = "\
2366.version 7.0
2367.target sm_52
2368.address_size 64
2369
2370.visible .entry scatter_add_1d_kernel(
2371 .param .u64 grad_output_ptr,
2372 .param .u64 indices_ptr,
2373 .param .u64 grad_input_ptr,
2374 .param .u32 n_indices
2375) {
2376 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
2377 .reg .u64 %go, %indices, %gi, %off, %addr;
2378 .reg .f32 %idx_f, %grad_val, %dummy;
2379 .reg .pred %p;
2380
2381 ld.param.u64 %go, [grad_output_ptr];
2382 ld.param.u64 %indices, [indices_ptr];
2383 ld.param.u64 %gi, [grad_input_ptr];
2384 ld.param.u32 %n_reg, [n_indices];
2385
2386 mov.u32 %bid, %ctaid.x;
2387 mov.u32 %bdim, %ntid.x;
2388 mov.u32 %r_tid, %tid.x;
2389 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2390
2391 setp.ge.u32 %p, %r_tid, %n_reg;
2392 @%p bra DONE;
2393
2394 // Byte offset for thread
2395 cvt.u64.u32 %off, %r_tid;
2396 shl.b64 %off, %off, 2;
2397
2398 // Read grad_output[tid]
2399 add.u64 %addr, %go, %off;
2400 ld.global.f32 %grad_val, [%addr];
2401
2402 // Read indices[tid] (f32 -> u32)
2403 add.u64 %addr, %indices, %off;
2404 ld.global.f32 %idx_f, [%addr];
2405 cvt.rzi.u32.f32 %idx, %idx_f;
2406
2407 // Atomic add: grad_input[idx] += grad_val
2408 cvt.u64.u32 %addr, %idx;
2409 shl.b64 %addr, %addr, 2;
2410 add.u64 %addr, %gi, %addr;
2411 atom.global.add.f32 %dummy, [%addr], %grad_val;
2412
2413DONE:
2414 ret;
2415}
2416";
2417
2418#[cfg(feature = "cuda")]
2425pub(crate) const MASKED_FILL_PTX: &str = "\
2426.version 7.0
2427.target sm_52
2428.address_size 64
2429
2430.visible .entry masked_fill_kernel(
2431 .param .u64 input_ptr,
2432 .param .u64 mask_ptr,
2433 .param .u64 out_ptr,
2434 .param .f32 fill_value,
2435 .param .u32 n
2436) {
2437 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2438 .reg .u64 %input, %mask, %out, %off;
2439 .reg .f32 %in_val, %mask_val, %fill, %result, %half;
2440 .reg .pred %p, %pmask;
2441
2442 ld.param.u64 %input, [input_ptr];
2443 ld.param.u64 %mask, [mask_ptr];
2444 ld.param.u64 %out, [out_ptr];
2445 ld.param.f32 %fill, [fill_value];
2446 ld.param.u32 %n_reg, [n];
2447
2448 mov.u32 %bid, %ctaid.x;
2449 mov.u32 %bdim, %ntid.x;
2450 mov.u32 %r_tid, %tid.x;
2451 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2452
2453 setp.ge.u32 %p, %r_tid, %n_reg;
2454 @%p bra DONE;
2455
2456 cvt.u64.u32 %off, %r_tid;
2457 shl.b64 %off, %off, 2;
2458
2459 add.u64 %input, %input, %off;
2460 add.u64 %mask, %mask, %off;
2461 add.u64 %out, %out, %off;
2462
2463 ld.global.f32 %in_val, [%input];
2464 ld.global.f32 %mask_val, [%mask];
2465 mov.f32 %half, 0f3F000000;
2466 setp.ge.f32 %pmask, %mask_val, %half;
2467 selp.f32 %result, %fill, %in_val, %pmask;
2468 st.global.f32 [%out], %result;
2469
2470DONE:
2471 ret;
2472}
2473";
2474
2475#[cfg(feature = "cuda")]
2482pub(crate) const MASKED_ZERO_PTX: &str = "\
2483.version 7.0
2484.target sm_52
2485.address_size 64
2486
2487.visible .entry masked_zero_kernel(
2488 .param .u64 grad_ptr,
2489 .param .u64 mask_ptr,
2490 .param .u64 out_ptr,
2491 .param .u32 n
2492) {
2493 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2494 .reg .u64 %grad, %mask, %out, %off;
2495 .reg .f32 %vg, %mask_val, %zero, %result, %half;
2496 .reg .pred %p, %pmask;
2497
2498 ld.param.u64 %grad, [grad_ptr];
2499 ld.param.u64 %mask, [mask_ptr];
2500 ld.param.u64 %out, [out_ptr];
2501 ld.param.u32 %n_reg, [n];
2502
2503 mov.u32 %bid, %ctaid.x;
2504 mov.u32 %bdim, %ntid.x;
2505 mov.u32 %r_tid, %tid.x;
2506 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2507
2508 setp.ge.u32 %p, %r_tid, %n_reg;
2509 @%p bra DONE;
2510
2511 cvt.u64.u32 %off, %r_tid;
2512 shl.b64 %off, %off, 2;
2513
2514 add.u64 %grad, %grad, %off;
2515 add.u64 %mask, %mask, %off;
2516 add.u64 %out, %out, %off;
2517
2518 ld.global.f32 %vg, [%grad];
2519 ld.global.f32 %mask_val, [%mask];
2520 mov.f32 %zero, 0f00000000;
2521 mov.f32 %half, 0f3F000000;
2522 setp.ge.f32 %pmask, %mask_val, %half;
2523 selp.f32 %result, %zero, %vg, %pmask;
2524 st.global.f32 [%out], %result;
2525
2526DONE:
2527 ret;
2528}
2529";
2530
2531#[cfg(feature = "cuda")]
2536pub(crate) const SIGMOID_BACKWARD_PTX: &str = "\
2537.version 7.0
2538.target sm_52
2539.address_size 64
2540
2541.visible .entry sigmoid_backward_kernel(
2542 .param .u64 grad_ptr,
2543 .param .u64 output_ptr,
2544 .param .u64 out_ptr,
2545 .param .u32 n
2546) {
2547 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2548 .reg .u64 %grad, %output, %out, %off;
2549 .reg .f32 %vg, %vo, %one, %one_minus_o, %result;
2550 .reg .pred %p;
2551
2552 ld.param.u64 %grad, [grad_ptr];
2553 ld.param.u64 %output, [output_ptr];
2554 ld.param.u64 %out, [out_ptr];
2555 ld.param.u32 %n_reg, [n];
2556
2557 mov.u32 %bid, %ctaid.x;
2558 mov.u32 %bdim, %ntid.x;
2559 mov.u32 %r_tid, %tid.x;
2560 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2561
2562 setp.ge.u32 %p, %r_tid, %n_reg;
2563 @%p bra DONE;
2564
2565 cvt.u64.u32 %off, %r_tid;
2566 shl.b64 %off, %off, 2;
2567
2568 add.u64 %grad, %grad, %off;
2569 add.u64 %output, %output, %off;
2570 add.u64 %out, %out, %off;
2571
2572 ld.global.f32 %vg, [%grad];
2573 ld.global.f32 %vo, [%output];
2574 mov.f32 %one, 0f3F800000;
2575 sub.f32 %one_minus_o, %one, %vo;
2576 mul.f32 %result, %vo, %one_minus_o;
2577 mul.f32 %result, %vg, %result;
2578 st.global.f32 [%out], %result;
2579
2580DONE:
2581 ret;
2582}
2583";
2584
2585#[cfg(feature = "cuda")]
2590pub(crate) const TANH_BACKWARD_PTX: &str = "\
2591.version 7.0
2592.target sm_52
2593.address_size 64
2594
2595.visible .entry tanh_backward_kernel(
2596 .param .u64 grad_ptr,
2597 .param .u64 output_ptr,
2598 .param .u64 out_ptr,
2599 .param .u32 n
2600) {
2601 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2602 .reg .u64 %grad, %output, %out, %off;
2603 .reg .f32 %vg, %vo, %one, %o_sq, %one_minus_sq, %result;
2604 .reg .pred %p;
2605
2606 ld.param.u64 %grad, [grad_ptr];
2607 ld.param.u64 %output, [output_ptr];
2608 ld.param.u64 %out, [out_ptr];
2609 ld.param.u32 %n_reg, [n];
2610
2611 mov.u32 %bid, %ctaid.x;
2612 mov.u32 %bdim, %ntid.x;
2613 mov.u32 %r_tid, %tid.x;
2614 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2615
2616 setp.ge.u32 %p, %r_tid, %n_reg;
2617 @%p bra DONE;
2618
2619 cvt.u64.u32 %off, %r_tid;
2620 shl.b64 %off, %off, 2;
2621
2622 add.u64 %grad, %grad, %off;
2623 add.u64 %output, %output, %off;
2624 add.u64 %out, %out, %off;
2625
2626 ld.global.f32 %vg, [%grad];
2627 ld.global.f32 %vo, [%output];
2628 mov.f32 %one, 0f3F800000;
2629 mul.f32 %o_sq, %vo, %vo;
2630 sub.f32 %one_minus_sq, %one, %o_sq;
2631 mul.f32 %result, %vg, %one_minus_sq;
2632 st.global.f32 [%out], %result;
2633
2634DONE:
2635 ret;
2636}
2637";
2638
2639#[cfg(feature = "cuda")]
2648pub(crate) const SOFTMAX_BACKWARD_PTX: &str = "\
2649.version 7.0\n\
2650.target sm_52\n\
2651.address_size 64\n\
2652\n\
2653.shared .align 4 .f32 sdata[256];\n\
2654\n\
2655.visible .entry softmax_backward_kernel(\n\
2656 .param .u64 grad_ptr,\n\
2657 .param .u64 output_ptr,\n\
2658 .param .u64 out_ptr,\n\
2659 .param .u32 rows,\n\
2660 .param .u32 cols\n\
2661) {\n\
2662 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
2663 .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
2664 .reg .f32 %vg, %vo, %dot, %other_val, %diff, %result;\n\
2665 .reg .pred %p, %loop_p, %reduce_p;\n\
2666\n\
2667 ld.param.u64 %grad, [grad_ptr];\n\
2668 ld.param.u64 %output, [output_ptr];\n\
2669 ld.param.u64 %out, [out_ptr];\n\
2670 ld.param.u32 %rows_reg, [rows];\n\
2671 ld.param.u32 %cols_reg, [cols];\n\
2672\n\
2673 mov.u32 %bid, %ctaid.x;\n\
2674 mov.u32 %bdim, %ntid.x;\n\
2675 mov.u32 %r_tid, %tid.x;\n\
2676 mov.u64 %sbase, sdata;\n\
2677\n\
2678 setp.ge.u32 %p, %bid, %rows_reg;\n\
2679 @%p bra DONE;\n\
2680\n\
2681 // row_off = bid * cols * 4 (byte offset)\n\
2682 cvt.u64.u32 %row_off, %bid;\n\
2683 cvt.u64.u32 %off, %cols_reg;\n\
2684 mul.lo.u64 %row_off, %row_off, %off;\n\
2685 shl.b64 %row_off, %row_off, 2;\n\
2686\n\
2687 // Phase 1: compute partial dot = sum(grad[j] * output[j]) for this thread's elements\n\
2688 mov.f32 %dot, 0f00000000;\n\
2689 mov.u32 %j, %r_tid;\n\
2690DOT_LOOP:\n\
2691 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
2692 @%loop_p bra DOT_LOOP_DONE;\n\
2693 cvt.u64.u32 %off, %j;\n\
2694 shl.b64 %off, %off, 2;\n\
2695 add.u64 %saddr, %grad, %off;\n\
2696 add.u64 %saddr, %saddr, %row_off;\n\
2697 ld.global.f32 %vg, [%saddr];\n\
2698 add.u64 %saddr, %output, %off;\n\
2699 add.u64 %saddr, %saddr, %row_off;\n\
2700 ld.global.f32 %vo, [%saddr];\n\
2701 fma.rn.f32 %dot, %vg, %vo, %dot;\n\
2702 add.u32 %j, %j, %bdim;\n\
2703 bra DOT_LOOP;\n\
2704DOT_LOOP_DONE:\n\
2705\n\
2706 // Store partial dot into shared memory and reduce\n\
2707 cvt.u64.u32 %off, %r_tid;\n\
2708 shl.b64 %off, %off, 2;\n\
2709 add.u64 %saddr, %sbase, %off;\n\
2710 st.shared.f32 [%saddr], %dot;\n\
2711 bar.sync 0;\n\
2712\n\
2713 mov.u32 %half, %bdim;\n\
2714DOT_REDUCE:\n\
2715 shr.u32 %half, %half, 1;\n\
2716 setp.eq.u32 %reduce_p, %half, 0;\n\
2717 @%reduce_p bra DOT_REDUCE_DONE;\n\
2718 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
2719 @%reduce_p bra DOT_REDUCE_SKIP;\n\
2720 add.u32 %other_tid, %r_tid, %half;\n\
2721 cvt.u64.u32 %off, %other_tid;\n\
2722 shl.b64 %off, %off, 2;\n\
2723 add.u64 %saddr, %sbase, %off;\n\
2724 ld.shared.f32 %other_val, [%saddr];\n\
2725 cvt.u64.u32 %off, %r_tid;\n\
2726 shl.b64 %off, %off, 2;\n\
2727 add.u64 %saddr, %sbase, %off;\n\
2728 ld.shared.f32 %dot, [%saddr];\n\
2729 add.f32 %dot, %dot, %other_val;\n\
2730 st.shared.f32 [%saddr], %dot;\n\
2731DOT_REDUCE_SKIP:\n\
2732 bar.sync 0;\n\
2733 bra DOT_REDUCE;\n\
2734DOT_REDUCE_DONE:\n\
2735\n\
2736 // Broadcast dot to all threads\n\
2737 ld.shared.f32 %dot, [sdata];\n\
2738 bar.sync 0;\n\
2739\n\
2740 // Phase 2: out[j] = output[j] * (grad[j] - dot)\n\
2741 mov.u32 %j, %r_tid;\n\
2742WRITE_LOOP:\n\
2743 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
2744 @%loop_p bra WRITE_LOOP_DONE;\n\
2745 cvt.u64.u32 %off, %j;\n\
2746 shl.b64 %off, %off, 2;\n\
2747 add.u64 %saddr, %grad, %off;\n\
2748 add.u64 %saddr, %saddr, %row_off;\n\
2749 ld.global.f32 %vg, [%saddr];\n\
2750 add.u64 %saddr, %output, %off;\n\
2751 add.u64 %saddr, %saddr, %row_off;\n\
2752 ld.global.f32 %vo, [%saddr];\n\
2753 sub.f32 %diff, %vg, %dot;\n\
2754 mul.f32 %result, %vo, %diff;\n\
2755 add.u64 %saddr, %out, %off;\n\
2756 add.u64 %saddr, %saddr, %row_off;\n\
2757 st.global.f32 [%saddr], %result;\n\
2758 add.u32 %j, %j, %bdim;\n\
2759 bra WRITE_LOOP;\n\
2760WRITE_LOOP_DONE:\n\
2761\n\
2762DONE:\n\
2763 ret;\n\
2764}\n\
2765";
2766
2767#[cfg(feature = "cuda")]
2777pub(crate) const LOG_SOFTMAX_PTX: &str = "\
2778.version 7.0\n\
2779.target sm_52\n\
2780.address_size 64\n\
2781\n\
2782.shared .align 4 .f32 sdata[256];\n\
2783\n\
2784.visible .entry log_softmax_kernel(\n\
2785 .param .u64 input_ptr,\n\
2786 .param .u64 output_ptr,\n\
2787 .param .u32 rows,\n\
2788 .param .u32 cols\n\
2789) {\n\
2790 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
2791 .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
2792 .reg .f32 %val, %max_val, %sum_val, %exp_val, %log_sum_exp, %result;\n\
2793 .reg .pred %p, %loop_p;\n\
2794 .reg .u32 %half, %other_tid;\n\
2795 .reg .f32 %other_val;\n\
2796 .reg .pred %reduce_p;\n\
2797\n\
2798 ld.param.u64 %in, [input_ptr];\n\
2799 ld.param.u64 %out, [output_ptr];\n\
2800 ld.param.u32 %rows_reg, [rows];\n\
2801 ld.param.u32 %cols_reg, [cols];\n\
2802\n\
2803 mov.u32 %bid, %ctaid.x;\n\
2804 mov.u32 %bdim, %ntid.x;\n\
2805 mov.u32 %r_tid, %tid.x;\n\
2806 mov.u64 %sbase, sdata;\n\
2807\n\
2808 setp.ge.u32 %p, %bid, %rows_reg;\n\
2809 @%p bra DONE;\n\
2810\n\
2811 // row_off = bid * cols * 4 (byte offset)\n\
2812 cvt.u64.u32 %row_off, %bid;\n\
2813 cvt.u64.u32 %off, %cols_reg;\n\
2814 mul.lo.u64 %row_off, %row_off, %off;\n\
2815 shl.b64 %row_off, %row_off, 2;\n\
2816\n\
2817 // Phase 1: find max across row (grid-stride over columns)\n\
2818 mov.f32 %max_val, 0fFF800000;\n\
2819 mov.u32 %j, %r_tid;\n\
2820FIND_MAX:\n\
2821 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
2822 @%loop_p bra FIND_MAX_DONE;\n\
2823 cvt.u64.u32 %off, %j;\n\
2824 shl.b64 %off, %off, 2;\n\
2825 add.u64 %off, %in, %off;\n\
2826 add.u64 %off, %off, %row_off;\n\
2827 ld.global.f32 %val, [%off];\n\
2828 max.f32 %max_val, %max_val, %val;\n\
2829 add.u32 %j, %j, %bdim;\n\
2830 bra FIND_MAX;\n\
2831FIND_MAX_DONE:\n\
2832\n\
2833 // Shared-memory tree reduction for max\n\
2834 cvt.u64.u32 %off, %r_tid;\n\
2835 shl.b64 %off, %off, 2;\n\
2836 add.u64 %saddr, %sbase, %off;\n\
2837 st.shared.f32 [%saddr], %max_val;\n\
2838 bar.sync 0;\n\
2839\n\
2840 mov.u32 %half, %bdim;\n\
2841MAX_REDUCE:\n\
2842 shr.u32 %half, %half, 1;\n\
2843 setp.eq.u32 %reduce_p, %half, 0;\n\
2844 @%reduce_p bra MAX_REDUCE_DONE;\n\
2845 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
2846 @%reduce_p bra MAX_REDUCE_SKIP;\n\
2847 add.u32 %other_tid, %r_tid, %half;\n\
2848 cvt.u64.u32 %off, %other_tid;\n\
2849 shl.b64 %off, %off, 2;\n\
2850 add.u64 %saddr, %sbase, %off;\n\
2851 ld.shared.f32 %other_val, [%saddr];\n\
2852 cvt.u64.u32 %off, %r_tid;\n\
2853 shl.b64 %off, %off, 2;\n\
2854 add.u64 %saddr, %sbase, %off;\n\
2855 ld.shared.f32 %max_val, [%saddr];\n\
2856 max.f32 %max_val, %max_val, %other_val;\n\
2857 add.u64 %saddr, %sbase, %off;\n\
2858 st.shared.f32 [%saddr], %max_val;\n\
2859MAX_REDUCE_SKIP:\n\
2860 bar.sync 0;\n\
2861 bra MAX_REDUCE;\n\
2862MAX_REDUCE_DONE:\n\
2863\n\
2864 // Broadcast max to all threads\n\
2865 ld.shared.f32 %max_val, [sdata];\n\
2866 bar.sync 0;\n\
2867\n\
2868 // Phase 2: compute partial sum of exp(x[j] - max)\n\
2869 mov.f32 %sum_val, 0f00000000;\n\
2870 mov.u32 %j, %r_tid;\n\
2871SUM_EXP:\n\
2872 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
2873 @%loop_p bra SUM_EXP_DONE;\n\
2874 cvt.u64.u32 %off, %j;\n\
2875 shl.b64 %off, %off, 2;\n\
2876 add.u64 %off, %in, %off;\n\
2877 add.u64 %off, %off, %row_off;\n\
2878 ld.global.f32 %val, [%off];\n\
2879 sub.f32 %val, %val, %max_val;\n\
2880 // exp(x) = exp2(x * log2(e)), log2(e) = 0x3FB8AA3B\n\
2881 mul.f32 %val, %val, 0f3FB8AA3B;\n\
2882 ex2.approx.f32 %exp_val, %val;\n\
2883 add.f32 %sum_val, %sum_val, %exp_val;\n\
2884 add.u32 %j, %j, %bdim;\n\
2885 bra SUM_EXP;\n\
2886SUM_EXP_DONE:\n\
2887\n\
2888 // Shared-memory tree reduction for sum\n\
2889 cvt.u64.u32 %off, %r_tid;\n\
2890 shl.b64 %off, %off, 2;\n\
2891 add.u64 %saddr, %sbase, %off;\n\
2892 st.shared.f32 [%saddr], %sum_val;\n\
2893 bar.sync 0;\n\
2894\n\
2895 mov.u32 %half, %bdim;\n\
2896SUM_REDUCE:\n\
2897 shr.u32 %half, %half, 1;\n\
2898 setp.eq.u32 %reduce_p, %half, 0;\n\
2899 @%reduce_p bra SUM_REDUCE_DONE;\n\
2900 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
2901 @%reduce_p bra SUM_REDUCE_SKIP;\n\
2902 add.u32 %other_tid, %r_tid, %half;\n\
2903 cvt.u64.u32 %off, %other_tid;\n\
2904 shl.b64 %off, %off, 2;\n\
2905 add.u64 %saddr, %sbase, %off;\n\
2906 ld.shared.f32 %other_val, [%saddr];\n\
2907 cvt.u64.u32 %off, %r_tid;\n\
2908 shl.b64 %off, %off, 2;\n\
2909 add.u64 %saddr, %sbase, %off;\n\
2910 ld.shared.f32 %sum_val, [%saddr];\n\
2911 add.f32 %sum_val, %sum_val, %other_val;\n\
2912 add.u64 %saddr, %sbase, %off;\n\
2913 st.shared.f32 [%saddr], %sum_val;\n\
2914SUM_REDUCE_SKIP:\n\
2915 bar.sync 0;\n\
2916 bra SUM_REDUCE;\n\
2917SUM_REDUCE_DONE:\n\
2918\n\
2919 // Broadcast sum to all threads, compute log_sum_exp = max + log(sum)\n\
2920 ld.shared.f32 %sum_val, [sdata];\n\
2921 bar.sync 0;\n\
2922 // log(x) = log2(x) / log2(e) = log2(x) * ln(2)\n\
2923 // ln(2) = 0x3F317218\n\
2924 lg2.approx.f32 %log_sum_exp, %sum_val;\n\
2925 mul.f32 %log_sum_exp, %log_sum_exp, 0f3F317218;\n\
2926 add.f32 %log_sum_exp, %max_val, %log_sum_exp;\n\
2927\n\
2928 // Phase 3: out[j] = x[j] - log_sum_exp\n\
2929 mov.u32 %j, %r_tid;\n\
2930WRITE_OUTPUT:\n\
2931 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
2932 @%loop_p bra WRITE_OUTPUT_DONE;\n\
2933 cvt.u64.u32 %off, %j;\n\
2934 shl.b64 %off, %off, 2;\n\
2935 add.u64 %saddr, %in, %off;\n\
2936 add.u64 %saddr, %saddr, %row_off;\n\
2937 ld.global.f32 %val, [%saddr];\n\
2938 sub.f32 %result, %val, %log_sum_exp;\n\
2939 cvt.u64.u32 %off, %j;\n\
2940 shl.b64 %off, %off, 2;\n\
2941 add.u64 %saddr, %out, %off;\n\
2942 add.u64 %saddr, %saddr, %row_off;\n\
2943 st.global.f32 [%saddr], %result;\n\
2944 add.u32 %j, %j, %bdim;\n\
2945 bra WRITE_OUTPUT;\n\
2946WRITE_OUTPUT_DONE:\n\
2947\n\
2948DONE:\n\
2949 ret;\n\
2950}\n\
2951";
2952
2953#[cfg(feature = "cuda")]
2963pub(crate) const LOG_SOFTMAX_BACKWARD_PTX: &str = "\
2964.version 7.0\n\
2965.target sm_52\n\
2966.address_size 64\n\
2967\n\
2968.shared .align 4 .f32 sdata[256];\n\
2969\n\
2970.visible .entry log_softmax_backward_kernel(\n\
2971 .param .u64 grad_ptr,\n\
2972 .param .u64 output_ptr,\n\
2973 .param .u64 out_ptr,\n\
2974 .param .u32 rows,\n\
2975 .param .u32 cols\n\
2976) {\n\
2977 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
2978 .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
2979 .reg .f32 %vg, %vo, %sum_grad, %other_val, %softmax_j, %result;\n\
2980 .reg .pred %p, %loop_p, %reduce_p;\n\
2981\n\
2982 ld.param.u64 %grad, [grad_ptr];\n\
2983 ld.param.u64 %output, [output_ptr];\n\
2984 ld.param.u64 %out, [out_ptr];\n\
2985 ld.param.u32 %rows_reg, [rows];\n\
2986 ld.param.u32 %cols_reg, [cols];\n\
2987\n\
2988 mov.u32 %bid, %ctaid.x;\n\
2989 mov.u32 %bdim, %ntid.x;\n\
2990 mov.u32 %r_tid, %tid.x;\n\
2991 mov.u64 %sbase, sdata;\n\
2992\n\
2993 setp.ge.u32 %p, %bid, %rows_reg;\n\
2994 @%p bra DONE;\n\
2995\n\
2996 // row_off = bid * cols * 4 (byte offset)\n\
2997 cvt.u64.u32 %row_off, %bid;\n\
2998 cvt.u64.u32 %off, %cols_reg;\n\
2999 mul.lo.u64 %row_off, %row_off, %off;\n\
3000 shl.b64 %row_off, %row_off, 2;\n\
3001\n\
3002 // Phase 1: compute partial sum_grad = sum(grad[j]) for this thread's elements\n\
3003 mov.f32 %sum_grad, 0f00000000;\n\
3004 mov.u32 %j, %r_tid;\n\
3005SUM_LOOP:\n\
3006 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3007 @%loop_p bra SUM_LOOP_DONE;\n\
3008 cvt.u64.u32 %off, %j;\n\
3009 shl.b64 %off, %off, 2;\n\
3010 add.u64 %saddr, %grad, %off;\n\
3011 add.u64 %saddr, %saddr, %row_off;\n\
3012 ld.global.f32 %vg, [%saddr];\n\
3013 add.f32 %sum_grad, %sum_grad, %vg;\n\
3014 add.u32 %j, %j, %bdim;\n\
3015 bra SUM_LOOP;\n\
3016SUM_LOOP_DONE:\n\
3017\n\
3018 // Store partial sum into shared memory and reduce\n\
3019 cvt.u64.u32 %off, %r_tid;\n\
3020 shl.b64 %off, %off, 2;\n\
3021 add.u64 %saddr, %sbase, %off;\n\
3022 st.shared.f32 [%saddr], %sum_grad;\n\
3023 bar.sync 0;\n\
3024\n\
3025 mov.u32 %half, %bdim;\n\
3026SUM_REDUCE:\n\
3027 shr.u32 %half, %half, 1;\n\
3028 setp.eq.u32 %reduce_p, %half, 0;\n\
3029 @%reduce_p bra SUM_REDUCE_DONE;\n\
3030 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
3031 @%reduce_p bra SUM_REDUCE_SKIP;\n\
3032 add.u32 %other_tid, %r_tid, %half;\n\
3033 cvt.u64.u32 %off, %other_tid;\n\
3034 shl.b64 %off, %off, 2;\n\
3035 add.u64 %saddr, %sbase, %off;\n\
3036 ld.shared.f32 %other_val, [%saddr];\n\
3037 cvt.u64.u32 %off, %r_tid;\n\
3038 shl.b64 %off, %off, 2;\n\
3039 add.u64 %saddr, %sbase, %off;\n\
3040 ld.shared.f32 %sum_grad, [%saddr];\n\
3041 add.f32 %sum_grad, %sum_grad, %other_val;\n\
3042 st.shared.f32 [%saddr], %sum_grad;\n\
3043SUM_REDUCE_SKIP:\n\
3044 bar.sync 0;\n\
3045 bra SUM_REDUCE;\n\
3046SUM_REDUCE_DONE:\n\
3047\n\
3048 // Broadcast sum_grad to all threads\n\
3049 ld.shared.f32 %sum_grad, [sdata];\n\
3050 bar.sync 0;\n\
3051\n\
3052 // Phase 2: out[j] = grad[j] - exp(output[j]) * sum_grad\n\
3053 mov.u32 %j, %r_tid;\n\
3054WRITE_LOOP:\n\
3055 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3056 @%loop_p bra WRITE_LOOP_DONE;\n\
3057 cvt.u64.u32 %off, %j;\n\
3058 shl.b64 %off, %off, 2;\n\
3059 add.u64 %saddr, %grad, %off;\n\
3060 add.u64 %saddr, %saddr, %row_off;\n\
3061 ld.global.f32 %vg, [%saddr];\n\
3062 add.u64 %saddr, %output, %off;\n\
3063 add.u64 %saddr, %saddr, %row_off;\n\
3064 ld.global.f32 %vo, [%saddr];\n\
3065 // exp(log_softmax_output) = softmax probability\n\
3066 mul.f32 %vo, %vo, 0f3FB8AA3B;\n\
3067 ex2.approx.f32 %softmax_j, %vo;\n\
3068 // out[j] = grad[j] - softmax[j] * sum_grad\n\
3069 mul.f32 %result, %softmax_j, %sum_grad;\n\
3070 sub.f32 %result, %vg, %result;\n\
3071 add.u64 %saddr, %out, %off;\n\
3072 add.u64 %saddr, %saddr, %row_off;\n\
3073 st.global.f32 [%saddr], %result;\n\
3074 add.u32 %j, %j, %bdim;\n\
3075 bra WRITE_LOOP;\n\
3076WRITE_LOOP_DONE:\n\
3077\n\
3078DONE:\n\
3079 ret;\n\
3080}\n\
3081";
3082
3083#[cfg(feature = "cuda")]
3097pub(crate) const REDUCE_SUM_PTX: &str = "\
3098.version 7.0
3099.target sm_52
3100.address_size 64
3101
3102// Shared memory for intra-block reduction (256 floats = 1024 bytes).
3103.shared .align 4 .f32 sdata[256];
3104
3105.visible .entry reduce_sum_kernel(
3106 .param .u64 in_ptr,
3107 .param .u64 out_ptr,
3108 .param .u32 n
3109) {
3110 .reg .u32 %tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
3111 .reg .u64 %in, %out, %off;
3112 .reg .f32 %sum, %other;
3113 .reg .pred %p, %ptid;
3114
3115 ld.param.u64 %in, [in_ptr];
3116 ld.param.u64 %out, [out_ptr];
3117 ld.param.u32 %n_reg, [n];
3118
3119 mov.u32 %tid, %tid.x;
3120 mov.u32 %bid, %ctaid.x;
3121 mov.u32 %bdim, %ntid.x;
3122 mov.u32 %gdim, %nctaid.x;
3123
3124 // Grid-stride accumulation: each thread sums multiple elements.
3125 // idx = bid * bdim + tid; stride = bdim * gdim
3126 mad.lo.u32 %idx, %bid, %bdim, %tid;
3127 mul.lo.u32 %stride, %bdim, %gdim;
3128 mov.f32 %sum, 0f00000000;
3129
3130GRID_LOOP:
3131 setp.ge.u32 %p, %idx, %n_reg;
3132 @%p bra GRID_DONE;
3133
3134 cvt.u64.u32 %off, %idx;
3135 shl.b64 %off, %off, 2;
3136 add.u64 %off, %in, %off;
3137 ld.global.f32 %other, [%off];
3138 add.f32 %sum, %sum, %other;
3139 add.u32 %idx, %idx, %stride;
3140 bra GRID_LOOP;
3141
3142GRID_DONE:
3143 // Write thread's partial sum to shared memory.
3144 cvt.u64.u32 %off, %tid;
3145 shl.b64 %off, %off, 2;
3146 st.shared.f32 [sdata + %off], %sum;
3147 bar.sync 0;
3148
3149 // Tree reduction in shared memory.
3150 mov.u32 %half, 128;
3151TREE_LOOP:
3152 setp.lt.u32 %p, %half, 1;
3153 @%p bra TREE_DONE;
3154
3155 setp.ge.u32 %ptid, %tid, %half;
3156 @%ptid bra TREE_SKIP;
3157
3158 // Load partner's value from sdata[tid + half].
3159 add.u32 %idx, %tid, %half;
3160 cvt.u64.u32 %off, %idx;
3161 shl.b64 %off, %off, 2;
3162 ld.shared.f32 %other, [sdata + %off];
3163 // Load own value.
3164 cvt.u64.u32 %off, %tid;
3165 shl.b64 %off, %off, 2;
3166 ld.shared.f32 %sum, [sdata + %off];
3167 add.f32 %sum, %sum, %other;
3168 st.shared.f32 [sdata + %off], %sum;
3169
3170TREE_SKIP:
3171 bar.sync 0;
3172 shr.u32 %half, %half, 1;
3173 bra TREE_LOOP;
3174
3175TREE_DONE:
3176 // Thread 0 writes block result.
3177 setp.ne.u32 %ptid, %tid, 0;
3178 @%ptid bra END;
3179
3180 ld.shared.f32 %sum, [sdata];
3181 cvt.u64.u32 %off, %bid;
3182 shl.b64 %off, %off, 2;
3183 add.u64 %out, %out, %off;
3184 st.global.f32 [%out], %sum;
3185
3186END:
3187 ret;
3188}
3189";
3190
3191#[cfg(feature = "cuda")]
3195pub(crate) const SUM_AXIS_PTX: &str = "\
3196.version 7.0
3197.target sm_52
3198.address_size 64
3199
3200.visible .entry sum_axis_kernel(
3201 .param .u64 input_ptr,
3202 .param .u64 output_ptr,
3203 .param .u32 outer_size,
3204 .param .u32 axis_size,
3205 .param .u32 inner_size,
3206 .param .u32 total_output
3207) {
3208 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %axis_sz, %inner_sz;
3209 .reg .u32 %outer_idx, %inner_idx, %k, %tmp;
3210 .reg .u64 %in, %out, %off, %addr;
3211 .reg .f32 %val, %sum;
3212 .reg .pred %p, %lp;
3213
3214 ld.param.u64 %in, [input_ptr];
3215 ld.param.u64 %out, [output_ptr];
3216 ld.param.u32 %outer_sz, [outer_size];
3217 ld.param.u32 %axis_sz, [axis_size];
3218 ld.param.u32 %inner_sz, [inner_size];
3219 ld.param.u32 %n_reg, [total_output];
3220
3221 mov.u32 %bid, %ctaid.x;
3222 mov.u32 %bdim, %ntid.x;
3223 mov.u32 %r_tid, %tid.x;
3224 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3225
3226 setp.ge.u32 %p, %r_tid, %n_reg;
3227 @%p bra DONE;
3228
3229 // outer_idx = r_tid / inner_size
3230 div.u32 %outer_idx, %r_tid, %inner_sz;
3231 // inner_idx = r_tid % inner_size
3232 rem.u32 %inner_idx, %r_tid, %inner_sz;
3233
3234 // base = outer_idx * axis_size * inner_size + inner_idx
3235 mul.lo.u32 %tmp, %outer_idx, %axis_sz;
3236 mul.lo.u32 %tmp, %tmp, %inner_sz;
3237 add.u32 %tmp, %tmp, %inner_idx;
3238
3239 mov.f32 %sum, 0f00000000;
3240 mov.u32 %k, 0;
3241SUM_LOOP:
3242 setp.ge.u32 %lp, %k, %axis_sz;
3243 @%lp bra SUM_LOOP_DONE;
3244
3245 // addr = in + (tmp + k * inner_size) * 4
3246 mul.lo.u32 %inner_idx, %k, %inner_sz;
3247 add.u32 %inner_idx, %tmp, %inner_idx;
3248 cvt.u64.u32 %off, %inner_idx;
3249 shl.b64 %off, %off, 2;
3250 add.u64 %addr, %in, %off;
3251 ld.global.f32 %val, [%addr];
3252 add.f32 %sum, %sum, %val;
3253
3254 add.u32 %k, %k, 1;
3255 bra SUM_LOOP;
3256SUM_LOOP_DONE:
3257
3258 // output[r_tid] = sum
3259 cvt.u64.u32 %off, %r_tid;
3260 shl.b64 %off, %off, 2;
3261 add.u64 %addr, %out, %off;
3262 st.global.f32 [%addr], %sum;
3263
3264DONE:
3265 ret;
3266}
3267";
3268
3269#[cfg(feature = "cuda")]
3281pub(crate) const CUMSUM_PTX: &str = "\
3282.version 7.0
3283.target sm_52
3284.address_size 64
3285
3286.visible .entry cumsum_kernel(
3287 .param .u64 input_ptr,
3288 .param .u64 output_ptr,
3289 .param .u32 outer_size,
3290 .param .u32 dim_size,
3291 .param .u32 inner_size,
3292 .param .u32 total
3293) {
3294 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
3295 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
3296 .reg .u64 %in, %out, %off, %addr;
3297 .reg .f32 %val, %acc;
3298 .reg .pred %p, %lp;
3299
3300 ld.param.u64 %in, [input_ptr];
3301 ld.param.u64 %out, [output_ptr];
3302 ld.param.u32 %outer_sz, [outer_size];
3303 ld.param.u32 %dim_sz, [dim_size];
3304 ld.param.u32 %inner_sz, [inner_size];
3305 ld.param.u32 %n_reg, [total];
3306
3307 mov.u32 %bid, %ctaid.x;
3308 mov.u32 %bdim, %ntid.x;
3309 mov.u32 %r_tid, %tid.x;
3310 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3311
3312 // total threads = outer * inner
3313 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
3314 setp.ge.u32 %p, %r_tid, %tmp;
3315 @%p bra DONE;
3316
3317 div.u32 %outer_idx, %r_tid, %inner_sz;
3318 rem.u32 %inner_idx, %r_tid, %inner_sz;
3319
3320 // base = outer_idx * dim_size * inner_size + inner_idx
3321 mul.lo.u32 %base, %outer_idx, %dim_sz;
3322 mul.lo.u32 %base, %base, %inner_sz;
3323 add.u32 %base, %base, %inner_idx;
3324
3325 mov.f32 %acc, 0f00000000;
3326 mov.u32 %k, 0;
3327SCAN_LOOP:
3328 setp.ge.u32 %lp, %k, %dim_sz;
3329 @%lp bra SCAN_DONE;
3330
3331 // idx = base + k * inner_size
3332 mul.lo.u32 %idx, %k, %inner_sz;
3333 add.u32 %idx, %base, %idx;
3334
3335 cvt.u64.u32 %off, %idx;
3336 shl.b64 %off, %off, 2;
3337 add.u64 %addr, %in, %off;
3338 ld.global.f32 %val, [%addr];
3339
3340 add.f32 %acc, %acc, %val;
3341
3342 add.u64 %addr, %out, %off;
3343 st.global.f32 [%addr], %acc;
3344
3345 add.u32 %k, %k, 1;
3346 bra SCAN_LOOP;
3347SCAN_DONE:
3348
3349DONE:
3350 ret;
3351}
3352";
3353
3354#[cfg(feature = "cuda")]
3359pub(crate) const CUMPROD_PTX: &str = "\
3360.version 7.0
3361.target sm_52
3362.address_size 64
3363
3364.visible .entry cumprod_kernel(
3365 .param .u64 input_ptr,
3366 .param .u64 output_ptr,
3367 .param .u32 outer_size,
3368 .param .u32 dim_size,
3369 .param .u32 inner_size,
3370 .param .u32 total
3371) {
3372 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
3373 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
3374 .reg .u64 %in, %out, %off, %addr;
3375 .reg .f32 %val, %acc;
3376 .reg .pred %p, %lp;
3377
3378 ld.param.u64 %in, [input_ptr];
3379 ld.param.u64 %out, [output_ptr];
3380 ld.param.u32 %outer_sz, [outer_size];
3381 ld.param.u32 %dim_sz, [dim_size];
3382 ld.param.u32 %inner_sz, [inner_size];
3383 ld.param.u32 %n_reg, [total];
3384
3385 mov.u32 %bid, %ctaid.x;
3386 mov.u32 %bdim, %ntid.x;
3387 mov.u32 %r_tid, %tid.x;
3388 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3389
3390 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
3391 setp.ge.u32 %p, %r_tid, %tmp;
3392 @%p bra DONE;
3393
3394 div.u32 %outer_idx, %r_tid, %inner_sz;
3395 rem.u32 %inner_idx, %r_tid, %inner_sz;
3396
3397 mul.lo.u32 %base, %outer_idx, %dim_sz;
3398 mul.lo.u32 %base, %base, %inner_sz;
3399 add.u32 %base, %base, %inner_idx;
3400
3401 // acc = 1.0
3402 mov.f32 %acc, 0f3F800000;
3403 mov.u32 %k, 0;
3404SCAN_LOOP:
3405 setp.ge.u32 %lp, %k, %dim_sz;
3406 @%lp bra SCAN_DONE;
3407
3408 mul.lo.u32 %idx, %k, %inner_sz;
3409 add.u32 %idx, %base, %idx;
3410
3411 cvt.u64.u32 %off, %idx;
3412 shl.b64 %off, %off, 2;
3413 add.u64 %addr, %in, %off;
3414 ld.global.f32 %val, [%addr];
3415
3416 mul.f32 %acc, %acc, %val;
3417
3418 add.u64 %addr, %out, %off;
3419 st.global.f32 [%addr], %acc;
3420
3421 add.u32 %k, %k, 1;
3422 bra SCAN_LOOP;
3423SCAN_DONE:
3424
3425DONE:
3426 ret;
3427}
3428";
3429
3430#[cfg(feature = "cuda")]
3437pub(crate) const CUMMAX_PTX: &str = "\
3438.version 7.0
3439.target sm_52
3440.address_size 64
3441
3442.visible .entry cummax_kernel(
3443 .param .u64 input_ptr,
3444 .param .u64 output_ptr,
3445 .param .u64 indices_ptr,
3446 .param .u32 outer_size,
3447 .param .u32 dim_size,
3448 .param .u32 inner_size,
3449 .param .u32 total
3450) {
3451 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
3452 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
3453 .reg .u64 %in, %out, %ind, %off, %addr;
3454 .reg .f32 %val, %acc, %best_k_f;
3455 .reg .pred %p, %lp, %is_new_max;
3456
3457 ld.param.u64 %in, [input_ptr];
3458 ld.param.u64 %out, [output_ptr];
3459 ld.param.u64 %ind, [indices_ptr];
3460 ld.param.u32 %outer_sz, [outer_size];
3461 ld.param.u32 %dim_sz, [dim_size];
3462 ld.param.u32 %inner_sz, [inner_size];
3463 ld.param.u32 %n_reg, [total];
3464
3465 mov.u32 %bid, %ctaid.x;
3466 mov.u32 %bdim, %ntid.x;
3467 mov.u32 %r_tid, %tid.x;
3468 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3469
3470 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
3471 setp.ge.u32 %p, %r_tid, %tmp;
3472 @%p bra DONE;
3473
3474 div.u32 %outer_idx, %r_tid, %inner_sz;
3475 rem.u32 %inner_idx, %r_tid, %inner_sz;
3476
3477 mul.lo.u32 %base, %outer_idx, %dim_sz;
3478 mul.lo.u32 %base, %base, %inner_sz;
3479 add.u32 %base, %base, %inner_idx;
3480
3481 mov.b32 %acc, 0xFF800000;
3482 mov.u32 %best_k, 0;
3483 mov.u32 %k, 0;
3484SCAN_LOOP:
3485 setp.ge.u32 %lp, %k, %dim_sz;
3486 @%lp bra SCAN_DONE;
3487
3488 mul.lo.u32 %idx, %k, %inner_sz;
3489 add.u32 %idx, %base, %idx;
3490
3491 cvt.u64.u32 %off, %idx;
3492 shl.b64 %off, %off, 2;
3493 add.u64 %addr, %in, %off;
3494 ld.global.f32 %val, [%addr];
3495
3496 setp.gt.f32 %is_new_max, %val, %acc;
3497 @%is_new_max mov.u32 %best_k, %k;
3498 max.f32 %acc, %acc, %val;
3499
3500 add.u64 %addr, %out, %off;
3501 st.global.f32 [%addr], %acc;
3502
3503 cvt.rn.f32.u32 %best_k_f, %best_k;
3504 add.u64 %addr, %ind, %off;
3505 st.global.f32 [%addr], %best_k_f;
3506
3507 add.u32 %k, %k, 1;
3508 bra SCAN_LOOP;
3509SCAN_DONE:
3510
3511DONE:
3512 ret;
3513}
3514";
3515
3516#[cfg(feature = "cuda")]
3521pub(crate) const CUMMIN_PTX: &str = "\
3522.version 7.0
3523.target sm_52
3524.address_size 64
3525
3526.visible .entry cummin_kernel(
3527 .param .u64 input_ptr,
3528 .param .u64 output_ptr,
3529 .param .u64 indices_ptr,
3530 .param .u32 outer_size,
3531 .param .u32 dim_size,
3532 .param .u32 inner_size,
3533 .param .u32 total
3534) {
3535 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
3536 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
3537 .reg .u64 %in, %out, %ind, %off, %addr;
3538 .reg .f32 %val, %acc, %best_k_f;
3539 .reg .pred %p, %lp, %is_new_min;
3540
3541 ld.param.u64 %in, [input_ptr];
3542 ld.param.u64 %out, [output_ptr];
3543 ld.param.u64 %ind, [indices_ptr];
3544 ld.param.u32 %outer_sz, [outer_size];
3545 ld.param.u32 %dim_sz, [dim_size];
3546 ld.param.u32 %inner_sz, [inner_size];
3547 ld.param.u32 %n_reg, [total];
3548
3549 mov.u32 %bid, %ctaid.x;
3550 mov.u32 %bdim, %ntid.x;
3551 mov.u32 %r_tid, %tid.x;
3552 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3553
3554 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
3555 setp.ge.u32 %p, %r_tid, %tmp;
3556 @%p bra DONE;
3557
3558 div.u32 %outer_idx, %r_tid, %inner_sz;
3559 rem.u32 %inner_idx, %r_tid, %inner_sz;
3560
3561 mul.lo.u32 %base, %outer_idx, %dim_sz;
3562 mul.lo.u32 %base, %base, %inner_sz;
3563 add.u32 %base, %base, %inner_idx;
3564
3565 mov.b32 %acc, 0x7F800000;
3566 mov.u32 %best_k, 0;
3567 mov.u32 %k, 0;
3568SCAN_LOOP:
3569 setp.ge.u32 %lp, %k, %dim_sz;
3570 @%lp bra SCAN_DONE;
3571
3572 mul.lo.u32 %idx, %k, %inner_sz;
3573 add.u32 %idx, %base, %idx;
3574
3575 cvt.u64.u32 %off, %idx;
3576 shl.b64 %off, %off, 2;
3577 add.u64 %addr, %in, %off;
3578 ld.global.f32 %val, [%addr];
3579
3580 setp.lt.f32 %is_new_min, %val, %acc;
3581 @%is_new_min mov.u32 %best_k, %k;
3582 min.f32 %acc, %acc, %val;
3583
3584 add.u64 %addr, %out, %off;
3585 st.global.f32 [%addr], %acc;
3586
3587 cvt.rn.f32.u32 %best_k_f, %best_k;
3588 add.u64 %addr, %ind, %off;
3589 st.global.f32 [%addr], %best_k_f;
3590
3591 add.u32 %k, %k, 1;
3592 bra SCAN_LOOP;
3593SCAN_DONE:
3594
3595DONE:
3596 ret;
3597}
3598";
3599
3600#[cfg(feature = "cuda")]
3609pub(crate) const LOGCUMSUMEXP_PTX: &str = "\
3610.version 7.0
3611.target sm_52
3612.address_size 64
3613
3614.visible .entry logcumsumexp_kernel(
3615 .param .u64 input_ptr,
3616 .param .u64 output_ptr,
3617 .param .u32 outer_size,
3618 .param .u32 dim_size,
3619 .param .u32 inner_size,
3620 .param .u32 total
3621) {
3622 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
3623 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
3624 .reg .u64 %in, %out, %off, %addr;
3625 .reg .f32 %val, %acc, %m, %ea, %ev, %s, %ls, %log2e, %ln2;
3626 .reg .pred %p, %lp;
3627
3628 ld.param.u64 %in, [input_ptr];
3629 ld.param.u64 %out, [output_ptr];
3630 ld.param.u32 %outer_sz, [outer_size];
3631 ld.param.u32 %dim_sz, [dim_size];
3632 ld.param.u32 %inner_sz, [inner_size];
3633 ld.param.u32 %n_reg, [total];
3634
3635 // log2(e) = 1.4426950408... -> 0x3FB8AA3B
3636 mov.b32 %log2e, 0x3FB8AA3B;
3637 // ln(2) = 0.6931471805... -> 0x3F317218
3638 mov.b32 %ln2, 0x3F317218;
3639
3640 mov.u32 %bid, %ctaid.x;
3641 mov.u32 %bdim, %ntid.x;
3642 mov.u32 %r_tid, %tid.x;
3643 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3644
3645 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
3646 setp.ge.u32 %p, %r_tid, %tmp;
3647 @%p bra DONE;
3648
3649 div.u32 %outer_idx, %r_tid, %inner_sz;
3650 rem.u32 %inner_idx, %r_tid, %inner_sz;
3651
3652 mul.lo.u32 %base, %outer_idx, %dim_sz;
3653 mul.lo.u32 %base, %base, %inner_sz;
3654 add.u32 %base, %base, %inner_idx;
3655
3656 // acc = -inf
3657 mov.b32 %acc, 0xFF800000;
3658 mov.u32 %k, 0;
3659SCAN_LOOP:
3660 setp.ge.u32 %lp, %k, %dim_sz;
3661 @%lp bra SCAN_DONE;
3662
3663 mul.lo.u32 %idx, %k, %inner_sz;
3664 add.u32 %idx, %base, %idx;
3665
3666 cvt.u64.u32 %off, %idx;
3667 shl.b64 %off, %off, 2;
3668 add.u64 %addr, %in, %off;
3669 ld.global.f32 %val, [%addr];
3670
3671 // Numerically stable: m = max(acc, x)
3672 max.f32 %m, %acc, %val;
3673 // exp(acc - m): (acc - m) * log2(e) -> ex2
3674 sub.f32 %ea, %acc, %m;
3675 mul.f32 %ea, %ea, %log2e;
3676 ex2.approx.f32 %ea, %ea;
3677 // exp(x - m): (x - m) * log2(e) -> ex2
3678 sub.f32 %ev, %val, %m;
3679 mul.f32 %ev, %ev, %log2e;
3680 ex2.approx.f32 %ev, %ev;
3681 // sum
3682 add.f32 %s, %ea, %ev;
3683 // log(sum) = lg2(sum) * ln(2)
3684 lg2.approx.f32 %ls, %s;
3685 mul.f32 %ls, %ls, %ln2;
3686 // acc = m + log(sum)
3687 add.f32 %acc, %m, %ls;
3688
3689 add.u64 %addr, %out, %off;
3690 st.global.f32 [%addr], %acc;
3691
3692 add.u32 %k, %k, 1;
3693 bra SCAN_LOOP;
3694SCAN_DONE:
3695
3696DONE:
3697 ret;
3698}
3699";
3700
3701#[cfg(feature = "cuda")]
3711pub(crate) const LAYERNORM_PTX: &str = "\
3712.version 7.0
3713.target sm_52
3714.address_size 64
3715
3716.shared .align 4 .f32 sdata[256];
3717
3718.visible .entry layernorm_kernel(
3719 .param .u64 in_ptr,
3720 .param .u64 out_ptr,
3721 .param .u64 w_ptr,
3722 .param .u64 b_ptr,
3723 .param .u32 rows,
3724 .param .u32 cols,
3725 .param .f32 eps
3726) {
3727 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
3728 .reg .u64 %in, %out, %w, %b, %row_off, %off, %sbase, %saddr;
3729 .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %normed, %wv, %bv, %result, %other_val, %n_f;
3730 .reg .pred %p, %lp, %rp;
3731
3732 ld.param.u64 %in, [in_ptr];
3733 ld.param.u64 %out, [out_ptr];
3734 ld.param.u64 %w, [w_ptr];
3735 ld.param.u64 %b, [b_ptr];
3736 ld.param.u32 %rows_reg, [rows];
3737 ld.param.u32 %cols_reg, [cols];
3738 ld.param.f32 %eps_r, [eps];
3739
3740 mov.u64 %sbase, sdata;
3741
3742 mov.u32 %r_bid, %ctaid.x;
3743 mov.u32 %r_bdim, %ntid.x;
3744 mov.u32 %r_tid, %tid.x;
3745
3746 setp.ge.u32 %p, %r_bid, %rows_reg;
3747 @%p bra DONE;
3748
3749 cvt.u64.u32 %row_off, %r_bid;
3750 cvt.u64.u32 %off, %cols_reg;
3751 mul.lo.u64 %row_off, %row_off, %off;
3752 shl.b64 %row_off, %row_off, 2;
3753 cvt.rn.f32.u32 %n_f, %cols_reg;
3754
3755 mov.f32 %mean, 0f00000000;
3756 mov.u32 %j, %r_tid;
3757SM:
3758 setp.ge.u32 %lp, %j, %cols_reg;
3759 @%lp bra SMD;
3760 cvt.u64.u32 %off, %j;
3761 shl.b64 %off, %off, 2;
3762 add.u64 %off, %in, %off;
3763 add.u64 %off, %off, %row_off;
3764 ld.global.f32 %val, [%off];
3765 add.f32 %mean, %mean, %val;
3766 add.u32 %j, %j, %r_bdim;
3767 bra SM;
3768SMD:
3769 cvt.u64.u32 %off, %r_tid;
3770 shl.b64 %off, %off, 2;
3771 add.u64 %saddr, %sbase, %off;
3772 st.shared.f32 [%saddr], %mean;
3773 bar.sync 0;
3774 mov.u32 %half, %r_bdim;
3775MR:
3776 shr.u32 %half, %half, 1;
3777 setp.eq.u32 %rp, %half, 0;
3778 @%rp bra MRD;
3779 setp.ge.u32 %rp, %r_tid, %half;
3780 @%rp bra MRS;
3781 add.u32 %r_otid, %r_tid, %half;
3782 cvt.u64.u32 %off, %r_otid;
3783 shl.b64 %off, %off, 2;
3784 add.u64 %saddr, %sbase, %off;
3785 ld.shared.f32 %other_val, [%saddr];
3786 cvt.u64.u32 %off, %r_tid;
3787 shl.b64 %off, %off, 2;
3788 add.u64 %saddr, %sbase, %off;
3789 ld.shared.f32 %mean, [%saddr];
3790 add.f32 %mean, %mean, %other_val;
3791 add.u64 %saddr, %sbase, %off;
3792 st.shared.f32 [%saddr], %mean;
3793MRS:
3794 bar.sync 0;
3795 bra MR;
3796MRD:
3797 ld.shared.f32 %mean, [%sbase];
3798 div.approx.f32 %mean, %mean, %n_f;
3799 bar.sync 0;
3800
3801 mov.f32 %var, 0f00000000;
3802 mov.u32 %j, %r_tid;
3803SV:
3804 setp.ge.u32 %lp, %j, %cols_reg;
3805 @%lp bra SVD;
3806 cvt.u64.u32 %off, %j;
3807 shl.b64 %off, %off, 2;
3808 add.u64 %off, %in, %off;
3809 add.u64 %off, %off, %row_off;
3810 ld.global.f32 %val, [%off];
3811 sub.f32 %diff, %val, %mean;
3812 fma.rn.f32 %var, %diff, %diff, %var;
3813 add.u32 %j, %j, %r_bdim;
3814 bra SV;
3815SVD:
3816 cvt.u64.u32 %off, %r_tid;
3817 shl.b64 %off, %off, 2;
3818 add.u64 %saddr, %sbase, %off;
3819 st.shared.f32 [%saddr], %var;
3820 bar.sync 0;
3821 mov.u32 %half, %r_bdim;
3822VR:
3823 shr.u32 %half, %half, 1;
3824 setp.eq.u32 %rp, %half, 0;
3825 @%rp bra VRD;
3826 setp.ge.u32 %rp, %r_tid, %half;
3827 @%rp bra VRS;
3828 add.u32 %r_otid, %r_tid, %half;
3829 cvt.u64.u32 %off, %r_otid;
3830 shl.b64 %off, %off, 2;
3831 add.u64 %saddr, %sbase, %off;
3832 ld.shared.f32 %other_val, [%saddr];
3833 cvt.u64.u32 %off, %r_tid;
3834 shl.b64 %off, %off, 2;
3835 add.u64 %saddr, %sbase, %off;
3836 ld.shared.f32 %var, [%saddr];
3837 add.f32 %var, %var, %other_val;
3838 add.u64 %saddr, %sbase, %off;
3839 st.shared.f32 [%saddr], %var;
3840VRS:
3841 bar.sync 0;
3842 bra VR;
3843VRD:
3844 ld.shared.f32 %var, [%sbase];
3845 div.approx.f32 %var, %var, %n_f;
3846 add.f32 %var, %var, %eps_r;
3847 sqrt.approx.f32 %inv_std, %var;
3848 rcp.approx.f32 %inv_std, %inv_std;
3849 bar.sync 0;
3850
3851 mov.u32 %j, %r_tid;
3852NM:
3853 setp.ge.u32 %lp, %j, %cols_reg;
3854 @%lp bra NMD;
3855 cvt.u64.u32 %off, %j;
3856 shl.b64 %off, %off, 2;
3857 add.u64 %off, %in, %off;
3858 add.u64 %off, %off, %row_off;
3859 ld.global.f32 %val, [%off];
3860 sub.f32 %normed, %val, %mean;
3861 mul.f32 %normed, %normed, %inv_std;
3862 cvt.u64.u32 %off, %j;
3863 shl.b64 %off, %off, 2;
3864 add.u64 %off, %w, %off;
3865 ld.global.f32 %wv, [%off];
3866 cvt.u64.u32 %off, %j;
3867 shl.b64 %off, %off, 2;
3868 add.u64 %off, %b, %off;
3869 ld.global.f32 %bv, [%off];
3870 fma.rn.f32 %result, %wv, %normed, %bv;
3871 cvt.u64.u32 %off, %j;
3872 shl.b64 %off, %off, 2;
3873 add.u64 %off, %out, %off;
3874 add.u64 %off, %off, %row_off;
3875 st.global.f32 [%off], %result;
3876 add.u32 %j, %j, %r_bdim;
3877 bra NM;
3878NMD:
3879
3880DONE:
3881 ret;
3882}
3883";
3884
3885#[cfg(feature = "cuda")]
3910pub(crate) const LAYERNORM_BACKWARD_PTX: &str = "\
3911.version 7.0
3912.target sm_52
3913.address_size 64
3914
3915.shared .align 4 .f32 sdata[256];
3916
3917.visible .entry layernorm_backward_kernel(
3918 .param .u64 in_ptr,
3919 .param .u64 grad_out_ptr,
3920 .param .u64 w_ptr,
3921 .param .u64 grad_in_ptr,
3922 .param .u64 grad_w_ptr,
3923 .param .u64 grad_b_ptr,
3924 .param .u32 rows,
3925 .param .u32 cols,
3926 .param .f32 eps
3927) {
3928 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
3929 .reg .u64 %in, %go, %w, %gi, %gw, %gb, %row_off, %off, %sbase, %saddr, %addr;
3930 .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %x_hat, %wv, %gov;
3931 .reg .f32 %dl_dx_hat, %sum1, %sum2, %other_val, %n_f, %mean1, %mean2, %result;
3932 .reg .pred %p, %lp, %rp;
3933
3934 ld.param.u64 %in, [in_ptr];
3935 ld.param.u64 %go, [grad_out_ptr];
3936 ld.param.u64 %w, [w_ptr];
3937 ld.param.u64 %gi, [grad_in_ptr];
3938 ld.param.u64 %gw, [grad_w_ptr];
3939 ld.param.u64 %gb, [grad_b_ptr];
3940 ld.param.u32 %rows_reg, [rows];
3941 ld.param.u32 %cols_reg, [cols];
3942 ld.param.f32 %eps_r, [eps];
3943
3944 mov.u64 %sbase, sdata;
3945
3946 mov.u32 %r_bid, %ctaid.x;
3947 mov.u32 %r_bdim, %ntid.x;
3948 mov.u32 %r_tid, %tid.x;
3949
3950 setp.ge.u32 %p, %r_bid, %rows_reg;
3951 @%p bra LNB_DONE;
3952
3953 // row_off = bid * cols * 4 (byte offset for this row)
3954 cvt.u64.u32 %row_off, %r_bid;
3955 cvt.u64.u32 %off, %cols_reg;
3956 mul.lo.u64 %row_off, %row_off, %off;
3957 shl.b64 %row_off, %row_off, 2;
3958 cvt.rn.f32.u32 %n_f, %cols_reg;
3959
3960 // ===== Phase 1: Compute mean =====
3961 mov.f32 %mean, 0f00000000;
3962 mov.u32 %j, %r_tid;
3963LNB_SM:
3964 setp.ge.u32 %lp, %j, %cols_reg;
3965 @%lp bra LNB_SMD;
3966 cvt.u64.u32 %off, %j;
3967 shl.b64 %off, %off, 2;
3968 add.u64 %addr, %in, %off;
3969 add.u64 %addr, %addr, %row_off;
3970 ld.global.f32 %val, [%addr];
3971 add.f32 %mean, %mean, %val;
3972 add.u32 %j, %j, %r_bdim;
3973 bra LNB_SM;
3974LNB_SMD:
3975 // Shared memory reduce for mean
3976 cvt.u64.u32 %off, %r_tid;
3977 shl.b64 %off, %off, 2;
3978 add.u64 %saddr, %sbase, %off;
3979 st.shared.f32 [%saddr], %mean;
3980 bar.sync 0;
3981 mov.u32 %half, %r_bdim;
3982LNB_MR:
3983 shr.u32 %half, %half, 1;
3984 setp.eq.u32 %rp, %half, 0;
3985 @%rp bra LNB_MRD;
3986 setp.ge.u32 %rp, %r_tid, %half;
3987 @%rp bra LNB_MRS;
3988 add.u32 %r_otid, %r_tid, %half;
3989 cvt.u64.u32 %off, %r_otid;
3990 shl.b64 %off, %off, 2;
3991 add.u64 %saddr, %sbase, %off;
3992 ld.shared.f32 %other_val, [%saddr];
3993 cvt.u64.u32 %off, %r_tid;
3994 shl.b64 %off, %off, 2;
3995 add.u64 %saddr, %sbase, %off;
3996 ld.shared.f32 %mean, [%saddr];
3997 add.f32 %mean, %mean, %other_val;
3998 st.shared.f32 [%saddr], %mean;
3999LNB_MRS:
4000 bar.sync 0;
4001 bra LNB_MR;
4002LNB_MRD:
4003 ld.shared.f32 %mean, [%sbase];
4004 div.approx.f32 %mean, %mean, %n_f;
4005 bar.sync 0;
4006
4007 // ===== Phase 2: Compute variance =====
4008 mov.f32 %var, 0f00000000;
4009 mov.u32 %j, %r_tid;
4010LNB_SV:
4011 setp.ge.u32 %lp, %j, %cols_reg;
4012 @%lp bra LNB_SVD;
4013 cvt.u64.u32 %off, %j;
4014 shl.b64 %off, %off, 2;
4015 add.u64 %addr, %in, %off;
4016 add.u64 %addr, %addr, %row_off;
4017 ld.global.f32 %val, [%addr];
4018 sub.f32 %diff, %val, %mean;
4019 fma.rn.f32 %var, %diff, %diff, %var;
4020 add.u32 %j, %j, %r_bdim;
4021 bra LNB_SV;
4022LNB_SVD:
4023 // Shared memory reduce for variance
4024 cvt.u64.u32 %off, %r_tid;
4025 shl.b64 %off, %off, 2;
4026 add.u64 %saddr, %sbase, %off;
4027 st.shared.f32 [%saddr], %var;
4028 bar.sync 0;
4029 mov.u32 %half, %r_bdim;
4030LNB_VR:
4031 shr.u32 %half, %half, 1;
4032 setp.eq.u32 %rp, %half, 0;
4033 @%rp bra LNB_VRD;
4034 setp.ge.u32 %rp, %r_tid, %half;
4035 @%rp bra LNB_VRS;
4036 add.u32 %r_otid, %r_tid, %half;
4037 cvt.u64.u32 %off, %r_otid;
4038 shl.b64 %off, %off, 2;
4039 add.u64 %saddr, %sbase, %off;
4040 ld.shared.f32 %other_val, [%saddr];
4041 cvt.u64.u32 %off, %r_tid;
4042 shl.b64 %off, %off, 2;
4043 add.u64 %saddr, %sbase, %off;
4044 ld.shared.f32 %var, [%saddr];
4045 add.f32 %var, %var, %other_val;
4046 st.shared.f32 [%saddr], %var;
4047LNB_VRS:
4048 bar.sync 0;
4049 bra LNB_VR;
4050LNB_VRD:
4051 ld.shared.f32 %var, [%sbase];
4052 div.approx.f32 %var, %var, %n_f;
4053 add.f32 %var, %var, %eps_r;
4054 sqrt.approx.f32 %inv_std, %var;
4055 rcp.approx.f32 %inv_std, %inv_std;
4056 bar.sync 0;
4057
4058 // ===== Phase 3: Compute sum1 = sum(dl_dx_hat), sum2 = sum(dl_dx_hat * x_hat) =====
4059 // Also accumulate grad_weight and grad_bias via atomicAdd
4060 mov.f32 %sum1, 0f00000000;
4061 mov.f32 %sum2, 0f00000000;
4062 mov.u32 %j, %r_tid;
4063LNB_S12:
4064 setp.ge.u32 %lp, %j, %cols_reg;
4065 @%lp bra LNB_S12D;
4066 // Load input[row, j]
4067 cvt.u64.u32 %off, %j;
4068 shl.b64 %off, %off, 2;
4069 add.u64 %addr, %in, %off;
4070 add.u64 %addr, %addr, %row_off;
4071 ld.global.f32 %val, [%addr];
4072 // x_hat = (val - mean) * inv_std
4073 sub.f32 %x_hat, %val, %mean;
4074 mul.f32 %x_hat, %x_hat, %inv_std;
4075 // Load grad_output[row, j]
4076 cvt.u64.u32 %off, %j;
4077 shl.b64 %off, %off, 2;
4078 add.u64 %addr, %go, %off;
4079 add.u64 %addr, %addr, %row_off;
4080 ld.global.f32 %gov, [%addr];
4081 // Load weight[j]
4082 cvt.u64.u32 %off, %j;
4083 shl.b64 %off, %off, 2;
4084 add.u64 %addr, %w, %off;
4085 ld.global.f32 %wv, [%addr];
4086 // dl_dx_hat = grad_output * weight
4087 mul.f32 %dl_dx_hat, %gov, %wv;
4088 // Accumulate sums
4089 add.f32 %sum1, %sum1, %dl_dx_hat;
4090 fma.rn.f32 %sum2, %dl_dx_hat, %x_hat, %sum2;
4091 // atomicAdd grad_weight[j] += grad_output * x_hat
4092 cvt.u64.u32 %off, %j;
4093 shl.b64 %off, %off, 2;
4094 add.u64 %addr, %gw, %off;
4095 mul.f32 %result, %gov, %x_hat;
4096 atom.global.add.f32 %result, [%addr], %result;
4097 // atomicAdd grad_bias[j] += grad_output
4098 add.u64 %addr, %gb, %off;
4099 atom.global.add.f32 %result, [%addr], %gov;
4100 add.u32 %j, %j, %r_bdim;
4101 bra LNB_S12;
4102LNB_S12D:
4103 // Reduce sum1 in shared memory
4104 cvt.u64.u32 %off, %r_tid;
4105 shl.b64 %off, %off, 2;
4106 add.u64 %saddr, %sbase, %off;
4107 st.shared.f32 [%saddr], %sum1;
4108 bar.sync 0;
4109 mov.u32 %half, %r_bdim;
4110LNB_R1:
4111 shr.u32 %half, %half, 1;
4112 setp.eq.u32 %rp, %half, 0;
4113 @%rp bra LNB_R1D;
4114 setp.ge.u32 %rp, %r_tid, %half;
4115 @%rp bra LNB_R1S;
4116 add.u32 %r_otid, %r_tid, %half;
4117 cvt.u64.u32 %off, %r_otid;
4118 shl.b64 %off, %off, 2;
4119 add.u64 %saddr, %sbase, %off;
4120 ld.shared.f32 %other_val, [%saddr];
4121 cvt.u64.u32 %off, %r_tid;
4122 shl.b64 %off, %off, 2;
4123 add.u64 %saddr, %sbase, %off;
4124 ld.shared.f32 %sum1, [%saddr];
4125 add.f32 %sum1, %sum1, %other_val;
4126 st.shared.f32 [%saddr], %sum1;
4127LNB_R1S:
4128 bar.sync 0;
4129 bra LNB_R1;
4130LNB_R1D:
4131 ld.shared.f32 %sum1, [%sbase];
4132 // mean1 = sum1 / n
4133 div.approx.f32 %mean1, %sum1, %n_f;
4134 bar.sync 0;
4135
4136 // Reduce sum2 in shared memory
4137 cvt.u64.u32 %off, %r_tid;
4138 shl.b64 %off, %off, 2;
4139 add.u64 %saddr, %sbase, %off;
4140 st.shared.f32 [%saddr], %sum2;
4141 bar.sync 0;
4142 mov.u32 %half, %r_bdim;
4143LNB_R2:
4144 shr.u32 %half, %half, 1;
4145 setp.eq.u32 %rp, %half, 0;
4146 @%rp bra LNB_R2D;
4147 setp.ge.u32 %rp, %r_tid, %half;
4148 @%rp bra LNB_R2S;
4149 add.u32 %r_otid, %r_tid, %half;
4150 cvt.u64.u32 %off, %r_otid;
4151 shl.b64 %off, %off, 2;
4152 add.u64 %saddr, %sbase, %off;
4153 ld.shared.f32 %other_val, [%saddr];
4154 cvt.u64.u32 %off, %r_tid;
4155 shl.b64 %off, %off, 2;
4156 add.u64 %saddr, %sbase, %off;
4157 ld.shared.f32 %sum2, [%saddr];
4158 add.f32 %sum2, %sum2, %other_val;
4159 st.shared.f32 [%saddr], %sum2;
4160LNB_R2S:
4161 bar.sync 0;
4162 bra LNB_R2;
4163LNB_R2D:
4164 ld.shared.f32 %sum2, [%sbase];
4165 // mean2 = sum2 / n
4166 div.approx.f32 %mean2, %sum2, %n_f;
4167 bar.sync 0;
4168
4169 // ===== Phase 4: Compute grad_input =====
4170 // grad_input[j] = inv_std * (dl_dx_hat[j] - mean1 - x_hat[j] * mean2)
4171 mov.u32 %j, %r_tid;
4172LNB_GI:
4173 setp.ge.u32 %lp, %j, %cols_reg;
4174 @%lp bra LNB_GID;
4175 // Reload input to recompute x_hat
4176 cvt.u64.u32 %off, %j;
4177 shl.b64 %off, %off, 2;
4178 add.u64 %addr, %in, %off;
4179 add.u64 %addr, %addr, %row_off;
4180 ld.global.f32 %val, [%addr];
4181 sub.f32 %x_hat, %val, %mean;
4182 mul.f32 %x_hat, %x_hat, %inv_std;
4183 // Reload grad_output and weight to recompute dl_dx_hat
4184 cvt.u64.u32 %off, %j;
4185 shl.b64 %off, %off, 2;
4186 add.u64 %addr, %go, %off;
4187 add.u64 %addr, %addr, %row_off;
4188 ld.global.f32 %gov, [%addr];
4189 cvt.u64.u32 %off, %j;
4190 shl.b64 %off, %off, 2;
4191 add.u64 %addr, %w, %off;
4192 ld.global.f32 %wv, [%addr];
4193 mul.f32 %dl_dx_hat, %gov, %wv;
4194 // result = inv_std * (dl_dx_hat - mean1 - x_hat * mean2)
4195 sub.f32 %result, %dl_dx_hat, %mean1;
4196 mul.f32 %diff, %x_hat, %mean2;
4197 sub.f32 %result, %result, %diff;
4198 mul.f32 %result, %inv_std, %result;
4199 // Store grad_input[row, j]
4200 cvt.u64.u32 %off, %j;
4201 shl.b64 %off, %off, 2;
4202 add.u64 %addr, %gi, %off;
4203 add.u64 %addr, %addr, %row_off;
4204 st.global.f32 [%addr], %result;
4205 add.u32 %j, %j, %r_bdim;
4206 bra LNB_GI;
4207LNB_GID:
4208
4209LNB_DONE:
4210 ret;
4211}
4212";
4213
4214#[cfg(feature = "cuda")]
4227pub(crate) const RMSNORM_PTX: &str = "\
4228.version 7.0
4229.target sm_52
4230.address_size 64
4231
4232.shared .align 4 .f32 sdata[256];
4233
4234.visible .entry rmsnorm_kernel(
4235 .param .u64 in_ptr,
4236 .param .u64 out_ptr,
4237 .param .u64 w_ptr,
4238 .param .u32 rows,
4239 .param .u32 cols,
4240 .param .f32 eps
4241) {
4242 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
4243 .reg .u64 %in, %out, %w, %row_off, %off, %sbase, %saddr;
4244 .reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %wv, %result, %other_val, %n_f;
4245 .reg .pred %p, %lp, %rp;
4246
4247 ld.param.u64 %in, [in_ptr];
4248 ld.param.u64 %out, [out_ptr];
4249 ld.param.u64 %w, [w_ptr];
4250 ld.param.u32 %rows_reg, [rows];
4251 ld.param.u32 %cols_reg, [cols];
4252 ld.param.f32 %eps_r, [eps];
4253
4254 mov.u64 %sbase, sdata;
4255
4256 mov.u32 %r_bid, %ctaid.x;
4257 mov.u32 %r_bdim, %ntid.x;
4258 mov.u32 %r_tid, %tid.x;
4259
4260 setp.ge.u32 %p, %r_bid, %rows_reg;
4261 @%p bra DONE;
4262
4263 cvt.u64.u32 %row_off, %r_bid;
4264 cvt.u64.u32 %off, %cols_reg;
4265 mul.lo.u64 %row_off, %row_off, %off;
4266 shl.b64 %row_off, %row_off, 2;
4267 cvt.rn.f32.u32 %n_f, %cols_reg;
4268
4269 // ===== Phase 1: Compute sum(x^2) =====
4270 mov.f32 %sq_sum, 0f00000000;
4271 mov.u32 %j, %r_tid;
4272SS:
4273 setp.ge.u32 %lp, %j, %cols_reg;
4274 @%lp bra SSD;
4275 cvt.u64.u32 %off, %j;
4276 shl.b64 %off, %off, 2;
4277 add.u64 %off, %in, %off;
4278 add.u64 %off, %off, %row_off;
4279 ld.global.f32 %val, [%off];
4280 fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
4281 add.u32 %j, %j, %r_bdim;
4282 bra SS;
4283SSD:
4284 cvt.u64.u32 %off, %r_tid;
4285 shl.b64 %off, %off, 2;
4286 add.u64 %saddr, %sbase, %off;
4287 st.shared.f32 [%saddr], %sq_sum;
4288 bar.sync 0;
4289 mov.u32 %half, %r_bdim;
4290SR:
4291 shr.u32 %half, %half, 1;
4292 setp.eq.u32 %rp, %half, 0;
4293 @%rp bra SRD;
4294 setp.ge.u32 %rp, %r_tid, %half;
4295 @%rp bra SRS;
4296 add.u32 %r_otid, %r_tid, %half;
4297 cvt.u64.u32 %off, %r_otid;
4298 shl.b64 %off, %off, 2;
4299 add.u64 %saddr, %sbase, %off;
4300 ld.shared.f32 %other_val, [%saddr];
4301 cvt.u64.u32 %off, %r_tid;
4302 shl.b64 %off, %off, 2;
4303 add.u64 %saddr, %sbase, %off;
4304 ld.shared.f32 %sq_sum, [%saddr];
4305 add.f32 %sq_sum, %sq_sum, %other_val;
4306 add.u64 %saddr, %sbase, %off;
4307 st.shared.f32 [%saddr], %sq_sum;
4308SRS:
4309 bar.sync 0;
4310 bra SR;
4311SRD:
4312 ld.shared.f32 %sq_sum, [%sbase];
4313 div.approx.f32 %sq_sum, %sq_sum, %n_f;
4314 add.f32 %sq_sum, %sq_sum, %eps_r;
4315 sqrt.approx.f32 %inv_rms, %sq_sum;
4316 rcp.approx.f32 %inv_rms, %inv_rms;
4317 bar.sync 0;
4318
4319 // ===== Phase 2: Normalize and scale =====
4320 // out[j] = x[j] * inv_rms * weight[j]
4321 mov.u32 %j, %r_tid;
4322NM:
4323 setp.ge.u32 %lp, %j, %cols_reg;
4324 @%lp bra NMD;
4325 cvt.u64.u32 %off, %j;
4326 shl.b64 %off, %off, 2;
4327 add.u64 %off, %in, %off;
4328 add.u64 %off, %off, %row_off;
4329 ld.global.f32 %val, [%off];
4330 mul.f32 %result, %val, %inv_rms;
4331 cvt.u64.u32 %off, %j;
4332 shl.b64 %off, %off, 2;
4333 add.u64 %off, %w, %off;
4334 ld.global.f32 %wv, [%off];
4335 mul.f32 %result, %result, %wv;
4336 cvt.u64.u32 %off, %j;
4337 shl.b64 %off, %off, 2;
4338 add.u64 %off, %out, %off;
4339 add.u64 %off, %off, %row_off;
4340 st.global.f32 [%off], %result;
4341 add.u32 %j, %j, %r_bdim;
4342 bra NM;
4343NMD:
4344
4345DONE:
4346 ret;
4347}
4348";
4349
4350#[cfg(feature = "cuda")]
4374pub(crate) const RMSNORM_BACKWARD_PTX: &str = "\
4375.version 7.0
4376.target sm_52
4377.address_size 64
4378
4379.shared .align 4 .f32 sdata[256];
4380
4381.visible .entry rmsnorm_backward_kernel(
4382 .param .u64 in_ptr,
4383 .param .u64 grad_out_ptr,
4384 .param .u64 w_ptr,
4385 .param .u64 grad_in_ptr,
4386 .param .u64 grad_w_ptr,
4387 .param .u32 rows,
4388 .param .u32 cols,
4389 .param .f32 eps
4390) {
4391 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
4392 .reg .u64 %in, %go, %w, %gi, %gw, %row_off, %off, %sbase, %saddr, %addr;
4393 .reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %inv_rms3, %wv, %gov;
4394 .reg .f32 %dot, %other_val, %n_f, %coeff, %result, %tmp;
4395 .reg .pred %p, %lp, %rp;
4396
4397 ld.param.u64 %in, [in_ptr];
4398 ld.param.u64 %go, [grad_out_ptr];
4399 ld.param.u64 %w, [w_ptr];
4400 ld.param.u64 %gi, [grad_in_ptr];
4401 ld.param.u64 %gw, [grad_w_ptr];
4402 ld.param.u32 %rows_reg, [rows];
4403 ld.param.u32 %cols_reg, [cols];
4404 ld.param.f32 %eps_r, [eps];
4405
4406 mov.u64 %sbase, sdata;
4407
4408 mov.u32 %r_bid, %ctaid.x;
4409 mov.u32 %r_bdim, %ntid.x;
4410 mov.u32 %r_tid, %tid.x;
4411
4412 setp.ge.u32 %p, %r_bid, %rows_reg;
4413 @%p bra RNB_DONE;
4414
4415 // row_off = bid * cols * 4 (byte offset for this row)
4416 cvt.u64.u32 %row_off, %r_bid;
4417 cvt.u64.u32 %off, %cols_reg;
4418 mul.lo.u64 %row_off, %row_off, %off;
4419 shl.b64 %row_off, %row_off, 2;
4420 cvt.rn.f32.u32 %n_f, %cols_reg;
4421
4422 // ===== Phase 1: Compute sum(x^2) -> inv_rms =====
4423 mov.f32 %sq_sum, 0f00000000;
4424 mov.u32 %j, %r_tid;
4425RNB_SS:
4426 setp.ge.u32 %lp, %j, %cols_reg;
4427 @%lp bra RNB_SSD;
4428 cvt.u64.u32 %off, %j;
4429 shl.b64 %off, %off, 2;
4430 add.u64 %addr, %in, %off;
4431 add.u64 %addr, %addr, %row_off;
4432 ld.global.f32 %val, [%addr];
4433 fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
4434 add.u32 %j, %j, %r_bdim;
4435 bra RNB_SS;
4436RNB_SSD:
4437 // Shared memory reduce for sum(x^2)
4438 cvt.u64.u32 %off, %r_tid;
4439 shl.b64 %off, %off, 2;
4440 add.u64 %saddr, %sbase, %off;
4441 st.shared.f32 [%saddr], %sq_sum;
4442 bar.sync 0;
4443 mov.u32 %half, %r_bdim;
4444RNB_SR:
4445 shr.u32 %half, %half, 1;
4446 setp.eq.u32 %rp, %half, 0;
4447 @%rp bra RNB_SRD;
4448 setp.ge.u32 %rp, %r_tid, %half;
4449 @%rp bra RNB_SRS;
4450 add.u32 %r_otid, %r_tid, %half;
4451 cvt.u64.u32 %off, %r_otid;
4452 shl.b64 %off, %off, 2;
4453 add.u64 %saddr, %sbase, %off;
4454 ld.shared.f32 %other_val, [%saddr];
4455 cvt.u64.u32 %off, %r_tid;
4456 shl.b64 %off, %off, 2;
4457 add.u64 %saddr, %sbase, %off;
4458 ld.shared.f32 %sq_sum, [%saddr];
4459 add.f32 %sq_sum, %sq_sum, %other_val;
4460 st.shared.f32 [%saddr], %sq_sum;
4461RNB_SRS:
4462 bar.sync 0;
4463 bra RNB_SR;
4464RNB_SRD:
4465 ld.shared.f32 %sq_sum, [%sbase];
4466 div.approx.f32 %sq_sum, %sq_sum, %n_f;
4467 add.f32 %sq_sum, %sq_sum, %eps_r;
4468 sqrt.approx.f32 %inv_rms, %sq_sum;
4469 rcp.approx.f32 %inv_rms, %inv_rms;
4470 // inv_rms3 = inv_rms^3 = inv_rms * inv_rms * inv_rms
4471 mul.f32 %inv_rms3, %inv_rms, %inv_rms;
4472 mul.f32 %inv_rms3, %inv_rms3, %inv_rms;
4473 bar.sync 0;
4474
4475 // ===== Phase 2: Compute dot = sum(go[j] * x[j] * w[j]) =====
4476 // Also accumulate grad_weight via atomicAdd
4477 mov.f32 %dot, 0f00000000;
4478 mov.u32 %j, %r_tid;
4479RNB_DOT:
4480 setp.ge.u32 %lp, %j, %cols_reg;
4481 @%lp bra RNB_DOTD;
4482 // Load input[row, j]
4483 cvt.u64.u32 %off, %j;
4484 shl.b64 %off, %off, 2;
4485 add.u64 %addr, %in, %off;
4486 add.u64 %addr, %addr, %row_off;
4487 ld.global.f32 %val, [%addr];
4488 // Load grad_output[row, j]
4489 cvt.u64.u32 %off, %j;
4490 shl.b64 %off, %off, 2;
4491 add.u64 %addr, %go, %off;
4492 add.u64 %addr, %addr, %row_off;
4493 ld.global.f32 %gov, [%addr];
4494 // Load weight[j]
4495 cvt.u64.u32 %off, %j;
4496 shl.b64 %off, %off, 2;
4497 add.u64 %addr, %w, %off;
4498 ld.global.f32 %wv, [%addr];
4499 // dot += go * x * w
4500 mul.f32 %tmp, %gov, %val;
4501 fma.rn.f32 %dot, %tmp, %wv, %dot;
4502 // atomicAdd grad_weight[j] += go * x * inv_rms
4503 cvt.u64.u32 %off, %j;
4504 shl.b64 %off, %off, 2;
4505 add.u64 %addr, %gw, %off;
4506 mul.f32 %result, %gov, %val;
4507 mul.f32 %result, %result, %inv_rms;
4508 atom.global.add.f32 %result, [%addr], %result;
4509 add.u32 %j, %j, %r_bdim;
4510 bra RNB_DOT;
4511RNB_DOTD:
4512 // Reduce dot in shared memory
4513 cvt.u64.u32 %off, %r_tid;
4514 shl.b64 %off, %off, 2;
4515 add.u64 %saddr, %sbase, %off;
4516 st.shared.f32 [%saddr], %dot;
4517 bar.sync 0;
4518 mov.u32 %half, %r_bdim;
4519RNB_DR:
4520 shr.u32 %half, %half, 1;
4521 setp.eq.u32 %rp, %half, 0;
4522 @%rp bra RNB_DRD;
4523 setp.ge.u32 %rp, %r_tid, %half;
4524 @%rp bra RNB_DRS;
4525 add.u32 %r_otid, %r_tid, %half;
4526 cvt.u64.u32 %off, %r_otid;
4527 shl.b64 %off, %off, 2;
4528 add.u64 %saddr, %sbase, %off;
4529 ld.shared.f32 %other_val, [%saddr];
4530 cvt.u64.u32 %off, %r_tid;
4531 shl.b64 %off, %off, 2;
4532 add.u64 %saddr, %sbase, %off;
4533 ld.shared.f32 %dot, [%saddr];
4534 add.f32 %dot, %dot, %other_val;
4535 st.shared.f32 [%saddr], %dot;
4536RNB_DRS:
4537 bar.sync 0;
4538 bra RNB_DR;
4539RNB_DRD:
4540 ld.shared.f32 %dot, [%sbase];
4541 // coeff = dot * inv_rms3 / n
4542 mul.f32 %coeff, %dot, %inv_rms3;
4543 div.approx.f32 %coeff, %coeff, %n_f;
4544 bar.sync 0;
4545
4546 // ===== Phase 3: Compute grad_input =====
4547 // grad_input[j] = inv_rms * w[j] * go[j] - x[j] * coeff
4548 mov.u32 %j, %r_tid;
4549RNB_GI:
4550 setp.ge.u32 %lp, %j, %cols_reg;
4551 @%lp bra RNB_GID;
4552 // Reload input
4553 cvt.u64.u32 %off, %j;
4554 shl.b64 %off, %off, 2;
4555 add.u64 %addr, %in, %off;
4556 add.u64 %addr, %addr, %row_off;
4557 ld.global.f32 %val, [%addr];
4558 // Reload grad_output and weight
4559 cvt.u64.u32 %off, %j;
4560 shl.b64 %off, %off, 2;
4561 add.u64 %addr, %go, %off;
4562 add.u64 %addr, %addr, %row_off;
4563 ld.global.f32 %gov, [%addr];
4564 cvt.u64.u32 %off, %j;
4565 shl.b64 %off, %off, 2;
4566 add.u64 %addr, %w, %off;
4567 ld.global.f32 %wv, [%addr];
4568 // result = inv_rms * w * go - x * coeff
4569 mul.f32 %result, %inv_rms, %wv;
4570 mul.f32 %result, %result, %gov;
4571 mul.f32 %tmp, %val, %coeff;
4572 sub.f32 %result, %result, %tmp;
4573 // Store grad_input[row, j]
4574 cvt.u64.u32 %off, %j;
4575 shl.b64 %off, %off, 2;
4576 add.u64 %addr, %gi, %off;
4577 add.u64 %addr, %addr, %row_off;
4578 st.global.f32 [%addr], %result;
4579 add.u32 %j, %j, %r_bdim;
4580 bra RNB_GI;
4581RNB_GID:
4582
4583RNB_DONE:
4584 ret;
4585}
4586";
4587
4588#[cfg(feature = "cuda")]
4621pub(crate) const BATCHNORM_FORWARD_PTX: &str = "\
4622.version 7.0
4623.target sm_52
4624.address_size 64
4625
4626// Shared memory for block reduction
4627.shared .align 4 .f32 smem_sum[256];
4628.shared .align 4 .f32 smem_sq[256];
4629
4630.visible .entry batchnorm_forward_kernel(
4631 .param .u64 input_ptr,
4632 .param .u64 output_ptr,
4633 .param .u64 weight_ptr,
4634 .param .u64 bias_ptr,
4635 .param .u64 rmean_ptr,
4636 .param .u64 rvar_ptr,
4637 .param .u64 save_mean_ptr,
4638 .param .u64 save_invstd_ptr,
4639 .param .u32 channels,
4640 .param .u32 spatial,
4641 .param .f32 eps,
4642 .param .f32 momentum,
4643 .param .u32 total_per_ch,
4644 .param .u32 training
4645) {
4646 .reg .u32 %tid, %bid, %bdim, %ch, %n_ch, %sp, %tpc, %idx, %train;
4647 .reg .u64 %in, %out, %w, %b, %rm, %rv, %sm, %si, %off64, %tmp64;
4648 .reg .f32 %sum, %sqsum, %val, %mean, %var, %invstd;
4649 .reg .f32 %gamma, %beta, %eps_reg, %mom, %other;
4650 .reg .f32 %n_f, %one, %normalized;
4651 .reg .pred %p, %ptrain, %ptid0;
4652 .reg .u32 %half;
4653
4654 ld.param.u64 %in, [input_ptr];
4655 ld.param.u64 %out, [output_ptr];
4656 ld.param.u64 %w, [weight_ptr];
4657 ld.param.u64 %b, [bias_ptr];
4658 ld.param.u64 %rm, [rmean_ptr];
4659 ld.param.u64 %rv, [rvar_ptr];
4660 ld.param.u64 %sm, [save_mean_ptr];
4661 ld.param.u64 %si, [save_invstd_ptr];
4662 ld.param.u32 %n_ch, [channels];
4663 ld.param.u32 %sp, [spatial];
4664 ld.param.f32 %eps_reg, [eps];
4665 ld.param.f32 %mom, [momentum];
4666 ld.param.u32 %tpc, [total_per_ch];
4667 ld.param.u32 %train, [training];
4668
4669 mov.u32 %bid, %ctaid.x;
4670 mov.u32 %tid, %tid.x;
4671 mov.u32 %bdim, %ntid.x;
4672 mov.u32 %ch, %bid;
4673 mov.f32 %one, 0f3F800000;
4674
4675 setp.ge.u32 %p, %ch, %n_ch;
4676 @%p bra END;
4677
4678 setp.ne.u32 %ptrain, %train, 0;
4679
4680 // ---- Pass 1: compute sum and sum-of-squares for this channel ----
4681 mov.f32 %sum, 0f00000000;
4682 mov.f32 %sqsum, 0f00000000;
4683
4684 // Grid-stride loop over B*spatial for this channel
4685 mov.u32 %idx, %tid;
4686PASS1_LOOP:
4687 setp.ge.u32 %p, %idx, %tpc;
4688 @%p bra PASS1_DONE;
4689
4690 // Linear offset = (idx / spatial) * channels * spatial + ch * spatial + idx % spatial
4691 div.u32 %half, %idx, %sp;
4692 rem.u32 %half, %idx, %sp; // reuse half as spatial_idx
4693 // batch_offset = (idx / sp) * (n_ch * sp) + ch * sp + (idx % sp)
4694 div.u32 %half, %idx, %sp; // batch_idx
4695 mul.lo.u32 %half, %half, %n_ch;
4696 add.u32 %half, %half, %ch;
4697 mul.lo.u32 %half, %half, %sp;
4698 rem.u32 %idx, %idx, %sp; // spatial_idx
4699 add.u32 %half, %half, %idx;
4700
4701 cvt.u64.u32 %off64, %half;
4702 shl.b64 %off64, %off64, 2;
4703 add.u64 %tmp64, %in, %off64;
4704 ld.global.f32 %val, [%tmp64];
4705 add.f32 %sum, %sum, %val;
4706 fma.rn.f32 %sqsum, %val, %val, %sqsum;
4707
4708 // Restore idx for stride
4709 // Recompute idx from tid + iteration * bdim
4710 add.u32 %idx, %idx, %bdim; // This is wrong - need proper loop counter
4711 bra PASS1_LOOP;
4712
4713PASS1_DONE:
4714 // Store to shared memory for block reduction
4715 cvt.u64.u32 %off64, %tid;
4716 shl.b64 %off64, %off64, 2;
4717 st.shared.f32 [smem_sum + %off64], %sum;
4718 st.shared.f32 [smem_sq + %off64], %sqsum;
4719 bar.sync 0;
4720
4721 // Tree reduction
4722 mov.u32 %half, 128;
4723REDUCE_LOOP:
4724 setp.lt.u32 %p, %half, 1;
4725 @%p bra REDUCE_DONE;
4726 setp.ge.u32 %p, %tid, %half;
4727 @%p bra REDUCE_SKIP;
4728
4729 add.u32 %idx, %tid, %half;
4730 cvt.u64.u32 %off64, %idx;
4731 shl.b64 %off64, %off64, 2;
4732 ld.shared.f32 %other, [smem_sum + %off64];
4733 cvt.u64.u32 %tmp64, %tid;
4734 shl.b64 %tmp64, %tmp64, 2;
4735 ld.shared.f32 %sum, [smem_sum + %tmp64];
4736 add.f32 %sum, %sum, %other;
4737 st.shared.f32 [smem_sum + %tmp64], %sum;
4738
4739 ld.shared.f32 %other, [smem_sq + %off64];
4740 ld.shared.f32 %sqsum, [smem_sq + %tmp64];
4741 add.f32 %sqsum, %sqsum, %other;
4742 st.shared.f32 [smem_sq + %tmp64], %sqsum;
4743
4744REDUCE_SKIP:
4745 bar.sync 0;
4746 shr.u32 %half, %half, 1;
4747 bra REDUCE_LOOP;
4748
4749REDUCE_DONE:
4750 // Thread 0 computes mean and invstd
4751 setp.ne.u32 %ptid0, %tid, 0;
4752
4753 @%ptid0 bra WAIT_STATS;
4754
4755 ld.shared.f32 %sum, [smem_sum];
4756 ld.shared.f32 %sqsum, [smem_sq];
4757 cvt.rn.f32.u32 %n_f, %tpc;
4758 div.rn.f32 %mean, %sum, %n_f;
4759 // var = sqsum/n - mean^2
4760 div.rn.f32 %var, %sqsum, %n_f;
4761 fma.rn.f32 %var, %mean, %mean, %var; // This adds mean^2, need to subtract
4762 // Actually: var = E[x^2] - E[x]^2, so var = sqsum/n - mean^2
4763 // We had: var = sqsum/n, now subtract mean^2
4764 neg.f32 %other, %mean;
4765 fma.rn.f32 %var, %other, %mean, %var; // var = var + (-mean)*mean = sqsum/n - mean^2
4766
4767 // invstd = 1/sqrt(var + eps)
4768 add.f32 %other, %var, %eps_reg;
4769 sqrt.rn.f32 %other, %other;
4770 div.rn.f32 %invstd, %one, %other;
4771
4772 // Save mean and invstd
4773 cvt.u64.u32 %off64, %ch;
4774 shl.b64 %off64, %off64, 2;
4775 add.u64 %tmp64, %sm, %off64;
4776 st.global.f32 [%tmp64], %mean;
4777 add.u64 %tmp64, %si, %off64;
4778 st.global.f32 [%tmp64], %invstd;
4779
4780 // Store to shared for other threads
4781 st.shared.f32 [smem_sum], %mean;
4782 st.shared.f32 [smem_sq], %invstd;
4783
4784WAIT_STATS:
4785 bar.sync 0;
4786 // All threads read mean and invstd from shared
4787 ld.shared.f32 %mean, [smem_sum];
4788 ld.shared.f32 %invstd, [smem_sq];
4789
4790 // Load weight and bias for this channel
4791 cvt.u64.u32 %off64, %ch;
4792 shl.b64 %off64, %off64, 2;
4793 add.u64 %tmp64, %w, %off64;
4794 ld.global.f32 %gamma, [%tmp64];
4795 add.u64 %tmp64, %b, %off64;
4796 ld.global.f32 %beta, [%tmp64];
4797
4798 // ---- Pass 2: normalize + affine ----
4799 // For now this is a placeholder - the indexing needs to match pass 1
4800 // Each thread normalizes its elements
4801
4802END:
4803 ret;
4804}
4805";
4806
4807#[cfg(feature = "cuda")]
4812pub(crate) const MAXPOOL2D_PTX: &str = "\
4813.version 7.0
4814.target sm_52
4815.address_size 64
4816
4817.visible .entry maxpool2d_forward_kernel(
4818 .param .u64 input_ptr,
4819 .param .u64 output_ptr,
4820 .param .u32 batch,
4821 .param .u32 channels,
4822 .param .u32 h_in,
4823 .param .u32 w_in,
4824 .param .u32 h_out,
4825 .param .u32 w_out,
4826 .param .u32 kh,
4827 .param .u32 kw,
4828 .param .u32 sh,
4829 .param .u32 sw,
4830 .param .u32 ph,
4831 .param .u32 pw,
4832 .param .u32 total
4833) {
4834 .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
4835 .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp;
4836 .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
4837 .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
4838 .reg .u32 %batch_reg, %ch_reg;
4839 .reg .u64 %in, %out, %off64, %tmp64;
4840 .reg .f32 %max_val, %cur_val, %neg_inf;
4841 .reg .pred %p, %p_bounds, %p_gt;
4842
4843 ld.param.u64 %in, [input_ptr];
4844 ld.param.u64 %out, [output_ptr];
4845 ld.param.u32 %batch_reg, [batch];
4846 ld.param.u32 %ch_reg, [channels];
4847 ld.param.u32 %h_in_reg, [h_in];
4848 ld.param.u32 %w_in_reg, [w_in];
4849 ld.param.u32 %h_out_reg, [h_out];
4850 ld.param.u32 %w_out_reg, [w_out];
4851 ld.param.u32 %kh_reg, [kh];
4852 ld.param.u32 %kw_reg, [kw];
4853 ld.param.u32 %sh_reg, [sh];
4854 ld.param.u32 %sw_reg, [sw];
4855 ld.param.u32 %ph_reg, [ph];
4856 ld.param.u32 %pw_reg, [pw];
4857 ld.param.u32 %total_reg, [total];
4858
4859 mov.u32 %bid, %ctaid.x;
4860 mov.u32 %bdim, %ntid.x;
4861 mov.u32 %tid, %tid.x;
4862 mov.u32 %gdim, %nctaid.x;
4863 mad.lo.u32 %idx, %bid, %bdim, %tid;
4864 mul.lo.u32 %stride, %bdim, %gdim;
4865
4866 // -inf for max initialization
4867 mov.f32 %neg_inf, 0fFF800000;
4868
4869LOOP:
4870 setp.ge.u32 %p, %idx, %total_reg;
4871 @%p bra END;
4872
4873 // Decompose idx into (b, c, oh, ow)
4874 mov.u32 %rem, %idx;
4875 div.u32 %b_idx, %rem, %ch_reg;
4876 // Actually need: idx = b * C * H_out * W_out + c * H_out * W_out + oh * W_out + ow
4877 // So decompose from the right:
4878 rem.u32 %ow, %rem, %w_out_reg;
4879 div.u32 %rem, %rem, %w_out_reg;
4880 rem.u32 %oh, %rem, %h_out_reg;
4881 div.u32 %rem, %rem, %h_out_reg;
4882 rem.u32 %c_idx, %rem, %ch_reg;
4883 div.u32 %b_idx, %rem, %ch_reg;
4884
4885 mov.f32 %max_val, %neg_inf;
4886
4887 // Slide the kernel window
4888 mov.u32 %i, 0;
4889KH_LOOP:
4890 setp.ge.u32 %p, %i, %kh_reg;
4891 @%p bra KH_DONE;
4892
4893 mov.u32 %j, 0;
4894KW_LOOP:
4895 setp.ge.u32 %p, %j, %kw_reg;
4896 @%p bra KW_DONE;
4897
4898 // ih = oh * sh + i - ph, iw = ow * sw + j - pw
4899 mad.lo.u32 %ih, %oh, %sh_reg, %i;
4900 sub.u32 %ih, %ih, %ph_reg;
4901 mad.lo.u32 %iw, %ow, %sw_reg, %j;
4902 sub.u32 %iw, %iw, %pw_reg;
4903
4904 // Bounds check: 0 <= ih < h_in && 0 <= iw < w_in
4905 // Since unsigned, just check < h_in and < w_in
4906 setp.ge.u32 %p_bounds, %ih, %h_in_reg;
4907 @%p_bounds bra KW_NEXT;
4908 setp.ge.u32 %p_bounds, %iw, %w_in_reg;
4909 @%p_bounds bra KW_NEXT;
4910
4911 // input_offset = b * C * H * W + c * H * W + ih * W + iw
4912 mul.lo.u32 %tmp, %b_idx, %ch_reg;
4913 add.u32 %tmp, %tmp, %c_idx;
4914 mul.lo.u32 %tmp, %tmp, %h_in_reg;
4915 add.u32 %tmp, %tmp, %ih;
4916 mul.lo.u32 %tmp, %tmp, %w_in_reg;
4917 add.u32 %tmp, %tmp, %iw;
4918
4919 cvt.u64.u32 %off64, %tmp;
4920 shl.b64 %off64, %off64, 2;
4921 add.u64 %tmp64, %in, %off64;
4922 ld.global.f32 %cur_val, [%tmp64];
4923
4924 max.f32 %max_val, %max_val, %cur_val;
4925
4926KW_NEXT:
4927 add.u32 %j, %j, 1;
4928 bra KW_LOOP;
4929
4930KW_DONE:
4931 add.u32 %i, %i, 1;
4932 bra KH_LOOP;
4933
4934KH_DONE:
4935 // Store output
4936 cvt.u64.u32 %off64, %idx;
4937 shl.b64 %off64, %off64, 2;
4938 add.u64 %tmp64, %out, %off64;
4939 st.global.f32 [%tmp64], %max_val;
4940
4941 add.u32 %idx, %idx, %stride;
4942 bra LOOP;
4943
4944END:
4945 ret;
4946}
4947";
4948
4949#[cfg(feature = "cuda")]
4954pub(crate) const AVGPOOL2D_PTX: &str = "\
4955.version 7.0
4956.target sm_52
4957.address_size 64
4958
4959.visible .entry avgpool2d_forward_kernel(
4960 .param .u64 input_ptr,
4961 .param .u64 output_ptr,
4962 .param .u32 batch,
4963 .param .u32 channels,
4964 .param .u32 h_in,
4965 .param .u32 w_in,
4966 .param .u32 h_out,
4967 .param .u32 w_out,
4968 .param .u32 kh,
4969 .param .u32 kw,
4970 .param .u32 sh,
4971 .param .u32 sw,
4972 .param .u32 ph,
4973 .param .u32 pw,
4974 .param .u32 total
4975) {
4976 .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
4977 .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp, %count;
4978 .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
4979 .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
4980 .reg .u32 %batch_reg, %ch_reg;
4981 .reg .u64 %in, %out, %off64, %tmp64;
4982 .reg .f32 %sum_val, %cur_val, %count_f, %avg;
4983 .reg .pred %p, %p_bounds;
4984
4985 ld.param.u64 %in, [input_ptr];
4986 ld.param.u64 %out, [output_ptr];
4987 ld.param.u32 %batch_reg, [batch];
4988 ld.param.u32 %ch_reg, [channels];
4989 ld.param.u32 %h_in_reg, [h_in];
4990 ld.param.u32 %w_in_reg, [w_in];
4991 ld.param.u32 %h_out_reg, [h_out];
4992 ld.param.u32 %w_out_reg, [w_out];
4993 ld.param.u32 %kh_reg, [kh];
4994 ld.param.u32 %kw_reg, [kw];
4995 ld.param.u32 %sh_reg, [sh];
4996 ld.param.u32 %sw_reg, [sw];
4997 ld.param.u32 %ph_reg, [ph];
4998 ld.param.u32 %pw_reg, [pw];
4999 ld.param.u32 %total_reg, [total];
5000
5001 mov.u32 %bid, %ctaid.x;
5002 mov.u32 %bdim, %ntid.x;
5003 mov.u32 %tid, %tid.x;
5004 mov.u32 %gdim, %nctaid.x;
5005 mad.lo.u32 %idx, %bid, %bdim, %tid;
5006 mul.lo.u32 %stride, %bdim, %gdim;
5007
5008LOOP:
5009 setp.ge.u32 %p, %idx, %total_reg;
5010 @%p bra END;
5011
5012 // Decompose idx into (b, c, oh, ow) — same as MaxPool2d
5013 mov.u32 %rem, %idx;
5014 rem.u32 %ow, %rem, %w_out_reg;
5015 div.u32 %rem, %rem, %w_out_reg;
5016 rem.u32 %oh, %rem, %h_out_reg;
5017 div.u32 %rem, %rem, %h_out_reg;
5018 rem.u32 %c_idx, %rem, %ch_reg;
5019 div.u32 %b_idx, %rem, %ch_reg;
5020
5021 mov.f32 %sum_val, 0f00000000;
5022 mov.u32 %count, 0;
5023
5024 mov.u32 %i, 0;
5025AKH_LOOP:
5026 setp.ge.u32 %p, %i, %kh_reg;
5027 @%p bra AKH_DONE;
5028
5029 mov.u32 %j, 0;
5030AKW_LOOP:
5031 setp.ge.u32 %p, %j, %kw_reg;
5032 @%p bra AKW_DONE;
5033
5034 mad.lo.u32 %ih, %oh, %sh_reg, %i;
5035 sub.u32 %ih, %ih, %ph_reg;
5036 mad.lo.u32 %iw, %ow, %sw_reg, %j;
5037 sub.u32 %iw, %iw, %pw_reg;
5038
5039 setp.ge.u32 %p_bounds, %ih, %h_in_reg;
5040 @%p_bounds bra AKW_NEXT;
5041 setp.ge.u32 %p_bounds, %iw, %w_in_reg;
5042 @%p_bounds bra AKW_NEXT;
5043
5044 mul.lo.u32 %tmp, %b_idx, %ch_reg;
5045 add.u32 %tmp, %tmp, %c_idx;
5046 mul.lo.u32 %tmp, %tmp, %h_in_reg;
5047 add.u32 %tmp, %tmp, %ih;
5048 mul.lo.u32 %tmp, %tmp, %w_in_reg;
5049 add.u32 %tmp, %tmp, %iw;
5050
5051 cvt.u64.u32 %off64, %tmp;
5052 shl.b64 %off64, %off64, 2;
5053 add.u64 %tmp64, %in, %off64;
5054 ld.global.f32 %cur_val, [%tmp64];
5055
5056 add.f32 %sum_val, %sum_val, %cur_val;
5057 add.u32 %count, %count, 1;
5058
5059AKW_NEXT:
5060 add.u32 %j, %j, 1;
5061 bra AKW_LOOP;
5062
5063AKW_DONE:
5064 add.u32 %i, %i, 1;
5065 bra AKH_LOOP;
5066
5067AKH_DONE:
5068 // avg = sum / count (count_include_pad = false behavior)
5069 cvt.rn.f32.u32 %count_f, %count;
5070 div.rn.f32 %avg, %sum_val, %count_f;
5071
5072 cvt.u64.u32 %off64, %idx;
5073 shl.b64 %off64, %off64, 2;
5074 add.u64 %tmp64, %out, %off64;
5075 st.global.f32 [%tmp64], %avg;
5076
5077 add.u32 %idx, %idx, %stride;
5078 bra LOOP;
5079
5080END:
5081 ret;
5082}
5083";
5084
5085#[cfg(feature = "cuda")]
5086pub(crate) const SOFTMAX_PTX: &str = "\
5087.version 7.0\n\
5088.target sm_52\n\
5089.address_size 64\n\
5090\n\
5091.shared .align 4 .f32 sdata[256];\n\
5092\n\
5093.visible .entry softmax_kernel(\n\
5094 .param .u64 input_ptr,\n\
5095 .param .u64 output_ptr,\n\
5096 .param .u32 rows,\n\
5097 .param .u32 cols\n\
5098) {\n\
5099 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
5100 .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
5101 .reg .f32 %val, %max_val, %sum_val, %exp_val, %result;\n\
5102 .reg .pred %p, %loop_p;\n\
5103 .reg .u32 %half, %other_tid;\n\
5104 .reg .f32 %other_val;\n\
5105 .reg .pred %reduce_p;\n\
5106\n\
5107 ld.param.u64 %in, [input_ptr];\n\
5108 ld.param.u64 %out, [output_ptr];\n\
5109 ld.param.u32 %rows_reg, [rows];\n\
5110 ld.param.u32 %cols_reg, [cols];\n\
5111\n\
5112 mov.u32 %bid, %ctaid.x;\n\
5113 mov.u32 %bdim, %ntid.x;\n\
5114 mov.u32 %r_tid, %tid.x;\n\
5115 mov.u64 %sbase, sdata;\n\
5116\n\
5117 setp.ge.u32 %p, %bid, %rows_reg;\n\
5118 @%p bra DONE;\n\
5119\n\
5120 cvt.u64.u32 %row_off, %bid;\n\
5121 cvt.u64.u32 %off, %cols_reg;\n\
5122 mul.lo.u64 %row_off, %row_off, %off;\n\
5123 shl.b64 %row_off, %row_off, 2;\n\
5124\n\
5125 mov.f32 %max_val, 0fFF800000;\n\
5126 mov.u32 %j, %r_tid;\n\
5127FIND_MAX:\n\
5128 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
5129 @%loop_p bra FIND_MAX_DONE;\n\
5130 cvt.u64.u32 %off, %j;\n\
5131 shl.b64 %off, %off, 2;\n\
5132 add.u64 %off, %in, %off;\n\
5133 add.u64 %off, %off, %row_off;\n\
5134 ld.global.f32 %val, [%off];\n\
5135 max.f32 %max_val, %max_val, %val;\n\
5136 add.u32 %j, %j, %bdim;\n\
5137 bra FIND_MAX;\n\
5138FIND_MAX_DONE:\n\
5139\n\
5140 cvt.u64.u32 %off, %r_tid;\n\
5141 shl.b64 %off, %off, 2;\n\
5142 add.u64 %saddr, %sbase, %off;\n\
5143 st.shared.f32 [%saddr], %max_val;\n\
5144 bar.sync 0;\n\
5145\n\
5146 mov.u32 %half, %bdim;\n\
5147MAX_REDUCE:\n\
5148 shr.u32 %half, %half, 1;\n\
5149 setp.eq.u32 %reduce_p, %half, 0;\n\
5150 @%reduce_p bra MAX_REDUCE_DONE;\n\
5151 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
5152 @%reduce_p bra MAX_REDUCE_SKIP;\n\
5153 add.u32 %other_tid, %r_tid, %half;\n\
5154 cvt.u64.u32 %off, %other_tid;\n\
5155 shl.b64 %off, %off, 2;\n\
5156 add.u64 %saddr, %sbase, %off;
5157 ld.shared.f32 %other_val, [%saddr];\n\
5158 cvt.u64.u32 %off, %r_tid;\n\
5159 shl.b64 %off, %off, 2;\n\
5160 add.u64 %saddr, %sbase, %off;\n\
5161 ld.shared.f32 %max_val, [%saddr];\n\
5162 max.f32 %max_val, %max_val, %other_val;\n\
5163 add.u64 %saddr, %sbase, %off;\n\
5164 st.shared.f32 [%saddr], %max_val;\n\
5165MAX_REDUCE_SKIP:\n\
5166 bar.sync 0;\n\
5167 bra MAX_REDUCE;\n\
5168MAX_REDUCE_DONE:\n\
5169\n\
5170 ld.shared.f32 %max_val, [sdata];\n\
5171 bar.sync 0;\n\
5172\n\
5173 mov.f32 %sum_val, 0f00000000;\n\
5174 mov.u32 %j, %r_tid;\n\
5175SUM_EXP:\n\
5176 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
5177 @%loop_p bra SUM_EXP_DONE;\n\
5178 cvt.u64.u32 %off, %j;\n\
5179 shl.b64 %off, %off, 2;\n\
5180 add.u64 %off, %in, %off;\n\
5181 add.u64 %off, %off, %row_off;\n\
5182 ld.global.f32 %val, [%off];\n\
5183 sub.f32 %val, %val, %max_val;\n\
5184 mul.f32 %val, %val, 0f3FB8AA3B;\n\
5185 ex2.approx.f32 %exp_val, %val;\n\
5186 add.f32 %sum_val, %sum_val, %exp_val;\n\
5187 cvt.u64.u32 %off, %j;\n\
5188 shl.b64 %off, %off, 2;\n\
5189 add.u64 %off, %out, %off;\n\
5190 add.u64 %off, %off, %row_off;\n\
5191 st.global.f32 [%off], %exp_val;\n\
5192 add.u32 %j, %j, %bdim;\n\
5193 bra SUM_EXP;\n\
5194SUM_EXP_DONE:\n\
5195\n\
5196 cvt.u64.u32 %off, %r_tid;\n\
5197 shl.b64 %off, %off, 2;\n\
5198 add.u64 %saddr, %sbase, %off;\n\
5199 st.shared.f32 [%saddr], %sum_val;\n\
5200 bar.sync 0;\n\
5201\n\
5202 mov.u32 %half, %bdim;\n\
5203SUM_REDUCE:\n\
5204 shr.u32 %half, %half, 1;\n\
5205 setp.eq.u32 %reduce_p, %half, 0;\n\
5206 @%reduce_p bra SUM_REDUCE_DONE;\n\
5207 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
5208 @%reduce_p bra SUM_REDUCE_SKIP;\n\
5209 add.u32 %other_tid, %r_tid, %half;\n\
5210 cvt.u64.u32 %off, %other_tid;\n\
5211 shl.b64 %off, %off, 2;\n\
5212 add.u64 %saddr, %sbase, %off;
5213 ld.shared.f32 %other_val, [%saddr];\n\
5214 cvt.u64.u32 %off, %r_tid;\n\
5215 shl.b64 %off, %off, 2;\n\
5216 add.u64 %saddr, %sbase, %off;\n\
5217 ld.shared.f32 %sum_val, [%saddr];\n\
5218 add.f32 %sum_val, %sum_val, %other_val;\n\
5219 add.u64 %saddr, %sbase, %off;\n\
5220 st.shared.f32 [%saddr], %sum_val;\n\
5221SUM_REDUCE_SKIP:\n\
5222 bar.sync 0;\n\
5223 bra SUM_REDUCE;\n\
5224SUM_REDUCE_DONE:\n\
5225\n\
5226 ld.shared.f32 %sum_val, [sdata];\n\
5227 bar.sync 0;\n\
5228\n\
5229 rcp.approx.f32 %sum_val, %sum_val;\n\
5230 mov.u32 %j, %r_tid;\n\
5231NORMALIZE:\n\
5232 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
5233 @%loop_p bra NORMALIZE_DONE;\n\
5234 cvt.u64.u32 %off, %j;\n\
5235 shl.b64 %off, %off, 2;\n\
5236 add.u64 %off, %out, %off;\n\
5237 add.u64 %off, %off, %row_off;\n\
5238 ld.global.f32 %val, [%off];\n\
5239 mul.f32 %result, %val, %sum_val;\n\
5240 st.global.f32 [%off], %result;\n\
5241 add.u32 %j, %j, %bdim;\n\
5242 bra NORMALIZE;\n\
5243NORMALIZE_DONE:\n\
5244\n\
5245DONE:\n\
5246 ret;\n\
5247}\n\
5248";
5249
5250#[cfg(feature = "cuda")]
5255pub(crate) const DROPOUT_PTX: &str = "\
5256.version 7.0\n\
5257.target sm_52\n\
5258.address_size 64\n\
5259\n\
5260.visible .entry dropout_kernel(\n\
5261 .param .u64 input_ptr,\n\
5262 .param .u64 output_ptr,\n\
5263 .param .u32 n,\n\
5264 .param .u32 threshold,\n\
5265 .param .f32 scale,\n\
5266 .param .u32 seed\n\
5267) {\n\
5268 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %thresh, %seed_reg, %rng, %tmp;\n\
5269 .reg .u64 %in, %out, %off;\n\
5270 .reg .f32 %val, %scale_reg, %zero;\n\
5271 .reg .pred %p, %drop_p;\n\
5272\n\
5273 ld.param.u64 %in, [input_ptr];\n\
5274 ld.param.u64 %out, [output_ptr];\n\
5275 ld.param.u32 %n_reg, [n];\n\
5276 ld.param.u32 %thresh, [threshold];\n\
5277 ld.param.f32 %scale_reg, [scale];\n\
5278 ld.param.u32 %seed_reg, [seed];\n\
5279\n\
5280 mov.u32 %bid, %ctaid.x;\n\
5281 mov.u32 %bdim, %ntid.x;\n\
5282 mov.u32 %r_tid, %tid.x;\n\
5283 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
5284\n\
5285 setp.ge.u32 %p, %r_tid, %n_reg;\n\
5286 @%p bra DONE;\n\
5287\n\
5288 mul.lo.u32 %rng, %r_tid, 2654435761;\n\
5289 xor.b32 %rng, %rng, %seed_reg;\n\
5290 shl.b32 %tmp, %rng, 13;\n\
5291 xor.b32 %rng, %rng, %tmp;\n\
5292 shr.b32 %tmp, %rng, 17;\n\
5293 xor.b32 %rng, %rng, %tmp;\n\
5294 shl.b32 %tmp, %rng, 5;\n\
5295 xor.b32 %rng, %rng, %tmp;\n\
5296\n\
5297 cvt.u64.u32 %off, %r_tid;\n\
5298 shl.b64 %off, %off, 2;\n\
5299 add.u64 %in, %in, %off;\n\
5300 add.u64 %out, %out, %off;\n\
5301 ld.global.f32 %val, [%in];\n\
5302\n\
5303 setp.lo.u32 %drop_p, %rng, %thresh;\n\
5304 mov.f32 %zero, 0f00000000;\n\
5305 @%drop_p mov.f32 %val, %zero;\n\
5306 @!%drop_p mul.f32 %val, %val, %scale_reg;\n\
5307\n\
5308 st.global.f32 [%out], %val;\n\
5309\n\
5310DONE:\n\
5311 ret;\n\
5312}\n\
5313";
5314
5315#[cfg(feature = "cuda")]
5338pub(crate) const BROADCAST_ADD_PTX: &str = "\
5339.version 7.0
5340.target sm_52
5341.address_size 64
5342
5343.visible .entry broadcast_add_kernel(
5344 .param .u64 a_ptr,
5345 .param .u64 b_ptr,
5346 .param .u64 out_ptr,
5347 .param .u64 a_strides_ptr,
5348 .param .u64 b_strides_ptr,
5349 .param .u64 out_shape_ptr,
5350 .param .u32 n,
5351 .param .u32 ndim
5352) {
5353 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
5354 .reg .u32 %remaining, %a_idx, %b_idx, %d;
5355 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
5356 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
5357 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
5358 .reg .f32 %va, %vb, %vr;
5359 .reg .pred %p, %loop_p;
5360
5361 ld.param.u64 %a, [a_ptr];
5362 ld.param.u64 %b, [b_ptr];
5363 ld.param.u64 %out, [out_ptr];
5364 ld.param.u64 %a_str, [a_strides_ptr];
5365 ld.param.u64 %b_str, [b_strides_ptr];
5366 ld.param.u64 %oshape, [out_shape_ptr];
5367 ld.param.u32 %n_reg, [n];
5368 ld.param.u32 %ndim_reg, [ndim];
5369
5370 // Global thread index.
5371 mov.u32 %bid, %ctaid.x;
5372 mov.u32 %bdim, %ntid.x;
5373 mov.u32 %r_tid, %tid.x;
5374 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5375
5376 setp.ge.u32 %p, %r_tid, %n_reg;
5377 @%p bra DONE;
5378
5379 // Decompose flat index into N-d coordinates and compute A/B indices.
5380 mov.u32 %remaining, %r_tid;
5381 mov.u32 %a_idx, 0;
5382 mov.u32 %b_idx, 0;
5383 mov.u32 %d, %ndim_reg;
5384
5385LOOP:
5386 setp.eq.u32 %loop_p, %d, 0;
5387 @%loop_p bra END_LOOP;
5388
5389 sub.u32 %d, %d, 1;
5390
5391 // Byte offset for dimension d: d * 4.
5392 cvt.u64.u32 %d64, %d;
5393 shl.b64 %d64, %d64, 2;
5394
5395 // Load out_shape[d].
5396 add.u64 %tmp, %oshape, %d64;
5397 ld.global.u32 %shape_d, [%tmp];
5398
5399 // Load a_strides[d] and b_strides[d].
5400 add.u64 %tmp, %a_str, %d64;
5401 ld.global.u32 %a_str_d, [%tmp];
5402 add.u64 %tmp, %b_str, %d64;
5403 ld.global.u32 %b_str_d, [%tmp];
5404
5405 // coord = remaining % shape_d; remaining /= shape_d.
5406 rem.u32 %coord, %remaining, %shape_d;
5407 div.u32 %remaining, %remaining, %shape_d;
5408
5409 // a_idx += coord * a_stride[d]; b_idx += coord * b_stride[d].
5410 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
5411 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
5412
5413 bra LOOP;
5414END_LOOP:
5415
5416 // Load a[a_idx] and b[b_idx] (f32 = 4 bytes).
5417 cvt.u64.u32 %off_a, %a_idx;
5418 shl.b64 %off_a, %off_a, 2;
5419 add.u64 %off_a, %a, %off_a;
5420 ld.global.f32 %va, [%off_a];
5421
5422 cvt.u64.u32 %off_b, %b_idx;
5423 shl.b64 %off_b, %off_b, 2;
5424 add.u64 %off_b, %b, %off_b;
5425 ld.global.f32 %vb, [%off_b];
5426
5427 // Operation: add.
5428 add.f32 %vr, %va, %vb;
5429
5430 // Store to out[tid].
5431 cvt.u64.u32 %off_out, %r_tid;
5432 shl.b64 %off_out, %off_out, 2;
5433 add.u64 %off_out, %out, %off_out;
5434 st.global.f32 [%off_out], %vr;
5435
5436DONE:
5437 ret;
5438}
5439";
5440
5441#[cfg(feature = "cuda")]
5443pub(crate) const BROADCAST_SUB_PTX: &str = "\
5444.version 7.0
5445.target sm_52
5446.address_size 64
5447
5448.visible .entry broadcast_sub_kernel(
5449 .param .u64 a_ptr,
5450 .param .u64 b_ptr,
5451 .param .u64 out_ptr,
5452 .param .u64 a_strides_ptr,
5453 .param .u64 b_strides_ptr,
5454 .param .u64 out_shape_ptr,
5455 .param .u32 n,
5456 .param .u32 ndim
5457) {
5458 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
5459 .reg .u32 %remaining, %a_idx, %b_idx, %d;
5460 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
5461 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
5462 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
5463 .reg .f32 %va, %vb, %vr;
5464 .reg .pred %p, %loop_p;
5465
5466 ld.param.u64 %a, [a_ptr];
5467 ld.param.u64 %b, [b_ptr];
5468 ld.param.u64 %out, [out_ptr];
5469 ld.param.u64 %a_str, [a_strides_ptr];
5470 ld.param.u64 %b_str, [b_strides_ptr];
5471 ld.param.u64 %oshape, [out_shape_ptr];
5472 ld.param.u32 %n_reg, [n];
5473 ld.param.u32 %ndim_reg, [ndim];
5474
5475 mov.u32 %bid, %ctaid.x;
5476 mov.u32 %bdim, %ntid.x;
5477 mov.u32 %r_tid, %tid.x;
5478 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5479 setp.ge.u32 %p, %r_tid, %n_reg;
5480 @%p bra DONE;
5481
5482 mov.u32 %remaining, %r_tid;
5483 mov.u32 %a_idx, 0;
5484 mov.u32 %b_idx, 0;
5485 mov.u32 %d, %ndim_reg;
5486LOOP:
5487 setp.eq.u32 %loop_p, %d, 0;
5488 @%loop_p bra END_LOOP;
5489 sub.u32 %d, %d, 1;
5490 cvt.u64.u32 %d64, %d;
5491 shl.b64 %d64, %d64, 2;
5492 add.u64 %tmp, %oshape, %d64;
5493 ld.global.u32 %shape_d, [%tmp];
5494 add.u64 %tmp, %a_str, %d64;
5495 ld.global.u32 %a_str_d, [%tmp];
5496 add.u64 %tmp, %b_str, %d64;
5497 ld.global.u32 %b_str_d, [%tmp];
5498 rem.u32 %coord, %remaining, %shape_d;
5499 div.u32 %remaining, %remaining, %shape_d;
5500 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
5501 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
5502 bra LOOP;
5503END_LOOP:
5504
5505 cvt.u64.u32 %off_a, %a_idx;
5506 shl.b64 %off_a, %off_a, 2;
5507 add.u64 %off_a, %a, %off_a;
5508 ld.global.f32 %va, [%off_a];
5509 cvt.u64.u32 %off_b, %b_idx;
5510 shl.b64 %off_b, %off_b, 2;
5511 add.u64 %off_b, %b, %off_b;
5512 ld.global.f32 %vb, [%off_b];
5513
5514 sub.f32 %vr, %va, %vb;
5515
5516 cvt.u64.u32 %off_out, %r_tid;
5517 shl.b64 %off_out, %off_out, 2;
5518 add.u64 %off_out, %out, %off_out;
5519 st.global.f32 [%off_out], %vr;
5520DONE:
5521 ret;
5522}
5523";
5524
5525#[cfg(feature = "cuda")]
5527pub(crate) const BROADCAST_MUL_PTX: &str = "\
5528.version 7.0
5529.target sm_52
5530.address_size 64
5531
5532.visible .entry broadcast_mul_kernel(
5533 .param .u64 a_ptr,
5534 .param .u64 b_ptr,
5535 .param .u64 out_ptr,
5536 .param .u64 a_strides_ptr,
5537 .param .u64 b_strides_ptr,
5538 .param .u64 out_shape_ptr,
5539 .param .u32 n,
5540 .param .u32 ndim
5541) {
5542 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
5543 .reg .u32 %remaining, %a_idx, %b_idx, %d;
5544 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
5545 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
5546 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
5547 .reg .f32 %va, %vb, %vr;
5548 .reg .pred %p, %loop_p;
5549
5550 ld.param.u64 %a, [a_ptr];
5551 ld.param.u64 %b, [b_ptr];
5552 ld.param.u64 %out, [out_ptr];
5553 ld.param.u64 %a_str, [a_strides_ptr];
5554 ld.param.u64 %b_str, [b_strides_ptr];
5555 ld.param.u64 %oshape, [out_shape_ptr];
5556 ld.param.u32 %n_reg, [n];
5557 ld.param.u32 %ndim_reg, [ndim];
5558
5559 mov.u32 %bid, %ctaid.x;
5560 mov.u32 %bdim, %ntid.x;
5561 mov.u32 %r_tid, %tid.x;
5562 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5563 setp.ge.u32 %p, %r_tid, %n_reg;
5564 @%p bra DONE;
5565
5566 mov.u32 %remaining, %r_tid;
5567 mov.u32 %a_idx, 0;
5568 mov.u32 %b_idx, 0;
5569 mov.u32 %d, %ndim_reg;
5570LOOP:
5571 setp.eq.u32 %loop_p, %d, 0;
5572 @%loop_p bra END_LOOP;
5573 sub.u32 %d, %d, 1;
5574 cvt.u64.u32 %d64, %d;
5575 shl.b64 %d64, %d64, 2;
5576 add.u64 %tmp, %oshape, %d64;
5577 ld.global.u32 %shape_d, [%tmp];
5578 add.u64 %tmp, %a_str, %d64;
5579 ld.global.u32 %a_str_d, [%tmp];
5580 add.u64 %tmp, %b_str, %d64;
5581 ld.global.u32 %b_str_d, [%tmp];
5582 rem.u32 %coord, %remaining, %shape_d;
5583 div.u32 %remaining, %remaining, %shape_d;
5584 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
5585 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
5586 bra LOOP;
5587END_LOOP:
5588
5589 cvt.u64.u32 %off_a, %a_idx;
5590 shl.b64 %off_a, %off_a, 2;
5591 add.u64 %off_a, %a, %off_a;
5592 ld.global.f32 %va, [%off_a];
5593 cvt.u64.u32 %off_b, %b_idx;
5594 shl.b64 %off_b, %off_b, 2;
5595 add.u64 %off_b, %b, %off_b;
5596 ld.global.f32 %vb, [%off_b];
5597
5598 mul.f32 %vr, %va, %vb;
5599
5600 cvt.u64.u32 %off_out, %r_tid;
5601 shl.b64 %off_out, %off_out, 2;
5602 add.u64 %off_out, %out, %off_out;
5603 st.global.f32 [%off_out], %vr;
5604DONE:
5605 ret;
5606}
5607";
5608
5609#[cfg(feature = "cuda")]
5612pub(crate) const BROADCAST_DIV_PTX: &str = "\
5613.version 7.0
5614.target sm_52
5615.address_size 64
5616
5617.visible .entry broadcast_div_kernel(
5618 .param .u64 a_ptr,
5619 .param .u64 b_ptr,
5620 .param .u64 out_ptr,
5621 .param .u64 a_strides_ptr,
5622 .param .u64 b_strides_ptr,
5623 .param .u64 out_shape_ptr,
5624 .param .u32 n,
5625 .param .u32 ndim
5626) {
5627 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
5628 .reg .u32 %remaining, %a_idx, %b_idx, %d;
5629 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
5630 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
5631 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
5632 .reg .f32 %va, %vb, %vr;
5633 .reg .pred %p, %loop_p;
5634
5635 ld.param.u64 %a, [a_ptr];
5636 ld.param.u64 %b, [b_ptr];
5637 ld.param.u64 %out, [out_ptr];
5638 ld.param.u64 %a_str, [a_strides_ptr];
5639 ld.param.u64 %b_str, [b_strides_ptr];
5640 ld.param.u64 %oshape, [out_shape_ptr];
5641 ld.param.u32 %n_reg, [n];
5642 ld.param.u32 %ndim_reg, [ndim];
5643
5644 mov.u32 %bid, %ctaid.x;
5645 mov.u32 %bdim, %ntid.x;
5646 mov.u32 %r_tid, %tid.x;
5647 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5648 setp.ge.u32 %p, %r_tid, %n_reg;
5649 @%p bra DONE;
5650
5651 mov.u32 %remaining, %r_tid;
5652 mov.u32 %a_idx, 0;
5653 mov.u32 %b_idx, 0;
5654 mov.u32 %d, %ndim_reg;
5655LOOP:
5656 setp.eq.u32 %loop_p, %d, 0;
5657 @%loop_p bra END_LOOP;
5658 sub.u32 %d, %d, 1;
5659 cvt.u64.u32 %d64, %d;
5660 shl.b64 %d64, %d64, 2;
5661 add.u64 %tmp, %oshape, %d64;
5662 ld.global.u32 %shape_d, [%tmp];
5663 add.u64 %tmp, %a_str, %d64;
5664 ld.global.u32 %a_str_d, [%tmp];
5665 add.u64 %tmp, %b_str, %d64;
5666 ld.global.u32 %b_str_d, [%tmp];
5667 rem.u32 %coord, %remaining, %shape_d;
5668 div.u32 %remaining, %remaining, %shape_d;
5669 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
5670 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
5671 bra LOOP;
5672END_LOOP:
5673
5674 cvt.u64.u32 %off_a, %a_idx;
5675 shl.b64 %off_a, %off_a, 2;
5676 add.u64 %off_a, %a, %off_a;
5677 ld.global.f32 %va, [%off_a];
5678 cvt.u64.u32 %off_b, %b_idx;
5679 shl.b64 %off_b, %off_b, 2;
5680 add.u64 %off_b, %b, %off_b;
5681 ld.global.f32 %vb, [%off_b];
5682
5683 div.f32 %vr, %va, %vb;
5684
5685 cvt.u64.u32 %off_out, %r_tid;
5686 shl.b64 %off_out, %off_out, 2;
5687 add.u64 %off_out, %out, %off_out;
5688 st.global.f32 [%off_out], %vr;
5689DONE:
5690 ret;
5691}
5692";
5693
5694#[cfg(feature = "cuda")]
5702pub(crate) const STRIDED_SPLIT_PTX: &str = "\
5703.version 7.0
5704.target sm_52
5705.address_size 64
5706
5707.visible .entry strided_split_kernel(
5708 .param .u64 input_ptr,
5709 .param .u64 output_ptr,
5710 .param .u32 total_along_axis,
5711 .param .u32 split_offset,
5712 .param .u32 split_size,
5713 .param .u32 inner_size,
5714 .param .u32 n
5715) {
5716 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
5717 .reg .u32 %total_ax, %sp_off, %sp_sz, %inner_sz;
5718 .reg .u32 %outer_idx, %within, %chunk_stride, %src_idx, %base_off, %tmp;
5719 .reg .u64 %in, %out, %off;
5720 .reg .f32 %val;
5721 .reg .pred %p;
5722
5723 ld.param.u64 %in, [input_ptr];
5724 ld.param.u64 %out, [output_ptr];
5725 ld.param.u32 %total_ax, [total_along_axis];
5726 ld.param.u32 %sp_off, [split_offset];
5727 ld.param.u32 %sp_sz, [split_size];
5728 ld.param.u32 %inner_sz, [inner_size];
5729 ld.param.u32 %n_reg, [n];
5730
5731 mov.u32 %bid, %ctaid.x;
5732 mov.u32 %bdim, %ntid.x;
5733 mov.u32 %r_tid, %tid.x;
5734 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5735
5736 setp.ge.u32 %p, %r_tid, %n_reg;
5737 @%p bra DONE;
5738
5739 // chunk_stride = split_size * inner_size
5740 mul.lo.u32 %chunk_stride, %sp_sz, %inner_sz;
5741
5742 // outer_idx = r_tid / chunk_stride
5743 div.u32 %outer_idx, %r_tid, %chunk_stride;
5744
5745 // within = r_tid % chunk_stride
5746 rem.u32 %within, %r_tid, %chunk_stride;
5747
5748 // base_off = split_offset * inner_size
5749 mul.lo.u32 %base_off, %sp_off, %inner_sz;
5750
5751 // src_idx = outer_idx * total_along_axis * inner_size + base_off + within
5752 mul.lo.u32 %src_idx, %outer_idx, %total_ax;
5753 mul.lo.u32 %src_idx, %src_idx, %inner_sz;
5754 add.u32 %src_idx, %src_idx, %base_off;
5755 add.u32 %src_idx, %src_idx, %within;
5756
5757 // Load from in[src_idx]
5758 cvt.u64.u32 %off, %src_idx;
5759 shl.b64 %off, %off, 2;
5760 add.u64 %off, %in, %off;
5761 ld.global.f32 %val, [%off];
5762
5763 // Store to out[r_tid]
5764 cvt.u64.u32 %off, %r_tid;
5765 shl.b64 %off, %off, 2;
5766 add.u64 %off, %out, %off;
5767 st.global.f32 [%off], %val;
5768
5769DONE:
5770 ret;
5771}
5772";
5773
5774#[cfg(feature = "cuda")]
5783pub(crate) const STRIDED_CAT_PTX: &str = "\
5784.version 7.0
5785.target sm_52
5786.address_size 64
5787
5788.visible .entry strided_cat_kernel(
5789 .param .u64 input_ptr,
5790 .param .u64 output_ptr,
5791 .param .u32 total_along_axis,
5792 .param .u32 cat_offset,
5793 .param .u32 part_size,
5794 .param .u32 inner_size,
5795 .param .u32 n
5796) {
5797 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
5798 .reg .u32 %total_ax, %cat_off, %part_sz, %inner_sz;
5799 .reg .u32 %outer_idx, %within, %chunk_stride, %dst_idx, %base_off;
5800 .reg .u64 %in, %out, %off;
5801 .reg .f32 %val;
5802 .reg .pred %p;
5803
5804 ld.param.u64 %in, [input_ptr];
5805 ld.param.u64 %out, [output_ptr];
5806 ld.param.u32 %total_ax, [total_along_axis];
5807 ld.param.u32 %cat_off, [cat_offset];
5808 ld.param.u32 %part_sz, [part_size];
5809 ld.param.u32 %inner_sz, [inner_size];
5810 ld.param.u32 %n_reg, [n];
5811
5812 mov.u32 %bid, %ctaid.x;
5813 mov.u32 %bdim, %ntid.x;
5814 mov.u32 %r_tid, %tid.x;
5815 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5816
5817 setp.ge.u32 %p, %r_tid, %n_reg;
5818 @%p bra DONE;
5819
5820 // chunk_stride = part_size * inner_size
5821 mul.lo.u32 %chunk_stride, %part_sz, %inner_sz;
5822
5823 // outer_idx = r_tid / chunk_stride
5824 div.u32 %outer_idx, %r_tid, %chunk_stride;
5825
5826 // within = r_tid % chunk_stride
5827 rem.u32 %within, %r_tid, %chunk_stride;
5828
5829 // base_off = cat_offset * inner_size
5830 mul.lo.u32 %base_off, %cat_off, %inner_sz;
5831
5832 // dst_idx = outer_idx * total_along_axis * inner_size + base_off + within
5833 mul.lo.u32 %dst_idx, %outer_idx, %total_ax;
5834 mul.lo.u32 %dst_idx, %dst_idx, %inner_sz;
5835 add.u32 %dst_idx, %dst_idx, %base_off;
5836 add.u32 %dst_idx, %dst_idx, %within;
5837
5838 // Load from in[r_tid]
5839 cvt.u64.u32 %off, %r_tid;
5840 shl.b64 %off, %off, 2;
5841 add.u64 %off, %in, %off;
5842 ld.global.f32 %val, [%off];
5843
5844 // Store to out[dst_idx]
5845 cvt.u64.u32 %off, %dst_idx;
5846 shl.b64 %off, %off, 2;
5847 add.u64 %off, %out, %off;
5848 st.global.f32 [%off], %val;
5849
5850DONE:
5851 ret;
5852}
5853";
5854
5855#[cfg(feature = "cuda")]
5857pub(crate) const DIV_PTX: &str = "\
5858.version 7.0
5859.target sm_52
5860.address_size 64
5861
5862.visible .entry div_kernel(
5863 .param .u64 a_ptr,
5864 .param .u64 b_ptr,
5865 .param .u64 out_ptr,
5866 .param .u32 n
5867) {
5868 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
5869 .reg .u64 %a, %b, %out, %off;
5870 .reg .f32 %va, %vb, %vr;
5871 .reg .pred %p;
5872
5873 ld.param.u64 %a, [a_ptr];
5874 ld.param.u64 %b, [b_ptr];
5875 ld.param.u64 %out, [out_ptr];
5876 ld.param.u32 %n_reg, [n];
5877
5878 mov.u32 %bid, %ctaid.x;
5879 mov.u32 %bdim, %ntid.x;
5880 mov.u32 %r_tid, %tid.x;
5881 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5882
5883 setp.ge.u32 %p, %r_tid, %n_reg;
5884 @%p bra DONE;
5885
5886 cvt.u64.u32 %off, %r_tid;
5887 shl.b64 %off, %off, 2;
5888
5889 add.u64 %a, %a, %off;
5890 add.u64 %b, %b, %off;
5891 add.u64 %out, %out, %off;
5892
5893 ld.global.f32 %va, [%a];
5894 ld.global.f32 %vb, [%b];
5895 div.rn.f32 %vr, %va, %vb;
5896 st.global.f32 [%out], %vr;
5897
5898DONE:
5899 ret;
5900}
5901";
5902
5903#[cfg(feature = "cuda")]
5905pub(crate) const EXP_PTX: &str = "\
5906.version 7.0
5907.target sm_52
5908.address_size 64
5909
5910.visible .entry exp_kernel(
5911 .param .u64 a_ptr,
5912 .param .u64 out_ptr,
5913 .param .u32 n
5914) {
5915 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
5916 .reg .u64 %a, %out, %off;
5917 .reg .f32 %va, %vr;
5918 .reg .pred %p;
5919
5920 ld.param.u64 %a, [a_ptr];
5921 ld.param.u64 %out, [out_ptr];
5922 ld.param.u32 %n_reg, [n];
5923
5924 mov.u32 %bid, %ctaid.x;
5925 mov.u32 %bdim, %ntid.x;
5926 mov.u32 %r_tid, %tid.x;
5927 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5928
5929 setp.ge.u32 %p, %r_tid, %n_reg;
5930 @%p bra DONE;
5931
5932 cvt.u64.u32 %off, %r_tid;
5933 shl.b64 %off, %off, 2;
5934
5935 add.u64 %a, %a, %off;
5936 add.u64 %out, %out, %off;
5937
5938 ld.global.f32 %va, [%a];
5939 // PTX ex2.approx computes 2^x; use the identity exp(x) = 2^(x * log2(e))
5940 // log2(e) = 1.4426950408889634
5941 mul.f32 %va, %va, 0f3FB8AA3B;
5942 ex2.approx.f32 %vr, %va;
5943 st.global.f32 [%out], %vr;
5944
5945DONE:
5946 ret;
5947}
5948";
5949
5950#[cfg(feature = "cuda")]
5952pub(crate) const LOG_PTX: &str = "\
5953.version 7.0
5954.target sm_52
5955.address_size 64
5956
5957.visible .entry log_kernel(
5958 .param .u64 a_ptr,
5959 .param .u64 out_ptr,
5960 .param .u32 n
5961) {
5962 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
5963 .reg .u64 %a, %out, %off;
5964 .reg .f32 %va, %vr;
5965 .reg .pred %p;
5966
5967 ld.param.u64 %a, [a_ptr];
5968 ld.param.u64 %out, [out_ptr];
5969 ld.param.u32 %n_reg, [n];
5970
5971 mov.u32 %bid, %ctaid.x;
5972 mov.u32 %bdim, %ntid.x;
5973 mov.u32 %r_tid, %tid.x;
5974 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5975
5976 setp.ge.u32 %p, %r_tid, %n_reg;
5977 @%p bra DONE;
5978
5979 cvt.u64.u32 %off, %r_tid;
5980 shl.b64 %off, %off, 2;
5981
5982 add.u64 %a, %a, %off;
5983 add.u64 %out, %out, %off;
5984
5985 ld.global.f32 %va, [%a];
5986 // PTX lg2.approx computes log2(x); use the identity ln(x) = log2(x) / log2(e)
5987 // 1/log2(e) = ln(2) = 0.6931471805599453
5988 lg2.approx.f32 %vr, %va;
5989 mul.f32 %vr, %vr, 0f3F317218;
5990 st.global.f32 [%out], %vr;
5991
5992DONE:
5993 ret;
5994}
5995";
5996
5997#[cfg(feature = "cuda")]
5999pub(crate) const SQRT_PTX: &str = "\
6000.version 7.0
6001.target sm_52
6002.address_size 64
6003
6004.visible .entry sqrt_kernel(
6005 .param .u64 a_ptr,
6006 .param .u64 out_ptr,
6007 .param .u32 n
6008) {
6009 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
6010 .reg .u64 %a, %out, %off;
6011 .reg .f32 %va, %vr;
6012 .reg .pred %p;
6013
6014 ld.param.u64 %a, [a_ptr];
6015 ld.param.u64 %out, [out_ptr];
6016 ld.param.u32 %n_reg, [n];
6017
6018 mov.u32 %bid, %ctaid.x;
6019 mov.u32 %bdim, %ntid.x;
6020 mov.u32 %r_tid, %tid.x;
6021 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
6022
6023 setp.ge.u32 %p, %r_tid, %n_reg;
6024 @%p bra DONE;
6025
6026 cvt.u64.u32 %off, %r_tid;
6027 shl.b64 %off, %off, 2;
6028
6029 add.u64 %a, %a, %off;
6030 add.u64 %out, %out, %off;
6031
6032 ld.global.f32 %va, [%a];
6033 sqrt.rn.f32 %vr, %va;
6034 st.global.f32 [%out], %vr;
6035
6036DONE:
6037 ret;
6038}
6039";
6040
6041#[cfg(feature = "cuda")]
6044pub(crate) const POW_PTX: &str = "\
6045.version 7.0
6046.target sm_52
6047.address_size 64
6048
6049.visible .entry pow_kernel(
6050 .param .u64 a_ptr,
6051 .param .u64 out_ptr,
6052 .param .f32 exponent,
6053 .param .u32 n
6054) {
6055 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
6056 .reg .u64 %a, %out, %off;
6057 .reg .f32 %va, %vr, %exp, %lg;
6058 .reg .pred %p;
6059
6060 ld.param.u64 %a, [a_ptr];
6061 ld.param.u64 %out, [out_ptr];
6062 ld.param.f32 %exp, [exponent];
6063 ld.param.u32 %n_reg, [n];
6064
6065 mov.u32 %bid, %ctaid.x;
6066 mov.u32 %bdim, %ntid.x;
6067 mov.u32 %r_tid, %tid.x;
6068 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
6069
6070 setp.ge.u32 %p, %r_tid, %n_reg;
6071 @%p bra DONE;
6072
6073 cvt.u64.u32 %off, %r_tid;
6074 shl.b64 %off, %off, 2;
6075
6076 add.u64 %a, %a, %off;
6077 add.u64 %out, %out, %off;
6078
6079 ld.global.f32 %va, [%a];
6080 // x^e = 2^(e * log2(x))
6081 lg2.approx.f32 %lg, %va;
6082 mul.f32 %lg, %lg, %exp;
6083 ex2.approx.f32 %vr, %lg;
6084 st.global.f32 [%out], %vr;
6085
6086DONE:
6087 ret;
6088}
6089";
6090
6091#[cfg(feature = "cuda")]
6093pub(crate) const ABS_PTX: &str = "\
6094.version 7.0
6095.target sm_52
6096.address_size 64
6097
6098.visible .entry abs_kernel(
6099 .param .u64 a_ptr,
6100 .param .u64 out_ptr,
6101 .param .u32 n
6102) {
6103 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
6104 .reg .u64 %a, %out, %off;
6105 .reg .f32 %va, %vr;
6106 .reg .pred %p;
6107
6108 ld.param.u64 %a, [a_ptr];
6109 ld.param.u64 %out, [out_ptr];
6110 ld.param.u32 %n_reg, [n];
6111
6112 mov.u32 %bid, %ctaid.x;
6113 mov.u32 %bdim, %ntid.x;
6114 mov.u32 %r_tid, %tid.x;
6115 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
6116
6117 setp.ge.u32 %p, %r_tid, %n_reg;
6118 @%p bra DONE;
6119
6120 cvt.u64.u32 %off, %r_tid;
6121 shl.b64 %off, %off, 2;
6122
6123 add.u64 %a, %a, %off;
6124 add.u64 %out, %out, %off;
6125
6126 ld.global.f32 %va, [%a];
6127 abs.f32 %vr, %va;
6128 st.global.f32 [%out], %vr;
6129
6130DONE:
6131 ret;
6132}
6133";
6134
6135#[cfg(feature = "cuda")]
6137pub(crate) const SIGMOID_PTX: &str = "\
6138.version 7.0
6139.target sm_52
6140.address_size 64
6141
6142.visible .entry sigmoid_kernel(
6143 .param .u64 a_ptr,
6144 .param .u64 out_ptr,
6145 .param .u32 n
6146) {
6147 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
6148 .reg .u64 %a, %out, %off;
6149 .reg .f32 %va, %vr, %neg, %e, %denom, %one, %lg2e;
6150 .reg .pred %p;
6151
6152 ld.param.u64 %a, [a_ptr];
6153 ld.param.u64 %out, [out_ptr];
6154 ld.param.u32 %n_reg, [n];
6155
6156 mov.u32 %bid, %ctaid.x;
6157 mov.u32 %bdim, %ntid.x;
6158 mov.u32 %r_tid, %tid.x;
6159 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
6160
6161 setp.ge.u32 %p, %r_tid, %n_reg;
6162 @%p bra DONE;
6163
6164 cvt.u64.u32 %off, %r_tid;
6165 shl.b64 %off, %off, 2;
6166
6167 add.u64 %a, %a, %off;
6168 add.u64 %out, %out, %off;
6169
6170 ld.global.f32 %va, [%a];
6171 // sigmoid(x) = 1 / (1 + exp(-x))
6172 neg.f32 %neg, %va;
6173 mov.f32 %lg2e, 0f3FB8AA3B;
6174 mul.f32 %neg, %neg, %lg2e;
6175 ex2.approx.f32 %e, %neg;
6176 mov.f32 %one, 0f3F800000;
6177 add.f32 %denom, %one, %e;
6178 div.rn.f32 %vr, %one, %denom;
6179 st.global.f32 [%out], %vr;
6180
6181DONE:
6182 ret;
6183}
6184";
6185
6186#[cfg(feature = "cuda")]
6189pub(crate) const TANH_PTX: &str = "\
6190.version 7.0
6191.target sm_52
6192.address_size 64
6193
6194.visible .entry tanh_kernel(
6195 .param .u64 a_ptr,
6196 .param .u64 out_ptr,
6197 .param .u32 n
6198) {
6199 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
6200 .reg .u64 %a, %out, %off;
6201 .reg .f32 %va, %vr, %neg2x, %e, %denom, %sig, %one, %two, %lg2e;
6202 .reg .pred %p;
6203
6204 ld.param.u64 %a, [a_ptr];
6205 ld.param.u64 %out, [out_ptr];
6206 ld.param.u32 %n_reg, [n];
6207
6208 mov.u32 %bid, %ctaid.x;
6209 mov.u32 %bdim, %ntid.x;
6210 mov.u32 %r_tid, %tid.x;
6211 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
6212
6213 setp.ge.u32 %p, %r_tid, %n_reg;
6214 @%p bra DONE;
6215
6216 cvt.u64.u32 %off, %r_tid;
6217 shl.b64 %off, %off, 2;
6218
6219 add.u64 %a, %a, %off;
6220 add.u64 %out, %out, %off;
6221
6222 ld.global.f32 %va, [%a];
6223 // tanh(x) = 2*sigmoid(2x) - 1
6224 mov.f32 %two, 0f40000000;
6225 mul.f32 %neg2x, %va, %two;
6226 neg.f32 %neg2x, %neg2x;
6227 mov.f32 %lg2e, 0f3FB8AA3B;
6228 mul.f32 %neg2x, %neg2x, %lg2e;
6229 ex2.approx.f32 %e, %neg2x;
6230 mov.f32 %one, 0f3F800000;
6231 add.f32 %denom, %one, %e;
6232 div.rn.f32 %sig, %one, %denom;
6233 mul.f32 %vr, %two, %sig;
6234 sub.f32 %vr, %vr, %one;
6235 st.global.f32 [%out], %vr;
6236
6237DONE:
6238 ret;
6239}
6240";
6241
6242#[cfg(feature = "cuda")]
6252pub(crate) const FUSED_ADAM_PTX: &str = "\
6253.version 7.0
6254.target sm_52
6255.address_size 64
6256
6257.visible .entry fused_adam_kernel(
6258 .param .u64 param_ptr,
6259 .param .u64 grad_ptr,
6260 .param .u64 exp_avg_ptr,
6261 .param .u64 exp_avg_sq_ptr,
6262 .param .f32 beta1,
6263 .param .f32 beta2,
6264 .param .f32 lr,
6265 .param .f32 eps,
6266 .param .f32 bc1,
6267 .param .f32 bc2,
6268 .param .f32 weight_decay,
6269 .param .u32 n
6270) {
6271 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
6272 .reg .u64 %p, %g, %m, %v, %off;
6273 .reg .f32 %vp, %vg, %vm, %vv;
6274 .reg .f32 %b1, %b2, %f_lr, %f_eps, %f_bc1, %f_bc2, %f_wd;
6275 .reg .f32 %t1, %t2, %m_hat, %v_hat, %denom, %update;
6276 .reg .f32 %one;
6277 .reg .pred %p_bound, %p_wd;
6278
6279 ld.param.u64 %p, [param_ptr];
6280 ld.param.u64 %g, [grad_ptr];
6281 ld.param.u64 %m, [exp_avg_ptr];
6282 ld.param.u64 %v, [exp_avg_sq_ptr];
6283 ld.param.f32 %b1, [beta1];
6284 ld.param.f32 %b2, [beta2];
6285 ld.param.f32 %f_lr, [lr];
6286 ld.param.f32 %f_eps, [eps];
6287 ld.param.f32 %f_bc1, [bc1];
6288 ld.param.f32 %f_bc2, [bc2];
6289 ld.param.f32 %f_wd, [weight_decay];
6290 ld.param.u32 %n_reg, [n];
6291
6292 mov.u32 %bid, %ctaid.x;
6293 mov.u32 %bdim, %ntid.x;
6294 mov.u32 %r_tid, %tid.x;
6295 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
6296
6297 setp.ge.u32 %p_bound, %r_tid, %n_reg;
6298 @%p_bound bra DONE;
6299
6300 cvt.u64.u32 %off, %r_tid;
6301 shl.b64 %off, %off, 2;
6302
6303 add.u64 %p, %p, %off;
6304 add.u64 %g, %g, %off;
6305 add.u64 %m, %m, %off;
6306 add.u64 %v, %v, %off;
6307
6308 ld.global.f32 %vp, [%p];
6309 ld.global.f32 %vg, [%g];
6310 ld.global.f32 %vm, [%m];
6311 ld.global.f32 %vv, [%v];
6312
6313 // L2 weight decay: g = g + wd * p
6314 mov.f32 %one, 0f00000000;
6315 setp.gt.f32 %p_wd, %f_wd, %one;
6316 @%p_wd fma.rn.f32 %vg, %f_wd, %vp, %vg;
6317
6318 // exp_avg = beta1 * exp_avg + (1 - beta1) * g
6319 mov.f32 %one, 0f3F800000;
6320 sub.f32 %t1, %one, %b1;
6321 mul.f32 %vm, %vm, %b1;
6322 fma.rn.f32 %vm, %t1, %vg, %vm;
6323
6324 // exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * g * g
6325 sub.f32 %t2, %one, %b2;
6326 mul.f32 %vv, %vv, %b2;
6327 mul.f32 %t1, %vg, %vg;
6328 fma.rn.f32 %vv, %t2, %t1, %vv;
6329
6330 // m_hat = exp_avg / bc1
6331 div.rn.f32 %m_hat, %vm, %f_bc1;
6332
6333 // v_hat = exp_avg_sq / bc2
6334 div.rn.f32 %v_hat, %vv, %f_bc2;
6335
6336 // denom = sqrt(v_hat) + eps
6337 sqrt.rn.f32 %denom, %v_hat;
6338 add.f32 %denom, %denom, %f_eps;
6339
6340 // param = param - lr * m_hat / denom
6341 div.rn.f32 %update, %m_hat, %denom;
6342 mul.f32 %update, %update, %f_lr;
6343 sub.f32 %vp, %vp, %update;
6344
6345 st.global.f32 [%p], %vp;
6346 st.global.f32 [%m], %vm;
6347 st.global.f32 [%v], %vv;
6348
6349DONE:
6350 ret;
6351}
6352";
6353
6354#[cfg(feature = "cuda")]
6366pub(crate) const FUSED_GRU_FORWARD_PTX: &str = "\
6367.version 7.0
6368.target sm_52
6369.address_size 64
6370
6371.visible .entry fused_gru_forward_kernel(
6372 .param .u64 input_gates_ptr,
6373 .param .u64 hidden_gates_ptr,
6374 .param .u64 bias_ih_ptr,
6375 .param .u64 bias_hh_ptr,
6376 .param .u64 hx_ptr,
6377 .param .u64 hy_ptr,
6378 .param .u64 workspace_ptr,
6379 .param .u32 hsz,
6380 .param .u32 total
6381) {
6382 .reg .u32 %tid, %bid, %bdim, %gdim, %total_reg, %hsz_reg;
6383 .reg .u32 %idx, %stride, %offset3, %offset5, %hmod, %batch_idx;
6384 .reg .u64 %ig, %hg, %b1, %b2, %hx, %hy, %ws;
6385 .reg .u64 %off64, %tmp64;
6386 .reg .f32 %ir, %ii, %in, %hr, %hi, %hn;
6387 .reg .f32 %b1r, %b1i, %b1n, %b2r, %b2i, %b2n;
6388 .reg .f32 %hx_val, %rg, %zg, %ng, %hy_val;
6389 .reg .f32 %one, %neg_one, %exp_val, %denom, %tmp;
6390 .reg .pred %p;
6391
6392 ld.param.u64 %ig, [input_gates_ptr];
6393 ld.param.u64 %hg, [hidden_gates_ptr];
6394 ld.param.u64 %b1, [bias_ih_ptr];
6395 ld.param.u64 %b2, [bias_hh_ptr];
6396 ld.param.u64 %hx, [hx_ptr];
6397 ld.param.u64 %hy, [hy_ptr];
6398 ld.param.u64 %ws, [workspace_ptr];
6399 ld.param.u32 %hsz_reg, [hsz];
6400 ld.param.u32 %total_reg, [total];
6401
6402 mov.u32 %bid, %ctaid.x;
6403 mov.u32 %bdim, %ntid.x;
6404 mov.u32 %tid, %tid.x;
6405 mov.u32 %gdim, %nctaid.x;
6406 mad.lo.u32 %idx, %bid, %bdim, %tid;
6407 mul.lo.u32 %stride, %bdim, %gdim;
6408 mov.f32 %one, 0f3F800000;
6409
6410LOOP:
6411 setp.ge.u32 %p, %idx, %total_reg;
6412 @%p bra END;
6413
6414 // offset3 = (idx/hsz)*3*hsz + idx%hsz (into [B, 3*H] gates tensor)
6415 div.u32 %batch_idx, %idx, %hsz_reg;
6416 rem.u32 %hmod, %idx, %hsz_reg;
6417 mul.lo.u32 %offset3, %batch_idx, %hsz_reg;
6418 mul.lo.u32 %offset3, %offset3, 3;
6419 add.u32 %offset3, %offset3, %hmod;
6420
6421 // Load input gate components: ir, ii, in
6422 cvt.u64.u32 %off64, %offset3;
6423 shl.b64 %off64, %off64, 2;
6424 add.u64 %tmp64, %ig, %off64;
6425 ld.global.f32 %ir, [%tmp64];
6426 cvt.u64.u32 %off64, %hsz_reg;
6427 shl.b64 %off64, %off64, 2;
6428 add.u64 %tmp64, %tmp64, %off64;
6429 ld.global.f32 %ii, [%tmp64];
6430 add.u64 %tmp64, %tmp64, %off64;
6431 ld.global.f32 %in, [%tmp64];
6432
6433 // Load hidden gate components: hr, hi, hn
6434 cvt.u64.u32 %off64, %offset3;
6435 shl.b64 %off64, %off64, 2;
6436 add.u64 %tmp64, %hg, %off64;
6437 ld.global.f32 %hr, [%tmp64];
6438 cvt.u64.u32 %off64, %hsz_reg;
6439 shl.b64 %off64, %off64, 2;
6440 add.u64 %tmp64, %tmp64, %off64;
6441 ld.global.f32 %hi, [%tmp64];
6442 add.u64 %tmp64, %tmp64, %off64;
6443 ld.global.f32 %hn, [%tmp64];
6444
6445 // Load biases (indexed by hmod, hmod+hsz, hmod+2*hsz)
6446 cvt.u64.u32 %off64, %hmod;
6447 shl.b64 %off64, %off64, 2;
6448 add.u64 %tmp64, %b1, %off64;
6449 ld.global.f32 %b1r, [%tmp64];
6450 cvt.u64.u32 %off64, %hsz_reg;
6451 shl.b64 %off64, %off64, 2;
6452 add.u64 %tmp64, %tmp64, %off64;
6453 ld.global.f32 %b1i, [%tmp64];
6454 add.u64 %tmp64, %tmp64, %off64;
6455 ld.global.f32 %b1n, [%tmp64];
6456
6457 cvt.u64.u32 %off64, %hmod;
6458 shl.b64 %off64, %off64, 2;
6459 add.u64 %tmp64, %b2, %off64;
6460 ld.global.f32 %b2r, [%tmp64];
6461 cvt.u64.u32 %off64, %hsz_reg;
6462 shl.b64 %off64, %off64, 2;
6463 add.u64 %tmp64, %tmp64, %off64;
6464 ld.global.f32 %b2i, [%tmp64];
6465 add.u64 %tmp64, %tmp64, %off64;
6466 ld.global.f32 %b2n, [%tmp64];
6467
6468 // Load hx[idx]
6469 cvt.u64.u32 %off64, %idx;
6470 shl.b64 %off64, %off64, 2;
6471 add.u64 %tmp64, %hx, %off64;
6472 ld.global.f32 %hx_val, [%tmp64];
6473
6474 // r = sigmoid(ir + hr + b1r + b2r)
6475 add.f32 %rg, %ir, %hr;
6476 add.f32 %rg, %rg, %b1r;
6477 add.f32 %rg, %rg, %b2r;
6478 neg.f32 %tmp, %rg;
6479 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
6480 ex2.approx.f32 %exp_val, %tmp;
6481 add.f32 %denom, %one, %exp_val;
6482 div.rn.f32 %rg, %one, %denom;
6483
6484 // z = sigmoid(ii + hi + b1i + b2i)
6485 add.f32 %zg, %ii, %hi;
6486 add.f32 %zg, %zg, %b1i;
6487 add.f32 %zg, %zg, %b2i;
6488 neg.f32 %tmp, %zg;
6489 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
6490 ex2.approx.f32 %exp_val, %tmp;
6491 add.f32 %denom, %one, %exp_val;
6492 div.rn.f32 %zg, %one, %denom;
6493
6494 // n = tanh(in + b1n + r*(hn + b2n))
6495 add.f32 %tmp, %hn, %b2n;
6496 fma.rn.f32 %ng, %rg, %tmp, %in;
6497 add.f32 %ng, %ng, %b1n;
6498 // tanh via 2*sigmoid(2x)-1
6499 mul.f32 %tmp, %ng, 0f40000000;
6500 neg.f32 %tmp, %tmp;
6501 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
6502 ex2.approx.f32 %exp_val, %tmp;
6503 add.f32 %denom, %one, %exp_val;
6504 div.rn.f32 %ng, %one, %denom;
6505 mul.f32 %ng, %ng, 0f40000000;
6506 sub.f32 %ng, %ng, %one;
6507
6508 // hy = n + z * (hx - n)
6509 sub.f32 %tmp, %hx_val, %ng;
6510 fma.rn.f32 %hy_val, %zg, %tmp, %ng;
6511
6512 // Store hy[idx]
6513 cvt.u64.u32 %off64, %idx;
6514 shl.b64 %off64, %off64, 2;
6515 add.u64 %tmp64, %hy, %off64;
6516 st.global.f32 [%tmp64], %hy_val;
6517
6518 // Store workspace: [r, z, n, hx, hn+b2n] at offset5 = (idx/hsz)*5*hsz + idx%hsz
6519 mul.lo.u32 %offset5, %batch_idx, %hsz_reg;
6520 mul.lo.u32 %offset5, %offset5, 5;
6521 add.u32 %offset5, %offset5, %hmod;
6522
6523 cvt.u64.u32 %off64, %offset5;
6524 shl.b64 %off64, %off64, 2;
6525 add.u64 %tmp64, %ws, %off64;
6526 st.global.f32 [%tmp64], %rg;
6527 cvt.u64.u32 %off64, %hsz_reg;
6528 shl.b64 %off64, %off64, 2;
6529 add.u64 %tmp64, %tmp64, %off64;
6530 st.global.f32 [%tmp64], %zg;
6531 add.u64 %tmp64, %tmp64, %off64;
6532 st.global.f32 [%tmp64], %ng;
6533 add.u64 %tmp64, %tmp64, %off64;
6534 st.global.f32 [%tmp64], %hx_val;
6535 add.u64 %tmp64, %tmp64, %off64;
6536 add.f32 %tmp, %hn, %b2n;
6537 st.global.f32 [%tmp64], %tmp;
6538
6539 add.u32 %idx, %idx, %stride;
6540 bra LOOP;
6541
6542END:
6543 ret;
6544}
6545";
6546
6547#[cfg(feature = "cuda")]
6561fn launch_cfg(n: usize) -> GpuResult<LaunchConfig> {
6562 if n > u32::MAX as usize {
6563 return Err(GpuError::ShapeMismatch {
6564 op: "kernel_launch",
6565 expected: vec![u32::MAX as usize],
6566 got: vec![n],
6567 });
6568 }
6569 const BLOCK: u32 = 256;
6570 let grid = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
6571 Ok(LaunchConfig {
6572 grid_dim: (grid.max(1), 1, 1),
6573 block_dim: (BLOCK, 1, 1),
6574 shared_mem_bytes: 0,
6575 })
6576}
6577
6578#[cfg(feature = "cuda")]
6584fn validate_binary(a: &CudaBuffer<f32>, b: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
6585 if a.device_ordinal() != device.ordinal() {
6586 return Err(GpuError::DeviceMismatch {
6587 expected: a.device_ordinal(),
6588 got: device.ordinal(),
6589 });
6590 }
6591 if b.device_ordinal() != device.ordinal() {
6592 return Err(GpuError::DeviceMismatch {
6593 expected: b.device_ordinal(),
6594 got: device.ordinal(),
6595 });
6596 }
6597 if a.len() != b.len() {
6598 return Err(GpuError::LengthMismatch {
6599 a: a.len(),
6600 b: b.len(),
6601 });
6602 }
6603 Ok(())
6604}
6605
6606#[cfg(feature = "cuda")]
6608fn validate_unary(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
6609 if a.device_ordinal() != device.ordinal() {
6610 return Err(GpuError::DeviceMismatch {
6611 expected: a.device_ordinal(),
6612 got: device.ordinal(),
6613 });
6614 }
6615 Ok(())
6616}
6617
6618#[cfg(feature = "cuda")]
6626fn try_launch_binary(
6627 a: &CudaBuffer<f32>,
6628 b: &CudaBuffer<f32>,
6629 device: &GpuDevice,
6630 ptx_src: &'static str,
6631 kernel_name: &'static str,
6632) -> GpuResult<Option<CudaBuffer<f32>>> {
6633 use cudarc::driver::PushKernelArg;
6634
6635 let n = a.len();
6636 let ctx = device.context();
6637 let stream = device.stream();
6638
6639 let f = match crate::module_cache::get_or_compile(
6643 ctx,
6644 ptx_src,
6645 kernel_name,
6646 device.ordinal() as u32,
6647 ) {
6648 Ok(f) => f,
6649 Err(_) => return Ok(None),
6650 };
6651
6652 let mut out = alloc_zeros_f32(n, device)?;
6653 let cfg = launch_cfg(n)?;
6654 let n_u32 = n as u32;
6655
6656 unsafe {
6660 stream
6661 .launch_builder(&f)
6662 .arg(a.inner())
6663 .arg(b.inner())
6664 .arg(out.inner_mut())
6665 .arg(&n_u32)
6666 .launch(cfg)?;
6667 }
6668
6669 Ok(Some(out))
6670}
6671
6672#[cfg(feature = "cuda")]
6677fn try_launch_binary_vec4(
6678 a: &CudaBuffer<f32>,
6679 b: &CudaBuffer<f32>,
6680 device: &GpuDevice,
6681 ptx_src: &'static str,
6682 kernel_name: &'static str,
6683) -> GpuResult<Option<CudaBuffer<f32>>> {
6684 use cudarc::driver::PushKernelArg;
6685
6686 let n = a.len();
6687 let n4 = (n / 4) as u32;
6688 let ctx = device.context();
6689 let stream = device.stream();
6690
6691 let f = match crate::module_cache::get_or_compile(
6692 ctx,
6693 ptx_src,
6694 kernel_name,
6695 device.ordinal() as u32,
6696 ) {
6697 Ok(f) => f,
6698 Err(_) => return Ok(None),
6699 };
6700
6701 let mut out = alloc_zeros_f32(n, device)?;
6702 let cfg = launch_cfg(n4 as usize)?;
6703
6704 unsafe {
6705 stream
6706 .launch_builder(&f)
6707 .arg(a.inner())
6708 .arg(b.inner())
6709 .arg(out.inner_mut())
6710 .arg(&n4)
6711 .launch(cfg)?;
6712 }
6713
6714 Ok(Some(out))
6715}
6716
6717#[cfg(feature = "cuda")]
6720fn try_launch_unary(
6721 a: &CudaBuffer<f32>,
6722 device: &GpuDevice,
6723 ptx_src: &'static str,
6724 kernel_name: &'static str,
6725) -> GpuResult<Option<CudaBuffer<f32>>> {
6726 use cudarc::driver::PushKernelArg;
6727
6728 let n = a.len();
6729 let ctx = device.context();
6730 let stream = device.stream();
6731
6732 let f = match crate::module_cache::get_or_compile(
6734 ctx,
6735 ptx_src,
6736 kernel_name,
6737 device.ordinal() as u32,
6738 ) {
6739 Ok(f) => f,
6740 Err(_) => return Ok(None),
6741 };
6742
6743 let mut out = alloc_zeros_f32(n, device)?;
6744 let cfg = launch_cfg(n)?;
6745 let n_u32 = n as u32;
6746
6747 unsafe {
6750 stream
6751 .launch_builder(&f)
6752 .arg(a.inner())
6753 .arg(out.inner_mut())
6754 .arg(&n_u32)
6755 .launch(cfg)?;
6756 }
6757
6758 Ok(Some(out))
6759}
6760
6761#[cfg(feature = "cuda")]
6768fn try_launch_binary_into(
6769 a: &CudaBuffer<f32>,
6770 b: &CudaBuffer<f32>,
6771 out: &mut CudaBuffer<f32>,
6772 device: &GpuDevice,
6773 ptx_src: &'static str,
6774 kernel_name: &'static str,
6775) -> GpuResult<bool> {
6776 use cudarc::driver::PushKernelArg;
6777
6778 let n = a.len();
6779 let ctx = device.context();
6780 let stream = device.stream();
6781
6782 let f = match crate::module_cache::get_or_compile(
6783 ctx,
6784 ptx_src,
6785 kernel_name,
6786 device.ordinal() as u32,
6787 ) {
6788 Ok(f) => f,
6789 Err(_) => return Ok(false),
6790 };
6791
6792 let cfg = launch_cfg(n)?;
6793 let n_u32 = n as u32;
6794
6795 unsafe {
6796 stream
6797 .launch_builder(&f)
6798 .arg(a.inner())
6799 .arg(b.inner())
6800 .arg(out.inner_mut())
6801 .arg(&n_u32)
6802 .launch(cfg)?;
6803 }
6804
6805 Ok(true)
6806}
6807
6808#[cfg(feature = "cuda")]
6811fn try_launch_unary_into(
6812 a: &CudaBuffer<f32>,
6813 out: &mut CudaBuffer<f32>,
6814 device: &GpuDevice,
6815 ptx_src: &'static str,
6816 kernel_name: &'static str,
6817) -> GpuResult<bool> {
6818 use cudarc::driver::PushKernelArg;
6819
6820 let n = a.len();
6821 let ctx = device.context();
6822 let stream = device.stream();
6823
6824 let f = match crate::module_cache::get_or_compile(
6825 ctx,
6826 ptx_src,
6827 kernel_name,
6828 device.ordinal() as u32,
6829 ) {
6830 Ok(f) => f,
6831 Err(_) => return Ok(false),
6832 };
6833
6834 let cfg = launch_cfg(n)?;
6835 let n_u32 = n as u32;
6836
6837 unsafe {
6838 stream
6839 .launch_builder(&f)
6840 .arg(a.inner())
6841 .arg(out.inner_mut())
6842 .arg(&n_u32)
6843 .launch(cfg)?;
6844 }
6845
6846 Ok(true)
6847}
6848
6849#[cfg(feature = "cuda")]
6856#[allow(clippy::too_many_arguments)]
6857fn try_launch_broadcast_binary(
6858 a: &CudaBuffer<f32>,
6859 b: &CudaBuffer<f32>,
6860 a_strides: &[u32],
6861 b_strides: &[u32],
6862 out_shape: &[u32],
6863 out_numel: usize,
6864 device: &GpuDevice,
6865 ptx_src: &'static str,
6866 kernel_name: &'static str,
6867) -> GpuResult<Option<CudaBuffer<f32>>> {
6868 use cudarc::driver::PushKernelArg;
6869
6870 let ndim = out_shape.len();
6871 let ctx = device.context();
6872 let stream = device.stream();
6873
6874 let f = match crate::module_cache::get_or_compile(
6875 ctx,
6876 ptx_src,
6877 kernel_name,
6878 device.ordinal() as u32,
6879 ) {
6880 Ok(f) => f,
6881 Err(_) => return Ok(None),
6882 };
6883
6884 let a_str_buf = cpu_to_gpu(a_strides, device)?;
6886 let b_str_buf = cpu_to_gpu(b_strides, device)?;
6887 let shape_buf = cpu_to_gpu(out_shape, device)?;
6888
6889 let mut out = alloc_zeros_f32(out_numel, device)?;
6890 let cfg = launch_cfg(out_numel)?;
6891 let n_u32 = out_numel as u32;
6892 let ndim_u32 = ndim as u32;
6893
6894 unsafe {
6897 stream
6898 .launch_builder(&f)
6899 .arg(a.inner())
6900 .arg(b.inner())
6901 .arg(out.inner_mut())
6902 .arg(a_str_buf.inner())
6903 .arg(b_str_buf.inner())
6904 .arg(shape_buf.inner())
6905 .arg(&n_u32)
6906 .arg(&ndim_u32)
6907 .launch(cfg)?;
6908 }
6909
6910 Ok(Some(out))
6911}
6912
6913#[cfg(feature = "cuda")]
6920fn broadcast_strides(in_shape: &[usize], out_shape: &[usize]) -> Vec<u32> {
6921 let ndim = out_shape.len();
6922 let in_ndim = in_shape.len();
6923 let mut strides = vec![0u32; ndim];
6924
6925 let mut stride: u32 = 1;
6927 for d in (0..ndim).rev() {
6928 let in_d = if d + in_ndim >= ndim {
6929 d + in_ndim - ndim
6930 } else {
6931 strides[d] = 0;
6933 continue;
6934 };
6935
6936 if in_shape[in_d] == 1 {
6937 strides[d] = 0; } else {
6939 strides[d] = stride;
6940 }
6941 stride *= in_shape[in_d] as u32;
6942 }
6943
6944 strides
6945}
6946
6947#[cfg(feature = "cuda")]
6954fn cpu_fallback_binary(
6955 a: &CudaBuffer<f32>,
6956 b: &CudaBuffer<f32>,
6957 device: &GpuDevice,
6958 op: fn(f32, f32) -> f32,
6959) -> GpuResult<CudaBuffer<f32>> {
6960 let a_host = gpu_to_cpu(a, device)?;
6961 let b_host = gpu_to_cpu(b, device)?;
6962 let result: Vec<f32> = a_host
6963 .iter()
6964 .zip(b_host.iter())
6965 .map(|(&x, &y)| op(x, y))
6966 .collect();
6967 cpu_to_gpu(&result, device)
6968}
6969
6970#[cfg(feature = "cuda")]
6972fn cpu_fallback_unary(
6973 a: &CudaBuffer<f32>,
6974 device: &GpuDevice,
6975 op: fn(f32) -> f32,
6976) -> GpuResult<CudaBuffer<f32>> {
6977 let a_host = gpu_to_cpu(a, device)?;
6978 let result: Vec<f32> = a_host.iter().map(|&x| op(x)).collect();
6979 cpu_to_gpu(&result, device)
6980}
6981
6982#[cfg(feature = "cuda")]
6998pub fn gpu_add(
6999 a: &CudaBuffer<f32>,
7000 b: &CudaBuffer<f32>,
7001 device: &GpuDevice,
7002) -> GpuResult<CudaBuffer<f32>> {
7003 validate_binary(a, b, device)?;
7004
7005 let n = a.len();
7007 if n >= 16 && n % 4 == 0 {
7008 if let Some(out) = try_launch_binary_vec4(
7009 a, b, device, ADD_VEC4_PTX, "add_vec4_kernel",
7010 )? {
7011 return Ok(out);
7012 }
7013 }
7014
7015 if let Some(out) = try_launch_binary(a, b, device, ADD_PTX, "add_kernel")? {
7016 return Ok(out);
7017 }
7018
7019 cpu_fallback_binary(a, b, device, |x, y| x + y)
7020}
7021
7022#[cfg(feature = "cuda")]
7034pub fn gpu_sub(
7035 a: &CudaBuffer<f32>,
7036 b: &CudaBuffer<f32>,
7037 device: &GpuDevice,
7038) -> GpuResult<CudaBuffer<f32>> {
7039 validate_binary(a, b, device)?;
7040
7041 if let Some(out) = try_launch_binary(a, b, device, SUB_PTX, "sub_kernel")? {
7042 return Ok(out);
7043 }
7044
7045 cpu_fallback_binary(a, b, device, |x, y| x - y)
7046}
7047
7048#[cfg(feature = "cuda")]
7060pub fn gpu_mul(
7061 a: &CudaBuffer<f32>,
7062 b: &CudaBuffer<f32>,
7063 device: &GpuDevice,
7064) -> GpuResult<CudaBuffer<f32>> {
7065 validate_binary(a, b, device)?;
7066
7067 let n = a.len();
7068 if n >= 16 && n % 4 == 0 {
7069 if let Some(out) = try_launch_binary_vec4(
7070 a, b, device, MUL_VEC4_PTX, "mul_vec4_kernel",
7071 )? {
7072 return Ok(out);
7073 }
7074 }
7075
7076 if let Some(out) = try_launch_binary(a, b, device, MUL_PTX, "mul_kernel")? {
7077 return Ok(out);
7078 }
7079
7080 cpu_fallback_binary(a, b, device, |x, y| x * y)
7081}
7082
7083#[cfg(feature = "cuda")]
7096pub fn gpu_broadcast_add(
7097 a: &CudaBuffer<f32>,
7098 b: &CudaBuffer<f32>,
7099 a_shape: &[usize],
7100 b_shape: &[usize],
7101 out_shape: &[usize],
7102 device: &GpuDevice,
7103) -> GpuResult<CudaBuffer<f32>> {
7104 let a_str = broadcast_strides(a_shape, out_shape);
7105 let b_str = broadcast_strides(b_shape, out_shape);
7106 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
7107 let out_numel: usize = out_shape.iter().product();
7108
7109 if let Some(out) = try_launch_broadcast_binary(
7110 a,
7111 b,
7112 &a_str,
7113 &b_str,
7114 &shape_u32,
7115 out_numel,
7116 device,
7117 BROADCAST_ADD_PTX,
7118 "broadcast_add_kernel",
7119 )? {
7120 return Ok(out);
7121 }
7122
7123 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
7125}
7126
7127#[cfg(feature = "cuda")]
7129pub fn gpu_broadcast_sub(
7130 a: &CudaBuffer<f32>,
7131 b: &CudaBuffer<f32>,
7132 a_shape: &[usize],
7133 b_shape: &[usize],
7134 out_shape: &[usize],
7135 device: &GpuDevice,
7136) -> GpuResult<CudaBuffer<f32>> {
7137 let a_str = broadcast_strides(a_shape, out_shape);
7138 let b_str = broadcast_strides(b_shape, out_shape);
7139 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
7140 let out_numel: usize = out_shape.iter().product();
7141
7142 if let Some(out) = try_launch_broadcast_binary(
7143 a,
7144 b,
7145 &a_str,
7146 &b_str,
7147 &shape_u32,
7148 out_numel,
7149 device,
7150 BROADCAST_SUB_PTX,
7151 "broadcast_sub_kernel",
7152 )? {
7153 return Ok(out);
7154 }
7155
7156 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
7157}
7158
7159#[cfg(feature = "cuda")]
7161pub fn gpu_broadcast_mul(
7162 a: &CudaBuffer<f32>,
7163 b: &CudaBuffer<f32>,
7164 a_shape: &[usize],
7165 b_shape: &[usize],
7166 out_shape: &[usize],
7167 device: &GpuDevice,
7168) -> GpuResult<CudaBuffer<f32>> {
7169 let a_str = broadcast_strides(a_shape, out_shape);
7170 let b_str = broadcast_strides(b_shape, out_shape);
7171 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
7172 let out_numel: usize = out_shape.iter().product();
7173
7174 if let Some(out) = try_launch_broadcast_binary(
7175 a,
7176 b,
7177 &a_str,
7178 &b_str,
7179 &shape_u32,
7180 out_numel,
7181 device,
7182 BROADCAST_MUL_PTX,
7183 "broadcast_mul_kernel",
7184 )? {
7185 return Ok(out);
7186 }
7187
7188 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
7189}
7190
7191#[cfg(feature = "cuda")]
7193pub fn gpu_broadcast_div(
7194 a: &CudaBuffer<f32>,
7195 b: &CudaBuffer<f32>,
7196 a_shape: &[usize],
7197 b_shape: &[usize],
7198 out_shape: &[usize],
7199 device: &GpuDevice,
7200) -> GpuResult<CudaBuffer<f32>> {
7201 let a_str = broadcast_strides(a_shape, out_shape);
7202 let b_str = broadcast_strides(b_shape, out_shape);
7203 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
7204 let out_numel: usize = out_shape.iter().product();
7205
7206 if let Some(out) = try_launch_broadcast_binary(
7207 a,
7208 b,
7209 &a_str,
7210 &b_str,
7211 &shape_u32,
7212 out_numel,
7213 device,
7214 BROADCAST_DIV_PTX,
7215 "broadcast_div_kernel",
7216 )? {
7217 return Ok(out);
7218 }
7219
7220 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
7221}
7222
7223#[cfg(feature = "cuda")]
7226fn cpu_fallback_broadcast_binary(
7227 a: &CudaBuffer<f32>,
7228 b: &CudaBuffer<f32>,
7229 a_shape: &[usize],
7230 b_shape: &[usize],
7231 out_shape: &[usize],
7232 device: &GpuDevice,
7233 op: fn(f32, f32) -> f32,
7234) -> GpuResult<CudaBuffer<f32>> {
7235 let a_host = gpu_to_cpu(a, device)?;
7236 let b_host = gpu_to_cpu(b, device)?;
7237 let out_numel: usize = out_shape.iter().product();
7238
7239 let a_str = broadcast_strides(a_shape, out_shape);
7240 let b_str = broadcast_strides(b_shape, out_shape);
7241
7242 let mut result = Vec::with_capacity(out_numel);
7243 for i in 0..out_numel {
7244 let mut remaining = i;
7245 let mut a_idx = 0usize;
7246 let mut b_idx = 0usize;
7247 for d in (0..out_shape.len()).rev() {
7248 let coord = remaining % out_shape[d];
7249 remaining /= out_shape[d];
7250 a_idx += coord * a_str[d] as usize;
7251 b_idx += coord * b_str[d] as usize;
7252 }
7253 result.push(op(a_host[a_idx], b_host[b_idx]));
7254 }
7255 cpu_to_gpu(&result, device)
7256}
7257
7258#[cfg(feature = "cuda")]
7273pub fn gpu_neg(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7274 validate_unary(a, device)?;
7275
7276 if let Some(out) = try_launch_unary(a, device, NEG_PTX, "neg_kernel")? {
7277 return Ok(out);
7278 }
7279
7280 cpu_fallback_unary(a, device, |x| -x)
7281}
7282
7283#[cfg(feature = "cuda")]
7294pub fn gpu_relu(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7295 validate_unary(a, device)?;
7296
7297 if let Some(out) = try_launch_unary(a, device, RELU_PTX, "relu_kernel")? {
7298 return Ok(out);
7299 }
7300
7301 cpu_fallback_unary(a, device, |x| x.max(0.0))
7302}
7303
7304#[cfg(feature = "cuda")]
7306pub fn gpu_relu_backward(
7307 grad: &CudaBuffer<f32>,
7308 input: &CudaBuffer<f32>,
7309 device: &GpuDevice,
7310) -> GpuResult<CudaBuffer<f32>> {
7311 validate_binary(grad, input, device)?;
7312
7313 if let Some(out) = try_launch_binary(
7314 grad,
7315 input,
7316 device,
7317 RELU_BACKWARD_PTX,
7318 "relu_backward_kernel",
7319 )? {
7320 return Ok(out);
7321 }
7322
7323 let grad_host = gpu_to_cpu(grad, device)?;
7325 let input_host = gpu_to_cpu(input, device)?;
7326 let result: Vec<f32> = grad_host
7327 .iter()
7328 .zip(input_host.iter())
7329 .map(|(&g, &x)| if x > 0.0 { g } else { 0.0 })
7330 .collect();
7331 cpu_to_gpu(&result, device)
7332}
7333
7334#[cfg(feature = "cuda")]
7337pub fn gpu_gelu_backward(
7338 grad: &CudaBuffer<f32>,
7339 input: &CudaBuffer<f32>,
7340 device: &GpuDevice,
7341) -> GpuResult<CudaBuffer<f32>> {
7342 validate_binary(grad, input, device)?;
7343
7344 if let Some(out) = try_launch_binary(
7345 grad,
7346 input,
7347 device,
7348 GELU_BACKWARD_PTX,
7349 "gelu_backward_kernel",
7350 )? {
7351 return Ok(out);
7352 }
7353
7354 let grad_host = gpu_to_cpu(grad, device)?;
7356 let input_host = gpu_to_cpu(input, device)?;
7357 let result: Vec<f32> = grad_host
7358 .iter()
7359 .zip(input_host.iter())
7360 .map(|(&g, &x)| {
7361 let k: f32 = 1.702;
7362 let sig = 1.0 / (1.0 + (-k * x).exp());
7363 g * (sig + k * x * sig * (1.0 - sig))
7364 })
7365 .collect();
7366 cpu_to_gpu(&result, device)
7367}
7368
7369#[cfg(feature = "cuda")]
7373pub fn gpu_gelu_backward_erf(
7374 grad: &CudaBuffer<f32>,
7375 input: &CudaBuffer<f32>,
7376 device: &GpuDevice,
7377) -> GpuResult<CudaBuffer<f32>> {
7378 validate_binary(grad, input, device)?;
7379
7380 if let Some(out) = try_launch_binary(
7381 grad,
7382 input,
7383 device,
7384 GELU_BACKWARD_ERF_PTX,
7385 "gelu_backward_erf_kernel",
7386 )? {
7387 return Ok(out);
7388 }
7389
7390 let grad_host = gpu_to_cpu(grad, device)?;
7392 let input_host = gpu_to_cpu(input, device)?;
7393 let inv_sqrt_2: f32 = std::f32::consts::FRAC_1_SQRT_2;
7394 let inv_sqrt_2pi: f32 = 1.0 / (2.0 * std::f32::consts::PI).sqrt();
7395 let result: Vec<f32> = grad_host
7396 .iter()
7397 .zip(input_host.iter())
7398 .map(|(&g, &x)| {
7399 let z = x * inv_sqrt_2;
7400 let az = z.abs();
7401 let t = 1.0 / (1.0 + 0.3275911 * az);
7402 let poly = t * (0.2548296 + t * (-0.2844967 + t * (1.4214137 + t * (-1.4531520 + t * 0.3275911))));
7403 let erf_abs = 1.0 - poly * (-az * az).exp();
7404 let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
7405 let cdf = 0.5 * (1.0 + erf_val);
7406 let pdf = inv_sqrt_2pi * (-0.5 * x * x).exp();
7407 g * (cdf + x * pdf)
7408 })
7409 .collect();
7410 cpu_to_gpu(&result, device)
7411}
7412
7413#[cfg(feature = "cuda")]
7422pub fn gpu_index_select_1d(
7423 input: &CudaBuffer<f32>,
7424 indices: &CudaBuffer<f32>,
7425 device: &GpuDevice,
7426) -> GpuResult<CudaBuffer<f32>> {
7427 use cudarc::driver::PushKernelArg;
7428
7429 validate_unary(input, device)?;
7430
7431 let n = indices.len();
7432 let ctx = device.context();
7433 let stream = device.stream();
7434
7435 let f = match crate::module_cache::get_or_compile(
7436 ctx,
7437 INDEX_SELECT_1D_PTX,
7438 "index_select_1d_kernel",
7439 device.ordinal() as u32,
7440 ) {
7441 Ok(f) => f,
7442 Err(_) => {
7443 let input_host = gpu_to_cpu(input, device)?;
7445 let indices_host = gpu_to_cpu(indices, device)?;
7446 let result: Vec<f32> = indices_host
7447 .iter()
7448 .map(|&idx_f| input_host[idx_f as usize])
7449 .collect();
7450 return cpu_to_gpu(&result, device);
7451 }
7452 };
7453
7454 let mut out = alloc_zeros_f32(n, device)?;
7455 let cfg = launch_cfg(n)?;
7456 let n_u32 = n as u32;
7457
7458 unsafe {
7459 stream
7460 .launch_builder(&f)
7461 .arg(input.inner())
7462 .arg(indices.inner())
7463 .arg(out.inner_mut())
7464 .arg(&n_u32)
7465 .launch(cfg)?;
7466 }
7467
7468 Ok(out)
7469}
7470
7471#[cfg(feature = "cuda")]
7483pub fn gpu_scatter_add_1d(
7484 grad_output: &CudaBuffer<f32>,
7485 indices: &CudaBuffer<f32>,
7486 input_len: usize,
7487 device: &GpuDevice,
7488) -> GpuResult<CudaBuffer<f32>> {
7489 use cudarc::driver::PushKernelArg;
7490
7491 validate_unary(grad_output, device)?;
7492
7493 let n = grad_output.len();
7494 let ctx = device.context();
7495 let stream = device.stream();
7496
7497 let f = match crate::module_cache::get_or_compile(
7498 ctx,
7499 SCATTER_ADD_1D_PTX,
7500 "scatter_add_1d_kernel",
7501 device.ordinal() as u32,
7502 ) {
7503 Ok(f) => f,
7504 Err(_) => {
7505 let go_host = gpu_to_cpu(grad_output, device)?;
7507 let idx_host = gpu_to_cpu(indices, device)?;
7508 let mut result = vec![0.0f32; input_len];
7509 for (i, &idx_f) in idx_host.iter().enumerate() {
7510 result[idx_f as usize] += go_host[i];
7511 }
7512 return cpu_to_gpu(&result, device);
7513 }
7514 };
7515
7516 let mut out = alloc_zeros_f32(input_len, device)?;
7517 let cfg = launch_cfg(n)?;
7518 let n_u32 = n as u32;
7519
7520 unsafe {
7521 stream
7522 .launch_builder(&f)
7523 .arg(grad_output.inner())
7524 .arg(indices.inner())
7525 .arg(out.inner_mut())
7526 .arg(&n_u32)
7527 .launch(cfg)?;
7528 }
7529
7530 Ok(out)
7531}
7532
7533#[cfg(feature = "cuda")]
7542pub fn gpu_masked_fill(
7543 input: &CudaBuffer<f32>,
7544 mask: &CudaBuffer<f32>,
7545 value: f32,
7546 device: &GpuDevice,
7547) -> GpuResult<CudaBuffer<f32>> {
7548 use cudarc::driver::PushKernelArg;
7549
7550 validate_binary(input, mask, device)?;
7551
7552 let n = input.len();
7553 let ctx = device.context();
7554 let stream = device.stream();
7555
7556 let f = match crate::module_cache::get_or_compile(
7557 ctx,
7558 MASKED_FILL_PTX,
7559 "masked_fill_kernel",
7560 device.ordinal() as u32,
7561 ) {
7562 Ok(f) => f,
7563 Err(_) => {
7564 let input_host = gpu_to_cpu(input, device)?;
7566 let mask_host = gpu_to_cpu(mask, device)?;
7567 let result: Vec<f32> = input_host
7568 .iter()
7569 .zip(mask_host.iter())
7570 .map(|(&x, &m)| if m >= 0.5 { value } else { x })
7571 .collect();
7572 return cpu_to_gpu(&result, device);
7573 }
7574 };
7575
7576 let mut out = alloc_zeros_f32(n, device)?;
7577 let cfg = launch_cfg(n)?;
7578 let n_u32 = n as u32;
7579
7580 unsafe {
7581 stream
7582 .launch_builder(&f)
7583 .arg(input.inner())
7584 .arg(mask.inner())
7585 .arg(out.inner_mut())
7586 .arg(&value)
7587 .arg(&n_u32)
7588 .launch(cfg)?;
7589 }
7590
7591 Ok(out)
7592}
7593
7594#[cfg(feature = "cuda")]
7603pub fn gpu_masked_zero(
7604 grad: &CudaBuffer<f32>,
7605 mask: &CudaBuffer<f32>,
7606 device: &GpuDevice,
7607) -> GpuResult<CudaBuffer<f32>> {
7608 validate_binary(grad, mask, device)?;
7609
7610 if let Some(out) = try_launch_binary(grad, mask, device, MASKED_ZERO_PTX, "masked_zero_kernel")?
7611 {
7612 return Ok(out);
7613 }
7614
7615 let grad_host = gpu_to_cpu(grad, device)?;
7617 let mask_host = gpu_to_cpu(mask, device)?;
7618 let result: Vec<f32> = grad_host
7619 .iter()
7620 .zip(mask_host.iter())
7621 .map(|(&g, &m)| if m >= 0.5 { 0.0 } else { g })
7622 .collect();
7623 cpu_to_gpu(&result, device)
7624}
7625
7626#[cfg(feature = "cuda")]
7634pub fn gpu_sigmoid_backward(
7635 grad: &CudaBuffer<f32>,
7636 output: &CudaBuffer<f32>,
7637 device: &GpuDevice,
7638) -> GpuResult<CudaBuffer<f32>> {
7639 validate_binary(grad, output, device)?;
7640
7641 if let Some(out) = try_launch_binary(
7642 grad,
7643 output,
7644 device,
7645 SIGMOID_BACKWARD_PTX,
7646 "sigmoid_backward_kernel",
7647 )? {
7648 return Ok(out);
7649 }
7650
7651 let grad_host = gpu_to_cpu(grad, device)?;
7653 let output_host = gpu_to_cpu(output, device)?;
7654 let result: Vec<f32> = grad_host
7655 .iter()
7656 .zip(output_host.iter())
7657 .map(|(&g, &o)| g * o * (1.0 - o))
7658 .collect();
7659 cpu_to_gpu(&result, device)
7660}
7661
7662#[cfg(feature = "cuda")]
7670pub fn gpu_tanh_backward(
7671 grad: &CudaBuffer<f32>,
7672 output: &CudaBuffer<f32>,
7673 device: &GpuDevice,
7674) -> GpuResult<CudaBuffer<f32>> {
7675 validate_binary(grad, output, device)?;
7676
7677 if let Some(out) = try_launch_binary(
7678 grad,
7679 output,
7680 device,
7681 TANH_BACKWARD_PTX,
7682 "tanh_backward_kernel",
7683 )? {
7684 return Ok(out);
7685 }
7686
7687 let grad_host = gpu_to_cpu(grad, device)?;
7689 let output_host = gpu_to_cpu(output, device)?;
7690 let result: Vec<f32> = grad_host
7691 .iter()
7692 .zip(output_host.iter())
7693 .map(|(&g, &o)| g * (1.0 - o * o))
7694 .collect();
7695 cpu_to_gpu(&result, device)
7696}
7697
7698#[cfg(feature = "cuda")]
7710pub fn gpu_softmax_backward(
7711 grad: &CudaBuffer<f32>,
7712 output: &CudaBuffer<f32>,
7713 cols: usize,
7714 device: &GpuDevice,
7715) -> GpuResult<CudaBuffer<f32>> {
7716 use cudarc::driver::PushKernelArg;
7717
7718 validate_binary(grad, output, device)?;
7719
7720 let total = grad.len();
7721 let rows = total / cols;
7722
7723 let ctx = device.context();
7724 let stream = device.stream();
7725
7726 let f = match crate::module_cache::get_or_compile(
7727 ctx,
7728 SOFTMAX_BACKWARD_PTX,
7729 "softmax_backward_kernel",
7730 device.ordinal() as u32,
7731 ) {
7732 Ok(f) => f,
7733 Err(_) => {
7734 let grad_host = gpu_to_cpu(grad, device)?;
7736 let output_host = gpu_to_cpu(output, device)?;
7737 let mut result = vec![0.0f32; total];
7738 for r in 0..rows {
7739 let base = r * cols;
7740 let mut dot = 0.0f32;
7741 for c in 0..cols {
7742 dot += grad_host[base + c] * output_host[base + c];
7743 }
7744 for c in 0..cols {
7745 result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
7746 }
7747 }
7748 return cpu_to_gpu(&result, device);
7749 }
7750 };
7751
7752 let mut out = alloc_zeros_f32(total, device)?;
7753 let rows_u32 = rows as u32;
7754 let cols_u32 = cols as u32;
7755
7756 let cfg = LaunchConfig {
7758 grid_dim: ((rows as u32).max(1), 1, 1),
7759 block_dim: (256, 1, 1),
7760 shared_mem_bytes: 256 * 4,
7761 };
7762
7763 unsafe {
7764 stream
7765 .launch_builder(&f)
7766 .arg(grad.inner())
7767 .arg(output.inner())
7768 .arg(out.inner_mut())
7769 .arg(&rows_u32)
7770 .arg(&cols_u32)
7771 .launch(cfg)?;
7772 }
7773
7774 Ok(out)
7775}
7776
7777#[cfg(feature = "cuda")]
7788pub fn gpu_log_softmax(
7789 input: &CudaBuffer<f32>,
7790 cols: usize,
7791 device: &GpuDevice,
7792) -> GpuResult<CudaBuffer<f32>> {
7793 use cudarc::driver::PushKernelArg;
7794
7795 validate_unary(input, device)?;
7796
7797 let total = input.len();
7798 let rows = total / cols;
7799
7800 let ctx = device.context();
7801 let stream = device.stream();
7802
7803 let f = match crate::module_cache::get_or_compile(
7804 ctx,
7805 LOG_SOFTMAX_PTX,
7806 "log_softmax_kernel",
7807 device.ordinal() as u32,
7808 ) {
7809 Ok(f) => f,
7810 Err(_) => {
7811 let host = gpu_to_cpu(input, device)?;
7813 let mut out = vec![0.0f32; total];
7814 for r in 0..rows {
7815 let base = r * cols;
7816 let mut max_v = f32::NEG_INFINITY;
7817 for c in 0..cols {
7818 max_v = max_v.max(host[base + c]);
7819 }
7820 let mut sum_exp = 0.0f32;
7821 for c in 0..cols {
7822 sum_exp += (host[base + c] - max_v).exp();
7823 }
7824 let log_sum_exp = max_v + sum_exp.ln();
7825 for c in 0..cols {
7826 out[base + c] = host[base + c] - log_sum_exp;
7827 }
7828 }
7829 return cpu_to_gpu(&out, device);
7830 }
7831 };
7832
7833 let mut out = alloc_zeros_f32(total, device)?;
7834 let rows_u32 = rows as u32;
7835 let cols_u32 = cols as u32;
7836
7837 let cfg = LaunchConfig {
7839 grid_dim: ((rows as u32).max(1), 1, 1),
7840 block_dim: (256, 1, 1),
7841 shared_mem_bytes: 256 * 4,
7842 };
7843
7844 unsafe {
7845 stream
7846 .launch_builder(&f)
7847 .arg(input.inner())
7848 .arg(out.inner_mut())
7849 .arg(&rows_u32)
7850 .arg(&cols_u32)
7851 .launch(cfg)?;
7852 }
7853
7854 Ok(out)
7855}
7856
7857#[cfg(feature = "cuda")]
7865pub fn gpu_log_softmax_backward(
7866 grad: &CudaBuffer<f32>,
7867 output: &CudaBuffer<f32>,
7868 cols: usize,
7869 device: &GpuDevice,
7870) -> GpuResult<CudaBuffer<f32>> {
7871 use cudarc::driver::PushKernelArg;
7872
7873 validate_binary(grad, output, device)?;
7874
7875 let total = grad.len();
7876 let rows = total / cols;
7877
7878 let ctx = device.context();
7879 let stream = device.stream();
7880
7881 let f = match crate::module_cache::get_or_compile(
7882 ctx,
7883 LOG_SOFTMAX_BACKWARD_PTX,
7884 "log_softmax_backward_kernel",
7885 device.ordinal() as u32,
7886 ) {
7887 Ok(f) => f,
7888 Err(_) => {
7889 let grad_host = gpu_to_cpu(grad, device)?;
7891 let output_host = gpu_to_cpu(output, device)?;
7892 let mut result = vec![0.0f32; total];
7893 for r in 0..rows {
7894 let base = r * cols;
7895 let mut sum_grad = 0.0f32;
7896 for c in 0..cols {
7897 sum_grad += grad_host[base + c];
7898 }
7899 for c in 0..cols {
7900 result[base + c] =
7901 grad_host[base + c] - output_host[base + c].exp() * sum_grad;
7902 }
7903 }
7904 return cpu_to_gpu(&result, device);
7905 }
7906 };
7907
7908 let mut out = alloc_zeros_f32(total, device)?;
7909 let rows_u32 = rows as u32;
7910 let cols_u32 = cols as u32;
7911
7912 let cfg = LaunchConfig {
7914 grid_dim: ((rows as u32).max(1), 1, 1),
7915 block_dim: (256, 1, 1),
7916 shared_mem_bytes: 256 * 4,
7917 };
7918
7919 unsafe {
7920 stream
7921 .launch_builder(&f)
7922 .arg(grad.inner())
7923 .arg(output.inner())
7924 .arg(out.inner_mut())
7925 .arg(&rows_u32)
7926 .arg(&cols_u32)
7927 .launch(cfg)?;
7928 }
7929
7930 Ok(out)
7931}
7932
7933#[cfg(feature = "cuda")]
7947pub fn gpu_reduce_sum(
7948 a: &CudaBuffer<f32>,
7949 device: &GpuDevice,
7950) -> GpuResult<CudaBuffer<f32>> {
7951 use cudarc::driver::PushKernelArg;
7952
7953 let n = a.len();
7954 if n == 0 {
7955 return cpu_to_gpu(&[0.0f32], device);
7956 }
7957
7958 let ctx = device.context();
7959 let stream = device.stream();
7960
7961 let f = match crate::module_cache::get_or_compile(
7962 ctx,
7963 REDUCE_SUM_PTX,
7964 "reduce_sum_kernel",
7965 device.ordinal() as u32,
7966 ) {
7967 Ok(f) => f,
7968 Err(_) => {
7969 let host = gpu_to_cpu(a, device)?;
7971 let total: f32 = host.iter().sum();
7972 return cpu_to_gpu(&[total], device);
7973 }
7974 };
7975
7976 const BLOCK: u32 = 256;
7978 let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
7979 let num_blocks = num_blocks.min(1024);
7981
7982 let mut partials = alloc_zeros_f32(num_blocks as usize, device)?;
7983 let n_u32 = n as u32;
7984
7985 let cfg = cudarc::driver::LaunchConfig {
7986 grid_dim: (num_blocks.max(1), 1, 1),
7987 block_dim: (BLOCK, 1, 1),
7988 shared_mem_bytes: 0, };
7990
7991 unsafe {
7992 stream
7993 .launch_builder(&f)
7994 .arg(a.inner())
7995 .arg(partials.inner_mut())
7996 .arg(&n_u32)
7997 .launch(cfg)?;
7998 }
7999
8000 if num_blocks <= 1 {
8002 return Ok(partials);
8003 }
8004
8005 if num_blocks <= 256 {
8007 let host_partials = gpu_to_cpu(&partials, device)?;
8008 let total: f32 = host_partials.iter().sum();
8009 return cpu_to_gpu(&[total], device);
8010 }
8011
8012 gpu_reduce_sum(&partials, device)
8014}
8015
8016#[cfg(not(feature = "cuda"))]
8018pub fn gpu_reduce_sum(
8019 _a: &CudaBuffer<f32>,
8020 _device: &GpuDevice,
8021) -> GpuResult<CudaBuffer<f32>> {
8022 Err(GpuError::NoCudaFeature)
8023}
8024
8025#[cfg(feature = "cuda")]
8029pub fn gpu_sum_axis(
8030 a: &CudaBuffer<f32>,
8031 outer: usize,
8032 axis_size: usize,
8033 inner: usize,
8034 device: &GpuDevice,
8035) -> GpuResult<CudaBuffer<f32>> {
8036 use cudarc::driver::PushKernelArg;
8037
8038 validate_unary(a, device)?;
8039
8040 let total_output = outer * inner;
8041 let ctx = device.context();
8042 let stream = device.stream();
8043
8044 let f = match crate::module_cache::get_or_compile(
8045 ctx,
8046 SUM_AXIS_PTX,
8047 "sum_axis_kernel",
8048 device.ordinal() as u32,
8049 ) {
8050 Ok(f) => f,
8051 Err(_) => {
8052 let host = gpu_to_cpu(a, device)?;
8054 let mut result = vec![0.0f32; total_output];
8055 for (i, out) in result.iter_mut().enumerate() {
8056 let outer_idx = i / inner;
8057 let inner_idx = i % inner;
8058 let mut sum = 0.0f32;
8059 for k in 0..axis_size {
8060 sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
8061 }
8062 *out = sum;
8063 }
8064 return cpu_to_gpu(&result, device);
8065 }
8066 };
8067
8068 let mut out = alloc_zeros_f32(total_output, device)?;
8069 let cfg = launch_cfg(total_output)?;
8070 let outer_u32 = outer as u32;
8071 let axis_size_u32 = axis_size as u32;
8072 let inner_u32 = inner as u32;
8073 let total_u32 = total_output as u32;
8074
8075 unsafe {
8076 stream
8077 .launch_builder(&f)
8078 .arg(a.inner())
8079 .arg(out.inner_mut())
8080 .arg(&outer_u32)
8081 .arg(&axis_size_u32)
8082 .arg(&inner_u32)
8083 .arg(&total_u32)
8084 .launch(cfg)?;
8085 }
8086
8087 Ok(out)
8088}
8089
8090#[cfg(feature = "cuda")]
8107pub fn gpu_cumsum(
8108 input: &CudaBuffer<f32>,
8109 outer: usize,
8110 dim_size: usize,
8111 inner: usize,
8112 device: &GpuDevice,
8113) -> GpuResult<CudaBuffer<f32>> {
8114 use cudarc::driver::PushKernelArg;
8115
8116 validate_unary(input, device)?;
8117
8118 let total = outer * dim_size * inner;
8119 let num_threads = outer * inner;
8120 let ctx = device.context();
8121 let stream = device.stream();
8122
8123 let f = match crate::module_cache::get_or_compile(
8124 ctx,
8125 CUMSUM_PTX,
8126 "cumsum_kernel",
8127 device.ordinal() as u32,
8128 ) {
8129 Ok(f) => f,
8130 Err(_) => {
8131 let host = gpu_to_cpu(input, device)?;
8133 let mut result = vec![0.0f32; total];
8134 for i in 0..num_threads {
8135 let outer_idx = i / inner;
8136 let inner_idx = i % inner;
8137 let base = outer_idx * dim_size * inner + inner_idx;
8138 let mut acc = 0.0f32;
8139 for k in 0..dim_size {
8140 let idx = base + k * inner;
8141 acc += host[idx];
8142 result[idx] = acc;
8143 }
8144 }
8145 return cpu_to_gpu(&result, device);
8146 }
8147 };
8148
8149 let mut out = alloc_zeros_f32(total, device)?;
8150 let cfg = launch_cfg(num_threads)?;
8151 let outer_u32 = outer as u32;
8152 let dim_size_u32 = dim_size as u32;
8153 let inner_u32 = inner as u32;
8154 let total_u32 = total as u32;
8155
8156 unsafe {
8157 stream
8158 .launch_builder(&f)
8159 .arg(input.inner())
8160 .arg(out.inner_mut())
8161 .arg(&outer_u32)
8162 .arg(&dim_size_u32)
8163 .arg(&inner_u32)
8164 .arg(&total_u32)
8165 .launch(cfg)?;
8166 }
8167
8168 Ok(out)
8169}
8170
8171#[cfg(feature = "cuda")]
8181pub fn gpu_cumprod(
8182 input: &CudaBuffer<f32>,
8183 outer: usize,
8184 dim_size: usize,
8185 inner: usize,
8186 device: &GpuDevice,
8187) -> GpuResult<CudaBuffer<f32>> {
8188 use cudarc::driver::PushKernelArg;
8189
8190 validate_unary(input, device)?;
8191
8192 let total = outer * dim_size * inner;
8193 let num_threads = outer * inner;
8194 let ctx = device.context();
8195 let stream = device.stream();
8196
8197 let f = match crate::module_cache::get_or_compile(
8198 ctx,
8199 CUMPROD_PTX,
8200 "cumprod_kernel",
8201 device.ordinal() as u32,
8202 ) {
8203 Ok(f) => f,
8204 Err(_) => {
8205 let host = gpu_to_cpu(input, device)?;
8207 let mut result = vec![0.0f32; total];
8208 for i in 0..num_threads {
8209 let outer_idx = i / inner;
8210 let inner_idx = i % inner;
8211 let base = outer_idx * dim_size * inner + inner_idx;
8212 let mut acc = 1.0f32;
8213 for k in 0..dim_size {
8214 let idx = base + k * inner;
8215 acc *= host[idx];
8216 result[idx] = acc;
8217 }
8218 }
8219 return cpu_to_gpu(&result, device);
8220 }
8221 };
8222
8223 let mut out = alloc_zeros_f32(total, device)?;
8224 let cfg = launch_cfg(num_threads)?;
8225 let outer_u32 = outer as u32;
8226 let dim_size_u32 = dim_size as u32;
8227 let inner_u32 = inner as u32;
8228 let total_u32 = total as u32;
8229
8230 unsafe {
8231 stream
8232 .launch_builder(&f)
8233 .arg(input.inner())
8234 .arg(out.inner_mut())
8235 .arg(&outer_u32)
8236 .arg(&dim_size_u32)
8237 .arg(&inner_u32)
8238 .arg(&total_u32)
8239 .launch(cfg)?;
8240 }
8241
8242 Ok(out)
8243}
8244
8245#[cfg(feature = "cuda")]
8255pub fn gpu_cummax(
8256 input: &CudaBuffer<f32>,
8257 outer: usize,
8258 dim_size: usize,
8259 inner: usize,
8260 device: &GpuDevice,
8261) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
8262 use cudarc::driver::PushKernelArg;
8263
8264 validate_unary(input, device)?;
8265
8266 let total = outer * dim_size * inner;
8267 let num_threads = outer * inner;
8268 let ctx = device.context();
8269 let stream = device.stream();
8270
8271 let f = match crate::module_cache::get_or_compile(
8272 ctx,
8273 CUMMAX_PTX,
8274 "cummax_kernel",
8275 device.ordinal() as u32,
8276 ) {
8277 Ok(f) => f,
8278 Err(_) => {
8279 let host = gpu_to_cpu(input, device)?;
8280 let mut vals = vec![0.0f32; total];
8281 let mut idxs = vec![0.0f32; total];
8282 for i in 0..num_threads {
8283 let outer_idx = i / inner;
8284 let inner_idx = i % inner;
8285 let base = outer_idx * dim_size * inner + inner_idx;
8286 let mut acc = f32::NEG_INFINITY;
8287 let mut best = 0u32;
8288 for k in 0..dim_size {
8289 let idx = base + k * inner;
8290 if host[idx] > acc {
8291 acc = host[idx];
8292 best = k as u32;
8293 }
8294 vals[idx] = acc;
8295 idxs[idx] = best as f32;
8296 }
8297 }
8298 return Ok((cpu_to_gpu(&vals, device)?, cpu_to_gpu(&idxs, device)?));
8299 }
8300 };
8301
8302 let mut out = alloc_zeros_f32(total, device)?;
8303 let mut out_idx = alloc_zeros_f32(total, device)?;
8304 let cfg = launch_cfg(num_threads)?;
8305 let outer_u32 = outer as u32;
8306 let dim_size_u32 = dim_size as u32;
8307 let inner_u32 = inner as u32;
8308 let total_u32 = total as u32;
8309
8310 unsafe {
8311 stream
8312 .launch_builder(&f)
8313 .arg(input.inner())
8314 .arg(out.inner_mut())
8315 .arg(out_idx.inner_mut())
8316 .arg(&outer_u32)
8317 .arg(&dim_size_u32)
8318 .arg(&inner_u32)
8319 .arg(&total_u32)
8320 .launch(cfg)?;
8321 }
8322
8323 Ok((out, out_idx))
8324}
8325
8326#[cfg(feature = "cuda")]
8336pub fn gpu_cummin(
8337 input: &CudaBuffer<f32>,
8338 outer: usize,
8339 dim_size: usize,
8340 inner: usize,
8341 device: &GpuDevice,
8342) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
8343 use cudarc::driver::PushKernelArg;
8344
8345 validate_unary(input, device)?;
8346
8347 let total = outer * dim_size * inner;
8348 let num_threads = outer * inner;
8349 let ctx = device.context();
8350 let stream = device.stream();
8351
8352 let f = match crate::module_cache::get_or_compile(
8353 ctx,
8354 CUMMIN_PTX,
8355 "cummin_kernel",
8356 device.ordinal() as u32,
8357 ) {
8358 Ok(f) => f,
8359 Err(_) => {
8360 let host = gpu_to_cpu(input, device)?;
8361 let mut vals = vec![0.0f32; total];
8362 let mut idxs = vec![0.0f32; total];
8363 for i in 0..num_threads {
8364 let outer_idx = i / inner;
8365 let inner_idx = i % inner;
8366 let base = outer_idx * dim_size * inner + inner_idx;
8367 let mut acc = f32::INFINITY;
8368 let mut best = 0u32;
8369 for k in 0..dim_size {
8370 let idx = base + k * inner;
8371 if host[idx] < acc {
8372 acc = host[idx];
8373 best = k as u32;
8374 }
8375 vals[idx] = acc;
8376 idxs[idx] = best as f32;
8377 }
8378 }
8379 return Ok((cpu_to_gpu(&vals, device)?, cpu_to_gpu(&idxs, device)?));
8380 }
8381 };
8382
8383 let mut out = alloc_zeros_f32(total, device)?;
8384 let mut out_idx = alloc_zeros_f32(total, device)?;
8385 let cfg = launch_cfg(num_threads)?;
8386 let outer_u32 = outer as u32;
8387 let dim_size_u32 = dim_size as u32;
8388 let inner_u32 = inner as u32;
8389 let total_u32 = total as u32;
8390
8391 unsafe {
8392 stream
8393 .launch_builder(&f)
8394 .arg(input.inner())
8395 .arg(out.inner_mut())
8396 .arg(out_idx.inner_mut())
8397 .arg(&outer_u32)
8398 .arg(&dim_size_u32)
8399 .arg(&inner_u32)
8400 .arg(&total_u32)
8401 .launch(cfg)?;
8402 }
8403
8404 Ok((out, out_idx))
8405}
8406
8407#[cfg(feature = "cuda")]
8417pub fn gpu_logcumsumexp(
8418 input: &CudaBuffer<f32>,
8419 outer: usize,
8420 dim_size: usize,
8421 inner: usize,
8422 device: &GpuDevice,
8423) -> GpuResult<CudaBuffer<f32>> {
8424 use cudarc::driver::PushKernelArg;
8425
8426 validate_unary(input, device)?;
8427
8428 let total = outer * dim_size * inner;
8429 let num_threads = outer * inner;
8430 let ctx = device.context();
8431 let stream = device.stream();
8432
8433 let f = match crate::module_cache::get_or_compile(
8434 ctx,
8435 LOGCUMSUMEXP_PTX,
8436 "logcumsumexp_kernel",
8437 device.ordinal() as u32,
8438 ) {
8439 Ok(f) => f,
8440 Err(_) => {
8441 let host = gpu_to_cpu(input, device)?;
8443 let mut result = vec![0.0f32; total];
8444 for i in 0..num_threads {
8445 let outer_idx = i / inner;
8446 let inner_idx = i % inner;
8447 let base = outer_idx * dim_size * inner + inner_idx;
8448 let mut acc = f32::NEG_INFINITY;
8449 for k in 0..dim_size {
8450 let idx = base + k * inner;
8451 let x = host[idx];
8452 let m = acc.max(x);
8453 acc = m + ((acc - m).exp() + (x - m).exp()).ln();
8454 result[idx] = acc;
8455 }
8456 }
8457 return cpu_to_gpu(&result, device);
8458 }
8459 };
8460
8461 let mut out = alloc_zeros_f32(total, device)?;
8462 let cfg = launch_cfg(num_threads)?;
8463 let outer_u32 = outer as u32;
8464 let dim_size_u32 = dim_size as u32;
8465 let inner_u32 = inner as u32;
8466 let total_u32 = total as u32;
8467
8468 unsafe {
8469 stream
8470 .launch_builder(&f)
8471 .arg(input.inner())
8472 .arg(out.inner_mut())
8473 .arg(&outer_u32)
8474 .arg(&dim_size_u32)
8475 .arg(&inner_u32)
8476 .arg(&total_u32)
8477 .launch(cfg)?;
8478 }
8479
8480 Ok(out)
8481}
8482
8483#[cfg(feature = "cuda")]
8501pub fn gpu_strided_split(
8502 input: &CudaBuffer<f32>,
8503 total_along_axis: usize,
8504 split_offset: usize,
8505 split_size: usize,
8506 inner_size: usize,
8507 n: usize,
8508 device: &GpuDevice,
8509) -> GpuResult<CudaBuffer<f32>> {
8510 use cudarc::driver::PushKernelArg;
8511
8512 validate_unary(input, device)?;
8513
8514 let ctx = device.context();
8515 let stream = device.stream();
8516
8517 let f = match crate::module_cache::get_or_compile(
8518 ctx,
8519 STRIDED_SPLIT_PTX,
8520 "strided_split_kernel",
8521 device.ordinal() as u32,
8522 ) {
8523 Ok(f) => f,
8524 Err(_) => {
8525 let host = gpu_to_cpu(input, device)?;
8527 let outer = n / (split_size * inner_size);
8528 let mut result = vec![0.0f32; n];
8529 for (i, out) in result.iter_mut().enumerate() {
8530 let outer_idx = i / (split_size * inner_size);
8531 let within = i % (split_size * inner_size);
8532 let src_idx =
8533 outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
8534 *out = host[src_idx];
8535 }
8536 let _ = outer;
8537 return cpu_to_gpu(&result, device);
8538 }
8539 };
8540
8541 let mut out = alloc_zeros_f32(n, device)?;
8542 let cfg = launch_cfg(n)?;
8543 let total_ax_u32 = total_along_axis as u32;
8544 let offset_u32 = split_offset as u32;
8545 let split_sz_u32 = split_size as u32;
8546 let inner_u32 = inner_size as u32;
8547 let n_u32 = n as u32;
8548
8549 unsafe {
8550 stream
8551 .launch_builder(&f)
8552 .arg(input.inner())
8553 .arg(out.inner_mut())
8554 .arg(&total_ax_u32)
8555 .arg(&offset_u32)
8556 .arg(&split_sz_u32)
8557 .arg(&inner_u32)
8558 .arg(&n_u32)
8559 .launch(cfg)?;
8560 }
8561
8562 Ok(out)
8563}
8564
8565#[cfg(feature = "cuda")]
8589#[allow(clippy::too_many_arguments)]
8590pub fn gpu_strided_cat(
8591 input: &CudaBuffer<f32>,
8592 output: &mut CudaBuffer<f32>,
8593 total_along_axis: usize,
8594 cat_offset: usize,
8595 part_size: usize,
8596 inner_size: usize,
8597 n: usize,
8598 device: &GpuDevice,
8599) -> GpuResult<()> {
8600 use cudarc::driver::PushKernelArg;
8601
8602 validate_unary(input, device)?;
8603
8604 let ctx = device.context();
8605 let stream = device.stream();
8606
8607 let f = match crate::module_cache::get_or_compile(
8608 ctx,
8609 STRIDED_CAT_PTX,
8610 "strided_cat_kernel",
8611 device.ordinal() as u32,
8612 ) {
8613 Ok(f) => f,
8614 Err(_) => {
8615 let host_in = gpu_to_cpu(input, device)?;
8617 let mut host_out = gpu_to_cpu(output, device)?;
8618 for (i, &val) in host_in.iter().enumerate().take(n) {
8619 let outer_idx = i / (part_size * inner_size);
8620 let within = i % (part_size * inner_size);
8621 let dst_idx =
8622 outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
8623 host_out[dst_idx] = val;
8624 }
8625 *output = cpu_to_gpu(&host_out, device)?;
8626 return Ok(());
8627 }
8628 };
8629
8630 let cfg = launch_cfg(n)?;
8631 let total_ax_u32 = total_along_axis as u32;
8632 let offset_u32 = cat_offset as u32;
8633 let part_sz_u32 = part_size as u32;
8634 let inner_u32 = inner_size as u32;
8635 let n_u32 = n as u32;
8636
8637 unsafe {
8638 stream
8639 .launch_builder(&f)
8640 .arg(input.inner())
8641 .arg(output.inner_mut())
8642 .arg(&total_ax_u32)
8643 .arg(&offset_u32)
8644 .arg(&part_sz_u32)
8645 .arg(&inner_u32)
8646 .arg(&n_u32)
8647 .launch(cfg)?;
8648 }
8649
8650 Ok(())
8651}
8652
8653#[cfg(feature = "cuda")]
8662pub fn gpu_scale(
8663 a: &CudaBuffer<f32>,
8664 scalar: f32,
8665 device: &GpuDevice,
8666) -> GpuResult<CudaBuffer<f32>> {
8667 use cudarc::driver::PushKernelArg;
8668
8669 validate_unary(a, device)?;
8670
8671 let n = a.len();
8672 let ctx = device.context();
8673 let stream = device.stream();
8674
8675 let f = match crate::module_cache::get_or_compile(
8676 ctx,
8677 SCALE_PTX,
8678 "scale_kernel",
8679 device.ordinal() as u32,
8680 ) {
8681 Ok(f) => f,
8682 Err(_) => {
8683 let host = gpu_to_cpu(a, device)?;
8685 let result: Vec<f32> = host.iter().map(|&x| x * scalar).collect();
8686 return cpu_to_gpu(&result, device);
8687 }
8688 };
8689
8690 let mut out = alloc_zeros_f32(n, device)?;
8691 let cfg = launch_cfg(n)?;
8692 let n_u32 = n as u32;
8693
8694 unsafe {
8695 stream
8696 .launch_builder(&f)
8697 .arg(a.inner())
8698 .arg(out.inner_mut())
8699 .arg(&scalar)
8700 .arg(&n_u32)
8701 .launch(cfg)?;
8702 }
8703
8704 Ok(out)
8705}
8706
8707#[cfg(feature = "cuda")]
8715pub fn gpu_softmax(
8716 input: &CudaBuffer<f32>,
8717 rows: usize,
8718 cols: usize,
8719 device: &GpuDevice,
8720) -> GpuResult<CudaBuffer<f32>> {
8721 use cudarc::driver::PushKernelArg;
8722
8723 validate_unary(input, device)?;
8724
8725 let ctx = device.context();
8726 let stream = device.stream();
8727
8728 let f = match crate::module_cache::get_or_compile(
8729 ctx,
8730 SOFTMAX_PTX,
8731 "softmax_kernel",
8732 device.ordinal() as u32,
8733 ) {
8734 Ok(f) => f,
8735 Err(_) => {
8736 let host = gpu_to_cpu(input, device)?;
8738 let mut out = vec![0.0f32; host.len()];
8739 for r in 0..rows {
8740 let base = r * cols;
8741 let mut max_v = f32::NEG_INFINITY;
8742 for c in 0..cols {
8743 max_v = max_v.max(host[base + c]);
8744 }
8745 let mut sum = 0.0f32;
8746 for c in 0..cols {
8747 let e = (host[base + c] - max_v).exp();
8748 out[base + c] = e;
8749 sum += e;
8750 }
8751 let inv = 1.0 / sum;
8752 for c in 0..cols {
8753 out[base + c] *= inv;
8754 }
8755 }
8756 return cpu_to_gpu(&out, device);
8757 }
8758 };
8759
8760 let mut out = alloc_zeros_f32(rows * cols, device)?;
8761 let rows_u32 = rows as u32;
8762 let cols_u32 = cols as u32;
8763
8764 let cfg = LaunchConfig {
8766 grid_dim: ((rows as u32).max(1), 1, 1),
8767 block_dim: (256, 1, 1),
8768 shared_mem_bytes: 256 * 4, };
8770
8771 unsafe {
8772 stream
8773 .launch_builder(&f)
8774 .arg(input.inner())
8775 .arg(out.inner_mut())
8776 .arg(&rows_u32)
8777 .arg(&cols_u32)
8778 .launch(cfg)?;
8779 }
8780
8781 Ok(out)
8782}
8783
8784#[cfg(feature = "cuda")]
8803pub fn gpu_dropout(
8804 input: &CudaBuffer<f32>,
8805 threshold: u32,
8806 scale: f32,
8807 seed: u32,
8808 device: &GpuDevice,
8809) -> GpuResult<CudaBuffer<f32>> {
8810 use cudarc::driver::PushKernelArg;
8811
8812 validate_unary(input, device)?;
8813
8814 let n = input.len();
8815 let ctx = device.context();
8816 let stream = device.stream();
8817
8818 let f = match crate::module_cache::get_or_compile(
8819 ctx,
8820 DROPOUT_PTX,
8821 "dropout_kernel",
8822 device.ordinal() as u32,
8823 ) {
8824 Ok(f) => f,
8825 Err(_) => {
8826 let host = gpu_to_cpu(input, device)?;
8828 let result: Vec<f32> = host
8832 .iter()
8833 .enumerate()
8834 .map(|(i, &x)| {
8835 let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
8836 r ^= r << 13;
8837 r ^= r >> 17;
8838 r ^= r << 5;
8839 if r < threshold { 0.0 } else { x * scale }
8840 })
8841 .collect();
8842 return cpu_to_gpu(&result, device);
8843 }
8844 };
8845
8846 let mut out = alloc_zeros_f32(n, device)?;
8847 let cfg = launch_cfg(n)?;
8848 let n_u32 = n as u32;
8849
8850 unsafe {
8851 stream
8852 .launch_builder(&f)
8853 .arg(input.inner())
8854 .arg(out.inner_mut())
8855 .arg(&n_u32)
8856 .arg(&threshold)
8857 .arg(&scale)
8858 .arg(&seed)
8859 .launch(cfg)?;
8860 }
8861
8862 Ok(out)
8863}
8864
8865#[cfg(feature = "cuda")]
8871pub fn gpu_transpose_2d(
8872 input: &CudaBuffer<f32>,
8873 m: usize,
8874 n: usize,
8875 device: &GpuDevice,
8876) -> GpuResult<CudaBuffer<f32>> {
8877 use cudarc::driver::PushKernelArg;
8878
8879 validate_unary(input, device)?;
8880
8881 let total = m * n;
8882 let ctx = device.context();
8883 let stream = device.stream();
8884
8885 let f = match crate::module_cache::get_or_compile(
8886 ctx,
8887 TRANSPOSE_2D_PTX,
8888 "transpose_2d_kernel",
8889 device.ordinal() as u32,
8890 ) {
8891 Ok(f) => f,
8892 Err(_) => {
8893 let host = gpu_to_cpu(input, device)?;
8895 let mut out = vec![0.0f32; total];
8896 for i in 0..m {
8897 for j in 0..n {
8898 out[j * m + i] = host[i * n + j];
8899 }
8900 }
8901 return cpu_to_gpu(&out, device);
8902 }
8903 };
8904
8905 let mut out = alloc_zeros_f32(total, device)?;
8906 let cfg = launch_cfg(total)?;
8907 let m_u32 = m as u32;
8908 let n_u32 = n as u32;
8909 let total_u32 = total as u32;
8910
8911 unsafe {
8912 stream
8913 .launch_builder(&f)
8914 .arg(input.inner())
8915 .arg(out.inner_mut())
8916 .arg(&m_u32)
8917 .arg(&n_u32)
8918 .arg(&total_u32)
8919 .launch(cfg)?;
8920 }
8921
8922 Ok(out)
8923}
8924
8925#[cfg(feature = "cuda")]
8932pub fn gpu_permute_0213(
8933 input: &CudaBuffer<f32>,
8934 d0: usize,
8935 d1: usize,
8936 d2: usize,
8937 d3: usize,
8938 device: &GpuDevice,
8939) -> GpuResult<CudaBuffer<f32>> {
8940 use cudarc::driver::PushKernelArg;
8941
8942 validate_unary(input, device)?;
8943
8944 let total = d0 * d1 * d2 * d3;
8945 let ctx = device.context();
8946 let stream = device.stream();
8947
8948 let f = match crate::module_cache::get_or_compile(
8949 ctx,
8950 PERMUTE_0213_PTX,
8951 "permute_0213_kernel",
8952 device.ordinal() as u32,
8953 ) {
8954 Ok(f) => f,
8955 Err(_) => {
8956 let host = gpu_to_cpu(input, device)?;
8958 let mut out = vec![0.0f32; total];
8959 for i0 in 0..d0 {
8960 for i1 in 0..d1 {
8961 for i2 in 0..d2 {
8962 for i3 in 0..d3 {
8963 let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
8964 let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
8965 out[out_idx] = host[in_idx];
8966 }
8967 }
8968 }
8969 }
8970 return cpu_to_gpu(&out, device);
8971 }
8972 };
8973
8974 let mut out = alloc_zeros_f32(total, device)?;
8975 let cfg = launch_cfg(total)?;
8976 let d0_u32 = d0 as u32;
8977 let d1_u32 = d1 as u32;
8978 let d2_u32 = d2 as u32;
8979 let d3_u32 = d3 as u32;
8980 let total_u32 = total as u32;
8981
8982 unsafe {
8983 stream
8984 .launch_builder(&f)
8985 .arg(input.inner())
8986 .arg(out.inner_mut())
8987 .arg(&d0_u32)
8988 .arg(&d1_u32)
8989 .arg(&d2_u32)
8990 .arg(&d3_u32)
8991 .arg(&total_u32)
8992 .launch(cfg)?;
8993 }
8994
8995 Ok(out)
8996}
8997
8998#[cfg(feature = "cuda")]
9007pub fn gpu_small_matmul(
9008 a: &CudaBuffer<f32>,
9009 b: &CudaBuffer<f32>,
9010 m: usize,
9011 k: usize,
9012 n: usize,
9013 device: &GpuDevice,
9014) -> GpuResult<CudaBuffer<f32>> {
9015 use cudarc::driver::PushKernelArg;
9016
9017 let total = m * n;
9018 let ctx = device.context();
9019 let stream = device.stream();
9020
9021 let f = match crate::module_cache::get_or_compile(
9022 ctx,
9023 SMALL_MATMUL_PTX,
9024 "small_matmul_kernel",
9025 device.ordinal() as u32,
9026 ) {
9027 Ok(f) => f,
9028 Err(_) => {
9029 return crate::blas::gpu_matmul_f32(a, b, m, k, n, device);
9031 }
9032 };
9033
9034 let mut c = alloc_zeros_f32(total, device)?;
9035 let cfg = launch_cfg(total)?;
9036 let m_u32 = m as u32;
9037 let k_u32 = k as u32;
9038 let n_u32 = n as u32;
9039 let total_u32 = total as u32;
9040
9041 unsafe {
9042 stream
9043 .launch_builder(&f)
9044 .arg(a.inner())
9045 .arg(b.inner())
9046 .arg(c.inner_mut())
9047 .arg(&m_u32)
9048 .arg(&k_u32)
9049 .arg(&n_u32)
9050 .arg(&total_u32)
9051 .launch(cfg)?;
9052 }
9053
9054 Ok(c)
9055}
9056
9057#[cfg(feature = "cuda")]
9066pub fn gpu_small_bmm(
9067 a: &CudaBuffer<f32>,
9068 b: &CudaBuffer<f32>,
9069 batch: usize,
9070 m: usize,
9071 k: usize,
9072 n: usize,
9073 device: &GpuDevice,
9074) -> GpuResult<CudaBuffer<f32>> {
9075 if batch == 1 {
9077 return gpu_small_matmul(a, b, m, k, n, device);
9078 }
9079 crate::blas::gpu_bmm_f32(a, b, batch, m, k, n, device)
9082}
9083
9084#[cfg(feature = "cuda")]
9092pub fn gpu_embed_lookup(
9093 idx: &CudaBuffer<f32>,
9094 weight: &CudaBuffer<f32>,
9095 d: usize,
9096 device: &GpuDevice,
9097) -> GpuResult<CudaBuffer<f32>> {
9098 use cudarc::driver::PushKernelArg;
9099
9100 let ctx = device.context();
9101 let stream = device.stream();
9102
9103 let f = match crate::module_cache::get_or_compile(
9104 ctx,
9105 EMBED_LOOKUP_PTX,
9106 "embed_lookup_kernel",
9107 device.ordinal() as u32,
9108 ) {
9109 Ok(f) => f,
9110 Err(_) => {
9111 let idx_host = gpu_to_cpu(idx, device)?;
9113 let weight_host = gpu_to_cpu(weight, device)?;
9114 let row = idx_host[0] as usize;
9115 let start = row * d;
9116 let out = weight_host[start..start + d].to_vec();
9117 return cpu_to_gpu(&out, device);
9118 }
9119 };
9120
9121 let mut out = alloc_zeros_f32(d, device)?;
9122 let cfg = launch_cfg(d)?;
9123 let d_u32 = d as u32;
9124
9125 unsafe {
9126 stream
9127 .launch_builder(&f)
9128 .arg(idx.inner())
9129 .arg(weight.inner())
9130 .arg(out.inner_mut())
9131 .arg(&d_u32)
9132 .launch(cfg)?;
9133 }
9134
9135 Ok(out)
9136}
9137
9138#[cfg(feature = "cuda")]
9145pub fn gpu_slice_write(
9146 src: &CudaBuffer<f32>,
9147 dst: &mut CudaBuffer<f32>,
9148 n_batch: usize,
9149 d: usize,
9150 max_len: usize,
9151 pos: usize,
9152 device: &GpuDevice,
9153) -> GpuResult<()> {
9154 use cudarc::driver::PushKernelArg;
9155
9156 let total = n_batch * d;
9157 let ctx = device.context();
9158 let stream = device.stream();
9159
9160 let f = match crate::module_cache::get_or_compile(
9161 ctx,
9162 SLICE_WRITE_PTX,
9163 "slice_write_kernel",
9164 device.ordinal() as u32,
9165 ) {
9166 Ok(f) => f,
9167 Err(_) => {
9168 let src_host = gpu_to_cpu(src, device)?;
9170 let mut dst_host = gpu_to_cpu(dst, device)?;
9171 for b in 0..n_batch {
9172 for di in 0..d {
9173 dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
9174 }
9175 }
9176 let new_dst = cpu_to_gpu(&dst_host, device)?;
9177 *dst = new_dst;
9178 return Ok(());
9179 }
9180 };
9181
9182 let cfg = launch_cfg(total)?;
9183 let n_u32 = total as u32;
9184 let d_u32 = d as u32;
9185 let max_len_u32 = max_len as u32;
9186 let pos_u32 = pos as u32;
9187
9188 unsafe {
9189 stream
9190 .launch_builder(&f)
9191 .arg(src.inner())
9192 .arg(dst.inner_mut())
9193 .arg(&n_u32)
9194 .arg(&d_u32)
9195 .arg(&max_len_u32)
9196 .arg(&pos_u32)
9197 .launch(cfg)?;
9198 }
9199
9200 Ok(())
9201}
9202
9203#[cfg(feature = "cuda")]
9209pub fn gpu_slice_read(
9210 src: &CudaBuffer<f32>,
9211 n_batch: usize,
9212 d: usize,
9213 len: usize,
9214 max_len: usize,
9215 device: &GpuDevice,
9216) -> GpuResult<CudaBuffer<f32>> {
9217 use cudarc::driver::PushKernelArg;
9218
9219 let total = n_batch * len * d;
9220 let ctx = device.context();
9221 let stream = device.stream();
9222
9223 let f = match crate::module_cache::get_or_compile(
9224 ctx,
9225 SLICE_READ_PTX,
9226 "slice_read_kernel",
9227 device.ordinal() as u32,
9228 ) {
9229 Ok(f) => f,
9230 Err(_) => {
9231 let host = gpu_to_cpu(src, device)?;
9232 let mut out = vec![0.0f32; total];
9233 for b in 0..n_batch {
9234 for r in 0..len {
9235 for di in 0..d {
9236 out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
9237 }
9238 }
9239 }
9240 return cpu_to_gpu(&out, device);
9241 }
9242 };
9243
9244 let mut out = alloc_zeros_f32(total, device)?;
9245 let cfg = launch_cfg(total)?;
9246 let total_u32 = total as u32;
9247 let d_u32 = d as u32;
9248 let len_u32 = len as u32;
9249 let max_len_u32 = max_len as u32;
9250
9251 unsafe {
9252 stream
9253 .launch_builder(&f)
9254 .arg(src.inner())
9255 .arg(out.inner_mut())
9256 .arg(&total_u32)
9257 .arg(&d_u32)
9258 .arg(&len_u32)
9259 .arg(&max_len_u32)
9260 .launch(cfg)?;
9261 }
9262
9263 Ok(out)
9264}
9265
9266#[cfg(feature = "cuda")]
9272pub fn gpu_gelu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
9273 validate_unary(input, device)?;
9274 if let Some(out) = try_launch_unary(input, device, GELU_PTX, "gelu_kernel")? {
9275 return Ok(out);
9276 }
9277 cpu_fallback_unary(input, device, |x| {
9278 let s = 1.0 / (1.0 + (-1.702 * x).exp());
9279 x * s
9280 })
9281}
9282
9283#[cfg(feature = "cuda")]
9288pub fn gpu_gelu_tanh(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
9289 validate_unary(input, device)?;
9290 if let Some(out) = try_launch_unary(input, device, GELU_TANH_PTX, "gelu_tanh_kernel")? {
9291 return Ok(out);
9292 }
9293 cpu_fallback_unary(input, device, |x| {
9294 let sqrt_2_over_pi: f32 = 0.7978845608;
9295 let c: f32 = 0.044715;
9296 let inner = sqrt_2_over_pi * (x + c * x * x * x);
9297 0.5 * x * (1.0 + inner.tanh())
9298 })
9299}
9300
9301#[cfg(feature = "cuda")]
9306pub fn gpu_gelu_erf(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
9307 validate_unary(input, device)?;
9308 if let Some(out) = try_launch_unary(input, device, GELU_ERF_PTX, "gelu_erf_kernel")? {
9309 return Ok(out);
9310 }
9311 cpu_fallback_unary(input, device, |x| {
9312 let z = x * std::f32::consts::FRAC_1_SQRT_2;
9314 let az = z.abs();
9315 let t = 1.0 / (1.0 + 0.3275911 * az);
9316 let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
9317 let erf_abs = 1.0 - poly * (-az * az).exp();
9318 let erf_val = if z < 0.0 { -erf_abs } else { erf_abs };
9319 x * 0.5 * (1.0 + erf_val)
9320 })
9321}
9322
9323#[cfg(feature = "cuda")]
9327pub fn gpu_gelu_backward_tanh(
9328 grad: &CudaBuffer<f32>,
9329 input: &CudaBuffer<f32>,
9330 device: &GpuDevice,
9331) -> GpuResult<CudaBuffer<f32>> {
9332 validate_binary(grad, input, device)?;
9333 if let Some(out) = try_launch_binary(
9334 grad,
9335 input,
9336 device,
9337 GELU_BACKWARD_TANH_PTX,
9338 "gelu_backward_tanh_kernel",
9339 )? {
9340 return Ok(out);
9341 }
9342 let grad_host = gpu_to_cpu(grad, device)?;
9344 let input_host = gpu_to_cpu(input, device)?;
9345 let result: Vec<f32> = grad_host
9346 .iter()
9347 .zip(input_host.iter())
9348 .map(|(&g, &x)| {
9349 let sqrt_2_over_pi: f32 = 0.7978845608;
9350 let c: f32 = 0.044715;
9351 let c3: f32 = 0.134145;
9352 let u = sqrt_2_over_pi * (x + c * x * x * x);
9353 let t = u.tanh();
9354 let dt = 1.0 - t * t;
9355 let d_inner = sqrt_2_over_pi * (1.0 + c3 * x * x);
9356 g * (0.5 * (1.0 + t) + 0.5 * x * dt * d_inner)
9357 })
9358 .collect();
9359 cpu_to_gpu(&result, device)
9360}
9361
9362#[cfg(feature = "cuda")]
9368pub fn gpu_silu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
9369 validate_unary(input, device)?;
9370 if let Some(out) = try_launch_unary(input, device, SILU_PTX, "silu_kernel")? {
9371 return Ok(out);
9372 }
9373 cpu_fallback_unary(input, device, |x| {
9374 let sig = 1.0 / (1.0 + (-x).exp());
9375 x * sig
9376 })
9377}
9378
9379#[cfg(feature = "cuda")]
9382pub fn gpu_silu_backward(
9383 grad: &CudaBuffer<f32>,
9384 input: &CudaBuffer<f32>,
9385 device: &GpuDevice,
9386) -> GpuResult<CudaBuffer<f32>> {
9387 validate_binary(grad, input, device)?;
9388
9389 if let Some(out) = try_launch_binary(
9390 grad,
9391 input,
9392 device,
9393 SILU_BACKWARD_PTX,
9394 "silu_backward_kernel",
9395 )? {
9396 return Ok(out);
9397 }
9398
9399 let grad_host = gpu_to_cpu(grad, device)?;
9401 let input_host = gpu_to_cpu(input, device)?;
9402 let result: Vec<f32> = grad_host
9403 .iter()
9404 .zip(input_host.iter())
9405 .map(|(&g, &x)| {
9406 let sig = 1.0 / (1.0 + (-x).exp());
9407 g * (sig + x * sig * (1.0 - sig))
9408 })
9409 .collect();
9410 cpu_to_gpu(&result, device)
9411}
9412
9413#[cfg(feature = "cuda")]
9421pub fn gpu_elu(
9422 input: &CudaBuffer<f32>,
9423 alpha: f32,
9424 device: &GpuDevice,
9425) -> GpuResult<CudaBuffer<f32>> {
9426 use cudarc::driver::PushKernelArg;
9427
9428 validate_unary(input, device)?;
9429
9430 let n = input.len();
9431 let ctx = device.context();
9432 let stream = device.stream();
9433
9434 let f = match crate::module_cache::get_or_compile(
9435 ctx,
9436 ELU_PTX,
9437 "elu_kernel",
9438 device.ordinal() as u32,
9439 ) {
9440 Ok(f) => f,
9441 Err(_) => {
9442 let host = gpu_to_cpu(input, device)?;
9443 let result: Vec<f32> = host
9444 .iter()
9445 .map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
9446 .collect();
9447 return cpu_to_gpu(&result, device);
9448 }
9449 };
9450
9451 let mut out = alloc_zeros_f32(n, device)?;
9452 let cfg = launch_cfg(n)?;
9453 let n_u32 = n as u32;
9454
9455 unsafe {
9456 stream
9457 .launch_builder(&f)
9458 .arg(input.inner())
9459 .arg(out.inner_mut())
9460 .arg(&n_u32)
9461 .arg(&alpha)
9462 .launch(cfg)?;
9463 }
9464
9465 Ok(out)
9466}
9467
9468#[cfg(feature = "cuda")]
9472pub fn gpu_elu_backward(
9473 grad: &CudaBuffer<f32>,
9474 input: &CudaBuffer<f32>,
9475 alpha: f32,
9476 device: &GpuDevice,
9477) -> GpuResult<CudaBuffer<f32>> {
9478 use cudarc::driver::PushKernelArg;
9479
9480 validate_binary(grad, input, device)?;
9481
9482 let n = grad.len();
9483 let ctx = device.context();
9484 let stream = device.stream();
9485
9486 let f = match crate::module_cache::get_or_compile(
9487 ctx,
9488 ELU_BACKWARD_PTX,
9489 "elu_backward_kernel",
9490 device.ordinal() as u32,
9491 ) {
9492 Ok(f) => f,
9493 Err(_) => {
9494 let grad_host = gpu_to_cpu(grad, device)?;
9495 let input_host = gpu_to_cpu(input, device)?;
9496 let result: Vec<f32> = grad_host
9497 .iter()
9498 .zip(input_host.iter())
9499 .map(|(&g, &x)| if x > 0.0 { g } else { g * alpha * x.exp() })
9500 .collect();
9501 return cpu_to_gpu(&result, device);
9502 }
9503 };
9504
9505 let mut out = alloc_zeros_f32(n, device)?;
9506 let cfg = launch_cfg(n)?;
9507 let n_u32 = n as u32;
9508
9509 unsafe {
9510 stream
9511 .launch_builder(&f)
9512 .arg(grad.inner())
9513 .arg(input.inner())
9514 .arg(out.inner_mut())
9515 .arg(&n_u32)
9516 .arg(&alpha)
9517 .launch(cfg)?;
9518 }
9519
9520 Ok(out)
9521}
9522
9523#[cfg(feature = "cuda")]
9529pub fn gpu_mish(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
9530 validate_unary(input, device)?;
9531 if let Some(out) = try_launch_unary(input, device, MISH_PTX, "mish_kernel")? {
9532 return Ok(out);
9533 }
9534 cpu_fallback_unary(input, device, |x| {
9535 let sp = if x > 20.0 { x } else { (1.0 + x.exp()).ln() };
9536 x * sp.tanh()
9537 })
9538}
9539
9540#[cfg(feature = "cuda")]
9544pub fn gpu_mish_backward(
9545 grad: &CudaBuffer<f32>,
9546 input: &CudaBuffer<f32>,
9547 device: &GpuDevice,
9548) -> GpuResult<CudaBuffer<f32>> {
9549 validate_binary(grad, input, device)?;
9550
9551 if let Some(out) = try_launch_binary(
9552 grad,
9553 input,
9554 device,
9555 MISH_BACKWARD_PTX,
9556 "mish_backward_kernel",
9557 )? {
9558 return Ok(out);
9559 }
9560
9561 let grad_host = gpu_to_cpu(grad, device)?;
9563 let input_host = gpu_to_cpu(input, device)?;
9564 let result: Vec<f32> = grad_host
9565 .iter()
9566 .zip(input_host.iter())
9567 .map(|(&g, &x)| {
9568 let sp = if x > 20.0 { x } else { (1.0 + x.exp()).ln() };
9569 let t = sp.tanh();
9570 let sig = 1.0 / (1.0 + (-x).exp());
9571 g * (t + x * sig * (1.0 - t * t))
9572 })
9573 .collect();
9574 cpu_to_gpu(&result, device)
9575}
9576
9577#[cfg(feature = "cuda")]
9581pub fn gpu_clamp(
9582 input: &CudaBuffer<f32>,
9583 min_val: f32,
9584 max_val: f32,
9585 device: &GpuDevice,
9586) -> GpuResult<CudaBuffer<f32>> {
9587 use cudarc::driver::PushKernelArg;
9588
9589 validate_unary(input, device)?;
9590
9591 let n = input.len();
9592 let ctx = device.context();
9593 let stream = device.stream();
9594
9595 let f = match crate::module_cache::get_or_compile(
9596 ctx,
9597 CLAMP_PTX,
9598 "clamp_kernel",
9599 device.ordinal() as u32,
9600 ) {
9601 Ok(f) => f,
9602 Err(_) => {
9603 let host = gpu_to_cpu(input, device)?;
9604 let result: Vec<f32> = host
9605 .iter()
9606 .map(|&x| x.max(min_val).min(max_val))
9607 .collect();
9608 return cpu_to_gpu(&result, device);
9609 }
9610 };
9611
9612 let mut out = alloc_zeros_f32(n, device)?;
9613 let cfg = launch_cfg(n)?;
9614 let n_u32 = n as u32;
9615
9616 unsafe {
9617 stream
9618 .launch_builder(&f)
9619 .arg(input.inner())
9620 .arg(out.inner_mut())
9621 .arg(&n_u32)
9622 .arg(&min_val)
9623 .arg(&max_val)
9624 .launch(cfg)?;
9625 }
9626
9627 Ok(out)
9628}
9629
9630#[cfg(feature = "cuda")]
9636pub fn gpu_div(
9637 a: &CudaBuffer<f32>,
9638 b: &CudaBuffer<f32>,
9639 device: &GpuDevice,
9640) -> GpuResult<CudaBuffer<f32>> {
9641 validate_binary(a, b, device)?;
9642
9643 if let Some(out) = try_launch_binary(a, b, device, DIV_PTX, "div_kernel")? {
9644 return Ok(out);
9645 }
9646
9647 let a_host = gpu_to_cpu(a, device)?;
9649 let b_host = gpu_to_cpu(b, device)?;
9650 let result: Vec<f32> = a_host
9651 .iter()
9652 .zip(b_host.iter())
9653 .map(|(&x, &y)| x / y)
9654 .collect();
9655 cpu_to_gpu(&result, device)
9656}
9657
9658#[cfg(feature = "cuda")]
9660pub fn gpu_exp(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
9661 validate_unary(a, device)?;
9662 if let Some(out) = try_launch_unary(a, device, EXP_PTX, "exp_kernel")? {
9663 return Ok(out);
9664 }
9665 cpu_fallback_unary(a, device, |x| x.exp())
9666}
9667
9668#[cfg(feature = "cuda")]
9670pub fn gpu_log(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
9671 validate_unary(a, device)?;
9672 if let Some(out) = try_launch_unary(a, device, LOG_PTX, "log_kernel")? {
9673 return Ok(out);
9674 }
9675 cpu_fallback_unary(a, device, |x| x.ln())
9676}
9677
9678#[cfg(feature = "cuda")]
9680pub fn gpu_sqrt(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
9681 validate_unary(a, device)?;
9682 if let Some(out) = try_launch_unary(a, device, SQRT_PTX, "sqrt_kernel")? {
9683 return Ok(out);
9684 }
9685 cpu_fallback_unary(a, device, |x| x.sqrt())
9686}
9687
9688#[cfg(feature = "cuda")]
9690pub fn gpu_pow(
9691 a: &CudaBuffer<f32>,
9692 exponent: f32,
9693 device: &GpuDevice,
9694) -> GpuResult<CudaBuffer<f32>> {
9695 use cudarc::driver::PushKernelArg;
9696
9697 validate_unary(a, device)?;
9698
9699 let n = a.len();
9700 let ctx = device.context();
9701 let stream = device.stream();
9702
9703 let f = match crate::module_cache::get_or_compile(
9704 ctx,
9705 POW_PTX,
9706 "pow_kernel",
9707 device.ordinal() as u32,
9708 ) {
9709 Ok(f) => f,
9710 Err(_) => {
9711 let host = gpu_to_cpu(a, device)?;
9712 let result: Vec<f32> = host.iter().map(|&x| x.powf(exponent)).collect();
9713 return cpu_to_gpu(&result, device);
9714 }
9715 };
9716
9717 let mut out = alloc_zeros_f32(n, device)?;
9718 let cfg = launch_cfg(n)?;
9719 let n_u32 = n as u32;
9720
9721 unsafe {
9722 stream
9723 .launch_builder(&f)
9724 .arg(a.inner())
9725 .arg(out.inner_mut())
9726 .arg(&exponent)
9727 .arg(&n_u32)
9728 .launch(cfg)?;
9729 }
9730
9731 Ok(out)
9732}
9733
9734#[cfg(feature = "cuda")]
9736pub fn gpu_abs(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
9737 validate_unary(a, device)?;
9738 if let Some(out) = try_launch_unary(a, device, ABS_PTX, "abs_kernel")? {
9739 return Ok(out);
9740 }
9741 cpu_fallback_unary(a, device, |x| x.abs())
9742}
9743
9744#[cfg(feature = "cuda")]
9746pub fn gpu_sigmoid(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
9747 validate_unary(a, device)?;
9748 if let Some(out) = try_launch_unary(a, device, SIGMOID_PTX, "sigmoid_kernel")? {
9749 return Ok(out);
9750 }
9751 cpu_fallback_unary(a, device, |x| 1.0 / (1.0 + (-x).exp()))
9752}
9753
9754#[cfg(feature = "cuda")]
9756pub fn gpu_tanh(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
9757 validate_unary(a, device)?;
9758 if let Some(out) = try_launch_unary(a, device, TANH_PTX, "tanh_kernel")? {
9759 return Ok(out);
9760 }
9761 cpu_fallback_unary(a, device, |x| x.tanh())
9762}
9763
9764#[cfg(feature = "cuda")]
9774#[allow(clippy::too_many_arguments)]
9775pub fn gpu_fused_adam(
9776 param: &mut CudaBuffer<f32>,
9777 grad: &CudaBuffer<f32>,
9778 exp_avg: &mut CudaBuffer<f32>,
9779 exp_avg_sq: &mut CudaBuffer<f32>,
9780 beta1: f32,
9781 beta2: f32,
9782 lr: f32,
9783 eps: f32,
9784 bc1: f32,
9785 bc2: f32,
9786 weight_decay: f32,
9787 device: &GpuDevice,
9788) -> GpuResult<()> {
9789 use cudarc::driver::PushKernelArg;
9790
9791 let n = param.len();
9792 if grad.len() != n || exp_avg.len() != n || exp_avg_sq.len() != n {
9793 return Err(GpuError::LengthMismatch {
9794 a: n,
9795 b: grad.len(),
9796 });
9797 }
9798
9799 let ctx = device.context();
9800 let stream = device.stream();
9801
9802 let f = match crate::module_cache::get_or_compile(
9803 ctx,
9804 FUSED_ADAM_PTX,
9805 "fused_adam_kernel",
9806 device.ordinal() as u32,
9807 ) {
9808 Ok(f) => f,
9809 Err(_) => {
9810 let mut p_host = gpu_to_cpu(param, device)?;
9812 let g_host = gpu_to_cpu(grad, device)?;
9813 let mut m_host = gpu_to_cpu(exp_avg, device)?;
9814 let mut v_host = gpu_to_cpu(exp_avg_sq, device)?;
9815
9816 for i in 0..n {
9817 let mut g = g_host[i];
9818 if weight_decay > 0.0 {
9819 g += weight_decay * p_host[i];
9820 }
9821 m_host[i] = beta1 * m_host[i] + (1.0 - beta1) * g;
9822 v_host[i] = beta2 * v_host[i] + (1.0 - beta2) * g * g;
9823 let m_hat = m_host[i] / bc1;
9824 let v_hat = v_host[i] / bc2;
9825 p_host[i] -= lr * m_hat / (v_hat.sqrt() + eps);
9826 }
9827
9828 *param = cpu_to_gpu(&p_host, device)?;
9829 *exp_avg = cpu_to_gpu(&m_host, device)?;
9830 *exp_avg_sq = cpu_to_gpu(&v_host, device)?;
9831 return Ok(());
9832 }
9833 };
9834
9835 let cfg = launch_cfg(n)?;
9836 let n_u32 = n as u32;
9837
9838 unsafe {
9839 stream
9840 .launch_builder(&f)
9841 .arg(param.inner_mut())
9842 .arg(grad.inner())
9843 .arg(exp_avg.inner_mut())
9844 .arg(exp_avg_sq.inner_mut())
9845 .arg(&beta1)
9846 .arg(&beta2)
9847 .arg(&lr)
9848 .arg(&eps)
9849 .arg(&bc1)
9850 .arg(&bc2)
9851 .arg(&weight_decay)
9852 .arg(&n_u32)
9853 .launch(cfg)?;
9854 }
9855
9856 Ok(())
9857}
9858
9859#[cfg(not(feature = "cuda"))]
9861#[allow(clippy::too_many_arguments)]
9862pub fn gpu_fused_adam(
9863 _param: &mut CudaBuffer<f32>,
9864 _grad: &CudaBuffer<f32>,
9865 _exp_avg: &mut CudaBuffer<f32>,
9866 _exp_avg_sq: &mut CudaBuffer<f32>,
9867 _beta1: f32,
9868 _beta2: f32,
9869 _lr: f32,
9870 _eps: f32,
9871 _bc1: f32,
9872 _bc2: f32,
9873 _weight_decay: f32,
9874 _device: &GpuDevice,
9875) -> GpuResult<()> {
9876 Err(GpuError::NoCudaFeature)
9877}
9878
9879#[cfg(feature = "cuda")]
9897pub fn gpu_fused_gru_forward(
9898 input_gates: &CudaBuffer<f32>,
9899 hidden_gates: &CudaBuffer<f32>,
9900 bias_ih: &CudaBuffer<f32>,
9901 bias_hh: &CudaBuffer<f32>,
9902 hx: &CudaBuffer<f32>,
9903 hsz: usize,
9904 device: &GpuDevice,
9905) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
9906 use cudarc::driver::PushKernelArg;
9907
9908 let total = hx.len(); let batch = total / hsz;
9910
9911 let ctx = device.context();
9912 let stream = device.stream();
9913
9914 let f = match crate::module_cache::get_or_compile(
9915 ctx,
9916 FUSED_GRU_FORWARD_PTX,
9917 "fused_gru_forward_kernel",
9918 device.ordinal() as u32,
9919 ) {
9920 Ok(f) => f,
9921 Err(_) => {
9922 return Err(GpuError::PtxCompileFailed {
9923 kernel: "fused_gru_forward_kernel",
9924 });
9925 }
9926 };
9927
9928 let mut hy = alloc_zeros_f32(total, device)?;
9929 let mut workspace = alloc_zeros_f32(batch * 5 * hsz, device)?;
9930
9931 let cfg = launch_cfg(total)?;
9932 let hsz_u32 = hsz as u32;
9933 let total_u32 = total as u32;
9934
9935 unsafe {
9936 stream
9937 .launch_builder(&f)
9938 .arg(input_gates.inner())
9939 .arg(hidden_gates.inner())
9940 .arg(bias_ih.inner())
9941 .arg(bias_hh.inner())
9942 .arg(hx.inner())
9943 .arg(hy.inner_mut())
9944 .arg(workspace.inner_mut())
9945 .arg(&hsz_u32)
9946 .arg(&total_u32)
9947 .launch(cfg)?;
9948 }
9949
9950 Ok((hy, workspace))
9951}
9952
9953#[cfg(not(feature = "cuda"))]
9955pub fn gpu_fused_gru_forward(
9956 _input_gates: &CudaBuffer<f32>,
9957 _hidden_gates: &CudaBuffer<f32>,
9958 _bias_ih: &CudaBuffer<f32>,
9959 _bias_hh: &CudaBuffer<f32>,
9960 _hx: &CudaBuffer<f32>,
9961 _hsz: usize,
9962 _device: &GpuDevice,
9963) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
9964 Err(GpuError::NoCudaFeature)
9965}
9966
9967#[cfg(feature = "cuda")]
9973#[allow(clippy::too_many_arguments)]
9974pub fn gpu_maxpool2d(
9975 input: &CudaBuffer<f32>,
9976 batch: usize,
9977 channels: usize,
9978 h_in: usize,
9979 w_in: usize,
9980 kh: usize,
9981 kw: usize,
9982 sh: usize,
9983 sw: usize,
9984 ph: usize,
9985 pw: usize,
9986 device: &GpuDevice,
9987) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
9988 use cudarc::driver::PushKernelArg;
9989
9990 let h_out = (h_in + 2 * ph - kh) / sh + 1;
9991 let w_out = (w_in + 2 * pw - kw) / sw + 1;
9992 let total = batch * channels * h_out * w_out;
9993
9994 let ctx = device.context();
9995 let stream = device.stream();
9996
9997 let f = match crate::module_cache::get_or_compile(
9998 ctx, MAXPOOL2D_PTX, "maxpool2d_forward_kernel", device.ordinal() as u32,
9999 ) {
10000 Ok(f) => f,
10001 Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "maxpool2d_forward_kernel" }),
10002 };
10003
10004 let mut out = alloc_zeros_f32(total, device)?;
10005 let cfg = launch_cfg(total)?;
10006
10007 let (batch_u32, ch_u32) = (batch as u32, channels as u32);
10008 let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
10009 let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
10010 let (kh_u32, kw_u32) = (kh as u32, kw as u32);
10011 let (sh_u32, sw_u32) = (sh as u32, sw as u32);
10012 let (ph_u32, pw_u32) = (ph as u32, pw as u32);
10013 let total_u32 = total as u32;
10014
10015 unsafe {
10016 stream.launch_builder(&f)
10017 .arg(input.inner())
10018 .arg(out.inner_mut())
10019 .arg(&batch_u32).arg(&ch_u32)
10020 .arg(&h_in_u32).arg(&w_in_u32)
10021 .arg(&h_out_u32).arg(&w_out_u32)
10022 .arg(&kh_u32).arg(&kw_u32)
10023 .arg(&sh_u32).arg(&sw_u32)
10024 .arg(&ph_u32).arg(&pw_u32)
10025 .arg(&total_u32)
10026 .launch(cfg)?;
10027 }
10028
10029 Ok((out, [batch, channels, h_out, w_out]))
10030}
10031
10032#[cfg(not(feature = "cuda"))]
10034#[allow(clippy::too_many_arguments)]
10035pub fn gpu_maxpool2d(
10036 _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
10037 _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
10038 _sh: usize, _sw: usize, _ph: usize, _pw: usize,
10039 _device: &GpuDevice,
10040) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
10041 Err(GpuError::NoCudaFeature)
10042}
10043
10044#[cfg(feature = "cuda")]
10046#[allow(clippy::too_many_arguments)]
10047pub fn gpu_avgpool2d(
10048 input: &CudaBuffer<f32>,
10049 batch: usize,
10050 channels: usize,
10051 h_in: usize,
10052 w_in: usize,
10053 kh: usize,
10054 kw: usize,
10055 sh: usize,
10056 sw: usize,
10057 ph: usize,
10058 pw: usize,
10059 device: &GpuDevice,
10060) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
10061 use cudarc::driver::PushKernelArg;
10062
10063 let h_out = (h_in + 2 * ph - kh) / sh + 1;
10064 let w_out = (w_in + 2 * pw - kw) / sw + 1;
10065 let total = batch * channels * h_out * w_out;
10066
10067 let ctx = device.context();
10068 let stream = device.stream();
10069
10070 let f = match crate::module_cache::get_or_compile(
10071 ctx, AVGPOOL2D_PTX, "avgpool2d_forward_kernel", device.ordinal() as u32,
10072 ) {
10073 Ok(f) => f,
10074 Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "avgpool2d_forward_kernel" }),
10075 };
10076
10077 let mut out = alloc_zeros_f32(total, device)?;
10078 let cfg = launch_cfg(total)?;
10079
10080 let (batch_u32, ch_u32) = (batch as u32, channels as u32);
10081 let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
10082 let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
10083 let (kh_u32, kw_u32) = (kh as u32, kw as u32);
10084 let (sh_u32, sw_u32) = (sh as u32, sw as u32);
10085 let (ph_u32, pw_u32) = (ph as u32, pw as u32);
10086 let total_u32 = total as u32;
10087
10088 unsafe {
10089 stream.launch_builder(&f)
10090 .arg(input.inner())
10091 .arg(out.inner_mut())
10092 .arg(&batch_u32).arg(&ch_u32)
10093 .arg(&h_in_u32).arg(&w_in_u32)
10094 .arg(&h_out_u32).arg(&w_out_u32)
10095 .arg(&kh_u32).arg(&kw_u32)
10096 .arg(&sh_u32).arg(&sw_u32)
10097 .arg(&ph_u32).arg(&pw_u32)
10098 .arg(&total_u32)
10099 .launch(cfg)?;
10100 }
10101
10102 Ok((out, [batch, channels, h_out, w_out]))
10103}
10104
10105#[cfg(not(feature = "cuda"))]
10107#[allow(clippy::too_many_arguments)]
10108pub fn gpu_avgpool2d(
10109 _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
10110 _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
10111 _sh: usize, _sw: usize, _ph: usize, _pw: usize,
10112 _device: &GpuDevice,
10113) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
10114 Err(GpuError::NoCudaFeature)
10115}
10116
10117#[cfg(feature = "cuda")]
10125#[allow(clippy::too_many_arguments)]
10126pub fn gpu_batchnorm_forward(
10127 _input: &CudaBuffer<f32>,
10128 _weight: &CudaBuffer<f32>,
10129 _bias: &CudaBuffer<f32>,
10130 _running_mean: &mut CudaBuffer<f32>,
10131 _running_var: &mut CudaBuffer<f32>,
10132 _channels: usize,
10133 _spatial: usize,
10134 _eps: f32,
10135 _momentum: f32,
10136 _training: bool,
10137 device: &GpuDevice,
10138) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
10139 let ctx = device.context();
10141 let _f = crate::module_cache::get_or_compile(
10142 ctx,
10143 BATCHNORM_FORWARD_PTX,
10144 "batchnorm_forward_kernel",
10145 device.ordinal() as u32,
10146 );
10147 Err(GpuError::ShapeMismatch {
10149 op: "batchnorm_forward",
10150 expected: vec![0],
10151 got: vec![1],
10152 })
10153}
10154
10155#[cfg(not(feature = "cuda"))]
10157#[allow(clippy::too_many_arguments)]
10158pub fn gpu_batchnorm_forward(
10159 _input: &CudaBuffer<f32>,
10160 _weight: &CudaBuffer<f32>,
10161 _bias: &CudaBuffer<f32>,
10162 _running_mean: &mut CudaBuffer<f32>,
10163 _running_var: &mut CudaBuffer<f32>,
10164 _channels: usize,
10165 _spatial: usize,
10166 _eps: f32,
10167 _momentum: f32,
10168 _training: bool,
10169 _device: &GpuDevice,
10170) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
10171 Err(GpuError::NoCudaFeature)
10172}
10173
10174#[cfg(feature = "cuda")]
10183pub fn gpu_layernorm(
10184 input: &CudaBuffer<f32>,
10185 weight: &CudaBuffer<f32>,
10186 bias: &CudaBuffer<f32>,
10187 rows: usize,
10188 cols: usize,
10189 eps: f32,
10190 device: &GpuDevice,
10191) -> GpuResult<CudaBuffer<f32>> {
10192 use cudarc::driver::PushKernelArg;
10193
10194 validate_unary(input, device)?;
10195
10196 let ctx = device.context();
10197 let stream = device.stream();
10198
10199 let f = match crate::module_cache::get_or_compile(
10200 ctx,
10201 LAYERNORM_PTX,
10202 "layernorm_kernel",
10203 device.ordinal() as u32,
10204 ) {
10205 Ok(f) => f,
10206 Err(e) => {
10207 eprintln!("ferrotorch-gpu: LayerNorm PTX compilation failed ({e:?}), CPU fallback");
10208 std::fs::write("/tmp/layernorm_debug.ptx", LAYERNORM_PTX).ok();
10209 eprintln!(
10210 "ferrotorch-gpu: dumped PTX to /tmp/layernorm_debug.ptx ({} bytes)",
10211 LAYERNORM_PTX.len()
10212 );
10213 let h_in = gpu_to_cpu(input, device)?;
10214 let h_w = gpu_to_cpu(weight, device)?;
10215 let h_b = gpu_to_cpu(bias, device)?;
10216 let mut out = vec![0.0f32; rows * cols];
10217 for r in 0..rows {
10218 let base = r * cols;
10219 let slice = &h_in[base..base + cols];
10220 let mean: f32 = slice.iter().sum::<f32>() / cols as f32;
10221 let var: f32 =
10222 slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / cols as f32;
10223 let inv_std = 1.0 / (var + eps).sqrt();
10224 for c in 0..cols {
10225 let normed = (slice[c] - mean) * inv_std;
10226 out[base + c] = h_w[c] * normed + h_b[c];
10227 }
10228 }
10229 return cpu_to_gpu(&out, device);
10230 }
10231 };
10232
10233 let mut out = alloc_zeros_f32(rows * cols, device)?;
10234 let rows_u32 = rows as u32;
10235 let cols_u32 = cols as u32;
10236
10237 let cfg = LaunchConfig {
10238 grid_dim: ((rows as u32).max(1), 1, 1),
10239 block_dim: (256, 1, 1),
10240 shared_mem_bytes: 256 * 4,
10241 };
10242
10243 unsafe {
10244 stream
10245 .launch_builder(&f)
10246 .arg(input.inner())
10247 .arg(out.inner_mut())
10248 .arg(weight.inner())
10249 .arg(bias.inner())
10250 .arg(&rows_u32)
10251 .arg(&cols_u32)
10252 .arg(&eps)
10253 .launch(cfg)?;
10254 }
10255
10256 Ok(out)
10257}
10258
10259#[cfg(feature = "cuda")]
10272pub fn gpu_layernorm_backward(
10273 input: &CudaBuffer<f32>,
10274 grad_output: &CudaBuffer<f32>,
10275 weight: &CudaBuffer<f32>,
10276 rows: usize,
10277 cols: usize,
10278 eps: f32,
10279 device: &GpuDevice,
10280) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
10281 use cudarc::driver::PushKernelArg;
10282
10283 validate_unary(input, device)?;
10284
10285 let ctx = device.context();
10286 let stream = device.stream();
10287
10288 let f = match crate::module_cache::get_or_compile(
10289 ctx,
10290 LAYERNORM_BACKWARD_PTX,
10291 "layernorm_backward_kernel",
10292 device.ordinal() as u32,
10293 ) {
10294 Ok(f) => f,
10295 Err(_) => {
10296 let h_in = gpu_to_cpu(input, device)?;
10298 let h_go = gpu_to_cpu(grad_output, device)?;
10299 let h_w = gpu_to_cpu(weight, device)?;
10300 let mut grad_input = vec![0.0f32; rows * cols];
10301 let mut grad_weight = vec![0.0f32; cols];
10302 let mut grad_bias = vec![0.0f32; cols];
10303 let n_f = cols as f32;
10304 for r in 0..rows {
10305 let base = r * cols;
10306 let x_slice = &h_in[base..base + cols];
10307 let go_slice = &h_go[base..base + cols];
10308 let mean: f32 = x_slice.iter().sum::<f32>() / n_f;
10309 let var: f32 = x_slice
10310 .iter()
10311 .map(|&x| (x - mean) * (x - mean))
10312 .sum::<f32>()
10313 / n_f;
10314 let inv_std = 1.0 / (var + eps).sqrt();
10315 let mut sum1 = 0.0f32;
10316 let mut sum2 = 0.0f32;
10317 for c in 0..cols {
10318 let x_hat = (x_slice[c] - mean) * inv_std;
10319 let dl = go_slice[c] * h_w[c];
10320 sum1 += dl;
10321 sum2 += dl * x_hat;
10322 grad_weight[c] += go_slice[c] * x_hat;
10323 grad_bias[c] += go_slice[c];
10324 }
10325 let m1 = sum1 / n_f;
10326 let m2 = sum2 / n_f;
10327 for c in 0..cols {
10328 let x_hat = (x_slice[c] - mean) * inv_std;
10329 let dl = go_slice[c] * h_w[c];
10330 grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
10331 }
10332 }
10333 let gi = cpu_to_gpu(&grad_input, device)?;
10334 let gw = cpu_to_gpu(&grad_weight, device)?;
10335 let gb = cpu_to_gpu(&grad_bias, device)?;
10336 return Ok((gi, gw, gb));
10337 }
10338 };
10339
10340 let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
10341 let mut grad_w = alloc_zeros_f32(cols, device)?;
10342 let mut grad_b = alloc_zeros_f32(cols, device)?;
10343 let rows_u32 = rows as u32;
10344 let cols_u32 = cols as u32;
10345
10346 let cfg = LaunchConfig {
10348 grid_dim: ((rows as u32).max(1), 1, 1),
10349 block_dim: (256, 1, 1),
10350 shared_mem_bytes: 256 * 4,
10351 };
10352
10353 unsafe {
10354 stream
10355 .launch_builder(&f)
10356 .arg(input.inner())
10357 .arg(grad_output.inner())
10358 .arg(weight.inner())
10359 .arg(grad_in.inner_mut())
10360 .arg(grad_w.inner_mut())
10361 .arg(grad_b.inner_mut())
10362 .arg(&rows_u32)
10363 .arg(&cols_u32)
10364 .arg(&eps)
10365 .launch(cfg)?;
10366 }
10367
10368 Ok((grad_in, grad_w, grad_b))
10369}
10370
10371#[cfg(not(feature = "cuda"))]
10373pub fn gpu_layernorm_backward(
10374 _input: &CudaBuffer<f32>,
10375 _grad_output: &CudaBuffer<f32>,
10376 _weight: &CudaBuffer<f32>,
10377 _rows: usize,
10378 _cols: usize,
10379 _eps: f32,
10380 _device: &GpuDevice,
10381) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
10382 Err(GpuError::NoCudaFeature)
10383}
10384
10385#[cfg(feature = "cuda")]
10397pub fn gpu_rmsnorm(
10398 input: &CudaBuffer<f32>,
10399 weight: &CudaBuffer<f32>,
10400 rows: usize,
10401 cols: usize,
10402 eps: f32,
10403 device: &GpuDevice,
10404) -> GpuResult<CudaBuffer<f32>> {
10405 use cudarc::driver::PushKernelArg;
10406
10407 validate_unary(input, device)?;
10408
10409 let ctx = device.context();
10410 let stream = device.stream();
10411
10412 let f = match crate::module_cache::get_or_compile(
10413 ctx,
10414 RMSNORM_PTX,
10415 "rmsnorm_kernel",
10416 device.ordinal() as u32,
10417 ) {
10418 Ok(f) => f,
10419 Err(e) => {
10420 eprintln!("ferrotorch-gpu: RMSNorm PTX compilation failed ({e:?}), CPU fallback");
10421 std::fs::write("/tmp/rmsnorm_debug.ptx", RMSNORM_PTX).ok();
10422 eprintln!(
10423 "ferrotorch-gpu: dumped PTX to /tmp/rmsnorm_debug.ptx ({} bytes)",
10424 RMSNORM_PTX.len()
10425 );
10426 let h_in = gpu_to_cpu(input, device)?;
10427 let h_w = gpu_to_cpu(weight, device)?;
10428 let mut out = vec![0.0f32; rows * cols];
10429 for r in 0..rows {
10430 let base = r * cols;
10431 let slice = &h_in[base..base + cols];
10432 let sq_mean: f32 =
10433 slice.iter().map(|&x| x * x).sum::<f32>() / cols as f32;
10434 let inv_rms = 1.0 / (sq_mean + eps).sqrt();
10435 for c in 0..cols {
10436 out[base + c] = slice[c] * inv_rms * h_w[c];
10437 }
10438 }
10439 return cpu_to_gpu(&out, device);
10440 }
10441 };
10442
10443 let mut out = alloc_zeros_f32(rows * cols, device)?;
10444 let rows_u32 = rows as u32;
10445 let cols_u32 = cols as u32;
10446
10447 let cfg = LaunchConfig {
10448 grid_dim: ((rows as u32).max(1), 1, 1),
10449 block_dim: (256, 1, 1),
10450 shared_mem_bytes: 256 * 4,
10451 };
10452
10453 unsafe {
10454 stream
10455 .launch_builder(&f)
10456 .arg(input.inner())
10457 .arg(out.inner_mut())
10458 .arg(weight.inner())
10459 .arg(&rows_u32)
10460 .arg(&cols_u32)
10461 .arg(&eps)
10462 .launch(cfg)?;
10463 }
10464
10465 Ok(out)
10466}
10467
10468#[cfg(feature = "cuda")]
10481pub fn gpu_rmsnorm_backward(
10482 input: &CudaBuffer<f32>,
10483 grad_output: &CudaBuffer<f32>,
10484 weight: &CudaBuffer<f32>,
10485 rows: usize,
10486 cols: usize,
10487 eps: f32,
10488 device: &GpuDevice,
10489) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
10490 use cudarc::driver::PushKernelArg;
10491
10492 validate_unary(input, device)?;
10493
10494 let ctx = device.context();
10495 let stream = device.stream();
10496
10497 let f = match crate::module_cache::get_or_compile(
10498 ctx,
10499 RMSNORM_BACKWARD_PTX,
10500 "rmsnorm_backward_kernel",
10501 device.ordinal() as u32,
10502 ) {
10503 Ok(f) => f,
10504 Err(_) => {
10505 let h_in = gpu_to_cpu(input, device)?;
10507 let h_go = gpu_to_cpu(grad_output, device)?;
10508 let h_w = gpu_to_cpu(weight, device)?;
10509 let mut grad_input = vec![0.0f32; rows * cols];
10510 let mut grad_weight = vec![0.0f32; cols];
10511 let n_f = cols as f32;
10512 for r in 0..rows {
10513 let base = r * cols;
10514 let x_slice = &h_in[base..base + cols];
10515 let go_slice = &h_go[base..base + cols];
10516 let sq_mean: f32 =
10517 x_slice.iter().map(|&x| x * x).sum::<f32>() / n_f;
10518 let inv_rms = 1.0 / (sq_mean + eps).sqrt();
10519 let inv_rms3 = inv_rms * inv_rms * inv_rms;
10520 let mut dot = 0.0f32;
10521 for c in 0..cols {
10522 dot += go_slice[c] * x_slice[c] * h_w[c];
10523 grad_weight[c] += go_slice[c] * x_slice[c] * inv_rms;
10524 }
10525 let coeff = dot * inv_rms3 / n_f;
10526 for c in 0..cols {
10527 grad_input[base + c] =
10528 inv_rms * h_w[c] * go_slice[c] - x_slice[c] * coeff;
10529 }
10530 }
10531 let gi = cpu_to_gpu(&grad_input, device)?;
10532 let gw = cpu_to_gpu(&grad_weight, device)?;
10533 return Ok((gi, gw));
10534 }
10535 };
10536
10537 let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
10538 let mut grad_w = alloc_zeros_f32(cols, device)?;
10539 let rows_u32 = rows as u32;
10540 let cols_u32 = cols as u32;
10541
10542 let cfg = LaunchConfig {
10544 grid_dim: ((rows as u32).max(1), 1, 1),
10545 block_dim: (256, 1, 1),
10546 shared_mem_bytes: 256 * 4,
10547 };
10548
10549 unsafe {
10550 stream
10551 .launch_builder(&f)
10552 .arg(input.inner())
10553 .arg(grad_output.inner())
10554 .arg(weight.inner())
10555 .arg(grad_in.inner_mut())
10556 .arg(grad_w.inner_mut())
10557 .arg(&rows_u32)
10558 .arg(&cols_u32)
10559 .arg(&eps)
10560 .launch(cfg)?;
10561 }
10562
10563 Ok((grad_in, grad_w))
10564}
10565
10566#[cfg(not(feature = "cuda"))]
10568pub fn gpu_rmsnorm(
10569 _input: &CudaBuffer<f32>,
10570 _weight: &CudaBuffer<f32>,
10571 _rows: usize,
10572 _cols: usize,
10573 _eps: f32,
10574 _device: &GpuDevice,
10575) -> GpuResult<CudaBuffer<f32>> {
10576 Err(GpuError::NoCudaFeature)
10577}
10578
10579#[cfg(not(feature = "cuda"))]
10581pub fn gpu_rmsnorm_backward(
10582 _input: &CudaBuffer<f32>,
10583 _grad_output: &CudaBuffer<f32>,
10584 _weight: &CudaBuffer<f32>,
10585 _rows: usize,
10586 _cols: usize,
10587 _eps: f32,
10588 _device: &GpuDevice,
10589) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
10590 Err(GpuError::NoCudaFeature)
10591}
10592
10593#[cfg(feature = "cuda")]
10603pub fn gpu_add_into(
10604 a: &CudaBuffer<f32>,
10605 b: &CudaBuffer<f32>,
10606 out: &mut CudaBuffer<f32>,
10607 device: &GpuDevice,
10608) -> GpuResult<()> {
10609 validate_binary(a, b, device)?;
10610 if out.len() < a.len() {
10611 return Err(GpuError::ShapeMismatch {
10612 op: "add_into",
10613 expected: vec![a.len()],
10614 got: vec![out.len()],
10615 });
10616 }
10617 if try_launch_binary_into(a, b, out, device, ADD_PTX, "add_kernel")? {
10618 return Ok(());
10619 }
10620 Err(GpuError::PtxCompileFailed {
10621 kernel: "add_kernel",
10622 })
10623}
10624
10625#[cfg(feature = "cuda")]
10627pub fn gpu_mul_into(
10628 a: &CudaBuffer<f32>,
10629 b: &CudaBuffer<f32>,
10630 out: &mut CudaBuffer<f32>,
10631 device: &GpuDevice,
10632) -> GpuResult<()> {
10633 validate_binary(a, b, device)?;
10634 if out.len() < a.len() {
10635 return Err(GpuError::ShapeMismatch {
10636 op: "mul_into",
10637 expected: vec![a.len()],
10638 got: vec![out.len()],
10639 });
10640 }
10641 if try_launch_binary_into(a, b, out, device, MUL_PTX, "mul_kernel")? {
10642 return Ok(());
10643 }
10644 Err(GpuError::PtxCompileFailed {
10645 kernel: "mul_kernel",
10646 })
10647}
10648
10649#[cfg(feature = "cuda")]
10651pub fn gpu_scale_into(
10652 a: &CudaBuffer<f32>,
10653 scalar: f32,
10654 out: &mut CudaBuffer<f32>,
10655 device: &GpuDevice,
10656) -> GpuResult<()> {
10657 use cudarc::driver::PushKernelArg;
10658 validate_unary(a, device)?;
10659 let n = a.len();
10660 let ctx = device.context();
10661 let stream = device.stream();
10662 let f = crate::module_cache::get_or_compile(
10663 ctx,
10664 SCALE_PTX,
10665 "scale_kernel",
10666 device.ordinal() as u32,
10667 )
10668 .map_err(|_| GpuError::PtxCompileFailed {
10669 kernel: "scale_kernel",
10670 })?;
10671 let cfg = launch_cfg(n)?;
10672 let n_u32 = n as u32;
10673 unsafe {
10674 stream
10675 .launch_builder(&f)
10676 .arg(a.inner())
10677 .arg(out.inner_mut())
10678 .arg(&scalar)
10679 .arg(&n_u32)
10680 .launch(cfg)?;
10681 }
10682 Ok(())
10683}
10684
10685#[cfg(feature = "cuda")]
10702pub fn gpu_has_inf_nan(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<bool> {
10703 let n = a.len();
10704 if n == 0 {
10705 return Ok(false);
10706 }
10707
10708 validate_unary(a, device)?;
10709
10710 let host: Vec<f32> = crate::transfer::gpu_to_cpu(a, device)?;
10711 Ok(host.iter().any(|v| !v.is_finite()))
10712}
10713
10714#[cfg(not(feature = "cuda"))]
10716pub fn gpu_has_inf_nan(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<bool> {
10717 Err(GpuError::NoCudaFeature)
10718}
10719
10720#[cfg(feature = "cuda")]
10722pub fn gpu_gelu_into(
10723 a: &CudaBuffer<f32>,
10724 out: &mut CudaBuffer<f32>,
10725 device: &GpuDevice,
10726) -> GpuResult<()> {
10727 validate_unary(a, device)?;
10728 if try_launch_unary_into(a, out, device, GELU_PTX, "gelu_kernel")? {
10729 return Ok(());
10730 }
10731 Err(GpuError::PtxCompileFailed {
10732 kernel: "gelu_kernel",
10733 })
10734}
10735
10736#[cfg(feature = "cuda")]
10738pub fn gpu_embed_lookup_into(
10739 idx: &CudaBuffer<f32>,
10740 weight: &CudaBuffer<f32>,
10741 d: usize,
10742 out: &mut CudaBuffer<f32>,
10743 device: &GpuDevice,
10744) -> GpuResult<()> {
10745 use cudarc::driver::PushKernelArg;
10746 let ctx = device.context();
10747 let stream = device.stream();
10748 let f = crate::module_cache::get_or_compile(
10749 ctx,
10750 EMBED_LOOKUP_PTX,
10751 "embed_lookup_kernel",
10752 device.ordinal() as u32,
10753 )
10754 .map_err(|_| GpuError::PtxCompileFailed {
10755 kernel: "embed_lookup_kernel",
10756 })?;
10757 let cfg = launch_cfg(d)?;
10758 let d_u32 = d as u32;
10759 unsafe {
10760 stream
10761 .launch_builder(&f)
10762 .arg(idx.inner())
10763 .arg(weight.inner())
10764 .arg(out.inner_mut())
10765 .arg(&d_u32)
10766 .launch(cfg)?;
10767 }
10768 Ok(())
10769}
10770
10771#[cfg(feature = "cuda")]
10779pub fn gpu_embed_lookup_batch(
10780 indices: &CudaBuffer<f32>,
10781 weight: &CudaBuffer<f32>,
10782 n: usize,
10783 d: usize,
10784 device: &GpuDevice,
10785) -> GpuResult<CudaBuffer<f32>> {
10786 use cudarc::driver::PushKernelArg;
10787
10788 let total = n * d;
10789 if total == 0 {
10790 return alloc_zeros_f32(0, device);
10791 }
10792
10793 let ctx = device.context();
10794 let stream = device.stream();
10795
10796 let f = match crate::module_cache::get_or_compile(
10797 ctx,
10798 EMBED_LOOKUP_BATCH_PTX,
10799 "embed_lookup_batch_kernel",
10800 device.ordinal() as u32,
10801 ) {
10802 Ok(f) => f,
10803 Err(_) => {
10804 let idx_host = gpu_to_cpu(indices, device)?;
10806 let weight_host = gpu_to_cpu(weight, device)?;
10807 let mut out = Vec::with_capacity(total);
10808 for &idx_f in &idx_host {
10809 let row = idx_f as usize;
10810 let start = row * d;
10811 out.extend_from_slice(&weight_host[start..start + d]);
10812 }
10813 return cpu_to_gpu(&out, device);
10814 }
10815 };
10816
10817 let mut out = alloc_zeros_f32(total, device)?;
10818 let cfg = launch_cfg(total)?;
10819 let d_u32 = d as u32;
10820 let total_u32 = total as u32;
10821
10822 unsafe {
10823 stream
10824 .launch_builder(&f)
10825 .arg(indices.inner())
10826 .arg(weight.inner())
10827 .arg(out.inner_mut())
10828 .arg(&d_u32)
10829 .arg(&total_u32)
10830 .launch(cfg)?;
10831 }
10832
10833 Ok(out)
10834}
10835
10836#[cfg(feature = "cuda")]
10846pub fn gpu_scatter_add_rows(
10847 grad_output: &CudaBuffer<f32>,
10848 indices: &CudaBuffer<f32>,
10849 num_embeddings: usize,
10850 d: usize,
10851 device: &GpuDevice,
10852) -> GpuResult<CudaBuffer<f32>> {
10853 use cudarc::driver::PushKernelArg;
10854
10855 let n = indices.len();
10856 let total = n * d;
10857
10858 if total == 0 {
10859 return alloc_zeros_f32(num_embeddings * d, device);
10860 }
10861
10862 let ctx = device.context();
10863 let stream = device.stream();
10864
10865 let f = match crate::module_cache::get_or_compile(
10866 ctx,
10867 SCATTER_ADD_ROWS_PTX,
10868 "scatter_add_rows_kernel",
10869 device.ordinal() as u32,
10870 ) {
10871 Ok(f) => f,
10872 Err(_) => {
10873 let go_host = gpu_to_cpu(grad_output, device)?;
10875 let idx_host = gpu_to_cpu(indices, device)?;
10876 let mut result = vec![0.0f32; num_embeddings * d];
10877 for (i, &idx_f) in idx_host.iter().enumerate() {
10878 let row = idx_f as usize;
10879 for j in 0..d {
10880 result[row * d + j] += go_host[i * d + j];
10881 }
10882 }
10883 return cpu_to_gpu(&result, device);
10884 }
10885 };
10886
10887 let mut out = alloc_zeros_f32(num_embeddings * d, device)?;
10888 let cfg = launch_cfg(total)?;
10889 let d_u32 = d as u32;
10890 let total_u32 = total as u32;
10891
10892 unsafe {
10893 stream
10894 .launch_builder(&f)
10895 .arg(grad_output.inner())
10896 .arg(indices.inner())
10897 .arg(out.inner_mut())
10898 .arg(&d_u32)
10899 .arg(&total_u32)
10900 .launch(cfg)?;
10901 }
10902
10903 Ok(out)
10904}
10905
10906#[cfg(feature = "cuda")]
10908pub fn gpu_transpose_2d_into(
10909 a: &CudaBuffer<f32>,
10910 m: usize,
10911 n: usize,
10912 out: &mut CudaBuffer<f32>,
10913 device: &GpuDevice,
10914) -> GpuResult<()> {
10915 use cudarc::driver::PushKernelArg;
10916 let total = m * n;
10917 let ctx = device.context();
10918 let stream = device.stream();
10919 let f = crate::module_cache::get_or_compile(
10920 ctx,
10921 TRANSPOSE_2D_PTX,
10922 "transpose_2d_kernel",
10923 device.ordinal() as u32,
10924 )
10925 .map_err(|_| GpuError::PtxCompileFailed {
10926 kernel: "transpose_2d_kernel",
10927 })?;
10928 let cfg = launch_cfg(total)?;
10929 let m_u32 = m as u32;
10930 let n_u32 = n as u32;
10931 let total_u32 = total as u32;
10932 unsafe {
10933 stream
10934 .launch_builder(&f)
10935 .arg(a.inner())
10936 .arg(out.inner_mut())
10937 .arg(&m_u32)
10938 .arg(&n_u32)
10939 .arg(&total_u32)
10940 .launch(cfg)?;
10941 }
10942 Ok(())
10943}
10944
10945#[cfg(feature = "cuda")]
10947pub fn gpu_permute_0213_into(
10948 a: &CudaBuffer<f32>,
10949 d0: usize,
10950 d1: usize,
10951 d2: usize,
10952 d3: usize,
10953 out: &mut CudaBuffer<f32>,
10954 device: &GpuDevice,
10955) -> GpuResult<()> {
10956 use cudarc::driver::PushKernelArg;
10957 let total = d0 * d1 * d2 * d3;
10958 let ctx = device.context();
10959 let stream = device.stream();
10960 let f = crate::module_cache::get_or_compile(
10961 ctx,
10962 PERMUTE_0213_PTX,
10963 "permute_0213_kernel",
10964 device.ordinal() as u32,
10965 )
10966 .map_err(|_| GpuError::PtxCompileFailed {
10967 kernel: "permute_0213_kernel",
10968 })?;
10969 let cfg = launch_cfg(total)?;
10970 let (d0u, d1u, d2u, d3u, tu) = (d0 as u32, d1 as u32, d2 as u32, d3 as u32, total as u32);
10971 unsafe {
10972 stream
10973 .launch_builder(&f)
10974 .arg(a.inner())
10975 .arg(out.inner_mut())
10976 .arg(&d0u)
10977 .arg(&d1u)
10978 .arg(&d2u)
10979 .arg(&d3u)
10980 .arg(&tu)
10981 .launch(cfg)?;
10982 }
10983 Ok(())
10984}
10985
10986#[cfg(feature = "cuda")]
10988pub fn gpu_softmax_into(
10989 a: &CudaBuffer<f32>,
10990 rows: usize,
10991 cols: usize,
10992 out: &mut CudaBuffer<f32>,
10993 device: &GpuDevice,
10994) -> GpuResult<()> {
10995 use cudarc::driver::PushKernelArg;
10996 let ctx = device.context();
10997 let stream = device.stream();
10998 let f = crate::module_cache::get_or_compile(
10999 ctx,
11000 SOFTMAX_PTX,
11001 "softmax_kernel",
11002 device.ordinal() as u32,
11003 )
11004 .map_err(|_| GpuError::PtxCompileFailed {
11005 kernel: "softmax_kernel",
11006 })?;
11007 let block_size = 256u32;
11008 let grid_size = rows as u32;
11009 let cfg = LaunchConfig {
11010 grid_dim: (grid_size, 1, 1),
11011 block_dim: (block_size, 1, 1),
11012 shared_mem_bytes: (cols as u32) * 4,
11013 };
11014 let rows_u32 = rows as u32;
11015 let cols_u32 = cols as u32;
11016 unsafe {
11017 stream
11018 .launch_builder(&f)
11019 .arg(a.inner())
11020 .arg(out.inner_mut())
11021 .arg(&rows_u32)
11022 .arg(&cols_u32)
11023 .launch(cfg)?;
11024 }
11025 Ok(())
11026}
11027
11028#[cfg(feature = "cuda")]
11030#[allow(clippy::too_many_arguments)]
11031pub fn gpu_layernorm_into(
11032 input: &CudaBuffer<f32>,
11033 weight: &CudaBuffer<f32>,
11034 bias: &CudaBuffer<f32>,
11035 rows: usize,
11036 cols: usize,
11037 eps: f32,
11038 out: &mut CudaBuffer<f32>,
11039 device: &GpuDevice,
11040) -> GpuResult<()> {
11041 use cudarc::driver::PushKernelArg;
11042 let ctx = device.context();
11043 let stream = device.stream();
11044 let f = crate::module_cache::get_or_compile(
11045 ctx,
11046 LAYERNORM_PTX,
11047 "layernorm_kernel",
11048 device.ordinal() as u32,
11049 )
11050 .map_err(|_| GpuError::PtxCompileFailed {
11051 kernel: "layernorm_kernel",
11052 })?;
11053 let block_size = 256u32;
11054 let grid_size = rows as u32;
11055 let cfg = LaunchConfig {
11056 grid_dim: (grid_size, 1, 1),
11057 block_dim: (block_size, 1, 1),
11058 shared_mem_bytes: (cols as u32) * 4,
11059 };
11060 let rows_u32 = rows as u32;
11061 let cols_u32 = cols as u32;
11062 unsafe {
11063 stream
11064 .launch_builder(&f)
11065 .arg(input.inner())
11066 .arg(out.inner_mut())
11067 .arg(weight.inner())
11068 .arg(bias.inner())
11069 .arg(&rows_u32)
11070 .arg(&cols_u32)
11071 .arg(&eps)
11072 .launch(cfg)?;
11073 }
11074 Ok(())
11075}
11076
11077#[cfg(feature = "cuda")]
11080pub fn gpu_slice_read_into(
11081 src: &CudaBuffer<f32>,
11082 n_batch: usize,
11083 d: usize,
11084 len: usize,
11085 max_len: usize,
11086 out: &mut CudaBuffer<f32>,
11087 device: &GpuDevice,
11088) -> GpuResult<()> {
11089 use cudarc::driver::PushKernelArg;
11090 let total = n_batch * len * d;
11091 let ctx = device.context();
11092 let stream = device.stream();
11093 let f = crate::module_cache::get_or_compile(
11094 ctx,
11095 SLICE_READ_PTX,
11096 "slice_read_kernel",
11097 device.ordinal() as u32,
11098 )
11099 .map_err(|_| GpuError::PtxCompileFailed {
11100 kernel: "slice_read_kernel",
11101 })?;
11102 let cfg = launch_cfg(total)?;
11103 let total_u32 = total as u32;
11104 let d_u32 = d as u32;
11105 let len_u32 = len as u32;
11106 let max_len_u32 = max_len as u32;
11107 unsafe {
11108 stream
11109 .launch_builder(&f)
11110 .arg(src.inner())
11111 .arg(out.inner_mut())
11112 .arg(&total_u32)
11113 .arg(&d_u32)
11114 .arg(&len_u32)
11115 .arg(&max_len_u32)
11116 .launch(cfg)?;
11117 }
11118 Ok(())
11119}
11120
11121#[cfg(feature = "cuda")]
11123pub fn gpu_small_matmul_into(
11124 a: &CudaBuffer<f32>,
11125 b: &CudaBuffer<f32>,
11126 m: usize,
11127 k: usize,
11128 n: usize,
11129 out: &mut CudaBuffer<f32>,
11130 device: &GpuDevice,
11131) -> GpuResult<()> {
11132 use cudarc::driver::PushKernelArg;
11133 let total = m * n;
11134 let ctx = device.context();
11135 let stream = device.stream();
11136 let f = crate::module_cache::get_or_compile(
11137 ctx,
11138 SMALL_MATMUL_PTX,
11139 "small_matmul_kernel",
11140 device.ordinal() as u32,
11141 )
11142 .map_err(|_| GpuError::PtxCompileFailed {
11143 kernel: "small_matmul_kernel",
11144 })?;
11145 let cfg = launch_cfg(total)?;
11146 let (m_u32, k_u32, n_u32, total_u32) = (m as u32, k as u32, n as u32, total as u32);
11147 unsafe {
11148 stream
11149 .launch_builder(&f)
11150 .arg(a.inner())
11151 .arg(b.inner())
11152 .arg(out.inner_mut())
11153 .arg(&m_u32)
11154 .arg(&k_u32)
11155 .arg(&n_u32)
11156 .arg(&total_u32)
11157 .launch(cfg)?;
11158 }
11159 Ok(())
11160}
11161
11162#[cfg(feature = "cuda")]
11169pub fn gpu_slice_write_indirect(
11170 src: &CudaBuffer<f32>,
11171 dst: &mut CudaBuffer<f32>,
11172 n_batch: usize,
11173 d: usize,
11174 max_len: usize,
11175 pos_ptr: &cudarc::driver::CudaSlice<u32>,
11176 device: &GpuDevice,
11177) -> GpuResult<()> {
11178 use cudarc::driver::PushKernelArg;
11179 let total = n_batch * d;
11180 let ctx = device.context();
11181 let stream = device.stream();
11182 let f = crate::module_cache::get_or_compile(
11183 ctx,
11184 SLICE_WRITE_INDIRECT_PTX,
11185 "slice_write_indirect_kernel",
11186 device.ordinal() as u32,
11187 )
11188 .map_err(|_| GpuError::PtxCompileFailed {
11189 kernel: "slice_write_indirect_kernel",
11190 })?;
11191 let cfg = launch_cfg(total)?;
11192 let n_u32 = total as u32;
11193 let d_u32 = d as u32;
11194 let max_len_u32 = max_len as u32;
11195 unsafe {
11196 stream
11197 .launch_builder(&f)
11198 .arg(src.inner())
11199 .arg(dst.inner_mut())
11200 .arg(&n_u32)
11201 .arg(&d_u32)
11202 .arg(&max_len_u32)
11203 .arg(pos_ptr)
11204 .launch(cfg)?;
11205 }
11206 Ok(())
11207}
11208
11209#[cfg(feature = "cuda")]
11213pub fn gpu_causal_mask_indirect(
11214 total_len_ptr: &cudarc::driver::CudaSlice<u32>,
11215 n_head: usize,
11216 max_pos: usize,
11217 out: &mut CudaBuffer<f32>,
11218 device: &GpuDevice,
11219) -> GpuResult<()> {
11220 use cudarc::driver::PushKernelArg;
11221 let total = n_head * max_pos;
11222 let ctx = device.context();
11223 let stream = device.stream();
11224 let f = crate::module_cache::get_or_compile(
11225 ctx,
11226 CAUSAL_MASK_INDIRECT_PTX,
11227 "causal_mask_indirect_kernel",
11228 device.ordinal() as u32,
11229 )
11230 .map_err(|_| GpuError::PtxCompileFailed {
11231 kernel: "causal_mask_indirect_kernel",
11232 })?;
11233 let cfg = launch_cfg(total)?;
11234 let max_pos_u32 = max_pos as u32;
11235 let total_u32 = total as u32;
11236 unsafe {
11237 stream
11238 .launch_builder(&f)
11239 .arg(total_len_ptr)
11240 .arg(out.inner_mut())
11241 .arg(&max_pos_u32)
11242 .arg(&total_u32)
11243 .launch(cfg)?;
11244 }
11245 Ok(())
11246}
11247
11248#[cfg(feature = "cuda")]
11256pub fn precompile_decode_kernels(device: &GpuDevice) -> GpuResult<()> {
11257 let ctx = device.context();
11258 ctx.bind_to_thread()?;
11259 let ord = device.ordinal() as u32;
11260 let compile = |ptx: &'static str, name: &'static str| -> GpuResult<()> {
11261 crate::module_cache::get_or_compile(ctx, ptx, name, ord)
11262 .map(|_| ())
11263 .map_err(GpuError::Driver)
11264 };
11265 compile(ADD_PTX, "add_kernel")?;
11266 compile(MUL_PTX, "mul_kernel")?;
11267 compile(SCALE_PTX, "scale_kernel")?;
11268 compile(GELU_PTX, "gelu_kernel")?;
11269 compile(SOFTMAX_PTX, "softmax_kernel")?;
11270 compile(LAYERNORM_PTX, "layernorm_kernel")?;
11271 compile(PERMUTE_0213_PTX, "permute_0213_kernel")?;
11272 compile(EMBED_LOOKUP_PTX, "embed_lookup_kernel")?;
11273 compile(EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel")?;
11274 compile(SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel")?;
11275 compile(SMALL_MATMUL_PTX, "small_matmul_kernel")?;
11276 compile(SLICE_WRITE_INDIRECT_PTX, "slice_write_indirect_kernel")?;
11277 compile(CAUSAL_MASK_INDIRECT_PTX, "causal_mask_indirect_kernel")?;
11278 compile(SLICE_READ_PTX, "slice_read_kernel")?;
11279 compile(RELU_BACKWARD_PTX, "relu_backward_kernel")?;
11280 compile(GELU_BACKWARD_PTX, "gelu_backward_kernel")?;
11281 Ok(())
11282}
11283
11284#[cfg(not(feature = "cuda"))]
11286pub fn precompile_decode_kernels(_device: &GpuDevice) -> GpuResult<()> {
11287 Err(GpuError::NoCudaFeature)
11288}
11289
11290#[cfg(not(feature = "cuda"))]
11296pub fn gpu_gelu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
11297 Err(GpuError::NoCudaFeature)
11298}
11299
11300#[cfg(not(feature = "cuda"))]
11302pub fn gpu_gelu_tanh(
11303 _input: &CudaBuffer<f32>,
11304 _device: &GpuDevice,
11305) -> GpuResult<CudaBuffer<f32>> {
11306 Err(GpuError::NoCudaFeature)
11307}
11308
11309#[cfg(not(feature = "cuda"))]
11311pub fn gpu_gelu_erf(
11312 _input: &CudaBuffer<f32>,
11313 _device: &GpuDevice,
11314) -> GpuResult<CudaBuffer<f32>> {
11315 Err(GpuError::NoCudaFeature)
11316}
11317
11318#[cfg(not(feature = "cuda"))]
11320pub fn gpu_gelu_backward_tanh(
11321 _grad: &CudaBuffer<f32>,
11322 _input: &CudaBuffer<f32>,
11323 _device: &GpuDevice,
11324) -> GpuResult<CudaBuffer<f32>> {
11325 Err(GpuError::NoCudaFeature)
11326}
11327
11328#[cfg(not(feature = "cuda"))]
11330pub fn gpu_silu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
11331 Err(GpuError::NoCudaFeature)
11332}
11333
11334#[cfg(not(feature = "cuda"))]
11336pub fn gpu_silu_backward(
11337 _grad: &CudaBuffer<f32>,
11338 _input: &CudaBuffer<f32>,
11339 _device: &GpuDevice,
11340) -> GpuResult<CudaBuffer<f32>> {
11341 Err(GpuError::NoCudaFeature)
11342}
11343
11344#[cfg(not(feature = "cuda"))]
11346pub fn gpu_elu(
11347 _input: &CudaBuffer<f32>,
11348 _alpha: f32,
11349 _device: &GpuDevice,
11350) -> GpuResult<CudaBuffer<f32>> {
11351 Err(GpuError::NoCudaFeature)
11352}
11353
11354#[cfg(not(feature = "cuda"))]
11356pub fn gpu_elu_backward(
11357 _grad: &CudaBuffer<f32>,
11358 _input: &CudaBuffer<f32>,
11359 _alpha: f32,
11360 _device: &GpuDevice,
11361) -> GpuResult<CudaBuffer<f32>> {
11362 Err(GpuError::NoCudaFeature)
11363}
11364
11365#[cfg(not(feature = "cuda"))]
11367pub fn gpu_mish(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
11368 Err(GpuError::NoCudaFeature)
11369}
11370
11371#[cfg(not(feature = "cuda"))]
11373pub fn gpu_mish_backward(
11374 _grad: &CudaBuffer<f32>,
11375 _input: &CudaBuffer<f32>,
11376 _device: &GpuDevice,
11377) -> GpuResult<CudaBuffer<f32>> {
11378 Err(GpuError::NoCudaFeature)
11379}
11380
11381#[cfg(not(feature = "cuda"))]
11383pub fn gpu_clamp(
11384 _input: &CudaBuffer<f32>,
11385 _min_val: f32,
11386 _max_val: f32,
11387 _device: &GpuDevice,
11388) -> GpuResult<CudaBuffer<f32>> {
11389 Err(GpuError::NoCudaFeature)
11390}
11391
11392#[cfg(not(feature = "cuda"))]
11394pub fn gpu_div(
11395 _a: &CudaBuffer<f32>,
11396 _b: &CudaBuffer<f32>,
11397 _device: &GpuDevice,
11398) -> GpuResult<CudaBuffer<f32>> {
11399 Err(GpuError::NoCudaFeature)
11400}
11401
11402#[cfg(not(feature = "cuda"))]
11404pub fn gpu_exp(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
11405 Err(GpuError::NoCudaFeature)
11406}
11407
11408#[cfg(not(feature = "cuda"))]
11410pub fn gpu_log(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
11411 Err(GpuError::NoCudaFeature)
11412}
11413
11414#[cfg(not(feature = "cuda"))]
11416pub fn gpu_sqrt(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
11417 Err(GpuError::NoCudaFeature)
11418}
11419
11420#[cfg(not(feature = "cuda"))]
11422pub fn gpu_pow(
11423 _a: &CudaBuffer<f32>,
11424 _exponent: f32,
11425 _device: &GpuDevice,
11426) -> GpuResult<CudaBuffer<f32>> {
11427 Err(GpuError::NoCudaFeature)
11428}
11429
11430#[cfg(not(feature = "cuda"))]
11432pub fn gpu_abs(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
11433 Err(GpuError::NoCudaFeature)
11434}
11435
11436#[cfg(not(feature = "cuda"))]
11438pub fn gpu_sigmoid(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
11439 Err(GpuError::NoCudaFeature)
11440}
11441
11442#[cfg(not(feature = "cuda"))]
11444pub fn gpu_tanh(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
11445 Err(GpuError::NoCudaFeature)
11446}
11447
11448#[cfg(not(feature = "cuda"))]
11450pub fn gpu_layernorm(
11451 _input: &CudaBuffer<f32>,
11452 _weight: &CudaBuffer<f32>,
11453 _bias: &CudaBuffer<f32>,
11454 _rows: usize,
11455 _cols: usize,
11456 _eps: f32,
11457 _device: &GpuDevice,
11458) -> GpuResult<CudaBuffer<f32>> {
11459 Err(GpuError::NoCudaFeature)
11460}
11461
11462#[cfg(not(feature = "cuda"))]
11464pub fn gpu_transpose_2d(
11465 _input: &CudaBuffer<f32>,
11466 _m: usize,
11467 _n: usize,
11468 _device: &GpuDevice,
11469) -> GpuResult<CudaBuffer<f32>> {
11470 Err(GpuError::NoCudaFeature)
11471}
11472
11473#[cfg(not(feature = "cuda"))]
11475pub fn gpu_add(
11476 _a: &CudaBuffer<f32>,
11477 _b: &CudaBuffer<f32>,
11478 _device: &GpuDevice,
11479) -> GpuResult<CudaBuffer<f32>> {
11480 Err(GpuError::NoCudaFeature)
11481}
11482
11483#[cfg(not(feature = "cuda"))]
11485pub fn gpu_sub(
11486 _a: &CudaBuffer<f32>,
11487 _b: &CudaBuffer<f32>,
11488 _device: &GpuDevice,
11489) -> GpuResult<CudaBuffer<f32>> {
11490 Err(GpuError::NoCudaFeature)
11491}
11492
11493#[cfg(not(feature = "cuda"))]
11495pub fn gpu_mul(
11496 _a: &CudaBuffer<f32>,
11497 _b: &CudaBuffer<f32>,
11498 _device: &GpuDevice,
11499) -> GpuResult<CudaBuffer<f32>> {
11500 Err(GpuError::NoCudaFeature)
11501}
11502
11503#[cfg(not(feature = "cuda"))]
11505pub fn gpu_neg(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
11506 Err(GpuError::NoCudaFeature)
11507}
11508
11509#[cfg(not(feature = "cuda"))]
11511pub fn gpu_relu(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
11512 Err(GpuError::NoCudaFeature)
11513}
11514
11515#[cfg(not(feature = "cuda"))]
11517pub fn gpu_scale(
11518 _a: &CudaBuffer<f32>,
11519 _scalar: f32,
11520 _device: &GpuDevice,
11521) -> GpuResult<CudaBuffer<f32>> {
11522 Err(GpuError::NoCudaFeature)
11523}
11524
11525#[cfg(not(feature = "cuda"))]
11527pub fn gpu_broadcast_add(
11528 _a: &CudaBuffer<f32>,
11529 _b: &CudaBuffer<f32>,
11530 _a_shape: &[usize],
11531 _b_shape: &[usize],
11532 _out_shape: &[usize],
11533 _device: &GpuDevice,
11534) -> GpuResult<CudaBuffer<f32>> {
11535 Err(GpuError::NoCudaFeature)
11536}
11537
11538#[cfg(not(feature = "cuda"))]
11540pub fn gpu_broadcast_sub(
11541 _a: &CudaBuffer<f32>,
11542 _b: &CudaBuffer<f32>,
11543 _a_shape: &[usize],
11544 _b_shape: &[usize],
11545 _out_shape: &[usize],
11546 _device: &GpuDevice,
11547) -> GpuResult<CudaBuffer<f32>> {
11548 Err(GpuError::NoCudaFeature)
11549}
11550
11551#[cfg(not(feature = "cuda"))]
11553pub fn gpu_broadcast_mul(
11554 _a: &CudaBuffer<f32>,
11555 _b: &CudaBuffer<f32>,
11556 _a_shape: &[usize],
11557 _b_shape: &[usize],
11558 _out_shape: &[usize],
11559 _device: &GpuDevice,
11560) -> GpuResult<CudaBuffer<f32>> {
11561 Err(GpuError::NoCudaFeature)
11562}
11563
11564#[cfg(not(feature = "cuda"))]
11566pub fn gpu_softmax(
11567 _input: &CudaBuffer<f32>,
11568 _rows: usize,
11569 _cols: usize,
11570 _device: &GpuDevice,
11571) -> GpuResult<CudaBuffer<f32>> {
11572 Err(GpuError::NoCudaFeature)
11573}
11574
11575#[cfg(not(feature = "cuda"))]
11577pub fn gpu_dropout(
11578 _input: &CudaBuffer<f32>,
11579 _threshold: u32,
11580 _scale: f32,
11581 _seed: u32,
11582 _device: &GpuDevice,
11583) -> GpuResult<CudaBuffer<f32>> {
11584 Err(GpuError::NoCudaFeature)
11585}
11586
11587#[cfg(not(feature = "cuda"))]
11589pub fn gpu_permute_0213(
11590 _input: &CudaBuffer<f32>,
11591 _d0: usize,
11592 _d1: usize,
11593 _d2: usize,
11594 _d3: usize,
11595 _device: &GpuDevice,
11596) -> GpuResult<CudaBuffer<f32>> {
11597 Err(GpuError::NoCudaFeature)
11598}
11599
11600#[cfg(not(feature = "cuda"))]
11602pub fn gpu_slice_write(
11603 _src: &CudaBuffer<f32>,
11604 _dst: &mut CudaBuffer<f32>,
11605 _n_batch: usize,
11606 _d: usize,
11607 _max_len: usize,
11608 _pos: usize,
11609 _device: &GpuDevice,
11610) -> GpuResult<()> {
11611 Err(GpuError::NoCudaFeature)
11612}
11613
11614#[cfg(not(feature = "cuda"))]
11616pub fn gpu_slice_read(
11617 _src: &CudaBuffer<f32>,
11618 _n_batch: usize,
11619 _d: usize,
11620 _len: usize,
11621 _max_len: usize,
11622 _device: &GpuDevice,
11623) -> GpuResult<CudaBuffer<f32>> {
11624 Err(GpuError::NoCudaFeature)
11625}
11626
11627#[cfg(not(feature = "cuda"))]
11629pub fn gpu_embed_lookup(
11630 _idx: &CudaBuffer<f32>,
11631 _weight: &CudaBuffer<f32>,
11632 _d: usize,
11633 _device: &GpuDevice,
11634) -> GpuResult<CudaBuffer<f32>> {
11635 Err(GpuError::NoCudaFeature)
11636}
11637
11638#[cfg(not(feature = "cuda"))]
11640pub fn gpu_embed_lookup_batch(
11641 _indices: &CudaBuffer<f32>,
11642 _weight: &CudaBuffer<f32>,
11643 _n: usize,
11644 _d: usize,
11645 _device: &GpuDevice,
11646) -> GpuResult<CudaBuffer<f32>> {
11647 Err(GpuError::NoCudaFeature)
11648}
11649
11650#[cfg(not(feature = "cuda"))]
11652pub fn gpu_scatter_add_rows(
11653 _grad_output: &CudaBuffer<f32>,
11654 _indices: &CudaBuffer<f32>,
11655 _num_embeddings: usize,
11656 _d: usize,
11657 _device: &GpuDevice,
11658) -> GpuResult<CudaBuffer<f32>> {
11659 Err(GpuError::NoCudaFeature)
11660}
11661
11662#[cfg(not(feature = "cuda"))]
11664pub fn gpu_relu_backward(
11665 _grad: &CudaBuffer<f32>,
11666 _input: &CudaBuffer<f32>,
11667 _device: &GpuDevice,
11668) -> GpuResult<CudaBuffer<f32>> {
11669 Err(GpuError::NoCudaFeature)
11670}
11671
11672#[cfg(not(feature = "cuda"))]
11674pub fn gpu_gelu_backward(
11675 _grad: &CudaBuffer<f32>,
11676 _input: &CudaBuffer<f32>,
11677 _device: &GpuDevice,
11678) -> GpuResult<CudaBuffer<f32>> {
11679 Err(GpuError::NoCudaFeature)
11680}
11681
11682#[cfg(not(feature = "cuda"))]
11684pub fn gpu_index_select_1d(
11685 _input: &CudaBuffer<f32>,
11686 _indices: &CudaBuffer<f32>,
11687 _device: &GpuDevice,
11688) -> GpuResult<CudaBuffer<f32>> {
11689 Err(GpuError::NoCudaFeature)
11690}
11691
11692#[cfg(not(feature = "cuda"))]
11694pub fn gpu_scatter_add_1d(
11695 _grad_output: &CudaBuffer<f32>,
11696 _indices: &CudaBuffer<f32>,
11697 _input_len: usize,
11698 _device: &GpuDevice,
11699) -> GpuResult<CudaBuffer<f32>> {
11700 Err(GpuError::NoCudaFeature)
11701}
11702
11703#[cfg(not(feature = "cuda"))]
11705pub fn gpu_masked_fill(
11706 _input: &CudaBuffer<f32>,
11707 _mask: &CudaBuffer<f32>,
11708 _value: f32,
11709 _device: &GpuDevice,
11710) -> GpuResult<CudaBuffer<f32>> {
11711 Err(GpuError::NoCudaFeature)
11712}
11713
11714#[cfg(not(feature = "cuda"))]
11716pub fn gpu_masked_zero(
11717 _grad: &CudaBuffer<f32>,
11718 _mask: &CudaBuffer<f32>,
11719 _device: &GpuDevice,
11720) -> GpuResult<CudaBuffer<f32>> {
11721 Err(GpuError::NoCudaFeature)
11722}
11723
11724#[cfg(not(feature = "cuda"))]
11726pub fn gpu_sigmoid_backward(
11727 _grad: &CudaBuffer<f32>,
11728 _output: &CudaBuffer<f32>,
11729 _device: &GpuDevice,
11730) -> GpuResult<CudaBuffer<f32>> {
11731 Err(GpuError::NoCudaFeature)
11732}
11733
11734#[cfg(not(feature = "cuda"))]
11736pub fn gpu_tanh_backward(
11737 _grad: &CudaBuffer<f32>,
11738 _output: &CudaBuffer<f32>,
11739 _device: &GpuDevice,
11740) -> GpuResult<CudaBuffer<f32>> {
11741 Err(GpuError::NoCudaFeature)
11742}
11743
11744#[cfg(not(feature = "cuda"))]
11746pub fn gpu_softmax_backward(
11747 _grad: &CudaBuffer<f32>,
11748 _output: &CudaBuffer<f32>,
11749 _cols: usize,
11750 _device: &GpuDevice,
11751) -> GpuResult<CudaBuffer<f32>> {
11752 Err(GpuError::NoCudaFeature)
11753}
11754
11755#[cfg(not(feature = "cuda"))]
11757pub fn gpu_log_softmax(
11758 _input: &CudaBuffer<f32>,
11759 _cols: usize,
11760 _device: &GpuDevice,
11761) -> GpuResult<CudaBuffer<f32>> {
11762 Err(GpuError::NoCudaFeature)
11763}
11764
11765#[cfg(not(feature = "cuda"))]
11767pub fn gpu_log_softmax_backward(
11768 _grad: &CudaBuffer<f32>,
11769 _output: &CudaBuffer<f32>,
11770 _cols: usize,
11771 _device: &GpuDevice,
11772) -> GpuResult<CudaBuffer<f32>> {
11773 Err(GpuError::NoCudaFeature)
11774}
11775
11776#[cfg(not(feature = "cuda"))]
11778pub fn gpu_sum_axis(
11779 _a: &CudaBuffer<f32>,
11780 _outer: usize,
11781 _axis_size: usize,
11782 _inner: usize,
11783 _device: &GpuDevice,
11784) -> GpuResult<CudaBuffer<f32>> {
11785 Err(GpuError::NoCudaFeature)
11786}
11787
11788#[cfg(not(feature = "cuda"))]
11790pub fn gpu_cumsum(
11791 _input: &CudaBuffer<f32>,
11792 _outer: usize,
11793 _dim_size: usize,
11794 _inner: usize,
11795 _device: &GpuDevice,
11796) -> GpuResult<CudaBuffer<f32>> {
11797 Err(GpuError::NoCudaFeature)
11798}
11799
11800#[cfg(not(feature = "cuda"))]
11802pub fn gpu_cumprod(
11803 _input: &CudaBuffer<f32>,
11804 _outer: usize,
11805 _dim_size: usize,
11806 _inner: usize,
11807 _device: &GpuDevice,
11808) -> GpuResult<CudaBuffer<f32>> {
11809 Err(GpuError::NoCudaFeature)
11810}
11811
11812#[cfg(not(feature = "cuda"))]
11814pub fn gpu_cummax(
11815 _input: &CudaBuffer<f32>,
11816 _outer: usize,
11817 _dim_size: usize,
11818 _inner: usize,
11819 _device: &GpuDevice,
11820) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
11821 Err(GpuError::NoCudaFeature)
11822}
11823
11824#[cfg(not(feature = "cuda"))]
11826pub fn gpu_cummin(
11827 _input: &CudaBuffer<f32>,
11828 _outer: usize,
11829 _dim_size: usize,
11830 _inner: usize,
11831 _device: &GpuDevice,
11832) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
11833 Err(GpuError::NoCudaFeature)
11834}
11835
11836#[cfg(not(feature = "cuda"))]
11838pub fn gpu_logcumsumexp(
11839 _input: &CudaBuffer<f32>,
11840 _outer: usize,
11841 _dim_size: usize,
11842 _inner: usize,
11843 _device: &GpuDevice,
11844) -> GpuResult<CudaBuffer<f32>> {
11845 Err(GpuError::NoCudaFeature)
11846}
11847
11848#[cfg(not(feature = "cuda"))]
11850pub fn gpu_strided_split(
11851 _input: &CudaBuffer<f32>,
11852 _total_along_axis: usize,
11853 _split_offset: usize,
11854 _split_size: usize,
11855 _inner_size: usize,
11856 _n: usize,
11857 _device: &GpuDevice,
11858) -> GpuResult<CudaBuffer<f32>> {
11859 Err(GpuError::NoCudaFeature)
11860}
11861
11862#[cfg(not(feature = "cuda"))]
11864pub fn gpu_strided_cat(
11865 _input: &CudaBuffer<f32>,
11866 _output: &mut CudaBuffer<f32>,
11867 _total_along_axis: usize,
11868 _cat_offset: usize,
11869 _part_size: usize,
11870 _inner_size: usize,
11871 _n: usize,
11872 _device: &GpuDevice,
11873) -> GpuResult<()> {
11874 Err(GpuError::NoCudaFeature)
11875}
11876
11877#[cfg(feature = "cuda")]
11893pub(crate) fn gpu_f32_to_f16(
11894 input: &CudaBuffer<f32>,
11895 device: &GpuDevice,
11896) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
11897 use cudarc::driver::PushKernelArg;
11898
11899 let n = input.len();
11900 if n == 0 {
11901 let empty = device.stream().alloc_zeros::<u16>(0)?;
11902 return Ok(empty);
11903 }
11904
11905 let ctx = device.context();
11906 let stream = device.stream();
11907
11908 let f = crate::module_cache::get_or_compile(
11909 ctx,
11910 F32_TO_F16_PTX,
11911 "f32_to_f16_kernel",
11912 device.ordinal() as u32,
11913 )
11914 .map_err(|_| GpuError::PtxCompileFailed {
11915 kernel: "f32_to_f16_kernel",
11916 })?;
11917
11918 let mut out = stream.alloc_zeros::<u16>(n)?;
11919 let cfg = launch_cfg(n)?;
11920 let n_u32 = n as u32;
11921
11922 unsafe {
11926 stream
11927 .launch_builder(&f)
11928 .arg(input.inner())
11929 .arg(&mut out)
11930 .arg(&n_u32)
11931 .launch(cfg)?;
11932 }
11933
11934 Ok(out)
11935}
11936
11937#[cfg(not(feature = "cuda"))]
11939pub(crate) fn gpu_f32_to_f16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
11940 Err(GpuError::NoCudaFeature)
11941}
11942
11943#[cfg(feature = "cuda")]
11948pub(crate) fn gpu_f32_to_bf16(
11949 input: &CudaBuffer<f32>,
11950 device: &GpuDevice,
11951) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
11952 use cudarc::driver::PushKernelArg;
11953
11954 let n = input.len();
11955 if n == 0 {
11956 let empty = device.stream().alloc_zeros::<u16>(0)?;
11957 return Ok(empty);
11958 }
11959
11960 let ctx = device.context();
11961 let stream = device.stream();
11962
11963 let f = crate::module_cache::get_or_compile(
11964 ctx,
11965 F32_TO_BF16_PTX,
11966 "f32_to_bf16_kernel",
11967 device.ordinal() as u32,
11968 )
11969 .map_err(|_| GpuError::PtxCompileFailed {
11970 kernel: "f32_to_bf16_kernel",
11971 })?;
11972
11973 let mut out = stream.alloc_zeros::<u16>(n)?;
11974 let cfg = launch_cfg(n)?;
11975 let n_u32 = n as u32;
11976
11977 unsafe {
11978 stream
11979 .launch_builder(&f)
11980 .arg(input.inner())
11981 .arg(&mut out)
11982 .arg(&n_u32)
11983 .launch(cfg)?;
11984 }
11985
11986 Ok(out)
11987}
11988
11989#[cfg(not(feature = "cuda"))]
11991pub(crate) fn gpu_f32_to_bf16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
11992 Err(GpuError::NoCudaFeature)
11993}
11994
11995#[cfg(test)]
12000#[cfg(feature = "cuda")]
12001mod tests {
12002 use super::*;
12003
12004 fn setup(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
12006 let dev = GpuDevice::new(0).expect("CUDA device 0");
12007 let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
12008 (dev, buf)
12009 }
12010
12011 fn assert_buf_eq(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32]) {
12014 let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
12015 assert_eq!(host.len(), expected.len(), "length mismatch");
12016 for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
12017 assert!(
12018 (got - exp).abs() < 1e-6,
12019 "element {i}: got {got}, expected {exp}",
12020 );
12021 }
12022 }
12023
12024 #[test]
12027 fn add_basic() {
12028 let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
12029 let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
12030 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
12031
12032 let (dev, a) = setup(&a_data);
12033 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
12034 let out = gpu_add(&a, &b, &dev).expect("gpu_add");
12035 assert_buf_eq(&out, &dev, &expected);
12036 }
12037
12038 #[test]
12039 fn add_empty() {
12040 let (dev, a) = setup(&[]);
12041 let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
12042 let out = gpu_add(&a, &b, &dev).expect("gpu_add empty");
12043 assert_eq!(out.len(), 0);
12044 }
12045
12046 #[test]
12047 fn add_large() {
12048 let n = 100_000;
12049 let a_data: Vec<f32> = (0..n).map(|i| i as f32).collect();
12050 let b_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
12051 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
12052
12053 let (dev, a) = setup(&a_data);
12054 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
12055 let out = gpu_add(&a, &b, &dev).expect("gpu_add large");
12056 assert_buf_eq(&out, &dev, &expected);
12057 }
12058
12059 #[test]
12060 fn add_length_mismatch() {
12061 let (dev, a) = setup(&[1.0, 2.0, 3.0]);
12062 let b = cpu_to_gpu(&[1.0, 2.0], &dev).expect("cpu_to_gpu b");
12063 let err = gpu_add(&a, &b, &dev).unwrap_err();
12064 match err {
12065 GpuError::LengthMismatch { a: 3, b: 2 } => {}
12066 other => panic!("unexpected error: {other}"),
12067 }
12068 }
12069
12070 #[test]
12073 fn sub_basic() {
12074 let a_data = vec![10.0f32, 20.0, 30.0, 40.0];
12075 let b_data = vec![1.0f32, 2.0, 3.0, 4.0];
12076 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x - y).collect();
12077
12078 let (dev, a) = setup(&a_data);
12079 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
12080 let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
12081 assert_buf_eq(&out, &dev, &expected);
12082 }
12083
12084 #[test]
12085 fn sub_negative_result() {
12086 let a_data = vec![1.0f32, 2.0];
12087 let b_data = vec![5.0f32, 10.0];
12088 let expected: Vec<f32> = vec![-4.0, -8.0];
12089
12090 let (dev, a) = setup(&a_data);
12091 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
12092 let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
12093 assert_buf_eq(&out, &dev, &expected);
12094 }
12095
12096 #[test]
12099 fn mul_basic() {
12100 let a_data = vec![2.0f32, 3.0, 4.0, 5.0];
12101 let b_data = vec![10.0f32, 10.0, 10.0, 10.0];
12102 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x * y).collect();
12103
12104 let (dev, a) = setup(&a_data);
12105 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
12106 let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
12107 assert_buf_eq(&out, &dev, &expected);
12108 }
12109
12110 #[test]
12111 fn mul_by_zero() {
12112 let a_data = vec![1.0f32, 2.0, 3.0];
12113 let b_data = vec![0.0f32, 0.0, 0.0];
12114 let expected = vec![0.0f32, 0.0, 0.0];
12115
12116 let (dev, a) = setup(&a_data);
12117 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
12118 let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
12119 assert_buf_eq(&out, &dev, &expected);
12120 }
12121
12122 #[test]
12125 fn neg_basic() {
12126 let a_data = vec![1.0f32, -2.0, 3.0, 0.0, -5.5];
12127 let expected: Vec<f32> = a_data.iter().map(|x| -x).collect();
12128
12129 let (dev, a) = setup(&a_data);
12130 let out = gpu_neg(&a, &dev).expect("gpu_neg");
12131 assert_buf_eq(&out, &dev, &expected);
12132 }
12133
12134 #[test]
12135 fn neg_double_negation() {
12136 let a_data = vec![1.0f32, -2.0, 3.0];
12137 let (dev, a) = setup(&a_data);
12138 let neg1 = gpu_neg(&a, &dev).expect("gpu_neg 1");
12139 let neg2 = gpu_neg(&neg1, &dev).expect("gpu_neg 2");
12140 assert_buf_eq(&neg2, &dev, &a_data);
12141 }
12142
12143 #[test]
12146 fn relu_basic() {
12147 let a_data = vec![-3.0f32, -1.0, 0.0, 1.0, 3.0];
12148 let expected = vec![0.0f32, 0.0, 0.0, 1.0, 3.0];
12149
12150 let (dev, a) = setup(&a_data);
12151 let out = gpu_relu(&a, &dev).expect("gpu_relu");
12152 assert_buf_eq(&out, &dev, &expected);
12153 }
12154
12155 #[test]
12156 fn relu_all_negative() {
12157 let a_data = vec![-5.0f32, -0.1, -100.0];
12158 let expected = vec![0.0f32, 0.0, 0.0];
12159
12160 let (dev, a) = setup(&a_data);
12161 let out = gpu_relu(&a, &dev).expect("gpu_relu");
12162 assert_buf_eq(&out, &dev, &expected);
12163 }
12164
12165 #[test]
12166 fn relu_all_positive() {
12167 let a_data = vec![0.1f32, 1.0, 100.0];
12168
12169 let (dev, a) = setup(&a_data);
12170 let out = gpu_relu(&a, &dev).expect("gpu_relu");
12171 assert_buf_eq(&out, &dev, &a_data);
12172 }
12173
12174 #[test]
12175 fn relu_empty() {
12176 let (dev, a) = setup(&[]);
12177 let out = gpu_relu(&a, &dev).expect("gpu_relu empty");
12178 assert_eq!(out.len(), 0);
12179 }
12180
12181 #[test]
12182 fn small_matmul_2x2() {
12183 let dev = GpuDevice::new(0).expect("CUDA device 0");
12184 let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0, 4.0], &dev).unwrap();
12187 let b = cpu_to_gpu(&[5.0f32, 6.0, 7.0, 8.0], &dev).unwrap();
12188 let c = gpu_small_matmul(&a, &b, 2, 2, 2, &dev).unwrap();
12189 assert_buf_eq(&c, &dev, &[19.0, 22.0, 43.0, 50.0]);
12190 }
12191
12192 #[test]
12193 fn small_matmul_1xk_kxn() {
12194 let dev = GpuDevice::new(0).expect("CUDA device 0");
12195 let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0], &dev).unwrap();
12198 let b = cpu_to_gpu(&[1.0f32, 0.0, 0.0, 1.0, 1.0, 1.0], &dev).unwrap();
12199 let c = gpu_small_matmul(&a, &b, 1, 3, 2, &dev).unwrap();
12200 assert_buf_eq(&c, &dev, &[4.0, 5.0]);
12201 }
12202
12203 #[test]
12204 fn small_matmul_vs_cublas() {
12205 let dev = GpuDevice::new(0).expect("CUDA device 0");
12208 let m = 1;
12209 let k = 64;
12210 let n = 64;
12211
12212 let a_data: Vec<f32> = (0..m * k)
12214 .map(|i| ((i * 7 + 3) % 100) as f32 / 100.0)
12215 .collect();
12216 let b_data: Vec<f32> = (0..k * n)
12217 .map(|i| ((i * 11 + 5) % 100) as f32 / 100.0)
12218 .collect();
12219
12220 let a = cpu_to_gpu(&a_data, &dev).unwrap();
12221 let b = cpu_to_gpu(&b_data, &dev).unwrap();
12222
12223 let c_cublas = crate::blas::gpu_matmul_f32(&a, &b, m, k, n, &dev).unwrap();
12225 let cublas_result = gpu_to_cpu(&c_cublas, &dev).unwrap();
12226
12227 let c_ours = gpu_small_matmul(&a, &b, m, k, n, &dev).unwrap();
12229 let our_result = gpu_to_cpu(&c_ours, &dev).unwrap();
12230
12231 assert_eq!(cublas_result.len(), our_result.len());
12232 for (i, (&cb, &ours)) in cublas_result.iter().zip(our_result.iter()).enumerate() {
12233 assert!(
12234 (cb - ours).abs() < 0.1,
12235 "element {i}: cuBLAS={cb}, ours={ours}, diff={}",
12236 (cb - ours).abs()
12237 );
12238 }
12239 }
12240}