1#[cfg(feature = "cuda")]
23use cudarc::driver::LaunchConfig;
24
25use crate::buffer::CudaBuffer;
26use crate::device::GpuDevice;
27use crate::error::{GpuError, GpuResult};
28#[cfg(feature = "cuda")]
29use crate::transfer::{alloc_zeros_f32, cpu_to_gpu, gpu_to_cpu};
30
31#[cfg(feature = "cuda")]
37pub(crate) const ADD_PTX: &str = "\
38.version 7.0
39.target sm_52
40.address_size 64
41
42.visible .entry add_kernel(
43 .param .u64 a_ptr,
44 .param .u64 b_ptr,
45 .param .u64 out_ptr,
46 .param .u32 n
47) {
48 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
49 .reg .u64 %a, %b, %out, %off;
50 .reg .f32 %va, %vb, %vr;
51 .reg .pred %p;
52
53 ld.param.u64 %a, [a_ptr];
54 ld.param.u64 %b, [b_ptr];
55 ld.param.u64 %out, [out_ptr];
56 ld.param.u32 %n_reg, [n];
57
58 mov.u32 %bid, %ctaid.x;
59 mov.u32 %bdim, %ntid.x;
60 mov.u32 %r_tid, %tid.x;
61 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
62
63 setp.ge.u32 %p, %r_tid, %n_reg;
64 @%p bra DONE;
65
66 cvt.u64.u32 %off, %r_tid;
67 shl.b64 %off, %off, 2;
68
69 add.u64 %a, %a, %off;
70 add.u64 %b, %b, %off;
71 add.u64 %out, %out, %off;
72
73 ld.global.f32 %va, [%a];
74 ld.global.f32 %vb, [%b];
75 add.f32 %vr, %va, %vb;
76 st.global.f32 [%out], %vr;
77
78DONE:
79 ret;
80}
81";
82
83#[cfg(feature = "cuda")]
88pub(crate) const ADD_VEC4_PTX: &str = "\
89.version 7.0
90.target sm_52
91.address_size 64
92
93.visible .entry add_vec4_kernel(
94 .param .u64 a_ptr,
95 .param .u64 b_ptr,
96 .param .u64 out_ptr,
97 .param .u32 n4
98) {
99 .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
100 .reg .u64 %a, %b, %out, %off;
101 .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
102 .reg .pred %p;
103
104 ld.param.u64 %a, [a_ptr];
105 ld.param.u64 %b, [b_ptr];
106 ld.param.u64 %out, [out_ptr];
107 ld.param.u32 %n4_reg, [n4];
108
109 mov.u32 %bid, %ctaid.x;
110 mov.u32 %bdim, %ntid.x;
111 mov.u32 %r_tid, %tid.x;
112 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
113
114 setp.ge.u32 %p, %r_tid, %n4_reg;
115 @%p bra DONE;
116
117 // Byte offset = tid * 16 (4 floats × 4 bytes)
118 cvt.u64.u32 %off, %r_tid;
119 shl.b64 %off, %off, 4;
120
121 add.u64 %a, %a, %off;
122 add.u64 %b, %b, %off;
123 add.u64 %out, %out, %off;
124
125 ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
126 ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
127
128 add.f32 %r0, %a0, %b0;
129 add.f32 %r1, %a1, %b1;
130 add.f32 %r2, %a2, %b2;
131 add.f32 %r3, %a3, %b3;
132
133 st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
134
135DONE:
136 ret;
137}
138";
139
140#[cfg(feature = "cuda")]
142pub(crate) const MUL_VEC4_PTX: &str = "\
143.version 7.0
144.target sm_52
145.address_size 64
146
147.visible .entry mul_vec4_kernel(
148 .param .u64 a_ptr,
149 .param .u64 b_ptr,
150 .param .u64 out_ptr,
151 .param .u32 n4
152) {
153 .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
154 .reg .u64 %a, %b, %out, %off;
155 .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
156 .reg .pred %p;
157
158 ld.param.u64 %a, [a_ptr];
159 ld.param.u64 %b, [b_ptr];
160 ld.param.u64 %out, [out_ptr];
161 ld.param.u32 %n4_reg, [n4];
162
163 mov.u32 %bid, %ctaid.x;
164 mov.u32 %bdim, %ntid.x;
165 mov.u32 %r_tid, %tid.x;
166 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
167
168 setp.ge.u32 %p, %r_tid, %n4_reg;
169 @%p bra DONE;
170
171 cvt.u64.u32 %off, %r_tid;
172 shl.b64 %off, %off, 4;
173
174 add.u64 %a, %a, %off;
175 add.u64 %b, %b, %off;
176 add.u64 %out, %out, %off;
177
178 ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
179 ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
180
181 mul.f32 %r0, %a0, %b0;
182 mul.f32 %r1, %a1, %b1;
183 mul.f32 %r2, %a2, %b2;
184 mul.f32 %r3, %a3, %b3;
185
186 st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
187
188DONE:
189 ret;
190}
191";
192
193#[cfg(feature = "cuda")]
195pub(crate) const SUB_PTX: &str = "\
196.version 7.0
197.target sm_52
198.address_size 64
199
200.visible .entry sub_kernel(
201 .param .u64 a_ptr,
202 .param .u64 b_ptr,
203 .param .u64 out_ptr,
204 .param .u32 n
205) {
206 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
207 .reg .u64 %a, %b, %out, %off;
208 .reg .f32 %va, %vb, %vr;
209 .reg .pred %p;
210
211 ld.param.u64 %a, [a_ptr];
212 ld.param.u64 %b, [b_ptr];
213 ld.param.u64 %out, [out_ptr];
214 ld.param.u32 %n_reg, [n];
215
216 mov.u32 %bid, %ctaid.x;
217 mov.u32 %bdim, %ntid.x;
218 mov.u32 %r_tid, %tid.x;
219 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
220
221 setp.ge.u32 %p, %r_tid, %n_reg;
222 @%p bra DONE;
223
224 cvt.u64.u32 %off, %r_tid;
225 shl.b64 %off, %off, 2;
226
227 add.u64 %a, %a, %off;
228 add.u64 %b, %b, %off;
229 add.u64 %out, %out, %off;
230
231 ld.global.f32 %va, [%a];
232 ld.global.f32 %vb, [%b];
233 sub.f32 %vr, %va, %vb;
234 st.global.f32 [%out], %vr;
235
236DONE:
237 ret;
238}
239";
240
241#[cfg(feature = "cuda")]
243pub(crate) const MUL_PTX: &str = "\
244.version 7.0
245.target sm_52
246.address_size 64
247
248.visible .entry mul_kernel(
249 .param .u64 a_ptr,
250 .param .u64 b_ptr,
251 .param .u64 out_ptr,
252 .param .u32 n
253) {
254 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
255 .reg .u64 %a, %b, %out, %off;
256 .reg .f32 %va, %vb, %vr;
257 .reg .pred %p;
258
259 ld.param.u64 %a, [a_ptr];
260 ld.param.u64 %b, [b_ptr];
261 ld.param.u64 %out, [out_ptr];
262 ld.param.u32 %n_reg, [n];
263
264 mov.u32 %bid, %ctaid.x;
265 mov.u32 %bdim, %ntid.x;
266 mov.u32 %r_tid, %tid.x;
267 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
268
269 setp.ge.u32 %p, %r_tid, %n_reg;
270 @%p bra DONE;
271
272 cvt.u64.u32 %off, %r_tid;
273 shl.b64 %off, %off, 2;
274
275 add.u64 %a, %a, %off;
276 add.u64 %b, %b, %off;
277 add.u64 %out, %out, %off;
278
279 ld.global.f32 %va, [%a];
280 ld.global.f32 %vb, [%b];
281 mul.f32 %vr, %va, %vb;
282 st.global.f32 [%out], %vr;
283
284DONE:
285 ret;
286}
287";
288
289#[cfg(feature = "cuda")]
291pub(crate) const NEG_PTX: &str = "\
292.version 7.0
293.target sm_52
294.address_size 64
295
296.visible .entry neg_kernel(
297 .param .u64 a_ptr,
298 .param .u64 out_ptr,
299 .param .u32 n
300) {
301 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
302 .reg .u64 %a, %out, %off;
303 .reg .f32 %va, %vr;
304 .reg .pred %p;
305
306 ld.param.u64 %a, [a_ptr];
307 ld.param.u64 %out, [out_ptr];
308 ld.param.u32 %n_reg, [n];
309
310 mov.u32 %bid, %ctaid.x;
311 mov.u32 %bdim, %ntid.x;
312 mov.u32 %r_tid, %tid.x;
313 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
314
315 setp.ge.u32 %p, %r_tid, %n_reg;
316 @%p bra DONE;
317
318 cvt.u64.u32 %off, %r_tid;
319 shl.b64 %off, %off, 2;
320
321 add.u64 %a, %a, %off;
322 add.u64 %out, %out, %off;
323
324 ld.global.f32 %va, [%a];
325 neg.f32 %vr, %va;
326 st.global.f32 [%out], %vr;
327
328DONE:
329 ret;
330}
331";
332
333#[cfg(feature = "cuda")]
335pub(crate) const RELU_PTX: &str = "\
336.version 7.0
337.target sm_52
338.address_size 64
339
340.visible .entry relu_kernel(
341 .param .u64 a_ptr,
342 .param .u64 out_ptr,
343 .param .u32 n
344) {
345 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
346 .reg .u64 %a, %out, %off;
347 .reg .f32 %va, %vr, %zero;
348 .reg .pred %p;
349
350 ld.param.u64 %a, [a_ptr];
351 ld.param.u64 %out, [out_ptr];
352 ld.param.u32 %n_reg, [n];
353
354 mov.u32 %bid, %ctaid.x;
355 mov.u32 %bdim, %ntid.x;
356 mov.u32 %r_tid, %tid.x;
357 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
358
359 setp.ge.u32 %p, %r_tid, %n_reg;
360 @%p bra DONE;
361
362 cvt.u64.u32 %off, %r_tid;
363 shl.b64 %off, %off, 2;
364
365 add.u64 %a, %a, %off;
366 add.u64 %out, %out, %off;
367
368 ld.global.f32 %va, [%a];
369 mov.f32 %zero, 0f00000000;
370 max.f32 %vr, %va, %zero;
371 st.global.f32 [%out], %vr;
372
373DONE:
374 ret;
375}
376";
377
378#[cfg(feature = "cuda")]
380pub(crate) const SCALE_PTX: &str = "\
381.version 7.0
382.target sm_52
383.address_size 64
384
385.visible .entry scale_kernel(
386 .param .u64 a_ptr,
387 .param .u64 out_ptr,
388 .param .f32 scalar,
389 .param .u32 n
390) {
391 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
392 .reg .u64 %a, %out, %off;
393 .reg .f32 %va, %vr, %s;
394 .reg .pred %p;
395
396 ld.param.u64 %a, [a_ptr];
397 ld.param.u64 %out, [out_ptr];
398 ld.param.f32 %s, [scalar];
399 ld.param.u32 %n_reg, [n];
400
401 mov.u32 %bid, %ctaid.x;
402 mov.u32 %bdim, %ntid.x;
403 mov.u32 %r_tid, %tid.x;
404 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
405
406 setp.ge.u32 %p, %r_tid, %n_reg;
407 @%p bra DONE;
408
409 cvt.u64.u32 %off, %r_tid;
410 shl.b64 %off, %off, 2;
411
412 add.u64 %a, %a, %off;
413 add.u64 %out, %out, %off;
414
415 ld.global.f32 %va, [%a];
416 mul.f32 %vr, %va, %s;
417 st.global.f32 [%out], %vr;
418
419DONE:
420 ret;
421}
422";
423
424#[cfg(feature = "cuda")]
427pub(crate) const TRANSPOSE_2D_PTX: &str = "\
428.version 7.0\n\
429.target sm_52\n\
430.address_size 64\n\
431\n\
432.visible .entry transpose_2d_kernel(\n\
433 .param .u64 in_ptr,\n\
434 .param .u64 out_ptr,\n\
435 .param .u32 M,\n\
436 .param .u32 N,\n\
437 .param .u32 total\n\
438) {\n\
439 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %N_reg;\n\
440 .reg .u32 %out_row, %out_col, %in_idx;\n\
441 .reg .u64 %in, %out, %off_in, %off_out;\n\
442 .reg .f32 %val;\n\
443 .reg .pred %p;\n\
444\n\
445 ld.param.u64 %in, [in_ptr];\n\
446 ld.param.u64 %out, [out_ptr];\n\
447 ld.param.u32 %M_reg, [M];\n\
448 ld.param.u32 %N_reg, [N];\n\
449 ld.param.u32 %total_reg, [total];\n\
450\n\
451 mov.u32 %bid, %ctaid.x;\n\
452 mov.u32 %bdim, %ntid.x;\n\
453 mov.u32 %r_tid, %tid.x;\n\
454 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
455\n\
456 setp.ge.u32 %p, %r_tid, %total_reg;\n\
457 @%p bra DONE;\n\
458\n\
459 // Output shape is [N, M]. tid = out_row * M + out_col.\n\
460 div.u32 %out_row, %r_tid, %M_reg;\n\
461 rem.u32 %out_col, %r_tid, %M_reg;\n\
462 // Input index: out_col * N + out_row (transposed).\n\
463 mad.lo.u32 %in_idx, %out_col, %N_reg, %out_row;\n\
464\n\
465 cvt.u64.u32 %off_in, %in_idx;\n\
466 shl.b64 %off_in, %off_in, 2;\n\
467 add.u64 %off_in, %in, %off_in;\n\
468 ld.global.f32 %val, [%off_in];\n\
469\n\
470 cvt.u64.u32 %off_out, %r_tid;\n\
471 shl.b64 %off_out, %off_out, 2;\n\
472 add.u64 %off_out, %out, %off_out;\n\
473 st.global.f32 [%off_out], %val;\n\
474\n\
475DONE:\n\
476 ret;\n\
477}\n\
478";
479
480#[cfg(feature = "cuda")]
488pub(crate) const PERMUTE_0213_PTX: &str = "\
489.version 7.0\n\
490.target sm_52\n\
491.address_size 64\n\
492\n\
493.visible .entry permute_0213_kernel(\n\
494 .param .u64 in_ptr,\n\
495 .param .u64 out_ptr,\n\
496 .param .u32 d0,\n\
497 .param .u32 d1,\n\
498 .param .u32 d2,\n\
499 .param .u32 d3,\n\
500 .param .u32 total\n\
501) {\n\
502 .reg .u32 %r_tid, %bid, %bdim, %total_reg;\n\
503 .reg .u32 %d0r, %d1r, %d2r, %d3r;\n\
504 .reg .u32 %i0, %i1, %i2, %i3, %rem, %in_idx;\n\
505 .reg .u32 %s_out2, %s_out1, %s_in1;\n\
506 .reg .u64 %in, %out, %off_in, %off_out;\n\
507 .reg .f32 %val;\n\
508 .reg .pred %p;\n\
509\n\
510 ld.param.u64 %in, [in_ptr];\n\
511 ld.param.u64 %out, [out_ptr];\n\
512 ld.param.u32 %d0r, [d0];\n\
513 ld.param.u32 %d1r, [d1];\n\
514 ld.param.u32 %d2r, [d2];\n\
515 ld.param.u32 %d3r, [d3];\n\
516 ld.param.u32 %total_reg, [total];\n\
517\n\
518 mov.u32 %bid, %ctaid.x;\n\
519 mov.u32 %bdim, %ntid.x;\n\
520 mov.u32 %r_tid, %tid.x;\n\
521 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
522\n\
523 setp.ge.u32 %p, %r_tid, %total_reg;\n\
524 @%p bra DONE;\n\
525\n\
526 // Output shape: [d0, d2, d1, d3]\n\
527 // Decompose tid into (i0, i2, i1, i3) in output layout.\n\
528 mul.lo.u32 %s_out2, %d1r, %d3r;\n\
529 mul.lo.u32 %s_out1, %s_out2, %d2r;\n\
530\n\
531 div.u32 %i0, %r_tid, %s_out1;\n\
532 rem.u32 %rem, %r_tid, %s_out1;\n\
533 div.u32 %i2, %rem, %s_out2;\n\
534 rem.u32 %rem, %rem, %s_out2;\n\
535 div.u32 %i1, %rem, %d3r;\n\
536 rem.u32 %i3, %rem, %d3r;\n\
537\n\
538 // Input index: i0 * (d1*d2*d3) + i1 * (d2*d3) + i2 * d3 + i3\n\
539 mul.lo.u32 %s_in1, %d2r, %d3r;\n\
540 mul.lo.u32 %in_idx, %i0, %d1r;\n\
541 add.u32 %in_idx, %in_idx, %i1;\n\
542 mul.lo.u32 %in_idx, %in_idx, %s_in1;\n\
543 mad.lo.u32 %in_idx, %i2, %d3r, %in_idx;\n\
544 add.u32 %in_idx, %in_idx, %i3;\n\
545\n\
546 cvt.u64.u32 %off_in, %in_idx;\n\
547 shl.b64 %off_in, %off_in, 2;\n\
548 add.u64 %off_in, %in, %off_in;\n\
549 ld.global.f32 %val, [%off_in];\n\
550\n\
551 cvt.u64.u32 %off_out, %r_tid;\n\
552 shl.b64 %off_out, %off_out, 2;\n\
553 add.u64 %off_out, %out, %off_out;\n\
554 st.global.f32 [%off_out], %val;\n\
555\n\
556DONE:\n\
557 ret;\n\
558}\n\
559";
560
561#[cfg(feature = "cuda")]
568pub(crate) const F32_TO_F16_PTX: &str = "\
569.version 7.0
570.target sm_52
571.address_size 64
572
573.visible .entry f32_to_f16_kernel(
574 .param .u64 in_ptr,
575 .param .u64 out_ptr,
576 .param .u32 n
577) {
578 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
579 .reg .u64 %in, %out, %off_in, %off_out;
580 .reg .f32 %vf;
581 .reg .b16 %vh;
582 .reg .pred %p;
583
584 ld.param.u64 %in, [in_ptr];
585 ld.param.u64 %out, [out_ptr];
586 ld.param.u32 %n_reg, [n];
587
588 mov.u32 %bid, %ctaid.x;
589 mov.u32 %bdim, %ntid.x;
590 mov.u32 %r_tid, %tid.x;
591 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
592
593 setp.ge.u32 %p, %r_tid, %n_reg;
594 @%p bra DONE;
595
596 // Compute input offset: i * 4 (f32 = 4 bytes)
597 cvt.u64.u32 %off_in, %r_tid;
598 shl.b64 %off_in, %off_in, 2;
599 add.u64 %in, %in, %off_in;
600
601 // Compute output offset: i * 2 (f16 = 2 bytes)
602 cvt.u64.u32 %off_out, %r_tid;
603 shl.b64 %off_out, %off_out, 1;
604 add.u64 %out, %out, %off_out;
605
606 // Load f32, convert to f16 (round-to-nearest-even), store as u16
607 ld.global.f32 %vf, [%in];
608 cvt.rn.f16.f32 %vh, %vf;
609 st.global.b16 [%out], %vh;
610
611DONE:
612 ret;
613}
614";
615
616#[cfg(feature = "cuda")]
623pub(crate) const F32_TO_BF16_PTX: &str = "\
624.version 7.0
625.target sm_52
626.address_size 64
627
628.visible .entry f32_to_bf16_kernel(
629 .param .u64 in_ptr,
630 .param .u64 out_ptr,
631 .param .u32 n
632) {
633 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
634 .reg .u64 %in, %out, %off_in, %off_out;
635 .reg .f32 %vf;
636 .reg .u32 %bits, %round, %lsb, %result;
637 .reg .pred %p;
638
639 ld.param.u64 %in, [in_ptr];
640 ld.param.u64 %out, [out_ptr];
641 ld.param.u32 %n_reg, [n];
642
643 mov.u32 %bid, %ctaid.x;
644 mov.u32 %bdim, %ntid.x;
645 mov.u32 %r_tid, %tid.x;
646 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
647
648 setp.ge.u32 %p, %r_tid, %n_reg;
649 @%p bra DONE;
650
651 cvt.u64.u32 %off_in, %r_tid;
652 shl.b64 %off_in, %off_in, 2;
653 add.u64 %in, %in, %off_in;
654
655 cvt.u64.u32 %off_out, %r_tid;
656 shl.b64 %off_out, %off_out, 1;
657 add.u64 %out, %out, %off_out;
658
659 // Load f32 as raw bits
660 ld.global.u32 %bits, [%in];
661
662 // Round-to-nearest-even: add (0x7FFF + bit[16]) then shift right 16
663 shr.u32 %lsb, %bits, 16;
664 and.b32 %lsb, %lsb, 1;
665 add.u32 %round, %bits, 0x7FFF;
666 add.u32 %round, %round, %lsb;
667 shr.u32 %result, %round, 16;
668
669 // Store as u16
670 st.global.u16 [%out], %result;
671
672DONE:
673 ret;
674}
675";
676
677#[cfg(feature = "cuda")]
684pub(crate) const SMALL_MATMUL_PTX: &str = "\
685.version 7.0
686.target sm_52
687.address_size 64
688
689.visible .entry small_matmul_kernel(
690 .param .u64 a_ptr,
691 .param .u64 b_ptr,
692 .param .u64 c_ptr,
693 .param .u32 M,
694 .param .u32 K,
695 .param .u32 N,
696 .param .u32 total
697) {
698 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %K_reg, %N_reg;
699 .reg .u32 %row, %col, %p, %idx;
700 .reg .u64 %a, %b, %c, %a_off, %b_off, %c_off;
701 .reg .f32 %sum, %va, %vb;
702 .reg .pred %bounds_p, %loop_p;
703
704 ld.param.u64 %a, [a_ptr];
705 ld.param.u64 %b, [b_ptr];
706 ld.param.u64 %c, [c_ptr];
707 ld.param.u32 %M_reg, [M];
708 ld.param.u32 %K_reg, [K];
709 ld.param.u32 %N_reg, [N];
710 ld.param.u32 %total_reg, [total];
711
712 mov.u32 %bid, %ctaid.x;
713 mov.u32 %bdim, %ntid.x;
714 mov.u32 %r_tid, %tid.x;
715 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
716
717 setp.ge.u32 %bounds_p, %r_tid, %total_reg;
718 @%bounds_p bra DONE;
719
720 div.u32 %row, %r_tid, %N_reg;
721 rem.u32 %col, %r_tid, %N_reg;
722
723 mov.f32 %sum, 0f00000000;
724 mov.u32 %p, 0;
725DOT:
726 setp.ge.u32 %loop_p, %p, %K_reg;
727 @%loop_p bra DOT_DONE;
728
729 mad.lo.u32 %idx, %row, %K_reg, %p;
730 cvt.u64.u32 %a_off, %idx;
731 shl.b64 %a_off, %a_off, 2;
732 add.u64 %a_off, %a, %a_off;
733 ld.global.f32 %va, [%a_off];
734
735 mad.lo.u32 %idx, %p, %N_reg, %col;
736 cvt.u64.u32 %b_off, %idx;
737 shl.b64 %b_off, %b_off, 2;
738 add.u64 %b_off, %b, %b_off;
739 ld.global.f32 %vb, [%b_off];
740
741 fma.rn.f32 %sum, %va, %vb, %sum;
742 add.u32 %p, %p, 1;
743 bra DOT;
744DOT_DONE:
745
746 cvt.u64.u32 %c_off, %r_tid;
747 shl.b64 %c_off, %c_off, 2;
748 add.u64 %c_off, %c, %c_off;
749 st.global.f32 [%c_off], %sum;
750
751DONE:
752 ret;
753}
754";
755
756#[cfg(feature = "cuda")]
761pub(crate) const SLICE_WRITE_PTX: &str = "\
762.version 7.0
763.target sm_52
764.address_size 64
765
766.visible .entry slice_write_kernel(
767 .param .u64 src_ptr,
768 .param .u64 dst_ptr,
769 .param .u32 n,
770 .param .u32 D,
771 .param .u32 max_len,
772 .param .u32 pos
773) {
774 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
775 .reg .u32 %batch_idx, %d_idx, %dst_row;
776 .reg .u64 %src, %dst, %src_off, %dst_off;
777 .reg .f32 %val;
778 .reg .pred %p;
779
780 ld.param.u64 %src, [src_ptr];
781 ld.param.u64 %dst, [dst_ptr];
782 ld.param.u32 %n_reg, [n];
783 ld.param.u32 %D_reg, [D];
784 ld.param.u32 %max_len_reg, [max_len];
785 ld.param.u32 %pos_reg, [pos];
786
787 mov.u32 %bid, %ctaid.x;
788 mov.u32 %bdim, %ntid.x;
789 mov.u32 %r_tid, %tid.x;
790 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
791
792 setp.ge.u32 %p, %r_tid, %n_reg;
793 @%p bra DONE;
794
795 cvt.u64.u32 %src_off, %r_tid;
796 shl.b64 %src_off, %src_off, 2;
797 add.u64 %src, %src, %src_off;
798 ld.global.f32 %val, [%src];
799
800 div.u32 %batch_idx, %r_tid, %D_reg;
801 rem.u32 %d_idx, %r_tid, %D_reg;
802 mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
803 add.u32 %dst_row, %dst_row, %pos_reg;
804 mul.lo.u32 %dst_row, %dst_row, %D_reg;
805 add.u32 %dst_row, %dst_row, %d_idx;
806 cvt.u64.u32 %dst_off, %dst_row;
807 shl.b64 %dst_off, %dst_off, 2;
808 add.u64 %dst, %dst, %dst_off;
809 st.global.f32 [%dst], %val;
810
811DONE:
812 ret;
813}
814";
815
816#[cfg(feature = "cuda")]
821pub(crate) const SLICE_WRITE_INDIRECT_PTX: &str = "\
822.version 7.0
823.target sm_52
824.address_size 64
825
826.visible .entry slice_write_indirect_kernel(
827 .param .u64 src_ptr,
828 .param .u64 dst_ptr,
829 .param .u32 n,
830 .param .u32 D,
831 .param .u32 max_len,
832 .param .u64 pos_ptr
833) {
834 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
835 .reg .u32 %batch_idx, %d_idx, %dst_row;
836 .reg .u64 %src, %dst, %src_off, %dst_off, %pos_p;
837 .reg .f32 %val;
838 .reg .pred %p;
839
840 ld.param.u64 %src, [src_ptr];
841 ld.param.u64 %dst, [dst_ptr];
842 ld.param.u32 %n_reg, [n];
843 ld.param.u32 %D_reg, [D];
844 ld.param.u32 %max_len_reg, [max_len];
845 ld.param.u64 %pos_p, [pos_ptr];
846 ld.global.u32 %pos_reg, [%pos_p];
847
848 mov.u32 %bid, %ctaid.x;
849 mov.u32 %bdim, %ntid.x;
850 mov.u32 %r_tid, %tid.x;
851 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
852
853 setp.ge.u32 %p, %r_tid, %n_reg;
854 @%p bra DONE;
855
856 cvt.u64.u32 %src_off, %r_tid;
857 shl.b64 %src_off, %src_off, 2;
858 add.u64 %src, %src, %src_off;
859 ld.global.f32 %val, [%src];
860
861 div.u32 %batch_idx, %r_tid, %D_reg;
862 rem.u32 %d_idx, %r_tid, %D_reg;
863 mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
864 add.u32 %dst_row, %dst_row, %pos_reg;
865 mul.lo.u32 %dst_row, %dst_row, %D_reg;
866 add.u32 %dst_row, %dst_row, %d_idx;
867 cvt.u64.u32 %dst_off, %dst_row;
868 shl.b64 %dst_off, %dst_off, 2;
869 add.u64 %dst, %dst, %dst_off;
870 st.global.f32 [%dst], %val;
871
872DONE:
873 ret;
874}
875";
876
877#[cfg(feature = "cuda")]
883pub(crate) const CAUSAL_MASK_INDIRECT_PTX: &str = "\
884.version 7.0
885.target sm_52
886.address_size 64
887
888.visible .entry causal_mask_indirect_kernel(
889 .param .u64 total_len_ptr,
890 .param .u64 out_ptr,
891 .param .u32 max_pos,
892 .param .u32 total
893) {
894 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %tlen, %max_pos_reg, %col;
895 .reg .u64 %out, %off, %tl_p;
896 .reg .f32 %val;
897 .reg .pred %bounds_p, %mask_p;
898
899 ld.param.u64 %tl_p, [total_len_ptr];
900 ld.param.u64 %out, [out_ptr];
901 ld.param.u32 %max_pos_reg, [max_pos];
902 ld.param.u32 %total_reg, [total];
903
904 ld.global.u32 %tlen, [%tl_p];
905
906 mov.u32 %bid, %ctaid.x;
907 mov.u32 %bdim, %ntid.x;
908 mov.u32 %r_tid, %tid.x;
909 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
910
911 setp.ge.u32 %bounds_p, %r_tid, %total_reg;
912 @%bounds_p bra DONE;
913
914 rem.u32 %col, %r_tid, %max_pos_reg;
915 setp.lt.u32 %mask_p, %col, %tlen;
916 @%mask_p bra WRITE_ZERO;
917
918 // 0fCE6E6B28 = -1.0e9 in IEEE 754 f32, used as a large negative mask value
919 // to effectively zero out masked positions after softmax.
920 mov.f32 %val, 0fCE6E6B28;
921 bra WRITE;
922
923WRITE_ZERO:
924 mov.f32 %val, 0f00000000;
925
926WRITE:
927 cvt.u64.u32 %off, %r_tid;
928 shl.b64 %off, %off, 2;
929 add.u64 %out, %out, %off;
930 st.global.f32 [%out], %val;
931
932DONE:
933 ret;
934}
935";
936
937#[cfg(feature = "cuda")]
942pub(crate) const EMBED_LOOKUP_PTX: &str = "\
943.version 7.0
944.target sm_52
945.address_size 64
946
947.visible .entry embed_lookup_kernel(
948 .param .u64 idx_ptr,
949 .param .u64 weight_ptr,
950 .param .u64 out_ptr,
951 .param .u32 D
952) {
953 .reg .u32 %r_tid, %bid, %bdim, %D_reg, %row, %src_idx;
954 .reg .u64 %idx_addr, %w, %out, %off;
955 .reg .f32 %idx_f, %val;
956 .reg .pred %p;
957
958 ld.param.u64 %idx_addr, [idx_ptr];
959 ld.param.u64 %w, [weight_ptr];
960 ld.param.u64 %out, [out_ptr];
961 ld.param.u32 %D_reg, [D];
962
963 mov.u32 %bid, %ctaid.x;
964 mov.u32 %bdim, %ntid.x;
965 mov.u32 %r_tid, %tid.x;
966 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
967
968 setp.ge.u32 %p, %r_tid, %D_reg;
969 @%p bra DONE;
970
971 ld.global.f32 %idx_f, [%idx_addr];
972 cvt.rzi.u32.f32 %row, %idx_f;
973
974 mad.lo.u32 %src_idx, %row, %D_reg, %r_tid;
975 cvt.u64.u32 %off, %src_idx;
976 shl.b64 %off, %off, 2;
977 add.u64 %off, %w, %off;
978 ld.global.f32 %val, [%off];
979
980 cvt.u64.u32 %off, %r_tid;
981 shl.b64 %off, %off, 2;
982 add.u64 %off, %out, %off;
983 st.global.f32 [%off], %val;
984
985DONE:
986 ret;
987}
988";
989
990#[cfg(feature = "cuda")]
998pub(crate) const EMBED_LOOKUP_BATCH_PTX: &str = "\
999.version 7.0
1000.target sm_52
1001.address_size 64
1002
1003.visible .entry embed_lookup_batch_kernel(
1004 .param .u64 idx_ptr,
1005 .param .u64 weight_ptr,
1006 .param .u64 out_ptr,
1007 .param .u32 D,
1008 .param .u32 total
1009) {
1010 .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1011 .reg .u32 %row, %col, %src_idx;
1012 .reg .u64 %idx_addr, %w, %out, %off;
1013 .reg .f32 %idx_f, %val;
1014 .reg .pred %p;
1015
1016 ld.param.u64 %idx_addr, [idx_ptr];
1017 ld.param.u64 %w, [weight_ptr];
1018 ld.param.u64 %out, [out_ptr];
1019 ld.param.u32 %D_reg, [D];
1020 ld.param.u32 %total_reg, [total];
1021
1022 mov.u32 %bid, %ctaid.x;
1023 mov.u32 %bdim, %ntid.x;
1024 mov.u32 %tid, %tid.x;
1025 mad.lo.u32 %tid, %bid, %bdim, %tid;
1026
1027 setp.ge.u32 %p, %tid, %total_reg;
1028 @%p bra DONE;
1029
1030 // row = tid / D, col = tid % D
1031 div.u32 %row, %tid, %D_reg;
1032 rem.u32 %col, %tid, %D_reg;
1033
1034 // Read indices[row] (f32 -> u32)
1035 cvt.u64.u32 %off, %row;
1036 shl.b64 %off, %off, 2;
1037 add.u64 %off, %idx_addr, %off;
1038 ld.global.f32 %idx_f, [%off];
1039 cvt.rzi.u32.f32 %src_idx, %idx_f;
1040
1041 // src_idx = indices[row] * D + col
1042 mad.lo.u32 %src_idx, %src_idx, %D_reg, %col;
1043 cvt.u64.u32 %off, %src_idx;
1044 shl.b64 %off, %off, 2;
1045 add.u64 %off, %w, %off;
1046 ld.global.f32 %val, [%off];
1047
1048 // Write to out[tid]
1049 cvt.u64.u32 %off, %tid;
1050 shl.b64 %off, %off, 2;
1051 add.u64 %off, %out, %off;
1052 st.global.f32 [%off], %val;
1053
1054DONE:
1055 ret;
1056}
1057";
1058
1059#[cfg(feature = "cuda")]
1067pub(crate) const SCATTER_ADD_ROWS_PTX: &str = "\
1068.version 7.0
1069.target sm_52
1070.address_size 64
1071
1072.visible .entry scatter_add_rows_kernel(
1073 .param .u64 grad_output_ptr,
1074 .param .u64 indices_ptr,
1075 .param .u64 grad_weight_ptr,
1076 .param .u32 D,
1077 .param .u32 total
1078) {
1079 .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1080 .reg .u32 %row, %col, %dst_idx;
1081 .reg .u64 %go, %idx_addr, %gw, %off;
1082 .reg .f32 %idx_f, %grad_val, %dummy;
1083 .reg .pred %p;
1084
1085 ld.param.u64 %go, [grad_output_ptr];
1086 ld.param.u64 %idx_addr, [indices_ptr];
1087 ld.param.u64 %gw, [grad_weight_ptr];
1088 ld.param.u32 %D_reg, [D];
1089 ld.param.u32 %total_reg, [total];
1090
1091 mov.u32 %bid, %ctaid.x;
1092 mov.u32 %bdim, %ntid.x;
1093 mov.u32 %tid, %tid.x;
1094 mad.lo.u32 %tid, %bid, %bdim, %tid;
1095
1096 setp.ge.u32 %p, %tid, %total_reg;
1097 @%p bra DONE;
1098
1099 // row = tid / D, col = tid % D
1100 div.u32 %row, %tid, %D_reg;
1101 rem.u32 %col, %tid, %D_reg;
1102
1103 // Read grad_output[tid]
1104 cvt.u64.u32 %off, %tid;
1105 shl.b64 %off, %off, 2;
1106 add.u64 %off, %go, %off;
1107 ld.global.f32 %grad_val, [%off];
1108
1109 // Read indices[row] (f32 -> u32)
1110 cvt.u64.u32 %off, %row;
1111 shl.b64 %off, %off, 2;
1112 add.u64 %off, %idx_addr, %off;
1113 ld.global.f32 %idx_f, [%off];
1114 cvt.rzi.u32.f32 %dst_idx, %idx_f;
1115
1116 // dst_idx = indices[row] * D + col
1117 mad.lo.u32 %dst_idx, %dst_idx, %D_reg, %col;
1118 cvt.u64.u32 %off, %dst_idx;
1119 shl.b64 %off, %off, 2;
1120 add.u64 %off, %gw, %off;
1121 atom.global.add.f32 %dummy, [%off], %grad_val;
1122
1123DONE:
1124 ret;
1125}
1126";
1127
1128#[cfg(feature = "cuda")]
1135pub(crate) const SLICE_READ_PTX: &str = "\
1136.version 7.0
1137.target sm_52
1138.address_size 64
1139
1140.visible .entry slice_read_kernel(
1141 .param .u64 src_ptr,
1142 .param .u64 dst_ptr,
1143 .param .u32 total,
1144 .param .u32 D,
1145 .param .u32 len,
1146 .param .u32 max_len
1147) {
1148 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %D_reg, %len_reg, %max_len_reg;
1149 .reg .u32 %batch_idx, %within, %row, %col, %src_idx;
1150 .reg .u32 %len_d;
1151 .reg .u64 %src, %dst, %src_off, %dst_off;
1152 .reg .f32 %val;
1153 .reg .pred %p;
1154
1155 ld.param.u64 %src, [src_ptr];
1156 ld.param.u64 %dst, [dst_ptr];
1157 ld.param.u32 %total_reg, [total];
1158 ld.param.u32 %D_reg, [D];
1159 ld.param.u32 %len_reg, [len];
1160 ld.param.u32 %max_len_reg, [max_len];
1161
1162 mov.u32 %bid, %ctaid.x;
1163 mov.u32 %bdim, %ntid.x;
1164 mov.u32 %r_tid, %tid.x;
1165 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1166
1167 setp.ge.u32 %p, %r_tid, %total_reg;
1168 @%p bra DONE;
1169
1170 // dst index = r_tid
1171 // batch_idx = r_tid / (len * D)
1172 // within = r_tid % (len * D)
1173 // row = within / D
1174 // col = within % D
1175 // src_idx = batch_idx * max_len * D + row * D + col
1176 mul.lo.u32 %len_d, %len_reg, %D_reg;
1177 div.u32 %batch_idx, %r_tid, %len_d;
1178 rem.u32 %within, %r_tid, %len_d;
1179 div.u32 %row, %within, %D_reg;
1180 rem.u32 %col, %within, %D_reg;
1181
1182 mul.lo.u32 %src_idx, %batch_idx, %max_len_reg;
1183 add.u32 %src_idx, %src_idx, %row;
1184 mul.lo.u32 %src_idx, %src_idx, %D_reg;
1185 add.u32 %src_idx, %src_idx, %col;
1186
1187 cvt.u64.u32 %src_off, %src_idx;
1188 shl.b64 %src_off, %src_off, 2;
1189 add.u64 %src_off, %src, %src_off;
1190 ld.global.f32 %val, [%src_off];
1191
1192 cvt.u64.u32 %dst_off, %r_tid;
1193 shl.b64 %dst_off, %dst_off, 2;
1194 add.u64 %dst_off, %dst, %dst_off;
1195 st.global.f32 [%dst_off], %val;
1196
1197DONE:
1198 ret;
1199}
1200";
1201
1202#[cfg(feature = "cuda")]
1212pub(crate) const GELU_PTX: &str = "\
1213.version 7.0
1214.target sm_52
1215.address_size 64
1216
1217.visible .entry gelu_kernel(
1218 .param .u64 in_ptr,
1219 .param .u64 out_ptr,
1220 .param .u32 n
1221) {
1222 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1223 .reg .u64 %in, %out, %off;
1224 .reg .f32 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
1225 .reg .pred %p;
1226
1227 ld.param.u64 %in, [in_ptr];
1228 ld.param.u64 %out, [out_ptr];
1229 ld.param.u32 %n_reg, [n];
1230
1231 mov.u32 %bid, %ctaid.x;
1232 mov.u32 %bdim, %ntid.x;
1233 mov.u32 %r_tid, %tid.x;
1234 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1235
1236 setp.ge.u32 %p, %r_tid, %n_reg;
1237 @%p bra DONE;
1238
1239 cvt.u64.u32 %off, %r_tid;
1240 shl.b64 %off, %off, 2;
1241 add.u64 %in, %in, %off;
1242 add.u64 %out, %out, %off;
1243
1244 ld.global.f32 %x, [%in];
1245
1246 mov.f32 %k, 0f3FDA2720;
1247 mul.f32 %neg_kx, %k, %x;
1248 neg.f32 %neg_kx, %neg_kx;
1249 mul.f32 %neg_kx, %neg_kx, 0f3FB8AA3B;
1250 ex2.approx.f32 %exp_neg, %neg_kx;
1251 mov.f32 %one, 0f3F800000;
1252 add.f32 %denom, %one, %exp_neg;
1253 rcp.approx.f32 %sig, %denom;
1254 mul.f32 %result, %x, %sig;
1255 st.global.f32 [%out], %result;
1256
1257DONE:
1258 ret;
1259}
1260";
1261
1262#[cfg(feature = "cuda")]
1269pub(crate) const RELU_BACKWARD_PTX: &str = "\
1270.version 7.0
1271.target sm_52
1272.address_size 64
1273
1274.visible .entry relu_backward_kernel(
1275 .param .u64 grad_ptr,
1276 .param .u64 input_ptr,
1277 .param .u64 out_ptr,
1278 .param .u32 n
1279) {
1280 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1281 .reg .u64 %grad, %input, %out, %off;
1282 .reg .f32 %vg, %vi, %zero, %vr;
1283 .reg .pred %p, %pos;
1284
1285 ld.param.u64 %grad, [grad_ptr];
1286 ld.param.u64 %input, [input_ptr];
1287 ld.param.u64 %out, [out_ptr];
1288 ld.param.u32 %n_reg, [n];
1289
1290 mov.u32 %bid, %ctaid.x;
1291 mov.u32 %bdim, %ntid.x;
1292 mov.u32 %r_tid, %tid.x;
1293 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1294
1295 setp.ge.u32 %p, %r_tid, %n_reg;
1296 @%p bra DONE;
1297
1298 cvt.u64.u32 %off, %r_tid;
1299 shl.b64 %off, %off, 2;
1300
1301 add.u64 %grad, %grad, %off;
1302 add.u64 %input, %input, %off;
1303 add.u64 %out, %out, %off;
1304
1305 ld.global.f32 %vg, [%grad];
1306 ld.global.f32 %vi, [%input];
1307 mov.f32 %zero, 0f00000000;
1308 setp.gt.f32 %pos, %vi, %zero;
1309 selp.f32 %vr, %vg, %zero, %pos;
1310 st.global.f32 [%out], %vr;
1311
1312DONE:
1313 ret;
1314}
1315";
1316
1317#[cfg(feature = "cuda")]
1327pub(crate) const GELU_BACKWARD_PTX: &str = "\
1328.version 7.0
1329.target sm_52
1330.address_size 64
1331
1332.visible .entry gelu_backward_kernel(
1333 .param .u64 grad_ptr,
1334 .param .u64 input_ptr,
1335 .param .u64 out_ptr,
1336 .param .u32 n
1337) {
1338 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1339 .reg .u64 %grad, %input, %out, %off;
1340 .reg .f32 %vg, %x, %k, %kx, %neg_kx, %log2e, %exp_neg, %one, %denom, %sig;
1341 .reg .f32 %one_minus_sig, %kx_sig_oms, %dsig, %result;
1342 .reg .pred %p;
1343
1344 ld.param.u64 %grad, [grad_ptr];
1345 ld.param.u64 %input, [input_ptr];
1346 ld.param.u64 %out, [out_ptr];
1347 ld.param.u32 %n_reg, [n];
1348
1349 mov.u32 %bid, %ctaid.x;
1350 mov.u32 %bdim, %ntid.x;
1351 mov.u32 %r_tid, %tid.x;
1352 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1353
1354 setp.ge.u32 %p, %r_tid, %n_reg;
1355 @%p bra DONE;
1356
1357 cvt.u64.u32 %off, %r_tid;
1358 shl.b64 %off, %off, 2;
1359
1360 add.u64 %grad, %grad, %off;
1361 add.u64 %input, %input, %off;
1362 add.u64 %out, %out, %off;
1363
1364 ld.global.f32 %vg, [%grad];
1365 ld.global.f32 %x, [%input];
1366
1367 // sig = sigmoid(1.702 * x)
1368 mov.f32 %k, 0f3FDA2720;
1369 mul.f32 %kx, %k, %x;
1370 neg.f32 %neg_kx, %kx;
1371 mov.f32 %log2e, 0f3FB8AA3B;
1372 mul.f32 %neg_kx, %neg_kx, %log2e;
1373 ex2.approx.f32 %exp_neg, %neg_kx;
1374 mov.f32 %one, 0f3F800000;
1375 add.f32 %denom, %one, %exp_neg;
1376 rcp.approx.f32 %sig, %denom;
1377
1378 // d/dx gelu(x) = sig + k * x * sig * (1 - sig)
1379 sub.f32 %one_minus_sig, %one, %sig;
1380 mul.f32 %kx_sig_oms, %kx, %sig;
1381 mul.f32 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
1382 add.f32 %dsig, %sig, %kx_sig_oms;
1383
1384 // out = grad * d_gelu
1385 mul.f32 %result, %vg, %dsig;
1386 st.global.f32 [%out], %result;
1387
1388DONE:
1389 ret;
1390}
1391";
1392
1393#[cfg(feature = "cuda")]
1400pub(crate) const INDEX_SELECT_1D_PTX: &str = "\
1401.version 7.0
1402.target sm_52
1403.address_size 64
1404
1405.visible .entry index_select_1d_kernel(
1406 .param .u64 input_ptr,
1407 .param .u64 indices_ptr,
1408 .param .u64 out_ptr,
1409 .param .u32 n_indices
1410) {
1411 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
1412 .reg .u64 %input, %indices, %out, %off, %addr;
1413 .reg .f32 %idx_f, %val;
1414 .reg .pred %p;
1415
1416 ld.param.u64 %input, [input_ptr];
1417 ld.param.u64 %indices, [indices_ptr];
1418 ld.param.u64 %out, [out_ptr];
1419 ld.param.u32 %n_reg, [n_indices];
1420
1421 mov.u32 %bid, %ctaid.x;
1422 mov.u32 %bdim, %ntid.x;
1423 mov.u32 %r_tid, %tid.x;
1424 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1425
1426 setp.ge.u32 %p, %r_tid, %n_reg;
1427 @%p bra DONE;
1428
1429 // Byte offset for thread
1430 cvt.u64.u32 %off, %r_tid;
1431 shl.b64 %off, %off, 2;
1432
1433 // Read indices[tid] (f32 -> u32)
1434 add.u64 %addr, %indices, %off;
1435 ld.global.f32 %idx_f, [%addr];
1436 cvt.rzi.u32.f32 %idx, %idx_f;
1437
1438 // Read input[idx]
1439 cvt.u64.u32 %addr, %idx;
1440 shl.b64 %addr, %addr, 2;
1441 add.u64 %addr, %input, %addr;
1442 ld.global.f32 %val, [%addr];
1443
1444 // Write output[tid]
1445 add.u64 %addr, %out, %off;
1446 st.global.f32 [%addr], %val;
1447
1448DONE:
1449 ret;
1450}
1451";
1452
1453#[cfg(feature = "cuda")]
1462pub(crate) const SCATTER_ADD_1D_PTX: &str = "\
1463.version 7.0
1464.target sm_52
1465.address_size 64
1466
1467.visible .entry scatter_add_1d_kernel(
1468 .param .u64 grad_output_ptr,
1469 .param .u64 indices_ptr,
1470 .param .u64 grad_input_ptr,
1471 .param .u32 n_indices
1472) {
1473 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
1474 .reg .u64 %go, %indices, %gi, %off, %addr;
1475 .reg .f32 %idx_f, %grad_val, %dummy;
1476 .reg .pred %p;
1477
1478 ld.param.u64 %go, [grad_output_ptr];
1479 ld.param.u64 %indices, [indices_ptr];
1480 ld.param.u64 %gi, [grad_input_ptr];
1481 ld.param.u32 %n_reg, [n_indices];
1482
1483 mov.u32 %bid, %ctaid.x;
1484 mov.u32 %bdim, %ntid.x;
1485 mov.u32 %r_tid, %tid.x;
1486 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1487
1488 setp.ge.u32 %p, %r_tid, %n_reg;
1489 @%p bra DONE;
1490
1491 // Byte offset for thread
1492 cvt.u64.u32 %off, %r_tid;
1493 shl.b64 %off, %off, 2;
1494
1495 // Read grad_output[tid]
1496 add.u64 %addr, %go, %off;
1497 ld.global.f32 %grad_val, [%addr];
1498
1499 // Read indices[tid] (f32 -> u32)
1500 add.u64 %addr, %indices, %off;
1501 ld.global.f32 %idx_f, [%addr];
1502 cvt.rzi.u32.f32 %idx, %idx_f;
1503
1504 // Atomic add: grad_input[idx] += grad_val
1505 cvt.u64.u32 %addr, %idx;
1506 shl.b64 %addr, %addr, 2;
1507 add.u64 %addr, %gi, %addr;
1508 atom.global.add.f32 %dummy, [%addr], %grad_val;
1509
1510DONE:
1511 ret;
1512}
1513";
1514
1515#[cfg(feature = "cuda")]
1522pub(crate) const MASKED_FILL_PTX: &str = "\
1523.version 7.0
1524.target sm_52
1525.address_size 64
1526
1527.visible .entry masked_fill_kernel(
1528 .param .u64 input_ptr,
1529 .param .u64 mask_ptr,
1530 .param .u64 out_ptr,
1531 .param .f32 fill_value,
1532 .param .u32 n
1533) {
1534 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1535 .reg .u64 %input, %mask, %out, %off;
1536 .reg .f32 %in_val, %mask_val, %fill, %result, %half;
1537 .reg .pred %p, %pmask;
1538
1539 ld.param.u64 %input, [input_ptr];
1540 ld.param.u64 %mask, [mask_ptr];
1541 ld.param.u64 %out, [out_ptr];
1542 ld.param.f32 %fill, [fill_value];
1543 ld.param.u32 %n_reg, [n];
1544
1545 mov.u32 %bid, %ctaid.x;
1546 mov.u32 %bdim, %ntid.x;
1547 mov.u32 %r_tid, %tid.x;
1548 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1549
1550 setp.ge.u32 %p, %r_tid, %n_reg;
1551 @%p bra DONE;
1552
1553 cvt.u64.u32 %off, %r_tid;
1554 shl.b64 %off, %off, 2;
1555
1556 add.u64 %input, %input, %off;
1557 add.u64 %mask, %mask, %off;
1558 add.u64 %out, %out, %off;
1559
1560 ld.global.f32 %in_val, [%input];
1561 ld.global.f32 %mask_val, [%mask];
1562 mov.f32 %half, 0f3F000000;
1563 setp.ge.f32 %pmask, %mask_val, %half;
1564 selp.f32 %result, %fill, %in_val, %pmask;
1565 st.global.f32 [%out], %result;
1566
1567DONE:
1568 ret;
1569}
1570";
1571
1572#[cfg(feature = "cuda")]
1579pub(crate) const MASKED_ZERO_PTX: &str = "\
1580.version 7.0
1581.target sm_52
1582.address_size 64
1583
1584.visible .entry masked_zero_kernel(
1585 .param .u64 grad_ptr,
1586 .param .u64 mask_ptr,
1587 .param .u64 out_ptr,
1588 .param .u32 n
1589) {
1590 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1591 .reg .u64 %grad, %mask, %out, %off;
1592 .reg .f32 %vg, %mask_val, %zero, %result, %half;
1593 .reg .pred %p, %pmask;
1594
1595 ld.param.u64 %grad, [grad_ptr];
1596 ld.param.u64 %mask, [mask_ptr];
1597 ld.param.u64 %out, [out_ptr];
1598 ld.param.u32 %n_reg, [n];
1599
1600 mov.u32 %bid, %ctaid.x;
1601 mov.u32 %bdim, %ntid.x;
1602 mov.u32 %r_tid, %tid.x;
1603 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1604
1605 setp.ge.u32 %p, %r_tid, %n_reg;
1606 @%p bra DONE;
1607
1608 cvt.u64.u32 %off, %r_tid;
1609 shl.b64 %off, %off, 2;
1610
1611 add.u64 %grad, %grad, %off;
1612 add.u64 %mask, %mask, %off;
1613 add.u64 %out, %out, %off;
1614
1615 ld.global.f32 %vg, [%grad];
1616 ld.global.f32 %mask_val, [%mask];
1617 mov.f32 %zero, 0f00000000;
1618 mov.f32 %half, 0f3F000000;
1619 setp.ge.f32 %pmask, %mask_val, %half;
1620 selp.f32 %result, %zero, %vg, %pmask;
1621 st.global.f32 [%out], %result;
1622
1623DONE:
1624 ret;
1625}
1626";
1627
1628#[cfg(feature = "cuda")]
1633pub(crate) const SIGMOID_BACKWARD_PTX: &str = "\
1634.version 7.0
1635.target sm_52
1636.address_size 64
1637
1638.visible .entry sigmoid_backward_kernel(
1639 .param .u64 grad_ptr,
1640 .param .u64 output_ptr,
1641 .param .u64 out_ptr,
1642 .param .u32 n
1643) {
1644 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1645 .reg .u64 %grad, %output, %out, %off;
1646 .reg .f32 %vg, %vo, %one, %one_minus_o, %result;
1647 .reg .pred %p;
1648
1649 ld.param.u64 %grad, [grad_ptr];
1650 ld.param.u64 %output, [output_ptr];
1651 ld.param.u64 %out, [out_ptr];
1652 ld.param.u32 %n_reg, [n];
1653
1654 mov.u32 %bid, %ctaid.x;
1655 mov.u32 %bdim, %ntid.x;
1656 mov.u32 %r_tid, %tid.x;
1657 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1658
1659 setp.ge.u32 %p, %r_tid, %n_reg;
1660 @%p bra DONE;
1661
1662 cvt.u64.u32 %off, %r_tid;
1663 shl.b64 %off, %off, 2;
1664
1665 add.u64 %grad, %grad, %off;
1666 add.u64 %output, %output, %off;
1667 add.u64 %out, %out, %off;
1668
1669 ld.global.f32 %vg, [%grad];
1670 ld.global.f32 %vo, [%output];
1671 mov.f32 %one, 0f3F800000;
1672 sub.f32 %one_minus_o, %one, %vo;
1673 mul.f32 %result, %vo, %one_minus_o;
1674 mul.f32 %result, %vg, %result;
1675 st.global.f32 [%out], %result;
1676
1677DONE:
1678 ret;
1679}
1680";
1681
1682#[cfg(feature = "cuda")]
1687pub(crate) const TANH_BACKWARD_PTX: &str = "\
1688.version 7.0
1689.target sm_52
1690.address_size 64
1691
1692.visible .entry tanh_backward_kernel(
1693 .param .u64 grad_ptr,
1694 .param .u64 output_ptr,
1695 .param .u64 out_ptr,
1696 .param .u32 n
1697) {
1698 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1699 .reg .u64 %grad, %output, %out, %off;
1700 .reg .f32 %vg, %vo, %one, %o_sq, %one_minus_sq, %result;
1701 .reg .pred %p;
1702
1703 ld.param.u64 %grad, [grad_ptr];
1704 ld.param.u64 %output, [output_ptr];
1705 ld.param.u64 %out, [out_ptr];
1706 ld.param.u32 %n_reg, [n];
1707
1708 mov.u32 %bid, %ctaid.x;
1709 mov.u32 %bdim, %ntid.x;
1710 mov.u32 %r_tid, %tid.x;
1711 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1712
1713 setp.ge.u32 %p, %r_tid, %n_reg;
1714 @%p bra DONE;
1715
1716 cvt.u64.u32 %off, %r_tid;
1717 shl.b64 %off, %off, 2;
1718
1719 add.u64 %grad, %grad, %off;
1720 add.u64 %output, %output, %off;
1721 add.u64 %out, %out, %off;
1722
1723 ld.global.f32 %vg, [%grad];
1724 ld.global.f32 %vo, [%output];
1725 mov.f32 %one, 0f3F800000;
1726 mul.f32 %o_sq, %vo, %vo;
1727 sub.f32 %one_minus_sq, %one, %o_sq;
1728 mul.f32 %result, %vg, %one_minus_sq;
1729 st.global.f32 [%out], %result;
1730
1731DONE:
1732 ret;
1733}
1734";
1735
1736#[cfg(feature = "cuda")]
1745pub(crate) const SOFTMAX_BACKWARD_PTX: &str = "\
1746.version 7.0\n\
1747.target sm_52\n\
1748.address_size 64\n\
1749\n\
1750.shared .align 4 .f32 sdata[256];\n\
1751\n\
1752.visible .entry softmax_backward_kernel(\n\
1753 .param .u64 grad_ptr,\n\
1754 .param .u64 output_ptr,\n\
1755 .param .u64 out_ptr,\n\
1756 .param .u32 rows,\n\
1757 .param .u32 cols\n\
1758) {\n\
1759 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
1760 .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
1761 .reg .f32 %vg, %vo, %dot, %other_val, %diff, %result;\n\
1762 .reg .pred %p, %loop_p, %reduce_p;\n\
1763\n\
1764 ld.param.u64 %grad, [grad_ptr];\n\
1765 ld.param.u64 %output, [output_ptr];\n\
1766 ld.param.u64 %out, [out_ptr];\n\
1767 ld.param.u32 %rows_reg, [rows];\n\
1768 ld.param.u32 %cols_reg, [cols];\n\
1769\n\
1770 mov.u32 %bid, %ctaid.x;\n\
1771 mov.u32 %bdim, %ntid.x;\n\
1772 mov.u32 %r_tid, %tid.x;\n\
1773 mov.u64 %sbase, sdata;\n\
1774\n\
1775 setp.ge.u32 %p, %bid, %rows_reg;\n\
1776 @%p bra DONE;\n\
1777\n\
1778 // row_off = bid * cols * 4 (byte offset)\n\
1779 cvt.u64.u32 %row_off, %bid;\n\
1780 cvt.u64.u32 %off, %cols_reg;\n\
1781 mul.lo.u64 %row_off, %row_off, %off;\n\
1782 shl.b64 %row_off, %row_off, 2;\n\
1783\n\
1784 // Phase 1: compute partial dot = sum(grad[j] * output[j]) for this thread's elements\n\
1785 mov.f32 %dot, 0f00000000;\n\
1786 mov.u32 %j, %r_tid;\n\
1787DOT_LOOP:\n\
1788 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
1789 @%loop_p bra DOT_LOOP_DONE;\n\
1790 cvt.u64.u32 %off, %j;\n\
1791 shl.b64 %off, %off, 2;\n\
1792 add.u64 %saddr, %grad, %off;\n\
1793 add.u64 %saddr, %saddr, %row_off;\n\
1794 ld.global.f32 %vg, [%saddr];\n\
1795 add.u64 %saddr, %output, %off;\n\
1796 add.u64 %saddr, %saddr, %row_off;\n\
1797 ld.global.f32 %vo, [%saddr];\n\
1798 fma.rn.f32 %dot, %vg, %vo, %dot;\n\
1799 add.u32 %j, %j, %bdim;\n\
1800 bra DOT_LOOP;\n\
1801DOT_LOOP_DONE:\n\
1802\n\
1803 // Store partial dot into shared memory and reduce\n\
1804 cvt.u64.u32 %off, %r_tid;\n\
1805 shl.b64 %off, %off, 2;\n\
1806 add.u64 %saddr, %sbase, %off;\n\
1807 st.shared.f32 [%saddr], %dot;\n\
1808 bar.sync 0;\n\
1809\n\
1810 mov.u32 %half, %bdim;\n\
1811DOT_REDUCE:\n\
1812 shr.u32 %half, %half, 1;\n\
1813 setp.eq.u32 %reduce_p, %half, 0;\n\
1814 @%reduce_p bra DOT_REDUCE_DONE;\n\
1815 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
1816 @%reduce_p bra DOT_REDUCE_SKIP;\n\
1817 add.u32 %other_tid, %r_tid, %half;\n\
1818 cvt.u64.u32 %off, %other_tid;\n\
1819 shl.b64 %off, %off, 2;\n\
1820 add.u64 %saddr, %sbase, %off;\n\
1821 ld.shared.f32 %other_val, [%saddr];\n\
1822 cvt.u64.u32 %off, %r_tid;\n\
1823 shl.b64 %off, %off, 2;\n\
1824 add.u64 %saddr, %sbase, %off;\n\
1825 ld.shared.f32 %dot, [%saddr];\n\
1826 add.f32 %dot, %dot, %other_val;\n\
1827 st.shared.f32 [%saddr], %dot;\n\
1828DOT_REDUCE_SKIP:\n\
1829 bar.sync 0;\n\
1830 bra DOT_REDUCE;\n\
1831DOT_REDUCE_DONE:\n\
1832\n\
1833 // Broadcast dot to all threads\n\
1834 ld.shared.f32 %dot, [sdata];\n\
1835 bar.sync 0;\n\
1836\n\
1837 // Phase 2: out[j] = output[j] * (grad[j] - dot)\n\
1838 mov.u32 %j, %r_tid;\n\
1839WRITE_LOOP:\n\
1840 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
1841 @%loop_p bra WRITE_LOOP_DONE;\n\
1842 cvt.u64.u32 %off, %j;\n\
1843 shl.b64 %off, %off, 2;\n\
1844 add.u64 %saddr, %grad, %off;\n\
1845 add.u64 %saddr, %saddr, %row_off;\n\
1846 ld.global.f32 %vg, [%saddr];\n\
1847 add.u64 %saddr, %output, %off;\n\
1848 add.u64 %saddr, %saddr, %row_off;\n\
1849 ld.global.f32 %vo, [%saddr];\n\
1850 sub.f32 %diff, %vg, %dot;\n\
1851 mul.f32 %result, %vo, %diff;\n\
1852 add.u64 %saddr, %out, %off;\n\
1853 add.u64 %saddr, %saddr, %row_off;\n\
1854 st.global.f32 [%saddr], %result;\n\
1855 add.u32 %j, %j, %bdim;\n\
1856 bra WRITE_LOOP;\n\
1857WRITE_LOOP_DONE:\n\
1858\n\
1859DONE:\n\
1860 ret;\n\
1861}\n\
1862";
1863
1864#[cfg(feature = "cuda")]
1878pub(crate) const REDUCE_SUM_PTX: &str = "\
1879.version 7.0
1880.target sm_52
1881.address_size 64
1882
1883// Shared memory for intra-block reduction (256 floats = 1024 bytes).
1884.shared .align 4 .f32 sdata[256];
1885
1886.visible .entry reduce_sum_kernel(
1887 .param .u64 in_ptr,
1888 .param .u64 out_ptr,
1889 .param .u32 n
1890) {
1891 .reg .u32 %tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
1892 .reg .u64 %in, %out, %off;
1893 .reg .f32 %sum, %other;
1894 .reg .pred %p, %ptid;
1895
1896 ld.param.u64 %in, [in_ptr];
1897 ld.param.u64 %out, [out_ptr];
1898 ld.param.u32 %n_reg, [n];
1899
1900 mov.u32 %tid, %tid.x;
1901 mov.u32 %bid, %ctaid.x;
1902 mov.u32 %bdim, %ntid.x;
1903 mov.u32 %gdim, %nctaid.x;
1904
1905 // Grid-stride accumulation: each thread sums multiple elements.
1906 // idx = bid * bdim + tid; stride = bdim * gdim
1907 mad.lo.u32 %idx, %bid, %bdim, %tid;
1908 mul.lo.u32 %stride, %bdim, %gdim;
1909 mov.f32 %sum, 0f00000000;
1910
1911GRID_LOOP:
1912 setp.ge.u32 %p, %idx, %n_reg;
1913 @%p bra GRID_DONE;
1914
1915 cvt.u64.u32 %off, %idx;
1916 shl.b64 %off, %off, 2;
1917 add.u64 %off, %in, %off;
1918 ld.global.f32 %other, [%off];
1919 add.f32 %sum, %sum, %other;
1920 add.u32 %idx, %idx, %stride;
1921 bra GRID_LOOP;
1922
1923GRID_DONE:
1924 // Write thread's partial sum to shared memory.
1925 cvt.u64.u32 %off, %tid;
1926 shl.b64 %off, %off, 2;
1927 st.shared.f32 [sdata + %off], %sum;
1928 bar.sync 0;
1929
1930 // Tree reduction in shared memory.
1931 mov.u32 %half, 128;
1932TREE_LOOP:
1933 setp.lt.u32 %p, %half, 1;
1934 @%p bra TREE_DONE;
1935
1936 setp.ge.u32 %ptid, %tid, %half;
1937 @%ptid bra TREE_SKIP;
1938
1939 // Load partner's value from sdata[tid + half].
1940 add.u32 %idx, %tid, %half;
1941 cvt.u64.u32 %off, %idx;
1942 shl.b64 %off, %off, 2;
1943 ld.shared.f32 %other, [sdata + %off];
1944 // Load own value.
1945 cvt.u64.u32 %off, %tid;
1946 shl.b64 %off, %off, 2;
1947 ld.shared.f32 %sum, [sdata + %off];
1948 add.f32 %sum, %sum, %other;
1949 st.shared.f32 [sdata + %off], %sum;
1950
1951TREE_SKIP:
1952 bar.sync 0;
1953 shr.u32 %half, %half, 1;
1954 bra TREE_LOOP;
1955
1956TREE_DONE:
1957 // Thread 0 writes block result.
1958 setp.ne.u32 %ptid, %tid, 0;
1959 @%ptid bra END;
1960
1961 ld.shared.f32 %sum, [sdata];
1962 cvt.u64.u32 %off, %bid;
1963 shl.b64 %off, %off, 2;
1964 add.u64 %out, %out, %off;
1965 st.global.f32 [%out], %sum;
1966
1967END:
1968 ret;
1969}
1970";
1971
1972#[cfg(feature = "cuda")]
1976pub(crate) const SUM_AXIS_PTX: &str = "\
1977.version 7.0
1978.target sm_52
1979.address_size 64
1980
1981.visible .entry sum_axis_kernel(
1982 .param .u64 input_ptr,
1983 .param .u64 output_ptr,
1984 .param .u32 outer_size,
1985 .param .u32 axis_size,
1986 .param .u32 inner_size,
1987 .param .u32 total_output
1988) {
1989 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %axis_sz, %inner_sz;
1990 .reg .u32 %outer_idx, %inner_idx, %k, %tmp;
1991 .reg .u64 %in, %out, %off, %addr;
1992 .reg .f32 %val, %sum;
1993 .reg .pred %p, %lp;
1994
1995 ld.param.u64 %in, [input_ptr];
1996 ld.param.u64 %out, [output_ptr];
1997 ld.param.u32 %outer_sz, [outer_size];
1998 ld.param.u32 %axis_sz, [axis_size];
1999 ld.param.u32 %inner_sz, [inner_size];
2000 ld.param.u32 %n_reg, [total_output];
2001
2002 mov.u32 %bid, %ctaid.x;
2003 mov.u32 %bdim, %ntid.x;
2004 mov.u32 %r_tid, %tid.x;
2005 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2006
2007 setp.ge.u32 %p, %r_tid, %n_reg;
2008 @%p bra DONE;
2009
2010 // outer_idx = r_tid / inner_size
2011 div.u32 %outer_idx, %r_tid, %inner_sz;
2012 // inner_idx = r_tid % inner_size
2013 rem.u32 %inner_idx, %r_tid, %inner_sz;
2014
2015 // base = outer_idx * axis_size * inner_size + inner_idx
2016 mul.lo.u32 %tmp, %outer_idx, %axis_sz;
2017 mul.lo.u32 %tmp, %tmp, %inner_sz;
2018 add.u32 %tmp, %tmp, %inner_idx;
2019
2020 mov.f32 %sum, 0f00000000;
2021 mov.u32 %k, 0;
2022SUM_LOOP:
2023 setp.ge.u32 %lp, %k, %axis_sz;
2024 @%lp bra SUM_LOOP_DONE;
2025
2026 // addr = in + (tmp + k * inner_size) * 4
2027 mul.lo.u32 %inner_idx, %k, %inner_sz;
2028 add.u32 %inner_idx, %tmp, %inner_idx;
2029 cvt.u64.u32 %off, %inner_idx;
2030 shl.b64 %off, %off, 2;
2031 add.u64 %addr, %in, %off;
2032 ld.global.f32 %val, [%addr];
2033 add.f32 %sum, %sum, %val;
2034
2035 add.u32 %k, %k, 1;
2036 bra SUM_LOOP;
2037SUM_LOOP_DONE:
2038
2039 // output[r_tid] = sum
2040 cvt.u64.u32 %off, %r_tid;
2041 shl.b64 %off, %off, 2;
2042 add.u64 %addr, %out, %off;
2043 st.global.f32 [%addr], %sum;
2044
2045DONE:
2046 ret;
2047}
2048";
2049
2050#[cfg(feature = "cuda")]
2060pub(crate) const LAYERNORM_PTX: &str = "\
2061.version 7.0
2062.target sm_52
2063.address_size 64
2064
2065.shared .align 4 .f32 sdata[256];
2066
2067.visible .entry layernorm_kernel(
2068 .param .u64 in_ptr,
2069 .param .u64 out_ptr,
2070 .param .u64 w_ptr,
2071 .param .u64 b_ptr,
2072 .param .u32 rows,
2073 .param .u32 cols,
2074 .param .f32 eps
2075) {
2076 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
2077 .reg .u64 %in, %out, %w, %b, %row_off, %off, %sbase, %saddr;
2078 .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %normed, %wv, %bv, %result, %other_val, %n_f;
2079 .reg .pred %p, %lp, %rp;
2080
2081 ld.param.u64 %in, [in_ptr];
2082 ld.param.u64 %out, [out_ptr];
2083 ld.param.u64 %w, [w_ptr];
2084 ld.param.u64 %b, [b_ptr];
2085 ld.param.u32 %rows_reg, [rows];
2086 ld.param.u32 %cols_reg, [cols];
2087 ld.param.f32 %eps_r, [eps];
2088
2089 mov.u64 %sbase, sdata;
2090
2091 mov.u32 %r_bid, %ctaid.x;
2092 mov.u32 %r_bdim, %ntid.x;
2093 mov.u32 %r_tid, %tid.x;
2094
2095 setp.ge.u32 %p, %r_bid, %rows_reg;
2096 @%p bra DONE;
2097
2098 cvt.u64.u32 %row_off, %r_bid;
2099 cvt.u64.u32 %off, %cols_reg;
2100 mul.lo.u64 %row_off, %row_off, %off;
2101 shl.b64 %row_off, %row_off, 2;
2102 cvt.rn.f32.u32 %n_f, %cols_reg;
2103
2104 mov.f32 %mean, 0f00000000;
2105 mov.u32 %j, %r_tid;
2106SM:
2107 setp.ge.u32 %lp, %j, %cols_reg;
2108 @%lp bra SMD;
2109 cvt.u64.u32 %off, %j;
2110 shl.b64 %off, %off, 2;
2111 add.u64 %off, %in, %off;
2112 add.u64 %off, %off, %row_off;
2113 ld.global.f32 %val, [%off];
2114 add.f32 %mean, %mean, %val;
2115 add.u32 %j, %j, %r_bdim;
2116 bra SM;
2117SMD:
2118 cvt.u64.u32 %off, %r_tid;
2119 shl.b64 %off, %off, 2;
2120 add.u64 %saddr, %sbase, %off;
2121 st.shared.f32 [%saddr], %mean;
2122 bar.sync 0;
2123 mov.u32 %half, %r_bdim;
2124MR:
2125 shr.u32 %half, %half, 1;
2126 setp.eq.u32 %rp, %half, 0;
2127 @%rp bra MRD;
2128 setp.ge.u32 %rp, %r_tid, %half;
2129 @%rp bra MRS;
2130 add.u32 %r_otid, %r_tid, %half;
2131 cvt.u64.u32 %off, %r_otid;
2132 shl.b64 %off, %off, 2;
2133 add.u64 %saddr, %sbase, %off;
2134 ld.shared.f32 %other_val, [%saddr];
2135 cvt.u64.u32 %off, %r_tid;
2136 shl.b64 %off, %off, 2;
2137 add.u64 %saddr, %sbase, %off;
2138 ld.shared.f32 %mean, [%saddr];
2139 add.f32 %mean, %mean, %other_val;
2140 add.u64 %saddr, %sbase, %off;
2141 st.shared.f32 [%saddr], %mean;
2142MRS:
2143 bar.sync 0;
2144 bra MR;
2145MRD:
2146 ld.shared.f32 %mean, [%sbase];
2147 div.approx.f32 %mean, %mean, %n_f;
2148 bar.sync 0;
2149
2150 mov.f32 %var, 0f00000000;
2151 mov.u32 %j, %r_tid;
2152SV:
2153 setp.ge.u32 %lp, %j, %cols_reg;
2154 @%lp bra SVD;
2155 cvt.u64.u32 %off, %j;
2156 shl.b64 %off, %off, 2;
2157 add.u64 %off, %in, %off;
2158 add.u64 %off, %off, %row_off;
2159 ld.global.f32 %val, [%off];
2160 sub.f32 %diff, %val, %mean;
2161 fma.rn.f32 %var, %diff, %diff, %var;
2162 add.u32 %j, %j, %r_bdim;
2163 bra SV;
2164SVD:
2165 cvt.u64.u32 %off, %r_tid;
2166 shl.b64 %off, %off, 2;
2167 add.u64 %saddr, %sbase, %off;
2168 st.shared.f32 [%saddr], %var;
2169 bar.sync 0;
2170 mov.u32 %half, %r_bdim;
2171VR:
2172 shr.u32 %half, %half, 1;
2173 setp.eq.u32 %rp, %half, 0;
2174 @%rp bra VRD;
2175 setp.ge.u32 %rp, %r_tid, %half;
2176 @%rp bra VRS;
2177 add.u32 %r_otid, %r_tid, %half;
2178 cvt.u64.u32 %off, %r_otid;
2179 shl.b64 %off, %off, 2;
2180 add.u64 %saddr, %sbase, %off;
2181 ld.shared.f32 %other_val, [%saddr];
2182 cvt.u64.u32 %off, %r_tid;
2183 shl.b64 %off, %off, 2;
2184 add.u64 %saddr, %sbase, %off;
2185 ld.shared.f32 %var, [%saddr];
2186 add.f32 %var, %var, %other_val;
2187 add.u64 %saddr, %sbase, %off;
2188 st.shared.f32 [%saddr], %var;
2189VRS:
2190 bar.sync 0;
2191 bra VR;
2192VRD:
2193 ld.shared.f32 %var, [%sbase];
2194 div.approx.f32 %var, %var, %n_f;
2195 add.f32 %var, %var, %eps_r;
2196 sqrt.approx.f32 %inv_std, %var;
2197 rcp.approx.f32 %inv_std, %inv_std;
2198 bar.sync 0;
2199
2200 mov.u32 %j, %r_tid;
2201NM:
2202 setp.ge.u32 %lp, %j, %cols_reg;
2203 @%lp bra NMD;
2204 cvt.u64.u32 %off, %j;
2205 shl.b64 %off, %off, 2;
2206 add.u64 %off, %in, %off;
2207 add.u64 %off, %off, %row_off;
2208 ld.global.f32 %val, [%off];
2209 sub.f32 %normed, %val, %mean;
2210 mul.f32 %normed, %normed, %inv_std;
2211 cvt.u64.u32 %off, %j;
2212 shl.b64 %off, %off, 2;
2213 add.u64 %off, %w, %off;
2214 ld.global.f32 %wv, [%off];
2215 cvt.u64.u32 %off, %j;
2216 shl.b64 %off, %off, 2;
2217 add.u64 %off, %b, %off;
2218 ld.global.f32 %bv, [%off];
2219 fma.rn.f32 %result, %wv, %normed, %bv;
2220 cvt.u64.u32 %off, %j;
2221 shl.b64 %off, %off, 2;
2222 add.u64 %off, %out, %off;
2223 add.u64 %off, %off, %row_off;
2224 st.global.f32 [%off], %result;
2225 add.u32 %j, %j, %r_bdim;
2226 bra NM;
2227NMD:
2228
2229DONE:
2230 ret;
2231}
2232";
2233
2234#[cfg(feature = "cuda")]
2259pub(crate) const LAYERNORM_BACKWARD_PTX: &str = "\
2260.version 7.0
2261.target sm_52
2262.address_size 64
2263
2264.shared .align 4 .f32 sdata[256];
2265
2266.visible .entry layernorm_backward_kernel(
2267 .param .u64 in_ptr,
2268 .param .u64 grad_out_ptr,
2269 .param .u64 w_ptr,
2270 .param .u64 grad_in_ptr,
2271 .param .u64 grad_w_ptr,
2272 .param .u64 grad_b_ptr,
2273 .param .u32 rows,
2274 .param .u32 cols,
2275 .param .f32 eps
2276) {
2277 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
2278 .reg .u64 %in, %go, %w, %gi, %gw, %gb, %row_off, %off, %sbase, %saddr, %addr;
2279 .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %x_hat, %wv, %gov;
2280 .reg .f32 %dl_dx_hat, %sum1, %sum2, %other_val, %n_f, %mean1, %mean2, %result;
2281 .reg .pred %p, %lp, %rp;
2282
2283 ld.param.u64 %in, [in_ptr];
2284 ld.param.u64 %go, [grad_out_ptr];
2285 ld.param.u64 %w, [w_ptr];
2286 ld.param.u64 %gi, [grad_in_ptr];
2287 ld.param.u64 %gw, [grad_w_ptr];
2288 ld.param.u64 %gb, [grad_b_ptr];
2289 ld.param.u32 %rows_reg, [rows];
2290 ld.param.u32 %cols_reg, [cols];
2291 ld.param.f32 %eps_r, [eps];
2292
2293 mov.u64 %sbase, sdata;
2294
2295 mov.u32 %r_bid, %ctaid.x;
2296 mov.u32 %r_bdim, %ntid.x;
2297 mov.u32 %r_tid, %tid.x;
2298
2299 setp.ge.u32 %p, %r_bid, %rows_reg;
2300 @%p bra LNB_DONE;
2301
2302 // row_off = bid * cols * 4 (byte offset for this row)
2303 cvt.u64.u32 %row_off, %r_bid;
2304 cvt.u64.u32 %off, %cols_reg;
2305 mul.lo.u64 %row_off, %row_off, %off;
2306 shl.b64 %row_off, %row_off, 2;
2307 cvt.rn.f32.u32 %n_f, %cols_reg;
2308
2309 // ===== Phase 1: Compute mean =====
2310 mov.f32 %mean, 0f00000000;
2311 mov.u32 %j, %r_tid;
2312LNB_SM:
2313 setp.ge.u32 %lp, %j, %cols_reg;
2314 @%lp bra LNB_SMD;
2315 cvt.u64.u32 %off, %j;
2316 shl.b64 %off, %off, 2;
2317 add.u64 %addr, %in, %off;
2318 add.u64 %addr, %addr, %row_off;
2319 ld.global.f32 %val, [%addr];
2320 add.f32 %mean, %mean, %val;
2321 add.u32 %j, %j, %r_bdim;
2322 bra LNB_SM;
2323LNB_SMD:
2324 // Shared memory reduce for mean
2325 cvt.u64.u32 %off, %r_tid;
2326 shl.b64 %off, %off, 2;
2327 add.u64 %saddr, %sbase, %off;
2328 st.shared.f32 [%saddr], %mean;
2329 bar.sync 0;
2330 mov.u32 %half, %r_bdim;
2331LNB_MR:
2332 shr.u32 %half, %half, 1;
2333 setp.eq.u32 %rp, %half, 0;
2334 @%rp bra LNB_MRD;
2335 setp.ge.u32 %rp, %r_tid, %half;
2336 @%rp bra LNB_MRS;
2337 add.u32 %r_otid, %r_tid, %half;
2338 cvt.u64.u32 %off, %r_otid;
2339 shl.b64 %off, %off, 2;
2340 add.u64 %saddr, %sbase, %off;
2341 ld.shared.f32 %other_val, [%saddr];
2342 cvt.u64.u32 %off, %r_tid;
2343 shl.b64 %off, %off, 2;
2344 add.u64 %saddr, %sbase, %off;
2345 ld.shared.f32 %mean, [%saddr];
2346 add.f32 %mean, %mean, %other_val;
2347 st.shared.f32 [%saddr], %mean;
2348LNB_MRS:
2349 bar.sync 0;
2350 bra LNB_MR;
2351LNB_MRD:
2352 ld.shared.f32 %mean, [%sbase];
2353 div.approx.f32 %mean, %mean, %n_f;
2354 bar.sync 0;
2355
2356 // ===== Phase 2: Compute variance =====
2357 mov.f32 %var, 0f00000000;
2358 mov.u32 %j, %r_tid;
2359LNB_SV:
2360 setp.ge.u32 %lp, %j, %cols_reg;
2361 @%lp bra LNB_SVD;
2362 cvt.u64.u32 %off, %j;
2363 shl.b64 %off, %off, 2;
2364 add.u64 %addr, %in, %off;
2365 add.u64 %addr, %addr, %row_off;
2366 ld.global.f32 %val, [%addr];
2367 sub.f32 %diff, %val, %mean;
2368 fma.rn.f32 %var, %diff, %diff, %var;
2369 add.u32 %j, %j, %r_bdim;
2370 bra LNB_SV;
2371LNB_SVD:
2372 // Shared memory reduce for variance
2373 cvt.u64.u32 %off, %r_tid;
2374 shl.b64 %off, %off, 2;
2375 add.u64 %saddr, %sbase, %off;
2376 st.shared.f32 [%saddr], %var;
2377 bar.sync 0;
2378 mov.u32 %half, %r_bdim;
2379LNB_VR:
2380 shr.u32 %half, %half, 1;
2381 setp.eq.u32 %rp, %half, 0;
2382 @%rp bra LNB_VRD;
2383 setp.ge.u32 %rp, %r_tid, %half;
2384 @%rp bra LNB_VRS;
2385 add.u32 %r_otid, %r_tid, %half;
2386 cvt.u64.u32 %off, %r_otid;
2387 shl.b64 %off, %off, 2;
2388 add.u64 %saddr, %sbase, %off;
2389 ld.shared.f32 %other_val, [%saddr];
2390 cvt.u64.u32 %off, %r_tid;
2391 shl.b64 %off, %off, 2;
2392 add.u64 %saddr, %sbase, %off;
2393 ld.shared.f32 %var, [%saddr];
2394 add.f32 %var, %var, %other_val;
2395 st.shared.f32 [%saddr], %var;
2396LNB_VRS:
2397 bar.sync 0;
2398 bra LNB_VR;
2399LNB_VRD:
2400 ld.shared.f32 %var, [%sbase];
2401 div.approx.f32 %var, %var, %n_f;
2402 add.f32 %var, %var, %eps_r;
2403 sqrt.approx.f32 %inv_std, %var;
2404 rcp.approx.f32 %inv_std, %inv_std;
2405 bar.sync 0;
2406
2407 // ===== Phase 3: Compute sum1 = sum(dl_dx_hat), sum2 = sum(dl_dx_hat * x_hat) =====
2408 // Also accumulate grad_weight and grad_bias via atomicAdd
2409 mov.f32 %sum1, 0f00000000;
2410 mov.f32 %sum2, 0f00000000;
2411 mov.u32 %j, %r_tid;
2412LNB_S12:
2413 setp.ge.u32 %lp, %j, %cols_reg;
2414 @%lp bra LNB_S12D;
2415 // Load input[row, j]
2416 cvt.u64.u32 %off, %j;
2417 shl.b64 %off, %off, 2;
2418 add.u64 %addr, %in, %off;
2419 add.u64 %addr, %addr, %row_off;
2420 ld.global.f32 %val, [%addr];
2421 // x_hat = (val - mean) * inv_std
2422 sub.f32 %x_hat, %val, %mean;
2423 mul.f32 %x_hat, %x_hat, %inv_std;
2424 // Load grad_output[row, j]
2425 cvt.u64.u32 %off, %j;
2426 shl.b64 %off, %off, 2;
2427 add.u64 %addr, %go, %off;
2428 add.u64 %addr, %addr, %row_off;
2429 ld.global.f32 %gov, [%addr];
2430 // Load weight[j]
2431 cvt.u64.u32 %off, %j;
2432 shl.b64 %off, %off, 2;
2433 add.u64 %addr, %w, %off;
2434 ld.global.f32 %wv, [%addr];
2435 // dl_dx_hat = grad_output * weight
2436 mul.f32 %dl_dx_hat, %gov, %wv;
2437 // Accumulate sums
2438 add.f32 %sum1, %sum1, %dl_dx_hat;
2439 fma.rn.f32 %sum2, %dl_dx_hat, %x_hat, %sum2;
2440 // atomicAdd grad_weight[j] += grad_output * x_hat
2441 cvt.u64.u32 %off, %j;
2442 shl.b64 %off, %off, 2;
2443 add.u64 %addr, %gw, %off;
2444 mul.f32 %result, %gov, %x_hat;
2445 atom.global.add.f32 %result, [%addr], %result;
2446 // atomicAdd grad_bias[j] += grad_output
2447 add.u64 %addr, %gb, %off;
2448 atom.global.add.f32 %result, [%addr], %gov;
2449 add.u32 %j, %j, %r_bdim;
2450 bra LNB_S12;
2451LNB_S12D:
2452 // Reduce sum1 in shared memory
2453 cvt.u64.u32 %off, %r_tid;
2454 shl.b64 %off, %off, 2;
2455 add.u64 %saddr, %sbase, %off;
2456 st.shared.f32 [%saddr], %sum1;
2457 bar.sync 0;
2458 mov.u32 %half, %r_bdim;
2459LNB_R1:
2460 shr.u32 %half, %half, 1;
2461 setp.eq.u32 %rp, %half, 0;
2462 @%rp bra LNB_R1D;
2463 setp.ge.u32 %rp, %r_tid, %half;
2464 @%rp bra LNB_R1S;
2465 add.u32 %r_otid, %r_tid, %half;
2466 cvt.u64.u32 %off, %r_otid;
2467 shl.b64 %off, %off, 2;
2468 add.u64 %saddr, %sbase, %off;
2469 ld.shared.f32 %other_val, [%saddr];
2470 cvt.u64.u32 %off, %r_tid;
2471 shl.b64 %off, %off, 2;
2472 add.u64 %saddr, %sbase, %off;
2473 ld.shared.f32 %sum1, [%saddr];
2474 add.f32 %sum1, %sum1, %other_val;
2475 st.shared.f32 [%saddr], %sum1;
2476LNB_R1S:
2477 bar.sync 0;
2478 bra LNB_R1;
2479LNB_R1D:
2480 ld.shared.f32 %sum1, [%sbase];
2481 // mean1 = sum1 / n
2482 div.approx.f32 %mean1, %sum1, %n_f;
2483 bar.sync 0;
2484
2485 // Reduce sum2 in shared memory
2486 cvt.u64.u32 %off, %r_tid;
2487 shl.b64 %off, %off, 2;
2488 add.u64 %saddr, %sbase, %off;
2489 st.shared.f32 [%saddr], %sum2;
2490 bar.sync 0;
2491 mov.u32 %half, %r_bdim;
2492LNB_R2:
2493 shr.u32 %half, %half, 1;
2494 setp.eq.u32 %rp, %half, 0;
2495 @%rp bra LNB_R2D;
2496 setp.ge.u32 %rp, %r_tid, %half;
2497 @%rp bra LNB_R2S;
2498 add.u32 %r_otid, %r_tid, %half;
2499 cvt.u64.u32 %off, %r_otid;
2500 shl.b64 %off, %off, 2;
2501 add.u64 %saddr, %sbase, %off;
2502 ld.shared.f32 %other_val, [%saddr];
2503 cvt.u64.u32 %off, %r_tid;
2504 shl.b64 %off, %off, 2;
2505 add.u64 %saddr, %sbase, %off;
2506 ld.shared.f32 %sum2, [%saddr];
2507 add.f32 %sum2, %sum2, %other_val;
2508 st.shared.f32 [%saddr], %sum2;
2509LNB_R2S:
2510 bar.sync 0;
2511 bra LNB_R2;
2512LNB_R2D:
2513 ld.shared.f32 %sum2, [%sbase];
2514 // mean2 = sum2 / n
2515 div.approx.f32 %mean2, %sum2, %n_f;
2516 bar.sync 0;
2517
2518 // ===== Phase 4: Compute grad_input =====
2519 // grad_input[j] = inv_std * (dl_dx_hat[j] - mean1 - x_hat[j] * mean2)
2520 mov.u32 %j, %r_tid;
2521LNB_GI:
2522 setp.ge.u32 %lp, %j, %cols_reg;
2523 @%lp bra LNB_GID;
2524 // Reload input to recompute x_hat
2525 cvt.u64.u32 %off, %j;
2526 shl.b64 %off, %off, 2;
2527 add.u64 %addr, %in, %off;
2528 add.u64 %addr, %addr, %row_off;
2529 ld.global.f32 %val, [%addr];
2530 sub.f32 %x_hat, %val, %mean;
2531 mul.f32 %x_hat, %x_hat, %inv_std;
2532 // Reload grad_output and weight to recompute dl_dx_hat
2533 cvt.u64.u32 %off, %j;
2534 shl.b64 %off, %off, 2;
2535 add.u64 %addr, %go, %off;
2536 add.u64 %addr, %addr, %row_off;
2537 ld.global.f32 %gov, [%addr];
2538 cvt.u64.u32 %off, %j;
2539 shl.b64 %off, %off, 2;
2540 add.u64 %addr, %w, %off;
2541 ld.global.f32 %wv, [%addr];
2542 mul.f32 %dl_dx_hat, %gov, %wv;
2543 // result = inv_std * (dl_dx_hat - mean1 - x_hat * mean2)
2544 sub.f32 %result, %dl_dx_hat, %mean1;
2545 mul.f32 %diff, %x_hat, %mean2;
2546 sub.f32 %result, %result, %diff;
2547 mul.f32 %result, %inv_std, %result;
2548 // Store grad_input[row, j]
2549 cvt.u64.u32 %off, %j;
2550 shl.b64 %off, %off, 2;
2551 add.u64 %addr, %gi, %off;
2552 add.u64 %addr, %addr, %row_off;
2553 st.global.f32 [%addr], %result;
2554 add.u32 %j, %j, %r_bdim;
2555 bra LNB_GI;
2556LNB_GID:
2557
2558LNB_DONE:
2559 ret;
2560}
2561";
2562
2563#[cfg(feature = "cuda")]
2596pub(crate) const BATCHNORM_FORWARD_PTX: &str = "\
2597.version 7.0
2598.target sm_52
2599.address_size 64
2600
2601// Shared memory for block reduction
2602.shared .align 4 .f32 smem_sum[256];
2603.shared .align 4 .f32 smem_sq[256];
2604
2605.visible .entry batchnorm_forward_kernel(
2606 .param .u64 input_ptr,
2607 .param .u64 output_ptr,
2608 .param .u64 weight_ptr,
2609 .param .u64 bias_ptr,
2610 .param .u64 rmean_ptr,
2611 .param .u64 rvar_ptr,
2612 .param .u64 save_mean_ptr,
2613 .param .u64 save_invstd_ptr,
2614 .param .u32 channels,
2615 .param .u32 spatial,
2616 .param .f32 eps,
2617 .param .f32 momentum,
2618 .param .u32 total_per_ch,
2619 .param .u32 training
2620) {
2621 .reg .u32 %tid, %bid, %bdim, %ch, %n_ch, %sp, %tpc, %idx, %train;
2622 .reg .u64 %in, %out, %w, %b, %rm, %rv, %sm, %si, %off64, %tmp64;
2623 .reg .f32 %sum, %sqsum, %val, %mean, %var, %invstd;
2624 .reg .f32 %gamma, %beta, %eps_reg, %mom, %other;
2625 .reg .f32 %n_f, %one, %normalized;
2626 .reg .pred %p, %ptrain, %ptid0;
2627 .reg .u32 %half;
2628
2629 ld.param.u64 %in, [input_ptr];
2630 ld.param.u64 %out, [output_ptr];
2631 ld.param.u64 %w, [weight_ptr];
2632 ld.param.u64 %b, [bias_ptr];
2633 ld.param.u64 %rm, [rmean_ptr];
2634 ld.param.u64 %rv, [rvar_ptr];
2635 ld.param.u64 %sm, [save_mean_ptr];
2636 ld.param.u64 %si, [save_invstd_ptr];
2637 ld.param.u32 %n_ch, [channels];
2638 ld.param.u32 %sp, [spatial];
2639 ld.param.f32 %eps_reg, [eps];
2640 ld.param.f32 %mom, [momentum];
2641 ld.param.u32 %tpc, [total_per_ch];
2642 ld.param.u32 %train, [training];
2643
2644 mov.u32 %bid, %ctaid.x;
2645 mov.u32 %tid, %tid.x;
2646 mov.u32 %bdim, %ntid.x;
2647 mov.u32 %ch, %bid;
2648 mov.f32 %one, 0f3F800000;
2649
2650 setp.ge.u32 %p, %ch, %n_ch;
2651 @%p bra END;
2652
2653 setp.ne.u32 %ptrain, %train, 0;
2654
2655 // ---- Pass 1: compute sum and sum-of-squares for this channel ----
2656 mov.f32 %sum, 0f00000000;
2657 mov.f32 %sqsum, 0f00000000;
2658
2659 // Grid-stride loop over B*spatial for this channel
2660 mov.u32 %idx, %tid;
2661PASS1_LOOP:
2662 setp.ge.u32 %p, %idx, %tpc;
2663 @%p bra PASS1_DONE;
2664
2665 // Linear offset = (idx / spatial) * channels * spatial + ch * spatial + idx % spatial
2666 div.u32 %half, %idx, %sp;
2667 rem.u32 %half, %idx, %sp; // reuse half as spatial_idx
2668 // batch_offset = (idx / sp) * (n_ch * sp) + ch * sp + (idx % sp)
2669 div.u32 %half, %idx, %sp; // batch_idx
2670 mul.lo.u32 %half, %half, %n_ch;
2671 add.u32 %half, %half, %ch;
2672 mul.lo.u32 %half, %half, %sp;
2673 rem.u32 %idx, %idx, %sp; // spatial_idx
2674 add.u32 %half, %half, %idx;
2675
2676 cvt.u64.u32 %off64, %half;
2677 shl.b64 %off64, %off64, 2;
2678 add.u64 %tmp64, %in, %off64;
2679 ld.global.f32 %val, [%tmp64];
2680 add.f32 %sum, %sum, %val;
2681 fma.rn.f32 %sqsum, %val, %val, %sqsum;
2682
2683 // Restore idx for stride
2684 // Recompute idx from tid + iteration * bdim
2685 add.u32 %idx, %idx, %bdim; // This is wrong - need proper loop counter
2686 bra PASS1_LOOP;
2687
2688PASS1_DONE:
2689 // Store to shared memory for block reduction
2690 cvt.u64.u32 %off64, %tid;
2691 shl.b64 %off64, %off64, 2;
2692 st.shared.f32 [smem_sum + %off64], %sum;
2693 st.shared.f32 [smem_sq + %off64], %sqsum;
2694 bar.sync 0;
2695
2696 // Tree reduction
2697 mov.u32 %half, 128;
2698REDUCE_LOOP:
2699 setp.lt.u32 %p, %half, 1;
2700 @%p bra REDUCE_DONE;
2701 setp.ge.u32 %p, %tid, %half;
2702 @%p bra REDUCE_SKIP;
2703
2704 add.u32 %idx, %tid, %half;
2705 cvt.u64.u32 %off64, %idx;
2706 shl.b64 %off64, %off64, 2;
2707 ld.shared.f32 %other, [smem_sum + %off64];
2708 cvt.u64.u32 %tmp64, %tid;
2709 shl.b64 %tmp64, %tmp64, 2;
2710 ld.shared.f32 %sum, [smem_sum + %tmp64];
2711 add.f32 %sum, %sum, %other;
2712 st.shared.f32 [smem_sum + %tmp64], %sum;
2713
2714 ld.shared.f32 %other, [smem_sq + %off64];
2715 ld.shared.f32 %sqsum, [smem_sq + %tmp64];
2716 add.f32 %sqsum, %sqsum, %other;
2717 st.shared.f32 [smem_sq + %tmp64], %sqsum;
2718
2719REDUCE_SKIP:
2720 bar.sync 0;
2721 shr.u32 %half, %half, 1;
2722 bra REDUCE_LOOP;
2723
2724REDUCE_DONE:
2725 // Thread 0 computes mean and invstd
2726 setp.ne.u32 %ptid0, %tid, 0;
2727
2728 @%ptid0 bra WAIT_STATS;
2729
2730 ld.shared.f32 %sum, [smem_sum];
2731 ld.shared.f32 %sqsum, [smem_sq];
2732 cvt.rn.f32.u32 %n_f, %tpc;
2733 div.rn.f32 %mean, %sum, %n_f;
2734 // var = sqsum/n - mean^2
2735 div.rn.f32 %var, %sqsum, %n_f;
2736 fma.rn.f32 %var, %mean, %mean, %var; // This adds mean^2, need to subtract
2737 // Actually: var = E[x^2] - E[x]^2, so var = sqsum/n - mean^2
2738 // We had: var = sqsum/n, now subtract mean^2
2739 neg.f32 %other, %mean;
2740 fma.rn.f32 %var, %other, %mean, %var; // var = var + (-mean)*mean = sqsum/n - mean^2
2741
2742 // invstd = 1/sqrt(var + eps)
2743 add.f32 %other, %var, %eps_reg;
2744 sqrt.rn.f32 %other, %other;
2745 div.rn.f32 %invstd, %one, %other;
2746
2747 // Save mean and invstd
2748 cvt.u64.u32 %off64, %ch;
2749 shl.b64 %off64, %off64, 2;
2750 add.u64 %tmp64, %sm, %off64;
2751 st.global.f32 [%tmp64], %mean;
2752 add.u64 %tmp64, %si, %off64;
2753 st.global.f32 [%tmp64], %invstd;
2754
2755 // Store to shared for other threads
2756 st.shared.f32 [smem_sum], %mean;
2757 st.shared.f32 [smem_sq], %invstd;
2758
2759WAIT_STATS:
2760 bar.sync 0;
2761 // All threads read mean and invstd from shared
2762 ld.shared.f32 %mean, [smem_sum];
2763 ld.shared.f32 %invstd, [smem_sq];
2764
2765 // Load weight and bias for this channel
2766 cvt.u64.u32 %off64, %ch;
2767 shl.b64 %off64, %off64, 2;
2768 add.u64 %tmp64, %w, %off64;
2769 ld.global.f32 %gamma, [%tmp64];
2770 add.u64 %tmp64, %b, %off64;
2771 ld.global.f32 %beta, [%tmp64];
2772
2773 // ---- Pass 2: normalize + affine ----
2774 // For now this is a placeholder - the indexing needs to match pass 1
2775 // Each thread normalizes its elements
2776
2777END:
2778 ret;
2779}
2780";
2781
2782#[cfg(feature = "cuda")]
2787pub(crate) const MAXPOOL2D_PTX: &str = "\
2788.version 7.0
2789.target sm_52
2790.address_size 64
2791
2792.visible .entry maxpool2d_forward_kernel(
2793 .param .u64 input_ptr,
2794 .param .u64 output_ptr,
2795 .param .u32 batch,
2796 .param .u32 channels,
2797 .param .u32 h_in,
2798 .param .u32 w_in,
2799 .param .u32 h_out,
2800 .param .u32 w_out,
2801 .param .u32 kh,
2802 .param .u32 kw,
2803 .param .u32 sh,
2804 .param .u32 sw,
2805 .param .u32 ph,
2806 .param .u32 pw,
2807 .param .u32 total
2808) {
2809 .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
2810 .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp;
2811 .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
2812 .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
2813 .reg .u32 %batch_reg, %ch_reg;
2814 .reg .u64 %in, %out, %off64, %tmp64;
2815 .reg .f32 %max_val, %cur_val, %neg_inf;
2816 .reg .pred %p, %p_bounds, %p_gt;
2817
2818 ld.param.u64 %in, [input_ptr];
2819 ld.param.u64 %out, [output_ptr];
2820 ld.param.u32 %batch_reg, [batch];
2821 ld.param.u32 %ch_reg, [channels];
2822 ld.param.u32 %h_in_reg, [h_in];
2823 ld.param.u32 %w_in_reg, [w_in];
2824 ld.param.u32 %h_out_reg, [h_out];
2825 ld.param.u32 %w_out_reg, [w_out];
2826 ld.param.u32 %kh_reg, [kh];
2827 ld.param.u32 %kw_reg, [kw];
2828 ld.param.u32 %sh_reg, [sh];
2829 ld.param.u32 %sw_reg, [sw];
2830 ld.param.u32 %ph_reg, [ph];
2831 ld.param.u32 %pw_reg, [pw];
2832 ld.param.u32 %total_reg, [total];
2833
2834 mov.u32 %bid, %ctaid.x;
2835 mov.u32 %bdim, %ntid.x;
2836 mov.u32 %tid, %tid.x;
2837 mov.u32 %gdim, %nctaid.x;
2838 mad.lo.u32 %idx, %bid, %bdim, %tid;
2839 mul.lo.u32 %stride, %bdim, %gdim;
2840
2841 // -inf for max initialization
2842 mov.f32 %neg_inf, 0fFF800000;
2843
2844LOOP:
2845 setp.ge.u32 %p, %idx, %total_reg;
2846 @%p bra END;
2847
2848 // Decompose idx into (b, c, oh, ow)
2849 mov.u32 %rem, %idx;
2850 div.u32 %b_idx, %rem, %ch_reg;
2851 // Actually need: idx = b * C * H_out * W_out + c * H_out * W_out + oh * W_out + ow
2852 // So decompose from the right:
2853 rem.u32 %ow, %rem, %w_out_reg;
2854 div.u32 %rem, %rem, %w_out_reg;
2855 rem.u32 %oh, %rem, %h_out_reg;
2856 div.u32 %rem, %rem, %h_out_reg;
2857 rem.u32 %c_idx, %rem, %ch_reg;
2858 div.u32 %b_idx, %rem, %ch_reg;
2859
2860 mov.f32 %max_val, %neg_inf;
2861
2862 // Slide the kernel window
2863 mov.u32 %i, 0;
2864KH_LOOP:
2865 setp.ge.u32 %p, %i, %kh_reg;
2866 @%p bra KH_DONE;
2867
2868 mov.u32 %j, 0;
2869KW_LOOP:
2870 setp.ge.u32 %p, %j, %kw_reg;
2871 @%p bra KW_DONE;
2872
2873 // ih = oh * sh + i - ph, iw = ow * sw + j - pw
2874 mad.lo.u32 %ih, %oh, %sh_reg, %i;
2875 sub.u32 %ih, %ih, %ph_reg;
2876 mad.lo.u32 %iw, %ow, %sw_reg, %j;
2877 sub.u32 %iw, %iw, %pw_reg;
2878
2879 // Bounds check: 0 <= ih < h_in && 0 <= iw < w_in
2880 // Since unsigned, just check < h_in and < w_in
2881 setp.ge.u32 %p_bounds, %ih, %h_in_reg;
2882 @%p_bounds bra KW_NEXT;
2883 setp.ge.u32 %p_bounds, %iw, %w_in_reg;
2884 @%p_bounds bra KW_NEXT;
2885
2886 // input_offset = b * C * H * W + c * H * W + ih * W + iw
2887 mul.lo.u32 %tmp, %b_idx, %ch_reg;
2888 add.u32 %tmp, %tmp, %c_idx;
2889 mul.lo.u32 %tmp, %tmp, %h_in_reg;
2890 add.u32 %tmp, %tmp, %ih;
2891 mul.lo.u32 %tmp, %tmp, %w_in_reg;
2892 add.u32 %tmp, %tmp, %iw;
2893
2894 cvt.u64.u32 %off64, %tmp;
2895 shl.b64 %off64, %off64, 2;
2896 add.u64 %tmp64, %in, %off64;
2897 ld.global.f32 %cur_val, [%tmp64];
2898
2899 max.f32 %max_val, %max_val, %cur_val;
2900
2901KW_NEXT:
2902 add.u32 %j, %j, 1;
2903 bra KW_LOOP;
2904
2905KW_DONE:
2906 add.u32 %i, %i, 1;
2907 bra KH_LOOP;
2908
2909KH_DONE:
2910 // Store output
2911 cvt.u64.u32 %off64, %idx;
2912 shl.b64 %off64, %off64, 2;
2913 add.u64 %tmp64, %out, %off64;
2914 st.global.f32 [%tmp64], %max_val;
2915
2916 add.u32 %idx, %idx, %stride;
2917 bra LOOP;
2918
2919END:
2920 ret;
2921}
2922";
2923
2924#[cfg(feature = "cuda")]
2929pub(crate) const AVGPOOL2D_PTX: &str = "\
2930.version 7.0
2931.target sm_52
2932.address_size 64
2933
2934.visible .entry avgpool2d_forward_kernel(
2935 .param .u64 input_ptr,
2936 .param .u64 output_ptr,
2937 .param .u32 batch,
2938 .param .u32 channels,
2939 .param .u32 h_in,
2940 .param .u32 w_in,
2941 .param .u32 h_out,
2942 .param .u32 w_out,
2943 .param .u32 kh,
2944 .param .u32 kw,
2945 .param .u32 sh,
2946 .param .u32 sw,
2947 .param .u32 ph,
2948 .param .u32 pw,
2949 .param .u32 total
2950) {
2951 .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
2952 .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp, %count;
2953 .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
2954 .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
2955 .reg .u32 %batch_reg, %ch_reg;
2956 .reg .u64 %in, %out, %off64, %tmp64;
2957 .reg .f32 %sum_val, %cur_val, %count_f, %avg;
2958 .reg .pred %p, %p_bounds;
2959
2960 ld.param.u64 %in, [input_ptr];
2961 ld.param.u64 %out, [output_ptr];
2962 ld.param.u32 %batch_reg, [batch];
2963 ld.param.u32 %ch_reg, [channels];
2964 ld.param.u32 %h_in_reg, [h_in];
2965 ld.param.u32 %w_in_reg, [w_in];
2966 ld.param.u32 %h_out_reg, [h_out];
2967 ld.param.u32 %w_out_reg, [w_out];
2968 ld.param.u32 %kh_reg, [kh];
2969 ld.param.u32 %kw_reg, [kw];
2970 ld.param.u32 %sh_reg, [sh];
2971 ld.param.u32 %sw_reg, [sw];
2972 ld.param.u32 %ph_reg, [ph];
2973 ld.param.u32 %pw_reg, [pw];
2974 ld.param.u32 %total_reg, [total];
2975
2976 mov.u32 %bid, %ctaid.x;
2977 mov.u32 %bdim, %ntid.x;
2978 mov.u32 %tid, %tid.x;
2979 mov.u32 %gdim, %nctaid.x;
2980 mad.lo.u32 %idx, %bid, %bdim, %tid;
2981 mul.lo.u32 %stride, %bdim, %gdim;
2982
2983LOOP:
2984 setp.ge.u32 %p, %idx, %total_reg;
2985 @%p bra END;
2986
2987 // Decompose idx into (b, c, oh, ow) — same as MaxPool2d
2988 mov.u32 %rem, %idx;
2989 rem.u32 %ow, %rem, %w_out_reg;
2990 div.u32 %rem, %rem, %w_out_reg;
2991 rem.u32 %oh, %rem, %h_out_reg;
2992 div.u32 %rem, %rem, %h_out_reg;
2993 rem.u32 %c_idx, %rem, %ch_reg;
2994 div.u32 %b_idx, %rem, %ch_reg;
2995
2996 mov.f32 %sum_val, 0f00000000;
2997 mov.u32 %count, 0;
2998
2999 mov.u32 %i, 0;
3000AKH_LOOP:
3001 setp.ge.u32 %p, %i, %kh_reg;
3002 @%p bra AKH_DONE;
3003
3004 mov.u32 %j, 0;
3005AKW_LOOP:
3006 setp.ge.u32 %p, %j, %kw_reg;
3007 @%p bra AKW_DONE;
3008
3009 mad.lo.u32 %ih, %oh, %sh_reg, %i;
3010 sub.u32 %ih, %ih, %ph_reg;
3011 mad.lo.u32 %iw, %ow, %sw_reg, %j;
3012 sub.u32 %iw, %iw, %pw_reg;
3013
3014 setp.ge.u32 %p_bounds, %ih, %h_in_reg;
3015 @%p_bounds bra AKW_NEXT;
3016 setp.ge.u32 %p_bounds, %iw, %w_in_reg;
3017 @%p_bounds bra AKW_NEXT;
3018
3019 mul.lo.u32 %tmp, %b_idx, %ch_reg;
3020 add.u32 %tmp, %tmp, %c_idx;
3021 mul.lo.u32 %tmp, %tmp, %h_in_reg;
3022 add.u32 %tmp, %tmp, %ih;
3023 mul.lo.u32 %tmp, %tmp, %w_in_reg;
3024 add.u32 %tmp, %tmp, %iw;
3025
3026 cvt.u64.u32 %off64, %tmp;
3027 shl.b64 %off64, %off64, 2;
3028 add.u64 %tmp64, %in, %off64;
3029 ld.global.f32 %cur_val, [%tmp64];
3030
3031 add.f32 %sum_val, %sum_val, %cur_val;
3032 add.u32 %count, %count, 1;
3033
3034AKW_NEXT:
3035 add.u32 %j, %j, 1;
3036 bra AKW_LOOP;
3037
3038AKW_DONE:
3039 add.u32 %i, %i, 1;
3040 bra AKH_LOOP;
3041
3042AKH_DONE:
3043 // avg = sum / count (count_include_pad = false behavior)
3044 cvt.rn.f32.u32 %count_f, %count;
3045 div.rn.f32 %avg, %sum_val, %count_f;
3046
3047 cvt.u64.u32 %off64, %idx;
3048 shl.b64 %off64, %off64, 2;
3049 add.u64 %tmp64, %out, %off64;
3050 st.global.f32 [%tmp64], %avg;
3051
3052 add.u32 %idx, %idx, %stride;
3053 bra LOOP;
3054
3055END:
3056 ret;
3057}
3058";
3059
3060#[cfg(feature = "cuda")]
3061pub(crate) const SOFTMAX_PTX: &str = "\
3062.version 7.0\n\
3063.target sm_52\n\
3064.address_size 64\n\
3065\n\
3066.shared .align 4 .f32 sdata[256];\n\
3067\n\
3068.visible .entry softmax_kernel(\n\
3069 .param .u64 input_ptr,\n\
3070 .param .u64 output_ptr,\n\
3071 .param .u32 rows,\n\
3072 .param .u32 cols\n\
3073) {\n\
3074 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
3075 .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
3076 .reg .f32 %val, %max_val, %sum_val, %exp_val, %result;\n\
3077 .reg .pred %p, %loop_p;\n\
3078 .reg .u32 %half, %other_tid;\n\
3079 .reg .f32 %other_val;\n\
3080 .reg .pred %reduce_p;\n\
3081\n\
3082 ld.param.u64 %in, [input_ptr];\n\
3083 ld.param.u64 %out, [output_ptr];\n\
3084 ld.param.u32 %rows_reg, [rows];\n\
3085 ld.param.u32 %cols_reg, [cols];\n\
3086\n\
3087 mov.u32 %bid, %ctaid.x;\n\
3088 mov.u32 %bdim, %ntid.x;\n\
3089 mov.u32 %r_tid, %tid.x;\n\
3090 mov.u64 %sbase, sdata;\n\
3091\n\
3092 setp.ge.u32 %p, %bid, %rows_reg;\n\
3093 @%p bra DONE;\n\
3094\n\
3095 cvt.u64.u32 %row_off, %bid;\n\
3096 cvt.u64.u32 %off, %cols_reg;\n\
3097 mul.lo.u64 %row_off, %row_off, %off;\n\
3098 shl.b64 %row_off, %row_off, 2;\n\
3099\n\
3100 mov.f32 %max_val, 0fFF800000;\n\
3101 mov.u32 %j, %r_tid;\n\
3102FIND_MAX:\n\
3103 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3104 @%loop_p bra FIND_MAX_DONE;\n\
3105 cvt.u64.u32 %off, %j;\n\
3106 shl.b64 %off, %off, 2;\n\
3107 add.u64 %off, %in, %off;\n\
3108 add.u64 %off, %off, %row_off;\n\
3109 ld.global.f32 %val, [%off];\n\
3110 max.f32 %max_val, %max_val, %val;\n\
3111 add.u32 %j, %j, %bdim;\n\
3112 bra FIND_MAX;\n\
3113FIND_MAX_DONE:\n\
3114\n\
3115 cvt.u64.u32 %off, %r_tid;\n\
3116 shl.b64 %off, %off, 2;\n\
3117 add.u64 %saddr, %sbase, %off;\n\
3118 st.shared.f32 [%saddr], %max_val;\n\
3119 bar.sync 0;\n\
3120\n\
3121 mov.u32 %half, %bdim;\n\
3122MAX_REDUCE:\n\
3123 shr.u32 %half, %half, 1;\n\
3124 setp.eq.u32 %reduce_p, %half, 0;\n\
3125 @%reduce_p bra MAX_REDUCE_DONE;\n\
3126 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
3127 @%reduce_p bra MAX_REDUCE_SKIP;\n\
3128 add.u32 %other_tid, %r_tid, %half;\n\
3129 cvt.u64.u32 %off, %other_tid;\n\
3130 shl.b64 %off, %off, 2;\n\
3131 add.u64 %saddr, %sbase, %off;
3132 ld.shared.f32 %other_val, [%saddr];\n\
3133 cvt.u64.u32 %off, %r_tid;\n\
3134 shl.b64 %off, %off, 2;\n\
3135 add.u64 %saddr, %sbase, %off;\n\
3136 ld.shared.f32 %max_val, [%saddr];\n\
3137 max.f32 %max_val, %max_val, %other_val;\n\
3138 add.u64 %saddr, %sbase, %off;\n\
3139 st.shared.f32 [%saddr], %max_val;\n\
3140MAX_REDUCE_SKIP:\n\
3141 bar.sync 0;\n\
3142 bra MAX_REDUCE;\n\
3143MAX_REDUCE_DONE:\n\
3144\n\
3145 ld.shared.f32 %max_val, [sdata];\n\
3146 bar.sync 0;\n\
3147\n\
3148 mov.f32 %sum_val, 0f00000000;\n\
3149 mov.u32 %j, %r_tid;\n\
3150SUM_EXP:\n\
3151 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3152 @%loop_p bra SUM_EXP_DONE;\n\
3153 cvt.u64.u32 %off, %j;\n\
3154 shl.b64 %off, %off, 2;\n\
3155 add.u64 %off, %in, %off;\n\
3156 add.u64 %off, %off, %row_off;\n\
3157 ld.global.f32 %val, [%off];\n\
3158 sub.f32 %val, %val, %max_val;\n\
3159 mul.f32 %val, %val, 0f3FB8AA3B;\n\
3160 ex2.approx.f32 %exp_val, %val;\n\
3161 add.f32 %sum_val, %sum_val, %exp_val;\n\
3162 cvt.u64.u32 %off, %j;\n\
3163 shl.b64 %off, %off, 2;\n\
3164 add.u64 %off, %out, %off;\n\
3165 add.u64 %off, %off, %row_off;\n\
3166 st.global.f32 [%off], %exp_val;\n\
3167 add.u32 %j, %j, %bdim;\n\
3168 bra SUM_EXP;\n\
3169SUM_EXP_DONE:\n\
3170\n\
3171 cvt.u64.u32 %off, %r_tid;\n\
3172 shl.b64 %off, %off, 2;\n\
3173 add.u64 %saddr, %sbase, %off;\n\
3174 st.shared.f32 [%saddr], %sum_val;\n\
3175 bar.sync 0;\n\
3176\n\
3177 mov.u32 %half, %bdim;\n\
3178SUM_REDUCE:\n\
3179 shr.u32 %half, %half, 1;\n\
3180 setp.eq.u32 %reduce_p, %half, 0;\n\
3181 @%reduce_p bra SUM_REDUCE_DONE;\n\
3182 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
3183 @%reduce_p bra SUM_REDUCE_SKIP;\n\
3184 add.u32 %other_tid, %r_tid, %half;\n\
3185 cvt.u64.u32 %off, %other_tid;\n\
3186 shl.b64 %off, %off, 2;\n\
3187 add.u64 %saddr, %sbase, %off;
3188 ld.shared.f32 %other_val, [%saddr];\n\
3189 cvt.u64.u32 %off, %r_tid;\n\
3190 shl.b64 %off, %off, 2;\n\
3191 add.u64 %saddr, %sbase, %off;\n\
3192 ld.shared.f32 %sum_val, [%saddr];\n\
3193 add.f32 %sum_val, %sum_val, %other_val;\n\
3194 add.u64 %saddr, %sbase, %off;\n\
3195 st.shared.f32 [%saddr], %sum_val;\n\
3196SUM_REDUCE_SKIP:\n\
3197 bar.sync 0;\n\
3198 bra SUM_REDUCE;\n\
3199SUM_REDUCE_DONE:\n\
3200\n\
3201 ld.shared.f32 %sum_val, [sdata];\n\
3202 bar.sync 0;\n\
3203\n\
3204 rcp.approx.f32 %sum_val, %sum_val;\n\
3205 mov.u32 %j, %r_tid;\n\
3206NORMALIZE:\n\
3207 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3208 @%loop_p bra NORMALIZE_DONE;\n\
3209 cvt.u64.u32 %off, %j;\n\
3210 shl.b64 %off, %off, 2;\n\
3211 add.u64 %off, %out, %off;\n\
3212 add.u64 %off, %off, %row_off;\n\
3213 ld.global.f32 %val, [%off];\n\
3214 mul.f32 %result, %val, %sum_val;\n\
3215 st.global.f32 [%off], %result;\n\
3216 add.u32 %j, %j, %bdim;\n\
3217 bra NORMALIZE;\n\
3218NORMALIZE_DONE:\n\
3219\n\
3220DONE:\n\
3221 ret;\n\
3222}\n\
3223";
3224
3225#[cfg(feature = "cuda")]
3230pub(crate) const DROPOUT_PTX: &str = "\
3231.version 7.0\n\
3232.target sm_52\n\
3233.address_size 64\n\
3234\n\
3235.visible .entry dropout_kernel(\n\
3236 .param .u64 input_ptr,\n\
3237 .param .u64 output_ptr,\n\
3238 .param .u32 n,\n\
3239 .param .u32 threshold,\n\
3240 .param .f32 scale,\n\
3241 .param .u32 seed\n\
3242) {\n\
3243 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %thresh, %seed_reg, %rng, %tmp;\n\
3244 .reg .u64 %in, %out, %off;\n\
3245 .reg .f32 %val, %scale_reg, %zero;\n\
3246 .reg .pred %p, %drop_p;\n\
3247\n\
3248 ld.param.u64 %in, [input_ptr];\n\
3249 ld.param.u64 %out, [output_ptr];\n\
3250 ld.param.u32 %n_reg, [n];\n\
3251 ld.param.u32 %thresh, [threshold];\n\
3252 ld.param.f32 %scale_reg, [scale];\n\
3253 ld.param.u32 %seed_reg, [seed];\n\
3254\n\
3255 mov.u32 %bid, %ctaid.x;\n\
3256 mov.u32 %bdim, %ntid.x;\n\
3257 mov.u32 %r_tid, %tid.x;\n\
3258 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
3259\n\
3260 setp.ge.u32 %p, %r_tid, %n_reg;\n\
3261 @%p bra DONE;\n\
3262\n\
3263 mul.lo.u32 %rng, %r_tid, 2654435761;\n\
3264 xor.b32 %rng, %rng, %seed_reg;\n\
3265 shl.b32 %tmp, %rng, 13;\n\
3266 xor.b32 %rng, %rng, %tmp;\n\
3267 shr.b32 %tmp, %rng, 17;\n\
3268 xor.b32 %rng, %rng, %tmp;\n\
3269 shl.b32 %tmp, %rng, 5;\n\
3270 xor.b32 %rng, %rng, %tmp;\n\
3271\n\
3272 cvt.u64.u32 %off, %r_tid;\n\
3273 shl.b64 %off, %off, 2;\n\
3274 add.u64 %in, %in, %off;\n\
3275 add.u64 %out, %out, %off;\n\
3276 ld.global.f32 %val, [%in];\n\
3277\n\
3278 setp.lo.u32 %drop_p, %rng, %thresh;\n\
3279 mov.f32 %zero, 0f00000000;\n\
3280 @%drop_p mov.f32 %val, %zero;\n\
3281 @!%drop_p mul.f32 %val, %val, %scale_reg;\n\
3282\n\
3283 st.global.f32 [%out], %val;\n\
3284\n\
3285DONE:\n\
3286 ret;\n\
3287}\n\
3288";
3289
3290#[cfg(feature = "cuda")]
3313pub(crate) const BROADCAST_ADD_PTX: &str = "\
3314.version 7.0
3315.target sm_52
3316.address_size 64
3317
3318.visible .entry broadcast_add_kernel(
3319 .param .u64 a_ptr,
3320 .param .u64 b_ptr,
3321 .param .u64 out_ptr,
3322 .param .u64 a_strides_ptr,
3323 .param .u64 b_strides_ptr,
3324 .param .u64 out_shape_ptr,
3325 .param .u32 n,
3326 .param .u32 ndim
3327) {
3328 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3329 .reg .u32 %remaining, %a_idx, %b_idx, %d;
3330 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3331 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3332 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3333 .reg .f32 %va, %vb, %vr;
3334 .reg .pred %p, %loop_p;
3335
3336 ld.param.u64 %a, [a_ptr];
3337 ld.param.u64 %b, [b_ptr];
3338 ld.param.u64 %out, [out_ptr];
3339 ld.param.u64 %a_str, [a_strides_ptr];
3340 ld.param.u64 %b_str, [b_strides_ptr];
3341 ld.param.u64 %oshape, [out_shape_ptr];
3342 ld.param.u32 %n_reg, [n];
3343 ld.param.u32 %ndim_reg, [ndim];
3344
3345 // Global thread index.
3346 mov.u32 %bid, %ctaid.x;
3347 mov.u32 %bdim, %ntid.x;
3348 mov.u32 %r_tid, %tid.x;
3349 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3350
3351 setp.ge.u32 %p, %r_tid, %n_reg;
3352 @%p bra DONE;
3353
3354 // Decompose flat index into N-d coordinates and compute A/B indices.
3355 mov.u32 %remaining, %r_tid;
3356 mov.u32 %a_idx, 0;
3357 mov.u32 %b_idx, 0;
3358 mov.u32 %d, %ndim_reg;
3359
3360LOOP:
3361 setp.eq.u32 %loop_p, %d, 0;
3362 @%loop_p bra END_LOOP;
3363
3364 sub.u32 %d, %d, 1;
3365
3366 // Byte offset for dimension d: d * 4.
3367 cvt.u64.u32 %d64, %d;
3368 shl.b64 %d64, %d64, 2;
3369
3370 // Load out_shape[d].
3371 add.u64 %tmp, %oshape, %d64;
3372 ld.global.u32 %shape_d, [%tmp];
3373
3374 // Load a_strides[d] and b_strides[d].
3375 add.u64 %tmp, %a_str, %d64;
3376 ld.global.u32 %a_str_d, [%tmp];
3377 add.u64 %tmp, %b_str, %d64;
3378 ld.global.u32 %b_str_d, [%tmp];
3379
3380 // coord = remaining % shape_d; remaining /= shape_d.
3381 rem.u32 %coord, %remaining, %shape_d;
3382 div.u32 %remaining, %remaining, %shape_d;
3383
3384 // a_idx += coord * a_stride[d]; b_idx += coord * b_stride[d].
3385 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3386 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3387
3388 bra LOOP;
3389END_LOOP:
3390
3391 // Load a[a_idx] and b[b_idx] (f32 = 4 bytes).
3392 cvt.u64.u32 %off_a, %a_idx;
3393 shl.b64 %off_a, %off_a, 2;
3394 add.u64 %off_a, %a, %off_a;
3395 ld.global.f32 %va, [%off_a];
3396
3397 cvt.u64.u32 %off_b, %b_idx;
3398 shl.b64 %off_b, %off_b, 2;
3399 add.u64 %off_b, %b, %off_b;
3400 ld.global.f32 %vb, [%off_b];
3401
3402 // Operation: add.
3403 add.f32 %vr, %va, %vb;
3404
3405 // Store to out[tid].
3406 cvt.u64.u32 %off_out, %r_tid;
3407 shl.b64 %off_out, %off_out, 2;
3408 add.u64 %off_out, %out, %off_out;
3409 st.global.f32 [%off_out], %vr;
3410
3411DONE:
3412 ret;
3413}
3414";
3415
3416#[cfg(feature = "cuda")]
3418pub(crate) const BROADCAST_SUB_PTX: &str = "\
3419.version 7.0
3420.target sm_52
3421.address_size 64
3422
3423.visible .entry broadcast_sub_kernel(
3424 .param .u64 a_ptr,
3425 .param .u64 b_ptr,
3426 .param .u64 out_ptr,
3427 .param .u64 a_strides_ptr,
3428 .param .u64 b_strides_ptr,
3429 .param .u64 out_shape_ptr,
3430 .param .u32 n,
3431 .param .u32 ndim
3432) {
3433 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3434 .reg .u32 %remaining, %a_idx, %b_idx, %d;
3435 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3436 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3437 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3438 .reg .f32 %va, %vb, %vr;
3439 .reg .pred %p, %loop_p;
3440
3441 ld.param.u64 %a, [a_ptr];
3442 ld.param.u64 %b, [b_ptr];
3443 ld.param.u64 %out, [out_ptr];
3444 ld.param.u64 %a_str, [a_strides_ptr];
3445 ld.param.u64 %b_str, [b_strides_ptr];
3446 ld.param.u64 %oshape, [out_shape_ptr];
3447 ld.param.u32 %n_reg, [n];
3448 ld.param.u32 %ndim_reg, [ndim];
3449
3450 mov.u32 %bid, %ctaid.x;
3451 mov.u32 %bdim, %ntid.x;
3452 mov.u32 %r_tid, %tid.x;
3453 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3454 setp.ge.u32 %p, %r_tid, %n_reg;
3455 @%p bra DONE;
3456
3457 mov.u32 %remaining, %r_tid;
3458 mov.u32 %a_idx, 0;
3459 mov.u32 %b_idx, 0;
3460 mov.u32 %d, %ndim_reg;
3461LOOP:
3462 setp.eq.u32 %loop_p, %d, 0;
3463 @%loop_p bra END_LOOP;
3464 sub.u32 %d, %d, 1;
3465 cvt.u64.u32 %d64, %d;
3466 shl.b64 %d64, %d64, 2;
3467 add.u64 %tmp, %oshape, %d64;
3468 ld.global.u32 %shape_d, [%tmp];
3469 add.u64 %tmp, %a_str, %d64;
3470 ld.global.u32 %a_str_d, [%tmp];
3471 add.u64 %tmp, %b_str, %d64;
3472 ld.global.u32 %b_str_d, [%tmp];
3473 rem.u32 %coord, %remaining, %shape_d;
3474 div.u32 %remaining, %remaining, %shape_d;
3475 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3476 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3477 bra LOOP;
3478END_LOOP:
3479
3480 cvt.u64.u32 %off_a, %a_idx;
3481 shl.b64 %off_a, %off_a, 2;
3482 add.u64 %off_a, %a, %off_a;
3483 ld.global.f32 %va, [%off_a];
3484 cvt.u64.u32 %off_b, %b_idx;
3485 shl.b64 %off_b, %off_b, 2;
3486 add.u64 %off_b, %b, %off_b;
3487 ld.global.f32 %vb, [%off_b];
3488
3489 sub.f32 %vr, %va, %vb;
3490
3491 cvt.u64.u32 %off_out, %r_tid;
3492 shl.b64 %off_out, %off_out, 2;
3493 add.u64 %off_out, %out, %off_out;
3494 st.global.f32 [%off_out], %vr;
3495DONE:
3496 ret;
3497}
3498";
3499
3500#[cfg(feature = "cuda")]
3502pub(crate) const BROADCAST_MUL_PTX: &str = "\
3503.version 7.0
3504.target sm_52
3505.address_size 64
3506
3507.visible .entry broadcast_mul_kernel(
3508 .param .u64 a_ptr,
3509 .param .u64 b_ptr,
3510 .param .u64 out_ptr,
3511 .param .u64 a_strides_ptr,
3512 .param .u64 b_strides_ptr,
3513 .param .u64 out_shape_ptr,
3514 .param .u32 n,
3515 .param .u32 ndim
3516) {
3517 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3518 .reg .u32 %remaining, %a_idx, %b_idx, %d;
3519 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3520 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3521 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3522 .reg .f32 %va, %vb, %vr;
3523 .reg .pred %p, %loop_p;
3524
3525 ld.param.u64 %a, [a_ptr];
3526 ld.param.u64 %b, [b_ptr];
3527 ld.param.u64 %out, [out_ptr];
3528 ld.param.u64 %a_str, [a_strides_ptr];
3529 ld.param.u64 %b_str, [b_strides_ptr];
3530 ld.param.u64 %oshape, [out_shape_ptr];
3531 ld.param.u32 %n_reg, [n];
3532 ld.param.u32 %ndim_reg, [ndim];
3533
3534 mov.u32 %bid, %ctaid.x;
3535 mov.u32 %bdim, %ntid.x;
3536 mov.u32 %r_tid, %tid.x;
3537 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3538 setp.ge.u32 %p, %r_tid, %n_reg;
3539 @%p bra DONE;
3540
3541 mov.u32 %remaining, %r_tid;
3542 mov.u32 %a_idx, 0;
3543 mov.u32 %b_idx, 0;
3544 mov.u32 %d, %ndim_reg;
3545LOOP:
3546 setp.eq.u32 %loop_p, %d, 0;
3547 @%loop_p bra END_LOOP;
3548 sub.u32 %d, %d, 1;
3549 cvt.u64.u32 %d64, %d;
3550 shl.b64 %d64, %d64, 2;
3551 add.u64 %tmp, %oshape, %d64;
3552 ld.global.u32 %shape_d, [%tmp];
3553 add.u64 %tmp, %a_str, %d64;
3554 ld.global.u32 %a_str_d, [%tmp];
3555 add.u64 %tmp, %b_str, %d64;
3556 ld.global.u32 %b_str_d, [%tmp];
3557 rem.u32 %coord, %remaining, %shape_d;
3558 div.u32 %remaining, %remaining, %shape_d;
3559 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3560 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3561 bra LOOP;
3562END_LOOP:
3563
3564 cvt.u64.u32 %off_a, %a_idx;
3565 shl.b64 %off_a, %off_a, 2;
3566 add.u64 %off_a, %a, %off_a;
3567 ld.global.f32 %va, [%off_a];
3568 cvt.u64.u32 %off_b, %b_idx;
3569 shl.b64 %off_b, %off_b, 2;
3570 add.u64 %off_b, %b, %off_b;
3571 ld.global.f32 %vb, [%off_b];
3572
3573 mul.f32 %vr, %va, %vb;
3574
3575 cvt.u64.u32 %off_out, %r_tid;
3576 shl.b64 %off_out, %off_out, 2;
3577 add.u64 %off_out, %out, %off_out;
3578 st.global.f32 [%off_out], %vr;
3579DONE:
3580 ret;
3581}
3582";
3583
3584#[cfg(feature = "cuda")]
3587pub(crate) const BROADCAST_DIV_PTX: &str = "\
3588.version 7.0
3589.target sm_52
3590.address_size 64
3591
3592.visible .entry broadcast_div_kernel(
3593 .param .u64 a_ptr,
3594 .param .u64 b_ptr,
3595 .param .u64 out_ptr,
3596 .param .u64 a_strides_ptr,
3597 .param .u64 b_strides_ptr,
3598 .param .u64 out_shape_ptr,
3599 .param .u32 n,
3600 .param .u32 ndim
3601) {
3602 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3603 .reg .u32 %remaining, %a_idx, %b_idx, %d;
3604 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3605 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3606 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3607 .reg .f32 %va, %vb, %vr;
3608 .reg .pred %p, %loop_p;
3609
3610 ld.param.u64 %a, [a_ptr];
3611 ld.param.u64 %b, [b_ptr];
3612 ld.param.u64 %out, [out_ptr];
3613 ld.param.u64 %a_str, [a_strides_ptr];
3614 ld.param.u64 %b_str, [b_strides_ptr];
3615 ld.param.u64 %oshape, [out_shape_ptr];
3616 ld.param.u32 %n_reg, [n];
3617 ld.param.u32 %ndim_reg, [ndim];
3618
3619 mov.u32 %bid, %ctaid.x;
3620 mov.u32 %bdim, %ntid.x;
3621 mov.u32 %r_tid, %tid.x;
3622 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3623 setp.ge.u32 %p, %r_tid, %n_reg;
3624 @%p bra DONE;
3625
3626 mov.u32 %remaining, %r_tid;
3627 mov.u32 %a_idx, 0;
3628 mov.u32 %b_idx, 0;
3629 mov.u32 %d, %ndim_reg;
3630LOOP:
3631 setp.eq.u32 %loop_p, %d, 0;
3632 @%loop_p bra END_LOOP;
3633 sub.u32 %d, %d, 1;
3634 cvt.u64.u32 %d64, %d;
3635 shl.b64 %d64, %d64, 2;
3636 add.u64 %tmp, %oshape, %d64;
3637 ld.global.u32 %shape_d, [%tmp];
3638 add.u64 %tmp, %a_str, %d64;
3639 ld.global.u32 %a_str_d, [%tmp];
3640 add.u64 %tmp, %b_str, %d64;
3641 ld.global.u32 %b_str_d, [%tmp];
3642 rem.u32 %coord, %remaining, %shape_d;
3643 div.u32 %remaining, %remaining, %shape_d;
3644 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3645 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3646 bra LOOP;
3647END_LOOP:
3648
3649 cvt.u64.u32 %off_a, %a_idx;
3650 shl.b64 %off_a, %off_a, 2;
3651 add.u64 %off_a, %a, %off_a;
3652 ld.global.f32 %va, [%off_a];
3653 cvt.u64.u32 %off_b, %b_idx;
3654 shl.b64 %off_b, %off_b, 2;
3655 add.u64 %off_b, %b, %off_b;
3656 ld.global.f32 %vb, [%off_b];
3657
3658 div.f32 %vr, %va, %vb;
3659
3660 cvt.u64.u32 %off_out, %r_tid;
3661 shl.b64 %off_out, %off_out, 2;
3662 add.u64 %off_out, %out, %off_out;
3663 st.global.f32 [%off_out], %vr;
3664DONE:
3665 ret;
3666}
3667";
3668
3669#[cfg(feature = "cuda")]
3677pub(crate) const STRIDED_SPLIT_PTX: &str = "\
3678.version 7.0
3679.target sm_52
3680.address_size 64
3681
3682.visible .entry strided_split_kernel(
3683 .param .u64 input_ptr,
3684 .param .u64 output_ptr,
3685 .param .u32 total_along_axis,
3686 .param .u32 split_offset,
3687 .param .u32 split_size,
3688 .param .u32 inner_size,
3689 .param .u32 n
3690) {
3691 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3692 .reg .u32 %total_ax, %sp_off, %sp_sz, %inner_sz;
3693 .reg .u32 %outer_idx, %within, %chunk_stride, %src_idx, %base_off, %tmp;
3694 .reg .u64 %in, %out, %off;
3695 .reg .f32 %val;
3696 .reg .pred %p;
3697
3698 ld.param.u64 %in, [input_ptr];
3699 ld.param.u64 %out, [output_ptr];
3700 ld.param.u32 %total_ax, [total_along_axis];
3701 ld.param.u32 %sp_off, [split_offset];
3702 ld.param.u32 %sp_sz, [split_size];
3703 ld.param.u32 %inner_sz, [inner_size];
3704 ld.param.u32 %n_reg, [n];
3705
3706 mov.u32 %bid, %ctaid.x;
3707 mov.u32 %bdim, %ntid.x;
3708 mov.u32 %r_tid, %tid.x;
3709 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3710
3711 setp.ge.u32 %p, %r_tid, %n_reg;
3712 @%p bra DONE;
3713
3714 // chunk_stride = split_size * inner_size
3715 mul.lo.u32 %chunk_stride, %sp_sz, %inner_sz;
3716
3717 // outer_idx = r_tid / chunk_stride
3718 div.u32 %outer_idx, %r_tid, %chunk_stride;
3719
3720 // within = r_tid % chunk_stride
3721 rem.u32 %within, %r_tid, %chunk_stride;
3722
3723 // base_off = split_offset * inner_size
3724 mul.lo.u32 %base_off, %sp_off, %inner_sz;
3725
3726 // src_idx = outer_idx * total_along_axis * inner_size + base_off + within
3727 mul.lo.u32 %src_idx, %outer_idx, %total_ax;
3728 mul.lo.u32 %src_idx, %src_idx, %inner_sz;
3729 add.u32 %src_idx, %src_idx, %base_off;
3730 add.u32 %src_idx, %src_idx, %within;
3731
3732 // Load from in[src_idx]
3733 cvt.u64.u32 %off, %src_idx;
3734 shl.b64 %off, %off, 2;
3735 add.u64 %off, %in, %off;
3736 ld.global.f32 %val, [%off];
3737
3738 // Store to out[r_tid]
3739 cvt.u64.u32 %off, %r_tid;
3740 shl.b64 %off, %off, 2;
3741 add.u64 %off, %out, %off;
3742 st.global.f32 [%off], %val;
3743
3744DONE:
3745 ret;
3746}
3747";
3748
3749#[cfg(feature = "cuda")]
3758pub(crate) const STRIDED_CAT_PTX: &str = "\
3759.version 7.0
3760.target sm_52
3761.address_size 64
3762
3763.visible .entry strided_cat_kernel(
3764 .param .u64 input_ptr,
3765 .param .u64 output_ptr,
3766 .param .u32 total_along_axis,
3767 .param .u32 cat_offset,
3768 .param .u32 part_size,
3769 .param .u32 inner_size,
3770 .param .u32 n
3771) {
3772 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3773 .reg .u32 %total_ax, %cat_off, %part_sz, %inner_sz;
3774 .reg .u32 %outer_idx, %within, %chunk_stride, %dst_idx, %base_off;
3775 .reg .u64 %in, %out, %off;
3776 .reg .f32 %val;
3777 .reg .pred %p;
3778
3779 ld.param.u64 %in, [input_ptr];
3780 ld.param.u64 %out, [output_ptr];
3781 ld.param.u32 %total_ax, [total_along_axis];
3782 ld.param.u32 %cat_off, [cat_offset];
3783 ld.param.u32 %part_sz, [part_size];
3784 ld.param.u32 %inner_sz, [inner_size];
3785 ld.param.u32 %n_reg, [n];
3786
3787 mov.u32 %bid, %ctaid.x;
3788 mov.u32 %bdim, %ntid.x;
3789 mov.u32 %r_tid, %tid.x;
3790 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3791
3792 setp.ge.u32 %p, %r_tid, %n_reg;
3793 @%p bra DONE;
3794
3795 // chunk_stride = part_size * inner_size
3796 mul.lo.u32 %chunk_stride, %part_sz, %inner_sz;
3797
3798 // outer_idx = r_tid / chunk_stride
3799 div.u32 %outer_idx, %r_tid, %chunk_stride;
3800
3801 // within = r_tid % chunk_stride
3802 rem.u32 %within, %r_tid, %chunk_stride;
3803
3804 // base_off = cat_offset * inner_size
3805 mul.lo.u32 %base_off, %cat_off, %inner_sz;
3806
3807 // dst_idx = outer_idx * total_along_axis * inner_size + base_off + within
3808 mul.lo.u32 %dst_idx, %outer_idx, %total_ax;
3809 mul.lo.u32 %dst_idx, %dst_idx, %inner_sz;
3810 add.u32 %dst_idx, %dst_idx, %base_off;
3811 add.u32 %dst_idx, %dst_idx, %within;
3812
3813 // Load from in[r_tid]
3814 cvt.u64.u32 %off, %r_tid;
3815 shl.b64 %off, %off, 2;
3816 add.u64 %off, %in, %off;
3817 ld.global.f32 %val, [%off];
3818
3819 // Store to out[dst_idx]
3820 cvt.u64.u32 %off, %dst_idx;
3821 shl.b64 %off, %off, 2;
3822 add.u64 %off, %out, %off;
3823 st.global.f32 [%off], %val;
3824
3825DONE:
3826 ret;
3827}
3828";
3829
3830#[cfg(feature = "cuda")]
3832pub(crate) const DIV_PTX: &str = "\
3833.version 7.0
3834.target sm_52
3835.address_size 64
3836
3837.visible .entry div_kernel(
3838 .param .u64 a_ptr,
3839 .param .u64 b_ptr,
3840 .param .u64 out_ptr,
3841 .param .u32 n
3842) {
3843 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3844 .reg .u64 %a, %b, %out, %off;
3845 .reg .f32 %va, %vb, %vr;
3846 .reg .pred %p;
3847
3848 ld.param.u64 %a, [a_ptr];
3849 ld.param.u64 %b, [b_ptr];
3850 ld.param.u64 %out, [out_ptr];
3851 ld.param.u32 %n_reg, [n];
3852
3853 mov.u32 %bid, %ctaid.x;
3854 mov.u32 %bdim, %ntid.x;
3855 mov.u32 %r_tid, %tid.x;
3856 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3857
3858 setp.ge.u32 %p, %r_tid, %n_reg;
3859 @%p bra DONE;
3860
3861 cvt.u64.u32 %off, %r_tid;
3862 shl.b64 %off, %off, 2;
3863
3864 add.u64 %a, %a, %off;
3865 add.u64 %b, %b, %off;
3866 add.u64 %out, %out, %off;
3867
3868 ld.global.f32 %va, [%a];
3869 ld.global.f32 %vb, [%b];
3870 div.rn.f32 %vr, %va, %vb;
3871 st.global.f32 [%out], %vr;
3872
3873DONE:
3874 ret;
3875}
3876";
3877
3878#[cfg(feature = "cuda")]
3880pub(crate) const EXP_PTX: &str = "\
3881.version 7.0
3882.target sm_52
3883.address_size 64
3884
3885.visible .entry exp_kernel(
3886 .param .u64 a_ptr,
3887 .param .u64 out_ptr,
3888 .param .u32 n
3889) {
3890 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3891 .reg .u64 %a, %out, %off;
3892 .reg .f32 %va, %vr;
3893 .reg .pred %p;
3894
3895 ld.param.u64 %a, [a_ptr];
3896 ld.param.u64 %out, [out_ptr];
3897 ld.param.u32 %n_reg, [n];
3898
3899 mov.u32 %bid, %ctaid.x;
3900 mov.u32 %bdim, %ntid.x;
3901 mov.u32 %r_tid, %tid.x;
3902 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3903
3904 setp.ge.u32 %p, %r_tid, %n_reg;
3905 @%p bra DONE;
3906
3907 cvt.u64.u32 %off, %r_tid;
3908 shl.b64 %off, %off, 2;
3909
3910 add.u64 %a, %a, %off;
3911 add.u64 %out, %out, %off;
3912
3913 ld.global.f32 %va, [%a];
3914 // PTX ex2.approx computes 2^x; use the identity exp(x) = 2^(x * log2(e))
3915 // log2(e) = 1.4426950408889634
3916 mul.f32 %va, %va, 0f3FB8AA3B;
3917 ex2.approx.f32 %vr, %va;
3918 st.global.f32 [%out], %vr;
3919
3920DONE:
3921 ret;
3922}
3923";
3924
3925#[cfg(feature = "cuda")]
3927pub(crate) const LOG_PTX: &str = "\
3928.version 7.0
3929.target sm_52
3930.address_size 64
3931
3932.visible .entry log_kernel(
3933 .param .u64 a_ptr,
3934 .param .u64 out_ptr,
3935 .param .u32 n
3936) {
3937 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3938 .reg .u64 %a, %out, %off;
3939 .reg .f32 %va, %vr;
3940 .reg .pred %p;
3941
3942 ld.param.u64 %a, [a_ptr];
3943 ld.param.u64 %out, [out_ptr];
3944 ld.param.u32 %n_reg, [n];
3945
3946 mov.u32 %bid, %ctaid.x;
3947 mov.u32 %bdim, %ntid.x;
3948 mov.u32 %r_tid, %tid.x;
3949 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3950
3951 setp.ge.u32 %p, %r_tid, %n_reg;
3952 @%p bra DONE;
3953
3954 cvt.u64.u32 %off, %r_tid;
3955 shl.b64 %off, %off, 2;
3956
3957 add.u64 %a, %a, %off;
3958 add.u64 %out, %out, %off;
3959
3960 ld.global.f32 %va, [%a];
3961 // PTX lg2.approx computes log2(x); use the identity ln(x) = log2(x) / log2(e)
3962 // 1/log2(e) = ln(2) = 0.6931471805599453
3963 lg2.approx.f32 %vr, %va;
3964 mul.f32 %vr, %vr, 0f3F317218;
3965 st.global.f32 [%out], %vr;
3966
3967DONE:
3968 ret;
3969}
3970";
3971
3972#[cfg(feature = "cuda")]
3974pub(crate) const SQRT_PTX: &str = "\
3975.version 7.0
3976.target sm_52
3977.address_size 64
3978
3979.visible .entry sqrt_kernel(
3980 .param .u64 a_ptr,
3981 .param .u64 out_ptr,
3982 .param .u32 n
3983) {
3984 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3985 .reg .u64 %a, %out, %off;
3986 .reg .f32 %va, %vr;
3987 .reg .pred %p;
3988
3989 ld.param.u64 %a, [a_ptr];
3990 ld.param.u64 %out, [out_ptr];
3991 ld.param.u32 %n_reg, [n];
3992
3993 mov.u32 %bid, %ctaid.x;
3994 mov.u32 %bdim, %ntid.x;
3995 mov.u32 %r_tid, %tid.x;
3996 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3997
3998 setp.ge.u32 %p, %r_tid, %n_reg;
3999 @%p bra DONE;
4000
4001 cvt.u64.u32 %off, %r_tid;
4002 shl.b64 %off, %off, 2;
4003
4004 add.u64 %a, %a, %off;
4005 add.u64 %out, %out, %off;
4006
4007 ld.global.f32 %va, [%a];
4008 sqrt.rn.f32 %vr, %va;
4009 st.global.f32 [%out], %vr;
4010
4011DONE:
4012 ret;
4013}
4014";
4015
4016#[cfg(feature = "cuda")]
4019pub(crate) const POW_PTX: &str = "\
4020.version 7.0
4021.target sm_52
4022.address_size 64
4023
4024.visible .entry pow_kernel(
4025 .param .u64 a_ptr,
4026 .param .u64 out_ptr,
4027 .param .f32 exponent,
4028 .param .u32 n
4029) {
4030 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4031 .reg .u64 %a, %out, %off;
4032 .reg .f32 %va, %vr, %exp, %lg;
4033 .reg .pred %p;
4034
4035 ld.param.u64 %a, [a_ptr];
4036 ld.param.u64 %out, [out_ptr];
4037 ld.param.f32 %exp, [exponent];
4038 ld.param.u32 %n_reg, [n];
4039
4040 mov.u32 %bid, %ctaid.x;
4041 mov.u32 %bdim, %ntid.x;
4042 mov.u32 %r_tid, %tid.x;
4043 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4044
4045 setp.ge.u32 %p, %r_tid, %n_reg;
4046 @%p bra DONE;
4047
4048 cvt.u64.u32 %off, %r_tid;
4049 shl.b64 %off, %off, 2;
4050
4051 add.u64 %a, %a, %off;
4052 add.u64 %out, %out, %off;
4053
4054 ld.global.f32 %va, [%a];
4055 // x^e = 2^(e * log2(x))
4056 lg2.approx.f32 %lg, %va;
4057 mul.f32 %lg, %lg, %exp;
4058 ex2.approx.f32 %vr, %lg;
4059 st.global.f32 [%out], %vr;
4060
4061DONE:
4062 ret;
4063}
4064";
4065
4066#[cfg(feature = "cuda")]
4068pub(crate) const ABS_PTX: &str = "\
4069.version 7.0
4070.target sm_52
4071.address_size 64
4072
4073.visible .entry abs_kernel(
4074 .param .u64 a_ptr,
4075 .param .u64 out_ptr,
4076 .param .u32 n
4077) {
4078 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4079 .reg .u64 %a, %out, %off;
4080 .reg .f32 %va, %vr;
4081 .reg .pred %p;
4082
4083 ld.param.u64 %a, [a_ptr];
4084 ld.param.u64 %out, [out_ptr];
4085 ld.param.u32 %n_reg, [n];
4086
4087 mov.u32 %bid, %ctaid.x;
4088 mov.u32 %bdim, %ntid.x;
4089 mov.u32 %r_tid, %tid.x;
4090 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4091
4092 setp.ge.u32 %p, %r_tid, %n_reg;
4093 @%p bra DONE;
4094
4095 cvt.u64.u32 %off, %r_tid;
4096 shl.b64 %off, %off, 2;
4097
4098 add.u64 %a, %a, %off;
4099 add.u64 %out, %out, %off;
4100
4101 ld.global.f32 %va, [%a];
4102 abs.f32 %vr, %va;
4103 st.global.f32 [%out], %vr;
4104
4105DONE:
4106 ret;
4107}
4108";
4109
4110#[cfg(feature = "cuda")]
4112pub(crate) const SIGMOID_PTX: &str = "\
4113.version 7.0
4114.target sm_52
4115.address_size 64
4116
4117.visible .entry sigmoid_kernel(
4118 .param .u64 a_ptr,
4119 .param .u64 out_ptr,
4120 .param .u32 n
4121) {
4122 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4123 .reg .u64 %a, %out, %off;
4124 .reg .f32 %va, %vr, %neg, %e, %denom, %one, %lg2e;
4125 .reg .pred %p;
4126
4127 ld.param.u64 %a, [a_ptr];
4128 ld.param.u64 %out, [out_ptr];
4129 ld.param.u32 %n_reg, [n];
4130
4131 mov.u32 %bid, %ctaid.x;
4132 mov.u32 %bdim, %ntid.x;
4133 mov.u32 %r_tid, %tid.x;
4134 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4135
4136 setp.ge.u32 %p, %r_tid, %n_reg;
4137 @%p bra DONE;
4138
4139 cvt.u64.u32 %off, %r_tid;
4140 shl.b64 %off, %off, 2;
4141
4142 add.u64 %a, %a, %off;
4143 add.u64 %out, %out, %off;
4144
4145 ld.global.f32 %va, [%a];
4146 // sigmoid(x) = 1 / (1 + exp(-x))
4147 neg.f32 %neg, %va;
4148 mov.f32 %lg2e, 0f3FB8AA3B;
4149 mul.f32 %neg, %neg, %lg2e;
4150 ex2.approx.f32 %e, %neg;
4151 mov.f32 %one, 0f3F800000;
4152 add.f32 %denom, %one, %e;
4153 div.rn.f32 %vr, %one, %denom;
4154 st.global.f32 [%out], %vr;
4155
4156DONE:
4157 ret;
4158}
4159";
4160
4161#[cfg(feature = "cuda")]
4164pub(crate) const TANH_PTX: &str = "\
4165.version 7.0
4166.target sm_52
4167.address_size 64
4168
4169.visible .entry tanh_kernel(
4170 .param .u64 a_ptr,
4171 .param .u64 out_ptr,
4172 .param .u32 n
4173) {
4174 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4175 .reg .u64 %a, %out, %off;
4176 .reg .f32 %va, %vr, %neg2x, %e, %denom, %sig, %one, %two, %lg2e;
4177 .reg .pred %p;
4178
4179 ld.param.u64 %a, [a_ptr];
4180 ld.param.u64 %out, [out_ptr];
4181 ld.param.u32 %n_reg, [n];
4182
4183 mov.u32 %bid, %ctaid.x;
4184 mov.u32 %bdim, %ntid.x;
4185 mov.u32 %r_tid, %tid.x;
4186 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4187
4188 setp.ge.u32 %p, %r_tid, %n_reg;
4189 @%p bra DONE;
4190
4191 cvt.u64.u32 %off, %r_tid;
4192 shl.b64 %off, %off, 2;
4193
4194 add.u64 %a, %a, %off;
4195 add.u64 %out, %out, %off;
4196
4197 ld.global.f32 %va, [%a];
4198 // tanh(x) = 2*sigmoid(2x) - 1
4199 mov.f32 %two, 0f40000000;
4200 mul.f32 %neg2x, %va, %two;
4201 neg.f32 %neg2x, %neg2x;
4202 mov.f32 %lg2e, 0f3FB8AA3B;
4203 mul.f32 %neg2x, %neg2x, %lg2e;
4204 ex2.approx.f32 %e, %neg2x;
4205 mov.f32 %one, 0f3F800000;
4206 add.f32 %denom, %one, %e;
4207 div.rn.f32 %sig, %one, %denom;
4208 mul.f32 %vr, %two, %sig;
4209 sub.f32 %vr, %vr, %one;
4210 st.global.f32 [%out], %vr;
4211
4212DONE:
4213 ret;
4214}
4215";
4216
4217#[cfg(feature = "cuda")]
4227pub(crate) const FUSED_ADAM_PTX: &str = "\
4228.version 7.0
4229.target sm_52
4230.address_size 64
4231
4232.visible .entry fused_adam_kernel(
4233 .param .u64 param_ptr,
4234 .param .u64 grad_ptr,
4235 .param .u64 exp_avg_ptr,
4236 .param .u64 exp_avg_sq_ptr,
4237 .param .f32 beta1,
4238 .param .f32 beta2,
4239 .param .f32 lr,
4240 .param .f32 eps,
4241 .param .f32 bc1,
4242 .param .f32 bc2,
4243 .param .f32 weight_decay,
4244 .param .u32 n
4245) {
4246 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4247 .reg .u64 %p, %g, %m, %v, %off;
4248 .reg .f32 %vp, %vg, %vm, %vv;
4249 .reg .f32 %b1, %b2, %f_lr, %f_eps, %f_bc1, %f_bc2, %f_wd;
4250 .reg .f32 %t1, %t2, %m_hat, %v_hat, %denom, %update;
4251 .reg .f32 %one;
4252 .reg .pred %p_bound, %p_wd;
4253
4254 ld.param.u64 %p, [param_ptr];
4255 ld.param.u64 %g, [grad_ptr];
4256 ld.param.u64 %m, [exp_avg_ptr];
4257 ld.param.u64 %v, [exp_avg_sq_ptr];
4258 ld.param.f32 %b1, [beta1];
4259 ld.param.f32 %b2, [beta2];
4260 ld.param.f32 %f_lr, [lr];
4261 ld.param.f32 %f_eps, [eps];
4262 ld.param.f32 %f_bc1, [bc1];
4263 ld.param.f32 %f_bc2, [bc2];
4264 ld.param.f32 %f_wd, [weight_decay];
4265 ld.param.u32 %n_reg, [n];
4266
4267 mov.u32 %bid, %ctaid.x;
4268 mov.u32 %bdim, %ntid.x;
4269 mov.u32 %r_tid, %tid.x;
4270 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4271
4272 setp.ge.u32 %p_bound, %r_tid, %n_reg;
4273 @%p_bound bra DONE;
4274
4275 cvt.u64.u32 %off, %r_tid;
4276 shl.b64 %off, %off, 2;
4277
4278 add.u64 %p, %p, %off;
4279 add.u64 %g, %g, %off;
4280 add.u64 %m, %m, %off;
4281 add.u64 %v, %v, %off;
4282
4283 ld.global.f32 %vp, [%p];
4284 ld.global.f32 %vg, [%g];
4285 ld.global.f32 %vm, [%m];
4286 ld.global.f32 %vv, [%v];
4287
4288 // L2 weight decay: g = g + wd * p
4289 mov.f32 %one, 0f00000000;
4290 setp.gt.f32 %p_wd, %f_wd, %one;
4291 @%p_wd fma.rn.f32 %vg, %f_wd, %vp, %vg;
4292
4293 // exp_avg = beta1 * exp_avg + (1 - beta1) * g
4294 mov.f32 %one, 0f3F800000;
4295 sub.f32 %t1, %one, %b1;
4296 mul.f32 %vm, %vm, %b1;
4297 fma.rn.f32 %vm, %t1, %vg, %vm;
4298
4299 // exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * g * g
4300 sub.f32 %t2, %one, %b2;
4301 mul.f32 %vv, %vv, %b2;
4302 mul.f32 %t1, %vg, %vg;
4303 fma.rn.f32 %vv, %t2, %t1, %vv;
4304
4305 // m_hat = exp_avg / bc1
4306 div.rn.f32 %m_hat, %vm, %f_bc1;
4307
4308 // v_hat = exp_avg_sq / bc2
4309 div.rn.f32 %v_hat, %vv, %f_bc2;
4310
4311 // denom = sqrt(v_hat) + eps
4312 sqrt.rn.f32 %denom, %v_hat;
4313 add.f32 %denom, %denom, %f_eps;
4314
4315 // param = param - lr * m_hat / denom
4316 div.rn.f32 %update, %m_hat, %denom;
4317 mul.f32 %update, %update, %f_lr;
4318 sub.f32 %vp, %vp, %update;
4319
4320 st.global.f32 [%p], %vp;
4321 st.global.f32 [%m], %vm;
4322 st.global.f32 [%v], %vv;
4323
4324DONE:
4325 ret;
4326}
4327";
4328
4329#[cfg(feature = "cuda")]
4341pub(crate) const FUSED_GRU_FORWARD_PTX: &str = "\
4342.version 7.0
4343.target sm_52
4344.address_size 64
4345
4346.visible .entry fused_gru_forward_kernel(
4347 .param .u64 input_gates_ptr,
4348 .param .u64 hidden_gates_ptr,
4349 .param .u64 bias_ih_ptr,
4350 .param .u64 bias_hh_ptr,
4351 .param .u64 hx_ptr,
4352 .param .u64 hy_ptr,
4353 .param .u64 workspace_ptr,
4354 .param .u32 hsz,
4355 .param .u32 total
4356) {
4357 .reg .u32 %tid, %bid, %bdim, %gdim, %total_reg, %hsz_reg;
4358 .reg .u32 %idx, %stride, %offset3, %offset5, %hmod, %batch_idx;
4359 .reg .u64 %ig, %hg, %b1, %b2, %hx, %hy, %ws;
4360 .reg .u64 %off64, %tmp64;
4361 .reg .f32 %ir, %ii, %in, %hr, %hi, %hn;
4362 .reg .f32 %b1r, %b1i, %b1n, %b2r, %b2i, %b2n;
4363 .reg .f32 %hx_val, %rg, %zg, %ng, %hy_val;
4364 .reg .f32 %one, %neg_one, %exp_val, %denom, %tmp;
4365 .reg .pred %p;
4366
4367 ld.param.u64 %ig, [input_gates_ptr];
4368 ld.param.u64 %hg, [hidden_gates_ptr];
4369 ld.param.u64 %b1, [bias_ih_ptr];
4370 ld.param.u64 %b2, [bias_hh_ptr];
4371 ld.param.u64 %hx, [hx_ptr];
4372 ld.param.u64 %hy, [hy_ptr];
4373 ld.param.u64 %ws, [workspace_ptr];
4374 ld.param.u32 %hsz_reg, [hsz];
4375 ld.param.u32 %total_reg, [total];
4376
4377 mov.u32 %bid, %ctaid.x;
4378 mov.u32 %bdim, %ntid.x;
4379 mov.u32 %tid, %tid.x;
4380 mov.u32 %gdim, %nctaid.x;
4381 mad.lo.u32 %idx, %bid, %bdim, %tid;
4382 mul.lo.u32 %stride, %bdim, %gdim;
4383 mov.f32 %one, 0f3F800000;
4384
4385LOOP:
4386 setp.ge.u32 %p, %idx, %total_reg;
4387 @%p bra END;
4388
4389 // offset3 = (idx/hsz)*3*hsz + idx%hsz (into [B, 3*H] gates tensor)
4390 div.u32 %batch_idx, %idx, %hsz_reg;
4391 rem.u32 %hmod, %idx, %hsz_reg;
4392 mul.lo.u32 %offset3, %batch_idx, %hsz_reg;
4393 mul.lo.u32 %offset3, %offset3, 3;
4394 add.u32 %offset3, %offset3, %hmod;
4395
4396 // Load input gate components: ir, ii, in
4397 cvt.u64.u32 %off64, %offset3;
4398 shl.b64 %off64, %off64, 2;
4399 add.u64 %tmp64, %ig, %off64;
4400 ld.global.f32 %ir, [%tmp64];
4401 cvt.u64.u32 %off64, %hsz_reg;
4402 shl.b64 %off64, %off64, 2;
4403 add.u64 %tmp64, %tmp64, %off64;
4404 ld.global.f32 %ii, [%tmp64];
4405 add.u64 %tmp64, %tmp64, %off64;
4406 ld.global.f32 %in, [%tmp64];
4407
4408 // Load hidden gate components: hr, hi, hn
4409 cvt.u64.u32 %off64, %offset3;
4410 shl.b64 %off64, %off64, 2;
4411 add.u64 %tmp64, %hg, %off64;
4412 ld.global.f32 %hr, [%tmp64];
4413 cvt.u64.u32 %off64, %hsz_reg;
4414 shl.b64 %off64, %off64, 2;
4415 add.u64 %tmp64, %tmp64, %off64;
4416 ld.global.f32 %hi, [%tmp64];
4417 add.u64 %tmp64, %tmp64, %off64;
4418 ld.global.f32 %hn, [%tmp64];
4419
4420 // Load biases (indexed by hmod, hmod+hsz, hmod+2*hsz)
4421 cvt.u64.u32 %off64, %hmod;
4422 shl.b64 %off64, %off64, 2;
4423 add.u64 %tmp64, %b1, %off64;
4424 ld.global.f32 %b1r, [%tmp64];
4425 cvt.u64.u32 %off64, %hsz_reg;
4426 shl.b64 %off64, %off64, 2;
4427 add.u64 %tmp64, %tmp64, %off64;
4428 ld.global.f32 %b1i, [%tmp64];
4429 add.u64 %tmp64, %tmp64, %off64;
4430 ld.global.f32 %b1n, [%tmp64];
4431
4432 cvt.u64.u32 %off64, %hmod;
4433 shl.b64 %off64, %off64, 2;
4434 add.u64 %tmp64, %b2, %off64;
4435 ld.global.f32 %b2r, [%tmp64];
4436 cvt.u64.u32 %off64, %hsz_reg;
4437 shl.b64 %off64, %off64, 2;
4438 add.u64 %tmp64, %tmp64, %off64;
4439 ld.global.f32 %b2i, [%tmp64];
4440 add.u64 %tmp64, %tmp64, %off64;
4441 ld.global.f32 %b2n, [%tmp64];
4442
4443 // Load hx[idx]
4444 cvt.u64.u32 %off64, %idx;
4445 shl.b64 %off64, %off64, 2;
4446 add.u64 %tmp64, %hx, %off64;
4447 ld.global.f32 %hx_val, [%tmp64];
4448
4449 // r = sigmoid(ir + hr + b1r + b2r)
4450 add.f32 %rg, %ir, %hr;
4451 add.f32 %rg, %rg, %b1r;
4452 add.f32 %rg, %rg, %b2r;
4453 neg.f32 %tmp, %rg;
4454 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
4455 ex2.approx.f32 %exp_val, %tmp;
4456 add.f32 %denom, %one, %exp_val;
4457 div.rn.f32 %rg, %one, %denom;
4458
4459 // z = sigmoid(ii + hi + b1i + b2i)
4460 add.f32 %zg, %ii, %hi;
4461 add.f32 %zg, %zg, %b1i;
4462 add.f32 %zg, %zg, %b2i;
4463 neg.f32 %tmp, %zg;
4464 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
4465 ex2.approx.f32 %exp_val, %tmp;
4466 add.f32 %denom, %one, %exp_val;
4467 div.rn.f32 %zg, %one, %denom;
4468
4469 // n = tanh(in + b1n + r*(hn + b2n))
4470 add.f32 %tmp, %hn, %b2n;
4471 fma.rn.f32 %ng, %rg, %tmp, %in;
4472 add.f32 %ng, %ng, %b1n;
4473 // tanh via 2*sigmoid(2x)-1
4474 mul.f32 %tmp, %ng, 0f40000000;
4475 neg.f32 %tmp, %tmp;
4476 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
4477 ex2.approx.f32 %exp_val, %tmp;
4478 add.f32 %denom, %one, %exp_val;
4479 div.rn.f32 %ng, %one, %denom;
4480 mul.f32 %ng, %ng, 0f40000000;
4481 sub.f32 %ng, %ng, %one;
4482
4483 // hy = n + z * (hx - n)
4484 sub.f32 %tmp, %hx_val, %ng;
4485 fma.rn.f32 %hy_val, %zg, %tmp, %ng;
4486
4487 // Store hy[idx]
4488 cvt.u64.u32 %off64, %idx;
4489 shl.b64 %off64, %off64, 2;
4490 add.u64 %tmp64, %hy, %off64;
4491 st.global.f32 [%tmp64], %hy_val;
4492
4493 // Store workspace: [r, z, n, hx, hn+b2n] at offset5 = (idx/hsz)*5*hsz + idx%hsz
4494 mul.lo.u32 %offset5, %batch_idx, %hsz_reg;
4495 mul.lo.u32 %offset5, %offset5, 5;
4496 add.u32 %offset5, %offset5, %hmod;
4497
4498 cvt.u64.u32 %off64, %offset5;
4499 shl.b64 %off64, %off64, 2;
4500 add.u64 %tmp64, %ws, %off64;
4501 st.global.f32 [%tmp64], %rg;
4502 cvt.u64.u32 %off64, %hsz_reg;
4503 shl.b64 %off64, %off64, 2;
4504 add.u64 %tmp64, %tmp64, %off64;
4505 st.global.f32 [%tmp64], %zg;
4506 add.u64 %tmp64, %tmp64, %off64;
4507 st.global.f32 [%tmp64], %ng;
4508 add.u64 %tmp64, %tmp64, %off64;
4509 st.global.f32 [%tmp64], %hx_val;
4510 add.u64 %tmp64, %tmp64, %off64;
4511 add.f32 %tmp, %hn, %b2n;
4512 st.global.f32 [%tmp64], %tmp;
4513
4514 add.u32 %idx, %idx, %stride;
4515 bra LOOP;
4516
4517END:
4518 ret;
4519}
4520";
4521
4522#[cfg(feature = "cuda")]
4536fn launch_cfg(n: usize) -> GpuResult<LaunchConfig> {
4537 if n > u32::MAX as usize {
4538 return Err(GpuError::ShapeMismatch {
4539 op: "kernel_launch",
4540 expected: vec![u32::MAX as usize],
4541 got: vec![n],
4542 });
4543 }
4544 const BLOCK: u32 = 256;
4545 let grid = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
4546 Ok(LaunchConfig {
4547 grid_dim: (grid.max(1), 1, 1),
4548 block_dim: (BLOCK, 1, 1),
4549 shared_mem_bytes: 0,
4550 })
4551}
4552
4553#[cfg(feature = "cuda")]
4559fn validate_binary(a: &CudaBuffer<f32>, b: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
4560 if a.device_ordinal() != device.ordinal() {
4561 return Err(GpuError::DeviceMismatch {
4562 expected: a.device_ordinal(),
4563 got: device.ordinal(),
4564 });
4565 }
4566 if b.device_ordinal() != device.ordinal() {
4567 return Err(GpuError::DeviceMismatch {
4568 expected: b.device_ordinal(),
4569 got: device.ordinal(),
4570 });
4571 }
4572 if a.len() != b.len() {
4573 return Err(GpuError::LengthMismatch {
4574 a: a.len(),
4575 b: b.len(),
4576 });
4577 }
4578 Ok(())
4579}
4580
4581#[cfg(feature = "cuda")]
4583fn validate_unary(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
4584 if a.device_ordinal() != device.ordinal() {
4585 return Err(GpuError::DeviceMismatch {
4586 expected: a.device_ordinal(),
4587 got: device.ordinal(),
4588 });
4589 }
4590 Ok(())
4591}
4592
4593#[cfg(feature = "cuda")]
4601fn try_launch_binary(
4602 a: &CudaBuffer<f32>,
4603 b: &CudaBuffer<f32>,
4604 device: &GpuDevice,
4605 ptx_src: &'static str,
4606 kernel_name: &'static str,
4607) -> GpuResult<Option<CudaBuffer<f32>>> {
4608 use cudarc::driver::PushKernelArg;
4609
4610 let n = a.len();
4611 let ctx = device.context();
4612 let stream = device.stream();
4613
4614 let f = match crate::module_cache::get_or_compile(
4618 ctx,
4619 ptx_src,
4620 kernel_name,
4621 device.ordinal() as u32,
4622 ) {
4623 Ok(f) => f,
4624 Err(_) => return Ok(None),
4625 };
4626
4627 let mut out = alloc_zeros_f32(n, device)?;
4628 let cfg = launch_cfg(n)?;
4629 let n_u32 = n as u32;
4630
4631 unsafe {
4635 stream
4636 .launch_builder(&f)
4637 .arg(a.inner())
4638 .arg(b.inner())
4639 .arg(out.inner_mut())
4640 .arg(&n_u32)
4641 .launch(cfg)?;
4642 }
4643
4644 Ok(Some(out))
4645}
4646
4647#[cfg(feature = "cuda")]
4652fn try_launch_binary_vec4(
4653 a: &CudaBuffer<f32>,
4654 b: &CudaBuffer<f32>,
4655 device: &GpuDevice,
4656 ptx_src: &'static str,
4657 kernel_name: &'static str,
4658) -> GpuResult<Option<CudaBuffer<f32>>> {
4659 use cudarc::driver::PushKernelArg;
4660
4661 let n = a.len();
4662 let n4 = (n / 4) as u32;
4663 let ctx = device.context();
4664 let stream = device.stream();
4665
4666 let f = match crate::module_cache::get_or_compile(
4667 ctx,
4668 ptx_src,
4669 kernel_name,
4670 device.ordinal() as u32,
4671 ) {
4672 Ok(f) => f,
4673 Err(_) => return Ok(None),
4674 };
4675
4676 let mut out = alloc_zeros_f32(n, device)?;
4677 let cfg = launch_cfg(n4 as usize)?;
4678
4679 unsafe {
4680 stream
4681 .launch_builder(&f)
4682 .arg(a.inner())
4683 .arg(b.inner())
4684 .arg(out.inner_mut())
4685 .arg(&n4)
4686 .launch(cfg)?;
4687 }
4688
4689 Ok(Some(out))
4690}
4691
4692#[cfg(feature = "cuda")]
4695fn try_launch_unary(
4696 a: &CudaBuffer<f32>,
4697 device: &GpuDevice,
4698 ptx_src: &'static str,
4699 kernel_name: &'static str,
4700) -> GpuResult<Option<CudaBuffer<f32>>> {
4701 use cudarc::driver::PushKernelArg;
4702
4703 let n = a.len();
4704 let ctx = device.context();
4705 let stream = device.stream();
4706
4707 let f = match crate::module_cache::get_or_compile(
4709 ctx,
4710 ptx_src,
4711 kernel_name,
4712 device.ordinal() as u32,
4713 ) {
4714 Ok(f) => f,
4715 Err(_) => return Ok(None),
4716 };
4717
4718 let mut out = alloc_zeros_f32(n, device)?;
4719 let cfg = launch_cfg(n)?;
4720 let n_u32 = n as u32;
4721
4722 unsafe {
4725 stream
4726 .launch_builder(&f)
4727 .arg(a.inner())
4728 .arg(out.inner_mut())
4729 .arg(&n_u32)
4730 .launch(cfg)?;
4731 }
4732
4733 Ok(Some(out))
4734}
4735
4736#[cfg(feature = "cuda")]
4743fn try_launch_binary_into(
4744 a: &CudaBuffer<f32>,
4745 b: &CudaBuffer<f32>,
4746 out: &mut CudaBuffer<f32>,
4747 device: &GpuDevice,
4748 ptx_src: &'static str,
4749 kernel_name: &'static str,
4750) -> GpuResult<bool> {
4751 use cudarc::driver::PushKernelArg;
4752
4753 let n = a.len();
4754 let ctx = device.context();
4755 let stream = device.stream();
4756
4757 let f = match crate::module_cache::get_or_compile(
4758 ctx,
4759 ptx_src,
4760 kernel_name,
4761 device.ordinal() as u32,
4762 ) {
4763 Ok(f) => f,
4764 Err(_) => return Ok(false),
4765 };
4766
4767 let cfg = launch_cfg(n)?;
4768 let n_u32 = n as u32;
4769
4770 unsafe {
4771 stream
4772 .launch_builder(&f)
4773 .arg(a.inner())
4774 .arg(b.inner())
4775 .arg(out.inner_mut())
4776 .arg(&n_u32)
4777 .launch(cfg)?;
4778 }
4779
4780 Ok(true)
4781}
4782
4783#[cfg(feature = "cuda")]
4786fn try_launch_unary_into(
4787 a: &CudaBuffer<f32>,
4788 out: &mut CudaBuffer<f32>,
4789 device: &GpuDevice,
4790 ptx_src: &'static str,
4791 kernel_name: &'static str,
4792) -> GpuResult<bool> {
4793 use cudarc::driver::PushKernelArg;
4794
4795 let n = a.len();
4796 let ctx = device.context();
4797 let stream = device.stream();
4798
4799 let f = match crate::module_cache::get_or_compile(
4800 ctx,
4801 ptx_src,
4802 kernel_name,
4803 device.ordinal() as u32,
4804 ) {
4805 Ok(f) => f,
4806 Err(_) => return Ok(false),
4807 };
4808
4809 let cfg = launch_cfg(n)?;
4810 let n_u32 = n as u32;
4811
4812 unsafe {
4813 stream
4814 .launch_builder(&f)
4815 .arg(a.inner())
4816 .arg(out.inner_mut())
4817 .arg(&n_u32)
4818 .launch(cfg)?;
4819 }
4820
4821 Ok(true)
4822}
4823
4824#[cfg(feature = "cuda")]
4831#[allow(clippy::too_many_arguments)]
4832fn try_launch_broadcast_binary(
4833 a: &CudaBuffer<f32>,
4834 b: &CudaBuffer<f32>,
4835 a_strides: &[u32],
4836 b_strides: &[u32],
4837 out_shape: &[u32],
4838 out_numel: usize,
4839 device: &GpuDevice,
4840 ptx_src: &'static str,
4841 kernel_name: &'static str,
4842) -> GpuResult<Option<CudaBuffer<f32>>> {
4843 use cudarc::driver::PushKernelArg;
4844
4845 let ndim = out_shape.len();
4846 let ctx = device.context();
4847 let stream = device.stream();
4848
4849 let f = match crate::module_cache::get_or_compile(
4850 ctx,
4851 ptx_src,
4852 kernel_name,
4853 device.ordinal() as u32,
4854 ) {
4855 Ok(f) => f,
4856 Err(_) => return Ok(None),
4857 };
4858
4859 let a_str_buf = cpu_to_gpu(a_strides, device)?;
4861 let b_str_buf = cpu_to_gpu(b_strides, device)?;
4862 let shape_buf = cpu_to_gpu(out_shape, device)?;
4863
4864 let mut out = alloc_zeros_f32(out_numel, device)?;
4865 let cfg = launch_cfg(out_numel)?;
4866 let n_u32 = out_numel as u32;
4867 let ndim_u32 = ndim as u32;
4868
4869 unsafe {
4872 stream
4873 .launch_builder(&f)
4874 .arg(a.inner())
4875 .arg(b.inner())
4876 .arg(out.inner_mut())
4877 .arg(a_str_buf.inner())
4878 .arg(b_str_buf.inner())
4879 .arg(shape_buf.inner())
4880 .arg(&n_u32)
4881 .arg(&ndim_u32)
4882 .launch(cfg)?;
4883 }
4884
4885 Ok(Some(out))
4886}
4887
4888#[cfg(feature = "cuda")]
4895fn broadcast_strides(in_shape: &[usize], out_shape: &[usize]) -> Vec<u32> {
4896 let ndim = out_shape.len();
4897 let in_ndim = in_shape.len();
4898 let mut strides = vec![0u32; ndim];
4899
4900 let mut stride: u32 = 1;
4902 for d in (0..ndim).rev() {
4903 let in_d = if d + in_ndim >= ndim {
4904 d + in_ndim - ndim
4905 } else {
4906 strides[d] = 0;
4908 continue;
4909 };
4910
4911 if in_shape[in_d] == 1 {
4912 strides[d] = 0; } else {
4914 strides[d] = stride;
4915 }
4916 stride *= in_shape[in_d] as u32;
4917 }
4918
4919 strides
4920}
4921
4922#[cfg(feature = "cuda")]
4929fn cpu_fallback_binary(
4930 a: &CudaBuffer<f32>,
4931 b: &CudaBuffer<f32>,
4932 device: &GpuDevice,
4933 op: fn(f32, f32) -> f32,
4934) -> GpuResult<CudaBuffer<f32>> {
4935 let a_host = gpu_to_cpu(a, device)?;
4936 let b_host = gpu_to_cpu(b, device)?;
4937 let result: Vec<f32> = a_host
4938 .iter()
4939 .zip(b_host.iter())
4940 .map(|(&x, &y)| op(x, y))
4941 .collect();
4942 cpu_to_gpu(&result, device)
4943}
4944
4945#[cfg(feature = "cuda")]
4947fn cpu_fallback_unary(
4948 a: &CudaBuffer<f32>,
4949 device: &GpuDevice,
4950 op: fn(f32) -> f32,
4951) -> GpuResult<CudaBuffer<f32>> {
4952 let a_host = gpu_to_cpu(a, device)?;
4953 let result: Vec<f32> = a_host.iter().map(|&x| op(x)).collect();
4954 cpu_to_gpu(&result, device)
4955}
4956
4957#[cfg(feature = "cuda")]
4973pub fn gpu_add(
4974 a: &CudaBuffer<f32>,
4975 b: &CudaBuffer<f32>,
4976 device: &GpuDevice,
4977) -> GpuResult<CudaBuffer<f32>> {
4978 validate_binary(a, b, device)?;
4979
4980 let n = a.len();
4982 if n >= 16 && n % 4 == 0 {
4983 if let Some(out) = try_launch_binary_vec4(
4984 a, b, device, ADD_VEC4_PTX, "add_vec4_kernel",
4985 )? {
4986 return Ok(out);
4987 }
4988 }
4989
4990 if let Some(out) = try_launch_binary(a, b, device, ADD_PTX, "add_kernel")? {
4991 return Ok(out);
4992 }
4993
4994 cpu_fallback_binary(a, b, device, |x, y| x + y)
4995}
4996
4997#[cfg(feature = "cuda")]
5009pub fn gpu_sub(
5010 a: &CudaBuffer<f32>,
5011 b: &CudaBuffer<f32>,
5012 device: &GpuDevice,
5013) -> GpuResult<CudaBuffer<f32>> {
5014 validate_binary(a, b, device)?;
5015
5016 if let Some(out) = try_launch_binary(a, b, device, SUB_PTX, "sub_kernel")? {
5017 return Ok(out);
5018 }
5019
5020 cpu_fallback_binary(a, b, device, |x, y| x - y)
5021}
5022
5023#[cfg(feature = "cuda")]
5035pub fn gpu_mul(
5036 a: &CudaBuffer<f32>,
5037 b: &CudaBuffer<f32>,
5038 device: &GpuDevice,
5039) -> GpuResult<CudaBuffer<f32>> {
5040 validate_binary(a, b, device)?;
5041
5042 let n = a.len();
5043 if n >= 16 && n % 4 == 0 {
5044 if let Some(out) = try_launch_binary_vec4(
5045 a, b, device, MUL_VEC4_PTX, "mul_vec4_kernel",
5046 )? {
5047 return Ok(out);
5048 }
5049 }
5050
5051 if let Some(out) = try_launch_binary(a, b, device, MUL_PTX, "mul_kernel")? {
5052 return Ok(out);
5053 }
5054
5055 cpu_fallback_binary(a, b, device, |x, y| x * y)
5056}
5057
5058#[cfg(feature = "cuda")]
5071pub fn gpu_broadcast_add(
5072 a: &CudaBuffer<f32>,
5073 b: &CudaBuffer<f32>,
5074 a_shape: &[usize],
5075 b_shape: &[usize],
5076 out_shape: &[usize],
5077 device: &GpuDevice,
5078) -> GpuResult<CudaBuffer<f32>> {
5079 let a_str = broadcast_strides(a_shape, out_shape);
5080 let b_str = broadcast_strides(b_shape, out_shape);
5081 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5082 let out_numel: usize = out_shape.iter().product();
5083
5084 if let Some(out) = try_launch_broadcast_binary(
5085 a,
5086 b,
5087 &a_str,
5088 &b_str,
5089 &shape_u32,
5090 out_numel,
5091 device,
5092 BROADCAST_ADD_PTX,
5093 "broadcast_add_kernel",
5094 )? {
5095 return Ok(out);
5096 }
5097
5098 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
5100}
5101
5102#[cfg(feature = "cuda")]
5104pub fn gpu_broadcast_sub(
5105 a: &CudaBuffer<f32>,
5106 b: &CudaBuffer<f32>,
5107 a_shape: &[usize],
5108 b_shape: &[usize],
5109 out_shape: &[usize],
5110 device: &GpuDevice,
5111) -> GpuResult<CudaBuffer<f32>> {
5112 let a_str = broadcast_strides(a_shape, out_shape);
5113 let b_str = broadcast_strides(b_shape, out_shape);
5114 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5115 let out_numel: usize = out_shape.iter().product();
5116
5117 if let Some(out) = try_launch_broadcast_binary(
5118 a,
5119 b,
5120 &a_str,
5121 &b_str,
5122 &shape_u32,
5123 out_numel,
5124 device,
5125 BROADCAST_SUB_PTX,
5126 "broadcast_sub_kernel",
5127 )? {
5128 return Ok(out);
5129 }
5130
5131 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
5132}
5133
5134#[cfg(feature = "cuda")]
5136pub fn gpu_broadcast_mul(
5137 a: &CudaBuffer<f32>,
5138 b: &CudaBuffer<f32>,
5139 a_shape: &[usize],
5140 b_shape: &[usize],
5141 out_shape: &[usize],
5142 device: &GpuDevice,
5143) -> GpuResult<CudaBuffer<f32>> {
5144 let a_str = broadcast_strides(a_shape, out_shape);
5145 let b_str = broadcast_strides(b_shape, out_shape);
5146 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5147 let out_numel: usize = out_shape.iter().product();
5148
5149 if let Some(out) = try_launch_broadcast_binary(
5150 a,
5151 b,
5152 &a_str,
5153 &b_str,
5154 &shape_u32,
5155 out_numel,
5156 device,
5157 BROADCAST_MUL_PTX,
5158 "broadcast_mul_kernel",
5159 )? {
5160 return Ok(out);
5161 }
5162
5163 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
5164}
5165
5166#[cfg(feature = "cuda")]
5168pub fn gpu_broadcast_div(
5169 a: &CudaBuffer<f32>,
5170 b: &CudaBuffer<f32>,
5171 a_shape: &[usize],
5172 b_shape: &[usize],
5173 out_shape: &[usize],
5174 device: &GpuDevice,
5175) -> GpuResult<CudaBuffer<f32>> {
5176 let a_str = broadcast_strides(a_shape, out_shape);
5177 let b_str = broadcast_strides(b_shape, out_shape);
5178 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5179 let out_numel: usize = out_shape.iter().product();
5180
5181 if let Some(out) = try_launch_broadcast_binary(
5182 a,
5183 b,
5184 &a_str,
5185 &b_str,
5186 &shape_u32,
5187 out_numel,
5188 device,
5189 BROADCAST_DIV_PTX,
5190 "broadcast_div_kernel",
5191 )? {
5192 return Ok(out);
5193 }
5194
5195 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
5196}
5197
5198#[cfg(feature = "cuda")]
5201fn cpu_fallback_broadcast_binary(
5202 a: &CudaBuffer<f32>,
5203 b: &CudaBuffer<f32>,
5204 a_shape: &[usize],
5205 b_shape: &[usize],
5206 out_shape: &[usize],
5207 device: &GpuDevice,
5208 op: fn(f32, f32) -> f32,
5209) -> GpuResult<CudaBuffer<f32>> {
5210 let a_host = gpu_to_cpu(a, device)?;
5211 let b_host = gpu_to_cpu(b, device)?;
5212 let out_numel: usize = out_shape.iter().product();
5213
5214 let a_str = broadcast_strides(a_shape, out_shape);
5215 let b_str = broadcast_strides(b_shape, out_shape);
5216
5217 let mut result = Vec::with_capacity(out_numel);
5218 for i in 0..out_numel {
5219 let mut remaining = i;
5220 let mut a_idx = 0usize;
5221 let mut b_idx = 0usize;
5222 for d in (0..out_shape.len()).rev() {
5223 let coord = remaining % out_shape[d];
5224 remaining /= out_shape[d];
5225 a_idx += coord * a_str[d] as usize;
5226 b_idx += coord * b_str[d] as usize;
5227 }
5228 result.push(op(a_host[a_idx], b_host[b_idx]));
5229 }
5230 cpu_to_gpu(&result, device)
5231}
5232
5233#[cfg(feature = "cuda")]
5248pub fn gpu_neg(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
5249 validate_unary(a, device)?;
5250
5251 if let Some(out) = try_launch_unary(a, device, NEG_PTX, "neg_kernel")? {
5252 return Ok(out);
5253 }
5254
5255 cpu_fallback_unary(a, device, |x| -x)
5256}
5257
5258#[cfg(feature = "cuda")]
5269pub fn gpu_relu(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
5270 validate_unary(a, device)?;
5271
5272 if let Some(out) = try_launch_unary(a, device, RELU_PTX, "relu_kernel")? {
5273 return Ok(out);
5274 }
5275
5276 cpu_fallback_unary(a, device, |x| x.max(0.0))
5277}
5278
5279#[cfg(feature = "cuda")]
5281pub fn gpu_relu_backward(
5282 grad: &CudaBuffer<f32>,
5283 input: &CudaBuffer<f32>,
5284 device: &GpuDevice,
5285) -> GpuResult<CudaBuffer<f32>> {
5286 validate_binary(grad, input, device)?;
5287
5288 if let Some(out) = try_launch_binary(
5289 grad,
5290 input,
5291 device,
5292 RELU_BACKWARD_PTX,
5293 "relu_backward_kernel",
5294 )? {
5295 return Ok(out);
5296 }
5297
5298 let grad_host = gpu_to_cpu(grad, device)?;
5300 let input_host = gpu_to_cpu(input, device)?;
5301 let result: Vec<f32> = grad_host
5302 .iter()
5303 .zip(input_host.iter())
5304 .map(|(&g, &x)| if x > 0.0 { g } else { 0.0 })
5305 .collect();
5306 cpu_to_gpu(&result, device)
5307}
5308
5309#[cfg(feature = "cuda")]
5312pub fn gpu_gelu_backward(
5313 grad: &CudaBuffer<f32>,
5314 input: &CudaBuffer<f32>,
5315 device: &GpuDevice,
5316) -> GpuResult<CudaBuffer<f32>> {
5317 validate_binary(grad, input, device)?;
5318
5319 if let Some(out) = try_launch_binary(
5320 grad,
5321 input,
5322 device,
5323 GELU_BACKWARD_PTX,
5324 "gelu_backward_kernel",
5325 )? {
5326 return Ok(out);
5327 }
5328
5329 let grad_host = gpu_to_cpu(grad, device)?;
5331 let input_host = gpu_to_cpu(input, device)?;
5332 let result: Vec<f32> = grad_host
5333 .iter()
5334 .zip(input_host.iter())
5335 .map(|(&g, &x)| {
5336 let k: f32 = 1.702;
5337 let sig = 1.0 / (1.0 + (-k * x).exp());
5338 g * (sig + k * x * sig * (1.0 - sig))
5339 })
5340 .collect();
5341 cpu_to_gpu(&result, device)
5342}
5343
5344#[cfg(feature = "cuda")]
5353pub fn gpu_index_select_1d(
5354 input: &CudaBuffer<f32>,
5355 indices: &CudaBuffer<f32>,
5356 device: &GpuDevice,
5357) -> GpuResult<CudaBuffer<f32>> {
5358 use cudarc::driver::PushKernelArg;
5359
5360 validate_unary(input, device)?;
5361
5362 let n = indices.len();
5363 let ctx = device.context();
5364 let stream = device.stream();
5365
5366 let f = match crate::module_cache::get_or_compile(
5367 ctx,
5368 INDEX_SELECT_1D_PTX,
5369 "index_select_1d_kernel",
5370 device.ordinal() as u32,
5371 ) {
5372 Ok(f) => f,
5373 Err(_) => {
5374 let input_host = gpu_to_cpu(input, device)?;
5376 let indices_host = gpu_to_cpu(indices, device)?;
5377 let result: Vec<f32> = indices_host
5378 .iter()
5379 .map(|&idx_f| input_host[idx_f as usize])
5380 .collect();
5381 return cpu_to_gpu(&result, device);
5382 }
5383 };
5384
5385 let mut out = alloc_zeros_f32(n, device)?;
5386 let cfg = launch_cfg(n)?;
5387 let n_u32 = n as u32;
5388
5389 unsafe {
5390 stream
5391 .launch_builder(&f)
5392 .arg(input.inner())
5393 .arg(indices.inner())
5394 .arg(out.inner_mut())
5395 .arg(&n_u32)
5396 .launch(cfg)?;
5397 }
5398
5399 Ok(out)
5400}
5401
5402#[cfg(feature = "cuda")]
5414pub fn gpu_scatter_add_1d(
5415 grad_output: &CudaBuffer<f32>,
5416 indices: &CudaBuffer<f32>,
5417 input_len: usize,
5418 device: &GpuDevice,
5419) -> GpuResult<CudaBuffer<f32>> {
5420 use cudarc::driver::PushKernelArg;
5421
5422 validate_unary(grad_output, device)?;
5423
5424 let n = grad_output.len();
5425 let ctx = device.context();
5426 let stream = device.stream();
5427
5428 let f = match crate::module_cache::get_or_compile(
5429 ctx,
5430 SCATTER_ADD_1D_PTX,
5431 "scatter_add_1d_kernel",
5432 device.ordinal() as u32,
5433 ) {
5434 Ok(f) => f,
5435 Err(_) => {
5436 let go_host = gpu_to_cpu(grad_output, device)?;
5438 let idx_host = gpu_to_cpu(indices, device)?;
5439 let mut result = vec![0.0f32; input_len];
5440 for (i, &idx_f) in idx_host.iter().enumerate() {
5441 result[idx_f as usize] += go_host[i];
5442 }
5443 return cpu_to_gpu(&result, device);
5444 }
5445 };
5446
5447 let mut out = alloc_zeros_f32(input_len, device)?;
5448 let cfg = launch_cfg(n)?;
5449 let n_u32 = n as u32;
5450
5451 unsafe {
5452 stream
5453 .launch_builder(&f)
5454 .arg(grad_output.inner())
5455 .arg(indices.inner())
5456 .arg(out.inner_mut())
5457 .arg(&n_u32)
5458 .launch(cfg)?;
5459 }
5460
5461 Ok(out)
5462}
5463
5464#[cfg(feature = "cuda")]
5473pub fn gpu_masked_fill(
5474 input: &CudaBuffer<f32>,
5475 mask: &CudaBuffer<f32>,
5476 value: f32,
5477 device: &GpuDevice,
5478) -> GpuResult<CudaBuffer<f32>> {
5479 use cudarc::driver::PushKernelArg;
5480
5481 validate_binary(input, mask, device)?;
5482
5483 let n = input.len();
5484 let ctx = device.context();
5485 let stream = device.stream();
5486
5487 let f = match crate::module_cache::get_or_compile(
5488 ctx,
5489 MASKED_FILL_PTX,
5490 "masked_fill_kernel",
5491 device.ordinal() as u32,
5492 ) {
5493 Ok(f) => f,
5494 Err(_) => {
5495 let input_host = gpu_to_cpu(input, device)?;
5497 let mask_host = gpu_to_cpu(mask, device)?;
5498 let result: Vec<f32> = input_host
5499 .iter()
5500 .zip(mask_host.iter())
5501 .map(|(&x, &m)| if m >= 0.5 { value } else { x })
5502 .collect();
5503 return cpu_to_gpu(&result, device);
5504 }
5505 };
5506
5507 let mut out = alloc_zeros_f32(n, device)?;
5508 let cfg = launch_cfg(n)?;
5509 let n_u32 = n as u32;
5510
5511 unsafe {
5512 stream
5513 .launch_builder(&f)
5514 .arg(input.inner())
5515 .arg(mask.inner())
5516 .arg(out.inner_mut())
5517 .arg(&value)
5518 .arg(&n_u32)
5519 .launch(cfg)?;
5520 }
5521
5522 Ok(out)
5523}
5524
5525#[cfg(feature = "cuda")]
5534pub fn gpu_masked_zero(
5535 grad: &CudaBuffer<f32>,
5536 mask: &CudaBuffer<f32>,
5537 device: &GpuDevice,
5538) -> GpuResult<CudaBuffer<f32>> {
5539 validate_binary(grad, mask, device)?;
5540
5541 if let Some(out) = try_launch_binary(grad, mask, device, MASKED_ZERO_PTX, "masked_zero_kernel")?
5542 {
5543 return Ok(out);
5544 }
5545
5546 let grad_host = gpu_to_cpu(grad, device)?;
5548 let mask_host = gpu_to_cpu(mask, device)?;
5549 let result: Vec<f32> = grad_host
5550 .iter()
5551 .zip(mask_host.iter())
5552 .map(|(&g, &m)| if m >= 0.5 { 0.0 } else { g })
5553 .collect();
5554 cpu_to_gpu(&result, device)
5555}
5556
5557#[cfg(feature = "cuda")]
5565pub fn gpu_sigmoid_backward(
5566 grad: &CudaBuffer<f32>,
5567 output: &CudaBuffer<f32>,
5568 device: &GpuDevice,
5569) -> GpuResult<CudaBuffer<f32>> {
5570 validate_binary(grad, output, device)?;
5571
5572 if let Some(out) = try_launch_binary(
5573 grad,
5574 output,
5575 device,
5576 SIGMOID_BACKWARD_PTX,
5577 "sigmoid_backward_kernel",
5578 )? {
5579 return Ok(out);
5580 }
5581
5582 let grad_host = gpu_to_cpu(grad, device)?;
5584 let output_host = gpu_to_cpu(output, device)?;
5585 let result: Vec<f32> = grad_host
5586 .iter()
5587 .zip(output_host.iter())
5588 .map(|(&g, &o)| g * o * (1.0 - o))
5589 .collect();
5590 cpu_to_gpu(&result, device)
5591}
5592
5593#[cfg(feature = "cuda")]
5601pub fn gpu_tanh_backward(
5602 grad: &CudaBuffer<f32>,
5603 output: &CudaBuffer<f32>,
5604 device: &GpuDevice,
5605) -> GpuResult<CudaBuffer<f32>> {
5606 validate_binary(grad, output, device)?;
5607
5608 if let Some(out) = try_launch_binary(
5609 grad,
5610 output,
5611 device,
5612 TANH_BACKWARD_PTX,
5613 "tanh_backward_kernel",
5614 )? {
5615 return Ok(out);
5616 }
5617
5618 let grad_host = gpu_to_cpu(grad, device)?;
5620 let output_host = gpu_to_cpu(output, device)?;
5621 let result: Vec<f32> = grad_host
5622 .iter()
5623 .zip(output_host.iter())
5624 .map(|(&g, &o)| g * (1.0 - o * o))
5625 .collect();
5626 cpu_to_gpu(&result, device)
5627}
5628
5629#[cfg(feature = "cuda")]
5641pub fn gpu_softmax_backward(
5642 grad: &CudaBuffer<f32>,
5643 output: &CudaBuffer<f32>,
5644 cols: usize,
5645 device: &GpuDevice,
5646) -> GpuResult<CudaBuffer<f32>> {
5647 use cudarc::driver::PushKernelArg;
5648
5649 validate_binary(grad, output, device)?;
5650
5651 let total = grad.len();
5652 let rows = total / cols;
5653
5654 let ctx = device.context();
5655 let stream = device.stream();
5656
5657 let f = match crate::module_cache::get_or_compile(
5658 ctx,
5659 SOFTMAX_BACKWARD_PTX,
5660 "softmax_backward_kernel",
5661 device.ordinal() as u32,
5662 ) {
5663 Ok(f) => f,
5664 Err(_) => {
5665 let grad_host = gpu_to_cpu(grad, device)?;
5667 let output_host = gpu_to_cpu(output, device)?;
5668 let mut result = vec![0.0f32; total];
5669 for r in 0..rows {
5670 let base = r * cols;
5671 let mut dot = 0.0f32;
5672 for c in 0..cols {
5673 dot += grad_host[base + c] * output_host[base + c];
5674 }
5675 for c in 0..cols {
5676 result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
5677 }
5678 }
5679 return cpu_to_gpu(&result, device);
5680 }
5681 };
5682
5683 let mut out = alloc_zeros_f32(total, device)?;
5684 let rows_u32 = rows as u32;
5685 let cols_u32 = cols as u32;
5686
5687 let cfg = LaunchConfig {
5689 grid_dim: ((rows as u32).max(1), 1, 1),
5690 block_dim: (256, 1, 1),
5691 shared_mem_bytes: 256 * 4,
5692 };
5693
5694 unsafe {
5695 stream
5696 .launch_builder(&f)
5697 .arg(grad.inner())
5698 .arg(output.inner())
5699 .arg(out.inner_mut())
5700 .arg(&rows_u32)
5701 .arg(&cols_u32)
5702 .launch(cfg)?;
5703 }
5704
5705 Ok(out)
5706}
5707
5708#[cfg(feature = "cuda")]
5722pub fn gpu_reduce_sum(
5723 a: &CudaBuffer<f32>,
5724 device: &GpuDevice,
5725) -> GpuResult<CudaBuffer<f32>> {
5726 use cudarc::driver::PushKernelArg;
5727
5728 let n = a.len();
5729 if n == 0 {
5730 return cpu_to_gpu(&[0.0f32], device);
5731 }
5732
5733 let ctx = device.context();
5734 let stream = device.stream();
5735
5736 let f = match crate::module_cache::get_or_compile(
5737 ctx,
5738 REDUCE_SUM_PTX,
5739 "reduce_sum_kernel",
5740 device.ordinal() as u32,
5741 ) {
5742 Ok(f) => f,
5743 Err(_) => {
5744 let host = gpu_to_cpu(a, device)?;
5746 let total: f32 = host.iter().sum();
5747 return cpu_to_gpu(&[total], device);
5748 }
5749 };
5750
5751 const BLOCK: u32 = 256;
5753 let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
5754 let num_blocks = num_blocks.min(1024);
5756
5757 let mut partials = alloc_zeros_f32(num_blocks as usize, device)?;
5758 let n_u32 = n as u32;
5759
5760 let cfg = cudarc::driver::LaunchConfig {
5761 grid_dim: (num_blocks.max(1), 1, 1),
5762 block_dim: (BLOCK, 1, 1),
5763 shared_mem_bytes: 0, };
5765
5766 unsafe {
5767 stream
5768 .launch_builder(&f)
5769 .arg(a.inner())
5770 .arg(partials.inner_mut())
5771 .arg(&n_u32)
5772 .launch(cfg)?;
5773 }
5774
5775 if num_blocks <= 1 {
5777 return Ok(partials);
5778 }
5779
5780 if num_blocks <= 256 {
5782 let host_partials = gpu_to_cpu(&partials, device)?;
5783 let total: f32 = host_partials.iter().sum();
5784 return cpu_to_gpu(&[total], device);
5785 }
5786
5787 gpu_reduce_sum(&partials, device)
5789}
5790
5791#[cfg(not(feature = "cuda"))]
5793pub fn gpu_reduce_sum(
5794 _a: &CudaBuffer<f32>,
5795 _device: &GpuDevice,
5796) -> GpuResult<CudaBuffer<f32>> {
5797 Err(GpuError::NoCudaFeature)
5798}
5799
5800#[cfg(feature = "cuda")]
5804pub fn gpu_sum_axis(
5805 a: &CudaBuffer<f32>,
5806 outer: usize,
5807 axis_size: usize,
5808 inner: usize,
5809 device: &GpuDevice,
5810) -> GpuResult<CudaBuffer<f32>> {
5811 use cudarc::driver::PushKernelArg;
5812
5813 validate_unary(a, device)?;
5814
5815 let total_output = outer * inner;
5816 let ctx = device.context();
5817 let stream = device.stream();
5818
5819 let f = match crate::module_cache::get_or_compile(
5820 ctx,
5821 SUM_AXIS_PTX,
5822 "sum_axis_kernel",
5823 device.ordinal() as u32,
5824 ) {
5825 Ok(f) => f,
5826 Err(_) => {
5827 let host = gpu_to_cpu(a, device)?;
5829 let mut result = vec![0.0f32; total_output];
5830 for (i, out) in result.iter_mut().enumerate() {
5831 let outer_idx = i / inner;
5832 let inner_idx = i % inner;
5833 let mut sum = 0.0f32;
5834 for k in 0..axis_size {
5835 sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
5836 }
5837 *out = sum;
5838 }
5839 return cpu_to_gpu(&result, device);
5840 }
5841 };
5842
5843 let mut out = alloc_zeros_f32(total_output, device)?;
5844 let cfg = launch_cfg(total_output)?;
5845 let outer_u32 = outer as u32;
5846 let axis_size_u32 = axis_size as u32;
5847 let inner_u32 = inner as u32;
5848 let total_u32 = total_output as u32;
5849
5850 unsafe {
5851 stream
5852 .launch_builder(&f)
5853 .arg(a.inner())
5854 .arg(out.inner_mut())
5855 .arg(&outer_u32)
5856 .arg(&axis_size_u32)
5857 .arg(&inner_u32)
5858 .arg(&total_u32)
5859 .launch(cfg)?;
5860 }
5861
5862 Ok(out)
5863}
5864
5865#[cfg(feature = "cuda")]
5883pub fn gpu_strided_split(
5884 input: &CudaBuffer<f32>,
5885 total_along_axis: usize,
5886 split_offset: usize,
5887 split_size: usize,
5888 inner_size: usize,
5889 n: usize,
5890 device: &GpuDevice,
5891) -> GpuResult<CudaBuffer<f32>> {
5892 use cudarc::driver::PushKernelArg;
5893
5894 validate_unary(input, device)?;
5895
5896 let ctx = device.context();
5897 let stream = device.stream();
5898
5899 let f = match crate::module_cache::get_or_compile(
5900 ctx,
5901 STRIDED_SPLIT_PTX,
5902 "strided_split_kernel",
5903 device.ordinal() as u32,
5904 ) {
5905 Ok(f) => f,
5906 Err(_) => {
5907 let host = gpu_to_cpu(input, device)?;
5909 let outer = n / (split_size * inner_size);
5910 let mut result = vec![0.0f32; n];
5911 for (i, out) in result.iter_mut().enumerate() {
5912 let outer_idx = i / (split_size * inner_size);
5913 let within = i % (split_size * inner_size);
5914 let src_idx =
5915 outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
5916 *out = host[src_idx];
5917 }
5918 let _ = outer;
5919 return cpu_to_gpu(&result, device);
5920 }
5921 };
5922
5923 let mut out = alloc_zeros_f32(n, device)?;
5924 let cfg = launch_cfg(n)?;
5925 let total_ax_u32 = total_along_axis as u32;
5926 let offset_u32 = split_offset as u32;
5927 let split_sz_u32 = split_size as u32;
5928 let inner_u32 = inner_size as u32;
5929 let n_u32 = n as u32;
5930
5931 unsafe {
5932 stream
5933 .launch_builder(&f)
5934 .arg(input.inner())
5935 .arg(out.inner_mut())
5936 .arg(&total_ax_u32)
5937 .arg(&offset_u32)
5938 .arg(&split_sz_u32)
5939 .arg(&inner_u32)
5940 .arg(&n_u32)
5941 .launch(cfg)?;
5942 }
5943
5944 Ok(out)
5945}
5946
5947#[cfg(feature = "cuda")]
5971#[allow(clippy::too_many_arguments)]
5972pub fn gpu_strided_cat(
5973 input: &CudaBuffer<f32>,
5974 output: &mut CudaBuffer<f32>,
5975 total_along_axis: usize,
5976 cat_offset: usize,
5977 part_size: usize,
5978 inner_size: usize,
5979 n: usize,
5980 device: &GpuDevice,
5981) -> GpuResult<()> {
5982 use cudarc::driver::PushKernelArg;
5983
5984 validate_unary(input, device)?;
5985
5986 let ctx = device.context();
5987 let stream = device.stream();
5988
5989 let f = match crate::module_cache::get_or_compile(
5990 ctx,
5991 STRIDED_CAT_PTX,
5992 "strided_cat_kernel",
5993 device.ordinal() as u32,
5994 ) {
5995 Ok(f) => f,
5996 Err(_) => {
5997 let host_in = gpu_to_cpu(input, device)?;
5999 let mut host_out = gpu_to_cpu(output, device)?;
6000 for (i, &val) in host_in.iter().enumerate().take(n) {
6001 let outer_idx = i / (part_size * inner_size);
6002 let within = i % (part_size * inner_size);
6003 let dst_idx =
6004 outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
6005 host_out[dst_idx] = val;
6006 }
6007 *output = cpu_to_gpu(&host_out, device)?;
6008 return Ok(());
6009 }
6010 };
6011
6012 let cfg = launch_cfg(n)?;
6013 let total_ax_u32 = total_along_axis as u32;
6014 let offset_u32 = cat_offset as u32;
6015 let part_sz_u32 = part_size as u32;
6016 let inner_u32 = inner_size as u32;
6017 let n_u32 = n as u32;
6018
6019 unsafe {
6020 stream
6021 .launch_builder(&f)
6022 .arg(input.inner())
6023 .arg(output.inner_mut())
6024 .arg(&total_ax_u32)
6025 .arg(&offset_u32)
6026 .arg(&part_sz_u32)
6027 .arg(&inner_u32)
6028 .arg(&n_u32)
6029 .launch(cfg)?;
6030 }
6031
6032 Ok(())
6033}
6034
6035#[cfg(feature = "cuda")]
6044pub fn gpu_scale(
6045 a: &CudaBuffer<f32>,
6046 scalar: f32,
6047 device: &GpuDevice,
6048) -> GpuResult<CudaBuffer<f32>> {
6049 use cudarc::driver::PushKernelArg;
6050
6051 validate_unary(a, device)?;
6052
6053 let n = a.len();
6054 let ctx = device.context();
6055 let stream = device.stream();
6056
6057 let f = match crate::module_cache::get_or_compile(
6058 ctx,
6059 SCALE_PTX,
6060 "scale_kernel",
6061 device.ordinal() as u32,
6062 ) {
6063 Ok(f) => f,
6064 Err(_) => {
6065 let host = gpu_to_cpu(a, device)?;
6067 let result: Vec<f32> = host.iter().map(|&x| x * scalar).collect();
6068 return cpu_to_gpu(&result, device);
6069 }
6070 };
6071
6072 let mut out = alloc_zeros_f32(n, device)?;
6073 let cfg = launch_cfg(n)?;
6074 let n_u32 = n as u32;
6075
6076 unsafe {
6077 stream
6078 .launch_builder(&f)
6079 .arg(a.inner())
6080 .arg(out.inner_mut())
6081 .arg(&scalar)
6082 .arg(&n_u32)
6083 .launch(cfg)?;
6084 }
6085
6086 Ok(out)
6087}
6088
6089#[cfg(feature = "cuda")]
6097pub fn gpu_softmax(
6098 input: &CudaBuffer<f32>,
6099 rows: usize,
6100 cols: usize,
6101 device: &GpuDevice,
6102) -> GpuResult<CudaBuffer<f32>> {
6103 use cudarc::driver::PushKernelArg;
6104
6105 validate_unary(input, device)?;
6106
6107 let ctx = device.context();
6108 let stream = device.stream();
6109
6110 let f = match crate::module_cache::get_or_compile(
6111 ctx,
6112 SOFTMAX_PTX,
6113 "softmax_kernel",
6114 device.ordinal() as u32,
6115 ) {
6116 Ok(f) => f,
6117 Err(_) => {
6118 let host = gpu_to_cpu(input, device)?;
6120 let mut out = vec![0.0f32; host.len()];
6121 for r in 0..rows {
6122 let base = r * cols;
6123 let mut max_v = f32::NEG_INFINITY;
6124 for c in 0..cols {
6125 max_v = max_v.max(host[base + c]);
6126 }
6127 let mut sum = 0.0f32;
6128 for c in 0..cols {
6129 let e = (host[base + c] - max_v).exp();
6130 out[base + c] = e;
6131 sum += e;
6132 }
6133 let inv = 1.0 / sum;
6134 for c in 0..cols {
6135 out[base + c] *= inv;
6136 }
6137 }
6138 return cpu_to_gpu(&out, device);
6139 }
6140 };
6141
6142 let mut out = alloc_zeros_f32(rows * cols, device)?;
6143 let rows_u32 = rows as u32;
6144 let cols_u32 = cols as u32;
6145
6146 let cfg = LaunchConfig {
6148 grid_dim: ((rows as u32).max(1), 1, 1),
6149 block_dim: (256, 1, 1),
6150 shared_mem_bytes: 256 * 4, };
6152
6153 unsafe {
6154 stream
6155 .launch_builder(&f)
6156 .arg(input.inner())
6157 .arg(out.inner_mut())
6158 .arg(&rows_u32)
6159 .arg(&cols_u32)
6160 .launch(cfg)?;
6161 }
6162
6163 Ok(out)
6164}
6165
6166#[cfg(feature = "cuda")]
6185pub fn gpu_dropout(
6186 input: &CudaBuffer<f32>,
6187 threshold: u32,
6188 scale: f32,
6189 seed: u32,
6190 device: &GpuDevice,
6191) -> GpuResult<CudaBuffer<f32>> {
6192 use cudarc::driver::PushKernelArg;
6193
6194 validate_unary(input, device)?;
6195
6196 let n = input.len();
6197 let ctx = device.context();
6198 let stream = device.stream();
6199
6200 let f = match crate::module_cache::get_or_compile(
6201 ctx,
6202 DROPOUT_PTX,
6203 "dropout_kernel",
6204 device.ordinal() as u32,
6205 ) {
6206 Ok(f) => f,
6207 Err(_) => {
6208 let host = gpu_to_cpu(input, device)?;
6210 let result: Vec<f32> = host
6214 .iter()
6215 .enumerate()
6216 .map(|(i, &x)| {
6217 let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
6218 r ^= r << 13;
6219 r ^= r >> 17;
6220 r ^= r << 5;
6221 if r < threshold { 0.0 } else { x * scale }
6222 })
6223 .collect();
6224 return cpu_to_gpu(&result, device);
6225 }
6226 };
6227
6228 let mut out = alloc_zeros_f32(n, device)?;
6229 let cfg = launch_cfg(n)?;
6230 let n_u32 = n as u32;
6231
6232 unsafe {
6233 stream
6234 .launch_builder(&f)
6235 .arg(input.inner())
6236 .arg(out.inner_mut())
6237 .arg(&n_u32)
6238 .arg(&threshold)
6239 .arg(&scale)
6240 .arg(&seed)
6241 .launch(cfg)?;
6242 }
6243
6244 Ok(out)
6245}
6246
6247#[cfg(feature = "cuda")]
6253pub fn gpu_transpose_2d(
6254 input: &CudaBuffer<f32>,
6255 m: usize,
6256 n: usize,
6257 device: &GpuDevice,
6258) -> GpuResult<CudaBuffer<f32>> {
6259 use cudarc::driver::PushKernelArg;
6260
6261 validate_unary(input, device)?;
6262
6263 let total = m * n;
6264 let ctx = device.context();
6265 let stream = device.stream();
6266
6267 let f = match crate::module_cache::get_or_compile(
6268 ctx,
6269 TRANSPOSE_2D_PTX,
6270 "transpose_2d_kernel",
6271 device.ordinal() as u32,
6272 ) {
6273 Ok(f) => f,
6274 Err(_) => {
6275 let host = gpu_to_cpu(input, device)?;
6277 let mut out = vec![0.0f32; total];
6278 for i in 0..m {
6279 for j in 0..n {
6280 out[j * m + i] = host[i * n + j];
6281 }
6282 }
6283 return cpu_to_gpu(&out, device);
6284 }
6285 };
6286
6287 let mut out = alloc_zeros_f32(total, device)?;
6288 let cfg = launch_cfg(total)?;
6289 let m_u32 = m as u32;
6290 let n_u32 = n as u32;
6291 let total_u32 = total as u32;
6292
6293 unsafe {
6294 stream
6295 .launch_builder(&f)
6296 .arg(input.inner())
6297 .arg(out.inner_mut())
6298 .arg(&m_u32)
6299 .arg(&n_u32)
6300 .arg(&total_u32)
6301 .launch(cfg)?;
6302 }
6303
6304 Ok(out)
6305}
6306
6307#[cfg(feature = "cuda")]
6314pub fn gpu_permute_0213(
6315 input: &CudaBuffer<f32>,
6316 d0: usize,
6317 d1: usize,
6318 d2: usize,
6319 d3: usize,
6320 device: &GpuDevice,
6321) -> GpuResult<CudaBuffer<f32>> {
6322 use cudarc::driver::PushKernelArg;
6323
6324 validate_unary(input, device)?;
6325
6326 let total = d0 * d1 * d2 * d3;
6327 let ctx = device.context();
6328 let stream = device.stream();
6329
6330 let f = match crate::module_cache::get_or_compile(
6331 ctx,
6332 PERMUTE_0213_PTX,
6333 "permute_0213_kernel",
6334 device.ordinal() as u32,
6335 ) {
6336 Ok(f) => f,
6337 Err(_) => {
6338 let host = gpu_to_cpu(input, device)?;
6340 let mut out = vec![0.0f32; total];
6341 for i0 in 0..d0 {
6342 for i1 in 0..d1 {
6343 for i2 in 0..d2 {
6344 for i3 in 0..d3 {
6345 let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
6346 let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
6347 out[out_idx] = host[in_idx];
6348 }
6349 }
6350 }
6351 }
6352 return cpu_to_gpu(&out, device);
6353 }
6354 };
6355
6356 let mut out = alloc_zeros_f32(total, device)?;
6357 let cfg = launch_cfg(total)?;
6358 let d0_u32 = d0 as u32;
6359 let d1_u32 = d1 as u32;
6360 let d2_u32 = d2 as u32;
6361 let d3_u32 = d3 as u32;
6362 let total_u32 = total as u32;
6363
6364 unsafe {
6365 stream
6366 .launch_builder(&f)
6367 .arg(input.inner())
6368 .arg(out.inner_mut())
6369 .arg(&d0_u32)
6370 .arg(&d1_u32)
6371 .arg(&d2_u32)
6372 .arg(&d3_u32)
6373 .arg(&total_u32)
6374 .launch(cfg)?;
6375 }
6376
6377 Ok(out)
6378}
6379
6380#[cfg(feature = "cuda")]
6389pub fn gpu_small_matmul(
6390 a: &CudaBuffer<f32>,
6391 b: &CudaBuffer<f32>,
6392 m: usize,
6393 k: usize,
6394 n: usize,
6395 device: &GpuDevice,
6396) -> GpuResult<CudaBuffer<f32>> {
6397 use cudarc::driver::PushKernelArg;
6398
6399 let total = m * n;
6400 let ctx = device.context();
6401 let stream = device.stream();
6402
6403 let f = match crate::module_cache::get_or_compile(
6404 ctx,
6405 SMALL_MATMUL_PTX,
6406 "small_matmul_kernel",
6407 device.ordinal() as u32,
6408 ) {
6409 Ok(f) => f,
6410 Err(_) => {
6411 return crate::blas::gpu_matmul_f32(a, b, m, k, n, device);
6413 }
6414 };
6415
6416 let mut c = alloc_zeros_f32(total, device)?;
6417 let cfg = launch_cfg(total)?;
6418 let m_u32 = m as u32;
6419 let k_u32 = k as u32;
6420 let n_u32 = n as u32;
6421 let total_u32 = total as u32;
6422
6423 unsafe {
6424 stream
6425 .launch_builder(&f)
6426 .arg(a.inner())
6427 .arg(b.inner())
6428 .arg(c.inner_mut())
6429 .arg(&m_u32)
6430 .arg(&k_u32)
6431 .arg(&n_u32)
6432 .arg(&total_u32)
6433 .launch(cfg)?;
6434 }
6435
6436 Ok(c)
6437}
6438
6439#[cfg(feature = "cuda")]
6448pub fn gpu_small_bmm(
6449 a: &CudaBuffer<f32>,
6450 b: &CudaBuffer<f32>,
6451 batch: usize,
6452 m: usize,
6453 k: usize,
6454 n: usize,
6455 device: &GpuDevice,
6456) -> GpuResult<CudaBuffer<f32>> {
6457 if batch == 1 {
6459 return gpu_small_matmul(a, b, m, k, n, device);
6460 }
6461 crate::blas::gpu_bmm_f32(a, b, batch, m, k, n, device)
6464}
6465
6466#[cfg(feature = "cuda")]
6474pub fn gpu_embed_lookup(
6475 idx: &CudaBuffer<f32>,
6476 weight: &CudaBuffer<f32>,
6477 d: usize,
6478 device: &GpuDevice,
6479) -> GpuResult<CudaBuffer<f32>> {
6480 use cudarc::driver::PushKernelArg;
6481
6482 let ctx = device.context();
6483 let stream = device.stream();
6484
6485 let f = match crate::module_cache::get_or_compile(
6486 ctx,
6487 EMBED_LOOKUP_PTX,
6488 "embed_lookup_kernel",
6489 device.ordinal() as u32,
6490 ) {
6491 Ok(f) => f,
6492 Err(_) => {
6493 let idx_host = gpu_to_cpu(idx, device)?;
6495 let weight_host = gpu_to_cpu(weight, device)?;
6496 let row = idx_host[0] as usize;
6497 let start = row * d;
6498 let out = weight_host[start..start + d].to_vec();
6499 return cpu_to_gpu(&out, device);
6500 }
6501 };
6502
6503 let mut out = alloc_zeros_f32(d, device)?;
6504 let cfg = launch_cfg(d)?;
6505 let d_u32 = d as u32;
6506
6507 unsafe {
6508 stream
6509 .launch_builder(&f)
6510 .arg(idx.inner())
6511 .arg(weight.inner())
6512 .arg(out.inner_mut())
6513 .arg(&d_u32)
6514 .launch(cfg)?;
6515 }
6516
6517 Ok(out)
6518}
6519
6520#[cfg(feature = "cuda")]
6527pub fn gpu_slice_write(
6528 src: &CudaBuffer<f32>,
6529 dst: &mut CudaBuffer<f32>,
6530 n_batch: usize,
6531 d: usize,
6532 max_len: usize,
6533 pos: usize,
6534 device: &GpuDevice,
6535) -> GpuResult<()> {
6536 use cudarc::driver::PushKernelArg;
6537
6538 let total = n_batch * d;
6539 let ctx = device.context();
6540 let stream = device.stream();
6541
6542 let f = match crate::module_cache::get_or_compile(
6543 ctx,
6544 SLICE_WRITE_PTX,
6545 "slice_write_kernel",
6546 device.ordinal() as u32,
6547 ) {
6548 Ok(f) => f,
6549 Err(_) => {
6550 let src_host = gpu_to_cpu(src, device)?;
6552 let mut dst_host = gpu_to_cpu(dst, device)?;
6553 for b in 0..n_batch {
6554 for di in 0..d {
6555 dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
6556 }
6557 }
6558 let new_dst = cpu_to_gpu(&dst_host, device)?;
6559 *dst = new_dst;
6560 return Ok(());
6561 }
6562 };
6563
6564 let cfg = launch_cfg(total)?;
6565 let n_u32 = total as u32;
6566 let d_u32 = d as u32;
6567 let max_len_u32 = max_len as u32;
6568 let pos_u32 = pos as u32;
6569
6570 unsafe {
6571 stream
6572 .launch_builder(&f)
6573 .arg(src.inner())
6574 .arg(dst.inner_mut())
6575 .arg(&n_u32)
6576 .arg(&d_u32)
6577 .arg(&max_len_u32)
6578 .arg(&pos_u32)
6579 .launch(cfg)?;
6580 }
6581
6582 Ok(())
6583}
6584
6585#[cfg(feature = "cuda")]
6591pub fn gpu_slice_read(
6592 src: &CudaBuffer<f32>,
6593 n_batch: usize,
6594 d: usize,
6595 len: usize,
6596 max_len: usize,
6597 device: &GpuDevice,
6598) -> GpuResult<CudaBuffer<f32>> {
6599 use cudarc::driver::PushKernelArg;
6600
6601 let total = n_batch * len * d;
6602 let ctx = device.context();
6603 let stream = device.stream();
6604
6605 let f = match crate::module_cache::get_or_compile(
6606 ctx,
6607 SLICE_READ_PTX,
6608 "slice_read_kernel",
6609 device.ordinal() as u32,
6610 ) {
6611 Ok(f) => f,
6612 Err(_) => {
6613 let host = gpu_to_cpu(src, device)?;
6614 let mut out = vec![0.0f32; total];
6615 for b in 0..n_batch {
6616 for r in 0..len {
6617 for di in 0..d {
6618 out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
6619 }
6620 }
6621 }
6622 return cpu_to_gpu(&out, device);
6623 }
6624 };
6625
6626 let mut out = alloc_zeros_f32(total, device)?;
6627 let cfg = launch_cfg(total)?;
6628 let total_u32 = total as u32;
6629 let d_u32 = d as u32;
6630 let len_u32 = len as u32;
6631 let max_len_u32 = max_len as u32;
6632
6633 unsafe {
6634 stream
6635 .launch_builder(&f)
6636 .arg(src.inner())
6637 .arg(out.inner_mut())
6638 .arg(&total_u32)
6639 .arg(&d_u32)
6640 .arg(&len_u32)
6641 .arg(&max_len_u32)
6642 .launch(cfg)?;
6643 }
6644
6645 Ok(out)
6646}
6647
6648#[cfg(feature = "cuda")]
6654pub fn gpu_gelu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6655 validate_unary(input, device)?;
6656 if let Some(out) = try_launch_unary(input, device, GELU_PTX, "gelu_kernel")? {
6657 return Ok(out);
6658 }
6659 cpu_fallback_unary(input, device, |x| {
6660 let s = 1.0 / (1.0 + (-1.702 * x).exp());
6661 x * s
6662 })
6663}
6664
6665#[cfg(feature = "cuda")]
6671pub fn gpu_div(
6672 a: &CudaBuffer<f32>,
6673 b: &CudaBuffer<f32>,
6674 device: &GpuDevice,
6675) -> GpuResult<CudaBuffer<f32>> {
6676 validate_binary(a, b, device)?;
6677
6678 if let Some(out) = try_launch_binary(a, b, device, DIV_PTX, "div_kernel")? {
6679 return Ok(out);
6680 }
6681
6682 let a_host = gpu_to_cpu(a, device)?;
6684 let b_host = gpu_to_cpu(b, device)?;
6685 let result: Vec<f32> = a_host
6686 .iter()
6687 .zip(b_host.iter())
6688 .map(|(&x, &y)| x / y)
6689 .collect();
6690 cpu_to_gpu(&result, device)
6691}
6692
6693#[cfg(feature = "cuda")]
6695pub fn gpu_exp(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6696 validate_unary(a, device)?;
6697 if let Some(out) = try_launch_unary(a, device, EXP_PTX, "exp_kernel")? {
6698 return Ok(out);
6699 }
6700 cpu_fallback_unary(a, device, |x| x.exp())
6701}
6702
6703#[cfg(feature = "cuda")]
6705pub fn gpu_log(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6706 validate_unary(a, device)?;
6707 if let Some(out) = try_launch_unary(a, device, LOG_PTX, "log_kernel")? {
6708 return Ok(out);
6709 }
6710 cpu_fallback_unary(a, device, |x| x.ln())
6711}
6712
6713#[cfg(feature = "cuda")]
6715pub fn gpu_sqrt(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6716 validate_unary(a, device)?;
6717 if let Some(out) = try_launch_unary(a, device, SQRT_PTX, "sqrt_kernel")? {
6718 return Ok(out);
6719 }
6720 cpu_fallback_unary(a, device, |x| x.sqrt())
6721}
6722
6723#[cfg(feature = "cuda")]
6725pub fn gpu_pow(
6726 a: &CudaBuffer<f32>,
6727 exponent: f32,
6728 device: &GpuDevice,
6729) -> GpuResult<CudaBuffer<f32>> {
6730 use cudarc::driver::PushKernelArg;
6731
6732 validate_unary(a, device)?;
6733
6734 let n = a.len();
6735 let ctx = device.context();
6736 let stream = device.stream();
6737
6738 let f = match crate::module_cache::get_or_compile(
6739 ctx,
6740 POW_PTX,
6741 "pow_kernel",
6742 device.ordinal() as u32,
6743 ) {
6744 Ok(f) => f,
6745 Err(_) => {
6746 let host = gpu_to_cpu(a, device)?;
6747 let result: Vec<f32> = host.iter().map(|&x| x.powf(exponent)).collect();
6748 return cpu_to_gpu(&result, device);
6749 }
6750 };
6751
6752 let mut out = alloc_zeros_f32(n, device)?;
6753 let cfg = launch_cfg(n)?;
6754 let n_u32 = n as u32;
6755
6756 unsafe {
6757 stream
6758 .launch_builder(&f)
6759 .arg(a.inner())
6760 .arg(out.inner_mut())
6761 .arg(&exponent)
6762 .arg(&n_u32)
6763 .launch(cfg)?;
6764 }
6765
6766 Ok(out)
6767}
6768
6769#[cfg(feature = "cuda")]
6771pub fn gpu_abs(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6772 validate_unary(a, device)?;
6773 if let Some(out) = try_launch_unary(a, device, ABS_PTX, "abs_kernel")? {
6774 return Ok(out);
6775 }
6776 cpu_fallback_unary(a, device, |x| x.abs())
6777}
6778
6779#[cfg(feature = "cuda")]
6781pub fn gpu_sigmoid(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6782 validate_unary(a, device)?;
6783 if let Some(out) = try_launch_unary(a, device, SIGMOID_PTX, "sigmoid_kernel")? {
6784 return Ok(out);
6785 }
6786 cpu_fallback_unary(a, device, |x| 1.0 / (1.0 + (-x).exp()))
6787}
6788
6789#[cfg(feature = "cuda")]
6791pub fn gpu_tanh(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
6792 validate_unary(a, device)?;
6793 if let Some(out) = try_launch_unary(a, device, TANH_PTX, "tanh_kernel")? {
6794 return Ok(out);
6795 }
6796 cpu_fallback_unary(a, device, |x| x.tanh())
6797}
6798
6799#[cfg(feature = "cuda")]
6809#[allow(clippy::too_many_arguments)]
6810pub fn gpu_fused_adam(
6811 param: &mut CudaBuffer<f32>,
6812 grad: &CudaBuffer<f32>,
6813 exp_avg: &mut CudaBuffer<f32>,
6814 exp_avg_sq: &mut CudaBuffer<f32>,
6815 beta1: f32,
6816 beta2: f32,
6817 lr: f32,
6818 eps: f32,
6819 bc1: f32,
6820 bc2: f32,
6821 weight_decay: f32,
6822 device: &GpuDevice,
6823) -> GpuResult<()> {
6824 use cudarc::driver::PushKernelArg;
6825
6826 let n = param.len();
6827 if grad.len() != n || exp_avg.len() != n || exp_avg_sq.len() != n {
6828 return Err(GpuError::LengthMismatch {
6829 a: n,
6830 b: grad.len(),
6831 });
6832 }
6833
6834 let ctx = device.context();
6835 let stream = device.stream();
6836
6837 let f = match crate::module_cache::get_or_compile(
6838 ctx,
6839 FUSED_ADAM_PTX,
6840 "fused_adam_kernel",
6841 device.ordinal() as u32,
6842 ) {
6843 Ok(f) => f,
6844 Err(_) => {
6845 let mut p_host = gpu_to_cpu(param, device)?;
6847 let g_host = gpu_to_cpu(grad, device)?;
6848 let mut m_host = gpu_to_cpu(exp_avg, device)?;
6849 let mut v_host = gpu_to_cpu(exp_avg_sq, device)?;
6850
6851 for i in 0..n {
6852 let mut g = g_host[i];
6853 if weight_decay > 0.0 {
6854 g += weight_decay * p_host[i];
6855 }
6856 m_host[i] = beta1 * m_host[i] + (1.0 - beta1) * g;
6857 v_host[i] = beta2 * v_host[i] + (1.0 - beta2) * g * g;
6858 let m_hat = m_host[i] / bc1;
6859 let v_hat = v_host[i] / bc2;
6860 p_host[i] -= lr * m_hat / (v_hat.sqrt() + eps);
6861 }
6862
6863 *param = cpu_to_gpu(&p_host, device)?;
6864 *exp_avg = cpu_to_gpu(&m_host, device)?;
6865 *exp_avg_sq = cpu_to_gpu(&v_host, device)?;
6866 return Ok(());
6867 }
6868 };
6869
6870 let cfg = launch_cfg(n)?;
6871 let n_u32 = n as u32;
6872
6873 unsafe {
6874 stream
6875 .launch_builder(&f)
6876 .arg(param.inner_mut())
6877 .arg(grad.inner())
6878 .arg(exp_avg.inner_mut())
6879 .arg(exp_avg_sq.inner_mut())
6880 .arg(&beta1)
6881 .arg(&beta2)
6882 .arg(&lr)
6883 .arg(&eps)
6884 .arg(&bc1)
6885 .arg(&bc2)
6886 .arg(&weight_decay)
6887 .arg(&n_u32)
6888 .launch(cfg)?;
6889 }
6890
6891 Ok(())
6892}
6893
6894#[cfg(not(feature = "cuda"))]
6896#[allow(clippy::too_many_arguments)]
6897pub fn gpu_fused_adam(
6898 _param: &mut CudaBuffer<f32>,
6899 _grad: &CudaBuffer<f32>,
6900 _exp_avg: &mut CudaBuffer<f32>,
6901 _exp_avg_sq: &mut CudaBuffer<f32>,
6902 _beta1: f32,
6903 _beta2: f32,
6904 _lr: f32,
6905 _eps: f32,
6906 _bc1: f32,
6907 _bc2: f32,
6908 _weight_decay: f32,
6909 _device: &GpuDevice,
6910) -> GpuResult<()> {
6911 Err(GpuError::NoCudaFeature)
6912}
6913
6914#[cfg(feature = "cuda")]
6932pub fn gpu_fused_gru_forward(
6933 input_gates: &CudaBuffer<f32>,
6934 hidden_gates: &CudaBuffer<f32>,
6935 bias_ih: &CudaBuffer<f32>,
6936 bias_hh: &CudaBuffer<f32>,
6937 hx: &CudaBuffer<f32>,
6938 hsz: usize,
6939 device: &GpuDevice,
6940) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
6941 use cudarc::driver::PushKernelArg;
6942
6943 let total = hx.len(); let batch = total / hsz;
6945
6946 let ctx = device.context();
6947 let stream = device.stream();
6948
6949 let f = match crate::module_cache::get_or_compile(
6950 ctx,
6951 FUSED_GRU_FORWARD_PTX,
6952 "fused_gru_forward_kernel",
6953 device.ordinal() as u32,
6954 ) {
6955 Ok(f) => f,
6956 Err(_) => {
6957 return Err(GpuError::PtxCompileFailed {
6958 kernel: "fused_gru_forward_kernel",
6959 });
6960 }
6961 };
6962
6963 let mut hy = alloc_zeros_f32(total, device)?;
6964 let mut workspace = alloc_zeros_f32(batch * 5 * hsz, device)?;
6965
6966 let cfg = launch_cfg(total)?;
6967 let hsz_u32 = hsz as u32;
6968 let total_u32 = total as u32;
6969
6970 unsafe {
6971 stream
6972 .launch_builder(&f)
6973 .arg(input_gates.inner())
6974 .arg(hidden_gates.inner())
6975 .arg(bias_ih.inner())
6976 .arg(bias_hh.inner())
6977 .arg(hx.inner())
6978 .arg(hy.inner_mut())
6979 .arg(workspace.inner_mut())
6980 .arg(&hsz_u32)
6981 .arg(&total_u32)
6982 .launch(cfg)?;
6983 }
6984
6985 Ok((hy, workspace))
6986}
6987
6988#[cfg(not(feature = "cuda"))]
6990pub fn gpu_fused_gru_forward(
6991 _input_gates: &CudaBuffer<f32>,
6992 _hidden_gates: &CudaBuffer<f32>,
6993 _bias_ih: &CudaBuffer<f32>,
6994 _bias_hh: &CudaBuffer<f32>,
6995 _hx: &CudaBuffer<f32>,
6996 _hsz: usize,
6997 _device: &GpuDevice,
6998) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
6999 Err(GpuError::NoCudaFeature)
7000}
7001
7002#[cfg(feature = "cuda")]
7008#[allow(clippy::too_many_arguments)]
7009pub fn gpu_maxpool2d(
7010 input: &CudaBuffer<f32>,
7011 batch: usize,
7012 channels: usize,
7013 h_in: usize,
7014 w_in: usize,
7015 kh: usize,
7016 kw: usize,
7017 sh: usize,
7018 sw: usize,
7019 ph: usize,
7020 pw: usize,
7021 device: &GpuDevice,
7022) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7023 use cudarc::driver::PushKernelArg;
7024
7025 let h_out = (h_in + 2 * ph - kh) / sh + 1;
7026 let w_out = (w_in + 2 * pw - kw) / sw + 1;
7027 let total = batch * channels * h_out * w_out;
7028
7029 let ctx = device.context();
7030 let stream = device.stream();
7031
7032 let f = match crate::module_cache::get_or_compile(
7033 ctx, MAXPOOL2D_PTX, "maxpool2d_forward_kernel", device.ordinal() as u32,
7034 ) {
7035 Ok(f) => f,
7036 Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "maxpool2d_forward_kernel" }),
7037 };
7038
7039 let mut out = alloc_zeros_f32(total, device)?;
7040 let cfg = launch_cfg(total)?;
7041
7042 let (batch_u32, ch_u32) = (batch as u32, channels as u32);
7043 let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
7044 let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
7045 let (kh_u32, kw_u32) = (kh as u32, kw as u32);
7046 let (sh_u32, sw_u32) = (sh as u32, sw as u32);
7047 let (ph_u32, pw_u32) = (ph as u32, pw as u32);
7048 let total_u32 = total as u32;
7049
7050 unsafe {
7051 stream.launch_builder(&f)
7052 .arg(input.inner())
7053 .arg(out.inner_mut())
7054 .arg(&batch_u32).arg(&ch_u32)
7055 .arg(&h_in_u32).arg(&w_in_u32)
7056 .arg(&h_out_u32).arg(&w_out_u32)
7057 .arg(&kh_u32).arg(&kw_u32)
7058 .arg(&sh_u32).arg(&sw_u32)
7059 .arg(&ph_u32).arg(&pw_u32)
7060 .arg(&total_u32)
7061 .launch(cfg)?;
7062 }
7063
7064 Ok((out, [batch, channels, h_out, w_out]))
7065}
7066
7067#[cfg(not(feature = "cuda"))]
7069#[allow(clippy::too_many_arguments)]
7070pub fn gpu_maxpool2d(
7071 _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
7072 _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
7073 _sh: usize, _sw: usize, _ph: usize, _pw: usize,
7074 _device: &GpuDevice,
7075) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7076 Err(GpuError::NoCudaFeature)
7077}
7078
7079#[cfg(feature = "cuda")]
7081#[allow(clippy::too_many_arguments)]
7082pub fn gpu_avgpool2d(
7083 input: &CudaBuffer<f32>,
7084 batch: usize,
7085 channels: usize,
7086 h_in: usize,
7087 w_in: usize,
7088 kh: usize,
7089 kw: usize,
7090 sh: usize,
7091 sw: usize,
7092 ph: usize,
7093 pw: usize,
7094 device: &GpuDevice,
7095) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7096 use cudarc::driver::PushKernelArg;
7097
7098 let h_out = (h_in + 2 * ph - kh) / sh + 1;
7099 let w_out = (w_in + 2 * pw - kw) / sw + 1;
7100 let total = batch * channels * h_out * w_out;
7101
7102 let ctx = device.context();
7103 let stream = device.stream();
7104
7105 let f = match crate::module_cache::get_or_compile(
7106 ctx, AVGPOOL2D_PTX, "avgpool2d_forward_kernel", device.ordinal() as u32,
7107 ) {
7108 Ok(f) => f,
7109 Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "avgpool2d_forward_kernel" }),
7110 };
7111
7112 let mut out = alloc_zeros_f32(total, device)?;
7113 let cfg = launch_cfg(total)?;
7114
7115 let (batch_u32, ch_u32) = (batch as u32, channels as u32);
7116 let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
7117 let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
7118 let (kh_u32, kw_u32) = (kh as u32, kw as u32);
7119 let (sh_u32, sw_u32) = (sh as u32, sw as u32);
7120 let (ph_u32, pw_u32) = (ph as u32, pw as u32);
7121 let total_u32 = total as u32;
7122
7123 unsafe {
7124 stream.launch_builder(&f)
7125 .arg(input.inner())
7126 .arg(out.inner_mut())
7127 .arg(&batch_u32).arg(&ch_u32)
7128 .arg(&h_in_u32).arg(&w_in_u32)
7129 .arg(&h_out_u32).arg(&w_out_u32)
7130 .arg(&kh_u32).arg(&kw_u32)
7131 .arg(&sh_u32).arg(&sw_u32)
7132 .arg(&ph_u32).arg(&pw_u32)
7133 .arg(&total_u32)
7134 .launch(cfg)?;
7135 }
7136
7137 Ok((out, [batch, channels, h_out, w_out]))
7138}
7139
7140#[cfg(not(feature = "cuda"))]
7142#[allow(clippy::too_many_arguments)]
7143pub fn gpu_avgpool2d(
7144 _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
7145 _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
7146 _sh: usize, _sw: usize, _ph: usize, _pw: usize,
7147 _device: &GpuDevice,
7148) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7149 Err(GpuError::NoCudaFeature)
7150}
7151
7152#[cfg(feature = "cuda")]
7160#[allow(clippy::too_many_arguments)]
7161pub fn gpu_batchnorm_forward(
7162 _input: &CudaBuffer<f32>,
7163 _weight: &CudaBuffer<f32>,
7164 _bias: &CudaBuffer<f32>,
7165 _running_mean: &mut CudaBuffer<f32>,
7166 _running_var: &mut CudaBuffer<f32>,
7167 _channels: usize,
7168 _spatial: usize,
7169 _eps: f32,
7170 _momentum: f32,
7171 _training: bool,
7172 device: &GpuDevice,
7173) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7174 let ctx = device.context();
7176 let _f = crate::module_cache::get_or_compile(
7177 ctx,
7178 BATCHNORM_FORWARD_PTX,
7179 "batchnorm_forward_kernel",
7180 device.ordinal() as u32,
7181 );
7182 Err(GpuError::ShapeMismatch {
7184 op: "batchnorm_forward",
7185 expected: vec![0],
7186 got: vec![1],
7187 })
7188}
7189
7190#[cfg(not(feature = "cuda"))]
7192#[allow(clippy::too_many_arguments)]
7193pub fn gpu_batchnorm_forward(
7194 _input: &CudaBuffer<f32>,
7195 _weight: &CudaBuffer<f32>,
7196 _bias: &CudaBuffer<f32>,
7197 _running_mean: &mut CudaBuffer<f32>,
7198 _running_var: &mut CudaBuffer<f32>,
7199 _channels: usize,
7200 _spatial: usize,
7201 _eps: f32,
7202 _momentum: f32,
7203 _training: bool,
7204 _device: &GpuDevice,
7205) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7206 Err(GpuError::NoCudaFeature)
7207}
7208
7209#[cfg(feature = "cuda")]
7218pub fn gpu_layernorm(
7219 input: &CudaBuffer<f32>,
7220 weight: &CudaBuffer<f32>,
7221 bias: &CudaBuffer<f32>,
7222 rows: usize,
7223 cols: usize,
7224 eps: f32,
7225 device: &GpuDevice,
7226) -> GpuResult<CudaBuffer<f32>> {
7227 use cudarc::driver::PushKernelArg;
7228
7229 validate_unary(input, device)?;
7230
7231 let ctx = device.context();
7232 let stream = device.stream();
7233
7234 let f = match crate::module_cache::get_or_compile(
7235 ctx,
7236 LAYERNORM_PTX,
7237 "layernorm_kernel",
7238 device.ordinal() as u32,
7239 ) {
7240 Ok(f) => f,
7241 Err(e) => {
7242 eprintln!("ferrotorch-gpu: LayerNorm PTX compilation failed ({e:?}), CPU fallback");
7243 std::fs::write("/tmp/layernorm_debug.ptx", LAYERNORM_PTX).ok();
7244 eprintln!(
7245 "ferrotorch-gpu: dumped PTX to /tmp/layernorm_debug.ptx ({} bytes)",
7246 LAYERNORM_PTX.len()
7247 );
7248 let h_in = gpu_to_cpu(input, device)?;
7249 let h_w = gpu_to_cpu(weight, device)?;
7250 let h_b = gpu_to_cpu(bias, device)?;
7251 let mut out = vec![0.0f32; rows * cols];
7252 for r in 0..rows {
7253 let base = r * cols;
7254 let slice = &h_in[base..base + cols];
7255 let mean: f32 = slice.iter().sum::<f32>() / cols as f32;
7256 let var: f32 =
7257 slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / cols as f32;
7258 let inv_std = 1.0 / (var + eps).sqrt();
7259 for c in 0..cols {
7260 let normed = (slice[c] - mean) * inv_std;
7261 out[base + c] = h_w[c] * normed + h_b[c];
7262 }
7263 }
7264 return cpu_to_gpu(&out, device);
7265 }
7266 };
7267
7268 let mut out = alloc_zeros_f32(rows * cols, device)?;
7269 let rows_u32 = rows as u32;
7270 let cols_u32 = cols as u32;
7271
7272 let cfg = LaunchConfig {
7273 grid_dim: ((rows as u32).max(1), 1, 1),
7274 block_dim: (256, 1, 1),
7275 shared_mem_bytes: 256 * 4,
7276 };
7277
7278 unsafe {
7279 stream
7280 .launch_builder(&f)
7281 .arg(input.inner())
7282 .arg(out.inner_mut())
7283 .arg(weight.inner())
7284 .arg(bias.inner())
7285 .arg(&rows_u32)
7286 .arg(&cols_u32)
7287 .arg(&eps)
7288 .launch(cfg)?;
7289 }
7290
7291 Ok(out)
7292}
7293
7294#[cfg(feature = "cuda")]
7307pub fn gpu_layernorm_backward(
7308 input: &CudaBuffer<f32>,
7309 grad_output: &CudaBuffer<f32>,
7310 weight: &CudaBuffer<f32>,
7311 rows: usize,
7312 cols: usize,
7313 eps: f32,
7314 device: &GpuDevice,
7315) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7316 use cudarc::driver::PushKernelArg;
7317
7318 validate_unary(input, device)?;
7319
7320 let ctx = device.context();
7321 let stream = device.stream();
7322
7323 let f = match crate::module_cache::get_or_compile(
7324 ctx,
7325 LAYERNORM_BACKWARD_PTX,
7326 "layernorm_backward_kernel",
7327 device.ordinal() as u32,
7328 ) {
7329 Ok(f) => f,
7330 Err(_) => {
7331 let h_in = gpu_to_cpu(input, device)?;
7333 let h_go = gpu_to_cpu(grad_output, device)?;
7334 let h_w = gpu_to_cpu(weight, device)?;
7335 let mut grad_input = vec![0.0f32; rows * cols];
7336 let mut grad_weight = vec![0.0f32; cols];
7337 let mut grad_bias = vec![0.0f32; cols];
7338 let n_f = cols as f32;
7339 for r in 0..rows {
7340 let base = r * cols;
7341 let x_slice = &h_in[base..base + cols];
7342 let go_slice = &h_go[base..base + cols];
7343 let mean: f32 = x_slice.iter().sum::<f32>() / n_f;
7344 let var: f32 = x_slice
7345 .iter()
7346 .map(|&x| (x - mean) * (x - mean))
7347 .sum::<f32>()
7348 / n_f;
7349 let inv_std = 1.0 / (var + eps).sqrt();
7350 let mut sum1 = 0.0f32;
7351 let mut sum2 = 0.0f32;
7352 for c in 0..cols {
7353 let x_hat = (x_slice[c] - mean) * inv_std;
7354 let dl = go_slice[c] * h_w[c];
7355 sum1 += dl;
7356 sum2 += dl * x_hat;
7357 grad_weight[c] += go_slice[c] * x_hat;
7358 grad_bias[c] += go_slice[c];
7359 }
7360 let m1 = sum1 / n_f;
7361 let m2 = sum2 / n_f;
7362 for c in 0..cols {
7363 let x_hat = (x_slice[c] - mean) * inv_std;
7364 let dl = go_slice[c] * h_w[c];
7365 grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
7366 }
7367 }
7368 let gi = cpu_to_gpu(&grad_input, device)?;
7369 let gw = cpu_to_gpu(&grad_weight, device)?;
7370 let gb = cpu_to_gpu(&grad_bias, device)?;
7371 return Ok((gi, gw, gb));
7372 }
7373 };
7374
7375 let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
7376 let mut grad_w = alloc_zeros_f32(cols, device)?;
7377 let mut grad_b = alloc_zeros_f32(cols, device)?;
7378 let rows_u32 = rows as u32;
7379 let cols_u32 = cols as u32;
7380
7381 let cfg = LaunchConfig {
7383 grid_dim: ((rows as u32).max(1), 1, 1),
7384 block_dim: (256, 1, 1),
7385 shared_mem_bytes: 256 * 4,
7386 };
7387
7388 unsafe {
7389 stream
7390 .launch_builder(&f)
7391 .arg(input.inner())
7392 .arg(grad_output.inner())
7393 .arg(weight.inner())
7394 .arg(grad_in.inner_mut())
7395 .arg(grad_w.inner_mut())
7396 .arg(grad_b.inner_mut())
7397 .arg(&rows_u32)
7398 .arg(&cols_u32)
7399 .arg(&eps)
7400 .launch(cfg)?;
7401 }
7402
7403 Ok((grad_in, grad_w, grad_b))
7404}
7405
7406#[cfg(not(feature = "cuda"))]
7408pub fn gpu_layernorm_backward(
7409 _input: &CudaBuffer<f32>,
7410 _grad_output: &CudaBuffer<f32>,
7411 _weight: &CudaBuffer<f32>,
7412 _rows: usize,
7413 _cols: usize,
7414 _eps: f32,
7415 _device: &GpuDevice,
7416) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7417 Err(GpuError::NoCudaFeature)
7418}
7419
7420#[cfg(feature = "cuda")]
7430pub fn gpu_add_into(
7431 a: &CudaBuffer<f32>,
7432 b: &CudaBuffer<f32>,
7433 out: &mut CudaBuffer<f32>,
7434 device: &GpuDevice,
7435) -> GpuResult<()> {
7436 validate_binary(a, b, device)?;
7437 if out.len() < a.len() {
7438 return Err(GpuError::ShapeMismatch {
7439 op: "add_into",
7440 expected: vec![a.len()],
7441 got: vec![out.len()],
7442 });
7443 }
7444 if try_launch_binary_into(a, b, out, device, ADD_PTX, "add_kernel")? {
7445 return Ok(());
7446 }
7447 Err(GpuError::PtxCompileFailed {
7448 kernel: "add_kernel",
7449 })
7450}
7451
7452#[cfg(feature = "cuda")]
7454pub fn gpu_mul_into(
7455 a: &CudaBuffer<f32>,
7456 b: &CudaBuffer<f32>,
7457 out: &mut CudaBuffer<f32>,
7458 device: &GpuDevice,
7459) -> GpuResult<()> {
7460 validate_binary(a, b, device)?;
7461 if out.len() < a.len() {
7462 return Err(GpuError::ShapeMismatch {
7463 op: "mul_into",
7464 expected: vec![a.len()],
7465 got: vec![out.len()],
7466 });
7467 }
7468 if try_launch_binary_into(a, b, out, device, MUL_PTX, "mul_kernel")? {
7469 return Ok(());
7470 }
7471 Err(GpuError::PtxCompileFailed {
7472 kernel: "mul_kernel",
7473 })
7474}
7475
7476#[cfg(feature = "cuda")]
7478pub fn gpu_scale_into(
7479 a: &CudaBuffer<f32>,
7480 scalar: f32,
7481 out: &mut CudaBuffer<f32>,
7482 device: &GpuDevice,
7483) -> GpuResult<()> {
7484 use cudarc::driver::PushKernelArg;
7485 validate_unary(a, device)?;
7486 let n = a.len();
7487 let ctx = device.context();
7488 let stream = device.stream();
7489 let f = crate::module_cache::get_or_compile(
7490 ctx,
7491 SCALE_PTX,
7492 "scale_kernel",
7493 device.ordinal() as u32,
7494 )
7495 .map_err(|_| GpuError::PtxCompileFailed {
7496 kernel: "scale_kernel",
7497 })?;
7498 let cfg = launch_cfg(n)?;
7499 let n_u32 = n as u32;
7500 unsafe {
7501 stream
7502 .launch_builder(&f)
7503 .arg(a.inner())
7504 .arg(out.inner_mut())
7505 .arg(&scalar)
7506 .arg(&n_u32)
7507 .launch(cfg)?;
7508 }
7509 Ok(())
7510}
7511
7512#[cfg(feature = "cuda")]
7529pub fn gpu_has_inf_nan(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<bool> {
7530 let n = a.len();
7531 if n == 0 {
7532 return Ok(false);
7533 }
7534
7535 validate_unary(a, device)?;
7536
7537 let host: Vec<f32> = crate::transfer::gpu_to_cpu(a, device)?;
7538 Ok(host.iter().any(|v| !v.is_finite()))
7539}
7540
7541#[cfg(not(feature = "cuda"))]
7543pub fn gpu_has_inf_nan(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<bool> {
7544 Err(GpuError::NoCudaFeature)
7545}
7546
7547#[cfg(feature = "cuda")]
7549pub fn gpu_gelu_into(
7550 a: &CudaBuffer<f32>,
7551 out: &mut CudaBuffer<f32>,
7552 device: &GpuDevice,
7553) -> GpuResult<()> {
7554 validate_unary(a, device)?;
7555 if try_launch_unary_into(a, out, device, GELU_PTX, "gelu_kernel")? {
7556 return Ok(());
7557 }
7558 Err(GpuError::PtxCompileFailed {
7559 kernel: "gelu_kernel",
7560 })
7561}
7562
7563#[cfg(feature = "cuda")]
7565pub fn gpu_embed_lookup_into(
7566 idx: &CudaBuffer<f32>,
7567 weight: &CudaBuffer<f32>,
7568 d: usize,
7569 out: &mut CudaBuffer<f32>,
7570 device: &GpuDevice,
7571) -> GpuResult<()> {
7572 use cudarc::driver::PushKernelArg;
7573 let ctx = device.context();
7574 let stream = device.stream();
7575 let f = crate::module_cache::get_or_compile(
7576 ctx,
7577 EMBED_LOOKUP_PTX,
7578 "embed_lookup_kernel",
7579 device.ordinal() as u32,
7580 )
7581 .map_err(|_| GpuError::PtxCompileFailed {
7582 kernel: "embed_lookup_kernel",
7583 })?;
7584 let cfg = launch_cfg(d)?;
7585 let d_u32 = d as u32;
7586 unsafe {
7587 stream
7588 .launch_builder(&f)
7589 .arg(idx.inner())
7590 .arg(weight.inner())
7591 .arg(out.inner_mut())
7592 .arg(&d_u32)
7593 .launch(cfg)?;
7594 }
7595 Ok(())
7596}
7597
7598#[cfg(feature = "cuda")]
7606pub fn gpu_embed_lookup_batch(
7607 indices: &CudaBuffer<f32>,
7608 weight: &CudaBuffer<f32>,
7609 n: usize,
7610 d: usize,
7611 device: &GpuDevice,
7612) -> GpuResult<CudaBuffer<f32>> {
7613 use cudarc::driver::PushKernelArg;
7614
7615 let total = n * d;
7616 if total == 0 {
7617 return alloc_zeros_f32(0, device);
7618 }
7619
7620 let ctx = device.context();
7621 let stream = device.stream();
7622
7623 let f = match crate::module_cache::get_or_compile(
7624 ctx,
7625 EMBED_LOOKUP_BATCH_PTX,
7626 "embed_lookup_batch_kernel",
7627 device.ordinal() as u32,
7628 ) {
7629 Ok(f) => f,
7630 Err(_) => {
7631 let idx_host = gpu_to_cpu(indices, device)?;
7633 let weight_host = gpu_to_cpu(weight, device)?;
7634 let mut out = Vec::with_capacity(total);
7635 for &idx_f in &idx_host {
7636 let row = idx_f as usize;
7637 let start = row * d;
7638 out.extend_from_slice(&weight_host[start..start + d]);
7639 }
7640 return cpu_to_gpu(&out, device);
7641 }
7642 };
7643
7644 let mut out = alloc_zeros_f32(total, device)?;
7645 let cfg = launch_cfg(total)?;
7646 let d_u32 = d as u32;
7647 let total_u32 = total as u32;
7648
7649 unsafe {
7650 stream
7651 .launch_builder(&f)
7652 .arg(indices.inner())
7653 .arg(weight.inner())
7654 .arg(out.inner_mut())
7655 .arg(&d_u32)
7656 .arg(&total_u32)
7657 .launch(cfg)?;
7658 }
7659
7660 Ok(out)
7661}
7662
7663#[cfg(feature = "cuda")]
7673pub fn gpu_scatter_add_rows(
7674 grad_output: &CudaBuffer<f32>,
7675 indices: &CudaBuffer<f32>,
7676 num_embeddings: usize,
7677 d: usize,
7678 device: &GpuDevice,
7679) -> GpuResult<CudaBuffer<f32>> {
7680 use cudarc::driver::PushKernelArg;
7681
7682 let n = indices.len();
7683 let total = n * d;
7684
7685 if total == 0 {
7686 return alloc_zeros_f32(num_embeddings * d, device);
7687 }
7688
7689 let ctx = device.context();
7690 let stream = device.stream();
7691
7692 let f = match crate::module_cache::get_or_compile(
7693 ctx,
7694 SCATTER_ADD_ROWS_PTX,
7695 "scatter_add_rows_kernel",
7696 device.ordinal() as u32,
7697 ) {
7698 Ok(f) => f,
7699 Err(_) => {
7700 let go_host = gpu_to_cpu(grad_output, device)?;
7702 let idx_host = gpu_to_cpu(indices, device)?;
7703 let mut result = vec![0.0f32; num_embeddings * d];
7704 for (i, &idx_f) in idx_host.iter().enumerate() {
7705 let row = idx_f as usize;
7706 for j in 0..d {
7707 result[row * d + j] += go_host[i * d + j];
7708 }
7709 }
7710 return cpu_to_gpu(&result, device);
7711 }
7712 };
7713
7714 let mut out = alloc_zeros_f32(num_embeddings * d, device)?;
7715 let cfg = launch_cfg(total)?;
7716 let d_u32 = d as u32;
7717 let total_u32 = total as u32;
7718
7719 unsafe {
7720 stream
7721 .launch_builder(&f)
7722 .arg(grad_output.inner())
7723 .arg(indices.inner())
7724 .arg(out.inner_mut())
7725 .arg(&d_u32)
7726 .arg(&total_u32)
7727 .launch(cfg)?;
7728 }
7729
7730 Ok(out)
7731}
7732
7733#[cfg(feature = "cuda")]
7735pub fn gpu_transpose_2d_into(
7736 a: &CudaBuffer<f32>,
7737 m: usize,
7738 n: usize,
7739 out: &mut CudaBuffer<f32>,
7740 device: &GpuDevice,
7741) -> GpuResult<()> {
7742 use cudarc::driver::PushKernelArg;
7743 let total = m * n;
7744 let ctx = device.context();
7745 let stream = device.stream();
7746 let f = crate::module_cache::get_or_compile(
7747 ctx,
7748 TRANSPOSE_2D_PTX,
7749 "transpose_2d_kernel",
7750 device.ordinal() as u32,
7751 )
7752 .map_err(|_| GpuError::PtxCompileFailed {
7753 kernel: "transpose_2d_kernel",
7754 })?;
7755 let cfg = launch_cfg(total)?;
7756 let m_u32 = m as u32;
7757 let n_u32 = n as u32;
7758 let total_u32 = total as u32;
7759 unsafe {
7760 stream
7761 .launch_builder(&f)
7762 .arg(a.inner())
7763 .arg(out.inner_mut())
7764 .arg(&m_u32)
7765 .arg(&n_u32)
7766 .arg(&total_u32)
7767 .launch(cfg)?;
7768 }
7769 Ok(())
7770}
7771
7772#[cfg(feature = "cuda")]
7774pub fn gpu_permute_0213_into(
7775 a: &CudaBuffer<f32>,
7776 d0: usize,
7777 d1: usize,
7778 d2: usize,
7779 d3: usize,
7780 out: &mut CudaBuffer<f32>,
7781 device: &GpuDevice,
7782) -> GpuResult<()> {
7783 use cudarc::driver::PushKernelArg;
7784 let total = d0 * d1 * d2 * d3;
7785 let ctx = device.context();
7786 let stream = device.stream();
7787 let f = crate::module_cache::get_or_compile(
7788 ctx,
7789 PERMUTE_0213_PTX,
7790 "permute_0213_kernel",
7791 device.ordinal() as u32,
7792 )
7793 .map_err(|_| GpuError::PtxCompileFailed {
7794 kernel: "permute_0213_kernel",
7795 })?;
7796 let cfg = launch_cfg(total)?;
7797 let (d0u, d1u, d2u, d3u, tu) = (d0 as u32, d1 as u32, d2 as u32, d3 as u32, total as u32);
7798 unsafe {
7799 stream
7800 .launch_builder(&f)
7801 .arg(a.inner())
7802 .arg(out.inner_mut())
7803 .arg(&d0u)
7804 .arg(&d1u)
7805 .arg(&d2u)
7806 .arg(&d3u)
7807 .arg(&tu)
7808 .launch(cfg)?;
7809 }
7810 Ok(())
7811}
7812
7813#[cfg(feature = "cuda")]
7815pub fn gpu_softmax_into(
7816 a: &CudaBuffer<f32>,
7817 rows: usize,
7818 cols: usize,
7819 out: &mut CudaBuffer<f32>,
7820 device: &GpuDevice,
7821) -> GpuResult<()> {
7822 use cudarc::driver::PushKernelArg;
7823 let ctx = device.context();
7824 let stream = device.stream();
7825 let f = crate::module_cache::get_or_compile(
7826 ctx,
7827 SOFTMAX_PTX,
7828 "softmax_kernel",
7829 device.ordinal() as u32,
7830 )
7831 .map_err(|_| GpuError::PtxCompileFailed {
7832 kernel: "softmax_kernel",
7833 })?;
7834 let block_size = 256u32;
7835 let grid_size = rows as u32;
7836 let cfg = LaunchConfig {
7837 grid_dim: (grid_size, 1, 1),
7838 block_dim: (block_size, 1, 1),
7839 shared_mem_bytes: (cols as u32) * 4,
7840 };
7841 let rows_u32 = rows as u32;
7842 let cols_u32 = cols as u32;
7843 unsafe {
7844 stream
7845 .launch_builder(&f)
7846 .arg(a.inner())
7847 .arg(out.inner_mut())
7848 .arg(&rows_u32)
7849 .arg(&cols_u32)
7850 .launch(cfg)?;
7851 }
7852 Ok(())
7853}
7854
7855#[cfg(feature = "cuda")]
7857#[allow(clippy::too_many_arguments)]
7858pub fn gpu_layernorm_into(
7859 input: &CudaBuffer<f32>,
7860 weight: &CudaBuffer<f32>,
7861 bias: &CudaBuffer<f32>,
7862 rows: usize,
7863 cols: usize,
7864 eps: f32,
7865 out: &mut CudaBuffer<f32>,
7866 device: &GpuDevice,
7867) -> GpuResult<()> {
7868 use cudarc::driver::PushKernelArg;
7869 let ctx = device.context();
7870 let stream = device.stream();
7871 let f = crate::module_cache::get_or_compile(
7872 ctx,
7873 LAYERNORM_PTX,
7874 "layernorm_kernel",
7875 device.ordinal() as u32,
7876 )
7877 .map_err(|_| GpuError::PtxCompileFailed {
7878 kernel: "layernorm_kernel",
7879 })?;
7880 let block_size = 256u32;
7881 let grid_size = rows as u32;
7882 let cfg = LaunchConfig {
7883 grid_dim: (grid_size, 1, 1),
7884 block_dim: (block_size, 1, 1),
7885 shared_mem_bytes: (cols as u32) * 4,
7886 };
7887 let rows_u32 = rows as u32;
7888 let cols_u32 = cols as u32;
7889 unsafe {
7890 stream
7891 .launch_builder(&f)
7892 .arg(input.inner())
7893 .arg(out.inner_mut())
7894 .arg(weight.inner())
7895 .arg(bias.inner())
7896 .arg(&rows_u32)
7897 .arg(&cols_u32)
7898 .arg(&eps)
7899 .launch(cfg)?;
7900 }
7901 Ok(())
7902}
7903
7904#[cfg(feature = "cuda")]
7907pub fn gpu_slice_read_into(
7908 src: &CudaBuffer<f32>,
7909 n_batch: usize,
7910 d: usize,
7911 len: usize,
7912 max_len: usize,
7913 out: &mut CudaBuffer<f32>,
7914 device: &GpuDevice,
7915) -> GpuResult<()> {
7916 use cudarc::driver::PushKernelArg;
7917 let total = n_batch * len * d;
7918 let ctx = device.context();
7919 let stream = device.stream();
7920 let f = crate::module_cache::get_or_compile(
7921 ctx,
7922 SLICE_READ_PTX,
7923 "slice_read_kernel",
7924 device.ordinal() as u32,
7925 )
7926 .map_err(|_| GpuError::PtxCompileFailed {
7927 kernel: "slice_read_kernel",
7928 })?;
7929 let cfg = launch_cfg(total)?;
7930 let total_u32 = total as u32;
7931 let d_u32 = d as u32;
7932 let len_u32 = len as u32;
7933 let max_len_u32 = max_len as u32;
7934 unsafe {
7935 stream
7936 .launch_builder(&f)
7937 .arg(src.inner())
7938 .arg(out.inner_mut())
7939 .arg(&total_u32)
7940 .arg(&d_u32)
7941 .arg(&len_u32)
7942 .arg(&max_len_u32)
7943 .launch(cfg)?;
7944 }
7945 Ok(())
7946}
7947
7948#[cfg(feature = "cuda")]
7950pub fn gpu_small_matmul_into(
7951 a: &CudaBuffer<f32>,
7952 b: &CudaBuffer<f32>,
7953 m: usize,
7954 k: usize,
7955 n: usize,
7956 out: &mut CudaBuffer<f32>,
7957 device: &GpuDevice,
7958) -> GpuResult<()> {
7959 use cudarc::driver::PushKernelArg;
7960 let total = m * n;
7961 let ctx = device.context();
7962 let stream = device.stream();
7963 let f = crate::module_cache::get_or_compile(
7964 ctx,
7965 SMALL_MATMUL_PTX,
7966 "small_matmul_kernel",
7967 device.ordinal() as u32,
7968 )
7969 .map_err(|_| GpuError::PtxCompileFailed {
7970 kernel: "small_matmul_kernel",
7971 })?;
7972 let cfg = launch_cfg(total)?;
7973 let (m_u32, k_u32, n_u32, total_u32) = (m as u32, k as u32, n as u32, total as u32);
7974 unsafe {
7975 stream
7976 .launch_builder(&f)
7977 .arg(a.inner())
7978 .arg(b.inner())
7979 .arg(out.inner_mut())
7980 .arg(&m_u32)
7981 .arg(&k_u32)
7982 .arg(&n_u32)
7983 .arg(&total_u32)
7984 .launch(cfg)?;
7985 }
7986 Ok(())
7987}
7988
7989#[cfg(feature = "cuda")]
7996pub fn gpu_slice_write_indirect(
7997 src: &CudaBuffer<f32>,
7998 dst: &mut CudaBuffer<f32>,
7999 n_batch: usize,
8000 d: usize,
8001 max_len: usize,
8002 pos_ptr: &cudarc::driver::CudaSlice<u32>,
8003 device: &GpuDevice,
8004) -> GpuResult<()> {
8005 use cudarc::driver::PushKernelArg;
8006 let total = n_batch * d;
8007 let ctx = device.context();
8008 let stream = device.stream();
8009 let f = crate::module_cache::get_or_compile(
8010 ctx,
8011 SLICE_WRITE_INDIRECT_PTX,
8012 "slice_write_indirect_kernel",
8013 device.ordinal() as u32,
8014 )
8015 .map_err(|_| GpuError::PtxCompileFailed {
8016 kernel: "slice_write_indirect_kernel",
8017 })?;
8018 let cfg = launch_cfg(total)?;
8019 let n_u32 = total as u32;
8020 let d_u32 = d as u32;
8021 let max_len_u32 = max_len as u32;
8022 unsafe {
8023 stream
8024 .launch_builder(&f)
8025 .arg(src.inner())
8026 .arg(dst.inner_mut())
8027 .arg(&n_u32)
8028 .arg(&d_u32)
8029 .arg(&max_len_u32)
8030 .arg(pos_ptr)
8031 .launch(cfg)?;
8032 }
8033 Ok(())
8034}
8035
8036#[cfg(feature = "cuda")]
8040pub fn gpu_causal_mask_indirect(
8041 total_len_ptr: &cudarc::driver::CudaSlice<u32>,
8042 n_head: usize,
8043 max_pos: usize,
8044 out: &mut CudaBuffer<f32>,
8045 device: &GpuDevice,
8046) -> GpuResult<()> {
8047 use cudarc::driver::PushKernelArg;
8048 let total = n_head * max_pos;
8049 let ctx = device.context();
8050 let stream = device.stream();
8051 let f = crate::module_cache::get_or_compile(
8052 ctx,
8053 CAUSAL_MASK_INDIRECT_PTX,
8054 "causal_mask_indirect_kernel",
8055 device.ordinal() as u32,
8056 )
8057 .map_err(|_| GpuError::PtxCompileFailed {
8058 kernel: "causal_mask_indirect_kernel",
8059 })?;
8060 let cfg = launch_cfg(total)?;
8061 let max_pos_u32 = max_pos as u32;
8062 let total_u32 = total as u32;
8063 unsafe {
8064 stream
8065 .launch_builder(&f)
8066 .arg(total_len_ptr)
8067 .arg(out.inner_mut())
8068 .arg(&max_pos_u32)
8069 .arg(&total_u32)
8070 .launch(cfg)?;
8071 }
8072 Ok(())
8073}
8074
8075#[cfg(feature = "cuda")]
8083pub fn precompile_decode_kernels(device: &GpuDevice) -> GpuResult<()> {
8084 let ctx = device.context();
8085 ctx.bind_to_thread()?;
8086 let ord = device.ordinal() as u32;
8087 let compile = |ptx: &'static str, name: &'static str| -> GpuResult<()> {
8088 crate::module_cache::get_or_compile(ctx, ptx, name, ord)
8089 .map(|_| ())
8090 .map_err(GpuError::Driver)
8091 };
8092 compile(ADD_PTX, "add_kernel")?;
8093 compile(MUL_PTX, "mul_kernel")?;
8094 compile(SCALE_PTX, "scale_kernel")?;
8095 compile(GELU_PTX, "gelu_kernel")?;
8096 compile(SOFTMAX_PTX, "softmax_kernel")?;
8097 compile(LAYERNORM_PTX, "layernorm_kernel")?;
8098 compile(PERMUTE_0213_PTX, "permute_0213_kernel")?;
8099 compile(EMBED_LOOKUP_PTX, "embed_lookup_kernel")?;
8100 compile(EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel")?;
8101 compile(SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel")?;
8102 compile(SMALL_MATMUL_PTX, "small_matmul_kernel")?;
8103 compile(SLICE_WRITE_INDIRECT_PTX, "slice_write_indirect_kernel")?;
8104 compile(CAUSAL_MASK_INDIRECT_PTX, "causal_mask_indirect_kernel")?;
8105 compile(SLICE_READ_PTX, "slice_read_kernel")?;
8106 compile(RELU_BACKWARD_PTX, "relu_backward_kernel")?;
8107 compile(GELU_BACKWARD_PTX, "gelu_backward_kernel")?;
8108 Ok(())
8109}
8110
8111#[cfg(not(feature = "cuda"))]
8113pub fn precompile_decode_kernels(_device: &GpuDevice) -> GpuResult<()> {
8114 Err(GpuError::NoCudaFeature)
8115}
8116
8117#[cfg(not(feature = "cuda"))]
8123pub fn gpu_gelu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8124 Err(GpuError::NoCudaFeature)
8125}
8126
8127#[cfg(not(feature = "cuda"))]
8129pub fn gpu_div(
8130 _a: &CudaBuffer<f32>,
8131 _b: &CudaBuffer<f32>,
8132 _device: &GpuDevice,
8133) -> GpuResult<CudaBuffer<f32>> {
8134 Err(GpuError::NoCudaFeature)
8135}
8136
8137#[cfg(not(feature = "cuda"))]
8139pub fn gpu_exp(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8140 Err(GpuError::NoCudaFeature)
8141}
8142
8143#[cfg(not(feature = "cuda"))]
8145pub fn gpu_log(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8146 Err(GpuError::NoCudaFeature)
8147}
8148
8149#[cfg(not(feature = "cuda"))]
8151pub fn gpu_sqrt(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8152 Err(GpuError::NoCudaFeature)
8153}
8154
8155#[cfg(not(feature = "cuda"))]
8157pub fn gpu_pow(
8158 _a: &CudaBuffer<f32>,
8159 _exponent: f32,
8160 _device: &GpuDevice,
8161) -> GpuResult<CudaBuffer<f32>> {
8162 Err(GpuError::NoCudaFeature)
8163}
8164
8165#[cfg(not(feature = "cuda"))]
8167pub fn gpu_abs(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8168 Err(GpuError::NoCudaFeature)
8169}
8170
8171#[cfg(not(feature = "cuda"))]
8173pub fn gpu_sigmoid(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8174 Err(GpuError::NoCudaFeature)
8175}
8176
8177#[cfg(not(feature = "cuda"))]
8179pub fn gpu_tanh(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8180 Err(GpuError::NoCudaFeature)
8181}
8182
8183#[cfg(not(feature = "cuda"))]
8185pub fn gpu_layernorm(
8186 _input: &CudaBuffer<f32>,
8187 _weight: &CudaBuffer<f32>,
8188 _bias: &CudaBuffer<f32>,
8189 _rows: usize,
8190 _cols: usize,
8191 _eps: f32,
8192 _device: &GpuDevice,
8193) -> GpuResult<CudaBuffer<f32>> {
8194 Err(GpuError::NoCudaFeature)
8195}
8196
8197#[cfg(not(feature = "cuda"))]
8199pub fn gpu_transpose_2d(
8200 _input: &CudaBuffer<f32>,
8201 _m: usize,
8202 _n: usize,
8203 _device: &GpuDevice,
8204) -> GpuResult<CudaBuffer<f32>> {
8205 Err(GpuError::NoCudaFeature)
8206}
8207
8208#[cfg(not(feature = "cuda"))]
8210pub fn gpu_add(
8211 _a: &CudaBuffer<f32>,
8212 _b: &CudaBuffer<f32>,
8213 _device: &GpuDevice,
8214) -> GpuResult<CudaBuffer<f32>> {
8215 Err(GpuError::NoCudaFeature)
8216}
8217
8218#[cfg(not(feature = "cuda"))]
8220pub fn gpu_sub(
8221 _a: &CudaBuffer<f32>,
8222 _b: &CudaBuffer<f32>,
8223 _device: &GpuDevice,
8224) -> GpuResult<CudaBuffer<f32>> {
8225 Err(GpuError::NoCudaFeature)
8226}
8227
8228#[cfg(not(feature = "cuda"))]
8230pub fn gpu_mul(
8231 _a: &CudaBuffer<f32>,
8232 _b: &CudaBuffer<f32>,
8233 _device: &GpuDevice,
8234) -> GpuResult<CudaBuffer<f32>> {
8235 Err(GpuError::NoCudaFeature)
8236}
8237
8238#[cfg(not(feature = "cuda"))]
8240pub fn gpu_neg(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8241 Err(GpuError::NoCudaFeature)
8242}
8243
8244#[cfg(not(feature = "cuda"))]
8246pub fn gpu_relu(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8247 Err(GpuError::NoCudaFeature)
8248}
8249
8250#[cfg(not(feature = "cuda"))]
8252pub fn gpu_scale(
8253 _a: &CudaBuffer<f32>,
8254 _scalar: f32,
8255 _device: &GpuDevice,
8256) -> GpuResult<CudaBuffer<f32>> {
8257 Err(GpuError::NoCudaFeature)
8258}
8259
8260#[cfg(not(feature = "cuda"))]
8262pub fn gpu_broadcast_add(
8263 _a: &CudaBuffer<f32>,
8264 _b: &CudaBuffer<f32>,
8265 _a_shape: &[usize],
8266 _b_shape: &[usize],
8267 _out_shape: &[usize],
8268 _device: &GpuDevice,
8269) -> GpuResult<CudaBuffer<f32>> {
8270 Err(GpuError::NoCudaFeature)
8271}
8272
8273#[cfg(not(feature = "cuda"))]
8275pub fn gpu_broadcast_sub(
8276 _a: &CudaBuffer<f32>,
8277 _b: &CudaBuffer<f32>,
8278 _a_shape: &[usize],
8279 _b_shape: &[usize],
8280 _out_shape: &[usize],
8281 _device: &GpuDevice,
8282) -> GpuResult<CudaBuffer<f32>> {
8283 Err(GpuError::NoCudaFeature)
8284}
8285
8286#[cfg(not(feature = "cuda"))]
8288pub fn gpu_broadcast_mul(
8289 _a: &CudaBuffer<f32>,
8290 _b: &CudaBuffer<f32>,
8291 _a_shape: &[usize],
8292 _b_shape: &[usize],
8293 _out_shape: &[usize],
8294 _device: &GpuDevice,
8295) -> GpuResult<CudaBuffer<f32>> {
8296 Err(GpuError::NoCudaFeature)
8297}
8298
8299#[cfg(not(feature = "cuda"))]
8301pub fn gpu_softmax(
8302 _input: &CudaBuffer<f32>,
8303 _rows: usize,
8304 _cols: usize,
8305 _device: &GpuDevice,
8306) -> GpuResult<CudaBuffer<f32>> {
8307 Err(GpuError::NoCudaFeature)
8308}
8309
8310#[cfg(not(feature = "cuda"))]
8312pub fn gpu_dropout(
8313 _input: &CudaBuffer<f32>,
8314 _threshold: u32,
8315 _scale: f32,
8316 _seed: u32,
8317 _device: &GpuDevice,
8318) -> GpuResult<CudaBuffer<f32>> {
8319 Err(GpuError::NoCudaFeature)
8320}
8321
8322#[cfg(not(feature = "cuda"))]
8324pub fn gpu_permute_0213(
8325 _input: &CudaBuffer<f32>,
8326 _d0: usize,
8327 _d1: usize,
8328 _d2: usize,
8329 _d3: usize,
8330 _device: &GpuDevice,
8331) -> GpuResult<CudaBuffer<f32>> {
8332 Err(GpuError::NoCudaFeature)
8333}
8334
8335#[cfg(not(feature = "cuda"))]
8337pub fn gpu_slice_write(
8338 _src: &CudaBuffer<f32>,
8339 _dst: &mut CudaBuffer<f32>,
8340 _n_batch: usize,
8341 _d: usize,
8342 _max_len: usize,
8343 _pos: usize,
8344 _device: &GpuDevice,
8345) -> GpuResult<()> {
8346 Err(GpuError::NoCudaFeature)
8347}
8348
8349#[cfg(not(feature = "cuda"))]
8351pub fn gpu_slice_read(
8352 _src: &CudaBuffer<f32>,
8353 _n_batch: usize,
8354 _d: usize,
8355 _len: usize,
8356 _max_len: usize,
8357 _device: &GpuDevice,
8358) -> GpuResult<CudaBuffer<f32>> {
8359 Err(GpuError::NoCudaFeature)
8360}
8361
8362#[cfg(not(feature = "cuda"))]
8364pub fn gpu_embed_lookup(
8365 _idx: &CudaBuffer<f32>,
8366 _weight: &CudaBuffer<f32>,
8367 _d: usize,
8368 _device: &GpuDevice,
8369) -> GpuResult<CudaBuffer<f32>> {
8370 Err(GpuError::NoCudaFeature)
8371}
8372
8373#[cfg(not(feature = "cuda"))]
8375pub fn gpu_embed_lookup_batch(
8376 _indices: &CudaBuffer<f32>,
8377 _weight: &CudaBuffer<f32>,
8378 _n: usize,
8379 _d: usize,
8380 _device: &GpuDevice,
8381) -> GpuResult<CudaBuffer<f32>> {
8382 Err(GpuError::NoCudaFeature)
8383}
8384
8385#[cfg(not(feature = "cuda"))]
8387pub fn gpu_scatter_add_rows(
8388 _grad_output: &CudaBuffer<f32>,
8389 _indices: &CudaBuffer<f32>,
8390 _num_embeddings: usize,
8391 _d: usize,
8392 _device: &GpuDevice,
8393) -> GpuResult<CudaBuffer<f32>> {
8394 Err(GpuError::NoCudaFeature)
8395}
8396
8397#[cfg(not(feature = "cuda"))]
8399pub fn gpu_relu_backward(
8400 _grad: &CudaBuffer<f32>,
8401 _input: &CudaBuffer<f32>,
8402 _device: &GpuDevice,
8403) -> GpuResult<CudaBuffer<f32>> {
8404 Err(GpuError::NoCudaFeature)
8405}
8406
8407#[cfg(not(feature = "cuda"))]
8409pub fn gpu_gelu_backward(
8410 _grad: &CudaBuffer<f32>,
8411 _input: &CudaBuffer<f32>,
8412 _device: &GpuDevice,
8413) -> GpuResult<CudaBuffer<f32>> {
8414 Err(GpuError::NoCudaFeature)
8415}
8416
8417#[cfg(not(feature = "cuda"))]
8419pub fn gpu_index_select_1d(
8420 _input: &CudaBuffer<f32>,
8421 _indices: &CudaBuffer<f32>,
8422 _device: &GpuDevice,
8423) -> GpuResult<CudaBuffer<f32>> {
8424 Err(GpuError::NoCudaFeature)
8425}
8426
8427#[cfg(not(feature = "cuda"))]
8429pub fn gpu_scatter_add_1d(
8430 _grad_output: &CudaBuffer<f32>,
8431 _indices: &CudaBuffer<f32>,
8432 _input_len: usize,
8433 _device: &GpuDevice,
8434) -> GpuResult<CudaBuffer<f32>> {
8435 Err(GpuError::NoCudaFeature)
8436}
8437
8438#[cfg(not(feature = "cuda"))]
8440pub fn gpu_masked_fill(
8441 _input: &CudaBuffer<f32>,
8442 _mask: &CudaBuffer<f32>,
8443 _value: f32,
8444 _device: &GpuDevice,
8445) -> GpuResult<CudaBuffer<f32>> {
8446 Err(GpuError::NoCudaFeature)
8447}
8448
8449#[cfg(not(feature = "cuda"))]
8451pub fn gpu_masked_zero(
8452 _grad: &CudaBuffer<f32>,
8453 _mask: &CudaBuffer<f32>,
8454 _device: &GpuDevice,
8455) -> GpuResult<CudaBuffer<f32>> {
8456 Err(GpuError::NoCudaFeature)
8457}
8458
8459#[cfg(not(feature = "cuda"))]
8461pub fn gpu_sigmoid_backward(
8462 _grad: &CudaBuffer<f32>,
8463 _output: &CudaBuffer<f32>,
8464 _device: &GpuDevice,
8465) -> GpuResult<CudaBuffer<f32>> {
8466 Err(GpuError::NoCudaFeature)
8467}
8468
8469#[cfg(not(feature = "cuda"))]
8471pub fn gpu_tanh_backward(
8472 _grad: &CudaBuffer<f32>,
8473 _output: &CudaBuffer<f32>,
8474 _device: &GpuDevice,
8475) -> GpuResult<CudaBuffer<f32>> {
8476 Err(GpuError::NoCudaFeature)
8477}
8478
8479#[cfg(not(feature = "cuda"))]
8481pub fn gpu_softmax_backward(
8482 _grad: &CudaBuffer<f32>,
8483 _output: &CudaBuffer<f32>,
8484 _cols: usize,
8485 _device: &GpuDevice,
8486) -> GpuResult<CudaBuffer<f32>> {
8487 Err(GpuError::NoCudaFeature)
8488}
8489
8490#[cfg(not(feature = "cuda"))]
8492pub fn gpu_sum_axis(
8493 _a: &CudaBuffer<f32>,
8494 _outer: usize,
8495 _axis_size: usize,
8496 _inner: usize,
8497 _device: &GpuDevice,
8498) -> GpuResult<CudaBuffer<f32>> {
8499 Err(GpuError::NoCudaFeature)
8500}
8501
8502#[cfg(not(feature = "cuda"))]
8504pub fn gpu_strided_split(
8505 _input: &CudaBuffer<f32>,
8506 _total_along_axis: usize,
8507 _split_offset: usize,
8508 _split_size: usize,
8509 _inner_size: usize,
8510 _n: usize,
8511 _device: &GpuDevice,
8512) -> GpuResult<CudaBuffer<f32>> {
8513 Err(GpuError::NoCudaFeature)
8514}
8515
8516#[cfg(not(feature = "cuda"))]
8518pub fn gpu_strided_cat(
8519 _input: &CudaBuffer<f32>,
8520 _output: &mut CudaBuffer<f32>,
8521 _total_along_axis: usize,
8522 _cat_offset: usize,
8523 _part_size: usize,
8524 _inner_size: usize,
8525 _n: usize,
8526 _device: &GpuDevice,
8527) -> GpuResult<()> {
8528 Err(GpuError::NoCudaFeature)
8529}
8530
8531#[cfg(feature = "cuda")]
8547pub(crate) fn gpu_f32_to_f16(
8548 input: &CudaBuffer<f32>,
8549 device: &GpuDevice,
8550) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
8551 use cudarc::driver::PushKernelArg;
8552
8553 let n = input.len();
8554 if n == 0 {
8555 let empty = device.stream().alloc_zeros::<u16>(0)?;
8556 return Ok(empty);
8557 }
8558
8559 let ctx = device.context();
8560 let stream = device.stream();
8561
8562 let f = crate::module_cache::get_or_compile(
8563 ctx,
8564 F32_TO_F16_PTX,
8565 "f32_to_f16_kernel",
8566 device.ordinal() as u32,
8567 )
8568 .map_err(|_| GpuError::PtxCompileFailed {
8569 kernel: "f32_to_f16_kernel",
8570 })?;
8571
8572 let mut out = stream.alloc_zeros::<u16>(n)?;
8573 let cfg = launch_cfg(n)?;
8574 let n_u32 = n as u32;
8575
8576 unsafe {
8580 stream
8581 .launch_builder(&f)
8582 .arg(input.inner())
8583 .arg(&mut out)
8584 .arg(&n_u32)
8585 .launch(cfg)?;
8586 }
8587
8588 Ok(out)
8589}
8590
8591#[cfg(not(feature = "cuda"))]
8593pub(crate) fn gpu_f32_to_f16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
8594 Err(GpuError::NoCudaFeature)
8595}
8596
8597#[cfg(feature = "cuda")]
8602pub(crate) fn gpu_f32_to_bf16(
8603 input: &CudaBuffer<f32>,
8604 device: &GpuDevice,
8605) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
8606 use cudarc::driver::PushKernelArg;
8607
8608 let n = input.len();
8609 if n == 0 {
8610 let empty = device.stream().alloc_zeros::<u16>(0)?;
8611 return Ok(empty);
8612 }
8613
8614 let ctx = device.context();
8615 let stream = device.stream();
8616
8617 let f = crate::module_cache::get_or_compile(
8618 ctx,
8619 F32_TO_BF16_PTX,
8620 "f32_to_bf16_kernel",
8621 device.ordinal() as u32,
8622 )
8623 .map_err(|_| GpuError::PtxCompileFailed {
8624 kernel: "f32_to_bf16_kernel",
8625 })?;
8626
8627 let mut out = stream.alloc_zeros::<u16>(n)?;
8628 let cfg = launch_cfg(n)?;
8629 let n_u32 = n as u32;
8630
8631 unsafe {
8632 stream
8633 .launch_builder(&f)
8634 .arg(input.inner())
8635 .arg(&mut out)
8636 .arg(&n_u32)
8637 .launch(cfg)?;
8638 }
8639
8640 Ok(out)
8641}
8642
8643#[cfg(not(feature = "cuda"))]
8645pub(crate) fn gpu_f32_to_bf16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
8646 Err(GpuError::NoCudaFeature)
8647}
8648
8649#[cfg(test)]
8654#[cfg(feature = "cuda")]
8655mod tests {
8656 use super::*;
8657
8658 fn setup(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
8660 let dev = GpuDevice::new(0).expect("CUDA device 0");
8661 let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
8662 (dev, buf)
8663 }
8664
8665 fn assert_buf_eq(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32]) {
8668 let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
8669 assert_eq!(host.len(), expected.len(), "length mismatch");
8670 for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
8671 assert!(
8672 (got - exp).abs() < 1e-6,
8673 "element {i}: got {got}, expected {exp}",
8674 );
8675 }
8676 }
8677
8678 #[test]
8681 fn add_basic() {
8682 let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
8683 let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
8684 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
8685
8686 let (dev, a) = setup(&a_data);
8687 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
8688 let out = gpu_add(&a, &b, &dev).expect("gpu_add");
8689 assert_buf_eq(&out, &dev, &expected);
8690 }
8691
8692 #[test]
8693 fn add_empty() {
8694 let (dev, a) = setup(&[]);
8695 let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
8696 let out = gpu_add(&a, &b, &dev).expect("gpu_add empty");
8697 assert_eq!(out.len(), 0);
8698 }
8699
8700 #[test]
8701 fn add_large() {
8702 let n = 100_000;
8703 let a_data: Vec<f32> = (0..n).map(|i| i as f32).collect();
8704 let b_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
8705 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
8706
8707 let (dev, a) = setup(&a_data);
8708 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
8709 let out = gpu_add(&a, &b, &dev).expect("gpu_add large");
8710 assert_buf_eq(&out, &dev, &expected);
8711 }
8712
8713 #[test]
8714 fn add_length_mismatch() {
8715 let (dev, a) = setup(&[1.0, 2.0, 3.0]);
8716 let b = cpu_to_gpu(&[1.0, 2.0], &dev).expect("cpu_to_gpu b");
8717 let err = gpu_add(&a, &b, &dev).unwrap_err();
8718 match err {
8719 GpuError::LengthMismatch { a: 3, b: 2 } => {}
8720 other => panic!("unexpected error: {other}"),
8721 }
8722 }
8723
8724 #[test]
8727 fn sub_basic() {
8728 let a_data = vec![10.0f32, 20.0, 30.0, 40.0];
8729 let b_data = vec![1.0f32, 2.0, 3.0, 4.0];
8730 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x - y).collect();
8731
8732 let (dev, a) = setup(&a_data);
8733 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
8734 let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
8735 assert_buf_eq(&out, &dev, &expected);
8736 }
8737
8738 #[test]
8739 fn sub_negative_result() {
8740 let a_data = vec![1.0f32, 2.0];
8741 let b_data = vec![5.0f32, 10.0];
8742 let expected: Vec<f32> = vec![-4.0, -8.0];
8743
8744 let (dev, a) = setup(&a_data);
8745 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
8746 let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
8747 assert_buf_eq(&out, &dev, &expected);
8748 }
8749
8750 #[test]
8753 fn mul_basic() {
8754 let a_data = vec![2.0f32, 3.0, 4.0, 5.0];
8755 let b_data = vec![10.0f32, 10.0, 10.0, 10.0];
8756 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x * y).collect();
8757
8758 let (dev, a) = setup(&a_data);
8759 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
8760 let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
8761 assert_buf_eq(&out, &dev, &expected);
8762 }
8763
8764 #[test]
8765 fn mul_by_zero() {
8766 let a_data = vec![1.0f32, 2.0, 3.0];
8767 let b_data = vec![0.0f32, 0.0, 0.0];
8768 let expected = vec![0.0f32, 0.0, 0.0];
8769
8770 let (dev, a) = setup(&a_data);
8771 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
8772 let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
8773 assert_buf_eq(&out, &dev, &expected);
8774 }
8775
8776 #[test]
8779 fn neg_basic() {
8780 let a_data = vec![1.0f32, -2.0, 3.0, 0.0, -5.5];
8781 let expected: Vec<f32> = a_data.iter().map(|x| -x).collect();
8782
8783 let (dev, a) = setup(&a_data);
8784 let out = gpu_neg(&a, &dev).expect("gpu_neg");
8785 assert_buf_eq(&out, &dev, &expected);
8786 }
8787
8788 #[test]
8789 fn neg_double_negation() {
8790 let a_data = vec![1.0f32, -2.0, 3.0];
8791 let (dev, a) = setup(&a_data);
8792 let neg1 = gpu_neg(&a, &dev).expect("gpu_neg 1");
8793 let neg2 = gpu_neg(&neg1, &dev).expect("gpu_neg 2");
8794 assert_buf_eq(&neg2, &dev, &a_data);
8795 }
8796
8797 #[test]
8800 fn relu_basic() {
8801 let a_data = vec![-3.0f32, -1.0, 0.0, 1.0, 3.0];
8802 let expected = vec![0.0f32, 0.0, 0.0, 1.0, 3.0];
8803
8804 let (dev, a) = setup(&a_data);
8805 let out = gpu_relu(&a, &dev).expect("gpu_relu");
8806 assert_buf_eq(&out, &dev, &expected);
8807 }
8808
8809 #[test]
8810 fn relu_all_negative() {
8811 let a_data = vec![-5.0f32, -0.1, -100.0];
8812 let expected = vec![0.0f32, 0.0, 0.0];
8813
8814 let (dev, a) = setup(&a_data);
8815 let out = gpu_relu(&a, &dev).expect("gpu_relu");
8816 assert_buf_eq(&out, &dev, &expected);
8817 }
8818
8819 #[test]
8820 fn relu_all_positive() {
8821 let a_data = vec![0.1f32, 1.0, 100.0];
8822
8823 let (dev, a) = setup(&a_data);
8824 let out = gpu_relu(&a, &dev).expect("gpu_relu");
8825 assert_buf_eq(&out, &dev, &a_data);
8826 }
8827
8828 #[test]
8829 fn relu_empty() {
8830 let (dev, a) = setup(&[]);
8831 let out = gpu_relu(&a, &dev).expect("gpu_relu empty");
8832 assert_eq!(out.len(), 0);
8833 }
8834
8835 #[test]
8836 fn small_matmul_2x2() {
8837 let dev = GpuDevice::new(0).expect("CUDA device 0");
8838 let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0, 4.0], &dev).unwrap();
8841 let b = cpu_to_gpu(&[5.0f32, 6.0, 7.0, 8.0], &dev).unwrap();
8842 let c = gpu_small_matmul(&a, &b, 2, 2, 2, &dev).unwrap();
8843 assert_buf_eq(&c, &dev, &[19.0, 22.0, 43.0, 50.0]);
8844 }
8845
8846 #[test]
8847 fn small_matmul_1xk_kxn() {
8848 let dev = GpuDevice::new(0).expect("CUDA device 0");
8849 let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0], &dev).unwrap();
8852 let b = cpu_to_gpu(&[1.0f32, 0.0, 0.0, 1.0, 1.0, 1.0], &dev).unwrap();
8853 let c = gpu_small_matmul(&a, &b, 1, 3, 2, &dev).unwrap();
8854 assert_buf_eq(&c, &dev, &[4.0, 5.0]);
8855 }
8856
8857 #[test]
8858 fn small_matmul_vs_cublas() {
8859 let dev = GpuDevice::new(0).expect("CUDA device 0");
8862 let m = 1;
8863 let k = 64;
8864 let n = 64;
8865
8866 let a_data: Vec<f32> = (0..m * k)
8868 .map(|i| ((i * 7 + 3) % 100) as f32 / 100.0)
8869 .collect();
8870 let b_data: Vec<f32> = (0..k * n)
8871 .map(|i| ((i * 11 + 5) % 100) as f32 / 100.0)
8872 .collect();
8873
8874 let a = cpu_to_gpu(&a_data, &dev).unwrap();
8875 let b = cpu_to_gpu(&b_data, &dev).unwrap();
8876
8877 let c_cublas = crate::blas::gpu_matmul_f32(&a, &b, m, k, n, &dev).unwrap();
8879 let cublas_result = gpu_to_cpu(&c_cublas, &dev).unwrap();
8880
8881 let c_ours = gpu_small_matmul(&a, &b, m, k, n, &dev).unwrap();
8883 let our_result = gpu_to_cpu(&c_ours, &dev).unwrap();
8884
8885 assert_eq!(cublas_result.len(), our_result.len());
8886 for (i, (&cb, &ours)) in cublas_result.iter().zip(our_result.iter()).enumerate() {
8887 assert!(
8888 (cb - ours).abs() < 0.1,
8889 "element {i}: cuBLAS={cb}, ours={ours}, diff={}",
8890 (cb - ours).abs()
8891 );
8892 }
8893 }
8894}