1#[cfg(feature = "cuda")]
23use cudarc::driver::LaunchConfig;
24
25use crate::buffer::CudaBuffer;
26use crate::device::GpuDevice;
27use crate::error::{GpuError, GpuResult};
28#[cfg(feature = "cuda")]
29use crate::transfer::{alloc_zeros_f32, cpu_to_gpu, gpu_to_cpu};
30
31#[cfg(feature = "cuda")]
37pub(crate) const ADD_PTX: &str = "\
38.version 7.0
39.target sm_52
40.address_size 64
41
42.visible .entry add_kernel(
43 .param .u64 a_ptr,
44 .param .u64 b_ptr,
45 .param .u64 out_ptr,
46 .param .u32 n
47) {
48 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
49 .reg .u64 %a, %b, %out, %off;
50 .reg .f32 %va, %vb, %vr;
51 .reg .pred %p;
52
53 ld.param.u64 %a, [a_ptr];
54 ld.param.u64 %b, [b_ptr];
55 ld.param.u64 %out, [out_ptr];
56 ld.param.u32 %n_reg, [n];
57
58 mov.u32 %bid, %ctaid.x;
59 mov.u32 %bdim, %ntid.x;
60 mov.u32 %r_tid, %tid.x;
61 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
62
63 setp.ge.u32 %p, %r_tid, %n_reg;
64 @%p bra DONE;
65
66 cvt.u64.u32 %off, %r_tid;
67 shl.b64 %off, %off, 2;
68
69 add.u64 %a, %a, %off;
70 add.u64 %b, %b, %off;
71 add.u64 %out, %out, %off;
72
73 ld.global.f32 %va, [%a];
74 ld.global.f32 %vb, [%b];
75 add.f32 %vr, %va, %vb;
76 st.global.f32 [%out], %vr;
77
78DONE:
79 ret;
80}
81";
82
83#[cfg(feature = "cuda")]
88pub(crate) const ADD_VEC4_PTX: &str = "\
89.version 7.0
90.target sm_52
91.address_size 64
92
93.visible .entry add_vec4_kernel(
94 .param .u64 a_ptr,
95 .param .u64 b_ptr,
96 .param .u64 out_ptr,
97 .param .u32 n4
98) {
99 .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
100 .reg .u64 %a, %b, %out, %off;
101 .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
102 .reg .pred %p;
103
104 ld.param.u64 %a, [a_ptr];
105 ld.param.u64 %b, [b_ptr];
106 ld.param.u64 %out, [out_ptr];
107 ld.param.u32 %n4_reg, [n4];
108
109 mov.u32 %bid, %ctaid.x;
110 mov.u32 %bdim, %ntid.x;
111 mov.u32 %r_tid, %tid.x;
112 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
113
114 setp.ge.u32 %p, %r_tid, %n4_reg;
115 @%p bra DONE;
116
117 // Byte offset = tid * 16 (4 floats × 4 bytes)
118 cvt.u64.u32 %off, %r_tid;
119 shl.b64 %off, %off, 4;
120
121 add.u64 %a, %a, %off;
122 add.u64 %b, %b, %off;
123 add.u64 %out, %out, %off;
124
125 ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
126 ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
127
128 add.f32 %r0, %a0, %b0;
129 add.f32 %r1, %a1, %b1;
130 add.f32 %r2, %a2, %b2;
131 add.f32 %r3, %a3, %b3;
132
133 st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
134
135DONE:
136 ret;
137}
138";
139
140#[cfg(feature = "cuda")]
142pub(crate) const MUL_VEC4_PTX: &str = "\
143.version 7.0
144.target sm_52
145.address_size 64
146
147.visible .entry mul_vec4_kernel(
148 .param .u64 a_ptr,
149 .param .u64 b_ptr,
150 .param .u64 out_ptr,
151 .param .u32 n4
152) {
153 .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
154 .reg .u64 %a, %b, %out, %off;
155 .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
156 .reg .pred %p;
157
158 ld.param.u64 %a, [a_ptr];
159 ld.param.u64 %b, [b_ptr];
160 ld.param.u64 %out, [out_ptr];
161 ld.param.u32 %n4_reg, [n4];
162
163 mov.u32 %bid, %ctaid.x;
164 mov.u32 %bdim, %ntid.x;
165 mov.u32 %r_tid, %tid.x;
166 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
167
168 setp.ge.u32 %p, %r_tid, %n4_reg;
169 @%p bra DONE;
170
171 cvt.u64.u32 %off, %r_tid;
172 shl.b64 %off, %off, 4;
173
174 add.u64 %a, %a, %off;
175 add.u64 %b, %b, %off;
176 add.u64 %out, %out, %off;
177
178 ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
179 ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
180
181 mul.f32 %r0, %a0, %b0;
182 mul.f32 %r1, %a1, %b1;
183 mul.f32 %r2, %a2, %b2;
184 mul.f32 %r3, %a3, %b3;
185
186 st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
187
188DONE:
189 ret;
190}
191";
192
193#[cfg(feature = "cuda")]
195pub(crate) const SUB_PTX: &str = "\
196.version 7.0
197.target sm_52
198.address_size 64
199
200.visible .entry sub_kernel(
201 .param .u64 a_ptr,
202 .param .u64 b_ptr,
203 .param .u64 out_ptr,
204 .param .u32 n
205) {
206 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
207 .reg .u64 %a, %b, %out, %off;
208 .reg .f32 %va, %vb, %vr;
209 .reg .pred %p;
210
211 ld.param.u64 %a, [a_ptr];
212 ld.param.u64 %b, [b_ptr];
213 ld.param.u64 %out, [out_ptr];
214 ld.param.u32 %n_reg, [n];
215
216 mov.u32 %bid, %ctaid.x;
217 mov.u32 %bdim, %ntid.x;
218 mov.u32 %r_tid, %tid.x;
219 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
220
221 setp.ge.u32 %p, %r_tid, %n_reg;
222 @%p bra DONE;
223
224 cvt.u64.u32 %off, %r_tid;
225 shl.b64 %off, %off, 2;
226
227 add.u64 %a, %a, %off;
228 add.u64 %b, %b, %off;
229 add.u64 %out, %out, %off;
230
231 ld.global.f32 %va, [%a];
232 ld.global.f32 %vb, [%b];
233 sub.f32 %vr, %va, %vb;
234 st.global.f32 [%out], %vr;
235
236DONE:
237 ret;
238}
239";
240
241#[cfg(feature = "cuda")]
243pub(crate) const MUL_PTX: &str = "\
244.version 7.0
245.target sm_52
246.address_size 64
247
248.visible .entry mul_kernel(
249 .param .u64 a_ptr,
250 .param .u64 b_ptr,
251 .param .u64 out_ptr,
252 .param .u32 n
253) {
254 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
255 .reg .u64 %a, %b, %out, %off;
256 .reg .f32 %va, %vb, %vr;
257 .reg .pred %p;
258
259 ld.param.u64 %a, [a_ptr];
260 ld.param.u64 %b, [b_ptr];
261 ld.param.u64 %out, [out_ptr];
262 ld.param.u32 %n_reg, [n];
263
264 mov.u32 %bid, %ctaid.x;
265 mov.u32 %bdim, %ntid.x;
266 mov.u32 %r_tid, %tid.x;
267 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
268
269 setp.ge.u32 %p, %r_tid, %n_reg;
270 @%p bra DONE;
271
272 cvt.u64.u32 %off, %r_tid;
273 shl.b64 %off, %off, 2;
274
275 add.u64 %a, %a, %off;
276 add.u64 %b, %b, %off;
277 add.u64 %out, %out, %off;
278
279 ld.global.f32 %va, [%a];
280 ld.global.f32 %vb, [%b];
281 mul.f32 %vr, %va, %vb;
282 st.global.f32 [%out], %vr;
283
284DONE:
285 ret;
286}
287";
288
289#[cfg(feature = "cuda")]
291pub(crate) const NEG_PTX: &str = "\
292.version 7.0
293.target sm_52
294.address_size 64
295
296.visible .entry neg_kernel(
297 .param .u64 a_ptr,
298 .param .u64 out_ptr,
299 .param .u32 n
300) {
301 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
302 .reg .u64 %a, %out, %off;
303 .reg .f32 %va, %vr;
304 .reg .pred %p;
305
306 ld.param.u64 %a, [a_ptr];
307 ld.param.u64 %out, [out_ptr];
308 ld.param.u32 %n_reg, [n];
309
310 mov.u32 %bid, %ctaid.x;
311 mov.u32 %bdim, %ntid.x;
312 mov.u32 %r_tid, %tid.x;
313 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
314
315 setp.ge.u32 %p, %r_tid, %n_reg;
316 @%p bra DONE;
317
318 cvt.u64.u32 %off, %r_tid;
319 shl.b64 %off, %off, 2;
320
321 add.u64 %a, %a, %off;
322 add.u64 %out, %out, %off;
323
324 ld.global.f32 %va, [%a];
325 neg.f32 %vr, %va;
326 st.global.f32 [%out], %vr;
327
328DONE:
329 ret;
330}
331";
332
333#[cfg(feature = "cuda")]
335pub(crate) const RELU_PTX: &str = "\
336.version 7.0
337.target sm_52
338.address_size 64
339
340.visible .entry relu_kernel(
341 .param .u64 a_ptr,
342 .param .u64 out_ptr,
343 .param .u32 n
344) {
345 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
346 .reg .u64 %a, %out, %off;
347 .reg .f32 %va, %vr, %zero;
348 .reg .pred %p;
349
350 ld.param.u64 %a, [a_ptr];
351 ld.param.u64 %out, [out_ptr];
352 ld.param.u32 %n_reg, [n];
353
354 mov.u32 %bid, %ctaid.x;
355 mov.u32 %bdim, %ntid.x;
356 mov.u32 %r_tid, %tid.x;
357 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
358
359 setp.ge.u32 %p, %r_tid, %n_reg;
360 @%p bra DONE;
361
362 cvt.u64.u32 %off, %r_tid;
363 shl.b64 %off, %off, 2;
364
365 add.u64 %a, %a, %off;
366 add.u64 %out, %out, %off;
367
368 ld.global.f32 %va, [%a];
369 mov.f32 %zero, 0f00000000;
370 max.f32 %vr, %va, %zero;
371 st.global.f32 [%out], %vr;
372
373DONE:
374 ret;
375}
376";
377
378#[cfg(feature = "cuda")]
380pub(crate) const SCALE_PTX: &str = "\
381.version 7.0
382.target sm_52
383.address_size 64
384
385.visible .entry scale_kernel(
386 .param .u64 a_ptr,
387 .param .u64 out_ptr,
388 .param .f32 scalar,
389 .param .u32 n
390) {
391 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
392 .reg .u64 %a, %out, %off;
393 .reg .f32 %va, %vr, %s;
394 .reg .pred %p;
395
396 ld.param.u64 %a, [a_ptr];
397 ld.param.u64 %out, [out_ptr];
398 ld.param.f32 %s, [scalar];
399 ld.param.u32 %n_reg, [n];
400
401 mov.u32 %bid, %ctaid.x;
402 mov.u32 %bdim, %ntid.x;
403 mov.u32 %r_tid, %tid.x;
404 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
405
406 setp.ge.u32 %p, %r_tid, %n_reg;
407 @%p bra DONE;
408
409 cvt.u64.u32 %off, %r_tid;
410 shl.b64 %off, %off, 2;
411
412 add.u64 %a, %a, %off;
413 add.u64 %out, %out, %off;
414
415 ld.global.f32 %va, [%a];
416 mul.f32 %vr, %va, %s;
417 st.global.f32 [%out], %vr;
418
419DONE:
420 ret;
421}
422";
423
424#[cfg(feature = "cuda")]
427pub(crate) const TRANSPOSE_2D_PTX: &str = "\
428.version 7.0\n\
429.target sm_52\n\
430.address_size 64\n\
431\n\
432.visible .entry transpose_2d_kernel(\n\
433 .param .u64 in_ptr,\n\
434 .param .u64 out_ptr,\n\
435 .param .u32 M,\n\
436 .param .u32 N,\n\
437 .param .u32 total\n\
438) {\n\
439 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %N_reg;\n\
440 .reg .u32 %out_row, %out_col, %in_idx;\n\
441 .reg .u64 %in, %out, %off_in, %off_out;\n\
442 .reg .f32 %val;\n\
443 .reg .pred %p;\n\
444\n\
445 ld.param.u64 %in, [in_ptr];\n\
446 ld.param.u64 %out, [out_ptr];\n\
447 ld.param.u32 %M_reg, [M];\n\
448 ld.param.u32 %N_reg, [N];\n\
449 ld.param.u32 %total_reg, [total];\n\
450\n\
451 mov.u32 %bid, %ctaid.x;\n\
452 mov.u32 %bdim, %ntid.x;\n\
453 mov.u32 %r_tid, %tid.x;\n\
454 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
455\n\
456 setp.ge.u32 %p, %r_tid, %total_reg;\n\
457 @%p bra DONE;\n\
458\n\
459 // Output shape is [N, M]. tid = out_row * M + out_col.\n\
460 div.u32 %out_row, %r_tid, %M_reg;\n\
461 rem.u32 %out_col, %r_tid, %M_reg;\n\
462 // Input index: out_col * N + out_row (transposed).\n\
463 mad.lo.u32 %in_idx, %out_col, %N_reg, %out_row;\n\
464\n\
465 cvt.u64.u32 %off_in, %in_idx;\n\
466 shl.b64 %off_in, %off_in, 2;\n\
467 add.u64 %off_in, %in, %off_in;\n\
468 ld.global.f32 %val, [%off_in];\n\
469\n\
470 cvt.u64.u32 %off_out, %r_tid;\n\
471 shl.b64 %off_out, %off_out, 2;\n\
472 add.u64 %off_out, %out, %off_out;\n\
473 st.global.f32 [%off_out], %val;\n\
474\n\
475DONE:\n\
476 ret;\n\
477}\n\
478";
479
480#[cfg(feature = "cuda")]
488pub(crate) const PERMUTE_0213_PTX: &str = "\
489.version 7.0\n\
490.target sm_52\n\
491.address_size 64\n\
492\n\
493.visible .entry permute_0213_kernel(\n\
494 .param .u64 in_ptr,\n\
495 .param .u64 out_ptr,\n\
496 .param .u32 d0,\n\
497 .param .u32 d1,\n\
498 .param .u32 d2,\n\
499 .param .u32 d3,\n\
500 .param .u32 total\n\
501) {\n\
502 .reg .u32 %r_tid, %bid, %bdim, %total_reg;\n\
503 .reg .u32 %d0r, %d1r, %d2r, %d3r;\n\
504 .reg .u32 %i0, %i1, %i2, %i3, %rem, %in_idx;\n\
505 .reg .u32 %s_out2, %s_out1, %s_in1;\n\
506 .reg .u64 %in, %out, %off_in, %off_out;\n\
507 .reg .f32 %val;\n\
508 .reg .pred %p;\n\
509\n\
510 ld.param.u64 %in, [in_ptr];\n\
511 ld.param.u64 %out, [out_ptr];\n\
512 ld.param.u32 %d0r, [d0];\n\
513 ld.param.u32 %d1r, [d1];\n\
514 ld.param.u32 %d2r, [d2];\n\
515 ld.param.u32 %d3r, [d3];\n\
516 ld.param.u32 %total_reg, [total];\n\
517\n\
518 mov.u32 %bid, %ctaid.x;\n\
519 mov.u32 %bdim, %ntid.x;\n\
520 mov.u32 %r_tid, %tid.x;\n\
521 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
522\n\
523 setp.ge.u32 %p, %r_tid, %total_reg;\n\
524 @%p bra DONE;\n\
525\n\
526 // Output shape: [d0, d2, d1, d3]\n\
527 // Decompose tid into (i0, i2, i1, i3) in output layout.\n\
528 mul.lo.u32 %s_out2, %d1r, %d3r;\n\
529 mul.lo.u32 %s_out1, %s_out2, %d2r;\n\
530\n\
531 div.u32 %i0, %r_tid, %s_out1;\n\
532 rem.u32 %rem, %r_tid, %s_out1;\n\
533 div.u32 %i2, %rem, %s_out2;\n\
534 rem.u32 %rem, %rem, %s_out2;\n\
535 div.u32 %i1, %rem, %d3r;\n\
536 rem.u32 %i3, %rem, %d3r;\n\
537\n\
538 // Input index: i0 * (d1*d2*d3) + i1 * (d2*d3) + i2 * d3 + i3\n\
539 mul.lo.u32 %s_in1, %d2r, %d3r;\n\
540 mul.lo.u32 %in_idx, %i0, %d1r;\n\
541 add.u32 %in_idx, %in_idx, %i1;\n\
542 mul.lo.u32 %in_idx, %in_idx, %s_in1;\n\
543 mad.lo.u32 %in_idx, %i2, %d3r, %in_idx;\n\
544 add.u32 %in_idx, %in_idx, %i3;\n\
545\n\
546 cvt.u64.u32 %off_in, %in_idx;\n\
547 shl.b64 %off_in, %off_in, 2;\n\
548 add.u64 %off_in, %in, %off_in;\n\
549 ld.global.f32 %val, [%off_in];\n\
550\n\
551 cvt.u64.u32 %off_out, %r_tid;\n\
552 shl.b64 %off_out, %off_out, 2;\n\
553 add.u64 %off_out, %out, %off_out;\n\
554 st.global.f32 [%off_out], %val;\n\
555\n\
556DONE:\n\
557 ret;\n\
558}\n\
559";
560
561#[cfg(feature = "cuda")]
568pub(crate) const F32_TO_F16_PTX: &str = "\
569.version 7.0
570.target sm_52
571.address_size 64
572
573.visible .entry f32_to_f16_kernel(
574 .param .u64 in_ptr,
575 .param .u64 out_ptr,
576 .param .u32 n
577) {
578 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
579 .reg .u64 %in, %out, %off_in, %off_out;
580 .reg .f32 %vf;
581 .reg .b16 %vh;
582 .reg .pred %p;
583
584 ld.param.u64 %in, [in_ptr];
585 ld.param.u64 %out, [out_ptr];
586 ld.param.u32 %n_reg, [n];
587
588 mov.u32 %bid, %ctaid.x;
589 mov.u32 %bdim, %ntid.x;
590 mov.u32 %r_tid, %tid.x;
591 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
592
593 setp.ge.u32 %p, %r_tid, %n_reg;
594 @%p bra DONE;
595
596 // Compute input offset: i * 4 (f32 = 4 bytes)
597 cvt.u64.u32 %off_in, %r_tid;
598 shl.b64 %off_in, %off_in, 2;
599 add.u64 %in, %in, %off_in;
600
601 // Compute output offset: i * 2 (f16 = 2 bytes)
602 cvt.u64.u32 %off_out, %r_tid;
603 shl.b64 %off_out, %off_out, 1;
604 add.u64 %out, %out, %off_out;
605
606 // Load f32, convert to f16 (round-to-nearest-even), store as u16
607 ld.global.f32 %vf, [%in];
608 cvt.rn.f16.f32 %vh, %vf;
609 st.global.b16 [%out], %vh;
610
611DONE:
612 ret;
613}
614";
615
616#[cfg(feature = "cuda")]
623pub(crate) const F32_TO_BF16_PTX: &str = "\
624.version 7.0
625.target sm_52
626.address_size 64
627
628.visible .entry f32_to_bf16_kernel(
629 .param .u64 in_ptr,
630 .param .u64 out_ptr,
631 .param .u32 n
632) {
633 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
634 .reg .u64 %in, %out, %off_in, %off_out;
635 .reg .f32 %vf;
636 .reg .u32 %bits, %round, %lsb, %result;
637 .reg .pred %p;
638
639 ld.param.u64 %in, [in_ptr];
640 ld.param.u64 %out, [out_ptr];
641 ld.param.u32 %n_reg, [n];
642
643 mov.u32 %bid, %ctaid.x;
644 mov.u32 %bdim, %ntid.x;
645 mov.u32 %r_tid, %tid.x;
646 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
647
648 setp.ge.u32 %p, %r_tid, %n_reg;
649 @%p bra DONE;
650
651 cvt.u64.u32 %off_in, %r_tid;
652 shl.b64 %off_in, %off_in, 2;
653 add.u64 %in, %in, %off_in;
654
655 cvt.u64.u32 %off_out, %r_tid;
656 shl.b64 %off_out, %off_out, 1;
657 add.u64 %out, %out, %off_out;
658
659 // Load f32 as raw bits
660 ld.global.u32 %bits, [%in];
661
662 // Round-to-nearest-even: add (0x7FFF + bit[16]) then shift right 16
663 shr.u32 %lsb, %bits, 16;
664 and.b32 %lsb, %lsb, 1;
665 add.u32 %round, %bits, 0x7FFF;
666 add.u32 %round, %round, %lsb;
667 shr.u32 %result, %round, 16;
668
669 // Store as u16
670 st.global.u16 [%out], %result;
671
672DONE:
673 ret;
674}
675";
676
677#[cfg(feature = "cuda")]
684pub(crate) const SMALL_MATMUL_PTX: &str = "\
685.version 7.0
686.target sm_52
687.address_size 64
688
689.visible .entry small_matmul_kernel(
690 .param .u64 a_ptr,
691 .param .u64 b_ptr,
692 .param .u64 c_ptr,
693 .param .u32 M,
694 .param .u32 K,
695 .param .u32 N,
696 .param .u32 total
697) {
698 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %K_reg, %N_reg;
699 .reg .u32 %row, %col, %p, %idx;
700 .reg .u64 %a, %b, %c, %a_off, %b_off, %c_off;
701 .reg .f32 %sum, %va, %vb;
702 .reg .pred %bounds_p, %loop_p;
703
704 ld.param.u64 %a, [a_ptr];
705 ld.param.u64 %b, [b_ptr];
706 ld.param.u64 %c, [c_ptr];
707 ld.param.u32 %M_reg, [M];
708 ld.param.u32 %K_reg, [K];
709 ld.param.u32 %N_reg, [N];
710 ld.param.u32 %total_reg, [total];
711
712 mov.u32 %bid, %ctaid.x;
713 mov.u32 %bdim, %ntid.x;
714 mov.u32 %r_tid, %tid.x;
715 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
716
717 setp.ge.u32 %bounds_p, %r_tid, %total_reg;
718 @%bounds_p bra DONE;
719
720 div.u32 %row, %r_tid, %N_reg;
721 rem.u32 %col, %r_tid, %N_reg;
722
723 mov.f32 %sum, 0f00000000;
724 mov.u32 %p, 0;
725DOT:
726 setp.ge.u32 %loop_p, %p, %K_reg;
727 @%loop_p bra DOT_DONE;
728
729 mad.lo.u32 %idx, %row, %K_reg, %p;
730 cvt.u64.u32 %a_off, %idx;
731 shl.b64 %a_off, %a_off, 2;
732 add.u64 %a_off, %a, %a_off;
733 ld.global.f32 %va, [%a_off];
734
735 mad.lo.u32 %idx, %p, %N_reg, %col;
736 cvt.u64.u32 %b_off, %idx;
737 shl.b64 %b_off, %b_off, 2;
738 add.u64 %b_off, %b, %b_off;
739 ld.global.f32 %vb, [%b_off];
740
741 fma.rn.f32 %sum, %va, %vb, %sum;
742 add.u32 %p, %p, 1;
743 bra DOT;
744DOT_DONE:
745
746 cvt.u64.u32 %c_off, %r_tid;
747 shl.b64 %c_off, %c_off, 2;
748 add.u64 %c_off, %c, %c_off;
749 st.global.f32 [%c_off], %sum;
750
751DONE:
752 ret;
753}
754";
755
756#[cfg(feature = "cuda")]
761pub(crate) const SLICE_WRITE_PTX: &str = "\
762.version 7.0
763.target sm_52
764.address_size 64
765
766.visible .entry slice_write_kernel(
767 .param .u64 src_ptr,
768 .param .u64 dst_ptr,
769 .param .u32 n,
770 .param .u32 D,
771 .param .u32 max_len,
772 .param .u32 pos
773) {
774 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
775 .reg .u32 %batch_idx, %d_idx, %dst_row;
776 .reg .u64 %src, %dst, %src_off, %dst_off;
777 .reg .f32 %val;
778 .reg .pred %p;
779
780 ld.param.u64 %src, [src_ptr];
781 ld.param.u64 %dst, [dst_ptr];
782 ld.param.u32 %n_reg, [n];
783 ld.param.u32 %D_reg, [D];
784 ld.param.u32 %max_len_reg, [max_len];
785 ld.param.u32 %pos_reg, [pos];
786
787 mov.u32 %bid, %ctaid.x;
788 mov.u32 %bdim, %ntid.x;
789 mov.u32 %r_tid, %tid.x;
790 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
791
792 setp.ge.u32 %p, %r_tid, %n_reg;
793 @%p bra DONE;
794
795 cvt.u64.u32 %src_off, %r_tid;
796 shl.b64 %src_off, %src_off, 2;
797 add.u64 %src, %src, %src_off;
798 ld.global.f32 %val, [%src];
799
800 div.u32 %batch_idx, %r_tid, %D_reg;
801 rem.u32 %d_idx, %r_tid, %D_reg;
802 mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
803 add.u32 %dst_row, %dst_row, %pos_reg;
804 mul.lo.u32 %dst_row, %dst_row, %D_reg;
805 add.u32 %dst_row, %dst_row, %d_idx;
806 cvt.u64.u32 %dst_off, %dst_row;
807 shl.b64 %dst_off, %dst_off, 2;
808 add.u64 %dst, %dst, %dst_off;
809 st.global.f32 [%dst], %val;
810
811DONE:
812 ret;
813}
814";
815
816#[cfg(feature = "cuda")]
821pub(crate) const SLICE_WRITE_INDIRECT_PTX: &str = "\
822.version 7.0
823.target sm_52
824.address_size 64
825
826.visible .entry slice_write_indirect_kernel(
827 .param .u64 src_ptr,
828 .param .u64 dst_ptr,
829 .param .u32 n,
830 .param .u32 D,
831 .param .u32 max_len,
832 .param .u64 pos_ptr
833) {
834 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
835 .reg .u32 %batch_idx, %d_idx, %dst_row;
836 .reg .u64 %src, %dst, %src_off, %dst_off, %pos_p;
837 .reg .f32 %val;
838 .reg .pred %p;
839
840 ld.param.u64 %src, [src_ptr];
841 ld.param.u64 %dst, [dst_ptr];
842 ld.param.u32 %n_reg, [n];
843 ld.param.u32 %D_reg, [D];
844 ld.param.u32 %max_len_reg, [max_len];
845 ld.param.u64 %pos_p, [pos_ptr];
846 ld.global.u32 %pos_reg, [%pos_p];
847
848 mov.u32 %bid, %ctaid.x;
849 mov.u32 %bdim, %ntid.x;
850 mov.u32 %r_tid, %tid.x;
851 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
852
853 setp.ge.u32 %p, %r_tid, %n_reg;
854 @%p bra DONE;
855
856 cvt.u64.u32 %src_off, %r_tid;
857 shl.b64 %src_off, %src_off, 2;
858 add.u64 %src, %src, %src_off;
859 ld.global.f32 %val, [%src];
860
861 div.u32 %batch_idx, %r_tid, %D_reg;
862 rem.u32 %d_idx, %r_tid, %D_reg;
863 mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
864 add.u32 %dst_row, %dst_row, %pos_reg;
865 mul.lo.u32 %dst_row, %dst_row, %D_reg;
866 add.u32 %dst_row, %dst_row, %d_idx;
867 cvt.u64.u32 %dst_off, %dst_row;
868 shl.b64 %dst_off, %dst_off, 2;
869 add.u64 %dst, %dst, %dst_off;
870 st.global.f32 [%dst], %val;
871
872DONE:
873 ret;
874}
875";
876
877#[cfg(feature = "cuda")]
883pub(crate) const CAUSAL_MASK_INDIRECT_PTX: &str = "\
884.version 7.0
885.target sm_52
886.address_size 64
887
888.visible .entry causal_mask_indirect_kernel(
889 .param .u64 total_len_ptr,
890 .param .u64 out_ptr,
891 .param .u32 max_pos,
892 .param .u32 total
893) {
894 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %tlen, %max_pos_reg, %col;
895 .reg .u64 %out, %off, %tl_p;
896 .reg .f32 %val;
897 .reg .pred %bounds_p, %mask_p;
898
899 ld.param.u64 %tl_p, [total_len_ptr];
900 ld.param.u64 %out, [out_ptr];
901 ld.param.u32 %max_pos_reg, [max_pos];
902 ld.param.u32 %total_reg, [total];
903
904 ld.global.u32 %tlen, [%tl_p];
905
906 mov.u32 %bid, %ctaid.x;
907 mov.u32 %bdim, %ntid.x;
908 mov.u32 %r_tid, %tid.x;
909 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
910
911 setp.ge.u32 %bounds_p, %r_tid, %total_reg;
912 @%bounds_p bra DONE;
913
914 rem.u32 %col, %r_tid, %max_pos_reg;
915 setp.lt.u32 %mask_p, %col, %tlen;
916 @%mask_p bra WRITE_ZERO;
917
918 // 0fCE6E6B28 = -1.0e9 in IEEE 754 f32, used as a large negative mask value
919 // to effectively zero out masked positions after softmax.
920 mov.f32 %val, 0fCE6E6B28;
921 bra WRITE;
922
923WRITE_ZERO:
924 mov.f32 %val, 0f00000000;
925
926WRITE:
927 cvt.u64.u32 %off, %r_tid;
928 shl.b64 %off, %off, 2;
929 add.u64 %out, %out, %off;
930 st.global.f32 [%out], %val;
931
932DONE:
933 ret;
934}
935";
936
937#[cfg(feature = "cuda")]
942pub(crate) const EMBED_LOOKUP_PTX: &str = "\
943.version 7.0
944.target sm_52
945.address_size 64
946
947.visible .entry embed_lookup_kernel(
948 .param .u64 idx_ptr,
949 .param .u64 weight_ptr,
950 .param .u64 out_ptr,
951 .param .u32 D
952) {
953 .reg .u32 %r_tid, %bid, %bdim, %D_reg, %row, %src_idx;
954 .reg .u64 %idx_addr, %w, %out, %off;
955 .reg .f32 %idx_f, %val;
956 .reg .pred %p;
957
958 ld.param.u64 %idx_addr, [idx_ptr];
959 ld.param.u64 %w, [weight_ptr];
960 ld.param.u64 %out, [out_ptr];
961 ld.param.u32 %D_reg, [D];
962
963 mov.u32 %bid, %ctaid.x;
964 mov.u32 %bdim, %ntid.x;
965 mov.u32 %r_tid, %tid.x;
966 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
967
968 setp.ge.u32 %p, %r_tid, %D_reg;
969 @%p bra DONE;
970
971 ld.global.f32 %idx_f, [%idx_addr];
972 cvt.rzi.u32.f32 %row, %idx_f;
973
974 mad.lo.u32 %src_idx, %row, %D_reg, %r_tid;
975 cvt.u64.u32 %off, %src_idx;
976 shl.b64 %off, %off, 2;
977 add.u64 %off, %w, %off;
978 ld.global.f32 %val, [%off];
979
980 cvt.u64.u32 %off, %r_tid;
981 shl.b64 %off, %off, 2;
982 add.u64 %off, %out, %off;
983 st.global.f32 [%off], %val;
984
985DONE:
986 ret;
987}
988";
989
990#[cfg(feature = "cuda")]
998pub(crate) const EMBED_LOOKUP_BATCH_PTX: &str = "\
999.version 7.0
1000.target sm_52
1001.address_size 64
1002
1003.visible .entry embed_lookup_batch_kernel(
1004 .param .u64 idx_ptr,
1005 .param .u64 weight_ptr,
1006 .param .u64 out_ptr,
1007 .param .u32 D,
1008 .param .u32 total
1009) {
1010 .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1011 .reg .u32 %row, %col, %src_idx;
1012 .reg .u64 %idx_addr, %w, %out, %off;
1013 .reg .f32 %idx_f, %val;
1014 .reg .pred %p;
1015
1016 ld.param.u64 %idx_addr, [idx_ptr];
1017 ld.param.u64 %w, [weight_ptr];
1018 ld.param.u64 %out, [out_ptr];
1019 ld.param.u32 %D_reg, [D];
1020 ld.param.u32 %total_reg, [total];
1021
1022 mov.u32 %bid, %ctaid.x;
1023 mov.u32 %bdim, %ntid.x;
1024 mov.u32 %tid, %tid.x;
1025 mad.lo.u32 %tid, %bid, %bdim, %tid;
1026
1027 setp.ge.u32 %p, %tid, %total_reg;
1028 @%p bra DONE;
1029
1030 // row = tid / D, col = tid % D
1031 div.u32 %row, %tid, %D_reg;
1032 rem.u32 %col, %tid, %D_reg;
1033
1034 // Read indices[row] (f32 -> u32)
1035 cvt.u64.u32 %off, %row;
1036 shl.b64 %off, %off, 2;
1037 add.u64 %off, %idx_addr, %off;
1038 ld.global.f32 %idx_f, [%off];
1039 cvt.rzi.u32.f32 %src_idx, %idx_f;
1040
1041 // src_idx = indices[row] * D + col
1042 mad.lo.u32 %src_idx, %src_idx, %D_reg, %col;
1043 cvt.u64.u32 %off, %src_idx;
1044 shl.b64 %off, %off, 2;
1045 add.u64 %off, %w, %off;
1046 ld.global.f32 %val, [%off];
1047
1048 // Write to out[tid]
1049 cvt.u64.u32 %off, %tid;
1050 shl.b64 %off, %off, 2;
1051 add.u64 %off, %out, %off;
1052 st.global.f32 [%off], %val;
1053
1054DONE:
1055 ret;
1056}
1057";
1058
1059#[cfg(feature = "cuda")]
1067pub(crate) const SCATTER_ADD_ROWS_PTX: &str = "\
1068.version 7.0
1069.target sm_52
1070.address_size 64
1071
1072.visible .entry scatter_add_rows_kernel(
1073 .param .u64 grad_output_ptr,
1074 .param .u64 indices_ptr,
1075 .param .u64 grad_weight_ptr,
1076 .param .u32 D,
1077 .param .u32 total
1078) {
1079 .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1080 .reg .u32 %row, %col, %dst_idx;
1081 .reg .u64 %go, %idx_addr, %gw, %off;
1082 .reg .f32 %idx_f, %grad_val, %dummy;
1083 .reg .pred %p;
1084
1085 ld.param.u64 %go, [grad_output_ptr];
1086 ld.param.u64 %idx_addr, [indices_ptr];
1087 ld.param.u64 %gw, [grad_weight_ptr];
1088 ld.param.u32 %D_reg, [D];
1089 ld.param.u32 %total_reg, [total];
1090
1091 mov.u32 %bid, %ctaid.x;
1092 mov.u32 %bdim, %ntid.x;
1093 mov.u32 %tid, %tid.x;
1094 mad.lo.u32 %tid, %bid, %bdim, %tid;
1095
1096 setp.ge.u32 %p, %tid, %total_reg;
1097 @%p bra DONE;
1098
1099 // row = tid / D, col = tid % D
1100 div.u32 %row, %tid, %D_reg;
1101 rem.u32 %col, %tid, %D_reg;
1102
1103 // Read grad_output[tid]
1104 cvt.u64.u32 %off, %tid;
1105 shl.b64 %off, %off, 2;
1106 add.u64 %off, %go, %off;
1107 ld.global.f32 %grad_val, [%off];
1108
1109 // Read indices[row] (f32 -> u32)
1110 cvt.u64.u32 %off, %row;
1111 shl.b64 %off, %off, 2;
1112 add.u64 %off, %idx_addr, %off;
1113 ld.global.f32 %idx_f, [%off];
1114 cvt.rzi.u32.f32 %dst_idx, %idx_f;
1115
1116 // dst_idx = indices[row] * D + col
1117 mad.lo.u32 %dst_idx, %dst_idx, %D_reg, %col;
1118 cvt.u64.u32 %off, %dst_idx;
1119 shl.b64 %off, %off, 2;
1120 add.u64 %off, %gw, %off;
1121 atom.global.add.f32 %dummy, [%off], %grad_val;
1122
1123DONE:
1124 ret;
1125}
1126";
1127
1128#[cfg(feature = "cuda")]
1135pub(crate) const SLICE_READ_PTX: &str = "\
1136.version 7.0
1137.target sm_52
1138.address_size 64
1139
1140.visible .entry slice_read_kernel(
1141 .param .u64 src_ptr,
1142 .param .u64 dst_ptr,
1143 .param .u32 total,
1144 .param .u32 D,
1145 .param .u32 len,
1146 .param .u32 max_len
1147) {
1148 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %D_reg, %len_reg, %max_len_reg;
1149 .reg .u32 %batch_idx, %within, %row, %col, %src_idx;
1150 .reg .u32 %len_d;
1151 .reg .u64 %src, %dst, %src_off, %dst_off;
1152 .reg .f32 %val;
1153 .reg .pred %p;
1154
1155 ld.param.u64 %src, [src_ptr];
1156 ld.param.u64 %dst, [dst_ptr];
1157 ld.param.u32 %total_reg, [total];
1158 ld.param.u32 %D_reg, [D];
1159 ld.param.u32 %len_reg, [len];
1160 ld.param.u32 %max_len_reg, [max_len];
1161
1162 mov.u32 %bid, %ctaid.x;
1163 mov.u32 %bdim, %ntid.x;
1164 mov.u32 %r_tid, %tid.x;
1165 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1166
1167 setp.ge.u32 %p, %r_tid, %total_reg;
1168 @%p bra DONE;
1169
1170 // dst index = r_tid
1171 // batch_idx = r_tid / (len * D)
1172 // within = r_tid % (len * D)
1173 // row = within / D
1174 // col = within % D
1175 // src_idx = batch_idx * max_len * D + row * D + col
1176 mul.lo.u32 %len_d, %len_reg, %D_reg;
1177 div.u32 %batch_idx, %r_tid, %len_d;
1178 rem.u32 %within, %r_tid, %len_d;
1179 div.u32 %row, %within, %D_reg;
1180 rem.u32 %col, %within, %D_reg;
1181
1182 mul.lo.u32 %src_idx, %batch_idx, %max_len_reg;
1183 add.u32 %src_idx, %src_idx, %row;
1184 mul.lo.u32 %src_idx, %src_idx, %D_reg;
1185 add.u32 %src_idx, %src_idx, %col;
1186
1187 cvt.u64.u32 %src_off, %src_idx;
1188 shl.b64 %src_off, %src_off, 2;
1189 add.u64 %src_off, %src, %src_off;
1190 ld.global.f32 %val, [%src_off];
1191
1192 cvt.u64.u32 %dst_off, %r_tid;
1193 shl.b64 %dst_off, %dst_off, 2;
1194 add.u64 %dst_off, %dst, %dst_off;
1195 st.global.f32 [%dst_off], %val;
1196
1197DONE:
1198 ret;
1199}
1200";
1201
1202#[cfg(feature = "cuda")]
1212pub(crate) const GELU_PTX: &str = "\
1213.version 7.0
1214.target sm_52
1215.address_size 64
1216
1217.visible .entry gelu_kernel(
1218 .param .u64 in_ptr,
1219 .param .u64 out_ptr,
1220 .param .u32 n
1221) {
1222 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1223 .reg .u64 %in, %out, %off;
1224 .reg .f32 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
1225 .reg .pred %p;
1226
1227 ld.param.u64 %in, [in_ptr];
1228 ld.param.u64 %out, [out_ptr];
1229 ld.param.u32 %n_reg, [n];
1230
1231 mov.u32 %bid, %ctaid.x;
1232 mov.u32 %bdim, %ntid.x;
1233 mov.u32 %r_tid, %tid.x;
1234 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1235
1236 setp.ge.u32 %p, %r_tid, %n_reg;
1237 @%p bra DONE;
1238
1239 cvt.u64.u32 %off, %r_tid;
1240 shl.b64 %off, %off, 2;
1241 add.u64 %in, %in, %off;
1242 add.u64 %out, %out, %off;
1243
1244 ld.global.f32 %x, [%in];
1245
1246 mov.f32 %k, 0f3FDA2720;
1247 mul.f32 %neg_kx, %k, %x;
1248 neg.f32 %neg_kx, %neg_kx;
1249 mul.f32 %neg_kx, %neg_kx, 0f3FB8AA3B;
1250 ex2.approx.f32 %exp_neg, %neg_kx;
1251 mov.f32 %one, 0f3F800000;
1252 add.f32 %denom, %one, %exp_neg;
1253 rcp.approx.f32 %sig, %denom;
1254 mul.f32 %result, %x, %sig;
1255 st.global.f32 [%out], %result;
1256
1257DONE:
1258 ret;
1259}
1260";
1261
1262#[cfg(feature = "cuda")]
1268pub(crate) const GELU_TANH_PTX: &str = "\
1269.version 7.0
1270.target sm_52
1271.address_size 64
1272
1273.visible .entry gelu_tanh_kernel(
1274 .param .u64 in_ptr,
1275 .param .u64 out_ptr,
1276 .param .u32 n
1277) {
1278 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1279 .reg .u64 %in, %out, %off;
1280 .reg .f32 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
1281 .reg .f32 %e2y_m1, %e2y_p1, %th, %one, %half, %log2e, %result;
1282 .reg .pred %p;
1283
1284 ld.param.u64 %in, [in_ptr];
1285 ld.param.u64 %out, [out_ptr];
1286 ld.param.u32 %n_reg, [n];
1287
1288 mov.u32 %bid, %ctaid.x;
1289 mov.u32 %bdim, %ntid.x;
1290 mov.u32 %r_tid, %tid.x;
1291 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1292
1293 setp.ge.u32 %p, %r_tid, %n_reg;
1294 @%p bra DONE;
1295
1296 cvt.u64.u32 %off, %r_tid;
1297 shl.b64 %off, %off, 2;
1298 add.u64 %in, %in, %off;
1299 add.u64 %out, %out, %off;
1300
1301 ld.global.f32 %x, [%in];
1302
1303 // inner = sqrt(2/π) * (x + 0.044715 * x³)
1304 // sqrt(2/π) = 0.7978845608 = 0x3F4C422A
1305 // 0.044715 = 0x3D372713
1306 mul.f32 %x3, %x, %x;
1307 mul.f32 %x3, %x3, %x;
1308 mov.f32 %c, 0f3D372713;
1309 mul.f32 %x3, %c, %x3;
1310 add.f32 %inner, %x, %x3;
1311 mov.f32 %sqrt2pi, 0f3F4C422A;
1312 mul.f32 %y, %sqrt2pi, %inner;
1313
1314 // tanh(y) = (exp(2y) - 1) / (exp(2y) + 1)
1315 // exp(2y) = 2^(2y * log2(e))
1316 mov.f32 %log2e, 0f3FB8AA3B;
1317 add.f32 %two_y, %y, %y;
1318 mul.f32 %two_y, %two_y, %log2e;
1319 ex2.approx.f32 %e2y, %two_y;
1320 mov.f32 %one, 0f3F800000;
1321 sub.f32 %e2y_m1, %e2y, %one;
1322 add.f32 %e2y_p1, %e2y, %one;
1323 rcp.approx.f32 %e2y_p1, %e2y_p1;
1324 mul.f32 %th, %e2y_m1, %e2y_p1;
1325
1326 // out = 0.5 * x * (1 + tanh)
1327 add.f32 %th, %one, %th;
1328 mov.f32 %half, 0f3F000000;
1329 mul.f32 %result, %half, %x;
1330 mul.f32 %result, %result, %th;
1331 st.global.f32 [%out], %result;
1332
1333DONE:
1334 ret;
1335}
1336";
1337
1338#[cfg(feature = "cuda")]
1343pub(crate) const GELU_ERF_PTX: &str = "\
1344.version 7.0
1345.target sm_52
1346.address_size 64
1347
1348.visible .entry gelu_erf_kernel(
1349 .param .u64 in_ptr,
1350 .param .u64 out_ptr,
1351 .param .u32 n
1352) {
1353 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1354 .reg .u64 %in, %out, %off;
1355 .reg .f32 %x, %z, %ax, %one, %half, %log2e;
1356 .reg .f32 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
1357 .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %result;
1358 .reg .pred %pred_ge, %pred_neg;
1359
1360 ld.param.u64 %in, [in_ptr];
1361 ld.param.u64 %out, [out_ptr];
1362 ld.param.u32 %n_reg, [n];
1363
1364 mov.u32 %bid, %ctaid.x;
1365 mov.u32 %bdim, %ntid.x;
1366 mov.u32 %r_tid, %tid.x;
1367 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1368
1369 setp.ge.u32 %pred_ge, %r_tid, %n_reg;
1370 @%pred_ge bra DONE;
1371
1372 cvt.u64.u32 %off, %r_tid;
1373 shl.b64 %off, %off, 2;
1374 add.u64 %in, %in, %off;
1375 add.u64 %out, %out, %off;
1376
1377 ld.global.f32 %x, [%in];
1378 mov.f32 %one, 0f3F800000;
1379 mov.f32 %half, 0f3F000000;
1380 mov.f32 %log2e, 0f3FB8AA3B;
1381
1382 // z = x / sqrt(2) = x * 0.70710678
1383 mov.f32 %z, 0f3F3504F3;
1384 mul.f32 %z, %x, %z;
1385
1386 // |z| for erf(|z|)
1387 abs.f32 %ax, %z;
1388
1389 // t = 1 / (1 + 0.3275911 * |z|)
1390 mov.f32 %p, 0f3EA7BA05;
1391 mul.f32 %t, %p, %ax;
1392 add.f32 %t, %one, %t;
1393 rcp.approx.f32 %t, %t;
1394
1395 // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
1396 mov.f32 %a5, 0f3E0AAAAB;
1397 mov.f32 %a4, 0fBEB3A903;
1398 mov.f32 %a3, 0f3FB506DD;
1399 mov.f32 %a2, 0fBF03C1E1;
1400 mov.f32 %a1, 0f3EA0D6BB;
1401
1402 mul.f32 %pt, %t, %a5;
1403 add.f32 %pt, %pt, %a4;
1404 mul.f32 %pt, %pt, %t;
1405 add.f32 %pt, %pt, %a3;
1406 mul.f32 %pt, %pt, %t;
1407 add.f32 %pt, %pt, %a2;
1408 mul.f32 %pt, %pt, %t;
1409 add.f32 %pt, %pt, %a1;
1410 mul.f32 %pt, %pt, %t;
1411
1412 // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
1413 mul.f32 %z2, %ax, %ax;
1414 neg.f32 %neg_z2, %z2;
1415 mul.f32 %neg_z2, %neg_z2, %log2e;
1416 ex2.approx.f32 %exp_neg_z2, %neg_z2;
1417
1418 // erf(|z|) = 1 - poly * exp(-z^2)
1419 mul.f32 %erf_val, %pt, %exp_neg_z2;
1420 sub.f32 %erf_val, %one, %erf_val;
1421
1422 // erf(-z) = -erf(z), so sign-correct
1423 setp.lt.f32 %pred_neg, %z, 0f00000000;
1424 @%pred_neg neg.f32 %erf_val, %erf_val;
1425
1426 // out = x * 0.5 * (1 + erf(x/sqrt(2)))
1427 add.f32 %erf_val, %one, %erf_val;
1428 mul.f32 %result, %half, %x;
1429 mul.f32 %result, %result, %erf_val;
1430 st.global.f32 [%out], %result;
1431
1432DONE:
1433 ret;
1434}
1435";
1436
1437#[cfg(feature = "cuda")]
1443pub(crate) const GELU_BACKWARD_TANH_PTX: &str = "\
1444.version 7.0
1445.target sm_52
1446.address_size 64
1447
1448.visible .entry gelu_backward_tanh_kernel(
1449 .param .u64 grad_ptr,
1450 .param .u64 input_ptr,
1451 .param .u64 out_ptr,
1452 .param .u32 n
1453) {
1454 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1455 .reg .u64 %grad, %input, %out, %off;
1456 .reg .f32 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
1457 .reg .f32 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half, %log2e;
1458 .reg .f32 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
1459 .reg .pred %p;
1460
1461 ld.param.u64 %grad, [grad_ptr];
1462 ld.param.u64 %input, [input_ptr];
1463 ld.param.u64 %out, [out_ptr];
1464 ld.param.u32 %n_reg, [n];
1465
1466 mov.u32 %bid, %ctaid.x;
1467 mov.u32 %bdim, %ntid.x;
1468 mov.u32 %r_tid, %tid.x;
1469 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1470
1471 setp.ge.u32 %p, %r_tid, %n_reg;
1472 @%p bra DONE;
1473
1474 cvt.u64.u32 %off, %r_tid;
1475 shl.b64 %off, %off, 2;
1476 add.u64 %grad, %grad, %off;
1477 add.u64 %input, %input, %off;
1478 add.u64 %out, %out, %off;
1479
1480 ld.global.f32 %vg, [%grad];
1481 ld.global.f32 %x, [%input];
1482
1483 mov.f32 %one, 0f3F800000;
1484 mov.f32 %half, 0f3F000000;
1485 mov.f32 %log2e, 0f3FB8AA3B;
1486 mov.f32 %sqrt2pi, 0f3F4C422A;
1487 mov.f32 %c, 0f3D372713;
1488 // 3 * 0.044715 = 0.134145 = 0x3E096B8C
1489 mov.f32 %c3, 0f3E096B8C;
1490
1491 // u = sqrt(2/π) * (x + 0.044715 * x³)
1492 mul.f32 %x2, %x, %x;
1493 mul.f32 %x3, %x2, %x;
1494 mul.f32 %x3, %c, %x3;
1495 add.f32 %inner, %x, %x3;
1496 mul.f32 %y, %sqrt2pi, %inner;
1497
1498 // tanh(y) via exp
1499 add.f32 %two_y, %y, %y;
1500 mul.f32 %two_y, %two_y, %log2e;
1501 ex2.approx.f32 %e2y, %two_y;
1502 sub.f32 %e2y_m1, %e2y, %one;
1503 add.f32 %e2y_p1, %e2y, %one;
1504 rcp.approx.f32 %e2y_p1, %e2y_p1;
1505 mul.f32 %th, %e2y_m1, %e2y_p1;
1506
1507 // d/dx = 0.5*(1+tanh) + 0.5*x*(1-tanh²)*sqrt(2/π)*(1+3*0.044715*x²)
1508 // term1 = 0.5 * (1 + th)
1509 add.f32 %term1, %one, %th;
1510 mul.f32 %term1, %half, %term1;
1511
1512 // (1 - th²)
1513 mul.f32 %th2, %th, %th;
1514 sub.f32 %one_m_th2, %one, %th2;
1515
1516 // d_inner = sqrt(2/π) * (1 + 3*0.044715*x²)
1517 mul.f32 %d_inner, %c3, %x2;
1518 add.f32 %d_inner, %one, %d_inner;
1519 mul.f32 %d_inner, %sqrt2pi, %d_inner;
1520
1521 // term2 = 0.5 * x * (1-th²) * d_inner
1522 mul.f32 %term2, %half, %x;
1523 mul.f32 %term2, %term2, %one_m_th2;
1524 mul.f32 %term2, %term2, %d_inner;
1525
1526 add.f32 %d_gelu, %term1, %term2;
1527 mul.f32 %result, %vg, %d_gelu;
1528 st.global.f32 [%out], %result;
1529
1530DONE:
1531 ret;
1532}
1533";
1534
1535#[cfg(feature = "cuda")]
1542pub(crate) const RELU_BACKWARD_PTX: &str = "\
1543.version 7.0
1544.target sm_52
1545.address_size 64
1546
1547.visible .entry relu_backward_kernel(
1548 .param .u64 grad_ptr,
1549 .param .u64 input_ptr,
1550 .param .u64 out_ptr,
1551 .param .u32 n
1552) {
1553 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1554 .reg .u64 %grad, %input, %out, %off;
1555 .reg .f32 %vg, %vi, %zero, %vr;
1556 .reg .pred %p, %pos;
1557
1558 ld.param.u64 %grad, [grad_ptr];
1559 ld.param.u64 %input, [input_ptr];
1560 ld.param.u64 %out, [out_ptr];
1561 ld.param.u32 %n_reg, [n];
1562
1563 mov.u32 %bid, %ctaid.x;
1564 mov.u32 %bdim, %ntid.x;
1565 mov.u32 %r_tid, %tid.x;
1566 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1567
1568 setp.ge.u32 %p, %r_tid, %n_reg;
1569 @%p bra DONE;
1570
1571 cvt.u64.u32 %off, %r_tid;
1572 shl.b64 %off, %off, 2;
1573
1574 add.u64 %grad, %grad, %off;
1575 add.u64 %input, %input, %off;
1576 add.u64 %out, %out, %off;
1577
1578 ld.global.f32 %vg, [%grad];
1579 ld.global.f32 %vi, [%input];
1580 mov.f32 %zero, 0f00000000;
1581 setp.gt.f32 %pos, %vi, %zero;
1582 selp.f32 %vr, %vg, %zero, %pos;
1583 st.global.f32 [%out], %vr;
1584
1585DONE:
1586 ret;
1587}
1588";
1589
1590#[cfg(feature = "cuda")]
1600pub(crate) const GELU_BACKWARD_PTX: &str = "\
1601.version 7.0
1602.target sm_52
1603.address_size 64
1604
1605.visible .entry gelu_backward_kernel(
1606 .param .u64 grad_ptr,
1607 .param .u64 input_ptr,
1608 .param .u64 out_ptr,
1609 .param .u32 n
1610) {
1611 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1612 .reg .u64 %grad, %input, %out, %off;
1613 .reg .f32 %vg, %x, %k, %kx, %neg_kx, %log2e, %exp_neg, %one, %denom, %sig;
1614 .reg .f32 %one_minus_sig, %kx_sig_oms, %dsig, %result;
1615 .reg .pred %p;
1616
1617 ld.param.u64 %grad, [grad_ptr];
1618 ld.param.u64 %input, [input_ptr];
1619 ld.param.u64 %out, [out_ptr];
1620 ld.param.u32 %n_reg, [n];
1621
1622 mov.u32 %bid, %ctaid.x;
1623 mov.u32 %bdim, %ntid.x;
1624 mov.u32 %r_tid, %tid.x;
1625 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1626
1627 setp.ge.u32 %p, %r_tid, %n_reg;
1628 @%p bra DONE;
1629
1630 cvt.u64.u32 %off, %r_tid;
1631 shl.b64 %off, %off, 2;
1632
1633 add.u64 %grad, %grad, %off;
1634 add.u64 %input, %input, %off;
1635 add.u64 %out, %out, %off;
1636
1637 ld.global.f32 %vg, [%grad];
1638 ld.global.f32 %x, [%input];
1639
1640 // sig = sigmoid(1.702 * x)
1641 mov.f32 %k, 0f3FDA2720;
1642 mul.f32 %kx, %k, %x;
1643 neg.f32 %neg_kx, %kx;
1644 mov.f32 %log2e, 0f3FB8AA3B;
1645 mul.f32 %neg_kx, %neg_kx, %log2e;
1646 ex2.approx.f32 %exp_neg, %neg_kx;
1647 mov.f32 %one, 0f3F800000;
1648 add.f32 %denom, %one, %exp_neg;
1649 rcp.approx.f32 %sig, %denom;
1650
1651 // d/dx gelu(x) = sig + k * x * sig * (1 - sig)
1652 sub.f32 %one_minus_sig, %one, %sig;
1653 mul.f32 %kx_sig_oms, %kx, %sig;
1654 mul.f32 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
1655 add.f32 %dsig, %sig, %kx_sig_oms;
1656
1657 // out = grad * d_gelu
1658 mul.f32 %result, %vg, %dsig;
1659 st.global.f32 [%out], %result;
1660
1661DONE:
1662 ret;
1663}
1664";
1665
1666#[cfg(feature = "cuda")]
1674pub(crate) const GELU_BACKWARD_ERF_PTX: &str = "\
1675.version 7.0
1676.target sm_52
1677.address_size 64
1678
1679.visible .entry gelu_backward_erf_kernel(
1680 .param .u64 grad_ptr,
1681 .param .u64 input_ptr,
1682 .param .u64 out_ptr,
1683 .param .u32 n
1684) {
1685 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1686 .reg .u64 %grad, %input, %out, %off;
1687 .reg .f32 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
1688 .reg .f32 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
1689 .reg .f32 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
1690 .reg .f32 %d_gelu, %result;
1691 .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %log2e;
1692 .reg .pred %pred_ge, %pred_neg;
1693
1694 ld.param.u64 %grad, [grad_ptr];
1695 ld.param.u64 %input, [input_ptr];
1696 ld.param.u64 %out, [out_ptr];
1697 ld.param.u32 %n_reg, [n];
1698
1699 mov.u32 %bid, %ctaid.x;
1700 mov.u32 %bdim, %ntid.x;
1701 mov.u32 %r_tid, %tid.x;
1702 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1703
1704 setp.ge.u32 %pred_ge, %r_tid, %n_reg;
1705 @%pred_ge bra DONE;
1706
1707 cvt.u64.u32 %off, %r_tid;
1708 shl.b64 %off, %off, 2;
1709
1710 add.u64 %grad, %grad, %off;
1711 add.u64 %input, %input, %off;
1712 add.u64 %out, %out, %off;
1713
1714 ld.global.f32 %vg, [%grad];
1715 ld.global.f32 %x, [%input];
1716
1717 mov.f32 %one, 0f3F800000;
1718 mov.f32 %half, 0f3F000000;
1719
1720 // z = x / sqrt(2) = x * 0.70710678
1721 mov.f32 %z, 0f3F3504F3;
1722 mul.f32 %z, %x, %z;
1723
1724 // |z| for erf(|z|)
1725 abs.f32 %ax, %z;
1726
1727 // t = 1 / (1 + 0.3275911 * |z|)
1728 mov.f32 %p, 0f3EA7BA05;
1729 mul.f32 %t, %p, %ax;
1730 add.f32 %t, %one, %t;
1731 rcp.approx.f32 %t, %t;
1732
1733 // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
1734 mov.f32 %a5, 0f3E0AAAAB;
1735 mov.f32 %a4, 0fBEB3A903;
1736 mov.f32 %a3, 0f3FB506DD;
1737 mov.f32 %a2, 0fBF03C1E1;
1738 mov.f32 %a1, 0f3EA0D6BB;
1739
1740 mul.f32 %pt, %t, %a5;
1741 add.f32 %pt, %pt, %a4;
1742 mul.f32 %pt, %pt, %t;
1743 add.f32 %pt, %pt, %a3;
1744 mul.f32 %pt, %pt, %t;
1745 add.f32 %pt, %pt, %a2;
1746 mul.f32 %pt, %pt, %t;
1747 add.f32 %pt, %pt, %a1;
1748 mul.f32 %pt, %pt, %t;
1749
1750 // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
1751 mul.f32 %z2, %ax, %ax;
1752 neg.f32 %neg_z2, %z2;
1753 mov.f32 %log2e, 0f3FB8AA3B;
1754 mul.f32 %neg_z2, %neg_z2, %log2e;
1755 ex2.approx.f32 %exp_neg_z2, %neg_z2;
1756
1757 // erf(|z|) = 1 - poly * exp(-z^2)
1758 mul.f32 %erf_val, %pt, %exp_neg_z2;
1759 sub.f32 %erf_val, %one, %erf_val;
1760
1761 // erf(-z) = -erf(z), so sign-correct
1762 setp.lt.f32 %pred_neg, %z, 0f00000000;
1763 @%pred_neg neg.f32 %erf_val, %erf_val;
1764
1765 // Φ(x) = 0.5 * (1 + erf(x/sqrt(2)))
1766 add.f32 %cdf, %one, %erf_val;
1767 mul.f32 %cdf, %half, %cdf;
1768
1769 // φ(x) = exp(-x²/2) / sqrt(2π)
1770 // exp(-x²/2):
1771 mul.f32 %neg_x2h, %x, %x;
1772 mul.f32 %neg_x2h, %neg_x2h, %half;
1773 neg.f32 %neg_x2h, %neg_x2h;
1774 mul.f32 %neg_x2h, %neg_x2h, %log2e;
1775 ex2.approx.f32 %exp_neg_x2h, %neg_x2h;
1776
1777 // 1/sqrt(2π) = 0.39894228
1778 mov.f32 %inv_sqrt_2pi, 0f3ECC4220;
1779 mul.f32 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
1780
1781 // d/dx gelu(x) = Φ(x) + x * φ(x)
1782 mul.f32 %x_pdf, %x, %pdf;
1783 add.f32 %d_gelu, %cdf, %x_pdf;
1784
1785 // out = grad * d_gelu
1786 mul.f32 %result, %vg, %d_gelu;
1787 st.global.f32 [%out], %result;
1788
1789DONE:
1790 ret;
1791}
1792";
1793
1794#[cfg(feature = "cuda")]
1801pub(crate) const INDEX_SELECT_1D_PTX: &str = "\
1802.version 7.0
1803.target sm_52
1804.address_size 64
1805
1806.visible .entry index_select_1d_kernel(
1807 .param .u64 input_ptr,
1808 .param .u64 indices_ptr,
1809 .param .u64 out_ptr,
1810 .param .u32 n_indices
1811) {
1812 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
1813 .reg .u64 %input, %indices, %out, %off, %addr;
1814 .reg .f32 %idx_f, %val;
1815 .reg .pred %p;
1816
1817 ld.param.u64 %input, [input_ptr];
1818 ld.param.u64 %indices, [indices_ptr];
1819 ld.param.u64 %out, [out_ptr];
1820 ld.param.u32 %n_reg, [n_indices];
1821
1822 mov.u32 %bid, %ctaid.x;
1823 mov.u32 %bdim, %ntid.x;
1824 mov.u32 %r_tid, %tid.x;
1825 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1826
1827 setp.ge.u32 %p, %r_tid, %n_reg;
1828 @%p bra DONE;
1829
1830 // Byte offset for thread
1831 cvt.u64.u32 %off, %r_tid;
1832 shl.b64 %off, %off, 2;
1833
1834 // Read indices[tid] (f32 -> u32)
1835 add.u64 %addr, %indices, %off;
1836 ld.global.f32 %idx_f, [%addr];
1837 cvt.rzi.u32.f32 %idx, %idx_f;
1838
1839 // Read input[idx]
1840 cvt.u64.u32 %addr, %idx;
1841 shl.b64 %addr, %addr, 2;
1842 add.u64 %addr, %input, %addr;
1843 ld.global.f32 %val, [%addr];
1844
1845 // Write output[tid]
1846 add.u64 %addr, %out, %off;
1847 st.global.f32 [%addr], %val;
1848
1849DONE:
1850 ret;
1851}
1852";
1853
1854#[cfg(feature = "cuda")]
1863pub(crate) const SCATTER_ADD_1D_PTX: &str = "\
1864.version 7.0
1865.target sm_52
1866.address_size 64
1867
1868.visible .entry scatter_add_1d_kernel(
1869 .param .u64 grad_output_ptr,
1870 .param .u64 indices_ptr,
1871 .param .u64 grad_input_ptr,
1872 .param .u32 n_indices
1873) {
1874 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
1875 .reg .u64 %go, %indices, %gi, %off, %addr;
1876 .reg .f32 %idx_f, %grad_val, %dummy;
1877 .reg .pred %p;
1878
1879 ld.param.u64 %go, [grad_output_ptr];
1880 ld.param.u64 %indices, [indices_ptr];
1881 ld.param.u64 %gi, [grad_input_ptr];
1882 ld.param.u32 %n_reg, [n_indices];
1883
1884 mov.u32 %bid, %ctaid.x;
1885 mov.u32 %bdim, %ntid.x;
1886 mov.u32 %r_tid, %tid.x;
1887 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1888
1889 setp.ge.u32 %p, %r_tid, %n_reg;
1890 @%p bra DONE;
1891
1892 // Byte offset for thread
1893 cvt.u64.u32 %off, %r_tid;
1894 shl.b64 %off, %off, 2;
1895
1896 // Read grad_output[tid]
1897 add.u64 %addr, %go, %off;
1898 ld.global.f32 %grad_val, [%addr];
1899
1900 // Read indices[tid] (f32 -> u32)
1901 add.u64 %addr, %indices, %off;
1902 ld.global.f32 %idx_f, [%addr];
1903 cvt.rzi.u32.f32 %idx, %idx_f;
1904
1905 // Atomic add: grad_input[idx] += grad_val
1906 cvt.u64.u32 %addr, %idx;
1907 shl.b64 %addr, %addr, 2;
1908 add.u64 %addr, %gi, %addr;
1909 atom.global.add.f32 %dummy, [%addr], %grad_val;
1910
1911DONE:
1912 ret;
1913}
1914";
1915
1916#[cfg(feature = "cuda")]
1923pub(crate) const MASKED_FILL_PTX: &str = "\
1924.version 7.0
1925.target sm_52
1926.address_size 64
1927
1928.visible .entry masked_fill_kernel(
1929 .param .u64 input_ptr,
1930 .param .u64 mask_ptr,
1931 .param .u64 out_ptr,
1932 .param .f32 fill_value,
1933 .param .u32 n
1934) {
1935 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1936 .reg .u64 %input, %mask, %out, %off;
1937 .reg .f32 %in_val, %mask_val, %fill, %result, %half;
1938 .reg .pred %p, %pmask;
1939
1940 ld.param.u64 %input, [input_ptr];
1941 ld.param.u64 %mask, [mask_ptr];
1942 ld.param.u64 %out, [out_ptr];
1943 ld.param.f32 %fill, [fill_value];
1944 ld.param.u32 %n_reg, [n];
1945
1946 mov.u32 %bid, %ctaid.x;
1947 mov.u32 %bdim, %ntid.x;
1948 mov.u32 %r_tid, %tid.x;
1949 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1950
1951 setp.ge.u32 %p, %r_tid, %n_reg;
1952 @%p bra DONE;
1953
1954 cvt.u64.u32 %off, %r_tid;
1955 shl.b64 %off, %off, 2;
1956
1957 add.u64 %input, %input, %off;
1958 add.u64 %mask, %mask, %off;
1959 add.u64 %out, %out, %off;
1960
1961 ld.global.f32 %in_val, [%input];
1962 ld.global.f32 %mask_val, [%mask];
1963 mov.f32 %half, 0f3F000000;
1964 setp.ge.f32 %pmask, %mask_val, %half;
1965 selp.f32 %result, %fill, %in_val, %pmask;
1966 st.global.f32 [%out], %result;
1967
1968DONE:
1969 ret;
1970}
1971";
1972
1973#[cfg(feature = "cuda")]
1980pub(crate) const MASKED_ZERO_PTX: &str = "\
1981.version 7.0
1982.target sm_52
1983.address_size 64
1984
1985.visible .entry masked_zero_kernel(
1986 .param .u64 grad_ptr,
1987 .param .u64 mask_ptr,
1988 .param .u64 out_ptr,
1989 .param .u32 n
1990) {
1991 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1992 .reg .u64 %grad, %mask, %out, %off;
1993 .reg .f32 %vg, %mask_val, %zero, %result, %half;
1994 .reg .pred %p, %pmask;
1995
1996 ld.param.u64 %grad, [grad_ptr];
1997 ld.param.u64 %mask, [mask_ptr];
1998 ld.param.u64 %out, [out_ptr];
1999 ld.param.u32 %n_reg, [n];
2000
2001 mov.u32 %bid, %ctaid.x;
2002 mov.u32 %bdim, %ntid.x;
2003 mov.u32 %r_tid, %tid.x;
2004 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2005
2006 setp.ge.u32 %p, %r_tid, %n_reg;
2007 @%p bra DONE;
2008
2009 cvt.u64.u32 %off, %r_tid;
2010 shl.b64 %off, %off, 2;
2011
2012 add.u64 %grad, %grad, %off;
2013 add.u64 %mask, %mask, %off;
2014 add.u64 %out, %out, %off;
2015
2016 ld.global.f32 %vg, [%grad];
2017 ld.global.f32 %mask_val, [%mask];
2018 mov.f32 %zero, 0f00000000;
2019 mov.f32 %half, 0f3F000000;
2020 setp.ge.f32 %pmask, %mask_val, %half;
2021 selp.f32 %result, %zero, %vg, %pmask;
2022 st.global.f32 [%out], %result;
2023
2024DONE:
2025 ret;
2026}
2027";
2028
2029#[cfg(feature = "cuda")]
2034pub(crate) const SIGMOID_BACKWARD_PTX: &str = "\
2035.version 7.0
2036.target sm_52
2037.address_size 64
2038
2039.visible .entry sigmoid_backward_kernel(
2040 .param .u64 grad_ptr,
2041 .param .u64 output_ptr,
2042 .param .u64 out_ptr,
2043 .param .u32 n
2044) {
2045 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2046 .reg .u64 %grad, %output, %out, %off;
2047 .reg .f32 %vg, %vo, %one, %one_minus_o, %result;
2048 .reg .pred %p;
2049
2050 ld.param.u64 %grad, [grad_ptr];
2051 ld.param.u64 %output, [output_ptr];
2052 ld.param.u64 %out, [out_ptr];
2053 ld.param.u32 %n_reg, [n];
2054
2055 mov.u32 %bid, %ctaid.x;
2056 mov.u32 %bdim, %ntid.x;
2057 mov.u32 %r_tid, %tid.x;
2058 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2059
2060 setp.ge.u32 %p, %r_tid, %n_reg;
2061 @%p bra DONE;
2062
2063 cvt.u64.u32 %off, %r_tid;
2064 shl.b64 %off, %off, 2;
2065
2066 add.u64 %grad, %grad, %off;
2067 add.u64 %output, %output, %off;
2068 add.u64 %out, %out, %off;
2069
2070 ld.global.f32 %vg, [%grad];
2071 ld.global.f32 %vo, [%output];
2072 mov.f32 %one, 0f3F800000;
2073 sub.f32 %one_minus_o, %one, %vo;
2074 mul.f32 %result, %vo, %one_minus_o;
2075 mul.f32 %result, %vg, %result;
2076 st.global.f32 [%out], %result;
2077
2078DONE:
2079 ret;
2080}
2081";
2082
2083#[cfg(feature = "cuda")]
2088pub(crate) const TANH_BACKWARD_PTX: &str = "\
2089.version 7.0
2090.target sm_52
2091.address_size 64
2092
2093.visible .entry tanh_backward_kernel(
2094 .param .u64 grad_ptr,
2095 .param .u64 output_ptr,
2096 .param .u64 out_ptr,
2097 .param .u32 n
2098) {
2099 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2100 .reg .u64 %grad, %output, %out, %off;
2101 .reg .f32 %vg, %vo, %one, %o_sq, %one_minus_sq, %result;
2102 .reg .pred %p;
2103
2104 ld.param.u64 %grad, [grad_ptr];
2105 ld.param.u64 %output, [output_ptr];
2106 ld.param.u64 %out, [out_ptr];
2107 ld.param.u32 %n_reg, [n];
2108
2109 mov.u32 %bid, %ctaid.x;
2110 mov.u32 %bdim, %ntid.x;
2111 mov.u32 %r_tid, %tid.x;
2112 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2113
2114 setp.ge.u32 %p, %r_tid, %n_reg;
2115 @%p bra DONE;
2116
2117 cvt.u64.u32 %off, %r_tid;
2118 shl.b64 %off, %off, 2;
2119
2120 add.u64 %grad, %grad, %off;
2121 add.u64 %output, %output, %off;
2122 add.u64 %out, %out, %off;
2123
2124 ld.global.f32 %vg, [%grad];
2125 ld.global.f32 %vo, [%output];
2126 mov.f32 %one, 0f3F800000;
2127 mul.f32 %o_sq, %vo, %vo;
2128 sub.f32 %one_minus_sq, %one, %o_sq;
2129 mul.f32 %result, %vg, %one_minus_sq;
2130 st.global.f32 [%out], %result;
2131
2132DONE:
2133 ret;
2134}
2135";
2136
2137#[cfg(feature = "cuda")]
2146pub(crate) const SOFTMAX_BACKWARD_PTX: &str = "\
2147.version 7.0\n\
2148.target sm_52\n\
2149.address_size 64\n\
2150\n\
2151.shared .align 4 .f32 sdata[256];\n\
2152\n\
2153.visible .entry softmax_backward_kernel(\n\
2154 .param .u64 grad_ptr,\n\
2155 .param .u64 output_ptr,\n\
2156 .param .u64 out_ptr,\n\
2157 .param .u32 rows,\n\
2158 .param .u32 cols\n\
2159) {\n\
2160 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
2161 .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
2162 .reg .f32 %vg, %vo, %dot, %other_val, %diff, %result;\n\
2163 .reg .pred %p, %loop_p, %reduce_p;\n\
2164\n\
2165 ld.param.u64 %grad, [grad_ptr];\n\
2166 ld.param.u64 %output, [output_ptr];\n\
2167 ld.param.u64 %out, [out_ptr];\n\
2168 ld.param.u32 %rows_reg, [rows];\n\
2169 ld.param.u32 %cols_reg, [cols];\n\
2170\n\
2171 mov.u32 %bid, %ctaid.x;\n\
2172 mov.u32 %bdim, %ntid.x;\n\
2173 mov.u32 %r_tid, %tid.x;\n\
2174 mov.u64 %sbase, sdata;\n\
2175\n\
2176 setp.ge.u32 %p, %bid, %rows_reg;\n\
2177 @%p bra DONE;\n\
2178\n\
2179 // row_off = bid * cols * 4 (byte offset)\n\
2180 cvt.u64.u32 %row_off, %bid;\n\
2181 cvt.u64.u32 %off, %cols_reg;\n\
2182 mul.lo.u64 %row_off, %row_off, %off;\n\
2183 shl.b64 %row_off, %row_off, 2;\n\
2184\n\
2185 // Phase 1: compute partial dot = sum(grad[j] * output[j]) for this thread's elements\n\
2186 mov.f32 %dot, 0f00000000;\n\
2187 mov.u32 %j, %r_tid;\n\
2188DOT_LOOP:\n\
2189 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
2190 @%loop_p bra DOT_LOOP_DONE;\n\
2191 cvt.u64.u32 %off, %j;\n\
2192 shl.b64 %off, %off, 2;\n\
2193 add.u64 %saddr, %grad, %off;\n\
2194 add.u64 %saddr, %saddr, %row_off;\n\
2195 ld.global.f32 %vg, [%saddr];\n\
2196 add.u64 %saddr, %output, %off;\n\
2197 add.u64 %saddr, %saddr, %row_off;\n\
2198 ld.global.f32 %vo, [%saddr];\n\
2199 fma.rn.f32 %dot, %vg, %vo, %dot;\n\
2200 add.u32 %j, %j, %bdim;\n\
2201 bra DOT_LOOP;\n\
2202DOT_LOOP_DONE:\n\
2203\n\
2204 // Store partial dot into shared memory and reduce\n\
2205 cvt.u64.u32 %off, %r_tid;\n\
2206 shl.b64 %off, %off, 2;\n\
2207 add.u64 %saddr, %sbase, %off;\n\
2208 st.shared.f32 [%saddr], %dot;\n\
2209 bar.sync 0;\n\
2210\n\
2211 mov.u32 %half, %bdim;\n\
2212DOT_REDUCE:\n\
2213 shr.u32 %half, %half, 1;\n\
2214 setp.eq.u32 %reduce_p, %half, 0;\n\
2215 @%reduce_p bra DOT_REDUCE_DONE;\n\
2216 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
2217 @%reduce_p bra DOT_REDUCE_SKIP;\n\
2218 add.u32 %other_tid, %r_tid, %half;\n\
2219 cvt.u64.u32 %off, %other_tid;\n\
2220 shl.b64 %off, %off, 2;\n\
2221 add.u64 %saddr, %sbase, %off;\n\
2222 ld.shared.f32 %other_val, [%saddr];\n\
2223 cvt.u64.u32 %off, %r_tid;\n\
2224 shl.b64 %off, %off, 2;\n\
2225 add.u64 %saddr, %sbase, %off;\n\
2226 ld.shared.f32 %dot, [%saddr];\n\
2227 add.f32 %dot, %dot, %other_val;\n\
2228 st.shared.f32 [%saddr], %dot;\n\
2229DOT_REDUCE_SKIP:\n\
2230 bar.sync 0;\n\
2231 bra DOT_REDUCE;\n\
2232DOT_REDUCE_DONE:\n\
2233\n\
2234 // Broadcast dot to all threads\n\
2235 ld.shared.f32 %dot, [sdata];\n\
2236 bar.sync 0;\n\
2237\n\
2238 // Phase 2: out[j] = output[j] * (grad[j] - dot)\n\
2239 mov.u32 %j, %r_tid;\n\
2240WRITE_LOOP:\n\
2241 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
2242 @%loop_p bra WRITE_LOOP_DONE;\n\
2243 cvt.u64.u32 %off, %j;\n\
2244 shl.b64 %off, %off, 2;\n\
2245 add.u64 %saddr, %grad, %off;\n\
2246 add.u64 %saddr, %saddr, %row_off;\n\
2247 ld.global.f32 %vg, [%saddr];\n\
2248 add.u64 %saddr, %output, %off;\n\
2249 add.u64 %saddr, %saddr, %row_off;\n\
2250 ld.global.f32 %vo, [%saddr];\n\
2251 sub.f32 %diff, %vg, %dot;\n\
2252 mul.f32 %result, %vo, %diff;\n\
2253 add.u64 %saddr, %out, %off;\n\
2254 add.u64 %saddr, %saddr, %row_off;\n\
2255 st.global.f32 [%saddr], %result;\n\
2256 add.u32 %j, %j, %bdim;\n\
2257 bra WRITE_LOOP;\n\
2258WRITE_LOOP_DONE:\n\
2259\n\
2260DONE:\n\
2261 ret;\n\
2262}\n\
2263";
2264
2265#[cfg(feature = "cuda")]
2279pub(crate) const REDUCE_SUM_PTX: &str = "\
2280.version 7.0
2281.target sm_52
2282.address_size 64
2283
2284// Shared memory for intra-block reduction (256 floats = 1024 bytes).
2285.shared .align 4 .f32 sdata[256];
2286
2287.visible .entry reduce_sum_kernel(
2288 .param .u64 in_ptr,
2289 .param .u64 out_ptr,
2290 .param .u32 n
2291) {
2292 .reg .u32 %tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
2293 .reg .u64 %in, %out, %off;
2294 .reg .f32 %sum, %other;
2295 .reg .pred %p, %ptid;
2296
2297 ld.param.u64 %in, [in_ptr];
2298 ld.param.u64 %out, [out_ptr];
2299 ld.param.u32 %n_reg, [n];
2300
2301 mov.u32 %tid, %tid.x;
2302 mov.u32 %bid, %ctaid.x;
2303 mov.u32 %bdim, %ntid.x;
2304 mov.u32 %gdim, %nctaid.x;
2305
2306 // Grid-stride accumulation: each thread sums multiple elements.
2307 // idx = bid * bdim + tid; stride = bdim * gdim
2308 mad.lo.u32 %idx, %bid, %bdim, %tid;
2309 mul.lo.u32 %stride, %bdim, %gdim;
2310 mov.f32 %sum, 0f00000000;
2311
2312GRID_LOOP:
2313 setp.ge.u32 %p, %idx, %n_reg;
2314 @%p bra GRID_DONE;
2315
2316 cvt.u64.u32 %off, %idx;
2317 shl.b64 %off, %off, 2;
2318 add.u64 %off, %in, %off;
2319 ld.global.f32 %other, [%off];
2320 add.f32 %sum, %sum, %other;
2321 add.u32 %idx, %idx, %stride;
2322 bra GRID_LOOP;
2323
2324GRID_DONE:
2325 // Write thread's partial sum to shared memory.
2326 cvt.u64.u32 %off, %tid;
2327 shl.b64 %off, %off, 2;
2328 st.shared.f32 [sdata + %off], %sum;
2329 bar.sync 0;
2330
2331 // Tree reduction in shared memory.
2332 mov.u32 %half, 128;
2333TREE_LOOP:
2334 setp.lt.u32 %p, %half, 1;
2335 @%p bra TREE_DONE;
2336
2337 setp.ge.u32 %ptid, %tid, %half;
2338 @%ptid bra TREE_SKIP;
2339
2340 // Load partner's value from sdata[tid + half].
2341 add.u32 %idx, %tid, %half;
2342 cvt.u64.u32 %off, %idx;
2343 shl.b64 %off, %off, 2;
2344 ld.shared.f32 %other, [sdata + %off];
2345 // Load own value.
2346 cvt.u64.u32 %off, %tid;
2347 shl.b64 %off, %off, 2;
2348 ld.shared.f32 %sum, [sdata + %off];
2349 add.f32 %sum, %sum, %other;
2350 st.shared.f32 [sdata + %off], %sum;
2351
2352TREE_SKIP:
2353 bar.sync 0;
2354 shr.u32 %half, %half, 1;
2355 bra TREE_LOOP;
2356
2357TREE_DONE:
2358 // Thread 0 writes block result.
2359 setp.ne.u32 %ptid, %tid, 0;
2360 @%ptid bra END;
2361
2362 ld.shared.f32 %sum, [sdata];
2363 cvt.u64.u32 %off, %bid;
2364 shl.b64 %off, %off, 2;
2365 add.u64 %out, %out, %off;
2366 st.global.f32 [%out], %sum;
2367
2368END:
2369 ret;
2370}
2371";
2372
2373#[cfg(feature = "cuda")]
2377pub(crate) const SUM_AXIS_PTX: &str = "\
2378.version 7.0
2379.target sm_52
2380.address_size 64
2381
2382.visible .entry sum_axis_kernel(
2383 .param .u64 input_ptr,
2384 .param .u64 output_ptr,
2385 .param .u32 outer_size,
2386 .param .u32 axis_size,
2387 .param .u32 inner_size,
2388 .param .u32 total_output
2389) {
2390 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %axis_sz, %inner_sz;
2391 .reg .u32 %outer_idx, %inner_idx, %k, %tmp;
2392 .reg .u64 %in, %out, %off, %addr;
2393 .reg .f32 %val, %sum;
2394 .reg .pred %p, %lp;
2395
2396 ld.param.u64 %in, [input_ptr];
2397 ld.param.u64 %out, [output_ptr];
2398 ld.param.u32 %outer_sz, [outer_size];
2399 ld.param.u32 %axis_sz, [axis_size];
2400 ld.param.u32 %inner_sz, [inner_size];
2401 ld.param.u32 %n_reg, [total_output];
2402
2403 mov.u32 %bid, %ctaid.x;
2404 mov.u32 %bdim, %ntid.x;
2405 mov.u32 %r_tid, %tid.x;
2406 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2407
2408 setp.ge.u32 %p, %r_tid, %n_reg;
2409 @%p bra DONE;
2410
2411 // outer_idx = r_tid / inner_size
2412 div.u32 %outer_idx, %r_tid, %inner_sz;
2413 // inner_idx = r_tid % inner_size
2414 rem.u32 %inner_idx, %r_tid, %inner_sz;
2415
2416 // base = outer_idx * axis_size * inner_size + inner_idx
2417 mul.lo.u32 %tmp, %outer_idx, %axis_sz;
2418 mul.lo.u32 %tmp, %tmp, %inner_sz;
2419 add.u32 %tmp, %tmp, %inner_idx;
2420
2421 mov.f32 %sum, 0f00000000;
2422 mov.u32 %k, 0;
2423SUM_LOOP:
2424 setp.ge.u32 %lp, %k, %axis_sz;
2425 @%lp bra SUM_LOOP_DONE;
2426
2427 // addr = in + (tmp + k * inner_size) * 4
2428 mul.lo.u32 %inner_idx, %k, %inner_sz;
2429 add.u32 %inner_idx, %tmp, %inner_idx;
2430 cvt.u64.u32 %off, %inner_idx;
2431 shl.b64 %off, %off, 2;
2432 add.u64 %addr, %in, %off;
2433 ld.global.f32 %val, [%addr];
2434 add.f32 %sum, %sum, %val;
2435
2436 add.u32 %k, %k, 1;
2437 bra SUM_LOOP;
2438SUM_LOOP_DONE:
2439
2440 // output[r_tid] = sum
2441 cvt.u64.u32 %off, %r_tid;
2442 shl.b64 %off, %off, 2;
2443 add.u64 %addr, %out, %off;
2444 st.global.f32 [%addr], %sum;
2445
2446DONE:
2447 ret;
2448}
2449";
2450
2451#[cfg(feature = "cuda")]
2461pub(crate) const LAYERNORM_PTX: &str = "\
2462.version 7.0
2463.target sm_52
2464.address_size 64
2465
2466.shared .align 4 .f32 sdata[256];
2467
2468.visible .entry layernorm_kernel(
2469 .param .u64 in_ptr,
2470 .param .u64 out_ptr,
2471 .param .u64 w_ptr,
2472 .param .u64 b_ptr,
2473 .param .u32 rows,
2474 .param .u32 cols,
2475 .param .f32 eps
2476) {
2477 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
2478 .reg .u64 %in, %out, %w, %b, %row_off, %off, %sbase, %saddr;
2479 .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %normed, %wv, %bv, %result, %other_val, %n_f;
2480 .reg .pred %p, %lp, %rp;
2481
2482 ld.param.u64 %in, [in_ptr];
2483 ld.param.u64 %out, [out_ptr];
2484 ld.param.u64 %w, [w_ptr];
2485 ld.param.u64 %b, [b_ptr];
2486 ld.param.u32 %rows_reg, [rows];
2487 ld.param.u32 %cols_reg, [cols];
2488 ld.param.f32 %eps_r, [eps];
2489
2490 mov.u64 %sbase, sdata;
2491
2492 mov.u32 %r_bid, %ctaid.x;
2493 mov.u32 %r_bdim, %ntid.x;
2494 mov.u32 %r_tid, %tid.x;
2495
2496 setp.ge.u32 %p, %r_bid, %rows_reg;
2497 @%p bra DONE;
2498
2499 cvt.u64.u32 %row_off, %r_bid;
2500 cvt.u64.u32 %off, %cols_reg;
2501 mul.lo.u64 %row_off, %row_off, %off;
2502 shl.b64 %row_off, %row_off, 2;
2503 cvt.rn.f32.u32 %n_f, %cols_reg;
2504
2505 mov.f32 %mean, 0f00000000;
2506 mov.u32 %j, %r_tid;
2507SM:
2508 setp.ge.u32 %lp, %j, %cols_reg;
2509 @%lp bra SMD;
2510 cvt.u64.u32 %off, %j;
2511 shl.b64 %off, %off, 2;
2512 add.u64 %off, %in, %off;
2513 add.u64 %off, %off, %row_off;
2514 ld.global.f32 %val, [%off];
2515 add.f32 %mean, %mean, %val;
2516 add.u32 %j, %j, %r_bdim;
2517 bra SM;
2518SMD:
2519 cvt.u64.u32 %off, %r_tid;
2520 shl.b64 %off, %off, 2;
2521 add.u64 %saddr, %sbase, %off;
2522 st.shared.f32 [%saddr], %mean;
2523 bar.sync 0;
2524 mov.u32 %half, %r_bdim;
2525MR:
2526 shr.u32 %half, %half, 1;
2527 setp.eq.u32 %rp, %half, 0;
2528 @%rp bra MRD;
2529 setp.ge.u32 %rp, %r_tid, %half;
2530 @%rp bra MRS;
2531 add.u32 %r_otid, %r_tid, %half;
2532 cvt.u64.u32 %off, %r_otid;
2533 shl.b64 %off, %off, 2;
2534 add.u64 %saddr, %sbase, %off;
2535 ld.shared.f32 %other_val, [%saddr];
2536 cvt.u64.u32 %off, %r_tid;
2537 shl.b64 %off, %off, 2;
2538 add.u64 %saddr, %sbase, %off;
2539 ld.shared.f32 %mean, [%saddr];
2540 add.f32 %mean, %mean, %other_val;
2541 add.u64 %saddr, %sbase, %off;
2542 st.shared.f32 [%saddr], %mean;
2543MRS:
2544 bar.sync 0;
2545 bra MR;
2546MRD:
2547 ld.shared.f32 %mean, [%sbase];
2548 div.approx.f32 %mean, %mean, %n_f;
2549 bar.sync 0;
2550
2551 mov.f32 %var, 0f00000000;
2552 mov.u32 %j, %r_tid;
2553SV:
2554 setp.ge.u32 %lp, %j, %cols_reg;
2555 @%lp bra SVD;
2556 cvt.u64.u32 %off, %j;
2557 shl.b64 %off, %off, 2;
2558 add.u64 %off, %in, %off;
2559 add.u64 %off, %off, %row_off;
2560 ld.global.f32 %val, [%off];
2561 sub.f32 %diff, %val, %mean;
2562 fma.rn.f32 %var, %diff, %diff, %var;
2563 add.u32 %j, %j, %r_bdim;
2564 bra SV;
2565SVD:
2566 cvt.u64.u32 %off, %r_tid;
2567 shl.b64 %off, %off, 2;
2568 add.u64 %saddr, %sbase, %off;
2569 st.shared.f32 [%saddr], %var;
2570 bar.sync 0;
2571 mov.u32 %half, %r_bdim;
2572VR:
2573 shr.u32 %half, %half, 1;
2574 setp.eq.u32 %rp, %half, 0;
2575 @%rp bra VRD;
2576 setp.ge.u32 %rp, %r_tid, %half;
2577 @%rp bra VRS;
2578 add.u32 %r_otid, %r_tid, %half;
2579 cvt.u64.u32 %off, %r_otid;
2580 shl.b64 %off, %off, 2;
2581 add.u64 %saddr, %sbase, %off;
2582 ld.shared.f32 %other_val, [%saddr];
2583 cvt.u64.u32 %off, %r_tid;
2584 shl.b64 %off, %off, 2;
2585 add.u64 %saddr, %sbase, %off;
2586 ld.shared.f32 %var, [%saddr];
2587 add.f32 %var, %var, %other_val;
2588 add.u64 %saddr, %sbase, %off;
2589 st.shared.f32 [%saddr], %var;
2590VRS:
2591 bar.sync 0;
2592 bra VR;
2593VRD:
2594 ld.shared.f32 %var, [%sbase];
2595 div.approx.f32 %var, %var, %n_f;
2596 add.f32 %var, %var, %eps_r;
2597 sqrt.approx.f32 %inv_std, %var;
2598 rcp.approx.f32 %inv_std, %inv_std;
2599 bar.sync 0;
2600
2601 mov.u32 %j, %r_tid;
2602NM:
2603 setp.ge.u32 %lp, %j, %cols_reg;
2604 @%lp bra NMD;
2605 cvt.u64.u32 %off, %j;
2606 shl.b64 %off, %off, 2;
2607 add.u64 %off, %in, %off;
2608 add.u64 %off, %off, %row_off;
2609 ld.global.f32 %val, [%off];
2610 sub.f32 %normed, %val, %mean;
2611 mul.f32 %normed, %normed, %inv_std;
2612 cvt.u64.u32 %off, %j;
2613 shl.b64 %off, %off, 2;
2614 add.u64 %off, %w, %off;
2615 ld.global.f32 %wv, [%off];
2616 cvt.u64.u32 %off, %j;
2617 shl.b64 %off, %off, 2;
2618 add.u64 %off, %b, %off;
2619 ld.global.f32 %bv, [%off];
2620 fma.rn.f32 %result, %wv, %normed, %bv;
2621 cvt.u64.u32 %off, %j;
2622 shl.b64 %off, %off, 2;
2623 add.u64 %off, %out, %off;
2624 add.u64 %off, %off, %row_off;
2625 st.global.f32 [%off], %result;
2626 add.u32 %j, %j, %r_bdim;
2627 bra NM;
2628NMD:
2629
2630DONE:
2631 ret;
2632}
2633";
2634
2635#[cfg(feature = "cuda")]
2660pub(crate) const LAYERNORM_BACKWARD_PTX: &str = "\
2661.version 7.0
2662.target sm_52
2663.address_size 64
2664
2665.shared .align 4 .f32 sdata[256];
2666
2667.visible .entry layernorm_backward_kernel(
2668 .param .u64 in_ptr,
2669 .param .u64 grad_out_ptr,
2670 .param .u64 w_ptr,
2671 .param .u64 grad_in_ptr,
2672 .param .u64 grad_w_ptr,
2673 .param .u64 grad_b_ptr,
2674 .param .u32 rows,
2675 .param .u32 cols,
2676 .param .f32 eps
2677) {
2678 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
2679 .reg .u64 %in, %go, %w, %gi, %gw, %gb, %row_off, %off, %sbase, %saddr, %addr;
2680 .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %x_hat, %wv, %gov;
2681 .reg .f32 %dl_dx_hat, %sum1, %sum2, %other_val, %n_f, %mean1, %mean2, %result;
2682 .reg .pred %p, %lp, %rp;
2683
2684 ld.param.u64 %in, [in_ptr];
2685 ld.param.u64 %go, [grad_out_ptr];
2686 ld.param.u64 %w, [w_ptr];
2687 ld.param.u64 %gi, [grad_in_ptr];
2688 ld.param.u64 %gw, [grad_w_ptr];
2689 ld.param.u64 %gb, [grad_b_ptr];
2690 ld.param.u32 %rows_reg, [rows];
2691 ld.param.u32 %cols_reg, [cols];
2692 ld.param.f32 %eps_r, [eps];
2693
2694 mov.u64 %sbase, sdata;
2695
2696 mov.u32 %r_bid, %ctaid.x;
2697 mov.u32 %r_bdim, %ntid.x;
2698 mov.u32 %r_tid, %tid.x;
2699
2700 setp.ge.u32 %p, %r_bid, %rows_reg;
2701 @%p bra LNB_DONE;
2702
2703 // row_off = bid * cols * 4 (byte offset for this row)
2704 cvt.u64.u32 %row_off, %r_bid;
2705 cvt.u64.u32 %off, %cols_reg;
2706 mul.lo.u64 %row_off, %row_off, %off;
2707 shl.b64 %row_off, %row_off, 2;
2708 cvt.rn.f32.u32 %n_f, %cols_reg;
2709
2710 // ===== Phase 1: Compute mean =====
2711 mov.f32 %mean, 0f00000000;
2712 mov.u32 %j, %r_tid;
2713LNB_SM:
2714 setp.ge.u32 %lp, %j, %cols_reg;
2715 @%lp bra LNB_SMD;
2716 cvt.u64.u32 %off, %j;
2717 shl.b64 %off, %off, 2;
2718 add.u64 %addr, %in, %off;
2719 add.u64 %addr, %addr, %row_off;
2720 ld.global.f32 %val, [%addr];
2721 add.f32 %mean, %mean, %val;
2722 add.u32 %j, %j, %r_bdim;
2723 bra LNB_SM;
2724LNB_SMD:
2725 // Shared memory reduce for mean
2726 cvt.u64.u32 %off, %r_tid;
2727 shl.b64 %off, %off, 2;
2728 add.u64 %saddr, %sbase, %off;
2729 st.shared.f32 [%saddr], %mean;
2730 bar.sync 0;
2731 mov.u32 %half, %r_bdim;
2732LNB_MR:
2733 shr.u32 %half, %half, 1;
2734 setp.eq.u32 %rp, %half, 0;
2735 @%rp bra LNB_MRD;
2736 setp.ge.u32 %rp, %r_tid, %half;
2737 @%rp bra LNB_MRS;
2738 add.u32 %r_otid, %r_tid, %half;
2739 cvt.u64.u32 %off, %r_otid;
2740 shl.b64 %off, %off, 2;
2741 add.u64 %saddr, %sbase, %off;
2742 ld.shared.f32 %other_val, [%saddr];
2743 cvt.u64.u32 %off, %r_tid;
2744 shl.b64 %off, %off, 2;
2745 add.u64 %saddr, %sbase, %off;
2746 ld.shared.f32 %mean, [%saddr];
2747 add.f32 %mean, %mean, %other_val;
2748 st.shared.f32 [%saddr], %mean;
2749LNB_MRS:
2750 bar.sync 0;
2751 bra LNB_MR;
2752LNB_MRD:
2753 ld.shared.f32 %mean, [%sbase];
2754 div.approx.f32 %mean, %mean, %n_f;
2755 bar.sync 0;
2756
2757 // ===== Phase 2: Compute variance =====
2758 mov.f32 %var, 0f00000000;
2759 mov.u32 %j, %r_tid;
2760LNB_SV:
2761 setp.ge.u32 %lp, %j, %cols_reg;
2762 @%lp bra LNB_SVD;
2763 cvt.u64.u32 %off, %j;
2764 shl.b64 %off, %off, 2;
2765 add.u64 %addr, %in, %off;
2766 add.u64 %addr, %addr, %row_off;
2767 ld.global.f32 %val, [%addr];
2768 sub.f32 %diff, %val, %mean;
2769 fma.rn.f32 %var, %diff, %diff, %var;
2770 add.u32 %j, %j, %r_bdim;
2771 bra LNB_SV;
2772LNB_SVD:
2773 // Shared memory reduce for variance
2774 cvt.u64.u32 %off, %r_tid;
2775 shl.b64 %off, %off, 2;
2776 add.u64 %saddr, %sbase, %off;
2777 st.shared.f32 [%saddr], %var;
2778 bar.sync 0;
2779 mov.u32 %half, %r_bdim;
2780LNB_VR:
2781 shr.u32 %half, %half, 1;
2782 setp.eq.u32 %rp, %half, 0;
2783 @%rp bra LNB_VRD;
2784 setp.ge.u32 %rp, %r_tid, %half;
2785 @%rp bra LNB_VRS;
2786 add.u32 %r_otid, %r_tid, %half;
2787 cvt.u64.u32 %off, %r_otid;
2788 shl.b64 %off, %off, 2;
2789 add.u64 %saddr, %sbase, %off;
2790 ld.shared.f32 %other_val, [%saddr];
2791 cvt.u64.u32 %off, %r_tid;
2792 shl.b64 %off, %off, 2;
2793 add.u64 %saddr, %sbase, %off;
2794 ld.shared.f32 %var, [%saddr];
2795 add.f32 %var, %var, %other_val;
2796 st.shared.f32 [%saddr], %var;
2797LNB_VRS:
2798 bar.sync 0;
2799 bra LNB_VR;
2800LNB_VRD:
2801 ld.shared.f32 %var, [%sbase];
2802 div.approx.f32 %var, %var, %n_f;
2803 add.f32 %var, %var, %eps_r;
2804 sqrt.approx.f32 %inv_std, %var;
2805 rcp.approx.f32 %inv_std, %inv_std;
2806 bar.sync 0;
2807
2808 // ===== Phase 3: Compute sum1 = sum(dl_dx_hat), sum2 = sum(dl_dx_hat * x_hat) =====
2809 // Also accumulate grad_weight and grad_bias via atomicAdd
2810 mov.f32 %sum1, 0f00000000;
2811 mov.f32 %sum2, 0f00000000;
2812 mov.u32 %j, %r_tid;
2813LNB_S12:
2814 setp.ge.u32 %lp, %j, %cols_reg;
2815 @%lp bra LNB_S12D;
2816 // Load input[row, j]
2817 cvt.u64.u32 %off, %j;
2818 shl.b64 %off, %off, 2;
2819 add.u64 %addr, %in, %off;
2820 add.u64 %addr, %addr, %row_off;
2821 ld.global.f32 %val, [%addr];
2822 // x_hat = (val - mean) * inv_std
2823 sub.f32 %x_hat, %val, %mean;
2824 mul.f32 %x_hat, %x_hat, %inv_std;
2825 // Load grad_output[row, j]
2826 cvt.u64.u32 %off, %j;
2827 shl.b64 %off, %off, 2;
2828 add.u64 %addr, %go, %off;
2829 add.u64 %addr, %addr, %row_off;
2830 ld.global.f32 %gov, [%addr];
2831 // Load weight[j]
2832 cvt.u64.u32 %off, %j;
2833 shl.b64 %off, %off, 2;
2834 add.u64 %addr, %w, %off;
2835 ld.global.f32 %wv, [%addr];
2836 // dl_dx_hat = grad_output * weight
2837 mul.f32 %dl_dx_hat, %gov, %wv;
2838 // Accumulate sums
2839 add.f32 %sum1, %sum1, %dl_dx_hat;
2840 fma.rn.f32 %sum2, %dl_dx_hat, %x_hat, %sum2;
2841 // atomicAdd grad_weight[j] += grad_output * x_hat
2842 cvt.u64.u32 %off, %j;
2843 shl.b64 %off, %off, 2;
2844 add.u64 %addr, %gw, %off;
2845 mul.f32 %result, %gov, %x_hat;
2846 atom.global.add.f32 %result, [%addr], %result;
2847 // atomicAdd grad_bias[j] += grad_output
2848 add.u64 %addr, %gb, %off;
2849 atom.global.add.f32 %result, [%addr], %gov;
2850 add.u32 %j, %j, %r_bdim;
2851 bra LNB_S12;
2852LNB_S12D:
2853 // Reduce sum1 in shared memory
2854 cvt.u64.u32 %off, %r_tid;
2855 shl.b64 %off, %off, 2;
2856 add.u64 %saddr, %sbase, %off;
2857 st.shared.f32 [%saddr], %sum1;
2858 bar.sync 0;
2859 mov.u32 %half, %r_bdim;
2860LNB_R1:
2861 shr.u32 %half, %half, 1;
2862 setp.eq.u32 %rp, %half, 0;
2863 @%rp bra LNB_R1D;
2864 setp.ge.u32 %rp, %r_tid, %half;
2865 @%rp bra LNB_R1S;
2866 add.u32 %r_otid, %r_tid, %half;
2867 cvt.u64.u32 %off, %r_otid;
2868 shl.b64 %off, %off, 2;
2869 add.u64 %saddr, %sbase, %off;
2870 ld.shared.f32 %other_val, [%saddr];
2871 cvt.u64.u32 %off, %r_tid;
2872 shl.b64 %off, %off, 2;
2873 add.u64 %saddr, %sbase, %off;
2874 ld.shared.f32 %sum1, [%saddr];
2875 add.f32 %sum1, %sum1, %other_val;
2876 st.shared.f32 [%saddr], %sum1;
2877LNB_R1S:
2878 bar.sync 0;
2879 bra LNB_R1;
2880LNB_R1D:
2881 ld.shared.f32 %sum1, [%sbase];
2882 // mean1 = sum1 / n
2883 div.approx.f32 %mean1, %sum1, %n_f;
2884 bar.sync 0;
2885
2886 // Reduce sum2 in shared memory
2887 cvt.u64.u32 %off, %r_tid;
2888 shl.b64 %off, %off, 2;
2889 add.u64 %saddr, %sbase, %off;
2890 st.shared.f32 [%saddr], %sum2;
2891 bar.sync 0;
2892 mov.u32 %half, %r_bdim;
2893LNB_R2:
2894 shr.u32 %half, %half, 1;
2895 setp.eq.u32 %rp, %half, 0;
2896 @%rp bra LNB_R2D;
2897 setp.ge.u32 %rp, %r_tid, %half;
2898 @%rp bra LNB_R2S;
2899 add.u32 %r_otid, %r_tid, %half;
2900 cvt.u64.u32 %off, %r_otid;
2901 shl.b64 %off, %off, 2;
2902 add.u64 %saddr, %sbase, %off;
2903 ld.shared.f32 %other_val, [%saddr];
2904 cvt.u64.u32 %off, %r_tid;
2905 shl.b64 %off, %off, 2;
2906 add.u64 %saddr, %sbase, %off;
2907 ld.shared.f32 %sum2, [%saddr];
2908 add.f32 %sum2, %sum2, %other_val;
2909 st.shared.f32 [%saddr], %sum2;
2910LNB_R2S:
2911 bar.sync 0;
2912 bra LNB_R2;
2913LNB_R2D:
2914 ld.shared.f32 %sum2, [%sbase];
2915 // mean2 = sum2 / n
2916 div.approx.f32 %mean2, %sum2, %n_f;
2917 bar.sync 0;
2918
2919 // ===== Phase 4: Compute grad_input =====
2920 // grad_input[j] = inv_std * (dl_dx_hat[j] - mean1 - x_hat[j] * mean2)
2921 mov.u32 %j, %r_tid;
2922LNB_GI:
2923 setp.ge.u32 %lp, %j, %cols_reg;
2924 @%lp bra LNB_GID;
2925 // Reload input to recompute x_hat
2926 cvt.u64.u32 %off, %j;
2927 shl.b64 %off, %off, 2;
2928 add.u64 %addr, %in, %off;
2929 add.u64 %addr, %addr, %row_off;
2930 ld.global.f32 %val, [%addr];
2931 sub.f32 %x_hat, %val, %mean;
2932 mul.f32 %x_hat, %x_hat, %inv_std;
2933 // Reload grad_output and weight to recompute dl_dx_hat
2934 cvt.u64.u32 %off, %j;
2935 shl.b64 %off, %off, 2;
2936 add.u64 %addr, %go, %off;
2937 add.u64 %addr, %addr, %row_off;
2938 ld.global.f32 %gov, [%addr];
2939 cvt.u64.u32 %off, %j;
2940 shl.b64 %off, %off, 2;
2941 add.u64 %addr, %w, %off;
2942 ld.global.f32 %wv, [%addr];
2943 mul.f32 %dl_dx_hat, %gov, %wv;
2944 // result = inv_std * (dl_dx_hat - mean1 - x_hat * mean2)
2945 sub.f32 %result, %dl_dx_hat, %mean1;
2946 mul.f32 %diff, %x_hat, %mean2;
2947 sub.f32 %result, %result, %diff;
2948 mul.f32 %result, %inv_std, %result;
2949 // Store grad_input[row, j]
2950 cvt.u64.u32 %off, %j;
2951 shl.b64 %off, %off, 2;
2952 add.u64 %addr, %gi, %off;
2953 add.u64 %addr, %addr, %row_off;
2954 st.global.f32 [%addr], %result;
2955 add.u32 %j, %j, %r_bdim;
2956 bra LNB_GI;
2957LNB_GID:
2958
2959LNB_DONE:
2960 ret;
2961}
2962";
2963
2964#[cfg(feature = "cuda")]
2997pub(crate) const BATCHNORM_FORWARD_PTX: &str = "\
2998.version 7.0
2999.target sm_52
3000.address_size 64
3001
3002// Shared memory for block reduction
3003.shared .align 4 .f32 smem_sum[256];
3004.shared .align 4 .f32 smem_sq[256];
3005
3006.visible .entry batchnorm_forward_kernel(
3007 .param .u64 input_ptr,
3008 .param .u64 output_ptr,
3009 .param .u64 weight_ptr,
3010 .param .u64 bias_ptr,
3011 .param .u64 rmean_ptr,
3012 .param .u64 rvar_ptr,
3013 .param .u64 save_mean_ptr,
3014 .param .u64 save_invstd_ptr,
3015 .param .u32 channels,
3016 .param .u32 spatial,
3017 .param .f32 eps,
3018 .param .f32 momentum,
3019 .param .u32 total_per_ch,
3020 .param .u32 training
3021) {
3022 .reg .u32 %tid, %bid, %bdim, %ch, %n_ch, %sp, %tpc, %idx, %train;
3023 .reg .u64 %in, %out, %w, %b, %rm, %rv, %sm, %si, %off64, %tmp64;
3024 .reg .f32 %sum, %sqsum, %val, %mean, %var, %invstd;
3025 .reg .f32 %gamma, %beta, %eps_reg, %mom, %other;
3026 .reg .f32 %n_f, %one, %normalized;
3027 .reg .pred %p, %ptrain, %ptid0;
3028 .reg .u32 %half;
3029
3030 ld.param.u64 %in, [input_ptr];
3031 ld.param.u64 %out, [output_ptr];
3032 ld.param.u64 %w, [weight_ptr];
3033 ld.param.u64 %b, [bias_ptr];
3034 ld.param.u64 %rm, [rmean_ptr];
3035 ld.param.u64 %rv, [rvar_ptr];
3036 ld.param.u64 %sm, [save_mean_ptr];
3037 ld.param.u64 %si, [save_invstd_ptr];
3038 ld.param.u32 %n_ch, [channels];
3039 ld.param.u32 %sp, [spatial];
3040 ld.param.f32 %eps_reg, [eps];
3041 ld.param.f32 %mom, [momentum];
3042 ld.param.u32 %tpc, [total_per_ch];
3043 ld.param.u32 %train, [training];
3044
3045 mov.u32 %bid, %ctaid.x;
3046 mov.u32 %tid, %tid.x;
3047 mov.u32 %bdim, %ntid.x;
3048 mov.u32 %ch, %bid;
3049 mov.f32 %one, 0f3F800000;
3050
3051 setp.ge.u32 %p, %ch, %n_ch;
3052 @%p bra END;
3053
3054 setp.ne.u32 %ptrain, %train, 0;
3055
3056 // ---- Pass 1: compute sum and sum-of-squares for this channel ----
3057 mov.f32 %sum, 0f00000000;
3058 mov.f32 %sqsum, 0f00000000;
3059
3060 // Grid-stride loop over B*spatial for this channel
3061 mov.u32 %idx, %tid;
3062PASS1_LOOP:
3063 setp.ge.u32 %p, %idx, %tpc;
3064 @%p bra PASS1_DONE;
3065
3066 // Linear offset = (idx / spatial) * channels * spatial + ch * spatial + idx % spatial
3067 div.u32 %half, %idx, %sp;
3068 rem.u32 %half, %idx, %sp; // reuse half as spatial_idx
3069 // batch_offset = (idx / sp) * (n_ch * sp) + ch * sp + (idx % sp)
3070 div.u32 %half, %idx, %sp; // batch_idx
3071 mul.lo.u32 %half, %half, %n_ch;
3072 add.u32 %half, %half, %ch;
3073 mul.lo.u32 %half, %half, %sp;
3074 rem.u32 %idx, %idx, %sp; // spatial_idx
3075 add.u32 %half, %half, %idx;
3076
3077 cvt.u64.u32 %off64, %half;
3078 shl.b64 %off64, %off64, 2;
3079 add.u64 %tmp64, %in, %off64;
3080 ld.global.f32 %val, [%tmp64];
3081 add.f32 %sum, %sum, %val;
3082 fma.rn.f32 %sqsum, %val, %val, %sqsum;
3083
3084 // Restore idx for stride
3085 // Recompute idx from tid + iteration * bdim
3086 add.u32 %idx, %idx, %bdim; // This is wrong - need proper loop counter
3087 bra PASS1_LOOP;
3088
3089PASS1_DONE:
3090 // Store to shared memory for block reduction
3091 cvt.u64.u32 %off64, %tid;
3092 shl.b64 %off64, %off64, 2;
3093 st.shared.f32 [smem_sum + %off64], %sum;
3094 st.shared.f32 [smem_sq + %off64], %sqsum;
3095 bar.sync 0;
3096
3097 // Tree reduction
3098 mov.u32 %half, 128;
3099REDUCE_LOOP:
3100 setp.lt.u32 %p, %half, 1;
3101 @%p bra REDUCE_DONE;
3102 setp.ge.u32 %p, %tid, %half;
3103 @%p bra REDUCE_SKIP;
3104
3105 add.u32 %idx, %tid, %half;
3106 cvt.u64.u32 %off64, %idx;
3107 shl.b64 %off64, %off64, 2;
3108 ld.shared.f32 %other, [smem_sum + %off64];
3109 cvt.u64.u32 %tmp64, %tid;
3110 shl.b64 %tmp64, %tmp64, 2;
3111 ld.shared.f32 %sum, [smem_sum + %tmp64];
3112 add.f32 %sum, %sum, %other;
3113 st.shared.f32 [smem_sum + %tmp64], %sum;
3114
3115 ld.shared.f32 %other, [smem_sq + %off64];
3116 ld.shared.f32 %sqsum, [smem_sq + %tmp64];
3117 add.f32 %sqsum, %sqsum, %other;
3118 st.shared.f32 [smem_sq + %tmp64], %sqsum;
3119
3120REDUCE_SKIP:
3121 bar.sync 0;
3122 shr.u32 %half, %half, 1;
3123 bra REDUCE_LOOP;
3124
3125REDUCE_DONE:
3126 // Thread 0 computes mean and invstd
3127 setp.ne.u32 %ptid0, %tid, 0;
3128
3129 @%ptid0 bra WAIT_STATS;
3130
3131 ld.shared.f32 %sum, [smem_sum];
3132 ld.shared.f32 %sqsum, [smem_sq];
3133 cvt.rn.f32.u32 %n_f, %tpc;
3134 div.rn.f32 %mean, %sum, %n_f;
3135 // var = sqsum/n - mean^2
3136 div.rn.f32 %var, %sqsum, %n_f;
3137 fma.rn.f32 %var, %mean, %mean, %var; // This adds mean^2, need to subtract
3138 // Actually: var = E[x^2] - E[x]^2, so var = sqsum/n - mean^2
3139 // We had: var = sqsum/n, now subtract mean^2
3140 neg.f32 %other, %mean;
3141 fma.rn.f32 %var, %other, %mean, %var; // var = var + (-mean)*mean = sqsum/n - mean^2
3142
3143 // invstd = 1/sqrt(var + eps)
3144 add.f32 %other, %var, %eps_reg;
3145 sqrt.rn.f32 %other, %other;
3146 div.rn.f32 %invstd, %one, %other;
3147
3148 // Save mean and invstd
3149 cvt.u64.u32 %off64, %ch;
3150 shl.b64 %off64, %off64, 2;
3151 add.u64 %tmp64, %sm, %off64;
3152 st.global.f32 [%tmp64], %mean;
3153 add.u64 %tmp64, %si, %off64;
3154 st.global.f32 [%tmp64], %invstd;
3155
3156 // Store to shared for other threads
3157 st.shared.f32 [smem_sum], %mean;
3158 st.shared.f32 [smem_sq], %invstd;
3159
3160WAIT_STATS:
3161 bar.sync 0;
3162 // All threads read mean and invstd from shared
3163 ld.shared.f32 %mean, [smem_sum];
3164 ld.shared.f32 %invstd, [smem_sq];
3165
3166 // Load weight and bias for this channel
3167 cvt.u64.u32 %off64, %ch;
3168 shl.b64 %off64, %off64, 2;
3169 add.u64 %tmp64, %w, %off64;
3170 ld.global.f32 %gamma, [%tmp64];
3171 add.u64 %tmp64, %b, %off64;
3172 ld.global.f32 %beta, [%tmp64];
3173
3174 // ---- Pass 2: normalize + affine ----
3175 // For now this is a placeholder - the indexing needs to match pass 1
3176 // Each thread normalizes its elements
3177
3178END:
3179 ret;
3180}
3181";
3182
3183#[cfg(feature = "cuda")]
3188pub(crate) const MAXPOOL2D_PTX: &str = "\
3189.version 7.0
3190.target sm_52
3191.address_size 64
3192
3193.visible .entry maxpool2d_forward_kernel(
3194 .param .u64 input_ptr,
3195 .param .u64 output_ptr,
3196 .param .u32 batch,
3197 .param .u32 channels,
3198 .param .u32 h_in,
3199 .param .u32 w_in,
3200 .param .u32 h_out,
3201 .param .u32 w_out,
3202 .param .u32 kh,
3203 .param .u32 kw,
3204 .param .u32 sh,
3205 .param .u32 sw,
3206 .param .u32 ph,
3207 .param .u32 pw,
3208 .param .u32 total
3209) {
3210 .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
3211 .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp;
3212 .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
3213 .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
3214 .reg .u32 %batch_reg, %ch_reg;
3215 .reg .u64 %in, %out, %off64, %tmp64;
3216 .reg .f32 %max_val, %cur_val, %neg_inf;
3217 .reg .pred %p, %p_bounds, %p_gt;
3218
3219 ld.param.u64 %in, [input_ptr];
3220 ld.param.u64 %out, [output_ptr];
3221 ld.param.u32 %batch_reg, [batch];
3222 ld.param.u32 %ch_reg, [channels];
3223 ld.param.u32 %h_in_reg, [h_in];
3224 ld.param.u32 %w_in_reg, [w_in];
3225 ld.param.u32 %h_out_reg, [h_out];
3226 ld.param.u32 %w_out_reg, [w_out];
3227 ld.param.u32 %kh_reg, [kh];
3228 ld.param.u32 %kw_reg, [kw];
3229 ld.param.u32 %sh_reg, [sh];
3230 ld.param.u32 %sw_reg, [sw];
3231 ld.param.u32 %ph_reg, [ph];
3232 ld.param.u32 %pw_reg, [pw];
3233 ld.param.u32 %total_reg, [total];
3234
3235 mov.u32 %bid, %ctaid.x;
3236 mov.u32 %bdim, %ntid.x;
3237 mov.u32 %tid, %tid.x;
3238 mov.u32 %gdim, %nctaid.x;
3239 mad.lo.u32 %idx, %bid, %bdim, %tid;
3240 mul.lo.u32 %stride, %bdim, %gdim;
3241
3242 // -inf for max initialization
3243 mov.f32 %neg_inf, 0fFF800000;
3244
3245LOOP:
3246 setp.ge.u32 %p, %idx, %total_reg;
3247 @%p bra END;
3248
3249 // Decompose idx into (b, c, oh, ow)
3250 mov.u32 %rem, %idx;
3251 div.u32 %b_idx, %rem, %ch_reg;
3252 // Actually need: idx = b * C * H_out * W_out + c * H_out * W_out + oh * W_out + ow
3253 // So decompose from the right:
3254 rem.u32 %ow, %rem, %w_out_reg;
3255 div.u32 %rem, %rem, %w_out_reg;
3256 rem.u32 %oh, %rem, %h_out_reg;
3257 div.u32 %rem, %rem, %h_out_reg;
3258 rem.u32 %c_idx, %rem, %ch_reg;
3259 div.u32 %b_idx, %rem, %ch_reg;
3260
3261 mov.f32 %max_val, %neg_inf;
3262
3263 // Slide the kernel window
3264 mov.u32 %i, 0;
3265KH_LOOP:
3266 setp.ge.u32 %p, %i, %kh_reg;
3267 @%p bra KH_DONE;
3268
3269 mov.u32 %j, 0;
3270KW_LOOP:
3271 setp.ge.u32 %p, %j, %kw_reg;
3272 @%p bra KW_DONE;
3273
3274 // ih = oh * sh + i - ph, iw = ow * sw + j - pw
3275 mad.lo.u32 %ih, %oh, %sh_reg, %i;
3276 sub.u32 %ih, %ih, %ph_reg;
3277 mad.lo.u32 %iw, %ow, %sw_reg, %j;
3278 sub.u32 %iw, %iw, %pw_reg;
3279
3280 // Bounds check: 0 <= ih < h_in && 0 <= iw < w_in
3281 // Since unsigned, just check < h_in and < w_in
3282 setp.ge.u32 %p_bounds, %ih, %h_in_reg;
3283 @%p_bounds bra KW_NEXT;
3284 setp.ge.u32 %p_bounds, %iw, %w_in_reg;
3285 @%p_bounds bra KW_NEXT;
3286
3287 // input_offset = b * C * H * W + c * H * W + ih * W + iw
3288 mul.lo.u32 %tmp, %b_idx, %ch_reg;
3289 add.u32 %tmp, %tmp, %c_idx;
3290 mul.lo.u32 %tmp, %tmp, %h_in_reg;
3291 add.u32 %tmp, %tmp, %ih;
3292 mul.lo.u32 %tmp, %tmp, %w_in_reg;
3293 add.u32 %tmp, %tmp, %iw;
3294
3295 cvt.u64.u32 %off64, %tmp;
3296 shl.b64 %off64, %off64, 2;
3297 add.u64 %tmp64, %in, %off64;
3298 ld.global.f32 %cur_val, [%tmp64];
3299
3300 max.f32 %max_val, %max_val, %cur_val;
3301
3302KW_NEXT:
3303 add.u32 %j, %j, 1;
3304 bra KW_LOOP;
3305
3306KW_DONE:
3307 add.u32 %i, %i, 1;
3308 bra KH_LOOP;
3309
3310KH_DONE:
3311 // Store output
3312 cvt.u64.u32 %off64, %idx;
3313 shl.b64 %off64, %off64, 2;
3314 add.u64 %tmp64, %out, %off64;
3315 st.global.f32 [%tmp64], %max_val;
3316
3317 add.u32 %idx, %idx, %stride;
3318 bra LOOP;
3319
3320END:
3321 ret;
3322}
3323";
3324
3325#[cfg(feature = "cuda")]
3330pub(crate) const AVGPOOL2D_PTX: &str = "\
3331.version 7.0
3332.target sm_52
3333.address_size 64
3334
3335.visible .entry avgpool2d_forward_kernel(
3336 .param .u64 input_ptr,
3337 .param .u64 output_ptr,
3338 .param .u32 batch,
3339 .param .u32 channels,
3340 .param .u32 h_in,
3341 .param .u32 w_in,
3342 .param .u32 h_out,
3343 .param .u32 w_out,
3344 .param .u32 kh,
3345 .param .u32 kw,
3346 .param .u32 sh,
3347 .param .u32 sw,
3348 .param .u32 ph,
3349 .param .u32 pw,
3350 .param .u32 total
3351) {
3352 .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
3353 .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp, %count;
3354 .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
3355 .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
3356 .reg .u32 %batch_reg, %ch_reg;
3357 .reg .u64 %in, %out, %off64, %tmp64;
3358 .reg .f32 %sum_val, %cur_val, %count_f, %avg;
3359 .reg .pred %p, %p_bounds;
3360
3361 ld.param.u64 %in, [input_ptr];
3362 ld.param.u64 %out, [output_ptr];
3363 ld.param.u32 %batch_reg, [batch];
3364 ld.param.u32 %ch_reg, [channels];
3365 ld.param.u32 %h_in_reg, [h_in];
3366 ld.param.u32 %w_in_reg, [w_in];
3367 ld.param.u32 %h_out_reg, [h_out];
3368 ld.param.u32 %w_out_reg, [w_out];
3369 ld.param.u32 %kh_reg, [kh];
3370 ld.param.u32 %kw_reg, [kw];
3371 ld.param.u32 %sh_reg, [sh];
3372 ld.param.u32 %sw_reg, [sw];
3373 ld.param.u32 %ph_reg, [ph];
3374 ld.param.u32 %pw_reg, [pw];
3375 ld.param.u32 %total_reg, [total];
3376
3377 mov.u32 %bid, %ctaid.x;
3378 mov.u32 %bdim, %ntid.x;
3379 mov.u32 %tid, %tid.x;
3380 mov.u32 %gdim, %nctaid.x;
3381 mad.lo.u32 %idx, %bid, %bdim, %tid;
3382 mul.lo.u32 %stride, %bdim, %gdim;
3383
3384LOOP:
3385 setp.ge.u32 %p, %idx, %total_reg;
3386 @%p bra END;
3387
3388 // Decompose idx into (b, c, oh, ow) — same as MaxPool2d
3389 mov.u32 %rem, %idx;
3390 rem.u32 %ow, %rem, %w_out_reg;
3391 div.u32 %rem, %rem, %w_out_reg;
3392 rem.u32 %oh, %rem, %h_out_reg;
3393 div.u32 %rem, %rem, %h_out_reg;
3394 rem.u32 %c_idx, %rem, %ch_reg;
3395 div.u32 %b_idx, %rem, %ch_reg;
3396
3397 mov.f32 %sum_val, 0f00000000;
3398 mov.u32 %count, 0;
3399
3400 mov.u32 %i, 0;
3401AKH_LOOP:
3402 setp.ge.u32 %p, %i, %kh_reg;
3403 @%p bra AKH_DONE;
3404
3405 mov.u32 %j, 0;
3406AKW_LOOP:
3407 setp.ge.u32 %p, %j, %kw_reg;
3408 @%p bra AKW_DONE;
3409
3410 mad.lo.u32 %ih, %oh, %sh_reg, %i;
3411 sub.u32 %ih, %ih, %ph_reg;
3412 mad.lo.u32 %iw, %ow, %sw_reg, %j;
3413 sub.u32 %iw, %iw, %pw_reg;
3414
3415 setp.ge.u32 %p_bounds, %ih, %h_in_reg;
3416 @%p_bounds bra AKW_NEXT;
3417 setp.ge.u32 %p_bounds, %iw, %w_in_reg;
3418 @%p_bounds bra AKW_NEXT;
3419
3420 mul.lo.u32 %tmp, %b_idx, %ch_reg;
3421 add.u32 %tmp, %tmp, %c_idx;
3422 mul.lo.u32 %tmp, %tmp, %h_in_reg;
3423 add.u32 %tmp, %tmp, %ih;
3424 mul.lo.u32 %tmp, %tmp, %w_in_reg;
3425 add.u32 %tmp, %tmp, %iw;
3426
3427 cvt.u64.u32 %off64, %tmp;
3428 shl.b64 %off64, %off64, 2;
3429 add.u64 %tmp64, %in, %off64;
3430 ld.global.f32 %cur_val, [%tmp64];
3431
3432 add.f32 %sum_val, %sum_val, %cur_val;
3433 add.u32 %count, %count, 1;
3434
3435AKW_NEXT:
3436 add.u32 %j, %j, 1;
3437 bra AKW_LOOP;
3438
3439AKW_DONE:
3440 add.u32 %i, %i, 1;
3441 bra AKH_LOOP;
3442
3443AKH_DONE:
3444 // avg = sum / count (count_include_pad = false behavior)
3445 cvt.rn.f32.u32 %count_f, %count;
3446 div.rn.f32 %avg, %sum_val, %count_f;
3447
3448 cvt.u64.u32 %off64, %idx;
3449 shl.b64 %off64, %off64, 2;
3450 add.u64 %tmp64, %out, %off64;
3451 st.global.f32 [%tmp64], %avg;
3452
3453 add.u32 %idx, %idx, %stride;
3454 bra LOOP;
3455
3456END:
3457 ret;
3458}
3459";
3460
3461#[cfg(feature = "cuda")]
3462pub(crate) const SOFTMAX_PTX: &str = "\
3463.version 7.0\n\
3464.target sm_52\n\
3465.address_size 64\n\
3466\n\
3467.shared .align 4 .f32 sdata[256];\n\
3468\n\
3469.visible .entry softmax_kernel(\n\
3470 .param .u64 input_ptr,\n\
3471 .param .u64 output_ptr,\n\
3472 .param .u32 rows,\n\
3473 .param .u32 cols\n\
3474) {\n\
3475 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
3476 .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
3477 .reg .f32 %val, %max_val, %sum_val, %exp_val, %result;\n\
3478 .reg .pred %p, %loop_p;\n\
3479 .reg .u32 %half, %other_tid;\n\
3480 .reg .f32 %other_val;\n\
3481 .reg .pred %reduce_p;\n\
3482\n\
3483 ld.param.u64 %in, [input_ptr];\n\
3484 ld.param.u64 %out, [output_ptr];\n\
3485 ld.param.u32 %rows_reg, [rows];\n\
3486 ld.param.u32 %cols_reg, [cols];\n\
3487\n\
3488 mov.u32 %bid, %ctaid.x;\n\
3489 mov.u32 %bdim, %ntid.x;\n\
3490 mov.u32 %r_tid, %tid.x;\n\
3491 mov.u64 %sbase, sdata;\n\
3492\n\
3493 setp.ge.u32 %p, %bid, %rows_reg;\n\
3494 @%p bra DONE;\n\
3495\n\
3496 cvt.u64.u32 %row_off, %bid;\n\
3497 cvt.u64.u32 %off, %cols_reg;\n\
3498 mul.lo.u64 %row_off, %row_off, %off;\n\
3499 shl.b64 %row_off, %row_off, 2;\n\
3500\n\
3501 mov.f32 %max_val, 0fFF800000;\n\
3502 mov.u32 %j, %r_tid;\n\
3503FIND_MAX:\n\
3504 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3505 @%loop_p bra FIND_MAX_DONE;\n\
3506 cvt.u64.u32 %off, %j;\n\
3507 shl.b64 %off, %off, 2;\n\
3508 add.u64 %off, %in, %off;\n\
3509 add.u64 %off, %off, %row_off;\n\
3510 ld.global.f32 %val, [%off];\n\
3511 max.f32 %max_val, %max_val, %val;\n\
3512 add.u32 %j, %j, %bdim;\n\
3513 bra FIND_MAX;\n\
3514FIND_MAX_DONE:\n\
3515\n\
3516 cvt.u64.u32 %off, %r_tid;\n\
3517 shl.b64 %off, %off, 2;\n\
3518 add.u64 %saddr, %sbase, %off;\n\
3519 st.shared.f32 [%saddr], %max_val;\n\
3520 bar.sync 0;\n\
3521\n\
3522 mov.u32 %half, %bdim;\n\
3523MAX_REDUCE:\n\
3524 shr.u32 %half, %half, 1;\n\
3525 setp.eq.u32 %reduce_p, %half, 0;\n\
3526 @%reduce_p bra MAX_REDUCE_DONE;\n\
3527 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
3528 @%reduce_p bra MAX_REDUCE_SKIP;\n\
3529 add.u32 %other_tid, %r_tid, %half;\n\
3530 cvt.u64.u32 %off, %other_tid;\n\
3531 shl.b64 %off, %off, 2;\n\
3532 add.u64 %saddr, %sbase, %off;
3533 ld.shared.f32 %other_val, [%saddr];\n\
3534 cvt.u64.u32 %off, %r_tid;\n\
3535 shl.b64 %off, %off, 2;\n\
3536 add.u64 %saddr, %sbase, %off;\n\
3537 ld.shared.f32 %max_val, [%saddr];\n\
3538 max.f32 %max_val, %max_val, %other_val;\n\
3539 add.u64 %saddr, %sbase, %off;\n\
3540 st.shared.f32 [%saddr], %max_val;\n\
3541MAX_REDUCE_SKIP:\n\
3542 bar.sync 0;\n\
3543 bra MAX_REDUCE;\n\
3544MAX_REDUCE_DONE:\n\
3545\n\
3546 ld.shared.f32 %max_val, [sdata];\n\
3547 bar.sync 0;\n\
3548\n\
3549 mov.f32 %sum_val, 0f00000000;\n\
3550 mov.u32 %j, %r_tid;\n\
3551SUM_EXP:\n\
3552 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3553 @%loop_p bra SUM_EXP_DONE;\n\
3554 cvt.u64.u32 %off, %j;\n\
3555 shl.b64 %off, %off, 2;\n\
3556 add.u64 %off, %in, %off;\n\
3557 add.u64 %off, %off, %row_off;\n\
3558 ld.global.f32 %val, [%off];\n\
3559 sub.f32 %val, %val, %max_val;\n\
3560 mul.f32 %val, %val, 0f3FB8AA3B;\n\
3561 ex2.approx.f32 %exp_val, %val;\n\
3562 add.f32 %sum_val, %sum_val, %exp_val;\n\
3563 cvt.u64.u32 %off, %j;\n\
3564 shl.b64 %off, %off, 2;\n\
3565 add.u64 %off, %out, %off;\n\
3566 add.u64 %off, %off, %row_off;\n\
3567 st.global.f32 [%off], %exp_val;\n\
3568 add.u32 %j, %j, %bdim;\n\
3569 bra SUM_EXP;\n\
3570SUM_EXP_DONE:\n\
3571\n\
3572 cvt.u64.u32 %off, %r_tid;\n\
3573 shl.b64 %off, %off, 2;\n\
3574 add.u64 %saddr, %sbase, %off;\n\
3575 st.shared.f32 [%saddr], %sum_val;\n\
3576 bar.sync 0;\n\
3577\n\
3578 mov.u32 %half, %bdim;\n\
3579SUM_REDUCE:\n\
3580 shr.u32 %half, %half, 1;\n\
3581 setp.eq.u32 %reduce_p, %half, 0;\n\
3582 @%reduce_p bra SUM_REDUCE_DONE;\n\
3583 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
3584 @%reduce_p bra SUM_REDUCE_SKIP;\n\
3585 add.u32 %other_tid, %r_tid, %half;\n\
3586 cvt.u64.u32 %off, %other_tid;\n\
3587 shl.b64 %off, %off, 2;\n\
3588 add.u64 %saddr, %sbase, %off;
3589 ld.shared.f32 %other_val, [%saddr];\n\
3590 cvt.u64.u32 %off, %r_tid;\n\
3591 shl.b64 %off, %off, 2;\n\
3592 add.u64 %saddr, %sbase, %off;\n\
3593 ld.shared.f32 %sum_val, [%saddr];\n\
3594 add.f32 %sum_val, %sum_val, %other_val;\n\
3595 add.u64 %saddr, %sbase, %off;\n\
3596 st.shared.f32 [%saddr], %sum_val;\n\
3597SUM_REDUCE_SKIP:\n\
3598 bar.sync 0;\n\
3599 bra SUM_REDUCE;\n\
3600SUM_REDUCE_DONE:\n\
3601\n\
3602 ld.shared.f32 %sum_val, [sdata];\n\
3603 bar.sync 0;\n\
3604\n\
3605 rcp.approx.f32 %sum_val, %sum_val;\n\
3606 mov.u32 %j, %r_tid;\n\
3607NORMALIZE:\n\
3608 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
3609 @%loop_p bra NORMALIZE_DONE;\n\
3610 cvt.u64.u32 %off, %j;\n\
3611 shl.b64 %off, %off, 2;\n\
3612 add.u64 %off, %out, %off;\n\
3613 add.u64 %off, %off, %row_off;\n\
3614 ld.global.f32 %val, [%off];\n\
3615 mul.f32 %result, %val, %sum_val;\n\
3616 st.global.f32 [%off], %result;\n\
3617 add.u32 %j, %j, %bdim;\n\
3618 bra NORMALIZE;\n\
3619NORMALIZE_DONE:\n\
3620\n\
3621DONE:\n\
3622 ret;\n\
3623}\n\
3624";
3625
3626#[cfg(feature = "cuda")]
3631pub(crate) const DROPOUT_PTX: &str = "\
3632.version 7.0\n\
3633.target sm_52\n\
3634.address_size 64\n\
3635\n\
3636.visible .entry dropout_kernel(\n\
3637 .param .u64 input_ptr,\n\
3638 .param .u64 output_ptr,\n\
3639 .param .u32 n,\n\
3640 .param .u32 threshold,\n\
3641 .param .f32 scale,\n\
3642 .param .u32 seed\n\
3643) {\n\
3644 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %thresh, %seed_reg, %rng, %tmp;\n\
3645 .reg .u64 %in, %out, %off;\n\
3646 .reg .f32 %val, %scale_reg, %zero;\n\
3647 .reg .pred %p, %drop_p;\n\
3648\n\
3649 ld.param.u64 %in, [input_ptr];\n\
3650 ld.param.u64 %out, [output_ptr];\n\
3651 ld.param.u32 %n_reg, [n];\n\
3652 ld.param.u32 %thresh, [threshold];\n\
3653 ld.param.f32 %scale_reg, [scale];\n\
3654 ld.param.u32 %seed_reg, [seed];\n\
3655\n\
3656 mov.u32 %bid, %ctaid.x;\n\
3657 mov.u32 %bdim, %ntid.x;\n\
3658 mov.u32 %r_tid, %tid.x;\n\
3659 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
3660\n\
3661 setp.ge.u32 %p, %r_tid, %n_reg;\n\
3662 @%p bra DONE;\n\
3663\n\
3664 mul.lo.u32 %rng, %r_tid, 2654435761;\n\
3665 xor.b32 %rng, %rng, %seed_reg;\n\
3666 shl.b32 %tmp, %rng, 13;\n\
3667 xor.b32 %rng, %rng, %tmp;\n\
3668 shr.b32 %tmp, %rng, 17;\n\
3669 xor.b32 %rng, %rng, %tmp;\n\
3670 shl.b32 %tmp, %rng, 5;\n\
3671 xor.b32 %rng, %rng, %tmp;\n\
3672\n\
3673 cvt.u64.u32 %off, %r_tid;\n\
3674 shl.b64 %off, %off, 2;\n\
3675 add.u64 %in, %in, %off;\n\
3676 add.u64 %out, %out, %off;\n\
3677 ld.global.f32 %val, [%in];\n\
3678\n\
3679 setp.lo.u32 %drop_p, %rng, %thresh;\n\
3680 mov.f32 %zero, 0f00000000;\n\
3681 @%drop_p mov.f32 %val, %zero;\n\
3682 @!%drop_p mul.f32 %val, %val, %scale_reg;\n\
3683\n\
3684 st.global.f32 [%out], %val;\n\
3685\n\
3686DONE:\n\
3687 ret;\n\
3688}\n\
3689";
3690
3691#[cfg(feature = "cuda")]
3714pub(crate) const BROADCAST_ADD_PTX: &str = "\
3715.version 7.0
3716.target sm_52
3717.address_size 64
3718
3719.visible .entry broadcast_add_kernel(
3720 .param .u64 a_ptr,
3721 .param .u64 b_ptr,
3722 .param .u64 out_ptr,
3723 .param .u64 a_strides_ptr,
3724 .param .u64 b_strides_ptr,
3725 .param .u64 out_shape_ptr,
3726 .param .u32 n,
3727 .param .u32 ndim
3728) {
3729 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3730 .reg .u32 %remaining, %a_idx, %b_idx, %d;
3731 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3732 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3733 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3734 .reg .f32 %va, %vb, %vr;
3735 .reg .pred %p, %loop_p;
3736
3737 ld.param.u64 %a, [a_ptr];
3738 ld.param.u64 %b, [b_ptr];
3739 ld.param.u64 %out, [out_ptr];
3740 ld.param.u64 %a_str, [a_strides_ptr];
3741 ld.param.u64 %b_str, [b_strides_ptr];
3742 ld.param.u64 %oshape, [out_shape_ptr];
3743 ld.param.u32 %n_reg, [n];
3744 ld.param.u32 %ndim_reg, [ndim];
3745
3746 // Global thread index.
3747 mov.u32 %bid, %ctaid.x;
3748 mov.u32 %bdim, %ntid.x;
3749 mov.u32 %r_tid, %tid.x;
3750 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3751
3752 setp.ge.u32 %p, %r_tid, %n_reg;
3753 @%p bra DONE;
3754
3755 // Decompose flat index into N-d coordinates and compute A/B indices.
3756 mov.u32 %remaining, %r_tid;
3757 mov.u32 %a_idx, 0;
3758 mov.u32 %b_idx, 0;
3759 mov.u32 %d, %ndim_reg;
3760
3761LOOP:
3762 setp.eq.u32 %loop_p, %d, 0;
3763 @%loop_p bra END_LOOP;
3764
3765 sub.u32 %d, %d, 1;
3766
3767 // Byte offset for dimension d: d * 4.
3768 cvt.u64.u32 %d64, %d;
3769 shl.b64 %d64, %d64, 2;
3770
3771 // Load out_shape[d].
3772 add.u64 %tmp, %oshape, %d64;
3773 ld.global.u32 %shape_d, [%tmp];
3774
3775 // Load a_strides[d] and b_strides[d].
3776 add.u64 %tmp, %a_str, %d64;
3777 ld.global.u32 %a_str_d, [%tmp];
3778 add.u64 %tmp, %b_str, %d64;
3779 ld.global.u32 %b_str_d, [%tmp];
3780
3781 // coord = remaining % shape_d; remaining /= shape_d.
3782 rem.u32 %coord, %remaining, %shape_d;
3783 div.u32 %remaining, %remaining, %shape_d;
3784
3785 // a_idx += coord * a_stride[d]; b_idx += coord * b_stride[d].
3786 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3787 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3788
3789 bra LOOP;
3790END_LOOP:
3791
3792 // Load a[a_idx] and b[b_idx] (f32 = 4 bytes).
3793 cvt.u64.u32 %off_a, %a_idx;
3794 shl.b64 %off_a, %off_a, 2;
3795 add.u64 %off_a, %a, %off_a;
3796 ld.global.f32 %va, [%off_a];
3797
3798 cvt.u64.u32 %off_b, %b_idx;
3799 shl.b64 %off_b, %off_b, 2;
3800 add.u64 %off_b, %b, %off_b;
3801 ld.global.f32 %vb, [%off_b];
3802
3803 // Operation: add.
3804 add.f32 %vr, %va, %vb;
3805
3806 // Store to out[tid].
3807 cvt.u64.u32 %off_out, %r_tid;
3808 shl.b64 %off_out, %off_out, 2;
3809 add.u64 %off_out, %out, %off_out;
3810 st.global.f32 [%off_out], %vr;
3811
3812DONE:
3813 ret;
3814}
3815";
3816
3817#[cfg(feature = "cuda")]
3819pub(crate) const BROADCAST_SUB_PTX: &str = "\
3820.version 7.0
3821.target sm_52
3822.address_size 64
3823
3824.visible .entry broadcast_sub_kernel(
3825 .param .u64 a_ptr,
3826 .param .u64 b_ptr,
3827 .param .u64 out_ptr,
3828 .param .u64 a_strides_ptr,
3829 .param .u64 b_strides_ptr,
3830 .param .u64 out_shape_ptr,
3831 .param .u32 n,
3832 .param .u32 ndim
3833) {
3834 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3835 .reg .u32 %remaining, %a_idx, %b_idx, %d;
3836 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3837 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3838 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3839 .reg .f32 %va, %vb, %vr;
3840 .reg .pred %p, %loop_p;
3841
3842 ld.param.u64 %a, [a_ptr];
3843 ld.param.u64 %b, [b_ptr];
3844 ld.param.u64 %out, [out_ptr];
3845 ld.param.u64 %a_str, [a_strides_ptr];
3846 ld.param.u64 %b_str, [b_strides_ptr];
3847 ld.param.u64 %oshape, [out_shape_ptr];
3848 ld.param.u32 %n_reg, [n];
3849 ld.param.u32 %ndim_reg, [ndim];
3850
3851 mov.u32 %bid, %ctaid.x;
3852 mov.u32 %bdim, %ntid.x;
3853 mov.u32 %r_tid, %tid.x;
3854 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3855 setp.ge.u32 %p, %r_tid, %n_reg;
3856 @%p bra DONE;
3857
3858 mov.u32 %remaining, %r_tid;
3859 mov.u32 %a_idx, 0;
3860 mov.u32 %b_idx, 0;
3861 mov.u32 %d, %ndim_reg;
3862LOOP:
3863 setp.eq.u32 %loop_p, %d, 0;
3864 @%loop_p bra END_LOOP;
3865 sub.u32 %d, %d, 1;
3866 cvt.u64.u32 %d64, %d;
3867 shl.b64 %d64, %d64, 2;
3868 add.u64 %tmp, %oshape, %d64;
3869 ld.global.u32 %shape_d, [%tmp];
3870 add.u64 %tmp, %a_str, %d64;
3871 ld.global.u32 %a_str_d, [%tmp];
3872 add.u64 %tmp, %b_str, %d64;
3873 ld.global.u32 %b_str_d, [%tmp];
3874 rem.u32 %coord, %remaining, %shape_d;
3875 div.u32 %remaining, %remaining, %shape_d;
3876 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3877 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3878 bra LOOP;
3879END_LOOP:
3880
3881 cvt.u64.u32 %off_a, %a_idx;
3882 shl.b64 %off_a, %off_a, 2;
3883 add.u64 %off_a, %a, %off_a;
3884 ld.global.f32 %va, [%off_a];
3885 cvt.u64.u32 %off_b, %b_idx;
3886 shl.b64 %off_b, %off_b, 2;
3887 add.u64 %off_b, %b, %off_b;
3888 ld.global.f32 %vb, [%off_b];
3889
3890 sub.f32 %vr, %va, %vb;
3891
3892 cvt.u64.u32 %off_out, %r_tid;
3893 shl.b64 %off_out, %off_out, 2;
3894 add.u64 %off_out, %out, %off_out;
3895 st.global.f32 [%off_out], %vr;
3896DONE:
3897 ret;
3898}
3899";
3900
3901#[cfg(feature = "cuda")]
3903pub(crate) const BROADCAST_MUL_PTX: &str = "\
3904.version 7.0
3905.target sm_52
3906.address_size 64
3907
3908.visible .entry broadcast_mul_kernel(
3909 .param .u64 a_ptr,
3910 .param .u64 b_ptr,
3911 .param .u64 out_ptr,
3912 .param .u64 a_strides_ptr,
3913 .param .u64 b_strides_ptr,
3914 .param .u64 out_shape_ptr,
3915 .param .u32 n,
3916 .param .u32 ndim
3917) {
3918 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
3919 .reg .u32 %remaining, %a_idx, %b_idx, %d;
3920 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
3921 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
3922 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
3923 .reg .f32 %va, %vb, %vr;
3924 .reg .pred %p, %loop_p;
3925
3926 ld.param.u64 %a, [a_ptr];
3927 ld.param.u64 %b, [b_ptr];
3928 ld.param.u64 %out, [out_ptr];
3929 ld.param.u64 %a_str, [a_strides_ptr];
3930 ld.param.u64 %b_str, [b_strides_ptr];
3931 ld.param.u64 %oshape, [out_shape_ptr];
3932 ld.param.u32 %n_reg, [n];
3933 ld.param.u32 %ndim_reg, [ndim];
3934
3935 mov.u32 %bid, %ctaid.x;
3936 mov.u32 %bdim, %ntid.x;
3937 mov.u32 %r_tid, %tid.x;
3938 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3939 setp.ge.u32 %p, %r_tid, %n_reg;
3940 @%p bra DONE;
3941
3942 mov.u32 %remaining, %r_tid;
3943 mov.u32 %a_idx, 0;
3944 mov.u32 %b_idx, 0;
3945 mov.u32 %d, %ndim_reg;
3946LOOP:
3947 setp.eq.u32 %loop_p, %d, 0;
3948 @%loop_p bra END_LOOP;
3949 sub.u32 %d, %d, 1;
3950 cvt.u64.u32 %d64, %d;
3951 shl.b64 %d64, %d64, 2;
3952 add.u64 %tmp, %oshape, %d64;
3953 ld.global.u32 %shape_d, [%tmp];
3954 add.u64 %tmp, %a_str, %d64;
3955 ld.global.u32 %a_str_d, [%tmp];
3956 add.u64 %tmp, %b_str, %d64;
3957 ld.global.u32 %b_str_d, [%tmp];
3958 rem.u32 %coord, %remaining, %shape_d;
3959 div.u32 %remaining, %remaining, %shape_d;
3960 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
3961 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
3962 bra LOOP;
3963END_LOOP:
3964
3965 cvt.u64.u32 %off_a, %a_idx;
3966 shl.b64 %off_a, %off_a, 2;
3967 add.u64 %off_a, %a, %off_a;
3968 ld.global.f32 %va, [%off_a];
3969 cvt.u64.u32 %off_b, %b_idx;
3970 shl.b64 %off_b, %off_b, 2;
3971 add.u64 %off_b, %b, %off_b;
3972 ld.global.f32 %vb, [%off_b];
3973
3974 mul.f32 %vr, %va, %vb;
3975
3976 cvt.u64.u32 %off_out, %r_tid;
3977 shl.b64 %off_out, %off_out, 2;
3978 add.u64 %off_out, %out, %off_out;
3979 st.global.f32 [%off_out], %vr;
3980DONE:
3981 ret;
3982}
3983";
3984
3985#[cfg(feature = "cuda")]
3988pub(crate) const BROADCAST_DIV_PTX: &str = "\
3989.version 7.0
3990.target sm_52
3991.address_size 64
3992
3993.visible .entry broadcast_div_kernel(
3994 .param .u64 a_ptr,
3995 .param .u64 b_ptr,
3996 .param .u64 out_ptr,
3997 .param .u64 a_strides_ptr,
3998 .param .u64 b_strides_ptr,
3999 .param .u64 out_shape_ptr,
4000 .param .u32 n,
4001 .param .u32 ndim
4002) {
4003 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
4004 .reg .u32 %remaining, %a_idx, %b_idx, %d;
4005 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
4006 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
4007 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
4008 .reg .f32 %va, %vb, %vr;
4009 .reg .pred %p, %loop_p;
4010
4011 ld.param.u64 %a, [a_ptr];
4012 ld.param.u64 %b, [b_ptr];
4013 ld.param.u64 %out, [out_ptr];
4014 ld.param.u64 %a_str, [a_strides_ptr];
4015 ld.param.u64 %b_str, [b_strides_ptr];
4016 ld.param.u64 %oshape, [out_shape_ptr];
4017 ld.param.u32 %n_reg, [n];
4018 ld.param.u32 %ndim_reg, [ndim];
4019
4020 mov.u32 %bid, %ctaid.x;
4021 mov.u32 %bdim, %ntid.x;
4022 mov.u32 %r_tid, %tid.x;
4023 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4024 setp.ge.u32 %p, %r_tid, %n_reg;
4025 @%p bra DONE;
4026
4027 mov.u32 %remaining, %r_tid;
4028 mov.u32 %a_idx, 0;
4029 mov.u32 %b_idx, 0;
4030 mov.u32 %d, %ndim_reg;
4031LOOP:
4032 setp.eq.u32 %loop_p, %d, 0;
4033 @%loop_p bra END_LOOP;
4034 sub.u32 %d, %d, 1;
4035 cvt.u64.u32 %d64, %d;
4036 shl.b64 %d64, %d64, 2;
4037 add.u64 %tmp, %oshape, %d64;
4038 ld.global.u32 %shape_d, [%tmp];
4039 add.u64 %tmp, %a_str, %d64;
4040 ld.global.u32 %a_str_d, [%tmp];
4041 add.u64 %tmp, %b_str, %d64;
4042 ld.global.u32 %b_str_d, [%tmp];
4043 rem.u32 %coord, %remaining, %shape_d;
4044 div.u32 %remaining, %remaining, %shape_d;
4045 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
4046 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
4047 bra LOOP;
4048END_LOOP:
4049
4050 cvt.u64.u32 %off_a, %a_idx;
4051 shl.b64 %off_a, %off_a, 2;
4052 add.u64 %off_a, %a, %off_a;
4053 ld.global.f32 %va, [%off_a];
4054 cvt.u64.u32 %off_b, %b_idx;
4055 shl.b64 %off_b, %off_b, 2;
4056 add.u64 %off_b, %b, %off_b;
4057 ld.global.f32 %vb, [%off_b];
4058
4059 div.f32 %vr, %va, %vb;
4060
4061 cvt.u64.u32 %off_out, %r_tid;
4062 shl.b64 %off_out, %off_out, 2;
4063 add.u64 %off_out, %out, %off_out;
4064 st.global.f32 [%off_out], %vr;
4065DONE:
4066 ret;
4067}
4068";
4069
4070#[cfg(feature = "cuda")]
4078pub(crate) const STRIDED_SPLIT_PTX: &str = "\
4079.version 7.0
4080.target sm_52
4081.address_size 64
4082
4083.visible .entry strided_split_kernel(
4084 .param .u64 input_ptr,
4085 .param .u64 output_ptr,
4086 .param .u32 total_along_axis,
4087 .param .u32 split_offset,
4088 .param .u32 split_size,
4089 .param .u32 inner_size,
4090 .param .u32 n
4091) {
4092 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4093 .reg .u32 %total_ax, %sp_off, %sp_sz, %inner_sz;
4094 .reg .u32 %outer_idx, %within, %chunk_stride, %src_idx, %base_off, %tmp;
4095 .reg .u64 %in, %out, %off;
4096 .reg .f32 %val;
4097 .reg .pred %p;
4098
4099 ld.param.u64 %in, [input_ptr];
4100 ld.param.u64 %out, [output_ptr];
4101 ld.param.u32 %total_ax, [total_along_axis];
4102 ld.param.u32 %sp_off, [split_offset];
4103 ld.param.u32 %sp_sz, [split_size];
4104 ld.param.u32 %inner_sz, [inner_size];
4105 ld.param.u32 %n_reg, [n];
4106
4107 mov.u32 %bid, %ctaid.x;
4108 mov.u32 %bdim, %ntid.x;
4109 mov.u32 %r_tid, %tid.x;
4110 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4111
4112 setp.ge.u32 %p, %r_tid, %n_reg;
4113 @%p bra DONE;
4114
4115 // chunk_stride = split_size * inner_size
4116 mul.lo.u32 %chunk_stride, %sp_sz, %inner_sz;
4117
4118 // outer_idx = r_tid / chunk_stride
4119 div.u32 %outer_idx, %r_tid, %chunk_stride;
4120
4121 // within = r_tid % chunk_stride
4122 rem.u32 %within, %r_tid, %chunk_stride;
4123
4124 // base_off = split_offset * inner_size
4125 mul.lo.u32 %base_off, %sp_off, %inner_sz;
4126
4127 // src_idx = outer_idx * total_along_axis * inner_size + base_off + within
4128 mul.lo.u32 %src_idx, %outer_idx, %total_ax;
4129 mul.lo.u32 %src_idx, %src_idx, %inner_sz;
4130 add.u32 %src_idx, %src_idx, %base_off;
4131 add.u32 %src_idx, %src_idx, %within;
4132
4133 // Load from in[src_idx]
4134 cvt.u64.u32 %off, %src_idx;
4135 shl.b64 %off, %off, 2;
4136 add.u64 %off, %in, %off;
4137 ld.global.f32 %val, [%off];
4138
4139 // Store to out[r_tid]
4140 cvt.u64.u32 %off, %r_tid;
4141 shl.b64 %off, %off, 2;
4142 add.u64 %off, %out, %off;
4143 st.global.f32 [%off], %val;
4144
4145DONE:
4146 ret;
4147}
4148";
4149
4150#[cfg(feature = "cuda")]
4159pub(crate) const STRIDED_CAT_PTX: &str = "\
4160.version 7.0
4161.target sm_52
4162.address_size 64
4163
4164.visible .entry strided_cat_kernel(
4165 .param .u64 input_ptr,
4166 .param .u64 output_ptr,
4167 .param .u32 total_along_axis,
4168 .param .u32 cat_offset,
4169 .param .u32 part_size,
4170 .param .u32 inner_size,
4171 .param .u32 n
4172) {
4173 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4174 .reg .u32 %total_ax, %cat_off, %part_sz, %inner_sz;
4175 .reg .u32 %outer_idx, %within, %chunk_stride, %dst_idx, %base_off;
4176 .reg .u64 %in, %out, %off;
4177 .reg .f32 %val;
4178 .reg .pred %p;
4179
4180 ld.param.u64 %in, [input_ptr];
4181 ld.param.u64 %out, [output_ptr];
4182 ld.param.u32 %total_ax, [total_along_axis];
4183 ld.param.u32 %cat_off, [cat_offset];
4184 ld.param.u32 %part_sz, [part_size];
4185 ld.param.u32 %inner_sz, [inner_size];
4186 ld.param.u32 %n_reg, [n];
4187
4188 mov.u32 %bid, %ctaid.x;
4189 mov.u32 %bdim, %ntid.x;
4190 mov.u32 %r_tid, %tid.x;
4191 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4192
4193 setp.ge.u32 %p, %r_tid, %n_reg;
4194 @%p bra DONE;
4195
4196 // chunk_stride = part_size * inner_size
4197 mul.lo.u32 %chunk_stride, %part_sz, %inner_sz;
4198
4199 // outer_idx = r_tid / chunk_stride
4200 div.u32 %outer_idx, %r_tid, %chunk_stride;
4201
4202 // within = r_tid % chunk_stride
4203 rem.u32 %within, %r_tid, %chunk_stride;
4204
4205 // base_off = cat_offset * inner_size
4206 mul.lo.u32 %base_off, %cat_off, %inner_sz;
4207
4208 // dst_idx = outer_idx * total_along_axis * inner_size + base_off + within
4209 mul.lo.u32 %dst_idx, %outer_idx, %total_ax;
4210 mul.lo.u32 %dst_idx, %dst_idx, %inner_sz;
4211 add.u32 %dst_idx, %dst_idx, %base_off;
4212 add.u32 %dst_idx, %dst_idx, %within;
4213
4214 // Load from in[r_tid]
4215 cvt.u64.u32 %off, %r_tid;
4216 shl.b64 %off, %off, 2;
4217 add.u64 %off, %in, %off;
4218 ld.global.f32 %val, [%off];
4219
4220 // Store to out[dst_idx]
4221 cvt.u64.u32 %off, %dst_idx;
4222 shl.b64 %off, %off, 2;
4223 add.u64 %off, %out, %off;
4224 st.global.f32 [%off], %val;
4225
4226DONE:
4227 ret;
4228}
4229";
4230
4231#[cfg(feature = "cuda")]
4233pub(crate) const DIV_PTX: &str = "\
4234.version 7.0
4235.target sm_52
4236.address_size 64
4237
4238.visible .entry div_kernel(
4239 .param .u64 a_ptr,
4240 .param .u64 b_ptr,
4241 .param .u64 out_ptr,
4242 .param .u32 n
4243) {
4244 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4245 .reg .u64 %a, %b, %out, %off;
4246 .reg .f32 %va, %vb, %vr;
4247 .reg .pred %p;
4248
4249 ld.param.u64 %a, [a_ptr];
4250 ld.param.u64 %b, [b_ptr];
4251 ld.param.u64 %out, [out_ptr];
4252 ld.param.u32 %n_reg, [n];
4253
4254 mov.u32 %bid, %ctaid.x;
4255 mov.u32 %bdim, %ntid.x;
4256 mov.u32 %r_tid, %tid.x;
4257 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4258
4259 setp.ge.u32 %p, %r_tid, %n_reg;
4260 @%p bra DONE;
4261
4262 cvt.u64.u32 %off, %r_tid;
4263 shl.b64 %off, %off, 2;
4264
4265 add.u64 %a, %a, %off;
4266 add.u64 %b, %b, %off;
4267 add.u64 %out, %out, %off;
4268
4269 ld.global.f32 %va, [%a];
4270 ld.global.f32 %vb, [%b];
4271 div.rn.f32 %vr, %va, %vb;
4272 st.global.f32 [%out], %vr;
4273
4274DONE:
4275 ret;
4276}
4277";
4278
4279#[cfg(feature = "cuda")]
4281pub(crate) const EXP_PTX: &str = "\
4282.version 7.0
4283.target sm_52
4284.address_size 64
4285
4286.visible .entry exp_kernel(
4287 .param .u64 a_ptr,
4288 .param .u64 out_ptr,
4289 .param .u32 n
4290) {
4291 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4292 .reg .u64 %a, %out, %off;
4293 .reg .f32 %va, %vr;
4294 .reg .pred %p;
4295
4296 ld.param.u64 %a, [a_ptr];
4297 ld.param.u64 %out, [out_ptr];
4298 ld.param.u32 %n_reg, [n];
4299
4300 mov.u32 %bid, %ctaid.x;
4301 mov.u32 %bdim, %ntid.x;
4302 mov.u32 %r_tid, %tid.x;
4303 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4304
4305 setp.ge.u32 %p, %r_tid, %n_reg;
4306 @%p bra DONE;
4307
4308 cvt.u64.u32 %off, %r_tid;
4309 shl.b64 %off, %off, 2;
4310
4311 add.u64 %a, %a, %off;
4312 add.u64 %out, %out, %off;
4313
4314 ld.global.f32 %va, [%a];
4315 // PTX ex2.approx computes 2^x; use the identity exp(x) = 2^(x * log2(e))
4316 // log2(e) = 1.4426950408889634
4317 mul.f32 %va, %va, 0f3FB8AA3B;
4318 ex2.approx.f32 %vr, %va;
4319 st.global.f32 [%out], %vr;
4320
4321DONE:
4322 ret;
4323}
4324";
4325
4326#[cfg(feature = "cuda")]
4328pub(crate) const LOG_PTX: &str = "\
4329.version 7.0
4330.target sm_52
4331.address_size 64
4332
4333.visible .entry log_kernel(
4334 .param .u64 a_ptr,
4335 .param .u64 out_ptr,
4336 .param .u32 n
4337) {
4338 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4339 .reg .u64 %a, %out, %off;
4340 .reg .f32 %va, %vr;
4341 .reg .pred %p;
4342
4343 ld.param.u64 %a, [a_ptr];
4344 ld.param.u64 %out, [out_ptr];
4345 ld.param.u32 %n_reg, [n];
4346
4347 mov.u32 %bid, %ctaid.x;
4348 mov.u32 %bdim, %ntid.x;
4349 mov.u32 %r_tid, %tid.x;
4350 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4351
4352 setp.ge.u32 %p, %r_tid, %n_reg;
4353 @%p bra DONE;
4354
4355 cvt.u64.u32 %off, %r_tid;
4356 shl.b64 %off, %off, 2;
4357
4358 add.u64 %a, %a, %off;
4359 add.u64 %out, %out, %off;
4360
4361 ld.global.f32 %va, [%a];
4362 // PTX lg2.approx computes log2(x); use the identity ln(x) = log2(x) / log2(e)
4363 // 1/log2(e) = ln(2) = 0.6931471805599453
4364 lg2.approx.f32 %vr, %va;
4365 mul.f32 %vr, %vr, 0f3F317218;
4366 st.global.f32 [%out], %vr;
4367
4368DONE:
4369 ret;
4370}
4371";
4372
4373#[cfg(feature = "cuda")]
4375pub(crate) const SQRT_PTX: &str = "\
4376.version 7.0
4377.target sm_52
4378.address_size 64
4379
4380.visible .entry sqrt_kernel(
4381 .param .u64 a_ptr,
4382 .param .u64 out_ptr,
4383 .param .u32 n
4384) {
4385 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4386 .reg .u64 %a, %out, %off;
4387 .reg .f32 %va, %vr;
4388 .reg .pred %p;
4389
4390 ld.param.u64 %a, [a_ptr];
4391 ld.param.u64 %out, [out_ptr];
4392 ld.param.u32 %n_reg, [n];
4393
4394 mov.u32 %bid, %ctaid.x;
4395 mov.u32 %bdim, %ntid.x;
4396 mov.u32 %r_tid, %tid.x;
4397 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4398
4399 setp.ge.u32 %p, %r_tid, %n_reg;
4400 @%p bra DONE;
4401
4402 cvt.u64.u32 %off, %r_tid;
4403 shl.b64 %off, %off, 2;
4404
4405 add.u64 %a, %a, %off;
4406 add.u64 %out, %out, %off;
4407
4408 ld.global.f32 %va, [%a];
4409 sqrt.rn.f32 %vr, %va;
4410 st.global.f32 [%out], %vr;
4411
4412DONE:
4413 ret;
4414}
4415";
4416
4417#[cfg(feature = "cuda")]
4420pub(crate) const POW_PTX: &str = "\
4421.version 7.0
4422.target sm_52
4423.address_size 64
4424
4425.visible .entry pow_kernel(
4426 .param .u64 a_ptr,
4427 .param .u64 out_ptr,
4428 .param .f32 exponent,
4429 .param .u32 n
4430) {
4431 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4432 .reg .u64 %a, %out, %off;
4433 .reg .f32 %va, %vr, %exp, %lg;
4434 .reg .pred %p;
4435
4436 ld.param.u64 %a, [a_ptr];
4437 ld.param.u64 %out, [out_ptr];
4438 ld.param.f32 %exp, [exponent];
4439 ld.param.u32 %n_reg, [n];
4440
4441 mov.u32 %bid, %ctaid.x;
4442 mov.u32 %bdim, %ntid.x;
4443 mov.u32 %r_tid, %tid.x;
4444 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4445
4446 setp.ge.u32 %p, %r_tid, %n_reg;
4447 @%p bra DONE;
4448
4449 cvt.u64.u32 %off, %r_tid;
4450 shl.b64 %off, %off, 2;
4451
4452 add.u64 %a, %a, %off;
4453 add.u64 %out, %out, %off;
4454
4455 ld.global.f32 %va, [%a];
4456 // x^e = 2^(e * log2(x))
4457 lg2.approx.f32 %lg, %va;
4458 mul.f32 %lg, %lg, %exp;
4459 ex2.approx.f32 %vr, %lg;
4460 st.global.f32 [%out], %vr;
4461
4462DONE:
4463 ret;
4464}
4465";
4466
4467#[cfg(feature = "cuda")]
4469pub(crate) const ABS_PTX: &str = "\
4470.version 7.0
4471.target sm_52
4472.address_size 64
4473
4474.visible .entry abs_kernel(
4475 .param .u64 a_ptr,
4476 .param .u64 out_ptr,
4477 .param .u32 n
4478) {
4479 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4480 .reg .u64 %a, %out, %off;
4481 .reg .f32 %va, %vr;
4482 .reg .pred %p;
4483
4484 ld.param.u64 %a, [a_ptr];
4485 ld.param.u64 %out, [out_ptr];
4486 ld.param.u32 %n_reg, [n];
4487
4488 mov.u32 %bid, %ctaid.x;
4489 mov.u32 %bdim, %ntid.x;
4490 mov.u32 %r_tid, %tid.x;
4491 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4492
4493 setp.ge.u32 %p, %r_tid, %n_reg;
4494 @%p bra DONE;
4495
4496 cvt.u64.u32 %off, %r_tid;
4497 shl.b64 %off, %off, 2;
4498
4499 add.u64 %a, %a, %off;
4500 add.u64 %out, %out, %off;
4501
4502 ld.global.f32 %va, [%a];
4503 abs.f32 %vr, %va;
4504 st.global.f32 [%out], %vr;
4505
4506DONE:
4507 ret;
4508}
4509";
4510
4511#[cfg(feature = "cuda")]
4513pub(crate) const SIGMOID_PTX: &str = "\
4514.version 7.0
4515.target sm_52
4516.address_size 64
4517
4518.visible .entry sigmoid_kernel(
4519 .param .u64 a_ptr,
4520 .param .u64 out_ptr,
4521 .param .u32 n
4522) {
4523 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4524 .reg .u64 %a, %out, %off;
4525 .reg .f32 %va, %vr, %neg, %e, %denom, %one, %lg2e;
4526 .reg .pred %p;
4527
4528 ld.param.u64 %a, [a_ptr];
4529 ld.param.u64 %out, [out_ptr];
4530 ld.param.u32 %n_reg, [n];
4531
4532 mov.u32 %bid, %ctaid.x;
4533 mov.u32 %bdim, %ntid.x;
4534 mov.u32 %r_tid, %tid.x;
4535 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4536
4537 setp.ge.u32 %p, %r_tid, %n_reg;
4538 @%p bra DONE;
4539
4540 cvt.u64.u32 %off, %r_tid;
4541 shl.b64 %off, %off, 2;
4542
4543 add.u64 %a, %a, %off;
4544 add.u64 %out, %out, %off;
4545
4546 ld.global.f32 %va, [%a];
4547 // sigmoid(x) = 1 / (1 + exp(-x))
4548 neg.f32 %neg, %va;
4549 mov.f32 %lg2e, 0f3FB8AA3B;
4550 mul.f32 %neg, %neg, %lg2e;
4551 ex2.approx.f32 %e, %neg;
4552 mov.f32 %one, 0f3F800000;
4553 add.f32 %denom, %one, %e;
4554 div.rn.f32 %vr, %one, %denom;
4555 st.global.f32 [%out], %vr;
4556
4557DONE:
4558 ret;
4559}
4560";
4561
4562#[cfg(feature = "cuda")]
4565pub(crate) const TANH_PTX: &str = "\
4566.version 7.0
4567.target sm_52
4568.address_size 64
4569
4570.visible .entry tanh_kernel(
4571 .param .u64 a_ptr,
4572 .param .u64 out_ptr,
4573 .param .u32 n
4574) {
4575 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4576 .reg .u64 %a, %out, %off;
4577 .reg .f32 %va, %vr, %neg2x, %e, %denom, %sig, %one, %two, %lg2e;
4578 .reg .pred %p;
4579
4580 ld.param.u64 %a, [a_ptr];
4581 ld.param.u64 %out, [out_ptr];
4582 ld.param.u32 %n_reg, [n];
4583
4584 mov.u32 %bid, %ctaid.x;
4585 mov.u32 %bdim, %ntid.x;
4586 mov.u32 %r_tid, %tid.x;
4587 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4588
4589 setp.ge.u32 %p, %r_tid, %n_reg;
4590 @%p bra DONE;
4591
4592 cvt.u64.u32 %off, %r_tid;
4593 shl.b64 %off, %off, 2;
4594
4595 add.u64 %a, %a, %off;
4596 add.u64 %out, %out, %off;
4597
4598 ld.global.f32 %va, [%a];
4599 // tanh(x) = 2*sigmoid(2x) - 1
4600 mov.f32 %two, 0f40000000;
4601 mul.f32 %neg2x, %va, %two;
4602 neg.f32 %neg2x, %neg2x;
4603 mov.f32 %lg2e, 0f3FB8AA3B;
4604 mul.f32 %neg2x, %neg2x, %lg2e;
4605 ex2.approx.f32 %e, %neg2x;
4606 mov.f32 %one, 0f3F800000;
4607 add.f32 %denom, %one, %e;
4608 div.rn.f32 %sig, %one, %denom;
4609 mul.f32 %vr, %two, %sig;
4610 sub.f32 %vr, %vr, %one;
4611 st.global.f32 [%out], %vr;
4612
4613DONE:
4614 ret;
4615}
4616";
4617
4618#[cfg(feature = "cuda")]
4628pub(crate) const FUSED_ADAM_PTX: &str = "\
4629.version 7.0
4630.target sm_52
4631.address_size 64
4632
4633.visible .entry fused_adam_kernel(
4634 .param .u64 param_ptr,
4635 .param .u64 grad_ptr,
4636 .param .u64 exp_avg_ptr,
4637 .param .u64 exp_avg_sq_ptr,
4638 .param .f32 beta1,
4639 .param .f32 beta2,
4640 .param .f32 lr,
4641 .param .f32 eps,
4642 .param .f32 bc1,
4643 .param .f32 bc2,
4644 .param .f32 weight_decay,
4645 .param .u32 n
4646) {
4647 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4648 .reg .u64 %p, %g, %m, %v, %off;
4649 .reg .f32 %vp, %vg, %vm, %vv;
4650 .reg .f32 %b1, %b2, %f_lr, %f_eps, %f_bc1, %f_bc2, %f_wd;
4651 .reg .f32 %t1, %t2, %m_hat, %v_hat, %denom, %update;
4652 .reg .f32 %one;
4653 .reg .pred %p_bound, %p_wd;
4654
4655 ld.param.u64 %p, [param_ptr];
4656 ld.param.u64 %g, [grad_ptr];
4657 ld.param.u64 %m, [exp_avg_ptr];
4658 ld.param.u64 %v, [exp_avg_sq_ptr];
4659 ld.param.f32 %b1, [beta1];
4660 ld.param.f32 %b2, [beta2];
4661 ld.param.f32 %f_lr, [lr];
4662 ld.param.f32 %f_eps, [eps];
4663 ld.param.f32 %f_bc1, [bc1];
4664 ld.param.f32 %f_bc2, [bc2];
4665 ld.param.f32 %f_wd, [weight_decay];
4666 ld.param.u32 %n_reg, [n];
4667
4668 mov.u32 %bid, %ctaid.x;
4669 mov.u32 %bdim, %ntid.x;
4670 mov.u32 %r_tid, %tid.x;
4671 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4672
4673 setp.ge.u32 %p_bound, %r_tid, %n_reg;
4674 @%p_bound bra DONE;
4675
4676 cvt.u64.u32 %off, %r_tid;
4677 shl.b64 %off, %off, 2;
4678
4679 add.u64 %p, %p, %off;
4680 add.u64 %g, %g, %off;
4681 add.u64 %m, %m, %off;
4682 add.u64 %v, %v, %off;
4683
4684 ld.global.f32 %vp, [%p];
4685 ld.global.f32 %vg, [%g];
4686 ld.global.f32 %vm, [%m];
4687 ld.global.f32 %vv, [%v];
4688
4689 // L2 weight decay: g = g + wd * p
4690 mov.f32 %one, 0f00000000;
4691 setp.gt.f32 %p_wd, %f_wd, %one;
4692 @%p_wd fma.rn.f32 %vg, %f_wd, %vp, %vg;
4693
4694 // exp_avg = beta1 * exp_avg + (1 - beta1) * g
4695 mov.f32 %one, 0f3F800000;
4696 sub.f32 %t1, %one, %b1;
4697 mul.f32 %vm, %vm, %b1;
4698 fma.rn.f32 %vm, %t1, %vg, %vm;
4699
4700 // exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * g * g
4701 sub.f32 %t2, %one, %b2;
4702 mul.f32 %vv, %vv, %b2;
4703 mul.f32 %t1, %vg, %vg;
4704 fma.rn.f32 %vv, %t2, %t1, %vv;
4705
4706 // m_hat = exp_avg / bc1
4707 div.rn.f32 %m_hat, %vm, %f_bc1;
4708
4709 // v_hat = exp_avg_sq / bc2
4710 div.rn.f32 %v_hat, %vv, %f_bc2;
4711
4712 // denom = sqrt(v_hat) + eps
4713 sqrt.rn.f32 %denom, %v_hat;
4714 add.f32 %denom, %denom, %f_eps;
4715
4716 // param = param - lr * m_hat / denom
4717 div.rn.f32 %update, %m_hat, %denom;
4718 mul.f32 %update, %update, %f_lr;
4719 sub.f32 %vp, %vp, %update;
4720
4721 st.global.f32 [%p], %vp;
4722 st.global.f32 [%m], %vm;
4723 st.global.f32 [%v], %vv;
4724
4725DONE:
4726 ret;
4727}
4728";
4729
4730#[cfg(feature = "cuda")]
4742pub(crate) const FUSED_GRU_FORWARD_PTX: &str = "\
4743.version 7.0
4744.target sm_52
4745.address_size 64
4746
4747.visible .entry fused_gru_forward_kernel(
4748 .param .u64 input_gates_ptr,
4749 .param .u64 hidden_gates_ptr,
4750 .param .u64 bias_ih_ptr,
4751 .param .u64 bias_hh_ptr,
4752 .param .u64 hx_ptr,
4753 .param .u64 hy_ptr,
4754 .param .u64 workspace_ptr,
4755 .param .u32 hsz,
4756 .param .u32 total
4757) {
4758 .reg .u32 %tid, %bid, %bdim, %gdim, %total_reg, %hsz_reg;
4759 .reg .u32 %idx, %stride, %offset3, %offset5, %hmod, %batch_idx;
4760 .reg .u64 %ig, %hg, %b1, %b2, %hx, %hy, %ws;
4761 .reg .u64 %off64, %tmp64;
4762 .reg .f32 %ir, %ii, %in, %hr, %hi, %hn;
4763 .reg .f32 %b1r, %b1i, %b1n, %b2r, %b2i, %b2n;
4764 .reg .f32 %hx_val, %rg, %zg, %ng, %hy_val;
4765 .reg .f32 %one, %neg_one, %exp_val, %denom, %tmp;
4766 .reg .pred %p;
4767
4768 ld.param.u64 %ig, [input_gates_ptr];
4769 ld.param.u64 %hg, [hidden_gates_ptr];
4770 ld.param.u64 %b1, [bias_ih_ptr];
4771 ld.param.u64 %b2, [bias_hh_ptr];
4772 ld.param.u64 %hx, [hx_ptr];
4773 ld.param.u64 %hy, [hy_ptr];
4774 ld.param.u64 %ws, [workspace_ptr];
4775 ld.param.u32 %hsz_reg, [hsz];
4776 ld.param.u32 %total_reg, [total];
4777
4778 mov.u32 %bid, %ctaid.x;
4779 mov.u32 %bdim, %ntid.x;
4780 mov.u32 %tid, %tid.x;
4781 mov.u32 %gdim, %nctaid.x;
4782 mad.lo.u32 %idx, %bid, %bdim, %tid;
4783 mul.lo.u32 %stride, %bdim, %gdim;
4784 mov.f32 %one, 0f3F800000;
4785
4786LOOP:
4787 setp.ge.u32 %p, %idx, %total_reg;
4788 @%p bra END;
4789
4790 // offset3 = (idx/hsz)*3*hsz + idx%hsz (into [B, 3*H] gates tensor)
4791 div.u32 %batch_idx, %idx, %hsz_reg;
4792 rem.u32 %hmod, %idx, %hsz_reg;
4793 mul.lo.u32 %offset3, %batch_idx, %hsz_reg;
4794 mul.lo.u32 %offset3, %offset3, 3;
4795 add.u32 %offset3, %offset3, %hmod;
4796
4797 // Load input gate components: ir, ii, in
4798 cvt.u64.u32 %off64, %offset3;
4799 shl.b64 %off64, %off64, 2;
4800 add.u64 %tmp64, %ig, %off64;
4801 ld.global.f32 %ir, [%tmp64];
4802 cvt.u64.u32 %off64, %hsz_reg;
4803 shl.b64 %off64, %off64, 2;
4804 add.u64 %tmp64, %tmp64, %off64;
4805 ld.global.f32 %ii, [%tmp64];
4806 add.u64 %tmp64, %tmp64, %off64;
4807 ld.global.f32 %in, [%tmp64];
4808
4809 // Load hidden gate components: hr, hi, hn
4810 cvt.u64.u32 %off64, %offset3;
4811 shl.b64 %off64, %off64, 2;
4812 add.u64 %tmp64, %hg, %off64;
4813 ld.global.f32 %hr, [%tmp64];
4814 cvt.u64.u32 %off64, %hsz_reg;
4815 shl.b64 %off64, %off64, 2;
4816 add.u64 %tmp64, %tmp64, %off64;
4817 ld.global.f32 %hi, [%tmp64];
4818 add.u64 %tmp64, %tmp64, %off64;
4819 ld.global.f32 %hn, [%tmp64];
4820
4821 // Load biases (indexed by hmod, hmod+hsz, hmod+2*hsz)
4822 cvt.u64.u32 %off64, %hmod;
4823 shl.b64 %off64, %off64, 2;
4824 add.u64 %tmp64, %b1, %off64;
4825 ld.global.f32 %b1r, [%tmp64];
4826 cvt.u64.u32 %off64, %hsz_reg;
4827 shl.b64 %off64, %off64, 2;
4828 add.u64 %tmp64, %tmp64, %off64;
4829 ld.global.f32 %b1i, [%tmp64];
4830 add.u64 %tmp64, %tmp64, %off64;
4831 ld.global.f32 %b1n, [%tmp64];
4832
4833 cvt.u64.u32 %off64, %hmod;
4834 shl.b64 %off64, %off64, 2;
4835 add.u64 %tmp64, %b2, %off64;
4836 ld.global.f32 %b2r, [%tmp64];
4837 cvt.u64.u32 %off64, %hsz_reg;
4838 shl.b64 %off64, %off64, 2;
4839 add.u64 %tmp64, %tmp64, %off64;
4840 ld.global.f32 %b2i, [%tmp64];
4841 add.u64 %tmp64, %tmp64, %off64;
4842 ld.global.f32 %b2n, [%tmp64];
4843
4844 // Load hx[idx]
4845 cvt.u64.u32 %off64, %idx;
4846 shl.b64 %off64, %off64, 2;
4847 add.u64 %tmp64, %hx, %off64;
4848 ld.global.f32 %hx_val, [%tmp64];
4849
4850 // r = sigmoid(ir + hr + b1r + b2r)
4851 add.f32 %rg, %ir, %hr;
4852 add.f32 %rg, %rg, %b1r;
4853 add.f32 %rg, %rg, %b2r;
4854 neg.f32 %tmp, %rg;
4855 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
4856 ex2.approx.f32 %exp_val, %tmp;
4857 add.f32 %denom, %one, %exp_val;
4858 div.rn.f32 %rg, %one, %denom;
4859
4860 // z = sigmoid(ii + hi + b1i + b2i)
4861 add.f32 %zg, %ii, %hi;
4862 add.f32 %zg, %zg, %b1i;
4863 add.f32 %zg, %zg, %b2i;
4864 neg.f32 %tmp, %zg;
4865 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
4866 ex2.approx.f32 %exp_val, %tmp;
4867 add.f32 %denom, %one, %exp_val;
4868 div.rn.f32 %zg, %one, %denom;
4869
4870 // n = tanh(in + b1n + r*(hn + b2n))
4871 add.f32 %tmp, %hn, %b2n;
4872 fma.rn.f32 %ng, %rg, %tmp, %in;
4873 add.f32 %ng, %ng, %b1n;
4874 // tanh via 2*sigmoid(2x)-1
4875 mul.f32 %tmp, %ng, 0f40000000;
4876 neg.f32 %tmp, %tmp;
4877 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
4878 ex2.approx.f32 %exp_val, %tmp;
4879 add.f32 %denom, %one, %exp_val;
4880 div.rn.f32 %ng, %one, %denom;
4881 mul.f32 %ng, %ng, 0f40000000;
4882 sub.f32 %ng, %ng, %one;
4883
4884 // hy = n + z * (hx - n)
4885 sub.f32 %tmp, %hx_val, %ng;
4886 fma.rn.f32 %hy_val, %zg, %tmp, %ng;
4887
4888 // Store hy[idx]
4889 cvt.u64.u32 %off64, %idx;
4890 shl.b64 %off64, %off64, 2;
4891 add.u64 %tmp64, %hy, %off64;
4892 st.global.f32 [%tmp64], %hy_val;
4893
4894 // Store workspace: [r, z, n, hx, hn+b2n] at offset5 = (idx/hsz)*5*hsz + idx%hsz
4895 mul.lo.u32 %offset5, %batch_idx, %hsz_reg;
4896 mul.lo.u32 %offset5, %offset5, 5;
4897 add.u32 %offset5, %offset5, %hmod;
4898
4899 cvt.u64.u32 %off64, %offset5;
4900 shl.b64 %off64, %off64, 2;
4901 add.u64 %tmp64, %ws, %off64;
4902 st.global.f32 [%tmp64], %rg;
4903 cvt.u64.u32 %off64, %hsz_reg;
4904 shl.b64 %off64, %off64, 2;
4905 add.u64 %tmp64, %tmp64, %off64;
4906 st.global.f32 [%tmp64], %zg;
4907 add.u64 %tmp64, %tmp64, %off64;
4908 st.global.f32 [%tmp64], %ng;
4909 add.u64 %tmp64, %tmp64, %off64;
4910 st.global.f32 [%tmp64], %hx_val;
4911 add.u64 %tmp64, %tmp64, %off64;
4912 add.f32 %tmp, %hn, %b2n;
4913 st.global.f32 [%tmp64], %tmp;
4914
4915 add.u32 %idx, %idx, %stride;
4916 bra LOOP;
4917
4918END:
4919 ret;
4920}
4921";
4922
4923#[cfg(feature = "cuda")]
4937fn launch_cfg(n: usize) -> GpuResult<LaunchConfig> {
4938 if n > u32::MAX as usize {
4939 return Err(GpuError::ShapeMismatch {
4940 op: "kernel_launch",
4941 expected: vec![u32::MAX as usize],
4942 got: vec![n],
4943 });
4944 }
4945 const BLOCK: u32 = 256;
4946 let grid = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
4947 Ok(LaunchConfig {
4948 grid_dim: (grid.max(1), 1, 1),
4949 block_dim: (BLOCK, 1, 1),
4950 shared_mem_bytes: 0,
4951 })
4952}
4953
4954#[cfg(feature = "cuda")]
4960fn validate_binary(a: &CudaBuffer<f32>, b: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
4961 if a.device_ordinal() != device.ordinal() {
4962 return Err(GpuError::DeviceMismatch {
4963 expected: a.device_ordinal(),
4964 got: device.ordinal(),
4965 });
4966 }
4967 if b.device_ordinal() != device.ordinal() {
4968 return Err(GpuError::DeviceMismatch {
4969 expected: b.device_ordinal(),
4970 got: device.ordinal(),
4971 });
4972 }
4973 if a.len() != b.len() {
4974 return Err(GpuError::LengthMismatch {
4975 a: a.len(),
4976 b: b.len(),
4977 });
4978 }
4979 Ok(())
4980}
4981
4982#[cfg(feature = "cuda")]
4984fn validate_unary(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
4985 if a.device_ordinal() != device.ordinal() {
4986 return Err(GpuError::DeviceMismatch {
4987 expected: a.device_ordinal(),
4988 got: device.ordinal(),
4989 });
4990 }
4991 Ok(())
4992}
4993
4994#[cfg(feature = "cuda")]
5002fn try_launch_binary(
5003 a: &CudaBuffer<f32>,
5004 b: &CudaBuffer<f32>,
5005 device: &GpuDevice,
5006 ptx_src: &'static str,
5007 kernel_name: &'static str,
5008) -> GpuResult<Option<CudaBuffer<f32>>> {
5009 use cudarc::driver::PushKernelArg;
5010
5011 let n = a.len();
5012 let ctx = device.context();
5013 let stream = device.stream();
5014
5015 let f = match crate::module_cache::get_or_compile(
5019 ctx,
5020 ptx_src,
5021 kernel_name,
5022 device.ordinal() as u32,
5023 ) {
5024 Ok(f) => f,
5025 Err(_) => return Ok(None),
5026 };
5027
5028 let mut out = alloc_zeros_f32(n, device)?;
5029 let cfg = launch_cfg(n)?;
5030 let n_u32 = n as u32;
5031
5032 unsafe {
5036 stream
5037 .launch_builder(&f)
5038 .arg(a.inner())
5039 .arg(b.inner())
5040 .arg(out.inner_mut())
5041 .arg(&n_u32)
5042 .launch(cfg)?;
5043 }
5044
5045 Ok(Some(out))
5046}
5047
5048#[cfg(feature = "cuda")]
5053fn try_launch_binary_vec4(
5054 a: &CudaBuffer<f32>,
5055 b: &CudaBuffer<f32>,
5056 device: &GpuDevice,
5057 ptx_src: &'static str,
5058 kernel_name: &'static str,
5059) -> GpuResult<Option<CudaBuffer<f32>>> {
5060 use cudarc::driver::PushKernelArg;
5061
5062 let n = a.len();
5063 let n4 = (n / 4) as u32;
5064 let ctx = device.context();
5065 let stream = device.stream();
5066
5067 let f = match crate::module_cache::get_or_compile(
5068 ctx,
5069 ptx_src,
5070 kernel_name,
5071 device.ordinal() as u32,
5072 ) {
5073 Ok(f) => f,
5074 Err(_) => return Ok(None),
5075 };
5076
5077 let mut out = alloc_zeros_f32(n, device)?;
5078 let cfg = launch_cfg(n4 as usize)?;
5079
5080 unsafe {
5081 stream
5082 .launch_builder(&f)
5083 .arg(a.inner())
5084 .arg(b.inner())
5085 .arg(out.inner_mut())
5086 .arg(&n4)
5087 .launch(cfg)?;
5088 }
5089
5090 Ok(Some(out))
5091}
5092
5093#[cfg(feature = "cuda")]
5096fn try_launch_unary(
5097 a: &CudaBuffer<f32>,
5098 device: &GpuDevice,
5099 ptx_src: &'static str,
5100 kernel_name: &'static str,
5101) -> GpuResult<Option<CudaBuffer<f32>>> {
5102 use cudarc::driver::PushKernelArg;
5103
5104 let n = a.len();
5105 let ctx = device.context();
5106 let stream = device.stream();
5107
5108 let f = match crate::module_cache::get_or_compile(
5110 ctx,
5111 ptx_src,
5112 kernel_name,
5113 device.ordinal() as u32,
5114 ) {
5115 Ok(f) => f,
5116 Err(_) => return Ok(None),
5117 };
5118
5119 let mut out = alloc_zeros_f32(n, device)?;
5120 let cfg = launch_cfg(n)?;
5121 let n_u32 = n as u32;
5122
5123 unsafe {
5126 stream
5127 .launch_builder(&f)
5128 .arg(a.inner())
5129 .arg(out.inner_mut())
5130 .arg(&n_u32)
5131 .launch(cfg)?;
5132 }
5133
5134 Ok(Some(out))
5135}
5136
5137#[cfg(feature = "cuda")]
5144fn try_launch_binary_into(
5145 a: &CudaBuffer<f32>,
5146 b: &CudaBuffer<f32>,
5147 out: &mut CudaBuffer<f32>,
5148 device: &GpuDevice,
5149 ptx_src: &'static str,
5150 kernel_name: &'static str,
5151) -> GpuResult<bool> {
5152 use cudarc::driver::PushKernelArg;
5153
5154 let n = a.len();
5155 let ctx = device.context();
5156 let stream = device.stream();
5157
5158 let f = match crate::module_cache::get_or_compile(
5159 ctx,
5160 ptx_src,
5161 kernel_name,
5162 device.ordinal() as u32,
5163 ) {
5164 Ok(f) => f,
5165 Err(_) => return Ok(false),
5166 };
5167
5168 let cfg = launch_cfg(n)?;
5169 let n_u32 = n as u32;
5170
5171 unsafe {
5172 stream
5173 .launch_builder(&f)
5174 .arg(a.inner())
5175 .arg(b.inner())
5176 .arg(out.inner_mut())
5177 .arg(&n_u32)
5178 .launch(cfg)?;
5179 }
5180
5181 Ok(true)
5182}
5183
5184#[cfg(feature = "cuda")]
5187fn try_launch_unary_into(
5188 a: &CudaBuffer<f32>,
5189 out: &mut CudaBuffer<f32>,
5190 device: &GpuDevice,
5191 ptx_src: &'static str,
5192 kernel_name: &'static str,
5193) -> GpuResult<bool> {
5194 use cudarc::driver::PushKernelArg;
5195
5196 let n = a.len();
5197 let ctx = device.context();
5198 let stream = device.stream();
5199
5200 let f = match crate::module_cache::get_or_compile(
5201 ctx,
5202 ptx_src,
5203 kernel_name,
5204 device.ordinal() as u32,
5205 ) {
5206 Ok(f) => f,
5207 Err(_) => return Ok(false),
5208 };
5209
5210 let cfg = launch_cfg(n)?;
5211 let n_u32 = n as u32;
5212
5213 unsafe {
5214 stream
5215 .launch_builder(&f)
5216 .arg(a.inner())
5217 .arg(out.inner_mut())
5218 .arg(&n_u32)
5219 .launch(cfg)?;
5220 }
5221
5222 Ok(true)
5223}
5224
5225#[cfg(feature = "cuda")]
5232#[allow(clippy::too_many_arguments)]
5233fn try_launch_broadcast_binary(
5234 a: &CudaBuffer<f32>,
5235 b: &CudaBuffer<f32>,
5236 a_strides: &[u32],
5237 b_strides: &[u32],
5238 out_shape: &[u32],
5239 out_numel: usize,
5240 device: &GpuDevice,
5241 ptx_src: &'static str,
5242 kernel_name: &'static str,
5243) -> GpuResult<Option<CudaBuffer<f32>>> {
5244 use cudarc::driver::PushKernelArg;
5245
5246 let ndim = out_shape.len();
5247 let ctx = device.context();
5248 let stream = device.stream();
5249
5250 let f = match crate::module_cache::get_or_compile(
5251 ctx,
5252 ptx_src,
5253 kernel_name,
5254 device.ordinal() as u32,
5255 ) {
5256 Ok(f) => f,
5257 Err(_) => return Ok(None),
5258 };
5259
5260 let a_str_buf = cpu_to_gpu(a_strides, device)?;
5262 let b_str_buf = cpu_to_gpu(b_strides, device)?;
5263 let shape_buf = cpu_to_gpu(out_shape, device)?;
5264
5265 let mut out = alloc_zeros_f32(out_numel, device)?;
5266 let cfg = launch_cfg(out_numel)?;
5267 let n_u32 = out_numel as u32;
5268 let ndim_u32 = ndim as u32;
5269
5270 unsafe {
5273 stream
5274 .launch_builder(&f)
5275 .arg(a.inner())
5276 .arg(b.inner())
5277 .arg(out.inner_mut())
5278 .arg(a_str_buf.inner())
5279 .arg(b_str_buf.inner())
5280 .arg(shape_buf.inner())
5281 .arg(&n_u32)
5282 .arg(&ndim_u32)
5283 .launch(cfg)?;
5284 }
5285
5286 Ok(Some(out))
5287}
5288
5289#[cfg(feature = "cuda")]
5296fn broadcast_strides(in_shape: &[usize], out_shape: &[usize]) -> Vec<u32> {
5297 let ndim = out_shape.len();
5298 let in_ndim = in_shape.len();
5299 let mut strides = vec![0u32; ndim];
5300
5301 let mut stride: u32 = 1;
5303 for d in (0..ndim).rev() {
5304 let in_d = if d + in_ndim >= ndim {
5305 d + in_ndim - ndim
5306 } else {
5307 strides[d] = 0;
5309 continue;
5310 };
5311
5312 if in_shape[in_d] == 1 {
5313 strides[d] = 0; } else {
5315 strides[d] = stride;
5316 }
5317 stride *= in_shape[in_d] as u32;
5318 }
5319
5320 strides
5321}
5322
5323#[cfg(feature = "cuda")]
5330fn cpu_fallback_binary(
5331 a: &CudaBuffer<f32>,
5332 b: &CudaBuffer<f32>,
5333 device: &GpuDevice,
5334 op: fn(f32, f32) -> f32,
5335) -> GpuResult<CudaBuffer<f32>> {
5336 let a_host = gpu_to_cpu(a, device)?;
5337 let b_host = gpu_to_cpu(b, device)?;
5338 let result: Vec<f32> = a_host
5339 .iter()
5340 .zip(b_host.iter())
5341 .map(|(&x, &y)| op(x, y))
5342 .collect();
5343 cpu_to_gpu(&result, device)
5344}
5345
5346#[cfg(feature = "cuda")]
5348fn cpu_fallback_unary(
5349 a: &CudaBuffer<f32>,
5350 device: &GpuDevice,
5351 op: fn(f32) -> f32,
5352) -> GpuResult<CudaBuffer<f32>> {
5353 let a_host = gpu_to_cpu(a, device)?;
5354 let result: Vec<f32> = a_host.iter().map(|&x| op(x)).collect();
5355 cpu_to_gpu(&result, device)
5356}
5357
5358#[cfg(feature = "cuda")]
5374pub fn gpu_add(
5375 a: &CudaBuffer<f32>,
5376 b: &CudaBuffer<f32>,
5377 device: &GpuDevice,
5378) -> GpuResult<CudaBuffer<f32>> {
5379 validate_binary(a, b, device)?;
5380
5381 let n = a.len();
5383 if n >= 16 && n % 4 == 0 {
5384 if let Some(out) = try_launch_binary_vec4(
5385 a, b, device, ADD_VEC4_PTX, "add_vec4_kernel",
5386 )? {
5387 return Ok(out);
5388 }
5389 }
5390
5391 if let Some(out) = try_launch_binary(a, b, device, ADD_PTX, "add_kernel")? {
5392 return Ok(out);
5393 }
5394
5395 cpu_fallback_binary(a, b, device, |x, y| x + y)
5396}
5397
5398#[cfg(feature = "cuda")]
5410pub fn gpu_sub(
5411 a: &CudaBuffer<f32>,
5412 b: &CudaBuffer<f32>,
5413 device: &GpuDevice,
5414) -> GpuResult<CudaBuffer<f32>> {
5415 validate_binary(a, b, device)?;
5416
5417 if let Some(out) = try_launch_binary(a, b, device, SUB_PTX, "sub_kernel")? {
5418 return Ok(out);
5419 }
5420
5421 cpu_fallback_binary(a, b, device, |x, y| x - y)
5422}
5423
5424#[cfg(feature = "cuda")]
5436pub fn gpu_mul(
5437 a: &CudaBuffer<f32>,
5438 b: &CudaBuffer<f32>,
5439 device: &GpuDevice,
5440) -> GpuResult<CudaBuffer<f32>> {
5441 validate_binary(a, b, device)?;
5442
5443 let n = a.len();
5444 if n >= 16 && n % 4 == 0 {
5445 if let Some(out) = try_launch_binary_vec4(
5446 a, b, device, MUL_VEC4_PTX, "mul_vec4_kernel",
5447 )? {
5448 return Ok(out);
5449 }
5450 }
5451
5452 if let Some(out) = try_launch_binary(a, b, device, MUL_PTX, "mul_kernel")? {
5453 return Ok(out);
5454 }
5455
5456 cpu_fallback_binary(a, b, device, |x, y| x * y)
5457}
5458
5459#[cfg(feature = "cuda")]
5472pub fn gpu_broadcast_add(
5473 a: &CudaBuffer<f32>,
5474 b: &CudaBuffer<f32>,
5475 a_shape: &[usize],
5476 b_shape: &[usize],
5477 out_shape: &[usize],
5478 device: &GpuDevice,
5479) -> GpuResult<CudaBuffer<f32>> {
5480 let a_str = broadcast_strides(a_shape, out_shape);
5481 let b_str = broadcast_strides(b_shape, out_shape);
5482 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5483 let out_numel: usize = out_shape.iter().product();
5484
5485 if let Some(out) = try_launch_broadcast_binary(
5486 a,
5487 b,
5488 &a_str,
5489 &b_str,
5490 &shape_u32,
5491 out_numel,
5492 device,
5493 BROADCAST_ADD_PTX,
5494 "broadcast_add_kernel",
5495 )? {
5496 return Ok(out);
5497 }
5498
5499 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
5501}
5502
5503#[cfg(feature = "cuda")]
5505pub fn gpu_broadcast_sub(
5506 a: &CudaBuffer<f32>,
5507 b: &CudaBuffer<f32>,
5508 a_shape: &[usize],
5509 b_shape: &[usize],
5510 out_shape: &[usize],
5511 device: &GpuDevice,
5512) -> GpuResult<CudaBuffer<f32>> {
5513 let a_str = broadcast_strides(a_shape, out_shape);
5514 let b_str = broadcast_strides(b_shape, out_shape);
5515 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5516 let out_numel: usize = out_shape.iter().product();
5517
5518 if let Some(out) = try_launch_broadcast_binary(
5519 a,
5520 b,
5521 &a_str,
5522 &b_str,
5523 &shape_u32,
5524 out_numel,
5525 device,
5526 BROADCAST_SUB_PTX,
5527 "broadcast_sub_kernel",
5528 )? {
5529 return Ok(out);
5530 }
5531
5532 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
5533}
5534
5535#[cfg(feature = "cuda")]
5537pub fn gpu_broadcast_mul(
5538 a: &CudaBuffer<f32>,
5539 b: &CudaBuffer<f32>,
5540 a_shape: &[usize],
5541 b_shape: &[usize],
5542 out_shape: &[usize],
5543 device: &GpuDevice,
5544) -> GpuResult<CudaBuffer<f32>> {
5545 let a_str = broadcast_strides(a_shape, out_shape);
5546 let b_str = broadcast_strides(b_shape, out_shape);
5547 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5548 let out_numel: usize = out_shape.iter().product();
5549
5550 if let Some(out) = try_launch_broadcast_binary(
5551 a,
5552 b,
5553 &a_str,
5554 &b_str,
5555 &shape_u32,
5556 out_numel,
5557 device,
5558 BROADCAST_MUL_PTX,
5559 "broadcast_mul_kernel",
5560 )? {
5561 return Ok(out);
5562 }
5563
5564 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
5565}
5566
5567#[cfg(feature = "cuda")]
5569pub fn gpu_broadcast_div(
5570 a: &CudaBuffer<f32>,
5571 b: &CudaBuffer<f32>,
5572 a_shape: &[usize],
5573 b_shape: &[usize],
5574 out_shape: &[usize],
5575 device: &GpuDevice,
5576) -> GpuResult<CudaBuffer<f32>> {
5577 let a_str = broadcast_strides(a_shape, out_shape);
5578 let b_str = broadcast_strides(b_shape, out_shape);
5579 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
5580 let out_numel: usize = out_shape.iter().product();
5581
5582 if let Some(out) = try_launch_broadcast_binary(
5583 a,
5584 b,
5585 &a_str,
5586 &b_str,
5587 &shape_u32,
5588 out_numel,
5589 device,
5590 BROADCAST_DIV_PTX,
5591 "broadcast_div_kernel",
5592 )? {
5593 return Ok(out);
5594 }
5595
5596 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
5597}
5598
5599#[cfg(feature = "cuda")]
5602fn cpu_fallback_broadcast_binary(
5603 a: &CudaBuffer<f32>,
5604 b: &CudaBuffer<f32>,
5605 a_shape: &[usize],
5606 b_shape: &[usize],
5607 out_shape: &[usize],
5608 device: &GpuDevice,
5609 op: fn(f32, f32) -> f32,
5610) -> GpuResult<CudaBuffer<f32>> {
5611 let a_host = gpu_to_cpu(a, device)?;
5612 let b_host = gpu_to_cpu(b, device)?;
5613 let out_numel: usize = out_shape.iter().product();
5614
5615 let a_str = broadcast_strides(a_shape, out_shape);
5616 let b_str = broadcast_strides(b_shape, out_shape);
5617
5618 let mut result = Vec::with_capacity(out_numel);
5619 for i in 0..out_numel {
5620 let mut remaining = i;
5621 let mut a_idx = 0usize;
5622 let mut b_idx = 0usize;
5623 for d in (0..out_shape.len()).rev() {
5624 let coord = remaining % out_shape[d];
5625 remaining /= out_shape[d];
5626 a_idx += coord * a_str[d] as usize;
5627 b_idx += coord * b_str[d] as usize;
5628 }
5629 result.push(op(a_host[a_idx], b_host[b_idx]));
5630 }
5631 cpu_to_gpu(&result, device)
5632}
5633
5634#[cfg(feature = "cuda")]
5649pub fn gpu_neg(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
5650 validate_unary(a, device)?;
5651
5652 if let Some(out) = try_launch_unary(a, device, NEG_PTX, "neg_kernel")? {
5653 return Ok(out);
5654 }
5655
5656 cpu_fallback_unary(a, device, |x| -x)
5657}
5658
5659#[cfg(feature = "cuda")]
5670pub fn gpu_relu(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
5671 validate_unary(a, device)?;
5672
5673 if let Some(out) = try_launch_unary(a, device, RELU_PTX, "relu_kernel")? {
5674 return Ok(out);
5675 }
5676
5677 cpu_fallback_unary(a, device, |x| x.max(0.0))
5678}
5679
5680#[cfg(feature = "cuda")]
5682pub fn gpu_relu_backward(
5683 grad: &CudaBuffer<f32>,
5684 input: &CudaBuffer<f32>,
5685 device: &GpuDevice,
5686) -> GpuResult<CudaBuffer<f32>> {
5687 validate_binary(grad, input, device)?;
5688
5689 if let Some(out) = try_launch_binary(
5690 grad,
5691 input,
5692 device,
5693 RELU_BACKWARD_PTX,
5694 "relu_backward_kernel",
5695 )? {
5696 return Ok(out);
5697 }
5698
5699 let grad_host = gpu_to_cpu(grad, device)?;
5701 let input_host = gpu_to_cpu(input, device)?;
5702 let result: Vec<f32> = grad_host
5703 .iter()
5704 .zip(input_host.iter())
5705 .map(|(&g, &x)| if x > 0.0 { g } else { 0.0 })
5706 .collect();
5707 cpu_to_gpu(&result, device)
5708}
5709
5710#[cfg(feature = "cuda")]
5713pub fn gpu_gelu_backward(
5714 grad: &CudaBuffer<f32>,
5715 input: &CudaBuffer<f32>,
5716 device: &GpuDevice,
5717) -> GpuResult<CudaBuffer<f32>> {
5718 validate_binary(grad, input, device)?;
5719
5720 if let Some(out) = try_launch_binary(
5721 grad,
5722 input,
5723 device,
5724 GELU_BACKWARD_PTX,
5725 "gelu_backward_kernel",
5726 )? {
5727 return Ok(out);
5728 }
5729
5730 let grad_host = gpu_to_cpu(grad, device)?;
5732 let input_host = gpu_to_cpu(input, device)?;
5733 let result: Vec<f32> = grad_host
5734 .iter()
5735 .zip(input_host.iter())
5736 .map(|(&g, &x)| {
5737 let k: f32 = 1.702;
5738 let sig = 1.0 / (1.0 + (-k * x).exp());
5739 g * (sig + k * x * sig * (1.0 - sig))
5740 })
5741 .collect();
5742 cpu_to_gpu(&result, device)
5743}
5744
5745#[cfg(feature = "cuda")]
5749pub fn gpu_gelu_backward_erf(
5750 grad: &CudaBuffer<f32>,
5751 input: &CudaBuffer<f32>,
5752 device: &GpuDevice,
5753) -> GpuResult<CudaBuffer<f32>> {
5754 validate_binary(grad, input, device)?;
5755
5756 if let Some(out) = try_launch_binary(
5757 grad,
5758 input,
5759 device,
5760 GELU_BACKWARD_ERF_PTX,
5761 "gelu_backward_erf_kernel",
5762 )? {
5763 return Ok(out);
5764 }
5765
5766 let grad_host = gpu_to_cpu(grad, device)?;
5768 let input_host = gpu_to_cpu(input, device)?;
5769 let inv_sqrt_2: f32 = std::f32::consts::FRAC_1_SQRT_2;
5770 let inv_sqrt_2pi: f32 = 1.0 / (2.0 * std::f32::consts::PI).sqrt();
5771 let result: Vec<f32> = grad_host
5772 .iter()
5773 .zip(input_host.iter())
5774 .map(|(&g, &x)| {
5775 let z = x * inv_sqrt_2;
5776 let az = z.abs();
5777 let t = 1.0 / (1.0 + 0.3275911 * az);
5778 let poly = t * (0.2548296 + t * (-0.2844967 + t * (1.4214137 + t * (-1.4531520 + t * 0.3275911))));
5779 let erf_abs = 1.0 - poly * (-az * az).exp();
5780 let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
5781 let cdf = 0.5 * (1.0 + erf_val);
5782 let pdf = inv_sqrt_2pi * (-0.5 * x * x).exp();
5783 g * (cdf + x * pdf)
5784 })
5785 .collect();
5786 cpu_to_gpu(&result, device)
5787}
5788
5789#[cfg(feature = "cuda")]
5798pub fn gpu_index_select_1d(
5799 input: &CudaBuffer<f32>,
5800 indices: &CudaBuffer<f32>,
5801 device: &GpuDevice,
5802) -> GpuResult<CudaBuffer<f32>> {
5803 use cudarc::driver::PushKernelArg;
5804
5805 validate_unary(input, device)?;
5806
5807 let n = indices.len();
5808 let ctx = device.context();
5809 let stream = device.stream();
5810
5811 let f = match crate::module_cache::get_or_compile(
5812 ctx,
5813 INDEX_SELECT_1D_PTX,
5814 "index_select_1d_kernel",
5815 device.ordinal() as u32,
5816 ) {
5817 Ok(f) => f,
5818 Err(_) => {
5819 let input_host = gpu_to_cpu(input, device)?;
5821 let indices_host = gpu_to_cpu(indices, device)?;
5822 let result: Vec<f32> = indices_host
5823 .iter()
5824 .map(|&idx_f| input_host[idx_f as usize])
5825 .collect();
5826 return cpu_to_gpu(&result, device);
5827 }
5828 };
5829
5830 let mut out = alloc_zeros_f32(n, device)?;
5831 let cfg = launch_cfg(n)?;
5832 let n_u32 = n as u32;
5833
5834 unsafe {
5835 stream
5836 .launch_builder(&f)
5837 .arg(input.inner())
5838 .arg(indices.inner())
5839 .arg(out.inner_mut())
5840 .arg(&n_u32)
5841 .launch(cfg)?;
5842 }
5843
5844 Ok(out)
5845}
5846
5847#[cfg(feature = "cuda")]
5859pub fn gpu_scatter_add_1d(
5860 grad_output: &CudaBuffer<f32>,
5861 indices: &CudaBuffer<f32>,
5862 input_len: usize,
5863 device: &GpuDevice,
5864) -> GpuResult<CudaBuffer<f32>> {
5865 use cudarc::driver::PushKernelArg;
5866
5867 validate_unary(grad_output, device)?;
5868
5869 let n = grad_output.len();
5870 let ctx = device.context();
5871 let stream = device.stream();
5872
5873 let f = match crate::module_cache::get_or_compile(
5874 ctx,
5875 SCATTER_ADD_1D_PTX,
5876 "scatter_add_1d_kernel",
5877 device.ordinal() as u32,
5878 ) {
5879 Ok(f) => f,
5880 Err(_) => {
5881 let go_host = gpu_to_cpu(grad_output, device)?;
5883 let idx_host = gpu_to_cpu(indices, device)?;
5884 let mut result = vec![0.0f32; input_len];
5885 for (i, &idx_f) in idx_host.iter().enumerate() {
5886 result[idx_f as usize] += go_host[i];
5887 }
5888 return cpu_to_gpu(&result, device);
5889 }
5890 };
5891
5892 let mut out = alloc_zeros_f32(input_len, device)?;
5893 let cfg = launch_cfg(n)?;
5894 let n_u32 = n as u32;
5895
5896 unsafe {
5897 stream
5898 .launch_builder(&f)
5899 .arg(grad_output.inner())
5900 .arg(indices.inner())
5901 .arg(out.inner_mut())
5902 .arg(&n_u32)
5903 .launch(cfg)?;
5904 }
5905
5906 Ok(out)
5907}
5908
5909#[cfg(feature = "cuda")]
5918pub fn gpu_masked_fill(
5919 input: &CudaBuffer<f32>,
5920 mask: &CudaBuffer<f32>,
5921 value: f32,
5922 device: &GpuDevice,
5923) -> GpuResult<CudaBuffer<f32>> {
5924 use cudarc::driver::PushKernelArg;
5925
5926 validate_binary(input, mask, device)?;
5927
5928 let n = input.len();
5929 let ctx = device.context();
5930 let stream = device.stream();
5931
5932 let f = match crate::module_cache::get_or_compile(
5933 ctx,
5934 MASKED_FILL_PTX,
5935 "masked_fill_kernel",
5936 device.ordinal() as u32,
5937 ) {
5938 Ok(f) => f,
5939 Err(_) => {
5940 let input_host = gpu_to_cpu(input, device)?;
5942 let mask_host = gpu_to_cpu(mask, device)?;
5943 let result: Vec<f32> = input_host
5944 .iter()
5945 .zip(mask_host.iter())
5946 .map(|(&x, &m)| if m >= 0.5 { value } else { x })
5947 .collect();
5948 return cpu_to_gpu(&result, device);
5949 }
5950 };
5951
5952 let mut out = alloc_zeros_f32(n, device)?;
5953 let cfg = launch_cfg(n)?;
5954 let n_u32 = n as u32;
5955
5956 unsafe {
5957 stream
5958 .launch_builder(&f)
5959 .arg(input.inner())
5960 .arg(mask.inner())
5961 .arg(out.inner_mut())
5962 .arg(&value)
5963 .arg(&n_u32)
5964 .launch(cfg)?;
5965 }
5966
5967 Ok(out)
5968}
5969
5970#[cfg(feature = "cuda")]
5979pub fn gpu_masked_zero(
5980 grad: &CudaBuffer<f32>,
5981 mask: &CudaBuffer<f32>,
5982 device: &GpuDevice,
5983) -> GpuResult<CudaBuffer<f32>> {
5984 validate_binary(grad, mask, device)?;
5985
5986 if let Some(out) = try_launch_binary(grad, mask, device, MASKED_ZERO_PTX, "masked_zero_kernel")?
5987 {
5988 return Ok(out);
5989 }
5990
5991 let grad_host = gpu_to_cpu(grad, device)?;
5993 let mask_host = gpu_to_cpu(mask, device)?;
5994 let result: Vec<f32> = grad_host
5995 .iter()
5996 .zip(mask_host.iter())
5997 .map(|(&g, &m)| if m >= 0.5 { 0.0 } else { g })
5998 .collect();
5999 cpu_to_gpu(&result, device)
6000}
6001
6002#[cfg(feature = "cuda")]
6010pub fn gpu_sigmoid_backward(
6011 grad: &CudaBuffer<f32>,
6012 output: &CudaBuffer<f32>,
6013 device: &GpuDevice,
6014) -> GpuResult<CudaBuffer<f32>> {
6015 validate_binary(grad, output, device)?;
6016
6017 if let Some(out) = try_launch_binary(
6018 grad,
6019 output,
6020 device,
6021 SIGMOID_BACKWARD_PTX,
6022 "sigmoid_backward_kernel",
6023 )? {
6024 return Ok(out);
6025 }
6026
6027 let grad_host = gpu_to_cpu(grad, device)?;
6029 let output_host = gpu_to_cpu(output, device)?;
6030 let result: Vec<f32> = grad_host
6031 .iter()
6032 .zip(output_host.iter())
6033 .map(|(&g, &o)| g * o * (1.0 - o))
6034 .collect();
6035 cpu_to_gpu(&result, device)
6036}
6037
6038#[cfg(feature = "cuda")]
6046pub fn gpu_tanh_backward(
6047 grad: &CudaBuffer<f32>,
6048 output: &CudaBuffer<f32>,
6049 device: &GpuDevice,
6050) -> GpuResult<CudaBuffer<f32>> {
6051 validate_binary(grad, output, device)?;
6052
6053 if let Some(out) = try_launch_binary(
6054 grad,
6055 output,
6056 device,
6057 TANH_BACKWARD_PTX,
6058 "tanh_backward_kernel",
6059 )? {
6060 return Ok(out);
6061 }
6062
6063 let grad_host = gpu_to_cpu(grad, device)?;
6065 let output_host = gpu_to_cpu(output, device)?;
6066 let result: Vec<f32> = grad_host
6067 .iter()
6068 .zip(output_host.iter())
6069 .map(|(&g, &o)| g * (1.0 - o * o))
6070 .collect();
6071 cpu_to_gpu(&result, device)
6072}
6073
6074#[cfg(feature = "cuda")]
6086pub fn gpu_softmax_backward(
6087 grad: &CudaBuffer<f32>,
6088 output: &CudaBuffer<f32>,
6089 cols: usize,
6090 device: &GpuDevice,
6091) -> GpuResult<CudaBuffer<f32>> {
6092 use cudarc::driver::PushKernelArg;
6093
6094 validate_binary(grad, output, device)?;
6095
6096 let total = grad.len();
6097 let rows = total / cols;
6098
6099 let ctx = device.context();
6100 let stream = device.stream();
6101
6102 let f = match crate::module_cache::get_or_compile(
6103 ctx,
6104 SOFTMAX_BACKWARD_PTX,
6105 "softmax_backward_kernel",
6106 device.ordinal() as u32,
6107 ) {
6108 Ok(f) => f,
6109 Err(_) => {
6110 let grad_host = gpu_to_cpu(grad, device)?;
6112 let output_host = gpu_to_cpu(output, device)?;
6113 let mut result = vec![0.0f32; total];
6114 for r in 0..rows {
6115 let base = r * cols;
6116 let mut dot = 0.0f32;
6117 for c in 0..cols {
6118 dot += grad_host[base + c] * output_host[base + c];
6119 }
6120 for c in 0..cols {
6121 result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
6122 }
6123 }
6124 return cpu_to_gpu(&result, device);
6125 }
6126 };
6127
6128 let mut out = alloc_zeros_f32(total, device)?;
6129 let rows_u32 = rows as u32;
6130 let cols_u32 = cols as u32;
6131
6132 let cfg = LaunchConfig {
6134 grid_dim: ((rows as u32).max(1), 1, 1),
6135 block_dim: (256, 1, 1),
6136 shared_mem_bytes: 256 * 4,
6137 };
6138
6139 unsafe {
6140 stream
6141 .launch_builder(&f)
6142 .arg(grad.inner())
6143 .arg(output.inner())
6144 .arg(out.inner_mut())
6145 .arg(&rows_u32)
6146 .arg(&cols_u32)
6147 .launch(cfg)?;
6148 }
6149
6150 Ok(out)
6151}
6152
6153#[cfg(feature = "cuda")]
6167pub fn gpu_reduce_sum(
6168 a: &CudaBuffer<f32>,
6169 device: &GpuDevice,
6170) -> GpuResult<CudaBuffer<f32>> {
6171 use cudarc::driver::PushKernelArg;
6172
6173 let n = a.len();
6174 if n == 0 {
6175 return cpu_to_gpu(&[0.0f32], device);
6176 }
6177
6178 let ctx = device.context();
6179 let stream = device.stream();
6180
6181 let f = match crate::module_cache::get_or_compile(
6182 ctx,
6183 REDUCE_SUM_PTX,
6184 "reduce_sum_kernel",
6185 device.ordinal() as u32,
6186 ) {
6187 Ok(f) => f,
6188 Err(_) => {
6189 let host = gpu_to_cpu(a, device)?;
6191 let total: f32 = host.iter().sum();
6192 return cpu_to_gpu(&[total], device);
6193 }
6194 };
6195
6196 const BLOCK: u32 = 256;
6198 let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
6199 let num_blocks = num_blocks.min(1024);
6201
6202 let mut partials = alloc_zeros_f32(num_blocks as usize, device)?;
6203 let n_u32 = n as u32;
6204
6205 let cfg = cudarc::driver::LaunchConfig {
6206 grid_dim: (num_blocks.max(1), 1, 1),
6207 block_dim: (BLOCK, 1, 1),
6208 shared_mem_bytes: 0, };
6210
6211 unsafe {
6212 stream
6213 .launch_builder(&f)
6214 .arg(a.inner())
6215 .arg(partials.inner_mut())
6216 .arg(&n_u32)
6217 .launch(cfg)?;
6218 }
6219
6220 if num_blocks <= 1 {
6222 return Ok(partials);
6223 }
6224
6225 if num_blocks <= 256 {
6227 let host_partials = gpu_to_cpu(&partials, device)?;
6228 let total: f32 = host_partials.iter().sum();
6229 return cpu_to_gpu(&[total], device);
6230 }
6231
6232 gpu_reduce_sum(&partials, device)
6234}
6235
6236#[cfg(not(feature = "cuda"))]
6238pub fn gpu_reduce_sum(
6239 _a: &CudaBuffer<f32>,
6240 _device: &GpuDevice,
6241) -> GpuResult<CudaBuffer<f32>> {
6242 Err(GpuError::NoCudaFeature)
6243}
6244
6245#[cfg(feature = "cuda")]
6249pub fn gpu_sum_axis(
6250 a: &CudaBuffer<f32>,
6251 outer: usize,
6252 axis_size: usize,
6253 inner: usize,
6254 device: &GpuDevice,
6255) -> GpuResult<CudaBuffer<f32>> {
6256 use cudarc::driver::PushKernelArg;
6257
6258 validate_unary(a, device)?;
6259
6260 let total_output = outer * inner;
6261 let ctx = device.context();
6262 let stream = device.stream();
6263
6264 let f = match crate::module_cache::get_or_compile(
6265 ctx,
6266 SUM_AXIS_PTX,
6267 "sum_axis_kernel",
6268 device.ordinal() as u32,
6269 ) {
6270 Ok(f) => f,
6271 Err(_) => {
6272 let host = gpu_to_cpu(a, device)?;
6274 let mut result = vec![0.0f32; total_output];
6275 for (i, out) in result.iter_mut().enumerate() {
6276 let outer_idx = i / inner;
6277 let inner_idx = i % inner;
6278 let mut sum = 0.0f32;
6279 for k in 0..axis_size {
6280 sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
6281 }
6282 *out = sum;
6283 }
6284 return cpu_to_gpu(&result, device);
6285 }
6286 };
6287
6288 let mut out = alloc_zeros_f32(total_output, device)?;
6289 let cfg = launch_cfg(total_output)?;
6290 let outer_u32 = outer as u32;
6291 let axis_size_u32 = axis_size as u32;
6292 let inner_u32 = inner as u32;
6293 let total_u32 = total_output as u32;
6294
6295 unsafe {
6296 stream
6297 .launch_builder(&f)
6298 .arg(a.inner())
6299 .arg(out.inner_mut())
6300 .arg(&outer_u32)
6301 .arg(&axis_size_u32)
6302 .arg(&inner_u32)
6303 .arg(&total_u32)
6304 .launch(cfg)?;
6305 }
6306
6307 Ok(out)
6308}
6309
6310#[cfg(feature = "cuda")]
6328pub fn gpu_strided_split(
6329 input: &CudaBuffer<f32>,
6330 total_along_axis: usize,
6331 split_offset: usize,
6332 split_size: usize,
6333 inner_size: usize,
6334 n: usize,
6335 device: &GpuDevice,
6336) -> GpuResult<CudaBuffer<f32>> {
6337 use cudarc::driver::PushKernelArg;
6338
6339 validate_unary(input, device)?;
6340
6341 let ctx = device.context();
6342 let stream = device.stream();
6343
6344 let f = match crate::module_cache::get_or_compile(
6345 ctx,
6346 STRIDED_SPLIT_PTX,
6347 "strided_split_kernel",
6348 device.ordinal() as u32,
6349 ) {
6350 Ok(f) => f,
6351 Err(_) => {
6352 let host = gpu_to_cpu(input, device)?;
6354 let outer = n / (split_size * inner_size);
6355 let mut result = vec![0.0f32; n];
6356 for (i, out) in result.iter_mut().enumerate() {
6357 let outer_idx = i / (split_size * inner_size);
6358 let within = i % (split_size * inner_size);
6359 let src_idx =
6360 outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
6361 *out = host[src_idx];
6362 }
6363 let _ = outer;
6364 return cpu_to_gpu(&result, device);
6365 }
6366 };
6367
6368 let mut out = alloc_zeros_f32(n, device)?;
6369 let cfg = launch_cfg(n)?;
6370 let total_ax_u32 = total_along_axis as u32;
6371 let offset_u32 = split_offset as u32;
6372 let split_sz_u32 = split_size as u32;
6373 let inner_u32 = inner_size as u32;
6374 let n_u32 = n as u32;
6375
6376 unsafe {
6377 stream
6378 .launch_builder(&f)
6379 .arg(input.inner())
6380 .arg(out.inner_mut())
6381 .arg(&total_ax_u32)
6382 .arg(&offset_u32)
6383 .arg(&split_sz_u32)
6384 .arg(&inner_u32)
6385 .arg(&n_u32)
6386 .launch(cfg)?;
6387 }
6388
6389 Ok(out)
6390}
6391
6392#[cfg(feature = "cuda")]
6416#[allow(clippy::too_many_arguments)]
6417pub fn gpu_strided_cat(
6418 input: &CudaBuffer<f32>,
6419 output: &mut CudaBuffer<f32>,
6420 total_along_axis: usize,
6421 cat_offset: usize,
6422 part_size: usize,
6423 inner_size: usize,
6424 n: usize,
6425 device: &GpuDevice,
6426) -> GpuResult<()> {
6427 use cudarc::driver::PushKernelArg;
6428
6429 validate_unary(input, device)?;
6430
6431 let ctx = device.context();
6432 let stream = device.stream();
6433
6434 let f = match crate::module_cache::get_or_compile(
6435 ctx,
6436 STRIDED_CAT_PTX,
6437 "strided_cat_kernel",
6438 device.ordinal() as u32,
6439 ) {
6440 Ok(f) => f,
6441 Err(_) => {
6442 let host_in = gpu_to_cpu(input, device)?;
6444 let mut host_out = gpu_to_cpu(output, device)?;
6445 for (i, &val) in host_in.iter().enumerate().take(n) {
6446 let outer_idx = i / (part_size * inner_size);
6447 let within = i % (part_size * inner_size);
6448 let dst_idx =
6449 outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
6450 host_out[dst_idx] = val;
6451 }
6452 *output = cpu_to_gpu(&host_out, device)?;
6453 return Ok(());
6454 }
6455 };
6456
6457 let cfg = launch_cfg(n)?;
6458 let total_ax_u32 = total_along_axis as u32;
6459 let offset_u32 = cat_offset as u32;
6460 let part_sz_u32 = part_size as u32;
6461 let inner_u32 = inner_size as u32;
6462 let n_u32 = n as u32;
6463
6464 unsafe {
6465 stream
6466 .launch_builder(&f)
6467 .arg(input.inner())
6468 .arg(output.inner_mut())
6469 .arg(&total_ax_u32)
6470 .arg(&offset_u32)
6471 .arg(&part_sz_u32)
6472 .arg(&inner_u32)
6473 .arg(&n_u32)
6474 .launch(cfg)?;
6475 }
6476
6477 Ok(())
6478}
6479
6480#[cfg(feature = "cuda")]
6489pub fn gpu_scale(
6490 a: &CudaBuffer<f32>,
6491 scalar: f32,
6492 device: &GpuDevice,
6493) -> GpuResult<CudaBuffer<f32>> {
6494 use cudarc::driver::PushKernelArg;
6495
6496 validate_unary(a, device)?;
6497
6498 let n = a.len();
6499 let ctx = device.context();
6500 let stream = device.stream();
6501
6502 let f = match crate::module_cache::get_or_compile(
6503 ctx,
6504 SCALE_PTX,
6505 "scale_kernel",
6506 device.ordinal() as u32,
6507 ) {
6508 Ok(f) => f,
6509 Err(_) => {
6510 let host = gpu_to_cpu(a, device)?;
6512 let result: Vec<f32> = host.iter().map(|&x| x * scalar).collect();
6513 return cpu_to_gpu(&result, device);
6514 }
6515 };
6516
6517 let mut out = alloc_zeros_f32(n, device)?;
6518 let cfg = launch_cfg(n)?;
6519 let n_u32 = n as u32;
6520
6521 unsafe {
6522 stream
6523 .launch_builder(&f)
6524 .arg(a.inner())
6525 .arg(out.inner_mut())
6526 .arg(&scalar)
6527 .arg(&n_u32)
6528 .launch(cfg)?;
6529 }
6530
6531 Ok(out)
6532}
6533
6534#[cfg(feature = "cuda")]
6542pub fn gpu_softmax(
6543 input: &CudaBuffer<f32>,
6544 rows: usize,
6545 cols: usize,
6546 device: &GpuDevice,
6547) -> GpuResult<CudaBuffer<f32>> {
6548 use cudarc::driver::PushKernelArg;
6549
6550 validate_unary(input, device)?;
6551
6552 let ctx = device.context();
6553 let stream = device.stream();
6554
6555 let f = match crate::module_cache::get_or_compile(
6556 ctx,
6557 SOFTMAX_PTX,
6558 "softmax_kernel",
6559 device.ordinal() as u32,
6560 ) {
6561 Ok(f) => f,
6562 Err(_) => {
6563 let host = gpu_to_cpu(input, device)?;
6565 let mut out = vec![0.0f32; host.len()];
6566 for r in 0..rows {
6567 let base = r * cols;
6568 let mut max_v = f32::NEG_INFINITY;
6569 for c in 0..cols {
6570 max_v = max_v.max(host[base + c]);
6571 }
6572 let mut sum = 0.0f32;
6573 for c in 0..cols {
6574 let e = (host[base + c] - max_v).exp();
6575 out[base + c] = e;
6576 sum += e;
6577 }
6578 let inv = 1.0 / sum;
6579 for c in 0..cols {
6580 out[base + c] *= inv;
6581 }
6582 }
6583 return cpu_to_gpu(&out, device);
6584 }
6585 };
6586
6587 let mut out = alloc_zeros_f32(rows * cols, device)?;
6588 let rows_u32 = rows as u32;
6589 let cols_u32 = cols as u32;
6590
6591 let cfg = LaunchConfig {
6593 grid_dim: ((rows as u32).max(1), 1, 1),
6594 block_dim: (256, 1, 1),
6595 shared_mem_bytes: 256 * 4, };
6597
6598 unsafe {
6599 stream
6600 .launch_builder(&f)
6601 .arg(input.inner())
6602 .arg(out.inner_mut())
6603 .arg(&rows_u32)
6604 .arg(&cols_u32)
6605 .launch(cfg)?;
6606 }
6607
6608 Ok(out)
6609}
6610
6611#[cfg(feature = "cuda")]
6630pub fn gpu_dropout(
6631 input: &CudaBuffer<f32>,
6632 threshold: u32,
6633 scale: f32,
6634 seed: u32,
6635 device: &GpuDevice,
6636) -> GpuResult<CudaBuffer<f32>> {
6637 use cudarc::driver::PushKernelArg;
6638
6639 validate_unary(input, device)?;
6640
6641 let n = input.len();
6642 let ctx = device.context();
6643 let stream = device.stream();
6644
6645 let f = match crate::module_cache::get_or_compile(
6646 ctx,
6647 DROPOUT_PTX,
6648 "dropout_kernel",
6649 device.ordinal() as u32,
6650 ) {
6651 Ok(f) => f,
6652 Err(_) => {
6653 let host = gpu_to_cpu(input, device)?;
6655 let result: Vec<f32> = host
6659 .iter()
6660 .enumerate()
6661 .map(|(i, &x)| {
6662 let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
6663 r ^= r << 13;
6664 r ^= r >> 17;
6665 r ^= r << 5;
6666 if r < threshold { 0.0 } else { x * scale }
6667 })
6668 .collect();
6669 return cpu_to_gpu(&result, device);
6670 }
6671 };
6672
6673 let mut out = alloc_zeros_f32(n, device)?;
6674 let cfg = launch_cfg(n)?;
6675 let n_u32 = n as u32;
6676
6677 unsafe {
6678 stream
6679 .launch_builder(&f)
6680 .arg(input.inner())
6681 .arg(out.inner_mut())
6682 .arg(&n_u32)
6683 .arg(&threshold)
6684 .arg(&scale)
6685 .arg(&seed)
6686 .launch(cfg)?;
6687 }
6688
6689 Ok(out)
6690}
6691
6692#[cfg(feature = "cuda")]
6698pub fn gpu_transpose_2d(
6699 input: &CudaBuffer<f32>,
6700 m: usize,
6701 n: usize,
6702 device: &GpuDevice,
6703) -> GpuResult<CudaBuffer<f32>> {
6704 use cudarc::driver::PushKernelArg;
6705
6706 validate_unary(input, device)?;
6707
6708 let total = m * n;
6709 let ctx = device.context();
6710 let stream = device.stream();
6711
6712 let f = match crate::module_cache::get_or_compile(
6713 ctx,
6714 TRANSPOSE_2D_PTX,
6715 "transpose_2d_kernel",
6716 device.ordinal() as u32,
6717 ) {
6718 Ok(f) => f,
6719 Err(_) => {
6720 let host = gpu_to_cpu(input, device)?;
6722 let mut out = vec![0.0f32; total];
6723 for i in 0..m {
6724 for j in 0..n {
6725 out[j * m + i] = host[i * n + j];
6726 }
6727 }
6728 return cpu_to_gpu(&out, device);
6729 }
6730 };
6731
6732 let mut out = alloc_zeros_f32(total, device)?;
6733 let cfg = launch_cfg(total)?;
6734 let m_u32 = m as u32;
6735 let n_u32 = n as u32;
6736 let total_u32 = total as u32;
6737
6738 unsafe {
6739 stream
6740 .launch_builder(&f)
6741 .arg(input.inner())
6742 .arg(out.inner_mut())
6743 .arg(&m_u32)
6744 .arg(&n_u32)
6745 .arg(&total_u32)
6746 .launch(cfg)?;
6747 }
6748
6749 Ok(out)
6750}
6751
6752#[cfg(feature = "cuda")]
6759pub fn gpu_permute_0213(
6760 input: &CudaBuffer<f32>,
6761 d0: usize,
6762 d1: usize,
6763 d2: usize,
6764 d3: usize,
6765 device: &GpuDevice,
6766) -> GpuResult<CudaBuffer<f32>> {
6767 use cudarc::driver::PushKernelArg;
6768
6769 validate_unary(input, device)?;
6770
6771 let total = d0 * d1 * d2 * d3;
6772 let ctx = device.context();
6773 let stream = device.stream();
6774
6775 let f = match crate::module_cache::get_or_compile(
6776 ctx,
6777 PERMUTE_0213_PTX,
6778 "permute_0213_kernel",
6779 device.ordinal() as u32,
6780 ) {
6781 Ok(f) => f,
6782 Err(_) => {
6783 let host = gpu_to_cpu(input, device)?;
6785 let mut out = vec![0.0f32; total];
6786 for i0 in 0..d0 {
6787 for i1 in 0..d1 {
6788 for i2 in 0..d2 {
6789 for i3 in 0..d3 {
6790 let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
6791 let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
6792 out[out_idx] = host[in_idx];
6793 }
6794 }
6795 }
6796 }
6797 return cpu_to_gpu(&out, device);
6798 }
6799 };
6800
6801 let mut out = alloc_zeros_f32(total, device)?;
6802 let cfg = launch_cfg(total)?;
6803 let d0_u32 = d0 as u32;
6804 let d1_u32 = d1 as u32;
6805 let d2_u32 = d2 as u32;
6806 let d3_u32 = d3 as u32;
6807 let total_u32 = total as u32;
6808
6809 unsafe {
6810 stream
6811 .launch_builder(&f)
6812 .arg(input.inner())
6813 .arg(out.inner_mut())
6814 .arg(&d0_u32)
6815 .arg(&d1_u32)
6816 .arg(&d2_u32)
6817 .arg(&d3_u32)
6818 .arg(&total_u32)
6819 .launch(cfg)?;
6820 }
6821
6822 Ok(out)
6823}
6824
6825#[cfg(feature = "cuda")]
6834pub fn gpu_small_matmul(
6835 a: &CudaBuffer<f32>,
6836 b: &CudaBuffer<f32>,
6837 m: usize,
6838 k: usize,
6839 n: usize,
6840 device: &GpuDevice,
6841) -> GpuResult<CudaBuffer<f32>> {
6842 use cudarc::driver::PushKernelArg;
6843
6844 let total = m * n;
6845 let ctx = device.context();
6846 let stream = device.stream();
6847
6848 let f = match crate::module_cache::get_or_compile(
6849 ctx,
6850 SMALL_MATMUL_PTX,
6851 "small_matmul_kernel",
6852 device.ordinal() as u32,
6853 ) {
6854 Ok(f) => f,
6855 Err(_) => {
6856 return crate::blas::gpu_matmul_f32(a, b, m, k, n, device);
6858 }
6859 };
6860
6861 let mut c = alloc_zeros_f32(total, device)?;
6862 let cfg = launch_cfg(total)?;
6863 let m_u32 = m as u32;
6864 let k_u32 = k as u32;
6865 let n_u32 = n as u32;
6866 let total_u32 = total as u32;
6867
6868 unsafe {
6869 stream
6870 .launch_builder(&f)
6871 .arg(a.inner())
6872 .arg(b.inner())
6873 .arg(c.inner_mut())
6874 .arg(&m_u32)
6875 .arg(&k_u32)
6876 .arg(&n_u32)
6877 .arg(&total_u32)
6878 .launch(cfg)?;
6879 }
6880
6881 Ok(c)
6882}
6883
6884#[cfg(feature = "cuda")]
6893pub fn gpu_small_bmm(
6894 a: &CudaBuffer<f32>,
6895 b: &CudaBuffer<f32>,
6896 batch: usize,
6897 m: usize,
6898 k: usize,
6899 n: usize,
6900 device: &GpuDevice,
6901) -> GpuResult<CudaBuffer<f32>> {
6902 if batch == 1 {
6904 return gpu_small_matmul(a, b, m, k, n, device);
6905 }
6906 crate::blas::gpu_bmm_f32(a, b, batch, m, k, n, device)
6909}
6910
6911#[cfg(feature = "cuda")]
6919pub fn gpu_embed_lookup(
6920 idx: &CudaBuffer<f32>,
6921 weight: &CudaBuffer<f32>,
6922 d: usize,
6923 device: &GpuDevice,
6924) -> GpuResult<CudaBuffer<f32>> {
6925 use cudarc::driver::PushKernelArg;
6926
6927 let ctx = device.context();
6928 let stream = device.stream();
6929
6930 let f = match crate::module_cache::get_or_compile(
6931 ctx,
6932 EMBED_LOOKUP_PTX,
6933 "embed_lookup_kernel",
6934 device.ordinal() as u32,
6935 ) {
6936 Ok(f) => f,
6937 Err(_) => {
6938 let idx_host = gpu_to_cpu(idx, device)?;
6940 let weight_host = gpu_to_cpu(weight, device)?;
6941 let row = idx_host[0] as usize;
6942 let start = row * d;
6943 let out = weight_host[start..start + d].to_vec();
6944 return cpu_to_gpu(&out, device);
6945 }
6946 };
6947
6948 let mut out = alloc_zeros_f32(d, device)?;
6949 let cfg = launch_cfg(d)?;
6950 let d_u32 = d as u32;
6951
6952 unsafe {
6953 stream
6954 .launch_builder(&f)
6955 .arg(idx.inner())
6956 .arg(weight.inner())
6957 .arg(out.inner_mut())
6958 .arg(&d_u32)
6959 .launch(cfg)?;
6960 }
6961
6962 Ok(out)
6963}
6964
6965#[cfg(feature = "cuda")]
6972pub fn gpu_slice_write(
6973 src: &CudaBuffer<f32>,
6974 dst: &mut CudaBuffer<f32>,
6975 n_batch: usize,
6976 d: usize,
6977 max_len: usize,
6978 pos: usize,
6979 device: &GpuDevice,
6980) -> GpuResult<()> {
6981 use cudarc::driver::PushKernelArg;
6982
6983 let total = n_batch * d;
6984 let ctx = device.context();
6985 let stream = device.stream();
6986
6987 let f = match crate::module_cache::get_or_compile(
6988 ctx,
6989 SLICE_WRITE_PTX,
6990 "slice_write_kernel",
6991 device.ordinal() as u32,
6992 ) {
6993 Ok(f) => f,
6994 Err(_) => {
6995 let src_host = gpu_to_cpu(src, device)?;
6997 let mut dst_host = gpu_to_cpu(dst, device)?;
6998 for b in 0..n_batch {
6999 for di in 0..d {
7000 dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
7001 }
7002 }
7003 let new_dst = cpu_to_gpu(&dst_host, device)?;
7004 *dst = new_dst;
7005 return Ok(());
7006 }
7007 };
7008
7009 let cfg = launch_cfg(total)?;
7010 let n_u32 = total as u32;
7011 let d_u32 = d as u32;
7012 let max_len_u32 = max_len as u32;
7013 let pos_u32 = pos as u32;
7014
7015 unsafe {
7016 stream
7017 .launch_builder(&f)
7018 .arg(src.inner())
7019 .arg(dst.inner_mut())
7020 .arg(&n_u32)
7021 .arg(&d_u32)
7022 .arg(&max_len_u32)
7023 .arg(&pos_u32)
7024 .launch(cfg)?;
7025 }
7026
7027 Ok(())
7028}
7029
7030#[cfg(feature = "cuda")]
7036pub fn gpu_slice_read(
7037 src: &CudaBuffer<f32>,
7038 n_batch: usize,
7039 d: usize,
7040 len: usize,
7041 max_len: usize,
7042 device: &GpuDevice,
7043) -> GpuResult<CudaBuffer<f32>> {
7044 use cudarc::driver::PushKernelArg;
7045
7046 let total = n_batch * len * d;
7047 let ctx = device.context();
7048 let stream = device.stream();
7049
7050 let f = match crate::module_cache::get_or_compile(
7051 ctx,
7052 SLICE_READ_PTX,
7053 "slice_read_kernel",
7054 device.ordinal() as u32,
7055 ) {
7056 Ok(f) => f,
7057 Err(_) => {
7058 let host = gpu_to_cpu(src, device)?;
7059 let mut out = vec![0.0f32; total];
7060 for b in 0..n_batch {
7061 for r in 0..len {
7062 for di in 0..d {
7063 out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
7064 }
7065 }
7066 }
7067 return cpu_to_gpu(&out, device);
7068 }
7069 };
7070
7071 let mut out = alloc_zeros_f32(total, device)?;
7072 let cfg = launch_cfg(total)?;
7073 let total_u32 = total as u32;
7074 let d_u32 = d as u32;
7075 let len_u32 = len as u32;
7076 let max_len_u32 = max_len as u32;
7077
7078 unsafe {
7079 stream
7080 .launch_builder(&f)
7081 .arg(src.inner())
7082 .arg(out.inner_mut())
7083 .arg(&total_u32)
7084 .arg(&d_u32)
7085 .arg(&len_u32)
7086 .arg(&max_len_u32)
7087 .launch(cfg)?;
7088 }
7089
7090 Ok(out)
7091}
7092
7093#[cfg(feature = "cuda")]
7099pub fn gpu_gelu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7100 validate_unary(input, device)?;
7101 if let Some(out) = try_launch_unary(input, device, GELU_PTX, "gelu_kernel")? {
7102 return Ok(out);
7103 }
7104 cpu_fallback_unary(input, device, |x| {
7105 let s = 1.0 / (1.0 + (-1.702 * x).exp());
7106 x * s
7107 })
7108}
7109
7110#[cfg(feature = "cuda")]
7115pub fn gpu_gelu_tanh(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7116 validate_unary(input, device)?;
7117 if let Some(out) = try_launch_unary(input, device, GELU_TANH_PTX, "gelu_tanh_kernel")? {
7118 return Ok(out);
7119 }
7120 cpu_fallback_unary(input, device, |x| {
7121 let sqrt_2_over_pi: f32 = 0.7978845608;
7122 let c: f32 = 0.044715;
7123 let inner = sqrt_2_over_pi * (x + c * x * x * x);
7124 0.5 * x * (1.0 + inner.tanh())
7125 })
7126}
7127
7128#[cfg(feature = "cuda")]
7133pub fn gpu_gelu_erf(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7134 validate_unary(input, device)?;
7135 if let Some(out) = try_launch_unary(input, device, GELU_ERF_PTX, "gelu_erf_kernel")? {
7136 return Ok(out);
7137 }
7138 cpu_fallback_unary(input, device, |x| {
7139 let z = x * std::f32::consts::FRAC_1_SQRT_2;
7141 let az = z.abs();
7142 let t = 1.0 / (1.0 + 0.3275911 * az);
7143 let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
7144 let erf_abs = 1.0 - poly * (-az * az).exp();
7145 let erf_val = if z < 0.0 { -erf_abs } else { erf_abs };
7146 x * 0.5 * (1.0 + erf_val)
7147 })
7148}
7149
7150#[cfg(feature = "cuda")]
7154pub fn gpu_gelu_backward_tanh(
7155 grad: &CudaBuffer<f32>,
7156 input: &CudaBuffer<f32>,
7157 device: &GpuDevice,
7158) -> GpuResult<CudaBuffer<f32>> {
7159 validate_binary(grad, input, device)?;
7160 if let Some(out) = try_launch_binary(
7161 grad,
7162 input,
7163 device,
7164 GELU_BACKWARD_TANH_PTX,
7165 "gelu_backward_tanh_kernel",
7166 )? {
7167 return Ok(out);
7168 }
7169 let grad_host = gpu_to_cpu(grad, device)?;
7171 let input_host = gpu_to_cpu(input, device)?;
7172 let result: Vec<f32> = grad_host
7173 .iter()
7174 .zip(input_host.iter())
7175 .map(|(&g, &x)| {
7176 let sqrt_2_over_pi: f32 = 0.7978845608;
7177 let c: f32 = 0.044715;
7178 let c3: f32 = 0.134145;
7179 let u = sqrt_2_over_pi * (x + c * x * x * x);
7180 let t = u.tanh();
7181 let dt = 1.0 - t * t;
7182 let d_inner = sqrt_2_over_pi * (1.0 + c3 * x * x);
7183 g * (0.5 * (1.0 + t) + 0.5 * x * dt * d_inner)
7184 })
7185 .collect();
7186 cpu_to_gpu(&result, device)
7187}
7188
7189#[cfg(feature = "cuda")]
7195pub fn gpu_div(
7196 a: &CudaBuffer<f32>,
7197 b: &CudaBuffer<f32>,
7198 device: &GpuDevice,
7199) -> GpuResult<CudaBuffer<f32>> {
7200 validate_binary(a, b, device)?;
7201
7202 if let Some(out) = try_launch_binary(a, b, device, DIV_PTX, "div_kernel")? {
7203 return Ok(out);
7204 }
7205
7206 let a_host = gpu_to_cpu(a, device)?;
7208 let b_host = gpu_to_cpu(b, device)?;
7209 let result: Vec<f32> = a_host
7210 .iter()
7211 .zip(b_host.iter())
7212 .map(|(&x, &y)| x / y)
7213 .collect();
7214 cpu_to_gpu(&result, device)
7215}
7216
7217#[cfg(feature = "cuda")]
7219pub fn gpu_exp(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7220 validate_unary(a, device)?;
7221 if let Some(out) = try_launch_unary(a, device, EXP_PTX, "exp_kernel")? {
7222 return Ok(out);
7223 }
7224 cpu_fallback_unary(a, device, |x| x.exp())
7225}
7226
7227#[cfg(feature = "cuda")]
7229pub fn gpu_log(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7230 validate_unary(a, device)?;
7231 if let Some(out) = try_launch_unary(a, device, LOG_PTX, "log_kernel")? {
7232 return Ok(out);
7233 }
7234 cpu_fallback_unary(a, device, |x| x.ln())
7235}
7236
7237#[cfg(feature = "cuda")]
7239pub fn gpu_sqrt(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7240 validate_unary(a, device)?;
7241 if let Some(out) = try_launch_unary(a, device, SQRT_PTX, "sqrt_kernel")? {
7242 return Ok(out);
7243 }
7244 cpu_fallback_unary(a, device, |x| x.sqrt())
7245}
7246
7247#[cfg(feature = "cuda")]
7249pub fn gpu_pow(
7250 a: &CudaBuffer<f32>,
7251 exponent: f32,
7252 device: &GpuDevice,
7253) -> GpuResult<CudaBuffer<f32>> {
7254 use cudarc::driver::PushKernelArg;
7255
7256 validate_unary(a, device)?;
7257
7258 let n = a.len();
7259 let ctx = device.context();
7260 let stream = device.stream();
7261
7262 let f = match crate::module_cache::get_or_compile(
7263 ctx,
7264 POW_PTX,
7265 "pow_kernel",
7266 device.ordinal() as u32,
7267 ) {
7268 Ok(f) => f,
7269 Err(_) => {
7270 let host = gpu_to_cpu(a, device)?;
7271 let result: Vec<f32> = host.iter().map(|&x| x.powf(exponent)).collect();
7272 return cpu_to_gpu(&result, device);
7273 }
7274 };
7275
7276 let mut out = alloc_zeros_f32(n, device)?;
7277 let cfg = launch_cfg(n)?;
7278 let n_u32 = n as u32;
7279
7280 unsafe {
7281 stream
7282 .launch_builder(&f)
7283 .arg(a.inner())
7284 .arg(out.inner_mut())
7285 .arg(&exponent)
7286 .arg(&n_u32)
7287 .launch(cfg)?;
7288 }
7289
7290 Ok(out)
7291}
7292
7293#[cfg(feature = "cuda")]
7295pub fn gpu_abs(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7296 validate_unary(a, device)?;
7297 if let Some(out) = try_launch_unary(a, device, ABS_PTX, "abs_kernel")? {
7298 return Ok(out);
7299 }
7300 cpu_fallback_unary(a, device, |x| x.abs())
7301}
7302
7303#[cfg(feature = "cuda")]
7305pub fn gpu_sigmoid(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7306 validate_unary(a, device)?;
7307 if let Some(out) = try_launch_unary(a, device, SIGMOID_PTX, "sigmoid_kernel")? {
7308 return Ok(out);
7309 }
7310 cpu_fallback_unary(a, device, |x| 1.0 / (1.0 + (-x).exp()))
7311}
7312
7313#[cfg(feature = "cuda")]
7315pub fn gpu_tanh(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
7316 validate_unary(a, device)?;
7317 if let Some(out) = try_launch_unary(a, device, TANH_PTX, "tanh_kernel")? {
7318 return Ok(out);
7319 }
7320 cpu_fallback_unary(a, device, |x| x.tanh())
7321}
7322
7323#[cfg(feature = "cuda")]
7333#[allow(clippy::too_many_arguments)]
7334pub fn gpu_fused_adam(
7335 param: &mut CudaBuffer<f32>,
7336 grad: &CudaBuffer<f32>,
7337 exp_avg: &mut CudaBuffer<f32>,
7338 exp_avg_sq: &mut CudaBuffer<f32>,
7339 beta1: f32,
7340 beta2: f32,
7341 lr: f32,
7342 eps: f32,
7343 bc1: f32,
7344 bc2: f32,
7345 weight_decay: f32,
7346 device: &GpuDevice,
7347) -> GpuResult<()> {
7348 use cudarc::driver::PushKernelArg;
7349
7350 let n = param.len();
7351 if grad.len() != n || exp_avg.len() != n || exp_avg_sq.len() != n {
7352 return Err(GpuError::LengthMismatch {
7353 a: n,
7354 b: grad.len(),
7355 });
7356 }
7357
7358 let ctx = device.context();
7359 let stream = device.stream();
7360
7361 let f = match crate::module_cache::get_or_compile(
7362 ctx,
7363 FUSED_ADAM_PTX,
7364 "fused_adam_kernel",
7365 device.ordinal() as u32,
7366 ) {
7367 Ok(f) => f,
7368 Err(_) => {
7369 let mut p_host = gpu_to_cpu(param, device)?;
7371 let g_host = gpu_to_cpu(grad, device)?;
7372 let mut m_host = gpu_to_cpu(exp_avg, device)?;
7373 let mut v_host = gpu_to_cpu(exp_avg_sq, device)?;
7374
7375 for i in 0..n {
7376 let mut g = g_host[i];
7377 if weight_decay > 0.0 {
7378 g += weight_decay * p_host[i];
7379 }
7380 m_host[i] = beta1 * m_host[i] + (1.0 - beta1) * g;
7381 v_host[i] = beta2 * v_host[i] + (1.0 - beta2) * g * g;
7382 let m_hat = m_host[i] / bc1;
7383 let v_hat = v_host[i] / bc2;
7384 p_host[i] -= lr * m_hat / (v_hat.sqrt() + eps);
7385 }
7386
7387 *param = cpu_to_gpu(&p_host, device)?;
7388 *exp_avg = cpu_to_gpu(&m_host, device)?;
7389 *exp_avg_sq = cpu_to_gpu(&v_host, device)?;
7390 return Ok(());
7391 }
7392 };
7393
7394 let cfg = launch_cfg(n)?;
7395 let n_u32 = n as u32;
7396
7397 unsafe {
7398 stream
7399 .launch_builder(&f)
7400 .arg(param.inner_mut())
7401 .arg(grad.inner())
7402 .arg(exp_avg.inner_mut())
7403 .arg(exp_avg_sq.inner_mut())
7404 .arg(&beta1)
7405 .arg(&beta2)
7406 .arg(&lr)
7407 .arg(&eps)
7408 .arg(&bc1)
7409 .arg(&bc2)
7410 .arg(&weight_decay)
7411 .arg(&n_u32)
7412 .launch(cfg)?;
7413 }
7414
7415 Ok(())
7416}
7417
7418#[cfg(not(feature = "cuda"))]
7420#[allow(clippy::too_many_arguments)]
7421pub fn gpu_fused_adam(
7422 _param: &mut CudaBuffer<f32>,
7423 _grad: &CudaBuffer<f32>,
7424 _exp_avg: &mut CudaBuffer<f32>,
7425 _exp_avg_sq: &mut CudaBuffer<f32>,
7426 _beta1: f32,
7427 _beta2: f32,
7428 _lr: f32,
7429 _eps: f32,
7430 _bc1: f32,
7431 _bc2: f32,
7432 _weight_decay: f32,
7433 _device: &GpuDevice,
7434) -> GpuResult<()> {
7435 Err(GpuError::NoCudaFeature)
7436}
7437
7438#[cfg(feature = "cuda")]
7456pub fn gpu_fused_gru_forward(
7457 input_gates: &CudaBuffer<f32>,
7458 hidden_gates: &CudaBuffer<f32>,
7459 bias_ih: &CudaBuffer<f32>,
7460 bias_hh: &CudaBuffer<f32>,
7461 hx: &CudaBuffer<f32>,
7462 hsz: usize,
7463 device: &GpuDevice,
7464) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
7465 use cudarc::driver::PushKernelArg;
7466
7467 let total = hx.len(); let batch = total / hsz;
7469
7470 let ctx = device.context();
7471 let stream = device.stream();
7472
7473 let f = match crate::module_cache::get_or_compile(
7474 ctx,
7475 FUSED_GRU_FORWARD_PTX,
7476 "fused_gru_forward_kernel",
7477 device.ordinal() as u32,
7478 ) {
7479 Ok(f) => f,
7480 Err(_) => {
7481 return Err(GpuError::PtxCompileFailed {
7482 kernel: "fused_gru_forward_kernel",
7483 });
7484 }
7485 };
7486
7487 let mut hy = alloc_zeros_f32(total, device)?;
7488 let mut workspace = alloc_zeros_f32(batch * 5 * hsz, device)?;
7489
7490 let cfg = launch_cfg(total)?;
7491 let hsz_u32 = hsz as u32;
7492 let total_u32 = total as u32;
7493
7494 unsafe {
7495 stream
7496 .launch_builder(&f)
7497 .arg(input_gates.inner())
7498 .arg(hidden_gates.inner())
7499 .arg(bias_ih.inner())
7500 .arg(bias_hh.inner())
7501 .arg(hx.inner())
7502 .arg(hy.inner_mut())
7503 .arg(workspace.inner_mut())
7504 .arg(&hsz_u32)
7505 .arg(&total_u32)
7506 .launch(cfg)?;
7507 }
7508
7509 Ok((hy, workspace))
7510}
7511
7512#[cfg(not(feature = "cuda"))]
7514pub fn gpu_fused_gru_forward(
7515 _input_gates: &CudaBuffer<f32>,
7516 _hidden_gates: &CudaBuffer<f32>,
7517 _bias_ih: &CudaBuffer<f32>,
7518 _bias_hh: &CudaBuffer<f32>,
7519 _hx: &CudaBuffer<f32>,
7520 _hsz: usize,
7521 _device: &GpuDevice,
7522) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
7523 Err(GpuError::NoCudaFeature)
7524}
7525
7526#[cfg(feature = "cuda")]
7532#[allow(clippy::too_many_arguments)]
7533pub fn gpu_maxpool2d(
7534 input: &CudaBuffer<f32>,
7535 batch: usize,
7536 channels: usize,
7537 h_in: usize,
7538 w_in: usize,
7539 kh: usize,
7540 kw: usize,
7541 sh: usize,
7542 sw: usize,
7543 ph: usize,
7544 pw: usize,
7545 device: &GpuDevice,
7546) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7547 use cudarc::driver::PushKernelArg;
7548
7549 let h_out = (h_in + 2 * ph - kh) / sh + 1;
7550 let w_out = (w_in + 2 * pw - kw) / sw + 1;
7551 let total = batch * channels * h_out * w_out;
7552
7553 let ctx = device.context();
7554 let stream = device.stream();
7555
7556 let f = match crate::module_cache::get_or_compile(
7557 ctx, MAXPOOL2D_PTX, "maxpool2d_forward_kernel", device.ordinal() as u32,
7558 ) {
7559 Ok(f) => f,
7560 Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "maxpool2d_forward_kernel" }),
7561 };
7562
7563 let mut out = alloc_zeros_f32(total, device)?;
7564 let cfg = launch_cfg(total)?;
7565
7566 let (batch_u32, ch_u32) = (batch as u32, channels as u32);
7567 let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
7568 let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
7569 let (kh_u32, kw_u32) = (kh as u32, kw as u32);
7570 let (sh_u32, sw_u32) = (sh as u32, sw as u32);
7571 let (ph_u32, pw_u32) = (ph as u32, pw as u32);
7572 let total_u32 = total as u32;
7573
7574 unsafe {
7575 stream.launch_builder(&f)
7576 .arg(input.inner())
7577 .arg(out.inner_mut())
7578 .arg(&batch_u32).arg(&ch_u32)
7579 .arg(&h_in_u32).arg(&w_in_u32)
7580 .arg(&h_out_u32).arg(&w_out_u32)
7581 .arg(&kh_u32).arg(&kw_u32)
7582 .arg(&sh_u32).arg(&sw_u32)
7583 .arg(&ph_u32).arg(&pw_u32)
7584 .arg(&total_u32)
7585 .launch(cfg)?;
7586 }
7587
7588 Ok((out, [batch, channels, h_out, w_out]))
7589}
7590
7591#[cfg(not(feature = "cuda"))]
7593#[allow(clippy::too_many_arguments)]
7594pub fn gpu_maxpool2d(
7595 _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
7596 _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
7597 _sh: usize, _sw: usize, _ph: usize, _pw: usize,
7598 _device: &GpuDevice,
7599) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7600 Err(GpuError::NoCudaFeature)
7601}
7602
7603#[cfg(feature = "cuda")]
7605#[allow(clippy::too_many_arguments)]
7606pub fn gpu_avgpool2d(
7607 input: &CudaBuffer<f32>,
7608 batch: usize,
7609 channels: usize,
7610 h_in: usize,
7611 w_in: usize,
7612 kh: usize,
7613 kw: usize,
7614 sh: usize,
7615 sw: usize,
7616 ph: usize,
7617 pw: usize,
7618 device: &GpuDevice,
7619) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7620 use cudarc::driver::PushKernelArg;
7621
7622 let h_out = (h_in + 2 * ph - kh) / sh + 1;
7623 let w_out = (w_in + 2 * pw - kw) / sw + 1;
7624 let total = batch * channels * h_out * w_out;
7625
7626 let ctx = device.context();
7627 let stream = device.stream();
7628
7629 let f = match crate::module_cache::get_or_compile(
7630 ctx, AVGPOOL2D_PTX, "avgpool2d_forward_kernel", device.ordinal() as u32,
7631 ) {
7632 Ok(f) => f,
7633 Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "avgpool2d_forward_kernel" }),
7634 };
7635
7636 let mut out = alloc_zeros_f32(total, device)?;
7637 let cfg = launch_cfg(total)?;
7638
7639 let (batch_u32, ch_u32) = (batch as u32, channels as u32);
7640 let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
7641 let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
7642 let (kh_u32, kw_u32) = (kh as u32, kw as u32);
7643 let (sh_u32, sw_u32) = (sh as u32, sw as u32);
7644 let (ph_u32, pw_u32) = (ph as u32, pw as u32);
7645 let total_u32 = total as u32;
7646
7647 unsafe {
7648 stream.launch_builder(&f)
7649 .arg(input.inner())
7650 .arg(out.inner_mut())
7651 .arg(&batch_u32).arg(&ch_u32)
7652 .arg(&h_in_u32).arg(&w_in_u32)
7653 .arg(&h_out_u32).arg(&w_out_u32)
7654 .arg(&kh_u32).arg(&kw_u32)
7655 .arg(&sh_u32).arg(&sw_u32)
7656 .arg(&ph_u32).arg(&pw_u32)
7657 .arg(&total_u32)
7658 .launch(cfg)?;
7659 }
7660
7661 Ok((out, [batch, channels, h_out, w_out]))
7662}
7663
7664#[cfg(not(feature = "cuda"))]
7666#[allow(clippy::too_many_arguments)]
7667pub fn gpu_avgpool2d(
7668 _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
7669 _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
7670 _sh: usize, _sw: usize, _ph: usize, _pw: usize,
7671 _device: &GpuDevice,
7672) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
7673 Err(GpuError::NoCudaFeature)
7674}
7675
7676#[cfg(feature = "cuda")]
7684#[allow(clippy::too_many_arguments)]
7685pub fn gpu_batchnorm_forward(
7686 _input: &CudaBuffer<f32>,
7687 _weight: &CudaBuffer<f32>,
7688 _bias: &CudaBuffer<f32>,
7689 _running_mean: &mut CudaBuffer<f32>,
7690 _running_var: &mut CudaBuffer<f32>,
7691 _channels: usize,
7692 _spatial: usize,
7693 _eps: f32,
7694 _momentum: f32,
7695 _training: bool,
7696 device: &GpuDevice,
7697) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7698 let ctx = device.context();
7700 let _f = crate::module_cache::get_or_compile(
7701 ctx,
7702 BATCHNORM_FORWARD_PTX,
7703 "batchnorm_forward_kernel",
7704 device.ordinal() as u32,
7705 );
7706 Err(GpuError::ShapeMismatch {
7708 op: "batchnorm_forward",
7709 expected: vec![0],
7710 got: vec![1],
7711 })
7712}
7713
7714#[cfg(not(feature = "cuda"))]
7716#[allow(clippy::too_many_arguments)]
7717pub fn gpu_batchnorm_forward(
7718 _input: &CudaBuffer<f32>,
7719 _weight: &CudaBuffer<f32>,
7720 _bias: &CudaBuffer<f32>,
7721 _running_mean: &mut CudaBuffer<f32>,
7722 _running_var: &mut CudaBuffer<f32>,
7723 _channels: usize,
7724 _spatial: usize,
7725 _eps: f32,
7726 _momentum: f32,
7727 _training: bool,
7728 _device: &GpuDevice,
7729) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7730 Err(GpuError::NoCudaFeature)
7731}
7732
7733#[cfg(feature = "cuda")]
7742pub fn gpu_layernorm(
7743 input: &CudaBuffer<f32>,
7744 weight: &CudaBuffer<f32>,
7745 bias: &CudaBuffer<f32>,
7746 rows: usize,
7747 cols: usize,
7748 eps: f32,
7749 device: &GpuDevice,
7750) -> GpuResult<CudaBuffer<f32>> {
7751 use cudarc::driver::PushKernelArg;
7752
7753 validate_unary(input, device)?;
7754
7755 let ctx = device.context();
7756 let stream = device.stream();
7757
7758 let f = match crate::module_cache::get_or_compile(
7759 ctx,
7760 LAYERNORM_PTX,
7761 "layernorm_kernel",
7762 device.ordinal() as u32,
7763 ) {
7764 Ok(f) => f,
7765 Err(e) => {
7766 eprintln!("ferrotorch-gpu: LayerNorm PTX compilation failed ({e:?}), CPU fallback");
7767 std::fs::write("/tmp/layernorm_debug.ptx", LAYERNORM_PTX).ok();
7768 eprintln!(
7769 "ferrotorch-gpu: dumped PTX to /tmp/layernorm_debug.ptx ({} bytes)",
7770 LAYERNORM_PTX.len()
7771 );
7772 let h_in = gpu_to_cpu(input, device)?;
7773 let h_w = gpu_to_cpu(weight, device)?;
7774 let h_b = gpu_to_cpu(bias, device)?;
7775 let mut out = vec![0.0f32; rows * cols];
7776 for r in 0..rows {
7777 let base = r * cols;
7778 let slice = &h_in[base..base + cols];
7779 let mean: f32 = slice.iter().sum::<f32>() / cols as f32;
7780 let var: f32 =
7781 slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / cols as f32;
7782 let inv_std = 1.0 / (var + eps).sqrt();
7783 for c in 0..cols {
7784 let normed = (slice[c] - mean) * inv_std;
7785 out[base + c] = h_w[c] * normed + h_b[c];
7786 }
7787 }
7788 return cpu_to_gpu(&out, device);
7789 }
7790 };
7791
7792 let mut out = alloc_zeros_f32(rows * cols, device)?;
7793 let rows_u32 = rows as u32;
7794 let cols_u32 = cols as u32;
7795
7796 let cfg = LaunchConfig {
7797 grid_dim: ((rows as u32).max(1), 1, 1),
7798 block_dim: (256, 1, 1),
7799 shared_mem_bytes: 256 * 4,
7800 };
7801
7802 unsafe {
7803 stream
7804 .launch_builder(&f)
7805 .arg(input.inner())
7806 .arg(out.inner_mut())
7807 .arg(weight.inner())
7808 .arg(bias.inner())
7809 .arg(&rows_u32)
7810 .arg(&cols_u32)
7811 .arg(&eps)
7812 .launch(cfg)?;
7813 }
7814
7815 Ok(out)
7816}
7817
7818#[cfg(feature = "cuda")]
7831pub fn gpu_layernorm_backward(
7832 input: &CudaBuffer<f32>,
7833 grad_output: &CudaBuffer<f32>,
7834 weight: &CudaBuffer<f32>,
7835 rows: usize,
7836 cols: usize,
7837 eps: f32,
7838 device: &GpuDevice,
7839) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7840 use cudarc::driver::PushKernelArg;
7841
7842 validate_unary(input, device)?;
7843
7844 let ctx = device.context();
7845 let stream = device.stream();
7846
7847 let f = match crate::module_cache::get_or_compile(
7848 ctx,
7849 LAYERNORM_BACKWARD_PTX,
7850 "layernorm_backward_kernel",
7851 device.ordinal() as u32,
7852 ) {
7853 Ok(f) => f,
7854 Err(_) => {
7855 let h_in = gpu_to_cpu(input, device)?;
7857 let h_go = gpu_to_cpu(grad_output, device)?;
7858 let h_w = gpu_to_cpu(weight, device)?;
7859 let mut grad_input = vec![0.0f32; rows * cols];
7860 let mut grad_weight = vec![0.0f32; cols];
7861 let mut grad_bias = vec![0.0f32; cols];
7862 let n_f = cols as f32;
7863 for r in 0..rows {
7864 let base = r * cols;
7865 let x_slice = &h_in[base..base + cols];
7866 let go_slice = &h_go[base..base + cols];
7867 let mean: f32 = x_slice.iter().sum::<f32>() / n_f;
7868 let var: f32 = x_slice
7869 .iter()
7870 .map(|&x| (x - mean) * (x - mean))
7871 .sum::<f32>()
7872 / n_f;
7873 let inv_std = 1.0 / (var + eps).sqrt();
7874 let mut sum1 = 0.0f32;
7875 let mut sum2 = 0.0f32;
7876 for c in 0..cols {
7877 let x_hat = (x_slice[c] - mean) * inv_std;
7878 let dl = go_slice[c] * h_w[c];
7879 sum1 += dl;
7880 sum2 += dl * x_hat;
7881 grad_weight[c] += go_slice[c] * x_hat;
7882 grad_bias[c] += go_slice[c];
7883 }
7884 let m1 = sum1 / n_f;
7885 let m2 = sum2 / n_f;
7886 for c in 0..cols {
7887 let x_hat = (x_slice[c] - mean) * inv_std;
7888 let dl = go_slice[c] * h_w[c];
7889 grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
7890 }
7891 }
7892 let gi = cpu_to_gpu(&grad_input, device)?;
7893 let gw = cpu_to_gpu(&grad_weight, device)?;
7894 let gb = cpu_to_gpu(&grad_bias, device)?;
7895 return Ok((gi, gw, gb));
7896 }
7897 };
7898
7899 let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
7900 let mut grad_w = alloc_zeros_f32(cols, device)?;
7901 let mut grad_b = alloc_zeros_f32(cols, device)?;
7902 let rows_u32 = rows as u32;
7903 let cols_u32 = cols as u32;
7904
7905 let cfg = LaunchConfig {
7907 grid_dim: ((rows as u32).max(1), 1, 1),
7908 block_dim: (256, 1, 1),
7909 shared_mem_bytes: 256 * 4,
7910 };
7911
7912 unsafe {
7913 stream
7914 .launch_builder(&f)
7915 .arg(input.inner())
7916 .arg(grad_output.inner())
7917 .arg(weight.inner())
7918 .arg(grad_in.inner_mut())
7919 .arg(grad_w.inner_mut())
7920 .arg(grad_b.inner_mut())
7921 .arg(&rows_u32)
7922 .arg(&cols_u32)
7923 .arg(&eps)
7924 .launch(cfg)?;
7925 }
7926
7927 Ok((grad_in, grad_w, grad_b))
7928}
7929
7930#[cfg(not(feature = "cuda"))]
7932pub fn gpu_layernorm_backward(
7933 _input: &CudaBuffer<f32>,
7934 _grad_output: &CudaBuffer<f32>,
7935 _weight: &CudaBuffer<f32>,
7936 _rows: usize,
7937 _cols: usize,
7938 _eps: f32,
7939 _device: &GpuDevice,
7940) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
7941 Err(GpuError::NoCudaFeature)
7942}
7943
7944#[cfg(feature = "cuda")]
7954pub fn gpu_add_into(
7955 a: &CudaBuffer<f32>,
7956 b: &CudaBuffer<f32>,
7957 out: &mut CudaBuffer<f32>,
7958 device: &GpuDevice,
7959) -> GpuResult<()> {
7960 validate_binary(a, b, device)?;
7961 if out.len() < a.len() {
7962 return Err(GpuError::ShapeMismatch {
7963 op: "add_into",
7964 expected: vec![a.len()],
7965 got: vec![out.len()],
7966 });
7967 }
7968 if try_launch_binary_into(a, b, out, device, ADD_PTX, "add_kernel")? {
7969 return Ok(());
7970 }
7971 Err(GpuError::PtxCompileFailed {
7972 kernel: "add_kernel",
7973 })
7974}
7975
7976#[cfg(feature = "cuda")]
7978pub fn gpu_mul_into(
7979 a: &CudaBuffer<f32>,
7980 b: &CudaBuffer<f32>,
7981 out: &mut CudaBuffer<f32>,
7982 device: &GpuDevice,
7983) -> GpuResult<()> {
7984 validate_binary(a, b, device)?;
7985 if out.len() < a.len() {
7986 return Err(GpuError::ShapeMismatch {
7987 op: "mul_into",
7988 expected: vec![a.len()],
7989 got: vec![out.len()],
7990 });
7991 }
7992 if try_launch_binary_into(a, b, out, device, MUL_PTX, "mul_kernel")? {
7993 return Ok(());
7994 }
7995 Err(GpuError::PtxCompileFailed {
7996 kernel: "mul_kernel",
7997 })
7998}
7999
8000#[cfg(feature = "cuda")]
8002pub fn gpu_scale_into(
8003 a: &CudaBuffer<f32>,
8004 scalar: f32,
8005 out: &mut CudaBuffer<f32>,
8006 device: &GpuDevice,
8007) -> GpuResult<()> {
8008 use cudarc::driver::PushKernelArg;
8009 validate_unary(a, device)?;
8010 let n = a.len();
8011 let ctx = device.context();
8012 let stream = device.stream();
8013 let f = crate::module_cache::get_or_compile(
8014 ctx,
8015 SCALE_PTX,
8016 "scale_kernel",
8017 device.ordinal() as u32,
8018 )
8019 .map_err(|_| GpuError::PtxCompileFailed {
8020 kernel: "scale_kernel",
8021 })?;
8022 let cfg = launch_cfg(n)?;
8023 let n_u32 = n as u32;
8024 unsafe {
8025 stream
8026 .launch_builder(&f)
8027 .arg(a.inner())
8028 .arg(out.inner_mut())
8029 .arg(&scalar)
8030 .arg(&n_u32)
8031 .launch(cfg)?;
8032 }
8033 Ok(())
8034}
8035
8036#[cfg(feature = "cuda")]
8053pub fn gpu_has_inf_nan(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<bool> {
8054 let n = a.len();
8055 if n == 0 {
8056 return Ok(false);
8057 }
8058
8059 validate_unary(a, device)?;
8060
8061 let host: Vec<f32> = crate::transfer::gpu_to_cpu(a, device)?;
8062 Ok(host.iter().any(|v| !v.is_finite()))
8063}
8064
8065#[cfg(not(feature = "cuda"))]
8067pub fn gpu_has_inf_nan(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<bool> {
8068 Err(GpuError::NoCudaFeature)
8069}
8070
8071#[cfg(feature = "cuda")]
8073pub fn gpu_gelu_into(
8074 a: &CudaBuffer<f32>,
8075 out: &mut CudaBuffer<f32>,
8076 device: &GpuDevice,
8077) -> GpuResult<()> {
8078 validate_unary(a, device)?;
8079 if try_launch_unary_into(a, out, device, GELU_PTX, "gelu_kernel")? {
8080 return Ok(());
8081 }
8082 Err(GpuError::PtxCompileFailed {
8083 kernel: "gelu_kernel",
8084 })
8085}
8086
8087#[cfg(feature = "cuda")]
8089pub fn gpu_embed_lookup_into(
8090 idx: &CudaBuffer<f32>,
8091 weight: &CudaBuffer<f32>,
8092 d: usize,
8093 out: &mut CudaBuffer<f32>,
8094 device: &GpuDevice,
8095) -> GpuResult<()> {
8096 use cudarc::driver::PushKernelArg;
8097 let ctx = device.context();
8098 let stream = device.stream();
8099 let f = crate::module_cache::get_or_compile(
8100 ctx,
8101 EMBED_LOOKUP_PTX,
8102 "embed_lookup_kernel",
8103 device.ordinal() as u32,
8104 )
8105 .map_err(|_| GpuError::PtxCompileFailed {
8106 kernel: "embed_lookup_kernel",
8107 })?;
8108 let cfg = launch_cfg(d)?;
8109 let d_u32 = d as u32;
8110 unsafe {
8111 stream
8112 .launch_builder(&f)
8113 .arg(idx.inner())
8114 .arg(weight.inner())
8115 .arg(out.inner_mut())
8116 .arg(&d_u32)
8117 .launch(cfg)?;
8118 }
8119 Ok(())
8120}
8121
8122#[cfg(feature = "cuda")]
8130pub fn gpu_embed_lookup_batch(
8131 indices: &CudaBuffer<f32>,
8132 weight: &CudaBuffer<f32>,
8133 n: usize,
8134 d: usize,
8135 device: &GpuDevice,
8136) -> GpuResult<CudaBuffer<f32>> {
8137 use cudarc::driver::PushKernelArg;
8138
8139 let total = n * d;
8140 if total == 0 {
8141 return alloc_zeros_f32(0, device);
8142 }
8143
8144 let ctx = device.context();
8145 let stream = device.stream();
8146
8147 let f = match crate::module_cache::get_or_compile(
8148 ctx,
8149 EMBED_LOOKUP_BATCH_PTX,
8150 "embed_lookup_batch_kernel",
8151 device.ordinal() as u32,
8152 ) {
8153 Ok(f) => f,
8154 Err(_) => {
8155 let idx_host = gpu_to_cpu(indices, device)?;
8157 let weight_host = gpu_to_cpu(weight, device)?;
8158 let mut out = Vec::with_capacity(total);
8159 for &idx_f in &idx_host {
8160 let row = idx_f as usize;
8161 let start = row * d;
8162 out.extend_from_slice(&weight_host[start..start + d]);
8163 }
8164 return cpu_to_gpu(&out, device);
8165 }
8166 };
8167
8168 let mut out = alloc_zeros_f32(total, device)?;
8169 let cfg = launch_cfg(total)?;
8170 let d_u32 = d as u32;
8171 let total_u32 = total as u32;
8172
8173 unsafe {
8174 stream
8175 .launch_builder(&f)
8176 .arg(indices.inner())
8177 .arg(weight.inner())
8178 .arg(out.inner_mut())
8179 .arg(&d_u32)
8180 .arg(&total_u32)
8181 .launch(cfg)?;
8182 }
8183
8184 Ok(out)
8185}
8186
8187#[cfg(feature = "cuda")]
8197pub fn gpu_scatter_add_rows(
8198 grad_output: &CudaBuffer<f32>,
8199 indices: &CudaBuffer<f32>,
8200 num_embeddings: usize,
8201 d: usize,
8202 device: &GpuDevice,
8203) -> GpuResult<CudaBuffer<f32>> {
8204 use cudarc::driver::PushKernelArg;
8205
8206 let n = indices.len();
8207 let total = n * d;
8208
8209 if total == 0 {
8210 return alloc_zeros_f32(num_embeddings * d, device);
8211 }
8212
8213 let ctx = device.context();
8214 let stream = device.stream();
8215
8216 let f = match crate::module_cache::get_or_compile(
8217 ctx,
8218 SCATTER_ADD_ROWS_PTX,
8219 "scatter_add_rows_kernel",
8220 device.ordinal() as u32,
8221 ) {
8222 Ok(f) => f,
8223 Err(_) => {
8224 let go_host = gpu_to_cpu(grad_output, device)?;
8226 let idx_host = gpu_to_cpu(indices, device)?;
8227 let mut result = vec![0.0f32; num_embeddings * d];
8228 for (i, &idx_f) in idx_host.iter().enumerate() {
8229 let row = idx_f as usize;
8230 for j in 0..d {
8231 result[row * d + j] += go_host[i * d + j];
8232 }
8233 }
8234 return cpu_to_gpu(&result, device);
8235 }
8236 };
8237
8238 let mut out = alloc_zeros_f32(num_embeddings * d, device)?;
8239 let cfg = launch_cfg(total)?;
8240 let d_u32 = d as u32;
8241 let total_u32 = total as u32;
8242
8243 unsafe {
8244 stream
8245 .launch_builder(&f)
8246 .arg(grad_output.inner())
8247 .arg(indices.inner())
8248 .arg(out.inner_mut())
8249 .arg(&d_u32)
8250 .arg(&total_u32)
8251 .launch(cfg)?;
8252 }
8253
8254 Ok(out)
8255}
8256
8257#[cfg(feature = "cuda")]
8259pub fn gpu_transpose_2d_into(
8260 a: &CudaBuffer<f32>,
8261 m: usize,
8262 n: usize,
8263 out: &mut CudaBuffer<f32>,
8264 device: &GpuDevice,
8265) -> GpuResult<()> {
8266 use cudarc::driver::PushKernelArg;
8267 let total = m * n;
8268 let ctx = device.context();
8269 let stream = device.stream();
8270 let f = crate::module_cache::get_or_compile(
8271 ctx,
8272 TRANSPOSE_2D_PTX,
8273 "transpose_2d_kernel",
8274 device.ordinal() as u32,
8275 )
8276 .map_err(|_| GpuError::PtxCompileFailed {
8277 kernel: "transpose_2d_kernel",
8278 })?;
8279 let cfg = launch_cfg(total)?;
8280 let m_u32 = m as u32;
8281 let n_u32 = n as u32;
8282 let total_u32 = total as u32;
8283 unsafe {
8284 stream
8285 .launch_builder(&f)
8286 .arg(a.inner())
8287 .arg(out.inner_mut())
8288 .arg(&m_u32)
8289 .arg(&n_u32)
8290 .arg(&total_u32)
8291 .launch(cfg)?;
8292 }
8293 Ok(())
8294}
8295
8296#[cfg(feature = "cuda")]
8298pub fn gpu_permute_0213_into(
8299 a: &CudaBuffer<f32>,
8300 d0: usize,
8301 d1: usize,
8302 d2: usize,
8303 d3: usize,
8304 out: &mut CudaBuffer<f32>,
8305 device: &GpuDevice,
8306) -> GpuResult<()> {
8307 use cudarc::driver::PushKernelArg;
8308 let total = d0 * d1 * d2 * d3;
8309 let ctx = device.context();
8310 let stream = device.stream();
8311 let f = crate::module_cache::get_or_compile(
8312 ctx,
8313 PERMUTE_0213_PTX,
8314 "permute_0213_kernel",
8315 device.ordinal() as u32,
8316 )
8317 .map_err(|_| GpuError::PtxCompileFailed {
8318 kernel: "permute_0213_kernel",
8319 })?;
8320 let cfg = launch_cfg(total)?;
8321 let (d0u, d1u, d2u, d3u, tu) = (d0 as u32, d1 as u32, d2 as u32, d3 as u32, total as u32);
8322 unsafe {
8323 stream
8324 .launch_builder(&f)
8325 .arg(a.inner())
8326 .arg(out.inner_mut())
8327 .arg(&d0u)
8328 .arg(&d1u)
8329 .arg(&d2u)
8330 .arg(&d3u)
8331 .arg(&tu)
8332 .launch(cfg)?;
8333 }
8334 Ok(())
8335}
8336
8337#[cfg(feature = "cuda")]
8339pub fn gpu_softmax_into(
8340 a: &CudaBuffer<f32>,
8341 rows: usize,
8342 cols: usize,
8343 out: &mut CudaBuffer<f32>,
8344 device: &GpuDevice,
8345) -> GpuResult<()> {
8346 use cudarc::driver::PushKernelArg;
8347 let ctx = device.context();
8348 let stream = device.stream();
8349 let f = crate::module_cache::get_or_compile(
8350 ctx,
8351 SOFTMAX_PTX,
8352 "softmax_kernel",
8353 device.ordinal() as u32,
8354 )
8355 .map_err(|_| GpuError::PtxCompileFailed {
8356 kernel: "softmax_kernel",
8357 })?;
8358 let block_size = 256u32;
8359 let grid_size = rows as u32;
8360 let cfg = LaunchConfig {
8361 grid_dim: (grid_size, 1, 1),
8362 block_dim: (block_size, 1, 1),
8363 shared_mem_bytes: (cols as u32) * 4,
8364 };
8365 let rows_u32 = rows as u32;
8366 let cols_u32 = cols as u32;
8367 unsafe {
8368 stream
8369 .launch_builder(&f)
8370 .arg(a.inner())
8371 .arg(out.inner_mut())
8372 .arg(&rows_u32)
8373 .arg(&cols_u32)
8374 .launch(cfg)?;
8375 }
8376 Ok(())
8377}
8378
8379#[cfg(feature = "cuda")]
8381#[allow(clippy::too_many_arguments)]
8382pub fn gpu_layernorm_into(
8383 input: &CudaBuffer<f32>,
8384 weight: &CudaBuffer<f32>,
8385 bias: &CudaBuffer<f32>,
8386 rows: usize,
8387 cols: usize,
8388 eps: f32,
8389 out: &mut CudaBuffer<f32>,
8390 device: &GpuDevice,
8391) -> GpuResult<()> {
8392 use cudarc::driver::PushKernelArg;
8393 let ctx = device.context();
8394 let stream = device.stream();
8395 let f = crate::module_cache::get_or_compile(
8396 ctx,
8397 LAYERNORM_PTX,
8398 "layernorm_kernel",
8399 device.ordinal() as u32,
8400 )
8401 .map_err(|_| GpuError::PtxCompileFailed {
8402 kernel: "layernorm_kernel",
8403 })?;
8404 let block_size = 256u32;
8405 let grid_size = rows as u32;
8406 let cfg = LaunchConfig {
8407 grid_dim: (grid_size, 1, 1),
8408 block_dim: (block_size, 1, 1),
8409 shared_mem_bytes: (cols as u32) * 4,
8410 };
8411 let rows_u32 = rows as u32;
8412 let cols_u32 = cols as u32;
8413 unsafe {
8414 stream
8415 .launch_builder(&f)
8416 .arg(input.inner())
8417 .arg(out.inner_mut())
8418 .arg(weight.inner())
8419 .arg(bias.inner())
8420 .arg(&rows_u32)
8421 .arg(&cols_u32)
8422 .arg(&eps)
8423 .launch(cfg)?;
8424 }
8425 Ok(())
8426}
8427
8428#[cfg(feature = "cuda")]
8431pub fn gpu_slice_read_into(
8432 src: &CudaBuffer<f32>,
8433 n_batch: usize,
8434 d: usize,
8435 len: usize,
8436 max_len: usize,
8437 out: &mut CudaBuffer<f32>,
8438 device: &GpuDevice,
8439) -> GpuResult<()> {
8440 use cudarc::driver::PushKernelArg;
8441 let total = n_batch * len * d;
8442 let ctx = device.context();
8443 let stream = device.stream();
8444 let f = crate::module_cache::get_or_compile(
8445 ctx,
8446 SLICE_READ_PTX,
8447 "slice_read_kernel",
8448 device.ordinal() as u32,
8449 )
8450 .map_err(|_| GpuError::PtxCompileFailed {
8451 kernel: "slice_read_kernel",
8452 })?;
8453 let cfg = launch_cfg(total)?;
8454 let total_u32 = total as u32;
8455 let d_u32 = d as u32;
8456 let len_u32 = len as u32;
8457 let max_len_u32 = max_len as u32;
8458 unsafe {
8459 stream
8460 .launch_builder(&f)
8461 .arg(src.inner())
8462 .arg(out.inner_mut())
8463 .arg(&total_u32)
8464 .arg(&d_u32)
8465 .arg(&len_u32)
8466 .arg(&max_len_u32)
8467 .launch(cfg)?;
8468 }
8469 Ok(())
8470}
8471
8472#[cfg(feature = "cuda")]
8474pub fn gpu_small_matmul_into(
8475 a: &CudaBuffer<f32>,
8476 b: &CudaBuffer<f32>,
8477 m: usize,
8478 k: usize,
8479 n: usize,
8480 out: &mut CudaBuffer<f32>,
8481 device: &GpuDevice,
8482) -> GpuResult<()> {
8483 use cudarc::driver::PushKernelArg;
8484 let total = m * n;
8485 let ctx = device.context();
8486 let stream = device.stream();
8487 let f = crate::module_cache::get_or_compile(
8488 ctx,
8489 SMALL_MATMUL_PTX,
8490 "small_matmul_kernel",
8491 device.ordinal() as u32,
8492 )
8493 .map_err(|_| GpuError::PtxCompileFailed {
8494 kernel: "small_matmul_kernel",
8495 })?;
8496 let cfg = launch_cfg(total)?;
8497 let (m_u32, k_u32, n_u32, total_u32) = (m as u32, k as u32, n as u32, total as u32);
8498 unsafe {
8499 stream
8500 .launch_builder(&f)
8501 .arg(a.inner())
8502 .arg(b.inner())
8503 .arg(out.inner_mut())
8504 .arg(&m_u32)
8505 .arg(&k_u32)
8506 .arg(&n_u32)
8507 .arg(&total_u32)
8508 .launch(cfg)?;
8509 }
8510 Ok(())
8511}
8512
8513#[cfg(feature = "cuda")]
8520pub fn gpu_slice_write_indirect(
8521 src: &CudaBuffer<f32>,
8522 dst: &mut CudaBuffer<f32>,
8523 n_batch: usize,
8524 d: usize,
8525 max_len: usize,
8526 pos_ptr: &cudarc::driver::CudaSlice<u32>,
8527 device: &GpuDevice,
8528) -> GpuResult<()> {
8529 use cudarc::driver::PushKernelArg;
8530 let total = n_batch * d;
8531 let ctx = device.context();
8532 let stream = device.stream();
8533 let f = crate::module_cache::get_or_compile(
8534 ctx,
8535 SLICE_WRITE_INDIRECT_PTX,
8536 "slice_write_indirect_kernel",
8537 device.ordinal() as u32,
8538 )
8539 .map_err(|_| GpuError::PtxCompileFailed {
8540 kernel: "slice_write_indirect_kernel",
8541 })?;
8542 let cfg = launch_cfg(total)?;
8543 let n_u32 = total as u32;
8544 let d_u32 = d as u32;
8545 let max_len_u32 = max_len as u32;
8546 unsafe {
8547 stream
8548 .launch_builder(&f)
8549 .arg(src.inner())
8550 .arg(dst.inner_mut())
8551 .arg(&n_u32)
8552 .arg(&d_u32)
8553 .arg(&max_len_u32)
8554 .arg(pos_ptr)
8555 .launch(cfg)?;
8556 }
8557 Ok(())
8558}
8559
8560#[cfg(feature = "cuda")]
8564pub fn gpu_causal_mask_indirect(
8565 total_len_ptr: &cudarc::driver::CudaSlice<u32>,
8566 n_head: usize,
8567 max_pos: usize,
8568 out: &mut CudaBuffer<f32>,
8569 device: &GpuDevice,
8570) -> GpuResult<()> {
8571 use cudarc::driver::PushKernelArg;
8572 let total = n_head * max_pos;
8573 let ctx = device.context();
8574 let stream = device.stream();
8575 let f = crate::module_cache::get_or_compile(
8576 ctx,
8577 CAUSAL_MASK_INDIRECT_PTX,
8578 "causal_mask_indirect_kernel",
8579 device.ordinal() as u32,
8580 )
8581 .map_err(|_| GpuError::PtxCompileFailed {
8582 kernel: "causal_mask_indirect_kernel",
8583 })?;
8584 let cfg = launch_cfg(total)?;
8585 let max_pos_u32 = max_pos as u32;
8586 let total_u32 = total as u32;
8587 unsafe {
8588 stream
8589 .launch_builder(&f)
8590 .arg(total_len_ptr)
8591 .arg(out.inner_mut())
8592 .arg(&max_pos_u32)
8593 .arg(&total_u32)
8594 .launch(cfg)?;
8595 }
8596 Ok(())
8597}
8598
8599#[cfg(feature = "cuda")]
8607pub fn precompile_decode_kernels(device: &GpuDevice) -> GpuResult<()> {
8608 let ctx = device.context();
8609 ctx.bind_to_thread()?;
8610 let ord = device.ordinal() as u32;
8611 let compile = |ptx: &'static str, name: &'static str| -> GpuResult<()> {
8612 crate::module_cache::get_or_compile(ctx, ptx, name, ord)
8613 .map(|_| ())
8614 .map_err(GpuError::Driver)
8615 };
8616 compile(ADD_PTX, "add_kernel")?;
8617 compile(MUL_PTX, "mul_kernel")?;
8618 compile(SCALE_PTX, "scale_kernel")?;
8619 compile(GELU_PTX, "gelu_kernel")?;
8620 compile(SOFTMAX_PTX, "softmax_kernel")?;
8621 compile(LAYERNORM_PTX, "layernorm_kernel")?;
8622 compile(PERMUTE_0213_PTX, "permute_0213_kernel")?;
8623 compile(EMBED_LOOKUP_PTX, "embed_lookup_kernel")?;
8624 compile(EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel")?;
8625 compile(SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel")?;
8626 compile(SMALL_MATMUL_PTX, "small_matmul_kernel")?;
8627 compile(SLICE_WRITE_INDIRECT_PTX, "slice_write_indirect_kernel")?;
8628 compile(CAUSAL_MASK_INDIRECT_PTX, "causal_mask_indirect_kernel")?;
8629 compile(SLICE_READ_PTX, "slice_read_kernel")?;
8630 compile(RELU_BACKWARD_PTX, "relu_backward_kernel")?;
8631 compile(GELU_BACKWARD_PTX, "gelu_backward_kernel")?;
8632 Ok(())
8633}
8634
8635#[cfg(not(feature = "cuda"))]
8637pub fn precompile_decode_kernels(_device: &GpuDevice) -> GpuResult<()> {
8638 Err(GpuError::NoCudaFeature)
8639}
8640
8641#[cfg(not(feature = "cuda"))]
8647pub fn gpu_gelu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8648 Err(GpuError::NoCudaFeature)
8649}
8650
8651#[cfg(not(feature = "cuda"))]
8653pub fn gpu_gelu_tanh(
8654 _input: &CudaBuffer<f32>,
8655 _device: &GpuDevice,
8656) -> GpuResult<CudaBuffer<f32>> {
8657 Err(GpuError::NoCudaFeature)
8658}
8659
8660#[cfg(not(feature = "cuda"))]
8662pub fn gpu_gelu_erf(
8663 _input: &CudaBuffer<f32>,
8664 _device: &GpuDevice,
8665) -> GpuResult<CudaBuffer<f32>> {
8666 Err(GpuError::NoCudaFeature)
8667}
8668
8669#[cfg(not(feature = "cuda"))]
8671pub fn gpu_gelu_backward_tanh(
8672 _grad: &CudaBuffer<f32>,
8673 _input: &CudaBuffer<f32>,
8674 _device: &GpuDevice,
8675) -> GpuResult<CudaBuffer<f32>> {
8676 Err(GpuError::NoCudaFeature)
8677}
8678
8679#[cfg(not(feature = "cuda"))]
8681pub fn gpu_div(
8682 _a: &CudaBuffer<f32>,
8683 _b: &CudaBuffer<f32>,
8684 _device: &GpuDevice,
8685) -> GpuResult<CudaBuffer<f32>> {
8686 Err(GpuError::NoCudaFeature)
8687}
8688
8689#[cfg(not(feature = "cuda"))]
8691pub fn gpu_exp(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8692 Err(GpuError::NoCudaFeature)
8693}
8694
8695#[cfg(not(feature = "cuda"))]
8697pub fn gpu_log(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8698 Err(GpuError::NoCudaFeature)
8699}
8700
8701#[cfg(not(feature = "cuda"))]
8703pub fn gpu_sqrt(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8704 Err(GpuError::NoCudaFeature)
8705}
8706
8707#[cfg(not(feature = "cuda"))]
8709pub fn gpu_pow(
8710 _a: &CudaBuffer<f32>,
8711 _exponent: f32,
8712 _device: &GpuDevice,
8713) -> GpuResult<CudaBuffer<f32>> {
8714 Err(GpuError::NoCudaFeature)
8715}
8716
8717#[cfg(not(feature = "cuda"))]
8719pub fn gpu_abs(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8720 Err(GpuError::NoCudaFeature)
8721}
8722
8723#[cfg(not(feature = "cuda"))]
8725pub fn gpu_sigmoid(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8726 Err(GpuError::NoCudaFeature)
8727}
8728
8729#[cfg(not(feature = "cuda"))]
8731pub fn gpu_tanh(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8732 Err(GpuError::NoCudaFeature)
8733}
8734
8735#[cfg(not(feature = "cuda"))]
8737pub fn gpu_layernorm(
8738 _input: &CudaBuffer<f32>,
8739 _weight: &CudaBuffer<f32>,
8740 _bias: &CudaBuffer<f32>,
8741 _rows: usize,
8742 _cols: usize,
8743 _eps: f32,
8744 _device: &GpuDevice,
8745) -> GpuResult<CudaBuffer<f32>> {
8746 Err(GpuError::NoCudaFeature)
8747}
8748
8749#[cfg(not(feature = "cuda"))]
8751pub fn gpu_transpose_2d(
8752 _input: &CudaBuffer<f32>,
8753 _m: usize,
8754 _n: usize,
8755 _device: &GpuDevice,
8756) -> GpuResult<CudaBuffer<f32>> {
8757 Err(GpuError::NoCudaFeature)
8758}
8759
8760#[cfg(not(feature = "cuda"))]
8762pub fn gpu_add(
8763 _a: &CudaBuffer<f32>,
8764 _b: &CudaBuffer<f32>,
8765 _device: &GpuDevice,
8766) -> GpuResult<CudaBuffer<f32>> {
8767 Err(GpuError::NoCudaFeature)
8768}
8769
8770#[cfg(not(feature = "cuda"))]
8772pub fn gpu_sub(
8773 _a: &CudaBuffer<f32>,
8774 _b: &CudaBuffer<f32>,
8775 _device: &GpuDevice,
8776) -> GpuResult<CudaBuffer<f32>> {
8777 Err(GpuError::NoCudaFeature)
8778}
8779
8780#[cfg(not(feature = "cuda"))]
8782pub fn gpu_mul(
8783 _a: &CudaBuffer<f32>,
8784 _b: &CudaBuffer<f32>,
8785 _device: &GpuDevice,
8786) -> GpuResult<CudaBuffer<f32>> {
8787 Err(GpuError::NoCudaFeature)
8788}
8789
8790#[cfg(not(feature = "cuda"))]
8792pub fn gpu_neg(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8793 Err(GpuError::NoCudaFeature)
8794}
8795
8796#[cfg(not(feature = "cuda"))]
8798pub fn gpu_relu(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
8799 Err(GpuError::NoCudaFeature)
8800}
8801
8802#[cfg(not(feature = "cuda"))]
8804pub fn gpu_scale(
8805 _a: &CudaBuffer<f32>,
8806 _scalar: f32,
8807 _device: &GpuDevice,
8808) -> GpuResult<CudaBuffer<f32>> {
8809 Err(GpuError::NoCudaFeature)
8810}
8811
8812#[cfg(not(feature = "cuda"))]
8814pub fn gpu_broadcast_add(
8815 _a: &CudaBuffer<f32>,
8816 _b: &CudaBuffer<f32>,
8817 _a_shape: &[usize],
8818 _b_shape: &[usize],
8819 _out_shape: &[usize],
8820 _device: &GpuDevice,
8821) -> GpuResult<CudaBuffer<f32>> {
8822 Err(GpuError::NoCudaFeature)
8823}
8824
8825#[cfg(not(feature = "cuda"))]
8827pub fn gpu_broadcast_sub(
8828 _a: &CudaBuffer<f32>,
8829 _b: &CudaBuffer<f32>,
8830 _a_shape: &[usize],
8831 _b_shape: &[usize],
8832 _out_shape: &[usize],
8833 _device: &GpuDevice,
8834) -> GpuResult<CudaBuffer<f32>> {
8835 Err(GpuError::NoCudaFeature)
8836}
8837
8838#[cfg(not(feature = "cuda"))]
8840pub fn gpu_broadcast_mul(
8841 _a: &CudaBuffer<f32>,
8842 _b: &CudaBuffer<f32>,
8843 _a_shape: &[usize],
8844 _b_shape: &[usize],
8845 _out_shape: &[usize],
8846 _device: &GpuDevice,
8847) -> GpuResult<CudaBuffer<f32>> {
8848 Err(GpuError::NoCudaFeature)
8849}
8850
8851#[cfg(not(feature = "cuda"))]
8853pub fn gpu_softmax(
8854 _input: &CudaBuffer<f32>,
8855 _rows: usize,
8856 _cols: usize,
8857 _device: &GpuDevice,
8858) -> GpuResult<CudaBuffer<f32>> {
8859 Err(GpuError::NoCudaFeature)
8860}
8861
8862#[cfg(not(feature = "cuda"))]
8864pub fn gpu_dropout(
8865 _input: &CudaBuffer<f32>,
8866 _threshold: u32,
8867 _scale: f32,
8868 _seed: u32,
8869 _device: &GpuDevice,
8870) -> GpuResult<CudaBuffer<f32>> {
8871 Err(GpuError::NoCudaFeature)
8872}
8873
8874#[cfg(not(feature = "cuda"))]
8876pub fn gpu_permute_0213(
8877 _input: &CudaBuffer<f32>,
8878 _d0: usize,
8879 _d1: usize,
8880 _d2: usize,
8881 _d3: usize,
8882 _device: &GpuDevice,
8883) -> GpuResult<CudaBuffer<f32>> {
8884 Err(GpuError::NoCudaFeature)
8885}
8886
8887#[cfg(not(feature = "cuda"))]
8889pub fn gpu_slice_write(
8890 _src: &CudaBuffer<f32>,
8891 _dst: &mut CudaBuffer<f32>,
8892 _n_batch: usize,
8893 _d: usize,
8894 _max_len: usize,
8895 _pos: usize,
8896 _device: &GpuDevice,
8897) -> GpuResult<()> {
8898 Err(GpuError::NoCudaFeature)
8899}
8900
8901#[cfg(not(feature = "cuda"))]
8903pub fn gpu_slice_read(
8904 _src: &CudaBuffer<f32>,
8905 _n_batch: usize,
8906 _d: usize,
8907 _len: usize,
8908 _max_len: usize,
8909 _device: &GpuDevice,
8910) -> GpuResult<CudaBuffer<f32>> {
8911 Err(GpuError::NoCudaFeature)
8912}
8913
8914#[cfg(not(feature = "cuda"))]
8916pub fn gpu_embed_lookup(
8917 _idx: &CudaBuffer<f32>,
8918 _weight: &CudaBuffer<f32>,
8919 _d: usize,
8920 _device: &GpuDevice,
8921) -> GpuResult<CudaBuffer<f32>> {
8922 Err(GpuError::NoCudaFeature)
8923}
8924
8925#[cfg(not(feature = "cuda"))]
8927pub fn gpu_embed_lookup_batch(
8928 _indices: &CudaBuffer<f32>,
8929 _weight: &CudaBuffer<f32>,
8930 _n: usize,
8931 _d: usize,
8932 _device: &GpuDevice,
8933) -> GpuResult<CudaBuffer<f32>> {
8934 Err(GpuError::NoCudaFeature)
8935}
8936
8937#[cfg(not(feature = "cuda"))]
8939pub fn gpu_scatter_add_rows(
8940 _grad_output: &CudaBuffer<f32>,
8941 _indices: &CudaBuffer<f32>,
8942 _num_embeddings: usize,
8943 _d: usize,
8944 _device: &GpuDevice,
8945) -> GpuResult<CudaBuffer<f32>> {
8946 Err(GpuError::NoCudaFeature)
8947}
8948
8949#[cfg(not(feature = "cuda"))]
8951pub fn gpu_relu_backward(
8952 _grad: &CudaBuffer<f32>,
8953 _input: &CudaBuffer<f32>,
8954 _device: &GpuDevice,
8955) -> GpuResult<CudaBuffer<f32>> {
8956 Err(GpuError::NoCudaFeature)
8957}
8958
8959#[cfg(not(feature = "cuda"))]
8961pub fn gpu_gelu_backward(
8962 _grad: &CudaBuffer<f32>,
8963 _input: &CudaBuffer<f32>,
8964 _device: &GpuDevice,
8965) -> GpuResult<CudaBuffer<f32>> {
8966 Err(GpuError::NoCudaFeature)
8967}
8968
8969#[cfg(not(feature = "cuda"))]
8971pub fn gpu_index_select_1d(
8972 _input: &CudaBuffer<f32>,
8973 _indices: &CudaBuffer<f32>,
8974 _device: &GpuDevice,
8975) -> GpuResult<CudaBuffer<f32>> {
8976 Err(GpuError::NoCudaFeature)
8977}
8978
8979#[cfg(not(feature = "cuda"))]
8981pub fn gpu_scatter_add_1d(
8982 _grad_output: &CudaBuffer<f32>,
8983 _indices: &CudaBuffer<f32>,
8984 _input_len: usize,
8985 _device: &GpuDevice,
8986) -> GpuResult<CudaBuffer<f32>> {
8987 Err(GpuError::NoCudaFeature)
8988}
8989
8990#[cfg(not(feature = "cuda"))]
8992pub fn gpu_masked_fill(
8993 _input: &CudaBuffer<f32>,
8994 _mask: &CudaBuffer<f32>,
8995 _value: f32,
8996 _device: &GpuDevice,
8997) -> GpuResult<CudaBuffer<f32>> {
8998 Err(GpuError::NoCudaFeature)
8999}
9000
9001#[cfg(not(feature = "cuda"))]
9003pub fn gpu_masked_zero(
9004 _grad: &CudaBuffer<f32>,
9005 _mask: &CudaBuffer<f32>,
9006 _device: &GpuDevice,
9007) -> GpuResult<CudaBuffer<f32>> {
9008 Err(GpuError::NoCudaFeature)
9009}
9010
9011#[cfg(not(feature = "cuda"))]
9013pub fn gpu_sigmoid_backward(
9014 _grad: &CudaBuffer<f32>,
9015 _output: &CudaBuffer<f32>,
9016 _device: &GpuDevice,
9017) -> GpuResult<CudaBuffer<f32>> {
9018 Err(GpuError::NoCudaFeature)
9019}
9020
9021#[cfg(not(feature = "cuda"))]
9023pub fn gpu_tanh_backward(
9024 _grad: &CudaBuffer<f32>,
9025 _output: &CudaBuffer<f32>,
9026 _device: &GpuDevice,
9027) -> GpuResult<CudaBuffer<f32>> {
9028 Err(GpuError::NoCudaFeature)
9029}
9030
9031#[cfg(not(feature = "cuda"))]
9033pub fn gpu_softmax_backward(
9034 _grad: &CudaBuffer<f32>,
9035 _output: &CudaBuffer<f32>,
9036 _cols: usize,
9037 _device: &GpuDevice,
9038) -> GpuResult<CudaBuffer<f32>> {
9039 Err(GpuError::NoCudaFeature)
9040}
9041
9042#[cfg(not(feature = "cuda"))]
9044pub fn gpu_sum_axis(
9045 _a: &CudaBuffer<f32>,
9046 _outer: usize,
9047 _axis_size: usize,
9048 _inner: usize,
9049 _device: &GpuDevice,
9050) -> GpuResult<CudaBuffer<f32>> {
9051 Err(GpuError::NoCudaFeature)
9052}
9053
9054#[cfg(not(feature = "cuda"))]
9056pub fn gpu_strided_split(
9057 _input: &CudaBuffer<f32>,
9058 _total_along_axis: usize,
9059 _split_offset: usize,
9060 _split_size: usize,
9061 _inner_size: usize,
9062 _n: usize,
9063 _device: &GpuDevice,
9064) -> GpuResult<CudaBuffer<f32>> {
9065 Err(GpuError::NoCudaFeature)
9066}
9067
9068#[cfg(not(feature = "cuda"))]
9070pub fn gpu_strided_cat(
9071 _input: &CudaBuffer<f32>,
9072 _output: &mut CudaBuffer<f32>,
9073 _total_along_axis: usize,
9074 _cat_offset: usize,
9075 _part_size: usize,
9076 _inner_size: usize,
9077 _n: usize,
9078 _device: &GpuDevice,
9079) -> GpuResult<()> {
9080 Err(GpuError::NoCudaFeature)
9081}
9082
9083#[cfg(feature = "cuda")]
9099pub(crate) fn gpu_f32_to_f16(
9100 input: &CudaBuffer<f32>,
9101 device: &GpuDevice,
9102) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
9103 use cudarc::driver::PushKernelArg;
9104
9105 let n = input.len();
9106 if n == 0 {
9107 let empty = device.stream().alloc_zeros::<u16>(0)?;
9108 return Ok(empty);
9109 }
9110
9111 let ctx = device.context();
9112 let stream = device.stream();
9113
9114 let f = crate::module_cache::get_or_compile(
9115 ctx,
9116 F32_TO_F16_PTX,
9117 "f32_to_f16_kernel",
9118 device.ordinal() as u32,
9119 )
9120 .map_err(|_| GpuError::PtxCompileFailed {
9121 kernel: "f32_to_f16_kernel",
9122 })?;
9123
9124 let mut out = stream.alloc_zeros::<u16>(n)?;
9125 let cfg = launch_cfg(n)?;
9126 let n_u32 = n as u32;
9127
9128 unsafe {
9132 stream
9133 .launch_builder(&f)
9134 .arg(input.inner())
9135 .arg(&mut out)
9136 .arg(&n_u32)
9137 .launch(cfg)?;
9138 }
9139
9140 Ok(out)
9141}
9142
9143#[cfg(not(feature = "cuda"))]
9145pub(crate) fn gpu_f32_to_f16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
9146 Err(GpuError::NoCudaFeature)
9147}
9148
9149#[cfg(feature = "cuda")]
9154pub(crate) fn gpu_f32_to_bf16(
9155 input: &CudaBuffer<f32>,
9156 device: &GpuDevice,
9157) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
9158 use cudarc::driver::PushKernelArg;
9159
9160 let n = input.len();
9161 if n == 0 {
9162 let empty = device.stream().alloc_zeros::<u16>(0)?;
9163 return Ok(empty);
9164 }
9165
9166 let ctx = device.context();
9167 let stream = device.stream();
9168
9169 let f = crate::module_cache::get_or_compile(
9170 ctx,
9171 F32_TO_BF16_PTX,
9172 "f32_to_bf16_kernel",
9173 device.ordinal() as u32,
9174 )
9175 .map_err(|_| GpuError::PtxCompileFailed {
9176 kernel: "f32_to_bf16_kernel",
9177 })?;
9178
9179 let mut out = stream.alloc_zeros::<u16>(n)?;
9180 let cfg = launch_cfg(n)?;
9181 let n_u32 = n as u32;
9182
9183 unsafe {
9184 stream
9185 .launch_builder(&f)
9186 .arg(input.inner())
9187 .arg(&mut out)
9188 .arg(&n_u32)
9189 .launch(cfg)?;
9190 }
9191
9192 Ok(out)
9193}
9194
9195#[cfg(not(feature = "cuda"))]
9197pub(crate) fn gpu_f32_to_bf16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
9198 Err(GpuError::NoCudaFeature)
9199}
9200
9201#[cfg(test)]
9206#[cfg(feature = "cuda")]
9207mod tests {
9208 use super::*;
9209
9210 fn setup(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
9212 let dev = GpuDevice::new(0).expect("CUDA device 0");
9213 let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
9214 (dev, buf)
9215 }
9216
9217 fn assert_buf_eq(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32]) {
9220 let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
9221 assert_eq!(host.len(), expected.len(), "length mismatch");
9222 for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
9223 assert!(
9224 (got - exp).abs() < 1e-6,
9225 "element {i}: got {got}, expected {exp}",
9226 );
9227 }
9228 }
9229
9230 #[test]
9233 fn add_basic() {
9234 let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
9235 let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
9236 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
9237
9238 let (dev, a) = setup(&a_data);
9239 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
9240 let out = gpu_add(&a, &b, &dev).expect("gpu_add");
9241 assert_buf_eq(&out, &dev, &expected);
9242 }
9243
9244 #[test]
9245 fn add_empty() {
9246 let (dev, a) = setup(&[]);
9247 let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
9248 let out = gpu_add(&a, &b, &dev).expect("gpu_add empty");
9249 assert_eq!(out.len(), 0);
9250 }
9251
9252 #[test]
9253 fn add_large() {
9254 let n = 100_000;
9255 let a_data: Vec<f32> = (0..n).map(|i| i as f32).collect();
9256 let b_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
9257 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
9258
9259 let (dev, a) = setup(&a_data);
9260 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
9261 let out = gpu_add(&a, &b, &dev).expect("gpu_add large");
9262 assert_buf_eq(&out, &dev, &expected);
9263 }
9264
9265 #[test]
9266 fn add_length_mismatch() {
9267 let (dev, a) = setup(&[1.0, 2.0, 3.0]);
9268 let b = cpu_to_gpu(&[1.0, 2.0], &dev).expect("cpu_to_gpu b");
9269 let err = gpu_add(&a, &b, &dev).unwrap_err();
9270 match err {
9271 GpuError::LengthMismatch { a: 3, b: 2 } => {}
9272 other => panic!("unexpected error: {other}"),
9273 }
9274 }
9275
9276 #[test]
9279 fn sub_basic() {
9280 let a_data = vec![10.0f32, 20.0, 30.0, 40.0];
9281 let b_data = vec![1.0f32, 2.0, 3.0, 4.0];
9282 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x - y).collect();
9283
9284 let (dev, a) = setup(&a_data);
9285 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
9286 let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
9287 assert_buf_eq(&out, &dev, &expected);
9288 }
9289
9290 #[test]
9291 fn sub_negative_result() {
9292 let a_data = vec![1.0f32, 2.0];
9293 let b_data = vec![5.0f32, 10.0];
9294 let expected: Vec<f32> = vec![-4.0, -8.0];
9295
9296 let (dev, a) = setup(&a_data);
9297 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
9298 let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
9299 assert_buf_eq(&out, &dev, &expected);
9300 }
9301
9302 #[test]
9305 fn mul_basic() {
9306 let a_data = vec![2.0f32, 3.0, 4.0, 5.0];
9307 let b_data = vec![10.0f32, 10.0, 10.0, 10.0];
9308 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x * y).collect();
9309
9310 let (dev, a) = setup(&a_data);
9311 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
9312 let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
9313 assert_buf_eq(&out, &dev, &expected);
9314 }
9315
9316 #[test]
9317 fn mul_by_zero() {
9318 let a_data = vec![1.0f32, 2.0, 3.0];
9319 let b_data = vec![0.0f32, 0.0, 0.0];
9320 let expected = vec![0.0f32, 0.0, 0.0];
9321
9322 let (dev, a) = setup(&a_data);
9323 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
9324 let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
9325 assert_buf_eq(&out, &dev, &expected);
9326 }
9327
9328 #[test]
9331 fn neg_basic() {
9332 let a_data = vec![1.0f32, -2.0, 3.0, 0.0, -5.5];
9333 let expected: Vec<f32> = a_data.iter().map(|x| -x).collect();
9334
9335 let (dev, a) = setup(&a_data);
9336 let out = gpu_neg(&a, &dev).expect("gpu_neg");
9337 assert_buf_eq(&out, &dev, &expected);
9338 }
9339
9340 #[test]
9341 fn neg_double_negation() {
9342 let a_data = vec![1.0f32, -2.0, 3.0];
9343 let (dev, a) = setup(&a_data);
9344 let neg1 = gpu_neg(&a, &dev).expect("gpu_neg 1");
9345 let neg2 = gpu_neg(&neg1, &dev).expect("gpu_neg 2");
9346 assert_buf_eq(&neg2, &dev, &a_data);
9347 }
9348
9349 #[test]
9352 fn relu_basic() {
9353 let a_data = vec![-3.0f32, -1.0, 0.0, 1.0, 3.0];
9354 let expected = vec![0.0f32, 0.0, 0.0, 1.0, 3.0];
9355
9356 let (dev, a) = setup(&a_data);
9357 let out = gpu_relu(&a, &dev).expect("gpu_relu");
9358 assert_buf_eq(&out, &dev, &expected);
9359 }
9360
9361 #[test]
9362 fn relu_all_negative() {
9363 let a_data = vec![-5.0f32, -0.1, -100.0];
9364 let expected = vec![0.0f32, 0.0, 0.0];
9365
9366 let (dev, a) = setup(&a_data);
9367 let out = gpu_relu(&a, &dev).expect("gpu_relu");
9368 assert_buf_eq(&out, &dev, &expected);
9369 }
9370
9371 #[test]
9372 fn relu_all_positive() {
9373 let a_data = vec![0.1f32, 1.0, 100.0];
9374
9375 let (dev, a) = setup(&a_data);
9376 let out = gpu_relu(&a, &dev).expect("gpu_relu");
9377 assert_buf_eq(&out, &dev, &a_data);
9378 }
9379
9380 #[test]
9381 fn relu_empty() {
9382 let (dev, a) = setup(&[]);
9383 let out = gpu_relu(&a, &dev).expect("gpu_relu empty");
9384 assert_eq!(out.len(), 0);
9385 }
9386
9387 #[test]
9388 fn small_matmul_2x2() {
9389 let dev = GpuDevice::new(0).expect("CUDA device 0");
9390 let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0, 4.0], &dev).unwrap();
9393 let b = cpu_to_gpu(&[5.0f32, 6.0, 7.0, 8.0], &dev).unwrap();
9394 let c = gpu_small_matmul(&a, &b, 2, 2, 2, &dev).unwrap();
9395 assert_buf_eq(&c, &dev, &[19.0, 22.0, 43.0, 50.0]);
9396 }
9397
9398 #[test]
9399 fn small_matmul_1xk_kxn() {
9400 let dev = GpuDevice::new(0).expect("CUDA device 0");
9401 let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0], &dev).unwrap();
9404 let b = cpu_to_gpu(&[1.0f32, 0.0, 0.0, 1.0, 1.0, 1.0], &dev).unwrap();
9405 let c = gpu_small_matmul(&a, &b, 1, 3, 2, &dev).unwrap();
9406 assert_buf_eq(&c, &dev, &[4.0, 5.0]);
9407 }
9408
9409 #[test]
9410 fn small_matmul_vs_cublas() {
9411 let dev = GpuDevice::new(0).expect("CUDA device 0");
9414 let m = 1;
9415 let k = 64;
9416 let n = 64;
9417
9418 let a_data: Vec<f32> = (0..m * k)
9420 .map(|i| ((i * 7 + 3) % 100) as f32 / 100.0)
9421 .collect();
9422 let b_data: Vec<f32> = (0..k * n)
9423 .map(|i| ((i * 11 + 5) % 100) as f32 / 100.0)
9424 .collect();
9425
9426 let a = cpu_to_gpu(&a_data, &dev).unwrap();
9427 let b = cpu_to_gpu(&b_data, &dev).unwrap();
9428
9429 let c_cublas = crate::blas::gpu_matmul_f32(&a, &b, m, k, n, &dev).unwrap();
9431 let cublas_result = gpu_to_cpu(&c_cublas, &dev).unwrap();
9432
9433 let c_ours = gpu_small_matmul(&a, &b, m, k, n, &dev).unwrap();
9435 let our_result = gpu_to_cpu(&c_ours, &dev).unwrap();
9436
9437 assert_eq!(cublas_result.len(), our_result.len());
9438 for (i, (&cb, &ours)) in cublas_result.iter().zip(our_result.iter()).enumerate() {
9439 assert!(
9440 (cb - ours).abs() < 0.1,
9441 "element {i}: cuBLAS={cb}, ours={ours}, diff={}",
9442 (cb - ours).abs()
9443 );
9444 }
9445 }
9446}