1#[cfg(feature = "cuda")]
23use cudarc::driver::LaunchConfig;
24
25use crate::buffer::CudaBuffer;
26use crate::device::GpuDevice;
27use crate::error::{GpuError, GpuResult};
28#[cfg(feature = "cuda")]
29use crate::transfer::{alloc_zeros_f32, alloc_zeros_f64, cpu_to_gpu, gpu_to_cpu};
30
31#[cfg(feature = "cuda")]
43pub(crate) fn ptx_f32_to_f64(f32_ptx: &str, f32_kernel_name: &str, f64_kernel_name: &str) -> String {
44 f32_ptx
45 .replace(f32_kernel_name, f64_kernel_name)
47 .replace(".reg .f32", ".reg .f64")
49 .replace("ld.global.f32", "ld.global.f64")
51 .replace("st.global.f32", "st.global.f64")
52 .replace("ld.shared.f32", "ld.shared.f64")
53 .replace("st.shared.f32", "st.shared.f64")
54 .replace("ld.param.f32", "ld.param.f64")
55 .replace(".param .f32", ".param .f64")
56 .replace(".shared .align 4 .f32", ".shared .align 8 .f64")
58 .replace("add.f32", "add.f64")
60 .replace("sub.f32", "sub.f64")
61 .replace("mul.f32", "mul.f64")
62 .replace("div.rn.f32", "div.rn.f64")
63 .replace("div.f32", "div.f64")
64 .replace("neg.f32", "neg.f64")
65 .replace("abs.f32", "abs.f64")
66 .replace("max.f32", "max.f64")
67 .replace("min.f32", "min.f64")
68 .replace("sqrt.rn.f32", "sqrt.rn.f64")
69 .replace("sqrt.f32", "sqrt.f64")
70 .replace("fma.rn.f32", "fma.rn.f64")
71 .replace("mov.f32", "mov.f64")
72 .replace("setp.gt.f32", "setp.gt.f64")
74 .replace("setp.ge.f32", "setp.ge.f64")
75 .replace("setp.lt.f32", "setp.lt.f64")
76 .replace("setp.le.f32", "setp.le.f64")
77 .replace("setp.eq.f32", "setp.eq.f64")
78 .replace("setp.ne.f32", "setp.ne.f64")
79 .replace("cvt.rn.f32.u32", "cvt.rn.f64.u32")
81 .replace("cvt.rn.f32.s32", "cvt.rn.f64.s32")
82 .replace("mov.b32", "mov.b64")
84 .replace("shl.b64 %off, %off, 2", "shl.b64 %off, %off, 3")
86 .replace("atom.global.add.f32", "atom.global.add.f64")
88 .replace("0f00000000", "0d0000000000000000") .replace("0f3F800000", "0d3FF0000000000000") .replace("0fBF800000", "0dBFF0000000000000") .replace("0f40000000", "0d4000000000000000") .replace("0f3F000000", "0d3FE0000000000000") .replace("0fFF800000", "0dFFF0000000000000") .replace("0f7F800000", "0d7FF0000000000000") .replace("0f3FB8AA3B", "0d3FF71547652B82FE") .replace("0f3F317218", "0d3FE62E42FEFA39EF") }
99
100#[cfg(feature = "cuda")]
105pub(crate) fn get_f64_ptx<'a>(
106 cache: &'a std::sync::OnceLock<String>,
107 f32_ptx: &str,
108 f32_name: &str,
109 f64_name: &str,
110) -> &'a str {
111 cache.get_or_init(|| ptx_f32_to_f64(f32_ptx, f32_name, f64_name))
112}
113
114#[cfg(feature = "cuda")]
120pub(crate) const ADD_PTX: &str = "\
121.version 7.0
122.target sm_52
123.address_size 64
124
125.visible .entry add_kernel(
126 .param .u64 a_ptr,
127 .param .u64 b_ptr,
128 .param .u64 out_ptr,
129 .param .u32 n
130) {
131 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
132 .reg .u64 %a, %b, %out, %off;
133 .reg .f32 %va, %vb, %vr;
134 .reg .pred %p;
135
136 ld.param.u64 %a, [a_ptr];
137 ld.param.u64 %b, [b_ptr];
138 ld.param.u64 %out, [out_ptr];
139 ld.param.u32 %n_reg, [n];
140
141 mov.u32 %bid, %ctaid.x;
142 mov.u32 %bdim, %ntid.x;
143 mov.u32 %r_tid, %tid.x;
144 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
145
146 setp.ge.u32 %p, %r_tid, %n_reg;
147 @%p bra DONE;
148
149 cvt.u64.u32 %off, %r_tid;
150 shl.b64 %off, %off, 2;
151
152 add.u64 %a, %a, %off;
153 add.u64 %b, %b, %off;
154 add.u64 %out, %out, %off;
155
156 ld.global.f32 %va, [%a];
157 ld.global.f32 %vb, [%b];
158 add.f32 %vr, %va, %vb;
159 st.global.f32 [%out], %vr;
160
161DONE:
162 ret;
163}
164";
165
166
167#[cfg(feature = "cuda")]
172pub(crate) const ADD_VEC4_PTX: &str = "\
173.version 7.0
174.target sm_52
175.address_size 64
176
177.visible .entry add_vec4_kernel(
178 .param .u64 a_ptr,
179 .param .u64 b_ptr,
180 .param .u64 out_ptr,
181 .param .u32 n4
182) {
183 .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
184 .reg .u64 %a, %b, %out, %off;
185 .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
186 .reg .pred %p;
187
188 ld.param.u64 %a, [a_ptr];
189 ld.param.u64 %b, [b_ptr];
190 ld.param.u64 %out, [out_ptr];
191 ld.param.u32 %n4_reg, [n4];
192
193 mov.u32 %bid, %ctaid.x;
194 mov.u32 %bdim, %ntid.x;
195 mov.u32 %r_tid, %tid.x;
196 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
197
198 setp.ge.u32 %p, %r_tid, %n4_reg;
199 @%p bra DONE;
200
201 // Byte offset = tid * 16 (4 floats × 4 bytes)
202 cvt.u64.u32 %off, %r_tid;
203 shl.b64 %off, %off, 4;
204
205 add.u64 %a, %a, %off;
206 add.u64 %b, %b, %off;
207 add.u64 %out, %out, %off;
208
209 ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
210 ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
211
212 add.f32 %r0, %a0, %b0;
213 add.f32 %r1, %a1, %b1;
214 add.f32 %r2, %a2, %b2;
215 add.f32 %r3, %a3, %b3;
216
217 st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
218
219DONE:
220 ret;
221}
222";
223
224#[cfg(feature = "cuda")]
226pub(crate) const MUL_VEC4_PTX: &str = "\
227.version 7.0
228.target sm_52
229.address_size 64
230
231.visible .entry mul_vec4_kernel(
232 .param .u64 a_ptr,
233 .param .u64 b_ptr,
234 .param .u64 out_ptr,
235 .param .u32 n4
236) {
237 .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
238 .reg .u64 %a, %b, %out, %off;
239 .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
240 .reg .pred %p;
241
242 ld.param.u64 %a, [a_ptr];
243 ld.param.u64 %b, [b_ptr];
244 ld.param.u64 %out, [out_ptr];
245 ld.param.u32 %n4_reg, [n4];
246
247 mov.u32 %bid, %ctaid.x;
248 mov.u32 %bdim, %ntid.x;
249 mov.u32 %r_tid, %tid.x;
250 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
251
252 setp.ge.u32 %p, %r_tid, %n4_reg;
253 @%p bra DONE;
254
255 cvt.u64.u32 %off, %r_tid;
256 shl.b64 %off, %off, 4;
257
258 add.u64 %a, %a, %off;
259 add.u64 %b, %b, %off;
260 add.u64 %out, %out, %off;
261
262 ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
263 ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
264
265 mul.f32 %r0, %a0, %b0;
266 mul.f32 %r1, %a1, %b1;
267 mul.f32 %r2, %a2, %b2;
268 mul.f32 %r3, %a3, %b3;
269
270 st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
271
272DONE:
273 ret;
274}
275";
276
277#[cfg(feature = "cuda")]
279pub(crate) const SUB_PTX: &str = "\
280.version 7.0
281.target sm_52
282.address_size 64
283
284.visible .entry sub_kernel(
285 .param .u64 a_ptr,
286 .param .u64 b_ptr,
287 .param .u64 out_ptr,
288 .param .u32 n
289) {
290 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
291 .reg .u64 %a, %b, %out, %off;
292 .reg .f32 %va, %vb, %vr;
293 .reg .pred %p;
294
295 ld.param.u64 %a, [a_ptr];
296 ld.param.u64 %b, [b_ptr];
297 ld.param.u64 %out, [out_ptr];
298 ld.param.u32 %n_reg, [n];
299
300 mov.u32 %bid, %ctaid.x;
301 mov.u32 %bdim, %ntid.x;
302 mov.u32 %r_tid, %tid.x;
303 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
304
305 setp.ge.u32 %p, %r_tid, %n_reg;
306 @%p bra DONE;
307
308 cvt.u64.u32 %off, %r_tid;
309 shl.b64 %off, %off, 2;
310
311 add.u64 %a, %a, %off;
312 add.u64 %b, %b, %off;
313 add.u64 %out, %out, %off;
314
315 ld.global.f32 %va, [%a];
316 ld.global.f32 %vb, [%b];
317 sub.f32 %vr, %va, %vb;
318 st.global.f32 [%out], %vr;
319
320DONE:
321 ret;
322}
323";
324
325
326#[cfg(feature = "cuda")]
328pub(crate) const MUL_PTX: &str = "\
329.version 7.0
330.target sm_52
331.address_size 64
332
333.visible .entry mul_kernel(
334 .param .u64 a_ptr,
335 .param .u64 b_ptr,
336 .param .u64 out_ptr,
337 .param .u32 n
338) {
339 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
340 .reg .u64 %a, %b, %out, %off;
341 .reg .f32 %va, %vb, %vr;
342 .reg .pred %p;
343
344 ld.param.u64 %a, [a_ptr];
345 ld.param.u64 %b, [b_ptr];
346 ld.param.u64 %out, [out_ptr];
347 ld.param.u32 %n_reg, [n];
348
349 mov.u32 %bid, %ctaid.x;
350 mov.u32 %bdim, %ntid.x;
351 mov.u32 %r_tid, %tid.x;
352 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
353
354 setp.ge.u32 %p, %r_tid, %n_reg;
355 @%p bra DONE;
356
357 cvt.u64.u32 %off, %r_tid;
358 shl.b64 %off, %off, 2;
359
360 add.u64 %a, %a, %off;
361 add.u64 %b, %b, %off;
362 add.u64 %out, %out, %off;
363
364 ld.global.f32 %va, [%a];
365 ld.global.f32 %vb, [%b];
366 mul.f32 %vr, %va, %vb;
367 st.global.f32 [%out], %vr;
368
369DONE:
370 ret;
371}
372";
373
374
375#[cfg(feature = "cuda")]
377pub(crate) const NEG_PTX: &str = "\
378.version 7.0
379.target sm_52
380.address_size 64
381
382.visible .entry neg_kernel(
383 .param .u64 a_ptr,
384 .param .u64 out_ptr,
385 .param .u32 n
386) {
387 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
388 .reg .u64 %a, %out, %off;
389 .reg .f32 %va, %vr;
390 .reg .pred %p;
391
392 ld.param.u64 %a, [a_ptr];
393 ld.param.u64 %out, [out_ptr];
394 ld.param.u32 %n_reg, [n];
395
396 mov.u32 %bid, %ctaid.x;
397 mov.u32 %bdim, %ntid.x;
398 mov.u32 %r_tid, %tid.x;
399 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
400
401 setp.ge.u32 %p, %r_tid, %n_reg;
402 @%p bra DONE;
403
404 cvt.u64.u32 %off, %r_tid;
405 shl.b64 %off, %off, 2;
406
407 add.u64 %a, %a, %off;
408 add.u64 %out, %out, %off;
409
410 ld.global.f32 %va, [%a];
411 neg.f32 %vr, %va;
412 st.global.f32 [%out], %vr;
413
414DONE:
415 ret;
416}
417";
418
419
420#[cfg(feature = "cuda")]
422pub(crate) const RELU_PTX: &str = "\
423.version 7.0
424.target sm_52
425.address_size 64
426
427.visible .entry relu_kernel(
428 .param .u64 a_ptr,
429 .param .u64 out_ptr,
430 .param .u32 n
431) {
432 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
433 .reg .u64 %a, %out, %off;
434 .reg .f32 %va, %vr, %zero;
435 .reg .pred %p;
436
437 ld.param.u64 %a, [a_ptr];
438 ld.param.u64 %out, [out_ptr];
439 ld.param.u32 %n_reg, [n];
440
441 mov.u32 %bid, %ctaid.x;
442 mov.u32 %bdim, %ntid.x;
443 mov.u32 %r_tid, %tid.x;
444 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
445
446 setp.ge.u32 %p, %r_tid, %n_reg;
447 @%p bra DONE;
448
449 cvt.u64.u32 %off, %r_tid;
450 shl.b64 %off, %off, 2;
451
452 add.u64 %a, %a, %off;
453 add.u64 %out, %out, %off;
454
455 ld.global.f32 %va, [%a];
456 mov.f32 %zero, 0f00000000;
457 max.f32 %vr, %va, %zero;
458 st.global.f32 [%out], %vr;
459
460DONE:
461 ret;
462}
463";
464
465
466#[cfg(feature = "cuda")]
468pub(crate) const SCALE_PTX: &str = "\
469.version 7.0
470.target sm_52
471.address_size 64
472
473.visible .entry scale_kernel(
474 .param .u64 a_ptr,
475 .param .u64 out_ptr,
476 .param .f32 scalar,
477 .param .u32 n
478) {
479 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
480 .reg .u64 %a, %out, %off;
481 .reg .f32 %va, %vr, %s;
482 .reg .pred %p;
483
484 ld.param.u64 %a, [a_ptr];
485 ld.param.u64 %out, [out_ptr];
486 ld.param.f32 %s, [scalar];
487 ld.param.u32 %n_reg, [n];
488
489 mov.u32 %bid, %ctaid.x;
490 mov.u32 %bdim, %ntid.x;
491 mov.u32 %r_tid, %tid.x;
492 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
493
494 setp.ge.u32 %p, %r_tid, %n_reg;
495 @%p bra DONE;
496
497 cvt.u64.u32 %off, %r_tid;
498 shl.b64 %off, %off, 2;
499
500 add.u64 %a, %a, %off;
501 add.u64 %out, %out, %off;
502
503 ld.global.f32 %va, [%a];
504 mul.f32 %vr, %va, %s;
505 st.global.f32 [%out], %vr;
506
507DONE:
508 ret;
509}
510";
511
512
513#[cfg(feature = "cuda")]
516pub(crate) const TRANSPOSE_2D_PTX: &str = "\
517.version 7.0\n\
518.target sm_52\n\
519.address_size 64\n\
520\n\
521.visible .entry transpose_2d_kernel(\n\
522 .param .u64 in_ptr,\n\
523 .param .u64 out_ptr,\n\
524 .param .u32 M,\n\
525 .param .u32 N,\n\
526 .param .u32 total\n\
527) {\n\
528 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %N_reg;\n\
529 .reg .u32 %out_row, %out_col, %in_idx;\n\
530 .reg .u64 %in, %out, %off_in, %off_out;\n\
531 .reg .f32 %val;\n\
532 .reg .pred %p;\n\
533\n\
534 ld.param.u64 %in, [in_ptr];\n\
535 ld.param.u64 %out, [out_ptr];\n\
536 ld.param.u32 %M_reg, [M];\n\
537 ld.param.u32 %N_reg, [N];\n\
538 ld.param.u32 %total_reg, [total];\n\
539\n\
540 mov.u32 %bid, %ctaid.x;\n\
541 mov.u32 %bdim, %ntid.x;\n\
542 mov.u32 %r_tid, %tid.x;\n\
543 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
544\n\
545 setp.ge.u32 %p, %r_tid, %total_reg;\n\
546 @%p bra DONE;\n\
547\n\
548 // Output shape is [N, M]. tid = out_row * M + out_col.\n\
549 div.u32 %out_row, %r_tid, %M_reg;\n\
550 rem.u32 %out_col, %r_tid, %M_reg;\n\
551 // Input index: out_col * N + out_row (transposed).\n\
552 mad.lo.u32 %in_idx, %out_col, %N_reg, %out_row;\n\
553\n\
554 cvt.u64.u32 %off_in, %in_idx;\n\
555 shl.b64 %off_in, %off_in, 2;\n\
556 add.u64 %off_in, %in, %off_in;\n\
557 ld.global.f32 %val, [%off_in];\n\
558\n\
559 cvt.u64.u32 %off_out, %r_tid;\n\
560 shl.b64 %off_out, %off_out, 2;\n\
561 add.u64 %off_out, %out, %off_out;\n\
562 st.global.f32 [%off_out], %val;\n\
563\n\
564DONE:\n\
565 ret;\n\
566}\n\
567";
568
569
570#[cfg(feature = "cuda")]
578pub(crate) const PERMUTE_0213_PTX: &str = "\
579.version 7.0\n\
580.target sm_52\n\
581.address_size 64\n\
582\n\
583.visible .entry permute_0213_kernel(\n\
584 .param .u64 in_ptr,\n\
585 .param .u64 out_ptr,\n\
586 .param .u32 d0,\n\
587 .param .u32 d1,\n\
588 .param .u32 d2,\n\
589 .param .u32 d3,\n\
590 .param .u32 total\n\
591) {\n\
592 .reg .u32 %r_tid, %bid, %bdim, %total_reg;\n\
593 .reg .u32 %d0r, %d1r, %d2r, %d3r;\n\
594 .reg .u32 %i0, %i1, %i2, %i3, %rem, %in_idx;\n\
595 .reg .u32 %s_out2, %s_out1, %s_in1;\n\
596 .reg .u64 %in, %out, %off_in, %off_out;\n\
597 .reg .f32 %val;\n\
598 .reg .pred %p;\n\
599\n\
600 ld.param.u64 %in, [in_ptr];\n\
601 ld.param.u64 %out, [out_ptr];\n\
602 ld.param.u32 %d0r, [d0];\n\
603 ld.param.u32 %d1r, [d1];\n\
604 ld.param.u32 %d2r, [d2];\n\
605 ld.param.u32 %d3r, [d3];\n\
606 ld.param.u32 %total_reg, [total];\n\
607\n\
608 mov.u32 %bid, %ctaid.x;\n\
609 mov.u32 %bdim, %ntid.x;\n\
610 mov.u32 %r_tid, %tid.x;\n\
611 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
612\n\
613 setp.ge.u32 %p, %r_tid, %total_reg;\n\
614 @%p bra DONE;\n\
615\n\
616 // Output shape: [d0, d2, d1, d3]\n\
617 // Decompose tid into (i0, i2, i1, i3) in output layout.\n\
618 mul.lo.u32 %s_out2, %d1r, %d3r;\n\
619 mul.lo.u32 %s_out1, %s_out2, %d2r;\n\
620\n\
621 div.u32 %i0, %r_tid, %s_out1;\n\
622 rem.u32 %rem, %r_tid, %s_out1;\n\
623 div.u32 %i2, %rem, %s_out2;\n\
624 rem.u32 %rem, %rem, %s_out2;\n\
625 div.u32 %i1, %rem, %d3r;\n\
626 rem.u32 %i3, %rem, %d3r;\n\
627\n\
628 // Input index: i0 * (d1*d2*d3) + i1 * (d2*d3) + i2 * d3 + i3\n\
629 mul.lo.u32 %s_in1, %d2r, %d3r;\n\
630 mul.lo.u32 %in_idx, %i0, %d1r;\n\
631 add.u32 %in_idx, %in_idx, %i1;\n\
632 mul.lo.u32 %in_idx, %in_idx, %s_in1;\n\
633 mad.lo.u32 %in_idx, %i2, %d3r, %in_idx;\n\
634 add.u32 %in_idx, %in_idx, %i3;\n\
635\n\
636 cvt.u64.u32 %off_in, %in_idx;\n\
637 shl.b64 %off_in, %off_in, 2;\n\
638 add.u64 %off_in, %in, %off_in;\n\
639 ld.global.f32 %val, [%off_in];\n\
640\n\
641 cvt.u64.u32 %off_out, %r_tid;\n\
642 shl.b64 %off_out, %off_out, 2;\n\
643 add.u64 %off_out, %out, %off_out;\n\
644 st.global.f32 [%off_out], %val;\n\
645\n\
646DONE:\n\
647 ret;\n\
648}\n\
649";
650
651
652#[cfg(feature = "cuda")]
659pub(crate) const F32_TO_F16_PTX: &str = "\
660.version 7.0
661.target sm_52
662.address_size 64
663
664.visible .entry f32_to_f16_kernel(
665 .param .u64 in_ptr,
666 .param .u64 out_ptr,
667 .param .u32 n
668) {
669 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
670 .reg .u64 %in, %out, %off_in, %off_out;
671 .reg .f32 %vf;
672 .reg .b16 %vh;
673 .reg .pred %p;
674
675 ld.param.u64 %in, [in_ptr];
676 ld.param.u64 %out, [out_ptr];
677 ld.param.u32 %n_reg, [n];
678
679 mov.u32 %bid, %ctaid.x;
680 mov.u32 %bdim, %ntid.x;
681 mov.u32 %r_tid, %tid.x;
682 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
683
684 setp.ge.u32 %p, %r_tid, %n_reg;
685 @%p bra DONE;
686
687 // Compute input offset: i * 4 (f32 = 4 bytes)
688 cvt.u64.u32 %off_in, %r_tid;
689 shl.b64 %off_in, %off_in, 2;
690 add.u64 %in, %in, %off_in;
691
692 // Compute output offset: i * 2 (f16 = 2 bytes)
693 cvt.u64.u32 %off_out, %r_tid;
694 shl.b64 %off_out, %off_out, 1;
695 add.u64 %out, %out, %off_out;
696
697 // Load f32, convert to f16 (round-to-nearest-even), store as u16
698 ld.global.f32 %vf, [%in];
699 cvt.rn.f16.f32 %vh, %vf;
700 st.global.b16 [%out], %vh;
701
702DONE:
703 ret;
704}
705";
706
707#[cfg(feature = "cuda")]
714pub(crate) const F32_TO_BF16_PTX: &str = "\
715.version 7.0
716.target sm_52
717.address_size 64
718
719.visible .entry f32_to_bf16_kernel(
720 .param .u64 in_ptr,
721 .param .u64 out_ptr,
722 .param .u32 n
723) {
724 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
725 .reg .u64 %in, %out, %off_in, %off_out;
726 .reg .f32 %vf;
727 .reg .u32 %bits, %round, %lsb, %result;
728 .reg .pred %p;
729
730 ld.param.u64 %in, [in_ptr];
731 ld.param.u64 %out, [out_ptr];
732 ld.param.u32 %n_reg, [n];
733
734 mov.u32 %bid, %ctaid.x;
735 mov.u32 %bdim, %ntid.x;
736 mov.u32 %r_tid, %tid.x;
737 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
738
739 setp.ge.u32 %p, %r_tid, %n_reg;
740 @%p bra DONE;
741
742 cvt.u64.u32 %off_in, %r_tid;
743 shl.b64 %off_in, %off_in, 2;
744 add.u64 %in, %in, %off_in;
745
746 cvt.u64.u32 %off_out, %r_tid;
747 shl.b64 %off_out, %off_out, 1;
748 add.u64 %out, %out, %off_out;
749
750 // Load f32 as raw bits
751 ld.global.u32 %bits, [%in];
752
753 // Round-to-nearest-even: add (0x7FFF + bit[16]) then shift right 16
754 shr.u32 %lsb, %bits, 16;
755 and.b32 %lsb, %lsb, 1;
756 add.u32 %round, %bits, 0x7FFF;
757 add.u32 %round, %round, %lsb;
758 shr.u32 %result, %round, 16;
759
760 // Store as u16
761 st.global.u16 [%out], %result;
762
763DONE:
764 ret;
765}
766";
767
768#[cfg(feature = "cuda")]
775pub(crate) const SMALL_MATMUL_PTX: &str = "\
776.version 7.0
777.target sm_52
778.address_size 64
779
780.visible .entry small_matmul_kernel(
781 .param .u64 a_ptr,
782 .param .u64 b_ptr,
783 .param .u64 c_ptr,
784 .param .u32 M,
785 .param .u32 K,
786 .param .u32 N,
787 .param .u32 total
788) {
789 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %K_reg, %N_reg;
790 .reg .u32 %row, %col, %p, %idx;
791 .reg .u64 %a, %b, %c, %a_off, %b_off, %c_off;
792 .reg .f32 %sum, %va, %vb;
793 .reg .pred %bounds_p, %loop_p;
794
795 ld.param.u64 %a, [a_ptr];
796 ld.param.u64 %b, [b_ptr];
797 ld.param.u64 %c, [c_ptr];
798 ld.param.u32 %M_reg, [M];
799 ld.param.u32 %K_reg, [K];
800 ld.param.u32 %N_reg, [N];
801 ld.param.u32 %total_reg, [total];
802
803 mov.u32 %bid, %ctaid.x;
804 mov.u32 %bdim, %ntid.x;
805 mov.u32 %r_tid, %tid.x;
806 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
807
808 setp.ge.u32 %bounds_p, %r_tid, %total_reg;
809 @%bounds_p bra DONE;
810
811 div.u32 %row, %r_tid, %N_reg;
812 rem.u32 %col, %r_tid, %N_reg;
813
814 mov.f32 %sum, 0f00000000;
815 mov.u32 %p, 0;
816DOT:
817 setp.ge.u32 %loop_p, %p, %K_reg;
818 @%loop_p bra DOT_DONE;
819
820 mad.lo.u32 %idx, %row, %K_reg, %p;
821 cvt.u64.u32 %a_off, %idx;
822 shl.b64 %a_off, %a_off, 2;
823 add.u64 %a_off, %a, %a_off;
824 ld.global.f32 %va, [%a_off];
825
826 mad.lo.u32 %idx, %p, %N_reg, %col;
827 cvt.u64.u32 %b_off, %idx;
828 shl.b64 %b_off, %b_off, 2;
829 add.u64 %b_off, %b, %b_off;
830 ld.global.f32 %vb, [%b_off];
831
832 fma.rn.f32 %sum, %va, %vb, %sum;
833 add.u32 %p, %p, 1;
834 bra DOT;
835DOT_DONE:
836
837 cvt.u64.u32 %c_off, %r_tid;
838 shl.b64 %c_off, %c_off, 2;
839 add.u64 %c_off, %c, %c_off;
840 st.global.f32 [%c_off], %sum;
841
842DONE:
843 ret;
844}
845";
846
847#[cfg(feature = "cuda")]
852pub(crate) const SLICE_WRITE_PTX: &str = "\
853.version 7.0
854.target sm_52
855.address_size 64
856
857.visible .entry slice_write_kernel(
858 .param .u64 src_ptr,
859 .param .u64 dst_ptr,
860 .param .u32 n,
861 .param .u32 D,
862 .param .u32 max_len,
863 .param .u32 pos
864) {
865 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
866 .reg .u32 %batch_idx, %d_idx, %dst_row;
867 .reg .u64 %src, %dst, %src_off, %dst_off;
868 .reg .f32 %val;
869 .reg .pred %p;
870
871 ld.param.u64 %src, [src_ptr];
872 ld.param.u64 %dst, [dst_ptr];
873 ld.param.u32 %n_reg, [n];
874 ld.param.u32 %D_reg, [D];
875 ld.param.u32 %max_len_reg, [max_len];
876 ld.param.u32 %pos_reg, [pos];
877
878 mov.u32 %bid, %ctaid.x;
879 mov.u32 %bdim, %ntid.x;
880 mov.u32 %r_tid, %tid.x;
881 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
882
883 setp.ge.u32 %p, %r_tid, %n_reg;
884 @%p bra DONE;
885
886 cvt.u64.u32 %src_off, %r_tid;
887 shl.b64 %src_off, %src_off, 2;
888 add.u64 %src, %src, %src_off;
889 ld.global.f32 %val, [%src];
890
891 div.u32 %batch_idx, %r_tid, %D_reg;
892 rem.u32 %d_idx, %r_tid, %D_reg;
893 mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
894 add.u32 %dst_row, %dst_row, %pos_reg;
895 mul.lo.u32 %dst_row, %dst_row, %D_reg;
896 add.u32 %dst_row, %dst_row, %d_idx;
897 cvt.u64.u32 %dst_off, %dst_row;
898 shl.b64 %dst_off, %dst_off, 2;
899 add.u64 %dst, %dst, %dst_off;
900 st.global.f32 [%dst], %val;
901
902DONE:
903 ret;
904}
905";
906
907
908#[cfg(feature = "cuda")]
913pub(crate) const SLICE_WRITE_INDIRECT_PTX: &str = "\
914.version 7.0
915.target sm_52
916.address_size 64
917
918.visible .entry slice_write_indirect_kernel(
919 .param .u64 src_ptr,
920 .param .u64 dst_ptr,
921 .param .u32 n,
922 .param .u32 D,
923 .param .u32 max_len,
924 .param .u64 pos_ptr
925) {
926 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
927 .reg .u32 %batch_idx, %d_idx, %dst_row;
928 .reg .u64 %src, %dst, %src_off, %dst_off, %pos_p;
929 .reg .f32 %val;
930 .reg .pred %p;
931
932 ld.param.u64 %src, [src_ptr];
933 ld.param.u64 %dst, [dst_ptr];
934 ld.param.u32 %n_reg, [n];
935 ld.param.u32 %D_reg, [D];
936 ld.param.u32 %max_len_reg, [max_len];
937 ld.param.u64 %pos_p, [pos_ptr];
938 ld.global.u32 %pos_reg, [%pos_p];
939
940 mov.u32 %bid, %ctaid.x;
941 mov.u32 %bdim, %ntid.x;
942 mov.u32 %r_tid, %tid.x;
943 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
944
945 setp.ge.u32 %p, %r_tid, %n_reg;
946 @%p bra DONE;
947
948 cvt.u64.u32 %src_off, %r_tid;
949 shl.b64 %src_off, %src_off, 2;
950 add.u64 %src, %src, %src_off;
951 ld.global.f32 %val, [%src];
952
953 div.u32 %batch_idx, %r_tid, %D_reg;
954 rem.u32 %d_idx, %r_tid, %D_reg;
955 mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
956 add.u32 %dst_row, %dst_row, %pos_reg;
957 mul.lo.u32 %dst_row, %dst_row, %D_reg;
958 add.u32 %dst_row, %dst_row, %d_idx;
959 cvt.u64.u32 %dst_off, %dst_row;
960 shl.b64 %dst_off, %dst_off, 2;
961 add.u64 %dst, %dst, %dst_off;
962 st.global.f32 [%dst], %val;
963
964DONE:
965 ret;
966}
967";
968
969#[cfg(feature = "cuda")]
975pub(crate) const CAUSAL_MASK_INDIRECT_PTX: &str = "\
976.version 7.0
977.target sm_52
978.address_size 64
979
980.visible .entry causal_mask_indirect_kernel(
981 .param .u64 total_len_ptr,
982 .param .u64 out_ptr,
983 .param .u32 max_pos,
984 .param .u32 total
985) {
986 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %tlen, %max_pos_reg, %col;
987 .reg .u64 %out, %off, %tl_p;
988 .reg .f32 %val;
989 .reg .pred %bounds_p, %mask_p;
990
991 ld.param.u64 %tl_p, [total_len_ptr];
992 ld.param.u64 %out, [out_ptr];
993 ld.param.u32 %max_pos_reg, [max_pos];
994 ld.param.u32 %total_reg, [total];
995
996 ld.global.u32 %tlen, [%tl_p];
997
998 mov.u32 %bid, %ctaid.x;
999 mov.u32 %bdim, %ntid.x;
1000 mov.u32 %r_tid, %tid.x;
1001 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1002
1003 setp.ge.u32 %bounds_p, %r_tid, %total_reg;
1004 @%bounds_p bra DONE;
1005
1006 rem.u32 %col, %r_tid, %max_pos_reg;
1007 setp.lt.u32 %mask_p, %col, %tlen;
1008 @%mask_p bra WRITE_ZERO;
1009
1010 // 0fCE6E6B28 = -1.0e9 in IEEE 754 f32, used as a large negative mask value
1011 // to effectively zero out masked positions after softmax.
1012 mov.f32 %val, 0fCE6E6B28;
1013 bra WRITE;
1014
1015WRITE_ZERO:
1016 mov.f32 %val, 0f00000000;
1017
1018WRITE:
1019 cvt.u64.u32 %off, %r_tid;
1020 shl.b64 %off, %off, 2;
1021 add.u64 %out, %out, %off;
1022 st.global.f32 [%out], %val;
1023
1024DONE:
1025 ret;
1026}
1027";
1028
1029#[cfg(feature = "cuda")]
1034pub(crate) const EMBED_LOOKUP_PTX: &str = "\
1035.version 7.0
1036.target sm_52
1037.address_size 64
1038
1039.visible .entry embed_lookup_kernel(
1040 .param .u64 idx_ptr,
1041 .param .u64 weight_ptr,
1042 .param .u64 out_ptr,
1043 .param .u32 D
1044) {
1045 .reg .u32 %r_tid, %bid, %bdim, %D_reg, %row, %src_idx;
1046 .reg .u64 %idx_addr, %w, %out, %off;
1047 .reg .f32 %idx_f, %val;
1048 .reg .pred %p;
1049
1050 ld.param.u64 %idx_addr, [idx_ptr];
1051 ld.param.u64 %w, [weight_ptr];
1052 ld.param.u64 %out, [out_ptr];
1053 ld.param.u32 %D_reg, [D];
1054
1055 mov.u32 %bid, %ctaid.x;
1056 mov.u32 %bdim, %ntid.x;
1057 mov.u32 %r_tid, %tid.x;
1058 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1059
1060 setp.ge.u32 %p, %r_tid, %D_reg;
1061 @%p bra DONE;
1062
1063 ld.global.f32 %idx_f, [%idx_addr];
1064 cvt.rzi.u32.f32 %row, %idx_f;
1065
1066 mad.lo.u32 %src_idx, %row, %D_reg, %r_tid;
1067 cvt.u64.u32 %off, %src_idx;
1068 shl.b64 %off, %off, 2;
1069 add.u64 %off, %w, %off;
1070 ld.global.f32 %val, [%off];
1071
1072 cvt.u64.u32 %off, %r_tid;
1073 shl.b64 %off, %off, 2;
1074 add.u64 %off, %out, %off;
1075 st.global.f32 [%off], %val;
1076
1077DONE:
1078 ret;
1079}
1080";
1081
1082
1083#[cfg(feature = "cuda")]
1091pub(crate) const EMBED_LOOKUP_BATCH_PTX: &str = "\
1092.version 7.0
1093.target sm_52
1094.address_size 64
1095
1096.visible .entry embed_lookup_batch_kernel(
1097 .param .u64 idx_ptr,
1098 .param .u64 weight_ptr,
1099 .param .u64 out_ptr,
1100 .param .u32 D,
1101 .param .u32 total
1102) {
1103 .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1104 .reg .u32 %row, %col, %src_idx;
1105 .reg .u64 %idx_addr, %w, %out, %off;
1106 .reg .f32 %idx_f, %val;
1107 .reg .pred %p;
1108
1109 ld.param.u64 %idx_addr, [idx_ptr];
1110 ld.param.u64 %w, [weight_ptr];
1111 ld.param.u64 %out, [out_ptr];
1112 ld.param.u32 %D_reg, [D];
1113 ld.param.u32 %total_reg, [total];
1114
1115 mov.u32 %bid, %ctaid.x;
1116 mov.u32 %bdim, %ntid.x;
1117 mov.u32 %tid, %tid.x;
1118 mad.lo.u32 %tid, %bid, %bdim, %tid;
1119
1120 setp.ge.u32 %p, %tid, %total_reg;
1121 @%p bra DONE;
1122
1123 // row = tid / D, col = tid % D
1124 div.u32 %row, %tid, %D_reg;
1125 rem.u32 %col, %tid, %D_reg;
1126
1127 // Read indices[row] (f32 -> u32)
1128 cvt.u64.u32 %off, %row;
1129 shl.b64 %off, %off, 2;
1130 add.u64 %off, %idx_addr, %off;
1131 ld.global.f32 %idx_f, [%off];
1132 cvt.rzi.u32.f32 %src_idx, %idx_f;
1133
1134 // src_idx = indices[row] * D + col
1135 mad.lo.u32 %src_idx, %src_idx, %D_reg, %col;
1136 cvt.u64.u32 %off, %src_idx;
1137 shl.b64 %off, %off, 2;
1138 add.u64 %off, %w, %off;
1139 ld.global.f32 %val, [%off];
1140
1141 // Write to out[tid]
1142 cvt.u64.u32 %off, %tid;
1143 shl.b64 %off, %off, 2;
1144 add.u64 %off, %out, %off;
1145 st.global.f32 [%off], %val;
1146
1147DONE:
1148 ret;
1149}
1150";
1151
1152
1153#[cfg(feature = "cuda")]
1161pub(crate) const SCATTER_ADD_ROWS_PTX: &str = "\
1162.version 7.0
1163.target sm_52
1164.address_size 64
1165
1166.visible .entry scatter_add_rows_kernel(
1167 .param .u64 grad_output_ptr,
1168 .param .u64 indices_ptr,
1169 .param .u64 grad_weight_ptr,
1170 .param .u32 D,
1171 .param .u32 total
1172) {
1173 .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1174 .reg .u32 %row, %col, %dst_idx;
1175 .reg .u64 %go, %idx_addr, %gw, %off;
1176 .reg .f32 %idx_f, %grad_val, %dummy;
1177 .reg .pred %p;
1178
1179 ld.param.u64 %go, [grad_output_ptr];
1180 ld.param.u64 %idx_addr, [indices_ptr];
1181 ld.param.u64 %gw, [grad_weight_ptr];
1182 ld.param.u32 %D_reg, [D];
1183 ld.param.u32 %total_reg, [total];
1184
1185 mov.u32 %bid, %ctaid.x;
1186 mov.u32 %bdim, %ntid.x;
1187 mov.u32 %tid, %tid.x;
1188 mad.lo.u32 %tid, %bid, %bdim, %tid;
1189
1190 setp.ge.u32 %p, %tid, %total_reg;
1191 @%p bra DONE;
1192
1193 // row = tid / D, col = tid % D
1194 div.u32 %row, %tid, %D_reg;
1195 rem.u32 %col, %tid, %D_reg;
1196
1197 // Read grad_output[tid]
1198 cvt.u64.u32 %off, %tid;
1199 shl.b64 %off, %off, 2;
1200 add.u64 %off, %go, %off;
1201 ld.global.f32 %grad_val, [%off];
1202
1203 // Read indices[row] (f32 -> u32)
1204 cvt.u64.u32 %off, %row;
1205 shl.b64 %off, %off, 2;
1206 add.u64 %off, %idx_addr, %off;
1207 ld.global.f32 %idx_f, [%off];
1208 cvt.rzi.u32.f32 %dst_idx, %idx_f;
1209
1210 // dst_idx = indices[row] * D + col
1211 mad.lo.u32 %dst_idx, %dst_idx, %D_reg, %col;
1212 cvt.u64.u32 %off, %dst_idx;
1213 shl.b64 %off, %off, 2;
1214 add.u64 %off, %gw, %off;
1215 atom.global.add.f32 %dummy, [%off], %grad_val;
1216
1217DONE:
1218 ret;
1219}
1220";
1221
1222
1223#[cfg(feature = "cuda")]
1230pub(crate) const SLICE_READ_PTX: &str = "\
1231.version 7.0
1232.target sm_52
1233.address_size 64
1234
1235.visible .entry slice_read_kernel(
1236 .param .u64 src_ptr,
1237 .param .u64 dst_ptr,
1238 .param .u32 total,
1239 .param .u32 D,
1240 .param .u32 len,
1241 .param .u32 max_len
1242) {
1243 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %D_reg, %len_reg, %max_len_reg;
1244 .reg .u32 %batch_idx, %within, %row, %col, %src_idx;
1245 .reg .u32 %len_d;
1246 .reg .u64 %src, %dst, %src_off, %dst_off;
1247 .reg .f32 %val;
1248 .reg .pred %p;
1249
1250 ld.param.u64 %src, [src_ptr];
1251 ld.param.u64 %dst, [dst_ptr];
1252 ld.param.u32 %total_reg, [total];
1253 ld.param.u32 %D_reg, [D];
1254 ld.param.u32 %len_reg, [len];
1255 ld.param.u32 %max_len_reg, [max_len];
1256
1257 mov.u32 %bid, %ctaid.x;
1258 mov.u32 %bdim, %ntid.x;
1259 mov.u32 %r_tid, %tid.x;
1260 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1261
1262 setp.ge.u32 %p, %r_tid, %total_reg;
1263 @%p bra DONE;
1264
1265 // dst index = r_tid
1266 // batch_idx = r_tid / (len * D)
1267 // within = r_tid % (len * D)
1268 // row = within / D
1269 // col = within % D
1270 // src_idx = batch_idx * max_len * D + row * D + col
1271 mul.lo.u32 %len_d, %len_reg, %D_reg;
1272 div.u32 %batch_idx, %r_tid, %len_d;
1273 rem.u32 %within, %r_tid, %len_d;
1274 div.u32 %row, %within, %D_reg;
1275 rem.u32 %col, %within, %D_reg;
1276
1277 mul.lo.u32 %src_idx, %batch_idx, %max_len_reg;
1278 add.u32 %src_idx, %src_idx, %row;
1279 mul.lo.u32 %src_idx, %src_idx, %D_reg;
1280 add.u32 %src_idx, %src_idx, %col;
1281
1282 cvt.u64.u32 %src_off, %src_idx;
1283 shl.b64 %src_off, %src_off, 2;
1284 add.u64 %src_off, %src, %src_off;
1285 ld.global.f32 %val, [%src_off];
1286
1287 cvt.u64.u32 %dst_off, %r_tid;
1288 shl.b64 %dst_off, %dst_off, 2;
1289 add.u64 %dst_off, %dst, %dst_off;
1290 st.global.f32 [%dst_off], %val;
1291
1292DONE:
1293 ret;
1294}
1295";
1296
1297
1298#[cfg(feature = "cuda")]
1308pub(crate) const GELU_PTX: &str = "\
1309.version 7.0
1310.target sm_52
1311.address_size 64
1312
1313.visible .entry gelu_kernel(
1314 .param .u64 in_ptr,
1315 .param .u64 out_ptr,
1316 .param .u32 n
1317) {
1318 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1319 .reg .u64 %in, %out, %off;
1320 .reg .f32 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
1321 .reg .pred %p;
1322
1323 ld.param.u64 %in, [in_ptr];
1324 ld.param.u64 %out, [out_ptr];
1325 ld.param.u32 %n_reg, [n];
1326
1327 mov.u32 %bid, %ctaid.x;
1328 mov.u32 %bdim, %ntid.x;
1329 mov.u32 %r_tid, %tid.x;
1330 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1331
1332 setp.ge.u32 %p, %r_tid, %n_reg;
1333 @%p bra DONE;
1334
1335 cvt.u64.u32 %off, %r_tid;
1336 shl.b64 %off, %off, 2;
1337 add.u64 %in, %in, %off;
1338 add.u64 %out, %out, %off;
1339
1340 ld.global.f32 %x, [%in];
1341
1342 mov.f32 %k, 0f3FDA2720;
1343 mul.f32 %neg_kx, %k, %x;
1344 neg.f32 %neg_kx, %neg_kx;
1345 mul.f32 %neg_kx, %neg_kx, 0f3FB8AA3B;
1346 ex2.approx.f32 %exp_neg, %neg_kx;
1347 mov.f32 %one, 0f3F800000;
1348 add.f32 %denom, %one, %exp_neg;
1349 rcp.approx.f32 %sig, %denom;
1350 mul.f32 %result, %x, %sig;
1351 st.global.f32 [%out], %result;
1352
1353DONE:
1354 ret;
1355}
1356";
1357
1358#[cfg(feature = "cuda")]
1361pub(crate) const GELU_F64_PTX: &str = "\
1362.version 7.0
1363.target sm_52
1364.address_size 64
1365
1366.visible .entry gelu_f64_kernel(
1367 .param .u64 in_ptr,
1368 .param .u64 out_ptr,
1369 .param .u32 n
1370) {
1371 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1372 .reg .u64 %in, %out, %off;
1373 .reg .f64 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
1374 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1375 .reg .s32 %e_ni;
1376 .reg .s64 %e_ni64, %e_bits;
1377 .reg .pred %p;
1378
1379 ld.param.u64 %in, [in_ptr];
1380 ld.param.u64 %out, [out_ptr];
1381 ld.param.u32 %n_reg, [n];
1382
1383 mov.u32 %bid, %ctaid.x;
1384 mov.u32 %bdim, %ntid.x;
1385 mov.u32 %r_tid, %tid.x;
1386 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1387
1388 setp.ge.u32 %p, %r_tid, %n_reg;
1389 @%p bra DONE;
1390
1391 cvt.u64.u32 %off, %r_tid;
1392 shl.b64 %off, %off, 3;
1393 add.u64 %in, %in, %off;
1394 add.u64 %out, %out, %off;
1395
1396 ld.global.f64 %x, [%in];
1397 mov.f64 %one, 0d3FF0000000000000;
1398
1399 // k = 1.702
1400 mov.f64 %k, 0d3FFB44E400000000;
1401 mul.f64 %neg_kx, %k, %x;
1402 neg.f64 %neg_kx, %neg_kx;
1403
1404 // --- exp(%neg_kx) via Cody-Waite + degree-11 Horner ---
1405 mov.f64 %e_half, 0d3FE0000000000000;
1406 fma.rn.f64 %e_nf, %neg_kx, 0d3FF71547652B82FE, %e_half;
1407 cvt.rmi.f64.f64 %e_nf, %e_nf;
1408 cvt.rni.s32.f64 %e_ni, %e_nf;
1409 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_kx;
1410 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1411 mov.f64 %e_p, 0d3E21EED8EFF8D898;
1412 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1413 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
1414 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
1415 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
1416 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
1417 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
1418 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
1419 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
1420 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
1421 fma.rn.f64 %e_p, %e_p, %e_r, %one;
1422 fma.rn.f64 %exp_neg, %e_p, %e_r, %one;
1423 cvt.s64.s32 %e_ni64, %e_ni;
1424 add.s64 %e_ni64, %e_ni64, 1023;
1425 shl.b64 %e_bits, %e_ni64, 52;
1426 mov.b64 %e_nf, %e_bits;
1427 mul.f64 %exp_neg, %exp_neg, %e_nf;
1428 // --- end exp ---
1429
1430 add.f64 %denom, %one, %exp_neg;
1431 div.rn.f64 %sig, %one, %denom;
1432 mul.f64 %result, %x, %sig;
1433 st.global.f64 [%out], %result;
1434
1435DONE:
1436 ret;
1437}
1438";
1439
1440#[cfg(feature = "cuda")]
1446pub(crate) const GELU_TANH_PTX: &str = "\
1447.version 7.0
1448.target sm_52
1449.address_size 64
1450
1451.visible .entry gelu_tanh_kernel(
1452 .param .u64 in_ptr,
1453 .param .u64 out_ptr,
1454 .param .u32 n
1455) {
1456 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1457 .reg .u64 %in, %out, %off;
1458 .reg .f32 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
1459 .reg .f32 %e2y_m1, %e2y_p1, %th, %one, %half, %log2e, %result;
1460 .reg .pred %p;
1461
1462 ld.param.u64 %in, [in_ptr];
1463 ld.param.u64 %out, [out_ptr];
1464 ld.param.u32 %n_reg, [n];
1465
1466 mov.u32 %bid, %ctaid.x;
1467 mov.u32 %bdim, %ntid.x;
1468 mov.u32 %r_tid, %tid.x;
1469 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1470
1471 setp.ge.u32 %p, %r_tid, %n_reg;
1472 @%p bra DONE;
1473
1474 cvt.u64.u32 %off, %r_tid;
1475 shl.b64 %off, %off, 2;
1476 add.u64 %in, %in, %off;
1477 add.u64 %out, %out, %off;
1478
1479 ld.global.f32 %x, [%in];
1480
1481 // inner = sqrt(2/π) * (x + 0.044715 * x³)
1482 // sqrt(2/π) = 0.7978845608 = 0x3F4C422A
1483 // 0.044715 = 0x3D372713
1484 mul.f32 %x3, %x, %x;
1485 mul.f32 %x3, %x3, %x;
1486 mov.f32 %c, 0f3D372713;
1487 mul.f32 %x3, %c, %x3;
1488 add.f32 %inner, %x, %x3;
1489 mov.f32 %sqrt2pi, 0f3F4C422A;
1490 mul.f32 %y, %sqrt2pi, %inner;
1491
1492 // tanh(y) = (exp(2y) - 1) / (exp(2y) + 1)
1493 // exp(2y) = 2^(2y * log2(e))
1494 mov.f32 %log2e, 0f3FB8AA3B;
1495 add.f32 %two_y, %y, %y;
1496 mul.f32 %two_y, %two_y, %log2e;
1497 ex2.approx.f32 %e2y, %two_y;
1498 mov.f32 %one, 0f3F800000;
1499 sub.f32 %e2y_m1, %e2y, %one;
1500 add.f32 %e2y_p1, %e2y, %one;
1501 rcp.approx.f32 %e2y_p1, %e2y_p1;
1502 mul.f32 %th, %e2y_m1, %e2y_p1;
1503
1504 // out = 0.5 * x * (1 + tanh)
1505 add.f32 %th, %one, %th;
1506 mov.f32 %half, 0f3F000000;
1507 mul.f32 %result, %half, %x;
1508 mul.f32 %result, %result, %th;
1509 st.global.f32 [%out], %result;
1510
1511DONE:
1512 ret;
1513}
1514";
1515
1516#[cfg(feature = "cuda")]
1519pub(crate) const GELU_TANH_F64_PTX: &str = "\
1520.version 7.0
1521.target sm_52
1522.address_size 64
1523
1524.visible .entry gelu_tanh_f64_kernel(
1525 .param .u64 in_ptr,
1526 .param .u64 out_ptr,
1527 .param .u32 n
1528) {
1529 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1530 .reg .u64 %in, %out, %off;
1531 .reg .f64 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
1532 .reg .f64 %e2y_m1, %e2y_p1, %th, %one, %half, %result;
1533 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1534 .reg .s32 %e_ni;
1535 .reg .s64 %e_ni64, %e_bits;
1536 .reg .pred %p;
1537
1538 ld.param.u64 %in, [in_ptr];
1539 ld.param.u64 %out, [out_ptr];
1540 ld.param.u32 %n_reg, [n];
1541
1542 mov.u32 %bid, %ctaid.x;
1543 mov.u32 %bdim, %ntid.x;
1544 mov.u32 %r_tid, %tid.x;
1545 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1546
1547 setp.ge.u32 %p, %r_tid, %n_reg;
1548 @%p bra DONE;
1549
1550 cvt.u64.u32 %off, %r_tid;
1551 shl.b64 %off, %off, 3;
1552 add.u64 %in, %in, %off;
1553 add.u64 %out, %out, %off;
1554
1555 ld.global.f64 %x, [%in];
1556 mov.f64 %one, 0d3FF0000000000000;
1557
1558 // inner = sqrt(2/pi) * (x + 0.044715 * x^3)
1559 mul.f64 %x3, %x, %x;
1560 mul.f64 %x3, %x3, %x;
1561 mov.f64 %c, 0d3FA6E4E260000000;
1562 mul.f64 %x3, %c, %x3;
1563 add.f64 %inner, %x, %x3;
1564 mov.f64 %sqrt2pi, 0d3FE9884540000000;
1565 mul.f64 %y, %sqrt2pi, %inner;
1566
1567 // tanh(y) = (exp(2y)-1)/(exp(2y)+1), exp(2y) in full f64
1568 add.f64 %two_y, %y, %y;
1569
1570 // --- exp(%two_y) via Cody-Waite + degree-11 Horner ---
1571 mov.f64 %e_half, 0d3FE0000000000000;
1572 fma.rn.f64 %e_nf, %two_y, 0d3FF71547652B82FE, %e_half;
1573 cvt.rmi.f64.f64 %e_nf, %e_nf;
1574 cvt.rni.s32.f64 %e_ni, %e_nf;
1575 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_y;
1576 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1577 mov.f64 %e_p, 0d3E21EED8EFF8D898;
1578 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1579 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
1580 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
1581 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
1582 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
1583 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
1584 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
1585 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
1586 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
1587 fma.rn.f64 %e_p, %e_p, %e_r, %one;
1588 fma.rn.f64 %e2y, %e_p, %e_r, %one;
1589 cvt.s64.s32 %e_ni64, %e_ni;
1590 add.s64 %e_ni64, %e_ni64, 1023;
1591 shl.b64 %e_bits, %e_ni64, 52;
1592 mov.b64 %e_nf, %e_bits;
1593 mul.f64 %e2y, %e2y, %e_nf;
1594 // --- end exp ---
1595
1596 sub.f64 %e2y_m1, %e2y, %one;
1597 add.f64 %e2y_p1, %e2y, %one;
1598 div.rn.f64 %th, %e2y_m1, %e2y_p1;
1599
1600 // out = 0.5 * x * (1 + tanh)
1601 add.f64 %th, %one, %th;
1602 mov.f64 %half, 0d3FE0000000000000;
1603 mul.f64 %result, %half, %x;
1604 mul.f64 %result, %result, %th;
1605 st.global.f64 [%out], %result;
1606
1607DONE:
1608 ret;
1609}
1610";
1611
1612#[cfg(feature = "cuda")]
1617pub(crate) const GELU_ERF_PTX: &str = "\
1618.version 7.0
1619.target sm_52
1620.address_size 64
1621
1622.visible .entry gelu_erf_kernel(
1623 .param .u64 in_ptr,
1624 .param .u64 out_ptr,
1625 .param .u32 n
1626) {
1627 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1628 .reg .u64 %in, %out, %off;
1629 .reg .f32 %x, %z, %ax, %one, %half, %log2e;
1630 .reg .f32 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
1631 .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %result;
1632 .reg .pred %pred_ge, %pred_neg;
1633
1634 ld.param.u64 %in, [in_ptr];
1635 ld.param.u64 %out, [out_ptr];
1636 ld.param.u32 %n_reg, [n];
1637
1638 mov.u32 %bid, %ctaid.x;
1639 mov.u32 %bdim, %ntid.x;
1640 mov.u32 %r_tid, %tid.x;
1641 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1642
1643 setp.ge.u32 %pred_ge, %r_tid, %n_reg;
1644 @%pred_ge bra DONE;
1645
1646 cvt.u64.u32 %off, %r_tid;
1647 shl.b64 %off, %off, 2;
1648 add.u64 %in, %in, %off;
1649 add.u64 %out, %out, %off;
1650
1651 ld.global.f32 %x, [%in];
1652 mov.f32 %one, 0f3F800000;
1653 mov.f32 %half, 0f3F000000;
1654 mov.f32 %log2e, 0f3FB8AA3B;
1655
1656 // z = x / sqrt(2) = x * 0.70710678
1657 mov.f32 %z, 0f3F3504F3;
1658 mul.f32 %z, %x, %z;
1659
1660 // |z| for erf(|z|)
1661 abs.f32 %ax, %z;
1662
1663 // t = 1 / (1 + 0.3275911 * |z|)
1664 mov.f32 %p, 0f3EA7BA05;
1665 mul.f32 %t, %p, %ax;
1666 add.f32 %t, %one, %t;
1667 rcp.approx.f32 %t, %t;
1668
1669 // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
1670 mov.f32 %a5, 0f3E0AAAAB;
1671 mov.f32 %a4, 0fBEB3A903;
1672 mov.f32 %a3, 0f3FB506DD;
1673 mov.f32 %a2, 0fBF03C1E1;
1674 mov.f32 %a1, 0f3EA0D6BB;
1675
1676 mul.f32 %pt, %t, %a5;
1677 add.f32 %pt, %pt, %a4;
1678 mul.f32 %pt, %pt, %t;
1679 add.f32 %pt, %pt, %a3;
1680 mul.f32 %pt, %pt, %t;
1681 add.f32 %pt, %pt, %a2;
1682 mul.f32 %pt, %pt, %t;
1683 add.f32 %pt, %pt, %a1;
1684 mul.f32 %pt, %pt, %t;
1685
1686 // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
1687 mul.f32 %z2, %ax, %ax;
1688 neg.f32 %neg_z2, %z2;
1689 mul.f32 %neg_z2, %neg_z2, %log2e;
1690 ex2.approx.f32 %exp_neg_z2, %neg_z2;
1691
1692 // erf(|z|) = 1 - poly * exp(-z^2)
1693 mul.f32 %erf_val, %pt, %exp_neg_z2;
1694 sub.f32 %erf_val, %one, %erf_val;
1695
1696 // erf(-z) = -erf(z), so sign-correct
1697 setp.lt.f32 %pred_neg, %z, 0f00000000;
1698 @%pred_neg neg.f32 %erf_val, %erf_val;
1699
1700 // out = x * 0.5 * (1 + erf(x/sqrt(2)))
1701 add.f32 %erf_val, %one, %erf_val;
1702 mul.f32 %result, %half, %x;
1703 mul.f32 %result, %result, %erf_val;
1704 st.global.f32 [%out], %result;
1705
1706DONE:
1707 ret;
1708}
1709";
1710
1711#[cfg(feature = "cuda")]
1714pub(crate) const GELU_ERF_F64_PTX: &str = "\
1715.version 7.0
1716.target sm_52
1717.address_size 64
1718
1719.visible .entry gelu_erf_f64_kernel(
1720 .param .u64 in_ptr,
1721 .param .u64 out_ptr,
1722 .param .u32 n
1723) {
1724 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1725 .reg .u64 %in, %out, %off;
1726 .reg .f64 %x, %z, %ax, %one, %half;
1727 .reg .f64 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
1728 .reg .f64 %p, %a1, %a2, %a3, %a4, %a5, %result;
1729 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1730 .reg .s32 %e_ni;
1731 .reg .s64 %e_ni64, %e_bits;
1732 .reg .pred %pred_ge, %pred_neg;
1733
1734 ld.param.u64 %in, [in_ptr];
1735 ld.param.u64 %out, [out_ptr];
1736 ld.param.u32 %n_reg, [n];
1737
1738 mov.u32 %bid, %ctaid.x;
1739 mov.u32 %bdim, %ntid.x;
1740 mov.u32 %r_tid, %tid.x;
1741 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1742
1743 setp.ge.u32 %pred_ge, %r_tid, %n_reg;
1744 @%pred_ge bra DONE;
1745
1746 cvt.u64.u32 %off, %r_tid;
1747 shl.b64 %off, %off, 3;
1748 add.u64 %in, %in, %off;
1749 add.u64 %out, %out, %off;
1750
1751 ld.global.f64 %x, [%in];
1752 mov.f64 %one, 0d3FF0000000000000;
1753 mov.f64 %half, 0d3FE0000000000000;
1754
1755 // z = x / sqrt(2) = x * 0.70710678
1756 mov.f64 %z, 0d3FE6A09E60000000;
1757 mul.f64 %z, %x, %z;
1758
1759 abs.f64 %ax, %z;
1760
1761 // t = 1 / (1 + 0.3275911 * |z|)
1762 mov.f64 %p, 0d3FD4F740A0000000;
1763 mul.f64 %t, %p, %ax;
1764 add.f64 %t, %one, %t;
1765 div.rn.f64 %t, %one, %t;
1766
1767 // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
1768 mov.f64 %a5, 0d3FC1555560000000;
1769 mov.f64 %a4, 0dBFD6752060000000;
1770 mov.f64 %a3, 0d3FF6A0DBA0000000;
1771 mov.f64 %a2, 0dBFE0783C20000000;
1772 mov.f64 %a1, 0d3FD41AD760000000;
1773
1774 mul.f64 %pt, %t, %a5;
1775 add.f64 %pt, %pt, %a4;
1776 mul.f64 %pt, %pt, %t;
1777 add.f64 %pt, %pt, %a3;
1778 mul.f64 %pt, %pt, %t;
1779 add.f64 %pt, %pt, %a2;
1780 mul.f64 %pt, %pt, %t;
1781 add.f64 %pt, %pt, %a1;
1782 mul.f64 %pt, %pt, %t;
1783
1784 // exp(-z^2) in full f64
1785 mul.f64 %z2, %ax, %ax;
1786 neg.f64 %neg_z2, %z2;
1787
1788 // --- exp(%neg_z2) via Cody-Waite + degree-11 Horner ---
1789 mov.f64 %e_half, 0d3FE0000000000000;
1790 fma.rn.f64 %e_nf, %neg_z2, 0d3FF71547652B82FE, %e_half;
1791 cvt.rmi.f64.f64 %e_nf, %e_nf;
1792 cvt.rni.s32.f64 %e_ni, %e_nf;
1793 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_z2;
1794 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1795 mov.f64 %e_p, 0d3E21EED8EFF8D898;
1796 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1797 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
1798 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
1799 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
1800 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
1801 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
1802 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
1803 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
1804 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
1805 fma.rn.f64 %e_p, %e_p, %e_r, %one;
1806 fma.rn.f64 %exp_neg_z2, %e_p, %e_r, %one;
1807 cvt.s64.s32 %e_ni64, %e_ni;
1808 add.s64 %e_ni64, %e_ni64, 1023;
1809 shl.b64 %e_bits, %e_ni64, 52;
1810 mov.b64 %e_nf, %e_bits;
1811 mul.f64 %exp_neg_z2, %exp_neg_z2, %e_nf;
1812 // --- end exp ---
1813
1814 mul.f64 %erf_val, %pt, %exp_neg_z2;
1815 sub.f64 %erf_val, %one, %erf_val;
1816
1817 setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
1818 @%pred_neg neg.f64 %erf_val, %erf_val;
1819
1820 add.f64 %erf_val, %one, %erf_val;
1821 mul.f64 %result, %half, %x;
1822 mul.f64 %result, %result, %erf_val;
1823 st.global.f64 [%out], %result;
1824
1825DONE:
1826 ret;
1827}
1828";
1829
1830#[cfg(feature = "cuda")]
1836pub(crate) const GELU_BACKWARD_TANH_PTX: &str = "\
1837.version 7.0
1838.target sm_52
1839.address_size 64
1840
1841.visible .entry gelu_backward_tanh_kernel(
1842 .param .u64 grad_ptr,
1843 .param .u64 input_ptr,
1844 .param .u64 out_ptr,
1845 .param .u32 n
1846) {
1847 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1848 .reg .u64 %grad, %input, %out, %off;
1849 .reg .f32 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
1850 .reg .f32 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half, %log2e;
1851 .reg .f32 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
1852 .reg .pred %p;
1853
1854 ld.param.u64 %grad, [grad_ptr];
1855 ld.param.u64 %input, [input_ptr];
1856 ld.param.u64 %out, [out_ptr];
1857 ld.param.u32 %n_reg, [n];
1858
1859 mov.u32 %bid, %ctaid.x;
1860 mov.u32 %bdim, %ntid.x;
1861 mov.u32 %r_tid, %tid.x;
1862 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1863
1864 setp.ge.u32 %p, %r_tid, %n_reg;
1865 @%p bra DONE;
1866
1867 cvt.u64.u32 %off, %r_tid;
1868 shl.b64 %off, %off, 2;
1869 add.u64 %grad, %grad, %off;
1870 add.u64 %input, %input, %off;
1871 add.u64 %out, %out, %off;
1872
1873 ld.global.f32 %vg, [%grad];
1874 ld.global.f32 %x, [%input];
1875
1876 mov.f32 %one, 0f3F800000;
1877 mov.f32 %half, 0f3F000000;
1878 mov.f32 %log2e, 0f3FB8AA3B;
1879 mov.f32 %sqrt2pi, 0f3F4C422A;
1880 mov.f32 %c, 0f3D372713;
1881 // 3 * 0.044715 = 0.134145 = 0x3E096B8C
1882 mov.f32 %c3, 0f3E096B8C;
1883
1884 // u = sqrt(2/π) * (x + 0.044715 * x³)
1885 mul.f32 %x2, %x, %x;
1886 mul.f32 %x3, %x2, %x;
1887 mul.f32 %x3, %c, %x3;
1888 add.f32 %inner, %x, %x3;
1889 mul.f32 %y, %sqrt2pi, %inner;
1890
1891 // tanh(y) via exp
1892 add.f32 %two_y, %y, %y;
1893 mul.f32 %two_y, %two_y, %log2e;
1894 ex2.approx.f32 %e2y, %two_y;
1895 sub.f32 %e2y_m1, %e2y, %one;
1896 add.f32 %e2y_p1, %e2y, %one;
1897 rcp.approx.f32 %e2y_p1, %e2y_p1;
1898 mul.f32 %th, %e2y_m1, %e2y_p1;
1899
1900 // d/dx = 0.5*(1+tanh) + 0.5*x*(1-tanh²)*sqrt(2/π)*(1+3*0.044715*x²)
1901 // term1 = 0.5 * (1 + th)
1902 add.f32 %term1, %one, %th;
1903 mul.f32 %term1, %half, %term1;
1904
1905 // (1 - th²)
1906 mul.f32 %th2, %th, %th;
1907 sub.f32 %one_m_th2, %one, %th2;
1908
1909 // d_inner = sqrt(2/π) * (1 + 3*0.044715*x²)
1910 mul.f32 %d_inner, %c3, %x2;
1911 add.f32 %d_inner, %one, %d_inner;
1912 mul.f32 %d_inner, %sqrt2pi, %d_inner;
1913
1914 // term2 = 0.5 * x * (1-th²) * d_inner
1915 mul.f32 %term2, %half, %x;
1916 mul.f32 %term2, %term2, %one_m_th2;
1917 mul.f32 %term2, %term2, %d_inner;
1918
1919 add.f32 %d_gelu, %term1, %term2;
1920 mul.f32 %result, %vg, %d_gelu;
1921 st.global.f32 [%out], %result;
1922
1923DONE:
1924 ret;
1925}
1926";
1927
1928#[cfg(feature = "cuda")]
1931pub(crate) const GELU_BACKWARD_TANH_F64_PTX: &str = "\
1932.version 7.0
1933.target sm_52
1934.address_size 64
1935
1936.visible .entry gelu_backward_tanh_f64_kernel(
1937 .param .u64 grad_ptr,
1938 .param .u64 input_ptr,
1939 .param .u64 out_ptr,
1940 .param .u32 n
1941) {
1942 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1943 .reg .u64 %grad, %input, %out, %off;
1944 .reg .f64 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
1945 .reg .f64 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half;
1946 .reg .f64 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
1947 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1948 .reg .s32 %e_ni;
1949 .reg .s64 %e_ni64, %e_bits;
1950 .reg .pred %p;
1951
1952 ld.param.u64 %grad, [grad_ptr];
1953 ld.param.u64 %input, [input_ptr];
1954 ld.param.u64 %out, [out_ptr];
1955 ld.param.u32 %n_reg, [n];
1956
1957 mov.u32 %bid, %ctaid.x;
1958 mov.u32 %bdim, %ntid.x;
1959 mov.u32 %r_tid, %tid.x;
1960 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1961
1962 setp.ge.u32 %p, %r_tid, %n_reg;
1963 @%p bra DONE;
1964
1965 cvt.u64.u32 %off, %r_tid;
1966 shl.b64 %off, %off, 3;
1967 add.u64 %grad, %grad, %off;
1968 add.u64 %input, %input, %off;
1969 add.u64 %out, %out, %off;
1970
1971 ld.global.f64 %vg, [%grad];
1972 ld.global.f64 %x, [%input];
1973
1974 mov.f64 %one, 0d3FF0000000000000;
1975 mov.f64 %half, 0d3FE0000000000000;
1976 mov.f64 %sqrt2pi, 0d3FE9884540000000;
1977 mov.f64 %c, 0d3FA6E4E260000000;
1978 // 3 * 0.044715 = 0.134145
1979 mov.f64 %c3, 0d3FC12D7180000000;
1980
1981 mul.f64 %x2, %x, %x;
1982 mul.f64 %x3, %x2, %x;
1983 mul.f64 %x3, %c, %x3;
1984 add.f64 %inner, %x, %x3;
1985 mul.f64 %y, %sqrt2pi, %inner;
1986
1987 // tanh(y) = (exp(2y)-1)/(exp(2y)+1) in full f64
1988 add.f64 %two_y, %y, %y;
1989
1990 // --- exp(%two_y) via Cody-Waite + degree-11 Horner ---
1991 mov.f64 %e_half, 0d3FE0000000000000;
1992 fma.rn.f64 %e_nf, %two_y, 0d3FF71547652B82FE, %e_half;
1993 cvt.rmi.f64.f64 %e_nf, %e_nf;
1994 cvt.rni.s32.f64 %e_ni, %e_nf;
1995 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_y;
1996 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1997 mov.f64 %e_p, 0d3E21EED8EFF8D898;
1998 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1999 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2000 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2001 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2002 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2003 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2004 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2005 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2006 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2007 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2008 fma.rn.f64 %e2y, %e_p, %e_r, %one;
2009 cvt.s64.s32 %e_ni64, %e_ni;
2010 add.s64 %e_ni64, %e_ni64, 1023;
2011 shl.b64 %e_bits, %e_ni64, 52;
2012 mov.b64 %e_nf, %e_bits;
2013 mul.f64 %e2y, %e2y, %e_nf;
2014 // --- end exp ---
2015
2016 sub.f64 %e2y_m1, %e2y, %one;
2017 add.f64 %e2y_p1, %e2y, %one;
2018 div.rn.f64 %th, %e2y_m1, %e2y_p1;
2019
2020 add.f64 %term1, %one, %th;
2021 mul.f64 %term1, %half, %term1;
2022
2023 mul.f64 %th2, %th, %th;
2024 sub.f64 %one_m_th2, %one, %th2;
2025
2026 mul.f64 %d_inner, %c3, %x2;
2027 add.f64 %d_inner, %one, %d_inner;
2028 mul.f64 %d_inner, %sqrt2pi, %d_inner;
2029
2030 mul.f64 %term2, %half, %x;
2031 mul.f64 %term2, %term2, %one_m_th2;
2032 mul.f64 %term2, %term2, %d_inner;
2033
2034 add.f64 %d_gelu, %term1, %term2;
2035 mul.f64 %result, %vg, %d_gelu;
2036 st.global.f64 [%out], %result;
2037
2038DONE:
2039 ret;
2040}
2041";
2042
2043#[cfg(feature = "cuda")]
2050pub(crate) const SILU_PTX: &str = "\
2051.version 7.0
2052.target sm_52
2053.address_size 64
2054
2055.visible .entry silu_kernel(
2056 .param .u64 a_ptr,
2057 .param .u64 out_ptr,
2058 .param .u32 n
2059) {
2060 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2061 .reg .u64 %a, %out, %off;
2062 .reg .f32 %x, %neg, %e, %denom, %sig, %vr, %one, %lg2e;
2063 .reg .pred %p;
2064
2065 ld.param.u64 %a, [a_ptr];
2066 ld.param.u64 %out, [out_ptr];
2067 ld.param.u32 %n_reg, [n];
2068
2069 mov.u32 %bid, %ctaid.x;
2070 mov.u32 %bdim, %ntid.x;
2071 mov.u32 %r_tid, %tid.x;
2072 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2073
2074 setp.ge.u32 %p, %r_tid, %n_reg;
2075 @%p bra DONE;
2076
2077 cvt.u64.u32 %off, %r_tid;
2078 shl.b64 %off, %off, 2;
2079
2080 add.u64 %a, %a, %off;
2081 add.u64 %out, %out, %off;
2082
2083 ld.global.f32 %x, [%a];
2084 // sigmoid(x) = 1 / (1 + exp(-x))
2085 // exp(-x) = 2^(-x * log2(e))
2086 mov.f32 %one, 0f3F800000;
2087 mov.f32 %lg2e, 0f3FB8AA3B;
2088 neg.f32 %neg, %x;
2089 mul.f32 %neg, %neg, %lg2e;
2090 ex2.approx.f32 %e, %neg;
2091 add.f32 %denom, %one, %e;
2092 rcp.approx.f32 %sig, %denom;
2093 // silu(x) = x * sigmoid(x)
2094 mul.f32 %vr, %x, %sig;
2095 st.global.f32 [%out], %vr;
2096
2097DONE:
2098 ret;
2099}
2100";
2101
2102#[cfg(feature = "cuda")]
2105pub(crate) const SILU_F64_PTX: &str = "\
2106.version 7.0
2107.target sm_52
2108.address_size 64
2109
2110.visible .entry silu_f64_kernel(
2111 .param .u64 a_ptr,
2112 .param .u64 out_ptr,
2113 .param .u32 n
2114) {
2115 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2116 .reg .u64 %a, %out, %off;
2117 .reg .f64 %x, %neg_x, %e, %denom, %sig, %vr, %one;
2118 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2119 .reg .s32 %e_ni;
2120 .reg .s64 %e_ni64, %e_bits;
2121 .reg .pred %p;
2122
2123 ld.param.u64 %a, [a_ptr];
2124 ld.param.u64 %out, [out_ptr];
2125 ld.param.u32 %n_reg, [n];
2126
2127 mov.u32 %bid, %ctaid.x;
2128 mov.u32 %bdim, %ntid.x;
2129 mov.u32 %r_tid, %tid.x;
2130 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2131
2132 setp.ge.u32 %p, %r_tid, %n_reg;
2133 @%p bra DONE;
2134
2135 cvt.u64.u32 %off, %r_tid;
2136 shl.b64 %off, %off, 3;
2137 add.u64 %a, %a, %off;
2138 add.u64 %out, %out, %off;
2139
2140 ld.global.f64 %x, [%a];
2141 mov.f64 %one, 0d3FF0000000000000;
2142 neg.f64 %neg_x, %x;
2143
2144 // --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
2145 mov.f64 %e_half, 0d3FE0000000000000;
2146 fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
2147 cvt.rmi.f64.f64 %e_nf, %e_nf;
2148 cvt.rni.s32.f64 %e_ni, %e_nf;
2149 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
2150 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2151 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2152 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2153 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2154 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2155 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2156 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2157 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2158 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2159 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2160 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2161 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2162 fma.rn.f64 %e, %e_p, %e_r, %one;
2163 cvt.s64.s32 %e_ni64, %e_ni;
2164 add.s64 %e_ni64, %e_ni64, 1023;
2165 shl.b64 %e_bits, %e_ni64, 52;
2166 mov.b64 %e_nf, %e_bits;
2167 mul.f64 %e, %e, %e_nf;
2168 // --- end exp ---
2169
2170 add.f64 %denom, %one, %e;
2171 div.rn.f64 %sig, %one, %denom;
2172 mul.f64 %vr, %x, %sig;
2173 st.global.f64 [%out], %vr;
2174
2175DONE:
2176 ret;
2177}
2178";
2179
2180#[cfg(feature = "cuda")]
2183pub(crate) const SILU_BACKWARD_PTX: &str = "\
2184.version 7.0
2185.target sm_52
2186.address_size 64
2187
2188.visible .entry silu_backward_kernel(
2189 .param .u64 grad_ptr,
2190 .param .u64 input_ptr,
2191 .param .u64 out_ptr,
2192 .param .u32 n
2193) {
2194 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2195 .reg .u64 %grad, %input, %out, %off;
2196 .reg .f32 %vg, %x, %neg, %e, %denom, %sig, %one, %lg2e;
2197 .reg .f32 %one_m_sig, %x_sig_omsig, %deriv, %result;
2198 .reg .pred %p;
2199
2200 ld.param.u64 %grad, [grad_ptr];
2201 ld.param.u64 %input, [input_ptr];
2202 ld.param.u64 %out, [out_ptr];
2203 ld.param.u32 %n_reg, [n];
2204
2205 mov.u32 %bid, %ctaid.x;
2206 mov.u32 %bdim, %ntid.x;
2207 mov.u32 %r_tid, %tid.x;
2208 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2209
2210 setp.ge.u32 %p, %r_tid, %n_reg;
2211 @%p bra DONE;
2212
2213 cvt.u64.u32 %off, %r_tid;
2214 shl.b64 %off, %off, 2;
2215 add.u64 %grad, %grad, %off;
2216 add.u64 %input, %input, %off;
2217 add.u64 %out, %out, %off;
2218
2219 ld.global.f32 %vg, [%grad];
2220 ld.global.f32 %x, [%input];
2221
2222 // sig = sigmoid(x) = 1 / (1 + exp(-x))
2223 mov.f32 %one, 0f3F800000;
2224 mov.f32 %lg2e, 0f3FB8AA3B;
2225 neg.f32 %neg, %x;
2226 mul.f32 %neg, %neg, %lg2e;
2227 ex2.approx.f32 %e, %neg;
2228 add.f32 %denom, %one, %e;
2229 rcp.approx.f32 %sig, %denom;
2230
2231 // deriv = sig + x * sig * (1 - sig)
2232 sub.f32 %one_m_sig, %one, %sig;
2233 mul.f32 %x_sig_omsig, %x, %sig;
2234 mul.f32 %x_sig_omsig, %x_sig_omsig, %one_m_sig;
2235 add.f32 %deriv, %sig, %x_sig_omsig;
2236 mul.f32 %result, %vg, %deriv;
2237 st.global.f32 [%out], %result;
2238
2239DONE:
2240 ret;
2241}
2242";
2243
2244#[cfg(feature = "cuda")]
2247pub(crate) const SILU_BACKWARD_F64_PTX: &str = "\
2248.version 7.0
2249.target sm_52
2250.address_size 64
2251
2252.visible .entry silu_backward_f64_kernel(
2253 .param .u64 grad_ptr,
2254 .param .u64 input_ptr,
2255 .param .u64 out_ptr,
2256 .param .u32 n
2257) {
2258 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2259 .reg .u64 %grad, %input, %out, %off;
2260 .reg .f64 %vg, %x, %neg_x, %e, %denom, %sig, %one;
2261 .reg .f64 %one_m_sig, %x_sig_omsig, %deriv, %result;
2262 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2263 .reg .s32 %e_ni;
2264 .reg .s64 %e_ni64, %e_bits;
2265 .reg .pred %p;
2266
2267 ld.param.u64 %grad, [grad_ptr];
2268 ld.param.u64 %input, [input_ptr];
2269 ld.param.u64 %out, [out_ptr];
2270 ld.param.u32 %n_reg, [n];
2271
2272 mov.u32 %bid, %ctaid.x;
2273 mov.u32 %bdim, %ntid.x;
2274 mov.u32 %r_tid, %tid.x;
2275 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2276
2277 setp.ge.u32 %p, %r_tid, %n_reg;
2278 @%p bra DONE;
2279
2280 cvt.u64.u32 %off, %r_tid;
2281 shl.b64 %off, %off, 3;
2282 add.u64 %grad, %grad, %off;
2283 add.u64 %input, %input, %off;
2284 add.u64 %out, %out, %off;
2285
2286 ld.global.f64 %vg, [%grad];
2287 ld.global.f64 %x, [%input];
2288
2289 mov.f64 %one, 0d3FF0000000000000;
2290 neg.f64 %neg_x, %x;
2291
2292 // --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
2293 mov.f64 %e_half, 0d3FE0000000000000;
2294 fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
2295 cvt.rmi.f64.f64 %e_nf, %e_nf;
2296 cvt.rni.s32.f64 %e_ni, %e_nf;
2297 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
2298 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2299 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2300 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2301 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2302 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2303 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2304 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2305 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2306 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2307 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2308 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2309 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2310 fma.rn.f64 %e, %e_p, %e_r, %one;
2311 cvt.s64.s32 %e_ni64, %e_ni;
2312 add.s64 %e_ni64, %e_ni64, 1023;
2313 shl.b64 %e_bits, %e_ni64, 52;
2314 mov.b64 %e_nf, %e_bits;
2315 mul.f64 %e, %e, %e_nf;
2316 // --- end exp ---
2317
2318 add.f64 %denom, %one, %e;
2319 div.rn.f64 %sig, %one, %denom;
2320
2321 sub.f64 %one_m_sig, %one, %sig;
2322 mul.f64 %x_sig_omsig, %x, %sig;
2323 mul.f64 %x_sig_omsig, %x_sig_omsig, %one_m_sig;
2324 add.f64 %deriv, %sig, %x_sig_omsig;
2325 mul.f64 %result, %vg, %deriv;
2326 st.global.f64 [%out], %result;
2327
2328DONE:
2329 ret;
2330}
2331";
2332
2333#[cfg(feature = "cuda")]
2336pub(crate) const ELU_PTX: &str = "\
2337.version 7.0
2338.target sm_52
2339.address_size 64
2340
2341.visible .entry elu_kernel(
2342 .param .u64 a_ptr,
2343 .param .u64 out_ptr,
2344 .param .u32 n,
2345 .param .f32 alpha
2346) {
2347 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2348 .reg .u64 %a, %out, %off;
2349 .reg .f32 %x, %alpha_r, %lg2e, %one, %ex, %em1, %neg_branch, %vr;
2350 .reg .pred %p, %pos;
2351
2352 ld.param.u64 %a, [a_ptr];
2353 ld.param.u64 %out, [out_ptr];
2354 ld.param.u32 %n_reg, [n];
2355 ld.param.f32 %alpha_r, [alpha];
2356
2357 mov.u32 %bid, %ctaid.x;
2358 mov.u32 %bdim, %ntid.x;
2359 mov.u32 %r_tid, %tid.x;
2360 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2361
2362 setp.ge.u32 %p, %r_tid, %n_reg;
2363 @%p bra DONE;
2364
2365 cvt.u64.u32 %off, %r_tid;
2366 shl.b64 %off, %off, 2;
2367
2368 add.u64 %a, %a, %off;
2369 add.u64 %out, %out, %off;
2370
2371 ld.global.f32 %x, [%a];
2372 mov.f32 %one, 0f3F800000;
2373 mov.f32 %lg2e, 0f3FB8AA3B;
2374
2375 // exp(x) = 2^(x * log2(e))
2376 mul.f32 %ex, %x, %lg2e;
2377 ex2.approx.f32 %ex, %ex;
2378 sub.f32 %em1, %ex, %one;
2379 mul.f32 %neg_branch, %alpha_r, %em1;
2380
2381 // x > 0 ? x : alpha*(exp(x)-1)
2382 mov.f32 %vr, 0f00000000;
2383 setp.gt.f32 %pos, %x, %vr;
2384 selp.f32 %vr, %x, %neg_branch, %pos;
2385 st.global.f32 [%out], %vr;
2386
2387DONE:
2388 ret;
2389}
2390";
2391
2392#[cfg(feature = "cuda")]
2395pub(crate) const ELU_F64_PTX: &str = "\
2396.version 7.0
2397.target sm_52
2398.address_size 64
2399
2400.visible .entry elu_f64_kernel(
2401 .param .u64 a_ptr,
2402 .param .u64 out_ptr,
2403 .param .u32 n,
2404 .param .f64 alpha
2405) {
2406 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2407 .reg .u64 %a, %out, %off;
2408 .reg .f64 %x, %alpha_r, %one, %ex, %em1, %neg_branch, %vr;
2409 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2410 .reg .s32 %e_ni;
2411 .reg .s64 %e_ni64, %e_bits;
2412 .reg .pred %p, %pos;
2413
2414 ld.param.u64 %a, [a_ptr];
2415 ld.param.u64 %out, [out_ptr];
2416 ld.param.u32 %n_reg, [n];
2417 ld.param.f64 %alpha_r, [alpha];
2418
2419 mov.u32 %bid, %ctaid.x;
2420 mov.u32 %bdim, %ntid.x;
2421 mov.u32 %r_tid, %tid.x;
2422 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2423
2424 setp.ge.u32 %p, %r_tid, %n_reg;
2425 @%p bra DONE;
2426
2427 cvt.u64.u32 %off, %r_tid;
2428 shl.b64 %off, %off, 3;
2429 add.u64 %a, %a, %off;
2430 add.u64 %out, %out, %off;
2431
2432 ld.global.f64 %x, [%a];
2433 mov.f64 %one, 0d3FF0000000000000;
2434
2435 // --- exp(%x) via Cody-Waite + degree-11 Horner ---
2436 mov.f64 %e_half, 0d3FE0000000000000;
2437 fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
2438 cvt.rmi.f64.f64 %e_nf, %e_nf;
2439 cvt.rni.s32.f64 %e_ni, %e_nf;
2440 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
2441 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2442 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2443 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2444 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2445 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2446 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2447 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2448 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2449 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2450 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2451 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2452 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2453 fma.rn.f64 %ex, %e_p, %e_r, %one;
2454 cvt.s64.s32 %e_ni64, %e_ni;
2455 add.s64 %e_ni64, %e_ni64, 1023;
2456 shl.b64 %e_bits, %e_ni64, 52;
2457 mov.b64 %e_nf, %e_bits;
2458 mul.f64 %ex, %ex, %e_nf;
2459 // --- end exp ---
2460
2461 sub.f64 %em1, %ex, %one;
2462 mul.f64 %neg_branch, %alpha_r, %em1;
2463
2464 mov.f64 %vr, 0d0000000000000000;
2465 setp.gt.f64 %pos, %x, %vr;
2466 selp.f64 %vr, %x, %neg_branch, %pos;
2467 st.global.f64 [%out], %vr;
2468
2469DONE:
2470 ret;
2471}
2472";
2473
2474#[cfg(feature = "cuda")]
2478pub(crate) const ELU_BACKWARD_PTX: &str = "\
2479.version 7.0
2480.target sm_52
2481.address_size 64
2482
2483.visible .entry elu_backward_kernel(
2484 .param .u64 grad_ptr,
2485 .param .u64 input_ptr,
2486 .param .u64 out_ptr,
2487 .param .u32 n,
2488 .param .f32 alpha
2489) {
2490 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2491 .reg .u64 %grad, %input, %out, %off;
2492 .reg .f32 %vg, %x, %alpha_r, %lg2e, %ex, %neg_branch, %vr, %zero;
2493 .reg .pred %p, %pos;
2494
2495 ld.param.u64 %grad, [grad_ptr];
2496 ld.param.u64 %input, [input_ptr];
2497 ld.param.u64 %out, [out_ptr];
2498 ld.param.u32 %n_reg, [n];
2499 ld.param.f32 %alpha_r, [alpha];
2500
2501 mov.u32 %bid, %ctaid.x;
2502 mov.u32 %bdim, %ntid.x;
2503 mov.u32 %r_tid, %tid.x;
2504 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2505
2506 setp.ge.u32 %p, %r_tid, %n_reg;
2507 @%p bra DONE;
2508
2509 cvt.u64.u32 %off, %r_tid;
2510 shl.b64 %off, %off, 2;
2511 add.u64 %grad, %grad, %off;
2512 add.u64 %input, %input, %off;
2513 add.u64 %out, %out, %off;
2514
2515 ld.global.f32 %vg, [%grad];
2516 ld.global.f32 %x, [%input];
2517
2518 mov.f32 %lg2e, 0f3FB8AA3B;
2519 mov.f32 %zero, 0f00000000;
2520
2521 // exp(x) = 2^(x * log2(e))
2522 mul.f32 %ex, %x, %lg2e;
2523 ex2.approx.f32 %ex, %ex;
2524 // negative branch: grad * alpha * exp(x)
2525 mul.f32 %neg_branch, %vg, %alpha_r;
2526 mul.f32 %neg_branch, %neg_branch, %ex;
2527
2528 // x > 0 ? grad : grad * alpha * exp(x)
2529 setp.gt.f32 %pos, %x, %zero;
2530 selp.f32 %vr, %vg, %neg_branch, %pos;
2531 st.global.f32 [%out], %vr;
2532
2533DONE:
2534 ret;
2535}
2536";
2537
2538#[cfg(feature = "cuda")]
2541pub(crate) const ELU_BACKWARD_F64_PTX: &str = "\
2542.version 7.0
2543.target sm_52
2544.address_size 64
2545
2546.visible .entry elu_backward_f64_kernel(
2547 .param .u64 grad_ptr,
2548 .param .u64 input_ptr,
2549 .param .u64 out_ptr,
2550 .param .u32 n,
2551 .param .f64 alpha
2552) {
2553 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2554 .reg .u64 %grad, %input, %out, %off;
2555 .reg .f64 %vg, %x, %alpha_r, %ex, %neg_branch, %vr, %zero, %one;
2556 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2557 .reg .s32 %e_ni;
2558 .reg .s64 %e_ni64, %e_bits;
2559 .reg .pred %p, %pos;
2560
2561 ld.param.u64 %grad, [grad_ptr];
2562 ld.param.u64 %input, [input_ptr];
2563 ld.param.u64 %out, [out_ptr];
2564 ld.param.u32 %n_reg, [n];
2565 ld.param.f64 %alpha_r, [alpha];
2566
2567 mov.u32 %bid, %ctaid.x;
2568 mov.u32 %bdim, %ntid.x;
2569 mov.u32 %r_tid, %tid.x;
2570 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2571
2572 setp.ge.u32 %p, %r_tid, %n_reg;
2573 @%p bra DONE;
2574
2575 cvt.u64.u32 %off, %r_tid;
2576 shl.b64 %off, %off, 3;
2577 add.u64 %grad, %grad, %off;
2578 add.u64 %input, %input, %off;
2579 add.u64 %out, %out, %off;
2580
2581 ld.global.f64 %vg, [%grad];
2582 ld.global.f64 %x, [%input];
2583
2584 mov.f64 %zero, 0d0000000000000000;
2585 mov.f64 %one, 0d3FF0000000000000;
2586
2587 // --- exp(%x) via Cody-Waite + degree-11 Horner ---
2588 mov.f64 %e_half, 0d3FE0000000000000;
2589 fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
2590 cvt.rmi.f64.f64 %e_nf, %e_nf;
2591 cvt.rni.s32.f64 %e_ni, %e_nf;
2592 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
2593 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2594 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2595 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2596 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2597 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2598 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2599 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2600 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2601 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2602 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2603 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2604 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2605 fma.rn.f64 %ex, %e_p, %e_r, %one;
2606 cvt.s64.s32 %e_ni64, %e_ni;
2607 add.s64 %e_ni64, %e_ni64, 1023;
2608 shl.b64 %e_bits, %e_ni64, 52;
2609 mov.b64 %e_nf, %e_bits;
2610 mul.f64 %ex, %ex, %e_nf;
2611 // --- end exp ---
2612
2613 mul.f64 %neg_branch, %vg, %alpha_r;
2614 mul.f64 %neg_branch, %neg_branch, %ex;
2615
2616 setp.gt.f64 %pos, %x, %zero;
2617 selp.f64 %vr, %vg, %neg_branch, %pos;
2618 st.global.f64 [%out], %vr;
2619
2620DONE:
2621 ret;
2622}
2623";
2624
2625#[cfg(feature = "cuda")]
2629pub(crate) const MISH_PTX: &str = "\
2630.version 7.0
2631.target sm_52
2632.address_size 64
2633
2634.visible .entry mish_kernel(
2635 .param .u64 a_ptr,
2636 .param .u64 out_ptr,
2637 .param .u32 n
2638) {
2639 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2640 .reg .u64 %a, %out, %off;
2641 .reg .f32 %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
2642 .reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %th, %vr;
2643 .reg .f32 %threshold;
2644 .reg .pred %p, %large;
2645
2646 ld.param.u64 %a, [a_ptr];
2647 ld.param.u64 %out, [out_ptr];
2648 ld.param.u32 %n_reg, [n];
2649
2650 mov.u32 %bid, %ctaid.x;
2651 mov.u32 %bdim, %ntid.x;
2652 mov.u32 %r_tid, %tid.x;
2653 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2654
2655 setp.ge.u32 %p, %r_tid, %n_reg;
2656 @%p bra DONE;
2657
2658 cvt.u64.u32 %off, %r_tid;
2659 shl.b64 %off, %off, 2;
2660
2661 add.u64 %a, %a, %off;
2662 add.u64 %out, %out, %off;
2663
2664 ld.global.f32 %x, [%a];
2665 mov.f32 %one, 0f3F800000;
2666 mov.f32 %lg2e, 0f3FB8AA3B;
2667 // threshold = 20.0 = 0x41A00000
2668 mov.f32 %threshold, 0f41A00000;
2669
2670 // softplus(x) = ln(1 + exp(x))
2671 // For large x (> 20), softplus ~ x to avoid overflow
2672 setp.gt.f32 %large, %x, %threshold;
2673 @%large bra LARGE_X;
2674
2675 // exp(x) = 2^(x * log2(e))
2676 mul.f32 %ex, %x, %lg2e;
2677 ex2.approx.f32 %ex, %ex;
2678 add.f32 %ep1, %ex, %one;
2679 // ln(1+exp(x)) = log2(1+exp(x)) / log2(e)
2680 lg2.approx.f32 %lg_ep1, %ep1;
2681 // 1/log2(e) = ln(2) = 0.6931472 = 0x3F317218
2682 mul.f32 %sp, %lg_ep1, 0f3F317218;
2683
2684 // tanh(sp) = (exp(2*sp) - 1) / (exp(2*sp) + 1)
2685 add.f32 %two_sp, %sp, %sp;
2686 mul.f32 %two_sp, %two_sp, %lg2e;
2687 ex2.approx.f32 %e2sp, %two_sp;
2688 sub.f32 %e2sp_m1, %e2sp, %one;
2689 add.f32 %e2sp_p1, %e2sp, %one;
2690 rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2691 mul.f32 %th, %e2sp_m1, %e2sp_p1;
2692
2693 mul.f32 %vr, %x, %th;
2694 st.global.f32 [%out], %vr;
2695 bra DONE;
2696
2697LARGE_X:
2698 // softplus ~ x, mish ~ x * tanh(x)
2699 // tanh(x) = (exp(2x)-1)/(exp(2x)+1)
2700 add.f32 %two_sp, %x, %x;
2701 mul.f32 %two_sp, %two_sp, %lg2e;
2702 ex2.approx.f32 %e2sp, %two_sp;
2703 sub.f32 %e2sp_m1, %e2sp, %one;
2704 add.f32 %e2sp_p1, %e2sp, %one;
2705 rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2706 mul.f32 %th, %e2sp_m1, %e2sp_p1;
2707 mul.f32 %vr, %x, %th;
2708 st.global.f32 [%out], %vr;
2709
2710DONE:
2711 ret;
2712}
2713";
2714
2715#[cfg(feature = "cuda")]
2718pub(crate) const MISH_F64_PTX: &str = "\
2719.version 7.0
2720.target sm_52
2721.address_size 64
2722
2723.visible .entry mish_f64_kernel(
2724 .param .u64 a_ptr,
2725 .param .u64 out_ptr,
2726 .param .u32 n
2727) {
2728 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2729 .reg .u64 %a, %out, %off;
2730 .reg .f64 %x, %one, %two, %ex, %ep1, %sp;
2731 .reg .f64 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %th, %vr;
2732 .reg .f64 %threshold;
2733 // exp subroutine regs
2734 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2735 .reg .s32 %e_ni;
2736 .reg .s64 %e_ni64, %e_bits;
2737 // log subroutine regs
2738 .reg .u64 %l_xbits, %l_mbits, %l_bias;
2739 .reg .s64 %l_exp64;
2740 .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
2741 .reg .pred %p, %large;
2742
2743 ld.param.u64 %a, [a_ptr];
2744 ld.param.u64 %out, [out_ptr];
2745 ld.param.u32 %n_reg, [n];
2746
2747 mov.u32 %bid, %ctaid.x;
2748 mov.u32 %bdim, %ntid.x;
2749 mov.u32 %r_tid, %tid.x;
2750 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2751
2752 setp.ge.u32 %p, %r_tid, %n_reg;
2753 @%p bra DONE;
2754
2755 cvt.u64.u32 %off, %r_tid;
2756 shl.b64 %off, %off, 3;
2757 add.u64 %a, %a, %off;
2758 add.u64 %out, %out, %off;
2759
2760 ld.global.f64 %x, [%a];
2761 mov.f64 %one, 0d3FF0000000000000;
2762 mov.f64 %two, 0d4000000000000000;
2763 mov.f64 %threshold, 0d4034000000000000;
2764
2765 setp.gt.f64 %large, %x, %threshold;
2766 @%large bra LARGE_X;
2767
2768 // === softplus: sp = ln(1 + exp(x)) ===
2769 // exp(x)
2770 mov.f64 %e_half, 0d3FE0000000000000;
2771 fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
2772 cvt.rmi.f64.f64 %e_nf, %e_nf;
2773 cvt.rni.s32.f64 %e_ni, %e_nf;
2774 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
2775 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2776 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2777 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2778 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2779 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2780 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2781 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2782 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2783 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2784 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2785 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2786 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2787 fma.rn.f64 %ex, %e_p, %e_r, %one;
2788 cvt.s64.s32 %e_ni64, %e_ni;
2789 add.s64 %e_ni64, %e_ni64, 1023;
2790 shl.b64 %e_bits, %e_ni64, 52;
2791 mov.b64 %e_nf, %e_bits;
2792 mul.f64 %ex, %ex, %e_nf;
2793
2794 // ep1 = 1 + exp(x)
2795 add.f64 %ep1, %ex, %one;
2796
2797 // ln(ep1) via argument reduction
2798 mov.b64 %l_xbits, %ep1;
2799 shr.u64 %l_exp64, %l_xbits, 52;
2800 and.b64 %l_exp64, %l_exp64, 2047;
2801 sub.s64 %l_exp64, %l_exp64, 1023;
2802 cvt.rn.f64.s64 %l_nf, %l_exp64;
2803 mov.u64 %l_bias, 0x3FF0000000000000;
2804 and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
2805 or.b64 %l_mbits, %l_mbits, %l_bias;
2806 mov.b64 %l_m, %l_mbits;
2807 sub.f64 %l_f, %l_m, %one;
2808 add.f64 %l_s, %l_m, %one;
2809 div.rn.f64 %l_f, %l_f, %l_s;
2810 mul.f64 %l_f2, %l_f, %l_f;
2811 mov.f64 %l_p, 0d3FB745D1745D1746;
2812 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
2813 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
2814 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
2815 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
2816 fma.rn.f64 %l_p, %l_p, %l_f2, %one;
2817 mul.f64 %l_p, %l_p, %l_f;
2818 add.f64 %l_p, %l_p, %l_p;
2819 mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
2820 fma.rn.f64 %sp, %l_nf, %l_ln2, %l_p;
2821
2822 // === tanh(sp) = (exp(2*sp)-1)/(exp(2*sp)+1) ===
2823 add.f64 %two_sp, %sp, %sp;
2824 fma.rn.f64 %e_nf, %two_sp, 0d3FF71547652B82FE, %e_half;
2825 cvt.rmi.f64.f64 %e_nf, %e_nf;
2826 cvt.rni.s32.f64 %e_ni, %e_nf;
2827 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
2828 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2829 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2830 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2831 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2832 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2833 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2834 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2835 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2836 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2837 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2838 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2839 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2840 fma.rn.f64 %e2sp, %e_p, %e_r, %one;
2841 cvt.s64.s32 %e_ni64, %e_ni;
2842 add.s64 %e_ni64, %e_ni64, 1023;
2843 shl.b64 %e_bits, %e_ni64, 52;
2844 mov.b64 %e_nf, %e_bits;
2845 mul.f64 %e2sp, %e2sp, %e_nf;
2846
2847 sub.f64 %e2sp_m1, %e2sp, %one;
2848 add.f64 %e2sp_p1, %e2sp, %one;
2849 div.rn.f64 %th, %e2sp_m1, %e2sp_p1;
2850
2851 mul.f64 %vr, %x, %th;
2852 st.global.f64 [%out], %vr;
2853 bra DONE;
2854
2855LARGE_X:
2856 // softplus ~ x, tanh(x) = (exp(2x)-1)/(exp(2x)+1) in f64
2857 add.f64 %two_sp, %x, %x;
2858 mov.f64 %e_half, 0d3FE0000000000000;
2859 fma.rn.f64 %e_nf, %two_sp, 0d3FF71547652B82FE, %e_half;
2860 cvt.rmi.f64.f64 %e_nf, %e_nf;
2861 cvt.rni.s32.f64 %e_ni, %e_nf;
2862 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
2863 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2864 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2865 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2866 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2867 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2868 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2869 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2870 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2871 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2872 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2873 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2874 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2875 fma.rn.f64 %e2sp, %e_p, %e_r, %one;
2876 cvt.s64.s32 %e_ni64, %e_ni;
2877 add.s64 %e_ni64, %e_ni64, 1023;
2878 shl.b64 %e_bits, %e_ni64, 52;
2879 mov.b64 %e_nf, %e_bits;
2880 mul.f64 %e2sp, %e2sp, %e_nf;
2881
2882 sub.f64 %e2sp_m1, %e2sp, %one;
2883 add.f64 %e2sp_p1, %e2sp, %one;
2884 div.rn.f64 %th, %e2sp_m1, %e2sp_p1;
2885 mul.f64 %vr, %x, %th;
2886 st.global.f64 [%out], %vr;
2887
2888DONE:
2889 ret;
2890}
2891";
2892
2893#[cfg(feature = "cuda")]
2902pub(crate) const MISH_BACKWARD_PTX: &str = "\
2903.version 7.0
2904.target sm_52
2905.address_size 64
2906
2907.visible .entry mish_backward_kernel(
2908 .param .u64 grad_ptr,
2909 .param .u64 input_ptr,
2910 .param .u64 out_ptr,
2911 .param .u32 n
2912) {
2913 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2914 .reg .u64 %grad, %input, %out, %off;
2915 .reg .f32 %vg, %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
2916 .reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %t, %t2, %one_m_t2;
2917 .reg .f32 %neg, %en, %denom, %sig, %x_sig_omt2, %deriv, %result;
2918 .reg .f32 %threshold;
2919 .reg .pred %p, %large;
2920
2921 ld.param.u64 %grad, [grad_ptr];
2922 ld.param.u64 %input, [input_ptr];
2923 ld.param.u64 %out, [out_ptr];
2924 ld.param.u32 %n_reg, [n];
2925
2926 mov.u32 %bid, %ctaid.x;
2927 mov.u32 %bdim, %ntid.x;
2928 mov.u32 %r_tid, %tid.x;
2929 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2930
2931 setp.ge.u32 %p, %r_tid, %n_reg;
2932 @%p bra DONE;
2933
2934 cvt.u64.u32 %off, %r_tid;
2935 shl.b64 %off, %off, 2;
2936 add.u64 %grad, %grad, %off;
2937 add.u64 %input, %input, %off;
2938 add.u64 %out, %out, %off;
2939
2940 ld.global.f32 %vg, [%grad];
2941 ld.global.f32 %x, [%input];
2942
2943 mov.f32 %one, 0f3F800000;
2944 mov.f32 %lg2e, 0f3FB8AA3B;
2945 // threshold = 20.0
2946 mov.f32 %threshold, 0f41A00000;
2947
2948 setp.gt.f32 %large, %x, %threshold;
2949 @%large bra LARGE_X;
2950
2951 // --- Normal path ---
2952 // softplus: sp = ln(1 + exp(x))
2953 mul.f32 %ex, %x, %lg2e;
2954 ex2.approx.f32 %ex, %ex;
2955 add.f32 %ep1, %ex, %one;
2956 lg2.approx.f32 %lg_ep1, %ep1;
2957 // ln(2) = 0x3F317218
2958 mul.f32 %sp, %lg_ep1, 0f3F317218;
2959
2960 // t = tanh(sp) = (exp(2*sp)-1)/(exp(2*sp)+1)
2961 add.f32 %two_sp, %sp, %sp;
2962 mul.f32 %two_sp, %two_sp, %lg2e;
2963 ex2.approx.f32 %e2sp, %two_sp;
2964 sub.f32 %e2sp_m1, %e2sp, %one;
2965 add.f32 %e2sp_p1, %e2sp, %one;
2966 rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2967 mul.f32 %t, %e2sp_m1, %e2sp_p1;
2968
2969 // sig = sigmoid(x) = 1/(1+exp(-x))
2970 neg.f32 %neg, %x;
2971 mul.f32 %neg, %neg, %lg2e;
2972 ex2.approx.f32 %en, %neg;
2973 add.f32 %denom, %one, %en;
2974 rcp.approx.f32 %sig, %denom;
2975
2976 // deriv = t + x * sig * (1 - t*t)
2977 mul.f32 %t2, %t, %t;
2978 sub.f32 %one_m_t2, %one, %t2;
2979 mul.f32 %x_sig_omt2, %x, %sig;
2980 mul.f32 %x_sig_omt2, %x_sig_omt2, %one_m_t2;
2981 add.f32 %deriv, %t, %x_sig_omt2;
2982 mul.f32 %result, %vg, %deriv;
2983 st.global.f32 [%out], %result;
2984 bra DONE;
2985
2986LARGE_X:
2987 // sp ~ x, t ~ tanh(x), sig ~ 1
2988 // tanh(x) = (exp(2x)-1)/(exp(2x)+1)
2989 add.f32 %two_sp, %x, %x;
2990 mul.f32 %two_sp, %two_sp, %lg2e;
2991 ex2.approx.f32 %e2sp, %two_sp;
2992 sub.f32 %e2sp_m1, %e2sp, %one;
2993 add.f32 %e2sp_p1, %e2sp, %one;
2994 rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2995 mul.f32 %t, %e2sp_m1, %e2sp_p1;
2996
2997 // sig ~ 1, deriv ~ t + x*(1-t*t)
2998 mul.f32 %t2, %t, %t;
2999 sub.f32 %one_m_t2, %one, %t2;
3000 mul.f32 %x_sig_omt2, %x, %one_m_t2;
3001 add.f32 %deriv, %t, %x_sig_omt2;
3002 mul.f32 %result, %vg, %deriv;
3003 st.global.f32 [%out], %result;
3004
3005DONE:
3006 ret;
3007}
3008";
3009
3010#[cfg(feature = "cuda")]
3013pub(crate) const MISH_BACKWARD_F64_PTX: &str = "\
3014.version 7.0
3015.target sm_52
3016.address_size 64
3017
3018.visible .entry mish_backward_f64_kernel(
3019 .param .u64 grad_ptr,
3020 .param .u64 input_ptr,
3021 .param .u64 out_ptr,
3022 .param .u32 n
3023) {
3024 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3025 .reg .u64 %grad, %input, %out, %off;
3026 .reg .f64 %vg, %x, %one, %ex, %ep1, %sp;
3027 .reg .f64 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %t, %t2, %one_m_t2;
3028 .reg .f64 %neg_x, %en, %denom, %sig, %x_sig_omt2, %deriv, %result;
3029 .reg .f64 %threshold;
3030 // exp subroutine regs
3031 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
3032 .reg .s32 %e_ni;
3033 .reg .s64 %e_ni64, %e_bits;
3034 // log subroutine regs
3035 .reg .u64 %l_xbits, %l_mbits, %l_bias;
3036 .reg .s64 %l_exp64;
3037 .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
3038 .reg .pred %p, %large;
3039
3040 ld.param.u64 %grad, [grad_ptr];
3041 ld.param.u64 %input, [input_ptr];
3042 ld.param.u64 %out, [out_ptr];
3043 ld.param.u32 %n_reg, [n];
3044
3045 mov.u32 %bid, %ctaid.x;
3046 mov.u32 %bdim, %ntid.x;
3047 mov.u32 %r_tid, %tid.x;
3048 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3049
3050 setp.ge.u32 %p, %r_tid, %n_reg;
3051 @%p bra DONE;
3052
3053 cvt.u64.u32 %off, %r_tid;
3054 shl.b64 %off, %off, 3;
3055 add.u64 %grad, %grad, %off;
3056 add.u64 %input, %input, %off;
3057 add.u64 %out, %out, %off;
3058
3059 ld.global.f64 %vg, [%grad];
3060 ld.global.f64 %x, [%input];
3061
3062 mov.f64 %one, 0d3FF0000000000000;
3063 mov.f64 %threshold, 0d4034000000000000;
3064
3065 setp.gt.f64 %large, %x, %threshold;
3066 @%large bra LARGE_X;
3067
3068 // === softplus: sp = ln(1 + exp(x)) ===
3069 // exp(x)
3070 mov.f64 %e_half, 0d3FE0000000000000;
3071 mul.f64 %e_nf, %x, 0d3FF71547652B82FE;
3072 cvt.rni.f64.f64 %e_nf, %e_nf;
3073 cvt.rni.s32.f64 %e_ni, %e_nf;
3074 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
3075 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3076 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3077 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3078 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3079 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3080 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3081 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3082 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3083 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3084 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3085 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3086 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3087 fma.rn.f64 %ex, %e_p, %e_r, %one;
3088 cvt.s64.s32 %e_ni64, %e_ni;
3089 add.s64 %e_ni64, %e_ni64, 1023;
3090 shl.b64 %e_bits, %e_ni64, 52;
3091 mov.b64 %e_nf, %e_bits;
3092 mul.f64 %ex, %ex, %e_nf;
3093
3094 add.f64 %ep1, %ex, %one;
3095
3096 // ln(ep1) via argument reduction
3097 mov.b64 %l_xbits, %ep1;
3098 shr.u64 %l_exp64, %l_xbits, 52;
3099 and.b64 %l_exp64, %l_exp64, 2047;
3100 sub.s64 %l_exp64, %l_exp64, 1023;
3101 cvt.rn.f64.s64 %l_nf, %l_exp64;
3102 mov.u64 %l_bias, 0x3FF0000000000000;
3103 and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
3104 or.b64 %l_mbits, %l_mbits, %l_bias;
3105 mov.b64 %l_m, %l_mbits;
3106 sub.f64 %l_f, %l_m, %one;
3107 add.f64 %l_s, %l_m, %one;
3108 div.rn.f64 %l_f, %l_f, %l_s;
3109 mul.f64 %l_f2, %l_f, %l_f;
3110 mov.f64 %l_p, 0d3FB745D1745D1746;
3111 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
3112 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
3113 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
3114 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
3115 fma.rn.f64 %l_p, %l_p, %l_f2, %one;
3116 mul.f64 %l_p, %l_p, %l_f;
3117 add.f64 %l_p, %l_p, %l_p;
3118 mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
3119 fma.rn.f64 %sp, %l_nf, %l_ln2, %l_p;
3120
3121 // === tanh(sp) ===
3122 add.f64 %two_sp, %sp, %sp;
3123 mul.f64 %e_nf, %two_sp, 0d3FF71547652B82FE;
3124 cvt.rni.f64.f64 %e_nf, %e_nf;
3125 cvt.rni.s32.f64 %e_ni, %e_nf;
3126 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
3127 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3128 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3129 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3130 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3131 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3132 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3133 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3134 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3135 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3136 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3137 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3138 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3139 fma.rn.f64 %e2sp, %e_p, %e_r, %one;
3140 cvt.s64.s32 %e_ni64, %e_ni;
3141 add.s64 %e_ni64, %e_ni64, 1023;
3142 shl.b64 %e_bits, %e_ni64, 52;
3143 mov.b64 %e_nf, %e_bits;
3144 mul.f64 %e2sp, %e2sp, %e_nf;
3145
3146 sub.f64 %e2sp_m1, %e2sp, %one;
3147 add.f64 %e2sp_p1, %e2sp, %one;
3148 div.rn.f64 %t, %e2sp_m1, %e2sp_p1;
3149
3150 // === sigmoid(x) = 1/(1+exp(-x)) ===
3151 neg.f64 %neg_x, %x;
3152 mul.f64 %e_nf, %neg_x, 0d3FF71547652B82FE;
3153 cvt.rni.f64.f64 %e_nf, %e_nf;
3154 cvt.rni.s32.f64 %e_ni, %e_nf;
3155 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
3156 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3157 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3158 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3159 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3160 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3161 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3162 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3163 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3164 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3165 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3166 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3167 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3168 fma.rn.f64 %en, %e_p, %e_r, %one;
3169 cvt.s64.s32 %e_ni64, %e_ni;
3170 add.s64 %e_ni64, %e_ni64, 1023;
3171 shl.b64 %e_bits, %e_ni64, 52;
3172 mov.b64 %e_nf, %e_bits;
3173 mul.f64 %en, %en, %e_nf;
3174
3175 add.f64 %denom, %one, %en;
3176 div.rn.f64 %sig, %one, %denom;
3177
3178 // deriv = t + x * sig * (1 - t*t)
3179 mul.f64 %t2, %t, %t;
3180 sub.f64 %one_m_t2, %one, %t2;
3181 mul.f64 %x_sig_omt2, %x, %sig;
3182 mul.f64 %x_sig_omt2, %x_sig_omt2, %one_m_t2;
3183 add.f64 %deriv, %t, %x_sig_omt2;
3184 mul.f64 %result, %vg, %deriv;
3185 st.global.f64 [%out], %result;
3186 bra DONE;
3187
3188LARGE_X:
3189 // sp ~ x, tanh(x) in f64, sig ~ 1
3190 add.f64 %two_sp, %x, %x;
3191 mov.f64 %e_half, 0d3FE0000000000000;
3192 mul.f64 %e_nf, %two_sp, 0d3FF71547652B82FE;
3193 cvt.rni.f64.f64 %e_nf, %e_nf;
3194 cvt.rni.s32.f64 %e_ni, %e_nf;
3195 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
3196 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3197 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3198 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3199 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3200 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3201 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3202 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3203 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3204 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3205 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3206 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3207 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3208 fma.rn.f64 %e2sp, %e_p, %e_r, %one;
3209 cvt.s64.s32 %e_ni64, %e_ni;
3210 add.s64 %e_ni64, %e_ni64, 1023;
3211 shl.b64 %e_bits, %e_ni64, 52;
3212 mov.b64 %e_nf, %e_bits;
3213 mul.f64 %e2sp, %e2sp, %e_nf;
3214
3215 sub.f64 %e2sp_m1, %e2sp, %one;
3216 add.f64 %e2sp_p1, %e2sp, %one;
3217 div.rn.f64 %t, %e2sp_m1, %e2sp_p1;
3218
3219 // sig ~ 1, deriv ~ t + x*(1-t*t)
3220 mul.f64 %t2, %t, %t;
3221 sub.f64 %one_m_t2, %one, %t2;
3222 mul.f64 %x_sig_omt2, %x, %one_m_t2;
3223 add.f64 %deriv, %t, %x_sig_omt2;
3224 mul.f64 %result, %vg, %deriv;
3225 st.global.f64 [%out], %result;
3226
3227DONE:
3228 ret;
3229}
3230";
3231
3232#[cfg(feature = "cuda")]
3235pub(crate) const CLAMP_PTX: &str = "\
3236.version 7.0
3237.target sm_52
3238.address_size 64
3239
3240.visible .entry clamp_kernel(
3241 .param .u64 in_ptr,
3242 .param .u64 out_ptr,
3243 .param .u32 n,
3244 .param .f32 min_val,
3245 .param .f32 max_val
3246) {
3247 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3248 .reg .u64 %in, %out, %off;
3249 .reg .f32 %x, %mn, %mx, %result;
3250 .reg .pred %p;
3251
3252 ld.param.u64 %in, [in_ptr];
3253 ld.param.u64 %out, [out_ptr];
3254 ld.param.u32 %n_reg, [n];
3255 ld.param.f32 %mn, [min_val];
3256 ld.param.f32 %mx, [max_val];
3257
3258 mov.u32 %bid, %ctaid.x;
3259 mov.u32 %bdim, %ntid.x;
3260 mov.u32 %r_tid, %tid.x;
3261 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3262
3263 setp.ge.u32 %p, %r_tid, %n_reg;
3264 @%p bra DONE;
3265
3266 cvt.u64.u32 %off, %r_tid;
3267 shl.b64 %off, %off, 2;
3268 add.u64 %in, %in, %off;
3269 add.u64 %out, %out, %off;
3270
3271 ld.global.f32 %x, [%in];
3272 max.f32 %result, %x, %mn;
3273 min.f32 %result, %result, %mx;
3274 st.global.f32 [%out], %result;
3275
3276DONE:
3277 ret;
3278}
3279";
3280
3281
3282#[cfg(feature = "cuda")]
3292pub(crate) const FILL_F32_PTX: &str = "\
3293.version 7.0
3294.target sm_52
3295.address_size 64
3296
3297.visible .entry fill_f32_kernel(
3298 .param .u64 out_ptr,
3299 .param .f32 scalar,
3300 .param .u32 n
3301) {
3302 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3303 .reg .u64 %out, %off;
3304 .reg .f32 %v;
3305 .reg .pred %p;
3306
3307 ld.param.u64 %out, [out_ptr];
3308 ld.param.f32 %v, [scalar];
3309 ld.param.u32 %n_reg, [n];
3310
3311 mov.u32 %bid, %ctaid.x;
3312 mov.u32 %bdim, %ntid.x;
3313 mov.u32 %r_tid, %tid.x;
3314 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3315
3316 setp.ge.u32 %p, %r_tid, %n_reg;
3317 @%p bra DONE;
3318
3319 cvt.u64.u32 %off, %r_tid;
3320 shl.b64 %off, %off, 2;
3321 add.u64 %out, %out, %off;
3322 st.global.f32 [%out], %v;
3323
3324DONE:
3325 ret;
3326}
3327";
3328
3329
3330#[cfg(feature = "cuda")]
3337pub(crate) const ABS_BACKWARD_PTX: &str = "\
3338.version 7.0
3339.target sm_52
3340.address_size 64
3341
3342.visible .entry abs_backward_kernel(
3343 .param .u64 grad_ptr,
3344 .param .u64 input_ptr,
3345 .param .u64 out_ptr,
3346 .param .u32 n
3347) {
3348 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3349 .reg .u64 %grad, %input, %out, %off;
3350 .reg .f32 %vg, %vi, %zero, %neg_vg, %tmp, %vr;
3351 .reg .pred %p, %pos, %neg;
3352
3353 ld.param.u64 %grad, [grad_ptr];
3354 ld.param.u64 %input, [input_ptr];
3355 ld.param.u64 %out, [out_ptr];
3356 ld.param.u32 %n_reg, [n];
3357
3358 mov.u32 %bid, %ctaid.x;
3359 mov.u32 %bdim, %ntid.x;
3360 mov.u32 %r_tid, %tid.x;
3361 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3362
3363 setp.ge.u32 %p, %r_tid, %n_reg;
3364 @%p bra DONE;
3365
3366 cvt.u64.u32 %off, %r_tid;
3367 shl.b64 %off, %off, 2;
3368
3369 add.u64 %grad, %grad, %off;
3370 add.u64 %input, %input, %off;
3371 add.u64 %out, %out, %off;
3372
3373 ld.global.f32 %vg, [%grad];
3374 ld.global.f32 %vi, [%input];
3375 mov.f32 %zero, 0f00000000;
3376
3377 neg.f32 %neg_vg, %vg;
3378
3379 // tmp = (vi < 0) ? -vg : 0
3380 setp.lt.f32 %neg, %vi, %zero;
3381 selp.f32 %tmp, %neg_vg, %zero, %neg;
3382 // vr = (vi > 0) ? vg : tmp
3383 setp.gt.f32 %pos, %vi, %zero;
3384 selp.f32 %vr, %vg, %tmp, %pos;
3385
3386 st.global.f32 [%out], %vr;
3387
3388DONE:
3389 ret;
3390}
3391";
3392
3393
3394#[cfg(feature = "cuda")]
3397pub(crate) const RELU_BACKWARD_PTX: &str = "\
3398.version 7.0
3399.target sm_52
3400.address_size 64
3401
3402.visible .entry relu_backward_kernel(
3403 .param .u64 grad_ptr,
3404 .param .u64 input_ptr,
3405 .param .u64 out_ptr,
3406 .param .u32 n
3407) {
3408 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3409 .reg .u64 %grad, %input, %out, %off;
3410 .reg .f32 %vg, %vi, %zero, %vr;
3411 .reg .pred %p, %pos;
3412
3413 ld.param.u64 %grad, [grad_ptr];
3414 ld.param.u64 %input, [input_ptr];
3415 ld.param.u64 %out, [out_ptr];
3416 ld.param.u32 %n_reg, [n];
3417
3418 mov.u32 %bid, %ctaid.x;
3419 mov.u32 %bdim, %ntid.x;
3420 mov.u32 %r_tid, %tid.x;
3421 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3422
3423 setp.ge.u32 %p, %r_tid, %n_reg;
3424 @%p bra DONE;
3425
3426 cvt.u64.u32 %off, %r_tid;
3427 shl.b64 %off, %off, 2;
3428
3429 add.u64 %grad, %grad, %off;
3430 add.u64 %input, %input, %off;
3431 add.u64 %out, %out, %off;
3432
3433 ld.global.f32 %vg, [%grad];
3434 ld.global.f32 %vi, [%input];
3435 mov.f32 %zero, 0f00000000;
3436 setp.gt.f32 %pos, %vi, %zero;
3437 selp.f32 %vr, %vg, %zero, %pos;
3438 st.global.f32 [%out], %vr;
3439
3440DONE:
3441 ret;
3442}
3443";
3444
3445
3446#[cfg(feature = "cuda")]
3456pub(crate) const GELU_BACKWARD_PTX: &str = "\
3457.version 7.0
3458.target sm_52
3459.address_size 64
3460
3461.visible .entry gelu_backward_kernel(
3462 .param .u64 grad_ptr,
3463 .param .u64 input_ptr,
3464 .param .u64 out_ptr,
3465 .param .u32 n
3466) {
3467 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3468 .reg .u64 %grad, %input, %out, %off;
3469 .reg .f32 %vg, %x, %k, %kx, %neg_kx, %log2e, %exp_neg, %one, %denom, %sig;
3470 .reg .f32 %one_minus_sig, %kx_sig_oms, %dsig, %result;
3471 .reg .pred %p;
3472
3473 ld.param.u64 %grad, [grad_ptr];
3474 ld.param.u64 %input, [input_ptr];
3475 ld.param.u64 %out, [out_ptr];
3476 ld.param.u32 %n_reg, [n];
3477
3478 mov.u32 %bid, %ctaid.x;
3479 mov.u32 %bdim, %ntid.x;
3480 mov.u32 %r_tid, %tid.x;
3481 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3482
3483 setp.ge.u32 %p, %r_tid, %n_reg;
3484 @%p bra DONE;
3485
3486 cvt.u64.u32 %off, %r_tid;
3487 shl.b64 %off, %off, 2;
3488
3489 add.u64 %grad, %grad, %off;
3490 add.u64 %input, %input, %off;
3491 add.u64 %out, %out, %off;
3492
3493 ld.global.f32 %vg, [%grad];
3494 ld.global.f32 %x, [%input];
3495
3496 // sig = sigmoid(1.702 * x)
3497 mov.f32 %k, 0f3FDA2720;
3498 mul.f32 %kx, %k, %x;
3499 neg.f32 %neg_kx, %kx;
3500 mov.f32 %log2e, 0f3FB8AA3B;
3501 mul.f32 %neg_kx, %neg_kx, %log2e;
3502 ex2.approx.f32 %exp_neg, %neg_kx;
3503 mov.f32 %one, 0f3F800000;
3504 add.f32 %denom, %one, %exp_neg;
3505 rcp.approx.f32 %sig, %denom;
3506
3507 // d/dx gelu(x) = sig + k * x * sig * (1 - sig)
3508 sub.f32 %one_minus_sig, %one, %sig;
3509 mul.f32 %kx_sig_oms, %kx, %sig;
3510 mul.f32 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
3511 add.f32 %dsig, %sig, %kx_sig_oms;
3512
3513 // out = grad * d_gelu
3514 mul.f32 %result, %vg, %dsig;
3515 st.global.f32 [%out], %result;
3516
3517DONE:
3518 ret;
3519}
3520";
3521
3522#[cfg(feature = "cuda")]
3525pub(crate) const GELU_BACKWARD_F64_PTX: &str = "\
3526.version 7.0
3527.target sm_52
3528.address_size 64
3529
3530.visible .entry gelu_backward_f64_kernel(
3531 .param .u64 grad_ptr,
3532 .param .u64 input_ptr,
3533 .param .u64 out_ptr,
3534 .param .u32 n
3535) {
3536 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3537 .reg .u64 %grad, %input, %out, %off;
3538 .reg .f64 %vg, %x, %k, %kx, %neg_kx, %exp_neg, %one, %denom, %sig;
3539 .reg .f64 %one_minus_sig, %kx_sig_oms, %dsig, %result;
3540 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
3541 .reg .s32 %e_ni;
3542 .reg .s64 %e_ni64, %e_bits;
3543 .reg .pred %p;
3544
3545 ld.param.u64 %grad, [grad_ptr];
3546 ld.param.u64 %input, [input_ptr];
3547 ld.param.u64 %out, [out_ptr];
3548 ld.param.u32 %n_reg, [n];
3549
3550 mov.u32 %bid, %ctaid.x;
3551 mov.u32 %bdim, %ntid.x;
3552 mov.u32 %r_tid, %tid.x;
3553 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3554
3555 setp.ge.u32 %p, %r_tid, %n_reg;
3556 @%p bra DONE;
3557
3558 cvt.u64.u32 %off, %r_tid;
3559 shl.b64 %off, %off, 3;
3560 add.u64 %grad, %grad, %off;
3561 add.u64 %input, %input, %off;
3562 add.u64 %out, %out, %off;
3563
3564 ld.global.f64 %vg, [%grad];
3565 ld.global.f64 %x, [%input];
3566
3567 mov.f64 %one, 0d3FF0000000000000;
3568 mov.f64 %k, 0d3FFB44E400000000;
3569 mul.f64 %kx, %k, %x;
3570 neg.f64 %neg_kx, %kx;
3571
3572 // --- exp(%neg_kx) via Cody-Waite + degree-11 Horner ---
3573 mov.f64 %e_half, 0d3FE0000000000000;
3574 fma.rn.f64 %e_nf, %neg_kx, 0d3FF71547652B82FE, %e_half;
3575 cvt.rmi.f64.f64 %e_nf, %e_nf;
3576 cvt.rni.s32.f64 %e_ni, %e_nf;
3577 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_kx;
3578 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3579 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3580 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3581 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3582 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3583 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3584 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3585 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3586 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3587 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
3588 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3589 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3590 fma.rn.f64 %exp_neg, %e_p, %e_r, %one;
3591 cvt.s64.s32 %e_ni64, %e_ni;
3592 add.s64 %e_ni64, %e_ni64, 1023;
3593 shl.b64 %e_bits, %e_ni64, 52;
3594 mov.b64 %e_nf, %e_bits;
3595 mul.f64 %exp_neg, %exp_neg, %e_nf;
3596 // --- end exp ---
3597
3598 add.f64 %denom, %one, %exp_neg;
3599 div.rn.f64 %sig, %one, %denom;
3600
3601 sub.f64 %one_minus_sig, %one, %sig;
3602 mul.f64 %kx_sig_oms, %kx, %sig;
3603 mul.f64 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
3604 add.f64 %dsig, %sig, %kx_sig_oms;
3605
3606 mul.f64 %result, %vg, %dsig;
3607 st.global.f64 [%out], %result;
3608
3609DONE:
3610 ret;
3611}
3612";
3613
3614#[cfg(feature = "cuda")]
3622pub(crate) const GELU_BACKWARD_ERF_PTX: &str = "\
3623.version 7.0
3624.target sm_52
3625.address_size 64
3626
3627.visible .entry gelu_backward_erf_kernel(
3628 .param .u64 grad_ptr,
3629 .param .u64 input_ptr,
3630 .param .u64 out_ptr,
3631 .param .u32 n
3632) {
3633 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3634 .reg .u64 %grad, %input, %out, %off;
3635 .reg .f32 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
3636 .reg .f32 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
3637 .reg .f32 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
3638 .reg .f32 %d_gelu, %result;
3639 .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %log2e;
3640 .reg .pred %pred_ge, %pred_neg;
3641
3642 ld.param.u64 %grad, [grad_ptr];
3643 ld.param.u64 %input, [input_ptr];
3644 ld.param.u64 %out, [out_ptr];
3645 ld.param.u32 %n_reg, [n];
3646
3647 mov.u32 %bid, %ctaid.x;
3648 mov.u32 %bdim, %ntid.x;
3649 mov.u32 %r_tid, %tid.x;
3650 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3651
3652 setp.ge.u32 %pred_ge, %r_tid, %n_reg;
3653 @%pred_ge bra DONE;
3654
3655 cvt.u64.u32 %off, %r_tid;
3656 shl.b64 %off, %off, 2;
3657
3658 add.u64 %grad, %grad, %off;
3659 add.u64 %input, %input, %off;
3660 add.u64 %out, %out, %off;
3661
3662 ld.global.f32 %vg, [%grad];
3663 ld.global.f32 %x, [%input];
3664
3665 mov.f32 %one, 0f3F800000;
3666 mov.f32 %half, 0f3F000000;
3667
3668 // z = x / sqrt(2) = x * 0.70710678
3669 mov.f32 %z, 0f3F3504F3;
3670 mul.f32 %z, %x, %z;
3671
3672 // |z| for erf(|z|)
3673 abs.f32 %ax, %z;
3674
3675 // t = 1 / (1 + 0.3275911 * |z|)
3676 mov.f32 %p, 0f3EA7BA05;
3677 mul.f32 %t, %p, %ax;
3678 add.f32 %t, %one, %t;
3679 rcp.approx.f32 %t, %t;
3680
3681 // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
3682 mov.f32 %a5, 0f3E0AAAAB;
3683 mov.f32 %a4, 0fBEB3A903;
3684 mov.f32 %a3, 0f3FB506DD;
3685 mov.f32 %a2, 0fBF03C1E1;
3686 mov.f32 %a1, 0f3EA0D6BB;
3687
3688 mul.f32 %pt, %t, %a5;
3689 add.f32 %pt, %pt, %a4;
3690 mul.f32 %pt, %pt, %t;
3691 add.f32 %pt, %pt, %a3;
3692 mul.f32 %pt, %pt, %t;
3693 add.f32 %pt, %pt, %a2;
3694 mul.f32 %pt, %pt, %t;
3695 add.f32 %pt, %pt, %a1;
3696 mul.f32 %pt, %pt, %t;
3697
3698 // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
3699 mul.f32 %z2, %ax, %ax;
3700 neg.f32 %neg_z2, %z2;
3701 mov.f32 %log2e, 0f3FB8AA3B;
3702 mul.f32 %neg_z2, %neg_z2, %log2e;
3703 ex2.approx.f32 %exp_neg_z2, %neg_z2;
3704
3705 // erf(|z|) = 1 - poly * exp(-z^2)
3706 mul.f32 %erf_val, %pt, %exp_neg_z2;
3707 sub.f32 %erf_val, %one, %erf_val;
3708
3709 // erf(-z) = -erf(z), so sign-correct
3710 setp.lt.f32 %pred_neg, %z, 0f00000000;
3711 @%pred_neg neg.f32 %erf_val, %erf_val;
3712
3713 // Φ(x) = 0.5 * (1 + erf(x/sqrt(2)))
3714 add.f32 %cdf, %one, %erf_val;
3715 mul.f32 %cdf, %half, %cdf;
3716
3717 // φ(x) = exp(-x²/2) / sqrt(2π)
3718 // exp(-x²/2):
3719 mul.f32 %neg_x2h, %x, %x;
3720 mul.f32 %neg_x2h, %neg_x2h, %half;
3721 neg.f32 %neg_x2h, %neg_x2h;
3722 mul.f32 %neg_x2h, %neg_x2h, %log2e;
3723 ex2.approx.f32 %exp_neg_x2h, %neg_x2h;
3724
3725 // 1/sqrt(2π) = 0.39894228
3726 mov.f32 %inv_sqrt_2pi, 0f3ECC4220;
3727 mul.f32 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
3728
3729 // d/dx gelu(x) = Φ(x) + x * φ(x)
3730 mul.f32 %x_pdf, %x, %pdf;
3731 add.f32 %d_gelu, %cdf, %x_pdf;
3732
3733 // out = grad * d_gelu
3734 mul.f32 %result, %vg, %d_gelu;
3735 st.global.f32 [%out], %result;
3736
3737DONE:
3738 ret;
3739}
3740";
3741
3742#[cfg(feature = "cuda")]
3745pub(crate) const GELU_BACKWARD_ERF_F64_PTX: &str = "\
3746.version 7.0
3747.target sm_52
3748.address_size 64
3749
3750.visible .entry gelu_backward_erf_f64_kernel(
3751 .param .u64 grad_ptr,
3752 .param .u64 input_ptr,
3753 .param .u64 out_ptr,
3754 .param .u32 n
3755) {
3756 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3757 .reg .u64 %grad, %input, %out, %off;
3758 .reg .f64 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
3759 .reg .f64 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
3760 .reg .f64 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
3761 .reg .f64 %d_gelu, %result;
3762 .reg .f64 %p_coef, %a1, %a2, %a3, %a4, %a5;
3763 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
3764 .reg .s32 %e_ni;
3765 .reg .s64 %e_ni64, %e_bits;
3766 .reg .pred %pred_ge, %pred_neg;
3767
3768 ld.param.u64 %grad, [grad_ptr];
3769 ld.param.u64 %input, [input_ptr];
3770 ld.param.u64 %out, [out_ptr];
3771 ld.param.u32 %n_reg, [n];
3772
3773 mov.u32 %bid, %ctaid.x;
3774 mov.u32 %bdim, %ntid.x;
3775 mov.u32 %r_tid, %tid.x;
3776 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3777
3778 setp.ge.u32 %pred_ge, %r_tid, %n_reg;
3779 @%pred_ge bra DONE;
3780
3781 cvt.u64.u32 %off, %r_tid;
3782 shl.b64 %off, %off, 3;
3783 add.u64 %grad, %grad, %off;
3784 add.u64 %input, %input, %off;
3785 add.u64 %out, %out, %off;
3786
3787 ld.global.f64 %vg, [%grad];
3788 ld.global.f64 %x, [%input];
3789
3790 mov.f64 %one, 0d3FF0000000000000;
3791 mov.f64 %half, 0d3FE0000000000000;
3792
3793 mov.f64 %z, 0d3FE6A09E60000000;
3794 mul.f64 %z, %x, %z;
3795 abs.f64 %ax, %z;
3796
3797 mov.f64 %p_coef, 0d3FD4F740A0000000;
3798 mul.f64 %t, %p_coef, %ax;
3799 add.f64 %t, %one, %t;
3800 div.rn.f64 %t, %one, %t;
3801
3802 mov.f64 %a5, 0d3FC1555560000000;
3803 mov.f64 %a4, 0dBFD6752060000000;
3804 mov.f64 %a3, 0d3FF6A0DBA0000000;
3805 mov.f64 %a2, 0dBFE0783C20000000;
3806 mov.f64 %a1, 0d3FD41AD760000000;
3807
3808 mul.f64 %pt, %t, %a5;
3809 add.f64 %pt, %pt, %a4;
3810 mul.f64 %pt, %pt, %t;
3811 add.f64 %pt, %pt, %a3;
3812 mul.f64 %pt, %pt, %t;
3813 add.f64 %pt, %pt, %a2;
3814 mul.f64 %pt, %pt, %t;
3815 add.f64 %pt, %pt, %a1;
3816 mul.f64 %pt, %pt, %t;
3817
3818 // exp(-z^2) in full f64
3819 mul.f64 %z2, %ax, %ax;
3820 neg.f64 %neg_z2, %z2;
3821
3822 // --- exp(%neg_z2) ---
3823 mov.f64 %e_half, 0d3FE0000000000000;
3824 fma.rn.f64 %e_nf, %neg_z2, 0d3FF71547652B82FE, %e_half;
3825 cvt.rmi.f64.f64 %e_nf, %e_nf;
3826 cvt.rni.s32.f64 %e_ni, %e_nf;
3827 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_z2;
3828 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3829 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3830 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3831 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3832 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3833 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3834 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3835 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3836 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3837 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
3838 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3839 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3840 fma.rn.f64 %exp_neg_z2, %e_p, %e_r, %one;
3841 cvt.s64.s32 %e_ni64, %e_ni;
3842 add.s64 %e_ni64, %e_ni64, 1023;
3843 shl.b64 %e_bits, %e_ni64, 52;
3844 mov.b64 %e_nf, %e_bits;
3845 mul.f64 %exp_neg_z2, %exp_neg_z2, %e_nf;
3846 // --- end exp ---
3847
3848 mul.f64 %erf_val, %pt, %exp_neg_z2;
3849 sub.f64 %erf_val, %one, %erf_val;
3850
3851 setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
3852 @%pred_neg neg.f64 %erf_val, %erf_val;
3853
3854 add.f64 %cdf, %one, %erf_val;
3855 mul.f64 %cdf, %half, %cdf;
3856
3857 // phi(x) = exp(-x^2/2) / sqrt(2*pi)
3858 mul.f64 %neg_x2h, %x, %x;
3859 mul.f64 %neg_x2h, %neg_x2h, %half;
3860 neg.f64 %neg_x2h, %neg_x2h;
3861
3862 // --- exp(%neg_x2h) ---
3863 fma.rn.f64 %e_nf, %neg_x2h, 0d3FF71547652B82FE, %e_half;
3864 cvt.rmi.f64.f64 %e_nf, %e_nf;
3865 cvt.rni.s32.f64 %e_ni, %e_nf;
3866 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x2h;
3867 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3868 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3869 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3870 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3871 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3872 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3873 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3874 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3875 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3876 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
3877 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3878 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3879 fma.rn.f64 %exp_neg_x2h, %e_p, %e_r, %one;
3880 cvt.s64.s32 %e_ni64, %e_ni;
3881 add.s64 %e_ni64, %e_ni64, 1023;
3882 shl.b64 %e_bits, %e_ni64, 52;
3883 mov.b64 %e_nf, %e_bits;
3884 mul.f64 %exp_neg_x2h, %exp_neg_x2h, %e_nf;
3885 // --- end exp ---
3886
3887 // 1/sqrt(2*pi) = 0.39894228
3888 mov.f64 %inv_sqrt_2pi, 0d3FD9884440000000;
3889 mul.f64 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
3890
3891 mul.f64 %x_pdf, %x, %pdf;
3892 add.f64 %d_gelu, %cdf, %x_pdf;
3893
3894 mul.f64 %result, %vg, %d_gelu;
3895 st.global.f64 [%out], %result;
3896
3897DONE:
3898 ret;
3899}
3900";
3901
3902#[cfg(feature = "cuda")]
3909pub(crate) const INDEX_SELECT_1D_PTX: &str = "\
3910.version 7.0
3911.target sm_52
3912.address_size 64
3913
3914.visible .entry index_select_1d_kernel(
3915 .param .u64 input_ptr,
3916 .param .u64 indices_ptr,
3917 .param .u64 out_ptr,
3918 .param .u32 n_indices
3919) {
3920 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
3921 .reg .u64 %input, %indices, %out, %off, %addr;
3922 .reg .f32 %idx_f, %val;
3923 .reg .pred %p;
3924
3925 ld.param.u64 %input, [input_ptr];
3926 ld.param.u64 %indices, [indices_ptr];
3927 ld.param.u64 %out, [out_ptr];
3928 ld.param.u32 %n_reg, [n_indices];
3929
3930 mov.u32 %bid, %ctaid.x;
3931 mov.u32 %bdim, %ntid.x;
3932 mov.u32 %r_tid, %tid.x;
3933 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3934
3935 setp.ge.u32 %p, %r_tid, %n_reg;
3936 @%p bra DONE;
3937
3938 // Byte offset for thread
3939 cvt.u64.u32 %off, %r_tid;
3940 shl.b64 %off, %off, 2;
3941
3942 // Read indices[tid] (f32 -> u32)
3943 add.u64 %addr, %indices, %off;
3944 ld.global.f32 %idx_f, [%addr];
3945 cvt.rzi.u32.f32 %idx, %idx_f;
3946
3947 // Read input[idx]
3948 cvt.u64.u32 %addr, %idx;
3949 shl.b64 %addr, %addr, 2;
3950 add.u64 %addr, %input, %addr;
3951 ld.global.f32 %val, [%addr];
3952
3953 // Write output[tid]
3954 add.u64 %addr, %out, %off;
3955 st.global.f32 [%addr], %val;
3956
3957DONE:
3958 ret;
3959}
3960";
3961
3962
3963#[cfg(feature = "cuda")]
3972pub(crate) const SCATTER_ADD_1D_PTX: &str = "\
3973.version 7.0
3974.target sm_52
3975.address_size 64
3976
3977.visible .entry scatter_add_1d_kernel(
3978 .param .u64 grad_output_ptr,
3979 .param .u64 indices_ptr,
3980 .param .u64 grad_input_ptr,
3981 .param .u32 n_indices
3982) {
3983 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
3984 .reg .u64 %go, %indices, %gi, %off, %addr;
3985 .reg .f32 %idx_f, %grad_val, %dummy;
3986 .reg .pred %p;
3987
3988 ld.param.u64 %go, [grad_output_ptr];
3989 ld.param.u64 %indices, [indices_ptr];
3990 ld.param.u64 %gi, [grad_input_ptr];
3991 ld.param.u32 %n_reg, [n_indices];
3992
3993 mov.u32 %bid, %ctaid.x;
3994 mov.u32 %bdim, %ntid.x;
3995 mov.u32 %r_tid, %tid.x;
3996 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3997
3998 setp.ge.u32 %p, %r_tid, %n_reg;
3999 @%p bra DONE;
4000
4001 // Byte offset for thread
4002 cvt.u64.u32 %off, %r_tid;
4003 shl.b64 %off, %off, 2;
4004
4005 // Read grad_output[tid]
4006 add.u64 %addr, %go, %off;
4007 ld.global.f32 %grad_val, [%addr];
4008
4009 // Read indices[tid] (f32 -> u32)
4010 add.u64 %addr, %indices, %off;
4011 ld.global.f32 %idx_f, [%addr];
4012 cvt.rzi.u32.f32 %idx, %idx_f;
4013
4014 // Atomic add: grad_input[idx] += grad_val
4015 cvt.u64.u32 %addr, %idx;
4016 shl.b64 %addr, %addr, 2;
4017 add.u64 %addr, %gi, %addr;
4018 atom.global.add.f32 %dummy, [%addr], %grad_val;
4019
4020DONE:
4021 ret;
4022}
4023";
4024
4025
4026#[cfg(feature = "cuda")]
4033pub(crate) const MASKED_FILL_PTX: &str = "\
4034.version 7.0
4035.target sm_52
4036.address_size 64
4037
4038.visible .entry masked_fill_kernel(
4039 .param .u64 input_ptr,
4040 .param .u64 mask_ptr,
4041 .param .u64 out_ptr,
4042 .param .f32 fill_value,
4043 .param .u32 n
4044) {
4045 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4046 .reg .u64 %input, %mask, %out, %off;
4047 .reg .f32 %in_val, %mask_val, %fill, %result, %half;
4048 .reg .pred %p, %pmask;
4049
4050 ld.param.u64 %input, [input_ptr];
4051 ld.param.u64 %mask, [mask_ptr];
4052 ld.param.u64 %out, [out_ptr];
4053 ld.param.f32 %fill, [fill_value];
4054 ld.param.u32 %n_reg, [n];
4055
4056 mov.u32 %bid, %ctaid.x;
4057 mov.u32 %bdim, %ntid.x;
4058 mov.u32 %r_tid, %tid.x;
4059 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4060
4061 setp.ge.u32 %p, %r_tid, %n_reg;
4062 @%p bra DONE;
4063
4064 cvt.u64.u32 %off, %r_tid;
4065 shl.b64 %off, %off, 2;
4066
4067 add.u64 %input, %input, %off;
4068 add.u64 %mask, %mask, %off;
4069 add.u64 %out, %out, %off;
4070
4071 ld.global.f32 %in_val, [%input];
4072 ld.global.f32 %mask_val, [%mask];
4073 mov.f32 %half, 0f3F000000;
4074 setp.ge.f32 %pmask, %mask_val, %half;
4075 selp.f32 %result, %fill, %in_val, %pmask;
4076 st.global.f32 [%out], %result;
4077
4078DONE:
4079 ret;
4080}
4081";
4082
4083
4084#[cfg(feature = "cuda")]
4091pub(crate) const MASKED_ZERO_PTX: &str = "\
4092.version 7.0
4093.target sm_52
4094.address_size 64
4095
4096.visible .entry masked_zero_kernel(
4097 .param .u64 grad_ptr,
4098 .param .u64 mask_ptr,
4099 .param .u64 out_ptr,
4100 .param .u32 n
4101) {
4102 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4103 .reg .u64 %grad, %mask, %out, %off;
4104 .reg .f32 %vg, %mask_val, %zero, %result, %half;
4105 .reg .pred %p, %pmask;
4106
4107 ld.param.u64 %grad, [grad_ptr];
4108 ld.param.u64 %mask, [mask_ptr];
4109 ld.param.u64 %out, [out_ptr];
4110 ld.param.u32 %n_reg, [n];
4111
4112 mov.u32 %bid, %ctaid.x;
4113 mov.u32 %bdim, %ntid.x;
4114 mov.u32 %r_tid, %tid.x;
4115 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4116
4117 setp.ge.u32 %p, %r_tid, %n_reg;
4118 @%p bra DONE;
4119
4120 cvt.u64.u32 %off, %r_tid;
4121 shl.b64 %off, %off, 2;
4122
4123 add.u64 %grad, %grad, %off;
4124 add.u64 %mask, %mask, %off;
4125 add.u64 %out, %out, %off;
4126
4127 ld.global.f32 %vg, [%grad];
4128 ld.global.f32 %mask_val, [%mask];
4129 mov.f32 %zero, 0f00000000;
4130 mov.f32 %half, 0f3F000000;
4131 setp.ge.f32 %pmask, %mask_val, %half;
4132 selp.f32 %result, %zero, %vg, %pmask;
4133 st.global.f32 [%out], %result;
4134
4135DONE:
4136 ret;
4137}
4138";
4139
4140
4141#[cfg(feature = "cuda")]
4146pub(crate) const SIGMOID_BACKWARD_PTX: &str = "\
4147.version 7.0
4148.target sm_52
4149.address_size 64
4150
4151.visible .entry sigmoid_backward_kernel(
4152 .param .u64 grad_ptr,
4153 .param .u64 output_ptr,
4154 .param .u64 out_ptr,
4155 .param .u32 n
4156) {
4157 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4158 .reg .u64 %grad, %output, %out, %off;
4159 .reg .f32 %vg, %vo, %one, %one_minus_o, %result;
4160 .reg .pred %p;
4161
4162 ld.param.u64 %grad, [grad_ptr];
4163 ld.param.u64 %output, [output_ptr];
4164 ld.param.u64 %out, [out_ptr];
4165 ld.param.u32 %n_reg, [n];
4166
4167 mov.u32 %bid, %ctaid.x;
4168 mov.u32 %bdim, %ntid.x;
4169 mov.u32 %r_tid, %tid.x;
4170 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4171
4172 setp.ge.u32 %p, %r_tid, %n_reg;
4173 @%p bra DONE;
4174
4175 cvt.u64.u32 %off, %r_tid;
4176 shl.b64 %off, %off, 2;
4177
4178 add.u64 %grad, %grad, %off;
4179 add.u64 %output, %output, %off;
4180 add.u64 %out, %out, %off;
4181
4182 ld.global.f32 %vg, [%grad];
4183 ld.global.f32 %vo, [%output];
4184 mov.f32 %one, 0f3F800000;
4185 sub.f32 %one_minus_o, %one, %vo;
4186 mul.f32 %result, %vo, %one_minus_o;
4187 mul.f32 %result, %vg, %result;
4188 st.global.f32 [%out], %result;
4189
4190DONE:
4191 ret;
4192}
4193";
4194
4195
4196#[cfg(feature = "cuda")]
4201pub(crate) const TANH_BACKWARD_PTX: &str = "\
4202.version 7.0
4203.target sm_52
4204.address_size 64
4205
4206.visible .entry tanh_backward_kernel(
4207 .param .u64 grad_ptr,
4208 .param .u64 output_ptr,
4209 .param .u64 out_ptr,
4210 .param .u32 n
4211) {
4212 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4213 .reg .u64 %grad, %output, %out, %off;
4214 .reg .f32 %vg, %vo, %one, %o_sq, %one_minus_sq, %result;
4215 .reg .pred %p;
4216
4217 ld.param.u64 %grad, [grad_ptr];
4218 ld.param.u64 %output, [output_ptr];
4219 ld.param.u64 %out, [out_ptr];
4220 ld.param.u32 %n_reg, [n];
4221
4222 mov.u32 %bid, %ctaid.x;
4223 mov.u32 %bdim, %ntid.x;
4224 mov.u32 %r_tid, %tid.x;
4225 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4226
4227 setp.ge.u32 %p, %r_tid, %n_reg;
4228 @%p bra DONE;
4229
4230 cvt.u64.u32 %off, %r_tid;
4231 shl.b64 %off, %off, 2;
4232
4233 add.u64 %grad, %grad, %off;
4234 add.u64 %output, %output, %off;
4235 add.u64 %out, %out, %off;
4236
4237 ld.global.f32 %vg, [%grad];
4238 ld.global.f32 %vo, [%output];
4239 mov.f32 %one, 0f3F800000;
4240 mul.f32 %o_sq, %vo, %vo;
4241 sub.f32 %one_minus_sq, %one, %o_sq;
4242 mul.f32 %result, %vg, %one_minus_sq;
4243 st.global.f32 [%out], %result;
4244
4245DONE:
4246 ret;
4247}
4248";
4249
4250
4251#[cfg(feature = "cuda")]
4260pub(crate) const SOFTMAX_BACKWARD_PTX: &str = "\
4261.version 7.0\n\
4262.target sm_52\n\
4263.address_size 64\n\
4264\n\
4265.shared .align 4 .f32 sdata[256];\n\
4266\n\
4267.visible .entry softmax_backward_kernel(\n\
4268 .param .u64 grad_ptr,\n\
4269 .param .u64 output_ptr,\n\
4270 .param .u64 out_ptr,\n\
4271 .param .u32 rows,\n\
4272 .param .u32 cols\n\
4273) {\n\
4274 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
4275 .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
4276 .reg .f32 %vg, %vo, %dot, %other_val, %diff, %result;\n\
4277 .reg .pred %p, %loop_p, %reduce_p;\n\
4278\n\
4279 ld.param.u64 %grad, [grad_ptr];\n\
4280 ld.param.u64 %output, [output_ptr];\n\
4281 ld.param.u64 %out, [out_ptr];\n\
4282 ld.param.u32 %rows_reg, [rows];\n\
4283 ld.param.u32 %cols_reg, [cols];\n\
4284\n\
4285 mov.u32 %bid, %ctaid.x;\n\
4286 mov.u32 %bdim, %ntid.x;\n\
4287 mov.u32 %r_tid, %tid.x;\n\
4288 mov.u64 %sbase, sdata;\n\
4289\n\
4290 setp.ge.u32 %p, %bid, %rows_reg;\n\
4291 @%p bra DONE;\n\
4292\n\
4293 // row_off = bid * cols * 4 (byte offset)\n\
4294 cvt.u64.u32 %row_off, %bid;\n\
4295 cvt.u64.u32 %off, %cols_reg;\n\
4296 mul.lo.u64 %row_off, %row_off, %off;\n\
4297 shl.b64 %row_off, %row_off, 2;\n\
4298\n\
4299 // Phase 1: compute partial dot = sum(grad[j] * output[j]) for this thread's elements\n\
4300 mov.f32 %dot, 0f00000000;\n\
4301 mov.u32 %j, %r_tid;\n\
4302DOT_LOOP:\n\
4303 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4304 @%loop_p bra DOT_LOOP_DONE;\n\
4305 cvt.u64.u32 %off, %j;\n\
4306 shl.b64 %off, %off, 2;\n\
4307 add.u64 %saddr, %grad, %off;\n\
4308 add.u64 %saddr, %saddr, %row_off;\n\
4309 ld.global.f32 %vg, [%saddr];\n\
4310 add.u64 %saddr, %output, %off;\n\
4311 add.u64 %saddr, %saddr, %row_off;\n\
4312 ld.global.f32 %vo, [%saddr];\n\
4313 fma.rn.f32 %dot, %vg, %vo, %dot;\n\
4314 add.u32 %j, %j, %bdim;\n\
4315 bra DOT_LOOP;\n\
4316DOT_LOOP_DONE:\n\
4317\n\
4318 // Store partial dot into shared memory and reduce\n\
4319 cvt.u64.u32 %off, %r_tid;\n\
4320 shl.b64 %off, %off, 2;\n\
4321 add.u64 %saddr, %sbase, %off;\n\
4322 st.shared.f32 [%saddr], %dot;\n\
4323 bar.sync 0;\n\
4324\n\
4325 mov.u32 %half, %bdim;\n\
4326DOT_REDUCE:\n\
4327 shr.u32 %half, %half, 1;\n\
4328 setp.eq.u32 %reduce_p, %half, 0;\n\
4329 @%reduce_p bra DOT_REDUCE_DONE;\n\
4330 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4331 @%reduce_p bra DOT_REDUCE_SKIP;\n\
4332 add.u32 %other_tid, %r_tid, %half;\n\
4333 cvt.u64.u32 %off, %other_tid;\n\
4334 shl.b64 %off, %off, 2;\n\
4335 add.u64 %saddr, %sbase, %off;\n\
4336 ld.shared.f32 %other_val, [%saddr];\n\
4337 cvt.u64.u32 %off, %r_tid;\n\
4338 shl.b64 %off, %off, 2;\n\
4339 add.u64 %saddr, %sbase, %off;\n\
4340 ld.shared.f32 %dot, [%saddr];\n\
4341 add.f32 %dot, %dot, %other_val;\n\
4342 st.shared.f32 [%saddr], %dot;\n\
4343DOT_REDUCE_SKIP:\n\
4344 bar.sync 0;\n\
4345 bra DOT_REDUCE;\n\
4346DOT_REDUCE_DONE:\n\
4347\n\
4348 // Broadcast dot to all threads\n\
4349 ld.shared.f32 %dot, [sdata];\n\
4350 bar.sync 0;\n\
4351\n\
4352 // Phase 2: out[j] = output[j] * (grad[j] - dot)\n\
4353 mov.u32 %j, %r_tid;\n\
4354WRITE_LOOP:\n\
4355 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4356 @%loop_p bra WRITE_LOOP_DONE;\n\
4357 cvt.u64.u32 %off, %j;\n\
4358 shl.b64 %off, %off, 2;\n\
4359 add.u64 %saddr, %grad, %off;\n\
4360 add.u64 %saddr, %saddr, %row_off;\n\
4361 ld.global.f32 %vg, [%saddr];\n\
4362 add.u64 %saddr, %output, %off;\n\
4363 add.u64 %saddr, %saddr, %row_off;\n\
4364 ld.global.f32 %vo, [%saddr];\n\
4365 sub.f32 %diff, %vg, %dot;\n\
4366 mul.f32 %result, %vo, %diff;\n\
4367 add.u64 %saddr, %out, %off;\n\
4368 add.u64 %saddr, %saddr, %row_off;\n\
4369 st.global.f32 [%saddr], %result;\n\
4370 add.u32 %j, %j, %bdim;\n\
4371 bra WRITE_LOOP;\n\
4372WRITE_LOOP_DONE:\n\
4373\n\
4374DONE:\n\
4375 ret;\n\
4376}\n\
4377";
4378
4379
4380#[cfg(feature = "cuda")]
4390pub(crate) const LOG_SOFTMAX_PTX: &str = "\
4391.version 7.0\n\
4392.target sm_52\n\
4393.address_size 64\n\
4394\n\
4395.shared .align 4 .f32 sdata[256];\n\
4396\n\
4397.visible .entry log_softmax_kernel(\n\
4398 .param .u64 input_ptr,\n\
4399 .param .u64 output_ptr,\n\
4400 .param .u32 rows,\n\
4401 .param .u32 cols\n\
4402) {\n\
4403 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
4404 .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
4405 .reg .f32 %val, %max_val, %sum_val, %exp_val, %log_sum_exp, %result;\n\
4406 .reg .pred %p, %loop_p;\n\
4407 .reg .u32 %half, %other_tid;\n\
4408 .reg .f32 %other_val;\n\
4409 .reg .pred %reduce_p;\n\
4410\n\
4411 ld.param.u64 %in, [input_ptr];\n\
4412 ld.param.u64 %out, [output_ptr];\n\
4413 ld.param.u32 %rows_reg, [rows];\n\
4414 ld.param.u32 %cols_reg, [cols];\n\
4415\n\
4416 mov.u32 %bid, %ctaid.x;\n\
4417 mov.u32 %bdim, %ntid.x;\n\
4418 mov.u32 %r_tid, %tid.x;\n\
4419 mov.u64 %sbase, sdata;\n\
4420\n\
4421 setp.ge.u32 %p, %bid, %rows_reg;\n\
4422 @%p bra DONE;\n\
4423\n\
4424 // row_off = bid * cols * 4 (byte offset)\n\
4425 cvt.u64.u32 %row_off, %bid;\n\
4426 cvt.u64.u32 %off, %cols_reg;\n\
4427 mul.lo.u64 %row_off, %row_off, %off;\n\
4428 shl.b64 %row_off, %row_off, 2;\n\
4429\n\
4430 // Phase 1: find max across row (grid-stride over columns)\n\
4431 mov.f32 %max_val, 0fFF800000;\n\
4432 mov.u32 %j, %r_tid;\n\
4433FIND_MAX:\n\
4434 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4435 @%loop_p bra FIND_MAX_DONE;\n\
4436 cvt.u64.u32 %off, %j;\n\
4437 shl.b64 %off, %off, 2;\n\
4438 add.u64 %off, %in, %off;\n\
4439 add.u64 %off, %off, %row_off;\n\
4440 ld.global.f32 %val, [%off];\n\
4441 max.f32 %max_val, %max_val, %val;\n\
4442 add.u32 %j, %j, %bdim;\n\
4443 bra FIND_MAX;\n\
4444FIND_MAX_DONE:\n\
4445\n\
4446 // Shared-memory tree reduction for max\n\
4447 cvt.u64.u32 %off, %r_tid;\n\
4448 shl.b64 %off, %off, 2;\n\
4449 add.u64 %saddr, %sbase, %off;\n\
4450 st.shared.f32 [%saddr], %max_val;\n\
4451 bar.sync 0;\n\
4452\n\
4453 mov.u32 %half, %bdim;\n\
4454MAX_REDUCE:\n\
4455 shr.u32 %half, %half, 1;\n\
4456 setp.eq.u32 %reduce_p, %half, 0;\n\
4457 @%reduce_p bra MAX_REDUCE_DONE;\n\
4458 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4459 @%reduce_p bra MAX_REDUCE_SKIP;\n\
4460 add.u32 %other_tid, %r_tid, %half;\n\
4461 cvt.u64.u32 %off, %other_tid;\n\
4462 shl.b64 %off, %off, 2;\n\
4463 add.u64 %saddr, %sbase, %off;\n\
4464 ld.shared.f32 %other_val, [%saddr];\n\
4465 cvt.u64.u32 %off, %r_tid;\n\
4466 shl.b64 %off, %off, 2;\n\
4467 add.u64 %saddr, %sbase, %off;\n\
4468 ld.shared.f32 %max_val, [%saddr];\n\
4469 max.f32 %max_val, %max_val, %other_val;\n\
4470 add.u64 %saddr, %sbase, %off;\n\
4471 st.shared.f32 [%saddr], %max_val;\n\
4472MAX_REDUCE_SKIP:\n\
4473 bar.sync 0;\n\
4474 bra MAX_REDUCE;\n\
4475MAX_REDUCE_DONE:\n\
4476\n\
4477 // Broadcast max to all threads\n\
4478 ld.shared.f32 %max_val, [sdata];\n\
4479 bar.sync 0;\n\
4480\n\
4481 // Phase 2: compute partial sum of exp(x[j] - max)\n\
4482 mov.f32 %sum_val, 0f00000000;\n\
4483 mov.u32 %j, %r_tid;\n\
4484SUM_EXP:\n\
4485 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4486 @%loop_p bra SUM_EXP_DONE;\n\
4487 cvt.u64.u32 %off, %j;\n\
4488 shl.b64 %off, %off, 2;\n\
4489 add.u64 %off, %in, %off;\n\
4490 add.u64 %off, %off, %row_off;\n\
4491 ld.global.f32 %val, [%off];\n\
4492 sub.f32 %val, %val, %max_val;\n\
4493 // exp(x) = exp2(x * log2(e)), log2(e) = 0x3FB8AA3B\n\
4494 mul.f32 %val, %val, 0f3FB8AA3B;\n\
4495 ex2.approx.f32 %exp_val, %val;\n\
4496 add.f32 %sum_val, %sum_val, %exp_val;\n\
4497 add.u32 %j, %j, %bdim;\n\
4498 bra SUM_EXP;\n\
4499SUM_EXP_DONE:\n\
4500\n\
4501 // Shared-memory tree reduction for sum\n\
4502 cvt.u64.u32 %off, %r_tid;\n\
4503 shl.b64 %off, %off, 2;\n\
4504 add.u64 %saddr, %sbase, %off;\n\
4505 st.shared.f32 [%saddr], %sum_val;\n\
4506 bar.sync 0;\n\
4507\n\
4508 mov.u32 %half, %bdim;\n\
4509SUM_REDUCE:\n\
4510 shr.u32 %half, %half, 1;\n\
4511 setp.eq.u32 %reduce_p, %half, 0;\n\
4512 @%reduce_p bra SUM_REDUCE_DONE;\n\
4513 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4514 @%reduce_p bra SUM_REDUCE_SKIP;\n\
4515 add.u32 %other_tid, %r_tid, %half;\n\
4516 cvt.u64.u32 %off, %other_tid;\n\
4517 shl.b64 %off, %off, 2;\n\
4518 add.u64 %saddr, %sbase, %off;\n\
4519 ld.shared.f32 %other_val, [%saddr];\n\
4520 cvt.u64.u32 %off, %r_tid;\n\
4521 shl.b64 %off, %off, 2;\n\
4522 add.u64 %saddr, %sbase, %off;\n\
4523 ld.shared.f32 %sum_val, [%saddr];\n\
4524 add.f32 %sum_val, %sum_val, %other_val;\n\
4525 add.u64 %saddr, %sbase, %off;\n\
4526 st.shared.f32 [%saddr], %sum_val;\n\
4527SUM_REDUCE_SKIP:\n\
4528 bar.sync 0;\n\
4529 bra SUM_REDUCE;\n\
4530SUM_REDUCE_DONE:\n\
4531\n\
4532 // Broadcast sum to all threads, compute log_sum_exp = max + log(sum)\n\
4533 ld.shared.f32 %sum_val, [sdata];\n\
4534 bar.sync 0;\n\
4535 // log(x) = log2(x) / log2(e) = log2(x) * ln(2)\n\
4536 // ln(2) = 0x3F317218\n\
4537 lg2.approx.f32 %log_sum_exp, %sum_val;\n\
4538 mul.f32 %log_sum_exp, %log_sum_exp, 0f3F317218;\n\
4539 add.f32 %log_sum_exp, %max_val, %log_sum_exp;\n\
4540\n\
4541 // Phase 3: out[j] = x[j] - log_sum_exp\n\
4542 mov.u32 %j, %r_tid;\n\
4543WRITE_OUTPUT:\n\
4544 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4545 @%loop_p bra WRITE_OUTPUT_DONE;\n\
4546 cvt.u64.u32 %off, %j;\n\
4547 shl.b64 %off, %off, 2;\n\
4548 add.u64 %saddr, %in, %off;\n\
4549 add.u64 %saddr, %saddr, %row_off;\n\
4550 ld.global.f32 %val, [%saddr];\n\
4551 sub.f32 %result, %val, %log_sum_exp;\n\
4552 cvt.u64.u32 %off, %j;\n\
4553 shl.b64 %off, %off, 2;\n\
4554 add.u64 %saddr, %out, %off;\n\
4555 add.u64 %saddr, %saddr, %row_off;\n\
4556 st.global.f32 [%saddr], %result;\n\
4557 add.u32 %j, %j, %bdim;\n\
4558 bra WRITE_OUTPUT;\n\
4559WRITE_OUTPUT_DONE:\n\
4560\n\
4561DONE:\n\
4562 ret;\n\
4563}\n\
4564";
4565
4566#[cfg(feature = "cuda")]
4568pub(crate) const LOG_SOFTMAX_F64_PTX: &str = "\
4569.version 7.0\n\
4570.target sm_52\n\
4571.address_size 64\n\
4572\n\
4573.shared .align 8 .f64 sdata[256];\n\
4574\n\
4575.visible .entry log_softmax_f64_kernel(\n\
4576 .param .u64 input_ptr,\n\
4577 .param .u64 output_ptr,\n\
4578 .param .u32 rows,\n\
4579 .param .u32 cols\n\
4580) {\n\
4581 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
4582 .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
4583 .reg .f64 %val, %max_val, %sum_val, %exp_val, %log_sum_exp, %result;\n\
4584 .reg .pred %p, %loop_p;\n\
4585 .reg .u32 %half, %other_tid;\n\
4586 .reg .f64 %other_val;\n\
4587 .reg .pred %reduce_p;\n\
4588 .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
4589 .reg .s32 %e_ni;\n\
4590 .reg .s64 %e_ni64, %e_bits;\n\
4591 .reg .u64 %l_xbits, %l_mbits, %l_bias;\n\
4592 .reg .s64 %l_exp64;\n\
4593 .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;\n\
4594\n\
4595 ld.param.u64 %in, [input_ptr];\n\
4596 ld.param.u64 %out, [output_ptr];\n\
4597 ld.param.u32 %rows_reg, [rows];\n\
4598 ld.param.u32 %cols_reg, [cols];\n\
4599\n\
4600 mov.u32 %bid, %ctaid.x;\n\
4601 mov.u32 %bdim, %ntid.x;\n\
4602 mov.u32 %r_tid, %tid.x;\n\
4603 mov.u64 %sbase, sdata;\n\
4604\n\
4605 setp.ge.u32 %p, %bid, %rows_reg;\n\
4606 @%p bra DONE;\n\
4607\n\
4608 cvt.u64.u32 %row_off, %bid;\n\
4609 cvt.u64.u32 %off, %cols_reg;\n\
4610 mul.lo.u64 %row_off, %row_off, %off;\n\
4611 shl.b64 %row_off, %row_off, 3;\n\
4612\n\
4613 mov.f64 %max_val, 0dFFF0000000000000;\n\
4614 mov.u32 %j, %r_tid;\n\
4615FIND_MAX:\n\
4616 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4617 @%loop_p bra FIND_MAX_DONE;\n\
4618 cvt.u64.u32 %off, %j;\n\
4619 shl.b64 %off, %off, 3;\n\
4620 add.u64 %off, %in, %off;\n\
4621 add.u64 %off, %off, %row_off;\n\
4622 ld.global.f64 %val, [%off];\n\
4623 max.f64 %max_val, %max_val, %val;\n\
4624 add.u32 %j, %j, %bdim;\n\
4625 bra FIND_MAX;\n\
4626FIND_MAX_DONE:\n\
4627\n\
4628 cvt.u64.u32 %off, %r_tid;\n\
4629 shl.b64 %off, %off, 3;\n\
4630 add.u64 %saddr, %sbase, %off;\n\
4631 st.shared.f64 [%saddr], %max_val;\n\
4632 bar.sync 0;\n\
4633\n\
4634 mov.u32 %half, %bdim;\n\
4635MAX_REDUCE:\n\
4636 shr.u32 %half, %half, 1;\n\
4637 setp.eq.u32 %reduce_p, %half, 0;\n\
4638 @%reduce_p bra MAX_REDUCE_DONE;\n\
4639 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4640 @%reduce_p bra MAX_REDUCE_SKIP;\n\
4641 add.u32 %other_tid, %r_tid, %half;\n\
4642 cvt.u64.u32 %off, %other_tid;\n\
4643 shl.b64 %off, %off, 3;\n\
4644 add.u64 %saddr, %sbase, %off;\n\
4645 ld.shared.f64 %other_val, [%saddr];\n\
4646 cvt.u64.u32 %off, %r_tid;\n\
4647 shl.b64 %off, %off, 3;\n\
4648 add.u64 %saddr, %sbase, %off;\n\
4649 ld.shared.f64 %max_val, [%saddr];\n\
4650 max.f64 %max_val, %max_val, %other_val;\n\
4651 st.shared.f64 [%saddr], %max_val;\n\
4652MAX_REDUCE_SKIP:\n\
4653 bar.sync 0;\n\
4654 bra MAX_REDUCE;\n\
4655MAX_REDUCE_DONE:\n\
4656\n\
4657 ld.shared.f64 %max_val, [sdata];\n\
4658 bar.sync 0;\n\
4659\n\
4660 mov.f64 %sum_val, 0d0000000000000000;\n\
4661 mov.u32 %j, %r_tid;\n\
4662SUM_EXP:\n\
4663 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4664 @%loop_p bra SUM_EXP_DONE;\n\
4665 cvt.u64.u32 %off, %j;\n\
4666 shl.b64 %off, %off, 3;\n\
4667 add.u64 %off, %in, %off;\n\
4668 add.u64 %off, %off, %row_off;\n\
4669 ld.global.f64 %val, [%off];\n\
4670 sub.f64 %val, %val, %max_val;\n\
4671 mov.f64 %e_one, 0d3FF0000000000000;\n\
4672 mov.f64 %e_half, 0d3FE0000000000000;\n\
4673 mul.f64 %e_nf, %val, 0d3FF71547652B82FE;\n\
4674 cvt.rni.f64.f64 %e_nf, %e_nf;\n\
4675 cvt.rni.s32.f64 %e_ni, %e_nf;\n\
4676 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %val;\n\
4677 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
4678 mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
4679 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
4680 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
4681 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
4682 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
4683 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
4684 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
4685 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
4686 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
4687 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
4688 fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
4689 fma.rn.f64 %exp_val, %e_p, %e_r, %e_one;\n\
4690 cvt.s64.s32 %e_ni64, %e_ni;\n\
4691 add.s64 %e_ni64, %e_ni64, 1023;\n\
4692 shl.b64 %e_bits, %e_ni64, 52;\n\
4693 mov.b64 %e_nf, %e_bits;\n\
4694 mul.f64 %exp_val, %exp_val, %e_nf;\n\
4695 add.f64 %sum_val, %sum_val, %exp_val;\n\
4696 add.u32 %j, %j, %bdim;\n\
4697 bra SUM_EXP;\n\
4698SUM_EXP_DONE:\n\
4699\n\
4700 cvt.u64.u32 %off, %r_tid;\n\
4701 shl.b64 %off, %off, 3;\n\
4702 add.u64 %saddr, %sbase, %off;\n\
4703 st.shared.f64 [%saddr], %sum_val;\n\
4704 bar.sync 0;\n\
4705\n\
4706 mov.u32 %half, %bdim;\n\
4707SUM_REDUCE:\n\
4708 shr.u32 %half, %half, 1;\n\
4709 setp.eq.u32 %reduce_p, %half, 0;\n\
4710 @%reduce_p bra SUM_REDUCE_DONE;\n\
4711 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4712 @%reduce_p bra SUM_REDUCE_SKIP;\n\
4713 add.u32 %other_tid, %r_tid, %half;\n\
4714 cvt.u64.u32 %off, %other_tid;\n\
4715 shl.b64 %off, %off, 3;\n\
4716 add.u64 %saddr, %sbase, %off;\n\
4717 ld.shared.f64 %other_val, [%saddr];\n\
4718 cvt.u64.u32 %off, %r_tid;\n\
4719 shl.b64 %off, %off, 3;\n\
4720 add.u64 %saddr, %sbase, %off;\n\
4721 ld.shared.f64 %sum_val, [%saddr];\n\
4722 add.f64 %sum_val, %sum_val, %other_val;\n\
4723 st.shared.f64 [%saddr], %sum_val;\n\
4724SUM_REDUCE_SKIP:\n\
4725 bar.sync 0;\n\
4726 bra SUM_REDUCE;\n\
4727SUM_REDUCE_DONE:\n\
4728\n\
4729 ld.shared.f64 %sum_val, [sdata];\n\
4730 bar.sync 0;\n\
4731 mov.f64 %e_one, 0d3FF0000000000000;\n\
4732 mov.b64 %l_xbits, %sum_val;\n\
4733 shr.u64 %l_exp64, %l_xbits, 52;\n\
4734 and.b64 %l_exp64, %l_exp64, 2047;\n\
4735 sub.s64 %l_exp64, %l_exp64, 1023;\n\
4736 cvt.rn.f64.s64 %l_nf, %l_exp64;\n\
4737 mov.u64 %l_bias, 0x3FF0000000000000;\n\
4738 and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;\n\
4739 or.b64 %l_mbits, %l_mbits, %l_bias;\n\
4740 mov.b64 %l_m, %l_mbits;\n\
4741 sub.f64 %l_f, %l_m, %e_one;\n\
4742 add.f64 %l_s, %l_m, %e_one;\n\
4743 div.rn.f64 %l_f, %l_f, %l_s;\n\
4744 mul.f64 %l_f2, %l_f, %l_f;\n\
4745 mov.f64 %l_p, 0d3FB745D1745D1746;\n\
4746 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;\n\
4747 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;\n\
4748 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;\n\
4749 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;\n\
4750 fma.rn.f64 %l_p, %l_p, %l_f2, %e_one;\n\
4751 mul.f64 %l_p, %l_p, %l_f;\n\
4752 add.f64 %l_p, %l_p, %l_p;\n\
4753 mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;\n\
4754 fma.rn.f64 %log_sum_exp, %l_nf, %l_ln2, %l_p;\n\
4755 add.f64 %log_sum_exp, %max_val, %log_sum_exp;\n\
4756\n\
4757 mov.u32 %j, %r_tid;\n\
4758WRITE_OUTPUT:\n\
4759 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4760 @%loop_p bra WRITE_OUTPUT_DONE;\n\
4761 cvt.u64.u32 %off, %j;\n\
4762 shl.b64 %off, %off, 3;\n\
4763 add.u64 %saddr, %in, %off;\n\
4764 add.u64 %saddr, %saddr, %row_off;\n\
4765 ld.global.f64 %val, [%saddr];\n\
4766 sub.f64 %result, %val, %log_sum_exp;\n\
4767 cvt.u64.u32 %off, %j;\n\
4768 shl.b64 %off, %off, 3;\n\
4769 add.u64 %saddr, %out, %off;\n\
4770 add.u64 %saddr, %saddr, %row_off;\n\
4771 st.global.f64 [%saddr], %result;\n\
4772 add.u32 %j, %j, %bdim;\n\
4773 bra WRITE_OUTPUT;\n\
4774WRITE_OUTPUT_DONE:\n\
4775\n\
4776DONE:\n\
4777 ret;\n\
4778}\n\
4779";
4780
4781#[cfg(feature = "cuda")]
4791pub(crate) const LOG_SOFTMAX_BACKWARD_PTX: &str = "\
4792.version 7.0\n\
4793.target sm_52\n\
4794.address_size 64\n\
4795\n\
4796.shared .align 4 .f32 sdata[256];\n\
4797\n\
4798.visible .entry log_softmax_backward_kernel(\n\
4799 .param .u64 grad_ptr,\n\
4800 .param .u64 output_ptr,\n\
4801 .param .u64 out_ptr,\n\
4802 .param .u32 rows,\n\
4803 .param .u32 cols\n\
4804) {\n\
4805 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
4806 .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
4807 .reg .f32 %vg, %vo, %sum_grad, %other_val, %softmax_j, %result;\n\
4808 .reg .pred %p, %loop_p, %reduce_p;\n\
4809\n\
4810 ld.param.u64 %grad, [grad_ptr];\n\
4811 ld.param.u64 %output, [output_ptr];\n\
4812 ld.param.u64 %out, [out_ptr];\n\
4813 ld.param.u32 %rows_reg, [rows];\n\
4814 ld.param.u32 %cols_reg, [cols];\n\
4815\n\
4816 mov.u32 %bid, %ctaid.x;\n\
4817 mov.u32 %bdim, %ntid.x;\n\
4818 mov.u32 %r_tid, %tid.x;\n\
4819 mov.u64 %sbase, sdata;\n\
4820\n\
4821 setp.ge.u32 %p, %bid, %rows_reg;\n\
4822 @%p bra DONE;\n\
4823\n\
4824 // row_off = bid * cols * 4 (byte offset)\n\
4825 cvt.u64.u32 %row_off, %bid;\n\
4826 cvt.u64.u32 %off, %cols_reg;\n\
4827 mul.lo.u64 %row_off, %row_off, %off;\n\
4828 shl.b64 %row_off, %row_off, 2;\n\
4829\n\
4830 // Phase 1: compute partial sum_grad = sum(grad[j]) for this thread's elements\n\
4831 mov.f32 %sum_grad, 0f00000000;\n\
4832 mov.u32 %j, %r_tid;\n\
4833SUM_LOOP:\n\
4834 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4835 @%loop_p bra SUM_LOOP_DONE;\n\
4836 cvt.u64.u32 %off, %j;\n\
4837 shl.b64 %off, %off, 2;\n\
4838 add.u64 %saddr, %grad, %off;\n\
4839 add.u64 %saddr, %saddr, %row_off;\n\
4840 ld.global.f32 %vg, [%saddr];\n\
4841 add.f32 %sum_grad, %sum_grad, %vg;\n\
4842 add.u32 %j, %j, %bdim;\n\
4843 bra SUM_LOOP;\n\
4844SUM_LOOP_DONE:\n\
4845\n\
4846 // Store partial sum into shared memory and reduce\n\
4847 cvt.u64.u32 %off, %r_tid;\n\
4848 shl.b64 %off, %off, 2;\n\
4849 add.u64 %saddr, %sbase, %off;\n\
4850 st.shared.f32 [%saddr], %sum_grad;\n\
4851 bar.sync 0;\n\
4852\n\
4853 mov.u32 %half, %bdim;\n\
4854SUM_REDUCE:\n\
4855 shr.u32 %half, %half, 1;\n\
4856 setp.eq.u32 %reduce_p, %half, 0;\n\
4857 @%reduce_p bra SUM_REDUCE_DONE;\n\
4858 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4859 @%reduce_p bra SUM_REDUCE_SKIP;\n\
4860 add.u32 %other_tid, %r_tid, %half;\n\
4861 cvt.u64.u32 %off, %other_tid;\n\
4862 shl.b64 %off, %off, 2;\n\
4863 add.u64 %saddr, %sbase, %off;\n\
4864 ld.shared.f32 %other_val, [%saddr];\n\
4865 cvt.u64.u32 %off, %r_tid;\n\
4866 shl.b64 %off, %off, 2;\n\
4867 add.u64 %saddr, %sbase, %off;\n\
4868 ld.shared.f32 %sum_grad, [%saddr];\n\
4869 add.f32 %sum_grad, %sum_grad, %other_val;\n\
4870 st.shared.f32 [%saddr], %sum_grad;\n\
4871SUM_REDUCE_SKIP:\n\
4872 bar.sync 0;\n\
4873 bra SUM_REDUCE;\n\
4874SUM_REDUCE_DONE:\n\
4875\n\
4876 // Broadcast sum_grad to all threads\n\
4877 ld.shared.f32 %sum_grad, [sdata];\n\
4878 bar.sync 0;\n\
4879\n\
4880 // Phase 2: out[j] = grad[j] - exp(output[j]) * sum_grad\n\
4881 mov.u32 %j, %r_tid;\n\
4882WRITE_LOOP:\n\
4883 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4884 @%loop_p bra WRITE_LOOP_DONE;\n\
4885 cvt.u64.u32 %off, %j;\n\
4886 shl.b64 %off, %off, 2;\n\
4887 add.u64 %saddr, %grad, %off;\n\
4888 add.u64 %saddr, %saddr, %row_off;\n\
4889 ld.global.f32 %vg, [%saddr];\n\
4890 add.u64 %saddr, %output, %off;\n\
4891 add.u64 %saddr, %saddr, %row_off;\n\
4892 ld.global.f32 %vo, [%saddr];\n\
4893 // exp(log_softmax_output) = softmax probability\n\
4894 mul.f32 %vo, %vo, 0f3FB8AA3B;\n\
4895 ex2.approx.f32 %softmax_j, %vo;\n\
4896 // out[j] = grad[j] - softmax[j] * sum_grad\n\
4897 mul.f32 %result, %softmax_j, %sum_grad;\n\
4898 sub.f32 %result, %vg, %result;\n\
4899 add.u64 %saddr, %out, %off;\n\
4900 add.u64 %saddr, %saddr, %row_off;\n\
4901 st.global.f32 [%saddr], %result;\n\
4902 add.u32 %j, %j, %bdim;\n\
4903 bra WRITE_LOOP;\n\
4904WRITE_LOOP_DONE:\n\
4905\n\
4906DONE:\n\
4907 ret;\n\
4908}\n\
4909";
4910
4911#[cfg(feature = "cuda")]
4913pub(crate) const LOG_SOFTMAX_BACKWARD_F64_PTX: &str = "\
4914.version 7.0\n\
4915.target sm_52\n\
4916.address_size 64\n\
4917\n\
4918.shared .align 8 .f64 sdata[256];\n\
4919\n\
4920.visible .entry log_softmax_backward_f64_kernel(\n\
4921 .param .u64 grad_ptr,\n\
4922 .param .u64 output_ptr,\n\
4923 .param .u64 out_ptr,\n\
4924 .param .u32 rows,\n\
4925 .param .u32 cols\n\
4926) {\n\
4927 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
4928 .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
4929 .reg .f64 %vg, %vo, %sum_grad, %other_val, %softmax_j, %result;\n\
4930 .reg .pred %p, %loop_p, %reduce_p;\n\
4931 .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
4932 .reg .s32 %e_ni;\n\
4933 .reg .s64 %e_ni64, %e_bits;\n\
4934\n\
4935 ld.param.u64 %grad, [grad_ptr];\n\
4936 ld.param.u64 %output, [output_ptr];\n\
4937 ld.param.u64 %out, [out_ptr];\n\
4938 ld.param.u32 %rows_reg, [rows];\n\
4939 ld.param.u32 %cols_reg, [cols];\n\
4940\n\
4941 mov.u32 %bid, %ctaid.x;\n\
4942 mov.u32 %bdim, %ntid.x;\n\
4943 mov.u32 %r_tid, %tid.x;\n\
4944 mov.u64 %sbase, sdata;\n\
4945\n\
4946 setp.ge.u32 %p, %bid, %rows_reg;\n\
4947 @%p bra DONE;\n\
4948\n\
4949 cvt.u64.u32 %row_off, %bid;\n\
4950 cvt.u64.u32 %off, %cols_reg;\n\
4951 mul.lo.u64 %row_off, %row_off, %off;\n\
4952 shl.b64 %row_off, %row_off, 3;\n\
4953\n\
4954 mov.f64 %sum_grad, 0d0000000000000000;\n\
4955 mov.u32 %j, %r_tid;\n\
4956SUM_LOOP:\n\
4957 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4958 @%loop_p bra SUM_LOOP_DONE;\n\
4959 cvt.u64.u32 %off, %j;\n\
4960 shl.b64 %off, %off, 3;\n\
4961 add.u64 %saddr, %grad, %off;\n\
4962 add.u64 %saddr, %saddr, %row_off;\n\
4963 ld.global.f64 %vg, [%saddr];\n\
4964 add.f64 %sum_grad, %sum_grad, %vg;\n\
4965 add.u32 %j, %j, %bdim;\n\
4966 bra SUM_LOOP;\n\
4967SUM_LOOP_DONE:\n\
4968\n\
4969 cvt.u64.u32 %off, %r_tid;\n\
4970 shl.b64 %off, %off, 3;\n\
4971 add.u64 %saddr, %sbase, %off;\n\
4972 st.shared.f64 [%saddr], %sum_grad;\n\
4973 bar.sync 0;\n\
4974\n\
4975 mov.u32 %half, %bdim;\n\
4976SUM_REDUCE:\n\
4977 shr.u32 %half, %half, 1;\n\
4978 setp.eq.u32 %reduce_p, %half, 0;\n\
4979 @%reduce_p bra SUM_REDUCE_DONE;\n\
4980 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4981 @%reduce_p bra SUM_REDUCE_SKIP;\n\
4982 add.u32 %other_tid, %r_tid, %half;\n\
4983 cvt.u64.u32 %off, %other_tid;\n\
4984 shl.b64 %off, %off, 3;\n\
4985 add.u64 %saddr, %sbase, %off;\n\
4986 ld.shared.f64 %other_val, [%saddr];\n\
4987 cvt.u64.u32 %off, %r_tid;\n\
4988 shl.b64 %off, %off, 3;\n\
4989 add.u64 %saddr, %sbase, %off;\n\
4990 ld.shared.f64 %sum_grad, [%saddr];\n\
4991 add.f64 %sum_grad, %sum_grad, %other_val;\n\
4992 st.shared.f64 [%saddr], %sum_grad;\n\
4993SUM_REDUCE_SKIP:\n\
4994 bar.sync 0;\n\
4995 bra SUM_REDUCE;\n\
4996SUM_REDUCE_DONE:\n\
4997\n\
4998 ld.shared.f64 %sum_grad, [sdata];\n\
4999 bar.sync 0;\n\
5000\n\
5001 mov.u32 %j, %r_tid;\n\
5002WRITE_LOOP:\n\
5003 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
5004 @%loop_p bra WRITE_LOOP_DONE;\n\
5005 cvt.u64.u32 %off, %j;\n\
5006 shl.b64 %off, %off, 3;\n\
5007 add.u64 %saddr, %grad, %off;\n\
5008 add.u64 %saddr, %saddr, %row_off;\n\
5009 ld.global.f64 %vg, [%saddr];\n\
5010 add.u64 %saddr, %output, %off;\n\
5011 add.u64 %saddr, %saddr, %row_off;\n\
5012 ld.global.f64 %vo, [%saddr];\n\
5013 // exp(log_softmax_output) — inline f64 exp\n\
5014 mov.f64 %e_one, 0d3FF0000000000000;\n\
5015 mov.f64 %e_half, 0d3FE0000000000000;\n\
5016 mul.f64 %e_nf, %vo, 0d3FF71547652B82FE;\n\
5017 cvt.rni.f64.f64 %e_nf, %e_nf;\n\
5018 cvt.rni.s32.f64 %e_ni, %e_nf;\n\
5019 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %vo;\n\
5020 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
5021 mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
5022 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
5023 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
5024 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
5025 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
5026 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
5027 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
5028 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
5029 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
5030 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
5031 fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
5032 fma.rn.f64 %softmax_j, %e_p, %e_r, %e_one;\n\
5033 cvt.s64.s32 %e_ni64, %e_ni;\n\
5034 add.s64 %e_ni64, %e_ni64, 1023;\n\
5035 shl.b64 %e_bits, %e_ni64, 52;\n\
5036 mov.b64 %e_nf, %e_bits;\n\
5037 mul.f64 %softmax_j, %softmax_j, %e_nf;\n\
5038 mul.f64 %result, %softmax_j, %sum_grad;\n\
5039 sub.f64 %result, %vg, %result;\n\
5040 add.u64 %saddr, %out, %off;\n\
5041 add.u64 %saddr, %saddr, %row_off;\n\
5042 st.global.f64 [%saddr], %result;\n\
5043 add.u32 %j, %j, %bdim;\n\
5044 bra WRITE_LOOP;\n\
5045WRITE_LOOP_DONE:\n\
5046\n\
5047DONE:\n\
5048 ret;\n\
5049}\n\
5050";
5051
5052#[cfg(feature = "cuda")]
5066pub(crate) const REDUCE_SUM_PTX: &str = "\
5067.version 7.0
5068.target sm_52
5069.address_size 64
5070
5071// Shared memory for intra-block reduction (256 floats = 1024 bytes).
5072.shared .align 4 .f32 sdata[256];
5073
5074.visible .entry reduce_sum_kernel(
5075 .param .u64 in_ptr,
5076 .param .u64 out_ptr,
5077 .param .u32 n
5078) {
5079 .reg .u32 %tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
5080 .reg .u64 %in, %out, %off;
5081 .reg .f32 %sum, %other;
5082 .reg .pred %p, %ptid;
5083
5084 ld.param.u64 %in, [in_ptr];
5085 ld.param.u64 %out, [out_ptr];
5086 ld.param.u32 %n_reg, [n];
5087
5088 mov.u32 %tid, %tid.x;
5089 mov.u32 %bid, %ctaid.x;
5090 mov.u32 %bdim, %ntid.x;
5091 mov.u32 %gdim, %nctaid.x;
5092
5093 // Grid-stride accumulation: each thread sums multiple elements.
5094 // idx = bid * bdim + tid; stride = bdim * gdim
5095 mad.lo.u32 %idx, %bid, %bdim, %tid;
5096 mul.lo.u32 %stride, %bdim, %gdim;
5097 mov.f32 %sum, 0f00000000;
5098
5099GRID_LOOP:
5100 setp.ge.u32 %p, %idx, %n_reg;
5101 @%p bra GRID_DONE;
5102
5103 cvt.u64.u32 %off, %idx;
5104 shl.b64 %off, %off, 2;
5105 add.u64 %off, %in, %off;
5106 ld.global.f32 %other, [%off];
5107 add.f32 %sum, %sum, %other;
5108 add.u32 %idx, %idx, %stride;
5109 bra GRID_LOOP;
5110
5111GRID_DONE:
5112 // Write thread's partial sum to shared memory.
5113 cvt.u64.u32 %off, %tid;
5114 shl.b64 %off, %off, 2;
5115 st.shared.f32 [sdata + %off], %sum;
5116 bar.sync 0;
5117
5118 // Tree reduction in shared memory.
5119 mov.u32 %half, 128;
5120TREE_LOOP:
5121 setp.lt.u32 %p, %half, 1;
5122 @%p bra TREE_DONE;
5123
5124 setp.ge.u32 %ptid, %tid, %half;
5125 @%ptid bra TREE_SKIP;
5126
5127 // Load partner's value from sdata[tid + half].
5128 add.u32 %idx, %tid, %half;
5129 cvt.u64.u32 %off, %idx;
5130 shl.b64 %off, %off, 2;
5131 ld.shared.f32 %other, [sdata + %off];
5132 // Load own value.
5133 cvt.u64.u32 %off, %tid;
5134 shl.b64 %off, %off, 2;
5135 ld.shared.f32 %sum, [sdata + %off];
5136 add.f32 %sum, %sum, %other;
5137 st.shared.f32 [sdata + %off], %sum;
5138
5139TREE_SKIP:
5140 bar.sync 0;
5141 shr.u32 %half, %half, 1;
5142 bra TREE_LOOP;
5143
5144TREE_DONE:
5145 // Thread 0 writes block result.
5146 setp.ne.u32 %ptid, %tid, 0;
5147 @%ptid bra END;
5148
5149 ld.shared.f32 %sum, [sdata];
5150 cvt.u64.u32 %off, %bid;
5151 shl.b64 %off, %off, 2;
5152 add.u64 %out, %out, %off;
5153 st.global.f32 [%out], %sum;
5154
5155END:
5156 ret;
5157}
5158";
5159
5160
5161#[cfg(feature = "cuda")]
5166pub(crate) const SUM_AXIS_PTX: &str = "\
5167.version 7.0
5168.target sm_52
5169.address_size 64
5170
5171.visible .entry sum_axis_kernel(
5172 .param .u64 input_ptr,
5173 .param .u64 output_ptr,
5174 .param .u32 outer_size,
5175 .param .u32 axis_size,
5176 .param .u32 inner_size,
5177 .param .u32 total_output
5178) {
5179 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %axis_sz, %inner_sz;
5180 .reg .u32 %outer_idx, %inner_idx, %k, %tmp;
5181 .reg .u64 %in, %out, %off, %addr;
5182 .reg .f32 %val, %sum;
5183 .reg .pred %p, %lp;
5184
5185 ld.param.u64 %in, [input_ptr];
5186 ld.param.u64 %out, [output_ptr];
5187 ld.param.u32 %outer_sz, [outer_size];
5188 ld.param.u32 %axis_sz, [axis_size];
5189 ld.param.u32 %inner_sz, [inner_size];
5190 ld.param.u32 %n_reg, [total_output];
5191
5192 mov.u32 %bid, %ctaid.x;
5193 mov.u32 %bdim, %ntid.x;
5194 mov.u32 %r_tid, %tid.x;
5195 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5196
5197 setp.ge.u32 %p, %r_tid, %n_reg;
5198 @%p bra DONE;
5199
5200 // outer_idx = r_tid / inner_size
5201 div.u32 %outer_idx, %r_tid, %inner_sz;
5202 // inner_idx = r_tid % inner_size
5203 rem.u32 %inner_idx, %r_tid, %inner_sz;
5204
5205 // base = outer_idx * axis_size * inner_size + inner_idx
5206 mul.lo.u32 %tmp, %outer_idx, %axis_sz;
5207 mul.lo.u32 %tmp, %tmp, %inner_sz;
5208 add.u32 %tmp, %tmp, %inner_idx;
5209
5210 mov.f32 %sum, 0f00000000;
5211 mov.u32 %k, 0;
5212SUM_LOOP:
5213 setp.ge.u32 %lp, %k, %axis_sz;
5214 @%lp bra SUM_LOOP_DONE;
5215
5216 // addr = in + (tmp + k * inner_size) * 4
5217 mul.lo.u32 %inner_idx, %k, %inner_sz;
5218 add.u32 %inner_idx, %tmp, %inner_idx;
5219 cvt.u64.u32 %off, %inner_idx;
5220 shl.b64 %off, %off, 2;
5221 add.u64 %addr, %in, %off;
5222 ld.global.f32 %val, [%addr];
5223 add.f32 %sum, %sum, %val;
5224
5225 add.u32 %k, %k, 1;
5226 bra SUM_LOOP;
5227SUM_LOOP_DONE:
5228
5229 // output[r_tid] = sum
5230 cvt.u64.u32 %off, %r_tid;
5231 shl.b64 %off, %off, 2;
5232 add.u64 %addr, %out, %off;
5233 st.global.f32 [%addr], %sum;
5234
5235DONE:
5236 ret;
5237}
5238";
5239
5240#[cfg(feature = "cuda")]
5252pub(crate) const CUMSUM_PTX: &str = "\
5253.version 7.0
5254.target sm_52
5255.address_size 64
5256
5257.visible .entry cumsum_kernel(
5258 .param .u64 input_ptr,
5259 .param .u64 output_ptr,
5260 .param .u32 outer_size,
5261 .param .u32 dim_size,
5262 .param .u32 inner_size,
5263 .param .u32 total
5264) {
5265 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5266 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5267 .reg .u64 %in, %out, %off, %addr;
5268 .reg .f32 %val, %acc;
5269 .reg .pred %p, %lp;
5270
5271 ld.param.u64 %in, [input_ptr];
5272 ld.param.u64 %out, [output_ptr];
5273 ld.param.u32 %outer_sz, [outer_size];
5274 ld.param.u32 %dim_sz, [dim_size];
5275 ld.param.u32 %inner_sz, [inner_size];
5276 ld.param.u32 %n_reg, [total];
5277
5278 mov.u32 %bid, %ctaid.x;
5279 mov.u32 %bdim, %ntid.x;
5280 mov.u32 %r_tid, %tid.x;
5281 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5282
5283 // total threads = outer * inner
5284 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5285 setp.ge.u32 %p, %r_tid, %tmp;
5286 @%p bra DONE;
5287
5288 div.u32 %outer_idx, %r_tid, %inner_sz;
5289 rem.u32 %inner_idx, %r_tid, %inner_sz;
5290
5291 // base = outer_idx * dim_size * inner_size + inner_idx
5292 mul.lo.u32 %base, %outer_idx, %dim_sz;
5293 mul.lo.u32 %base, %base, %inner_sz;
5294 add.u32 %base, %base, %inner_idx;
5295
5296 mov.f32 %acc, 0f00000000;
5297 mov.u32 %k, 0;
5298SCAN_LOOP:
5299 setp.ge.u32 %lp, %k, %dim_sz;
5300 @%lp bra SCAN_DONE;
5301
5302 // idx = base + k * inner_size
5303 mul.lo.u32 %idx, %k, %inner_sz;
5304 add.u32 %idx, %base, %idx;
5305
5306 cvt.u64.u32 %off, %idx;
5307 shl.b64 %off, %off, 2;
5308 add.u64 %addr, %in, %off;
5309 ld.global.f32 %val, [%addr];
5310
5311 add.f32 %acc, %acc, %val;
5312
5313 add.u64 %addr, %out, %off;
5314 st.global.f32 [%addr], %acc;
5315
5316 add.u32 %k, %k, 1;
5317 bra SCAN_LOOP;
5318SCAN_DONE:
5319
5320DONE:
5321 ret;
5322}
5323";
5324
5325
5326#[cfg(feature = "cuda")]
5331pub(crate) const CUMPROD_PTX: &str = "\
5332.version 7.0
5333.target sm_52
5334.address_size 64
5335
5336.visible .entry cumprod_kernel(
5337 .param .u64 input_ptr,
5338 .param .u64 output_ptr,
5339 .param .u32 outer_size,
5340 .param .u32 dim_size,
5341 .param .u32 inner_size,
5342 .param .u32 total
5343) {
5344 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5345 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5346 .reg .u64 %in, %out, %off, %addr;
5347 .reg .f32 %val, %acc;
5348 .reg .pred %p, %lp;
5349
5350 ld.param.u64 %in, [input_ptr];
5351 ld.param.u64 %out, [output_ptr];
5352 ld.param.u32 %outer_sz, [outer_size];
5353 ld.param.u32 %dim_sz, [dim_size];
5354 ld.param.u32 %inner_sz, [inner_size];
5355 ld.param.u32 %n_reg, [total];
5356
5357 mov.u32 %bid, %ctaid.x;
5358 mov.u32 %bdim, %ntid.x;
5359 mov.u32 %r_tid, %tid.x;
5360 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5361
5362 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5363 setp.ge.u32 %p, %r_tid, %tmp;
5364 @%p bra DONE;
5365
5366 div.u32 %outer_idx, %r_tid, %inner_sz;
5367 rem.u32 %inner_idx, %r_tid, %inner_sz;
5368
5369 mul.lo.u32 %base, %outer_idx, %dim_sz;
5370 mul.lo.u32 %base, %base, %inner_sz;
5371 add.u32 %base, %base, %inner_idx;
5372
5373 // acc = 1.0
5374 mov.f32 %acc, 0f3F800000;
5375 mov.u32 %k, 0;
5376SCAN_LOOP:
5377 setp.ge.u32 %lp, %k, %dim_sz;
5378 @%lp bra SCAN_DONE;
5379
5380 mul.lo.u32 %idx, %k, %inner_sz;
5381 add.u32 %idx, %base, %idx;
5382
5383 cvt.u64.u32 %off, %idx;
5384 shl.b64 %off, %off, 2;
5385 add.u64 %addr, %in, %off;
5386 ld.global.f32 %val, [%addr];
5387
5388 mul.f32 %acc, %acc, %val;
5389
5390 add.u64 %addr, %out, %off;
5391 st.global.f32 [%addr], %acc;
5392
5393 add.u32 %k, %k, 1;
5394 bra SCAN_LOOP;
5395SCAN_DONE:
5396
5397DONE:
5398 ret;
5399}
5400";
5401
5402
5403#[cfg(feature = "cuda")]
5410pub(crate) const CUMMAX_PTX: &str = "\
5411.version 7.0
5412.target sm_52
5413.address_size 64
5414
5415.visible .entry cummax_kernel(
5416 .param .u64 input_ptr,
5417 .param .u64 output_ptr,
5418 .param .u64 indices_ptr,
5419 .param .u32 outer_size,
5420 .param .u32 dim_size,
5421 .param .u32 inner_size,
5422 .param .u32 total
5423) {
5424 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5425 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
5426 .reg .u64 %in, %out, %ind, %off, %addr;
5427 .reg .f32 %val, %acc, %best_k_f;
5428 .reg .pred %p, %lp, %is_new_max;
5429
5430 ld.param.u64 %in, [input_ptr];
5431 ld.param.u64 %out, [output_ptr];
5432 ld.param.u64 %ind, [indices_ptr];
5433 ld.param.u32 %outer_sz, [outer_size];
5434 ld.param.u32 %dim_sz, [dim_size];
5435 ld.param.u32 %inner_sz, [inner_size];
5436 ld.param.u32 %n_reg, [total];
5437
5438 mov.u32 %bid, %ctaid.x;
5439 mov.u32 %bdim, %ntid.x;
5440 mov.u32 %r_tid, %tid.x;
5441 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5442
5443 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5444 setp.ge.u32 %p, %r_tid, %tmp;
5445 @%p bra DONE;
5446
5447 div.u32 %outer_idx, %r_tid, %inner_sz;
5448 rem.u32 %inner_idx, %r_tid, %inner_sz;
5449
5450 mul.lo.u32 %base, %outer_idx, %dim_sz;
5451 mul.lo.u32 %base, %base, %inner_sz;
5452 add.u32 %base, %base, %inner_idx;
5453
5454 mov.b32 %acc, 0xFF800000;
5455 mov.u32 %best_k, 0;
5456 mov.u32 %k, 0;
5457SCAN_LOOP:
5458 setp.ge.u32 %lp, %k, %dim_sz;
5459 @%lp bra SCAN_DONE;
5460
5461 mul.lo.u32 %idx, %k, %inner_sz;
5462 add.u32 %idx, %base, %idx;
5463
5464 cvt.u64.u32 %off, %idx;
5465 shl.b64 %off, %off, 2;
5466 add.u64 %addr, %in, %off;
5467 ld.global.f32 %val, [%addr];
5468
5469 setp.gt.f32 %is_new_max, %val, %acc;
5470 @%is_new_max mov.u32 %best_k, %k;
5471 max.f32 %acc, %acc, %val;
5472
5473 add.u64 %addr, %out, %off;
5474 st.global.f32 [%addr], %acc;
5475
5476 cvt.rn.f32.u32 %best_k_f, %best_k;
5477 add.u64 %addr, %ind, %off;
5478 st.global.f32 [%addr], %best_k_f;
5479
5480 add.u32 %k, %k, 1;
5481 bra SCAN_LOOP;
5482SCAN_DONE:
5483
5484DONE:
5485 ret;
5486}
5487";
5488
5489
5490#[cfg(feature = "cuda")]
5495pub(crate) const CUMMIN_PTX: &str = "\
5496.version 7.0
5497.target sm_52
5498.address_size 64
5499
5500.visible .entry cummin_kernel(
5501 .param .u64 input_ptr,
5502 .param .u64 output_ptr,
5503 .param .u64 indices_ptr,
5504 .param .u32 outer_size,
5505 .param .u32 dim_size,
5506 .param .u32 inner_size,
5507 .param .u32 total
5508) {
5509 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5510 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
5511 .reg .u64 %in, %out, %ind, %off, %addr;
5512 .reg .f32 %val, %acc, %best_k_f;
5513 .reg .pred %p, %lp, %is_new_min;
5514
5515 ld.param.u64 %in, [input_ptr];
5516 ld.param.u64 %out, [output_ptr];
5517 ld.param.u64 %ind, [indices_ptr];
5518 ld.param.u32 %outer_sz, [outer_size];
5519 ld.param.u32 %dim_sz, [dim_size];
5520 ld.param.u32 %inner_sz, [inner_size];
5521 ld.param.u32 %n_reg, [total];
5522
5523 mov.u32 %bid, %ctaid.x;
5524 mov.u32 %bdim, %ntid.x;
5525 mov.u32 %r_tid, %tid.x;
5526 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5527
5528 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5529 setp.ge.u32 %p, %r_tid, %tmp;
5530 @%p bra DONE;
5531
5532 div.u32 %outer_idx, %r_tid, %inner_sz;
5533 rem.u32 %inner_idx, %r_tid, %inner_sz;
5534
5535 mul.lo.u32 %base, %outer_idx, %dim_sz;
5536 mul.lo.u32 %base, %base, %inner_sz;
5537 add.u32 %base, %base, %inner_idx;
5538
5539 mov.b32 %acc, 0x7F800000;
5540 mov.u32 %best_k, 0;
5541 mov.u32 %k, 0;
5542SCAN_LOOP:
5543 setp.ge.u32 %lp, %k, %dim_sz;
5544 @%lp bra SCAN_DONE;
5545
5546 mul.lo.u32 %idx, %k, %inner_sz;
5547 add.u32 %idx, %base, %idx;
5548
5549 cvt.u64.u32 %off, %idx;
5550 shl.b64 %off, %off, 2;
5551 add.u64 %addr, %in, %off;
5552 ld.global.f32 %val, [%addr];
5553
5554 setp.lt.f32 %is_new_min, %val, %acc;
5555 @%is_new_min mov.u32 %best_k, %k;
5556 min.f32 %acc, %acc, %val;
5557
5558 add.u64 %addr, %out, %off;
5559 st.global.f32 [%addr], %acc;
5560
5561 cvt.rn.f32.u32 %best_k_f, %best_k;
5562 add.u64 %addr, %ind, %off;
5563 st.global.f32 [%addr], %best_k_f;
5564
5565 add.u32 %k, %k, 1;
5566 bra SCAN_LOOP;
5567SCAN_DONE:
5568
5569DONE:
5570 ret;
5571}
5572";
5573
5574
5575#[cfg(feature = "cuda")]
5584pub(crate) const LOGCUMSUMEXP_PTX: &str = "\
5585.version 7.0
5586.target sm_52
5587.address_size 64
5588
5589.visible .entry logcumsumexp_kernel(
5590 .param .u64 input_ptr,
5591 .param .u64 output_ptr,
5592 .param .u32 outer_size,
5593 .param .u32 dim_size,
5594 .param .u32 inner_size,
5595 .param .u32 total
5596) {
5597 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5598 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5599 .reg .u64 %in, %out, %off, %addr;
5600 .reg .f32 %val, %acc, %m, %ea, %ev, %s, %ls, %log2e, %ln2;
5601 .reg .pred %p, %lp;
5602
5603 ld.param.u64 %in, [input_ptr];
5604 ld.param.u64 %out, [output_ptr];
5605 ld.param.u32 %outer_sz, [outer_size];
5606 ld.param.u32 %dim_sz, [dim_size];
5607 ld.param.u32 %inner_sz, [inner_size];
5608 ld.param.u32 %n_reg, [total];
5609
5610 // log2(e) = 1.4426950408... -> 0x3FB8AA3B
5611 mov.b32 %log2e, 0x3FB8AA3B;
5612 // ln(2) = 0.6931471805... -> 0x3F317218
5613 mov.b32 %ln2, 0x3F317218;
5614
5615 mov.u32 %bid, %ctaid.x;
5616 mov.u32 %bdim, %ntid.x;
5617 mov.u32 %r_tid, %tid.x;
5618 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5619
5620 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5621 setp.ge.u32 %p, %r_tid, %tmp;
5622 @%p bra DONE;
5623
5624 div.u32 %outer_idx, %r_tid, %inner_sz;
5625 rem.u32 %inner_idx, %r_tid, %inner_sz;
5626
5627 mul.lo.u32 %base, %outer_idx, %dim_sz;
5628 mul.lo.u32 %base, %base, %inner_sz;
5629 add.u32 %base, %base, %inner_idx;
5630
5631 // acc = -inf
5632 mov.b32 %acc, 0xFF800000;
5633 mov.u32 %k, 0;
5634SCAN_LOOP:
5635 setp.ge.u32 %lp, %k, %dim_sz;
5636 @%lp bra SCAN_DONE;
5637
5638 mul.lo.u32 %idx, %k, %inner_sz;
5639 add.u32 %idx, %base, %idx;
5640
5641 cvt.u64.u32 %off, %idx;
5642 shl.b64 %off, %off, 2;
5643 add.u64 %addr, %in, %off;
5644 ld.global.f32 %val, [%addr];
5645
5646 // Numerically stable: m = max(acc, x)
5647 max.f32 %m, %acc, %val;
5648 // exp(acc - m): (acc - m) * log2(e) -> ex2
5649 sub.f32 %ea, %acc, %m;
5650 mul.f32 %ea, %ea, %log2e;
5651 ex2.approx.f32 %ea, %ea;
5652 // exp(x - m): (x - m) * log2(e) -> ex2
5653 sub.f32 %ev, %val, %m;
5654 mul.f32 %ev, %ev, %log2e;
5655 ex2.approx.f32 %ev, %ev;
5656 // sum
5657 add.f32 %s, %ea, %ev;
5658 // log(sum) = lg2(sum) * ln(2)
5659 lg2.approx.f32 %ls, %s;
5660 mul.f32 %ls, %ls, %ln2;
5661 // acc = m + log(sum)
5662 add.f32 %acc, %m, %ls;
5663
5664 add.u64 %addr, %out, %off;
5665 st.global.f32 [%addr], %acc;
5666
5667 add.u32 %k, %k, 1;
5668 bra SCAN_LOOP;
5669SCAN_DONE:
5670
5671DONE:
5672 ret;
5673}
5674";
5675
5676#[cfg(feature = "cuda")]
5678pub(crate) const LOGCUMSUMEXP_F64_PTX: &str = "\
5679.version 7.0
5680.target sm_52
5681.address_size 64
5682
5683.visible .entry logcumsumexp_f64_kernel(
5684 .param .u64 input_ptr,
5685 .param .u64 output_ptr,
5686 .param .u32 outer_size,
5687 .param .u32 dim_size,
5688 .param .u32 inner_size,
5689 .param .u32 total
5690) {
5691 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5692 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5693 .reg .u64 %in, %out, %off, %addr;
5694 .reg .f64 %val, %acc, %m, %ea, %ev, %s, %ls;
5695 .reg .pred %p, %lp;
5696 .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;
5697 .reg .s32 %e_ni;
5698 .reg .s64 %e_ni64, %e_bits;
5699 .reg .u64 %l_xbits, %l_mbits, %l_bias;
5700 .reg .s64 %l_exp64;
5701 .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
5702
5703 ld.param.u64 %in, [input_ptr];
5704 ld.param.u64 %out, [output_ptr];
5705 ld.param.u32 %outer_sz, [outer_size];
5706 ld.param.u32 %dim_sz, [dim_size];
5707 ld.param.u32 %inner_sz, [inner_size];
5708 ld.param.u32 %n_reg, [total];
5709
5710 mov.u32 %bid, %ctaid.x;
5711 mov.u32 %bdim, %ntid.x;
5712 mov.u32 %r_tid, %tid.x;
5713 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5714
5715 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5716 setp.ge.u32 %p, %r_tid, %tmp;
5717 @%p bra DONE;
5718
5719 div.u32 %outer_idx, %r_tid, %inner_sz;
5720 rem.u32 %inner_idx, %r_tid, %inner_sz;
5721
5722 mul.lo.u32 %base, %outer_idx, %dim_sz;
5723 mul.lo.u32 %base, %base, %inner_sz;
5724 add.u32 %base, %base, %inner_idx;
5725
5726 // acc = -inf
5727 mov.b64 %acc, 0xFFF0000000000000;
5728 mov.u32 %k, 0;
5729SCAN_LOOP:
5730 setp.ge.u32 %lp, %k, %dim_sz;
5731 @%lp bra SCAN_DONE;
5732
5733 mul.lo.u32 %idx, %k, %inner_sz;
5734 add.u32 %idx, %base, %idx;
5735
5736 cvt.u64.u32 %off, %idx;
5737 shl.b64 %off, %off, 3;
5738 add.u64 %addr, %in, %off;
5739 ld.global.f64 %val, [%addr];
5740
5741 max.f64 %m, %acc, %val;
5742 mov.f64 %e_one, 0d3FF0000000000000;
5743 mov.f64 %e_half, 0d3FE0000000000000;
5744 // --- inline exp(acc - m) -> %ea ---
5745 sub.f64 %ea, %acc, %m;
5746 mul.f64 %e_nf, %ea, 0d3FF71547652B82FE;
5747 cvt.rni.f64.f64 %e_nf, %e_nf;
5748 cvt.rni.s32.f64 %e_ni, %e_nf;
5749 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %ea;
5750 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
5751 mov.f64 %e_p, 0d3E21EED8EFF8D898;
5752 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
5753 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
5754 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
5755 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
5756 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
5757 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
5758 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
5759 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
5760 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
5761 fma.rn.f64 %e_p, %e_p, %e_r, %e_one;
5762 fma.rn.f64 %ea, %e_p, %e_r, %e_one;
5763 cvt.s64.s32 %e_ni64, %e_ni;
5764 add.s64 %e_ni64, %e_ni64, 1023;
5765 shl.b64 %e_bits, %e_ni64, 52;
5766 mov.b64 %e_nf, %e_bits;
5767 mul.f64 %ea, %ea, %e_nf;
5768 // --- inline exp(val - m) -> %ev ---
5769 sub.f64 %ev, %val, %m;
5770 mul.f64 %e_nf, %ev, 0d3FF71547652B82FE;
5771 cvt.rni.f64.f64 %e_nf, %e_nf;
5772 cvt.rni.s32.f64 %e_ni, %e_nf;
5773 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %ev;
5774 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
5775 mov.f64 %e_p, 0d3E21EED8EFF8D898;
5776 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
5777 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
5778 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
5779 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
5780 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
5781 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
5782 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
5783 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
5784 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
5785 fma.rn.f64 %e_p, %e_p, %e_r, %e_one;
5786 fma.rn.f64 %ev, %e_p, %e_r, %e_one;
5787 cvt.s64.s32 %e_ni64, %e_ni;
5788 add.s64 %e_ni64, %e_ni64, 1023;
5789 shl.b64 %e_bits, %e_ni64, 52;
5790 mov.b64 %e_nf, %e_bits;
5791 mul.f64 %ev, %ev, %e_nf;
5792 add.f64 %s, %ea, %ev;
5793 // --- inline ln(%s) -> %ls ---
5794 mov.b64 %l_xbits, %s;
5795 shr.u64 %l_exp64, %l_xbits, 52;
5796 and.b64 %l_exp64, %l_exp64, 2047;
5797 sub.s64 %l_exp64, %l_exp64, 1023;
5798 cvt.rn.f64.s64 %l_nf, %l_exp64;
5799 mov.u64 %l_bias, 0x3FF0000000000000;
5800 and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
5801 or.b64 %l_mbits, %l_mbits, %l_bias;
5802 mov.b64 %l_m, %l_mbits;
5803 sub.f64 %l_f, %l_m, %e_one;
5804 add.f64 %l_s, %l_m, %e_one;
5805 div.rn.f64 %l_f, %l_f, %l_s;
5806 mul.f64 %l_f2, %l_f, %l_f;
5807 mov.f64 %l_p, 0d3FB745D1745D1746;
5808 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
5809 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
5810 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
5811 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
5812 fma.rn.f64 %l_p, %l_p, %l_f2, %e_one;
5813 mul.f64 %l_p, %l_p, %l_f;
5814 add.f64 %l_p, %l_p, %l_p;
5815 mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
5816 fma.rn.f64 %ls, %l_nf, %l_ln2, %l_p;
5817 add.f64 %acc, %m, %ls;
5818
5819 add.u64 %addr, %out, %off;
5820 st.global.f64 [%addr], %acc;
5821
5822 add.u32 %k, %k, 1;
5823 bra SCAN_LOOP;
5824SCAN_DONE:
5825
5826DONE:
5827 ret;
5828}
5829";
5830
5831#[cfg(feature = "cuda")]
5841pub(crate) const LAYERNORM_PTX: &str = "\
5842.version 7.0
5843.target sm_52
5844.address_size 64
5845
5846.shared .align 4 .f32 sdata[256];
5847
5848.visible .entry layernorm_kernel(
5849 .param .u64 in_ptr,
5850 .param .u64 out_ptr,
5851 .param .u64 w_ptr,
5852 .param .u64 b_ptr,
5853 .param .u32 rows,
5854 .param .u32 cols,
5855 .param .f32 eps
5856) {
5857 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
5858 .reg .u64 %in, %out, %w, %b, %row_off, %off, %sbase, %saddr;
5859 .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %normed, %wv, %bv, %result, %other_val, %n_f;
5860 .reg .pred %p, %lp, %rp;
5861
5862 ld.param.u64 %in, [in_ptr];
5863 ld.param.u64 %out, [out_ptr];
5864 ld.param.u64 %w, [w_ptr];
5865 ld.param.u64 %b, [b_ptr];
5866 ld.param.u32 %rows_reg, [rows];
5867 ld.param.u32 %cols_reg, [cols];
5868 ld.param.f32 %eps_r, [eps];
5869
5870 mov.u64 %sbase, sdata;
5871
5872 mov.u32 %r_bid, %ctaid.x;
5873 mov.u32 %r_bdim, %ntid.x;
5874 mov.u32 %r_tid, %tid.x;
5875
5876 setp.ge.u32 %p, %r_bid, %rows_reg;
5877 @%p bra DONE;
5878
5879 cvt.u64.u32 %row_off, %r_bid;
5880 cvt.u64.u32 %off, %cols_reg;
5881 mul.lo.u64 %row_off, %row_off, %off;
5882 shl.b64 %row_off, %row_off, 2;
5883 cvt.rn.f32.u32 %n_f, %cols_reg;
5884
5885 mov.f32 %mean, 0f00000000;
5886 mov.u32 %j, %r_tid;
5887SM:
5888 setp.ge.u32 %lp, %j, %cols_reg;
5889 @%lp bra SMD;
5890 cvt.u64.u32 %off, %j;
5891 shl.b64 %off, %off, 2;
5892 add.u64 %off, %in, %off;
5893 add.u64 %off, %off, %row_off;
5894 ld.global.f32 %val, [%off];
5895 add.f32 %mean, %mean, %val;
5896 add.u32 %j, %j, %r_bdim;
5897 bra SM;
5898SMD:
5899 cvt.u64.u32 %off, %r_tid;
5900 shl.b64 %off, %off, 2;
5901 add.u64 %saddr, %sbase, %off;
5902 st.shared.f32 [%saddr], %mean;
5903 bar.sync 0;
5904 mov.u32 %half, %r_bdim;
5905MR:
5906 shr.u32 %half, %half, 1;
5907 setp.eq.u32 %rp, %half, 0;
5908 @%rp bra MRD;
5909 setp.ge.u32 %rp, %r_tid, %half;
5910 @%rp bra MRS;
5911 add.u32 %r_otid, %r_tid, %half;
5912 cvt.u64.u32 %off, %r_otid;
5913 shl.b64 %off, %off, 2;
5914 add.u64 %saddr, %sbase, %off;
5915 ld.shared.f32 %other_val, [%saddr];
5916 cvt.u64.u32 %off, %r_tid;
5917 shl.b64 %off, %off, 2;
5918 add.u64 %saddr, %sbase, %off;
5919 ld.shared.f32 %mean, [%saddr];
5920 add.f32 %mean, %mean, %other_val;
5921 add.u64 %saddr, %sbase, %off;
5922 st.shared.f32 [%saddr], %mean;
5923MRS:
5924 bar.sync 0;
5925 bra MR;
5926MRD:
5927 ld.shared.f32 %mean, [%sbase];
5928 div.approx.f32 %mean, %mean, %n_f;
5929 bar.sync 0;
5930
5931 mov.f32 %var, 0f00000000;
5932 mov.u32 %j, %r_tid;
5933SV:
5934 setp.ge.u32 %lp, %j, %cols_reg;
5935 @%lp bra SVD;
5936 cvt.u64.u32 %off, %j;
5937 shl.b64 %off, %off, 2;
5938 add.u64 %off, %in, %off;
5939 add.u64 %off, %off, %row_off;
5940 ld.global.f32 %val, [%off];
5941 sub.f32 %diff, %val, %mean;
5942 fma.rn.f32 %var, %diff, %diff, %var;
5943 add.u32 %j, %j, %r_bdim;
5944 bra SV;
5945SVD:
5946 cvt.u64.u32 %off, %r_tid;
5947 shl.b64 %off, %off, 2;
5948 add.u64 %saddr, %sbase, %off;
5949 st.shared.f32 [%saddr], %var;
5950 bar.sync 0;
5951 mov.u32 %half, %r_bdim;
5952VR:
5953 shr.u32 %half, %half, 1;
5954 setp.eq.u32 %rp, %half, 0;
5955 @%rp bra VRD;
5956 setp.ge.u32 %rp, %r_tid, %half;
5957 @%rp bra VRS;
5958 add.u32 %r_otid, %r_tid, %half;
5959 cvt.u64.u32 %off, %r_otid;
5960 shl.b64 %off, %off, 2;
5961 add.u64 %saddr, %sbase, %off;
5962 ld.shared.f32 %other_val, [%saddr];
5963 cvt.u64.u32 %off, %r_tid;
5964 shl.b64 %off, %off, 2;
5965 add.u64 %saddr, %sbase, %off;
5966 ld.shared.f32 %var, [%saddr];
5967 add.f32 %var, %var, %other_val;
5968 add.u64 %saddr, %sbase, %off;
5969 st.shared.f32 [%saddr], %var;
5970VRS:
5971 bar.sync 0;
5972 bra VR;
5973VRD:
5974 ld.shared.f32 %var, [%sbase];
5975 div.approx.f32 %var, %var, %n_f;
5976 add.f32 %var, %var, %eps_r;
5977 sqrt.approx.f32 %inv_std, %var;
5978 rcp.approx.f32 %inv_std, %inv_std;
5979 bar.sync 0;
5980
5981 mov.u32 %j, %r_tid;
5982NM:
5983 setp.ge.u32 %lp, %j, %cols_reg;
5984 @%lp bra NMD;
5985 cvt.u64.u32 %off, %j;
5986 shl.b64 %off, %off, 2;
5987 add.u64 %off, %in, %off;
5988 add.u64 %off, %off, %row_off;
5989 ld.global.f32 %val, [%off];
5990 sub.f32 %normed, %val, %mean;
5991 mul.f32 %normed, %normed, %inv_std;
5992 cvt.u64.u32 %off, %j;
5993 shl.b64 %off, %off, 2;
5994 add.u64 %off, %w, %off;
5995 ld.global.f32 %wv, [%off];
5996 cvt.u64.u32 %off, %j;
5997 shl.b64 %off, %off, 2;
5998 add.u64 %off, %b, %off;
5999 ld.global.f32 %bv, [%off];
6000 fma.rn.f32 %result, %wv, %normed, %bv;
6001 cvt.u64.u32 %off, %j;
6002 shl.b64 %off, %off, 2;
6003 add.u64 %off, %out, %off;
6004 add.u64 %off, %off, %row_off;
6005 st.global.f32 [%off], %result;
6006 add.u32 %j, %j, %r_bdim;
6007 bra NM;
6008NMD:
6009
6010DONE:
6011 ret;
6012}
6013";
6014
6015
6016#[cfg(feature = "cuda")]
6041pub(crate) const LAYERNORM_BACKWARD_PTX: &str = "\
6042.version 7.0
6043.target sm_52
6044.address_size 64
6045
6046.shared .align 4 .f32 sdata[256];
6047
6048.visible .entry layernorm_backward_kernel(
6049 .param .u64 in_ptr,
6050 .param .u64 grad_out_ptr,
6051 .param .u64 w_ptr,
6052 .param .u64 grad_in_ptr,
6053 .param .u64 grad_w_ptr,
6054 .param .u64 grad_b_ptr,
6055 .param .u32 rows,
6056 .param .u32 cols,
6057 .param .f32 eps
6058) {
6059 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
6060 .reg .u64 %in, %go, %w, %gi, %gw, %gb, %row_off, %off, %sbase, %saddr, %addr;
6061 .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %x_hat, %wv, %gov;
6062 .reg .f32 %dl_dx_hat, %sum1, %sum2, %other_val, %n_f, %mean1, %mean2, %result;
6063 .reg .pred %p, %lp, %rp;
6064
6065 ld.param.u64 %in, [in_ptr];
6066 ld.param.u64 %go, [grad_out_ptr];
6067 ld.param.u64 %w, [w_ptr];
6068 ld.param.u64 %gi, [grad_in_ptr];
6069 ld.param.u64 %gw, [grad_w_ptr];
6070 ld.param.u64 %gb, [grad_b_ptr];
6071 ld.param.u32 %rows_reg, [rows];
6072 ld.param.u32 %cols_reg, [cols];
6073 ld.param.f32 %eps_r, [eps];
6074
6075 mov.u64 %sbase, sdata;
6076
6077 mov.u32 %r_bid, %ctaid.x;
6078 mov.u32 %r_bdim, %ntid.x;
6079 mov.u32 %r_tid, %tid.x;
6080
6081 setp.ge.u32 %p, %r_bid, %rows_reg;
6082 @%p bra LNB_DONE;
6083
6084 // row_off = bid * cols * 4 (byte offset for this row)
6085 cvt.u64.u32 %row_off, %r_bid;
6086 cvt.u64.u32 %off, %cols_reg;
6087 mul.lo.u64 %row_off, %row_off, %off;
6088 shl.b64 %row_off, %row_off, 2;
6089 cvt.rn.f32.u32 %n_f, %cols_reg;
6090
6091 // ===== Phase 1: Compute mean =====
6092 mov.f32 %mean, 0f00000000;
6093 mov.u32 %j, %r_tid;
6094LNB_SM:
6095 setp.ge.u32 %lp, %j, %cols_reg;
6096 @%lp bra LNB_SMD;
6097 cvt.u64.u32 %off, %j;
6098 shl.b64 %off, %off, 2;
6099 add.u64 %addr, %in, %off;
6100 add.u64 %addr, %addr, %row_off;
6101 ld.global.f32 %val, [%addr];
6102 add.f32 %mean, %mean, %val;
6103 add.u32 %j, %j, %r_bdim;
6104 bra LNB_SM;
6105LNB_SMD:
6106 // Shared memory reduce for mean
6107 cvt.u64.u32 %off, %r_tid;
6108 shl.b64 %off, %off, 2;
6109 add.u64 %saddr, %sbase, %off;
6110 st.shared.f32 [%saddr], %mean;
6111 bar.sync 0;
6112 mov.u32 %half, %r_bdim;
6113LNB_MR:
6114 shr.u32 %half, %half, 1;
6115 setp.eq.u32 %rp, %half, 0;
6116 @%rp bra LNB_MRD;
6117 setp.ge.u32 %rp, %r_tid, %half;
6118 @%rp bra LNB_MRS;
6119 add.u32 %r_otid, %r_tid, %half;
6120 cvt.u64.u32 %off, %r_otid;
6121 shl.b64 %off, %off, 2;
6122 add.u64 %saddr, %sbase, %off;
6123 ld.shared.f32 %other_val, [%saddr];
6124 cvt.u64.u32 %off, %r_tid;
6125 shl.b64 %off, %off, 2;
6126 add.u64 %saddr, %sbase, %off;
6127 ld.shared.f32 %mean, [%saddr];
6128 add.f32 %mean, %mean, %other_val;
6129 st.shared.f32 [%saddr], %mean;
6130LNB_MRS:
6131 bar.sync 0;
6132 bra LNB_MR;
6133LNB_MRD:
6134 ld.shared.f32 %mean, [%sbase];
6135 div.approx.f32 %mean, %mean, %n_f;
6136 bar.sync 0;
6137
6138 // ===== Phase 2: Compute variance =====
6139 mov.f32 %var, 0f00000000;
6140 mov.u32 %j, %r_tid;
6141LNB_SV:
6142 setp.ge.u32 %lp, %j, %cols_reg;
6143 @%lp bra LNB_SVD;
6144 cvt.u64.u32 %off, %j;
6145 shl.b64 %off, %off, 2;
6146 add.u64 %addr, %in, %off;
6147 add.u64 %addr, %addr, %row_off;
6148 ld.global.f32 %val, [%addr];
6149 sub.f32 %diff, %val, %mean;
6150 fma.rn.f32 %var, %diff, %diff, %var;
6151 add.u32 %j, %j, %r_bdim;
6152 bra LNB_SV;
6153LNB_SVD:
6154 // Shared memory reduce for variance
6155 cvt.u64.u32 %off, %r_tid;
6156 shl.b64 %off, %off, 2;
6157 add.u64 %saddr, %sbase, %off;
6158 st.shared.f32 [%saddr], %var;
6159 bar.sync 0;
6160 mov.u32 %half, %r_bdim;
6161LNB_VR:
6162 shr.u32 %half, %half, 1;
6163 setp.eq.u32 %rp, %half, 0;
6164 @%rp bra LNB_VRD;
6165 setp.ge.u32 %rp, %r_tid, %half;
6166 @%rp bra LNB_VRS;
6167 add.u32 %r_otid, %r_tid, %half;
6168 cvt.u64.u32 %off, %r_otid;
6169 shl.b64 %off, %off, 2;
6170 add.u64 %saddr, %sbase, %off;
6171 ld.shared.f32 %other_val, [%saddr];
6172 cvt.u64.u32 %off, %r_tid;
6173 shl.b64 %off, %off, 2;
6174 add.u64 %saddr, %sbase, %off;
6175 ld.shared.f32 %var, [%saddr];
6176 add.f32 %var, %var, %other_val;
6177 st.shared.f32 [%saddr], %var;
6178LNB_VRS:
6179 bar.sync 0;
6180 bra LNB_VR;
6181LNB_VRD:
6182 ld.shared.f32 %var, [%sbase];
6183 div.approx.f32 %var, %var, %n_f;
6184 add.f32 %var, %var, %eps_r;
6185 sqrt.approx.f32 %inv_std, %var;
6186 rcp.approx.f32 %inv_std, %inv_std;
6187 bar.sync 0;
6188
6189 // ===== Phase 3: Compute sum1 = sum(dl_dx_hat), sum2 = sum(dl_dx_hat * x_hat) =====
6190 // Also accumulate grad_weight and grad_bias via atomicAdd
6191 mov.f32 %sum1, 0f00000000;
6192 mov.f32 %sum2, 0f00000000;
6193 mov.u32 %j, %r_tid;
6194LNB_S12:
6195 setp.ge.u32 %lp, %j, %cols_reg;
6196 @%lp bra LNB_S12D;
6197 // Load input[row, j]
6198 cvt.u64.u32 %off, %j;
6199 shl.b64 %off, %off, 2;
6200 add.u64 %addr, %in, %off;
6201 add.u64 %addr, %addr, %row_off;
6202 ld.global.f32 %val, [%addr];
6203 // x_hat = (val - mean) * inv_std
6204 sub.f32 %x_hat, %val, %mean;
6205 mul.f32 %x_hat, %x_hat, %inv_std;
6206 // Load grad_output[row, j]
6207 cvt.u64.u32 %off, %j;
6208 shl.b64 %off, %off, 2;
6209 add.u64 %addr, %go, %off;
6210 add.u64 %addr, %addr, %row_off;
6211 ld.global.f32 %gov, [%addr];
6212 // Load weight[j]
6213 cvt.u64.u32 %off, %j;
6214 shl.b64 %off, %off, 2;
6215 add.u64 %addr, %w, %off;
6216 ld.global.f32 %wv, [%addr];
6217 // dl_dx_hat = grad_output * weight
6218 mul.f32 %dl_dx_hat, %gov, %wv;
6219 // Accumulate sums
6220 add.f32 %sum1, %sum1, %dl_dx_hat;
6221 fma.rn.f32 %sum2, %dl_dx_hat, %x_hat, %sum2;
6222 // atomicAdd grad_weight[j] += grad_output * x_hat
6223 cvt.u64.u32 %off, %j;
6224 shl.b64 %off, %off, 2;
6225 add.u64 %addr, %gw, %off;
6226 mul.f32 %result, %gov, %x_hat;
6227 atom.global.add.f32 %result, [%addr], %result;
6228 // atomicAdd grad_bias[j] += grad_output
6229 add.u64 %addr, %gb, %off;
6230 atom.global.add.f32 %result, [%addr], %gov;
6231 add.u32 %j, %j, %r_bdim;
6232 bra LNB_S12;
6233LNB_S12D:
6234 // Reduce sum1 in shared memory
6235 cvt.u64.u32 %off, %r_tid;
6236 shl.b64 %off, %off, 2;
6237 add.u64 %saddr, %sbase, %off;
6238 st.shared.f32 [%saddr], %sum1;
6239 bar.sync 0;
6240 mov.u32 %half, %r_bdim;
6241LNB_R1:
6242 shr.u32 %half, %half, 1;
6243 setp.eq.u32 %rp, %half, 0;
6244 @%rp bra LNB_R1D;
6245 setp.ge.u32 %rp, %r_tid, %half;
6246 @%rp bra LNB_R1S;
6247 add.u32 %r_otid, %r_tid, %half;
6248 cvt.u64.u32 %off, %r_otid;
6249 shl.b64 %off, %off, 2;
6250 add.u64 %saddr, %sbase, %off;
6251 ld.shared.f32 %other_val, [%saddr];
6252 cvt.u64.u32 %off, %r_tid;
6253 shl.b64 %off, %off, 2;
6254 add.u64 %saddr, %sbase, %off;
6255 ld.shared.f32 %sum1, [%saddr];
6256 add.f32 %sum1, %sum1, %other_val;
6257 st.shared.f32 [%saddr], %sum1;
6258LNB_R1S:
6259 bar.sync 0;
6260 bra LNB_R1;
6261LNB_R1D:
6262 ld.shared.f32 %sum1, [%sbase];
6263 // mean1 = sum1 / n
6264 div.approx.f32 %mean1, %sum1, %n_f;
6265 bar.sync 0;
6266
6267 // Reduce sum2 in shared memory
6268 cvt.u64.u32 %off, %r_tid;
6269 shl.b64 %off, %off, 2;
6270 add.u64 %saddr, %sbase, %off;
6271 st.shared.f32 [%saddr], %sum2;
6272 bar.sync 0;
6273 mov.u32 %half, %r_bdim;
6274LNB_R2:
6275 shr.u32 %half, %half, 1;
6276 setp.eq.u32 %rp, %half, 0;
6277 @%rp bra LNB_R2D;
6278 setp.ge.u32 %rp, %r_tid, %half;
6279 @%rp bra LNB_R2S;
6280 add.u32 %r_otid, %r_tid, %half;
6281 cvt.u64.u32 %off, %r_otid;
6282 shl.b64 %off, %off, 2;
6283 add.u64 %saddr, %sbase, %off;
6284 ld.shared.f32 %other_val, [%saddr];
6285 cvt.u64.u32 %off, %r_tid;
6286 shl.b64 %off, %off, 2;
6287 add.u64 %saddr, %sbase, %off;
6288 ld.shared.f32 %sum2, [%saddr];
6289 add.f32 %sum2, %sum2, %other_val;
6290 st.shared.f32 [%saddr], %sum2;
6291LNB_R2S:
6292 bar.sync 0;
6293 bra LNB_R2;
6294LNB_R2D:
6295 ld.shared.f32 %sum2, [%sbase];
6296 // mean2 = sum2 / n
6297 div.approx.f32 %mean2, %sum2, %n_f;
6298 bar.sync 0;
6299
6300 // ===== Phase 4: Compute grad_input =====
6301 // grad_input[j] = inv_std * (dl_dx_hat[j] - mean1 - x_hat[j] * mean2)
6302 mov.u32 %j, %r_tid;
6303LNB_GI:
6304 setp.ge.u32 %lp, %j, %cols_reg;
6305 @%lp bra LNB_GID;
6306 // Reload input to recompute x_hat
6307 cvt.u64.u32 %off, %j;
6308 shl.b64 %off, %off, 2;
6309 add.u64 %addr, %in, %off;
6310 add.u64 %addr, %addr, %row_off;
6311 ld.global.f32 %val, [%addr];
6312 sub.f32 %x_hat, %val, %mean;
6313 mul.f32 %x_hat, %x_hat, %inv_std;
6314 // Reload grad_output and weight to recompute dl_dx_hat
6315 cvt.u64.u32 %off, %j;
6316 shl.b64 %off, %off, 2;
6317 add.u64 %addr, %go, %off;
6318 add.u64 %addr, %addr, %row_off;
6319 ld.global.f32 %gov, [%addr];
6320 cvt.u64.u32 %off, %j;
6321 shl.b64 %off, %off, 2;
6322 add.u64 %addr, %w, %off;
6323 ld.global.f32 %wv, [%addr];
6324 mul.f32 %dl_dx_hat, %gov, %wv;
6325 // result = inv_std * (dl_dx_hat - mean1 - x_hat * mean2)
6326 sub.f32 %result, %dl_dx_hat, %mean1;
6327 mul.f32 %diff, %x_hat, %mean2;
6328 sub.f32 %result, %result, %diff;
6329 mul.f32 %result, %inv_std, %result;
6330 // Store grad_input[row, j]
6331 cvt.u64.u32 %off, %j;
6332 shl.b64 %off, %off, 2;
6333 add.u64 %addr, %gi, %off;
6334 add.u64 %addr, %addr, %row_off;
6335 st.global.f32 [%addr], %result;
6336 add.u32 %j, %j, %r_bdim;
6337 bra LNB_GI;
6338LNB_GID:
6339
6340LNB_DONE:
6341 ret;
6342}
6343";
6344
6345
6346#[cfg(feature = "cuda")]
6359pub(crate) const RMSNORM_PTX: &str = "\
6360.version 7.0
6361.target sm_52
6362.address_size 64
6363
6364.shared .align 4 .f32 sdata[256];
6365
6366.visible .entry rmsnorm_kernel(
6367 .param .u64 in_ptr,
6368 .param .u64 out_ptr,
6369 .param .u64 w_ptr,
6370 .param .u32 rows,
6371 .param .u32 cols,
6372 .param .f32 eps
6373) {
6374 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
6375 .reg .u64 %in, %out, %w, %row_off, %off, %sbase, %saddr;
6376 .reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %wv, %result, %other_val, %n_f;
6377 .reg .pred %p, %lp, %rp;
6378
6379 ld.param.u64 %in, [in_ptr];
6380 ld.param.u64 %out, [out_ptr];
6381 ld.param.u64 %w, [w_ptr];
6382 ld.param.u32 %rows_reg, [rows];
6383 ld.param.u32 %cols_reg, [cols];
6384 ld.param.f32 %eps_r, [eps];
6385
6386 mov.u64 %sbase, sdata;
6387
6388 mov.u32 %r_bid, %ctaid.x;
6389 mov.u32 %r_bdim, %ntid.x;
6390 mov.u32 %r_tid, %tid.x;
6391
6392 setp.ge.u32 %p, %r_bid, %rows_reg;
6393 @%p bra DONE;
6394
6395 cvt.u64.u32 %row_off, %r_bid;
6396 cvt.u64.u32 %off, %cols_reg;
6397 mul.lo.u64 %row_off, %row_off, %off;
6398 shl.b64 %row_off, %row_off, 2;
6399 cvt.rn.f32.u32 %n_f, %cols_reg;
6400
6401 // ===== Phase 1: Compute sum(x^2) =====
6402 mov.f32 %sq_sum, 0f00000000;
6403 mov.u32 %j, %r_tid;
6404SS:
6405 setp.ge.u32 %lp, %j, %cols_reg;
6406 @%lp bra SSD;
6407 cvt.u64.u32 %off, %j;
6408 shl.b64 %off, %off, 2;
6409 add.u64 %off, %in, %off;
6410 add.u64 %off, %off, %row_off;
6411 ld.global.f32 %val, [%off];
6412 fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
6413 add.u32 %j, %j, %r_bdim;
6414 bra SS;
6415SSD:
6416 cvt.u64.u32 %off, %r_tid;
6417 shl.b64 %off, %off, 2;
6418 add.u64 %saddr, %sbase, %off;
6419 st.shared.f32 [%saddr], %sq_sum;
6420 bar.sync 0;
6421 mov.u32 %half, %r_bdim;
6422SR:
6423 shr.u32 %half, %half, 1;
6424 setp.eq.u32 %rp, %half, 0;
6425 @%rp bra SRD;
6426 setp.ge.u32 %rp, %r_tid, %half;
6427 @%rp bra SRS;
6428 add.u32 %r_otid, %r_tid, %half;
6429 cvt.u64.u32 %off, %r_otid;
6430 shl.b64 %off, %off, 2;
6431 add.u64 %saddr, %sbase, %off;
6432 ld.shared.f32 %other_val, [%saddr];
6433 cvt.u64.u32 %off, %r_tid;
6434 shl.b64 %off, %off, 2;
6435 add.u64 %saddr, %sbase, %off;
6436 ld.shared.f32 %sq_sum, [%saddr];
6437 add.f32 %sq_sum, %sq_sum, %other_val;
6438 add.u64 %saddr, %sbase, %off;
6439 st.shared.f32 [%saddr], %sq_sum;
6440SRS:
6441 bar.sync 0;
6442 bra SR;
6443SRD:
6444 ld.shared.f32 %sq_sum, [%sbase];
6445 div.approx.f32 %sq_sum, %sq_sum, %n_f;
6446 add.f32 %sq_sum, %sq_sum, %eps_r;
6447 sqrt.approx.f32 %inv_rms, %sq_sum;
6448 rcp.approx.f32 %inv_rms, %inv_rms;
6449 bar.sync 0;
6450
6451 // ===== Phase 2: Normalize and scale =====
6452 // out[j] = x[j] * inv_rms * weight[j]
6453 mov.u32 %j, %r_tid;
6454NM:
6455 setp.ge.u32 %lp, %j, %cols_reg;
6456 @%lp bra NMD;
6457 cvt.u64.u32 %off, %j;
6458 shl.b64 %off, %off, 2;
6459 add.u64 %off, %in, %off;
6460 add.u64 %off, %off, %row_off;
6461 ld.global.f32 %val, [%off];
6462 mul.f32 %result, %val, %inv_rms;
6463 cvt.u64.u32 %off, %j;
6464 shl.b64 %off, %off, 2;
6465 add.u64 %off, %w, %off;
6466 ld.global.f32 %wv, [%off];
6467 mul.f32 %result, %result, %wv;
6468 cvt.u64.u32 %off, %j;
6469 shl.b64 %off, %off, 2;
6470 add.u64 %off, %out, %off;
6471 add.u64 %off, %off, %row_off;
6472 st.global.f32 [%off], %result;
6473 add.u32 %j, %j, %r_bdim;
6474 bra NM;
6475NMD:
6476
6477DONE:
6478 ret;
6479}
6480";
6481
6482
6483#[cfg(feature = "cuda")]
6507pub(crate) const RMSNORM_BACKWARD_PTX: &str = "\
6508.version 7.0
6509.target sm_52
6510.address_size 64
6511
6512.shared .align 4 .f32 sdata[256];
6513
6514.visible .entry rmsnorm_backward_kernel(
6515 .param .u64 in_ptr,
6516 .param .u64 grad_out_ptr,
6517 .param .u64 w_ptr,
6518 .param .u64 grad_in_ptr,
6519 .param .u64 grad_w_ptr,
6520 .param .u32 rows,
6521 .param .u32 cols,
6522 .param .f32 eps
6523) {
6524 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
6525 .reg .u64 %in, %go, %w, %gi, %gw, %row_off, %off, %sbase, %saddr, %addr;
6526 .reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %inv_rms3, %wv, %gov;
6527 .reg .f32 %dot, %other_val, %n_f, %coeff, %result, %tmp;
6528 .reg .pred %p, %lp, %rp;
6529
6530 ld.param.u64 %in, [in_ptr];
6531 ld.param.u64 %go, [grad_out_ptr];
6532 ld.param.u64 %w, [w_ptr];
6533 ld.param.u64 %gi, [grad_in_ptr];
6534 ld.param.u64 %gw, [grad_w_ptr];
6535 ld.param.u32 %rows_reg, [rows];
6536 ld.param.u32 %cols_reg, [cols];
6537 ld.param.f32 %eps_r, [eps];
6538
6539 mov.u64 %sbase, sdata;
6540
6541 mov.u32 %r_bid, %ctaid.x;
6542 mov.u32 %r_bdim, %ntid.x;
6543 mov.u32 %r_tid, %tid.x;
6544
6545 setp.ge.u32 %p, %r_bid, %rows_reg;
6546 @%p bra RNB_DONE;
6547
6548 // row_off = bid * cols * 4 (byte offset for this row)
6549 cvt.u64.u32 %row_off, %r_bid;
6550 cvt.u64.u32 %off, %cols_reg;
6551 mul.lo.u64 %row_off, %row_off, %off;
6552 shl.b64 %row_off, %row_off, 2;
6553 cvt.rn.f32.u32 %n_f, %cols_reg;
6554
6555 // ===== Phase 1: Compute sum(x^2) -> inv_rms =====
6556 mov.f32 %sq_sum, 0f00000000;
6557 mov.u32 %j, %r_tid;
6558RNB_SS:
6559 setp.ge.u32 %lp, %j, %cols_reg;
6560 @%lp bra RNB_SSD;
6561 cvt.u64.u32 %off, %j;
6562 shl.b64 %off, %off, 2;
6563 add.u64 %addr, %in, %off;
6564 add.u64 %addr, %addr, %row_off;
6565 ld.global.f32 %val, [%addr];
6566 fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
6567 add.u32 %j, %j, %r_bdim;
6568 bra RNB_SS;
6569RNB_SSD:
6570 // Shared memory reduce for sum(x^2)
6571 cvt.u64.u32 %off, %r_tid;
6572 shl.b64 %off, %off, 2;
6573 add.u64 %saddr, %sbase, %off;
6574 st.shared.f32 [%saddr], %sq_sum;
6575 bar.sync 0;
6576 mov.u32 %half, %r_bdim;
6577RNB_SR:
6578 shr.u32 %half, %half, 1;
6579 setp.eq.u32 %rp, %half, 0;
6580 @%rp bra RNB_SRD;
6581 setp.ge.u32 %rp, %r_tid, %half;
6582 @%rp bra RNB_SRS;
6583 add.u32 %r_otid, %r_tid, %half;
6584 cvt.u64.u32 %off, %r_otid;
6585 shl.b64 %off, %off, 2;
6586 add.u64 %saddr, %sbase, %off;
6587 ld.shared.f32 %other_val, [%saddr];
6588 cvt.u64.u32 %off, %r_tid;
6589 shl.b64 %off, %off, 2;
6590 add.u64 %saddr, %sbase, %off;
6591 ld.shared.f32 %sq_sum, [%saddr];
6592 add.f32 %sq_sum, %sq_sum, %other_val;
6593 st.shared.f32 [%saddr], %sq_sum;
6594RNB_SRS:
6595 bar.sync 0;
6596 bra RNB_SR;
6597RNB_SRD:
6598 ld.shared.f32 %sq_sum, [%sbase];
6599 div.approx.f32 %sq_sum, %sq_sum, %n_f;
6600 add.f32 %sq_sum, %sq_sum, %eps_r;
6601 sqrt.approx.f32 %inv_rms, %sq_sum;
6602 rcp.approx.f32 %inv_rms, %inv_rms;
6603 // inv_rms3 = inv_rms^3 = inv_rms * inv_rms * inv_rms
6604 mul.f32 %inv_rms3, %inv_rms, %inv_rms;
6605 mul.f32 %inv_rms3, %inv_rms3, %inv_rms;
6606 bar.sync 0;
6607
6608 // ===== Phase 2: Compute dot = sum(go[j] * x[j] * w[j]) =====
6609 // Also accumulate grad_weight via atomicAdd
6610 mov.f32 %dot, 0f00000000;
6611 mov.u32 %j, %r_tid;
6612RNB_DOT:
6613 setp.ge.u32 %lp, %j, %cols_reg;
6614 @%lp bra RNB_DOTD;
6615 // Load input[row, j]
6616 cvt.u64.u32 %off, %j;
6617 shl.b64 %off, %off, 2;
6618 add.u64 %addr, %in, %off;
6619 add.u64 %addr, %addr, %row_off;
6620 ld.global.f32 %val, [%addr];
6621 // Load grad_output[row, j]
6622 cvt.u64.u32 %off, %j;
6623 shl.b64 %off, %off, 2;
6624 add.u64 %addr, %go, %off;
6625 add.u64 %addr, %addr, %row_off;
6626 ld.global.f32 %gov, [%addr];
6627 // Load weight[j]
6628 cvt.u64.u32 %off, %j;
6629 shl.b64 %off, %off, 2;
6630 add.u64 %addr, %w, %off;
6631 ld.global.f32 %wv, [%addr];
6632 // dot += go * x * w
6633 mul.f32 %tmp, %gov, %val;
6634 fma.rn.f32 %dot, %tmp, %wv, %dot;
6635 // atomicAdd grad_weight[j] += go * x * inv_rms
6636 cvt.u64.u32 %off, %j;
6637 shl.b64 %off, %off, 2;
6638 add.u64 %addr, %gw, %off;
6639 mul.f32 %result, %gov, %val;
6640 mul.f32 %result, %result, %inv_rms;
6641 atom.global.add.f32 %result, [%addr], %result;
6642 add.u32 %j, %j, %r_bdim;
6643 bra RNB_DOT;
6644RNB_DOTD:
6645 // Reduce dot in shared memory
6646 cvt.u64.u32 %off, %r_tid;
6647 shl.b64 %off, %off, 2;
6648 add.u64 %saddr, %sbase, %off;
6649 st.shared.f32 [%saddr], %dot;
6650 bar.sync 0;
6651 mov.u32 %half, %r_bdim;
6652RNB_DR:
6653 shr.u32 %half, %half, 1;
6654 setp.eq.u32 %rp, %half, 0;
6655 @%rp bra RNB_DRD;
6656 setp.ge.u32 %rp, %r_tid, %half;
6657 @%rp bra RNB_DRS;
6658 add.u32 %r_otid, %r_tid, %half;
6659 cvt.u64.u32 %off, %r_otid;
6660 shl.b64 %off, %off, 2;
6661 add.u64 %saddr, %sbase, %off;
6662 ld.shared.f32 %other_val, [%saddr];
6663 cvt.u64.u32 %off, %r_tid;
6664 shl.b64 %off, %off, 2;
6665 add.u64 %saddr, %sbase, %off;
6666 ld.shared.f32 %dot, [%saddr];
6667 add.f32 %dot, %dot, %other_val;
6668 st.shared.f32 [%saddr], %dot;
6669RNB_DRS:
6670 bar.sync 0;
6671 bra RNB_DR;
6672RNB_DRD:
6673 ld.shared.f32 %dot, [%sbase];
6674 // coeff = dot * inv_rms3 / n
6675 mul.f32 %coeff, %dot, %inv_rms3;
6676 div.approx.f32 %coeff, %coeff, %n_f;
6677 bar.sync 0;
6678
6679 // ===== Phase 3: Compute grad_input =====
6680 // grad_input[j] = inv_rms * w[j] * go[j] - x[j] * coeff
6681 mov.u32 %j, %r_tid;
6682RNB_GI:
6683 setp.ge.u32 %lp, %j, %cols_reg;
6684 @%lp bra RNB_GID;
6685 // Reload input
6686 cvt.u64.u32 %off, %j;
6687 shl.b64 %off, %off, 2;
6688 add.u64 %addr, %in, %off;
6689 add.u64 %addr, %addr, %row_off;
6690 ld.global.f32 %val, [%addr];
6691 // Reload grad_output and weight
6692 cvt.u64.u32 %off, %j;
6693 shl.b64 %off, %off, 2;
6694 add.u64 %addr, %go, %off;
6695 add.u64 %addr, %addr, %row_off;
6696 ld.global.f32 %gov, [%addr];
6697 cvt.u64.u32 %off, %j;
6698 shl.b64 %off, %off, 2;
6699 add.u64 %addr, %w, %off;
6700 ld.global.f32 %wv, [%addr];
6701 // result = inv_rms * w * go - x * coeff
6702 mul.f32 %result, %inv_rms, %wv;
6703 mul.f32 %result, %result, %gov;
6704 mul.f32 %tmp, %val, %coeff;
6705 sub.f32 %result, %result, %tmp;
6706 // Store grad_input[row, j]
6707 cvt.u64.u32 %off, %j;
6708 shl.b64 %off, %off, 2;
6709 add.u64 %addr, %gi, %off;
6710 add.u64 %addr, %addr, %row_off;
6711 st.global.f32 [%addr], %result;
6712 add.u32 %j, %j, %r_bdim;
6713 bra RNB_GI;
6714RNB_GID:
6715
6716RNB_DONE:
6717 ret;
6718}
6719";
6720
6721
6722#[cfg(feature = "cuda")]
6755pub(crate) const BATCHNORM_FORWARD_PTX: &str = "\
6756.version 7.0
6757.target sm_52
6758.address_size 64
6759
6760// Shared memory for block reduction
6761.shared .align 4 .f32 smem_sum[256];
6762.shared .align 4 .f32 smem_sq[256];
6763
6764.visible .entry batchnorm_forward_kernel(
6765 .param .u64 input_ptr,
6766 .param .u64 output_ptr,
6767 .param .u64 weight_ptr,
6768 .param .u64 bias_ptr,
6769 .param .u64 rmean_ptr,
6770 .param .u64 rvar_ptr,
6771 .param .u64 save_mean_ptr,
6772 .param .u64 save_invstd_ptr,
6773 .param .u32 channels,
6774 .param .u32 spatial,
6775 .param .f32 eps,
6776 .param .f32 momentum,
6777 .param .u32 total_per_ch,
6778 .param .u32 training
6779) {
6780 .reg .u32 %tid, %bid, %bdim, %ch, %n_ch, %sp, %tpc, %idx, %train;
6781 .reg .u64 %in, %out, %w, %b, %rm, %rv, %sm, %si, %off64, %tmp64;
6782 .reg .f32 %sum, %sqsum, %val, %mean, %var, %invstd;
6783 .reg .f32 %gamma, %beta, %eps_reg, %mom, %other;
6784 .reg .f32 %n_f, %one, %normalized;
6785 .reg .pred %p, %ptrain, %ptid0;
6786 .reg .u32 %half;
6787
6788 ld.param.u64 %in, [input_ptr];
6789 ld.param.u64 %out, [output_ptr];
6790 ld.param.u64 %w, [weight_ptr];
6791 ld.param.u64 %b, [bias_ptr];
6792 ld.param.u64 %rm, [rmean_ptr];
6793 ld.param.u64 %rv, [rvar_ptr];
6794 ld.param.u64 %sm, [save_mean_ptr];
6795 ld.param.u64 %si, [save_invstd_ptr];
6796 ld.param.u32 %n_ch, [channels];
6797 ld.param.u32 %sp, [spatial];
6798 ld.param.f32 %eps_reg, [eps];
6799 ld.param.f32 %mom, [momentum];
6800 ld.param.u32 %tpc, [total_per_ch];
6801 ld.param.u32 %train, [training];
6802
6803 mov.u32 %bid, %ctaid.x;
6804 mov.u32 %tid, %tid.x;
6805 mov.u32 %bdim, %ntid.x;
6806 mov.u32 %ch, %bid;
6807 mov.f32 %one, 0f3F800000;
6808
6809 setp.ge.u32 %p, %ch, %n_ch;
6810 @%p bra END;
6811
6812 setp.ne.u32 %ptrain, %train, 0;
6813
6814 // ---- Pass 1: compute sum and sum-of-squares for this channel ----
6815 mov.f32 %sum, 0f00000000;
6816 mov.f32 %sqsum, 0f00000000;
6817
6818 // Grid-stride loop over B*spatial for this channel
6819 mov.u32 %idx, %tid;
6820PASS1_LOOP:
6821 setp.ge.u32 %p, %idx, %tpc;
6822 @%p bra PASS1_DONE;
6823
6824 // Linear offset = (idx / spatial) * channels * spatial + ch * spatial + idx % spatial
6825 div.u32 %half, %idx, %sp;
6826 rem.u32 %half, %idx, %sp; // reuse half as spatial_idx
6827 // batch_offset = (idx / sp) * (n_ch * sp) + ch * sp + (idx % sp)
6828 div.u32 %half, %idx, %sp; // batch_idx
6829 mul.lo.u32 %half, %half, %n_ch;
6830 add.u32 %half, %half, %ch;
6831 mul.lo.u32 %half, %half, %sp;
6832 rem.u32 %idx, %idx, %sp; // spatial_idx
6833 add.u32 %half, %half, %idx;
6834
6835 cvt.u64.u32 %off64, %half;
6836 shl.b64 %off64, %off64, 2;
6837 add.u64 %tmp64, %in, %off64;
6838 ld.global.f32 %val, [%tmp64];
6839 add.f32 %sum, %sum, %val;
6840 fma.rn.f32 %sqsum, %val, %val, %sqsum;
6841
6842 // Restore idx for stride
6843 // Recompute idx from tid + iteration * bdim
6844 add.u32 %idx, %idx, %bdim; // This is wrong - need proper loop counter
6845 bra PASS1_LOOP;
6846
6847PASS1_DONE:
6848 // Store to shared memory for block reduction
6849 cvt.u64.u32 %off64, %tid;
6850 shl.b64 %off64, %off64, 2;
6851 st.shared.f32 [smem_sum + %off64], %sum;
6852 st.shared.f32 [smem_sq + %off64], %sqsum;
6853 bar.sync 0;
6854
6855 // Tree reduction
6856 mov.u32 %half, 128;
6857REDUCE_LOOP:
6858 setp.lt.u32 %p, %half, 1;
6859 @%p bra REDUCE_DONE;
6860 setp.ge.u32 %p, %tid, %half;
6861 @%p bra REDUCE_SKIP;
6862
6863 add.u32 %idx, %tid, %half;
6864 cvt.u64.u32 %off64, %idx;
6865 shl.b64 %off64, %off64, 2;
6866 ld.shared.f32 %other, [smem_sum + %off64];
6867 cvt.u64.u32 %tmp64, %tid;
6868 shl.b64 %tmp64, %tmp64, 2;
6869 ld.shared.f32 %sum, [smem_sum + %tmp64];
6870 add.f32 %sum, %sum, %other;
6871 st.shared.f32 [smem_sum + %tmp64], %sum;
6872
6873 ld.shared.f32 %other, [smem_sq + %off64];
6874 ld.shared.f32 %sqsum, [smem_sq + %tmp64];
6875 add.f32 %sqsum, %sqsum, %other;
6876 st.shared.f32 [smem_sq + %tmp64], %sqsum;
6877
6878REDUCE_SKIP:
6879 bar.sync 0;
6880 shr.u32 %half, %half, 1;
6881 bra REDUCE_LOOP;
6882
6883REDUCE_DONE:
6884 // Thread 0 computes mean and invstd
6885 setp.ne.u32 %ptid0, %tid, 0;
6886
6887 @%ptid0 bra WAIT_STATS;
6888
6889 ld.shared.f32 %sum, [smem_sum];
6890 ld.shared.f32 %sqsum, [smem_sq];
6891 cvt.rn.f32.u32 %n_f, %tpc;
6892 div.rn.f32 %mean, %sum, %n_f;
6893 // var = sqsum/n - mean^2
6894 div.rn.f32 %var, %sqsum, %n_f;
6895 fma.rn.f32 %var, %mean, %mean, %var; // This adds mean^2, need to subtract
6896 // Actually: var = E[x^2] - E[x]^2, so var = sqsum/n - mean^2
6897 // We had: var = sqsum/n, now subtract mean^2
6898 neg.f32 %other, %mean;
6899 fma.rn.f32 %var, %other, %mean, %var; // var = var + (-mean)*mean = sqsum/n - mean^2
6900
6901 // invstd = 1/sqrt(var + eps)
6902 add.f32 %other, %var, %eps_reg;
6903 sqrt.rn.f32 %other, %other;
6904 div.rn.f32 %invstd, %one, %other;
6905
6906 // Save mean and invstd
6907 cvt.u64.u32 %off64, %ch;
6908 shl.b64 %off64, %off64, 2;
6909 add.u64 %tmp64, %sm, %off64;
6910 st.global.f32 [%tmp64], %mean;
6911 add.u64 %tmp64, %si, %off64;
6912 st.global.f32 [%tmp64], %invstd;
6913
6914 // Store to shared for other threads
6915 st.shared.f32 [smem_sum], %mean;
6916 st.shared.f32 [smem_sq], %invstd;
6917
6918WAIT_STATS:
6919 bar.sync 0;
6920 // All threads read mean and invstd from shared
6921 ld.shared.f32 %mean, [smem_sum];
6922 ld.shared.f32 %invstd, [smem_sq];
6923
6924 // Load weight and bias for this channel
6925 cvt.u64.u32 %off64, %ch;
6926 shl.b64 %off64, %off64, 2;
6927 add.u64 %tmp64, %w, %off64;
6928 ld.global.f32 %gamma, [%tmp64];
6929 add.u64 %tmp64, %b, %off64;
6930 ld.global.f32 %beta, [%tmp64];
6931
6932 // ---- Pass 2: normalize + affine ----
6933 // For now this is a placeholder - the indexing needs to match pass 1
6934 // Each thread normalizes its elements
6935
6936END:
6937 ret;
6938}
6939";
6940
6941
6942#[cfg(feature = "cuda")]
6947pub(crate) const MAXPOOL2D_PTX: &str = "\
6948.version 7.0
6949.target sm_52
6950.address_size 64
6951
6952.visible .entry maxpool2d_forward_kernel(
6953 .param .u64 input_ptr,
6954 .param .u64 output_ptr,
6955 .param .u32 batch,
6956 .param .u32 channels,
6957 .param .u32 h_in,
6958 .param .u32 w_in,
6959 .param .u32 h_out,
6960 .param .u32 w_out,
6961 .param .u32 kh,
6962 .param .u32 kw,
6963 .param .u32 sh,
6964 .param .u32 sw,
6965 .param .u32 ph,
6966 .param .u32 pw,
6967 .param .u32 total
6968) {
6969 .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
6970 .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp;
6971 .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
6972 .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
6973 .reg .u32 %batch_reg, %ch_reg;
6974 .reg .u64 %in, %out, %off64, %tmp64;
6975 .reg .f32 %max_val, %cur_val, %neg_inf;
6976 .reg .pred %p, %p_bounds, %p_gt;
6977
6978 ld.param.u64 %in, [input_ptr];
6979 ld.param.u64 %out, [output_ptr];
6980 ld.param.u32 %batch_reg, [batch];
6981 ld.param.u32 %ch_reg, [channels];
6982 ld.param.u32 %h_in_reg, [h_in];
6983 ld.param.u32 %w_in_reg, [w_in];
6984 ld.param.u32 %h_out_reg, [h_out];
6985 ld.param.u32 %w_out_reg, [w_out];
6986 ld.param.u32 %kh_reg, [kh];
6987 ld.param.u32 %kw_reg, [kw];
6988 ld.param.u32 %sh_reg, [sh];
6989 ld.param.u32 %sw_reg, [sw];
6990 ld.param.u32 %ph_reg, [ph];
6991 ld.param.u32 %pw_reg, [pw];
6992 ld.param.u32 %total_reg, [total];
6993
6994 mov.u32 %bid, %ctaid.x;
6995 mov.u32 %bdim, %ntid.x;
6996 mov.u32 %tid, %tid.x;
6997 mov.u32 %gdim, %nctaid.x;
6998 mad.lo.u32 %idx, %bid, %bdim, %tid;
6999 mul.lo.u32 %stride, %bdim, %gdim;
7000
7001 // -inf for max initialization
7002 mov.f32 %neg_inf, 0fFF800000;
7003
7004LOOP:
7005 setp.ge.u32 %p, %idx, %total_reg;
7006 @%p bra END;
7007
7008 // Decompose idx into (b, c, oh, ow)
7009 mov.u32 %rem, %idx;
7010 div.u32 %b_idx, %rem, %ch_reg;
7011 // Actually need: idx = b * C * H_out * W_out + c * H_out * W_out + oh * W_out + ow
7012 // So decompose from the right:
7013 rem.u32 %ow, %rem, %w_out_reg;
7014 div.u32 %rem, %rem, %w_out_reg;
7015 rem.u32 %oh, %rem, %h_out_reg;
7016 div.u32 %rem, %rem, %h_out_reg;
7017 rem.u32 %c_idx, %rem, %ch_reg;
7018 div.u32 %b_idx, %rem, %ch_reg;
7019
7020 mov.f32 %max_val, %neg_inf;
7021
7022 // Slide the kernel window
7023 mov.u32 %i, 0;
7024KH_LOOP:
7025 setp.ge.u32 %p, %i, %kh_reg;
7026 @%p bra KH_DONE;
7027
7028 mov.u32 %j, 0;
7029KW_LOOP:
7030 setp.ge.u32 %p, %j, %kw_reg;
7031 @%p bra KW_DONE;
7032
7033 // ih = oh * sh + i - ph, iw = ow * sw + j - pw
7034 mad.lo.u32 %ih, %oh, %sh_reg, %i;
7035 sub.u32 %ih, %ih, %ph_reg;
7036 mad.lo.u32 %iw, %ow, %sw_reg, %j;
7037 sub.u32 %iw, %iw, %pw_reg;
7038
7039 // Bounds check: 0 <= ih < h_in && 0 <= iw < w_in
7040 // Since unsigned, just check < h_in and < w_in
7041 setp.ge.u32 %p_bounds, %ih, %h_in_reg;
7042 @%p_bounds bra KW_NEXT;
7043 setp.ge.u32 %p_bounds, %iw, %w_in_reg;
7044 @%p_bounds bra KW_NEXT;
7045
7046 // input_offset = b * C * H * W + c * H * W + ih * W + iw
7047 mul.lo.u32 %tmp, %b_idx, %ch_reg;
7048 add.u32 %tmp, %tmp, %c_idx;
7049 mul.lo.u32 %tmp, %tmp, %h_in_reg;
7050 add.u32 %tmp, %tmp, %ih;
7051 mul.lo.u32 %tmp, %tmp, %w_in_reg;
7052 add.u32 %tmp, %tmp, %iw;
7053
7054 cvt.u64.u32 %off64, %tmp;
7055 shl.b64 %off64, %off64, 2;
7056 add.u64 %tmp64, %in, %off64;
7057 ld.global.f32 %cur_val, [%tmp64];
7058
7059 max.f32 %max_val, %max_val, %cur_val;
7060
7061KW_NEXT:
7062 add.u32 %j, %j, 1;
7063 bra KW_LOOP;
7064
7065KW_DONE:
7066 add.u32 %i, %i, 1;
7067 bra KH_LOOP;
7068
7069KH_DONE:
7070 // Store output
7071 cvt.u64.u32 %off64, %idx;
7072 shl.b64 %off64, %off64, 2;
7073 add.u64 %tmp64, %out, %off64;
7074 st.global.f32 [%tmp64], %max_val;
7075
7076 add.u32 %idx, %idx, %stride;
7077 bra LOOP;
7078
7079END:
7080 ret;
7081}
7082";
7083
7084
7085#[cfg(feature = "cuda")]
7090pub(crate) const AVGPOOL2D_PTX: &str = "\
7091.version 7.0
7092.target sm_52
7093.address_size 64
7094
7095.visible .entry avgpool2d_forward_kernel(
7096 .param .u64 input_ptr,
7097 .param .u64 output_ptr,
7098 .param .u32 batch,
7099 .param .u32 channels,
7100 .param .u32 h_in,
7101 .param .u32 w_in,
7102 .param .u32 h_out,
7103 .param .u32 w_out,
7104 .param .u32 kh,
7105 .param .u32 kw,
7106 .param .u32 sh,
7107 .param .u32 sw,
7108 .param .u32 ph,
7109 .param .u32 pw,
7110 .param .u32 total
7111) {
7112 .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
7113 .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp, %count;
7114 .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
7115 .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
7116 .reg .u32 %batch_reg, %ch_reg;
7117 .reg .u64 %in, %out, %off64, %tmp64;
7118 .reg .f32 %sum_val, %cur_val, %count_f, %avg;
7119 .reg .pred %p, %p_bounds;
7120
7121 ld.param.u64 %in, [input_ptr];
7122 ld.param.u64 %out, [output_ptr];
7123 ld.param.u32 %batch_reg, [batch];
7124 ld.param.u32 %ch_reg, [channels];
7125 ld.param.u32 %h_in_reg, [h_in];
7126 ld.param.u32 %w_in_reg, [w_in];
7127 ld.param.u32 %h_out_reg, [h_out];
7128 ld.param.u32 %w_out_reg, [w_out];
7129 ld.param.u32 %kh_reg, [kh];
7130 ld.param.u32 %kw_reg, [kw];
7131 ld.param.u32 %sh_reg, [sh];
7132 ld.param.u32 %sw_reg, [sw];
7133 ld.param.u32 %ph_reg, [ph];
7134 ld.param.u32 %pw_reg, [pw];
7135 ld.param.u32 %total_reg, [total];
7136
7137 mov.u32 %bid, %ctaid.x;
7138 mov.u32 %bdim, %ntid.x;
7139 mov.u32 %tid, %tid.x;
7140 mov.u32 %gdim, %nctaid.x;
7141 mad.lo.u32 %idx, %bid, %bdim, %tid;
7142 mul.lo.u32 %stride, %bdim, %gdim;
7143
7144LOOP:
7145 setp.ge.u32 %p, %idx, %total_reg;
7146 @%p bra END;
7147
7148 // Decompose idx into (b, c, oh, ow) — same as MaxPool2d
7149 mov.u32 %rem, %idx;
7150 rem.u32 %ow, %rem, %w_out_reg;
7151 div.u32 %rem, %rem, %w_out_reg;
7152 rem.u32 %oh, %rem, %h_out_reg;
7153 div.u32 %rem, %rem, %h_out_reg;
7154 rem.u32 %c_idx, %rem, %ch_reg;
7155 div.u32 %b_idx, %rem, %ch_reg;
7156
7157 mov.f32 %sum_val, 0f00000000;
7158 mov.u32 %count, 0;
7159
7160 mov.u32 %i, 0;
7161AKH_LOOP:
7162 setp.ge.u32 %p, %i, %kh_reg;
7163 @%p bra AKH_DONE;
7164
7165 mov.u32 %j, 0;
7166AKW_LOOP:
7167 setp.ge.u32 %p, %j, %kw_reg;
7168 @%p bra AKW_DONE;
7169
7170 mad.lo.u32 %ih, %oh, %sh_reg, %i;
7171 sub.u32 %ih, %ih, %ph_reg;
7172 mad.lo.u32 %iw, %ow, %sw_reg, %j;
7173 sub.u32 %iw, %iw, %pw_reg;
7174
7175 setp.ge.u32 %p_bounds, %ih, %h_in_reg;
7176 @%p_bounds bra AKW_NEXT;
7177 setp.ge.u32 %p_bounds, %iw, %w_in_reg;
7178 @%p_bounds bra AKW_NEXT;
7179
7180 mul.lo.u32 %tmp, %b_idx, %ch_reg;
7181 add.u32 %tmp, %tmp, %c_idx;
7182 mul.lo.u32 %tmp, %tmp, %h_in_reg;
7183 add.u32 %tmp, %tmp, %ih;
7184 mul.lo.u32 %tmp, %tmp, %w_in_reg;
7185 add.u32 %tmp, %tmp, %iw;
7186
7187 cvt.u64.u32 %off64, %tmp;
7188 shl.b64 %off64, %off64, 2;
7189 add.u64 %tmp64, %in, %off64;
7190 ld.global.f32 %cur_val, [%tmp64];
7191
7192 add.f32 %sum_val, %sum_val, %cur_val;
7193 add.u32 %count, %count, 1;
7194
7195AKW_NEXT:
7196 add.u32 %j, %j, 1;
7197 bra AKW_LOOP;
7198
7199AKW_DONE:
7200 add.u32 %i, %i, 1;
7201 bra AKH_LOOP;
7202
7203AKH_DONE:
7204 // avg = sum / count (count_include_pad = false behavior)
7205 cvt.rn.f32.u32 %count_f, %count;
7206 div.rn.f32 %avg, %sum_val, %count_f;
7207
7208 cvt.u64.u32 %off64, %idx;
7209 shl.b64 %off64, %off64, 2;
7210 add.u64 %tmp64, %out, %off64;
7211 st.global.f32 [%tmp64], %avg;
7212
7213 add.u32 %idx, %idx, %stride;
7214 bra LOOP;
7215
7216END:
7217 ret;
7218}
7219";
7220
7221
7222#[cfg(feature = "cuda")]
7223pub(crate) const SOFTMAX_PTX: &str = "\
7224.version 7.0\n\
7225.target sm_52\n\
7226.address_size 64\n\
7227\n\
7228.shared .align 4 .f32 sdata[256];\n\
7229\n\
7230.visible .entry softmax_kernel(\n\
7231 .param .u64 input_ptr,\n\
7232 .param .u64 output_ptr,\n\
7233 .param .u32 rows,\n\
7234 .param .u32 cols\n\
7235) {\n\
7236 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
7237 .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
7238 .reg .f32 %val, %max_val, %sum_val, %exp_val, %result;\n\
7239 .reg .pred %p, %loop_p;\n\
7240 .reg .u32 %half, %other_tid;\n\
7241 .reg .f32 %other_val;\n\
7242 .reg .pred %reduce_p;\n\
7243\n\
7244 ld.param.u64 %in, [input_ptr];\n\
7245 ld.param.u64 %out, [output_ptr];\n\
7246 ld.param.u32 %rows_reg, [rows];\n\
7247 ld.param.u32 %cols_reg, [cols];\n\
7248\n\
7249 mov.u32 %bid, %ctaid.x;\n\
7250 mov.u32 %bdim, %ntid.x;\n\
7251 mov.u32 %r_tid, %tid.x;\n\
7252 mov.u64 %sbase, sdata;\n\
7253\n\
7254 setp.ge.u32 %p, %bid, %rows_reg;\n\
7255 @%p bra DONE;\n\
7256\n\
7257 cvt.u64.u32 %row_off, %bid;\n\
7258 cvt.u64.u32 %off, %cols_reg;\n\
7259 mul.lo.u64 %row_off, %row_off, %off;\n\
7260 shl.b64 %row_off, %row_off, 2;\n\
7261\n\
7262 mov.f32 %max_val, 0fFF800000;\n\
7263 mov.u32 %j, %r_tid;\n\
7264FIND_MAX:\n\
7265 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7266 @%loop_p bra FIND_MAX_DONE;\n\
7267 cvt.u64.u32 %off, %j;\n\
7268 shl.b64 %off, %off, 2;\n\
7269 add.u64 %off, %in, %off;\n\
7270 add.u64 %off, %off, %row_off;\n\
7271 ld.global.f32 %val, [%off];\n\
7272 max.f32 %max_val, %max_val, %val;\n\
7273 add.u32 %j, %j, %bdim;\n\
7274 bra FIND_MAX;\n\
7275FIND_MAX_DONE:\n\
7276\n\
7277 cvt.u64.u32 %off, %r_tid;\n\
7278 shl.b64 %off, %off, 2;\n\
7279 add.u64 %saddr, %sbase, %off;\n\
7280 st.shared.f32 [%saddr], %max_val;\n\
7281 bar.sync 0;\n\
7282\n\
7283 mov.u32 %half, %bdim;\n\
7284MAX_REDUCE:\n\
7285 shr.u32 %half, %half, 1;\n\
7286 setp.eq.u32 %reduce_p, %half, 0;\n\
7287 @%reduce_p bra MAX_REDUCE_DONE;\n\
7288 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7289 @%reduce_p bra MAX_REDUCE_SKIP;\n\
7290 add.u32 %other_tid, %r_tid, %half;\n\
7291 cvt.u64.u32 %off, %other_tid;\n\
7292 shl.b64 %off, %off, 2;\n\
7293 add.u64 %saddr, %sbase, %off;
7294 ld.shared.f32 %other_val, [%saddr];\n\
7295 cvt.u64.u32 %off, %r_tid;\n\
7296 shl.b64 %off, %off, 2;\n\
7297 add.u64 %saddr, %sbase, %off;\n\
7298 ld.shared.f32 %max_val, [%saddr];\n\
7299 max.f32 %max_val, %max_val, %other_val;\n\
7300 add.u64 %saddr, %sbase, %off;\n\
7301 st.shared.f32 [%saddr], %max_val;\n\
7302MAX_REDUCE_SKIP:\n\
7303 bar.sync 0;\n\
7304 bra MAX_REDUCE;\n\
7305MAX_REDUCE_DONE:\n\
7306\n\
7307 ld.shared.f32 %max_val, [sdata];\n\
7308 bar.sync 0;\n\
7309\n\
7310 mov.f32 %sum_val, 0f00000000;\n\
7311 mov.u32 %j, %r_tid;\n\
7312SUM_EXP:\n\
7313 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7314 @%loop_p bra SUM_EXP_DONE;\n\
7315 cvt.u64.u32 %off, %j;\n\
7316 shl.b64 %off, %off, 2;\n\
7317 add.u64 %off, %in, %off;\n\
7318 add.u64 %off, %off, %row_off;\n\
7319 ld.global.f32 %val, [%off];\n\
7320 sub.f32 %val, %val, %max_val;\n\
7321 mul.f32 %val, %val, 0f3FB8AA3B;\n\
7322 ex2.approx.f32 %exp_val, %val;\n\
7323 add.f32 %sum_val, %sum_val, %exp_val;\n\
7324 cvt.u64.u32 %off, %j;\n\
7325 shl.b64 %off, %off, 2;\n\
7326 add.u64 %off, %out, %off;\n\
7327 add.u64 %off, %off, %row_off;\n\
7328 st.global.f32 [%off], %exp_val;\n\
7329 add.u32 %j, %j, %bdim;\n\
7330 bra SUM_EXP;\n\
7331SUM_EXP_DONE:\n\
7332\n\
7333 cvt.u64.u32 %off, %r_tid;\n\
7334 shl.b64 %off, %off, 2;\n\
7335 add.u64 %saddr, %sbase, %off;\n\
7336 st.shared.f32 [%saddr], %sum_val;\n\
7337 bar.sync 0;\n\
7338\n\
7339 mov.u32 %half, %bdim;\n\
7340SUM_REDUCE:\n\
7341 shr.u32 %half, %half, 1;\n\
7342 setp.eq.u32 %reduce_p, %half, 0;\n\
7343 @%reduce_p bra SUM_REDUCE_DONE;\n\
7344 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7345 @%reduce_p bra SUM_REDUCE_SKIP;\n\
7346 add.u32 %other_tid, %r_tid, %half;\n\
7347 cvt.u64.u32 %off, %other_tid;\n\
7348 shl.b64 %off, %off, 2;\n\
7349 add.u64 %saddr, %sbase, %off;
7350 ld.shared.f32 %other_val, [%saddr];\n\
7351 cvt.u64.u32 %off, %r_tid;\n\
7352 shl.b64 %off, %off, 2;\n\
7353 add.u64 %saddr, %sbase, %off;\n\
7354 ld.shared.f32 %sum_val, [%saddr];\n\
7355 add.f32 %sum_val, %sum_val, %other_val;\n\
7356 add.u64 %saddr, %sbase, %off;\n\
7357 st.shared.f32 [%saddr], %sum_val;\n\
7358SUM_REDUCE_SKIP:\n\
7359 bar.sync 0;\n\
7360 bra SUM_REDUCE;\n\
7361SUM_REDUCE_DONE:\n\
7362\n\
7363 ld.shared.f32 %sum_val, [sdata];\n\
7364 bar.sync 0;\n\
7365\n\
7366 rcp.approx.f32 %sum_val, %sum_val;\n\
7367 mov.u32 %j, %r_tid;\n\
7368NORMALIZE:\n\
7369 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7370 @%loop_p bra NORMALIZE_DONE;\n\
7371 cvt.u64.u32 %off, %j;\n\
7372 shl.b64 %off, %off, 2;\n\
7373 add.u64 %off, %out, %off;\n\
7374 add.u64 %off, %off, %row_off;\n\
7375 ld.global.f32 %val, [%off];\n\
7376 mul.f32 %result, %val, %sum_val;\n\
7377 st.global.f32 [%off], %result;\n\
7378 add.u32 %j, %j, %bdim;\n\
7379 bra NORMALIZE;\n\
7380NORMALIZE_DONE:\n\
7381\n\
7382DONE:\n\
7383 ret;\n\
7384}\n\
7385";
7386
7387#[cfg(feature = "cuda")]
7389pub(crate) const SOFTMAX_F64_PTX: &str = "\
7390.version 7.0\n\
7391.target sm_52\n\
7392.address_size 64\n\
7393\n\
7394.shared .align 8 .f64 sdata[256];\n\
7395\n\
7396.visible .entry softmax_f64_kernel(\n\
7397 .param .u64 input_ptr,\n\
7398 .param .u64 output_ptr,\n\
7399 .param .u32 rows,\n\
7400 .param .u32 cols\n\
7401) {\n\
7402 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
7403 .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
7404 .reg .f64 %val, %max_val, %sum_val, %exp_val, %result, %one;\n\
7405 .reg .pred %p, %loop_p;\n\
7406 .reg .u32 %half, %other_tid;\n\
7407 .reg .f64 %other_val;\n\
7408 .reg .pred %reduce_p;\n\
7409 .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
7410 .reg .s32 %e_ni;\n\
7411 .reg .s64 %e_ni64, %e_bits;\n\
7412\n\
7413 ld.param.u64 %in, [input_ptr];\n\
7414 ld.param.u64 %out, [output_ptr];\n\
7415 ld.param.u32 %rows_reg, [rows];\n\
7416 ld.param.u32 %cols_reg, [cols];\n\
7417\n\
7418 mov.u32 %bid, %ctaid.x;\n\
7419 mov.u32 %bdim, %ntid.x;\n\
7420 mov.u32 %r_tid, %tid.x;\n\
7421 mov.u64 %sbase, sdata;\n\
7422 mov.f64 %one, 0d3FF0000000000000;\n\
7423\n\
7424 setp.ge.u32 %p, %bid, %rows_reg;\n\
7425 @%p bra DONE;\n\
7426\n\
7427 cvt.u64.u32 %row_off, %bid;\n\
7428 cvt.u64.u32 %off, %cols_reg;\n\
7429 mul.lo.u64 %row_off, %row_off, %off;\n\
7430 shl.b64 %row_off, %row_off, 3;\n\
7431\n\
7432 mov.f64 %max_val, 0dFFF0000000000000;\n\
7433 mov.u32 %j, %r_tid;\n\
7434FIND_MAX:\n\
7435 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7436 @%loop_p bra FIND_MAX_DONE;\n\
7437 cvt.u64.u32 %off, %j;\n\
7438 shl.b64 %off, %off, 3;\n\
7439 add.u64 %off, %in, %off;\n\
7440 add.u64 %off, %off, %row_off;\n\
7441 ld.global.f64 %val, [%off];\n\
7442 max.f64 %max_val, %max_val, %val;\n\
7443 add.u32 %j, %j, %bdim;\n\
7444 bra FIND_MAX;\n\
7445FIND_MAX_DONE:\n\
7446\n\
7447 cvt.u64.u32 %off, %r_tid;\n\
7448 shl.b64 %off, %off, 3;\n\
7449 add.u64 %saddr, %sbase, %off;\n\
7450 st.shared.f64 [%saddr], %max_val;\n\
7451 bar.sync 0;\n\
7452\n\
7453 mov.u32 %half, %bdim;\n\
7454MAX_REDUCE:\n\
7455 shr.u32 %half, %half, 1;\n\
7456 setp.eq.u32 %reduce_p, %half, 0;\n\
7457 @%reduce_p bra MAX_REDUCE_DONE;\n\
7458 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7459 @%reduce_p bra MAX_REDUCE_SKIP;\n\
7460 add.u32 %other_tid, %r_tid, %half;\n\
7461 cvt.u64.u32 %off, %other_tid;\n\
7462 shl.b64 %off, %off, 3;\n\
7463 add.u64 %saddr, %sbase, %off;\n\
7464 ld.shared.f64 %other_val, [%saddr];\n\
7465 cvt.u64.u32 %off, %r_tid;\n\
7466 shl.b64 %off, %off, 3;\n\
7467 add.u64 %saddr, %sbase, %off;\n\
7468 ld.shared.f64 %max_val, [%saddr];\n\
7469 max.f64 %max_val, %max_val, %other_val;\n\
7470 st.shared.f64 [%saddr], %max_val;\n\
7471MAX_REDUCE_SKIP:\n\
7472 bar.sync 0;\n\
7473 bra MAX_REDUCE;\n\
7474MAX_REDUCE_DONE:\n\
7475\n\
7476 ld.shared.f64 %max_val, [sdata];\n\
7477 bar.sync 0;\n\
7478\n\
7479 mov.f64 %sum_val, 0d0000000000000000;\n\
7480 mov.u32 %j, %r_tid;\n\
7481SUM_EXP:\n\
7482 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7483 @%loop_p bra SUM_EXP_DONE;\n\
7484 cvt.u64.u32 %off, %j;\n\
7485 shl.b64 %off, %off, 3;\n\
7486 add.u64 %off, %in, %off;\n\
7487 add.u64 %off, %off, %row_off;\n\
7488 ld.global.f64 %val, [%off];\n\
7489 sub.f64 %val, %val, %max_val;\n\
7490 mov.f64 %e_one, 0d3FF0000000000000;\n\
7491 mov.f64 %e_half, 0d3FE0000000000000;\n\
7492 mul.f64 %e_nf, %val, 0d3FF71547652B82FE;\n\
7493 cvt.rni.f64.f64 %e_nf, %e_nf;\n\
7494 cvt.rni.s32.f64 %e_ni, %e_nf;\n\
7495 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %val;\n\
7496 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
7497 mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
7498 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
7499 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
7500 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
7501 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
7502 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
7503 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
7504 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
7505 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
7506 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
7507 fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
7508 fma.rn.f64 %exp_val, %e_p, %e_r, %e_one;\n\
7509 cvt.s64.s32 %e_ni64, %e_ni;\n\
7510 add.s64 %e_ni64, %e_ni64, 1023;\n\
7511 shl.b64 %e_bits, %e_ni64, 52;\n\
7512 mov.b64 %e_nf, %e_bits;\n\
7513 mul.f64 %exp_val, %exp_val, %e_nf;\n\
7514 add.f64 %sum_val, %sum_val, %exp_val;\n\
7515 cvt.u64.u32 %off, %j;\n\
7516 shl.b64 %off, %off, 3;\n\
7517 add.u64 %off, %out, %off;\n\
7518 add.u64 %off, %off, %row_off;\n\
7519 st.global.f64 [%off], %exp_val;\n\
7520 add.u32 %j, %j, %bdim;\n\
7521 bra SUM_EXP;\n\
7522SUM_EXP_DONE:\n\
7523\n\
7524 cvt.u64.u32 %off, %r_tid;\n\
7525 shl.b64 %off, %off, 3;\n\
7526 add.u64 %saddr, %sbase, %off;\n\
7527 st.shared.f64 [%saddr], %sum_val;\n\
7528 bar.sync 0;\n\
7529\n\
7530 mov.u32 %half, %bdim;\n\
7531SUM_REDUCE:\n\
7532 shr.u32 %half, %half, 1;\n\
7533 setp.eq.u32 %reduce_p, %half, 0;\n\
7534 @%reduce_p bra SUM_REDUCE_DONE;\n\
7535 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7536 @%reduce_p bra SUM_REDUCE_SKIP;\n\
7537 add.u32 %other_tid, %r_tid, %half;\n\
7538 cvt.u64.u32 %off, %other_tid;\n\
7539 shl.b64 %off, %off, 3;\n\
7540 add.u64 %saddr, %sbase, %off;\n\
7541 ld.shared.f64 %other_val, [%saddr];\n\
7542 cvt.u64.u32 %off, %r_tid;\n\
7543 shl.b64 %off, %off, 3;\n\
7544 add.u64 %saddr, %sbase, %off;\n\
7545 ld.shared.f64 %sum_val, [%saddr];\n\
7546 add.f64 %sum_val, %sum_val, %other_val;\n\
7547 st.shared.f64 [%saddr], %sum_val;\n\
7548SUM_REDUCE_SKIP:\n\
7549 bar.sync 0;\n\
7550 bra SUM_REDUCE;\n\
7551SUM_REDUCE_DONE:\n\
7552\n\
7553 ld.shared.f64 %sum_val, [sdata];\n\
7554 bar.sync 0;\n\
7555\n\
7556 div.rn.f64 %sum_val, %one, %sum_val;\n\
7557 mov.u32 %j, %r_tid;\n\
7558NORMALIZE:\n\
7559 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7560 @%loop_p bra NORMALIZE_DONE;\n\
7561 cvt.u64.u32 %off, %j;\n\
7562 shl.b64 %off, %off, 3;\n\
7563 add.u64 %off, %out, %off;\n\
7564 add.u64 %off, %off, %row_off;\n\
7565 ld.global.f64 %val, [%off];\n\
7566 mul.f64 %result, %val, %sum_val;\n\
7567 st.global.f64 [%off], %result;\n\
7568 add.u32 %j, %j, %bdim;\n\
7569 bra NORMALIZE;\n\
7570NORMALIZE_DONE:\n\
7571\n\
7572DONE:\n\
7573 ret;\n\
7574}\n\
7575";
7576
7577#[cfg(feature = "cuda")]
7582pub(crate) const DROPOUT_PTX: &str = "\
7583.version 7.0\n\
7584.target sm_52\n\
7585.address_size 64\n\
7586\n\
7587.visible .entry dropout_kernel(\n\
7588 .param .u64 input_ptr,\n\
7589 .param .u64 output_ptr,\n\
7590 .param .u32 n,\n\
7591 .param .u32 threshold,\n\
7592 .param .f32 scale,\n\
7593 .param .u32 seed\n\
7594) {\n\
7595 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %thresh, %seed_reg, %rng, %tmp;\n\
7596 .reg .u64 %in, %out, %off;\n\
7597 .reg .f32 %val, %scale_reg, %zero;\n\
7598 .reg .pred %p, %drop_p;\n\
7599\n\
7600 ld.param.u64 %in, [input_ptr];\n\
7601 ld.param.u64 %out, [output_ptr];\n\
7602 ld.param.u32 %n_reg, [n];\n\
7603 ld.param.u32 %thresh, [threshold];\n\
7604 ld.param.f32 %scale_reg, [scale];\n\
7605 ld.param.u32 %seed_reg, [seed];\n\
7606\n\
7607 mov.u32 %bid, %ctaid.x;\n\
7608 mov.u32 %bdim, %ntid.x;\n\
7609 mov.u32 %r_tid, %tid.x;\n\
7610 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
7611\n\
7612 setp.ge.u32 %p, %r_tid, %n_reg;\n\
7613 @%p bra DONE;\n\
7614\n\
7615 mul.lo.u32 %rng, %r_tid, 2654435761;\n\
7616 xor.b32 %rng, %rng, %seed_reg;\n\
7617 shl.b32 %tmp, %rng, 13;\n\
7618 xor.b32 %rng, %rng, %tmp;\n\
7619 shr.b32 %tmp, %rng, 17;\n\
7620 xor.b32 %rng, %rng, %tmp;\n\
7621 shl.b32 %tmp, %rng, 5;\n\
7622 xor.b32 %rng, %rng, %tmp;\n\
7623\n\
7624 cvt.u64.u32 %off, %r_tid;\n\
7625 shl.b64 %off, %off, 2;\n\
7626 add.u64 %in, %in, %off;\n\
7627 add.u64 %out, %out, %off;\n\
7628 ld.global.f32 %val, [%in];\n\
7629\n\
7630 setp.lo.u32 %drop_p, %rng, %thresh;\n\
7631 mov.f32 %zero, 0f00000000;\n\
7632 @%drop_p mov.f32 %val, %zero;\n\
7633 @!%drop_p mul.f32 %val, %val, %scale_reg;\n\
7634\n\
7635 st.global.f32 [%out], %val;\n\
7636\n\
7637DONE:\n\
7638 ret;\n\
7639}\n\
7640";
7641
7642
7643#[cfg(feature = "cuda")]
7666pub(crate) const BROADCAST_ADD_PTX: &str = "\
7667.version 7.0
7668.target sm_52
7669.address_size 64
7670
7671.visible .entry broadcast_add_kernel(
7672 .param .u64 a_ptr,
7673 .param .u64 b_ptr,
7674 .param .u64 out_ptr,
7675 .param .u64 a_strides_ptr,
7676 .param .u64 b_strides_ptr,
7677 .param .u64 out_shape_ptr,
7678 .param .u32 n,
7679 .param .u32 ndim
7680) {
7681 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7682 .reg .u32 %remaining, %a_idx, %b_idx, %d;
7683 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7684 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7685 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7686 .reg .f32 %va, %vb, %vr;
7687 .reg .pred %p, %loop_p;
7688
7689 ld.param.u64 %a, [a_ptr];
7690 ld.param.u64 %b, [b_ptr];
7691 ld.param.u64 %out, [out_ptr];
7692 ld.param.u64 %a_str, [a_strides_ptr];
7693 ld.param.u64 %b_str, [b_strides_ptr];
7694 ld.param.u64 %oshape, [out_shape_ptr];
7695 ld.param.u32 %n_reg, [n];
7696 ld.param.u32 %ndim_reg, [ndim];
7697
7698 // Global thread index.
7699 mov.u32 %bid, %ctaid.x;
7700 mov.u32 %bdim, %ntid.x;
7701 mov.u32 %r_tid, %tid.x;
7702 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7703
7704 setp.ge.u32 %p, %r_tid, %n_reg;
7705 @%p bra DONE;
7706
7707 // Decompose flat index into N-d coordinates and compute A/B indices.
7708 mov.u32 %remaining, %r_tid;
7709 mov.u32 %a_idx, 0;
7710 mov.u32 %b_idx, 0;
7711 mov.u32 %d, %ndim_reg;
7712
7713LOOP:
7714 setp.eq.u32 %loop_p, %d, 0;
7715 @%loop_p bra END_LOOP;
7716
7717 sub.u32 %d, %d, 1;
7718
7719 // Byte offset for dimension d: d * 4.
7720 cvt.u64.u32 %d64, %d;
7721 shl.b64 %d64, %d64, 2;
7722
7723 // Load out_shape[d].
7724 add.u64 %tmp, %oshape, %d64;
7725 ld.global.u32 %shape_d, [%tmp];
7726
7727 // Load a_strides[d] and b_strides[d].
7728 add.u64 %tmp, %a_str, %d64;
7729 ld.global.u32 %a_str_d, [%tmp];
7730 add.u64 %tmp, %b_str, %d64;
7731 ld.global.u32 %b_str_d, [%tmp];
7732
7733 // coord = remaining % shape_d; remaining /= shape_d.
7734 rem.u32 %coord, %remaining, %shape_d;
7735 div.u32 %remaining, %remaining, %shape_d;
7736
7737 // a_idx += coord * a_stride[d]; b_idx += coord * b_stride[d].
7738 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7739 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7740
7741 bra LOOP;
7742END_LOOP:
7743
7744 // Load a[a_idx] and b[b_idx] (f32 = 4 bytes).
7745 cvt.u64.u32 %off_a, %a_idx;
7746 shl.b64 %off_a, %off_a, 2;
7747 add.u64 %off_a, %a, %off_a;
7748 ld.global.f32 %va, [%off_a];
7749
7750 cvt.u64.u32 %off_b, %b_idx;
7751 shl.b64 %off_b, %off_b, 2;
7752 add.u64 %off_b, %b, %off_b;
7753 ld.global.f32 %vb, [%off_b];
7754
7755 // Operation: add.
7756 add.f32 %vr, %va, %vb;
7757
7758 // Store to out[tid].
7759 cvt.u64.u32 %off_out, %r_tid;
7760 shl.b64 %off_out, %off_out, 2;
7761 add.u64 %off_out, %out, %off_out;
7762 st.global.f32 [%off_out], %vr;
7763
7764DONE:
7765 ret;
7766}
7767";
7768
7769
7770#[cfg(feature = "cuda")]
7772pub(crate) const BROADCAST_SUB_PTX: &str = "\
7773.version 7.0
7774.target sm_52
7775.address_size 64
7776
7777.visible .entry broadcast_sub_kernel(
7778 .param .u64 a_ptr,
7779 .param .u64 b_ptr,
7780 .param .u64 out_ptr,
7781 .param .u64 a_strides_ptr,
7782 .param .u64 b_strides_ptr,
7783 .param .u64 out_shape_ptr,
7784 .param .u32 n,
7785 .param .u32 ndim
7786) {
7787 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7788 .reg .u32 %remaining, %a_idx, %b_idx, %d;
7789 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7790 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7791 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7792 .reg .f32 %va, %vb, %vr;
7793 .reg .pred %p, %loop_p;
7794
7795 ld.param.u64 %a, [a_ptr];
7796 ld.param.u64 %b, [b_ptr];
7797 ld.param.u64 %out, [out_ptr];
7798 ld.param.u64 %a_str, [a_strides_ptr];
7799 ld.param.u64 %b_str, [b_strides_ptr];
7800 ld.param.u64 %oshape, [out_shape_ptr];
7801 ld.param.u32 %n_reg, [n];
7802 ld.param.u32 %ndim_reg, [ndim];
7803
7804 mov.u32 %bid, %ctaid.x;
7805 mov.u32 %bdim, %ntid.x;
7806 mov.u32 %r_tid, %tid.x;
7807 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7808 setp.ge.u32 %p, %r_tid, %n_reg;
7809 @%p bra DONE;
7810
7811 mov.u32 %remaining, %r_tid;
7812 mov.u32 %a_idx, 0;
7813 mov.u32 %b_idx, 0;
7814 mov.u32 %d, %ndim_reg;
7815LOOP:
7816 setp.eq.u32 %loop_p, %d, 0;
7817 @%loop_p bra END_LOOP;
7818 sub.u32 %d, %d, 1;
7819 cvt.u64.u32 %d64, %d;
7820 shl.b64 %d64, %d64, 2;
7821 add.u64 %tmp, %oshape, %d64;
7822 ld.global.u32 %shape_d, [%tmp];
7823 add.u64 %tmp, %a_str, %d64;
7824 ld.global.u32 %a_str_d, [%tmp];
7825 add.u64 %tmp, %b_str, %d64;
7826 ld.global.u32 %b_str_d, [%tmp];
7827 rem.u32 %coord, %remaining, %shape_d;
7828 div.u32 %remaining, %remaining, %shape_d;
7829 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7830 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7831 bra LOOP;
7832END_LOOP:
7833
7834 cvt.u64.u32 %off_a, %a_idx;
7835 shl.b64 %off_a, %off_a, 2;
7836 add.u64 %off_a, %a, %off_a;
7837 ld.global.f32 %va, [%off_a];
7838 cvt.u64.u32 %off_b, %b_idx;
7839 shl.b64 %off_b, %off_b, 2;
7840 add.u64 %off_b, %b, %off_b;
7841 ld.global.f32 %vb, [%off_b];
7842
7843 sub.f32 %vr, %va, %vb;
7844
7845 cvt.u64.u32 %off_out, %r_tid;
7846 shl.b64 %off_out, %off_out, 2;
7847 add.u64 %off_out, %out, %off_out;
7848 st.global.f32 [%off_out], %vr;
7849DONE:
7850 ret;
7851}
7852";
7853
7854
7855#[cfg(feature = "cuda")]
7857pub(crate) const BROADCAST_MUL_PTX: &str = "\
7858.version 7.0
7859.target sm_52
7860.address_size 64
7861
7862.visible .entry broadcast_mul_kernel(
7863 .param .u64 a_ptr,
7864 .param .u64 b_ptr,
7865 .param .u64 out_ptr,
7866 .param .u64 a_strides_ptr,
7867 .param .u64 b_strides_ptr,
7868 .param .u64 out_shape_ptr,
7869 .param .u32 n,
7870 .param .u32 ndim
7871) {
7872 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7873 .reg .u32 %remaining, %a_idx, %b_idx, %d;
7874 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7875 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7876 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7877 .reg .f32 %va, %vb, %vr;
7878 .reg .pred %p, %loop_p;
7879
7880 ld.param.u64 %a, [a_ptr];
7881 ld.param.u64 %b, [b_ptr];
7882 ld.param.u64 %out, [out_ptr];
7883 ld.param.u64 %a_str, [a_strides_ptr];
7884 ld.param.u64 %b_str, [b_strides_ptr];
7885 ld.param.u64 %oshape, [out_shape_ptr];
7886 ld.param.u32 %n_reg, [n];
7887 ld.param.u32 %ndim_reg, [ndim];
7888
7889 mov.u32 %bid, %ctaid.x;
7890 mov.u32 %bdim, %ntid.x;
7891 mov.u32 %r_tid, %tid.x;
7892 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7893 setp.ge.u32 %p, %r_tid, %n_reg;
7894 @%p bra DONE;
7895
7896 mov.u32 %remaining, %r_tid;
7897 mov.u32 %a_idx, 0;
7898 mov.u32 %b_idx, 0;
7899 mov.u32 %d, %ndim_reg;
7900LOOP:
7901 setp.eq.u32 %loop_p, %d, 0;
7902 @%loop_p bra END_LOOP;
7903 sub.u32 %d, %d, 1;
7904 cvt.u64.u32 %d64, %d;
7905 shl.b64 %d64, %d64, 2;
7906 add.u64 %tmp, %oshape, %d64;
7907 ld.global.u32 %shape_d, [%tmp];
7908 add.u64 %tmp, %a_str, %d64;
7909 ld.global.u32 %a_str_d, [%tmp];
7910 add.u64 %tmp, %b_str, %d64;
7911 ld.global.u32 %b_str_d, [%tmp];
7912 rem.u32 %coord, %remaining, %shape_d;
7913 div.u32 %remaining, %remaining, %shape_d;
7914 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7915 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7916 bra LOOP;
7917END_LOOP:
7918
7919 cvt.u64.u32 %off_a, %a_idx;
7920 shl.b64 %off_a, %off_a, 2;
7921 add.u64 %off_a, %a, %off_a;
7922 ld.global.f32 %va, [%off_a];
7923 cvt.u64.u32 %off_b, %b_idx;
7924 shl.b64 %off_b, %off_b, 2;
7925 add.u64 %off_b, %b, %off_b;
7926 ld.global.f32 %vb, [%off_b];
7927
7928 mul.f32 %vr, %va, %vb;
7929
7930 cvt.u64.u32 %off_out, %r_tid;
7931 shl.b64 %off_out, %off_out, 2;
7932 add.u64 %off_out, %out, %off_out;
7933 st.global.f32 [%off_out], %vr;
7934DONE:
7935 ret;
7936}
7937";
7938
7939
7940#[cfg(feature = "cuda")]
7943pub(crate) const BROADCAST_DIV_PTX: &str = "\
7944.version 7.0
7945.target sm_52
7946.address_size 64
7947
7948.visible .entry broadcast_div_kernel(
7949 .param .u64 a_ptr,
7950 .param .u64 b_ptr,
7951 .param .u64 out_ptr,
7952 .param .u64 a_strides_ptr,
7953 .param .u64 b_strides_ptr,
7954 .param .u64 out_shape_ptr,
7955 .param .u32 n,
7956 .param .u32 ndim
7957) {
7958 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7959 .reg .u32 %remaining, %a_idx, %b_idx, %d;
7960 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7961 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7962 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7963 .reg .f32 %va, %vb, %vr;
7964 .reg .pred %p, %loop_p;
7965
7966 ld.param.u64 %a, [a_ptr];
7967 ld.param.u64 %b, [b_ptr];
7968 ld.param.u64 %out, [out_ptr];
7969 ld.param.u64 %a_str, [a_strides_ptr];
7970 ld.param.u64 %b_str, [b_strides_ptr];
7971 ld.param.u64 %oshape, [out_shape_ptr];
7972 ld.param.u32 %n_reg, [n];
7973 ld.param.u32 %ndim_reg, [ndim];
7974
7975 mov.u32 %bid, %ctaid.x;
7976 mov.u32 %bdim, %ntid.x;
7977 mov.u32 %r_tid, %tid.x;
7978 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7979 setp.ge.u32 %p, %r_tid, %n_reg;
7980 @%p bra DONE;
7981
7982 mov.u32 %remaining, %r_tid;
7983 mov.u32 %a_idx, 0;
7984 mov.u32 %b_idx, 0;
7985 mov.u32 %d, %ndim_reg;
7986LOOP:
7987 setp.eq.u32 %loop_p, %d, 0;
7988 @%loop_p bra END_LOOP;
7989 sub.u32 %d, %d, 1;
7990 cvt.u64.u32 %d64, %d;
7991 shl.b64 %d64, %d64, 2;
7992 add.u64 %tmp, %oshape, %d64;
7993 ld.global.u32 %shape_d, [%tmp];
7994 add.u64 %tmp, %a_str, %d64;
7995 ld.global.u32 %a_str_d, [%tmp];
7996 add.u64 %tmp, %b_str, %d64;
7997 ld.global.u32 %b_str_d, [%tmp];
7998 rem.u32 %coord, %remaining, %shape_d;
7999 div.u32 %remaining, %remaining, %shape_d;
8000 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
8001 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
8002 bra LOOP;
8003END_LOOP:
8004
8005 cvt.u64.u32 %off_a, %a_idx;
8006 shl.b64 %off_a, %off_a, 2;
8007 add.u64 %off_a, %a, %off_a;
8008 ld.global.f32 %va, [%off_a];
8009 cvt.u64.u32 %off_b, %b_idx;
8010 shl.b64 %off_b, %off_b, 2;
8011 add.u64 %off_b, %b, %off_b;
8012 ld.global.f32 %vb, [%off_b];
8013
8014 div.f32 %vr, %va, %vb;
8015
8016 cvt.u64.u32 %off_out, %r_tid;
8017 shl.b64 %off_out, %off_out, 2;
8018 add.u64 %off_out, %out, %off_out;
8019 st.global.f32 [%off_out], %vr;
8020DONE:
8021 ret;
8022}
8023";
8024
8025
8026#[cfg(feature = "cuda")]
8034pub(crate) const STRIDED_SPLIT_PTX: &str = "\
8035.version 7.0
8036.target sm_52
8037.address_size 64
8038
8039.visible .entry strided_split_kernel(
8040 .param .u64 input_ptr,
8041 .param .u64 output_ptr,
8042 .param .u32 total_along_axis,
8043 .param .u32 split_offset,
8044 .param .u32 split_size,
8045 .param .u32 inner_size,
8046 .param .u32 n
8047) {
8048 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8049 .reg .u32 %total_ax, %sp_off, %sp_sz, %inner_sz;
8050 .reg .u32 %outer_idx, %within, %chunk_stride, %src_idx, %base_off, %tmp;
8051 .reg .u64 %in, %out, %off;
8052 .reg .f32 %val;
8053 .reg .pred %p;
8054
8055 ld.param.u64 %in, [input_ptr];
8056 ld.param.u64 %out, [output_ptr];
8057 ld.param.u32 %total_ax, [total_along_axis];
8058 ld.param.u32 %sp_off, [split_offset];
8059 ld.param.u32 %sp_sz, [split_size];
8060 ld.param.u32 %inner_sz, [inner_size];
8061 ld.param.u32 %n_reg, [n];
8062
8063 mov.u32 %bid, %ctaid.x;
8064 mov.u32 %bdim, %ntid.x;
8065 mov.u32 %r_tid, %tid.x;
8066 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8067
8068 setp.ge.u32 %p, %r_tid, %n_reg;
8069 @%p bra DONE;
8070
8071 // chunk_stride = split_size * inner_size
8072 mul.lo.u32 %chunk_stride, %sp_sz, %inner_sz;
8073
8074 // outer_idx = r_tid / chunk_stride
8075 div.u32 %outer_idx, %r_tid, %chunk_stride;
8076
8077 // within = r_tid % chunk_stride
8078 rem.u32 %within, %r_tid, %chunk_stride;
8079
8080 // base_off = split_offset * inner_size
8081 mul.lo.u32 %base_off, %sp_off, %inner_sz;
8082
8083 // src_idx = outer_idx * total_along_axis * inner_size + base_off + within
8084 mul.lo.u32 %src_idx, %outer_idx, %total_ax;
8085 mul.lo.u32 %src_idx, %src_idx, %inner_sz;
8086 add.u32 %src_idx, %src_idx, %base_off;
8087 add.u32 %src_idx, %src_idx, %within;
8088
8089 // Load from in[src_idx]
8090 cvt.u64.u32 %off, %src_idx;
8091 shl.b64 %off, %off, 2;
8092 add.u64 %off, %in, %off;
8093 ld.global.f32 %val, [%off];
8094
8095 // Store to out[r_tid]
8096 cvt.u64.u32 %off, %r_tid;
8097 shl.b64 %off, %off, 2;
8098 add.u64 %off, %out, %off;
8099 st.global.f32 [%off], %val;
8100
8101DONE:
8102 ret;
8103}
8104";
8105
8106
8107#[cfg(feature = "cuda")]
8116pub(crate) const STRIDED_CAT_PTX: &str = "\
8117.version 7.0
8118.target sm_52
8119.address_size 64
8120
8121.visible .entry strided_cat_kernel(
8122 .param .u64 input_ptr,
8123 .param .u64 output_ptr,
8124 .param .u32 total_along_axis,
8125 .param .u32 cat_offset,
8126 .param .u32 part_size,
8127 .param .u32 inner_size,
8128 .param .u32 n
8129) {
8130 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8131 .reg .u32 %total_ax, %cat_off, %part_sz, %inner_sz;
8132 .reg .u32 %outer_idx, %within, %chunk_stride, %dst_idx, %base_off;
8133 .reg .u64 %in, %out, %off;
8134 .reg .f32 %val;
8135 .reg .pred %p;
8136
8137 ld.param.u64 %in, [input_ptr];
8138 ld.param.u64 %out, [output_ptr];
8139 ld.param.u32 %total_ax, [total_along_axis];
8140 ld.param.u32 %cat_off, [cat_offset];
8141 ld.param.u32 %part_sz, [part_size];
8142 ld.param.u32 %inner_sz, [inner_size];
8143 ld.param.u32 %n_reg, [n];
8144
8145 mov.u32 %bid, %ctaid.x;
8146 mov.u32 %bdim, %ntid.x;
8147 mov.u32 %r_tid, %tid.x;
8148 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8149
8150 setp.ge.u32 %p, %r_tid, %n_reg;
8151 @%p bra DONE;
8152
8153 // chunk_stride = part_size * inner_size
8154 mul.lo.u32 %chunk_stride, %part_sz, %inner_sz;
8155
8156 // outer_idx = r_tid / chunk_stride
8157 div.u32 %outer_idx, %r_tid, %chunk_stride;
8158
8159 // within = r_tid % chunk_stride
8160 rem.u32 %within, %r_tid, %chunk_stride;
8161
8162 // base_off = cat_offset * inner_size
8163 mul.lo.u32 %base_off, %cat_off, %inner_sz;
8164
8165 // dst_idx = outer_idx * total_along_axis * inner_size + base_off + within
8166 mul.lo.u32 %dst_idx, %outer_idx, %total_ax;
8167 mul.lo.u32 %dst_idx, %dst_idx, %inner_sz;
8168 add.u32 %dst_idx, %dst_idx, %base_off;
8169 add.u32 %dst_idx, %dst_idx, %within;
8170
8171 // Load from in[r_tid]
8172 cvt.u64.u32 %off, %r_tid;
8173 shl.b64 %off, %off, 2;
8174 add.u64 %off, %in, %off;
8175 ld.global.f32 %val, [%off];
8176
8177 // Store to out[dst_idx]
8178 cvt.u64.u32 %off, %dst_idx;
8179 shl.b64 %off, %off, 2;
8180 add.u64 %off, %out, %off;
8181 st.global.f32 [%off], %val;
8182
8183DONE:
8184 ret;
8185}
8186";
8187
8188
8189#[cfg(feature = "cuda")]
8209pub(crate) const STRIDED_COPY_PTX: &str = "\
8210.version 7.0
8211.target sm_52
8212.address_size 64
8213
8214.visible .entry strided_copy_kernel(
8215 .param .u64 input_ptr,
8216 .param .u64 output_ptr,
8217 .param .u32 src_offset_base,
8218 .param .u32 n,
8219 .param .u32 os0, .param .u32 os1, .param .u32 os2, .param .u32 os3,
8220 .param .u32 os4, .param .u32 os5, .param .u32 os6, .param .u32 os7,
8221 .param .u32 ss0, .param .u32 ss1, .param .u32 ss2, .param .u32 ss3,
8222 .param .u32 ss4, .param .u32 ss5, .param .u32 ss6, .param .u32 ss7
8223) {
8224 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8225 .reg .u32 %flat, %src_idx, %coord, %tmp, %os, %ss;
8226 .reg .u64 %in, %out, %off;
8227 .reg .f32 %val;
8228 .reg .pred %p;
8229
8230 ld.param.u64 %in, [input_ptr];
8231 ld.param.u64 %out, [output_ptr];
8232 ld.param.u32 %src_idx, [src_offset_base];
8233 ld.param.u32 %n_reg, [n];
8234
8235 mov.u32 %bid, %ctaid.x;
8236 mov.u32 %bdim, %ntid.x;
8237 mov.u32 %r_tid, %tid.x;
8238 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8239
8240 setp.ge.u32 %p, %r_tid, %n_reg;
8241 @%p bra DONE;
8242
8243 mov.u32 %flat, %r_tid;
8244
8245 // Dim 0
8246 ld.param.u32 %os, [os0];
8247 ld.param.u32 %ss, [ss0];
8248 div.u32 %coord, %flat, %os;
8249 mul.lo.u32 %tmp, %coord, %os;
8250 sub.u32 %flat, %flat, %tmp;
8251 mul.lo.u32 %tmp, %coord, %ss;
8252 add.u32 %src_idx, %src_idx, %tmp;
8253
8254 // Dim 1
8255 ld.param.u32 %os, [os1];
8256 ld.param.u32 %ss, [ss1];
8257 div.u32 %coord, %flat, %os;
8258 mul.lo.u32 %tmp, %coord, %os;
8259 sub.u32 %flat, %flat, %tmp;
8260 mul.lo.u32 %tmp, %coord, %ss;
8261 add.u32 %src_idx, %src_idx, %tmp;
8262
8263 // Dim 2
8264 ld.param.u32 %os, [os2];
8265 ld.param.u32 %ss, [ss2];
8266 div.u32 %coord, %flat, %os;
8267 mul.lo.u32 %tmp, %coord, %os;
8268 sub.u32 %flat, %flat, %tmp;
8269 mul.lo.u32 %tmp, %coord, %ss;
8270 add.u32 %src_idx, %src_idx, %tmp;
8271
8272 // Dim 3
8273 ld.param.u32 %os, [os3];
8274 ld.param.u32 %ss, [ss3];
8275 div.u32 %coord, %flat, %os;
8276 mul.lo.u32 %tmp, %coord, %os;
8277 sub.u32 %flat, %flat, %tmp;
8278 mul.lo.u32 %tmp, %coord, %ss;
8279 add.u32 %src_idx, %src_idx, %tmp;
8280
8281 // Dim 4
8282 ld.param.u32 %os, [os4];
8283 ld.param.u32 %ss, [ss4];
8284 div.u32 %coord, %flat, %os;
8285 mul.lo.u32 %tmp, %coord, %os;
8286 sub.u32 %flat, %flat, %tmp;
8287 mul.lo.u32 %tmp, %coord, %ss;
8288 add.u32 %src_idx, %src_idx, %tmp;
8289
8290 // Dim 5
8291 ld.param.u32 %os, [os5];
8292 ld.param.u32 %ss, [ss5];
8293 div.u32 %coord, %flat, %os;
8294 mul.lo.u32 %tmp, %coord, %os;
8295 sub.u32 %flat, %flat, %tmp;
8296 mul.lo.u32 %tmp, %coord, %ss;
8297 add.u32 %src_idx, %src_idx, %tmp;
8298
8299 // Dim 6
8300 ld.param.u32 %os, [os6];
8301 ld.param.u32 %ss, [ss6];
8302 div.u32 %coord, %flat, %os;
8303 mul.lo.u32 %tmp, %coord, %os;
8304 sub.u32 %flat, %flat, %tmp;
8305 mul.lo.u32 %tmp, %coord, %ss;
8306 add.u32 %src_idx, %src_idx, %tmp;
8307
8308 // Dim 7
8309 ld.param.u32 %os, [os7];
8310 ld.param.u32 %ss, [ss7];
8311 div.u32 %coord, %flat, %os;
8312 mul.lo.u32 %tmp, %coord, %os;
8313 sub.u32 %flat, %flat, %tmp;
8314 mul.lo.u32 %tmp, %coord, %ss;
8315 add.u32 %src_idx, %src_idx, %tmp;
8316
8317 // Load from in[src_idx]
8318 cvt.u64.u32 %off, %src_idx;
8319 shl.b64 %off, %off, 2;
8320 add.u64 %off, %in, %off;
8321 ld.global.f32 %val, [%off];
8322
8323 // Store to out[r_tid]
8324 cvt.u64.u32 %off, %r_tid;
8325 shl.b64 %off, %off, 2;
8326 add.u64 %off, %out, %off;
8327 st.global.f32 [%off], %val;
8328
8329DONE:
8330 ret;
8331}
8332";
8333
8334
8335#[cfg(feature = "cuda")]
8337pub(crate) const DIV_PTX: &str = "\
8338.version 7.0
8339.target sm_52
8340.address_size 64
8341
8342.visible .entry div_kernel(
8343 .param .u64 a_ptr,
8344 .param .u64 b_ptr,
8345 .param .u64 out_ptr,
8346 .param .u32 n
8347) {
8348 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8349 .reg .u64 %a, %b, %out, %off;
8350 .reg .f32 %va, %vb, %vr;
8351 .reg .pred %p;
8352
8353 ld.param.u64 %a, [a_ptr];
8354 ld.param.u64 %b, [b_ptr];
8355 ld.param.u64 %out, [out_ptr];
8356 ld.param.u32 %n_reg, [n];
8357
8358 mov.u32 %bid, %ctaid.x;
8359 mov.u32 %bdim, %ntid.x;
8360 mov.u32 %r_tid, %tid.x;
8361 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8362
8363 setp.ge.u32 %p, %r_tid, %n_reg;
8364 @%p bra DONE;
8365
8366 cvt.u64.u32 %off, %r_tid;
8367 shl.b64 %off, %off, 2;
8368
8369 add.u64 %a, %a, %off;
8370 add.u64 %b, %b, %off;
8371 add.u64 %out, %out, %off;
8372
8373 ld.global.f32 %va, [%a];
8374 ld.global.f32 %vb, [%b];
8375 div.rn.f32 %vr, %va, %vb;
8376 st.global.f32 [%out], %vr;
8377
8378DONE:
8379 ret;
8380}
8381";
8382
8383
8384#[cfg(feature = "cuda")]
8386pub(crate) const EXP_PTX: &str = "\
8387.version 7.0
8388.target sm_52
8389.address_size 64
8390
8391.visible .entry exp_kernel(
8392 .param .u64 a_ptr,
8393 .param .u64 out_ptr,
8394 .param .u32 n
8395) {
8396 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8397 .reg .u64 %a, %out, %off;
8398 .reg .f32 %va, %vr;
8399 .reg .pred %p;
8400
8401 ld.param.u64 %a, [a_ptr];
8402 ld.param.u64 %out, [out_ptr];
8403 ld.param.u32 %n_reg, [n];
8404
8405 mov.u32 %bid, %ctaid.x;
8406 mov.u32 %bdim, %ntid.x;
8407 mov.u32 %r_tid, %tid.x;
8408 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8409
8410 setp.ge.u32 %p, %r_tid, %n_reg;
8411 @%p bra DONE;
8412
8413 cvt.u64.u32 %off, %r_tid;
8414 shl.b64 %off, %off, 2;
8415
8416 add.u64 %a, %a, %off;
8417 add.u64 %out, %out, %off;
8418
8419 ld.global.f32 %va, [%a];
8420 // PTX ex2.approx computes 2^x; use the identity exp(x) = 2^(x * log2(e))
8421 // log2(e) = 1.4426950408889634
8422 mul.f32 %va, %va, 0f3FB8AA3B;
8423 ex2.approx.f32 %vr, %va;
8424 st.global.f32 [%out], %vr;
8425
8426DONE:
8427 ret;
8428}
8429";
8430
8431#[cfg(feature = "cuda")]
8435pub(crate) const EXP_F64_PTX: &str = "\
8444.version 7.0
8445.target sm_52
8446.address_size 64
8447
8448.visible .entry exp_f64_kernel(
8449 .param .u64 a_ptr,
8450 .param .u64 out_ptr,
8451 .param .u32 n
8452) {
8453 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8454 .reg .u64 %a, %out, %off;
8455 .reg .f64 %x, %vr;
8456 .reg .f64 %log2e, %nf, %r;
8457 .reg .f64 %p, %one, %half;
8458 .reg .s32 %ni;
8459 .reg .s64 %ni64, %exp_bits;
8460 .reg .pred %p_bounds, %p_tid;
8461
8462 ld.param.u64 %a, [a_ptr];
8463 ld.param.u64 %out, [out_ptr];
8464 ld.param.u32 %n_reg, [n];
8465
8466 mov.u32 %bid, %ctaid.x;
8467 mov.u32 %bdim, %ntid.x;
8468 mov.u32 %r_tid, %tid.x;
8469 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8470
8471 setp.ge.u32 %p_tid, %r_tid, %n_reg;
8472 @%p_tid bra DONE;
8473
8474 cvt.u64.u32 %off, %r_tid;
8475 shl.b64 %off, %off, 3;
8476 add.u64 %a, %a, %off;
8477 add.u64 %out, %out, %off;
8478
8479 ld.global.f64 %x, [%a];
8480
8481 // Constants
8482 mov.f64 %log2e, 0d3FF71547652B82FE; // log2(e) = 1.4426950408889634
8483 mov.f64 %ln2_hi, 0d3FE62E42FEFA3800; // ln(2) high bits
8484 mov.f64 %ln2_lo, 0d3D2EF35793C76730; // ln(2) low bits
8485 mov.f64 %one, 0d3FF0000000000000; // 1.0
8486 mov.f64 %half, 0d3FE0000000000000; // 0.5
8487
8488 // n = round(x * log2(e))
8489 mul.f64 %nf, %x, %log2e;
8490 cvt.rni.f64.f64 %nf, %nf; // round to nearest integer
8491 cvt.rni.s32.f64 %ni, %nf; // integer n
8492
8493 // r = x - n * ln2 (Cody-Waite two-step for precision)
8494 fma.rn.f64 %r, %nf, 0dBFE62E42FEFA3800, %x; // r = x - n*ln2_hi
8495 fma.rn.f64 %r, %nf, 0dBD2EF35793C76730, %r; // r -= n*ln2_lo
8496
8497 // Horner polynomial for exp(r) - 1 - r = r^2 * (1/2! + r*(1/3! + r*(1/4! + ...)))
8498 // p starts at 1/11!, accumulates down to 1/2!
8499 mov.f64 %p, 0d3E21EED8EFF8D898; // 1/11! = 2.505e-8
8500 fma.rn.f64 %p, %p, %r, 0d3E5AE64567F544E4; // 1/10! = 2.756e-7
8501 fma.rn.f64 %p, %p, %r, 0d3E927E4FB7789F5C; // 1/9! = 2.756e-6
8502 fma.rn.f64 %p, %p, %r, 0d3EC71DE3A556C734; // 1/8! = 2.480e-5
8503 fma.rn.f64 %p, %p, %r, 0d3EFA01A01A01A01A; // 1/7! = 1.984e-4
8504 fma.rn.f64 %p, %p, %r, 0d3F2A01A01A01A01A; // 1/6! = 1.389e-3
8505 fma.rn.f64 %p, %p, %r, 0d3F56C16C16C16C17; // 1/5! = 8.333e-3
8506 fma.rn.f64 %p, %p, %r, 0d3F811111111111111; // 1/4! = 4.167e-2
8507 fma.rn.f64 %p, %p, %r, 0d3FC5555555555555; // 1/3! = 1.667e-1
8508 fma.rn.f64 %p, %p, %r, %half; // 1/2! = 5.000e-1
8509
8510 // exp(r) = 1 + r + r^2 * p => 1 + r*(1 + r*p)
8511 fma.rn.f64 %p, %p, %r, %one; // p = r*p + 1
8512 fma.rn.f64 %vr, %p, %r, %one; // vr = p*r + 1 = exp(r)
8513
8514 // Scale by 2^n: multiply by constructing the f64 bit pattern for 2^n.
8515 // IEEE 754 f64: 2^n has exponent field = n + 1023, no mantissa bits.
8516 // Bit pattern: (n + 1023) << 52.
8517 cvt.s64.s32 %ni64, %ni;
8518 add.s64 %ni64, %ni64, 1023;
8519 shl.b64 %exp_bits, %ni64, 52;
8520 mov.b64 %nf, %exp_bits; // reinterpret as f64 = 2^n
8521 mul.f64 %vr, %vr, %nf;
8522
8523 st.global.f64 [%out], %vr;
8524
8525DONE:
8526 ret;
8527}
8528";
8529
8530#[cfg(feature = "cuda")]
8532pub(crate) const LOG_PTX: &str = "\
8533.version 7.0
8534.target sm_52
8535.address_size 64
8536
8537.visible .entry log_kernel(
8538 .param .u64 a_ptr,
8539 .param .u64 out_ptr,
8540 .param .u32 n
8541) {
8542 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8543 .reg .u64 %a, %out, %off;
8544 .reg .f32 %va, %vr;
8545 .reg .pred %p;
8546
8547 ld.param.u64 %a, [a_ptr];
8548 ld.param.u64 %out, [out_ptr];
8549 ld.param.u32 %n_reg, [n];
8550
8551 mov.u32 %bid, %ctaid.x;
8552 mov.u32 %bdim, %ntid.x;
8553 mov.u32 %r_tid, %tid.x;
8554 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8555
8556 setp.ge.u32 %p, %r_tid, %n_reg;
8557 @%p bra DONE;
8558
8559 cvt.u64.u32 %off, %r_tid;
8560 shl.b64 %off, %off, 2;
8561
8562 add.u64 %a, %a, %off;
8563 add.u64 %out, %out, %off;
8564
8565 ld.global.f32 %va, [%a];
8566 // PTX lg2.approx computes log2(x); use the identity ln(x) = log2(x) / log2(e)
8567 // 1/log2(e) = ln(2) = 0.6931471805599453
8568 lg2.approx.f32 %vr, %va;
8569 mul.f32 %vr, %vr, 0f3F317218;
8570 st.global.f32 [%out], %vr;
8571
8572DONE:
8573 ret;
8574}
8575";
8576
8577#[cfg(feature = "cuda")]
8581pub(crate) const LOG_F64_PTX: &str = "\
8590.version 7.0
8591.target sm_52
8592.address_size 64
8593
8594.visible .entry log_f64_kernel(
8595 .param .u64 a_ptr,
8596 .param .u64 out_ptr,
8597 .param .u32 n
8598) {
8599 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8600 .reg .u64 %a, %out, %off;
8601 .reg .u64 %xbits, %mantissa_bits, %bias_bits;
8602 .reg .f64 %x, %vr, %m, %f, %f2, %s, %p;
8603 .reg .f64 %ln2_hi, %ln2_lo, %one, %two;
8604 .reg .s32 %exp_i;
8605 .reg .s64 %exp64;
8606 .reg .f64 %nf;
8607 .reg .pred %p_tid;
8608
8609 ld.param.u64 %a, [a_ptr];
8610 ld.param.u64 %out, [out_ptr];
8611 ld.param.u32 %n_reg, [n];
8612
8613 mov.u32 %bid, %ctaid.x;
8614 mov.u32 %bdim, %ntid.x;
8615 mov.u32 %r_tid, %tid.x;
8616 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8617
8618 setp.ge.u32 %p_tid, %r_tid, %n_reg;
8619 @%p_tid bra DONE;
8620
8621 cvt.u64.u32 %off, %r_tid;
8622 shl.b64 %off, %off, 3;
8623 add.u64 %a, %a, %off;
8624 add.u64 %out, %out, %off;
8625
8626 ld.global.f64 %x, [%a];
8627
8628 mov.f64 %ln2_hi, 0d3FE62E42FEFA39EF; // ln(2) = 0.6931471805599453
8629 mov.f64 %one, 0d3FF0000000000000;
8630 mov.f64 %two, 0d4000000000000000;
8631
8632 // Extract exponent: n = exponent_field - 1023
8633 mov.b64 %xbits, %x;
8634 shr.u64 %exp64, %xbits, 52;
8635 and.b64 %exp64, %exp64, 2047; // 11-bit exponent field
8636 sub.s64 %exp64, %exp64, 1023;
8637 cvt.rn.f64.s64 %nf, %exp64; // n as f64
8638
8639 // Extract mantissa m: set exponent to 1023 (so m is in [1, 2))
8640 mov.u64 %bias_bits, 0x3FF0000000000000; // exponent = 1023
8641 and.b64 %mantissa_bits, %xbits, 0x000FFFFFFFFFFFFF; // mantissa bits
8642 or.b64 %mantissa_bits, %mantissa_bits, %bias_bits;
8643 mov.b64 %m, %mantissa_bits; // m in [1.0, 2.0)
8644
8645 // f = (m - 1) / (m + 1) — maps [1,2) to [0, 1/3)
8646 sub.f64 %f, %m, %one;
8647 add.f64 %s, %m, %one;
8648 div.rn.f64 %f, %f, %s;
8649
8650 // ln(m) = 2*f + 2*f^3/3 + 2*f^5/5 + 2*f^7/7 + 2*f^9/9 + 2*f^11/11
8651 // Horner: ln(m) = 2*f*(1 + f^2*(1/3 + f^2*(1/5 + f^2*(1/7 + f^2*(1/9 + f^2/11)))))
8652 mul.f64 %f2, %f, %f;
8653
8654 // p = 1/11
8655 mov.f64 %p, 0d3FB745D1745D1746;
8656 // p = p*f2 + 1/9
8657 fma.rn.f64 %p, %p, %f2, 0d3FC1C71C71C71C72;
8658 // p = p*f2 + 1/7
8659 fma.rn.f64 %p, %p, %f2, 0d3FC2492492492492;
8660 // p = p*f2 + 1/5
8661 fma.rn.f64 %p, %p, %f2, 0d3FC999999999999A;
8662 // p = p*f2 + 1/3
8663 fma.rn.f64 %p, %p, %f2, 0d3FD5555555555555;
8664 // p = p*f2 + 1
8665 fma.rn.f64 %p, %p, %f2, %one;
8666
8667 // ln(m) = 2*f*p
8668 mul.f64 %p, %p, %f;
8669 add.f64 %p, %p, %p; // * 2
8670
8671 // ln(x) = n*ln(2) + ln(m)
8672 fma.rn.f64 %vr, %nf, %ln2_hi, %p;
8673
8674 st.global.f64 [%out], %vr;
8675
8676DONE:
8677 ret;
8678}
8679";
8680
8681#[cfg(feature = "cuda")]
8683pub(crate) const SQRT_PTX: &str = "\
8684.version 7.0
8685.target sm_52
8686.address_size 64
8687
8688.visible .entry sqrt_kernel(
8689 .param .u64 a_ptr,
8690 .param .u64 out_ptr,
8691 .param .u32 n
8692) {
8693 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8694 .reg .u64 %a, %out, %off;
8695 .reg .f32 %va, %vr;
8696 .reg .pred %p;
8697
8698 ld.param.u64 %a, [a_ptr];
8699 ld.param.u64 %out, [out_ptr];
8700 ld.param.u32 %n_reg, [n];
8701
8702 mov.u32 %bid, %ctaid.x;
8703 mov.u32 %bdim, %ntid.x;
8704 mov.u32 %r_tid, %tid.x;
8705 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8706
8707 setp.ge.u32 %p, %r_tid, %n_reg;
8708 @%p bra DONE;
8709
8710 cvt.u64.u32 %off, %r_tid;
8711 shl.b64 %off, %off, 2;
8712
8713 add.u64 %a, %a, %off;
8714 add.u64 %out, %out, %off;
8715
8716 ld.global.f32 %va, [%a];
8717 sqrt.rn.f32 %vr, %va;
8718 st.global.f32 [%out], %vr;
8719
8720DONE:
8721 ret;
8722}
8723";
8724
8725
8726#[cfg(feature = "cuda")]
8729pub(crate) const POW_PTX: &str = "\
8730.version 7.0
8731.target sm_52
8732.address_size 64
8733
8734.visible .entry pow_kernel(
8735 .param .u64 a_ptr,
8736 .param .u64 out_ptr,
8737 .param .f32 exponent,
8738 .param .u32 n
8739) {
8740 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8741 .reg .u64 %a, %out, %off;
8742 .reg .f32 %va, %vr, %exp, %lg;
8743 .reg .pred %p;
8744
8745 ld.param.u64 %a, [a_ptr];
8746 ld.param.u64 %out, [out_ptr];
8747 ld.param.f32 %exp, [exponent];
8748 ld.param.u32 %n_reg, [n];
8749
8750 mov.u32 %bid, %ctaid.x;
8751 mov.u32 %bdim, %ntid.x;
8752 mov.u32 %r_tid, %tid.x;
8753 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8754
8755 setp.ge.u32 %p, %r_tid, %n_reg;
8756 @%p bra DONE;
8757
8758 cvt.u64.u32 %off, %r_tid;
8759 shl.b64 %off, %off, 2;
8760
8761 add.u64 %a, %a, %off;
8762 add.u64 %out, %out, %off;
8763
8764 ld.global.f32 %va, [%a];
8765 // x^e = 2^(e * log2(x))
8766 lg2.approx.f32 %lg, %va;
8767 mul.f32 %lg, %lg, %exp;
8768 ex2.approx.f32 %vr, %lg;
8769 st.global.f32 [%out], %vr;
8770
8771DONE:
8772 ret;
8773}
8774";
8775
8776#[cfg(feature = "cuda")]
8781pub(crate) const POW_F64_PTX: &str = "\
8782.version 7.0
8783.target sm_52
8784.address_size 64
8785
8786.visible .entry pow_f64_kernel(
8787 .param .u64 a_ptr,
8788 .param .u64 out_ptr,
8789 .param .f64 exponent,
8790 .param .u32 n
8791) {
8792 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8793 .reg .u64 %a, %out, %off;
8794 .reg .f64 %va, %vr, %exp64, %one, %two;
8795 // log registers
8796 .reg .u64 %l_xbits, %l_mbits, %l_bias;
8797 .reg .s64 %l_exp64;
8798 .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2, %l_lnx;
8799 // exp registers
8800 .reg .f64 %e_z, %e_nf, %e_r, %e_p, %e_half;
8801 .reg .s32 %e_ni;
8802 .reg .s64 %e_ni64, %e_bits;
8803 .reg .pred %p;
8804
8805 ld.param.u64 %a, [a_ptr];
8806 ld.param.u64 %out, [out_ptr];
8807 ld.param.f64 %exp64, [exponent];
8808 ld.param.u32 %n_reg, [n];
8809
8810 mov.u32 %bid, %ctaid.x;
8811 mov.u32 %bdim, %ntid.x;
8812 mov.u32 %r_tid, %tid.x;
8813 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8814
8815 setp.ge.u32 %p, %r_tid, %n_reg;
8816 @%p bra DONE;
8817
8818 cvt.u64.u32 %off, %r_tid;
8819 shl.b64 %off, %off, 3;
8820
8821 add.u64 %a, %a, %off;
8822 add.u64 %out, %out, %off;
8823
8824 ld.global.f64 %va, [%a];
8825 mov.f64 %one, 0d3FF0000000000000;
8826 mov.f64 %two, 0d4000000000000000;
8827
8828 // === ln(va) via argument reduction ===
8829 // Decompose va = 2^n * m, m in [1,2), ln(va) = n*ln(2) + ln(m)
8830 mov.b64 %l_xbits, %va;
8831 shr.u64 %l_exp64, %l_xbits, 52;
8832 and.b64 %l_exp64, %l_exp64, 2047;
8833 sub.s64 %l_exp64, %l_exp64, 1023;
8834 cvt.rn.f64.s64 %l_nf, %l_exp64;
8835
8836 mov.u64 %l_bias, 0x3FF0000000000000;
8837 and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
8838 or.b64 %l_mbits, %l_mbits, %l_bias;
8839 mov.b64 %l_m, %l_mbits;
8840
8841 // f = (m-1)/(m+1)
8842 sub.f64 %l_f, %l_m, %one;
8843 add.f64 %l_s, %l_m, %one;
8844 div.rn.f64 %l_f, %l_f, %l_s;
8845 mul.f64 %l_f2, %l_f, %l_f;
8846
8847 // Horner: p = 1/11 + f2*(1/9 + f2*(1/7 + f2*(1/5 + f2*(1/3 + f2*1))))
8848 mov.f64 %l_p, 0d3FB745D1745D1746;
8849 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
8850 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
8851 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
8852 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
8853 fma.rn.f64 %l_p, %l_p, %l_f2, %one;
8854
8855 // ln(m) = 2*f*p
8856 mul.f64 %l_p, %l_p, %l_f;
8857 add.f64 %l_p, %l_p, %l_p;
8858
8859 // ln(x) = n*ln(2) + ln(m)
8860 mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
8861 fma.rn.f64 %l_lnx, %l_nf, %l_ln2, %l_p;
8862
8863 // === exp(exponent * ln(x)) ===
8864 mul.f64 %e_z, %exp64, %l_lnx;
8865
8866 mov.f64 %e_half, 0d3FE0000000000000;
8867 fma.rn.f64 %e_nf, %e_z, 0d3FF71547652B82FE, %e_half;
8868 cvt.rmi.f64.f64 %e_nf, %e_nf;
8869 cvt.rni.s32.f64 %e_ni, %e_nf;
8870 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %e_z;
8871 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
8872 mov.f64 %e_p, 0d3E21EED8EFF8D898;
8873 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
8874 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
8875 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
8876 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
8877 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
8878 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
8879 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
8880 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
8881 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
8882 fma.rn.f64 %e_p, %e_p, %e_r, %one;
8883 fma.rn.f64 %vr, %e_p, %e_r, %one;
8884 cvt.s64.s32 %e_ni64, %e_ni;
8885 add.s64 %e_ni64, %e_ni64, 1023;
8886 shl.b64 %e_bits, %e_ni64, 52;
8887 mov.b64 %e_nf, %e_bits;
8888 mul.f64 %vr, %vr, %e_nf;
8889
8890 st.global.f64 [%out], %vr;
8891
8892DONE:
8893 ret;
8894}
8895";
8896
8897#[cfg(feature = "cuda")]
8899pub(crate) const ABS_PTX: &str = "\
8900.version 7.0
8901.target sm_52
8902.address_size 64
8903
8904.visible .entry abs_kernel(
8905 .param .u64 a_ptr,
8906 .param .u64 out_ptr,
8907 .param .u32 n
8908) {
8909 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8910 .reg .u64 %a, %out, %off;
8911 .reg .f32 %va, %vr;
8912 .reg .pred %p;
8913
8914 ld.param.u64 %a, [a_ptr];
8915 ld.param.u64 %out, [out_ptr];
8916 ld.param.u32 %n_reg, [n];
8917
8918 mov.u32 %bid, %ctaid.x;
8919 mov.u32 %bdim, %ntid.x;
8920 mov.u32 %r_tid, %tid.x;
8921 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8922
8923 setp.ge.u32 %p, %r_tid, %n_reg;
8924 @%p bra DONE;
8925
8926 cvt.u64.u32 %off, %r_tid;
8927 shl.b64 %off, %off, 2;
8928
8929 add.u64 %a, %a, %off;
8930 add.u64 %out, %out, %off;
8931
8932 ld.global.f32 %va, [%a];
8933 abs.f32 %vr, %va;
8934 st.global.f32 [%out], %vr;
8935
8936DONE:
8937 ret;
8938}
8939";
8940
8941
8942#[cfg(feature = "cuda")]
8944pub(crate) const SIGMOID_PTX: &str = "\
8945.version 7.0
8946.target sm_52
8947.address_size 64
8948
8949.visible .entry sigmoid_kernel(
8950 .param .u64 a_ptr,
8951 .param .u64 out_ptr,
8952 .param .u32 n
8953) {
8954 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8955 .reg .u64 %a, %out, %off;
8956 .reg .f32 %va, %vr, %neg, %e, %denom, %one, %lg2e;
8957 .reg .pred %p;
8958
8959 ld.param.u64 %a, [a_ptr];
8960 ld.param.u64 %out, [out_ptr];
8961 ld.param.u32 %n_reg, [n];
8962
8963 mov.u32 %bid, %ctaid.x;
8964 mov.u32 %bdim, %ntid.x;
8965 mov.u32 %r_tid, %tid.x;
8966 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8967
8968 setp.ge.u32 %p, %r_tid, %n_reg;
8969 @%p bra DONE;
8970
8971 cvt.u64.u32 %off, %r_tid;
8972 shl.b64 %off, %off, 2;
8973
8974 add.u64 %a, %a, %off;
8975 add.u64 %out, %out, %off;
8976
8977 ld.global.f32 %va, [%a];
8978 // sigmoid(x) = 1 / (1 + exp(-x))
8979 neg.f32 %neg, %va;
8980 mov.f32 %lg2e, 0f3FB8AA3B;
8981 mul.f32 %neg, %neg, %lg2e;
8982 ex2.approx.f32 %e, %neg;
8983 mov.f32 %one, 0f3F800000;
8984 add.f32 %denom, %one, %e;
8985 div.rn.f32 %vr, %one, %denom;
8986 st.global.f32 [%out], %vr;
8987
8988DONE:
8989 ret;
8990}
8991";
8992
8993#[cfg(feature = "cuda")]
8997pub(crate) const SIGMOID_F64_PTX: &str = "\
8998.version 7.0
8999.target sm_52
9000.address_size 64
9001
9002.visible .entry sigmoid_f64_kernel(
9003 .param .u64 a_ptr,
9004 .param .u64 out_ptr,
9005 .param .u32 n
9006) {
9007 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
9008 .reg .u64 %a, %out, %off;
9009 .reg .f64 %va, %vr, %e64, %denom, %one, %neg_x;
9010 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
9011 .reg .s32 %e_ni;
9012 .reg .s64 %e_ni64, %e_bits;
9013 .reg .pred %p;
9014
9015 ld.param.u64 %a, [a_ptr];
9016 ld.param.u64 %out, [out_ptr];
9017 ld.param.u32 %n_reg, [n];
9018
9019 mov.u32 %bid, %ctaid.x;
9020 mov.u32 %bdim, %ntid.x;
9021 mov.u32 %r_tid, %tid.x;
9022 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
9023
9024 setp.ge.u32 %p, %r_tid, %n_reg;
9025 @%p bra DONE;
9026
9027 cvt.u64.u32 %off, %r_tid;
9028 shl.b64 %off, %off, 3;
9029
9030 add.u64 %a, %a, %off;
9031 add.u64 %out, %out, %off;
9032
9033 ld.global.f64 %va, [%a];
9034 mov.f64 %one, 0d3FF0000000000000;
9035
9036 // sigmoid(x) = 1 / (1 + exp(-x))
9037 neg.f64 %neg_x, %va;
9038
9039 // --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
9040 mov.f64 %e_half, 0d3FE0000000000000;
9041 fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
9042 cvt.rmi.f64.f64 %e_nf, %e_nf;
9043 cvt.rni.s32.f64 %e_ni, %e_nf;
9044 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
9045 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
9046 mov.f64 %e_p, 0d3E21EED8EFF8D898;
9047 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
9048 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
9049 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
9050 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
9051 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
9052 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
9053 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
9054 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
9055 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
9056 fma.rn.f64 %e_p, %e_p, %e_r, %one;
9057 fma.rn.f64 %e64, %e_p, %e_r, %one;
9058 cvt.s64.s32 %e_ni64, %e_ni;
9059 add.s64 %e_ni64, %e_ni64, 1023;
9060 shl.b64 %e_bits, %e_ni64, 52;
9061 mov.b64 %e_nf, %e_bits;
9062 mul.f64 %e64, %e64, %e_nf;
9063 // --- end exp ---
9064
9065 add.f64 %denom, %one, %e64;
9066 div.rn.f64 %vr, %one, %denom;
9067 st.global.f64 [%out], %vr;
9068
9069DONE:
9070 ret;
9071}
9072";
9073
9074#[cfg(feature = "cuda")]
9077pub(crate) const TANH_PTX: &str = "\
9078.version 7.0
9079.target sm_52
9080.address_size 64
9081
9082.visible .entry tanh_kernel(
9083 .param .u64 a_ptr,
9084 .param .u64 out_ptr,
9085 .param .u32 n
9086) {
9087 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
9088 .reg .u64 %a, %out, %off;
9089 .reg .f32 %va, %vr, %neg2x, %e, %denom, %sig, %one, %two, %lg2e;
9090 .reg .pred %p;
9091
9092 ld.param.u64 %a, [a_ptr];
9093 ld.param.u64 %out, [out_ptr];
9094 ld.param.u32 %n_reg, [n];
9095
9096 mov.u32 %bid, %ctaid.x;
9097 mov.u32 %bdim, %ntid.x;
9098 mov.u32 %r_tid, %tid.x;
9099 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
9100
9101 setp.ge.u32 %p, %r_tid, %n_reg;
9102 @%p bra DONE;
9103
9104 cvt.u64.u32 %off, %r_tid;
9105 shl.b64 %off, %off, 2;
9106
9107 add.u64 %a, %a, %off;
9108 add.u64 %out, %out, %off;
9109
9110 ld.global.f32 %va, [%a];
9111 // tanh(x) = 2*sigmoid(2x) - 1
9112 mov.f32 %two, 0f40000000;
9113 mul.f32 %neg2x, %va, %two;
9114 neg.f32 %neg2x, %neg2x;
9115 mov.f32 %lg2e, 0f3FB8AA3B;
9116 mul.f32 %neg2x, %neg2x, %lg2e;
9117 ex2.approx.f32 %e, %neg2x;
9118 mov.f32 %one, 0f3F800000;
9119 add.f32 %denom, %one, %e;
9120 div.rn.f32 %sig, %one, %denom;
9121 mul.f32 %vr, %two, %sig;
9122 sub.f32 %vr, %vr, %one;
9123 st.global.f32 [%out], %vr;
9124
9125DONE:
9126 ret;
9127}
9128";
9129
9130#[cfg(feature = "cuda")]
9134pub(crate) const TANH_F64_PTX: &str = "\
9135.version 7.0
9136.target sm_52
9137.address_size 64
9138
9139.visible .entry tanh_f64_kernel(
9140 .param .u64 a_ptr,
9141 .param .u64 out_ptr,
9142 .param .u32 n
9143) {
9144 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
9145 .reg .u64 %a, %out, %off;
9146 .reg .f64 %va, %vr, %e64, %num, %denom, %one, %two, %neg2x;
9147 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
9148 .reg .s32 %e_ni;
9149 .reg .s64 %e_ni64, %e_bits;
9150 .reg .pred %p;
9151
9152 ld.param.u64 %a, [a_ptr];
9153 ld.param.u64 %out, [out_ptr];
9154 ld.param.u32 %n_reg, [n];
9155
9156 mov.u32 %bid, %ctaid.x;
9157 mov.u32 %bdim, %ntid.x;
9158 mov.u32 %r_tid, %tid.x;
9159 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
9160
9161 setp.ge.u32 %p, %r_tid, %n_reg;
9162 @%p bra DONE;
9163
9164 cvt.u64.u32 %off, %r_tid;
9165 shl.b64 %off, %off, 3;
9166
9167 add.u64 %a, %a, %off;
9168 add.u64 %out, %out, %off;
9169
9170 ld.global.f64 %va, [%a];
9171 mov.f64 %one, 0d3FF0000000000000;
9172 mov.f64 %two, 0d4000000000000000;
9173
9174 // tanh(x) = (1 - exp(-2x)) / (1 + exp(-2x))
9175 mul.f64 %neg2x, %va, %two;
9176 neg.f64 %neg2x, %neg2x;
9177
9178 // --- exp(%neg2x) via Cody-Waite + degree-11 Horner ---
9179 mov.f64 %e_half, 0d3FE0000000000000;
9180 fma.rn.f64 %e_nf, %neg2x, 0d3FF71547652B82FE, %e_half;
9181 cvt.rmi.f64.f64 %e_nf, %e_nf;
9182 cvt.rni.s32.f64 %e_ni, %e_nf;
9183 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg2x;
9184 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
9185 mov.f64 %e_p, 0d3E21EED8EFF8D898;
9186 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
9187 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
9188 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
9189 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
9190 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
9191 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
9192 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
9193 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
9194 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
9195 fma.rn.f64 %e_p, %e_p, %e_r, %one;
9196 fma.rn.f64 %e64, %e_p, %e_r, %one;
9197 cvt.s64.s32 %e_ni64, %e_ni;
9198 add.s64 %e_ni64, %e_ni64, 1023;
9199 shl.b64 %e_bits, %e_ni64, 52;
9200 mov.b64 %e_nf, %e_bits;
9201 mul.f64 %e64, %e64, %e_nf;
9202 // --- end exp ---
9203
9204 sub.f64 %num, %one, %e64;
9205 add.f64 %denom, %one, %e64;
9206 div.rn.f64 %vr, %num, %denom;
9207 st.global.f64 [%out], %vr;
9208
9209DONE:
9210 ret;
9211}
9212";
9213
9214#[cfg(feature = "cuda")]
9224pub(crate) const FUSED_ADAM_PTX: &str = "\
9225.version 7.0
9226.target sm_52
9227.address_size 64
9228
9229.visible .entry fused_adam_kernel(
9230 .param .u64 param_ptr,
9231 .param .u64 grad_ptr,
9232 .param .u64 exp_avg_ptr,
9233 .param .u64 exp_avg_sq_ptr,
9234 .param .f32 beta1,
9235 .param .f32 beta2,
9236 .param .f32 lr,
9237 .param .f32 eps,
9238 .param .f32 bc1,
9239 .param .f32 bc2,
9240 .param .f32 weight_decay,
9241 .param .u32 n
9242) {
9243 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
9244 .reg .u64 %p, %g, %m, %v, %off;
9245 .reg .f32 %vp, %vg, %vm, %vv;
9246 .reg .f32 %b1, %b2, %f_lr, %f_eps, %f_bc1, %f_bc2, %f_wd;
9247 .reg .f32 %t1, %t2, %m_hat, %v_hat, %denom, %update;
9248 .reg .f32 %one;
9249 .reg .pred %p_bound, %p_wd;
9250
9251 ld.param.u64 %p, [param_ptr];
9252 ld.param.u64 %g, [grad_ptr];
9253 ld.param.u64 %m, [exp_avg_ptr];
9254 ld.param.u64 %v, [exp_avg_sq_ptr];
9255 ld.param.f32 %b1, [beta1];
9256 ld.param.f32 %b2, [beta2];
9257 ld.param.f32 %f_lr, [lr];
9258 ld.param.f32 %f_eps, [eps];
9259 ld.param.f32 %f_bc1, [bc1];
9260 ld.param.f32 %f_bc2, [bc2];
9261 ld.param.f32 %f_wd, [weight_decay];
9262 ld.param.u32 %n_reg, [n];
9263
9264 mov.u32 %bid, %ctaid.x;
9265 mov.u32 %bdim, %ntid.x;
9266 mov.u32 %r_tid, %tid.x;
9267 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
9268
9269 setp.ge.u32 %p_bound, %r_tid, %n_reg;
9270 @%p_bound bra DONE;
9271
9272 cvt.u64.u32 %off, %r_tid;
9273 shl.b64 %off, %off, 2;
9274
9275 add.u64 %p, %p, %off;
9276 add.u64 %g, %g, %off;
9277 add.u64 %m, %m, %off;
9278 add.u64 %v, %v, %off;
9279
9280 ld.global.f32 %vp, [%p];
9281 ld.global.f32 %vg, [%g];
9282 ld.global.f32 %vm, [%m];
9283 ld.global.f32 %vv, [%v];
9284
9285 // L2 weight decay: g = g + wd * p
9286 mov.f32 %one, 0f00000000;
9287 setp.gt.f32 %p_wd, %f_wd, %one;
9288 @%p_wd fma.rn.f32 %vg, %f_wd, %vp, %vg;
9289
9290 // exp_avg = beta1 * exp_avg + (1 - beta1) * g
9291 mov.f32 %one, 0f3F800000;
9292 sub.f32 %t1, %one, %b1;
9293 mul.f32 %vm, %vm, %b1;
9294 fma.rn.f32 %vm, %t1, %vg, %vm;
9295
9296 // exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * g * g
9297 sub.f32 %t2, %one, %b2;
9298 mul.f32 %vv, %vv, %b2;
9299 mul.f32 %t1, %vg, %vg;
9300 fma.rn.f32 %vv, %t2, %t1, %vv;
9301
9302 // m_hat = exp_avg / bc1
9303 div.rn.f32 %m_hat, %vm, %f_bc1;
9304
9305 // v_hat = exp_avg_sq / bc2
9306 div.rn.f32 %v_hat, %vv, %f_bc2;
9307
9308 // denom = sqrt(v_hat) + eps
9309 sqrt.rn.f32 %denom, %v_hat;
9310 add.f32 %denom, %denom, %f_eps;
9311
9312 // param = param - lr * m_hat / denom
9313 div.rn.f32 %update, %m_hat, %denom;
9314 mul.f32 %update, %update, %f_lr;
9315 sub.f32 %vp, %vp, %update;
9316
9317 st.global.f32 [%p], %vp;
9318 st.global.f32 [%m], %vm;
9319 st.global.f32 [%v], %vv;
9320
9321DONE:
9322 ret;
9323}
9324";
9325
9326#[cfg(feature = "cuda")]
9338pub(crate) const FUSED_GRU_FORWARD_PTX: &str = "\
9339.version 7.0
9340.target sm_52
9341.address_size 64
9342
9343.visible .entry fused_gru_forward_kernel(
9344 .param .u64 input_gates_ptr,
9345 .param .u64 hidden_gates_ptr,
9346 .param .u64 bias_ih_ptr,
9347 .param .u64 bias_hh_ptr,
9348 .param .u64 hx_ptr,
9349 .param .u64 hy_ptr,
9350 .param .u64 workspace_ptr,
9351 .param .u32 hsz,
9352 .param .u32 total
9353) {
9354 .reg .u32 %tid, %bid, %bdim, %gdim, %total_reg, %hsz_reg;
9355 .reg .u32 %idx, %stride, %offset3, %offset5, %hmod, %batch_idx;
9356 .reg .u64 %ig, %hg, %b1, %b2, %hx, %hy, %ws;
9357 .reg .u64 %off64, %tmp64;
9358 .reg .f32 %ir, %ii, %in, %hr, %hi, %hn;
9359 .reg .f32 %b1r, %b1i, %b1n, %b2r, %b2i, %b2n;
9360 .reg .f32 %hx_val, %rg, %zg, %ng, %hy_val;
9361 .reg .f32 %one, %neg_one, %exp_val, %denom, %tmp;
9362 .reg .pred %p;
9363
9364 ld.param.u64 %ig, [input_gates_ptr];
9365 ld.param.u64 %hg, [hidden_gates_ptr];
9366 ld.param.u64 %b1, [bias_ih_ptr];
9367 ld.param.u64 %b2, [bias_hh_ptr];
9368 ld.param.u64 %hx, [hx_ptr];
9369 ld.param.u64 %hy, [hy_ptr];
9370 ld.param.u64 %ws, [workspace_ptr];
9371 ld.param.u32 %hsz_reg, [hsz];
9372 ld.param.u32 %total_reg, [total];
9373
9374 mov.u32 %bid, %ctaid.x;
9375 mov.u32 %bdim, %ntid.x;
9376 mov.u32 %tid, %tid.x;
9377 mov.u32 %gdim, %nctaid.x;
9378 mad.lo.u32 %idx, %bid, %bdim, %tid;
9379 mul.lo.u32 %stride, %bdim, %gdim;
9380 mov.f32 %one, 0f3F800000;
9381
9382LOOP:
9383 setp.ge.u32 %p, %idx, %total_reg;
9384 @%p bra END;
9385
9386 // offset3 = (idx/hsz)*3*hsz + idx%hsz (into [B, 3*H] gates tensor)
9387 div.u32 %batch_idx, %idx, %hsz_reg;
9388 rem.u32 %hmod, %idx, %hsz_reg;
9389 mul.lo.u32 %offset3, %batch_idx, %hsz_reg;
9390 mul.lo.u32 %offset3, %offset3, 3;
9391 add.u32 %offset3, %offset3, %hmod;
9392
9393 // Load input gate components: ir, ii, in
9394 cvt.u64.u32 %off64, %offset3;
9395 shl.b64 %off64, %off64, 2;
9396 add.u64 %tmp64, %ig, %off64;
9397 ld.global.f32 %ir, [%tmp64];
9398 cvt.u64.u32 %off64, %hsz_reg;
9399 shl.b64 %off64, %off64, 2;
9400 add.u64 %tmp64, %tmp64, %off64;
9401 ld.global.f32 %ii, [%tmp64];
9402 add.u64 %tmp64, %tmp64, %off64;
9403 ld.global.f32 %in, [%tmp64];
9404
9405 // Load hidden gate components: hr, hi, hn
9406 cvt.u64.u32 %off64, %offset3;
9407 shl.b64 %off64, %off64, 2;
9408 add.u64 %tmp64, %hg, %off64;
9409 ld.global.f32 %hr, [%tmp64];
9410 cvt.u64.u32 %off64, %hsz_reg;
9411 shl.b64 %off64, %off64, 2;
9412 add.u64 %tmp64, %tmp64, %off64;
9413 ld.global.f32 %hi, [%tmp64];
9414 add.u64 %tmp64, %tmp64, %off64;
9415 ld.global.f32 %hn, [%tmp64];
9416
9417 // Load biases (indexed by hmod, hmod+hsz, hmod+2*hsz)
9418 cvt.u64.u32 %off64, %hmod;
9419 shl.b64 %off64, %off64, 2;
9420 add.u64 %tmp64, %b1, %off64;
9421 ld.global.f32 %b1r, [%tmp64];
9422 cvt.u64.u32 %off64, %hsz_reg;
9423 shl.b64 %off64, %off64, 2;
9424 add.u64 %tmp64, %tmp64, %off64;
9425 ld.global.f32 %b1i, [%tmp64];
9426 add.u64 %tmp64, %tmp64, %off64;
9427 ld.global.f32 %b1n, [%tmp64];
9428
9429 cvt.u64.u32 %off64, %hmod;
9430 shl.b64 %off64, %off64, 2;
9431 add.u64 %tmp64, %b2, %off64;
9432 ld.global.f32 %b2r, [%tmp64];
9433 cvt.u64.u32 %off64, %hsz_reg;
9434 shl.b64 %off64, %off64, 2;
9435 add.u64 %tmp64, %tmp64, %off64;
9436 ld.global.f32 %b2i, [%tmp64];
9437 add.u64 %tmp64, %tmp64, %off64;
9438 ld.global.f32 %b2n, [%tmp64];
9439
9440 // Load hx[idx]
9441 cvt.u64.u32 %off64, %idx;
9442 shl.b64 %off64, %off64, 2;
9443 add.u64 %tmp64, %hx, %off64;
9444 ld.global.f32 %hx_val, [%tmp64];
9445
9446 // r = sigmoid(ir + hr + b1r + b2r)
9447 add.f32 %rg, %ir, %hr;
9448 add.f32 %rg, %rg, %b1r;
9449 add.f32 %rg, %rg, %b2r;
9450 neg.f32 %tmp, %rg;
9451 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
9452 ex2.approx.f32 %exp_val, %tmp;
9453 add.f32 %denom, %one, %exp_val;
9454 div.rn.f32 %rg, %one, %denom;
9455
9456 // z = sigmoid(ii + hi + b1i + b2i)
9457 add.f32 %zg, %ii, %hi;
9458 add.f32 %zg, %zg, %b1i;
9459 add.f32 %zg, %zg, %b2i;
9460 neg.f32 %tmp, %zg;
9461 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
9462 ex2.approx.f32 %exp_val, %tmp;
9463 add.f32 %denom, %one, %exp_val;
9464 div.rn.f32 %zg, %one, %denom;
9465
9466 // n = tanh(in + b1n + r*(hn + b2n))
9467 add.f32 %tmp, %hn, %b2n;
9468 fma.rn.f32 %ng, %rg, %tmp, %in;
9469 add.f32 %ng, %ng, %b1n;
9470 // tanh via 2*sigmoid(2x)-1
9471 mul.f32 %tmp, %ng, 0f40000000;
9472 neg.f32 %tmp, %tmp;
9473 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
9474 ex2.approx.f32 %exp_val, %tmp;
9475 add.f32 %denom, %one, %exp_val;
9476 div.rn.f32 %ng, %one, %denom;
9477 mul.f32 %ng, %ng, 0f40000000;
9478 sub.f32 %ng, %ng, %one;
9479
9480 // hy = n + z * (hx - n)
9481 sub.f32 %tmp, %hx_val, %ng;
9482 fma.rn.f32 %hy_val, %zg, %tmp, %ng;
9483
9484 // Store hy[idx]
9485 cvt.u64.u32 %off64, %idx;
9486 shl.b64 %off64, %off64, 2;
9487 add.u64 %tmp64, %hy, %off64;
9488 st.global.f32 [%tmp64], %hy_val;
9489
9490 // Store workspace: [r, z, n, hx, hn+b2n] at offset5 = (idx/hsz)*5*hsz + idx%hsz
9491 mul.lo.u32 %offset5, %batch_idx, %hsz_reg;
9492 mul.lo.u32 %offset5, %offset5, 5;
9493 add.u32 %offset5, %offset5, %hmod;
9494
9495 cvt.u64.u32 %off64, %offset5;
9496 shl.b64 %off64, %off64, 2;
9497 add.u64 %tmp64, %ws, %off64;
9498 st.global.f32 [%tmp64], %rg;
9499 cvt.u64.u32 %off64, %hsz_reg;
9500 shl.b64 %off64, %off64, 2;
9501 add.u64 %tmp64, %tmp64, %off64;
9502 st.global.f32 [%tmp64], %zg;
9503 add.u64 %tmp64, %tmp64, %off64;
9504 st.global.f32 [%tmp64], %ng;
9505 add.u64 %tmp64, %tmp64, %off64;
9506 st.global.f32 [%tmp64], %hx_val;
9507 add.u64 %tmp64, %tmp64, %off64;
9508 add.f32 %tmp, %hn, %b2n;
9509 st.global.f32 [%tmp64], %tmp;
9510
9511 add.u32 %idx, %idx, %stride;
9512 bra LOOP;
9513
9514END:
9515 ret;
9516}
9517";
9518
9519#[cfg(feature = "cuda")]
9533fn launch_cfg(n: usize) -> GpuResult<LaunchConfig> {
9534 if n > u32::MAX as usize {
9535 return Err(GpuError::ShapeMismatch {
9536 op: "kernel_launch",
9537 expected: vec![u32::MAX as usize],
9538 got: vec![n],
9539 });
9540 }
9541 const BLOCK: u32 = 256;
9542 let grid = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
9543 Ok(LaunchConfig {
9544 grid_dim: (grid.max(1), 1, 1),
9545 block_dim: (BLOCK, 1, 1),
9546 shared_mem_bytes: 0,
9547 })
9548}
9549
9550#[cfg(feature = "cuda")]
9556fn validate_binary(a: &CudaBuffer<f32>, b: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
9557 if a.device_ordinal() != device.ordinal() {
9558 return Err(GpuError::DeviceMismatch {
9559 expected: a.device_ordinal(),
9560 got: device.ordinal(),
9561 });
9562 }
9563 if b.device_ordinal() != device.ordinal() {
9564 return Err(GpuError::DeviceMismatch {
9565 expected: b.device_ordinal(),
9566 got: device.ordinal(),
9567 });
9568 }
9569 if a.len() != b.len() {
9570 return Err(GpuError::LengthMismatch {
9571 a: a.len(),
9572 b: b.len(),
9573 });
9574 }
9575 Ok(())
9576}
9577
9578#[cfg(feature = "cuda")]
9580fn validate_unary(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
9581 if a.device_ordinal() != device.ordinal() {
9582 return Err(GpuError::DeviceMismatch {
9583 expected: a.device_ordinal(),
9584 got: device.ordinal(),
9585 });
9586 }
9587 Ok(())
9588}
9589
9590#[cfg(feature = "cuda")]
9592fn validate_device<T>(a: &CudaBuffer<T>, device: &GpuDevice) -> GpuResult<()> {
9593 if a.device_ordinal() != device.ordinal() {
9594 return Err(GpuError::DeviceMismatch {
9595 expected: a.device_ordinal(),
9596 got: device.ordinal(),
9597 });
9598 }
9599 Ok(())
9600}
9601
9602#[cfg(feature = "cuda")]
9610fn try_launch_binary(
9611 a: &CudaBuffer<f32>,
9612 b: &CudaBuffer<f32>,
9613 device: &GpuDevice,
9614 ptx_src: &'static str,
9615 kernel_name: &'static str,
9616) -> GpuResult<Option<CudaBuffer<f32>>> {
9617 use cudarc::driver::PushKernelArg;
9618
9619 let n = a.len();
9620 let ctx = device.context();
9621 let stream = device.stream();
9622
9623 let f = match crate::module_cache::get_or_compile(
9627 ctx,
9628 ptx_src,
9629 kernel_name,
9630 device.ordinal() as u32,
9631 ) {
9632 Ok(f) => f,
9633 Err(_) => return Ok(None),
9634 };
9635
9636 let mut out = alloc_zeros_f32(n, device)?;
9637 let cfg = launch_cfg(n)?;
9638 let n_u32 = n as u32;
9639
9640 unsafe {
9644 stream
9645 .launch_builder(&f)
9646 .arg(a.inner())
9647 .arg(b.inner())
9648 .arg(out.inner_mut())
9649 .arg(&n_u32)
9650 .launch(cfg)?;
9651 }
9652
9653 Ok(Some(out))
9654}
9655
9656#[cfg(feature = "cuda")]
9661fn try_launch_binary_vec4(
9662 a: &CudaBuffer<f32>,
9663 b: &CudaBuffer<f32>,
9664 device: &GpuDevice,
9665 ptx_src: &'static str,
9666 kernel_name: &'static str,
9667) -> GpuResult<Option<CudaBuffer<f32>>> {
9668 use cudarc::driver::PushKernelArg;
9669
9670 let n = a.len();
9671 let n4 = (n / 4) as u32;
9672 let ctx = device.context();
9673 let stream = device.stream();
9674
9675 let f = match crate::module_cache::get_or_compile(
9676 ctx,
9677 ptx_src,
9678 kernel_name,
9679 device.ordinal() as u32,
9680 ) {
9681 Ok(f) => f,
9682 Err(_) => return Ok(None),
9683 };
9684
9685 let mut out = alloc_zeros_f32(n, device)?;
9686 let cfg = launch_cfg(n4 as usize)?;
9687
9688 unsafe {
9689 stream
9690 .launch_builder(&f)
9691 .arg(a.inner())
9692 .arg(b.inner())
9693 .arg(out.inner_mut())
9694 .arg(&n4)
9695 .launch(cfg)?;
9696 }
9697
9698 Ok(Some(out))
9699}
9700
9701#[cfg(feature = "cuda")]
9704fn try_launch_unary(
9705 a: &CudaBuffer<f32>,
9706 device: &GpuDevice,
9707 ptx_src: &'static str,
9708 kernel_name: &'static str,
9709) -> GpuResult<Option<CudaBuffer<f32>>> {
9710 use cudarc::driver::PushKernelArg;
9711
9712 let n = a.len();
9713 let ctx = device.context();
9714 let stream = device.stream();
9715
9716 let f = match crate::module_cache::get_or_compile(
9718 ctx,
9719 ptx_src,
9720 kernel_name,
9721 device.ordinal() as u32,
9722 ) {
9723 Ok(f) => f,
9724 Err(_) => return Ok(None),
9725 };
9726
9727 let mut out = alloc_zeros_f32(n, device)?;
9728 let cfg = launch_cfg(n)?;
9729 let n_u32 = n as u32;
9730
9731 unsafe {
9734 stream
9735 .launch_builder(&f)
9736 .arg(a.inner())
9737 .arg(out.inner_mut())
9738 .arg(&n_u32)
9739 .launch(cfg)?;
9740 }
9741
9742 Ok(Some(out))
9743}
9744
9745#[cfg(feature = "cuda")]
9752fn try_launch_binary_into(
9753 a: &CudaBuffer<f32>,
9754 b: &CudaBuffer<f32>,
9755 out: &mut CudaBuffer<f32>,
9756 device: &GpuDevice,
9757 ptx_src: &'static str,
9758 kernel_name: &'static str,
9759) -> GpuResult<bool> {
9760 use cudarc::driver::PushKernelArg;
9761
9762 let n = a.len();
9763 let ctx = device.context();
9764 let stream = device.stream();
9765
9766 let f = match crate::module_cache::get_or_compile(
9767 ctx,
9768 ptx_src,
9769 kernel_name,
9770 device.ordinal() as u32,
9771 ) {
9772 Ok(f) => f,
9773 Err(_) => return Ok(false),
9774 };
9775
9776 let cfg = launch_cfg(n)?;
9777 let n_u32 = n as u32;
9778
9779 unsafe {
9780 stream
9781 .launch_builder(&f)
9782 .arg(a.inner())
9783 .arg(b.inner())
9784 .arg(out.inner_mut())
9785 .arg(&n_u32)
9786 .launch(cfg)?;
9787 }
9788
9789 Ok(true)
9790}
9791
9792#[cfg(feature = "cuda")]
9795fn try_launch_unary_into(
9796 a: &CudaBuffer<f32>,
9797 out: &mut CudaBuffer<f32>,
9798 device: &GpuDevice,
9799 ptx_src: &'static str,
9800 kernel_name: &'static str,
9801) -> GpuResult<bool> {
9802 use cudarc::driver::PushKernelArg;
9803
9804 let n = a.len();
9805 let ctx = device.context();
9806 let stream = device.stream();
9807
9808 let f = match crate::module_cache::get_or_compile(
9809 ctx,
9810 ptx_src,
9811 kernel_name,
9812 device.ordinal() as u32,
9813 ) {
9814 Ok(f) => f,
9815 Err(_) => return Ok(false),
9816 };
9817
9818 let cfg = launch_cfg(n)?;
9819 let n_u32 = n as u32;
9820
9821 unsafe {
9822 stream
9823 .launch_builder(&f)
9824 .arg(a.inner())
9825 .arg(out.inner_mut())
9826 .arg(&n_u32)
9827 .launch(cfg)?;
9828 }
9829
9830 Ok(true)
9831}
9832
9833#[cfg(feature = "cuda")]
9839fn try_launch_binary_f64(
9840 a: &CudaBuffer<f64>,
9841 b: &CudaBuffer<f64>,
9842 device: &GpuDevice,
9843 ptx_src: &'static str,
9844 kernel_name: &'static str,
9845) -> GpuResult<Option<CudaBuffer<f64>>> {
9846 use cudarc::driver::PushKernelArg;
9847
9848 let n = a.len();
9849 let ctx = device.context();
9850 let stream = device.stream();
9851
9852 let f = match crate::module_cache::get_or_compile(
9853 ctx, ptx_src, kernel_name, device.ordinal() as u32,
9854 ) {
9855 Ok(f) => f,
9856 Err(_) => return Ok(None),
9857 };
9858
9859 let mut out = alloc_zeros_f64(n, device)?;
9860 let cfg = launch_cfg(n)?;
9861 let n_u32 = n as u32;
9862
9863 unsafe {
9864 stream
9865 .launch_builder(&f)
9866 .arg(a.inner())
9867 .arg(b.inner())
9868 .arg(out.inner_mut())
9869 .arg(&n_u32)
9870 .launch(cfg)?;
9871 }
9872 Ok(Some(out))
9873}
9874
9875#[cfg(feature = "cuda")]
9877fn try_launch_unary_f64(
9878 a: &CudaBuffer<f64>,
9879 device: &GpuDevice,
9880 ptx_src: &'static str,
9881 kernel_name: &'static str,
9882) -> GpuResult<Option<CudaBuffer<f64>>> {
9883 use cudarc::driver::PushKernelArg;
9884
9885 let n = a.len();
9886 let ctx = device.context();
9887 let stream = device.stream();
9888
9889 let f = match crate::module_cache::get_or_compile(
9890 ctx, ptx_src, kernel_name, device.ordinal() as u32,
9891 ) {
9892 Ok(f) => f,
9893 Err(_) => return Ok(None),
9894 };
9895
9896 let mut out = alloc_zeros_f64(n, device)?;
9897 let cfg = launch_cfg(n)?;
9898 let n_u32 = n as u32;
9899
9900 unsafe {
9901 stream
9902 .launch_builder(&f)
9903 .arg(a.inner())
9904 .arg(out.inner_mut())
9905 .arg(&n_u32)
9906 .launch(cfg)?;
9907 }
9908 Ok(Some(out))
9909}
9910
9911#[cfg(feature = "cuda")]
9913fn cpu_fallback_binary_f64(
9914 a: &CudaBuffer<f64>,
9915 b: &CudaBuffer<f64>,
9916 device: &GpuDevice,
9917 op: fn(f64, f64) -> f64,
9918) -> GpuResult<CudaBuffer<f64>> {
9919 let a_host = gpu_to_cpu(a, device)?;
9920 let b_host = gpu_to_cpu(b, device)?;
9921 let result: Vec<f64> = a_host.iter().zip(b_host.iter()).map(|(&x, &y)| op(x, y)).collect();
9922 cpu_to_gpu(&result, device)
9923}
9924
9925#[cfg(feature = "cuda")]
9927fn cpu_fallback_unary_f64(
9928 a: &CudaBuffer<f64>,
9929 device: &GpuDevice,
9930 op: fn(f64) -> f64,
9931) -> GpuResult<CudaBuffer<f64>> {
9932 let a_host = gpu_to_cpu(a, device)?;
9933 let result: Vec<f64> = a_host.iter().map(|&x| op(x)).collect();
9934 cpu_to_gpu(&result, device)
9935}
9936
9937#[cfg(feature = "cuda")]
9941#[allow(clippy::too_many_arguments)]
9942fn try_launch_broadcast_binary_f64(
9943 a: &CudaBuffer<f64>,
9944 b: &CudaBuffer<f64>,
9945 a_strides: &[u32],
9946 b_strides: &[u32],
9947 out_shape: &[u32],
9948 out_numel: usize,
9949 device: &GpuDevice,
9950 ptx_src: &'static str,
9951 kernel_name: &'static str,
9952) -> GpuResult<Option<CudaBuffer<f64>>> {
9953 use cudarc::driver::PushKernelArg;
9954
9955 let ndim = out_shape.len();
9956 let ctx = device.context();
9957 let stream = device.stream();
9958
9959 let f = match crate::module_cache::get_or_compile(
9960 ctx,
9961 ptx_src,
9962 kernel_name,
9963 device.ordinal() as u32,
9964 ) {
9965 Ok(f) => f,
9966 Err(_) => return Ok(None),
9967 };
9968
9969 let a_str_buf = cpu_to_gpu(a_strides, device)?;
9971 let b_str_buf = cpu_to_gpu(b_strides, device)?;
9972 let shape_buf = cpu_to_gpu(out_shape, device)?;
9973
9974 let mut out = alloc_zeros_f64(out_numel, device)?;
9975 let cfg = launch_cfg(out_numel)?;
9976 let n_u32 = out_numel as u32;
9977 let ndim_u32 = ndim as u32;
9978
9979 unsafe {
9980 stream
9981 .launch_builder(&f)
9982 .arg(a.inner())
9983 .arg(b.inner())
9984 .arg(out.inner_mut())
9985 .arg(a_str_buf.inner())
9986 .arg(b_str_buf.inner())
9987 .arg(shape_buf.inner())
9988 .arg(&n_u32)
9989 .arg(&ndim_u32)
9990 .launch(cfg)?;
9991 }
9992
9993 Ok(Some(out))
9994}
9995
9996#[cfg(feature = "cuda")]
9998fn cpu_fallback_broadcast_binary_f64(
9999 a: &CudaBuffer<f64>,
10000 b: &CudaBuffer<f64>,
10001 a_shape: &[usize],
10002 b_shape: &[usize],
10003 out_shape: &[usize],
10004 device: &GpuDevice,
10005 op: fn(f64, f64) -> f64,
10006) -> GpuResult<CudaBuffer<f64>> {
10007 let a_host = gpu_to_cpu(a, device)?;
10008 let b_host = gpu_to_cpu(b, device)?;
10009 let out_numel: usize = out_shape.iter().product();
10010
10011 let a_str = broadcast_strides(a_shape, out_shape);
10012 let b_str = broadcast_strides(b_shape, out_shape);
10013
10014 let mut result = Vec::with_capacity(out_numel);
10015 for i in 0..out_numel {
10016 let mut remaining = i;
10017 let mut a_idx = 0usize;
10018 let mut b_idx = 0usize;
10019 for d in (0..out_shape.len()).rev() {
10020 let coord = remaining % out_shape[d];
10021 remaining /= out_shape[d];
10022 a_idx += coord * a_str[d] as usize;
10023 b_idx += coord * b_str[d] as usize;
10024 }
10025 result.push(op(a_host[a_idx], b_host[b_idx]));
10026 }
10027 cpu_to_gpu(&result, device)
10028}
10029
10030#[cfg(feature = "cuda")]
10037#[allow(clippy::too_many_arguments)]
10038fn try_launch_broadcast_binary(
10039 a: &CudaBuffer<f32>,
10040 b: &CudaBuffer<f32>,
10041 a_strides: &[u32],
10042 b_strides: &[u32],
10043 out_shape: &[u32],
10044 out_numel: usize,
10045 device: &GpuDevice,
10046 ptx_src: &'static str,
10047 kernel_name: &'static str,
10048) -> GpuResult<Option<CudaBuffer<f32>>> {
10049 use cudarc::driver::PushKernelArg;
10050
10051 let ndim = out_shape.len();
10052 let ctx = device.context();
10053 let stream = device.stream();
10054
10055 let f = match crate::module_cache::get_or_compile(
10056 ctx,
10057 ptx_src,
10058 kernel_name,
10059 device.ordinal() as u32,
10060 ) {
10061 Ok(f) => f,
10062 Err(_) => return Ok(None),
10063 };
10064
10065 let a_str_buf = cpu_to_gpu(a_strides, device)?;
10067 let b_str_buf = cpu_to_gpu(b_strides, device)?;
10068 let shape_buf = cpu_to_gpu(out_shape, device)?;
10069
10070 let mut out = alloc_zeros_f32(out_numel, device)?;
10071 let cfg = launch_cfg(out_numel)?;
10072 let n_u32 = out_numel as u32;
10073 let ndim_u32 = ndim as u32;
10074
10075 unsafe {
10078 stream
10079 .launch_builder(&f)
10080 .arg(a.inner())
10081 .arg(b.inner())
10082 .arg(out.inner_mut())
10083 .arg(a_str_buf.inner())
10084 .arg(b_str_buf.inner())
10085 .arg(shape_buf.inner())
10086 .arg(&n_u32)
10087 .arg(&ndim_u32)
10088 .launch(cfg)?;
10089 }
10090
10091 Ok(Some(out))
10092}
10093
10094#[cfg(feature = "cuda")]
10101fn broadcast_strides(in_shape: &[usize], out_shape: &[usize]) -> Vec<u32> {
10102 let ndim = out_shape.len();
10103 let in_ndim = in_shape.len();
10104 let mut strides = vec![0u32; ndim];
10105
10106 let mut stride: u32 = 1;
10108 for d in (0..ndim).rev() {
10109 let in_d = if d + in_ndim >= ndim {
10110 d + in_ndim - ndim
10111 } else {
10112 strides[d] = 0;
10114 continue;
10115 };
10116
10117 if in_shape[in_d] == 1 {
10118 strides[d] = 0; } else {
10120 strides[d] = stride;
10121 }
10122 stride *= in_shape[in_d] as u32;
10123 }
10124
10125 strides
10126}
10127
10128#[cfg(feature = "cuda")]
10135fn cpu_fallback_binary(
10136 a: &CudaBuffer<f32>,
10137 b: &CudaBuffer<f32>,
10138 device: &GpuDevice,
10139 op: fn(f32, f32) -> f32,
10140) -> GpuResult<CudaBuffer<f32>> {
10141 let a_host = gpu_to_cpu(a, device)?;
10142 let b_host = gpu_to_cpu(b, device)?;
10143 let result: Vec<f32> = a_host
10144 .iter()
10145 .zip(b_host.iter())
10146 .map(|(&x, &y)| op(x, y))
10147 .collect();
10148 cpu_to_gpu(&result, device)
10149}
10150
10151#[cfg(feature = "cuda")]
10153fn cpu_fallback_unary(
10154 a: &CudaBuffer<f32>,
10155 device: &GpuDevice,
10156 op: fn(f32) -> f32,
10157) -> GpuResult<CudaBuffer<f32>> {
10158 let a_host = gpu_to_cpu(a, device)?;
10159 let result: Vec<f32> = a_host.iter().map(|&x| op(x)).collect();
10160 cpu_to_gpu(&result, device)
10161}
10162
10163#[cfg(feature = "cuda")]
10179pub fn gpu_add(
10180 a: &CudaBuffer<f32>,
10181 b: &CudaBuffer<f32>,
10182 device: &GpuDevice,
10183) -> GpuResult<CudaBuffer<f32>> {
10184 validate_binary(a, b, device)?;
10185
10186 let n = a.len();
10188 if n >= 16 && n % 4 == 0 {
10189 if let Some(out) = try_launch_binary_vec4(
10190 a, b, device, ADD_VEC4_PTX, "add_vec4_kernel",
10191 )? {
10192 return Ok(out);
10193 }
10194 }
10195
10196 if let Some(out) = try_launch_binary(a, b, device, ADD_PTX, "add_kernel")? {
10197 return Ok(out);
10198 }
10199
10200 cpu_fallback_binary(a, b, device, |x, y| x + y)
10201}
10202
10203#[cfg(feature = "cuda")]
10215pub fn gpu_sub(
10216 a: &CudaBuffer<f32>,
10217 b: &CudaBuffer<f32>,
10218 device: &GpuDevice,
10219) -> GpuResult<CudaBuffer<f32>> {
10220 validate_binary(a, b, device)?;
10221
10222 if let Some(out) = try_launch_binary(a, b, device, SUB_PTX, "sub_kernel")? {
10223 return Ok(out);
10224 }
10225
10226 cpu_fallback_binary(a, b, device, |x, y| x - y)
10227}
10228
10229#[cfg(feature = "cuda")]
10241pub fn gpu_mul(
10242 a: &CudaBuffer<f32>,
10243 b: &CudaBuffer<f32>,
10244 device: &GpuDevice,
10245) -> GpuResult<CudaBuffer<f32>> {
10246 validate_binary(a, b, device)?;
10247
10248 let n = a.len();
10249 if n >= 16 && n % 4 == 0 {
10250 if let Some(out) = try_launch_binary_vec4(
10251 a, b, device, MUL_VEC4_PTX, "mul_vec4_kernel",
10252 )? {
10253 return Ok(out);
10254 }
10255 }
10256
10257 if let Some(out) = try_launch_binary(a, b, device, MUL_PTX, "mul_kernel")? {
10258 return Ok(out);
10259 }
10260
10261 cpu_fallback_binary(a, b, device, |x, y| x * y)
10262}
10263
10264#[cfg(feature = "cuda")]
10277pub fn gpu_broadcast_add(
10278 a: &CudaBuffer<f32>,
10279 b: &CudaBuffer<f32>,
10280 a_shape: &[usize],
10281 b_shape: &[usize],
10282 out_shape: &[usize],
10283 device: &GpuDevice,
10284) -> GpuResult<CudaBuffer<f32>> {
10285 let a_str = broadcast_strides(a_shape, out_shape);
10286 let b_str = broadcast_strides(b_shape, out_shape);
10287 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10288 let out_numel: usize = out_shape.iter().product();
10289
10290 if let Some(out) = try_launch_broadcast_binary(
10291 a,
10292 b,
10293 &a_str,
10294 &b_str,
10295 &shape_u32,
10296 out_numel,
10297 device,
10298 BROADCAST_ADD_PTX,
10299 "broadcast_add_kernel",
10300 )? {
10301 return Ok(out);
10302 }
10303
10304 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
10306}
10307
10308#[cfg(feature = "cuda")]
10310pub fn gpu_broadcast_sub(
10311 a: &CudaBuffer<f32>,
10312 b: &CudaBuffer<f32>,
10313 a_shape: &[usize],
10314 b_shape: &[usize],
10315 out_shape: &[usize],
10316 device: &GpuDevice,
10317) -> GpuResult<CudaBuffer<f32>> {
10318 let a_str = broadcast_strides(a_shape, out_shape);
10319 let b_str = broadcast_strides(b_shape, out_shape);
10320 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10321 let out_numel: usize = out_shape.iter().product();
10322
10323 if let Some(out) = try_launch_broadcast_binary(
10324 a,
10325 b,
10326 &a_str,
10327 &b_str,
10328 &shape_u32,
10329 out_numel,
10330 device,
10331 BROADCAST_SUB_PTX,
10332 "broadcast_sub_kernel",
10333 )? {
10334 return Ok(out);
10335 }
10336
10337 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
10338}
10339
10340#[cfg(feature = "cuda")]
10342pub fn gpu_broadcast_mul(
10343 a: &CudaBuffer<f32>,
10344 b: &CudaBuffer<f32>,
10345 a_shape: &[usize],
10346 b_shape: &[usize],
10347 out_shape: &[usize],
10348 device: &GpuDevice,
10349) -> GpuResult<CudaBuffer<f32>> {
10350 let a_str = broadcast_strides(a_shape, out_shape);
10351 let b_str = broadcast_strides(b_shape, out_shape);
10352 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10353 let out_numel: usize = out_shape.iter().product();
10354
10355 if let Some(out) = try_launch_broadcast_binary(
10356 a,
10357 b,
10358 &a_str,
10359 &b_str,
10360 &shape_u32,
10361 out_numel,
10362 device,
10363 BROADCAST_MUL_PTX,
10364 "broadcast_mul_kernel",
10365 )? {
10366 return Ok(out);
10367 }
10368
10369 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
10370}
10371
10372#[cfg(feature = "cuda")]
10374pub fn gpu_broadcast_div(
10375 a: &CudaBuffer<f32>,
10376 b: &CudaBuffer<f32>,
10377 a_shape: &[usize],
10378 b_shape: &[usize],
10379 out_shape: &[usize],
10380 device: &GpuDevice,
10381) -> GpuResult<CudaBuffer<f32>> {
10382 let a_str = broadcast_strides(a_shape, out_shape);
10383 let b_str = broadcast_strides(b_shape, out_shape);
10384 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10385 let out_numel: usize = out_shape.iter().product();
10386
10387 if let Some(out) = try_launch_broadcast_binary(
10388 a,
10389 b,
10390 &a_str,
10391 &b_str,
10392 &shape_u32,
10393 out_numel,
10394 device,
10395 BROADCAST_DIV_PTX,
10396 "broadcast_div_kernel",
10397 )? {
10398 return Ok(out);
10399 }
10400
10401 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
10402}
10403
10404#[cfg(feature = "cuda")]
10407fn cpu_fallback_broadcast_binary(
10408 a: &CudaBuffer<f32>,
10409 b: &CudaBuffer<f32>,
10410 a_shape: &[usize],
10411 b_shape: &[usize],
10412 out_shape: &[usize],
10413 device: &GpuDevice,
10414 op: fn(f32, f32) -> f32,
10415) -> GpuResult<CudaBuffer<f32>> {
10416 let a_host = gpu_to_cpu(a, device)?;
10417 let b_host = gpu_to_cpu(b, device)?;
10418 let out_numel: usize = out_shape.iter().product();
10419
10420 let a_str = broadcast_strides(a_shape, out_shape);
10421 let b_str = broadcast_strides(b_shape, out_shape);
10422
10423 let mut result = Vec::with_capacity(out_numel);
10424 for i in 0..out_numel {
10425 let mut remaining = i;
10426 let mut a_idx = 0usize;
10427 let mut b_idx = 0usize;
10428 for d in (0..out_shape.len()).rev() {
10429 let coord = remaining % out_shape[d];
10430 remaining /= out_shape[d];
10431 a_idx += coord * a_str[d] as usize;
10432 b_idx += coord * b_str[d] as usize;
10433 }
10434 result.push(op(a_host[a_idx], b_host[b_idx]));
10435 }
10436 cpu_to_gpu(&result, device)
10437}
10438
10439#[cfg(feature = "cuda")]
10454pub fn gpu_neg(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
10455 validate_unary(a, device)?;
10456
10457 if let Some(out) = try_launch_unary(a, device, NEG_PTX, "neg_kernel")? {
10458 return Ok(out);
10459 }
10460
10461 cpu_fallback_unary(a, device, |x| -x)
10462}
10463
10464#[cfg(feature = "cuda")]
10475pub fn gpu_relu(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
10476 validate_unary(a, device)?;
10477
10478 if let Some(out) = try_launch_unary(a, device, RELU_PTX, "relu_kernel")? {
10479 return Ok(out);
10480 }
10481
10482 cpu_fallback_unary(a, device, |x| x.max(0.0))
10483}
10484
10485#[cfg(feature = "cuda")]
10487pub fn gpu_relu_backward(
10488 grad: &CudaBuffer<f32>,
10489 input: &CudaBuffer<f32>,
10490 device: &GpuDevice,
10491) -> GpuResult<CudaBuffer<f32>> {
10492 validate_binary(grad, input, device)?;
10493
10494 if let Some(out) = try_launch_binary(
10495 grad,
10496 input,
10497 device,
10498 RELU_BACKWARD_PTX,
10499 "relu_backward_kernel",
10500 )? {
10501 return Ok(out);
10502 }
10503
10504 let grad_host = gpu_to_cpu(grad, device)?;
10506 let input_host = gpu_to_cpu(input, device)?;
10507 let result: Vec<f32> = grad_host
10508 .iter()
10509 .zip(input_host.iter())
10510 .map(|(&g, &x)| if x > 0.0 { g } else { 0.0 })
10511 .collect();
10512 cpu_to_gpu(&result, device)
10513}
10514
10515#[cfg(feature = "cuda")]
10518pub fn gpu_abs_backward(
10519 grad: &CudaBuffer<f32>,
10520 input: &CudaBuffer<f32>,
10521 device: &GpuDevice,
10522) -> GpuResult<CudaBuffer<f32>> {
10523 validate_binary(grad, input, device)?;
10524
10525 if let Some(out) = try_launch_binary(
10526 grad,
10527 input,
10528 device,
10529 ABS_BACKWARD_PTX,
10530 "abs_backward_kernel",
10531 )? {
10532 return Ok(out);
10533 }
10534
10535 let grad_host = gpu_to_cpu(grad, device)?;
10537 let input_host = gpu_to_cpu(input, device)?;
10538 let result: Vec<f32> = grad_host
10539 .iter()
10540 .zip(input_host.iter())
10541 .map(|(&g, &x)| {
10542 if x > 0.0 {
10543 g
10544 } else if x < 0.0 {
10545 -g
10546 } else {
10547 0.0
10548 }
10549 })
10550 .collect();
10551 cpu_to_gpu(&result, device)
10552}
10553
10554#[cfg(feature = "cuda")]
10557pub fn gpu_gelu_backward(
10558 grad: &CudaBuffer<f32>,
10559 input: &CudaBuffer<f32>,
10560 device: &GpuDevice,
10561) -> GpuResult<CudaBuffer<f32>> {
10562 validate_binary(grad, input, device)?;
10563
10564 if let Some(out) = try_launch_binary(
10565 grad,
10566 input,
10567 device,
10568 GELU_BACKWARD_PTX,
10569 "gelu_backward_kernel",
10570 )? {
10571 return Ok(out);
10572 }
10573
10574 let grad_host = gpu_to_cpu(grad, device)?;
10576 let input_host = gpu_to_cpu(input, device)?;
10577 let result: Vec<f32> = grad_host
10578 .iter()
10579 .zip(input_host.iter())
10580 .map(|(&g, &x)| {
10581 let k: f32 = 1.702;
10582 let sig = 1.0 / (1.0 + (-k * x).exp());
10583 g * (sig + k * x * sig * (1.0 - sig))
10584 })
10585 .collect();
10586 cpu_to_gpu(&result, device)
10587}
10588
10589#[cfg(feature = "cuda")]
10593pub fn gpu_gelu_backward_erf(
10594 grad: &CudaBuffer<f32>,
10595 input: &CudaBuffer<f32>,
10596 device: &GpuDevice,
10597) -> GpuResult<CudaBuffer<f32>> {
10598 validate_binary(grad, input, device)?;
10599
10600 if let Some(out) = try_launch_binary(
10601 grad,
10602 input,
10603 device,
10604 GELU_BACKWARD_ERF_PTX,
10605 "gelu_backward_erf_kernel",
10606 )? {
10607 return Ok(out);
10608 }
10609
10610 let grad_host = gpu_to_cpu(grad, device)?;
10612 let input_host = gpu_to_cpu(input, device)?;
10613 let inv_sqrt_2: f32 = std::f32::consts::FRAC_1_SQRT_2;
10614 let inv_sqrt_2pi: f32 = 1.0 / (2.0 * std::f32::consts::PI).sqrt();
10615 let result: Vec<f32> = grad_host
10616 .iter()
10617 .zip(input_host.iter())
10618 .map(|(&g, &x)| {
10619 let z = x * inv_sqrt_2;
10620 let az = z.abs();
10621 let t = 1.0 / (1.0 + 0.3275911 * az);
10622 let poly = t * (0.2548296 + t * (-0.2844967 + t * (1.4214137 + t * (-1.453_152 + t * 0.3275911))));
10623 let erf_abs = 1.0 - poly * (-az * az).exp();
10624 let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
10625 let cdf = 0.5 * (1.0 + erf_val);
10626 let pdf = inv_sqrt_2pi * (-0.5 * x * x).exp();
10627 g * (cdf + x * pdf)
10628 })
10629 .collect();
10630 cpu_to_gpu(&result, device)
10631}
10632
10633#[cfg(feature = "cuda")]
10642pub fn gpu_index_select_1d(
10643 input: &CudaBuffer<f32>,
10644 indices: &CudaBuffer<f32>,
10645 device: &GpuDevice,
10646) -> GpuResult<CudaBuffer<f32>> {
10647 use cudarc::driver::PushKernelArg;
10648
10649 validate_unary(input, device)?;
10650
10651 let n = indices.len();
10652 let ctx = device.context();
10653 let stream = device.stream();
10654
10655 let f = match crate::module_cache::get_or_compile(
10656 ctx,
10657 INDEX_SELECT_1D_PTX,
10658 "index_select_1d_kernel",
10659 device.ordinal() as u32,
10660 ) {
10661 Ok(f) => f,
10662 Err(_) => {
10663 let input_host = gpu_to_cpu(input, device)?;
10665 let indices_host = gpu_to_cpu(indices, device)?;
10666 let result: Vec<f32> = indices_host
10667 .iter()
10668 .map(|&idx_f| input_host[idx_f as usize])
10669 .collect();
10670 return cpu_to_gpu(&result, device);
10671 }
10672 };
10673
10674 let mut out = alloc_zeros_f32(n, device)?;
10675 let cfg = launch_cfg(n)?;
10676 let n_u32 = n as u32;
10677
10678 unsafe {
10679 stream
10680 .launch_builder(&f)
10681 .arg(input.inner())
10682 .arg(indices.inner())
10683 .arg(out.inner_mut())
10684 .arg(&n_u32)
10685 .launch(cfg)?;
10686 }
10687
10688 Ok(out)
10689}
10690
10691#[cfg(feature = "cuda")]
10703pub fn gpu_scatter_add_1d(
10704 grad_output: &CudaBuffer<f32>,
10705 indices: &CudaBuffer<f32>,
10706 input_len: usize,
10707 device: &GpuDevice,
10708) -> GpuResult<CudaBuffer<f32>> {
10709 use cudarc::driver::PushKernelArg;
10710
10711 validate_unary(grad_output, device)?;
10712
10713 let n = grad_output.len();
10714 let ctx = device.context();
10715 let stream = device.stream();
10716
10717 let f = match crate::module_cache::get_or_compile(
10718 ctx,
10719 SCATTER_ADD_1D_PTX,
10720 "scatter_add_1d_kernel",
10721 device.ordinal() as u32,
10722 ) {
10723 Ok(f) => f,
10724 Err(_) => {
10725 let go_host = gpu_to_cpu(grad_output, device)?;
10727 let idx_host = gpu_to_cpu(indices, device)?;
10728 let mut result = vec![0.0f32; input_len];
10729 for (i, &idx_f) in idx_host.iter().enumerate() {
10730 result[idx_f as usize] += go_host[i];
10731 }
10732 return cpu_to_gpu(&result, device);
10733 }
10734 };
10735
10736 let mut out = alloc_zeros_f32(input_len, device)?;
10737 let cfg = launch_cfg(n)?;
10738 let n_u32 = n as u32;
10739
10740 unsafe {
10741 stream
10742 .launch_builder(&f)
10743 .arg(grad_output.inner())
10744 .arg(indices.inner())
10745 .arg(out.inner_mut())
10746 .arg(&n_u32)
10747 .launch(cfg)?;
10748 }
10749
10750 Ok(out)
10751}
10752
10753#[cfg(feature = "cuda")]
10762pub fn gpu_masked_fill(
10763 input: &CudaBuffer<f32>,
10764 mask: &CudaBuffer<f32>,
10765 value: f32,
10766 device: &GpuDevice,
10767) -> GpuResult<CudaBuffer<f32>> {
10768 use cudarc::driver::PushKernelArg;
10769
10770 validate_binary(input, mask, device)?;
10771
10772 let n = input.len();
10773 let ctx = device.context();
10774 let stream = device.stream();
10775
10776 let f = match crate::module_cache::get_or_compile(
10777 ctx,
10778 MASKED_FILL_PTX,
10779 "masked_fill_kernel",
10780 device.ordinal() as u32,
10781 ) {
10782 Ok(f) => f,
10783 Err(_) => {
10784 let input_host = gpu_to_cpu(input, device)?;
10786 let mask_host = gpu_to_cpu(mask, device)?;
10787 let result: Vec<f32> = input_host
10788 .iter()
10789 .zip(mask_host.iter())
10790 .map(|(&x, &m)| if m >= 0.5 { value } else { x })
10791 .collect();
10792 return cpu_to_gpu(&result, device);
10793 }
10794 };
10795
10796 let mut out = alloc_zeros_f32(n, device)?;
10797 let cfg = launch_cfg(n)?;
10798 let n_u32 = n as u32;
10799
10800 unsafe {
10801 stream
10802 .launch_builder(&f)
10803 .arg(input.inner())
10804 .arg(mask.inner())
10805 .arg(out.inner_mut())
10806 .arg(&value)
10807 .arg(&n_u32)
10808 .launch(cfg)?;
10809 }
10810
10811 Ok(out)
10812}
10813
10814#[cfg(feature = "cuda")]
10823pub fn gpu_masked_zero(
10824 grad: &CudaBuffer<f32>,
10825 mask: &CudaBuffer<f32>,
10826 device: &GpuDevice,
10827) -> GpuResult<CudaBuffer<f32>> {
10828 validate_binary(grad, mask, device)?;
10829
10830 if let Some(out) = try_launch_binary(grad, mask, device, MASKED_ZERO_PTX, "masked_zero_kernel")?
10831 {
10832 return Ok(out);
10833 }
10834
10835 let grad_host = gpu_to_cpu(grad, device)?;
10837 let mask_host = gpu_to_cpu(mask, device)?;
10838 let result: Vec<f32> = grad_host
10839 .iter()
10840 .zip(mask_host.iter())
10841 .map(|(&g, &m)| if m >= 0.5 { 0.0 } else { g })
10842 .collect();
10843 cpu_to_gpu(&result, device)
10844}
10845
10846#[cfg(feature = "cuda")]
10854pub fn gpu_sigmoid_backward(
10855 grad: &CudaBuffer<f32>,
10856 output: &CudaBuffer<f32>,
10857 device: &GpuDevice,
10858) -> GpuResult<CudaBuffer<f32>> {
10859 validate_binary(grad, output, device)?;
10860
10861 if let Some(out) = try_launch_binary(
10862 grad,
10863 output,
10864 device,
10865 SIGMOID_BACKWARD_PTX,
10866 "sigmoid_backward_kernel",
10867 )? {
10868 return Ok(out);
10869 }
10870
10871 let grad_host = gpu_to_cpu(grad, device)?;
10873 let output_host = gpu_to_cpu(output, device)?;
10874 let result: Vec<f32> = grad_host
10875 .iter()
10876 .zip(output_host.iter())
10877 .map(|(&g, &o)| g * o * (1.0 - o))
10878 .collect();
10879 cpu_to_gpu(&result, device)
10880}
10881
10882#[cfg(feature = "cuda")]
10890pub fn gpu_tanh_backward(
10891 grad: &CudaBuffer<f32>,
10892 output: &CudaBuffer<f32>,
10893 device: &GpuDevice,
10894) -> GpuResult<CudaBuffer<f32>> {
10895 validate_binary(grad, output, device)?;
10896
10897 if let Some(out) = try_launch_binary(
10898 grad,
10899 output,
10900 device,
10901 TANH_BACKWARD_PTX,
10902 "tanh_backward_kernel",
10903 )? {
10904 return Ok(out);
10905 }
10906
10907 let grad_host = gpu_to_cpu(grad, device)?;
10909 let output_host = gpu_to_cpu(output, device)?;
10910 let result: Vec<f32> = grad_host
10911 .iter()
10912 .zip(output_host.iter())
10913 .map(|(&g, &o)| g * (1.0 - o * o))
10914 .collect();
10915 cpu_to_gpu(&result, device)
10916}
10917
10918#[cfg(feature = "cuda")]
10930pub fn gpu_softmax_backward(
10931 grad: &CudaBuffer<f32>,
10932 output: &CudaBuffer<f32>,
10933 cols: usize,
10934 device: &GpuDevice,
10935) -> GpuResult<CudaBuffer<f32>> {
10936 use cudarc::driver::PushKernelArg;
10937
10938 validate_binary(grad, output, device)?;
10939
10940 let total = grad.len();
10941 let rows = total / cols;
10942
10943 let ctx = device.context();
10944 let stream = device.stream();
10945
10946 let f = match crate::module_cache::get_or_compile(
10947 ctx,
10948 SOFTMAX_BACKWARD_PTX,
10949 "softmax_backward_kernel",
10950 device.ordinal() as u32,
10951 ) {
10952 Ok(f) => f,
10953 Err(_) => {
10954 let grad_host = gpu_to_cpu(grad, device)?;
10956 let output_host = gpu_to_cpu(output, device)?;
10957 let mut result = vec![0.0f32; total];
10958 for r in 0..rows {
10959 let base = r * cols;
10960 let mut dot = 0.0f32;
10961 for c in 0..cols {
10962 dot += grad_host[base + c] * output_host[base + c];
10963 }
10964 for c in 0..cols {
10965 result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
10966 }
10967 }
10968 return cpu_to_gpu(&result, device);
10969 }
10970 };
10971
10972 let mut out = alloc_zeros_f32(total, device)?;
10973 let rows_u32 = rows as u32;
10974 let cols_u32 = cols as u32;
10975
10976 let cfg = LaunchConfig {
10978 grid_dim: ((rows as u32).max(1), 1, 1),
10979 block_dim: (256, 1, 1),
10980 shared_mem_bytes: 256 * 4,
10981 };
10982
10983 unsafe {
10984 stream
10985 .launch_builder(&f)
10986 .arg(grad.inner())
10987 .arg(output.inner())
10988 .arg(out.inner_mut())
10989 .arg(&rows_u32)
10990 .arg(&cols_u32)
10991 .launch(cfg)?;
10992 }
10993
10994 Ok(out)
10995}
10996
10997#[cfg(feature = "cuda")]
11008pub fn gpu_log_softmax(
11009 input: &CudaBuffer<f32>,
11010 cols: usize,
11011 device: &GpuDevice,
11012) -> GpuResult<CudaBuffer<f32>> {
11013 use cudarc::driver::PushKernelArg;
11014
11015 validate_unary(input, device)?;
11016
11017 let total = input.len();
11018 let rows = total / cols;
11019
11020 let ctx = device.context();
11021 let stream = device.stream();
11022
11023 let f = match crate::module_cache::get_or_compile(
11024 ctx,
11025 LOG_SOFTMAX_PTX,
11026 "log_softmax_kernel",
11027 device.ordinal() as u32,
11028 ) {
11029 Ok(f) => f,
11030 Err(_) => {
11031 let host = gpu_to_cpu(input, device)?;
11033 let mut out = vec![0.0f32; total];
11034 for r in 0..rows {
11035 let base = r * cols;
11036 let mut max_v = f32::NEG_INFINITY;
11037 for c in 0..cols {
11038 max_v = max_v.max(host[base + c]);
11039 }
11040 let mut sum_exp = 0.0f32;
11041 for c in 0..cols {
11042 sum_exp += (host[base + c] - max_v).exp();
11043 }
11044 let log_sum_exp = max_v + sum_exp.ln();
11045 for c in 0..cols {
11046 out[base + c] = host[base + c] - log_sum_exp;
11047 }
11048 }
11049 return cpu_to_gpu(&out, device);
11050 }
11051 };
11052
11053 let mut out = alloc_zeros_f32(total, device)?;
11054 let rows_u32 = rows as u32;
11055 let cols_u32 = cols as u32;
11056
11057 let cfg = LaunchConfig {
11059 grid_dim: ((rows as u32).max(1), 1, 1),
11060 block_dim: (256, 1, 1),
11061 shared_mem_bytes: 256 * 4,
11062 };
11063
11064 unsafe {
11065 stream
11066 .launch_builder(&f)
11067 .arg(input.inner())
11068 .arg(out.inner_mut())
11069 .arg(&rows_u32)
11070 .arg(&cols_u32)
11071 .launch(cfg)?;
11072 }
11073
11074 Ok(out)
11075}
11076
11077#[cfg(feature = "cuda")]
11085pub fn gpu_log_softmax_backward(
11086 grad: &CudaBuffer<f32>,
11087 output: &CudaBuffer<f32>,
11088 cols: usize,
11089 device: &GpuDevice,
11090) -> GpuResult<CudaBuffer<f32>> {
11091 use cudarc::driver::PushKernelArg;
11092
11093 validate_binary(grad, output, device)?;
11094
11095 let total = grad.len();
11096 let rows = total / cols;
11097
11098 let ctx = device.context();
11099 let stream = device.stream();
11100
11101 let f = match crate::module_cache::get_or_compile(
11102 ctx,
11103 LOG_SOFTMAX_BACKWARD_PTX,
11104 "log_softmax_backward_kernel",
11105 device.ordinal() as u32,
11106 ) {
11107 Ok(f) => f,
11108 Err(_) => {
11109 let grad_host = gpu_to_cpu(grad, device)?;
11111 let output_host = gpu_to_cpu(output, device)?;
11112 let mut result = vec![0.0f32; total];
11113 for r in 0..rows {
11114 let base = r * cols;
11115 let mut sum_grad = 0.0f32;
11116 for c in 0..cols {
11117 sum_grad += grad_host[base + c];
11118 }
11119 for c in 0..cols {
11120 result[base + c] =
11121 grad_host[base + c] - output_host[base + c].exp() * sum_grad;
11122 }
11123 }
11124 return cpu_to_gpu(&result, device);
11125 }
11126 };
11127
11128 let mut out = alloc_zeros_f32(total, device)?;
11129 let rows_u32 = rows as u32;
11130 let cols_u32 = cols as u32;
11131
11132 let cfg = LaunchConfig {
11134 grid_dim: ((rows as u32).max(1), 1, 1),
11135 block_dim: (256, 1, 1),
11136 shared_mem_bytes: 256 * 4,
11137 };
11138
11139 unsafe {
11140 stream
11141 .launch_builder(&f)
11142 .arg(grad.inner())
11143 .arg(output.inner())
11144 .arg(out.inner_mut())
11145 .arg(&rows_u32)
11146 .arg(&cols_u32)
11147 .launch(cfg)?;
11148 }
11149
11150 Ok(out)
11151}
11152
11153#[cfg(feature = "cuda")]
11167pub fn gpu_reduce_sum(
11168 a: &CudaBuffer<f32>,
11169 device: &GpuDevice,
11170) -> GpuResult<CudaBuffer<f32>> {
11171 use cudarc::driver::PushKernelArg;
11172
11173 let n = a.len();
11174 if n == 0 {
11175 return cpu_to_gpu(&[0.0f32], device);
11176 }
11177
11178 let ctx = device.context();
11179 let stream = device.stream();
11180
11181 let f = match crate::module_cache::get_or_compile(
11182 ctx,
11183 REDUCE_SUM_PTX,
11184 "reduce_sum_kernel",
11185 device.ordinal() as u32,
11186 ) {
11187 Ok(f) => f,
11188 Err(_) => {
11189 let host = gpu_to_cpu(a, device)?;
11191 let total: f32 = host.iter().sum();
11192 return cpu_to_gpu(&[total], device);
11193 }
11194 };
11195
11196 const BLOCK: u32 = 256;
11198 let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
11199 let num_blocks = num_blocks.min(1024);
11201
11202 let mut partials = alloc_zeros_f32(num_blocks as usize, device)?;
11203 let n_u32 = n as u32;
11204
11205 let cfg = cudarc::driver::LaunchConfig {
11206 grid_dim: (num_blocks.max(1), 1, 1),
11207 block_dim: (BLOCK, 1, 1),
11208 shared_mem_bytes: 0, };
11210
11211 unsafe {
11212 stream
11213 .launch_builder(&f)
11214 .arg(a.inner())
11215 .arg(partials.inner_mut())
11216 .arg(&n_u32)
11217 .launch(cfg)?;
11218 }
11219
11220 if num_blocks <= 1 {
11222 return Ok(partials);
11223 }
11224
11225 if num_blocks <= 256 {
11227 let host_partials = gpu_to_cpu(&partials, device)?;
11228 let total: f32 = host_partials.iter().sum();
11229 return cpu_to_gpu(&[total], device);
11230 }
11231
11232 gpu_reduce_sum(&partials, device)
11234}
11235
11236#[cfg(not(feature = "cuda"))]
11238pub fn gpu_reduce_sum(
11239 _a: &CudaBuffer<f32>,
11240 _device: &GpuDevice,
11241) -> GpuResult<CudaBuffer<f32>> {
11242 Err(GpuError::NoCudaFeature)
11243}
11244
11245#[cfg(feature = "cuda")]
11249pub fn gpu_sum_axis(
11250 a: &CudaBuffer<f32>,
11251 outer: usize,
11252 axis_size: usize,
11253 inner: usize,
11254 device: &GpuDevice,
11255) -> GpuResult<CudaBuffer<f32>> {
11256 use cudarc::driver::PushKernelArg;
11257
11258 validate_unary(a, device)?;
11259
11260 let total_output = outer * inner;
11261 let ctx = device.context();
11262 let stream = device.stream();
11263
11264 let f = match crate::module_cache::get_or_compile(
11265 ctx,
11266 SUM_AXIS_PTX,
11267 "sum_axis_kernel",
11268 device.ordinal() as u32,
11269 ) {
11270 Ok(f) => f,
11271 Err(_) => {
11272 let host = gpu_to_cpu(a, device)?;
11274 let mut result = vec![0.0f32; total_output];
11275 for (i, out) in result.iter_mut().enumerate() {
11276 let outer_idx = i / inner;
11277 let inner_idx = i % inner;
11278 let mut sum = 0.0f32;
11279 for k in 0..axis_size {
11280 sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
11281 }
11282 *out = sum;
11283 }
11284 return cpu_to_gpu(&result, device);
11285 }
11286 };
11287
11288 let mut out = alloc_zeros_f32(total_output, device)?;
11289 let cfg = launch_cfg(total_output)?;
11290 let outer_u32 = outer as u32;
11291 let axis_size_u32 = axis_size as u32;
11292 let inner_u32 = inner as u32;
11293 let total_u32 = total_output as u32;
11294
11295 unsafe {
11296 stream
11297 .launch_builder(&f)
11298 .arg(a.inner())
11299 .arg(out.inner_mut())
11300 .arg(&outer_u32)
11301 .arg(&axis_size_u32)
11302 .arg(&inner_u32)
11303 .arg(&total_u32)
11304 .launch(cfg)?;
11305 }
11306
11307 Ok(out)
11308}
11309
11310#[cfg(feature = "cuda")]
11327pub fn gpu_cumsum(
11328 input: &CudaBuffer<f32>,
11329 outer: usize,
11330 dim_size: usize,
11331 inner: usize,
11332 device: &GpuDevice,
11333) -> GpuResult<CudaBuffer<f32>> {
11334 use cudarc::driver::PushKernelArg;
11335
11336 validate_unary(input, device)?;
11337
11338 let total = outer * dim_size * inner;
11339 let num_threads = outer * inner;
11340 let ctx = device.context();
11341 let stream = device.stream();
11342
11343 let f = match crate::module_cache::get_or_compile(
11344 ctx,
11345 CUMSUM_PTX,
11346 "cumsum_kernel",
11347 device.ordinal() as u32,
11348 ) {
11349 Ok(f) => f,
11350 Err(_) => {
11351 let host = gpu_to_cpu(input, device)?;
11353 let mut result = vec![0.0f32; total];
11354 for i in 0..num_threads {
11355 let outer_idx = i / inner;
11356 let inner_idx = i % inner;
11357 let base = outer_idx * dim_size * inner + inner_idx;
11358 let mut acc = 0.0f32;
11359 for k in 0..dim_size {
11360 let idx = base + k * inner;
11361 acc += host[idx];
11362 result[idx] = acc;
11363 }
11364 }
11365 return cpu_to_gpu(&result, device);
11366 }
11367 };
11368
11369 let mut out = alloc_zeros_f32(total, device)?;
11370 let cfg = launch_cfg(num_threads)?;
11371 let outer_u32 = outer as u32;
11372 let dim_size_u32 = dim_size as u32;
11373 let inner_u32 = inner as u32;
11374 let total_u32 = total as u32;
11375
11376 unsafe {
11377 stream
11378 .launch_builder(&f)
11379 .arg(input.inner())
11380 .arg(out.inner_mut())
11381 .arg(&outer_u32)
11382 .arg(&dim_size_u32)
11383 .arg(&inner_u32)
11384 .arg(&total_u32)
11385 .launch(cfg)?;
11386 }
11387
11388 Ok(out)
11389}
11390
11391#[cfg(feature = "cuda")]
11401pub fn gpu_cumprod(
11402 input: &CudaBuffer<f32>,
11403 outer: usize,
11404 dim_size: usize,
11405 inner: usize,
11406 device: &GpuDevice,
11407) -> GpuResult<CudaBuffer<f32>> {
11408 use cudarc::driver::PushKernelArg;
11409
11410 validate_unary(input, device)?;
11411
11412 let total = outer * dim_size * inner;
11413 let num_threads = outer * inner;
11414 let ctx = device.context();
11415 let stream = device.stream();
11416
11417 let f = match crate::module_cache::get_or_compile(
11418 ctx,
11419 CUMPROD_PTX,
11420 "cumprod_kernel",
11421 device.ordinal() as u32,
11422 ) {
11423 Ok(f) => f,
11424 Err(_) => {
11425 let host = gpu_to_cpu(input, device)?;
11427 let mut result = vec![0.0f32; total];
11428 for i in 0..num_threads {
11429 let outer_idx = i / inner;
11430 let inner_idx = i % inner;
11431 let base = outer_idx * dim_size * inner + inner_idx;
11432 let mut acc = 1.0f32;
11433 for k in 0..dim_size {
11434 let idx = base + k * inner;
11435 acc *= host[idx];
11436 result[idx] = acc;
11437 }
11438 }
11439 return cpu_to_gpu(&result, device);
11440 }
11441 };
11442
11443 let mut out = alloc_zeros_f32(total, device)?;
11444 let cfg = launch_cfg(num_threads)?;
11445 let outer_u32 = outer as u32;
11446 let dim_size_u32 = dim_size as u32;
11447 let inner_u32 = inner as u32;
11448 let total_u32 = total as u32;
11449
11450 unsafe {
11451 stream
11452 .launch_builder(&f)
11453 .arg(input.inner())
11454 .arg(out.inner_mut())
11455 .arg(&outer_u32)
11456 .arg(&dim_size_u32)
11457 .arg(&inner_u32)
11458 .arg(&total_u32)
11459 .launch(cfg)?;
11460 }
11461
11462 Ok(out)
11463}
11464
11465#[cfg(feature = "cuda")]
11475pub fn gpu_cummax(
11476 input: &CudaBuffer<f32>,
11477 outer: usize,
11478 dim_size: usize,
11479 inner: usize,
11480 device: &GpuDevice,
11481) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
11482 use cudarc::driver::PushKernelArg;
11483
11484 validate_unary(input, device)?;
11485
11486 let total = outer * dim_size * inner;
11487 let num_threads = outer * inner;
11488 let ctx = device.context();
11489 let stream = device.stream();
11490
11491 let f = match crate::module_cache::get_or_compile(
11492 ctx,
11493 CUMMAX_PTX,
11494 "cummax_kernel",
11495 device.ordinal() as u32,
11496 ) {
11497 Ok(f) => f,
11498 Err(_) => {
11499 let host = gpu_to_cpu(input, device)?;
11500 let mut vals = vec![0.0f32; total];
11501 let mut idxs = vec![0.0f32; total];
11502 for i in 0..num_threads {
11503 let outer_idx = i / inner;
11504 let inner_idx = i % inner;
11505 let base = outer_idx * dim_size * inner + inner_idx;
11506 let mut acc = f32::NEG_INFINITY;
11507 let mut best = 0u32;
11508 for k in 0..dim_size {
11509 let idx = base + k * inner;
11510 if host[idx] > acc {
11511 acc = host[idx];
11512 best = k as u32;
11513 }
11514 vals[idx] = acc;
11515 idxs[idx] = best as f32;
11516 }
11517 }
11518 return Ok((cpu_to_gpu(&vals, device)?, cpu_to_gpu(&idxs, device)?));
11519 }
11520 };
11521
11522 let mut out = alloc_zeros_f32(total, device)?;
11523 let mut out_idx = alloc_zeros_f32(total, device)?;
11524 let cfg = launch_cfg(num_threads)?;
11525 let outer_u32 = outer as u32;
11526 let dim_size_u32 = dim_size as u32;
11527 let inner_u32 = inner as u32;
11528 let total_u32 = total as u32;
11529
11530 unsafe {
11531 stream
11532 .launch_builder(&f)
11533 .arg(input.inner())
11534 .arg(out.inner_mut())
11535 .arg(out_idx.inner_mut())
11536 .arg(&outer_u32)
11537 .arg(&dim_size_u32)
11538 .arg(&inner_u32)
11539 .arg(&total_u32)
11540 .launch(cfg)?;
11541 }
11542
11543 Ok((out, out_idx))
11544}
11545
11546#[cfg(feature = "cuda")]
11556pub fn gpu_cummin(
11557 input: &CudaBuffer<f32>,
11558 outer: usize,
11559 dim_size: usize,
11560 inner: usize,
11561 device: &GpuDevice,
11562) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
11563 use cudarc::driver::PushKernelArg;
11564
11565 validate_unary(input, device)?;
11566
11567 let total = outer * dim_size * inner;
11568 let num_threads = outer * inner;
11569 let ctx = device.context();
11570 let stream = device.stream();
11571
11572 let f = match crate::module_cache::get_or_compile(
11573 ctx,
11574 CUMMIN_PTX,
11575 "cummin_kernel",
11576 device.ordinal() as u32,
11577 ) {
11578 Ok(f) => f,
11579 Err(_) => {
11580 let host = gpu_to_cpu(input, device)?;
11581 let mut vals = vec![0.0f32; total];
11582 let mut idxs = vec![0.0f32; total];
11583 for i in 0..num_threads {
11584 let outer_idx = i / inner;
11585 let inner_idx = i % inner;
11586 let base = outer_idx * dim_size * inner + inner_idx;
11587 let mut acc = f32::INFINITY;
11588 let mut best = 0u32;
11589 for k in 0..dim_size {
11590 let idx = base + k * inner;
11591 if host[idx] < acc {
11592 acc = host[idx];
11593 best = k as u32;
11594 }
11595 vals[idx] = acc;
11596 idxs[idx] = best as f32;
11597 }
11598 }
11599 return Ok((cpu_to_gpu(&vals, device)?, cpu_to_gpu(&idxs, device)?));
11600 }
11601 };
11602
11603 let mut out = alloc_zeros_f32(total, device)?;
11604 let mut out_idx = alloc_zeros_f32(total, device)?;
11605 let cfg = launch_cfg(num_threads)?;
11606 let outer_u32 = outer as u32;
11607 let dim_size_u32 = dim_size as u32;
11608 let inner_u32 = inner as u32;
11609 let total_u32 = total as u32;
11610
11611 unsafe {
11612 stream
11613 .launch_builder(&f)
11614 .arg(input.inner())
11615 .arg(out.inner_mut())
11616 .arg(out_idx.inner_mut())
11617 .arg(&outer_u32)
11618 .arg(&dim_size_u32)
11619 .arg(&inner_u32)
11620 .arg(&total_u32)
11621 .launch(cfg)?;
11622 }
11623
11624 Ok((out, out_idx))
11625}
11626
11627#[cfg(feature = "cuda")]
11637pub fn gpu_logcumsumexp(
11638 input: &CudaBuffer<f32>,
11639 outer: usize,
11640 dim_size: usize,
11641 inner: usize,
11642 device: &GpuDevice,
11643) -> GpuResult<CudaBuffer<f32>> {
11644 use cudarc::driver::PushKernelArg;
11645
11646 validate_unary(input, device)?;
11647
11648 let total = outer * dim_size * inner;
11649 let num_threads = outer * inner;
11650 let ctx = device.context();
11651 let stream = device.stream();
11652
11653 let f = match crate::module_cache::get_or_compile(
11654 ctx,
11655 LOGCUMSUMEXP_PTX,
11656 "logcumsumexp_kernel",
11657 device.ordinal() as u32,
11658 ) {
11659 Ok(f) => f,
11660 Err(_) => {
11661 let host = gpu_to_cpu(input, device)?;
11663 let mut result = vec![0.0f32; total];
11664 for i in 0..num_threads {
11665 let outer_idx = i / inner;
11666 let inner_idx = i % inner;
11667 let base = outer_idx * dim_size * inner + inner_idx;
11668 let mut acc = f32::NEG_INFINITY;
11669 for k in 0..dim_size {
11670 let idx = base + k * inner;
11671 let x = host[idx];
11672 let m = acc.max(x);
11673 acc = m + ((acc - m).exp() + (x - m).exp()).ln();
11674 result[idx] = acc;
11675 }
11676 }
11677 return cpu_to_gpu(&result, device);
11678 }
11679 };
11680
11681 let mut out = alloc_zeros_f32(total, device)?;
11682 let cfg = launch_cfg(num_threads)?;
11683 let outer_u32 = outer as u32;
11684 let dim_size_u32 = dim_size as u32;
11685 let inner_u32 = inner as u32;
11686 let total_u32 = total as u32;
11687
11688 unsafe {
11689 stream
11690 .launch_builder(&f)
11691 .arg(input.inner())
11692 .arg(out.inner_mut())
11693 .arg(&outer_u32)
11694 .arg(&dim_size_u32)
11695 .arg(&inner_u32)
11696 .arg(&total_u32)
11697 .launch(cfg)?;
11698 }
11699
11700 Ok(out)
11701}
11702
11703#[cfg(feature = "cuda")]
11721pub fn gpu_strided_split(
11722 input: &CudaBuffer<f32>,
11723 total_along_axis: usize,
11724 split_offset: usize,
11725 split_size: usize,
11726 inner_size: usize,
11727 n: usize,
11728 device: &GpuDevice,
11729) -> GpuResult<CudaBuffer<f32>> {
11730 use cudarc::driver::PushKernelArg;
11731
11732 validate_unary(input, device)?;
11733
11734 let ctx = device.context();
11735 let stream = device.stream();
11736
11737 let f = match crate::module_cache::get_or_compile(
11738 ctx,
11739 STRIDED_SPLIT_PTX,
11740 "strided_split_kernel",
11741 device.ordinal() as u32,
11742 ) {
11743 Ok(f) => f,
11744 Err(_) => {
11745 let host = gpu_to_cpu(input, device)?;
11747 let outer = n / (split_size * inner_size);
11748 let mut result = vec![0.0f32; n];
11749 for (i, out) in result.iter_mut().enumerate() {
11750 let outer_idx = i / (split_size * inner_size);
11751 let within = i % (split_size * inner_size);
11752 let src_idx =
11753 outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
11754 *out = host[src_idx];
11755 }
11756 let _ = outer;
11757 return cpu_to_gpu(&result, device);
11758 }
11759 };
11760
11761 let mut out = alloc_zeros_f32(n, device)?;
11762 let cfg = launch_cfg(n)?;
11763 let total_ax_u32 = total_along_axis as u32;
11764 let offset_u32 = split_offset as u32;
11765 let split_sz_u32 = split_size as u32;
11766 let inner_u32 = inner_size as u32;
11767 let n_u32 = n as u32;
11768
11769 unsafe {
11770 stream
11771 .launch_builder(&f)
11772 .arg(input.inner())
11773 .arg(out.inner_mut())
11774 .arg(&total_ax_u32)
11775 .arg(&offset_u32)
11776 .arg(&split_sz_u32)
11777 .arg(&inner_u32)
11778 .arg(&n_u32)
11779 .launch(cfg)?;
11780 }
11781
11782 Ok(out)
11783}
11784
11785#[cfg(feature = "cuda")]
11809#[allow(clippy::too_many_arguments)]
11810pub fn gpu_strided_cat(
11811 input: &CudaBuffer<f32>,
11812 output: &mut CudaBuffer<f32>,
11813 total_along_axis: usize,
11814 cat_offset: usize,
11815 part_size: usize,
11816 inner_size: usize,
11817 n: usize,
11818 device: &GpuDevice,
11819) -> GpuResult<()> {
11820 use cudarc::driver::PushKernelArg;
11821
11822 validate_unary(input, device)?;
11823
11824 let ctx = device.context();
11825 let stream = device.stream();
11826
11827 let f = match crate::module_cache::get_or_compile(
11828 ctx,
11829 STRIDED_CAT_PTX,
11830 "strided_cat_kernel",
11831 device.ordinal() as u32,
11832 ) {
11833 Ok(f) => f,
11834 Err(_) => {
11835 let host_in = gpu_to_cpu(input, device)?;
11837 let mut host_out = gpu_to_cpu(output, device)?;
11838 for (i, &val) in host_in.iter().enumerate().take(n) {
11839 let outer_idx = i / (part_size * inner_size);
11840 let within = i % (part_size * inner_size);
11841 let dst_idx =
11842 outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
11843 host_out[dst_idx] = val;
11844 }
11845 *output = cpu_to_gpu(&host_out, device)?;
11846 return Ok(());
11847 }
11848 };
11849
11850 let cfg = launch_cfg(n)?;
11851 let total_ax_u32 = total_along_axis as u32;
11852 let offset_u32 = cat_offset as u32;
11853 let part_sz_u32 = part_size as u32;
11854 let inner_u32 = inner_size as u32;
11855 let n_u32 = n as u32;
11856
11857 unsafe {
11858 stream
11859 .launch_builder(&f)
11860 .arg(input.inner())
11861 .arg(output.inner_mut())
11862 .arg(&total_ax_u32)
11863 .arg(&offset_u32)
11864 .arg(&part_sz_u32)
11865 .arg(&inner_u32)
11866 .arg(&n_u32)
11867 .launch(cfg)?;
11868 }
11869
11870 Ok(())
11871}
11872
11873pub const STRIDED_COPY_MAX_DIMS: usize = 8;
11880
11881#[cfg(feature = "cuda")]
11895fn pad_strided_copy_params(
11896 out_shape: &[usize],
11897 src_strides: &[isize],
11898 n: usize,
11899) -> GpuResult<([u32; STRIDED_COPY_MAX_DIMS], [u32; STRIDED_COPY_MAX_DIMS])> {
11900 if out_shape.len() != src_strides.len() {
11901 return Err(GpuError::ShapeMismatch {
11902 op: "strided_copy_pad",
11903 expected: vec![out_shape.len()],
11904 got: vec![src_strides.len()],
11905 });
11906 }
11907 if out_shape.len() > STRIDED_COPY_MAX_DIMS {
11908 return Err(GpuError::ShapeMismatch {
11909 op: "strided_copy_pad",
11910 expected: vec![STRIDED_COPY_MAX_DIMS],
11911 got: vec![out_shape.len()],
11912 });
11913 }
11914 for &s in src_strides {
11917 if s < 0 {
11918 return Err(GpuError::ShapeMismatch {
11919 op: "strided_copy_pad_negative_stride",
11920 expected: vec![0],
11921 got: vec![s.unsigned_abs()],
11922 });
11923 }
11924 }
11925
11926 let rank = out_shape.len();
11927 let mut out_stride = [0u32; STRIDED_COPY_MAX_DIMS];
11930 if rank > 0 {
11931 let mut acc: usize = 1;
11932 for d in (0..rank).rev() {
11933 if acc > u32::MAX as usize {
11934 return Err(GpuError::ShapeMismatch {
11935 op: "strided_copy_stride_overflow",
11936 expected: vec![u32::MAX as usize],
11937 got: vec![acc],
11938 });
11939 }
11940 out_stride[d] = acc as u32;
11941 acc = acc.saturating_mul(out_shape[d]);
11942 }
11943 }
11944
11945 let pad_val = (n as u32).saturating_add(1).max(1);
11948 out_stride[rank..STRIDED_COPY_MAX_DIMS].fill(pad_val);
11949
11950 let mut src_stride_out = [0u32; STRIDED_COPY_MAX_DIMS];
11952 for d in 0..rank {
11953 let s = src_strides[d];
11954 if s as usize > u32::MAX as usize {
11955 return Err(GpuError::ShapeMismatch {
11956 op: "strided_copy_src_stride_overflow",
11957 expected: vec![u32::MAX as usize],
11958 got: vec![s as usize],
11959 });
11960 }
11961 src_stride_out[d] = s as u32;
11962 }
11963
11964 Ok((out_stride, src_stride_out))
11965}
11966
11967#[cfg(feature = "cuda")]
11992pub fn gpu_strided_copy(
11993 input: &CudaBuffer<f32>,
11994 out_shape: &[usize],
11995 src_strides: &[isize],
11996 src_offset: usize,
11997 device: &GpuDevice,
11998) -> GpuResult<CudaBuffer<f32>> {
11999 use cudarc::driver::PushKernelArg;
12000
12001 validate_unary(input, device)?;
12002
12003 let n: usize = out_shape.iter().product();
12004 let (out_stride, src_stride) = pad_strided_copy_params(out_shape, src_strides, n)?;
12005
12006 if n == 0 {
12007 return alloc_zeros_f32(0, device);
12008 }
12009
12010 let ctx = device.context();
12011 let stream = device.stream();
12012
12013 let f = match crate::module_cache::get_or_compile(
12014 ctx,
12015 STRIDED_COPY_PTX,
12016 "strided_copy_kernel",
12017 device.ordinal() as u32,
12018 ) {
12019 Ok(f) => f,
12020 Err(_) => {
12021 let host = gpu_to_cpu(input, device)?;
12023 let mut result = vec![0.0f32; n];
12024 for (i, slot) in result.iter_mut().enumerate() {
12025 let mut flat = i as u32;
12026 let mut src_idx = src_offset as u32;
12027 for d in 0..STRIDED_COPY_MAX_DIMS {
12028 let os = out_stride[d];
12029 let ss = src_stride[d];
12030 let coord = flat / os;
12031 flat -= coord * os;
12032 src_idx += coord * ss;
12033 }
12034 *slot = host[src_idx as usize];
12035 }
12036 return cpu_to_gpu(&result, device);
12037 }
12038 };
12039
12040 let mut out = alloc_zeros_f32(n, device)?;
12041 let cfg = launch_cfg(n)?;
12042 let src_offset_u32 = src_offset as u32;
12043 let n_u32 = n as u32;
12044
12045 unsafe {
12046 stream
12047 .launch_builder(&f)
12048 .arg(input.inner())
12049 .arg(out.inner_mut())
12050 .arg(&src_offset_u32)
12051 .arg(&n_u32)
12052 .arg(&out_stride[0])
12053 .arg(&out_stride[1])
12054 .arg(&out_stride[2])
12055 .arg(&out_stride[3])
12056 .arg(&out_stride[4])
12057 .arg(&out_stride[5])
12058 .arg(&out_stride[6])
12059 .arg(&out_stride[7])
12060 .arg(&src_stride[0])
12061 .arg(&src_stride[1])
12062 .arg(&src_stride[2])
12063 .arg(&src_stride[3])
12064 .arg(&src_stride[4])
12065 .arg(&src_stride[5])
12066 .arg(&src_stride[6])
12067 .arg(&src_stride[7])
12068 .launch(cfg)?;
12069 }
12070
12071 Ok(out)
12072}
12073
12074#[cfg(feature = "cuda")]
12076pub fn gpu_strided_copy_f64(
12077 input: &CudaBuffer<f64>,
12078 out_shape: &[usize],
12079 src_strides: &[isize],
12080 src_offset: usize,
12081 device: &GpuDevice,
12082) -> GpuResult<CudaBuffer<f64>> {
12083 use cudarc::driver::PushKernelArg;
12084 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
12085
12086 validate_device(input, device)?;
12087
12088 let n: usize = out_shape.iter().product();
12089 let (out_stride, src_stride) = pad_strided_copy_params(out_shape, src_strides, n)?;
12090
12091 if n == 0 {
12092 return alloc_zeros_f64(0, device);
12093 }
12094
12095 let ctx = device.context();
12096 let stream = device.stream();
12097
12098 let ptx = get_f64_ptx(
12099 &CACHE,
12100 STRIDED_COPY_PTX,
12101 "strided_copy_kernel",
12102 "strided_copy_f64_kernel",
12103 );
12104 let f = match crate::module_cache::get_or_compile(
12105 ctx,
12106 ptx,
12107 "strided_copy_f64_kernel",
12108 device.ordinal() as u32,
12109 ) {
12110 Ok(f) => f,
12111 Err(_) => {
12112 let host = gpu_to_cpu(input, device)?;
12113 let mut result = vec![0.0f64; n];
12114 for (i, slot) in result.iter_mut().enumerate() {
12115 let mut flat = i as u32;
12116 let mut src_idx = src_offset as u32;
12117 for d in 0..STRIDED_COPY_MAX_DIMS {
12118 let os = out_stride[d];
12119 let ss = src_stride[d];
12120 let coord = flat / os;
12121 flat -= coord * os;
12122 src_idx += coord * ss;
12123 }
12124 *slot = host[src_idx as usize];
12125 }
12126 return cpu_to_gpu(&result, device);
12127 }
12128 };
12129
12130 let mut out = alloc_zeros_f64(n, device)?;
12131 let cfg = launch_cfg(n)?;
12132 let src_offset_u32 = src_offset as u32;
12133 let n_u32 = n as u32;
12134
12135 unsafe {
12136 stream
12137 .launch_builder(&f)
12138 .arg(input.inner())
12139 .arg(out.inner_mut())
12140 .arg(&src_offset_u32)
12141 .arg(&n_u32)
12142 .arg(&out_stride[0])
12143 .arg(&out_stride[1])
12144 .arg(&out_stride[2])
12145 .arg(&out_stride[3])
12146 .arg(&out_stride[4])
12147 .arg(&out_stride[5])
12148 .arg(&out_stride[6])
12149 .arg(&out_stride[7])
12150 .arg(&src_stride[0])
12151 .arg(&src_stride[1])
12152 .arg(&src_stride[2])
12153 .arg(&src_stride[3])
12154 .arg(&src_stride[4])
12155 .arg(&src_stride[5])
12156 .arg(&src_stride[6])
12157 .arg(&src_stride[7])
12158 .launch(cfg)?;
12159 }
12160
12161 Ok(out)
12162}
12163
12164#[cfg(feature = "cuda")]
12173pub fn gpu_scale(
12174 a: &CudaBuffer<f32>,
12175 scalar: f32,
12176 device: &GpuDevice,
12177) -> GpuResult<CudaBuffer<f32>> {
12178 use cudarc::driver::PushKernelArg;
12179
12180 validate_unary(a, device)?;
12181
12182 let n = a.len();
12183 let ctx = device.context();
12184 let stream = device.stream();
12185
12186 let f = match crate::module_cache::get_or_compile(
12187 ctx,
12188 SCALE_PTX,
12189 "scale_kernel",
12190 device.ordinal() as u32,
12191 ) {
12192 Ok(f) => f,
12193 Err(_) => {
12194 let host = gpu_to_cpu(a, device)?;
12196 let result: Vec<f32> = host.iter().map(|&x| x * scalar).collect();
12197 return cpu_to_gpu(&result, device);
12198 }
12199 };
12200
12201 let mut out = alloc_zeros_f32(n, device)?;
12202 let cfg = launch_cfg(n)?;
12203 let n_u32 = n as u32;
12204
12205 unsafe {
12206 stream
12207 .launch_builder(&f)
12208 .arg(a.inner())
12209 .arg(out.inner_mut())
12210 .arg(&scalar)
12211 .arg(&n_u32)
12212 .launch(cfg)?;
12213 }
12214
12215 Ok(out)
12216}
12217
12218#[cfg(feature = "cuda")]
12226pub fn gpu_softmax(
12227 input: &CudaBuffer<f32>,
12228 rows: usize,
12229 cols: usize,
12230 device: &GpuDevice,
12231) -> GpuResult<CudaBuffer<f32>> {
12232 use cudarc::driver::PushKernelArg;
12233
12234 validate_unary(input, device)?;
12235
12236 let ctx = device.context();
12237 let stream = device.stream();
12238
12239 let f = match crate::module_cache::get_or_compile(
12240 ctx,
12241 SOFTMAX_PTX,
12242 "softmax_kernel",
12243 device.ordinal() as u32,
12244 ) {
12245 Ok(f) => f,
12246 Err(_) => {
12247 let host = gpu_to_cpu(input, device)?;
12249 let mut out = vec![0.0f32; host.len()];
12250 for r in 0..rows {
12251 let base = r * cols;
12252 let mut max_v = f32::NEG_INFINITY;
12253 for c in 0..cols {
12254 max_v = max_v.max(host[base + c]);
12255 }
12256 let mut sum = 0.0f32;
12257 for c in 0..cols {
12258 let e = (host[base + c] - max_v).exp();
12259 out[base + c] = e;
12260 sum += e;
12261 }
12262 let inv = 1.0 / sum;
12263 for c in 0..cols {
12264 out[base + c] *= inv;
12265 }
12266 }
12267 return cpu_to_gpu(&out, device);
12268 }
12269 };
12270
12271 let mut out = alloc_zeros_f32(rows * cols, device)?;
12272 let rows_u32 = rows as u32;
12273 let cols_u32 = cols as u32;
12274
12275 let cfg = LaunchConfig {
12277 grid_dim: ((rows as u32).max(1), 1, 1),
12278 block_dim: (256, 1, 1),
12279 shared_mem_bytes: 256 * 4, };
12281
12282 unsafe {
12283 stream
12284 .launch_builder(&f)
12285 .arg(input.inner())
12286 .arg(out.inner_mut())
12287 .arg(&rows_u32)
12288 .arg(&cols_u32)
12289 .launch(cfg)?;
12290 }
12291
12292 Ok(out)
12293}
12294
12295#[cfg(feature = "cuda")]
12314pub fn gpu_dropout(
12315 input: &CudaBuffer<f32>,
12316 threshold: u32,
12317 scale: f32,
12318 seed: u32,
12319 device: &GpuDevice,
12320) -> GpuResult<CudaBuffer<f32>> {
12321 use cudarc::driver::PushKernelArg;
12322
12323 validate_unary(input, device)?;
12324
12325 let n = input.len();
12326 let ctx = device.context();
12327 let stream = device.stream();
12328
12329 let f = match crate::module_cache::get_or_compile(
12330 ctx,
12331 DROPOUT_PTX,
12332 "dropout_kernel",
12333 device.ordinal() as u32,
12334 ) {
12335 Ok(f) => f,
12336 Err(_) => {
12337 let host = gpu_to_cpu(input, device)?;
12339 let result: Vec<f32> = host
12343 .iter()
12344 .enumerate()
12345 .map(|(i, &x)| {
12346 let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
12347 r ^= r << 13;
12348 r ^= r >> 17;
12349 r ^= r << 5;
12350 if r < threshold { 0.0 } else { x * scale }
12351 })
12352 .collect();
12353 return cpu_to_gpu(&result, device);
12354 }
12355 };
12356
12357 let mut out = alloc_zeros_f32(n, device)?;
12358 let cfg = launch_cfg(n)?;
12359 let n_u32 = n as u32;
12360
12361 unsafe {
12362 stream
12363 .launch_builder(&f)
12364 .arg(input.inner())
12365 .arg(out.inner_mut())
12366 .arg(&n_u32)
12367 .arg(&threshold)
12368 .arg(&scale)
12369 .arg(&seed)
12370 .launch(cfg)?;
12371 }
12372
12373 Ok(out)
12374}
12375
12376#[cfg(feature = "cuda")]
12378pub fn gpu_dropout_f64(
12379 input: &CudaBuffer<f64>,
12380 threshold: u32,
12381 scale: f64,
12382 seed: u32,
12383 device: &GpuDevice,
12384) -> GpuResult<CudaBuffer<f64>> {
12385 use cudarc::driver::PushKernelArg;
12386 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
12387
12388 let n = input.len();
12389 let ctx = device.context();
12390 let stream = device.stream();
12391
12392 let ptx = get_f64_ptx(&CACHE, DROPOUT_PTX, "dropout_kernel", "dropout_f64_kernel");
12393 let f = match crate::module_cache::get_or_compile(
12394 ctx, ptx, "dropout_f64_kernel", device.ordinal() as u32,
12395 ) {
12396 Ok(f) => f,
12397 Err(_) => {
12398 let host = gpu_to_cpu(input, device)?;
12399 let result: Vec<f64> = host
12400 .iter()
12401 .enumerate()
12402 .map(|(i, &x)| {
12403 let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
12404 r ^= r << 13;
12405 r ^= r >> 17;
12406 r ^= r << 5;
12407 if r < threshold { 0.0 } else { x * scale }
12408 })
12409 .collect();
12410 return cpu_to_gpu(&result, device);
12411 }
12412 };
12413
12414 let mut out = alloc_zeros_f64(n, device)?;
12415 let cfg = launch_cfg(n)?;
12416 let n_u32 = n as u32;
12417
12418 unsafe {
12419 stream
12420 .launch_builder(&f)
12421 .arg(input.inner())
12422 .arg(out.inner_mut())
12423 .arg(&n_u32)
12424 .arg(&threshold)
12425 .arg(&scale)
12426 .arg(&seed)
12427 .launch(cfg)?;
12428 }
12429
12430 Ok(out)
12431}
12432
12433#[cfg(not(feature = "cuda"))]
12434pub fn gpu_dropout_f64(_input: &CudaBuffer<f64>, _threshold: u32, _scale: f64, _seed: u32, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
12435
12436#[cfg(feature = "cuda")]
12442pub fn gpu_transpose_2d(
12443 input: &CudaBuffer<f32>,
12444 m: usize,
12445 n: usize,
12446 device: &GpuDevice,
12447) -> GpuResult<CudaBuffer<f32>> {
12448 use cudarc::driver::PushKernelArg;
12449
12450 validate_unary(input, device)?;
12451
12452 let total = m * n;
12453 let ctx = device.context();
12454 let stream = device.stream();
12455
12456 let f = match crate::module_cache::get_or_compile(
12457 ctx,
12458 TRANSPOSE_2D_PTX,
12459 "transpose_2d_kernel",
12460 device.ordinal() as u32,
12461 ) {
12462 Ok(f) => f,
12463 Err(_) => {
12464 let host = gpu_to_cpu(input, device)?;
12466 let mut out = vec![0.0f32; total];
12467 for i in 0..m {
12468 for j in 0..n {
12469 out[j * m + i] = host[i * n + j];
12470 }
12471 }
12472 return cpu_to_gpu(&out, device);
12473 }
12474 };
12475
12476 let mut out = alloc_zeros_f32(total, device)?;
12477 let cfg = launch_cfg(total)?;
12478 let m_u32 = m as u32;
12479 let n_u32 = n as u32;
12480 let total_u32 = total as u32;
12481
12482 unsafe {
12483 stream
12484 .launch_builder(&f)
12485 .arg(input.inner())
12486 .arg(out.inner_mut())
12487 .arg(&m_u32)
12488 .arg(&n_u32)
12489 .arg(&total_u32)
12490 .launch(cfg)?;
12491 }
12492
12493 Ok(out)
12494}
12495
12496#[cfg(feature = "cuda")]
12503pub fn gpu_permute_0213(
12504 input: &CudaBuffer<f32>,
12505 d0: usize,
12506 d1: usize,
12507 d2: usize,
12508 d3: usize,
12509 device: &GpuDevice,
12510) -> GpuResult<CudaBuffer<f32>> {
12511 use cudarc::driver::PushKernelArg;
12512
12513 validate_unary(input, device)?;
12514
12515 let total = d0 * d1 * d2 * d3;
12516 let ctx = device.context();
12517 let stream = device.stream();
12518
12519 let f = match crate::module_cache::get_or_compile(
12520 ctx,
12521 PERMUTE_0213_PTX,
12522 "permute_0213_kernel",
12523 device.ordinal() as u32,
12524 ) {
12525 Ok(f) => f,
12526 Err(_) => {
12527 let host = gpu_to_cpu(input, device)?;
12529 let mut out = vec![0.0f32; total];
12530 for i0 in 0..d0 {
12531 for i1 in 0..d1 {
12532 for i2 in 0..d2 {
12533 for i3 in 0..d3 {
12534 let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
12535 let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
12536 out[out_idx] = host[in_idx];
12537 }
12538 }
12539 }
12540 }
12541 return cpu_to_gpu(&out, device);
12542 }
12543 };
12544
12545 let mut out = alloc_zeros_f32(total, device)?;
12546 let cfg = launch_cfg(total)?;
12547 let d0_u32 = d0 as u32;
12548 let d1_u32 = d1 as u32;
12549 let d2_u32 = d2 as u32;
12550 let d3_u32 = d3 as u32;
12551 let total_u32 = total as u32;
12552
12553 unsafe {
12554 stream
12555 .launch_builder(&f)
12556 .arg(input.inner())
12557 .arg(out.inner_mut())
12558 .arg(&d0_u32)
12559 .arg(&d1_u32)
12560 .arg(&d2_u32)
12561 .arg(&d3_u32)
12562 .arg(&total_u32)
12563 .launch(cfg)?;
12564 }
12565
12566 Ok(out)
12567}
12568
12569#[cfg(feature = "cuda")]
12578pub fn gpu_small_matmul(
12579 a: &CudaBuffer<f32>,
12580 b: &CudaBuffer<f32>,
12581 m: usize,
12582 k: usize,
12583 n: usize,
12584 device: &GpuDevice,
12585) -> GpuResult<CudaBuffer<f32>> {
12586 use cudarc::driver::PushKernelArg;
12587
12588 let total = m * n;
12589 let ctx = device.context();
12590 let stream = device.stream();
12591
12592 let f = match crate::module_cache::get_or_compile(
12593 ctx,
12594 SMALL_MATMUL_PTX,
12595 "small_matmul_kernel",
12596 device.ordinal() as u32,
12597 ) {
12598 Ok(f) => f,
12599 Err(_) => {
12600 return crate::blas::gpu_matmul_f32(a, b, m, k, n, device);
12602 }
12603 };
12604
12605 let mut c = alloc_zeros_f32(total, device)?;
12606 let cfg = launch_cfg(total)?;
12607 let m_u32 = m as u32;
12608 let k_u32 = k as u32;
12609 let n_u32 = n as u32;
12610 let total_u32 = total as u32;
12611
12612 unsafe {
12613 stream
12614 .launch_builder(&f)
12615 .arg(a.inner())
12616 .arg(b.inner())
12617 .arg(c.inner_mut())
12618 .arg(&m_u32)
12619 .arg(&k_u32)
12620 .arg(&n_u32)
12621 .arg(&total_u32)
12622 .launch(cfg)?;
12623 }
12624
12625 Ok(c)
12626}
12627
12628#[cfg(feature = "cuda")]
12637pub fn gpu_small_bmm(
12638 a: &CudaBuffer<f32>,
12639 b: &CudaBuffer<f32>,
12640 batch: usize,
12641 m: usize,
12642 k: usize,
12643 n: usize,
12644 device: &GpuDevice,
12645) -> GpuResult<CudaBuffer<f32>> {
12646 if batch == 1 {
12648 return gpu_small_matmul(a, b, m, k, n, device);
12649 }
12650 crate::blas::gpu_bmm_f32(a, b, batch, m, k, n, device)
12653}
12654
12655#[cfg(feature = "cuda")]
12663pub fn gpu_embed_lookup(
12664 idx: &CudaBuffer<f32>,
12665 weight: &CudaBuffer<f32>,
12666 d: usize,
12667 device: &GpuDevice,
12668) -> GpuResult<CudaBuffer<f32>> {
12669 use cudarc::driver::PushKernelArg;
12670
12671 let ctx = device.context();
12672 let stream = device.stream();
12673
12674 let f = match crate::module_cache::get_or_compile(
12675 ctx,
12676 EMBED_LOOKUP_PTX,
12677 "embed_lookup_kernel",
12678 device.ordinal() as u32,
12679 ) {
12680 Ok(f) => f,
12681 Err(_) => {
12682 let idx_host = gpu_to_cpu(idx, device)?;
12684 let weight_host = gpu_to_cpu(weight, device)?;
12685 let row = idx_host[0] as usize;
12686 let start = row * d;
12687 let out = weight_host[start..start + d].to_vec();
12688 return cpu_to_gpu(&out, device);
12689 }
12690 };
12691
12692 let mut out = alloc_zeros_f32(d, device)?;
12693 let cfg = launch_cfg(d)?;
12694 let d_u32 = d as u32;
12695
12696 unsafe {
12697 stream
12698 .launch_builder(&f)
12699 .arg(idx.inner())
12700 .arg(weight.inner())
12701 .arg(out.inner_mut())
12702 .arg(&d_u32)
12703 .launch(cfg)?;
12704 }
12705
12706 Ok(out)
12707}
12708
12709#[cfg(feature = "cuda")]
12716pub fn gpu_slice_write(
12717 src: &CudaBuffer<f32>,
12718 dst: &mut CudaBuffer<f32>,
12719 n_batch: usize,
12720 d: usize,
12721 max_len: usize,
12722 pos: usize,
12723 device: &GpuDevice,
12724) -> GpuResult<()> {
12725 use cudarc::driver::PushKernelArg;
12726
12727 let total = n_batch * d;
12728 let ctx = device.context();
12729 let stream = device.stream();
12730
12731 let f = match crate::module_cache::get_or_compile(
12732 ctx,
12733 SLICE_WRITE_PTX,
12734 "slice_write_kernel",
12735 device.ordinal() as u32,
12736 ) {
12737 Ok(f) => f,
12738 Err(_) => {
12739 let src_host = gpu_to_cpu(src, device)?;
12741 let mut dst_host = gpu_to_cpu(dst, device)?;
12742 for b in 0..n_batch {
12743 for di in 0..d {
12744 dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
12745 }
12746 }
12747 let new_dst = cpu_to_gpu(&dst_host, device)?;
12748 *dst = new_dst;
12749 return Ok(());
12750 }
12751 };
12752
12753 let cfg = launch_cfg(total)?;
12754 let n_u32 = total as u32;
12755 let d_u32 = d as u32;
12756 let max_len_u32 = max_len as u32;
12757 let pos_u32 = pos as u32;
12758
12759 unsafe {
12760 stream
12761 .launch_builder(&f)
12762 .arg(src.inner())
12763 .arg(dst.inner_mut())
12764 .arg(&n_u32)
12765 .arg(&d_u32)
12766 .arg(&max_len_u32)
12767 .arg(&pos_u32)
12768 .launch(cfg)?;
12769 }
12770
12771 Ok(())
12772}
12773
12774#[cfg(feature = "cuda")]
12780pub fn gpu_slice_read(
12781 src: &CudaBuffer<f32>,
12782 n_batch: usize,
12783 d: usize,
12784 len: usize,
12785 max_len: usize,
12786 device: &GpuDevice,
12787) -> GpuResult<CudaBuffer<f32>> {
12788 use cudarc::driver::PushKernelArg;
12789
12790 let total = n_batch * len * d;
12791 let ctx = device.context();
12792 let stream = device.stream();
12793
12794 let f = match crate::module_cache::get_or_compile(
12795 ctx,
12796 SLICE_READ_PTX,
12797 "slice_read_kernel",
12798 device.ordinal() as u32,
12799 ) {
12800 Ok(f) => f,
12801 Err(_) => {
12802 let host = gpu_to_cpu(src, device)?;
12803 let mut out = vec![0.0f32; total];
12804 for b in 0..n_batch {
12805 for r in 0..len {
12806 for di in 0..d {
12807 out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
12808 }
12809 }
12810 }
12811 return cpu_to_gpu(&out, device);
12812 }
12813 };
12814
12815 let mut out = alloc_zeros_f32(total, device)?;
12816 let cfg = launch_cfg(total)?;
12817 let total_u32 = total as u32;
12818 let d_u32 = d as u32;
12819 let len_u32 = len as u32;
12820 let max_len_u32 = max_len as u32;
12821
12822 unsafe {
12823 stream
12824 .launch_builder(&f)
12825 .arg(src.inner())
12826 .arg(out.inner_mut())
12827 .arg(&total_u32)
12828 .arg(&d_u32)
12829 .arg(&len_u32)
12830 .arg(&max_len_u32)
12831 .launch(cfg)?;
12832 }
12833
12834 Ok(out)
12835}
12836
12837#[cfg(feature = "cuda")]
12843pub fn gpu_gelu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12844 validate_unary(input, device)?;
12845 if let Some(out) = try_launch_unary(input, device, GELU_PTX, "gelu_kernel")? {
12846 return Ok(out);
12847 }
12848 cpu_fallback_unary(input, device, |x| {
12849 let s = 1.0 / (1.0 + (-1.702 * x).exp());
12850 x * s
12851 })
12852}
12853
12854#[cfg(feature = "cuda")]
12859pub fn gpu_gelu_tanh(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12860 validate_unary(input, device)?;
12861 if let Some(out) = try_launch_unary(input, device, GELU_TANH_PTX, "gelu_tanh_kernel")? {
12862 return Ok(out);
12863 }
12864 cpu_fallback_unary(input, device, |x| {
12865 let sqrt_2_over_pi: f32 = 0.797_884_6;
12866 let c: f32 = 0.044715;
12867 let inner = sqrt_2_over_pi * (x + c * x * x * x);
12868 0.5 * x * (1.0 + inner.tanh())
12869 })
12870}
12871
12872#[cfg(feature = "cuda")]
12877pub fn gpu_gelu_erf(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12878 validate_unary(input, device)?;
12879 if let Some(out) = try_launch_unary(input, device, GELU_ERF_PTX, "gelu_erf_kernel")? {
12880 return Ok(out);
12881 }
12882 cpu_fallback_unary(input, device, |x| {
12883 let z = x * std::f32::consts::FRAC_1_SQRT_2;
12885 let az = z.abs();
12886 let t = 1.0 / (1.0 + 0.3275911 * az);
12887 let poly = t * (0.254_829_6 + t * (-0.284_496_72 + t * (1.421_413_8 + t * (-1.453_152_1 + t * 1.061_405_4))));
12888 let erf_abs = 1.0 - poly * (-az * az).exp();
12889 let erf_val = if z < 0.0 { -erf_abs } else { erf_abs };
12890 x * 0.5 * (1.0 + erf_val)
12891 })
12892}
12893
12894#[cfg(feature = "cuda")]
12898pub fn gpu_gelu_backward_tanh(
12899 grad: &CudaBuffer<f32>,
12900 input: &CudaBuffer<f32>,
12901 device: &GpuDevice,
12902) -> GpuResult<CudaBuffer<f32>> {
12903 validate_binary(grad, input, device)?;
12904 if let Some(out) = try_launch_binary(
12905 grad,
12906 input,
12907 device,
12908 GELU_BACKWARD_TANH_PTX,
12909 "gelu_backward_tanh_kernel",
12910 )? {
12911 return Ok(out);
12912 }
12913 let grad_host = gpu_to_cpu(grad, device)?;
12915 let input_host = gpu_to_cpu(input, device)?;
12916 let result: Vec<f32> = grad_host
12917 .iter()
12918 .zip(input_host.iter())
12919 .map(|(&g, &x)| {
12920 let sqrt_2_over_pi: f32 = 0.797_884_6;
12921 let c: f32 = 0.044715;
12922 let c3: f32 = 0.134145;
12923 let u = sqrt_2_over_pi * (x + c * x * x * x);
12924 let t = u.tanh();
12925 let dt = 1.0 - t * t;
12926 let d_inner = sqrt_2_over_pi * (1.0 + c3 * x * x);
12927 g * (0.5 * (1.0 + t) + 0.5 * x * dt * d_inner)
12928 })
12929 .collect();
12930 cpu_to_gpu(&result, device)
12931}
12932
12933#[cfg(feature = "cuda")]
12939pub fn gpu_silu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12940 validate_unary(input, device)?;
12941 if let Some(out) = try_launch_unary(input, device, SILU_PTX, "silu_kernel")? {
12942 return Ok(out);
12943 }
12944 cpu_fallback_unary(input, device, |x| {
12945 let sig = 1.0 / (1.0 + (-x).exp());
12946 x * sig
12947 })
12948}
12949
12950#[cfg(feature = "cuda")]
12953pub fn gpu_silu_backward(
12954 grad: &CudaBuffer<f32>,
12955 input: &CudaBuffer<f32>,
12956 device: &GpuDevice,
12957) -> GpuResult<CudaBuffer<f32>> {
12958 validate_binary(grad, input, device)?;
12959
12960 if let Some(out) = try_launch_binary(
12961 grad,
12962 input,
12963 device,
12964 SILU_BACKWARD_PTX,
12965 "silu_backward_kernel",
12966 )? {
12967 return Ok(out);
12968 }
12969
12970 let grad_host = gpu_to_cpu(grad, device)?;
12972 let input_host = gpu_to_cpu(input, device)?;
12973 let result: Vec<f32> = grad_host
12974 .iter()
12975 .zip(input_host.iter())
12976 .map(|(&g, &x)| {
12977 let sig = 1.0 / (1.0 + (-x).exp());
12978 g * (sig + x * sig * (1.0 - sig))
12979 })
12980 .collect();
12981 cpu_to_gpu(&result, device)
12982}
12983
12984#[cfg(feature = "cuda")]
12992pub fn gpu_elu(
12993 input: &CudaBuffer<f32>,
12994 alpha: f32,
12995 device: &GpuDevice,
12996) -> GpuResult<CudaBuffer<f32>> {
12997 use cudarc::driver::PushKernelArg;
12998
12999 validate_unary(input, device)?;
13000
13001 let n = input.len();
13002 let ctx = device.context();
13003 let stream = device.stream();
13004
13005 let f = match crate::module_cache::get_or_compile(
13006 ctx,
13007 ELU_PTX,
13008 "elu_kernel",
13009 device.ordinal() as u32,
13010 ) {
13011 Ok(f) => f,
13012 Err(_) => {
13013 let host = gpu_to_cpu(input, device)?;
13014 let result: Vec<f32> = host
13015 .iter()
13016 .map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
13017 .collect();
13018 return cpu_to_gpu(&result, device);
13019 }
13020 };
13021
13022 let mut out = alloc_zeros_f32(n, device)?;
13023 let cfg = launch_cfg(n)?;
13024 let n_u32 = n as u32;
13025
13026 unsafe {
13027 stream
13028 .launch_builder(&f)
13029 .arg(input.inner())
13030 .arg(out.inner_mut())
13031 .arg(&n_u32)
13032 .arg(&alpha)
13033 .launch(cfg)?;
13034 }
13035
13036 Ok(out)
13037}
13038
13039#[cfg(feature = "cuda")]
13043pub fn gpu_elu_backward(
13044 grad: &CudaBuffer<f32>,
13045 input: &CudaBuffer<f32>,
13046 alpha: f32,
13047 device: &GpuDevice,
13048) -> GpuResult<CudaBuffer<f32>> {
13049 use cudarc::driver::PushKernelArg;
13050
13051 validate_binary(grad, input, device)?;
13052
13053 let n = grad.len();
13054 let ctx = device.context();
13055 let stream = device.stream();
13056
13057 let f = match crate::module_cache::get_or_compile(
13058 ctx,
13059 ELU_BACKWARD_PTX,
13060 "elu_backward_kernel",
13061 device.ordinal() as u32,
13062 ) {
13063 Ok(f) => f,
13064 Err(_) => {
13065 let grad_host = gpu_to_cpu(grad, device)?;
13066 let input_host = gpu_to_cpu(input, device)?;
13067 let result: Vec<f32> = grad_host
13068 .iter()
13069 .zip(input_host.iter())
13070 .map(|(&g, &x)| if x > 0.0 { g } else { g * alpha * x.exp() })
13071 .collect();
13072 return cpu_to_gpu(&result, device);
13073 }
13074 };
13075
13076 let mut out = alloc_zeros_f32(n, device)?;
13077 let cfg = launch_cfg(n)?;
13078 let n_u32 = n as u32;
13079
13080 unsafe {
13081 stream
13082 .launch_builder(&f)
13083 .arg(grad.inner())
13084 .arg(input.inner())
13085 .arg(out.inner_mut())
13086 .arg(&n_u32)
13087 .arg(&alpha)
13088 .launch(cfg)?;
13089 }
13090
13091 Ok(out)
13092}
13093
13094#[cfg(feature = "cuda")]
13100pub fn gpu_mish(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13101 validate_unary(input, device)?;
13102 if let Some(out) = try_launch_unary(input, device, MISH_PTX, "mish_kernel")? {
13103 return Ok(out);
13104 }
13105 cpu_fallback_unary(input, device, |x| {
13106 let sp = if x > 20.0 { x } else { (1.0 + x.exp()).ln() };
13107 x * sp.tanh()
13108 })
13109}
13110
13111#[cfg(feature = "cuda")]
13115pub fn gpu_mish_backward(
13116 grad: &CudaBuffer<f32>,
13117 input: &CudaBuffer<f32>,
13118 device: &GpuDevice,
13119) -> GpuResult<CudaBuffer<f32>> {
13120 validate_binary(grad, input, device)?;
13121
13122 if let Some(out) = try_launch_binary(
13123 grad,
13124 input,
13125 device,
13126 MISH_BACKWARD_PTX,
13127 "mish_backward_kernel",
13128 )? {
13129 return Ok(out);
13130 }
13131
13132 let grad_host = gpu_to_cpu(grad, device)?;
13134 let input_host = gpu_to_cpu(input, device)?;
13135 let result: Vec<f32> = grad_host
13136 .iter()
13137 .zip(input_host.iter())
13138 .map(|(&g, &x)| {
13139 let sp = if x > 20.0 { x } else { (1.0 + x.exp()).ln() };
13140 let t = sp.tanh();
13141 let sig = 1.0 / (1.0 + (-x).exp());
13142 g * (t + x * sig * (1.0 - t * t))
13143 })
13144 .collect();
13145 cpu_to_gpu(&result, device)
13146}
13147
13148#[cfg(feature = "cuda")]
13152pub fn gpu_clamp(
13153 input: &CudaBuffer<f32>,
13154 min_val: f32,
13155 max_val: f32,
13156 device: &GpuDevice,
13157) -> GpuResult<CudaBuffer<f32>> {
13158 use cudarc::driver::PushKernelArg;
13159
13160 validate_unary(input, device)?;
13161
13162 let n = input.len();
13163 let ctx = device.context();
13164 let stream = device.stream();
13165
13166 let f = match crate::module_cache::get_or_compile(
13167 ctx,
13168 CLAMP_PTX,
13169 "clamp_kernel",
13170 device.ordinal() as u32,
13171 ) {
13172 Ok(f) => f,
13173 Err(_) => {
13174 let host = gpu_to_cpu(input, device)?;
13175 let result: Vec<f32> = host
13176 .iter()
13177 .map(|&x| x.max(min_val).min(max_val))
13178 .collect();
13179 return cpu_to_gpu(&result, device);
13180 }
13181 };
13182
13183 let mut out = alloc_zeros_f32(n, device)?;
13184 let cfg = launch_cfg(n)?;
13185 let n_u32 = n as u32;
13186
13187 unsafe {
13188 stream
13189 .launch_builder(&f)
13190 .arg(input.inner())
13191 .arg(out.inner_mut())
13192 .arg(&n_u32)
13193 .arg(&min_val)
13194 .arg(&max_val)
13195 .launch(cfg)?;
13196 }
13197
13198 Ok(out)
13199}
13200
13201#[cfg(feature = "cuda")]
13207pub fn gpu_div(
13208 a: &CudaBuffer<f32>,
13209 b: &CudaBuffer<f32>,
13210 device: &GpuDevice,
13211) -> GpuResult<CudaBuffer<f32>> {
13212 validate_binary(a, b, device)?;
13213
13214 if let Some(out) = try_launch_binary(a, b, device, DIV_PTX, "div_kernel")? {
13215 return Ok(out);
13216 }
13217
13218 let a_host = gpu_to_cpu(a, device)?;
13220 let b_host = gpu_to_cpu(b, device)?;
13221 let result: Vec<f32> = a_host
13222 .iter()
13223 .zip(b_host.iter())
13224 .map(|(&x, &y)| x / y)
13225 .collect();
13226 cpu_to_gpu(&result, device)
13227}
13228
13229#[cfg(feature = "cuda")]
13231pub fn gpu_exp(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13232 validate_unary(a, device)?;
13233 if let Some(out) = try_launch_unary(a, device, EXP_PTX, "exp_kernel")? {
13234 return Ok(out);
13235 }
13236 cpu_fallback_unary(a, device, |x| x.exp())
13237}
13238
13239#[cfg(feature = "cuda")]
13241pub fn gpu_log(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13242 validate_unary(a, device)?;
13243 if let Some(out) = try_launch_unary(a, device, LOG_PTX, "log_kernel")? {
13244 return Ok(out);
13245 }
13246 cpu_fallback_unary(a, device, |x| x.ln())
13247}
13248
13249#[cfg(feature = "cuda")]
13251pub fn gpu_sqrt(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13252 validate_unary(a, device)?;
13253 if let Some(out) = try_launch_unary(a, device, SQRT_PTX, "sqrt_kernel")? {
13254 return Ok(out);
13255 }
13256 cpu_fallback_unary(a, device, |x| x.sqrt())
13257}
13258
13259#[cfg(feature = "cuda")]
13261pub fn gpu_pow(
13262 a: &CudaBuffer<f32>,
13263 exponent: f32,
13264 device: &GpuDevice,
13265) -> GpuResult<CudaBuffer<f32>> {
13266 use cudarc::driver::PushKernelArg;
13267
13268 validate_unary(a, device)?;
13269
13270 let n = a.len();
13271 let ctx = device.context();
13272 let stream = device.stream();
13273
13274 let f = match crate::module_cache::get_or_compile(
13275 ctx,
13276 POW_PTX,
13277 "pow_kernel",
13278 device.ordinal() as u32,
13279 ) {
13280 Ok(f) => f,
13281 Err(_) => {
13282 let host = gpu_to_cpu(a, device)?;
13283 let result: Vec<f32> = host.iter().map(|&x| x.powf(exponent)).collect();
13284 return cpu_to_gpu(&result, device);
13285 }
13286 };
13287
13288 let mut out = alloc_zeros_f32(n, device)?;
13289 let cfg = launch_cfg(n)?;
13290 let n_u32 = n as u32;
13291
13292 unsafe {
13293 stream
13294 .launch_builder(&f)
13295 .arg(a.inner())
13296 .arg(out.inner_mut())
13297 .arg(&exponent)
13298 .arg(&n_u32)
13299 .launch(cfg)?;
13300 }
13301
13302 Ok(out)
13303}
13304
13305#[cfg(feature = "cuda")]
13307pub fn gpu_abs(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13308 validate_unary(a, device)?;
13309 if let Some(out) = try_launch_unary(a, device, ABS_PTX, "abs_kernel")? {
13310 return Ok(out);
13311 }
13312 cpu_fallback_unary(a, device, |x| x.abs())
13313}
13314
13315#[cfg(feature = "cuda")]
13317pub fn gpu_sigmoid(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13318 validate_unary(a, device)?;
13319 if let Some(out) = try_launch_unary(a, device, SIGMOID_PTX, "sigmoid_kernel")? {
13320 return Ok(out);
13321 }
13322 cpu_fallback_unary(a, device, |x| 1.0 / (1.0 + (-x).exp()))
13323}
13324
13325#[cfg(feature = "cuda")]
13327pub fn gpu_tanh(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13328 validate_unary(a, device)?;
13329 if let Some(out) = try_launch_unary(a, device, TANH_PTX, "tanh_kernel")? {
13330 return Ok(out);
13331 }
13332 cpu_fallback_unary(a, device, |x| x.tanh())
13333}
13334
13335#[cfg(feature = "cuda")]
13341pub fn gpu_add_f64(
13342 a: &CudaBuffer<f64>,
13343 b: &CudaBuffer<f64>,
13344 device: &GpuDevice,
13345) -> GpuResult<CudaBuffer<f64>> {
13346 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13347 if a.len() != b.len() {
13348 return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13349 }
13350 let ptx = get_f64_ptx(&CACHE, ADD_PTX, "add_kernel", "add_f64_kernel");
13351 if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "add_f64_kernel")? {
13352 return Ok(out);
13353 }
13354 cpu_fallback_binary_f64(a, b, device, |x, y| x + y)
13355}
13356
13357#[cfg(feature = "cuda")]
13359pub fn gpu_sub_f64(
13360 a: &CudaBuffer<f64>,
13361 b: &CudaBuffer<f64>,
13362 device: &GpuDevice,
13363) -> GpuResult<CudaBuffer<f64>> {
13364 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13365 if a.len() != b.len() {
13366 return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13367 }
13368 let ptx = get_f64_ptx(&CACHE, SUB_PTX, "sub_kernel", "sub_f64_kernel");
13369 if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "sub_f64_kernel")? {
13370 return Ok(out);
13371 }
13372 cpu_fallback_binary_f64(a, b, device, |x, y| x - y)
13373}
13374
13375#[cfg(feature = "cuda")]
13377pub fn gpu_mul_f64(
13378 a: &CudaBuffer<f64>,
13379 b: &CudaBuffer<f64>,
13380 device: &GpuDevice,
13381) -> GpuResult<CudaBuffer<f64>> {
13382 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13383 if a.len() != b.len() {
13384 return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13385 }
13386 let ptx = get_f64_ptx(&CACHE, MUL_PTX, "mul_kernel", "mul_f64_kernel");
13387 if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "mul_f64_kernel")? {
13388 return Ok(out);
13389 }
13390 cpu_fallback_binary_f64(a, b, device, |x, y| x * y)
13391}
13392
13393#[cfg(feature = "cuda")]
13395pub fn gpu_div_f64(
13396 a: &CudaBuffer<f64>,
13397 b: &CudaBuffer<f64>,
13398 device: &GpuDevice,
13399) -> GpuResult<CudaBuffer<f64>> {
13400 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13401 if a.len() != b.len() {
13402 return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13403 }
13404 let ptx = get_f64_ptx(&CACHE, DIV_PTX, "div_kernel", "div_f64_kernel");
13405 if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "div_f64_kernel")? {
13406 return Ok(out);
13407 }
13408 cpu_fallback_binary_f64(a, b, device, |x, y| x / y)
13409}
13410
13411#[cfg(feature = "cuda")]
13413pub fn gpu_neg_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13414 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13415 let ptx = get_f64_ptx(&CACHE, NEG_PTX, "neg_kernel", "neg_f64_kernel");
13416 if let Some(out) = try_launch_unary_f64(a, device, ptx, "neg_f64_kernel")? {
13417 return Ok(out);
13418 }
13419 cpu_fallback_unary_f64(a, device, |x| -x)
13420}
13421
13422#[cfg(feature = "cuda")]
13424pub fn gpu_relu_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13425 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13426 let ptx = get_f64_ptx(&CACHE, RELU_PTX, "relu_kernel", "relu_f64_kernel");
13427 if let Some(out) = try_launch_unary_f64(a, device, ptx, "relu_f64_kernel")? {
13428 return Ok(out);
13429 }
13430 cpu_fallback_unary_f64(a, device, |x| x.max(0.0))
13431}
13432
13433#[cfg(feature = "cuda")]
13435pub fn gpu_scale_f64(
13436 a: &CudaBuffer<f64>,
13437 scalar: f64,
13438 device: &GpuDevice,
13439) -> GpuResult<CudaBuffer<f64>> {
13440 use cudarc::driver::PushKernelArg;
13441 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13442
13443 let n = a.len();
13444 let ctx = device.context();
13445 let stream = device.stream();
13446
13447 let ptx = get_f64_ptx(&CACHE, SCALE_PTX, "scale_kernel", "scale_f64_kernel");
13448 if let Ok(f) = crate::module_cache::get_or_compile(
13449 ctx, ptx, "scale_f64_kernel", device.ordinal() as u32,
13450 ) {
13451 let mut out = alloc_zeros_f64(n, device)?;
13452 let cfg = launch_cfg(n)?;
13453 let n_u32 = n as u32;
13454
13455 unsafe {
13456 stream
13457 .launch_builder(&f)
13458 .arg(a.inner())
13459 .arg(out.inner_mut())
13460 .arg(&scalar)
13461 .arg(&n_u32)
13462 .launch(cfg)?;
13463 }
13464 return Ok(out);
13465 }
13466
13467 let a_host = gpu_to_cpu(a, device)?;
13468 let result: Vec<f64> = a_host.iter().map(|&x| x * scalar).collect();
13469 cpu_to_gpu(&result, device)
13470}
13471
13472#[cfg(feature = "cuda")]
13474pub fn gpu_exp_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13475 if let Some(out) = try_launch_unary_f64(a, device, EXP_F64_PTX, "exp_f64_kernel")? {
13476 return Ok(out);
13477 }
13478 cpu_fallback_unary_f64(a, device, |x| x.exp())
13479}
13480
13481#[cfg(feature = "cuda")]
13483pub fn gpu_log_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13484 if let Some(out) = try_launch_unary_f64(a, device, LOG_F64_PTX, "log_f64_kernel")? {
13485 return Ok(out);
13486 }
13487 cpu_fallback_unary_f64(a, device, |x| x.ln())
13488}
13489
13490#[cfg(feature = "cuda")]
13492pub fn gpu_sqrt_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13493 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13494 let ptx = get_f64_ptx(&CACHE, SQRT_PTX, "sqrt_kernel", "sqrt_f64_kernel");
13495 if let Some(out) = try_launch_unary_f64(a, device, ptx, "sqrt_f64_kernel")? {
13496 return Ok(out);
13497 }
13498 cpu_fallback_unary_f64(a, device, |x| x.sqrt())
13499}
13500
13501#[cfg(feature = "cuda")]
13503pub fn gpu_pow_f64(
13504 a: &CudaBuffer<f64>,
13505 exponent: f64,
13506 device: &GpuDevice,
13507) -> GpuResult<CudaBuffer<f64>> {
13508 use cudarc::driver::PushKernelArg;
13509
13510 let n = a.len();
13511 let ctx = device.context();
13512 let stream = device.stream();
13513
13514 if let Ok(f) = crate::module_cache::get_or_compile(
13515 ctx, POW_F64_PTX, "pow_f64_kernel", device.ordinal() as u32,
13516 ) {
13517 let mut out = alloc_zeros_f64(n, device)?;
13518 let cfg = launch_cfg(n)?;
13519 let n_u32 = n as u32;
13520
13521 unsafe {
13522 stream
13523 .launch_builder(&f)
13524 .arg(a.inner())
13525 .arg(out.inner_mut())
13526 .arg(&exponent)
13527 .arg(&n_u32)
13528 .launch(cfg)?;
13529 }
13530 return Ok(out);
13531 }
13532
13533 let a_host = gpu_to_cpu(a, device)?;
13534 let result: Vec<f64> = a_host.iter().map(|&x| x.powf(exponent)).collect();
13535 cpu_to_gpu(&result, device)
13536}
13537
13538#[cfg(feature = "cuda")]
13540pub fn gpu_abs_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13541 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13542 let ptx = get_f64_ptx(&CACHE, ABS_PTX, "abs_kernel", "abs_f64_kernel");
13543 if let Some(out) = try_launch_unary_f64(a, device, ptx, "abs_f64_kernel")? {
13544 return Ok(out);
13545 }
13546 cpu_fallback_unary_f64(a, device, |x| x.abs())
13547}
13548
13549#[cfg(feature = "cuda")]
13551pub fn gpu_sigmoid_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13552 if let Some(out) = try_launch_unary_f64(a, device, SIGMOID_F64_PTX, "sigmoid_f64_kernel")? {
13553 return Ok(out);
13554 }
13555 cpu_fallback_unary_f64(a, device, |x| 1.0 / (1.0 + (-x).exp()))
13556}
13557
13558#[cfg(feature = "cuda")]
13560pub fn gpu_tanh_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13561 if let Some(out) = try_launch_unary_f64(a, device, TANH_F64_PTX, "tanh_f64_kernel")? {
13562 return Ok(out);
13563 }
13564 cpu_fallback_unary_f64(a, device, |x| x.tanh())
13565}
13566
13567#[cfg(feature = "cuda")]
13573pub fn gpu_relu_backward_f64(
13574 grad: &CudaBuffer<f64>,
13575 input: &CudaBuffer<f64>,
13576 device: &GpuDevice,
13577) -> GpuResult<CudaBuffer<f64>> {
13578 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13579 if grad.len() != input.len() {
13580 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
13581 }
13582 let ptx = get_f64_ptx(&CACHE, RELU_BACKWARD_PTX, "relu_backward_kernel", "relu_backward_f64_kernel");
13583 if let Some(out) = try_launch_binary_f64(
13584 grad,
13585 input,
13586 device,
13587 ptx,
13588 "relu_backward_f64_kernel",
13589 )? {
13590 return Ok(out);
13591 }
13592 cpu_fallback_binary_f64(grad, input, device, |g, x| if x > 0.0 { g } else { 0.0 })
13593}
13594
13595#[cfg(feature = "cuda")]
13597pub fn gpu_sigmoid_backward_f64(
13598 grad: &CudaBuffer<f64>,
13599 output: &CudaBuffer<f64>,
13600 device: &GpuDevice,
13601) -> GpuResult<CudaBuffer<f64>> {
13602 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13603 if grad.len() != output.len() {
13604 return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
13605 }
13606 let ptx = get_f64_ptx(&CACHE, SIGMOID_BACKWARD_PTX, "sigmoid_backward_kernel", "sigmoid_backward_f64_kernel");
13607 if let Some(out) = try_launch_binary_f64(
13608 grad,
13609 output,
13610 device,
13611 ptx,
13612 "sigmoid_backward_f64_kernel",
13613 )? {
13614 return Ok(out);
13615 }
13616 cpu_fallback_binary_f64(grad, output, device, |g, o| g * o * (1.0 - o))
13617}
13618
13619#[cfg(feature = "cuda")]
13621pub fn gpu_tanh_backward_f64(
13622 grad: &CudaBuffer<f64>,
13623 output: &CudaBuffer<f64>,
13624 device: &GpuDevice,
13625) -> GpuResult<CudaBuffer<f64>> {
13626 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13627 if grad.len() != output.len() {
13628 return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
13629 }
13630 let ptx = get_f64_ptx(&CACHE, TANH_BACKWARD_PTX, "tanh_backward_kernel", "tanh_backward_f64_kernel");
13631 if let Some(out) = try_launch_binary_f64(
13632 grad,
13633 output,
13634 device,
13635 ptx,
13636 "tanh_backward_f64_kernel",
13637 )? {
13638 return Ok(out);
13639 }
13640 cpu_fallback_binary_f64(grad, output, device, |g, o| g * (1.0 - o * o))
13641}
13642
13643#[cfg(feature = "cuda")]
13649pub fn gpu_broadcast_add_f64(
13650 a: &CudaBuffer<f64>,
13651 b: &CudaBuffer<f64>,
13652 a_shape: &[usize],
13653 b_shape: &[usize],
13654 out_shape: &[usize],
13655 device: &GpuDevice,
13656) -> GpuResult<CudaBuffer<f64>> {
13657 let a_str = broadcast_strides(a_shape, out_shape);
13658 let b_str = broadcast_strides(b_shape, out_shape);
13659 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13660 let out_numel: usize = out_shape.iter().product();
13661
13662 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13663 let ptx = get_f64_ptx(&CACHE, BROADCAST_ADD_PTX, "broadcast_add_kernel", "broadcast_add_f64_kernel");
13664 if let Some(out) = try_launch_broadcast_binary_f64(
13665 a,
13666 b,
13667 &a_str,
13668 &b_str,
13669 &shape_u32,
13670 out_numel,
13671 device,
13672 ptx,
13673 "broadcast_add_f64_kernel",
13674 )? {
13675 return Ok(out);
13676 }
13677
13678 cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
13679}
13680
13681#[cfg(feature = "cuda")]
13683pub fn gpu_broadcast_sub_f64(
13684 a: &CudaBuffer<f64>,
13685 b: &CudaBuffer<f64>,
13686 a_shape: &[usize],
13687 b_shape: &[usize],
13688 out_shape: &[usize],
13689 device: &GpuDevice,
13690) -> GpuResult<CudaBuffer<f64>> {
13691 let a_str = broadcast_strides(a_shape, out_shape);
13692 let b_str = broadcast_strides(b_shape, out_shape);
13693 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13694 let out_numel: usize = out_shape.iter().product();
13695
13696 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13697 let ptx = get_f64_ptx(&CACHE, BROADCAST_SUB_PTX, "broadcast_sub_kernel", "broadcast_sub_f64_kernel");
13698 if let Some(out) = try_launch_broadcast_binary_f64(
13699 a,
13700 b,
13701 &a_str,
13702 &b_str,
13703 &shape_u32,
13704 out_numel,
13705 device,
13706 ptx,
13707 "broadcast_sub_f64_kernel",
13708 )? {
13709 return Ok(out);
13710 }
13711
13712 cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
13713}
13714
13715#[cfg(feature = "cuda")]
13717pub fn gpu_broadcast_mul_f64(
13718 a: &CudaBuffer<f64>,
13719 b: &CudaBuffer<f64>,
13720 a_shape: &[usize],
13721 b_shape: &[usize],
13722 out_shape: &[usize],
13723 device: &GpuDevice,
13724) -> GpuResult<CudaBuffer<f64>> {
13725 let a_str = broadcast_strides(a_shape, out_shape);
13726 let b_str = broadcast_strides(b_shape, out_shape);
13727 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13728 let out_numel: usize = out_shape.iter().product();
13729
13730 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13731 let ptx = get_f64_ptx(&CACHE, BROADCAST_MUL_PTX, "broadcast_mul_kernel", "broadcast_mul_f64_kernel");
13732 if let Some(out) = try_launch_broadcast_binary_f64(
13733 a,
13734 b,
13735 &a_str,
13736 &b_str,
13737 &shape_u32,
13738 out_numel,
13739 device,
13740 ptx,
13741 "broadcast_mul_f64_kernel",
13742 )? {
13743 return Ok(out);
13744 }
13745
13746 cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
13747}
13748
13749#[cfg(feature = "cuda")]
13751pub fn gpu_broadcast_div_f64(
13752 a: &CudaBuffer<f64>,
13753 b: &CudaBuffer<f64>,
13754 a_shape: &[usize],
13755 b_shape: &[usize],
13756 out_shape: &[usize],
13757 device: &GpuDevice,
13758) -> GpuResult<CudaBuffer<f64>> {
13759 let a_str = broadcast_strides(a_shape, out_shape);
13760 let b_str = broadcast_strides(b_shape, out_shape);
13761 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13762 let out_numel: usize = out_shape.iter().product();
13763
13764 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13765 let ptx = get_f64_ptx(&CACHE, BROADCAST_DIV_PTX, "broadcast_div_kernel", "broadcast_div_f64_kernel");
13766 if let Some(out) = try_launch_broadcast_binary_f64(
13767 a,
13768 b,
13769 &a_str,
13770 &b_str,
13771 &shape_u32,
13772 out_numel,
13773 device,
13774 ptx,
13775 "broadcast_div_f64_kernel",
13776 )? {
13777 return Ok(out);
13778 }
13779
13780 cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
13781}
13782
13783#[cfg(feature = "cuda")]
13789pub fn gpu_reduce_sum_f64(
13790 a: &CudaBuffer<f64>,
13791 device: &GpuDevice,
13792) -> GpuResult<CudaBuffer<f64>> {
13793 use cudarc::driver::PushKernelArg;
13794 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13795
13796 let n = a.len();
13797 if n == 0 {
13798 return cpu_to_gpu(&[0.0f64], device);
13799 }
13800
13801 let ctx = device.context();
13802 let stream = device.stream();
13803
13804 let ptx = get_f64_ptx(&CACHE, REDUCE_SUM_PTX, "reduce_sum_kernel", "reduce_sum_f64_kernel");
13805 let f = match crate::module_cache::get_or_compile(
13806 ctx,
13807 ptx,
13808 "reduce_sum_f64_kernel",
13809 device.ordinal() as u32,
13810 ) {
13811 Ok(f) => f,
13812 Err(_) => {
13813 let host = gpu_to_cpu(a, device)?;
13814 let total: f64 = host.iter().sum();
13815 return cpu_to_gpu(&[total], device);
13816 }
13817 };
13818
13819 const BLOCK: u32 = 256;
13820 let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
13821 let num_blocks = num_blocks.min(1024);
13822
13823 let mut partials = alloc_zeros_f64(num_blocks as usize, device)?;
13824 let n_u32 = n as u32;
13825
13826 let cfg = cudarc::driver::LaunchConfig {
13827 grid_dim: (num_blocks.max(1), 1, 1),
13828 block_dim: (BLOCK, 1, 1),
13829 shared_mem_bytes: 0,
13830 };
13831
13832 unsafe {
13833 stream
13834 .launch_builder(&f)
13835 .arg(a.inner())
13836 .arg(partials.inner_mut())
13837 .arg(&n_u32)
13838 .launch(cfg)?;
13839 }
13840
13841 if num_blocks <= 1 {
13842 return Ok(partials);
13843 }
13844
13845 if num_blocks <= 256 {
13846 let host_partials = gpu_to_cpu(&partials, device)?;
13847 let total: f64 = host_partials.iter().sum();
13848 return cpu_to_gpu(&[total], device);
13849 }
13850
13851 gpu_reduce_sum_f64(&partials, device)
13852}
13853
13854#[cfg(feature = "cuda")]
13856pub fn gpu_sum_axis_f64(
13857 a: &CudaBuffer<f64>,
13858 outer: usize,
13859 axis_size: usize,
13860 inner: usize,
13861 device: &GpuDevice,
13862) -> GpuResult<CudaBuffer<f64>> {
13863 use cudarc::driver::PushKernelArg;
13864 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13865
13866 let total_output = outer * inner;
13867 let ctx = device.context();
13868 let stream = device.stream();
13869
13870 let ptx = get_f64_ptx(&CACHE, SUM_AXIS_PTX, "sum_axis_kernel", "sum_axis_f64_kernel");
13871 let f = match crate::module_cache::get_or_compile(
13872 ctx,
13873 ptx,
13874 "sum_axis_f64_kernel",
13875 device.ordinal() as u32,
13876 ) {
13877 Ok(f) => f,
13878 Err(_) => {
13879 let host = gpu_to_cpu(a, device)?;
13880 let mut result = vec![0.0f64; total_output];
13881 for (i, out) in result.iter_mut().enumerate() {
13882 let outer_idx = i / inner;
13883 let inner_idx = i % inner;
13884 let mut sum = 0.0f64;
13885 for k in 0..axis_size {
13886 sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
13887 }
13888 *out = sum;
13889 }
13890 return cpu_to_gpu(&result, device);
13891 }
13892 };
13893
13894 let mut out = alloc_zeros_f64(total_output, device)?;
13895 let cfg = launch_cfg(total_output)?;
13896 let outer_u32 = outer as u32;
13897 let axis_size_u32 = axis_size as u32;
13898 let inner_u32 = inner as u32;
13899 let total_u32 = total_output as u32;
13900
13901 unsafe {
13902 stream
13903 .launch_builder(&f)
13904 .arg(a.inner())
13905 .arg(out.inner_mut())
13906 .arg(&outer_u32)
13907 .arg(&axis_size_u32)
13908 .arg(&inner_u32)
13909 .arg(&total_u32)
13910 .launch(cfg)?;
13911 }
13912
13913 Ok(out)
13914}
13915
13916#[cfg(not(feature = "cuda"))]
13917pub fn gpu_reduce_sum_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
13918#[cfg(not(feature = "cuda"))]
13919pub fn gpu_sum_axis_f64(_a: &CudaBuffer<f64>, _outer: usize, _axis_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
13920
13921#[cfg(feature = "cuda")]
13927pub fn gpu_transpose_2d_f64(
13928 input: &CudaBuffer<f64>,
13929 m: usize,
13930 n: usize,
13931 device: &GpuDevice,
13932) -> GpuResult<CudaBuffer<f64>> {
13933 use cudarc::driver::PushKernelArg;
13934 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13935
13936 validate_device(input, device)?;
13937
13938 let total = m * n;
13939 let ctx = device.context();
13940 let stream = device.stream();
13941
13942 let ptx = get_f64_ptx(&CACHE, TRANSPOSE_2D_PTX, "transpose_2d_kernel", "transpose_2d_f64_kernel");
13943 let f = match crate::module_cache::get_or_compile(
13944 ctx,
13945 ptx,
13946 "transpose_2d_f64_kernel",
13947 device.ordinal() as u32,
13948 ) {
13949 Ok(f) => f,
13950 Err(_) => {
13951 let host = gpu_to_cpu(input, device)?;
13952 let mut out = vec![0.0f64; total];
13953 for i in 0..m {
13954 for j in 0..n {
13955 out[j * m + i] = host[i * n + j];
13956 }
13957 }
13958 return cpu_to_gpu(&out, device);
13959 }
13960 };
13961
13962 let mut out = alloc_zeros_f64(total, device)?;
13963 let cfg = launch_cfg(total)?;
13964 let m_u32 = m as u32;
13965 let n_u32 = n as u32;
13966 let total_u32 = total as u32;
13967
13968 unsafe {
13969 stream
13970 .launch_builder(&f)
13971 .arg(input.inner())
13972 .arg(out.inner_mut())
13973 .arg(&m_u32)
13974 .arg(&n_u32)
13975 .arg(&total_u32)
13976 .launch(cfg)?;
13977 }
13978
13979 Ok(out)
13980}
13981
13982#[cfg(feature = "cuda")]
13984pub fn gpu_permute_0213_f64(
13985 input: &CudaBuffer<f64>,
13986 d0: usize,
13987 d1: usize,
13988 d2: usize,
13989 d3: usize,
13990 device: &GpuDevice,
13991) -> GpuResult<CudaBuffer<f64>> {
13992 use cudarc::driver::PushKernelArg;
13993 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13994
13995 validate_device(input, device)?;
13996
13997 let total = d0 * d1 * d2 * d3;
13998 let ctx = device.context();
13999 let stream = device.stream();
14000
14001 let ptx = get_f64_ptx(&CACHE, PERMUTE_0213_PTX, "permute_0213_kernel", "permute_0213_f64_kernel");
14002 let f = match crate::module_cache::get_or_compile(
14003 ctx,
14004 ptx,
14005 "permute_0213_f64_kernel",
14006 device.ordinal() as u32,
14007 ) {
14008 Ok(f) => f,
14009 Err(_) => {
14010 let host = gpu_to_cpu(input, device)?;
14011 let mut out = vec![0.0f64; total];
14012 for i0 in 0..d0 {
14013 for i1 in 0..d1 {
14014 for i2 in 0..d2 {
14015 for i3 in 0..d3 {
14016 let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
14017 let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
14018 out[out_idx] = host[in_idx];
14019 }
14020 }
14021 }
14022 }
14023 return cpu_to_gpu(&out, device);
14024 }
14025 };
14026
14027 let mut out = alloc_zeros_f64(total, device)?;
14028 let cfg = launch_cfg(total)?;
14029 let d0_u32 = d0 as u32;
14030 let d1_u32 = d1 as u32;
14031 let d2_u32 = d2 as u32;
14032 let d3_u32 = d3 as u32;
14033 let total_u32 = total as u32;
14034
14035 unsafe {
14036 stream
14037 .launch_builder(&f)
14038 .arg(input.inner())
14039 .arg(out.inner_mut())
14040 .arg(&d0_u32)
14041 .arg(&d1_u32)
14042 .arg(&d2_u32)
14043 .arg(&d3_u32)
14044 .arg(&total_u32)
14045 .launch(cfg)?;
14046 }
14047
14048 Ok(out)
14049}
14050
14051#[cfg(feature = "cuda")]
14053pub fn gpu_strided_split_f64(
14054 input: &CudaBuffer<f64>,
14055 total_along_axis: usize,
14056 split_offset: usize,
14057 split_size: usize,
14058 inner_size: usize,
14059 n: usize,
14060 device: &GpuDevice,
14061) -> GpuResult<CudaBuffer<f64>> {
14062 use cudarc::driver::PushKernelArg;
14063 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14064
14065 validate_device(input, device)?;
14066
14067 let ctx = device.context();
14068 let stream = device.stream();
14069
14070 let ptx = get_f64_ptx(&CACHE, STRIDED_SPLIT_PTX, "strided_split_kernel", "strided_split_f64_kernel");
14071 let f = match crate::module_cache::get_or_compile(
14072 ctx,
14073 ptx,
14074 "strided_split_f64_kernel",
14075 device.ordinal() as u32,
14076 ) {
14077 Ok(f) => f,
14078 Err(_) => {
14079 let host = gpu_to_cpu(input, device)?;
14080 let mut result = vec![0.0f64; n];
14081 for (i, out) in result.iter_mut().enumerate() {
14082 let outer_idx = i / (split_size * inner_size);
14083 let within = i % (split_size * inner_size);
14084 let src_idx =
14085 outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
14086 *out = host[src_idx];
14087 }
14088 return cpu_to_gpu(&result, device);
14089 }
14090 };
14091
14092 let mut out = alloc_zeros_f64(n, device)?;
14093 let cfg = launch_cfg(n)?;
14094 let total_ax_u32 = total_along_axis as u32;
14095 let offset_u32 = split_offset as u32;
14096 let split_sz_u32 = split_size as u32;
14097 let inner_u32 = inner_size as u32;
14098 let n_u32 = n as u32;
14099
14100 unsafe {
14101 stream
14102 .launch_builder(&f)
14103 .arg(input.inner())
14104 .arg(out.inner_mut())
14105 .arg(&total_ax_u32)
14106 .arg(&offset_u32)
14107 .arg(&split_sz_u32)
14108 .arg(&inner_u32)
14109 .arg(&n_u32)
14110 .launch(cfg)?;
14111 }
14112
14113 Ok(out)
14114}
14115
14116#[cfg(feature = "cuda")]
14118#[allow(clippy::too_many_arguments)]
14119pub fn gpu_strided_cat_f64(
14120 input: &CudaBuffer<f64>,
14121 output: &mut CudaBuffer<f64>,
14122 total_along_axis: usize,
14123 cat_offset: usize,
14124 part_size: usize,
14125 inner_size: usize,
14126 n: usize,
14127 device: &GpuDevice,
14128) -> GpuResult<()> {
14129 use cudarc::driver::PushKernelArg;
14130
14131 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14132 validate_device(input, device)?;
14133
14134 let ctx = device.context();
14135 let stream = device.stream();
14136
14137 let ptx = get_f64_ptx(&CACHE, STRIDED_CAT_PTX, "strided_cat_kernel", "strided_cat_f64_kernel");
14138 let f = match crate::module_cache::get_or_compile(
14139 ctx,
14140 ptx,
14141 "strided_cat_f64_kernel",
14142 device.ordinal() as u32,
14143 ) {
14144 Ok(f) => f,
14145 Err(_) => {
14146 let host_in = gpu_to_cpu(input, device)?;
14147 let mut host_out = gpu_to_cpu(output, device)?;
14148 for (i, &val) in host_in.iter().enumerate().take(n) {
14149 let outer_idx = i / (part_size * inner_size);
14150 let within = i % (part_size * inner_size);
14151 let dst_idx =
14152 outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
14153 host_out[dst_idx] = val;
14154 }
14155 *output = cpu_to_gpu(&host_out, device)?;
14156 return Ok(());
14157 }
14158 };
14159
14160 let cfg = launch_cfg(n)?;
14161 let total_ax_u32 = total_along_axis as u32;
14162 let offset_u32 = cat_offset as u32;
14163 let part_sz_u32 = part_size as u32;
14164 let inner_u32 = inner_size as u32;
14165 let n_u32 = n as u32;
14166
14167 unsafe {
14168 stream
14169 .launch_builder(&f)
14170 .arg(input.inner())
14171 .arg(output.inner_mut())
14172 .arg(&total_ax_u32)
14173 .arg(&offset_u32)
14174 .arg(&part_sz_u32)
14175 .arg(&inner_u32)
14176 .arg(&n_u32)
14177 .launch(cfg)?;
14178 }
14179
14180 Ok(())
14181}
14182
14183#[cfg(feature = "cuda")]
14189pub fn gpu_index_select_1d_f64(
14190 input: &CudaBuffer<f64>,
14191 indices: &CudaBuffer<f32>,
14192 device: &GpuDevice,
14193) -> GpuResult<CudaBuffer<f64>> {
14194 use cudarc::driver::PushKernelArg;
14195 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14196
14197 validate_device(input, device)?;
14198
14199 let n = indices.len();
14200 let ctx = device.context();
14201 let stream = device.stream();
14202
14203 let ptx = get_f64_ptx(&CACHE, INDEX_SELECT_1D_PTX, "index_select_1d_kernel", "index_select_1d_f64_kernel");
14204 let f = match crate::module_cache::get_or_compile(
14205 ctx,
14206 ptx,
14207 "index_select_1d_f64_kernel",
14208 device.ordinal() as u32,
14209 ) {
14210 Ok(f) => f,
14211 Err(_) => {
14212 let input_host = gpu_to_cpu(input, device)?;
14213 let indices_host = gpu_to_cpu(indices, device)?;
14214 let result: Vec<f64> = indices_host
14215 .iter()
14216 .map(|&idx_f| input_host[idx_f as usize])
14217 .collect();
14218 return cpu_to_gpu(&result, device);
14219 }
14220 };
14221
14222 let mut out = alloc_zeros_f64(n, device)?;
14223 let cfg = launch_cfg(n)?;
14224 let n_u32 = n as u32;
14225
14226 unsafe {
14227 stream
14228 .launch_builder(&f)
14229 .arg(input.inner())
14230 .arg(indices.inner())
14231 .arg(out.inner_mut())
14232 .arg(&n_u32)
14233 .launch(cfg)?;
14234 }
14235
14236 Ok(out)
14237}
14238
14239#[cfg(feature = "cuda")]
14243pub fn gpu_scatter_add_1d_f64(
14244 grad_output: &CudaBuffer<f64>,
14245 indices: &CudaBuffer<f32>,
14246 input_len: usize,
14247 device: &GpuDevice,
14248) -> GpuResult<CudaBuffer<f64>> {
14249 use cudarc::driver::PushKernelArg;
14250 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14251
14252 validate_device(grad_output, device)?;
14253
14254 let n = grad_output.len();
14255 let ctx = device.context();
14256 let stream = device.stream();
14257
14258 let ptx = get_f64_ptx(&CACHE, SCATTER_ADD_1D_PTX, "scatter_add_1d_kernel", "scatter_add_1d_f64_kernel");
14259 let f = match crate::module_cache::get_or_compile(
14260 ctx,
14261 ptx,
14262 "scatter_add_1d_f64_kernel",
14263 device.ordinal() as u32,
14264 ) {
14265 Ok(f) => f,
14266 Err(_) => {
14267 let go_host = gpu_to_cpu(grad_output, device)?;
14268 let idx_host = gpu_to_cpu(indices, device)?;
14269 let mut result = vec![0.0f64; input_len];
14270 for (i, &idx_f) in idx_host.iter().enumerate() {
14271 result[idx_f as usize] += go_host[i];
14272 }
14273 return cpu_to_gpu(&result, device);
14274 }
14275 };
14276
14277 let mut out = alloc_zeros_f64(input_len, device)?;
14278 let cfg = launch_cfg(n)?;
14279 let n_u32 = n as u32;
14280
14281 unsafe {
14282 stream
14283 .launch_builder(&f)
14284 .arg(grad_output.inner())
14285 .arg(indices.inner())
14286 .arg(out.inner_mut())
14287 .arg(&n_u32)
14288 .launch(cfg)?;
14289 }
14290
14291 Ok(out)
14292}
14293
14294#[cfg(feature = "cuda")]
14299pub fn gpu_masked_fill_f64(
14300 input: &CudaBuffer<f64>,
14301 mask: &CudaBuffer<u8>,
14302 value: f64,
14303 device: &GpuDevice,
14304) -> GpuResult<CudaBuffer<f64>> {
14305 use cudarc::driver::PushKernelArg;
14306 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14307
14308 validate_device(input, device)?;
14309
14310 let n = input.len();
14311 let ctx = device.context();
14312 let stream = device.stream();
14313
14314 let ptx = get_f64_ptx(&CACHE, MASKED_FILL_PTX, "masked_fill_kernel", "masked_fill_f64_kernel");
14315 let f = match crate::module_cache::get_or_compile(
14316 ctx,
14317 ptx,
14318 "masked_fill_f64_kernel",
14319 device.ordinal() as u32,
14320 ) {
14321 Ok(f) => f,
14322 Err(_) => {
14323 let input_host = gpu_to_cpu(input, device)?;
14324 let mask_host = gpu_to_cpu(mask, device)?;
14325 let result: Vec<f64> = input_host
14326 .iter()
14327 .zip(mask_host.iter())
14328 .map(|(&x, &m)| if m != 0 { value } else { x })
14329 .collect();
14330 return cpu_to_gpu(&result, device);
14331 }
14332 };
14333
14334 let mut out = alloc_zeros_f64(n, device)?;
14335 let cfg = launch_cfg(n)?;
14336 let n_u32 = n as u32;
14337
14338 unsafe {
14339 stream
14340 .launch_builder(&f)
14341 .arg(input.inner())
14342 .arg(mask.inner())
14343 .arg(out.inner_mut())
14344 .arg(&value)
14345 .arg(&n_u32)
14346 .launch(cfg)?;
14347 }
14348
14349 Ok(out)
14350}
14351
14352#[cfg(feature = "cuda")]
14356pub fn gpu_masked_zero_f64(
14357 grad: &CudaBuffer<f64>,
14358 mask: &CudaBuffer<u8>,
14359 device: &GpuDevice,
14360) -> GpuResult<CudaBuffer<f64>> {
14361 use cudarc::driver::PushKernelArg;
14362 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14363
14364 validate_device(grad, device)?;
14365
14366 let n = grad.len();
14367 let ctx = device.context();
14368 let stream = device.stream();
14369
14370 let ptx = get_f64_ptx(&CACHE, MASKED_ZERO_PTX, "masked_zero_kernel", "masked_zero_f64_kernel");
14371 let f = match crate::module_cache::get_or_compile(
14372 ctx,
14373 ptx,
14374 "masked_zero_f64_kernel",
14375 device.ordinal() as u32,
14376 ) {
14377 Ok(f) => f,
14378 Err(_) => {
14379 let grad_host = gpu_to_cpu(grad, device)?;
14380 let mask_host = gpu_to_cpu(mask, device)?;
14381 let result: Vec<f64> = grad_host
14382 .iter()
14383 .zip(mask_host.iter())
14384 .map(|(&g, &m)| if m != 0 { 0.0 } else { g })
14385 .collect();
14386 return cpu_to_gpu(&result, device);
14387 }
14388 };
14389
14390 let mut out = alloc_zeros_f64(n, device)?;
14391 let cfg = launch_cfg(n)?;
14392 let n_u32 = n as u32;
14393
14394 unsafe {
14395 stream
14396 .launch_builder(&f)
14397 .arg(grad.inner())
14398 .arg(mask.inner())
14399 .arg(out.inner_mut())
14400 .arg(&n_u32)
14401 .launch(cfg)?;
14402 }
14403
14404 Ok(out)
14405}
14406
14407#[cfg(feature = "cuda")]
14409pub fn gpu_slice_write_f64(
14410 src: &CudaBuffer<f64>,
14411 dst: &mut CudaBuffer<f64>,
14412 n_batch: usize,
14413 d: usize,
14414 max_len: usize,
14415 pos: usize,
14416 device: &GpuDevice,
14417) -> GpuResult<()> {
14418 use cudarc::driver::PushKernelArg;
14419 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14420
14421 let total = n_batch * d;
14422 let ctx = device.context();
14423 let stream = device.stream();
14424
14425 let ptx = get_f64_ptx(&CACHE, SLICE_WRITE_PTX, "slice_write_kernel", "slice_write_f64_kernel");
14426 let f = match crate::module_cache::get_or_compile(
14427 ctx,
14428 ptx,
14429 "slice_write_f64_kernel",
14430 device.ordinal() as u32,
14431 ) {
14432 Ok(f) => f,
14433 Err(_) => {
14434 let src_host = gpu_to_cpu(src, device)?;
14435 let mut dst_host = gpu_to_cpu(dst, device)?;
14436 for b in 0..n_batch {
14437 for di in 0..d {
14438 dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
14439 }
14440 }
14441 let new_dst = cpu_to_gpu(&dst_host, device)?;
14442 *dst = new_dst;
14443 return Ok(());
14444 }
14445 };
14446
14447 let cfg = launch_cfg(total)?;
14448 let n_u32 = total as u32;
14449 let d_u32 = d as u32;
14450 let max_len_u32 = max_len as u32;
14451 let pos_u32 = pos as u32;
14452
14453 unsafe {
14454 stream
14455 .launch_builder(&f)
14456 .arg(src.inner())
14457 .arg(dst.inner_mut())
14458 .arg(&n_u32)
14459 .arg(&d_u32)
14460 .arg(&max_len_u32)
14461 .arg(&pos_u32)
14462 .launch(cfg)?;
14463 }
14464
14465 Ok(())
14466}
14467
14468#[cfg(feature = "cuda")]
14470pub fn gpu_slice_read_f64(
14471 src: &CudaBuffer<f64>,
14472 n_batch: usize,
14473 d: usize,
14474 len: usize,
14475 max_len: usize,
14476 device: &GpuDevice,
14477) -> GpuResult<CudaBuffer<f64>> {
14478 use cudarc::driver::PushKernelArg;
14479 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14480
14481 let total = n_batch * len * d;
14482 let ctx = device.context();
14483 let stream = device.stream();
14484
14485 let ptx = get_f64_ptx(&CACHE, SLICE_READ_PTX, "slice_read_kernel", "slice_read_f64_kernel");
14486 let f = match crate::module_cache::get_or_compile(
14487 ctx,
14488 ptx,
14489 "slice_read_f64_kernel",
14490 device.ordinal() as u32,
14491 ) {
14492 Ok(f) => f,
14493 Err(_) => {
14494 let host = gpu_to_cpu(src, device)?;
14495 let mut out = vec![0.0f64; total];
14496 for b in 0..n_batch {
14497 for r in 0..len {
14498 for di in 0..d {
14499 out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
14500 }
14501 }
14502 }
14503 return cpu_to_gpu(&out, device);
14504 }
14505 };
14506
14507 let mut out = alloc_zeros_f64(total, device)?;
14508 let cfg = launch_cfg(total)?;
14509 let total_u32 = total as u32;
14510 let d_u32 = d as u32;
14511 let len_u32 = len as u32;
14512 let max_len_u32 = max_len as u32;
14513
14514 unsafe {
14515 stream
14516 .launch_builder(&f)
14517 .arg(src.inner())
14518 .arg(out.inner_mut())
14519 .arg(&total_u32)
14520 .arg(&d_u32)
14521 .arg(&len_u32)
14522 .arg(&max_len_u32)
14523 .launch(cfg)?;
14524 }
14525
14526 Ok(out)
14527}
14528
14529#[cfg(feature = "cuda")]
14535pub fn gpu_embed_lookup_f64(
14536 idx: &CudaBuffer<f32>,
14537 weight: &CudaBuffer<f64>,
14538 d: usize,
14539 device: &GpuDevice,
14540) -> GpuResult<CudaBuffer<f64>> {
14541 use cudarc::driver::PushKernelArg;
14542 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14543
14544 let ctx = device.context();
14545 let stream = device.stream();
14546
14547 let ptx = get_f64_ptx(&CACHE, EMBED_LOOKUP_PTX, "embed_lookup_kernel", "embed_lookup_f64_kernel");
14548 let f = match crate::module_cache::get_or_compile(
14549 ctx,
14550 ptx,
14551 "embed_lookup_f64_kernel",
14552 device.ordinal() as u32,
14553 ) {
14554 Ok(f) => f,
14555 Err(_) => {
14556 let idx_host = gpu_to_cpu(idx, device)?;
14557 let weight_host = gpu_to_cpu(weight, device)?;
14558 let row = idx_host[0] as usize;
14559 let start = row * d;
14560 let out = weight_host[start..start + d].to_vec();
14561 return cpu_to_gpu(&out, device);
14562 }
14563 };
14564
14565 let mut out = alloc_zeros_f64(d, device)?;
14566 let cfg = launch_cfg(d)?;
14567 let d_u32 = d as u32;
14568
14569 unsafe {
14570 stream
14571 .launch_builder(&f)
14572 .arg(idx.inner())
14573 .arg(weight.inner())
14574 .arg(out.inner_mut())
14575 .arg(&d_u32)
14576 .launch(cfg)?;
14577 }
14578
14579 Ok(out)
14580}
14581
14582#[cfg(feature = "cuda")]
14584pub fn gpu_embed_lookup_batch_f64(
14585 indices: &CudaBuffer<f32>,
14586 weight: &CudaBuffer<f64>,
14587 n: usize,
14588 d: usize,
14589 device: &GpuDevice,
14590) -> GpuResult<CudaBuffer<f64>> {
14591 use cudarc::driver::PushKernelArg;
14592 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14593
14594 let total = n * d;
14595 if total == 0 {
14596 return alloc_zeros_f64(0, device);
14597 }
14598
14599 let ctx = device.context();
14600 let stream = device.stream();
14601
14602 let ptx = get_f64_ptx(&CACHE, EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel", "embed_lookup_batch_f64_kernel");
14603 let f = match crate::module_cache::get_or_compile(
14604 ctx,
14605 ptx,
14606 "embed_lookup_batch_f64_kernel",
14607 device.ordinal() as u32,
14608 ) {
14609 Ok(f) => f,
14610 Err(_) => {
14611 let idx_host = gpu_to_cpu(indices, device)?;
14612 let weight_host = gpu_to_cpu(weight, device)?;
14613 let mut out = Vec::with_capacity(total);
14614 for &idx_f in &idx_host {
14615 let row = idx_f as usize;
14616 let start = row * d;
14617 out.extend_from_slice(&weight_host[start..start + d]);
14618 }
14619 return cpu_to_gpu(&out, device);
14620 }
14621 };
14622
14623 let mut out = alloc_zeros_f64(total, device)?;
14624 let cfg = launch_cfg(total)?;
14625 let d_u32 = d as u32;
14626 let total_u32 = total as u32;
14627
14628 unsafe {
14629 stream
14630 .launch_builder(&f)
14631 .arg(indices.inner())
14632 .arg(weight.inner())
14633 .arg(out.inner_mut())
14634 .arg(&d_u32)
14635 .arg(&total_u32)
14636 .launch(cfg)?;
14637 }
14638
14639 Ok(out)
14640}
14641
14642#[cfg(feature = "cuda")]
14646pub fn gpu_scatter_add_rows_f64(
14647 grad_output: &CudaBuffer<f64>,
14648 indices: &CudaBuffer<f32>,
14649 num_embeddings: usize,
14650 d: usize,
14651 device: &GpuDevice,
14652) -> GpuResult<CudaBuffer<f64>> {
14653 use cudarc::driver::PushKernelArg;
14654 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14655
14656 let n = indices.len();
14657 let total = n * d;
14658
14659 if total == 0 {
14660 return alloc_zeros_f64(num_embeddings * d, device);
14661 }
14662
14663 let ctx = device.context();
14664 let stream = device.stream();
14665
14666 let ptx = get_f64_ptx(&CACHE, SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel", "scatter_add_rows_f64_kernel");
14667 let f = match crate::module_cache::get_or_compile(
14668 ctx,
14669 ptx,
14670 "scatter_add_rows_f64_kernel",
14671 device.ordinal() as u32,
14672 ) {
14673 Ok(f) => f,
14674 Err(_) => {
14675 let go_host = gpu_to_cpu(grad_output, device)?;
14676 let idx_host = gpu_to_cpu(indices, device)?;
14677 let mut result = vec![0.0f64; num_embeddings * d];
14678 for (i, &idx_f) in idx_host.iter().enumerate() {
14679 let row = idx_f as usize;
14680 for j in 0..d {
14681 result[row * d + j] += go_host[i * d + j];
14682 }
14683 }
14684 return cpu_to_gpu(&result, device);
14685 }
14686 };
14687
14688 let mut out = alloc_zeros_f64(num_embeddings * d, device)?;
14689 let cfg = launch_cfg(total)?;
14690 let d_u32 = d as u32;
14691 let total_u32 = total as u32;
14692
14693 unsafe {
14694 stream
14695 .launch_builder(&f)
14696 .arg(grad_output.inner())
14697 .arg(indices.inner())
14698 .arg(out.inner_mut())
14699 .arg(&d_u32)
14700 .arg(&total_u32)
14701 .launch(cfg)?;
14702 }
14703
14704 Ok(out)
14705}
14706
14707#[cfg(feature = "cuda")]
14717#[allow(clippy::too_many_arguments)]
14718pub fn gpu_fused_adam(
14719 param: &mut CudaBuffer<f32>,
14720 grad: &CudaBuffer<f32>,
14721 exp_avg: &mut CudaBuffer<f32>,
14722 exp_avg_sq: &mut CudaBuffer<f32>,
14723 beta1: f32,
14724 beta2: f32,
14725 lr: f32,
14726 eps: f32,
14727 bc1: f32,
14728 bc2: f32,
14729 weight_decay: f32,
14730 device: &GpuDevice,
14731) -> GpuResult<()> {
14732 use cudarc::driver::PushKernelArg;
14733
14734 let n = param.len();
14735 if grad.len() != n || exp_avg.len() != n || exp_avg_sq.len() != n {
14736 return Err(GpuError::LengthMismatch {
14737 a: n,
14738 b: grad.len(),
14739 });
14740 }
14741
14742 let ctx = device.context();
14743 let stream = device.stream();
14744
14745 let f = match crate::module_cache::get_or_compile(
14746 ctx,
14747 FUSED_ADAM_PTX,
14748 "fused_adam_kernel",
14749 device.ordinal() as u32,
14750 ) {
14751 Ok(f) => f,
14752 Err(_) => {
14753 let mut p_host = gpu_to_cpu(param, device)?;
14755 let g_host = gpu_to_cpu(grad, device)?;
14756 let mut m_host = gpu_to_cpu(exp_avg, device)?;
14757 let mut v_host = gpu_to_cpu(exp_avg_sq, device)?;
14758
14759 for i in 0..n {
14760 let mut g = g_host[i];
14761 if weight_decay > 0.0 {
14762 g += weight_decay * p_host[i];
14763 }
14764 m_host[i] = beta1 * m_host[i] + (1.0 - beta1) * g;
14765 v_host[i] = beta2 * v_host[i] + (1.0 - beta2) * g * g;
14766 let m_hat = m_host[i] / bc1;
14767 let v_hat = v_host[i] / bc2;
14768 p_host[i] -= lr * m_hat / (v_hat.sqrt() + eps);
14769 }
14770
14771 *param = cpu_to_gpu(&p_host, device)?;
14772 *exp_avg = cpu_to_gpu(&m_host, device)?;
14773 *exp_avg_sq = cpu_to_gpu(&v_host, device)?;
14774 return Ok(());
14775 }
14776 };
14777
14778 let cfg = launch_cfg(n)?;
14779 let n_u32 = n as u32;
14780
14781 unsafe {
14782 stream
14783 .launch_builder(&f)
14784 .arg(param.inner_mut())
14785 .arg(grad.inner())
14786 .arg(exp_avg.inner_mut())
14787 .arg(exp_avg_sq.inner_mut())
14788 .arg(&beta1)
14789 .arg(&beta2)
14790 .arg(&lr)
14791 .arg(&eps)
14792 .arg(&bc1)
14793 .arg(&bc2)
14794 .arg(&weight_decay)
14795 .arg(&n_u32)
14796 .launch(cfg)?;
14797 }
14798
14799 Ok(())
14800}
14801
14802#[cfg(not(feature = "cuda"))]
14804#[allow(clippy::too_many_arguments)]
14805pub fn gpu_fused_adam(
14806 _param: &mut CudaBuffer<f32>,
14807 _grad: &CudaBuffer<f32>,
14808 _exp_avg: &mut CudaBuffer<f32>,
14809 _exp_avg_sq: &mut CudaBuffer<f32>,
14810 _beta1: f32,
14811 _beta2: f32,
14812 _lr: f32,
14813 _eps: f32,
14814 _bc1: f32,
14815 _bc2: f32,
14816 _weight_decay: f32,
14817 _device: &GpuDevice,
14818) -> GpuResult<()> {
14819 Err(GpuError::NoCudaFeature)
14820}
14821
14822#[cfg(feature = "cuda")]
14840pub fn gpu_fused_gru_forward(
14841 input_gates: &CudaBuffer<f32>,
14842 hidden_gates: &CudaBuffer<f32>,
14843 bias_ih: &CudaBuffer<f32>,
14844 bias_hh: &CudaBuffer<f32>,
14845 hx: &CudaBuffer<f32>,
14846 hsz: usize,
14847 device: &GpuDevice,
14848) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
14849 use cudarc::driver::PushKernelArg;
14850
14851 let total = hx.len(); let batch = total / hsz;
14853
14854 let ctx = device.context();
14855 let stream = device.stream();
14856
14857 let f = match crate::module_cache::get_or_compile(
14858 ctx,
14859 FUSED_GRU_FORWARD_PTX,
14860 "fused_gru_forward_kernel",
14861 device.ordinal() as u32,
14862 ) {
14863 Ok(f) => f,
14864 Err(_) => {
14865 return Err(GpuError::PtxCompileFailed {
14866 kernel: "fused_gru_forward_kernel",
14867 });
14868 }
14869 };
14870
14871 let mut hy = alloc_zeros_f32(total, device)?;
14872 let mut workspace = alloc_zeros_f32(batch * 5 * hsz, device)?;
14873
14874 let cfg = launch_cfg(total)?;
14875 let hsz_u32 = hsz as u32;
14876 let total_u32 = total as u32;
14877
14878 unsafe {
14879 stream
14880 .launch_builder(&f)
14881 .arg(input_gates.inner())
14882 .arg(hidden_gates.inner())
14883 .arg(bias_ih.inner())
14884 .arg(bias_hh.inner())
14885 .arg(hx.inner())
14886 .arg(hy.inner_mut())
14887 .arg(workspace.inner_mut())
14888 .arg(&hsz_u32)
14889 .arg(&total_u32)
14890 .launch(cfg)?;
14891 }
14892
14893 Ok((hy, workspace))
14894}
14895
14896#[cfg(not(feature = "cuda"))]
14898pub fn gpu_fused_gru_forward(
14899 _input_gates: &CudaBuffer<f32>,
14900 _hidden_gates: &CudaBuffer<f32>,
14901 _bias_ih: &CudaBuffer<f32>,
14902 _bias_hh: &CudaBuffer<f32>,
14903 _hx: &CudaBuffer<f32>,
14904 _hsz: usize,
14905 _device: &GpuDevice,
14906) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
14907 Err(GpuError::NoCudaFeature)
14908}
14909
14910#[cfg(feature = "cuda")]
14916#[allow(clippy::too_many_arguments)]
14917pub fn gpu_maxpool2d(
14918 input: &CudaBuffer<f32>,
14919 batch: usize,
14920 channels: usize,
14921 h_in: usize,
14922 w_in: usize,
14923 kh: usize,
14924 kw: usize,
14925 sh: usize,
14926 sw: usize,
14927 ph: usize,
14928 pw: usize,
14929 device: &GpuDevice,
14930) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
14931 use cudarc::driver::PushKernelArg;
14932
14933 let h_out = (h_in + 2 * ph - kh) / sh + 1;
14934 let w_out = (w_in + 2 * pw - kw) / sw + 1;
14935 let total = batch * channels * h_out * w_out;
14936
14937 let ctx = device.context();
14938 let stream = device.stream();
14939
14940 let f = match crate::module_cache::get_or_compile(
14941 ctx, MAXPOOL2D_PTX, "maxpool2d_forward_kernel", device.ordinal() as u32,
14942 ) {
14943 Ok(f) => f,
14944 Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "maxpool2d_forward_kernel" }),
14945 };
14946
14947 let mut out = alloc_zeros_f32(total, device)?;
14948 let cfg = launch_cfg(total)?;
14949
14950 let (batch_u32, ch_u32) = (batch as u32, channels as u32);
14951 let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
14952 let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
14953 let (kh_u32, kw_u32) = (kh as u32, kw as u32);
14954 let (sh_u32, sw_u32) = (sh as u32, sw as u32);
14955 let (ph_u32, pw_u32) = (ph as u32, pw as u32);
14956 let total_u32 = total as u32;
14957
14958 unsafe {
14959 stream.launch_builder(&f)
14960 .arg(input.inner())
14961 .arg(out.inner_mut())
14962 .arg(&batch_u32).arg(&ch_u32)
14963 .arg(&h_in_u32).arg(&w_in_u32)
14964 .arg(&h_out_u32).arg(&w_out_u32)
14965 .arg(&kh_u32).arg(&kw_u32)
14966 .arg(&sh_u32).arg(&sw_u32)
14967 .arg(&ph_u32).arg(&pw_u32)
14968 .arg(&total_u32)
14969 .launch(cfg)?;
14970 }
14971
14972 Ok((out, [batch, channels, h_out, w_out]))
14973}
14974
14975#[cfg(not(feature = "cuda"))]
14977#[allow(clippy::too_many_arguments)]
14978pub fn gpu_maxpool2d(
14979 _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
14980 _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
14981 _sh: usize, _sw: usize, _ph: usize, _pw: usize,
14982 _device: &GpuDevice,
14983) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
14984 Err(GpuError::NoCudaFeature)
14985}
14986
14987#[cfg(feature = "cuda")]
14989#[allow(clippy::too_many_arguments)]
14990pub fn gpu_avgpool2d(
14991 input: &CudaBuffer<f32>,
14992 batch: usize,
14993 channels: usize,
14994 h_in: usize,
14995 w_in: usize,
14996 kh: usize,
14997 kw: usize,
14998 sh: usize,
14999 sw: usize,
15000 ph: usize,
15001 pw: usize,
15002 device: &GpuDevice,
15003) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
15004 use cudarc::driver::PushKernelArg;
15005
15006 let h_out = (h_in + 2 * ph - kh) / sh + 1;
15007 let w_out = (w_in + 2 * pw - kw) / sw + 1;
15008 let total = batch * channels * h_out * w_out;
15009
15010 let ctx = device.context();
15011 let stream = device.stream();
15012
15013 let f = match crate::module_cache::get_or_compile(
15014 ctx, AVGPOOL2D_PTX, "avgpool2d_forward_kernel", device.ordinal() as u32,
15015 ) {
15016 Ok(f) => f,
15017 Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "avgpool2d_forward_kernel" }),
15018 };
15019
15020 let mut out = alloc_zeros_f32(total, device)?;
15021 let cfg = launch_cfg(total)?;
15022
15023 let (batch_u32, ch_u32) = (batch as u32, channels as u32);
15024 let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
15025 let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
15026 let (kh_u32, kw_u32) = (kh as u32, kw as u32);
15027 let (sh_u32, sw_u32) = (sh as u32, sw as u32);
15028 let (ph_u32, pw_u32) = (ph as u32, pw as u32);
15029 let total_u32 = total as u32;
15030
15031 unsafe {
15032 stream.launch_builder(&f)
15033 .arg(input.inner())
15034 .arg(out.inner_mut())
15035 .arg(&batch_u32).arg(&ch_u32)
15036 .arg(&h_in_u32).arg(&w_in_u32)
15037 .arg(&h_out_u32).arg(&w_out_u32)
15038 .arg(&kh_u32).arg(&kw_u32)
15039 .arg(&sh_u32).arg(&sw_u32)
15040 .arg(&ph_u32).arg(&pw_u32)
15041 .arg(&total_u32)
15042 .launch(cfg)?;
15043 }
15044
15045 Ok((out, [batch, channels, h_out, w_out]))
15046}
15047
15048#[cfg(not(feature = "cuda"))]
15050#[allow(clippy::too_many_arguments)]
15051pub fn gpu_avgpool2d(
15052 _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
15053 _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
15054 _sh: usize, _sw: usize, _ph: usize, _pw: usize,
15055 _device: &GpuDevice,
15056) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
15057 Err(GpuError::NoCudaFeature)
15058}
15059
15060#[cfg(feature = "cuda")]
15068#[allow(clippy::too_many_arguments)]
15069pub fn gpu_batchnorm_forward(
15070 _input: &CudaBuffer<f32>,
15071 _weight: &CudaBuffer<f32>,
15072 _bias: &CudaBuffer<f32>,
15073 _running_mean: &mut CudaBuffer<f32>,
15074 _running_var: &mut CudaBuffer<f32>,
15075 _channels: usize,
15076 _spatial: usize,
15077 _eps: f32,
15078 _momentum: f32,
15079 _training: bool,
15080 device: &GpuDevice,
15081) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
15082 let ctx = device.context();
15084 let _f = crate::module_cache::get_or_compile(
15085 ctx,
15086 BATCHNORM_FORWARD_PTX,
15087 "batchnorm_forward_kernel",
15088 device.ordinal() as u32,
15089 );
15090 Err(GpuError::ShapeMismatch {
15092 op: "batchnorm_forward",
15093 expected: vec![0],
15094 got: vec![1],
15095 })
15096}
15097
15098#[cfg(not(feature = "cuda"))]
15100#[allow(clippy::too_many_arguments)]
15101pub fn gpu_batchnorm_forward(
15102 _input: &CudaBuffer<f32>,
15103 _weight: &CudaBuffer<f32>,
15104 _bias: &CudaBuffer<f32>,
15105 _running_mean: &mut CudaBuffer<f32>,
15106 _running_var: &mut CudaBuffer<f32>,
15107 _channels: usize,
15108 _spatial: usize,
15109 _eps: f32,
15110 _momentum: f32,
15111 _training: bool,
15112 _device: &GpuDevice,
15113) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
15114 Err(GpuError::NoCudaFeature)
15115}
15116
15117#[cfg(feature = "cuda")]
15126pub fn gpu_layernorm(
15127 input: &CudaBuffer<f32>,
15128 weight: &CudaBuffer<f32>,
15129 bias: &CudaBuffer<f32>,
15130 rows: usize,
15131 cols: usize,
15132 eps: f32,
15133 device: &GpuDevice,
15134) -> GpuResult<CudaBuffer<f32>> {
15135 use cudarc::driver::PushKernelArg;
15136
15137 validate_unary(input, device)?;
15138
15139 let ctx = device.context();
15140 let stream = device.stream();
15141
15142 let f = match crate::module_cache::get_or_compile(
15143 ctx,
15144 LAYERNORM_PTX,
15145 "layernorm_kernel",
15146 device.ordinal() as u32,
15147 ) {
15148 Ok(f) => f,
15149 Err(e) => {
15150 eprintln!("ferrotorch-gpu: LayerNorm PTX compilation failed ({e:?}), CPU fallback");
15151 std::fs::write("/tmp/layernorm_debug.ptx", LAYERNORM_PTX).ok();
15152 eprintln!(
15153 "ferrotorch-gpu: dumped PTX to /tmp/layernorm_debug.ptx ({} bytes)",
15154 LAYERNORM_PTX.len()
15155 );
15156 let h_in = gpu_to_cpu(input, device)?;
15157 let h_w = gpu_to_cpu(weight, device)?;
15158 let h_b = gpu_to_cpu(bias, device)?;
15159 let mut out = vec![0.0f32; rows * cols];
15160 for r in 0..rows {
15161 let base = r * cols;
15162 let slice = &h_in[base..base + cols];
15163 let mean: f32 = slice.iter().sum::<f32>() / cols as f32;
15164 let var: f32 =
15165 slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / cols as f32;
15166 let inv_std = 1.0 / (var + eps).sqrt();
15167 for c in 0..cols {
15168 let normed = (slice[c] - mean) * inv_std;
15169 out[base + c] = h_w[c] * normed + h_b[c];
15170 }
15171 }
15172 return cpu_to_gpu(&out, device);
15173 }
15174 };
15175
15176 let mut out = alloc_zeros_f32(rows * cols, device)?;
15177 let rows_u32 = rows as u32;
15178 let cols_u32 = cols as u32;
15179
15180 let cfg = LaunchConfig {
15181 grid_dim: ((rows as u32).max(1), 1, 1),
15182 block_dim: (256, 1, 1),
15183 shared_mem_bytes: 256 * 4,
15184 };
15185
15186 unsafe {
15187 stream
15188 .launch_builder(&f)
15189 .arg(input.inner())
15190 .arg(out.inner_mut())
15191 .arg(weight.inner())
15192 .arg(bias.inner())
15193 .arg(&rows_u32)
15194 .arg(&cols_u32)
15195 .arg(&eps)
15196 .launch(cfg)?;
15197 }
15198
15199 Ok(out)
15200}
15201
15202#[cfg(feature = "cuda")]
15215pub fn gpu_layernorm_backward(
15216 input: &CudaBuffer<f32>,
15217 grad_output: &CudaBuffer<f32>,
15218 weight: &CudaBuffer<f32>,
15219 rows: usize,
15220 cols: usize,
15221 eps: f32,
15222 device: &GpuDevice,
15223) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
15224 use cudarc::driver::PushKernelArg;
15225
15226 validate_unary(input, device)?;
15227
15228 let ctx = device.context();
15229 let stream = device.stream();
15230
15231 let f = match crate::module_cache::get_or_compile(
15232 ctx,
15233 LAYERNORM_BACKWARD_PTX,
15234 "layernorm_backward_kernel",
15235 device.ordinal() as u32,
15236 ) {
15237 Ok(f) => f,
15238 Err(_) => {
15239 let h_in = gpu_to_cpu(input, device)?;
15241 let h_go = gpu_to_cpu(grad_output, device)?;
15242 let h_w = gpu_to_cpu(weight, device)?;
15243 let mut grad_input = vec![0.0f32; rows * cols];
15244 let mut grad_weight = vec![0.0f32; cols];
15245 let mut grad_bias = vec![0.0f32; cols];
15246 let n_f = cols as f32;
15247 for r in 0..rows {
15248 let base = r * cols;
15249 let x_slice = &h_in[base..base + cols];
15250 let go_slice = &h_go[base..base + cols];
15251 let mean: f32 = x_slice.iter().sum::<f32>() / n_f;
15252 let var: f32 = x_slice
15253 .iter()
15254 .map(|&x| (x - mean) * (x - mean))
15255 .sum::<f32>()
15256 / n_f;
15257 let inv_std = 1.0 / (var + eps).sqrt();
15258 let mut sum1 = 0.0f32;
15259 let mut sum2 = 0.0f32;
15260 for c in 0..cols {
15261 let x_hat = (x_slice[c] - mean) * inv_std;
15262 let dl = go_slice[c] * h_w[c];
15263 sum1 += dl;
15264 sum2 += dl * x_hat;
15265 grad_weight[c] += go_slice[c] * x_hat;
15266 grad_bias[c] += go_slice[c];
15267 }
15268 let m1 = sum1 / n_f;
15269 let m2 = sum2 / n_f;
15270 for c in 0..cols {
15271 let x_hat = (x_slice[c] - mean) * inv_std;
15272 let dl = go_slice[c] * h_w[c];
15273 grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
15274 }
15275 }
15276 let gi = cpu_to_gpu(&grad_input, device)?;
15277 let gw = cpu_to_gpu(&grad_weight, device)?;
15278 let gb = cpu_to_gpu(&grad_bias, device)?;
15279 return Ok((gi, gw, gb));
15280 }
15281 };
15282
15283 let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
15284 let mut grad_w = alloc_zeros_f32(cols, device)?;
15285 let mut grad_b = alloc_zeros_f32(cols, device)?;
15286 let rows_u32 = rows as u32;
15287 let cols_u32 = cols as u32;
15288
15289 let cfg = LaunchConfig {
15291 grid_dim: ((rows as u32).max(1), 1, 1),
15292 block_dim: (256, 1, 1),
15293 shared_mem_bytes: 256 * 4,
15294 };
15295
15296 unsafe {
15297 stream
15298 .launch_builder(&f)
15299 .arg(input.inner())
15300 .arg(grad_output.inner())
15301 .arg(weight.inner())
15302 .arg(grad_in.inner_mut())
15303 .arg(grad_w.inner_mut())
15304 .arg(grad_b.inner_mut())
15305 .arg(&rows_u32)
15306 .arg(&cols_u32)
15307 .arg(&eps)
15308 .launch(cfg)?;
15309 }
15310
15311 Ok((grad_in, grad_w, grad_b))
15312}
15313
15314#[cfg(not(feature = "cuda"))]
15316pub fn gpu_layernorm_backward(
15317 _input: &CudaBuffer<f32>,
15318 _grad_output: &CudaBuffer<f32>,
15319 _weight: &CudaBuffer<f32>,
15320 _rows: usize,
15321 _cols: usize,
15322 _eps: f32,
15323 _device: &GpuDevice,
15324) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
15325 Err(GpuError::NoCudaFeature)
15326}
15327
15328#[cfg(feature = "cuda")]
15340pub fn gpu_rmsnorm(
15341 input: &CudaBuffer<f32>,
15342 weight: &CudaBuffer<f32>,
15343 rows: usize,
15344 cols: usize,
15345 eps: f32,
15346 device: &GpuDevice,
15347) -> GpuResult<CudaBuffer<f32>> {
15348 use cudarc::driver::PushKernelArg;
15349
15350 validate_unary(input, device)?;
15351
15352 let ctx = device.context();
15353 let stream = device.stream();
15354
15355 let f = match crate::module_cache::get_or_compile(
15356 ctx,
15357 RMSNORM_PTX,
15358 "rmsnorm_kernel",
15359 device.ordinal() as u32,
15360 ) {
15361 Ok(f) => f,
15362 Err(e) => {
15363 eprintln!("ferrotorch-gpu: RMSNorm PTX compilation failed ({e:?}), CPU fallback");
15364 std::fs::write("/tmp/rmsnorm_debug.ptx", RMSNORM_PTX).ok();
15365 eprintln!(
15366 "ferrotorch-gpu: dumped PTX to /tmp/rmsnorm_debug.ptx ({} bytes)",
15367 RMSNORM_PTX.len()
15368 );
15369 let h_in = gpu_to_cpu(input, device)?;
15370 let h_w = gpu_to_cpu(weight, device)?;
15371 let mut out = vec![0.0f32; rows * cols];
15372 for r in 0..rows {
15373 let base = r * cols;
15374 let slice = &h_in[base..base + cols];
15375 let sq_mean: f32 =
15376 slice.iter().map(|&x| x * x).sum::<f32>() / cols as f32;
15377 let inv_rms = 1.0 / (sq_mean + eps).sqrt();
15378 for c in 0..cols {
15379 out[base + c] = slice[c] * inv_rms * h_w[c];
15380 }
15381 }
15382 return cpu_to_gpu(&out, device);
15383 }
15384 };
15385
15386 let mut out = alloc_zeros_f32(rows * cols, device)?;
15387 let rows_u32 = rows as u32;
15388 let cols_u32 = cols as u32;
15389
15390 let cfg = LaunchConfig {
15391 grid_dim: ((rows as u32).max(1), 1, 1),
15392 block_dim: (256, 1, 1),
15393 shared_mem_bytes: 256 * 4,
15394 };
15395
15396 unsafe {
15397 stream
15398 .launch_builder(&f)
15399 .arg(input.inner())
15400 .arg(out.inner_mut())
15401 .arg(weight.inner())
15402 .arg(&rows_u32)
15403 .arg(&cols_u32)
15404 .arg(&eps)
15405 .launch(cfg)?;
15406 }
15407
15408 Ok(out)
15409}
15410
15411#[cfg(feature = "cuda")]
15424pub fn gpu_rmsnorm_backward(
15425 input: &CudaBuffer<f32>,
15426 grad_output: &CudaBuffer<f32>,
15427 weight: &CudaBuffer<f32>,
15428 rows: usize,
15429 cols: usize,
15430 eps: f32,
15431 device: &GpuDevice,
15432) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
15433 use cudarc::driver::PushKernelArg;
15434
15435 validate_unary(input, device)?;
15436
15437 let ctx = device.context();
15438 let stream = device.stream();
15439
15440 let f = match crate::module_cache::get_or_compile(
15441 ctx,
15442 RMSNORM_BACKWARD_PTX,
15443 "rmsnorm_backward_kernel",
15444 device.ordinal() as u32,
15445 ) {
15446 Ok(f) => f,
15447 Err(_) => {
15448 let h_in = gpu_to_cpu(input, device)?;
15450 let h_go = gpu_to_cpu(grad_output, device)?;
15451 let h_w = gpu_to_cpu(weight, device)?;
15452 let mut grad_input = vec![0.0f32; rows * cols];
15453 let mut grad_weight = vec![0.0f32; cols];
15454 let n_f = cols as f32;
15455 for r in 0..rows {
15456 let base = r * cols;
15457 let x_slice = &h_in[base..base + cols];
15458 let go_slice = &h_go[base..base + cols];
15459 let sq_mean: f32 =
15460 x_slice.iter().map(|&x| x * x).sum::<f32>() / n_f;
15461 let inv_rms = 1.0 / (sq_mean + eps).sqrt();
15462 let inv_rms3 = inv_rms * inv_rms * inv_rms;
15463 let mut dot = 0.0f32;
15464 for c in 0..cols {
15465 dot += go_slice[c] * x_slice[c] * h_w[c];
15466 grad_weight[c] += go_slice[c] * x_slice[c] * inv_rms;
15467 }
15468 let coeff = dot * inv_rms3 / n_f;
15469 for c in 0..cols {
15470 grad_input[base + c] =
15471 inv_rms * h_w[c] * go_slice[c] - x_slice[c] * coeff;
15472 }
15473 }
15474 let gi = cpu_to_gpu(&grad_input, device)?;
15475 let gw = cpu_to_gpu(&grad_weight, device)?;
15476 return Ok((gi, gw));
15477 }
15478 };
15479
15480 let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
15481 let mut grad_w = alloc_zeros_f32(cols, device)?;
15482 let rows_u32 = rows as u32;
15483 let cols_u32 = cols as u32;
15484
15485 let cfg = LaunchConfig {
15487 grid_dim: ((rows as u32).max(1), 1, 1),
15488 block_dim: (256, 1, 1),
15489 shared_mem_bytes: 256 * 4,
15490 };
15491
15492 unsafe {
15493 stream
15494 .launch_builder(&f)
15495 .arg(input.inner())
15496 .arg(grad_output.inner())
15497 .arg(weight.inner())
15498 .arg(grad_in.inner_mut())
15499 .arg(grad_w.inner_mut())
15500 .arg(&rows_u32)
15501 .arg(&cols_u32)
15502 .arg(&eps)
15503 .launch(cfg)?;
15504 }
15505
15506 Ok((grad_in, grad_w))
15507}
15508
15509#[cfg(not(feature = "cuda"))]
15511pub fn gpu_rmsnorm(
15512 _input: &CudaBuffer<f32>,
15513 _weight: &CudaBuffer<f32>,
15514 _rows: usize,
15515 _cols: usize,
15516 _eps: f32,
15517 _device: &GpuDevice,
15518) -> GpuResult<CudaBuffer<f32>> {
15519 Err(GpuError::NoCudaFeature)
15520}
15521
15522#[cfg(not(feature = "cuda"))]
15524pub fn gpu_rmsnorm_backward(
15525 _input: &CudaBuffer<f32>,
15526 _grad_output: &CudaBuffer<f32>,
15527 _weight: &CudaBuffer<f32>,
15528 _rows: usize,
15529 _cols: usize,
15530 _eps: f32,
15531 _device: &GpuDevice,
15532) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
15533 Err(GpuError::NoCudaFeature)
15534}
15535
15536#[cfg(feature = "cuda")]
15546pub fn gpu_add_into(
15547 a: &CudaBuffer<f32>,
15548 b: &CudaBuffer<f32>,
15549 out: &mut CudaBuffer<f32>,
15550 device: &GpuDevice,
15551) -> GpuResult<()> {
15552 validate_binary(a, b, device)?;
15553 if out.len() < a.len() {
15554 return Err(GpuError::ShapeMismatch {
15555 op: "add_into",
15556 expected: vec![a.len()],
15557 got: vec![out.len()],
15558 });
15559 }
15560 if try_launch_binary_into(a, b, out, device, ADD_PTX, "add_kernel")? {
15561 return Ok(());
15562 }
15563 Err(GpuError::PtxCompileFailed {
15564 kernel: "add_kernel",
15565 })
15566}
15567
15568#[cfg(feature = "cuda")]
15570pub fn gpu_mul_into(
15571 a: &CudaBuffer<f32>,
15572 b: &CudaBuffer<f32>,
15573 out: &mut CudaBuffer<f32>,
15574 device: &GpuDevice,
15575) -> GpuResult<()> {
15576 validate_binary(a, b, device)?;
15577 if out.len() < a.len() {
15578 return Err(GpuError::ShapeMismatch {
15579 op: "mul_into",
15580 expected: vec![a.len()],
15581 got: vec![out.len()],
15582 });
15583 }
15584 if try_launch_binary_into(a, b, out, device, MUL_PTX, "mul_kernel")? {
15585 return Ok(());
15586 }
15587 Err(GpuError::PtxCompileFailed {
15588 kernel: "mul_kernel",
15589 })
15590}
15591
15592#[cfg(feature = "cuda")]
15594pub fn gpu_scale_into(
15595 a: &CudaBuffer<f32>,
15596 scalar: f32,
15597 out: &mut CudaBuffer<f32>,
15598 device: &GpuDevice,
15599) -> GpuResult<()> {
15600 use cudarc::driver::PushKernelArg;
15601 validate_unary(a, device)?;
15602 let n = a.len();
15603 let ctx = device.context();
15604 let stream = device.stream();
15605 let f = crate::module_cache::get_or_compile(
15606 ctx,
15607 SCALE_PTX,
15608 "scale_kernel",
15609 device.ordinal() as u32,
15610 )
15611 .map_err(|_| GpuError::PtxCompileFailed {
15612 kernel: "scale_kernel",
15613 })?;
15614 let cfg = launch_cfg(n)?;
15615 let n_u32 = n as u32;
15616 unsafe {
15617 stream
15618 .launch_builder(&f)
15619 .arg(a.inner())
15620 .arg(out.inner_mut())
15621 .arg(&scalar)
15622 .arg(&n_u32)
15623 .launch(cfg)?;
15624 }
15625 Ok(())
15626}
15627
15628#[cfg(feature = "cuda")]
15635pub fn gpu_fill_f32(
15636 n: usize,
15637 scalar: f32,
15638 device: &GpuDevice,
15639) -> GpuResult<CudaBuffer<f32>> {
15640 use cudarc::driver::PushKernelArg;
15641
15642 let ctx = device.context();
15643 let stream = device.stream();
15644 let f = crate::module_cache::get_or_compile(
15645 ctx,
15646 FILL_F32_PTX,
15647 "fill_f32_kernel",
15648 device.ordinal() as u32,
15649 )
15650 .map_err(|_| GpuError::PtxCompileFailed {
15651 kernel: "fill_f32_kernel",
15652 })?;
15653
15654 let mut out = alloc_zeros_f32(n, device)?;
15655 if n == 0 {
15656 return Ok(out);
15657 }
15658 let cfg = launch_cfg(n)?;
15659 let n_u32 = n as u32;
15660 unsafe {
15661 stream
15662 .launch_builder(&f)
15663 .arg(out.inner_mut())
15664 .arg(&scalar)
15665 .arg(&n_u32)
15666 .launch(cfg)?;
15667 }
15668 Ok(out)
15669}
15670
15671#[cfg(feature = "cuda")]
15688pub fn gpu_has_inf_nan(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<bool> {
15689 let n = a.len();
15690 if n == 0 {
15691 return Ok(false);
15692 }
15693
15694 validate_unary(a, device)?;
15695
15696 let host: Vec<f32> = crate::transfer::gpu_to_cpu(a, device)?;
15697 Ok(host.iter().any(|v| !v.is_finite()))
15698}
15699
15700#[cfg(not(feature = "cuda"))]
15702pub fn gpu_has_inf_nan(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<bool> {
15703 Err(GpuError::NoCudaFeature)
15704}
15705
15706#[cfg(feature = "cuda")]
15708pub fn gpu_gelu_into(
15709 a: &CudaBuffer<f32>,
15710 out: &mut CudaBuffer<f32>,
15711 device: &GpuDevice,
15712) -> GpuResult<()> {
15713 validate_unary(a, device)?;
15714 if try_launch_unary_into(a, out, device, GELU_PTX, "gelu_kernel")? {
15715 return Ok(());
15716 }
15717 Err(GpuError::PtxCompileFailed {
15718 kernel: "gelu_kernel",
15719 })
15720}
15721
15722#[cfg(feature = "cuda")]
15724pub fn gpu_embed_lookup_into(
15725 idx: &CudaBuffer<f32>,
15726 weight: &CudaBuffer<f32>,
15727 d: usize,
15728 out: &mut CudaBuffer<f32>,
15729 device: &GpuDevice,
15730) -> GpuResult<()> {
15731 use cudarc::driver::PushKernelArg;
15732 let ctx = device.context();
15733 let stream = device.stream();
15734 let f = crate::module_cache::get_or_compile(
15735 ctx,
15736 EMBED_LOOKUP_PTX,
15737 "embed_lookup_kernel",
15738 device.ordinal() as u32,
15739 )
15740 .map_err(|_| GpuError::PtxCompileFailed {
15741 kernel: "embed_lookup_kernel",
15742 })?;
15743 let cfg = launch_cfg(d)?;
15744 let d_u32 = d as u32;
15745 unsafe {
15746 stream
15747 .launch_builder(&f)
15748 .arg(idx.inner())
15749 .arg(weight.inner())
15750 .arg(out.inner_mut())
15751 .arg(&d_u32)
15752 .launch(cfg)?;
15753 }
15754 Ok(())
15755}
15756
15757#[cfg(feature = "cuda")]
15765pub fn gpu_embed_lookup_batch(
15766 indices: &CudaBuffer<f32>,
15767 weight: &CudaBuffer<f32>,
15768 n: usize,
15769 d: usize,
15770 device: &GpuDevice,
15771) -> GpuResult<CudaBuffer<f32>> {
15772 use cudarc::driver::PushKernelArg;
15773
15774 let total = n * d;
15775 if total == 0 {
15776 return alloc_zeros_f32(0, device);
15777 }
15778
15779 let ctx = device.context();
15780 let stream = device.stream();
15781
15782 let f = match crate::module_cache::get_or_compile(
15783 ctx,
15784 EMBED_LOOKUP_BATCH_PTX,
15785 "embed_lookup_batch_kernel",
15786 device.ordinal() as u32,
15787 ) {
15788 Ok(f) => f,
15789 Err(_) => {
15790 let idx_host = gpu_to_cpu(indices, device)?;
15792 let weight_host = gpu_to_cpu(weight, device)?;
15793 let mut out = Vec::with_capacity(total);
15794 for &idx_f in &idx_host {
15795 let row = idx_f as usize;
15796 let start = row * d;
15797 out.extend_from_slice(&weight_host[start..start + d]);
15798 }
15799 return cpu_to_gpu(&out, device);
15800 }
15801 };
15802
15803 let mut out = alloc_zeros_f32(total, device)?;
15804 let cfg = launch_cfg(total)?;
15805 let d_u32 = d as u32;
15806 let total_u32 = total as u32;
15807
15808 unsafe {
15809 stream
15810 .launch_builder(&f)
15811 .arg(indices.inner())
15812 .arg(weight.inner())
15813 .arg(out.inner_mut())
15814 .arg(&d_u32)
15815 .arg(&total_u32)
15816 .launch(cfg)?;
15817 }
15818
15819 Ok(out)
15820}
15821
15822#[cfg(feature = "cuda")]
15832pub fn gpu_scatter_add_rows(
15833 grad_output: &CudaBuffer<f32>,
15834 indices: &CudaBuffer<f32>,
15835 num_embeddings: usize,
15836 d: usize,
15837 device: &GpuDevice,
15838) -> GpuResult<CudaBuffer<f32>> {
15839 use cudarc::driver::PushKernelArg;
15840
15841 let n = indices.len();
15842 let total = n * d;
15843
15844 if total == 0 {
15845 return alloc_zeros_f32(num_embeddings * d, device);
15846 }
15847
15848 let ctx = device.context();
15849 let stream = device.stream();
15850
15851 let f = match crate::module_cache::get_or_compile(
15852 ctx,
15853 SCATTER_ADD_ROWS_PTX,
15854 "scatter_add_rows_kernel",
15855 device.ordinal() as u32,
15856 ) {
15857 Ok(f) => f,
15858 Err(_) => {
15859 let go_host = gpu_to_cpu(grad_output, device)?;
15861 let idx_host = gpu_to_cpu(indices, device)?;
15862 let mut result = vec![0.0f32; num_embeddings * d];
15863 for (i, &idx_f) in idx_host.iter().enumerate() {
15864 let row = idx_f as usize;
15865 for j in 0..d {
15866 result[row * d + j] += go_host[i * d + j];
15867 }
15868 }
15869 return cpu_to_gpu(&result, device);
15870 }
15871 };
15872
15873 let mut out = alloc_zeros_f32(num_embeddings * d, device)?;
15874 let cfg = launch_cfg(total)?;
15875 let d_u32 = d as u32;
15876 let total_u32 = total as u32;
15877
15878 unsafe {
15879 stream
15880 .launch_builder(&f)
15881 .arg(grad_output.inner())
15882 .arg(indices.inner())
15883 .arg(out.inner_mut())
15884 .arg(&d_u32)
15885 .arg(&total_u32)
15886 .launch(cfg)?;
15887 }
15888
15889 Ok(out)
15890}
15891
15892#[cfg(feature = "cuda")]
15894pub fn gpu_transpose_2d_into(
15895 a: &CudaBuffer<f32>,
15896 m: usize,
15897 n: usize,
15898 out: &mut CudaBuffer<f32>,
15899 device: &GpuDevice,
15900) -> GpuResult<()> {
15901 use cudarc::driver::PushKernelArg;
15902 let total = m * n;
15903 let ctx = device.context();
15904 let stream = device.stream();
15905 let f = crate::module_cache::get_or_compile(
15906 ctx,
15907 TRANSPOSE_2D_PTX,
15908 "transpose_2d_kernel",
15909 device.ordinal() as u32,
15910 )
15911 .map_err(|_| GpuError::PtxCompileFailed {
15912 kernel: "transpose_2d_kernel",
15913 })?;
15914 let cfg = launch_cfg(total)?;
15915 let m_u32 = m as u32;
15916 let n_u32 = n as u32;
15917 let total_u32 = total as u32;
15918 unsafe {
15919 stream
15920 .launch_builder(&f)
15921 .arg(a.inner())
15922 .arg(out.inner_mut())
15923 .arg(&m_u32)
15924 .arg(&n_u32)
15925 .arg(&total_u32)
15926 .launch(cfg)?;
15927 }
15928 Ok(())
15929}
15930
15931#[cfg(feature = "cuda")]
15933pub fn gpu_permute_0213_into(
15934 a: &CudaBuffer<f32>,
15935 d0: usize,
15936 d1: usize,
15937 d2: usize,
15938 d3: usize,
15939 out: &mut CudaBuffer<f32>,
15940 device: &GpuDevice,
15941) -> GpuResult<()> {
15942 use cudarc::driver::PushKernelArg;
15943 let total = d0 * d1 * d2 * d3;
15944 let ctx = device.context();
15945 let stream = device.stream();
15946 let f = crate::module_cache::get_or_compile(
15947 ctx,
15948 PERMUTE_0213_PTX,
15949 "permute_0213_kernel",
15950 device.ordinal() as u32,
15951 )
15952 .map_err(|_| GpuError::PtxCompileFailed {
15953 kernel: "permute_0213_kernel",
15954 })?;
15955 let cfg = launch_cfg(total)?;
15956 let (d0u, d1u, d2u, d3u, tu) = (d0 as u32, d1 as u32, d2 as u32, d3 as u32, total as u32);
15957 unsafe {
15958 stream
15959 .launch_builder(&f)
15960 .arg(a.inner())
15961 .arg(out.inner_mut())
15962 .arg(&d0u)
15963 .arg(&d1u)
15964 .arg(&d2u)
15965 .arg(&d3u)
15966 .arg(&tu)
15967 .launch(cfg)?;
15968 }
15969 Ok(())
15970}
15971
15972#[cfg(feature = "cuda")]
15974pub fn gpu_softmax_into(
15975 a: &CudaBuffer<f32>,
15976 rows: usize,
15977 cols: usize,
15978 out: &mut CudaBuffer<f32>,
15979 device: &GpuDevice,
15980) -> GpuResult<()> {
15981 use cudarc::driver::PushKernelArg;
15982 let ctx = device.context();
15983 let stream = device.stream();
15984 let f = crate::module_cache::get_or_compile(
15985 ctx,
15986 SOFTMAX_PTX,
15987 "softmax_kernel",
15988 device.ordinal() as u32,
15989 )
15990 .map_err(|_| GpuError::PtxCompileFailed {
15991 kernel: "softmax_kernel",
15992 })?;
15993 let block_size = 256u32;
15994 let grid_size = rows as u32;
15995 let cfg = LaunchConfig {
15996 grid_dim: (grid_size, 1, 1),
15997 block_dim: (block_size, 1, 1),
15998 shared_mem_bytes: (cols as u32) * 4,
15999 };
16000 let rows_u32 = rows as u32;
16001 let cols_u32 = cols as u32;
16002 unsafe {
16003 stream
16004 .launch_builder(&f)
16005 .arg(a.inner())
16006 .arg(out.inner_mut())
16007 .arg(&rows_u32)
16008 .arg(&cols_u32)
16009 .launch(cfg)?;
16010 }
16011 Ok(())
16012}
16013
16014#[cfg(feature = "cuda")]
16016#[allow(clippy::too_many_arguments)]
16017pub fn gpu_layernorm_into(
16018 input: &CudaBuffer<f32>,
16019 weight: &CudaBuffer<f32>,
16020 bias: &CudaBuffer<f32>,
16021 rows: usize,
16022 cols: usize,
16023 eps: f32,
16024 out: &mut CudaBuffer<f32>,
16025 device: &GpuDevice,
16026) -> GpuResult<()> {
16027 use cudarc::driver::PushKernelArg;
16028 let ctx = device.context();
16029 let stream = device.stream();
16030 let f = crate::module_cache::get_or_compile(
16031 ctx,
16032 LAYERNORM_PTX,
16033 "layernorm_kernel",
16034 device.ordinal() as u32,
16035 )
16036 .map_err(|_| GpuError::PtxCompileFailed {
16037 kernel: "layernorm_kernel",
16038 })?;
16039 let block_size = 256u32;
16040 let grid_size = rows as u32;
16041 let cfg = LaunchConfig {
16042 grid_dim: (grid_size, 1, 1),
16043 block_dim: (block_size, 1, 1),
16044 shared_mem_bytes: (cols as u32) * 4,
16045 };
16046 let rows_u32 = rows as u32;
16047 let cols_u32 = cols as u32;
16048 unsafe {
16049 stream
16050 .launch_builder(&f)
16051 .arg(input.inner())
16052 .arg(out.inner_mut())
16053 .arg(weight.inner())
16054 .arg(bias.inner())
16055 .arg(&rows_u32)
16056 .arg(&cols_u32)
16057 .arg(&eps)
16058 .launch(cfg)?;
16059 }
16060 Ok(())
16061}
16062
16063#[cfg(feature = "cuda")]
16066pub fn gpu_slice_read_into(
16067 src: &CudaBuffer<f32>,
16068 n_batch: usize,
16069 d: usize,
16070 len: usize,
16071 max_len: usize,
16072 out: &mut CudaBuffer<f32>,
16073 device: &GpuDevice,
16074) -> GpuResult<()> {
16075 use cudarc::driver::PushKernelArg;
16076 let total = n_batch * len * d;
16077 let ctx = device.context();
16078 let stream = device.stream();
16079 let f = crate::module_cache::get_or_compile(
16080 ctx,
16081 SLICE_READ_PTX,
16082 "slice_read_kernel",
16083 device.ordinal() as u32,
16084 )
16085 .map_err(|_| GpuError::PtxCompileFailed {
16086 kernel: "slice_read_kernel",
16087 })?;
16088 let cfg = launch_cfg(total)?;
16089 let total_u32 = total as u32;
16090 let d_u32 = d as u32;
16091 let len_u32 = len as u32;
16092 let max_len_u32 = max_len as u32;
16093 unsafe {
16094 stream
16095 .launch_builder(&f)
16096 .arg(src.inner())
16097 .arg(out.inner_mut())
16098 .arg(&total_u32)
16099 .arg(&d_u32)
16100 .arg(&len_u32)
16101 .arg(&max_len_u32)
16102 .launch(cfg)?;
16103 }
16104 Ok(())
16105}
16106
16107#[cfg(feature = "cuda")]
16109pub fn gpu_small_matmul_into(
16110 a: &CudaBuffer<f32>,
16111 b: &CudaBuffer<f32>,
16112 m: usize,
16113 k: usize,
16114 n: usize,
16115 out: &mut CudaBuffer<f32>,
16116 device: &GpuDevice,
16117) -> GpuResult<()> {
16118 use cudarc::driver::PushKernelArg;
16119 let total = m * n;
16120 let ctx = device.context();
16121 let stream = device.stream();
16122 let f = crate::module_cache::get_or_compile(
16123 ctx,
16124 SMALL_MATMUL_PTX,
16125 "small_matmul_kernel",
16126 device.ordinal() as u32,
16127 )
16128 .map_err(|_| GpuError::PtxCompileFailed {
16129 kernel: "small_matmul_kernel",
16130 })?;
16131 let cfg = launch_cfg(total)?;
16132 let (m_u32, k_u32, n_u32, total_u32) = (m as u32, k as u32, n as u32, total as u32);
16133 unsafe {
16134 stream
16135 .launch_builder(&f)
16136 .arg(a.inner())
16137 .arg(b.inner())
16138 .arg(out.inner_mut())
16139 .arg(&m_u32)
16140 .arg(&k_u32)
16141 .arg(&n_u32)
16142 .arg(&total_u32)
16143 .launch(cfg)?;
16144 }
16145 Ok(())
16146}
16147
16148#[cfg(feature = "cuda")]
16155pub fn gpu_slice_write_indirect(
16156 src: &CudaBuffer<f32>,
16157 dst: &mut CudaBuffer<f32>,
16158 n_batch: usize,
16159 d: usize,
16160 max_len: usize,
16161 pos_ptr: &cudarc::driver::CudaSlice<u32>,
16162 device: &GpuDevice,
16163) -> GpuResult<()> {
16164 use cudarc::driver::PushKernelArg;
16165 let total = n_batch * d;
16166 let ctx = device.context();
16167 let stream = device.stream();
16168 let f = crate::module_cache::get_or_compile(
16169 ctx,
16170 SLICE_WRITE_INDIRECT_PTX,
16171 "slice_write_indirect_kernel",
16172 device.ordinal() as u32,
16173 )
16174 .map_err(|_| GpuError::PtxCompileFailed {
16175 kernel: "slice_write_indirect_kernel",
16176 })?;
16177 let cfg = launch_cfg(total)?;
16178 let n_u32 = total as u32;
16179 let d_u32 = d as u32;
16180 let max_len_u32 = max_len as u32;
16181 unsafe {
16182 stream
16183 .launch_builder(&f)
16184 .arg(src.inner())
16185 .arg(dst.inner_mut())
16186 .arg(&n_u32)
16187 .arg(&d_u32)
16188 .arg(&max_len_u32)
16189 .arg(pos_ptr)
16190 .launch(cfg)?;
16191 }
16192 Ok(())
16193}
16194
16195#[cfg(feature = "cuda")]
16199pub fn gpu_causal_mask_indirect(
16200 total_len_ptr: &cudarc::driver::CudaSlice<u32>,
16201 n_head: usize,
16202 max_pos: usize,
16203 out: &mut CudaBuffer<f32>,
16204 device: &GpuDevice,
16205) -> GpuResult<()> {
16206 use cudarc::driver::PushKernelArg;
16207 let total = n_head * max_pos;
16208 let ctx = device.context();
16209 let stream = device.stream();
16210 let f = crate::module_cache::get_or_compile(
16211 ctx,
16212 CAUSAL_MASK_INDIRECT_PTX,
16213 "causal_mask_indirect_kernel",
16214 device.ordinal() as u32,
16215 )
16216 .map_err(|_| GpuError::PtxCompileFailed {
16217 kernel: "causal_mask_indirect_kernel",
16218 })?;
16219 let cfg = launch_cfg(total)?;
16220 let max_pos_u32 = max_pos as u32;
16221 let total_u32 = total as u32;
16222 unsafe {
16223 stream
16224 .launch_builder(&f)
16225 .arg(total_len_ptr)
16226 .arg(out.inner_mut())
16227 .arg(&max_pos_u32)
16228 .arg(&total_u32)
16229 .launch(cfg)?;
16230 }
16231 Ok(())
16232}
16233
16234#[cfg(feature = "cuda")]
16242pub fn precompile_decode_kernels(device: &GpuDevice) -> GpuResult<()> {
16243 let ctx = device.context();
16244 ctx.bind_to_thread()?;
16245 let ord = device.ordinal() as u32;
16246 let compile = |ptx: &'static str, name: &'static str| -> GpuResult<()> {
16247 crate::module_cache::get_or_compile(ctx, ptx, name, ord)
16248 .map(|_| ())
16249 .map_err(GpuError::Driver)
16250 };
16251 compile(ADD_PTX, "add_kernel")?;
16252 compile(MUL_PTX, "mul_kernel")?;
16253 compile(SCALE_PTX, "scale_kernel")?;
16254 compile(GELU_PTX, "gelu_kernel")?;
16255 compile(SOFTMAX_PTX, "softmax_kernel")?;
16256 compile(LAYERNORM_PTX, "layernorm_kernel")?;
16257 compile(PERMUTE_0213_PTX, "permute_0213_kernel")?;
16258 compile(EMBED_LOOKUP_PTX, "embed_lookup_kernel")?;
16259 compile(EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel")?;
16260 compile(SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel")?;
16261 compile(SMALL_MATMUL_PTX, "small_matmul_kernel")?;
16262 compile(SLICE_WRITE_INDIRECT_PTX, "slice_write_indirect_kernel")?;
16263 compile(CAUSAL_MASK_INDIRECT_PTX, "causal_mask_indirect_kernel")?;
16264 compile(SLICE_READ_PTX, "slice_read_kernel")?;
16265 compile(RELU_BACKWARD_PTX, "relu_backward_kernel")?;
16266 compile(GELU_BACKWARD_PTX, "gelu_backward_kernel")?;
16267 Ok(())
16268}
16269
16270#[cfg(not(feature = "cuda"))]
16272pub fn precompile_decode_kernels(_device: &GpuDevice) -> GpuResult<()> {
16273 Err(GpuError::NoCudaFeature)
16274}
16275
16276#[cfg(not(feature = "cuda"))]
16282pub fn gpu_gelu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16283 Err(GpuError::NoCudaFeature)
16284}
16285
16286#[cfg(not(feature = "cuda"))]
16288pub fn gpu_gelu_tanh(
16289 _input: &CudaBuffer<f32>,
16290 _device: &GpuDevice,
16291) -> GpuResult<CudaBuffer<f32>> {
16292 Err(GpuError::NoCudaFeature)
16293}
16294
16295#[cfg(not(feature = "cuda"))]
16297pub fn gpu_gelu_erf(
16298 _input: &CudaBuffer<f32>,
16299 _device: &GpuDevice,
16300) -> GpuResult<CudaBuffer<f32>> {
16301 Err(GpuError::NoCudaFeature)
16302}
16303
16304#[cfg(not(feature = "cuda"))]
16306pub fn gpu_gelu_backward_tanh(
16307 _grad: &CudaBuffer<f32>,
16308 _input: &CudaBuffer<f32>,
16309 _device: &GpuDevice,
16310) -> GpuResult<CudaBuffer<f32>> {
16311 Err(GpuError::NoCudaFeature)
16312}
16313
16314#[cfg(not(feature = "cuda"))]
16316pub fn gpu_silu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16317 Err(GpuError::NoCudaFeature)
16318}
16319
16320#[cfg(not(feature = "cuda"))]
16322pub fn gpu_silu_backward(
16323 _grad: &CudaBuffer<f32>,
16324 _input: &CudaBuffer<f32>,
16325 _device: &GpuDevice,
16326) -> GpuResult<CudaBuffer<f32>> {
16327 Err(GpuError::NoCudaFeature)
16328}
16329
16330#[cfg(not(feature = "cuda"))]
16332pub fn gpu_elu(
16333 _input: &CudaBuffer<f32>,
16334 _alpha: f32,
16335 _device: &GpuDevice,
16336) -> GpuResult<CudaBuffer<f32>> {
16337 Err(GpuError::NoCudaFeature)
16338}
16339
16340#[cfg(not(feature = "cuda"))]
16342pub fn gpu_elu_backward(
16343 _grad: &CudaBuffer<f32>,
16344 _input: &CudaBuffer<f32>,
16345 _alpha: f32,
16346 _device: &GpuDevice,
16347) -> GpuResult<CudaBuffer<f32>> {
16348 Err(GpuError::NoCudaFeature)
16349}
16350
16351#[cfg(not(feature = "cuda"))]
16353pub fn gpu_mish(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16354 Err(GpuError::NoCudaFeature)
16355}
16356
16357#[cfg(not(feature = "cuda"))]
16359pub fn gpu_mish_backward(
16360 _grad: &CudaBuffer<f32>,
16361 _input: &CudaBuffer<f32>,
16362 _device: &GpuDevice,
16363) -> GpuResult<CudaBuffer<f32>> {
16364 Err(GpuError::NoCudaFeature)
16365}
16366
16367#[cfg(not(feature = "cuda"))]
16369pub fn gpu_clamp(
16370 _input: &CudaBuffer<f32>,
16371 _min_val: f32,
16372 _max_val: f32,
16373 _device: &GpuDevice,
16374) -> GpuResult<CudaBuffer<f32>> {
16375 Err(GpuError::NoCudaFeature)
16376}
16377
16378#[cfg(not(feature = "cuda"))]
16380pub fn gpu_div(
16381 _a: &CudaBuffer<f32>,
16382 _b: &CudaBuffer<f32>,
16383 _device: &GpuDevice,
16384) -> GpuResult<CudaBuffer<f32>> {
16385 Err(GpuError::NoCudaFeature)
16386}
16387
16388#[cfg(not(feature = "cuda"))]
16390pub fn gpu_exp(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16391 Err(GpuError::NoCudaFeature)
16392}
16393
16394#[cfg(not(feature = "cuda"))]
16396pub fn gpu_log(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16397 Err(GpuError::NoCudaFeature)
16398}
16399
16400#[cfg(not(feature = "cuda"))]
16402pub fn gpu_sqrt(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16403 Err(GpuError::NoCudaFeature)
16404}
16405
16406#[cfg(not(feature = "cuda"))]
16408pub fn gpu_pow(
16409 _a: &CudaBuffer<f32>,
16410 _exponent: f32,
16411 _device: &GpuDevice,
16412) -> GpuResult<CudaBuffer<f32>> {
16413 Err(GpuError::NoCudaFeature)
16414}
16415
16416#[cfg(not(feature = "cuda"))]
16418pub fn gpu_abs(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16419 Err(GpuError::NoCudaFeature)
16420}
16421
16422#[cfg(not(feature = "cuda"))]
16424pub fn gpu_sigmoid(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16425 Err(GpuError::NoCudaFeature)
16426}
16427
16428#[cfg(not(feature = "cuda"))]
16430pub fn gpu_tanh(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16431 Err(GpuError::NoCudaFeature)
16432}
16433
16434#[cfg(not(feature = "cuda"))]
16436pub fn gpu_layernorm(
16437 _input: &CudaBuffer<f32>,
16438 _weight: &CudaBuffer<f32>,
16439 _bias: &CudaBuffer<f32>,
16440 _rows: usize,
16441 _cols: usize,
16442 _eps: f32,
16443 _device: &GpuDevice,
16444) -> GpuResult<CudaBuffer<f32>> {
16445 Err(GpuError::NoCudaFeature)
16446}
16447
16448#[cfg(not(feature = "cuda"))]
16450pub fn gpu_transpose_2d(
16451 _input: &CudaBuffer<f32>,
16452 _m: usize,
16453 _n: usize,
16454 _device: &GpuDevice,
16455) -> GpuResult<CudaBuffer<f32>> {
16456 Err(GpuError::NoCudaFeature)
16457}
16458
16459#[cfg(not(feature = "cuda"))]
16461pub fn gpu_add(
16462 _a: &CudaBuffer<f32>,
16463 _b: &CudaBuffer<f32>,
16464 _device: &GpuDevice,
16465) -> GpuResult<CudaBuffer<f32>> {
16466 Err(GpuError::NoCudaFeature)
16467}
16468
16469#[cfg(not(feature = "cuda"))]
16471pub fn gpu_sub(
16472 _a: &CudaBuffer<f32>,
16473 _b: &CudaBuffer<f32>,
16474 _device: &GpuDevice,
16475) -> GpuResult<CudaBuffer<f32>> {
16476 Err(GpuError::NoCudaFeature)
16477}
16478
16479#[cfg(not(feature = "cuda"))]
16481pub fn gpu_mul(
16482 _a: &CudaBuffer<f32>,
16483 _b: &CudaBuffer<f32>,
16484 _device: &GpuDevice,
16485) -> GpuResult<CudaBuffer<f32>> {
16486 Err(GpuError::NoCudaFeature)
16487}
16488
16489#[cfg(not(feature = "cuda"))]
16491pub fn gpu_neg(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16492 Err(GpuError::NoCudaFeature)
16493}
16494
16495#[cfg(not(feature = "cuda"))]
16497pub fn gpu_relu(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16498 Err(GpuError::NoCudaFeature)
16499}
16500
16501#[cfg(not(feature = "cuda"))]
16503pub fn gpu_scale(
16504 _a: &CudaBuffer<f32>,
16505 _scalar: f32,
16506 _device: &GpuDevice,
16507) -> GpuResult<CudaBuffer<f32>> {
16508 Err(GpuError::NoCudaFeature)
16509}
16510
16511#[cfg(not(feature = "cuda"))]
16513pub fn gpu_broadcast_add(
16514 _a: &CudaBuffer<f32>,
16515 _b: &CudaBuffer<f32>,
16516 _a_shape: &[usize],
16517 _b_shape: &[usize],
16518 _out_shape: &[usize],
16519 _device: &GpuDevice,
16520) -> GpuResult<CudaBuffer<f32>> {
16521 Err(GpuError::NoCudaFeature)
16522}
16523
16524#[cfg(not(feature = "cuda"))]
16526pub fn gpu_broadcast_sub(
16527 _a: &CudaBuffer<f32>,
16528 _b: &CudaBuffer<f32>,
16529 _a_shape: &[usize],
16530 _b_shape: &[usize],
16531 _out_shape: &[usize],
16532 _device: &GpuDevice,
16533) -> GpuResult<CudaBuffer<f32>> {
16534 Err(GpuError::NoCudaFeature)
16535}
16536
16537#[cfg(not(feature = "cuda"))]
16539pub fn gpu_broadcast_mul(
16540 _a: &CudaBuffer<f32>,
16541 _b: &CudaBuffer<f32>,
16542 _a_shape: &[usize],
16543 _b_shape: &[usize],
16544 _out_shape: &[usize],
16545 _device: &GpuDevice,
16546) -> GpuResult<CudaBuffer<f32>> {
16547 Err(GpuError::NoCudaFeature)
16548}
16549
16550#[cfg(not(feature = "cuda"))]
16552pub fn gpu_softmax(
16553 _input: &CudaBuffer<f32>,
16554 _rows: usize,
16555 _cols: usize,
16556 _device: &GpuDevice,
16557) -> GpuResult<CudaBuffer<f32>> {
16558 Err(GpuError::NoCudaFeature)
16559}
16560
16561#[cfg(not(feature = "cuda"))]
16563pub fn gpu_dropout(
16564 _input: &CudaBuffer<f32>,
16565 _threshold: u32,
16566 _scale: f32,
16567 _seed: u32,
16568 _device: &GpuDevice,
16569) -> GpuResult<CudaBuffer<f32>> {
16570 Err(GpuError::NoCudaFeature)
16571}
16572
16573#[cfg(not(feature = "cuda"))]
16575pub fn gpu_permute_0213(
16576 _input: &CudaBuffer<f32>,
16577 _d0: usize,
16578 _d1: usize,
16579 _d2: usize,
16580 _d3: usize,
16581 _device: &GpuDevice,
16582) -> GpuResult<CudaBuffer<f32>> {
16583 Err(GpuError::NoCudaFeature)
16584}
16585
16586#[cfg(not(feature = "cuda"))]
16588pub fn gpu_slice_write(
16589 _src: &CudaBuffer<f32>,
16590 _dst: &mut CudaBuffer<f32>,
16591 _n_batch: usize,
16592 _d: usize,
16593 _max_len: usize,
16594 _pos: usize,
16595 _device: &GpuDevice,
16596) -> GpuResult<()> {
16597 Err(GpuError::NoCudaFeature)
16598}
16599
16600#[cfg(not(feature = "cuda"))]
16602pub fn gpu_slice_read(
16603 _src: &CudaBuffer<f32>,
16604 _n_batch: usize,
16605 _d: usize,
16606 _len: usize,
16607 _max_len: usize,
16608 _device: &GpuDevice,
16609) -> GpuResult<CudaBuffer<f32>> {
16610 Err(GpuError::NoCudaFeature)
16611}
16612
16613#[cfg(not(feature = "cuda"))]
16615pub fn gpu_embed_lookup(
16616 _idx: &CudaBuffer<f32>,
16617 _weight: &CudaBuffer<f32>,
16618 _d: usize,
16619 _device: &GpuDevice,
16620) -> GpuResult<CudaBuffer<f32>> {
16621 Err(GpuError::NoCudaFeature)
16622}
16623
16624#[cfg(not(feature = "cuda"))]
16626pub fn gpu_embed_lookup_batch(
16627 _indices: &CudaBuffer<f32>,
16628 _weight: &CudaBuffer<f32>,
16629 _n: usize,
16630 _d: usize,
16631 _device: &GpuDevice,
16632) -> GpuResult<CudaBuffer<f32>> {
16633 Err(GpuError::NoCudaFeature)
16634}
16635
16636#[cfg(not(feature = "cuda"))]
16638pub fn gpu_scatter_add_rows(
16639 _grad_output: &CudaBuffer<f32>,
16640 _indices: &CudaBuffer<f32>,
16641 _num_embeddings: usize,
16642 _d: usize,
16643 _device: &GpuDevice,
16644) -> GpuResult<CudaBuffer<f32>> {
16645 Err(GpuError::NoCudaFeature)
16646}
16647
16648#[cfg(not(feature = "cuda"))]
16650pub fn gpu_relu_backward(
16651 _grad: &CudaBuffer<f32>,
16652 _input: &CudaBuffer<f32>,
16653 _device: &GpuDevice,
16654) -> GpuResult<CudaBuffer<f32>> {
16655 Err(GpuError::NoCudaFeature)
16656}
16657
16658#[cfg(not(feature = "cuda"))]
16660pub fn gpu_abs_backward(
16661 _grad: &CudaBuffer<f32>,
16662 _input: &CudaBuffer<f32>,
16663 _device: &GpuDevice,
16664) -> GpuResult<CudaBuffer<f32>> {
16665 Err(GpuError::NoCudaFeature)
16666}
16667
16668#[cfg(not(feature = "cuda"))]
16670pub fn gpu_fill_f32(
16671 _n: usize,
16672 _scalar: f32,
16673 _device: &GpuDevice,
16674) -> GpuResult<CudaBuffer<f32>> {
16675 Err(GpuError::NoCudaFeature)
16676}
16677
16678#[cfg(not(feature = "cuda"))]
16680pub fn gpu_gelu_backward(
16681 _grad: &CudaBuffer<f32>,
16682 _input: &CudaBuffer<f32>,
16683 _device: &GpuDevice,
16684) -> GpuResult<CudaBuffer<f32>> {
16685 Err(GpuError::NoCudaFeature)
16686}
16687
16688#[cfg(not(feature = "cuda"))]
16690pub fn gpu_index_select_1d(
16691 _input: &CudaBuffer<f32>,
16692 _indices: &CudaBuffer<f32>,
16693 _device: &GpuDevice,
16694) -> GpuResult<CudaBuffer<f32>> {
16695 Err(GpuError::NoCudaFeature)
16696}
16697
16698#[cfg(not(feature = "cuda"))]
16700pub fn gpu_scatter_add_1d(
16701 _grad_output: &CudaBuffer<f32>,
16702 _indices: &CudaBuffer<f32>,
16703 _input_len: usize,
16704 _device: &GpuDevice,
16705) -> GpuResult<CudaBuffer<f32>> {
16706 Err(GpuError::NoCudaFeature)
16707}
16708
16709#[cfg(not(feature = "cuda"))]
16711pub fn gpu_masked_fill(
16712 _input: &CudaBuffer<f32>,
16713 _mask: &CudaBuffer<f32>,
16714 _value: f32,
16715 _device: &GpuDevice,
16716) -> GpuResult<CudaBuffer<f32>> {
16717 Err(GpuError::NoCudaFeature)
16718}
16719
16720#[cfg(not(feature = "cuda"))]
16722pub fn gpu_masked_zero(
16723 _grad: &CudaBuffer<f32>,
16724 _mask: &CudaBuffer<f32>,
16725 _device: &GpuDevice,
16726) -> GpuResult<CudaBuffer<f32>> {
16727 Err(GpuError::NoCudaFeature)
16728}
16729
16730#[cfg(not(feature = "cuda"))]
16732pub fn gpu_sigmoid_backward(
16733 _grad: &CudaBuffer<f32>,
16734 _output: &CudaBuffer<f32>,
16735 _device: &GpuDevice,
16736) -> GpuResult<CudaBuffer<f32>> {
16737 Err(GpuError::NoCudaFeature)
16738}
16739
16740#[cfg(not(feature = "cuda"))]
16742pub fn gpu_tanh_backward(
16743 _grad: &CudaBuffer<f32>,
16744 _output: &CudaBuffer<f32>,
16745 _device: &GpuDevice,
16746) -> GpuResult<CudaBuffer<f32>> {
16747 Err(GpuError::NoCudaFeature)
16748}
16749
16750#[cfg(not(feature = "cuda"))]
16752pub fn gpu_softmax_backward(
16753 _grad: &CudaBuffer<f32>,
16754 _output: &CudaBuffer<f32>,
16755 _cols: usize,
16756 _device: &GpuDevice,
16757) -> GpuResult<CudaBuffer<f32>> {
16758 Err(GpuError::NoCudaFeature)
16759}
16760
16761#[cfg(not(feature = "cuda"))]
16763pub fn gpu_log_softmax(
16764 _input: &CudaBuffer<f32>,
16765 _cols: usize,
16766 _device: &GpuDevice,
16767) -> GpuResult<CudaBuffer<f32>> {
16768 Err(GpuError::NoCudaFeature)
16769}
16770
16771#[cfg(not(feature = "cuda"))]
16773pub fn gpu_log_softmax_backward(
16774 _grad: &CudaBuffer<f32>,
16775 _output: &CudaBuffer<f32>,
16776 _cols: usize,
16777 _device: &GpuDevice,
16778) -> GpuResult<CudaBuffer<f32>> {
16779 Err(GpuError::NoCudaFeature)
16780}
16781
16782#[cfg(not(feature = "cuda"))]
16784pub fn gpu_sum_axis(
16785 _a: &CudaBuffer<f32>,
16786 _outer: usize,
16787 _axis_size: usize,
16788 _inner: usize,
16789 _device: &GpuDevice,
16790) -> GpuResult<CudaBuffer<f32>> {
16791 Err(GpuError::NoCudaFeature)
16792}
16793
16794#[cfg(not(feature = "cuda"))]
16796pub fn gpu_cumsum(
16797 _input: &CudaBuffer<f32>,
16798 _outer: usize,
16799 _dim_size: usize,
16800 _inner: usize,
16801 _device: &GpuDevice,
16802) -> GpuResult<CudaBuffer<f32>> {
16803 Err(GpuError::NoCudaFeature)
16804}
16805
16806#[cfg(not(feature = "cuda"))]
16808pub fn gpu_cumprod(
16809 _input: &CudaBuffer<f32>,
16810 _outer: usize,
16811 _dim_size: usize,
16812 _inner: usize,
16813 _device: &GpuDevice,
16814) -> GpuResult<CudaBuffer<f32>> {
16815 Err(GpuError::NoCudaFeature)
16816}
16817
16818#[cfg(not(feature = "cuda"))]
16820pub fn gpu_cummax(
16821 _input: &CudaBuffer<f32>,
16822 _outer: usize,
16823 _dim_size: usize,
16824 _inner: usize,
16825 _device: &GpuDevice,
16826) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
16827 Err(GpuError::NoCudaFeature)
16828}
16829
16830#[cfg(not(feature = "cuda"))]
16832pub fn gpu_cummin(
16833 _input: &CudaBuffer<f32>,
16834 _outer: usize,
16835 _dim_size: usize,
16836 _inner: usize,
16837 _device: &GpuDevice,
16838) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
16839 Err(GpuError::NoCudaFeature)
16840}
16841
16842#[cfg(not(feature = "cuda"))]
16844pub fn gpu_logcumsumexp(
16845 _input: &CudaBuffer<f32>,
16846 _outer: usize,
16847 _dim_size: usize,
16848 _inner: usize,
16849 _device: &GpuDevice,
16850) -> GpuResult<CudaBuffer<f32>> {
16851 Err(GpuError::NoCudaFeature)
16852}
16853
16854#[cfg(not(feature = "cuda"))]
16856pub fn gpu_strided_split(
16857 _input: &CudaBuffer<f32>,
16858 _total_along_axis: usize,
16859 _split_offset: usize,
16860 _split_size: usize,
16861 _inner_size: usize,
16862 _n: usize,
16863 _device: &GpuDevice,
16864) -> GpuResult<CudaBuffer<f32>> {
16865 Err(GpuError::NoCudaFeature)
16866}
16867
16868#[cfg(not(feature = "cuda"))]
16870pub fn gpu_strided_cat(
16871 _input: &CudaBuffer<f32>,
16872 _output: &mut CudaBuffer<f32>,
16873 _total_along_axis: usize,
16874 _cat_offset: usize,
16875 _part_size: usize,
16876 _inner_size: usize,
16877 _n: usize,
16878 _device: &GpuDevice,
16879) -> GpuResult<()> {
16880 Err(GpuError::NoCudaFeature)
16881}
16882
16883#[cfg(not(feature = "cuda"))]
16886pub const STRIDED_COPY_MAX_DIMS: usize = 8;
16887
16888#[cfg(not(feature = "cuda"))]
16890pub fn gpu_strided_copy(
16891 _input: &CudaBuffer<f32>,
16892 _out_shape: &[usize],
16893 _src_strides: &[isize],
16894 _src_offset: usize,
16895 _device: &GpuDevice,
16896) -> GpuResult<CudaBuffer<f32>> {
16897 Err(GpuError::NoCudaFeature)
16898}
16899
16900#[cfg(not(feature = "cuda"))]
16902pub fn gpu_strided_copy_f64(
16903 _input: &CudaBuffer<f64>,
16904 _out_shape: &[usize],
16905 _src_strides: &[isize],
16906 _src_offset: usize,
16907 _device: &GpuDevice,
16908) -> GpuResult<CudaBuffer<f64>> {
16909 Err(GpuError::NoCudaFeature)
16910}
16911
16912#[cfg(feature = "cuda")]
16928pub(crate) fn gpu_f32_to_f16(
16929 input: &CudaBuffer<f32>,
16930 device: &GpuDevice,
16931) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
16932 use cudarc::driver::PushKernelArg;
16933
16934 let n = input.len();
16935 if n == 0 {
16936 let empty = device.stream().alloc_zeros::<u16>(0)?;
16937 return Ok(empty);
16938 }
16939
16940 let ctx = device.context();
16941 let stream = device.stream();
16942
16943 let f = crate::module_cache::get_or_compile(
16944 ctx,
16945 F32_TO_F16_PTX,
16946 "f32_to_f16_kernel",
16947 device.ordinal() as u32,
16948 )
16949 .map_err(|_| GpuError::PtxCompileFailed {
16950 kernel: "f32_to_f16_kernel",
16951 })?;
16952
16953 let mut out = stream.alloc_zeros::<u16>(n)?;
16954 let cfg = launch_cfg(n)?;
16955 let n_u32 = n as u32;
16956
16957 unsafe {
16961 stream
16962 .launch_builder(&f)
16963 .arg(input.inner())
16964 .arg(&mut out)
16965 .arg(&n_u32)
16966 .launch(cfg)?;
16967 }
16968
16969 Ok(out)
16970}
16971
16972#[cfg(not(feature = "cuda"))]
16974pub(crate) fn gpu_f32_to_f16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
16975 Err(GpuError::NoCudaFeature)
16976}
16977
16978#[cfg(feature = "cuda")]
16983pub(crate) fn gpu_f32_to_bf16(
16984 input: &CudaBuffer<f32>,
16985 device: &GpuDevice,
16986) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
16987 use cudarc::driver::PushKernelArg;
16988
16989 let n = input.len();
16990 if n == 0 {
16991 let empty = device.stream().alloc_zeros::<u16>(0)?;
16992 return Ok(empty);
16993 }
16994
16995 let ctx = device.context();
16996 let stream = device.stream();
16997
16998 let f = crate::module_cache::get_or_compile(
16999 ctx,
17000 F32_TO_BF16_PTX,
17001 "f32_to_bf16_kernel",
17002 device.ordinal() as u32,
17003 )
17004 .map_err(|_| GpuError::PtxCompileFailed {
17005 kernel: "f32_to_bf16_kernel",
17006 })?;
17007
17008 let mut out = stream.alloc_zeros::<u16>(n)?;
17009 let cfg = launch_cfg(n)?;
17010 let n_u32 = n as u32;
17011
17012 unsafe {
17013 stream
17014 .launch_builder(&f)
17015 .arg(input.inner())
17016 .arg(&mut out)
17017 .arg(&n_u32)
17018 .launch(cfg)?;
17019 }
17020
17021 Ok(out)
17022}
17023
17024#[cfg(not(feature = "cuda"))]
17026pub(crate) fn gpu_f32_to_bf16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
17027 Err(GpuError::NoCudaFeature)
17028}
17029
17030#[cfg(not(feature = "cuda"))]
17035pub fn gpu_add_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17036#[cfg(not(feature = "cuda"))]
17037pub fn gpu_sub_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17038#[cfg(not(feature = "cuda"))]
17039pub fn gpu_mul_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17040#[cfg(not(feature = "cuda"))]
17041pub fn gpu_div_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17042#[cfg(not(feature = "cuda"))]
17043pub fn gpu_neg_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17044#[cfg(not(feature = "cuda"))]
17045pub fn gpu_relu_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17046#[cfg(not(feature = "cuda"))]
17047pub fn gpu_scale_f64(_a: &CudaBuffer<f64>, _scalar: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17048#[cfg(not(feature = "cuda"))]
17049pub fn gpu_exp_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17050#[cfg(not(feature = "cuda"))]
17051pub fn gpu_log_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17052#[cfg(not(feature = "cuda"))]
17053pub fn gpu_sqrt_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17054#[cfg(not(feature = "cuda"))]
17055pub fn gpu_pow_f64(_a: &CudaBuffer<f64>, _exponent: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17056#[cfg(not(feature = "cuda"))]
17057pub fn gpu_abs_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17058#[cfg(not(feature = "cuda"))]
17059pub fn gpu_sigmoid_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17060#[cfg(not(feature = "cuda"))]
17061pub fn gpu_tanh_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17062#[cfg(not(feature = "cuda"))]
17063pub fn gpu_relu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17064#[cfg(not(feature = "cuda"))]
17065pub fn gpu_sigmoid_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17066#[cfg(not(feature = "cuda"))]
17067pub fn gpu_tanh_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17068#[cfg(not(feature = "cuda"))]
17069pub fn gpu_broadcast_add_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17070#[cfg(not(feature = "cuda"))]
17071pub fn gpu_broadcast_sub_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17072#[cfg(not(feature = "cuda"))]
17073pub fn gpu_broadcast_mul_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17074#[cfg(not(feature = "cuda"))]
17075pub fn gpu_broadcast_div_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17076#[cfg(not(feature = "cuda"))]
17077pub fn gpu_transpose_2d_f64(_input: &CudaBuffer<f64>, _m: usize, _n: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17078#[cfg(not(feature = "cuda"))]
17079pub fn gpu_permute_0213_f64(_input: &CudaBuffer<f64>, _d0: usize, _d1: usize, _d2: usize, _d3: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17080#[cfg(not(feature = "cuda"))]
17081pub fn gpu_strided_split_f64(_input: &CudaBuffer<f64>, _total_along_axis: usize, _split_offset: usize, _split_size: usize, _inner_size: usize, _n: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17082#[cfg(not(feature = "cuda"))]
17083pub fn gpu_strided_cat_f64(_input: &CudaBuffer<f64>, _output: &mut CudaBuffer<f64>, _total_along_axis: usize, _cat_offset: usize, _part_size: usize, _inner_size: usize, _n: usize, _device: &GpuDevice) -> GpuResult<()> { Err(GpuError::NoCudaFeature) }
17084#[cfg(not(feature = "cuda"))]
17085pub fn gpu_index_select_1d_f64(_input: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17086#[cfg(not(feature = "cuda"))]
17087pub fn gpu_scatter_add_1d_f64(_grad_output: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _input_len: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17088#[cfg(not(feature = "cuda"))]
17089pub fn gpu_masked_fill_f64(_input: &CudaBuffer<f64>, _mask: &CudaBuffer<u8>, _value: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17090#[cfg(not(feature = "cuda"))]
17091pub fn gpu_masked_zero_f64(_grad: &CudaBuffer<f64>, _mask: &CudaBuffer<u8>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17092#[cfg(not(feature = "cuda"))]
17093pub fn gpu_slice_write_f64(_src: &CudaBuffer<f64>, _dst: &mut CudaBuffer<f64>, _n_batch: usize, _d: usize, _max_len: usize, _pos: usize, _device: &GpuDevice) -> GpuResult<()> { Err(GpuError::NoCudaFeature) }
17094#[cfg(not(feature = "cuda"))]
17095pub fn gpu_slice_read_f64(_src: &CudaBuffer<f64>, _n_batch: usize, _d: usize, _len: usize, _max_len: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17096#[cfg(not(feature = "cuda"))]
17097pub fn gpu_embed_lookup_f64(_idx: &CudaBuffer<f32>, _weight: &CudaBuffer<f64>, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17098#[cfg(not(feature = "cuda"))]
17099pub fn gpu_embed_lookup_batch_f64(_indices: &CudaBuffer<f32>, _weight: &CudaBuffer<f64>, _n: usize, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17100#[cfg(not(feature = "cuda"))]
17101pub fn gpu_scatter_add_rows_f64(_grad_output: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _num_embeddings: usize, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17102
17103
17104#[cfg(feature = "cuda")]
17110pub fn gpu_gelu_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17111 if let Some(out) = try_launch_unary_f64(input, device, GELU_F64_PTX, "gelu_f64_kernel")? {
17112 return Ok(out);
17113 }
17114 cpu_fallback_unary_f64(input, device, |x| x * (1.0 / (1.0 + (-1.702 * x).exp())))
17115}
17116
17117#[cfg(feature = "cuda")]
17119pub fn gpu_gelu_tanh_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17120 if let Some(out) = try_launch_unary_f64(input, device, GELU_TANH_F64_PTX, "gelu_tanh_f64_kernel")? {
17121 return Ok(out);
17122 }
17123 cpu_fallback_unary_f64(input, device, |x| {
17124 let inner = (2.0_f64 / std::f64::consts::PI).sqrt() * (x + 0.044715 * x * x * x);
17125 0.5 * x * (1.0 + inner.tanh())
17126 })
17127}
17128
17129#[cfg(feature = "cuda")]
17131pub fn gpu_gelu_erf_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17132 if let Some(out) = try_launch_unary_f64(input, device, GELU_ERF_F64_PTX, "gelu_erf_f64_kernel")? {
17133 return Ok(out);
17134 }
17135 cpu_fallback_unary_f64(input, device, |x| {
17136 let z = x * std::f64::consts::FRAC_1_SQRT_2;
17138 let az = z.abs();
17139 let t = 1.0 / (1.0 + 0.3275911 * az);
17140 let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
17141 let erf_abs = 1.0 - poly * (-az * az).exp();
17142 let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
17143 x * 0.5 * (1.0 + erf_val)
17144 })
17145}
17146
17147#[cfg(feature = "cuda")]
17149pub fn gpu_gelu_backward_f64(
17150 grad: &CudaBuffer<f64>,
17151 input: &CudaBuffer<f64>,
17152 device: &GpuDevice,
17153) -> GpuResult<CudaBuffer<f64>> {
17154 if grad.len() != input.len() {
17155 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17156 }
17157 if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_F64_PTX, "gelu_backward_f64_kernel")? {
17158 return Ok(out);
17159 }
17160 cpu_fallback_binary_f64(grad, input, device, |g, x| {
17161 let sig = 1.0 / (1.0 + (-1.702 * x).exp());
17162 g * (sig + 1.702 * x * sig * (1.0 - sig))
17163 })
17164}
17165
17166#[cfg(feature = "cuda")]
17168pub fn gpu_gelu_backward_tanh_f64(
17169 grad: &CudaBuffer<f64>,
17170 input: &CudaBuffer<f64>,
17171 device: &GpuDevice,
17172) -> GpuResult<CudaBuffer<f64>> {
17173 if grad.len() != input.len() {
17174 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17175 }
17176 if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_TANH_F64_PTX, "gelu_backward_tanh_f64_kernel")? {
17177 return Ok(out);
17178 }
17179 cpu_fallback_binary_f64(grad, input, device, |g, x| {
17180 let s2pi = (2.0_f64 / std::f64::consts::PI).sqrt();
17181 let c = 0.044715_f64;
17182 let u = s2pi * (x + c * x * x * x);
17183 let t = u.tanh();
17184 let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * s2pi * (1.0 + 3.0 * c * x * x);
17185 g * d
17186 })
17187}
17188
17189#[cfg(feature = "cuda")]
17191pub fn gpu_gelu_backward_erf_f64(
17192 grad: &CudaBuffer<f64>,
17193 input: &CudaBuffer<f64>,
17194 device: &GpuDevice,
17195) -> GpuResult<CudaBuffer<f64>> {
17196 if grad.len() != input.len() {
17197 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17198 }
17199 if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_ERF_F64_PTX, "gelu_backward_erf_f64_kernel")? {
17200 return Ok(out);
17201 }
17202 cpu_fallback_binary_f64(grad, input, device, |g, x| {
17203 let z = x * std::f64::consts::FRAC_1_SQRT_2;
17204 let az = z.abs();
17205 let t = 1.0 / (1.0 + 0.3275911 * az);
17206 let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
17207 let erf_abs = 1.0 - poly * (-az * az).exp();
17208 let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
17209 let cdf = 0.5 * (1.0 + erf_val);
17210 let pdf = (-x * x / 2.0).exp() / (2.0 * std::f64::consts::PI).sqrt();
17211 g * (cdf + x * pdf)
17212 })
17213}
17214
17215#[cfg(feature = "cuda")]
17217pub fn gpu_silu_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17218 if let Some(out) = try_launch_unary_f64(input, device, SILU_F64_PTX, "silu_f64_kernel")? {
17219 return Ok(out);
17220 }
17221 cpu_fallback_unary_f64(input, device, |x| x / (1.0 + (-x).exp()))
17222}
17223
17224#[cfg(feature = "cuda")]
17226pub fn gpu_silu_backward_f64(
17227 grad: &CudaBuffer<f64>,
17228 input: &CudaBuffer<f64>,
17229 device: &GpuDevice,
17230) -> GpuResult<CudaBuffer<f64>> {
17231 if grad.len() != input.len() {
17232 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17233 }
17234 if let Some(out) = try_launch_binary_f64(grad, input, device, SILU_BACKWARD_F64_PTX, "silu_backward_f64_kernel")? {
17235 return Ok(out);
17236 }
17237 cpu_fallback_binary_f64(grad, input, device, |g, x| {
17238 let sig = 1.0 / (1.0 + (-x).exp());
17239 g * (sig + x * sig * (1.0 - sig))
17240 })
17241}
17242
17243#[cfg(feature = "cuda")]
17245pub fn gpu_elu_f64(
17246 input: &CudaBuffer<f64>,
17247 alpha: f64,
17248 device: &GpuDevice,
17249) -> GpuResult<CudaBuffer<f64>> {
17250 use cudarc::driver::PushKernelArg;
17251 let n = input.len();
17252 if n == 0 { return cpu_to_gpu(&[], device); }
17253 let ctx = device.context();
17254 let stream = device.stream();
17255 if let Ok(f) = crate::module_cache::get_or_compile(ctx, ELU_F64_PTX, "elu_f64_kernel", device.ordinal() as u32) {
17256 let mut out = alloc_zeros_f64(n, device)?;
17257 let n_u32 = n as u32;
17258 let cfg = launch_cfg(n)?;
17259 unsafe {
17260 stream.launch_builder(&f)
17261 .arg(input.inner())
17262 .arg(out.inner_mut())
17263 .arg(&n_u32)
17264 .arg(&alpha)
17265 .launch(cfg)?;
17266 }
17267 return Ok(out);
17268 }
17269 let host = gpu_to_cpu(input, device)?;
17270 let result: Vec<f64> = host.iter().map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) }).collect();
17271 cpu_to_gpu(&result, device)
17272}
17273
17274#[cfg(feature = "cuda")]
17276pub fn gpu_elu_backward_f64(
17277 grad: &CudaBuffer<f64>,
17278 input: &CudaBuffer<f64>,
17279 alpha: f64,
17280 device: &GpuDevice,
17281) -> GpuResult<CudaBuffer<f64>> {
17282 use cudarc::driver::PushKernelArg;
17283 if grad.len() != input.len() {
17284 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17285 }
17286 let n = grad.len();
17287 if n == 0 { return cpu_to_gpu(&[], device); }
17288 let ctx = device.context();
17289 let stream = device.stream();
17290 if let Ok(f) = crate::module_cache::get_or_compile(ctx, ELU_BACKWARD_F64_PTX, "elu_backward_f64_kernel", device.ordinal() as u32) {
17291 let mut out = alloc_zeros_f64(n, device)?;
17292 let n_u32 = n as u32;
17293 let cfg = launch_cfg(n)?;
17294 unsafe {
17295 stream.launch_builder(&f)
17296 .arg(grad.inner())
17297 .arg(input.inner())
17298 .arg(out.inner_mut())
17299 .arg(&n_u32)
17300 .arg(&alpha)
17301 .launch(cfg)?;
17302 }
17303 return Ok(out);
17304 }
17305 let g_host = gpu_to_cpu(grad, device)?;
17306 let x_host = gpu_to_cpu(input, device)?;
17307 let result: Vec<f64> = g_host.iter().zip(x_host.iter()).map(|(&g, &x)| if x > 0.0 { g } else { g * alpha * x.exp() }).collect();
17308 cpu_to_gpu(&result, device)
17309}
17310
17311#[cfg(feature = "cuda")]
17313pub fn gpu_mish_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17314 if let Some(out) = try_launch_unary_f64(input, device, MISH_F64_PTX, "mish_f64_kernel")? {
17315 return Ok(out);
17316 }
17317 cpu_fallback_unary_f64(input, device, |x| x * (1.0_f64 + x.exp()).ln().tanh())
17318}
17319
17320#[cfg(feature = "cuda")]
17322pub fn gpu_mish_backward_f64(
17323 grad: &CudaBuffer<f64>,
17324 input: &CudaBuffer<f64>,
17325 device: &GpuDevice,
17326) -> GpuResult<CudaBuffer<f64>> {
17327 if grad.len() != input.len() {
17328 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17329 }
17330 if let Some(out) = try_launch_binary_f64(grad, input, device, MISH_BACKWARD_F64_PTX, "mish_backward_f64_kernel")? {
17331 return Ok(out);
17332 }
17333 cpu_fallback_binary_f64(grad, input, device, |g, x| {
17334 let sp = (1.0_f64 + x.exp()).ln();
17335 let t = sp.tanh();
17336 let sig = 1.0 / (1.0 + (-x).exp());
17337 g * (t + x * sig * (1.0 - t * t))
17338 })
17339}
17340
17341#[cfg(feature = "cuda")]
17343pub fn gpu_clamp_f64(
17344 input: &CudaBuffer<f64>,
17345 min_val: f64,
17346 max_val: f64,
17347 device: &GpuDevice,
17348) -> GpuResult<CudaBuffer<f64>> {
17349 use cudarc::driver::PushKernelArg;
17350 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17351 let n = input.len();
17352 if n == 0 { return cpu_to_gpu(&[], device); }
17353 let ctx = device.context();
17354 let stream = device.stream();
17355 let ptx = get_f64_ptx(&CACHE, CLAMP_PTX, "clamp_kernel", "clamp_f64_kernel");
17356 if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "clamp_f64_kernel", device.ordinal() as u32) {
17357 let mut out = alloc_zeros_f64(n, device)?;
17358 let n_u32 = n as u32;
17359 let cfg = launch_cfg(n)?;
17360 unsafe {
17361 stream.launch_builder(&f)
17362 .arg(input.inner())
17363 .arg(out.inner_mut())
17364 .arg(&n_u32)
17365 .arg(&min_val)
17366 .arg(&max_val)
17367 .launch(cfg)?;
17368 }
17369 return Ok(out);
17370 }
17371 let host = gpu_to_cpu(input, device)?;
17372 let result: Vec<f64> = host.iter().map(|&x| x.max(min_val).min(max_val)).collect();
17373 cpu_to_gpu(&result, device)
17374}
17375
17376#[cfg(feature = "cuda")]
17378pub fn gpu_cumsum_f64(
17379 input: &CudaBuffer<f64>,
17380 outer: usize,
17381 dim_size: usize,
17382 inner: usize,
17383 device: &GpuDevice,
17384) -> GpuResult<CudaBuffer<f64>> {
17385 use cudarc::driver::PushKernelArg;
17386 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17387 let total = outer * inner;
17388 let n = outer * dim_size * inner;
17389 if n == 0 { return cpu_to_gpu(&[], device); }
17390 let ctx = device.context();
17391 let stream = device.stream();
17392 let ptx = get_f64_ptx(&CACHE, CUMSUM_PTX, "cumsum_kernel", "cumsum_f64_kernel");
17393 if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "cumsum_f64_kernel", device.ordinal() as u32) {
17394 let mut out = alloc_zeros_f64(n, device)?;
17395 let cfg = launch_cfg(total)?;
17396 let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17397 unsafe {
17398 stream.launch_builder(&f)
17399 .arg(input.inner())
17400 .arg(out.inner_mut())
17401 .arg(&o)
17402 .arg(&d)
17403 .arg(&i)
17404 .arg(&t)
17405 .launch(cfg)?;
17406 }
17407 return Ok(out);
17408 }
17409 Err(GpuError::PtxCompileFailed { kernel: "cumsum_f64_kernel" })
17410}
17411
17412#[cfg(feature = "cuda")]
17414pub fn gpu_cumprod_f64(
17415 input: &CudaBuffer<f64>,
17416 outer: usize,
17417 dim_size: usize,
17418 inner: usize,
17419 device: &GpuDevice,
17420) -> GpuResult<CudaBuffer<f64>> {
17421 use cudarc::driver::PushKernelArg;
17422 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17423 let total = outer * inner;
17424 let n = outer * dim_size * inner;
17425 if n == 0 { return cpu_to_gpu(&[], device); }
17426 let ctx = device.context();
17427 let stream = device.stream();
17428 let ptx = get_f64_ptx(&CACHE, CUMPROD_PTX, "cumprod_kernel", "cumprod_f64_kernel");
17429 if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "cumprod_f64_kernel", device.ordinal() as u32) {
17430 let mut out = alloc_zeros_f64(n, device)?;
17431 let cfg = launch_cfg(total)?;
17432 let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17433 unsafe {
17434 stream.launch_builder(&f)
17435 .arg(input.inner())
17436 .arg(out.inner_mut())
17437 .arg(&o)
17438 .arg(&d)
17439 .arg(&i)
17440 .arg(&t)
17441 .launch(cfg)?;
17442 }
17443 return Ok(out);
17444 }
17445 Err(GpuError::PtxCompileFailed { kernel: "cumprod_f64_kernel" })
17446}
17447
17448#[cfg(feature = "cuda")]
17450pub fn gpu_cummax_f64(
17451 input: &CudaBuffer<f64>,
17452 outer: usize,
17453 dim_size: usize,
17454 inner: usize,
17455 device: &GpuDevice,
17456) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
17457 use cudarc::driver::PushKernelArg;
17458 let total = outer * inner;
17459 let n = outer * dim_size * inner;
17460 if n == 0 {
17461 let e: &[f64] = &[];
17462 return Ok((cpu_to_gpu(e, device)?, cpu_to_gpu(e, device)?));
17463 }
17464 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17465 let ctx = device.context();
17466 let stream = device.stream();
17467 let ptx = get_f64_ptx(&CACHE, CUMMAX_PTX, "cummax_kernel", "cummax_f64_kernel");
17468 let f = crate::module_cache::get_or_compile(ctx, ptx, "cummax_f64_kernel", device.ordinal() as u32)
17469 .map_err(|_| GpuError::PtxCompileFailed { kernel: "cummax_f64_kernel" })?;
17470 let mut out = alloc_zeros_f64(n, device)?;
17471 let mut ind = alloc_zeros_f64(n, device)?;
17472 let cfg = launch_cfg(total)?;
17473 let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17474 unsafe {
17475 stream.launch_builder(&f)
17476 .arg(input.inner())
17477 .arg(out.inner_mut())
17478 .arg(ind.inner_mut())
17479 .arg(&o)
17480 .arg(&d)
17481 .arg(&i)
17482 .arg(&t)
17483 .launch(cfg)?;
17484 }
17485 Ok((out, ind))
17486}
17487
17488#[cfg(feature = "cuda")]
17490pub fn gpu_cummin_f64(
17491 input: &CudaBuffer<f64>,
17492 outer: usize,
17493 dim_size: usize,
17494 inner: usize,
17495 device: &GpuDevice,
17496) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
17497 use cudarc::driver::PushKernelArg;
17498 let total = outer * inner;
17499 let n = outer * dim_size * inner;
17500 if n == 0 {
17501 let e: &[f64] = &[];
17502 return Ok((cpu_to_gpu(e, device)?, cpu_to_gpu(e, device)?));
17503 }
17504 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17505 let ctx = device.context();
17506 let stream = device.stream();
17507 let ptx = get_f64_ptx(&CACHE, CUMMIN_PTX, "cummin_kernel", "cummin_f64_kernel");
17508 let f = crate::module_cache::get_or_compile(ctx, ptx, "cummin_f64_kernel", device.ordinal() as u32)
17509 .map_err(|_| GpuError::PtxCompileFailed { kernel: "cummin_f64_kernel" })?;
17510 let mut out = alloc_zeros_f64(n, device)?;
17511 let mut ind = alloc_zeros_f64(n, device)?;
17512 let cfg = launch_cfg(total)?;
17513 let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17514 unsafe {
17515 stream.launch_builder(&f)
17516 .arg(input.inner())
17517 .arg(out.inner_mut())
17518 .arg(ind.inner_mut())
17519 .arg(&o)
17520 .arg(&d)
17521 .arg(&i)
17522 .arg(&t)
17523 .launch(cfg)?;
17524 }
17525 Ok((out, ind))
17526}
17527
17528#[cfg(feature = "cuda")]
17530pub fn gpu_logcumsumexp_f64(
17531 input: &CudaBuffer<f64>,
17532 outer: usize,
17533 dim_size: usize,
17534 inner: usize,
17535 device: &GpuDevice,
17536) -> GpuResult<CudaBuffer<f64>> {
17537 use cudarc::driver::PushKernelArg;
17538 let total = outer * inner;
17539 let n = outer * dim_size * inner;
17540 if n == 0 { return cpu_to_gpu(&[], device); }
17541 let ctx = device.context();
17542 let stream = device.stream();
17543 if let Ok(f) = crate::module_cache::get_or_compile(ctx, LOGCUMSUMEXP_F64_PTX, "logcumsumexp_f64_kernel", device.ordinal() as u32) {
17544 let mut out = alloc_zeros_f64(n, device)?;
17545 let cfg = launch_cfg(total)?;
17546 let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17547 unsafe {
17548 stream.launch_builder(&f)
17549 .arg(input.inner())
17550 .arg(out.inner_mut())
17551 .arg(&o)
17552 .arg(&d)
17553 .arg(&i)
17554 .arg(&t)
17555 .launch(cfg)?;
17556 }
17557 return Ok(out);
17558 }
17559 Err(GpuError::PtxCompileFailed { kernel: "logcumsumexp_f64_kernel" })
17560}
17561
17562#[cfg(feature = "cuda")]
17571pub fn gpu_softmax_f64(
17572 input: &CudaBuffer<f64>,
17573 rows: usize,
17574 cols: usize,
17575 device: &GpuDevice,
17576) -> GpuResult<CudaBuffer<f64>> {
17577 use cudarc::driver::PushKernelArg;
17578
17579 validate_device(input, device)?;
17580
17581 let ctx = device.context();
17582 let stream = device.stream();
17583
17584 let f = match crate::module_cache::get_or_compile(
17585 ctx,
17586 SOFTMAX_F64_PTX,
17587 "softmax_f64_kernel",
17588 device.ordinal() as u32,
17589 ) {
17590 Ok(f) => f,
17591 Err(_) => {
17592 let host = gpu_to_cpu(input, device)?;
17593 let mut out = vec![0.0f64; host.len()];
17594 for r in 0..rows {
17595 let base = r * cols;
17596 let mut max_v = f64::NEG_INFINITY;
17597 for c in 0..cols {
17598 max_v = max_v.max(host[base + c]);
17599 }
17600 let mut sum = 0.0f64;
17601 for c in 0..cols {
17602 let e = (host[base + c] - max_v).exp();
17603 out[base + c] = e;
17604 sum += e;
17605 }
17606 let inv = 1.0 / sum;
17607 for c in 0..cols {
17608 out[base + c] *= inv;
17609 }
17610 }
17611 return cpu_to_gpu(&out, device);
17612 }
17613 };
17614
17615 let mut out = alloc_zeros_f64(rows * cols, device)?;
17616 let rows_u32 = rows as u32;
17617 let cols_u32 = cols as u32;
17618
17619 let cfg = LaunchConfig {
17620 grid_dim: ((rows as u32).max(1), 1, 1),
17621 block_dim: (256, 1, 1),
17622 shared_mem_bytes: 256 * 8, };
17624
17625 unsafe {
17626 stream
17627 .launch_builder(&f)
17628 .arg(input.inner())
17629 .arg(out.inner_mut())
17630 .arg(&rows_u32)
17631 .arg(&cols_u32)
17632 .launch(cfg)?;
17633 }
17634
17635 Ok(out)
17636}
17637
17638#[cfg(feature = "cuda")]
17642pub fn gpu_softmax_backward_f64(
17643 grad: &CudaBuffer<f64>,
17644 output: &CudaBuffer<f64>,
17645 cols: usize,
17646 device: &GpuDevice,
17647) -> GpuResult<CudaBuffer<f64>> {
17648 use cudarc::driver::PushKernelArg;
17649
17650 validate_device(grad, device)?;
17651 if grad.len() != output.len() {
17652 return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
17653 }
17654
17655 let total = grad.len();
17656 let rows = total / cols;
17657
17658 let ctx = device.context();
17659 let stream = device.stream();
17660
17661 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17662 let ptx = get_f64_ptx(&CACHE, SOFTMAX_BACKWARD_PTX, "softmax_backward_kernel", "softmax_backward_f64_kernel");
17663 let f = match crate::module_cache::get_or_compile(
17664 ctx,
17665 ptx,
17666 "softmax_backward_f64_kernel",
17667 device.ordinal() as u32,
17668 ) {
17669 Ok(f) => f,
17670 Err(_) => {
17671 let grad_host = gpu_to_cpu(grad, device)?;
17672 let output_host = gpu_to_cpu(output, device)?;
17673 let mut result = vec![0.0f64; total];
17674 for r in 0..rows {
17675 let base = r * cols;
17676 let mut dot = 0.0f64;
17677 for c in 0..cols {
17678 dot += grad_host[base + c] * output_host[base + c];
17679 }
17680 for c in 0..cols {
17681 result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
17682 }
17683 }
17684 return cpu_to_gpu(&result, device);
17685 }
17686 };
17687
17688 let mut out = alloc_zeros_f64(total, device)?;
17689 let rows_u32 = rows as u32;
17690 let cols_u32 = cols as u32;
17691
17692 let cfg = LaunchConfig {
17693 grid_dim: ((rows as u32).max(1), 1, 1),
17694 block_dim: (256, 1, 1),
17695 shared_mem_bytes: 256 * 8,
17696 };
17697
17698 unsafe {
17699 stream
17700 .launch_builder(&f)
17701 .arg(grad.inner())
17702 .arg(output.inner())
17703 .arg(out.inner_mut())
17704 .arg(&rows_u32)
17705 .arg(&cols_u32)
17706 .launch(cfg)?;
17707 }
17708
17709 Ok(out)
17710}
17711
17712#[cfg(feature = "cuda")]
17716pub fn gpu_log_softmax_f64(
17717 input: &CudaBuffer<f64>,
17718 cols: usize,
17719 device: &GpuDevice,
17720) -> GpuResult<CudaBuffer<f64>> {
17721 use cudarc::driver::PushKernelArg;
17722
17723 validate_device(input, device)?;
17724
17725 let total = input.len();
17726 let rows = total / cols;
17727
17728 let ctx = device.context();
17729 let stream = device.stream();
17730
17731 let f = match crate::module_cache::get_or_compile(
17732 ctx,
17733 LOG_SOFTMAX_F64_PTX,
17734 "log_softmax_f64_kernel",
17735 device.ordinal() as u32,
17736 ) {
17737 Ok(f) => f,
17738 Err(_) => {
17739 let host = gpu_to_cpu(input, device)?;
17740 let mut out = vec![0.0f64; total];
17741 for r in 0..rows {
17742 let base = r * cols;
17743 let mut max_v = f64::NEG_INFINITY;
17744 for c in 0..cols {
17745 max_v = max_v.max(host[base + c]);
17746 }
17747 let mut sum_exp = 0.0f64;
17748 for c in 0..cols {
17749 sum_exp += (host[base + c] - max_v).exp();
17750 }
17751 let log_sum_exp = max_v + sum_exp.ln();
17752 for c in 0..cols {
17753 out[base + c] = host[base + c] - log_sum_exp;
17754 }
17755 }
17756 return cpu_to_gpu(&out, device);
17757 }
17758 };
17759
17760 let mut out = alloc_zeros_f64(total, device)?;
17761 let rows_u32 = rows as u32;
17762 let cols_u32 = cols as u32;
17763
17764 let cfg = LaunchConfig {
17765 grid_dim: ((rows as u32).max(1), 1, 1),
17766 block_dim: (256, 1, 1),
17767 shared_mem_bytes: 256 * 8,
17768 };
17769
17770 unsafe {
17771 stream
17772 .launch_builder(&f)
17773 .arg(input.inner())
17774 .arg(out.inner_mut())
17775 .arg(&rows_u32)
17776 .arg(&cols_u32)
17777 .launch(cfg)?;
17778 }
17779
17780 Ok(out)
17781}
17782
17783#[cfg(feature = "cuda")]
17789pub fn gpu_log_softmax_backward_f64(
17790 grad: &CudaBuffer<f64>,
17791 output: &CudaBuffer<f64>,
17792 cols: usize,
17793 device: &GpuDevice,
17794) -> GpuResult<CudaBuffer<f64>> {
17795 use cudarc::driver::PushKernelArg;
17796
17797 validate_device(grad, device)?;
17798 if grad.len() != output.len() {
17799 return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
17800 }
17801
17802 let total = grad.len();
17803 let rows = total / cols;
17804
17805 let ctx = device.context();
17806 let stream = device.stream();
17807
17808 let f = match crate::module_cache::get_or_compile(
17809 ctx,
17810 LOG_SOFTMAX_BACKWARD_F64_PTX,
17811 "log_softmax_backward_f64_kernel",
17812 device.ordinal() as u32,
17813 ) {
17814 Ok(f) => f,
17815 Err(_) => {
17816 let grad_host = gpu_to_cpu(grad, device)?;
17817 let output_host = gpu_to_cpu(output, device)?;
17818 let mut result = vec![0.0f64; total];
17819 for r in 0..rows {
17820 let base = r * cols;
17821 let mut sum_grad = 0.0f64;
17822 for c in 0..cols {
17823 sum_grad += grad_host[base + c];
17824 }
17825 for c in 0..cols {
17826 result[base + c] =
17827 grad_host[base + c] - output_host[base + c].exp() * sum_grad;
17828 }
17829 }
17830 return cpu_to_gpu(&result, device);
17831 }
17832 };
17833
17834 let mut out = alloc_zeros_f64(total, device)?;
17835 let rows_u32 = rows as u32;
17836 let cols_u32 = cols as u32;
17837
17838 let cfg = LaunchConfig {
17839 grid_dim: ((rows as u32).max(1), 1, 1),
17840 block_dim: (256, 1, 1),
17841 shared_mem_bytes: 256 * 8,
17842 };
17843
17844 unsafe {
17845 stream
17846 .launch_builder(&f)
17847 .arg(grad.inner())
17848 .arg(output.inner())
17849 .arg(out.inner_mut())
17850 .arg(&rows_u32)
17851 .arg(&cols_u32)
17852 .launch(cfg)?;
17853 }
17854
17855 Ok(out)
17856}
17857
17858#[cfg(feature = "cuda")]
17863pub fn gpu_layernorm_f64(
17864 input: &CudaBuffer<f64>,
17865 weight: &CudaBuffer<f64>,
17866 bias: &CudaBuffer<f64>,
17867 rows: usize,
17868 cols: usize,
17869 eps: f64,
17870 device: &GpuDevice,
17871) -> GpuResult<CudaBuffer<f64>> {
17872 use cudarc::driver::PushKernelArg;
17873 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17874
17875 validate_device(input, device)?;
17876
17877 let ctx = device.context();
17878 let stream = device.stream();
17879
17880 let ptx = get_f64_ptx(&CACHE, LAYERNORM_PTX, "layernorm_kernel", "layernorm_f64_kernel");
17881 let f = match crate::module_cache::get_or_compile(
17882 ctx,
17883 ptx,
17884 "layernorm_f64_kernel",
17885 device.ordinal() as u32,
17886 ) {
17887 Ok(f) => f,
17888 Err(_) => {
17889 let h_in = gpu_to_cpu(input, device)?;
17890 let h_w = gpu_to_cpu(weight, device)?;
17891 let h_b = gpu_to_cpu(bias, device)?;
17892 let mut out = vec![0.0f64; rows * cols];
17893 for r in 0..rows {
17894 let base = r * cols;
17895 let slice = &h_in[base..base + cols];
17896 let mean: f64 = slice.iter().sum::<f64>() / cols as f64;
17897 let var: f64 =
17898 slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / cols as f64;
17899 let inv_std = 1.0 / (var + eps).sqrt();
17900 for c in 0..cols {
17901 let normed = (slice[c] - mean) * inv_std;
17902 out[base + c] = h_w[c] * normed + h_b[c];
17903 }
17904 }
17905 return cpu_to_gpu(&out, device);
17906 }
17907 };
17908
17909 let mut out = alloc_zeros_f64(rows * cols, device)?;
17910 let rows_u32 = rows as u32;
17911 let cols_u32 = cols as u32;
17912
17913 let cfg = LaunchConfig {
17914 grid_dim: ((rows as u32).max(1), 1, 1),
17915 block_dim: (256, 1, 1),
17916 shared_mem_bytes: 256 * 8,
17917 };
17918
17919 unsafe {
17920 stream
17921 .launch_builder(&f)
17922 .arg(input.inner())
17923 .arg(out.inner_mut())
17924 .arg(weight.inner())
17925 .arg(bias.inner())
17926 .arg(&rows_u32)
17927 .arg(&cols_u32)
17928 .arg(&eps)
17929 .launch(cfg)?;
17930 }
17931
17932 Ok(out)
17933}
17934
17935#[cfg(feature = "cuda")]
17939pub fn gpu_layernorm_backward_f64(
17940 input: &CudaBuffer<f64>,
17941 grad_output: &CudaBuffer<f64>,
17942 weight: &CudaBuffer<f64>,
17943 rows: usize,
17944 cols: usize,
17945 eps: f64,
17946 device: &GpuDevice,
17947) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>, CudaBuffer<f64>)> {
17948 use cudarc::driver::PushKernelArg;
17949 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17950
17951 validate_device(input, device)?;
17952
17953 let ctx = device.context();
17954 let stream = device.stream();
17955
17956 let ptx = get_f64_ptx(&CACHE, LAYERNORM_BACKWARD_PTX, "layernorm_backward_kernel", "layernorm_backward_f64_kernel");
17957 let f = match crate::module_cache::get_or_compile(
17958 ctx,
17959 ptx,
17960 "layernorm_backward_f64_kernel",
17961 device.ordinal() as u32,
17962 ) {
17963 Ok(f) => f,
17964 Err(_) => {
17965 let h_in = gpu_to_cpu(input, device)?;
17966 let h_go = gpu_to_cpu(grad_output, device)?;
17967 let h_w = gpu_to_cpu(weight, device)?;
17968 let mut grad_input = vec![0.0f64; rows * cols];
17969 let mut grad_weight = vec![0.0f64; cols];
17970 let mut grad_bias = vec![0.0f64; cols];
17971 let n_f = cols as f64;
17972 for r in 0..rows {
17973 let base = r * cols;
17974 let x_slice = &h_in[base..base + cols];
17975 let go_slice = &h_go[base..base + cols];
17976 let mean: f64 = x_slice.iter().sum::<f64>() / n_f;
17977 let var: f64 = x_slice
17978 .iter()
17979 .map(|&x| (x - mean) * (x - mean))
17980 .sum::<f64>()
17981 / n_f;
17982 let inv_std = 1.0 / (var + eps).sqrt();
17983 let mut sum1 = 0.0f64;
17984 let mut sum2 = 0.0f64;
17985 for c in 0..cols {
17986 let x_hat = (x_slice[c] - mean) * inv_std;
17987 let dl = go_slice[c] * h_w[c];
17988 sum1 += dl;
17989 sum2 += dl * x_hat;
17990 grad_weight[c] += go_slice[c] * x_hat;
17991 grad_bias[c] += go_slice[c];
17992 }
17993 let m1 = sum1 / n_f;
17994 let m2 = sum2 / n_f;
17995 for c in 0..cols {
17996 let x_hat = (x_slice[c] - mean) * inv_std;
17997 let dl = go_slice[c] * h_w[c];
17998 grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
17999 }
18000 }
18001 let gi = cpu_to_gpu(&grad_input, device)?;
18002 let gw = cpu_to_gpu(&grad_weight, device)?;
18003 let gb = cpu_to_gpu(&grad_bias, device)?;
18004 return Ok((gi, gw, gb));
18005 }
18006 };
18007
18008 let mut grad_in = alloc_zeros_f64(rows * cols, device)?;
18009 let mut grad_w = alloc_zeros_f64(cols, device)?;
18010 let mut grad_b = alloc_zeros_f64(cols, device)?;
18011 let rows_u32 = rows as u32;
18012 let cols_u32 = cols as u32;
18013
18014 let cfg = LaunchConfig {
18015 grid_dim: ((rows as u32).max(1), 1, 1),
18016 block_dim: (256, 1, 1),
18017 shared_mem_bytes: 256 * 8,
18018 };
18019
18020 unsafe {
18021 stream
18022 .launch_builder(&f)
18023 .arg(input.inner())
18024 .arg(grad_output.inner())
18025 .arg(weight.inner())
18026 .arg(grad_in.inner_mut())
18027 .arg(grad_w.inner_mut())
18028 .arg(grad_b.inner_mut())
18029 .arg(&rows_u32)
18030 .arg(&cols_u32)
18031 .arg(&eps)
18032 .launch(cfg)?;
18033 }
18034
18035 Ok((grad_in, grad_w, grad_b))
18036}
18037
18038#[cfg(feature = "cuda")]
18043pub fn gpu_rmsnorm_f64(
18044 input: &CudaBuffer<f64>,
18045 weight: &CudaBuffer<f64>,
18046 rows: usize,
18047 cols: usize,
18048 eps: f64,
18049 device: &GpuDevice,
18050) -> GpuResult<CudaBuffer<f64>> {
18051 use cudarc::driver::PushKernelArg;
18052 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
18053
18054 validate_device(input, device)?;
18055
18056 let ctx = device.context();
18057 let stream = device.stream();
18058
18059 let ptx = get_f64_ptx(&CACHE, RMSNORM_PTX, "rmsnorm_kernel", "rmsnorm_f64_kernel");
18060 let f = match crate::module_cache::get_or_compile(
18061 ctx,
18062 ptx,
18063 "rmsnorm_f64_kernel",
18064 device.ordinal() as u32,
18065 ) {
18066 Ok(f) => f,
18067 Err(_) => {
18068 let h_in = gpu_to_cpu(input, device)?;
18069 let h_w = gpu_to_cpu(weight, device)?;
18070 let mut out = vec![0.0f64; rows * cols];
18071 for r in 0..rows {
18072 let base = r * cols;
18073 let slice = &h_in[base..base + cols];
18074 let sq_mean: f64 =
18075 slice.iter().map(|&x| x * x).sum::<f64>() / cols as f64;
18076 let inv_rms = 1.0 / (sq_mean + eps).sqrt();
18077 for c in 0..cols {
18078 out[base + c] = slice[c] * inv_rms * h_w[c];
18079 }
18080 }
18081 return cpu_to_gpu(&out, device);
18082 }
18083 };
18084
18085 let mut out = alloc_zeros_f64(rows * cols, device)?;
18086 let rows_u32 = rows as u32;
18087 let cols_u32 = cols as u32;
18088
18089 let cfg = LaunchConfig {
18090 grid_dim: ((rows as u32).max(1), 1, 1),
18091 block_dim: (256, 1, 1),
18092 shared_mem_bytes: 256 * 8,
18093 };
18094
18095 unsafe {
18096 stream
18097 .launch_builder(&f)
18098 .arg(input.inner())
18099 .arg(out.inner_mut())
18100 .arg(weight.inner())
18101 .arg(&rows_u32)
18102 .arg(&cols_u32)
18103 .arg(&eps)
18104 .launch(cfg)?;
18105 }
18106
18107 Ok(out)
18108}
18109
18110#[cfg(feature = "cuda")]
18114pub fn gpu_rmsnorm_backward_f64(
18115 input: &CudaBuffer<f64>,
18116 grad_output: &CudaBuffer<f64>,
18117 weight: &CudaBuffer<f64>,
18118 rows: usize,
18119 cols: usize,
18120 eps: f64,
18121 device: &GpuDevice,
18122) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
18123 use cudarc::driver::PushKernelArg;
18124 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
18125
18126 validate_device(input, device)?;
18127
18128 let ctx = device.context();
18129 let stream = device.stream();
18130
18131 let ptx = get_f64_ptx(&CACHE, RMSNORM_BACKWARD_PTX, "rmsnorm_backward_kernel", "rmsnorm_backward_f64_kernel");
18132 let f = match crate::module_cache::get_or_compile(
18133 ctx,
18134 ptx,
18135 "rmsnorm_backward_f64_kernel",
18136 device.ordinal() as u32,
18137 ) {
18138 Ok(f) => f,
18139 Err(_) => {
18140 let h_in = gpu_to_cpu(input, device)?;
18141 let h_go = gpu_to_cpu(grad_output, device)?;
18142 let h_w = gpu_to_cpu(weight, device)?;
18143 let mut grad_input = vec![0.0f64; rows * cols];
18144 let mut grad_weight = vec![0.0f64; cols];
18145 let n_f = cols as f64;
18146 for r in 0..rows {
18147 let base = r * cols;
18148 let x_slice = &h_in[base..base + cols];
18149 let go_slice = &h_go[base..base + cols];
18150 let sq_mean: f64 =
18151 x_slice.iter().map(|&x| x * x).sum::<f64>() / n_f;
18152 let inv_rms = 1.0 / (sq_mean + eps).sqrt();
18153 let inv_rms3 = inv_rms * inv_rms * inv_rms;
18154 let mut dot = 0.0f64;
18155 for c in 0..cols {
18156 dot += go_slice[c] * x_slice[c] * h_w[c];
18157 grad_weight[c] += go_slice[c] * x_slice[c] * inv_rms;
18158 }
18159 let coeff = dot * inv_rms3 / n_f;
18160 for c in 0..cols {
18161 grad_input[base + c] =
18162 inv_rms * h_w[c] * go_slice[c] - x_slice[c] * coeff;
18163 }
18164 }
18165 let gi = cpu_to_gpu(&grad_input, device)?;
18166 let gw = cpu_to_gpu(&grad_weight, device)?;
18167 return Ok((gi, gw));
18168 }
18169 };
18170
18171 let mut grad_in = alloc_zeros_f64(rows * cols, device)?;
18172 let mut grad_w = alloc_zeros_f64(cols, device)?;
18173 let rows_u32 = rows as u32;
18174 let cols_u32 = cols as u32;
18175
18176 let cfg = LaunchConfig {
18177 grid_dim: ((rows as u32).max(1), 1, 1),
18178 block_dim: (256, 1, 1),
18179 shared_mem_bytes: 256 * 8,
18180 };
18181
18182 unsafe {
18183 stream
18184 .launch_builder(&f)
18185 .arg(input.inner())
18186 .arg(grad_output.inner())
18187 .arg(weight.inner())
18188 .arg(grad_in.inner_mut())
18189 .arg(grad_w.inner_mut())
18190 .arg(&rows_u32)
18191 .arg(&cols_u32)
18192 .arg(&eps)
18193 .launch(cfg)?;
18194 }
18195
18196 Ok((grad_in, grad_w))
18197}
18198
18199#[cfg(not(feature = "cuda"))]
18204pub fn gpu_softmax_f64(_input: &CudaBuffer<f64>, _rows: usize, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18205#[cfg(not(feature = "cuda"))]
18206pub fn gpu_softmax_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18207#[cfg(not(feature = "cuda"))]
18208pub fn gpu_log_softmax_f64(_input: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18209#[cfg(not(feature = "cuda"))]
18210pub fn gpu_log_softmax_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18211#[cfg(not(feature = "cuda"))]
18212pub fn gpu_layernorm_f64(_input: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _bias: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18213#[cfg(not(feature = "cuda"))]
18214pub fn gpu_layernorm_backward_f64(_input: &CudaBuffer<f64>, _grad_output: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18215#[cfg(not(feature = "cuda"))]
18216pub fn gpu_rmsnorm_f64(_input: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18217#[cfg(not(feature = "cuda"))]
18218pub fn gpu_rmsnorm_backward_f64(_input: &CudaBuffer<f64>, _grad_output: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18219
18220#[cfg(not(feature = "cuda"))]
18225pub fn gpu_gelu_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18226#[cfg(not(feature = "cuda"))]
18227pub fn gpu_gelu_tanh_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18228#[cfg(not(feature = "cuda"))]
18229pub fn gpu_gelu_erf_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18230#[cfg(not(feature = "cuda"))]
18231pub fn gpu_gelu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18232#[cfg(not(feature = "cuda"))]
18233pub fn gpu_gelu_backward_tanh_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18234#[cfg(not(feature = "cuda"))]
18235pub fn gpu_gelu_backward_erf_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18236#[cfg(not(feature = "cuda"))]
18237pub fn gpu_silu_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18238#[cfg(not(feature = "cuda"))]
18239pub fn gpu_silu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18240#[cfg(not(feature = "cuda"))]
18241pub fn gpu_elu_f64(_input: &CudaBuffer<f64>, _alpha: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18242#[cfg(not(feature = "cuda"))]
18243pub fn gpu_elu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _alpha: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18244#[cfg(not(feature = "cuda"))]
18245pub fn gpu_mish_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18246#[cfg(not(feature = "cuda"))]
18247pub fn gpu_mish_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18248#[cfg(not(feature = "cuda"))]
18249pub fn gpu_clamp_f64(_input: &CudaBuffer<f64>, _min: f64, _max: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18250#[cfg(not(feature = "cuda"))]
18251pub fn gpu_cumsum_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18252#[cfg(not(feature = "cuda"))]
18253pub fn gpu_cumprod_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18254#[cfg(not(feature = "cuda"))]
18255pub fn gpu_cummax_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18256#[cfg(not(feature = "cuda"))]
18257pub fn gpu_cummin_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18258#[cfg(not(feature = "cuda"))]
18259pub fn gpu_logcumsumexp_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18260
18261#[cfg(test)]
18266#[cfg(feature = "cuda")]
18267mod tests {
18268 use super::*;
18269
18270 fn setup(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
18272 let dev = GpuDevice::new(0).expect("CUDA device 0");
18273 let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
18274 (dev, buf)
18275 }
18276
18277 fn assert_buf_eq(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32]) {
18280 let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
18281 assert_eq!(host.len(), expected.len(), "length mismatch");
18282 for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
18283 assert!(
18284 (got - exp).abs() < 1e-6,
18285 "element {i}: got {got}, expected {exp}",
18286 );
18287 }
18288 }
18289
18290 #[test]
18293 fn add_basic() {
18294 let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
18295 let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
18296 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
18297
18298 let (dev, a) = setup(&a_data);
18299 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18300 let out = gpu_add(&a, &b, &dev).expect("gpu_add");
18301 assert_buf_eq(&out, &dev, &expected);
18302 }
18303
18304 #[test]
18305 fn add_empty() {
18306 let (dev, a) = setup(&[]);
18307 let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
18308 let out = gpu_add(&a, &b, &dev).expect("gpu_add empty");
18309 assert_eq!(out.len(), 0);
18310 }
18311
18312 #[test]
18313 fn add_large() {
18314 let n = 100_000;
18315 let a_data: Vec<f32> = (0..n).map(|i| i as f32).collect();
18316 let b_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
18317 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
18318
18319 let (dev, a) = setup(&a_data);
18320 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18321 let out = gpu_add(&a, &b, &dev).expect("gpu_add large");
18322 assert_buf_eq(&out, &dev, &expected);
18323 }
18324
18325 #[test]
18326 fn add_length_mismatch() {
18327 let (dev, a) = setup(&[1.0, 2.0, 3.0]);
18328 let b = cpu_to_gpu(&[1.0, 2.0], &dev).expect("cpu_to_gpu b");
18329 let err = gpu_add(&a, &b, &dev).unwrap_err();
18330 match err {
18331 GpuError::LengthMismatch { a: 3, b: 2 } => {}
18332 other => panic!("unexpected error: {other}"),
18333 }
18334 }
18335
18336 #[test]
18339 fn sub_basic() {
18340 let a_data = vec![10.0f32, 20.0, 30.0, 40.0];
18341 let b_data = vec![1.0f32, 2.0, 3.0, 4.0];
18342 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x - y).collect();
18343
18344 let (dev, a) = setup(&a_data);
18345 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18346 let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
18347 assert_buf_eq(&out, &dev, &expected);
18348 }
18349
18350 #[test]
18351 fn sub_negative_result() {
18352 let a_data = vec![1.0f32, 2.0];
18353 let b_data = vec![5.0f32, 10.0];
18354 let expected: Vec<f32> = vec![-4.0, -8.0];
18355
18356 let (dev, a) = setup(&a_data);
18357 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18358 let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
18359 assert_buf_eq(&out, &dev, &expected);
18360 }
18361
18362 #[test]
18365 fn mul_basic() {
18366 let a_data = vec![2.0f32, 3.0, 4.0, 5.0];
18367 let b_data = vec![10.0f32, 10.0, 10.0, 10.0];
18368 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x * y).collect();
18369
18370 let (dev, a) = setup(&a_data);
18371 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18372 let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
18373 assert_buf_eq(&out, &dev, &expected);
18374 }
18375
18376 #[test]
18377 fn mul_by_zero() {
18378 let a_data = vec![1.0f32, 2.0, 3.0];
18379 let b_data = vec![0.0f32, 0.0, 0.0];
18380 let expected = vec![0.0f32, 0.0, 0.0];
18381
18382 let (dev, a) = setup(&a_data);
18383 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18384 let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
18385 assert_buf_eq(&out, &dev, &expected);
18386 }
18387
18388 #[test]
18391 fn neg_basic() {
18392 let a_data = vec![1.0f32, -2.0, 3.0, 0.0, -5.5];
18393 let expected: Vec<f32> = a_data.iter().map(|x| -x).collect();
18394
18395 let (dev, a) = setup(&a_data);
18396 let out = gpu_neg(&a, &dev).expect("gpu_neg");
18397 assert_buf_eq(&out, &dev, &expected);
18398 }
18399
18400 #[test]
18401 fn neg_double_negation() {
18402 let a_data = vec![1.0f32, -2.0, 3.0];
18403 let (dev, a) = setup(&a_data);
18404 let neg1 = gpu_neg(&a, &dev).expect("gpu_neg 1");
18405 let neg2 = gpu_neg(&neg1, &dev).expect("gpu_neg 2");
18406 assert_buf_eq(&neg2, &dev, &a_data);
18407 }
18408
18409 #[test]
18412 fn relu_basic() {
18413 let a_data = vec![-3.0f32, -1.0, 0.0, 1.0, 3.0];
18414 let expected = vec![0.0f32, 0.0, 0.0, 1.0, 3.0];
18415
18416 let (dev, a) = setup(&a_data);
18417 let out = gpu_relu(&a, &dev).expect("gpu_relu");
18418 assert_buf_eq(&out, &dev, &expected);
18419 }
18420
18421 #[test]
18422 fn relu_all_negative() {
18423 let a_data = vec![-5.0f32, -0.1, -100.0];
18424 let expected = vec![0.0f32, 0.0, 0.0];
18425
18426 let (dev, a) = setup(&a_data);
18427 let out = gpu_relu(&a, &dev).expect("gpu_relu");
18428 assert_buf_eq(&out, &dev, &expected);
18429 }
18430
18431 #[test]
18432 fn relu_all_positive() {
18433 let a_data = vec![0.1f32, 1.0, 100.0];
18434
18435 let (dev, a) = setup(&a_data);
18436 let out = gpu_relu(&a, &dev).expect("gpu_relu");
18437 assert_buf_eq(&out, &dev, &a_data);
18438 }
18439
18440 #[test]
18441 fn relu_empty() {
18442 let (dev, a) = setup(&[]);
18443 let out = gpu_relu(&a, &dev).expect("gpu_relu empty");
18444 assert_eq!(out.len(), 0);
18445 }
18446
18447 #[test]
18448 fn small_matmul_2x2() {
18449 let dev = GpuDevice::new(0).expect("CUDA device 0");
18450 let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0, 4.0], &dev).unwrap();
18453 let b = cpu_to_gpu(&[5.0f32, 6.0, 7.0, 8.0], &dev).unwrap();
18454 let c = gpu_small_matmul(&a, &b, 2, 2, 2, &dev).unwrap();
18455 assert_buf_eq(&c, &dev, &[19.0, 22.0, 43.0, 50.0]);
18456 }
18457
18458 #[test]
18459 fn small_matmul_1xk_kxn() {
18460 let dev = GpuDevice::new(0).expect("CUDA device 0");
18461 let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0], &dev).unwrap();
18464 let b = cpu_to_gpu(&[1.0f32, 0.0, 0.0, 1.0, 1.0, 1.0], &dev).unwrap();
18465 let c = gpu_small_matmul(&a, &b, 1, 3, 2, &dev).unwrap();
18466 assert_buf_eq(&c, &dev, &[4.0, 5.0]);
18467 }
18468
18469 #[test]
18470 fn small_matmul_vs_cublas() {
18471 let dev = GpuDevice::new(0).expect("CUDA device 0");
18474 let m = 1;
18475 let k = 64;
18476 let n = 64;
18477
18478 let a_data: Vec<f32> = (0..m * k)
18480 .map(|i| ((i * 7 + 3) % 100) as f32 / 100.0)
18481 .collect();
18482 let b_data: Vec<f32> = (0..k * n)
18483 .map(|i| ((i * 11 + 5) % 100) as f32 / 100.0)
18484 .collect();
18485
18486 let a = cpu_to_gpu(&a_data, &dev).unwrap();
18487 let b = cpu_to_gpu(&b_data, &dev).unwrap();
18488
18489 let c_cublas = crate::blas::gpu_matmul_f32(&a, &b, m, k, n, &dev).unwrap();
18491 let cublas_result = gpu_to_cpu(&c_cublas, &dev).unwrap();
18492
18493 let c_ours = gpu_small_matmul(&a, &b, m, k, n, &dev).unwrap();
18495 let our_result = gpu_to_cpu(&c_ours, &dev).unwrap();
18496
18497 assert_eq!(cublas_result.len(), our_result.len());
18498 for (i, (&cb, &ours)) in cublas_result.iter().zip(our_result.iter()).enumerate() {
18499 assert!(
18500 (cb - ours).abs() < 0.1,
18501 "element {i}: cuBLAS={cb}, ours={ours}, diff={}",
18502 (cb - ours).abs()
18503 );
18504 }
18505 }
18506
18507 #[test]
18510 fn strided_copy_identity_contiguous_2d() {
18511 let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
18515 let (dev, input) = setup(&data);
18516 let out = gpu_strided_copy(&input, &[2, 3], &[3, 1], 0, &dev)
18517 .expect("strided_copy identity");
18518 assert_buf_eq(&out, &dev, &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
18519 }
18520
18521 #[test]
18522 fn strided_copy_transpose_2d() {
18523 let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
18530 let (dev, input) = setup(&data);
18531 let out = gpu_strided_copy(&input, &[3, 2], &[1, 3], 0, &dev)
18532 .expect("strided_copy transpose");
18533 assert_buf_eq(&out, &dev, &[0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
18534 }
18535
18536 #[test]
18537 fn strided_copy_sliced_column() {
18538 let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
18545 let (dev, input) = setup(&data);
18546 let out = gpu_strided_copy(&input, &[3], &[4], 2, &dev)
18547 .expect("strided_copy col slice");
18548 assert_buf_eq(&out, &dev, &[2.0, 6.0, 10.0]);
18549 }
18550
18551 #[test]
18552 fn strided_copy_3d_permute() {
18553 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
18560 let (dev, input) = setup(&data);
18561 let out =
18562 gpu_strided_copy(&input, &[2, 4, 3], &[12, 1, 4], 0, &dev).expect("strided_copy 3d");
18563
18564 let mut expected = vec![0.0f32; 24];
18565 for b in 0..2 {
18566 for i in 0..4 {
18567 for j in 0..3 {
18568 let dst = b * 12 + i * 3 + j;
18569 let src = b * 12 + j * 4 + i;
18570 expected[dst] = data[src];
18571 }
18572 }
18573 }
18574 assert_buf_eq(&out, &dev, &expected);
18575 }
18576
18577 #[test]
18578 fn strided_copy_4d_max_rank_supported() {
18579 let shape = [2usize, 3, 2, 2];
18581 let n: usize = shape.iter().product();
18582 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
18583 let (dev, input) = setup(&data);
18584 let out = gpu_strided_copy(&input, &shape, &[12, 4, 2, 1], 0, &dev)
18586 .expect("strided_copy 4d");
18587 assert_buf_eq(&out, &dev, &data);
18588 }
18589
18590 #[test]
18591 fn strided_copy_rejects_too_many_dims() {
18592 let (dev, input) = setup(&[0.0f32; 16]);
18593 let result = gpu_strided_copy(
18595 &input,
18596 &[1, 1, 1, 1, 1, 1, 1, 1, 16],
18597 &[1; 9],
18598 0,
18599 &dev,
18600 );
18601 assert!(result.is_err());
18602 }
18603
18604 #[test]
18605 fn strided_copy_rejects_shape_stride_length_mismatch() {
18606 let (dev, input) = setup(&[0.0f32; 12]);
18607 let result = gpu_strided_copy(&input, &[3, 4], &[4, 1, 1], 0, &dev);
18608 assert!(result.is_err());
18609 }
18610
18611 #[test]
18612 fn strided_copy_rejects_negative_stride() {
18613 let (dev, input) = setup(&[0.0f32; 6]);
18614 let result = gpu_strided_copy(&input, &[2, 3], &[3, -1], 0, &dev);
18615 assert!(result.is_err());
18616 }
18617
18618 #[test]
18619 fn strided_copy_empty_output() {
18620 let (dev, input) = setup(&[1.0f32, 2.0, 3.0]);
18621 let out = gpu_strided_copy(&input, &[0, 3], &[3, 1], 0, &dev)
18622 .expect("strided_copy empty");
18623 assert_eq!(out.len(), 0);
18624 }
18625
18626 #[test]
18627 fn strided_copy_f64_transpose_matches_f32() {
18628 let data: Vec<f64> = (0..6).map(|i| i as f64).collect();
18630 let dev = GpuDevice::new(0).expect("CUDA device 0");
18631 let input = cpu_to_gpu(&data, &dev).expect("cpu_to_gpu f64");
18632 let out = gpu_strided_copy_f64(&input, &[3, 2], &[1, 3], 0, &dev)
18633 .expect("strided_copy_f64 transpose");
18634 let host = gpu_to_cpu(&out, &dev).expect("gpu_to_cpu f64");
18635 assert_eq!(host, vec![0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
18636 }
18637}