1#[cfg(feature = "cuda")]
23use cudarc::driver::LaunchConfig;
24
25use crate::buffer::CudaBuffer;
26use crate::device::GpuDevice;
27use crate::error::{GpuError, GpuResult};
28#[cfg(feature = "cuda")]
29use crate::transfer::{alloc_zeros_f32, alloc_zeros_f64, cpu_to_gpu, gpu_to_cpu};
30
31#[cfg(feature = "cuda")]
43pub(crate) fn ptx_f32_to_f64(f32_ptx: &str, f32_kernel_name: &str, f64_kernel_name: &str) -> String {
44 f32_ptx
45 .replace(f32_kernel_name, f64_kernel_name)
47 .replace(".reg .f32", ".reg .f64")
49 .replace("ld.global.f32", "ld.global.f64")
51 .replace("st.global.f32", "st.global.f64")
52 .replace("ld.shared.f32", "ld.shared.f64")
53 .replace("st.shared.f32", "st.shared.f64")
54 .replace("ld.param.f32", "ld.param.f64")
55 .replace(".param .f32", ".param .f64")
56 .replace(".shared .align 4 .f32", ".shared .align 8 .f64")
58 .replace("add.f32", "add.f64")
60 .replace("sub.f32", "sub.f64")
61 .replace("mul.f32", "mul.f64")
62 .replace("div.rn.f32", "div.rn.f64")
63 .replace("div.f32", "div.f64")
64 .replace("neg.f32", "neg.f64")
65 .replace("abs.f32", "abs.f64")
66 .replace("max.f32", "max.f64")
67 .replace("min.f32", "min.f64")
68 .replace("sqrt.rn.f32", "sqrt.rn.f64")
69 .replace("sqrt.f32", "sqrt.f64")
70 .replace("fma.rn.f32", "fma.rn.f64")
71 .replace("mov.f32", "mov.f64")
72 .replace("setp.gt.f32", "setp.gt.f64")
74 .replace("setp.ge.f32", "setp.ge.f64")
75 .replace("setp.lt.f32", "setp.lt.f64")
76 .replace("setp.le.f32", "setp.le.f64")
77 .replace("setp.eq.f32", "setp.eq.f64")
78 .replace("setp.ne.f32", "setp.ne.f64")
79 .replace("cvt.rn.f32.u32", "cvt.rn.f64.u32")
81 .replace("cvt.rn.f32.s32", "cvt.rn.f64.s32")
82 .replace("mov.b32", "mov.b64")
84 .replace("shl.b64 %off, %off, 2", "shl.b64 %off, %off, 3")
86 .replace("atom.global.add.f32", "atom.global.add.f64")
88 .replace("0f00000000", "0d0000000000000000") .replace("0f3F800000", "0d3FF0000000000000") .replace("0fBF800000", "0dBFF0000000000000") .replace("0f40000000", "0d4000000000000000") .replace("0f3F000000", "0d3FE0000000000000") .replace("0fFF800000", "0dFFF0000000000000") .replace("0f7F800000", "0d7FF0000000000000") .replace("0f3FB8AA3B", "0d3FF71547652B82FE") .replace("0f3F317218", "0d3FE62E42FEFA39EF") }
99
100#[cfg(feature = "cuda")]
105pub(crate) fn get_f64_ptx<'a>(
106 cache: &'a std::sync::OnceLock<String>,
107 f32_ptx: &str,
108 f32_name: &str,
109 f64_name: &str,
110) -> &'a str {
111 cache.get_or_init(|| ptx_f32_to_f64(f32_ptx, f32_name, f64_name))
112}
113
114#[cfg(feature = "cuda")]
120pub(crate) const ADD_PTX: &str = "\
121.version 7.0
122.target sm_52
123.address_size 64
124
125.visible .entry add_kernel(
126 .param .u64 a_ptr,
127 .param .u64 b_ptr,
128 .param .u64 out_ptr,
129 .param .u32 n
130) {
131 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
132 .reg .u64 %a, %b, %out, %off;
133 .reg .f32 %va, %vb, %vr;
134 .reg .pred %p;
135
136 ld.param.u64 %a, [a_ptr];
137 ld.param.u64 %b, [b_ptr];
138 ld.param.u64 %out, [out_ptr];
139 ld.param.u32 %n_reg, [n];
140
141 mov.u32 %bid, %ctaid.x;
142 mov.u32 %bdim, %ntid.x;
143 mov.u32 %r_tid, %tid.x;
144 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
145
146 setp.ge.u32 %p, %r_tid, %n_reg;
147 @%p bra DONE;
148
149 cvt.u64.u32 %off, %r_tid;
150 shl.b64 %off, %off, 2;
151
152 add.u64 %a, %a, %off;
153 add.u64 %b, %b, %off;
154 add.u64 %out, %out, %off;
155
156 ld.global.f32 %va, [%a];
157 ld.global.f32 %vb, [%b];
158 add.f32 %vr, %va, %vb;
159 st.global.f32 [%out], %vr;
160
161DONE:
162 ret;
163}
164";
165
166
167#[cfg(feature = "cuda")]
172pub(crate) const ADD_VEC4_PTX: &str = "\
173.version 7.0
174.target sm_52
175.address_size 64
176
177.visible .entry add_vec4_kernel(
178 .param .u64 a_ptr,
179 .param .u64 b_ptr,
180 .param .u64 out_ptr,
181 .param .u32 n4
182) {
183 .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
184 .reg .u64 %a, %b, %out, %off;
185 .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
186 .reg .pred %p;
187
188 ld.param.u64 %a, [a_ptr];
189 ld.param.u64 %b, [b_ptr];
190 ld.param.u64 %out, [out_ptr];
191 ld.param.u32 %n4_reg, [n4];
192
193 mov.u32 %bid, %ctaid.x;
194 mov.u32 %bdim, %ntid.x;
195 mov.u32 %r_tid, %tid.x;
196 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
197
198 setp.ge.u32 %p, %r_tid, %n4_reg;
199 @%p bra DONE;
200
201 // Byte offset = tid * 16 (4 floats × 4 bytes)
202 cvt.u64.u32 %off, %r_tid;
203 shl.b64 %off, %off, 4;
204
205 add.u64 %a, %a, %off;
206 add.u64 %b, %b, %off;
207 add.u64 %out, %out, %off;
208
209 ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
210 ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
211
212 add.f32 %r0, %a0, %b0;
213 add.f32 %r1, %a1, %b1;
214 add.f32 %r2, %a2, %b2;
215 add.f32 %r3, %a3, %b3;
216
217 st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
218
219DONE:
220 ret;
221}
222";
223
224#[cfg(feature = "cuda")]
226pub(crate) const MUL_VEC4_PTX: &str = "\
227.version 7.0
228.target sm_52
229.address_size 64
230
231.visible .entry mul_vec4_kernel(
232 .param .u64 a_ptr,
233 .param .u64 b_ptr,
234 .param .u64 out_ptr,
235 .param .u32 n4
236) {
237 .reg .u32 %r_tid, %bid, %bdim, %n4_reg;
238 .reg .u64 %a, %b, %out, %off;
239 .reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
240 .reg .pred %p;
241
242 ld.param.u64 %a, [a_ptr];
243 ld.param.u64 %b, [b_ptr];
244 ld.param.u64 %out, [out_ptr];
245 ld.param.u32 %n4_reg, [n4];
246
247 mov.u32 %bid, %ctaid.x;
248 mov.u32 %bdim, %ntid.x;
249 mov.u32 %r_tid, %tid.x;
250 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
251
252 setp.ge.u32 %p, %r_tid, %n4_reg;
253 @%p bra DONE;
254
255 cvt.u64.u32 %off, %r_tid;
256 shl.b64 %off, %off, 4;
257
258 add.u64 %a, %a, %off;
259 add.u64 %b, %b, %off;
260 add.u64 %out, %out, %off;
261
262 ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
263 ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
264
265 mul.f32 %r0, %a0, %b0;
266 mul.f32 %r1, %a1, %b1;
267 mul.f32 %r2, %a2, %b2;
268 mul.f32 %r3, %a3, %b3;
269
270 st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
271
272DONE:
273 ret;
274}
275";
276
277#[cfg(feature = "cuda")]
279pub(crate) const SUB_PTX: &str = "\
280.version 7.0
281.target sm_52
282.address_size 64
283
284.visible .entry sub_kernel(
285 .param .u64 a_ptr,
286 .param .u64 b_ptr,
287 .param .u64 out_ptr,
288 .param .u32 n
289) {
290 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
291 .reg .u64 %a, %b, %out, %off;
292 .reg .f32 %va, %vb, %vr;
293 .reg .pred %p;
294
295 ld.param.u64 %a, [a_ptr];
296 ld.param.u64 %b, [b_ptr];
297 ld.param.u64 %out, [out_ptr];
298 ld.param.u32 %n_reg, [n];
299
300 mov.u32 %bid, %ctaid.x;
301 mov.u32 %bdim, %ntid.x;
302 mov.u32 %r_tid, %tid.x;
303 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
304
305 setp.ge.u32 %p, %r_tid, %n_reg;
306 @%p bra DONE;
307
308 cvt.u64.u32 %off, %r_tid;
309 shl.b64 %off, %off, 2;
310
311 add.u64 %a, %a, %off;
312 add.u64 %b, %b, %off;
313 add.u64 %out, %out, %off;
314
315 ld.global.f32 %va, [%a];
316 ld.global.f32 %vb, [%b];
317 sub.f32 %vr, %va, %vb;
318 st.global.f32 [%out], %vr;
319
320DONE:
321 ret;
322}
323";
324
325
326#[cfg(feature = "cuda")]
328pub(crate) const MUL_PTX: &str = "\
329.version 7.0
330.target sm_52
331.address_size 64
332
333.visible .entry mul_kernel(
334 .param .u64 a_ptr,
335 .param .u64 b_ptr,
336 .param .u64 out_ptr,
337 .param .u32 n
338) {
339 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
340 .reg .u64 %a, %b, %out, %off;
341 .reg .f32 %va, %vb, %vr;
342 .reg .pred %p;
343
344 ld.param.u64 %a, [a_ptr];
345 ld.param.u64 %b, [b_ptr];
346 ld.param.u64 %out, [out_ptr];
347 ld.param.u32 %n_reg, [n];
348
349 mov.u32 %bid, %ctaid.x;
350 mov.u32 %bdim, %ntid.x;
351 mov.u32 %r_tid, %tid.x;
352 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
353
354 setp.ge.u32 %p, %r_tid, %n_reg;
355 @%p bra DONE;
356
357 cvt.u64.u32 %off, %r_tid;
358 shl.b64 %off, %off, 2;
359
360 add.u64 %a, %a, %off;
361 add.u64 %b, %b, %off;
362 add.u64 %out, %out, %off;
363
364 ld.global.f32 %va, [%a];
365 ld.global.f32 %vb, [%b];
366 mul.f32 %vr, %va, %vb;
367 st.global.f32 [%out], %vr;
368
369DONE:
370 ret;
371}
372";
373
374
375#[cfg(feature = "cuda")]
377pub(crate) const NEG_PTX: &str = "\
378.version 7.0
379.target sm_52
380.address_size 64
381
382.visible .entry neg_kernel(
383 .param .u64 a_ptr,
384 .param .u64 out_ptr,
385 .param .u32 n
386) {
387 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
388 .reg .u64 %a, %out, %off;
389 .reg .f32 %va, %vr;
390 .reg .pred %p;
391
392 ld.param.u64 %a, [a_ptr];
393 ld.param.u64 %out, [out_ptr];
394 ld.param.u32 %n_reg, [n];
395
396 mov.u32 %bid, %ctaid.x;
397 mov.u32 %bdim, %ntid.x;
398 mov.u32 %r_tid, %tid.x;
399 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
400
401 setp.ge.u32 %p, %r_tid, %n_reg;
402 @%p bra DONE;
403
404 cvt.u64.u32 %off, %r_tid;
405 shl.b64 %off, %off, 2;
406
407 add.u64 %a, %a, %off;
408 add.u64 %out, %out, %off;
409
410 ld.global.f32 %va, [%a];
411 neg.f32 %vr, %va;
412 st.global.f32 [%out], %vr;
413
414DONE:
415 ret;
416}
417";
418
419
420#[cfg(feature = "cuda")]
422pub(crate) const RELU_PTX: &str = "\
423.version 7.0
424.target sm_52
425.address_size 64
426
427.visible .entry relu_kernel(
428 .param .u64 a_ptr,
429 .param .u64 out_ptr,
430 .param .u32 n
431) {
432 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
433 .reg .u64 %a, %out, %off;
434 .reg .f32 %va, %vr, %zero;
435 .reg .pred %p;
436
437 ld.param.u64 %a, [a_ptr];
438 ld.param.u64 %out, [out_ptr];
439 ld.param.u32 %n_reg, [n];
440
441 mov.u32 %bid, %ctaid.x;
442 mov.u32 %bdim, %ntid.x;
443 mov.u32 %r_tid, %tid.x;
444 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
445
446 setp.ge.u32 %p, %r_tid, %n_reg;
447 @%p bra DONE;
448
449 cvt.u64.u32 %off, %r_tid;
450 shl.b64 %off, %off, 2;
451
452 add.u64 %a, %a, %off;
453 add.u64 %out, %out, %off;
454
455 ld.global.f32 %va, [%a];
456 mov.f32 %zero, 0f00000000;
457 max.f32 %vr, %va, %zero;
458 st.global.f32 [%out], %vr;
459
460DONE:
461 ret;
462}
463";
464
465
466#[cfg(feature = "cuda")]
468pub(crate) const SCALE_PTX: &str = "\
469.version 7.0
470.target sm_52
471.address_size 64
472
473.visible .entry scale_kernel(
474 .param .u64 a_ptr,
475 .param .u64 out_ptr,
476 .param .f32 scalar,
477 .param .u32 n
478) {
479 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
480 .reg .u64 %a, %out, %off;
481 .reg .f32 %va, %vr, %s;
482 .reg .pred %p;
483
484 ld.param.u64 %a, [a_ptr];
485 ld.param.u64 %out, [out_ptr];
486 ld.param.f32 %s, [scalar];
487 ld.param.u32 %n_reg, [n];
488
489 mov.u32 %bid, %ctaid.x;
490 mov.u32 %bdim, %ntid.x;
491 mov.u32 %r_tid, %tid.x;
492 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
493
494 setp.ge.u32 %p, %r_tid, %n_reg;
495 @%p bra DONE;
496
497 cvt.u64.u32 %off, %r_tid;
498 shl.b64 %off, %off, 2;
499
500 add.u64 %a, %a, %off;
501 add.u64 %out, %out, %off;
502
503 ld.global.f32 %va, [%a];
504 mul.f32 %vr, %va, %s;
505 st.global.f32 [%out], %vr;
506
507DONE:
508 ret;
509}
510";
511
512
513#[cfg(feature = "cuda")]
516pub(crate) const TRANSPOSE_2D_PTX: &str = "\
517.version 7.0\n\
518.target sm_52\n\
519.address_size 64\n\
520\n\
521.visible .entry transpose_2d_kernel(\n\
522 .param .u64 in_ptr,\n\
523 .param .u64 out_ptr,\n\
524 .param .u32 M,\n\
525 .param .u32 N,\n\
526 .param .u32 total\n\
527) {\n\
528 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %N_reg;\n\
529 .reg .u32 %out_row, %out_col, %in_idx;\n\
530 .reg .u64 %in, %out, %off_in, %off_out;\n\
531 .reg .f32 %val;\n\
532 .reg .pred %p;\n\
533\n\
534 ld.param.u64 %in, [in_ptr];\n\
535 ld.param.u64 %out, [out_ptr];\n\
536 ld.param.u32 %M_reg, [M];\n\
537 ld.param.u32 %N_reg, [N];\n\
538 ld.param.u32 %total_reg, [total];\n\
539\n\
540 mov.u32 %bid, %ctaid.x;\n\
541 mov.u32 %bdim, %ntid.x;\n\
542 mov.u32 %r_tid, %tid.x;\n\
543 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
544\n\
545 setp.ge.u32 %p, %r_tid, %total_reg;\n\
546 @%p bra DONE;\n\
547\n\
548 // Output shape is [N, M]. tid = out_row * M + out_col.\n\
549 div.u32 %out_row, %r_tid, %M_reg;\n\
550 rem.u32 %out_col, %r_tid, %M_reg;\n\
551 // Input index: out_col * N + out_row (transposed).\n\
552 mad.lo.u32 %in_idx, %out_col, %N_reg, %out_row;\n\
553\n\
554 cvt.u64.u32 %off_in, %in_idx;\n\
555 shl.b64 %off_in, %off_in, 2;\n\
556 add.u64 %off_in, %in, %off_in;\n\
557 ld.global.f32 %val, [%off_in];\n\
558\n\
559 cvt.u64.u32 %off_out, %r_tid;\n\
560 shl.b64 %off_out, %off_out, 2;\n\
561 add.u64 %off_out, %out, %off_out;\n\
562 st.global.f32 [%off_out], %val;\n\
563\n\
564DONE:\n\
565 ret;\n\
566}\n\
567";
568
569
570#[cfg(feature = "cuda")]
578pub(crate) const PERMUTE_0213_PTX: &str = "\
579.version 7.0\n\
580.target sm_52\n\
581.address_size 64\n\
582\n\
583.visible .entry permute_0213_kernel(\n\
584 .param .u64 in_ptr,\n\
585 .param .u64 out_ptr,\n\
586 .param .u32 d0,\n\
587 .param .u32 d1,\n\
588 .param .u32 d2,\n\
589 .param .u32 d3,\n\
590 .param .u32 total\n\
591) {\n\
592 .reg .u32 %r_tid, %bid, %bdim, %total_reg;\n\
593 .reg .u32 %d0r, %d1r, %d2r, %d3r;\n\
594 .reg .u32 %i0, %i1, %i2, %i3, %rem, %in_idx;\n\
595 .reg .u32 %s_out2, %s_out1, %s_in1;\n\
596 .reg .u64 %in, %out, %off_in, %off_out;\n\
597 .reg .f32 %val;\n\
598 .reg .pred %p;\n\
599\n\
600 ld.param.u64 %in, [in_ptr];\n\
601 ld.param.u64 %out, [out_ptr];\n\
602 ld.param.u32 %d0r, [d0];\n\
603 ld.param.u32 %d1r, [d1];\n\
604 ld.param.u32 %d2r, [d2];\n\
605 ld.param.u32 %d3r, [d3];\n\
606 ld.param.u32 %total_reg, [total];\n\
607\n\
608 mov.u32 %bid, %ctaid.x;\n\
609 mov.u32 %bdim, %ntid.x;\n\
610 mov.u32 %r_tid, %tid.x;\n\
611 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
612\n\
613 setp.ge.u32 %p, %r_tid, %total_reg;\n\
614 @%p bra DONE;\n\
615\n\
616 // Output shape: [d0, d2, d1, d3]\n\
617 // Decompose tid into (i0, i2, i1, i3) in output layout.\n\
618 mul.lo.u32 %s_out2, %d1r, %d3r;\n\
619 mul.lo.u32 %s_out1, %s_out2, %d2r;\n\
620\n\
621 div.u32 %i0, %r_tid, %s_out1;\n\
622 rem.u32 %rem, %r_tid, %s_out1;\n\
623 div.u32 %i2, %rem, %s_out2;\n\
624 rem.u32 %rem, %rem, %s_out2;\n\
625 div.u32 %i1, %rem, %d3r;\n\
626 rem.u32 %i3, %rem, %d3r;\n\
627\n\
628 // Input index: i0 * (d1*d2*d3) + i1 * (d2*d3) + i2 * d3 + i3\n\
629 mul.lo.u32 %s_in1, %d2r, %d3r;\n\
630 mul.lo.u32 %in_idx, %i0, %d1r;\n\
631 add.u32 %in_idx, %in_idx, %i1;\n\
632 mul.lo.u32 %in_idx, %in_idx, %s_in1;\n\
633 mad.lo.u32 %in_idx, %i2, %d3r, %in_idx;\n\
634 add.u32 %in_idx, %in_idx, %i3;\n\
635\n\
636 cvt.u64.u32 %off_in, %in_idx;\n\
637 shl.b64 %off_in, %off_in, 2;\n\
638 add.u64 %off_in, %in, %off_in;\n\
639 ld.global.f32 %val, [%off_in];\n\
640\n\
641 cvt.u64.u32 %off_out, %r_tid;\n\
642 shl.b64 %off_out, %off_out, 2;\n\
643 add.u64 %off_out, %out, %off_out;\n\
644 st.global.f32 [%off_out], %val;\n\
645\n\
646DONE:\n\
647 ret;\n\
648}\n\
649";
650
651
652#[cfg(feature = "cuda")]
659pub(crate) const F32_TO_F16_PTX: &str = "\
660.version 7.0
661.target sm_52
662.address_size 64
663
664.visible .entry f32_to_f16_kernel(
665 .param .u64 in_ptr,
666 .param .u64 out_ptr,
667 .param .u32 n
668) {
669 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
670 .reg .u64 %in, %out, %off_in, %off_out;
671 .reg .f32 %vf;
672 .reg .b16 %vh;
673 .reg .pred %p;
674
675 ld.param.u64 %in, [in_ptr];
676 ld.param.u64 %out, [out_ptr];
677 ld.param.u32 %n_reg, [n];
678
679 mov.u32 %bid, %ctaid.x;
680 mov.u32 %bdim, %ntid.x;
681 mov.u32 %r_tid, %tid.x;
682 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
683
684 setp.ge.u32 %p, %r_tid, %n_reg;
685 @%p bra DONE;
686
687 // Compute input offset: i * 4 (f32 = 4 bytes)
688 cvt.u64.u32 %off_in, %r_tid;
689 shl.b64 %off_in, %off_in, 2;
690 add.u64 %in, %in, %off_in;
691
692 // Compute output offset: i * 2 (f16 = 2 bytes)
693 cvt.u64.u32 %off_out, %r_tid;
694 shl.b64 %off_out, %off_out, 1;
695 add.u64 %out, %out, %off_out;
696
697 // Load f32, convert to f16 (round-to-nearest-even), store as u16
698 ld.global.f32 %vf, [%in];
699 cvt.rn.f16.f32 %vh, %vf;
700 st.global.b16 [%out], %vh;
701
702DONE:
703 ret;
704}
705";
706
707#[cfg(feature = "cuda")]
714pub(crate) const F32_TO_BF16_PTX: &str = "\
715.version 7.0
716.target sm_52
717.address_size 64
718
719.visible .entry f32_to_bf16_kernel(
720 .param .u64 in_ptr,
721 .param .u64 out_ptr,
722 .param .u32 n
723) {
724 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
725 .reg .u64 %in, %out, %off_in, %off_out;
726 .reg .f32 %vf;
727 .reg .u32 %bits, %round, %lsb, %result;
728 .reg .pred %p;
729
730 ld.param.u64 %in, [in_ptr];
731 ld.param.u64 %out, [out_ptr];
732 ld.param.u32 %n_reg, [n];
733
734 mov.u32 %bid, %ctaid.x;
735 mov.u32 %bdim, %ntid.x;
736 mov.u32 %r_tid, %tid.x;
737 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
738
739 setp.ge.u32 %p, %r_tid, %n_reg;
740 @%p bra DONE;
741
742 cvt.u64.u32 %off_in, %r_tid;
743 shl.b64 %off_in, %off_in, 2;
744 add.u64 %in, %in, %off_in;
745
746 cvt.u64.u32 %off_out, %r_tid;
747 shl.b64 %off_out, %off_out, 1;
748 add.u64 %out, %out, %off_out;
749
750 // Load f32 as raw bits
751 ld.global.u32 %bits, [%in];
752
753 // Round-to-nearest-even: add (0x7FFF + bit[16]) then shift right 16
754 shr.u32 %lsb, %bits, 16;
755 and.b32 %lsb, %lsb, 1;
756 add.u32 %round, %bits, 0x7FFF;
757 add.u32 %round, %round, %lsb;
758 shr.u32 %result, %round, 16;
759
760 // Store as u16
761 st.global.u16 [%out], %result;
762
763DONE:
764 ret;
765}
766";
767
768#[cfg(feature = "cuda")]
775pub(crate) const SMALL_MATMUL_PTX: &str = "\
776.version 7.0
777.target sm_52
778.address_size 64
779
780.visible .entry small_matmul_kernel(
781 .param .u64 a_ptr,
782 .param .u64 b_ptr,
783 .param .u64 c_ptr,
784 .param .u32 M,
785 .param .u32 K,
786 .param .u32 N,
787 .param .u32 total
788) {
789 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %K_reg, %N_reg;
790 .reg .u32 %row, %col, %p, %idx;
791 .reg .u64 %a, %b, %c, %a_off, %b_off, %c_off;
792 .reg .f32 %sum, %va, %vb;
793 .reg .pred %bounds_p, %loop_p;
794
795 ld.param.u64 %a, [a_ptr];
796 ld.param.u64 %b, [b_ptr];
797 ld.param.u64 %c, [c_ptr];
798 ld.param.u32 %M_reg, [M];
799 ld.param.u32 %K_reg, [K];
800 ld.param.u32 %N_reg, [N];
801 ld.param.u32 %total_reg, [total];
802
803 mov.u32 %bid, %ctaid.x;
804 mov.u32 %bdim, %ntid.x;
805 mov.u32 %r_tid, %tid.x;
806 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
807
808 setp.ge.u32 %bounds_p, %r_tid, %total_reg;
809 @%bounds_p bra DONE;
810
811 div.u32 %row, %r_tid, %N_reg;
812 rem.u32 %col, %r_tid, %N_reg;
813
814 mov.f32 %sum, 0f00000000;
815 mov.u32 %p, 0;
816DOT:
817 setp.ge.u32 %loop_p, %p, %K_reg;
818 @%loop_p bra DOT_DONE;
819
820 mad.lo.u32 %idx, %row, %K_reg, %p;
821 cvt.u64.u32 %a_off, %idx;
822 shl.b64 %a_off, %a_off, 2;
823 add.u64 %a_off, %a, %a_off;
824 ld.global.f32 %va, [%a_off];
825
826 mad.lo.u32 %idx, %p, %N_reg, %col;
827 cvt.u64.u32 %b_off, %idx;
828 shl.b64 %b_off, %b_off, 2;
829 add.u64 %b_off, %b, %b_off;
830 ld.global.f32 %vb, [%b_off];
831
832 fma.rn.f32 %sum, %va, %vb, %sum;
833 add.u32 %p, %p, 1;
834 bra DOT;
835DOT_DONE:
836
837 cvt.u64.u32 %c_off, %r_tid;
838 shl.b64 %c_off, %c_off, 2;
839 add.u64 %c_off, %c, %c_off;
840 st.global.f32 [%c_off], %sum;
841
842DONE:
843 ret;
844}
845";
846
847#[cfg(feature = "cuda")]
852pub(crate) const SLICE_WRITE_PTX: &str = "\
853.version 7.0
854.target sm_52
855.address_size 64
856
857.visible .entry slice_write_kernel(
858 .param .u64 src_ptr,
859 .param .u64 dst_ptr,
860 .param .u32 n,
861 .param .u32 D,
862 .param .u32 max_len,
863 .param .u32 pos
864) {
865 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
866 .reg .u32 %batch_idx, %d_idx, %dst_row;
867 .reg .u64 %src, %dst, %src_off, %dst_off;
868 .reg .f32 %val;
869 .reg .pred %p;
870
871 ld.param.u64 %src, [src_ptr];
872 ld.param.u64 %dst, [dst_ptr];
873 ld.param.u32 %n_reg, [n];
874 ld.param.u32 %D_reg, [D];
875 ld.param.u32 %max_len_reg, [max_len];
876 ld.param.u32 %pos_reg, [pos];
877
878 mov.u32 %bid, %ctaid.x;
879 mov.u32 %bdim, %ntid.x;
880 mov.u32 %r_tid, %tid.x;
881 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
882
883 setp.ge.u32 %p, %r_tid, %n_reg;
884 @%p bra DONE;
885
886 cvt.u64.u32 %src_off, %r_tid;
887 shl.b64 %src_off, %src_off, 2;
888 add.u64 %src, %src, %src_off;
889 ld.global.f32 %val, [%src];
890
891 div.u32 %batch_idx, %r_tid, %D_reg;
892 rem.u32 %d_idx, %r_tid, %D_reg;
893 mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
894 add.u32 %dst_row, %dst_row, %pos_reg;
895 mul.lo.u32 %dst_row, %dst_row, %D_reg;
896 add.u32 %dst_row, %dst_row, %d_idx;
897 cvt.u64.u32 %dst_off, %dst_row;
898 shl.b64 %dst_off, %dst_off, 2;
899 add.u64 %dst, %dst, %dst_off;
900 st.global.f32 [%dst], %val;
901
902DONE:
903 ret;
904}
905";
906
907
908#[cfg(feature = "cuda")]
913pub(crate) const SLICE_WRITE_INDIRECT_PTX: &str = "\
914.version 7.0
915.target sm_52
916.address_size 64
917
918.visible .entry slice_write_indirect_kernel(
919 .param .u64 src_ptr,
920 .param .u64 dst_ptr,
921 .param .u32 n,
922 .param .u32 D,
923 .param .u32 max_len,
924 .param .u64 pos_ptr
925) {
926 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
927 .reg .u32 %batch_idx, %d_idx, %dst_row;
928 .reg .u64 %src, %dst, %src_off, %dst_off, %pos_p;
929 .reg .f32 %val;
930 .reg .pred %p;
931
932 ld.param.u64 %src, [src_ptr];
933 ld.param.u64 %dst, [dst_ptr];
934 ld.param.u32 %n_reg, [n];
935 ld.param.u32 %D_reg, [D];
936 ld.param.u32 %max_len_reg, [max_len];
937 ld.param.u64 %pos_p, [pos_ptr];
938 ld.global.u32 %pos_reg, [%pos_p];
939
940 mov.u32 %bid, %ctaid.x;
941 mov.u32 %bdim, %ntid.x;
942 mov.u32 %r_tid, %tid.x;
943 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
944
945 setp.ge.u32 %p, %r_tid, %n_reg;
946 @%p bra DONE;
947
948 cvt.u64.u32 %src_off, %r_tid;
949 shl.b64 %src_off, %src_off, 2;
950 add.u64 %src, %src, %src_off;
951 ld.global.f32 %val, [%src];
952
953 div.u32 %batch_idx, %r_tid, %D_reg;
954 rem.u32 %d_idx, %r_tid, %D_reg;
955 mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
956 add.u32 %dst_row, %dst_row, %pos_reg;
957 mul.lo.u32 %dst_row, %dst_row, %D_reg;
958 add.u32 %dst_row, %dst_row, %d_idx;
959 cvt.u64.u32 %dst_off, %dst_row;
960 shl.b64 %dst_off, %dst_off, 2;
961 add.u64 %dst, %dst, %dst_off;
962 st.global.f32 [%dst], %val;
963
964DONE:
965 ret;
966}
967";
968
969#[cfg(feature = "cuda")]
975pub(crate) const CAUSAL_MASK_INDIRECT_PTX: &str = "\
976.version 7.0
977.target sm_52
978.address_size 64
979
980.visible .entry causal_mask_indirect_kernel(
981 .param .u64 total_len_ptr,
982 .param .u64 out_ptr,
983 .param .u32 max_pos,
984 .param .u32 total
985) {
986 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %tlen, %max_pos_reg, %col;
987 .reg .u64 %out, %off, %tl_p;
988 .reg .f32 %val;
989 .reg .pred %bounds_p, %mask_p;
990
991 ld.param.u64 %tl_p, [total_len_ptr];
992 ld.param.u64 %out, [out_ptr];
993 ld.param.u32 %max_pos_reg, [max_pos];
994 ld.param.u32 %total_reg, [total];
995
996 ld.global.u32 %tlen, [%tl_p];
997
998 mov.u32 %bid, %ctaid.x;
999 mov.u32 %bdim, %ntid.x;
1000 mov.u32 %r_tid, %tid.x;
1001 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1002
1003 setp.ge.u32 %bounds_p, %r_tid, %total_reg;
1004 @%bounds_p bra DONE;
1005
1006 rem.u32 %col, %r_tid, %max_pos_reg;
1007 setp.lt.u32 %mask_p, %col, %tlen;
1008 @%mask_p bra WRITE_ZERO;
1009
1010 // 0fCE6E6B28 = -1.0e9 in IEEE 754 f32, used as a large negative mask value
1011 // to effectively zero out masked positions after softmax.
1012 mov.f32 %val, 0fCE6E6B28;
1013 bra WRITE;
1014
1015WRITE_ZERO:
1016 mov.f32 %val, 0f00000000;
1017
1018WRITE:
1019 cvt.u64.u32 %off, %r_tid;
1020 shl.b64 %off, %off, 2;
1021 add.u64 %out, %out, %off;
1022 st.global.f32 [%out], %val;
1023
1024DONE:
1025 ret;
1026}
1027";
1028
1029#[cfg(feature = "cuda")]
1034pub(crate) const EMBED_LOOKUP_PTX: &str = "\
1035.version 7.0
1036.target sm_52
1037.address_size 64
1038
1039.visible .entry embed_lookup_kernel(
1040 .param .u64 idx_ptr,
1041 .param .u64 weight_ptr,
1042 .param .u64 out_ptr,
1043 .param .u32 D
1044) {
1045 .reg .u32 %r_tid, %bid, %bdim, %D_reg, %row, %src_idx;
1046 .reg .u64 %idx_addr, %w, %out, %off;
1047 .reg .f32 %idx_f, %val;
1048 .reg .pred %p;
1049
1050 ld.param.u64 %idx_addr, [idx_ptr];
1051 ld.param.u64 %w, [weight_ptr];
1052 ld.param.u64 %out, [out_ptr];
1053 ld.param.u32 %D_reg, [D];
1054
1055 mov.u32 %bid, %ctaid.x;
1056 mov.u32 %bdim, %ntid.x;
1057 mov.u32 %r_tid, %tid.x;
1058 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1059
1060 setp.ge.u32 %p, %r_tid, %D_reg;
1061 @%p bra DONE;
1062
1063 ld.global.f32 %idx_f, [%idx_addr];
1064 cvt.rzi.u32.f32 %row, %idx_f;
1065
1066 mad.lo.u32 %src_idx, %row, %D_reg, %r_tid;
1067 cvt.u64.u32 %off, %src_idx;
1068 shl.b64 %off, %off, 2;
1069 add.u64 %off, %w, %off;
1070 ld.global.f32 %val, [%off];
1071
1072 cvt.u64.u32 %off, %r_tid;
1073 shl.b64 %off, %off, 2;
1074 add.u64 %off, %out, %off;
1075 st.global.f32 [%off], %val;
1076
1077DONE:
1078 ret;
1079}
1080";
1081
1082
1083#[cfg(feature = "cuda")]
1091pub(crate) const EMBED_LOOKUP_BATCH_PTX: &str = "\
1092.version 7.0
1093.target sm_52
1094.address_size 64
1095
1096.visible .entry embed_lookup_batch_kernel(
1097 .param .u64 idx_ptr,
1098 .param .u64 weight_ptr,
1099 .param .u64 out_ptr,
1100 .param .u32 D,
1101 .param .u32 total
1102) {
1103 .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1104 .reg .u32 %row, %col, %src_idx;
1105 .reg .u64 %idx_addr, %w, %out, %off;
1106 .reg .f32 %idx_f, %val;
1107 .reg .pred %p;
1108
1109 ld.param.u64 %idx_addr, [idx_ptr];
1110 ld.param.u64 %w, [weight_ptr];
1111 ld.param.u64 %out, [out_ptr];
1112 ld.param.u32 %D_reg, [D];
1113 ld.param.u32 %total_reg, [total];
1114
1115 mov.u32 %bid, %ctaid.x;
1116 mov.u32 %bdim, %ntid.x;
1117 mov.u32 %tid, %tid.x;
1118 mad.lo.u32 %tid, %bid, %bdim, %tid;
1119
1120 setp.ge.u32 %p, %tid, %total_reg;
1121 @%p bra DONE;
1122
1123 // row = tid / D, col = tid % D
1124 div.u32 %row, %tid, %D_reg;
1125 rem.u32 %col, %tid, %D_reg;
1126
1127 // Read indices[row] (f32 -> u32)
1128 cvt.u64.u32 %off, %row;
1129 shl.b64 %off, %off, 2;
1130 add.u64 %off, %idx_addr, %off;
1131 ld.global.f32 %idx_f, [%off];
1132 cvt.rzi.u32.f32 %src_idx, %idx_f;
1133
1134 // src_idx = indices[row] * D + col
1135 mad.lo.u32 %src_idx, %src_idx, %D_reg, %col;
1136 cvt.u64.u32 %off, %src_idx;
1137 shl.b64 %off, %off, 2;
1138 add.u64 %off, %w, %off;
1139 ld.global.f32 %val, [%off];
1140
1141 // Write to out[tid]
1142 cvt.u64.u32 %off, %tid;
1143 shl.b64 %off, %off, 2;
1144 add.u64 %off, %out, %off;
1145 st.global.f32 [%off], %val;
1146
1147DONE:
1148 ret;
1149}
1150";
1151
1152
1153#[cfg(feature = "cuda")]
1161pub(crate) const SCATTER_ADD_ROWS_PTX: &str = "\
1162.version 7.0
1163.target sm_52
1164.address_size 64
1165
1166.visible .entry scatter_add_rows_kernel(
1167 .param .u64 grad_output_ptr,
1168 .param .u64 indices_ptr,
1169 .param .u64 grad_weight_ptr,
1170 .param .u32 D,
1171 .param .u32 total
1172) {
1173 .reg .u32 %tid, %bid, %bdim, %D_reg, %total_reg;
1174 .reg .u32 %row, %col, %dst_idx;
1175 .reg .u64 %go, %idx_addr, %gw, %off;
1176 .reg .f32 %idx_f, %grad_val, %dummy;
1177 .reg .pred %p;
1178
1179 ld.param.u64 %go, [grad_output_ptr];
1180 ld.param.u64 %idx_addr, [indices_ptr];
1181 ld.param.u64 %gw, [grad_weight_ptr];
1182 ld.param.u32 %D_reg, [D];
1183 ld.param.u32 %total_reg, [total];
1184
1185 mov.u32 %bid, %ctaid.x;
1186 mov.u32 %bdim, %ntid.x;
1187 mov.u32 %tid, %tid.x;
1188 mad.lo.u32 %tid, %bid, %bdim, %tid;
1189
1190 setp.ge.u32 %p, %tid, %total_reg;
1191 @%p bra DONE;
1192
1193 // row = tid / D, col = tid % D
1194 div.u32 %row, %tid, %D_reg;
1195 rem.u32 %col, %tid, %D_reg;
1196
1197 // Read grad_output[tid]
1198 cvt.u64.u32 %off, %tid;
1199 shl.b64 %off, %off, 2;
1200 add.u64 %off, %go, %off;
1201 ld.global.f32 %grad_val, [%off];
1202
1203 // Read indices[row] (f32 -> u32)
1204 cvt.u64.u32 %off, %row;
1205 shl.b64 %off, %off, 2;
1206 add.u64 %off, %idx_addr, %off;
1207 ld.global.f32 %idx_f, [%off];
1208 cvt.rzi.u32.f32 %dst_idx, %idx_f;
1209
1210 // dst_idx = indices[row] * D + col
1211 mad.lo.u32 %dst_idx, %dst_idx, %D_reg, %col;
1212 cvt.u64.u32 %off, %dst_idx;
1213 shl.b64 %off, %off, 2;
1214 add.u64 %off, %gw, %off;
1215 atom.global.add.f32 %dummy, [%off], %grad_val;
1216
1217DONE:
1218 ret;
1219}
1220";
1221
1222
1223#[cfg(feature = "cuda")]
1230pub(crate) const SLICE_READ_PTX: &str = "\
1231.version 7.0
1232.target sm_52
1233.address_size 64
1234
1235.visible .entry slice_read_kernel(
1236 .param .u64 src_ptr,
1237 .param .u64 dst_ptr,
1238 .param .u32 total,
1239 .param .u32 D,
1240 .param .u32 len,
1241 .param .u32 max_len
1242) {
1243 .reg .u32 %r_tid, %bid, %bdim, %total_reg, %D_reg, %len_reg, %max_len_reg;
1244 .reg .u32 %batch_idx, %within, %row, %col, %src_idx;
1245 .reg .u32 %len_d;
1246 .reg .u64 %src, %dst, %src_off, %dst_off;
1247 .reg .f32 %val;
1248 .reg .pred %p;
1249
1250 ld.param.u64 %src, [src_ptr];
1251 ld.param.u64 %dst, [dst_ptr];
1252 ld.param.u32 %total_reg, [total];
1253 ld.param.u32 %D_reg, [D];
1254 ld.param.u32 %len_reg, [len];
1255 ld.param.u32 %max_len_reg, [max_len];
1256
1257 mov.u32 %bid, %ctaid.x;
1258 mov.u32 %bdim, %ntid.x;
1259 mov.u32 %r_tid, %tid.x;
1260 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1261
1262 setp.ge.u32 %p, %r_tid, %total_reg;
1263 @%p bra DONE;
1264
1265 // dst index = r_tid
1266 // batch_idx = r_tid / (len * D)
1267 // within = r_tid % (len * D)
1268 // row = within / D
1269 // col = within % D
1270 // src_idx = batch_idx * max_len * D + row * D + col
1271 mul.lo.u32 %len_d, %len_reg, %D_reg;
1272 div.u32 %batch_idx, %r_tid, %len_d;
1273 rem.u32 %within, %r_tid, %len_d;
1274 div.u32 %row, %within, %D_reg;
1275 rem.u32 %col, %within, %D_reg;
1276
1277 mul.lo.u32 %src_idx, %batch_idx, %max_len_reg;
1278 add.u32 %src_idx, %src_idx, %row;
1279 mul.lo.u32 %src_idx, %src_idx, %D_reg;
1280 add.u32 %src_idx, %src_idx, %col;
1281
1282 cvt.u64.u32 %src_off, %src_idx;
1283 shl.b64 %src_off, %src_off, 2;
1284 add.u64 %src_off, %src, %src_off;
1285 ld.global.f32 %val, [%src_off];
1286
1287 cvt.u64.u32 %dst_off, %r_tid;
1288 shl.b64 %dst_off, %dst_off, 2;
1289 add.u64 %dst_off, %dst, %dst_off;
1290 st.global.f32 [%dst_off], %val;
1291
1292DONE:
1293 ret;
1294}
1295";
1296
1297
1298#[cfg(feature = "cuda")]
1308pub(crate) const GELU_PTX: &str = "\
1309.version 7.0
1310.target sm_52
1311.address_size 64
1312
1313.visible .entry gelu_kernel(
1314 .param .u64 in_ptr,
1315 .param .u64 out_ptr,
1316 .param .u32 n
1317) {
1318 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1319 .reg .u64 %in, %out, %off;
1320 .reg .f32 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
1321 .reg .pred %p;
1322
1323 ld.param.u64 %in, [in_ptr];
1324 ld.param.u64 %out, [out_ptr];
1325 ld.param.u32 %n_reg, [n];
1326
1327 mov.u32 %bid, %ctaid.x;
1328 mov.u32 %bdim, %ntid.x;
1329 mov.u32 %r_tid, %tid.x;
1330 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1331
1332 setp.ge.u32 %p, %r_tid, %n_reg;
1333 @%p bra DONE;
1334
1335 cvt.u64.u32 %off, %r_tid;
1336 shl.b64 %off, %off, 2;
1337 add.u64 %in, %in, %off;
1338 add.u64 %out, %out, %off;
1339
1340 ld.global.f32 %x, [%in];
1341
1342 mov.f32 %k, 0f3FDA2720;
1343 mul.f32 %neg_kx, %k, %x;
1344 neg.f32 %neg_kx, %neg_kx;
1345 mul.f32 %neg_kx, %neg_kx, 0f3FB8AA3B;
1346 ex2.approx.f32 %exp_neg, %neg_kx;
1347 mov.f32 %one, 0f3F800000;
1348 add.f32 %denom, %one, %exp_neg;
1349 rcp.approx.f32 %sig, %denom;
1350 mul.f32 %result, %x, %sig;
1351 st.global.f32 [%out], %result;
1352
1353DONE:
1354 ret;
1355}
1356";
1357
1358#[cfg(feature = "cuda")]
1361pub(crate) const GELU_F64_PTX: &str = "\
1362.version 7.0
1363.target sm_52
1364.address_size 64
1365
1366.visible .entry gelu_f64_kernel(
1367 .param .u64 in_ptr,
1368 .param .u64 out_ptr,
1369 .param .u32 n
1370) {
1371 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1372 .reg .u64 %in, %out, %off;
1373 .reg .f64 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
1374 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1375 .reg .s32 %e_ni;
1376 .reg .s64 %e_ni64, %e_bits;
1377 .reg .pred %p;
1378
1379 ld.param.u64 %in, [in_ptr];
1380 ld.param.u64 %out, [out_ptr];
1381 ld.param.u32 %n_reg, [n];
1382
1383 mov.u32 %bid, %ctaid.x;
1384 mov.u32 %bdim, %ntid.x;
1385 mov.u32 %r_tid, %tid.x;
1386 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1387
1388 setp.ge.u32 %p, %r_tid, %n_reg;
1389 @%p bra DONE;
1390
1391 cvt.u64.u32 %off, %r_tid;
1392 shl.b64 %off, %off, 3;
1393 add.u64 %in, %in, %off;
1394 add.u64 %out, %out, %off;
1395
1396 ld.global.f64 %x, [%in];
1397 mov.f64 %one, 0d3FF0000000000000;
1398
1399 // k = 1.702
1400 mov.f64 %k, 0d3FFB44E400000000;
1401 mul.f64 %neg_kx, %k, %x;
1402 neg.f64 %neg_kx, %neg_kx;
1403
1404 // --- exp(%neg_kx) via Cody-Waite + degree-11 Horner ---
1405 mov.f64 %e_half, 0d3FE0000000000000;
1406 fma.rn.f64 %e_nf, %neg_kx, 0d3FF71547652B82FE, %e_half;
1407 cvt.rmi.f64.f64 %e_nf, %e_nf;
1408 cvt.rni.s32.f64 %e_ni, %e_nf;
1409 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_kx;
1410 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1411 mov.f64 %e_p, 0d3E21EED8EFF8D898;
1412 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1413 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
1414 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
1415 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
1416 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
1417 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
1418 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
1419 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
1420 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
1421 fma.rn.f64 %e_p, %e_p, %e_r, %one;
1422 fma.rn.f64 %exp_neg, %e_p, %e_r, %one;
1423 cvt.s64.s32 %e_ni64, %e_ni;
1424 add.s64 %e_ni64, %e_ni64, 1023;
1425 shl.b64 %e_bits, %e_ni64, 52;
1426 mov.b64 %e_nf, %e_bits;
1427 mul.f64 %exp_neg, %exp_neg, %e_nf;
1428 // --- end exp ---
1429
1430 add.f64 %denom, %one, %exp_neg;
1431 div.rn.f64 %sig, %one, %denom;
1432 mul.f64 %result, %x, %sig;
1433 st.global.f64 [%out], %result;
1434
1435DONE:
1436 ret;
1437}
1438";
1439
1440#[cfg(feature = "cuda")]
1446pub(crate) const GELU_TANH_PTX: &str = "\
1447.version 7.0
1448.target sm_52
1449.address_size 64
1450
1451.visible .entry gelu_tanh_kernel(
1452 .param .u64 in_ptr,
1453 .param .u64 out_ptr,
1454 .param .u32 n
1455) {
1456 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1457 .reg .u64 %in, %out, %off;
1458 .reg .f32 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
1459 .reg .f32 %e2y_m1, %e2y_p1, %th, %one, %half, %log2e, %result;
1460 .reg .pred %p;
1461
1462 ld.param.u64 %in, [in_ptr];
1463 ld.param.u64 %out, [out_ptr];
1464 ld.param.u32 %n_reg, [n];
1465
1466 mov.u32 %bid, %ctaid.x;
1467 mov.u32 %bdim, %ntid.x;
1468 mov.u32 %r_tid, %tid.x;
1469 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1470
1471 setp.ge.u32 %p, %r_tid, %n_reg;
1472 @%p bra DONE;
1473
1474 cvt.u64.u32 %off, %r_tid;
1475 shl.b64 %off, %off, 2;
1476 add.u64 %in, %in, %off;
1477 add.u64 %out, %out, %off;
1478
1479 ld.global.f32 %x, [%in];
1480
1481 // inner = sqrt(2/π) * (x + 0.044715 * x³)
1482 // sqrt(2/π) = 0.7978845608 = 0x3F4C422A
1483 // 0.044715 = 0x3D372713
1484 mul.f32 %x3, %x, %x;
1485 mul.f32 %x3, %x3, %x;
1486 mov.f32 %c, 0f3D372713;
1487 mul.f32 %x3, %c, %x3;
1488 add.f32 %inner, %x, %x3;
1489 mov.f32 %sqrt2pi, 0f3F4C422A;
1490 mul.f32 %y, %sqrt2pi, %inner;
1491
1492 // tanh(y) = (exp(2y) - 1) / (exp(2y) + 1)
1493 // exp(2y) = 2^(2y * log2(e))
1494 mov.f32 %log2e, 0f3FB8AA3B;
1495 add.f32 %two_y, %y, %y;
1496 mul.f32 %two_y, %two_y, %log2e;
1497 ex2.approx.f32 %e2y, %two_y;
1498 mov.f32 %one, 0f3F800000;
1499 sub.f32 %e2y_m1, %e2y, %one;
1500 add.f32 %e2y_p1, %e2y, %one;
1501 rcp.approx.f32 %e2y_p1, %e2y_p1;
1502 mul.f32 %th, %e2y_m1, %e2y_p1;
1503
1504 // out = 0.5 * x * (1 + tanh)
1505 add.f32 %th, %one, %th;
1506 mov.f32 %half, 0f3F000000;
1507 mul.f32 %result, %half, %x;
1508 mul.f32 %result, %result, %th;
1509 st.global.f32 [%out], %result;
1510
1511DONE:
1512 ret;
1513}
1514";
1515
1516#[cfg(feature = "cuda")]
1519pub(crate) const GELU_TANH_F64_PTX: &str = "\
1520.version 7.0
1521.target sm_52
1522.address_size 64
1523
1524.visible .entry gelu_tanh_f64_kernel(
1525 .param .u64 in_ptr,
1526 .param .u64 out_ptr,
1527 .param .u32 n
1528) {
1529 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1530 .reg .u64 %in, %out, %off;
1531 .reg .f64 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
1532 .reg .f64 %e2y_m1, %e2y_p1, %th, %one, %half, %result;
1533 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1534 .reg .s32 %e_ni;
1535 .reg .s64 %e_ni64, %e_bits;
1536 .reg .pred %p;
1537
1538 ld.param.u64 %in, [in_ptr];
1539 ld.param.u64 %out, [out_ptr];
1540 ld.param.u32 %n_reg, [n];
1541
1542 mov.u32 %bid, %ctaid.x;
1543 mov.u32 %bdim, %ntid.x;
1544 mov.u32 %r_tid, %tid.x;
1545 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1546
1547 setp.ge.u32 %p, %r_tid, %n_reg;
1548 @%p bra DONE;
1549
1550 cvt.u64.u32 %off, %r_tid;
1551 shl.b64 %off, %off, 3;
1552 add.u64 %in, %in, %off;
1553 add.u64 %out, %out, %off;
1554
1555 ld.global.f64 %x, [%in];
1556 mov.f64 %one, 0d3FF0000000000000;
1557
1558 // inner = sqrt(2/pi) * (x + 0.044715 * x^3)
1559 mul.f64 %x3, %x, %x;
1560 mul.f64 %x3, %x3, %x;
1561 mov.f64 %c, 0d3FA6E4E260000000;
1562 mul.f64 %x3, %c, %x3;
1563 add.f64 %inner, %x, %x3;
1564 mov.f64 %sqrt2pi, 0d3FE9884540000000;
1565 mul.f64 %y, %sqrt2pi, %inner;
1566
1567 // tanh(y) = (exp(2y)-1)/(exp(2y)+1), exp(2y) in full f64
1568 add.f64 %two_y, %y, %y;
1569
1570 // --- exp(%two_y) via Cody-Waite + degree-11 Horner ---
1571 mov.f64 %e_half, 0d3FE0000000000000;
1572 fma.rn.f64 %e_nf, %two_y, 0d3FF71547652B82FE, %e_half;
1573 cvt.rmi.f64.f64 %e_nf, %e_nf;
1574 cvt.rni.s32.f64 %e_ni, %e_nf;
1575 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_y;
1576 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1577 mov.f64 %e_p, 0d3E21EED8EFF8D898;
1578 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1579 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
1580 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
1581 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
1582 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
1583 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
1584 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
1585 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
1586 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
1587 fma.rn.f64 %e_p, %e_p, %e_r, %one;
1588 fma.rn.f64 %e2y, %e_p, %e_r, %one;
1589 cvt.s64.s32 %e_ni64, %e_ni;
1590 add.s64 %e_ni64, %e_ni64, 1023;
1591 shl.b64 %e_bits, %e_ni64, 52;
1592 mov.b64 %e_nf, %e_bits;
1593 mul.f64 %e2y, %e2y, %e_nf;
1594 // --- end exp ---
1595
1596 sub.f64 %e2y_m1, %e2y, %one;
1597 add.f64 %e2y_p1, %e2y, %one;
1598 div.rn.f64 %th, %e2y_m1, %e2y_p1;
1599
1600 // out = 0.5 * x * (1 + tanh)
1601 add.f64 %th, %one, %th;
1602 mov.f64 %half, 0d3FE0000000000000;
1603 mul.f64 %result, %half, %x;
1604 mul.f64 %result, %result, %th;
1605 st.global.f64 [%out], %result;
1606
1607DONE:
1608 ret;
1609}
1610";
1611
1612#[cfg(feature = "cuda")]
1617pub(crate) const GELU_ERF_PTX: &str = "\
1618.version 7.0
1619.target sm_52
1620.address_size 64
1621
1622.visible .entry gelu_erf_kernel(
1623 .param .u64 in_ptr,
1624 .param .u64 out_ptr,
1625 .param .u32 n
1626) {
1627 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1628 .reg .u64 %in, %out, %off;
1629 .reg .f32 %x, %z, %ax, %one, %half, %log2e;
1630 .reg .f32 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
1631 .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %result;
1632 .reg .pred %pred_ge, %pred_neg;
1633
1634 ld.param.u64 %in, [in_ptr];
1635 ld.param.u64 %out, [out_ptr];
1636 ld.param.u32 %n_reg, [n];
1637
1638 mov.u32 %bid, %ctaid.x;
1639 mov.u32 %bdim, %ntid.x;
1640 mov.u32 %r_tid, %tid.x;
1641 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1642
1643 setp.ge.u32 %pred_ge, %r_tid, %n_reg;
1644 @%pred_ge bra DONE;
1645
1646 cvt.u64.u32 %off, %r_tid;
1647 shl.b64 %off, %off, 2;
1648 add.u64 %in, %in, %off;
1649 add.u64 %out, %out, %off;
1650
1651 ld.global.f32 %x, [%in];
1652 mov.f32 %one, 0f3F800000;
1653 mov.f32 %half, 0f3F000000;
1654 mov.f32 %log2e, 0f3FB8AA3B;
1655
1656 // z = x / sqrt(2) = x * 0.70710678
1657 mov.f32 %z, 0f3F3504F3;
1658 mul.f32 %z, %x, %z;
1659
1660 // |z| for erf(|z|)
1661 abs.f32 %ax, %z;
1662
1663 // t = 1 / (1 + 0.3275911 * |z|)
1664 mov.f32 %p, 0f3EA7BA05;
1665 mul.f32 %t, %p, %ax;
1666 add.f32 %t, %one, %t;
1667 rcp.approx.f32 %t, %t;
1668
1669 // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
1670 mov.f32 %a5, 0f3E0AAAAB;
1671 mov.f32 %a4, 0fBEB3A903;
1672 mov.f32 %a3, 0f3FB506DD;
1673 mov.f32 %a2, 0fBF03C1E1;
1674 mov.f32 %a1, 0f3EA0D6BB;
1675
1676 mul.f32 %pt, %t, %a5;
1677 add.f32 %pt, %pt, %a4;
1678 mul.f32 %pt, %pt, %t;
1679 add.f32 %pt, %pt, %a3;
1680 mul.f32 %pt, %pt, %t;
1681 add.f32 %pt, %pt, %a2;
1682 mul.f32 %pt, %pt, %t;
1683 add.f32 %pt, %pt, %a1;
1684 mul.f32 %pt, %pt, %t;
1685
1686 // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
1687 mul.f32 %z2, %ax, %ax;
1688 neg.f32 %neg_z2, %z2;
1689 mul.f32 %neg_z2, %neg_z2, %log2e;
1690 ex2.approx.f32 %exp_neg_z2, %neg_z2;
1691
1692 // erf(|z|) = 1 - poly * exp(-z^2)
1693 mul.f32 %erf_val, %pt, %exp_neg_z2;
1694 sub.f32 %erf_val, %one, %erf_val;
1695
1696 // erf(-z) = -erf(z), so sign-correct
1697 setp.lt.f32 %pred_neg, %z, 0f00000000;
1698 @%pred_neg neg.f32 %erf_val, %erf_val;
1699
1700 // out = x * 0.5 * (1 + erf(x/sqrt(2)))
1701 add.f32 %erf_val, %one, %erf_val;
1702 mul.f32 %result, %half, %x;
1703 mul.f32 %result, %result, %erf_val;
1704 st.global.f32 [%out], %result;
1705
1706DONE:
1707 ret;
1708}
1709";
1710
1711#[cfg(feature = "cuda")]
1714pub(crate) const GELU_ERF_F64_PTX: &str = "\
1715.version 7.0
1716.target sm_52
1717.address_size 64
1718
1719.visible .entry gelu_erf_f64_kernel(
1720 .param .u64 in_ptr,
1721 .param .u64 out_ptr,
1722 .param .u32 n
1723) {
1724 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1725 .reg .u64 %in, %out, %off;
1726 .reg .f64 %x, %z, %ax, %one, %half;
1727 .reg .f64 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
1728 .reg .f64 %p, %a1, %a2, %a3, %a4, %a5, %result;
1729 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1730 .reg .s32 %e_ni;
1731 .reg .s64 %e_ni64, %e_bits;
1732 .reg .pred %pred_ge, %pred_neg;
1733
1734 ld.param.u64 %in, [in_ptr];
1735 ld.param.u64 %out, [out_ptr];
1736 ld.param.u32 %n_reg, [n];
1737
1738 mov.u32 %bid, %ctaid.x;
1739 mov.u32 %bdim, %ntid.x;
1740 mov.u32 %r_tid, %tid.x;
1741 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1742
1743 setp.ge.u32 %pred_ge, %r_tid, %n_reg;
1744 @%pred_ge bra DONE;
1745
1746 cvt.u64.u32 %off, %r_tid;
1747 shl.b64 %off, %off, 3;
1748 add.u64 %in, %in, %off;
1749 add.u64 %out, %out, %off;
1750
1751 ld.global.f64 %x, [%in];
1752 mov.f64 %one, 0d3FF0000000000000;
1753 mov.f64 %half, 0d3FE0000000000000;
1754
1755 // z = x / sqrt(2) = x * 0.70710678
1756 mov.f64 %z, 0d3FE6A09E60000000;
1757 mul.f64 %z, %x, %z;
1758
1759 abs.f64 %ax, %z;
1760
1761 // t = 1 / (1 + 0.3275911 * |z|)
1762 mov.f64 %p, 0d3FD4F740A0000000;
1763 mul.f64 %t, %p, %ax;
1764 add.f64 %t, %one, %t;
1765 div.rn.f64 %t, %one, %t;
1766
1767 // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
1768 mov.f64 %a5, 0d3FC1555560000000;
1769 mov.f64 %a4, 0dBFD6752060000000;
1770 mov.f64 %a3, 0d3FF6A0DBA0000000;
1771 mov.f64 %a2, 0dBFE0783C20000000;
1772 mov.f64 %a1, 0d3FD41AD760000000;
1773
1774 mul.f64 %pt, %t, %a5;
1775 add.f64 %pt, %pt, %a4;
1776 mul.f64 %pt, %pt, %t;
1777 add.f64 %pt, %pt, %a3;
1778 mul.f64 %pt, %pt, %t;
1779 add.f64 %pt, %pt, %a2;
1780 mul.f64 %pt, %pt, %t;
1781 add.f64 %pt, %pt, %a1;
1782 mul.f64 %pt, %pt, %t;
1783
1784 // exp(-z^2) in full f64
1785 mul.f64 %z2, %ax, %ax;
1786 neg.f64 %neg_z2, %z2;
1787
1788 // --- exp(%neg_z2) via Cody-Waite + degree-11 Horner ---
1789 mov.f64 %e_half, 0d3FE0000000000000;
1790 fma.rn.f64 %e_nf, %neg_z2, 0d3FF71547652B82FE, %e_half;
1791 cvt.rmi.f64.f64 %e_nf, %e_nf;
1792 cvt.rni.s32.f64 %e_ni, %e_nf;
1793 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_z2;
1794 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1795 mov.f64 %e_p, 0d3E21EED8EFF8D898;
1796 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1797 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
1798 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
1799 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
1800 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
1801 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
1802 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
1803 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
1804 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
1805 fma.rn.f64 %e_p, %e_p, %e_r, %one;
1806 fma.rn.f64 %exp_neg_z2, %e_p, %e_r, %one;
1807 cvt.s64.s32 %e_ni64, %e_ni;
1808 add.s64 %e_ni64, %e_ni64, 1023;
1809 shl.b64 %e_bits, %e_ni64, 52;
1810 mov.b64 %e_nf, %e_bits;
1811 mul.f64 %exp_neg_z2, %exp_neg_z2, %e_nf;
1812 // --- end exp ---
1813
1814 mul.f64 %erf_val, %pt, %exp_neg_z2;
1815 sub.f64 %erf_val, %one, %erf_val;
1816
1817 setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
1818 @%pred_neg neg.f64 %erf_val, %erf_val;
1819
1820 add.f64 %erf_val, %one, %erf_val;
1821 mul.f64 %result, %half, %x;
1822 mul.f64 %result, %result, %erf_val;
1823 st.global.f64 [%out], %result;
1824
1825DONE:
1826 ret;
1827}
1828";
1829
1830#[cfg(feature = "cuda")]
1836pub(crate) const GELU_BACKWARD_TANH_PTX: &str = "\
1837.version 7.0
1838.target sm_52
1839.address_size 64
1840
1841.visible .entry gelu_backward_tanh_kernel(
1842 .param .u64 grad_ptr,
1843 .param .u64 input_ptr,
1844 .param .u64 out_ptr,
1845 .param .u32 n
1846) {
1847 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1848 .reg .u64 %grad, %input, %out, %off;
1849 .reg .f32 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
1850 .reg .f32 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half, %log2e;
1851 .reg .f32 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
1852 .reg .pred %p;
1853
1854 ld.param.u64 %grad, [grad_ptr];
1855 ld.param.u64 %input, [input_ptr];
1856 ld.param.u64 %out, [out_ptr];
1857 ld.param.u32 %n_reg, [n];
1858
1859 mov.u32 %bid, %ctaid.x;
1860 mov.u32 %bdim, %ntid.x;
1861 mov.u32 %r_tid, %tid.x;
1862 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1863
1864 setp.ge.u32 %p, %r_tid, %n_reg;
1865 @%p bra DONE;
1866
1867 cvt.u64.u32 %off, %r_tid;
1868 shl.b64 %off, %off, 2;
1869 add.u64 %grad, %grad, %off;
1870 add.u64 %input, %input, %off;
1871 add.u64 %out, %out, %off;
1872
1873 ld.global.f32 %vg, [%grad];
1874 ld.global.f32 %x, [%input];
1875
1876 mov.f32 %one, 0f3F800000;
1877 mov.f32 %half, 0f3F000000;
1878 mov.f32 %log2e, 0f3FB8AA3B;
1879 mov.f32 %sqrt2pi, 0f3F4C422A;
1880 mov.f32 %c, 0f3D372713;
1881 // 3 * 0.044715 = 0.134145 = 0x3E096B8C
1882 mov.f32 %c3, 0f3E096B8C;
1883
1884 // u = sqrt(2/π) * (x + 0.044715 * x³)
1885 mul.f32 %x2, %x, %x;
1886 mul.f32 %x3, %x2, %x;
1887 mul.f32 %x3, %c, %x3;
1888 add.f32 %inner, %x, %x3;
1889 mul.f32 %y, %sqrt2pi, %inner;
1890
1891 // tanh(y) via exp
1892 add.f32 %two_y, %y, %y;
1893 mul.f32 %two_y, %two_y, %log2e;
1894 ex2.approx.f32 %e2y, %two_y;
1895 sub.f32 %e2y_m1, %e2y, %one;
1896 add.f32 %e2y_p1, %e2y, %one;
1897 rcp.approx.f32 %e2y_p1, %e2y_p1;
1898 mul.f32 %th, %e2y_m1, %e2y_p1;
1899
1900 // d/dx = 0.5*(1+tanh) + 0.5*x*(1-tanh²)*sqrt(2/π)*(1+3*0.044715*x²)
1901 // term1 = 0.5 * (1 + th)
1902 add.f32 %term1, %one, %th;
1903 mul.f32 %term1, %half, %term1;
1904
1905 // (1 - th²)
1906 mul.f32 %th2, %th, %th;
1907 sub.f32 %one_m_th2, %one, %th2;
1908
1909 // d_inner = sqrt(2/π) * (1 + 3*0.044715*x²)
1910 mul.f32 %d_inner, %c3, %x2;
1911 add.f32 %d_inner, %one, %d_inner;
1912 mul.f32 %d_inner, %sqrt2pi, %d_inner;
1913
1914 // term2 = 0.5 * x * (1-th²) * d_inner
1915 mul.f32 %term2, %half, %x;
1916 mul.f32 %term2, %term2, %one_m_th2;
1917 mul.f32 %term2, %term2, %d_inner;
1918
1919 add.f32 %d_gelu, %term1, %term2;
1920 mul.f32 %result, %vg, %d_gelu;
1921 st.global.f32 [%out], %result;
1922
1923DONE:
1924 ret;
1925}
1926";
1927
1928#[cfg(feature = "cuda")]
1931pub(crate) const GELU_BACKWARD_TANH_F64_PTX: &str = "\
1932.version 7.0
1933.target sm_52
1934.address_size 64
1935
1936.visible .entry gelu_backward_tanh_f64_kernel(
1937 .param .u64 grad_ptr,
1938 .param .u64 input_ptr,
1939 .param .u64 out_ptr,
1940 .param .u32 n
1941) {
1942 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
1943 .reg .u64 %grad, %input, %out, %off;
1944 .reg .f64 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
1945 .reg .f64 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half;
1946 .reg .f64 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
1947 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
1948 .reg .s32 %e_ni;
1949 .reg .s64 %e_ni64, %e_bits;
1950 .reg .pred %p;
1951
1952 ld.param.u64 %grad, [grad_ptr];
1953 ld.param.u64 %input, [input_ptr];
1954 ld.param.u64 %out, [out_ptr];
1955 ld.param.u32 %n_reg, [n];
1956
1957 mov.u32 %bid, %ctaid.x;
1958 mov.u32 %bdim, %ntid.x;
1959 mov.u32 %r_tid, %tid.x;
1960 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
1961
1962 setp.ge.u32 %p, %r_tid, %n_reg;
1963 @%p bra DONE;
1964
1965 cvt.u64.u32 %off, %r_tid;
1966 shl.b64 %off, %off, 3;
1967 add.u64 %grad, %grad, %off;
1968 add.u64 %input, %input, %off;
1969 add.u64 %out, %out, %off;
1970
1971 ld.global.f64 %vg, [%grad];
1972 ld.global.f64 %x, [%input];
1973
1974 mov.f64 %one, 0d3FF0000000000000;
1975 mov.f64 %half, 0d3FE0000000000000;
1976 mov.f64 %sqrt2pi, 0d3FE9884540000000;
1977 mov.f64 %c, 0d3FA6E4E260000000;
1978 // 3 * 0.044715 = 0.134145
1979 mov.f64 %c3, 0d3FC12D7180000000;
1980
1981 mul.f64 %x2, %x, %x;
1982 mul.f64 %x3, %x2, %x;
1983 mul.f64 %x3, %c, %x3;
1984 add.f64 %inner, %x, %x3;
1985 mul.f64 %y, %sqrt2pi, %inner;
1986
1987 // tanh(y) = (exp(2y)-1)/(exp(2y)+1) in full f64
1988 add.f64 %two_y, %y, %y;
1989
1990 // --- exp(%two_y) via Cody-Waite + degree-11 Horner ---
1991 mov.f64 %e_half, 0d3FE0000000000000;
1992 fma.rn.f64 %e_nf, %two_y, 0d3FF71547652B82FE, %e_half;
1993 cvt.rmi.f64.f64 %e_nf, %e_nf;
1994 cvt.rni.s32.f64 %e_ni, %e_nf;
1995 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_y;
1996 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
1997 mov.f64 %e_p, 0d3E21EED8EFF8D898;
1998 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
1999 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2000 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2001 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2002 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2003 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2004 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2005 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2006 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2007 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2008 fma.rn.f64 %e2y, %e_p, %e_r, %one;
2009 cvt.s64.s32 %e_ni64, %e_ni;
2010 add.s64 %e_ni64, %e_ni64, 1023;
2011 shl.b64 %e_bits, %e_ni64, 52;
2012 mov.b64 %e_nf, %e_bits;
2013 mul.f64 %e2y, %e2y, %e_nf;
2014 // --- end exp ---
2015
2016 sub.f64 %e2y_m1, %e2y, %one;
2017 add.f64 %e2y_p1, %e2y, %one;
2018 div.rn.f64 %th, %e2y_m1, %e2y_p1;
2019
2020 add.f64 %term1, %one, %th;
2021 mul.f64 %term1, %half, %term1;
2022
2023 mul.f64 %th2, %th, %th;
2024 sub.f64 %one_m_th2, %one, %th2;
2025
2026 mul.f64 %d_inner, %c3, %x2;
2027 add.f64 %d_inner, %one, %d_inner;
2028 mul.f64 %d_inner, %sqrt2pi, %d_inner;
2029
2030 mul.f64 %term2, %half, %x;
2031 mul.f64 %term2, %term2, %one_m_th2;
2032 mul.f64 %term2, %term2, %d_inner;
2033
2034 add.f64 %d_gelu, %term1, %term2;
2035 mul.f64 %result, %vg, %d_gelu;
2036 st.global.f64 [%out], %result;
2037
2038DONE:
2039 ret;
2040}
2041";
2042
2043#[cfg(feature = "cuda")]
2050pub(crate) const SILU_PTX: &str = "\
2051.version 7.0
2052.target sm_52
2053.address_size 64
2054
2055.visible .entry silu_kernel(
2056 .param .u64 a_ptr,
2057 .param .u64 out_ptr,
2058 .param .u32 n
2059) {
2060 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2061 .reg .u64 %a, %out, %off;
2062 .reg .f32 %x, %neg, %e, %denom, %sig, %vr, %one, %lg2e;
2063 .reg .pred %p;
2064
2065 ld.param.u64 %a, [a_ptr];
2066 ld.param.u64 %out, [out_ptr];
2067 ld.param.u32 %n_reg, [n];
2068
2069 mov.u32 %bid, %ctaid.x;
2070 mov.u32 %bdim, %ntid.x;
2071 mov.u32 %r_tid, %tid.x;
2072 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2073
2074 setp.ge.u32 %p, %r_tid, %n_reg;
2075 @%p bra DONE;
2076
2077 cvt.u64.u32 %off, %r_tid;
2078 shl.b64 %off, %off, 2;
2079
2080 add.u64 %a, %a, %off;
2081 add.u64 %out, %out, %off;
2082
2083 ld.global.f32 %x, [%a];
2084 // sigmoid(x) = 1 / (1 + exp(-x))
2085 // exp(-x) = 2^(-x * log2(e))
2086 mov.f32 %one, 0f3F800000;
2087 mov.f32 %lg2e, 0f3FB8AA3B;
2088 neg.f32 %neg, %x;
2089 mul.f32 %neg, %neg, %lg2e;
2090 ex2.approx.f32 %e, %neg;
2091 add.f32 %denom, %one, %e;
2092 rcp.approx.f32 %sig, %denom;
2093 // silu(x) = x * sigmoid(x)
2094 mul.f32 %vr, %x, %sig;
2095 st.global.f32 [%out], %vr;
2096
2097DONE:
2098 ret;
2099}
2100";
2101
2102#[cfg(feature = "cuda")]
2105pub(crate) const SILU_F64_PTX: &str = "\
2106.version 7.0
2107.target sm_52
2108.address_size 64
2109
2110.visible .entry silu_f64_kernel(
2111 .param .u64 a_ptr,
2112 .param .u64 out_ptr,
2113 .param .u32 n
2114) {
2115 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2116 .reg .u64 %a, %out, %off;
2117 .reg .f64 %x, %neg_x, %e, %denom, %sig, %vr, %one;
2118 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2119 .reg .s32 %e_ni;
2120 .reg .s64 %e_ni64, %e_bits;
2121 .reg .pred %p;
2122
2123 ld.param.u64 %a, [a_ptr];
2124 ld.param.u64 %out, [out_ptr];
2125 ld.param.u32 %n_reg, [n];
2126
2127 mov.u32 %bid, %ctaid.x;
2128 mov.u32 %bdim, %ntid.x;
2129 mov.u32 %r_tid, %tid.x;
2130 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2131
2132 setp.ge.u32 %p, %r_tid, %n_reg;
2133 @%p bra DONE;
2134
2135 cvt.u64.u32 %off, %r_tid;
2136 shl.b64 %off, %off, 3;
2137 add.u64 %a, %a, %off;
2138 add.u64 %out, %out, %off;
2139
2140 ld.global.f64 %x, [%a];
2141 mov.f64 %one, 0d3FF0000000000000;
2142 neg.f64 %neg_x, %x;
2143
2144 // --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
2145 mov.f64 %e_half, 0d3FE0000000000000;
2146 fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
2147 cvt.rmi.f64.f64 %e_nf, %e_nf;
2148 cvt.rni.s32.f64 %e_ni, %e_nf;
2149 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
2150 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2151 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2152 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2153 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2154 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2155 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2156 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2157 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2158 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2159 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2160 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2161 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2162 fma.rn.f64 %e, %e_p, %e_r, %one;
2163 cvt.s64.s32 %e_ni64, %e_ni;
2164 add.s64 %e_ni64, %e_ni64, 1023;
2165 shl.b64 %e_bits, %e_ni64, 52;
2166 mov.b64 %e_nf, %e_bits;
2167 mul.f64 %e, %e, %e_nf;
2168 // --- end exp ---
2169
2170 add.f64 %denom, %one, %e;
2171 div.rn.f64 %sig, %one, %denom;
2172 mul.f64 %vr, %x, %sig;
2173 st.global.f64 [%out], %vr;
2174
2175DONE:
2176 ret;
2177}
2178";
2179
2180#[cfg(feature = "cuda")]
2183pub(crate) const SILU_BACKWARD_PTX: &str = "\
2184.version 7.0
2185.target sm_52
2186.address_size 64
2187
2188.visible .entry silu_backward_kernel(
2189 .param .u64 grad_ptr,
2190 .param .u64 input_ptr,
2191 .param .u64 out_ptr,
2192 .param .u32 n
2193) {
2194 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2195 .reg .u64 %grad, %input, %out, %off;
2196 .reg .f32 %vg, %x, %neg, %e, %denom, %sig, %one, %lg2e;
2197 .reg .f32 %one_m_sig, %x_sig_omsig, %deriv, %result;
2198 .reg .pred %p;
2199
2200 ld.param.u64 %grad, [grad_ptr];
2201 ld.param.u64 %input, [input_ptr];
2202 ld.param.u64 %out, [out_ptr];
2203 ld.param.u32 %n_reg, [n];
2204
2205 mov.u32 %bid, %ctaid.x;
2206 mov.u32 %bdim, %ntid.x;
2207 mov.u32 %r_tid, %tid.x;
2208 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2209
2210 setp.ge.u32 %p, %r_tid, %n_reg;
2211 @%p bra DONE;
2212
2213 cvt.u64.u32 %off, %r_tid;
2214 shl.b64 %off, %off, 2;
2215 add.u64 %grad, %grad, %off;
2216 add.u64 %input, %input, %off;
2217 add.u64 %out, %out, %off;
2218
2219 ld.global.f32 %vg, [%grad];
2220 ld.global.f32 %x, [%input];
2221
2222 // sig = sigmoid(x) = 1 / (1 + exp(-x))
2223 mov.f32 %one, 0f3F800000;
2224 mov.f32 %lg2e, 0f3FB8AA3B;
2225 neg.f32 %neg, %x;
2226 mul.f32 %neg, %neg, %lg2e;
2227 ex2.approx.f32 %e, %neg;
2228 add.f32 %denom, %one, %e;
2229 rcp.approx.f32 %sig, %denom;
2230
2231 // deriv = sig + x * sig * (1 - sig)
2232 sub.f32 %one_m_sig, %one, %sig;
2233 mul.f32 %x_sig_omsig, %x, %sig;
2234 mul.f32 %x_sig_omsig, %x_sig_omsig, %one_m_sig;
2235 add.f32 %deriv, %sig, %x_sig_omsig;
2236 mul.f32 %result, %vg, %deriv;
2237 st.global.f32 [%out], %result;
2238
2239DONE:
2240 ret;
2241}
2242";
2243
2244#[cfg(feature = "cuda")]
2247pub(crate) const SILU_BACKWARD_F64_PTX: &str = "\
2248.version 7.0
2249.target sm_52
2250.address_size 64
2251
2252.visible .entry silu_backward_f64_kernel(
2253 .param .u64 grad_ptr,
2254 .param .u64 input_ptr,
2255 .param .u64 out_ptr,
2256 .param .u32 n
2257) {
2258 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2259 .reg .u64 %grad, %input, %out, %off;
2260 .reg .f64 %vg, %x, %neg_x, %e, %denom, %sig, %one;
2261 .reg .f64 %one_m_sig, %x_sig_omsig, %deriv, %result;
2262 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2263 .reg .s32 %e_ni;
2264 .reg .s64 %e_ni64, %e_bits;
2265 .reg .pred %p;
2266
2267 ld.param.u64 %grad, [grad_ptr];
2268 ld.param.u64 %input, [input_ptr];
2269 ld.param.u64 %out, [out_ptr];
2270 ld.param.u32 %n_reg, [n];
2271
2272 mov.u32 %bid, %ctaid.x;
2273 mov.u32 %bdim, %ntid.x;
2274 mov.u32 %r_tid, %tid.x;
2275 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2276
2277 setp.ge.u32 %p, %r_tid, %n_reg;
2278 @%p bra DONE;
2279
2280 cvt.u64.u32 %off, %r_tid;
2281 shl.b64 %off, %off, 3;
2282 add.u64 %grad, %grad, %off;
2283 add.u64 %input, %input, %off;
2284 add.u64 %out, %out, %off;
2285
2286 ld.global.f64 %vg, [%grad];
2287 ld.global.f64 %x, [%input];
2288
2289 mov.f64 %one, 0d3FF0000000000000;
2290 neg.f64 %neg_x, %x;
2291
2292 // --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
2293 mov.f64 %e_half, 0d3FE0000000000000;
2294 fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
2295 cvt.rmi.f64.f64 %e_nf, %e_nf;
2296 cvt.rni.s32.f64 %e_ni, %e_nf;
2297 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
2298 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2299 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2300 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2301 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2302 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2303 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2304 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2305 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2306 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2307 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2308 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2309 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2310 fma.rn.f64 %e, %e_p, %e_r, %one;
2311 cvt.s64.s32 %e_ni64, %e_ni;
2312 add.s64 %e_ni64, %e_ni64, 1023;
2313 shl.b64 %e_bits, %e_ni64, 52;
2314 mov.b64 %e_nf, %e_bits;
2315 mul.f64 %e, %e, %e_nf;
2316 // --- end exp ---
2317
2318 add.f64 %denom, %one, %e;
2319 div.rn.f64 %sig, %one, %denom;
2320
2321 sub.f64 %one_m_sig, %one, %sig;
2322 mul.f64 %x_sig_omsig, %x, %sig;
2323 mul.f64 %x_sig_omsig, %x_sig_omsig, %one_m_sig;
2324 add.f64 %deriv, %sig, %x_sig_omsig;
2325 mul.f64 %result, %vg, %deriv;
2326 st.global.f64 [%out], %result;
2327
2328DONE:
2329 ret;
2330}
2331";
2332
2333#[cfg(feature = "cuda")]
2336pub(crate) const ELU_PTX: &str = "\
2337.version 7.0
2338.target sm_52
2339.address_size 64
2340
2341.visible .entry elu_kernel(
2342 .param .u64 a_ptr,
2343 .param .u64 out_ptr,
2344 .param .u32 n,
2345 .param .f32 alpha
2346) {
2347 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2348 .reg .u64 %a, %out, %off;
2349 .reg .f32 %x, %alpha_r, %lg2e, %one, %ex, %em1, %neg_branch, %vr;
2350 .reg .pred %p, %pos;
2351
2352 ld.param.u64 %a, [a_ptr];
2353 ld.param.u64 %out, [out_ptr];
2354 ld.param.u32 %n_reg, [n];
2355 ld.param.f32 %alpha_r, [alpha];
2356
2357 mov.u32 %bid, %ctaid.x;
2358 mov.u32 %bdim, %ntid.x;
2359 mov.u32 %r_tid, %tid.x;
2360 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2361
2362 setp.ge.u32 %p, %r_tid, %n_reg;
2363 @%p bra DONE;
2364
2365 cvt.u64.u32 %off, %r_tid;
2366 shl.b64 %off, %off, 2;
2367
2368 add.u64 %a, %a, %off;
2369 add.u64 %out, %out, %off;
2370
2371 ld.global.f32 %x, [%a];
2372 mov.f32 %one, 0f3F800000;
2373 mov.f32 %lg2e, 0f3FB8AA3B;
2374
2375 // exp(x) = 2^(x * log2(e))
2376 mul.f32 %ex, %x, %lg2e;
2377 ex2.approx.f32 %ex, %ex;
2378 sub.f32 %em1, %ex, %one;
2379 mul.f32 %neg_branch, %alpha_r, %em1;
2380
2381 // x > 0 ? x : alpha*(exp(x)-1)
2382 mov.f32 %vr, 0f00000000;
2383 setp.gt.f32 %pos, %x, %vr;
2384 selp.f32 %vr, %x, %neg_branch, %pos;
2385 st.global.f32 [%out], %vr;
2386
2387DONE:
2388 ret;
2389}
2390";
2391
2392#[cfg(feature = "cuda")]
2395pub(crate) const ELU_F64_PTX: &str = "\
2396.version 7.0
2397.target sm_52
2398.address_size 64
2399
2400.visible .entry elu_f64_kernel(
2401 .param .u64 a_ptr,
2402 .param .u64 out_ptr,
2403 .param .u32 n,
2404 .param .f64 alpha
2405) {
2406 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2407 .reg .u64 %a, %out, %off;
2408 .reg .f64 %x, %alpha_r, %one, %ex, %em1, %neg_branch, %vr;
2409 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2410 .reg .s32 %e_ni;
2411 .reg .s64 %e_ni64, %e_bits;
2412 .reg .pred %p, %pos;
2413
2414 ld.param.u64 %a, [a_ptr];
2415 ld.param.u64 %out, [out_ptr];
2416 ld.param.u32 %n_reg, [n];
2417 ld.param.f64 %alpha_r, [alpha];
2418
2419 mov.u32 %bid, %ctaid.x;
2420 mov.u32 %bdim, %ntid.x;
2421 mov.u32 %r_tid, %tid.x;
2422 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2423
2424 setp.ge.u32 %p, %r_tid, %n_reg;
2425 @%p bra DONE;
2426
2427 cvt.u64.u32 %off, %r_tid;
2428 shl.b64 %off, %off, 3;
2429 add.u64 %a, %a, %off;
2430 add.u64 %out, %out, %off;
2431
2432 ld.global.f64 %x, [%a];
2433 mov.f64 %one, 0d3FF0000000000000;
2434
2435 // --- exp(%x) via Cody-Waite + degree-11 Horner ---
2436 mov.f64 %e_half, 0d3FE0000000000000;
2437 fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
2438 cvt.rmi.f64.f64 %e_nf, %e_nf;
2439 cvt.rni.s32.f64 %e_ni, %e_nf;
2440 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
2441 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2442 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2443 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2444 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2445 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2446 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2447 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2448 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2449 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2450 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2451 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2452 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2453 fma.rn.f64 %ex, %e_p, %e_r, %one;
2454 cvt.s64.s32 %e_ni64, %e_ni;
2455 add.s64 %e_ni64, %e_ni64, 1023;
2456 shl.b64 %e_bits, %e_ni64, 52;
2457 mov.b64 %e_nf, %e_bits;
2458 mul.f64 %ex, %ex, %e_nf;
2459 // --- end exp ---
2460
2461 sub.f64 %em1, %ex, %one;
2462 mul.f64 %neg_branch, %alpha_r, %em1;
2463
2464 mov.f64 %vr, 0d0000000000000000;
2465 setp.gt.f64 %pos, %x, %vr;
2466 selp.f64 %vr, %x, %neg_branch, %pos;
2467 st.global.f64 [%out], %vr;
2468
2469DONE:
2470 ret;
2471}
2472";
2473
2474#[cfg(feature = "cuda")]
2478pub(crate) const ELU_BACKWARD_PTX: &str = "\
2479.version 7.0
2480.target sm_52
2481.address_size 64
2482
2483.visible .entry elu_backward_kernel(
2484 .param .u64 grad_ptr,
2485 .param .u64 input_ptr,
2486 .param .u64 out_ptr,
2487 .param .u32 n,
2488 .param .f32 alpha
2489) {
2490 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2491 .reg .u64 %grad, %input, %out, %off;
2492 .reg .f32 %vg, %x, %alpha_r, %lg2e, %ex, %neg_branch, %vr, %zero;
2493 .reg .pred %p, %pos;
2494
2495 ld.param.u64 %grad, [grad_ptr];
2496 ld.param.u64 %input, [input_ptr];
2497 ld.param.u64 %out, [out_ptr];
2498 ld.param.u32 %n_reg, [n];
2499 ld.param.f32 %alpha_r, [alpha];
2500
2501 mov.u32 %bid, %ctaid.x;
2502 mov.u32 %bdim, %ntid.x;
2503 mov.u32 %r_tid, %tid.x;
2504 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2505
2506 setp.ge.u32 %p, %r_tid, %n_reg;
2507 @%p bra DONE;
2508
2509 cvt.u64.u32 %off, %r_tid;
2510 shl.b64 %off, %off, 2;
2511 add.u64 %grad, %grad, %off;
2512 add.u64 %input, %input, %off;
2513 add.u64 %out, %out, %off;
2514
2515 ld.global.f32 %vg, [%grad];
2516 ld.global.f32 %x, [%input];
2517
2518 mov.f32 %lg2e, 0f3FB8AA3B;
2519 mov.f32 %zero, 0f00000000;
2520
2521 // exp(x) = 2^(x * log2(e))
2522 mul.f32 %ex, %x, %lg2e;
2523 ex2.approx.f32 %ex, %ex;
2524 // negative branch: grad * alpha * exp(x)
2525 mul.f32 %neg_branch, %vg, %alpha_r;
2526 mul.f32 %neg_branch, %neg_branch, %ex;
2527
2528 // x > 0 ? grad : grad * alpha * exp(x)
2529 setp.gt.f32 %pos, %x, %zero;
2530 selp.f32 %vr, %vg, %neg_branch, %pos;
2531 st.global.f32 [%out], %vr;
2532
2533DONE:
2534 ret;
2535}
2536";
2537
2538#[cfg(feature = "cuda")]
2541pub(crate) const ELU_BACKWARD_F64_PTX: &str = "\
2542.version 7.0
2543.target sm_52
2544.address_size 64
2545
2546.visible .entry elu_backward_f64_kernel(
2547 .param .u64 grad_ptr,
2548 .param .u64 input_ptr,
2549 .param .u64 out_ptr,
2550 .param .u32 n,
2551 .param .f64 alpha
2552) {
2553 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2554 .reg .u64 %grad, %input, %out, %off;
2555 .reg .f64 %vg, %x, %alpha_r, %ex, %neg_branch, %vr, %zero, %one;
2556 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2557 .reg .s32 %e_ni;
2558 .reg .s64 %e_ni64, %e_bits;
2559 .reg .pred %p, %pos;
2560
2561 ld.param.u64 %grad, [grad_ptr];
2562 ld.param.u64 %input, [input_ptr];
2563 ld.param.u64 %out, [out_ptr];
2564 ld.param.u32 %n_reg, [n];
2565 ld.param.f64 %alpha_r, [alpha];
2566
2567 mov.u32 %bid, %ctaid.x;
2568 mov.u32 %bdim, %ntid.x;
2569 mov.u32 %r_tid, %tid.x;
2570 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2571
2572 setp.ge.u32 %p, %r_tid, %n_reg;
2573 @%p bra DONE;
2574
2575 cvt.u64.u32 %off, %r_tid;
2576 shl.b64 %off, %off, 3;
2577 add.u64 %grad, %grad, %off;
2578 add.u64 %input, %input, %off;
2579 add.u64 %out, %out, %off;
2580
2581 ld.global.f64 %vg, [%grad];
2582 ld.global.f64 %x, [%input];
2583
2584 mov.f64 %zero, 0d0000000000000000;
2585 mov.f64 %one, 0d3FF0000000000000;
2586
2587 // --- exp(%x) via Cody-Waite + degree-11 Horner ---
2588 mov.f64 %e_half, 0d3FE0000000000000;
2589 fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
2590 cvt.rmi.f64.f64 %e_nf, %e_nf;
2591 cvt.rni.s32.f64 %e_ni, %e_nf;
2592 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
2593 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2594 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2595 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2596 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2597 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2598 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2599 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2600 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2601 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2602 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2603 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2604 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2605 fma.rn.f64 %ex, %e_p, %e_r, %one;
2606 cvt.s64.s32 %e_ni64, %e_ni;
2607 add.s64 %e_ni64, %e_ni64, 1023;
2608 shl.b64 %e_bits, %e_ni64, 52;
2609 mov.b64 %e_nf, %e_bits;
2610 mul.f64 %ex, %ex, %e_nf;
2611 // --- end exp ---
2612
2613 mul.f64 %neg_branch, %vg, %alpha_r;
2614 mul.f64 %neg_branch, %neg_branch, %ex;
2615
2616 setp.gt.f64 %pos, %x, %zero;
2617 selp.f64 %vr, %vg, %neg_branch, %pos;
2618 st.global.f64 [%out], %vr;
2619
2620DONE:
2621 ret;
2622}
2623";
2624
2625#[cfg(feature = "cuda")]
2629pub(crate) const MISH_PTX: &str = "\
2630.version 7.0
2631.target sm_52
2632.address_size 64
2633
2634.visible .entry mish_kernel(
2635 .param .u64 a_ptr,
2636 .param .u64 out_ptr,
2637 .param .u32 n
2638) {
2639 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2640 .reg .u64 %a, %out, %off;
2641 .reg .f32 %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
2642 .reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %th, %vr;
2643 .reg .f32 %threshold;
2644 .reg .pred %p, %large;
2645
2646 ld.param.u64 %a, [a_ptr];
2647 ld.param.u64 %out, [out_ptr];
2648 ld.param.u32 %n_reg, [n];
2649
2650 mov.u32 %bid, %ctaid.x;
2651 mov.u32 %bdim, %ntid.x;
2652 mov.u32 %r_tid, %tid.x;
2653 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2654
2655 setp.ge.u32 %p, %r_tid, %n_reg;
2656 @%p bra DONE;
2657
2658 cvt.u64.u32 %off, %r_tid;
2659 shl.b64 %off, %off, 2;
2660
2661 add.u64 %a, %a, %off;
2662 add.u64 %out, %out, %off;
2663
2664 ld.global.f32 %x, [%a];
2665 mov.f32 %one, 0f3F800000;
2666 mov.f32 %lg2e, 0f3FB8AA3B;
2667 // threshold = 20.0 = 0x41A00000
2668 mov.f32 %threshold, 0f41A00000;
2669
2670 // softplus(x) = ln(1 + exp(x))
2671 // For large x (> 20), softplus ~ x to avoid overflow
2672 setp.gt.f32 %large, %x, %threshold;
2673 @%large bra LARGE_X;
2674
2675 // exp(x) = 2^(x * log2(e))
2676 mul.f32 %ex, %x, %lg2e;
2677 ex2.approx.f32 %ex, %ex;
2678 add.f32 %ep1, %ex, %one;
2679 // ln(1+exp(x)) = log2(1+exp(x)) / log2(e)
2680 lg2.approx.f32 %lg_ep1, %ep1;
2681 // 1/log2(e) = ln(2) = 0.6931472 = 0x3F317218
2682 mul.f32 %sp, %lg_ep1, 0f3F317218;
2683
2684 // tanh(sp) = (exp(2*sp) - 1) / (exp(2*sp) + 1)
2685 add.f32 %two_sp, %sp, %sp;
2686 mul.f32 %two_sp, %two_sp, %lg2e;
2687 ex2.approx.f32 %e2sp, %two_sp;
2688 sub.f32 %e2sp_m1, %e2sp, %one;
2689 add.f32 %e2sp_p1, %e2sp, %one;
2690 rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2691 mul.f32 %th, %e2sp_m1, %e2sp_p1;
2692
2693 mul.f32 %vr, %x, %th;
2694 st.global.f32 [%out], %vr;
2695 bra DONE;
2696
2697LARGE_X:
2698 // softplus ~ x, mish ~ x * tanh(x)
2699 // tanh(x) = (exp(2x)-1)/(exp(2x)+1)
2700 add.f32 %two_sp, %x, %x;
2701 mul.f32 %two_sp, %two_sp, %lg2e;
2702 ex2.approx.f32 %e2sp, %two_sp;
2703 sub.f32 %e2sp_m1, %e2sp, %one;
2704 add.f32 %e2sp_p1, %e2sp, %one;
2705 rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2706 mul.f32 %th, %e2sp_m1, %e2sp_p1;
2707 mul.f32 %vr, %x, %th;
2708 st.global.f32 [%out], %vr;
2709
2710DONE:
2711 ret;
2712}
2713";
2714
2715#[cfg(feature = "cuda")]
2718pub(crate) const MISH_F64_PTX: &str = "\
2719.version 7.0
2720.target sm_52
2721.address_size 64
2722
2723.visible .entry mish_f64_kernel(
2724 .param .u64 a_ptr,
2725 .param .u64 out_ptr,
2726 .param .u32 n
2727) {
2728 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2729 .reg .u64 %a, %out, %off;
2730 .reg .f64 %x, %one, %two, %ex, %ep1, %sp;
2731 .reg .f64 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %th, %vr;
2732 .reg .f64 %threshold;
2733 // exp subroutine regs
2734 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
2735 .reg .s32 %e_ni;
2736 .reg .s64 %e_ni64, %e_bits;
2737 // log subroutine regs
2738 .reg .u64 %l_xbits, %l_mbits, %l_bias;
2739 .reg .s64 %l_exp64;
2740 .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
2741 .reg .pred %p, %large;
2742
2743 ld.param.u64 %a, [a_ptr];
2744 ld.param.u64 %out, [out_ptr];
2745 ld.param.u32 %n_reg, [n];
2746
2747 mov.u32 %bid, %ctaid.x;
2748 mov.u32 %bdim, %ntid.x;
2749 mov.u32 %r_tid, %tid.x;
2750 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2751
2752 setp.ge.u32 %p, %r_tid, %n_reg;
2753 @%p bra DONE;
2754
2755 cvt.u64.u32 %off, %r_tid;
2756 shl.b64 %off, %off, 3;
2757 add.u64 %a, %a, %off;
2758 add.u64 %out, %out, %off;
2759
2760 ld.global.f64 %x, [%a];
2761 mov.f64 %one, 0d3FF0000000000000;
2762 mov.f64 %two, 0d4000000000000000;
2763 mov.f64 %threshold, 0d4034000000000000;
2764
2765 setp.gt.f64 %large, %x, %threshold;
2766 @%large bra LARGE_X;
2767
2768 // === softplus: sp = ln(1 + exp(x)) ===
2769 // exp(x)
2770 mov.f64 %e_half, 0d3FE0000000000000;
2771 fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
2772 cvt.rmi.f64.f64 %e_nf, %e_nf;
2773 cvt.rni.s32.f64 %e_ni, %e_nf;
2774 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
2775 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2776 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2777 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2778 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2779 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2780 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2781 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2782 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2783 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2784 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2785 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2786 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2787 fma.rn.f64 %ex, %e_p, %e_r, %one;
2788 cvt.s64.s32 %e_ni64, %e_ni;
2789 add.s64 %e_ni64, %e_ni64, 1023;
2790 shl.b64 %e_bits, %e_ni64, 52;
2791 mov.b64 %e_nf, %e_bits;
2792 mul.f64 %ex, %ex, %e_nf;
2793
2794 // ep1 = 1 + exp(x)
2795 add.f64 %ep1, %ex, %one;
2796
2797 // ln(ep1) via argument reduction
2798 mov.b64 %l_xbits, %ep1;
2799 shr.u64 %l_exp64, %l_xbits, 52;
2800 and.b64 %l_exp64, %l_exp64, 2047;
2801 sub.s64 %l_exp64, %l_exp64, 1023;
2802 cvt.rn.f64.s64 %l_nf, %l_exp64;
2803 mov.u64 %l_bias, 0x3FF0000000000000;
2804 and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
2805 or.b64 %l_mbits, %l_mbits, %l_bias;
2806 mov.b64 %l_m, %l_mbits;
2807 sub.f64 %l_f, %l_m, %one;
2808 add.f64 %l_s, %l_m, %one;
2809 div.rn.f64 %l_f, %l_f, %l_s;
2810 mul.f64 %l_f2, %l_f, %l_f;
2811 mov.f64 %l_p, 0d3FB745D1745D1746;
2812 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
2813 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
2814 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
2815 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
2816 fma.rn.f64 %l_p, %l_p, %l_f2, %one;
2817 mul.f64 %l_p, %l_p, %l_f;
2818 add.f64 %l_p, %l_p, %l_p;
2819 mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
2820 fma.rn.f64 %sp, %l_nf, %l_ln2, %l_p;
2821
2822 // === tanh(sp) = (exp(2*sp)-1)/(exp(2*sp)+1) ===
2823 add.f64 %two_sp, %sp, %sp;
2824 fma.rn.f64 %e_nf, %two_sp, 0d3FF71547652B82FE, %e_half;
2825 cvt.rmi.f64.f64 %e_nf, %e_nf;
2826 cvt.rni.s32.f64 %e_ni, %e_nf;
2827 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
2828 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2829 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2830 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2831 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2832 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2833 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2834 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2835 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2836 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2837 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2838 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2839 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2840 fma.rn.f64 %e2sp, %e_p, %e_r, %one;
2841 cvt.s64.s32 %e_ni64, %e_ni;
2842 add.s64 %e_ni64, %e_ni64, 1023;
2843 shl.b64 %e_bits, %e_ni64, 52;
2844 mov.b64 %e_nf, %e_bits;
2845 mul.f64 %e2sp, %e2sp, %e_nf;
2846
2847 sub.f64 %e2sp_m1, %e2sp, %one;
2848 add.f64 %e2sp_p1, %e2sp, %one;
2849 div.rn.f64 %th, %e2sp_m1, %e2sp_p1;
2850
2851 mul.f64 %vr, %x, %th;
2852 st.global.f64 [%out], %vr;
2853 bra DONE;
2854
2855LARGE_X:
2856 // softplus ~ x, tanh(x) = (exp(2x)-1)/(exp(2x)+1) in f64
2857 add.f64 %two_sp, %x, %x;
2858 mov.f64 %e_half, 0d3FE0000000000000;
2859 fma.rn.f64 %e_nf, %two_sp, 0d3FF71547652B82FE, %e_half;
2860 cvt.rmi.f64.f64 %e_nf, %e_nf;
2861 cvt.rni.s32.f64 %e_ni, %e_nf;
2862 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
2863 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
2864 mov.f64 %e_p, 0d3E21EED8EFF8D898;
2865 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
2866 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
2867 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
2868 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
2869 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
2870 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
2871 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
2872 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
2873 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
2874 fma.rn.f64 %e_p, %e_p, %e_r, %one;
2875 fma.rn.f64 %e2sp, %e_p, %e_r, %one;
2876 cvt.s64.s32 %e_ni64, %e_ni;
2877 add.s64 %e_ni64, %e_ni64, 1023;
2878 shl.b64 %e_bits, %e_ni64, 52;
2879 mov.b64 %e_nf, %e_bits;
2880 mul.f64 %e2sp, %e2sp, %e_nf;
2881
2882 sub.f64 %e2sp_m1, %e2sp, %one;
2883 add.f64 %e2sp_p1, %e2sp, %one;
2884 div.rn.f64 %th, %e2sp_m1, %e2sp_p1;
2885 mul.f64 %vr, %x, %th;
2886 st.global.f64 [%out], %vr;
2887
2888DONE:
2889 ret;
2890}
2891";
2892
2893#[cfg(feature = "cuda")]
2902pub(crate) const MISH_BACKWARD_PTX: &str = "\
2903.version 7.0
2904.target sm_52
2905.address_size 64
2906
2907.visible .entry mish_backward_kernel(
2908 .param .u64 grad_ptr,
2909 .param .u64 input_ptr,
2910 .param .u64 out_ptr,
2911 .param .u32 n
2912) {
2913 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
2914 .reg .u64 %grad, %input, %out, %off;
2915 .reg .f32 %vg, %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
2916 .reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %t, %t2, %one_m_t2;
2917 .reg .f32 %neg, %en, %denom, %sig, %x_sig_omt2, %deriv, %result;
2918 .reg .f32 %threshold;
2919 .reg .pred %p, %large;
2920
2921 ld.param.u64 %grad, [grad_ptr];
2922 ld.param.u64 %input, [input_ptr];
2923 ld.param.u64 %out, [out_ptr];
2924 ld.param.u32 %n_reg, [n];
2925
2926 mov.u32 %bid, %ctaid.x;
2927 mov.u32 %bdim, %ntid.x;
2928 mov.u32 %r_tid, %tid.x;
2929 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
2930
2931 setp.ge.u32 %p, %r_tid, %n_reg;
2932 @%p bra DONE;
2933
2934 cvt.u64.u32 %off, %r_tid;
2935 shl.b64 %off, %off, 2;
2936 add.u64 %grad, %grad, %off;
2937 add.u64 %input, %input, %off;
2938 add.u64 %out, %out, %off;
2939
2940 ld.global.f32 %vg, [%grad];
2941 ld.global.f32 %x, [%input];
2942
2943 mov.f32 %one, 0f3F800000;
2944 mov.f32 %lg2e, 0f3FB8AA3B;
2945 // threshold = 20.0
2946 mov.f32 %threshold, 0f41A00000;
2947
2948 setp.gt.f32 %large, %x, %threshold;
2949 @%large bra LARGE_X;
2950
2951 // --- Normal path ---
2952 // softplus: sp = ln(1 + exp(x))
2953 mul.f32 %ex, %x, %lg2e;
2954 ex2.approx.f32 %ex, %ex;
2955 add.f32 %ep1, %ex, %one;
2956 lg2.approx.f32 %lg_ep1, %ep1;
2957 // ln(2) = 0x3F317218
2958 mul.f32 %sp, %lg_ep1, 0f3F317218;
2959
2960 // t = tanh(sp) = (exp(2*sp)-1)/(exp(2*sp)+1)
2961 add.f32 %two_sp, %sp, %sp;
2962 mul.f32 %two_sp, %two_sp, %lg2e;
2963 ex2.approx.f32 %e2sp, %two_sp;
2964 sub.f32 %e2sp_m1, %e2sp, %one;
2965 add.f32 %e2sp_p1, %e2sp, %one;
2966 rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2967 mul.f32 %t, %e2sp_m1, %e2sp_p1;
2968
2969 // sig = sigmoid(x) = 1/(1+exp(-x))
2970 neg.f32 %neg, %x;
2971 mul.f32 %neg, %neg, %lg2e;
2972 ex2.approx.f32 %en, %neg;
2973 add.f32 %denom, %one, %en;
2974 rcp.approx.f32 %sig, %denom;
2975
2976 // deriv = t + x * sig * (1 - t*t)
2977 mul.f32 %t2, %t, %t;
2978 sub.f32 %one_m_t2, %one, %t2;
2979 mul.f32 %x_sig_omt2, %x, %sig;
2980 mul.f32 %x_sig_omt2, %x_sig_omt2, %one_m_t2;
2981 add.f32 %deriv, %t, %x_sig_omt2;
2982 mul.f32 %result, %vg, %deriv;
2983 st.global.f32 [%out], %result;
2984 bra DONE;
2985
2986LARGE_X:
2987 // sp ~ x, t ~ tanh(x), sig ~ 1
2988 // tanh(x) = (exp(2x)-1)/(exp(2x)+1)
2989 add.f32 %two_sp, %x, %x;
2990 mul.f32 %two_sp, %two_sp, %lg2e;
2991 ex2.approx.f32 %e2sp, %two_sp;
2992 sub.f32 %e2sp_m1, %e2sp, %one;
2993 add.f32 %e2sp_p1, %e2sp, %one;
2994 rcp.approx.f32 %e2sp_p1, %e2sp_p1;
2995 mul.f32 %t, %e2sp_m1, %e2sp_p1;
2996
2997 // sig ~ 1, deriv ~ t + x*(1-t*t)
2998 mul.f32 %t2, %t, %t;
2999 sub.f32 %one_m_t2, %one, %t2;
3000 mul.f32 %x_sig_omt2, %x, %one_m_t2;
3001 add.f32 %deriv, %t, %x_sig_omt2;
3002 mul.f32 %result, %vg, %deriv;
3003 st.global.f32 [%out], %result;
3004
3005DONE:
3006 ret;
3007}
3008";
3009
3010#[cfg(feature = "cuda")]
3013pub(crate) const MISH_BACKWARD_F64_PTX: &str = "\
3014.version 7.0
3015.target sm_52
3016.address_size 64
3017
3018.visible .entry mish_backward_f64_kernel(
3019 .param .u64 grad_ptr,
3020 .param .u64 input_ptr,
3021 .param .u64 out_ptr,
3022 .param .u32 n
3023) {
3024 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3025 .reg .u64 %grad, %input, %out, %off;
3026 .reg .f64 %vg, %x, %one, %ex, %ep1, %sp;
3027 .reg .f64 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %t, %t2, %one_m_t2;
3028 .reg .f64 %neg_x, %en, %denom, %sig, %x_sig_omt2, %deriv, %result;
3029 .reg .f64 %threshold;
3030 // exp subroutine regs
3031 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
3032 .reg .s32 %e_ni;
3033 .reg .s64 %e_ni64, %e_bits;
3034 // log subroutine regs
3035 .reg .u64 %l_xbits, %l_mbits, %l_bias;
3036 .reg .s64 %l_exp64;
3037 .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
3038 .reg .pred %p, %large;
3039
3040 ld.param.u64 %grad, [grad_ptr];
3041 ld.param.u64 %input, [input_ptr];
3042 ld.param.u64 %out, [out_ptr];
3043 ld.param.u32 %n_reg, [n];
3044
3045 mov.u32 %bid, %ctaid.x;
3046 mov.u32 %bdim, %ntid.x;
3047 mov.u32 %r_tid, %tid.x;
3048 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3049
3050 setp.ge.u32 %p, %r_tid, %n_reg;
3051 @%p bra DONE;
3052
3053 cvt.u64.u32 %off, %r_tid;
3054 shl.b64 %off, %off, 3;
3055 add.u64 %grad, %grad, %off;
3056 add.u64 %input, %input, %off;
3057 add.u64 %out, %out, %off;
3058
3059 ld.global.f64 %vg, [%grad];
3060 ld.global.f64 %x, [%input];
3061
3062 mov.f64 %one, 0d3FF0000000000000;
3063 mov.f64 %threshold, 0d4034000000000000;
3064
3065 setp.gt.f64 %large, %x, %threshold;
3066 @%large bra LARGE_X;
3067
3068 // === softplus: sp = ln(1 + exp(x)) ===
3069 // exp(x)
3070 mov.f64 %e_half, 0d3FE0000000000000;
3071 mul.f64 %e_nf, %x, 0d3FF71547652B82FE;
3072 cvt.rni.f64.f64 %e_nf, %e_nf;
3073 cvt.rni.s32.f64 %e_ni, %e_nf;
3074 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
3075 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3076 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3077 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3078 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3079 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3080 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3081 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3082 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3083 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3084 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3085 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3086 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3087 fma.rn.f64 %ex, %e_p, %e_r, %one;
3088 cvt.s64.s32 %e_ni64, %e_ni;
3089 add.s64 %e_ni64, %e_ni64, 1023;
3090 shl.b64 %e_bits, %e_ni64, 52;
3091 mov.b64 %e_nf, %e_bits;
3092 mul.f64 %ex, %ex, %e_nf;
3093
3094 add.f64 %ep1, %ex, %one;
3095
3096 // ln(ep1) via argument reduction
3097 mov.b64 %l_xbits, %ep1;
3098 shr.u64 %l_exp64, %l_xbits, 52;
3099 and.b64 %l_exp64, %l_exp64, 2047;
3100 sub.s64 %l_exp64, %l_exp64, 1023;
3101 cvt.rn.f64.s64 %l_nf, %l_exp64;
3102 mov.u64 %l_bias, 0x3FF0000000000000;
3103 and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
3104 or.b64 %l_mbits, %l_mbits, %l_bias;
3105 mov.b64 %l_m, %l_mbits;
3106 sub.f64 %l_f, %l_m, %one;
3107 add.f64 %l_s, %l_m, %one;
3108 div.rn.f64 %l_f, %l_f, %l_s;
3109 mul.f64 %l_f2, %l_f, %l_f;
3110 mov.f64 %l_p, 0d3FB745D1745D1746;
3111 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
3112 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
3113 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
3114 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
3115 fma.rn.f64 %l_p, %l_p, %l_f2, %one;
3116 mul.f64 %l_p, %l_p, %l_f;
3117 add.f64 %l_p, %l_p, %l_p;
3118 mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
3119 fma.rn.f64 %sp, %l_nf, %l_ln2, %l_p;
3120
3121 // === tanh(sp) ===
3122 add.f64 %two_sp, %sp, %sp;
3123 mul.f64 %e_nf, %two_sp, 0d3FF71547652B82FE;
3124 cvt.rni.f64.f64 %e_nf, %e_nf;
3125 cvt.rni.s32.f64 %e_ni, %e_nf;
3126 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
3127 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3128 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3129 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3130 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3131 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3132 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3133 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3134 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3135 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3136 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3137 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3138 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3139 fma.rn.f64 %e2sp, %e_p, %e_r, %one;
3140 cvt.s64.s32 %e_ni64, %e_ni;
3141 add.s64 %e_ni64, %e_ni64, 1023;
3142 shl.b64 %e_bits, %e_ni64, 52;
3143 mov.b64 %e_nf, %e_bits;
3144 mul.f64 %e2sp, %e2sp, %e_nf;
3145
3146 sub.f64 %e2sp_m1, %e2sp, %one;
3147 add.f64 %e2sp_p1, %e2sp, %one;
3148 div.rn.f64 %t, %e2sp_m1, %e2sp_p1;
3149
3150 // === sigmoid(x) = 1/(1+exp(-x)) ===
3151 neg.f64 %neg_x, %x;
3152 mul.f64 %e_nf, %neg_x, 0d3FF71547652B82FE;
3153 cvt.rni.f64.f64 %e_nf, %e_nf;
3154 cvt.rni.s32.f64 %e_ni, %e_nf;
3155 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
3156 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3157 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3158 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3159 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3160 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3161 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3162 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3163 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3164 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3165 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3166 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3167 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3168 fma.rn.f64 %en, %e_p, %e_r, %one;
3169 cvt.s64.s32 %e_ni64, %e_ni;
3170 add.s64 %e_ni64, %e_ni64, 1023;
3171 shl.b64 %e_bits, %e_ni64, 52;
3172 mov.b64 %e_nf, %e_bits;
3173 mul.f64 %en, %en, %e_nf;
3174
3175 add.f64 %denom, %one, %en;
3176 div.rn.f64 %sig, %one, %denom;
3177
3178 // deriv = t + x * sig * (1 - t*t)
3179 mul.f64 %t2, %t, %t;
3180 sub.f64 %one_m_t2, %one, %t2;
3181 mul.f64 %x_sig_omt2, %x, %sig;
3182 mul.f64 %x_sig_omt2, %x_sig_omt2, %one_m_t2;
3183 add.f64 %deriv, %t, %x_sig_omt2;
3184 mul.f64 %result, %vg, %deriv;
3185 st.global.f64 [%out], %result;
3186 bra DONE;
3187
3188LARGE_X:
3189 // sp ~ x, tanh(x) in f64, sig ~ 1
3190 add.f64 %two_sp, %x, %x;
3191 mov.f64 %e_half, 0d3FE0000000000000;
3192 mul.f64 %e_nf, %two_sp, 0d3FF71547652B82FE;
3193 cvt.rni.f64.f64 %e_nf, %e_nf;
3194 cvt.rni.s32.f64 %e_ni, %e_nf;
3195 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
3196 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3197 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3198 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3199 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3200 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3201 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3202 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3203 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3204 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3205 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
3206 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3207 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3208 fma.rn.f64 %e2sp, %e_p, %e_r, %one;
3209 cvt.s64.s32 %e_ni64, %e_ni;
3210 add.s64 %e_ni64, %e_ni64, 1023;
3211 shl.b64 %e_bits, %e_ni64, 52;
3212 mov.b64 %e_nf, %e_bits;
3213 mul.f64 %e2sp, %e2sp, %e_nf;
3214
3215 sub.f64 %e2sp_m1, %e2sp, %one;
3216 add.f64 %e2sp_p1, %e2sp, %one;
3217 div.rn.f64 %t, %e2sp_m1, %e2sp_p1;
3218
3219 // sig ~ 1, deriv ~ t + x*(1-t*t)
3220 mul.f64 %t2, %t, %t;
3221 sub.f64 %one_m_t2, %one, %t2;
3222 mul.f64 %x_sig_omt2, %x, %one_m_t2;
3223 add.f64 %deriv, %t, %x_sig_omt2;
3224 mul.f64 %result, %vg, %deriv;
3225 st.global.f64 [%out], %result;
3226
3227DONE:
3228 ret;
3229}
3230";
3231
3232#[cfg(feature = "cuda")]
3235pub(crate) const CLAMP_PTX: &str = "\
3236.version 7.0
3237.target sm_52
3238.address_size 64
3239
3240.visible .entry clamp_kernel(
3241 .param .u64 in_ptr,
3242 .param .u64 out_ptr,
3243 .param .u32 n,
3244 .param .f32 min_val,
3245 .param .f32 max_val
3246) {
3247 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3248 .reg .u64 %in, %out, %off;
3249 .reg .f32 %x, %mn, %mx, %result;
3250 .reg .pred %p;
3251
3252 ld.param.u64 %in, [in_ptr];
3253 ld.param.u64 %out, [out_ptr];
3254 ld.param.u32 %n_reg, [n];
3255 ld.param.f32 %mn, [min_val];
3256 ld.param.f32 %mx, [max_val];
3257
3258 mov.u32 %bid, %ctaid.x;
3259 mov.u32 %bdim, %ntid.x;
3260 mov.u32 %r_tid, %tid.x;
3261 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3262
3263 setp.ge.u32 %p, %r_tid, %n_reg;
3264 @%p bra DONE;
3265
3266 cvt.u64.u32 %off, %r_tid;
3267 shl.b64 %off, %off, 2;
3268 add.u64 %in, %in, %off;
3269 add.u64 %out, %out, %off;
3270
3271 ld.global.f32 %x, [%in];
3272 max.f32 %result, %x, %mn;
3273 min.f32 %result, %result, %mx;
3274 st.global.f32 [%out], %result;
3275
3276DONE:
3277 ret;
3278}
3279";
3280
3281
3282#[cfg(feature = "cuda")]
3289pub(crate) const RELU_BACKWARD_PTX: &str = "\
3290.version 7.0
3291.target sm_52
3292.address_size 64
3293
3294.visible .entry relu_backward_kernel(
3295 .param .u64 grad_ptr,
3296 .param .u64 input_ptr,
3297 .param .u64 out_ptr,
3298 .param .u32 n
3299) {
3300 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3301 .reg .u64 %grad, %input, %out, %off;
3302 .reg .f32 %vg, %vi, %zero, %vr;
3303 .reg .pred %p, %pos;
3304
3305 ld.param.u64 %grad, [grad_ptr];
3306 ld.param.u64 %input, [input_ptr];
3307 ld.param.u64 %out, [out_ptr];
3308 ld.param.u32 %n_reg, [n];
3309
3310 mov.u32 %bid, %ctaid.x;
3311 mov.u32 %bdim, %ntid.x;
3312 mov.u32 %r_tid, %tid.x;
3313 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3314
3315 setp.ge.u32 %p, %r_tid, %n_reg;
3316 @%p bra DONE;
3317
3318 cvt.u64.u32 %off, %r_tid;
3319 shl.b64 %off, %off, 2;
3320
3321 add.u64 %grad, %grad, %off;
3322 add.u64 %input, %input, %off;
3323 add.u64 %out, %out, %off;
3324
3325 ld.global.f32 %vg, [%grad];
3326 ld.global.f32 %vi, [%input];
3327 mov.f32 %zero, 0f00000000;
3328 setp.gt.f32 %pos, %vi, %zero;
3329 selp.f32 %vr, %vg, %zero, %pos;
3330 st.global.f32 [%out], %vr;
3331
3332DONE:
3333 ret;
3334}
3335";
3336
3337
3338#[cfg(feature = "cuda")]
3348pub(crate) const GELU_BACKWARD_PTX: &str = "\
3349.version 7.0
3350.target sm_52
3351.address_size 64
3352
3353.visible .entry gelu_backward_kernel(
3354 .param .u64 grad_ptr,
3355 .param .u64 input_ptr,
3356 .param .u64 out_ptr,
3357 .param .u32 n
3358) {
3359 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3360 .reg .u64 %grad, %input, %out, %off;
3361 .reg .f32 %vg, %x, %k, %kx, %neg_kx, %log2e, %exp_neg, %one, %denom, %sig;
3362 .reg .f32 %one_minus_sig, %kx_sig_oms, %dsig, %result;
3363 .reg .pred %p;
3364
3365 ld.param.u64 %grad, [grad_ptr];
3366 ld.param.u64 %input, [input_ptr];
3367 ld.param.u64 %out, [out_ptr];
3368 ld.param.u32 %n_reg, [n];
3369
3370 mov.u32 %bid, %ctaid.x;
3371 mov.u32 %bdim, %ntid.x;
3372 mov.u32 %r_tid, %tid.x;
3373 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3374
3375 setp.ge.u32 %p, %r_tid, %n_reg;
3376 @%p bra DONE;
3377
3378 cvt.u64.u32 %off, %r_tid;
3379 shl.b64 %off, %off, 2;
3380
3381 add.u64 %grad, %grad, %off;
3382 add.u64 %input, %input, %off;
3383 add.u64 %out, %out, %off;
3384
3385 ld.global.f32 %vg, [%grad];
3386 ld.global.f32 %x, [%input];
3387
3388 // sig = sigmoid(1.702 * x)
3389 mov.f32 %k, 0f3FDA2720;
3390 mul.f32 %kx, %k, %x;
3391 neg.f32 %neg_kx, %kx;
3392 mov.f32 %log2e, 0f3FB8AA3B;
3393 mul.f32 %neg_kx, %neg_kx, %log2e;
3394 ex2.approx.f32 %exp_neg, %neg_kx;
3395 mov.f32 %one, 0f3F800000;
3396 add.f32 %denom, %one, %exp_neg;
3397 rcp.approx.f32 %sig, %denom;
3398
3399 // d/dx gelu(x) = sig + k * x * sig * (1 - sig)
3400 sub.f32 %one_minus_sig, %one, %sig;
3401 mul.f32 %kx_sig_oms, %kx, %sig;
3402 mul.f32 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
3403 add.f32 %dsig, %sig, %kx_sig_oms;
3404
3405 // out = grad * d_gelu
3406 mul.f32 %result, %vg, %dsig;
3407 st.global.f32 [%out], %result;
3408
3409DONE:
3410 ret;
3411}
3412";
3413
3414#[cfg(feature = "cuda")]
3417pub(crate) const GELU_BACKWARD_F64_PTX: &str = "\
3418.version 7.0
3419.target sm_52
3420.address_size 64
3421
3422.visible .entry gelu_backward_f64_kernel(
3423 .param .u64 grad_ptr,
3424 .param .u64 input_ptr,
3425 .param .u64 out_ptr,
3426 .param .u32 n
3427) {
3428 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3429 .reg .u64 %grad, %input, %out, %off;
3430 .reg .f64 %vg, %x, %k, %kx, %neg_kx, %exp_neg, %one, %denom, %sig;
3431 .reg .f64 %one_minus_sig, %kx_sig_oms, %dsig, %result;
3432 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
3433 .reg .s32 %e_ni;
3434 .reg .s64 %e_ni64, %e_bits;
3435 .reg .pred %p;
3436
3437 ld.param.u64 %grad, [grad_ptr];
3438 ld.param.u64 %input, [input_ptr];
3439 ld.param.u64 %out, [out_ptr];
3440 ld.param.u32 %n_reg, [n];
3441
3442 mov.u32 %bid, %ctaid.x;
3443 mov.u32 %bdim, %ntid.x;
3444 mov.u32 %r_tid, %tid.x;
3445 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3446
3447 setp.ge.u32 %p, %r_tid, %n_reg;
3448 @%p bra DONE;
3449
3450 cvt.u64.u32 %off, %r_tid;
3451 shl.b64 %off, %off, 3;
3452 add.u64 %grad, %grad, %off;
3453 add.u64 %input, %input, %off;
3454 add.u64 %out, %out, %off;
3455
3456 ld.global.f64 %vg, [%grad];
3457 ld.global.f64 %x, [%input];
3458
3459 mov.f64 %one, 0d3FF0000000000000;
3460 mov.f64 %k, 0d3FFB44E400000000;
3461 mul.f64 %kx, %k, %x;
3462 neg.f64 %neg_kx, %kx;
3463
3464 // --- exp(%neg_kx) via Cody-Waite + degree-11 Horner ---
3465 mov.f64 %e_half, 0d3FE0000000000000;
3466 fma.rn.f64 %e_nf, %neg_kx, 0d3FF71547652B82FE, %e_half;
3467 cvt.rmi.f64.f64 %e_nf, %e_nf;
3468 cvt.rni.s32.f64 %e_ni, %e_nf;
3469 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_kx;
3470 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3471 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3472 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3473 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3474 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3475 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3476 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3477 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3478 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3479 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
3480 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3481 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3482 fma.rn.f64 %exp_neg, %e_p, %e_r, %one;
3483 cvt.s64.s32 %e_ni64, %e_ni;
3484 add.s64 %e_ni64, %e_ni64, 1023;
3485 shl.b64 %e_bits, %e_ni64, 52;
3486 mov.b64 %e_nf, %e_bits;
3487 mul.f64 %exp_neg, %exp_neg, %e_nf;
3488 // --- end exp ---
3489
3490 add.f64 %denom, %one, %exp_neg;
3491 div.rn.f64 %sig, %one, %denom;
3492
3493 sub.f64 %one_minus_sig, %one, %sig;
3494 mul.f64 %kx_sig_oms, %kx, %sig;
3495 mul.f64 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
3496 add.f64 %dsig, %sig, %kx_sig_oms;
3497
3498 mul.f64 %result, %vg, %dsig;
3499 st.global.f64 [%out], %result;
3500
3501DONE:
3502 ret;
3503}
3504";
3505
3506#[cfg(feature = "cuda")]
3514pub(crate) const GELU_BACKWARD_ERF_PTX: &str = "\
3515.version 7.0
3516.target sm_52
3517.address_size 64
3518
3519.visible .entry gelu_backward_erf_kernel(
3520 .param .u64 grad_ptr,
3521 .param .u64 input_ptr,
3522 .param .u64 out_ptr,
3523 .param .u32 n
3524) {
3525 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3526 .reg .u64 %grad, %input, %out, %off;
3527 .reg .f32 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
3528 .reg .f32 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
3529 .reg .f32 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
3530 .reg .f32 %d_gelu, %result;
3531 .reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %log2e;
3532 .reg .pred %pred_ge, %pred_neg;
3533
3534 ld.param.u64 %grad, [grad_ptr];
3535 ld.param.u64 %input, [input_ptr];
3536 ld.param.u64 %out, [out_ptr];
3537 ld.param.u32 %n_reg, [n];
3538
3539 mov.u32 %bid, %ctaid.x;
3540 mov.u32 %bdim, %ntid.x;
3541 mov.u32 %r_tid, %tid.x;
3542 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3543
3544 setp.ge.u32 %pred_ge, %r_tid, %n_reg;
3545 @%pred_ge bra DONE;
3546
3547 cvt.u64.u32 %off, %r_tid;
3548 shl.b64 %off, %off, 2;
3549
3550 add.u64 %grad, %grad, %off;
3551 add.u64 %input, %input, %off;
3552 add.u64 %out, %out, %off;
3553
3554 ld.global.f32 %vg, [%grad];
3555 ld.global.f32 %x, [%input];
3556
3557 mov.f32 %one, 0f3F800000;
3558 mov.f32 %half, 0f3F000000;
3559
3560 // z = x / sqrt(2) = x * 0.70710678
3561 mov.f32 %z, 0f3F3504F3;
3562 mul.f32 %z, %x, %z;
3563
3564 // |z| for erf(|z|)
3565 abs.f32 %ax, %z;
3566
3567 // t = 1 / (1 + 0.3275911 * |z|)
3568 mov.f32 %p, 0f3EA7BA05;
3569 mul.f32 %t, %p, %ax;
3570 add.f32 %t, %one, %t;
3571 rcp.approx.f32 %t, %t;
3572
3573 // Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
3574 mov.f32 %a5, 0f3E0AAAAB;
3575 mov.f32 %a4, 0fBEB3A903;
3576 mov.f32 %a3, 0f3FB506DD;
3577 mov.f32 %a2, 0fBF03C1E1;
3578 mov.f32 %a1, 0f3EA0D6BB;
3579
3580 mul.f32 %pt, %t, %a5;
3581 add.f32 %pt, %pt, %a4;
3582 mul.f32 %pt, %pt, %t;
3583 add.f32 %pt, %pt, %a3;
3584 mul.f32 %pt, %pt, %t;
3585 add.f32 %pt, %pt, %a2;
3586 mul.f32 %pt, %pt, %t;
3587 add.f32 %pt, %pt, %a1;
3588 mul.f32 %pt, %pt, %t;
3589
3590 // exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
3591 mul.f32 %z2, %ax, %ax;
3592 neg.f32 %neg_z2, %z2;
3593 mov.f32 %log2e, 0f3FB8AA3B;
3594 mul.f32 %neg_z2, %neg_z2, %log2e;
3595 ex2.approx.f32 %exp_neg_z2, %neg_z2;
3596
3597 // erf(|z|) = 1 - poly * exp(-z^2)
3598 mul.f32 %erf_val, %pt, %exp_neg_z2;
3599 sub.f32 %erf_val, %one, %erf_val;
3600
3601 // erf(-z) = -erf(z), so sign-correct
3602 setp.lt.f32 %pred_neg, %z, 0f00000000;
3603 @%pred_neg neg.f32 %erf_val, %erf_val;
3604
3605 // Φ(x) = 0.5 * (1 + erf(x/sqrt(2)))
3606 add.f32 %cdf, %one, %erf_val;
3607 mul.f32 %cdf, %half, %cdf;
3608
3609 // φ(x) = exp(-x²/2) / sqrt(2π)
3610 // exp(-x²/2):
3611 mul.f32 %neg_x2h, %x, %x;
3612 mul.f32 %neg_x2h, %neg_x2h, %half;
3613 neg.f32 %neg_x2h, %neg_x2h;
3614 mul.f32 %neg_x2h, %neg_x2h, %log2e;
3615 ex2.approx.f32 %exp_neg_x2h, %neg_x2h;
3616
3617 // 1/sqrt(2π) = 0.39894228
3618 mov.f32 %inv_sqrt_2pi, 0f3ECC4220;
3619 mul.f32 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
3620
3621 // d/dx gelu(x) = Φ(x) + x * φ(x)
3622 mul.f32 %x_pdf, %x, %pdf;
3623 add.f32 %d_gelu, %cdf, %x_pdf;
3624
3625 // out = grad * d_gelu
3626 mul.f32 %result, %vg, %d_gelu;
3627 st.global.f32 [%out], %result;
3628
3629DONE:
3630 ret;
3631}
3632";
3633
3634#[cfg(feature = "cuda")]
3637pub(crate) const GELU_BACKWARD_ERF_F64_PTX: &str = "\
3638.version 7.0
3639.target sm_52
3640.address_size 64
3641
3642.visible .entry gelu_backward_erf_f64_kernel(
3643 .param .u64 grad_ptr,
3644 .param .u64 input_ptr,
3645 .param .u64 out_ptr,
3646 .param .u32 n
3647) {
3648 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3649 .reg .u64 %grad, %input, %out, %off;
3650 .reg .f64 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
3651 .reg .f64 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
3652 .reg .f64 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
3653 .reg .f64 %d_gelu, %result;
3654 .reg .f64 %p_coef, %a1, %a2, %a3, %a4, %a5;
3655 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
3656 .reg .s32 %e_ni;
3657 .reg .s64 %e_ni64, %e_bits;
3658 .reg .pred %pred_ge, %pred_neg;
3659
3660 ld.param.u64 %grad, [grad_ptr];
3661 ld.param.u64 %input, [input_ptr];
3662 ld.param.u64 %out, [out_ptr];
3663 ld.param.u32 %n_reg, [n];
3664
3665 mov.u32 %bid, %ctaid.x;
3666 mov.u32 %bdim, %ntid.x;
3667 mov.u32 %r_tid, %tid.x;
3668 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3669
3670 setp.ge.u32 %pred_ge, %r_tid, %n_reg;
3671 @%pred_ge bra DONE;
3672
3673 cvt.u64.u32 %off, %r_tid;
3674 shl.b64 %off, %off, 3;
3675 add.u64 %grad, %grad, %off;
3676 add.u64 %input, %input, %off;
3677 add.u64 %out, %out, %off;
3678
3679 ld.global.f64 %vg, [%grad];
3680 ld.global.f64 %x, [%input];
3681
3682 mov.f64 %one, 0d3FF0000000000000;
3683 mov.f64 %half, 0d3FE0000000000000;
3684
3685 mov.f64 %z, 0d3FE6A09E60000000;
3686 mul.f64 %z, %x, %z;
3687 abs.f64 %ax, %z;
3688
3689 mov.f64 %p_coef, 0d3FD4F740A0000000;
3690 mul.f64 %t, %p_coef, %ax;
3691 add.f64 %t, %one, %t;
3692 div.rn.f64 %t, %one, %t;
3693
3694 mov.f64 %a5, 0d3FC1555560000000;
3695 mov.f64 %a4, 0dBFD6752060000000;
3696 mov.f64 %a3, 0d3FF6A0DBA0000000;
3697 mov.f64 %a2, 0dBFE0783C20000000;
3698 mov.f64 %a1, 0d3FD41AD760000000;
3699
3700 mul.f64 %pt, %t, %a5;
3701 add.f64 %pt, %pt, %a4;
3702 mul.f64 %pt, %pt, %t;
3703 add.f64 %pt, %pt, %a3;
3704 mul.f64 %pt, %pt, %t;
3705 add.f64 %pt, %pt, %a2;
3706 mul.f64 %pt, %pt, %t;
3707 add.f64 %pt, %pt, %a1;
3708 mul.f64 %pt, %pt, %t;
3709
3710 // exp(-z^2) in full f64
3711 mul.f64 %z2, %ax, %ax;
3712 neg.f64 %neg_z2, %z2;
3713
3714 // --- exp(%neg_z2) ---
3715 mov.f64 %e_half, 0d3FE0000000000000;
3716 fma.rn.f64 %e_nf, %neg_z2, 0d3FF71547652B82FE, %e_half;
3717 cvt.rmi.f64.f64 %e_nf, %e_nf;
3718 cvt.rni.s32.f64 %e_ni, %e_nf;
3719 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_z2;
3720 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3721 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3722 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3723 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3724 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3725 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3726 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3727 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3728 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3729 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
3730 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3731 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3732 fma.rn.f64 %exp_neg_z2, %e_p, %e_r, %one;
3733 cvt.s64.s32 %e_ni64, %e_ni;
3734 add.s64 %e_ni64, %e_ni64, 1023;
3735 shl.b64 %e_bits, %e_ni64, 52;
3736 mov.b64 %e_nf, %e_bits;
3737 mul.f64 %exp_neg_z2, %exp_neg_z2, %e_nf;
3738 // --- end exp ---
3739
3740 mul.f64 %erf_val, %pt, %exp_neg_z2;
3741 sub.f64 %erf_val, %one, %erf_val;
3742
3743 setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
3744 @%pred_neg neg.f64 %erf_val, %erf_val;
3745
3746 add.f64 %cdf, %one, %erf_val;
3747 mul.f64 %cdf, %half, %cdf;
3748
3749 // phi(x) = exp(-x^2/2) / sqrt(2*pi)
3750 mul.f64 %neg_x2h, %x, %x;
3751 mul.f64 %neg_x2h, %neg_x2h, %half;
3752 neg.f64 %neg_x2h, %neg_x2h;
3753
3754 // --- exp(%neg_x2h) ---
3755 fma.rn.f64 %e_nf, %neg_x2h, 0d3FF71547652B82FE, %e_half;
3756 cvt.rmi.f64.f64 %e_nf, %e_nf;
3757 cvt.rni.s32.f64 %e_ni, %e_nf;
3758 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x2h;
3759 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
3760 mov.f64 %e_p, 0d3E21EED8EFF8D898;
3761 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
3762 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
3763 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
3764 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
3765 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
3766 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
3767 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
3768 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
3769 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
3770 fma.rn.f64 %e_p, %e_p, %e_r, %one;
3771 fma.rn.f64 %exp_neg_x2h, %e_p, %e_r, %one;
3772 cvt.s64.s32 %e_ni64, %e_ni;
3773 add.s64 %e_ni64, %e_ni64, 1023;
3774 shl.b64 %e_bits, %e_ni64, 52;
3775 mov.b64 %e_nf, %e_bits;
3776 mul.f64 %exp_neg_x2h, %exp_neg_x2h, %e_nf;
3777 // --- end exp ---
3778
3779 // 1/sqrt(2*pi) = 0.39894228
3780 mov.f64 %inv_sqrt_2pi, 0d3FD9884440000000;
3781 mul.f64 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
3782
3783 mul.f64 %x_pdf, %x, %pdf;
3784 add.f64 %d_gelu, %cdf, %x_pdf;
3785
3786 mul.f64 %result, %vg, %d_gelu;
3787 st.global.f64 [%out], %result;
3788
3789DONE:
3790 ret;
3791}
3792";
3793
3794#[cfg(feature = "cuda")]
3801pub(crate) const INDEX_SELECT_1D_PTX: &str = "\
3802.version 7.0
3803.target sm_52
3804.address_size 64
3805
3806.visible .entry index_select_1d_kernel(
3807 .param .u64 input_ptr,
3808 .param .u64 indices_ptr,
3809 .param .u64 out_ptr,
3810 .param .u32 n_indices
3811) {
3812 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
3813 .reg .u64 %input, %indices, %out, %off, %addr;
3814 .reg .f32 %idx_f, %val;
3815 .reg .pred %p;
3816
3817 ld.param.u64 %input, [input_ptr];
3818 ld.param.u64 %indices, [indices_ptr];
3819 ld.param.u64 %out, [out_ptr];
3820 ld.param.u32 %n_reg, [n_indices];
3821
3822 mov.u32 %bid, %ctaid.x;
3823 mov.u32 %bdim, %ntid.x;
3824 mov.u32 %r_tid, %tid.x;
3825 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3826
3827 setp.ge.u32 %p, %r_tid, %n_reg;
3828 @%p bra DONE;
3829
3830 // Byte offset for thread
3831 cvt.u64.u32 %off, %r_tid;
3832 shl.b64 %off, %off, 2;
3833
3834 // Read indices[tid] (f32 -> u32)
3835 add.u64 %addr, %indices, %off;
3836 ld.global.f32 %idx_f, [%addr];
3837 cvt.rzi.u32.f32 %idx, %idx_f;
3838
3839 // Read input[idx]
3840 cvt.u64.u32 %addr, %idx;
3841 shl.b64 %addr, %addr, 2;
3842 add.u64 %addr, %input, %addr;
3843 ld.global.f32 %val, [%addr];
3844
3845 // Write output[tid]
3846 add.u64 %addr, %out, %off;
3847 st.global.f32 [%addr], %val;
3848
3849DONE:
3850 ret;
3851}
3852";
3853
3854
3855#[cfg(feature = "cuda")]
3864pub(crate) const SCATTER_ADD_1D_PTX: &str = "\
3865.version 7.0
3866.target sm_52
3867.address_size 64
3868
3869.visible .entry scatter_add_1d_kernel(
3870 .param .u64 grad_output_ptr,
3871 .param .u64 indices_ptr,
3872 .param .u64 grad_input_ptr,
3873 .param .u32 n_indices
3874) {
3875 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
3876 .reg .u64 %go, %indices, %gi, %off, %addr;
3877 .reg .f32 %idx_f, %grad_val, %dummy;
3878 .reg .pred %p;
3879
3880 ld.param.u64 %go, [grad_output_ptr];
3881 ld.param.u64 %indices, [indices_ptr];
3882 ld.param.u64 %gi, [grad_input_ptr];
3883 ld.param.u32 %n_reg, [n_indices];
3884
3885 mov.u32 %bid, %ctaid.x;
3886 mov.u32 %bdim, %ntid.x;
3887 mov.u32 %r_tid, %tid.x;
3888 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3889
3890 setp.ge.u32 %p, %r_tid, %n_reg;
3891 @%p bra DONE;
3892
3893 // Byte offset for thread
3894 cvt.u64.u32 %off, %r_tid;
3895 shl.b64 %off, %off, 2;
3896
3897 // Read grad_output[tid]
3898 add.u64 %addr, %go, %off;
3899 ld.global.f32 %grad_val, [%addr];
3900
3901 // Read indices[tid] (f32 -> u32)
3902 add.u64 %addr, %indices, %off;
3903 ld.global.f32 %idx_f, [%addr];
3904 cvt.rzi.u32.f32 %idx, %idx_f;
3905
3906 // Atomic add: grad_input[idx] += grad_val
3907 cvt.u64.u32 %addr, %idx;
3908 shl.b64 %addr, %addr, 2;
3909 add.u64 %addr, %gi, %addr;
3910 atom.global.add.f32 %dummy, [%addr], %grad_val;
3911
3912DONE:
3913 ret;
3914}
3915";
3916
3917
3918#[cfg(feature = "cuda")]
3925pub(crate) const MASKED_FILL_PTX: &str = "\
3926.version 7.0
3927.target sm_52
3928.address_size 64
3929
3930.visible .entry masked_fill_kernel(
3931 .param .u64 input_ptr,
3932 .param .u64 mask_ptr,
3933 .param .u64 out_ptr,
3934 .param .f32 fill_value,
3935 .param .u32 n
3936) {
3937 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3938 .reg .u64 %input, %mask, %out, %off;
3939 .reg .f32 %in_val, %mask_val, %fill, %result, %half;
3940 .reg .pred %p, %pmask;
3941
3942 ld.param.u64 %input, [input_ptr];
3943 ld.param.u64 %mask, [mask_ptr];
3944 ld.param.u64 %out, [out_ptr];
3945 ld.param.f32 %fill, [fill_value];
3946 ld.param.u32 %n_reg, [n];
3947
3948 mov.u32 %bid, %ctaid.x;
3949 mov.u32 %bdim, %ntid.x;
3950 mov.u32 %r_tid, %tid.x;
3951 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
3952
3953 setp.ge.u32 %p, %r_tid, %n_reg;
3954 @%p bra DONE;
3955
3956 cvt.u64.u32 %off, %r_tid;
3957 shl.b64 %off, %off, 2;
3958
3959 add.u64 %input, %input, %off;
3960 add.u64 %mask, %mask, %off;
3961 add.u64 %out, %out, %off;
3962
3963 ld.global.f32 %in_val, [%input];
3964 ld.global.f32 %mask_val, [%mask];
3965 mov.f32 %half, 0f3F000000;
3966 setp.ge.f32 %pmask, %mask_val, %half;
3967 selp.f32 %result, %fill, %in_val, %pmask;
3968 st.global.f32 [%out], %result;
3969
3970DONE:
3971 ret;
3972}
3973";
3974
3975
3976#[cfg(feature = "cuda")]
3983pub(crate) const MASKED_ZERO_PTX: &str = "\
3984.version 7.0
3985.target sm_52
3986.address_size 64
3987
3988.visible .entry masked_zero_kernel(
3989 .param .u64 grad_ptr,
3990 .param .u64 mask_ptr,
3991 .param .u64 out_ptr,
3992 .param .u32 n
3993) {
3994 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
3995 .reg .u64 %grad, %mask, %out, %off;
3996 .reg .f32 %vg, %mask_val, %zero, %result, %half;
3997 .reg .pred %p, %pmask;
3998
3999 ld.param.u64 %grad, [grad_ptr];
4000 ld.param.u64 %mask, [mask_ptr];
4001 ld.param.u64 %out, [out_ptr];
4002 ld.param.u32 %n_reg, [n];
4003
4004 mov.u32 %bid, %ctaid.x;
4005 mov.u32 %bdim, %ntid.x;
4006 mov.u32 %r_tid, %tid.x;
4007 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4008
4009 setp.ge.u32 %p, %r_tid, %n_reg;
4010 @%p bra DONE;
4011
4012 cvt.u64.u32 %off, %r_tid;
4013 shl.b64 %off, %off, 2;
4014
4015 add.u64 %grad, %grad, %off;
4016 add.u64 %mask, %mask, %off;
4017 add.u64 %out, %out, %off;
4018
4019 ld.global.f32 %vg, [%grad];
4020 ld.global.f32 %mask_val, [%mask];
4021 mov.f32 %zero, 0f00000000;
4022 mov.f32 %half, 0f3F000000;
4023 setp.ge.f32 %pmask, %mask_val, %half;
4024 selp.f32 %result, %zero, %vg, %pmask;
4025 st.global.f32 [%out], %result;
4026
4027DONE:
4028 ret;
4029}
4030";
4031
4032
4033#[cfg(feature = "cuda")]
4038pub(crate) const SIGMOID_BACKWARD_PTX: &str = "\
4039.version 7.0
4040.target sm_52
4041.address_size 64
4042
4043.visible .entry sigmoid_backward_kernel(
4044 .param .u64 grad_ptr,
4045 .param .u64 output_ptr,
4046 .param .u64 out_ptr,
4047 .param .u32 n
4048) {
4049 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4050 .reg .u64 %grad, %output, %out, %off;
4051 .reg .f32 %vg, %vo, %one, %one_minus_o, %result;
4052 .reg .pred %p;
4053
4054 ld.param.u64 %grad, [grad_ptr];
4055 ld.param.u64 %output, [output_ptr];
4056 ld.param.u64 %out, [out_ptr];
4057 ld.param.u32 %n_reg, [n];
4058
4059 mov.u32 %bid, %ctaid.x;
4060 mov.u32 %bdim, %ntid.x;
4061 mov.u32 %r_tid, %tid.x;
4062 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4063
4064 setp.ge.u32 %p, %r_tid, %n_reg;
4065 @%p bra DONE;
4066
4067 cvt.u64.u32 %off, %r_tid;
4068 shl.b64 %off, %off, 2;
4069
4070 add.u64 %grad, %grad, %off;
4071 add.u64 %output, %output, %off;
4072 add.u64 %out, %out, %off;
4073
4074 ld.global.f32 %vg, [%grad];
4075 ld.global.f32 %vo, [%output];
4076 mov.f32 %one, 0f3F800000;
4077 sub.f32 %one_minus_o, %one, %vo;
4078 mul.f32 %result, %vo, %one_minus_o;
4079 mul.f32 %result, %vg, %result;
4080 st.global.f32 [%out], %result;
4081
4082DONE:
4083 ret;
4084}
4085";
4086
4087
4088#[cfg(feature = "cuda")]
4093pub(crate) const TANH_BACKWARD_PTX: &str = "\
4094.version 7.0
4095.target sm_52
4096.address_size 64
4097
4098.visible .entry tanh_backward_kernel(
4099 .param .u64 grad_ptr,
4100 .param .u64 output_ptr,
4101 .param .u64 out_ptr,
4102 .param .u32 n
4103) {
4104 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
4105 .reg .u64 %grad, %output, %out, %off;
4106 .reg .f32 %vg, %vo, %one, %o_sq, %one_minus_sq, %result;
4107 .reg .pred %p;
4108
4109 ld.param.u64 %grad, [grad_ptr];
4110 ld.param.u64 %output, [output_ptr];
4111 ld.param.u64 %out, [out_ptr];
4112 ld.param.u32 %n_reg, [n];
4113
4114 mov.u32 %bid, %ctaid.x;
4115 mov.u32 %bdim, %ntid.x;
4116 mov.u32 %r_tid, %tid.x;
4117 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
4118
4119 setp.ge.u32 %p, %r_tid, %n_reg;
4120 @%p bra DONE;
4121
4122 cvt.u64.u32 %off, %r_tid;
4123 shl.b64 %off, %off, 2;
4124
4125 add.u64 %grad, %grad, %off;
4126 add.u64 %output, %output, %off;
4127 add.u64 %out, %out, %off;
4128
4129 ld.global.f32 %vg, [%grad];
4130 ld.global.f32 %vo, [%output];
4131 mov.f32 %one, 0f3F800000;
4132 mul.f32 %o_sq, %vo, %vo;
4133 sub.f32 %one_minus_sq, %one, %o_sq;
4134 mul.f32 %result, %vg, %one_minus_sq;
4135 st.global.f32 [%out], %result;
4136
4137DONE:
4138 ret;
4139}
4140";
4141
4142
4143#[cfg(feature = "cuda")]
4152pub(crate) const SOFTMAX_BACKWARD_PTX: &str = "\
4153.version 7.0\n\
4154.target sm_52\n\
4155.address_size 64\n\
4156\n\
4157.shared .align 4 .f32 sdata[256];\n\
4158\n\
4159.visible .entry softmax_backward_kernel(\n\
4160 .param .u64 grad_ptr,\n\
4161 .param .u64 output_ptr,\n\
4162 .param .u64 out_ptr,\n\
4163 .param .u32 rows,\n\
4164 .param .u32 cols\n\
4165) {\n\
4166 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
4167 .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
4168 .reg .f32 %vg, %vo, %dot, %other_val, %diff, %result;\n\
4169 .reg .pred %p, %loop_p, %reduce_p;\n\
4170\n\
4171 ld.param.u64 %grad, [grad_ptr];\n\
4172 ld.param.u64 %output, [output_ptr];\n\
4173 ld.param.u64 %out, [out_ptr];\n\
4174 ld.param.u32 %rows_reg, [rows];\n\
4175 ld.param.u32 %cols_reg, [cols];\n\
4176\n\
4177 mov.u32 %bid, %ctaid.x;\n\
4178 mov.u32 %bdim, %ntid.x;\n\
4179 mov.u32 %r_tid, %tid.x;\n\
4180 mov.u64 %sbase, sdata;\n\
4181\n\
4182 setp.ge.u32 %p, %bid, %rows_reg;\n\
4183 @%p bra DONE;\n\
4184\n\
4185 // row_off = bid * cols * 4 (byte offset)\n\
4186 cvt.u64.u32 %row_off, %bid;\n\
4187 cvt.u64.u32 %off, %cols_reg;\n\
4188 mul.lo.u64 %row_off, %row_off, %off;\n\
4189 shl.b64 %row_off, %row_off, 2;\n\
4190\n\
4191 // Phase 1: compute partial dot = sum(grad[j] * output[j]) for this thread's elements\n\
4192 mov.f32 %dot, 0f00000000;\n\
4193 mov.u32 %j, %r_tid;\n\
4194DOT_LOOP:\n\
4195 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4196 @%loop_p bra DOT_LOOP_DONE;\n\
4197 cvt.u64.u32 %off, %j;\n\
4198 shl.b64 %off, %off, 2;\n\
4199 add.u64 %saddr, %grad, %off;\n\
4200 add.u64 %saddr, %saddr, %row_off;\n\
4201 ld.global.f32 %vg, [%saddr];\n\
4202 add.u64 %saddr, %output, %off;\n\
4203 add.u64 %saddr, %saddr, %row_off;\n\
4204 ld.global.f32 %vo, [%saddr];\n\
4205 fma.rn.f32 %dot, %vg, %vo, %dot;\n\
4206 add.u32 %j, %j, %bdim;\n\
4207 bra DOT_LOOP;\n\
4208DOT_LOOP_DONE:\n\
4209\n\
4210 // Store partial dot into shared memory and reduce\n\
4211 cvt.u64.u32 %off, %r_tid;\n\
4212 shl.b64 %off, %off, 2;\n\
4213 add.u64 %saddr, %sbase, %off;\n\
4214 st.shared.f32 [%saddr], %dot;\n\
4215 bar.sync 0;\n\
4216\n\
4217 mov.u32 %half, %bdim;\n\
4218DOT_REDUCE:\n\
4219 shr.u32 %half, %half, 1;\n\
4220 setp.eq.u32 %reduce_p, %half, 0;\n\
4221 @%reduce_p bra DOT_REDUCE_DONE;\n\
4222 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4223 @%reduce_p bra DOT_REDUCE_SKIP;\n\
4224 add.u32 %other_tid, %r_tid, %half;\n\
4225 cvt.u64.u32 %off, %other_tid;\n\
4226 shl.b64 %off, %off, 2;\n\
4227 add.u64 %saddr, %sbase, %off;\n\
4228 ld.shared.f32 %other_val, [%saddr];\n\
4229 cvt.u64.u32 %off, %r_tid;\n\
4230 shl.b64 %off, %off, 2;\n\
4231 add.u64 %saddr, %sbase, %off;\n\
4232 ld.shared.f32 %dot, [%saddr];\n\
4233 add.f32 %dot, %dot, %other_val;\n\
4234 st.shared.f32 [%saddr], %dot;\n\
4235DOT_REDUCE_SKIP:\n\
4236 bar.sync 0;\n\
4237 bra DOT_REDUCE;\n\
4238DOT_REDUCE_DONE:\n\
4239\n\
4240 // Broadcast dot to all threads\n\
4241 ld.shared.f32 %dot, [sdata];\n\
4242 bar.sync 0;\n\
4243\n\
4244 // Phase 2: out[j] = output[j] * (grad[j] - dot)\n\
4245 mov.u32 %j, %r_tid;\n\
4246WRITE_LOOP:\n\
4247 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4248 @%loop_p bra WRITE_LOOP_DONE;\n\
4249 cvt.u64.u32 %off, %j;\n\
4250 shl.b64 %off, %off, 2;\n\
4251 add.u64 %saddr, %grad, %off;\n\
4252 add.u64 %saddr, %saddr, %row_off;\n\
4253 ld.global.f32 %vg, [%saddr];\n\
4254 add.u64 %saddr, %output, %off;\n\
4255 add.u64 %saddr, %saddr, %row_off;\n\
4256 ld.global.f32 %vo, [%saddr];\n\
4257 sub.f32 %diff, %vg, %dot;\n\
4258 mul.f32 %result, %vo, %diff;\n\
4259 add.u64 %saddr, %out, %off;\n\
4260 add.u64 %saddr, %saddr, %row_off;\n\
4261 st.global.f32 [%saddr], %result;\n\
4262 add.u32 %j, %j, %bdim;\n\
4263 bra WRITE_LOOP;\n\
4264WRITE_LOOP_DONE:\n\
4265\n\
4266DONE:\n\
4267 ret;\n\
4268}\n\
4269";
4270
4271
4272#[cfg(feature = "cuda")]
4282pub(crate) const LOG_SOFTMAX_PTX: &str = "\
4283.version 7.0\n\
4284.target sm_52\n\
4285.address_size 64\n\
4286\n\
4287.shared .align 4 .f32 sdata[256];\n\
4288\n\
4289.visible .entry log_softmax_kernel(\n\
4290 .param .u64 input_ptr,\n\
4291 .param .u64 output_ptr,\n\
4292 .param .u32 rows,\n\
4293 .param .u32 cols\n\
4294) {\n\
4295 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
4296 .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
4297 .reg .f32 %val, %max_val, %sum_val, %exp_val, %log_sum_exp, %result;\n\
4298 .reg .pred %p, %loop_p;\n\
4299 .reg .u32 %half, %other_tid;\n\
4300 .reg .f32 %other_val;\n\
4301 .reg .pred %reduce_p;\n\
4302\n\
4303 ld.param.u64 %in, [input_ptr];\n\
4304 ld.param.u64 %out, [output_ptr];\n\
4305 ld.param.u32 %rows_reg, [rows];\n\
4306 ld.param.u32 %cols_reg, [cols];\n\
4307\n\
4308 mov.u32 %bid, %ctaid.x;\n\
4309 mov.u32 %bdim, %ntid.x;\n\
4310 mov.u32 %r_tid, %tid.x;\n\
4311 mov.u64 %sbase, sdata;\n\
4312\n\
4313 setp.ge.u32 %p, %bid, %rows_reg;\n\
4314 @%p bra DONE;\n\
4315\n\
4316 // row_off = bid * cols * 4 (byte offset)\n\
4317 cvt.u64.u32 %row_off, %bid;\n\
4318 cvt.u64.u32 %off, %cols_reg;\n\
4319 mul.lo.u64 %row_off, %row_off, %off;\n\
4320 shl.b64 %row_off, %row_off, 2;\n\
4321\n\
4322 // Phase 1: find max across row (grid-stride over columns)\n\
4323 mov.f32 %max_val, 0fFF800000;\n\
4324 mov.u32 %j, %r_tid;\n\
4325FIND_MAX:\n\
4326 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4327 @%loop_p bra FIND_MAX_DONE;\n\
4328 cvt.u64.u32 %off, %j;\n\
4329 shl.b64 %off, %off, 2;\n\
4330 add.u64 %off, %in, %off;\n\
4331 add.u64 %off, %off, %row_off;\n\
4332 ld.global.f32 %val, [%off];\n\
4333 max.f32 %max_val, %max_val, %val;\n\
4334 add.u32 %j, %j, %bdim;\n\
4335 bra FIND_MAX;\n\
4336FIND_MAX_DONE:\n\
4337\n\
4338 // Shared-memory tree reduction for max\n\
4339 cvt.u64.u32 %off, %r_tid;\n\
4340 shl.b64 %off, %off, 2;\n\
4341 add.u64 %saddr, %sbase, %off;\n\
4342 st.shared.f32 [%saddr], %max_val;\n\
4343 bar.sync 0;\n\
4344\n\
4345 mov.u32 %half, %bdim;\n\
4346MAX_REDUCE:\n\
4347 shr.u32 %half, %half, 1;\n\
4348 setp.eq.u32 %reduce_p, %half, 0;\n\
4349 @%reduce_p bra MAX_REDUCE_DONE;\n\
4350 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4351 @%reduce_p bra MAX_REDUCE_SKIP;\n\
4352 add.u32 %other_tid, %r_tid, %half;\n\
4353 cvt.u64.u32 %off, %other_tid;\n\
4354 shl.b64 %off, %off, 2;\n\
4355 add.u64 %saddr, %sbase, %off;\n\
4356 ld.shared.f32 %other_val, [%saddr];\n\
4357 cvt.u64.u32 %off, %r_tid;\n\
4358 shl.b64 %off, %off, 2;\n\
4359 add.u64 %saddr, %sbase, %off;\n\
4360 ld.shared.f32 %max_val, [%saddr];\n\
4361 max.f32 %max_val, %max_val, %other_val;\n\
4362 add.u64 %saddr, %sbase, %off;\n\
4363 st.shared.f32 [%saddr], %max_val;\n\
4364MAX_REDUCE_SKIP:\n\
4365 bar.sync 0;\n\
4366 bra MAX_REDUCE;\n\
4367MAX_REDUCE_DONE:\n\
4368\n\
4369 // Broadcast max to all threads\n\
4370 ld.shared.f32 %max_val, [sdata];\n\
4371 bar.sync 0;\n\
4372\n\
4373 // Phase 2: compute partial sum of exp(x[j] - max)\n\
4374 mov.f32 %sum_val, 0f00000000;\n\
4375 mov.u32 %j, %r_tid;\n\
4376SUM_EXP:\n\
4377 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4378 @%loop_p bra SUM_EXP_DONE;\n\
4379 cvt.u64.u32 %off, %j;\n\
4380 shl.b64 %off, %off, 2;\n\
4381 add.u64 %off, %in, %off;\n\
4382 add.u64 %off, %off, %row_off;\n\
4383 ld.global.f32 %val, [%off];\n\
4384 sub.f32 %val, %val, %max_val;\n\
4385 // exp(x) = exp2(x * log2(e)), log2(e) = 0x3FB8AA3B\n\
4386 mul.f32 %val, %val, 0f3FB8AA3B;\n\
4387 ex2.approx.f32 %exp_val, %val;\n\
4388 add.f32 %sum_val, %sum_val, %exp_val;\n\
4389 add.u32 %j, %j, %bdim;\n\
4390 bra SUM_EXP;\n\
4391SUM_EXP_DONE:\n\
4392\n\
4393 // Shared-memory tree reduction for sum\n\
4394 cvt.u64.u32 %off, %r_tid;\n\
4395 shl.b64 %off, %off, 2;\n\
4396 add.u64 %saddr, %sbase, %off;\n\
4397 st.shared.f32 [%saddr], %sum_val;\n\
4398 bar.sync 0;\n\
4399\n\
4400 mov.u32 %half, %bdim;\n\
4401SUM_REDUCE:\n\
4402 shr.u32 %half, %half, 1;\n\
4403 setp.eq.u32 %reduce_p, %half, 0;\n\
4404 @%reduce_p bra SUM_REDUCE_DONE;\n\
4405 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4406 @%reduce_p bra SUM_REDUCE_SKIP;\n\
4407 add.u32 %other_tid, %r_tid, %half;\n\
4408 cvt.u64.u32 %off, %other_tid;\n\
4409 shl.b64 %off, %off, 2;\n\
4410 add.u64 %saddr, %sbase, %off;\n\
4411 ld.shared.f32 %other_val, [%saddr];\n\
4412 cvt.u64.u32 %off, %r_tid;\n\
4413 shl.b64 %off, %off, 2;\n\
4414 add.u64 %saddr, %sbase, %off;\n\
4415 ld.shared.f32 %sum_val, [%saddr];\n\
4416 add.f32 %sum_val, %sum_val, %other_val;\n\
4417 add.u64 %saddr, %sbase, %off;\n\
4418 st.shared.f32 [%saddr], %sum_val;\n\
4419SUM_REDUCE_SKIP:\n\
4420 bar.sync 0;\n\
4421 bra SUM_REDUCE;\n\
4422SUM_REDUCE_DONE:\n\
4423\n\
4424 // Broadcast sum to all threads, compute log_sum_exp = max + log(sum)\n\
4425 ld.shared.f32 %sum_val, [sdata];\n\
4426 bar.sync 0;\n\
4427 // log(x) = log2(x) / log2(e) = log2(x) * ln(2)\n\
4428 // ln(2) = 0x3F317218\n\
4429 lg2.approx.f32 %log_sum_exp, %sum_val;\n\
4430 mul.f32 %log_sum_exp, %log_sum_exp, 0f3F317218;\n\
4431 add.f32 %log_sum_exp, %max_val, %log_sum_exp;\n\
4432\n\
4433 // Phase 3: out[j] = x[j] - log_sum_exp\n\
4434 mov.u32 %j, %r_tid;\n\
4435WRITE_OUTPUT:\n\
4436 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4437 @%loop_p bra WRITE_OUTPUT_DONE;\n\
4438 cvt.u64.u32 %off, %j;\n\
4439 shl.b64 %off, %off, 2;\n\
4440 add.u64 %saddr, %in, %off;\n\
4441 add.u64 %saddr, %saddr, %row_off;\n\
4442 ld.global.f32 %val, [%saddr];\n\
4443 sub.f32 %result, %val, %log_sum_exp;\n\
4444 cvt.u64.u32 %off, %j;\n\
4445 shl.b64 %off, %off, 2;\n\
4446 add.u64 %saddr, %out, %off;\n\
4447 add.u64 %saddr, %saddr, %row_off;\n\
4448 st.global.f32 [%saddr], %result;\n\
4449 add.u32 %j, %j, %bdim;\n\
4450 bra WRITE_OUTPUT;\n\
4451WRITE_OUTPUT_DONE:\n\
4452\n\
4453DONE:\n\
4454 ret;\n\
4455}\n\
4456";
4457
4458#[cfg(feature = "cuda")]
4460pub(crate) const LOG_SOFTMAX_F64_PTX: &str = "\
4461.version 7.0\n\
4462.target sm_52\n\
4463.address_size 64\n\
4464\n\
4465.shared .align 8 .f64 sdata[256];\n\
4466\n\
4467.visible .entry log_softmax_f64_kernel(\n\
4468 .param .u64 input_ptr,\n\
4469 .param .u64 output_ptr,\n\
4470 .param .u32 rows,\n\
4471 .param .u32 cols\n\
4472) {\n\
4473 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
4474 .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
4475 .reg .f64 %val, %max_val, %sum_val, %exp_val, %log_sum_exp, %result;\n\
4476 .reg .pred %p, %loop_p;\n\
4477 .reg .u32 %half, %other_tid;\n\
4478 .reg .f64 %other_val;\n\
4479 .reg .pred %reduce_p;\n\
4480 .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
4481 .reg .s32 %e_ni;\n\
4482 .reg .s64 %e_ni64, %e_bits;\n\
4483 .reg .u64 %l_xbits, %l_mbits, %l_bias;\n\
4484 .reg .s64 %l_exp64;\n\
4485 .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;\n\
4486\n\
4487 ld.param.u64 %in, [input_ptr];\n\
4488 ld.param.u64 %out, [output_ptr];\n\
4489 ld.param.u32 %rows_reg, [rows];\n\
4490 ld.param.u32 %cols_reg, [cols];\n\
4491\n\
4492 mov.u32 %bid, %ctaid.x;\n\
4493 mov.u32 %bdim, %ntid.x;\n\
4494 mov.u32 %r_tid, %tid.x;\n\
4495 mov.u64 %sbase, sdata;\n\
4496\n\
4497 setp.ge.u32 %p, %bid, %rows_reg;\n\
4498 @%p bra DONE;\n\
4499\n\
4500 cvt.u64.u32 %row_off, %bid;\n\
4501 cvt.u64.u32 %off, %cols_reg;\n\
4502 mul.lo.u64 %row_off, %row_off, %off;\n\
4503 shl.b64 %row_off, %row_off, 3;\n\
4504\n\
4505 mov.f64 %max_val, 0dFFF0000000000000;\n\
4506 mov.u32 %j, %r_tid;\n\
4507FIND_MAX:\n\
4508 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4509 @%loop_p bra FIND_MAX_DONE;\n\
4510 cvt.u64.u32 %off, %j;\n\
4511 shl.b64 %off, %off, 3;\n\
4512 add.u64 %off, %in, %off;\n\
4513 add.u64 %off, %off, %row_off;\n\
4514 ld.global.f64 %val, [%off];\n\
4515 max.f64 %max_val, %max_val, %val;\n\
4516 add.u32 %j, %j, %bdim;\n\
4517 bra FIND_MAX;\n\
4518FIND_MAX_DONE:\n\
4519\n\
4520 cvt.u64.u32 %off, %r_tid;\n\
4521 shl.b64 %off, %off, 3;\n\
4522 add.u64 %saddr, %sbase, %off;\n\
4523 st.shared.f64 [%saddr], %max_val;\n\
4524 bar.sync 0;\n\
4525\n\
4526 mov.u32 %half, %bdim;\n\
4527MAX_REDUCE:\n\
4528 shr.u32 %half, %half, 1;\n\
4529 setp.eq.u32 %reduce_p, %half, 0;\n\
4530 @%reduce_p bra MAX_REDUCE_DONE;\n\
4531 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4532 @%reduce_p bra MAX_REDUCE_SKIP;\n\
4533 add.u32 %other_tid, %r_tid, %half;\n\
4534 cvt.u64.u32 %off, %other_tid;\n\
4535 shl.b64 %off, %off, 3;\n\
4536 add.u64 %saddr, %sbase, %off;\n\
4537 ld.shared.f64 %other_val, [%saddr];\n\
4538 cvt.u64.u32 %off, %r_tid;\n\
4539 shl.b64 %off, %off, 3;\n\
4540 add.u64 %saddr, %sbase, %off;\n\
4541 ld.shared.f64 %max_val, [%saddr];\n\
4542 max.f64 %max_val, %max_val, %other_val;\n\
4543 st.shared.f64 [%saddr], %max_val;\n\
4544MAX_REDUCE_SKIP:\n\
4545 bar.sync 0;\n\
4546 bra MAX_REDUCE;\n\
4547MAX_REDUCE_DONE:\n\
4548\n\
4549 ld.shared.f64 %max_val, [sdata];\n\
4550 bar.sync 0;\n\
4551\n\
4552 mov.f64 %sum_val, 0d0000000000000000;\n\
4553 mov.u32 %j, %r_tid;\n\
4554SUM_EXP:\n\
4555 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4556 @%loop_p bra SUM_EXP_DONE;\n\
4557 cvt.u64.u32 %off, %j;\n\
4558 shl.b64 %off, %off, 3;\n\
4559 add.u64 %off, %in, %off;\n\
4560 add.u64 %off, %off, %row_off;\n\
4561 ld.global.f64 %val, [%off];\n\
4562 sub.f64 %val, %val, %max_val;\n\
4563 mov.f64 %e_one, 0d3FF0000000000000;\n\
4564 mov.f64 %e_half, 0d3FE0000000000000;\n\
4565 mul.f64 %e_nf, %val, 0d3FF71547652B82FE;\n\
4566 cvt.rni.f64.f64 %e_nf, %e_nf;\n\
4567 cvt.rni.s32.f64 %e_ni, %e_nf;\n\
4568 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %val;\n\
4569 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
4570 mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
4571 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
4572 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
4573 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
4574 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
4575 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
4576 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
4577 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
4578 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
4579 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
4580 fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
4581 fma.rn.f64 %exp_val, %e_p, %e_r, %e_one;\n\
4582 cvt.s64.s32 %e_ni64, %e_ni;\n\
4583 add.s64 %e_ni64, %e_ni64, 1023;\n\
4584 shl.b64 %e_bits, %e_ni64, 52;\n\
4585 mov.b64 %e_nf, %e_bits;\n\
4586 mul.f64 %exp_val, %exp_val, %e_nf;\n\
4587 add.f64 %sum_val, %sum_val, %exp_val;\n\
4588 add.u32 %j, %j, %bdim;\n\
4589 bra SUM_EXP;\n\
4590SUM_EXP_DONE:\n\
4591\n\
4592 cvt.u64.u32 %off, %r_tid;\n\
4593 shl.b64 %off, %off, 3;\n\
4594 add.u64 %saddr, %sbase, %off;\n\
4595 st.shared.f64 [%saddr], %sum_val;\n\
4596 bar.sync 0;\n\
4597\n\
4598 mov.u32 %half, %bdim;\n\
4599SUM_REDUCE:\n\
4600 shr.u32 %half, %half, 1;\n\
4601 setp.eq.u32 %reduce_p, %half, 0;\n\
4602 @%reduce_p bra SUM_REDUCE_DONE;\n\
4603 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4604 @%reduce_p bra SUM_REDUCE_SKIP;\n\
4605 add.u32 %other_tid, %r_tid, %half;\n\
4606 cvt.u64.u32 %off, %other_tid;\n\
4607 shl.b64 %off, %off, 3;\n\
4608 add.u64 %saddr, %sbase, %off;\n\
4609 ld.shared.f64 %other_val, [%saddr];\n\
4610 cvt.u64.u32 %off, %r_tid;\n\
4611 shl.b64 %off, %off, 3;\n\
4612 add.u64 %saddr, %sbase, %off;\n\
4613 ld.shared.f64 %sum_val, [%saddr];\n\
4614 add.f64 %sum_val, %sum_val, %other_val;\n\
4615 st.shared.f64 [%saddr], %sum_val;\n\
4616SUM_REDUCE_SKIP:\n\
4617 bar.sync 0;\n\
4618 bra SUM_REDUCE;\n\
4619SUM_REDUCE_DONE:\n\
4620\n\
4621 ld.shared.f64 %sum_val, [sdata];\n\
4622 bar.sync 0;\n\
4623 mov.f64 %e_one, 0d3FF0000000000000;\n\
4624 mov.b64 %l_xbits, %sum_val;\n\
4625 shr.u64 %l_exp64, %l_xbits, 52;\n\
4626 and.b64 %l_exp64, %l_exp64, 2047;\n\
4627 sub.s64 %l_exp64, %l_exp64, 1023;\n\
4628 cvt.rn.f64.s64 %l_nf, %l_exp64;\n\
4629 mov.u64 %l_bias, 0x3FF0000000000000;\n\
4630 and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;\n\
4631 or.b64 %l_mbits, %l_mbits, %l_bias;\n\
4632 mov.b64 %l_m, %l_mbits;\n\
4633 sub.f64 %l_f, %l_m, %e_one;\n\
4634 add.f64 %l_s, %l_m, %e_one;\n\
4635 div.rn.f64 %l_f, %l_f, %l_s;\n\
4636 mul.f64 %l_f2, %l_f, %l_f;\n\
4637 mov.f64 %l_p, 0d3FB745D1745D1746;\n\
4638 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;\n\
4639 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;\n\
4640 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;\n\
4641 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;\n\
4642 fma.rn.f64 %l_p, %l_p, %l_f2, %e_one;\n\
4643 mul.f64 %l_p, %l_p, %l_f;\n\
4644 add.f64 %l_p, %l_p, %l_p;\n\
4645 mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;\n\
4646 fma.rn.f64 %log_sum_exp, %l_nf, %l_ln2, %l_p;\n\
4647 add.f64 %log_sum_exp, %max_val, %log_sum_exp;\n\
4648\n\
4649 mov.u32 %j, %r_tid;\n\
4650WRITE_OUTPUT:\n\
4651 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4652 @%loop_p bra WRITE_OUTPUT_DONE;\n\
4653 cvt.u64.u32 %off, %j;\n\
4654 shl.b64 %off, %off, 3;\n\
4655 add.u64 %saddr, %in, %off;\n\
4656 add.u64 %saddr, %saddr, %row_off;\n\
4657 ld.global.f64 %val, [%saddr];\n\
4658 sub.f64 %result, %val, %log_sum_exp;\n\
4659 cvt.u64.u32 %off, %j;\n\
4660 shl.b64 %off, %off, 3;\n\
4661 add.u64 %saddr, %out, %off;\n\
4662 add.u64 %saddr, %saddr, %row_off;\n\
4663 st.global.f64 [%saddr], %result;\n\
4664 add.u32 %j, %j, %bdim;\n\
4665 bra WRITE_OUTPUT;\n\
4666WRITE_OUTPUT_DONE:\n\
4667\n\
4668DONE:\n\
4669 ret;\n\
4670}\n\
4671";
4672
4673#[cfg(feature = "cuda")]
4683pub(crate) const LOG_SOFTMAX_BACKWARD_PTX: &str = "\
4684.version 7.0\n\
4685.target sm_52\n\
4686.address_size 64\n\
4687\n\
4688.shared .align 4 .f32 sdata[256];\n\
4689\n\
4690.visible .entry log_softmax_backward_kernel(\n\
4691 .param .u64 grad_ptr,\n\
4692 .param .u64 output_ptr,\n\
4693 .param .u64 out_ptr,\n\
4694 .param .u32 rows,\n\
4695 .param .u32 cols\n\
4696) {\n\
4697 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
4698 .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
4699 .reg .f32 %vg, %vo, %sum_grad, %other_val, %softmax_j, %result;\n\
4700 .reg .pred %p, %loop_p, %reduce_p;\n\
4701\n\
4702 ld.param.u64 %grad, [grad_ptr];\n\
4703 ld.param.u64 %output, [output_ptr];\n\
4704 ld.param.u64 %out, [out_ptr];\n\
4705 ld.param.u32 %rows_reg, [rows];\n\
4706 ld.param.u32 %cols_reg, [cols];\n\
4707\n\
4708 mov.u32 %bid, %ctaid.x;\n\
4709 mov.u32 %bdim, %ntid.x;\n\
4710 mov.u32 %r_tid, %tid.x;\n\
4711 mov.u64 %sbase, sdata;\n\
4712\n\
4713 setp.ge.u32 %p, %bid, %rows_reg;\n\
4714 @%p bra DONE;\n\
4715\n\
4716 // row_off = bid * cols * 4 (byte offset)\n\
4717 cvt.u64.u32 %row_off, %bid;\n\
4718 cvt.u64.u32 %off, %cols_reg;\n\
4719 mul.lo.u64 %row_off, %row_off, %off;\n\
4720 shl.b64 %row_off, %row_off, 2;\n\
4721\n\
4722 // Phase 1: compute partial sum_grad = sum(grad[j]) for this thread's elements\n\
4723 mov.f32 %sum_grad, 0f00000000;\n\
4724 mov.u32 %j, %r_tid;\n\
4725SUM_LOOP:\n\
4726 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4727 @%loop_p bra SUM_LOOP_DONE;\n\
4728 cvt.u64.u32 %off, %j;\n\
4729 shl.b64 %off, %off, 2;\n\
4730 add.u64 %saddr, %grad, %off;\n\
4731 add.u64 %saddr, %saddr, %row_off;\n\
4732 ld.global.f32 %vg, [%saddr];\n\
4733 add.f32 %sum_grad, %sum_grad, %vg;\n\
4734 add.u32 %j, %j, %bdim;\n\
4735 bra SUM_LOOP;\n\
4736SUM_LOOP_DONE:\n\
4737\n\
4738 // Store partial sum into shared memory and reduce\n\
4739 cvt.u64.u32 %off, %r_tid;\n\
4740 shl.b64 %off, %off, 2;\n\
4741 add.u64 %saddr, %sbase, %off;\n\
4742 st.shared.f32 [%saddr], %sum_grad;\n\
4743 bar.sync 0;\n\
4744\n\
4745 mov.u32 %half, %bdim;\n\
4746SUM_REDUCE:\n\
4747 shr.u32 %half, %half, 1;\n\
4748 setp.eq.u32 %reduce_p, %half, 0;\n\
4749 @%reduce_p bra SUM_REDUCE_DONE;\n\
4750 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4751 @%reduce_p bra SUM_REDUCE_SKIP;\n\
4752 add.u32 %other_tid, %r_tid, %half;\n\
4753 cvt.u64.u32 %off, %other_tid;\n\
4754 shl.b64 %off, %off, 2;\n\
4755 add.u64 %saddr, %sbase, %off;\n\
4756 ld.shared.f32 %other_val, [%saddr];\n\
4757 cvt.u64.u32 %off, %r_tid;\n\
4758 shl.b64 %off, %off, 2;\n\
4759 add.u64 %saddr, %sbase, %off;\n\
4760 ld.shared.f32 %sum_grad, [%saddr];\n\
4761 add.f32 %sum_grad, %sum_grad, %other_val;\n\
4762 st.shared.f32 [%saddr], %sum_grad;\n\
4763SUM_REDUCE_SKIP:\n\
4764 bar.sync 0;\n\
4765 bra SUM_REDUCE;\n\
4766SUM_REDUCE_DONE:\n\
4767\n\
4768 // Broadcast sum_grad to all threads\n\
4769 ld.shared.f32 %sum_grad, [sdata];\n\
4770 bar.sync 0;\n\
4771\n\
4772 // Phase 2: out[j] = grad[j] - exp(output[j]) * sum_grad\n\
4773 mov.u32 %j, %r_tid;\n\
4774WRITE_LOOP:\n\
4775 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4776 @%loop_p bra WRITE_LOOP_DONE;\n\
4777 cvt.u64.u32 %off, %j;\n\
4778 shl.b64 %off, %off, 2;\n\
4779 add.u64 %saddr, %grad, %off;\n\
4780 add.u64 %saddr, %saddr, %row_off;\n\
4781 ld.global.f32 %vg, [%saddr];\n\
4782 add.u64 %saddr, %output, %off;\n\
4783 add.u64 %saddr, %saddr, %row_off;\n\
4784 ld.global.f32 %vo, [%saddr];\n\
4785 // exp(log_softmax_output) = softmax probability\n\
4786 mul.f32 %vo, %vo, 0f3FB8AA3B;\n\
4787 ex2.approx.f32 %softmax_j, %vo;\n\
4788 // out[j] = grad[j] - softmax[j] * sum_grad\n\
4789 mul.f32 %result, %softmax_j, %sum_grad;\n\
4790 sub.f32 %result, %vg, %result;\n\
4791 add.u64 %saddr, %out, %off;\n\
4792 add.u64 %saddr, %saddr, %row_off;\n\
4793 st.global.f32 [%saddr], %result;\n\
4794 add.u32 %j, %j, %bdim;\n\
4795 bra WRITE_LOOP;\n\
4796WRITE_LOOP_DONE:\n\
4797\n\
4798DONE:\n\
4799 ret;\n\
4800}\n\
4801";
4802
4803#[cfg(feature = "cuda")]
4805pub(crate) const LOG_SOFTMAX_BACKWARD_F64_PTX: &str = "\
4806.version 7.0\n\
4807.target sm_52\n\
4808.address_size 64\n\
4809\n\
4810.shared .align 8 .f64 sdata[256];\n\
4811\n\
4812.visible .entry log_softmax_backward_f64_kernel(\n\
4813 .param .u64 grad_ptr,\n\
4814 .param .u64 output_ptr,\n\
4815 .param .u64 out_ptr,\n\
4816 .param .u32 rows,\n\
4817 .param .u32 cols\n\
4818) {\n\
4819 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
4820 .reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
4821 .reg .f64 %vg, %vo, %sum_grad, %other_val, %softmax_j, %result;\n\
4822 .reg .pred %p, %loop_p, %reduce_p;\n\
4823 .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
4824 .reg .s32 %e_ni;\n\
4825 .reg .s64 %e_ni64, %e_bits;\n\
4826\n\
4827 ld.param.u64 %grad, [grad_ptr];\n\
4828 ld.param.u64 %output, [output_ptr];\n\
4829 ld.param.u64 %out, [out_ptr];\n\
4830 ld.param.u32 %rows_reg, [rows];\n\
4831 ld.param.u32 %cols_reg, [cols];\n\
4832\n\
4833 mov.u32 %bid, %ctaid.x;\n\
4834 mov.u32 %bdim, %ntid.x;\n\
4835 mov.u32 %r_tid, %tid.x;\n\
4836 mov.u64 %sbase, sdata;\n\
4837\n\
4838 setp.ge.u32 %p, %bid, %rows_reg;\n\
4839 @%p bra DONE;\n\
4840\n\
4841 cvt.u64.u32 %row_off, %bid;\n\
4842 cvt.u64.u32 %off, %cols_reg;\n\
4843 mul.lo.u64 %row_off, %row_off, %off;\n\
4844 shl.b64 %row_off, %row_off, 3;\n\
4845\n\
4846 mov.f64 %sum_grad, 0d0000000000000000;\n\
4847 mov.u32 %j, %r_tid;\n\
4848SUM_LOOP:\n\
4849 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4850 @%loop_p bra SUM_LOOP_DONE;\n\
4851 cvt.u64.u32 %off, %j;\n\
4852 shl.b64 %off, %off, 3;\n\
4853 add.u64 %saddr, %grad, %off;\n\
4854 add.u64 %saddr, %saddr, %row_off;\n\
4855 ld.global.f64 %vg, [%saddr];\n\
4856 add.f64 %sum_grad, %sum_grad, %vg;\n\
4857 add.u32 %j, %j, %bdim;\n\
4858 bra SUM_LOOP;\n\
4859SUM_LOOP_DONE:\n\
4860\n\
4861 cvt.u64.u32 %off, %r_tid;\n\
4862 shl.b64 %off, %off, 3;\n\
4863 add.u64 %saddr, %sbase, %off;\n\
4864 st.shared.f64 [%saddr], %sum_grad;\n\
4865 bar.sync 0;\n\
4866\n\
4867 mov.u32 %half, %bdim;\n\
4868SUM_REDUCE:\n\
4869 shr.u32 %half, %half, 1;\n\
4870 setp.eq.u32 %reduce_p, %half, 0;\n\
4871 @%reduce_p bra SUM_REDUCE_DONE;\n\
4872 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
4873 @%reduce_p bra SUM_REDUCE_SKIP;\n\
4874 add.u32 %other_tid, %r_tid, %half;\n\
4875 cvt.u64.u32 %off, %other_tid;\n\
4876 shl.b64 %off, %off, 3;\n\
4877 add.u64 %saddr, %sbase, %off;\n\
4878 ld.shared.f64 %other_val, [%saddr];\n\
4879 cvt.u64.u32 %off, %r_tid;\n\
4880 shl.b64 %off, %off, 3;\n\
4881 add.u64 %saddr, %sbase, %off;\n\
4882 ld.shared.f64 %sum_grad, [%saddr];\n\
4883 add.f64 %sum_grad, %sum_grad, %other_val;\n\
4884 st.shared.f64 [%saddr], %sum_grad;\n\
4885SUM_REDUCE_SKIP:\n\
4886 bar.sync 0;\n\
4887 bra SUM_REDUCE;\n\
4888SUM_REDUCE_DONE:\n\
4889\n\
4890 ld.shared.f64 %sum_grad, [sdata];\n\
4891 bar.sync 0;\n\
4892\n\
4893 mov.u32 %j, %r_tid;\n\
4894WRITE_LOOP:\n\
4895 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
4896 @%loop_p bra WRITE_LOOP_DONE;\n\
4897 cvt.u64.u32 %off, %j;\n\
4898 shl.b64 %off, %off, 3;\n\
4899 add.u64 %saddr, %grad, %off;\n\
4900 add.u64 %saddr, %saddr, %row_off;\n\
4901 ld.global.f64 %vg, [%saddr];\n\
4902 add.u64 %saddr, %output, %off;\n\
4903 add.u64 %saddr, %saddr, %row_off;\n\
4904 ld.global.f64 %vo, [%saddr];\n\
4905 // exp(log_softmax_output) — inline f64 exp\n\
4906 mov.f64 %e_one, 0d3FF0000000000000;\n\
4907 mov.f64 %e_half, 0d3FE0000000000000;\n\
4908 mul.f64 %e_nf, %vo, 0d3FF71547652B82FE;\n\
4909 cvt.rni.f64.f64 %e_nf, %e_nf;\n\
4910 cvt.rni.s32.f64 %e_ni, %e_nf;\n\
4911 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %vo;\n\
4912 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
4913 mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
4914 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
4915 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
4916 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
4917 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
4918 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
4919 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
4920 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
4921 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
4922 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
4923 fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
4924 fma.rn.f64 %softmax_j, %e_p, %e_r, %e_one;\n\
4925 cvt.s64.s32 %e_ni64, %e_ni;\n\
4926 add.s64 %e_ni64, %e_ni64, 1023;\n\
4927 shl.b64 %e_bits, %e_ni64, 52;\n\
4928 mov.b64 %e_nf, %e_bits;\n\
4929 mul.f64 %softmax_j, %softmax_j, %e_nf;\n\
4930 mul.f64 %result, %softmax_j, %sum_grad;\n\
4931 sub.f64 %result, %vg, %result;\n\
4932 add.u64 %saddr, %out, %off;\n\
4933 add.u64 %saddr, %saddr, %row_off;\n\
4934 st.global.f64 [%saddr], %result;\n\
4935 add.u32 %j, %j, %bdim;\n\
4936 bra WRITE_LOOP;\n\
4937WRITE_LOOP_DONE:\n\
4938\n\
4939DONE:\n\
4940 ret;\n\
4941}\n\
4942";
4943
4944#[cfg(feature = "cuda")]
4958pub(crate) const REDUCE_SUM_PTX: &str = "\
4959.version 7.0
4960.target sm_52
4961.address_size 64
4962
4963// Shared memory for intra-block reduction (256 floats = 1024 bytes).
4964.shared .align 4 .f32 sdata[256];
4965
4966.visible .entry reduce_sum_kernel(
4967 .param .u64 in_ptr,
4968 .param .u64 out_ptr,
4969 .param .u32 n
4970) {
4971 .reg .u32 %tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
4972 .reg .u64 %in, %out, %off;
4973 .reg .f32 %sum, %other;
4974 .reg .pred %p, %ptid;
4975
4976 ld.param.u64 %in, [in_ptr];
4977 ld.param.u64 %out, [out_ptr];
4978 ld.param.u32 %n_reg, [n];
4979
4980 mov.u32 %tid, %tid.x;
4981 mov.u32 %bid, %ctaid.x;
4982 mov.u32 %bdim, %ntid.x;
4983 mov.u32 %gdim, %nctaid.x;
4984
4985 // Grid-stride accumulation: each thread sums multiple elements.
4986 // idx = bid * bdim + tid; stride = bdim * gdim
4987 mad.lo.u32 %idx, %bid, %bdim, %tid;
4988 mul.lo.u32 %stride, %bdim, %gdim;
4989 mov.f32 %sum, 0f00000000;
4990
4991GRID_LOOP:
4992 setp.ge.u32 %p, %idx, %n_reg;
4993 @%p bra GRID_DONE;
4994
4995 cvt.u64.u32 %off, %idx;
4996 shl.b64 %off, %off, 2;
4997 add.u64 %off, %in, %off;
4998 ld.global.f32 %other, [%off];
4999 add.f32 %sum, %sum, %other;
5000 add.u32 %idx, %idx, %stride;
5001 bra GRID_LOOP;
5002
5003GRID_DONE:
5004 // Write thread's partial sum to shared memory.
5005 cvt.u64.u32 %off, %tid;
5006 shl.b64 %off, %off, 2;
5007 st.shared.f32 [sdata + %off], %sum;
5008 bar.sync 0;
5009
5010 // Tree reduction in shared memory.
5011 mov.u32 %half, 128;
5012TREE_LOOP:
5013 setp.lt.u32 %p, %half, 1;
5014 @%p bra TREE_DONE;
5015
5016 setp.ge.u32 %ptid, %tid, %half;
5017 @%ptid bra TREE_SKIP;
5018
5019 // Load partner's value from sdata[tid + half].
5020 add.u32 %idx, %tid, %half;
5021 cvt.u64.u32 %off, %idx;
5022 shl.b64 %off, %off, 2;
5023 ld.shared.f32 %other, [sdata + %off];
5024 // Load own value.
5025 cvt.u64.u32 %off, %tid;
5026 shl.b64 %off, %off, 2;
5027 ld.shared.f32 %sum, [sdata + %off];
5028 add.f32 %sum, %sum, %other;
5029 st.shared.f32 [sdata + %off], %sum;
5030
5031TREE_SKIP:
5032 bar.sync 0;
5033 shr.u32 %half, %half, 1;
5034 bra TREE_LOOP;
5035
5036TREE_DONE:
5037 // Thread 0 writes block result.
5038 setp.ne.u32 %ptid, %tid, 0;
5039 @%ptid bra END;
5040
5041 ld.shared.f32 %sum, [sdata];
5042 cvt.u64.u32 %off, %bid;
5043 shl.b64 %off, %off, 2;
5044 add.u64 %out, %out, %off;
5045 st.global.f32 [%out], %sum;
5046
5047END:
5048 ret;
5049}
5050";
5051
5052
5053#[cfg(feature = "cuda")]
5058pub(crate) const SUM_AXIS_PTX: &str = "\
5059.version 7.0
5060.target sm_52
5061.address_size 64
5062
5063.visible .entry sum_axis_kernel(
5064 .param .u64 input_ptr,
5065 .param .u64 output_ptr,
5066 .param .u32 outer_size,
5067 .param .u32 axis_size,
5068 .param .u32 inner_size,
5069 .param .u32 total_output
5070) {
5071 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %axis_sz, %inner_sz;
5072 .reg .u32 %outer_idx, %inner_idx, %k, %tmp;
5073 .reg .u64 %in, %out, %off, %addr;
5074 .reg .f32 %val, %sum;
5075 .reg .pred %p, %lp;
5076
5077 ld.param.u64 %in, [input_ptr];
5078 ld.param.u64 %out, [output_ptr];
5079 ld.param.u32 %outer_sz, [outer_size];
5080 ld.param.u32 %axis_sz, [axis_size];
5081 ld.param.u32 %inner_sz, [inner_size];
5082 ld.param.u32 %n_reg, [total_output];
5083
5084 mov.u32 %bid, %ctaid.x;
5085 mov.u32 %bdim, %ntid.x;
5086 mov.u32 %r_tid, %tid.x;
5087 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5088
5089 setp.ge.u32 %p, %r_tid, %n_reg;
5090 @%p bra DONE;
5091
5092 // outer_idx = r_tid / inner_size
5093 div.u32 %outer_idx, %r_tid, %inner_sz;
5094 // inner_idx = r_tid % inner_size
5095 rem.u32 %inner_idx, %r_tid, %inner_sz;
5096
5097 // base = outer_idx * axis_size * inner_size + inner_idx
5098 mul.lo.u32 %tmp, %outer_idx, %axis_sz;
5099 mul.lo.u32 %tmp, %tmp, %inner_sz;
5100 add.u32 %tmp, %tmp, %inner_idx;
5101
5102 mov.f32 %sum, 0f00000000;
5103 mov.u32 %k, 0;
5104SUM_LOOP:
5105 setp.ge.u32 %lp, %k, %axis_sz;
5106 @%lp bra SUM_LOOP_DONE;
5107
5108 // addr = in + (tmp + k * inner_size) * 4
5109 mul.lo.u32 %inner_idx, %k, %inner_sz;
5110 add.u32 %inner_idx, %tmp, %inner_idx;
5111 cvt.u64.u32 %off, %inner_idx;
5112 shl.b64 %off, %off, 2;
5113 add.u64 %addr, %in, %off;
5114 ld.global.f32 %val, [%addr];
5115 add.f32 %sum, %sum, %val;
5116
5117 add.u32 %k, %k, 1;
5118 bra SUM_LOOP;
5119SUM_LOOP_DONE:
5120
5121 // output[r_tid] = sum
5122 cvt.u64.u32 %off, %r_tid;
5123 shl.b64 %off, %off, 2;
5124 add.u64 %addr, %out, %off;
5125 st.global.f32 [%addr], %sum;
5126
5127DONE:
5128 ret;
5129}
5130";
5131
5132#[cfg(feature = "cuda")]
5144pub(crate) const CUMSUM_PTX: &str = "\
5145.version 7.0
5146.target sm_52
5147.address_size 64
5148
5149.visible .entry cumsum_kernel(
5150 .param .u64 input_ptr,
5151 .param .u64 output_ptr,
5152 .param .u32 outer_size,
5153 .param .u32 dim_size,
5154 .param .u32 inner_size,
5155 .param .u32 total
5156) {
5157 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5158 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5159 .reg .u64 %in, %out, %off, %addr;
5160 .reg .f32 %val, %acc;
5161 .reg .pred %p, %lp;
5162
5163 ld.param.u64 %in, [input_ptr];
5164 ld.param.u64 %out, [output_ptr];
5165 ld.param.u32 %outer_sz, [outer_size];
5166 ld.param.u32 %dim_sz, [dim_size];
5167 ld.param.u32 %inner_sz, [inner_size];
5168 ld.param.u32 %n_reg, [total];
5169
5170 mov.u32 %bid, %ctaid.x;
5171 mov.u32 %bdim, %ntid.x;
5172 mov.u32 %r_tid, %tid.x;
5173 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5174
5175 // total threads = outer * inner
5176 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5177 setp.ge.u32 %p, %r_tid, %tmp;
5178 @%p bra DONE;
5179
5180 div.u32 %outer_idx, %r_tid, %inner_sz;
5181 rem.u32 %inner_idx, %r_tid, %inner_sz;
5182
5183 // base = outer_idx * dim_size * inner_size + inner_idx
5184 mul.lo.u32 %base, %outer_idx, %dim_sz;
5185 mul.lo.u32 %base, %base, %inner_sz;
5186 add.u32 %base, %base, %inner_idx;
5187
5188 mov.f32 %acc, 0f00000000;
5189 mov.u32 %k, 0;
5190SCAN_LOOP:
5191 setp.ge.u32 %lp, %k, %dim_sz;
5192 @%lp bra SCAN_DONE;
5193
5194 // idx = base + k * inner_size
5195 mul.lo.u32 %idx, %k, %inner_sz;
5196 add.u32 %idx, %base, %idx;
5197
5198 cvt.u64.u32 %off, %idx;
5199 shl.b64 %off, %off, 2;
5200 add.u64 %addr, %in, %off;
5201 ld.global.f32 %val, [%addr];
5202
5203 add.f32 %acc, %acc, %val;
5204
5205 add.u64 %addr, %out, %off;
5206 st.global.f32 [%addr], %acc;
5207
5208 add.u32 %k, %k, 1;
5209 bra SCAN_LOOP;
5210SCAN_DONE:
5211
5212DONE:
5213 ret;
5214}
5215";
5216
5217
5218#[cfg(feature = "cuda")]
5223pub(crate) const CUMPROD_PTX: &str = "\
5224.version 7.0
5225.target sm_52
5226.address_size 64
5227
5228.visible .entry cumprod_kernel(
5229 .param .u64 input_ptr,
5230 .param .u64 output_ptr,
5231 .param .u32 outer_size,
5232 .param .u32 dim_size,
5233 .param .u32 inner_size,
5234 .param .u32 total
5235) {
5236 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5237 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5238 .reg .u64 %in, %out, %off, %addr;
5239 .reg .f32 %val, %acc;
5240 .reg .pred %p, %lp;
5241
5242 ld.param.u64 %in, [input_ptr];
5243 ld.param.u64 %out, [output_ptr];
5244 ld.param.u32 %outer_sz, [outer_size];
5245 ld.param.u32 %dim_sz, [dim_size];
5246 ld.param.u32 %inner_sz, [inner_size];
5247 ld.param.u32 %n_reg, [total];
5248
5249 mov.u32 %bid, %ctaid.x;
5250 mov.u32 %bdim, %ntid.x;
5251 mov.u32 %r_tid, %tid.x;
5252 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5253
5254 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5255 setp.ge.u32 %p, %r_tid, %tmp;
5256 @%p bra DONE;
5257
5258 div.u32 %outer_idx, %r_tid, %inner_sz;
5259 rem.u32 %inner_idx, %r_tid, %inner_sz;
5260
5261 mul.lo.u32 %base, %outer_idx, %dim_sz;
5262 mul.lo.u32 %base, %base, %inner_sz;
5263 add.u32 %base, %base, %inner_idx;
5264
5265 // acc = 1.0
5266 mov.f32 %acc, 0f3F800000;
5267 mov.u32 %k, 0;
5268SCAN_LOOP:
5269 setp.ge.u32 %lp, %k, %dim_sz;
5270 @%lp bra SCAN_DONE;
5271
5272 mul.lo.u32 %idx, %k, %inner_sz;
5273 add.u32 %idx, %base, %idx;
5274
5275 cvt.u64.u32 %off, %idx;
5276 shl.b64 %off, %off, 2;
5277 add.u64 %addr, %in, %off;
5278 ld.global.f32 %val, [%addr];
5279
5280 mul.f32 %acc, %acc, %val;
5281
5282 add.u64 %addr, %out, %off;
5283 st.global.f32 [%addr], %acc;
5284
5285 add.u32 %k, %k, 1;
5286 bra SCAN_LOOP;
5287SCAN_DONE:
5288
5289DONE:
5290 ret;
5291}
5292";
5293
5294
5295#[cfg(feature = "cuda")]
5302pub(crate) const CUMMAX_PTX: &str = "\
5303.version 7.0
5304.target sm_52
5305.address_size 64
5306
5307.visible .entry cummax_kernel(
5308 .param .u64 input_ptr,
5309 .param .u64 output_ptr,
5310 .param .u64 indices_ptr,
5311 .param .u32 outer_size,
5312 .param .u32 dim_size,
5313 .param .u32 inner_size,
5314 .param .u32 total
5315) {
5316 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5317 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
5318 .reg .u64 %in, %out, %ind, %off, %addr;
5319 .reg .f32 %val, %acc, %best_k_f;
5320 .reg .pred %p, %lp, %is_new_max;
5321
5322 ld.param.u64 %in, [input_ptr];
5323 ld.param.u64 %out, [output_ptr];
5324 ld.param.u64 %ind, [indices_ptr];
5325 ld.param.u32 %outer_sz, [outer_size];
5326 ld.param.u32 %dim_sz, [dim_size];
5327 ld.param.u32 %inner_sz, [inner_size];
5328 ld.param.u32 %n_reg, [total];
5329
5330 mov.u32 %bid, %ctaid.x;
5331 mov.u32 %bdim, %ntid.x;
5332 mov.u32 %r_tid, %tid.x;
5333 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5334
5335 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5336 setp.ge.u32 %p, %r_tid, %tmp;
5337 @%p bra DONE;
5338
5339 div.u32 %outer_idx, %r_tid, %inner_sz;
5340 rem.u32 %inner_idx, %r_tid, %inner_sz;
5341
5342 mul.lo.u32 %base, %outer_idx, %dim_sz;
5343 mul.lo.u32 %base, %base, %inner_sz;
5344 add.u32 %base, %base, %inner_idx;
5345
5346 mov.b32 %acc, 0xFF800000;
5347 mov.u32 %best_k, 0;
5348 mov.u32 %k, 0;
5349SCAN_LOOP:
5350 setp.ge.u32 %lp, %k, %dim_sz;
5351 @%lp bra SCAN_DONE;
5352
5353 mul.lo.u32 %idx, %k, %inner_sz;
5354 add.u32 %idx, %base, %idx;
5355
5356 cvt.u64.u32 %off, %idx;
5357 shl.b64 %off, %off, 2;
5358 add.u64 %addr, %in, %off;
5359 ld.global.f32 %val, [%addr];
5360
5361 setp.gt.f32 %is_new_max, %val, %acc;
5362 @%is_new_max mov.u32 %best_k, %k;
5363 max.f32 %acc, %acc, %val;
5364
5365 add.u64 %addr, %out, %off;
5366 st.global.f32 [%addr], %acc;
5367
5368 cvt.rn.f32.u32 %best_k_f, %best_k;
5369 add.u64 %addr, %ind, %off;
5370 st.global.f32 [%addr], %best_k_f;
5371
5372 add.u32 %k, %k, 1;
5373 bra SCAN_LOOP;
5374SCAN_DONE:
5375
5376DONE:
5377 ret;
5378}
5379";
5380
5381
5382#[cfg(feature = "cuda")]
5387pub(crate) const CUMMIN_PTX: &str = "\
5388.version 7.0
5389.target sm_52
5390.address_size 64
5391
5392.visible .entry cummin_kernel(
5393 .param .u64 input_ptr,
5394 .param .u64 output_ptr,
5395 .param .u64 indices_ptr,
5396 .param .u32 outer_size,
5397 .param .u32 dim_size,
5398 .param .u32 inner_size,
5399 .param .u32 total
5400) {
5401 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5402 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
5403 .reg .u64 %in, %out, %ind, %off, %addr;
5404 .reg .f32 %val, %acc, %best_k_f;
5405 .reg .pred %p, %lp, %is_new_min;
5406
5407 ld.param.u64 %in, [input_ptr];
5408 ld.param.u64 %out, [output_ptr];
5409 ld.param.u64 %ind, [indices_ptr];
5410 ld.param.u32 %outer_sz, [outer_size];
5411 ld.param.u32 %dim_sz, [dim_size];
5412 ld.param.u32 %inner_sz, [inner_size];
5413 ld.param.u32 %n_reg, [total];
5414
5415 mov.u32 %bid, %ctaid.x;
5416 mov.u32 %bdim, %ntid.x;
5417 mov.u32 %r_tid, %tid.x;
5418 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5419
5420 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5421 setp.ge.u32 %p, %r_tid, %tmp;
5422 @%p bra DONE;
5423
5424 div.u32 %outer_idx, %r_tid, %inner_sz;
5425 rem.u32 %inner_idx, %r_tid, %inner_sz;
5426
5427 mul.lo.u32 %base, %outer_idx, %dim_sz;
5428 mul.lo.u32 %base, %base, %inner_sz;
5429 add.u32 %base, %base, %inner_idx;
5430
5431 mov.b32 %acc, 0x7F800000;
5432 mov.u32 %best_k, 0;
5433 mov.u32 %k, 0;
5434SCAN_LOOP:
5435 setp.ge.u32 %lp, %k, %dim_sz;
5436 @%lp bra SCAN_DONE;
5437
5438 mul.lo.u32 %idx, %k, %inner_sz;
5439 add.u32 %idx, %base, %idx;
5440
5441 cvt.u64.u32 %off, %idx;
5442 shl.b64 %off, %off, 2;
5443 add.u64 %addr, %in, %off;
5444 ld.global.f32 %val, [%addr];
5445
5446 setp.lt.f32 %is_new_min, %val, %acc;
5447 @%is_new_min mov.u32 %best_k, %k;
5448 min.f32 %acc, %acc, %val;
5449
5450 add.u64 %addr, %out, %off;
5451 st.global.f32 [%addr], %acc;
5452
5453 cvt.rn.f32.u32 %best_k_f, %best_k;
5454 add.u64 %addr, %ind, %off;
5455 st.global.f32 [%addr], %best_k_f;
5456
5457 add.u32 %k, %k, 1;
5458 bra SCAN_LOOP;
5459SCAN_DONE:
5460
5461DONE:
5462 ret;
5463}
5464";
5465
5466
5467#[cfg(feature = "cuda")]
5476pub(crate) const LOGCUMSUMEXP_PTX: &str = "\
5477.version 7.0
5478.target sm_52
5479.address_size 64
5480
5481.visible .entry logcumsumexp_kernel(
5482 .param .u64 input_ptr,
5483 .param .u64 output_ptr,
5484 .param .u32 outer_size,
5485 .param .u32 dim_size,
5486 .param .u32 inner_size,
5487 .param .u32 total
5488) {
5489 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5490 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5491 .reg .u64 %in, %out, %off, %addr;
5492 .reg .f32 %val, %acc, %m, %ea, %ev, %s, %ls, %log2e, %ln2;
5493 .reg .pred %p, %lp;
5494
5495 ld.param.u64 %in, [input_ptr];
5496 ld.param.u64 %out, [output_ptr];
5497 ld.param.u32 %outer_sz, [outer_size];
5498 ld.param.u32 %dim_sz, [dim_size];
5499 ld.param.u32 %inner_sz, [inner_size];
5500 ld.param.u32 %n_reg, [total];
5501
5502 // log2(e) = 1.4426950408... -> 0x3FB8AA3B
5503 mov.b32 %log2e, 0x3FB8AA3B;
5504 // ln(2) = 0.6931471805... -> 0x3F317218
5505 mov.b32 %ln2, 0x3F317218;
5506
5507 mov.u32 %bid, %ctaid.x;
5508 mov.u32 %bdim, %ntid.x;
5509 mov.u32 %r_tid, %tid.x;
5510 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5511
5512 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5513 setp.ge.u32 %p, %r_tid, %tmp;
5514 @%p bra DONE;
5515
5516 div.u32 %outer_idx, %r_tid, %inner_sz;
5517 rem.u32 %inner_idx, %r_tid, %inner_sz;
5518
5519 mul.lo.u32 %base, %outer_idx, %dim_sz;
5520 mul.lo.u32 %base, %base, %inner_sz;
5521 add.u32 %base, %base, %inner_idx;
5522
5523 // acc = -inf
5524 mov.b32 %acc, 0xFF800000;
5525 mov.u32 %k, 0;
5526SCAN_LOOP:
5527 setp.ge.u32 %lp, %k, %dim_sz;
5528 @%lp bra SCAN_DONE;
5529
5530 mul.lo.u32 %idx, %k, %inner_sz;
5531 add.u32 %idx, %base, %idx;
5532
5533 cvt.u64.u32 %off, %idx;
5534 shl.b64 %off, %off, 2;
5535 add.u64 %addr, %in, %off;
5536 ld.global.f32 %val, [%addr];
5537
5538 // Numerically stable: m = max(acc, x)
5539 max.f32 %m, %acc, %val;
5540 // exp(acc - m): (acc - m) * log2(e) -> ex2
5541 sub.f32 %ea, %acc, %m;
5542 mul.f32 %ea, %ea, %log2e;
5543 ex2.approx.f32 %ea, %ea;
5544 // exp(x - m): (x - m) * log2(e) -> ex2
5545 sub.f32 %ev, %val, %m;
5546 mul.f32 %ev, %ev, %log2e;
5547 ex2.approx.f32 %ev, %ev;
5548 // sum
5549 add.f32 %s, %ea, %ev;
5550 // log(sum) = lg2(sum) * ln(2)
5551 lg2.approx.f32 %ls, %s;
5552 mul.f32 %ls, %ls, %ln2;
5553 // acc = m + log(sum)
5554 add.f32 %acc, %m, %ls;
5555
5556 add.u64 %addr, %out, %off;
5557 st.global.f32 [%addr], %acc;
5558
5559 add.u32 %k, %k, 1;
5560 bra SCAN_LOOP;
5561SCAN_DONE:
5562
5563DONE:
5564 ret;
5565}
5566";
5567
5568#[cfg(feature = "cuda")]
5570pub(crate) const LOGCUMSUMEXP_F64_PTX: &str = "\
5571.version 7.0
5572.target sm_52
5573.address_size 64
5574
5575.visible .entry logcumsumexp_f64_kernel(
5576 .param .u64 input_ptr,
5577 .param .u64 output_ptr,
5578 .param .u32 outer_size,
5579 .param .u32 dim_size,
5580 .param .u32 inner_size,
5581 .param .u32 total
5582) {
5583 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
5584 .reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
5585 .reg .u64 %in, %out, %off, %addr;
5586 .reg .f64 %val, %acc, %m, %ea, %ev, %s, %ls;
5587 .reg .pred %p, %lp;
5588 .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;
5589 .reg .s32 %e_ni;
5590 .reg .s64 %e_ni64, %e_bits;
5591 .reg .u64 %l_xbits, %l_mbits, %l_bias;
5592 .reg .s64 %l_exp64;
5593 .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2;
5594
5595 ld.param.u64 %in, [input_ptr];
5596 ld.param.u64 %out, [output_ptr];
5597 ld.param.u32 %outer_sz, [outer_size];
5598 ld.param.u32 %dim_sz, [dim_size];
5599 ld.param.u32 %inner_sz, [inner_size];
5600 ld.param.u32 %n_reg, [total];
5601
5602 mov.u32 %bid, %ctaid.x;
5603 mov.u32 %bdim, %ntid.x;
5604 mov.u32 %r_tid, %tid.x;
5605 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
5606
5607 mul.lo.u32 %tmp, %outer_sz, %inner_sz;
5608 setp.ge.u32 %p, %r_tid, %tmp;
5609 @%p bra DONE;
5610
5611 div.u32 %outer_idx, %r_tid, %inner_sz;
5612 rem.u32 %inner_idx, %r_tid, %inner_sz;
5613
5614 mul.lo.u32 %base, %outer_idx, %dim_sz;
5615 mul.lo.u32 %base, %base, %inner_sz;
5616 add.u32 %base, %base, %inner_idx;
5617
5618 // acc = -inf
5619 mov.b64 %acc, 0xFFF0000000000000;
5620 mov.u32 %k, 0;
5621SCAN_LOOP:
5622 setp.ge.u32 %lp, %k, %dim_sz;
5623 @%lp bra SCAN_DONE;
5624
5625 mul.lo.u32 %idx, %k, %inner_sz;
5626 add.u32 %idx, %base, %idx;
5627
5628 cvt.u64.u32 %off, %idx;
5629 shl.b64 %off, %off, 3;
5630 add.u64 %addr, %in, %off;
5631 ld.global.f64 %val, [%addr];
5632
5633 max.f64 %m, %acc, %val;
5634 mov.f64 %e_one, 0d3FF0000000000000;
5635 mov.f64 %e_half, 0d3FE0000000000000;
5636 // --- inline exp(acc - m) -> %ea ---
5637 sub.f64 %ea, %acc, %m;
5638 mul.f64 %e_nf, %ea, 0d3FF71547652B82FE;
5639 cvt.rni.f64.f64 %e_nf, %e_nf;
5640 cvt.rni.s32.f64 %e_ni, %e_nf;
5641 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %ea;
5642 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
5643 mov.f64 %e_p, 0d3E21EED8EFF8D898;
5644 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
5645 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
5646 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
5647 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
5648 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
5649 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
5650 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
5651 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
5652 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
5653 fma.rn.f64 %e_p, %e_p, %e_r, %e_one;
5654 fma.rn.f64 %ea, %e_p, %e_r, %e_one;
5655 cvt.s64.s32 %e_ni64, %e_ni;
5656 add.s64 %e_ni64, %e_ni64, 1023;
5657 shl.b64 %e_bits, %e_ni64, 52;
5658 mov.b64 %e_nf, %e_bits;
5659 mul.f64 %ea, %ea, %e_nf;
5660 // --- inline exp(val - m) -> %ev ---
5661 sub.f64 %ev, %val, %m;
5662 mul.f64 %e_nf, %ev, 0d3FF71547652B82FE;
5663 cvt.rni.f64.f64 %e_nf, %e_nf;
5664 cvt.rni.s32.f64 %e_ni, %e_nf;
5665 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %ev;
5666 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
5667 mov.f64 %e_p, 0d3E21EED8EFF8D898;
5668 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
5669 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
5670 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
5671 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
5672 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
5673 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
5674 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
5675 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
5676 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
5677 fma.rn.f64 %e_p, %e_p, %e_r, %e_one;
5678 fma.rn.f64 %ev, %e_p, %e_r, %e_one;
5679 cvt.s64.s32 %e_ni64, %e_ni;
5680 add.s64 %e_ni64, %e_ni64, 1023;
5681 shl.b64 %e_bits, %e_ni64, 52;
5682 mov.b64 %e_nf, %e_bits;
5683 mul.f64 %ev, %ev, %e_nf;
5684 add.f64 %s, %ea, %ev;
5685 // --- inline ln(%s) -> %ls ---
5686 mov.b64 %l_xbits, %s;
5687 shr.u64 %l_exp64, %l_xbits, 52;
5688 and.b64 %l_exp64, %l_exp64, 2047;
5689 sub.s64 %l_exp64, %l_exp64, 1023;
5690 cvt.rn.f64.s64 %l_nf, %l_exp64;
5691 mov.u64 %l_bias, 0x3FF0000000000000;
5692 and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
5693 or.b64 %l_mbits, %l_mbits, %l_bias;
5694 mov.b64 %l_m, %l_mbits;
5695 sub.f64 %l_f, %l_m, %e_one;
5696 add.f64 %l_s, %l_m, %e_one;
5697 div.rn.f64 %l_f, %l_f, %l_s;
5698 mul.f64 %l_f2, %l_f, %l_f;
5699 mov.f64 %l_p, 0d3FB745D1745D1746;
5700 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
5701 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
5702 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
5703 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
5704 fma.rn.f64 %l_p, %l_p, %l_f2, %e_one;
5705 mul.f64 %l_p, %l_p, %l_f;
5706 add.f64 %l_p, %l_p, %l_p;
5707 mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
5708 fma.rn.f64 %ls, %l_nf, %l_ln2, %l_p;
5709 add.f64 %acc, %m, %ls;
5710
5711 add.u64 %addr, %out, %off;
5712 st.global.f64 [%addr], %acc;
5713
5714 add.u32 %k, %k, 1;
5715 bra SCAN_LOOP;
5716SCAN_DONE:
5717
5718DONE:
5719 ret;
5720}
5721";
5722
5723#[cfg(feature = "cuda")]
5733pub(crate) const LAYERNORM_PTX: &str = "\
5734.version 7.0
5735.target sm_52
5736.address_size 64
5737
5738.shared .align 4 .f32 sdata[256];
5739
5740.visible .entry layernorm_kernel(
5741 .param .u64 in_ptr,
5742 .param .u64 out_ptr,
5743 .param .u64 w_ptr,
5744 .param .u64 b_ptr,
5745 .param .u32 rows,
5746 .param .u32 cols,
5747 .param .f32 eps
5748) {
5749 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
5750 .reg .u64 %in, %out, %w, %b, %row_off, %off, %sbase, %saddr;
5751 .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %normed, %wv, %bv, %result, %other_val, %n_f;
5752 .reg .pred %p, %lp, %rp;
5753
5754 ld.param.u64 %in, [in_ptr];
5755 ld.param.u64 %out, [out_ptr];
5756 ld.param.u64 %w, [w_ptr];
5757 ld.param.u64 %b, [b_ptr];
5758 ld.param.u32 %rows_reg, [rows];
5759 ld.param.u32 %cols_reg, [cols];
5760 ld.param.f32 %eps_r, [eps];
5761
5762 mov.u64 %sbase, sdata;
5763
5764 mov.u32 %r_bid, %ctaid.x;
5765 mov.u32 %r_bdim, %ntid.x;
5766 mov.u32 %r_tid, %tid.x;
5767
5768 setp.ge.u32 %p, %r_bid, %rows_reg;
5769 @%p bra DONE;
5770
5771 cvt.u64.u32 %row_off, %r_bid;
5772 cvt.u64.u32 %off, %cols_reg;
5773 mul.lo.u64 %row_off, %row_off, %off;
5774 shl.b64 %row_off, %row_off, 2;
5775 cvt.rn.f32.u32 %n_f, %cols_reg;
5776
5777 mov.f32 %mean, 0f00000000;
5778 mov.u32 %j, %r_tid;
5779SM:
5780 setp.ge.u32 %lp, %j, %cols_reg;
5781 @%lp bra SMD;
5782 cvt.u64.u32 %off, %j;
5783 shl.b64 %off, %off, 2;
5784 add.u64 %off, %in, %off;
5785 add.u64 %off, %off, %row_off;
5786 ld.global.f32 %val, [%off];
5787 add.f32 %mean, %mean, %val;
5788 add.u32 %j, %j, %r_bdim;
5789 bra SM;
5790SMD:
5791 cvt.u64.u32 %off, %r_tid;
5792 shl.b64 %off, %off, 2;
5793 add.u64 %saddr, %sbase, %off;
5794 st.shared.f32 [%saddr], %mean;
5795 bar.sync 0;
5796 mov.u32 %half, %r_bdim;
5797MR:
5798 shr.u32 %half, %half, 1;
5799 setp.eq.u32 %rp, %half, 0;
5800 @%rp bra MRD;
5801 setp.ge.u32 %rp, %r_tid, %half;
5802 @%rp bra MRS;
5803 add.u32 %r_otid, %r_tid, %half;
5804 cvt.u64.u32 %off, %r_otid;
5805 shl.b64 %off, %off, 2;
5806 add.u64 %saddr, %sbase, %off;
5807 ld.shared.f32 %other_val, [%saddr];
5808 cvt.u64.u32 %off, %r_tid;
5809 shl.b64 %off, %off, 2;
5810 add.u64 %saddr, %sbase, %off;
5811 ld.shared.f32 %mean, [%saddr];
5812 add.f32 %mean, %mean, %other_val;
5813 add.u64 %saddr, %sbase, %off;
5814 st.shared.f32 [%saddr], %mean;
5815MRS:
5816 bar.sync 0;
5817 bra MR;
5818MRD:
5819 ld.shared.f32 %mean, [%sbase];
5820 div.approx.f32 %mean, %mean, %n_f;
5821 bar.sync 0;
5822
5823 mov.f32 %var, 0f00000000;
5824 mov.u32 %j, %r_tid;
5825SV:
5826 setp.ge.u32 %lp, %j, %cols_reg;
5827 @%lp bra SVD;
5828 cvt.u64.u32 %off, %j;
5829 shl.b64 %off, %off, 2;
5830 add.u64 %off, %in, %off;
5831 add.u64 %off, %off, %row_off;
5832 ld.global.f32 %val, [%off];
5833 sub.f32 %diff, %val, %mean;
5834 fma.rn.f32 %var, %diff, %diff, %var;
5835 add.u32 %j, %j, %r_bdim;
5836 bra SV;
5837SVD:
5838 cvt.u64.u32 %off, %r_tid;
5839 shl.b64 %off, %off, 2;
5840 add.u64 %saddr, %sbase, %off;
5841 st.shared.f32 [%saddr], %var;
5842 bar.sync 0;
5843 mov.u32 %half, %r_bdim;
5844VR:
5845 shr.u32 %half, %half, 1;
5846 setp.eq.u32 %rp, %half, 0;
5847 @%rp bra VRD;
5848 setp.ge.u32 %rp, %r_tid, %half;
5849 @%rp bra VRS;
5850 add.u32 %r_otid, %r_tid, %half;
5851 cvt.u64.u32 %off, %r_otid;
5852 shl.b64 %off, %off, 2;
5853 add.u64 %saddr, %sbase, %off;
5854 ld.shared.f32 %other_val, [%saddr];
5855 cvt.u64.u32 %off, %r_tid;
5856 shl.b64 %off, %off, 2;
5857 add.u64 %saddr, %sbase, %off;
5858 ld.shared.f32 %var, [%saddr];
5859 add.f32 %var, %var, %other_val;
5860 add.u64 %saddr, %sbase, %off;
5861 st.shared.f32 [%saddr], %var;
5862VRS:
5863 bar.sync 0;
5864 bra VR;
5865VRD:
5866 ld.shared.f32 %var, [%sbase];
5867 div.approx.f32 %var, %var, %n_f;
5868 add.f32 %var, %var, %eps_r;
5869 sqrt.approx.f32 %inv_std, %var;
5870 rcp.approx.f32 %inv_std, %inv_std;
5871 bar.sync 0;
5872
5873 mov.u32 %j, %r_tid;
5874NM:
5875 setp.ge.u32 %lp, %j, %cols_reg;
5876 @%lp bra NMD;
5877 cvt.u64.u32 %off, %j;
5878 shl.b64 %off, %off, 2;
5879 add.u64 %off, %in, %off;
5880 add.u64 %off, %off, %row_off;
5881 ld.global.f32 %val, [%off];
5882 sub.f32 %normed, %val, %mean;
5883 mul.f32 %normed, %normed, %inv_std;
5884 cvt.u64.u32 %off, %j;
5885 shl.b64 %off, %off, 2;
5886 add.u64 %off, %w, %off;
5887 ld.global.f32 %wv, [%off];
5888 cvt.u64.u32 %off, %j;
5889 shl.b64 %off, %off, 2;
5890 add.u64 %off, %b, %off;
5891 ld.global.f32 %bv, [%off];
5892 fma.rn.f32 %result, %wv, %normed, %bv;
5893 cvt.u64.u32 %off, %j;
5894 shl.b64 %off, %off, 2;
5895 add.u64 %off, %out, %off;
5896 add.u64 %off, %off, %row_off;
5897 st.global.f32 [%off], %result;
5898 add.u32 %j, %j, %r_bdim;
5899 bra NM;
5900NMD:
5901
5902DONE:
5903 ret;
5904}
5905";
5906
5907
5908#[cfg(feature = "cuda")]
5933pub(crate) const LAYERNORM_BACKWARD_PTX: &str = "\
5934.version 7.0
5935.target sm_52
5936.address_size 64
5937
5938.shared .align 4 .f32 sdata[256];
5939
5940.visible .entry layernorm_backward_kernel(
5941 .param .u64 in_ptr,
5942 .param .u64 grad_out_ptr,
5943 .param .u64 w_ptr,
5944 .param .u64 grad_in_ptr,
5945 .param .u64 grad_w_ptr,
5946 .param .u64 grad_b_ptr,
5947 .param .u32 rows,
5948 .param .u32 cols,
5949 .param .f32 eps
5950) {
5951 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
5952 .reg .u64 %in, %go, %w, %gi, %gw, %gb, %row_off, %off, %sbase, %saddr, %addr;
5953 .reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %x_hat, %wv, %gov;
5954 .reg .f32 %dl_dx_hat, %sum1, %sum2, %other_val, %n_f, %mean1, %mean2, %result;
5955 .reg .pred %p, %lp, %rp;
5956
5957 ld.param.u64 %in, [in_ptr];
5958 ld.param.u64 %go, [grad_out_ptr];
5959 ld.param.u64 %w, [w_ptr];
5960 ld.param.u64 %gi, [grad_in_ptr];
5961 ld.param.u64 %gw, [grad_w_ptr];
5962 ld.param.u64 %gb, [grad_b_ptr];
5963 ld.param.u32 %rows_reg, [rows];
5964 ld.param.u32 %cols_reg, [cols];
5965 ld.param.f32 %eps_r, [eps];
5966
5967 mov.u64 %sbase, sdata;
5968
5969 mov.u32 %r_bid, %ctaid.x;
5970 mov.u32 %r_bdim, %ntid.x;
5971 mov.u32 %r_tid, %tid.x;
5972
5973 setp.ge.u32 %p, %r_bid, %rows_reg;
5974 @%p bra LNB_DONE;
5975
5976 // row_off = bid * cols * 4 (byte offset for this row)
5977 cvt.u64.u32 %row_off, %r_bid;
5978 cvt.u64.u32 %off, %cols_reg;
5979 mul.lo.u64 %row_off, %row_off, %off;
5980 shl.b64 %row_off, %row_off, 2;
5981 cvt.rn.f32.u32 %n_f, %cols_reg;
5982
5983 // ===== Phase 1: Compute mean =====
5984 mov.f32 %mean, 0f00000000;
5985 mov.u32 %j, %r_tid;
5986LNB_SM:
5987 setp.ge.u32 %lp, %j, %cols_reg;
5988 @%lp bra LNB_SMD;
5989 cvt.u64.u32 %off, %j;
5990 shl.b64 %off, %off, 2;
5991 add.u64 %addr, %in, %off;
5992 add.u64 %addr, %addr, %row_off;
5993 ld.global.f32 %val, [%addr];
5994 add.f32 %mean, %mean, %val;
5995 add.u32 %j, %j, %r_bdim;
5996 bra LNB_SM;
5997LNB_SMD:
5998 // Shared memory reduce for mean
5999 cvt.u64.u32 %off, %r_tid;
6000 shl.b64 %off, %off, 2;
6001 add.u64 %saddr, %sbase, %off;
6002 st.shared.f32 [%saddr], %mean;
6003 bar.sync 0;
6004 mov.u32 %half, %r_bdim;
6005LNB_MR:
6006 shr.u32 %half, %half, 1;
6007 setp.eq.u32 %rp, %half, 0;
6008 @%rp bra LNB_MRD;
6009 setp.ge.u32 %rp, %r_tid, %half;
6010 @%rp bra LNB_MRS;
6011 add.u32 %r_otid, %r_tid, %half;
6012 cvt.u64.u32 %off, %r_otid;
6013 shl.b64 %off, %off, 2;
6014 add.u64 %saddr, %sbase, %off;
6015 ld.shared.f32 %other_val, [%saddr];
6016 cvt.u64.u32 %off, %r_tid;
6017 shl.b64 %off, %off, 2;
6018 add.u64 %saddr, %sbase, %off;
6019 ld.shared.f32 %mean, [%saddr];
6020 add.f32 %mean, %mean, %other_val;
6021 st.shared.f32 [%saddr], %mean;
6022LNB_MRS:
6023 bar.sync 0;
6024 bra LNB_MR;
6025LNB_MRD:
6026 ld.shared.f32 %mean, [%sbase];
6027 div.approx.f32 %mean, %mean, %n_f;
6028 bar.sync 0;
6029
6030 // ===== Phase 2: Compute variance =====
6031 mov.f32 %var, 0f00000000;
6032 mov.u32 %j, %r_tid;
6033LNB_SV:
6034 setp.ge.u32 %lp, %j, %cols_reg;
6035 @%lp bra LNB_SVD;
6036 cvt.u64.u32 %off, %j;
6037 shl.b64 %off, %off, 2;
6038 add.u64 %addr, %in, %off;
6039 add.u64 %addr, %addr, %row_off;
6040 ld.global.f32 %val, [%addr];
6041 sub.f32 %diff, %val, %mean;
6042 fma.rn.f32 %var, %diff, %diff, %var;
6043 add.u32 %j, %j, %r_bdim;
6044 bra LNB_SV;
6045LNB_SVD:
6046 // Shared memory reduce for variance
6047 cvt.u64.u32 %off, %r_tid;
6048 shl.b64 %off, %off, 2;
6049 add.u64 %saddr, %sbase, %off;
6050 st.shared.f32 [%saddr], %var;
6051 bar.sync 0;
6052 mov.u32 %half, %r_bdim;
6053LNB_VR:
6054 shr.u32 %half, %half, 1;
6055 setp.eq.u32 %rp, %half, 0;
6056 @%rp bra LNB_VRD;
6057 setp.ge.u32 %rp, %r_tid, %half;
6058 @%rp bra LNB_VRS;
6059 add.u32 %r_otid, %r_tid, %half;
6060 cvt.u64.u32 %off, %r_otid;
6061 shl.b64 %off, %off, 2;
6062 add.u64 %saddr, %sbase, %off;
6063 ld.shared.f32 %other_val, [%saddr];
6064 cvt.u64.u32 %off, %r_tid;
6065 shl.b64 %off, %off, 2;
6066 add.u64 %saddr, %sbase, %off;
6067 ld.shared.f32 %var, [%saddr];
6068 add.f32 %var, %var, %other_val;
6069 st.shared.f32 [%saddr], %var;
6070LNB_VRS:
6071 bar.sync 0;
6072 bra LNB_VR;
6073LNB_VRD:
6074 ld.shared.f32 %var, [%sbase];
6075 div.approx.f32 %var, %var, %n_f;
6076 add.f32 %var, %var, %eps_r;
6077 sqrt.approx.f32 %inv_std, %var;
6078 rcp.approx.f32 %inv_std, %inv_std;
6079 bar.sync 0;
6080
6081 // ===== Phase 3: Compute sum1 = sum(dl_dx_hat), sum2 = sum(dl_dx_hat * x_hat) =====
6082 // Also accumulate grad_weight and grad_bias via atomicAdd
6083 mov.f32 %sum1, 0f00000000;
6084 mov.f32 %sum2, 0f00000000;
6085 mov.u32 %j, %r_tid;
6086LNB_S12:
6087 setp.ge.u32 %lp, %j, %cols_reg;
6088 @%lp bra LNB_S12D;
6089 // Load input[row, j]
6090 cvt.u64.u32 %off, %j;
6091 shl.b64 %off, %off, 2;
6092 add.u64 %addr, %in, %off;
6093 add.u64 %addr, %addr, %row_off;
6094 ld.global.f32 %val, [%addr];
6095 // x_hat = (val - mean) * inv_std
6096 sub.f32 %x_hat, %val, %mean;
6097 mul.f32 %x_hat, %x_hat, %inv_std;
6098 // Load grad_output[row, j]
6099 cvt.u64.u32 %off, %j;
6100 shl.b64 %off, %off, 2;
6101 add.u64 %addr, %go, %off;
6102 add.u64 %addr, %addr, %row_off;
6103 ld.global.f32 %gov, [%addr];
6104 // Load weight[j]
6105 cvt.u64.u32 %off, %j;
6106 shl.b64 %off, %off, 2;
6107 add.u64 %addr, %w, %off;
6108 ld.global.f32 %wv, [%addr];
6109 // dl_dx_hat = grad_output * weight
6110 mul.f32 %dl_dx_hat, %gov, %wv;
6111 // Accumulate sums
6112 add.f32 %sum1, %sum1, %dl_dx_hat;
6113 fma.rn.f32 %sum2, %dl_dx_hat, %x_hat, %sum2;
6114 // atomicAdd grad_weight[j] += grad_output * x_hat
6115 cvt.u64.u32 %off, %j;
6116 shl.b64 %off, %off, 2;
6117 add.u64 %addr, %gw, %off;
6118 mul.f32 %result, %gov, %x_hat;
6119 atom.global.add.f32 %result, [%addr], %result;
6120 // atomicAdd grad_bias[j] += grad_output
6121 add.u64 %addr, %gb, %off;
6122 atom.global.add.f32 %result, [%addr], %gov;
6123 add.u32 %j, %j, %r_bdim;
6124 bra LNB_S12;
6125LNB_S12D:
6126 // Reduce sum1 in shared memory
6127 cvt.u64.u32 %off, %r_tid;
6128 shl.b64 %off, %off, 2;
6129 add.u64 %saddr, %sbase, %off;
6130 st.shared.f32 [%saddr], %sum1;
6131 bar.sync 0;
6132 mov.u32 %half, %r_bdim;
6133LNB_R1:
6134 shr.u32 %half, %half, 1;
6135 setp.eq.u32 %rp, %half, 0;
6136 @%rp bra LNB_R1D;
6137 setp.ge.u32 %rp, %r_tid, %half;
6138 @%rp bra LNB_R1S;
6139 add.u32 %r_otid, %r_tid, %half;
6140 cvt.u64.u32 %off, %r_otid;
6141 shl.b64 %off, %off, 2;
6142 add.u64 %saddr, %sbase, %off;
6143 ld.shared.f32 %other_val, [%saddr];
6144 cvt.u64.u32 %off, %r_tid;
6145 shl.b64 %off, %off, 2;
6146 add.u64 %saddr, %sbase, %off;
6147 ld.shared.f32 %sum1, [%saddr];
6148 add.f32 %sum1, %sum1, %other_val;
6149 st.shared.f32 [%saddr], %sum1;
6150LNB_R1S:
6151 bar.sync 0;
6152 bra LNB_R1;
6153LNB_R1D:
6154 ld.shared.f32 %sum1, [%sbase];
6155 // mean1 = sum1 / n
6156 div.approx.f32 %mean1, %sum1, %n_f;
6157 bar.sync 0;
6158
6159 // Reduce sum2 in shared memory
6160 cvt.u64.u32 %off, %r_tid;
6161 shl.b64 %off, %off, 2;
6162 add.u64 %saddr, %sbase, %off;
6163 st.shared.f32 [%saddr], %sum2;
6164 bar.sync 0;
6165 mov.u32 %half, %r_bdim;
6166LNB_R2:
6167 shr.u32 %half, %half, 1;
6168 setp.eq.u32 %rp, %half, 0;
6169 @%rp bra LNB_R2D;
6170 setp.ge.u32 %rp, %r_tid, %half;
6171 @%rp bra LNB_R2S;
6172 add.u32 %r_otid, %r_tid, %half;
6173 cvt.u64.u32 %off, %r_otid;
6174 shl.b64 %off, %off, 2;
6175 add.u64 %saddr, %sbase, %off;
6176 ld.shared.f32 %other_val, [%saddr];
6177 cvt.u64.u32 %off, %r_tid;
6178 shl.b64 %off, %off, 2;
6179 add.u64 %saddr, %sbase, %off;
6180 ld.shared.f32 %sum2, [%saddr];
6181 add.f32 %sum2, %sum2, %other_val;
6182 st.shared.f32 [%saddr], %sum2;
6183LNB_R2S:
6184 bar.sync 0;
6185 bra LNB_R2;
6186LNB_R2D:
6187 ld.shared.f32 %sum2, [%sbase];
6188 // mean2 = sum2 / n
6189 div.approx.f32 %mean2, %sum2, %n_f;
6190 bar.sync 0;
6191
6192 // ===== Phase 4: Compute grad_input =====
6193 // grad_input[j] = inv_std * (dl_dx_hat[j] - mean1 - x_hat[j] * mean2)
6194 mov.u32 %j, %r_tid;
6195LNB_GI:
6196 setp.ge.u32 %lp, %j, %cols_reg;
6197 @%lp bra LNB_GID;
6198 // Reload input to recompute x_hat
6199 cvt.u64.u32 %off, %j;
6200 shl.b64 %off, %off, 2;
6201 add.u64 %addr, %in, %off;
6202 add.u64 %addr, %addr, %row_off;
6203 ld.global.f32 %val, [%addr];
6204 sub.f32 %x_hat, %val, %mean;
6205 mul.f32 %x_hat, %x_hat, %inv_std;
6206 // Reload grad_output and weight to recompute dl_dx_hat
6207 cvt.u64.u32 %off, %j;
6208 shl.b64 %off, %off, 2;
6209 add.u64 %addr, %go, %off;
6210 add.u64 %addr, %addr, %row_off;
6211 ld.global.f32 %gov, [%addr];
6212 cvt.u64.u32 %off, %j;
6213 shl.b64 %off, %off, 2;
6214 add.u64 %addr, %w, %off;
6215 ld.global.f32 %wv, [%addr];
6216 mul.f32 %dl_dx_hat, %gov, %wv;
6217 // result = inv_std * (dl_dx_hat - mean1 - x_hat * mean2)
6218 sub.f32 %result, %dl_dx_hat, %mean1;
6219 mul.f32 %diff, %x_hat, %mean2;
6220 sub.f32 %result, %result, %diff;
6221 mul.f32 %result, %inv_std, %result;
6222 // Store grad_input[row, j]
6223 cvt.u64.u32 %off, %j;
6224 shl.b64 %off, %off, 2;
6225 add.u64 %addr, %gi, %off;
6226 add.u64 %addr, %addr, %row_off;
6227 st.global.f32 [%addr], %result;
6228 add.u32 %j, %j, %r_bdim;
6229 bra LNB_GI;
6230LNB_GID:
6231
6232LNB_DONE:
6233 ret;
6234}
6235";
6236
6237
6238#[cfg(feature = "cuda")]
6251pub(crate) const RMSNORM_PTX: &str = "\
6252.version 7.0
6253.target sm_52
6254.address_size 64
6255
6256.shared .align 4 .f32 sdata[256];
6257
6258.visible .entry rmsnorm_kernel(
6259 .param .u64 in_ptr,
6260 .param .u64 out_ptr,
6261 .param .u64 w_ptr,
6262 .param .u32 rows,
6263 .param .u32 cols,
6264 .param .f32 eps
6265) {
6266 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
6267 .reg .u64 %in, %out, %w, %row_off, %off, %sbase, %saddr;
6268 .reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %wv, %result, %other_val, %n_f;
6269 .reg .pred %p, %lp, %rp;
6270
6271 ld.param.u64 %in, [in_ptr];
6272 ld.param.u64 %out, [out_ptr];
6273 ld.param.u64 %w, [w_ptr];
6274 ld.param.u32 %rows_reg, [rows];
6275 ld.param.u32 %cols_reg, [cols];
6276 ld.param.f32 %eps_r, [eps];
6277
6278 mov.u64 %sbase, sdata;
6279
6280 mov.u32 %r_bid, %ctaid.x;
6281 mov.u32 %r_bdim, %ntid.x;
6282 mov.u32 %r_tid, %tid.x;
6283
6284 setp.ge.u32 %p, %r_bid, %rows_reg;
6285 @%p bra DONE;
6286
6287 cvt.u64.u32 %row_off, %r_bid;
6288 cvt.u64.u32 %off, %cols_reg;
6289 mul.lo.u64 %row_off, %row_off, %off;
6290 shl.b64 %row_off, %row_off, 2;
6291 cvt.rn.f32.u32 %n_f, %cols_reg;
6292
6293 // ===== Phase 1: Compute sum(x^2) =====
6294 mov.f32 %sq_sum, 0f00000000;
6295 mov.u32 %j, %r_tid;
6296SS:
6297 setp.ge.u32 %lp, %j, %cols_reg;
6298 @%lp bra SSD;
6299 cvt.u64.u32 %off, %j;
6300 shl.b64 %off, %off, 2;
6301 add.u64 %off, %in, %off;
6302 add.u64 %off, %off, %row_off;
6303 ld.global.f32 %val, [%off];
6304 fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
6305 add.u32 %j, %j, %r_bdim;
6306 bra SS;
6307SSD:
6308 cvt.u64.u32 %off, %r_tid;
6309 shl.b64 %off, %off, 2;
6310 add.u64 %saddr, %sbase, %off;
6311 st.shared.f32 [%saddr], %sq_sum;
6312 bar.sync 0;
6313 mov.u32 %half, %r_bdim;
6314SR:
6315 shr.u32 %half, %half, 1;
6316 setp.eq.u32 %rp, %half, 0;
6317 @%rp bra SRD;
6318 setp.ge.u32 %rp, %r_tid, %half;
6319 @%rp bra SRS;
6320 add.u32 %r_otid, %r_tid, %half;
6321 cvt.u64.u32 %off, %r_otid;
6322 shl.b64 %off, %off, 2;
6323 add.u64 %saddr, %sbase, %off;
6324 ld.shared.f32 %other_val, [%saddr];
6325 cvt.u64.u32 %off, %r_tid;
6326 shl.b64 %off, %off, 2;
6327 add.u64 %saddr, %sbase, %off;
6328 ld.shared.f32 %sq_sum, [%saddr];
6329 add.f32 %sq_sum, %sq_sum, %other_val;
6330 add.u64 %saddr, %sbase, %off;
6331 st.shared.f32 [%saddr], %sq_sum;
6332SRS:
6333 bar.sync 0;
6334 bra SR;
6335SRD:
6336 ld.shared.f32 %sq_sum, [%sbase];
6337 div.approx.f32 %sq_sum, %sq_sum, %n_f;
6338 add.f32 %sq_sum, %sq_sum, %eps_r;
6339 sqrt.approx.f32 %inv_rms, %sq_sum;
6340 rcp.approx.f32 %inv_rms, %inv_rms;
6341 bar.sync 0;
6342
6343 // ===== Phase 2: Normalize and scale =====
6344 // out[j] = x[j] * inv_rms * weight[j]
6345 mov.u32 %j, %r_tid;
6346NM:
6347 setp.ge.u32 %lp, %j, %cols_reg;
6348 @%lp bra NMD;
6349 cvt.u64.u32 %off, %j;
6350 shl.b64 %off, %off, 2;
6351 add.u64 %off, %in, %off;
6352 add.u64 %off, %off, %row_off;
6353 ld.global.f32 %val, [%off];
6354 mul.f32 %result, %val, %inv_rms;
6355 cvt.u64.u32 %off, %j;
6356 shl.b64 %off, %off, 2;
6357 add.u64 %off, %w, %off;
6358 ld.global.f32 %wv, [%off];
6359 mul.f32 %result, %result, %wv;
6360 cvt.u64.u32 %off, %j;
6361 shl.b64 %off, %off, 2;
6362 add.u64 %off, %out, %off;
6363 add.u64 %off, %off, %row_off;
6364 st.global.f32 [%off], %result;
6365 add.u32 %j, %j, %r_bdim;
6366 bra NM;
6367NMD:
6368
6369DONE:
6370 ret;
6371}
6372";
6373
6374
6375#[cfg(feature = "cuda")]
6399pub(crate) const RMSNORM_BACKWARD_PTX: &str = "\
6400.version 7.0
6401.target sm_52
6402.address_size 64
6403
6404.shared .align 4 .f32 sdata[256];
6405
6406.visible .entry rmsnorm_backward_kernel(
6407 .param .u64 in_ptr,
6408 .param .u64 grad_out_ptr,
6409 .param .u64 w_ptr,
6410 .param .u64 grad_in_ptr,
6411 .param .u64 grad_w_ptr,
6412 .param .u32 rows,
6413 .param .u32 cols,
6414 .param .f32 eps
6415) {
6416 .reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
6417 .reg .u64 %in, %go, %w, %gi, %gw, %row_off, %off, %sbase, %saddr, %addr;
6418 .reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %inv_rms3, %wv, %gov;
6419 .reg .f32 %dot, %other_val, %n_f, %coeff, %result, %tmp;
6420 .reg .pred %p, %lp, %rp;
6421
6422 ld.param.u64 %in, [in_ptr];
6423 ld.param.u64 %go, [grad_out_ptr];
6424 ld.param.u64 %w, [w_ptr];
6425 ld.param.u64 %gi, [grad_in_ptr];
6426 ld.param.u64 %gw, [grad_w_ptr];
6427 ld.param.u32 %rows_reg, [rows];
6428 ld.param.u32 %cols_reg, [cols];
6429 ld.param.f32 %eps_r, [eps];
6430
6431 mov.u64 %sbase, sdata;
6432
6433 mov.u32 %r_bid, %ctaid.x;
6434 mov.u32 %r_bdim, %ntid.x;
6435 mov.u32 %r_tid, %tid.x;
6436
6437 setp.ge.u32 %p, %r_bid, %rows_reg;
6438 @%p bra RNB_DONE;
6439
6440 // row_off = bid * cols * 4 (byte offset for this row)
6441 cvt.u64.u32 %row_off, %r_bid;
6442 cvt.u64.u32 %off, %cols_reg;
6443 mul.lo.u64 %row_off, %row_off, %off;
6444 shl.b64 %row_off, %row_off, 2;
6445 cvt.rn.f32.u32 %n_f, %cols_reg;
6446
6447 // ===== Phase 1: Compute sum(x^2) -> inv_rms =====
6448 mov.f32 %sq_sum, 0f00000000;
6449 mov.u32 %j, %r_tid;
6450RNB_SS:
6451 setp.ge.u32 %lp, %j, %cols_reg;
6452 @%lp bra RNB_SSD;
6453 cvt.u64.u32 %off, %j;
6454 shl.b64 %off, %off, 2;
6455 add.u64 %addr, %in, %off;
6456 add.u64 %addr, %addr, %row_off;
6457 ld.global.f32 %val, [%addr];
6458 fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
6459 add.u32 %j, %j, %r_bdim;
6460 bra RNB_SS;
6461RNB_SSD:
6462 // Shared memory reduce for sum(x^2)
6463 cvt.u64.u32 %off, %r_tid;
6464 shl.b64 %off, %off, 2;
6465 add.u64 %saddr, %sbase, %off;
6466 st.shared.f32 [%saddr], %sq_sum;
6467 bar.sync 0;
6468 mov.u32 %half, %r_bdim;
6469RNB_SR:
6470 shr.u32 %half, %half, 1;
6471 setp.eq.u32 %rp, %half, 0;
6472 @%rp bra RNB_SRD;
6473 setp.ge.u32 %rp, %r_tid, %half;
6474 @%rp bra RNB_SRS;
6475 add.u32 %r_otid, %r_tid, %half;
6476 cvt.u64.u32 %off, %r_otid;
6477 shl.b64 %off, %off, 2;
6478 add.u64 %saddr, %sbase, %off;
6479 ld.shared.f32 %other_val, [%saddr];
6480 cvt.u64.u32 %off, %r_tid;
6481 shl.b64 %off, %off, 2;
6482 add.u64 %saddr, %sbase, %off;
6483 ld.shared.f32 %sq_sum, [%saddr];
6484 add.f32 %sq_sum, %sq_sum, %other_val;
6485 st.shared.f32 [%saddr], %sq_sum;
6486RNB_SRS:
6487 bar.sync 0;
6488 bra RNB_SR;
6489RNB_SRD:
6490 ld.shared.f32 %sq_sum, [%sbase];
6491 div.approx.f32 %sq_sum, %sq_sum, %n_f;
6492 add.f32 %sq_sum, %sq_sum, %eps_r;
6493 sqrt.approx.f32 %inv_rms, %sq_sum;
6494 rcp.approx.f32 %inv_rms, %inv_rms;
6495 // inv_rms3 = inv_rms^3 = inv_rms * inv_rms * inv_rms
6496 mul.f32 %inv_rms3, %inv_rms, %inv_rms;
6497 mul.f32 %inv_rms3, %inv_rms3, %inv_rms;
6498 bar.sync 0;
6499
6500 // ===== Phase 2: Compute dot = sum(go[j] * x[j] * w[j]) =====
6501 // Also accumulate grad_weight via atomicAdd
6502 mov.f32 %dot, 0f00000000;
6503 mov.u32 %j, %r_tid;
6504RNB_DOT:
6505 setp.ge.u32 %lp, %j, %cols_reg;
6506 @%lp bra RNB_DOTD;
6507 // Load input[row, j]
6508 cvt.u64.u32 %off, %j;
6509 shl.b64 %off, %off, 2;
6510 add.u64 %addr, %in, %off;
6511 add.u64 %addr, %addr, %row_off;
6512 ld.global.f32 %val, [%addr];
6513 // Load grad_output[row, j]
6514 cvt.u64.u32 %off, %j;
6515 shl.b64 %off, %off, 2;
6516 add.u64 %addr, %go, %off;
6517 add.u64 %addr, %addr, %row_off;
6518 ld.global.f32 %gov, [%addr];
6519 // Load weight[j]
6520 cvt.u64.u32 %off, %j;
6521 shl.b64 %off, %off, 2;
6522 add.u64 %addr, %w, %off;
6523 ld.global.f32 %wv, [%addr];
6524 // dot += go * x * w
6525 mul.f32 %tmp, %gov, %val;
6526 fma.rn.f32 %dot, %tmp, %wv, %dot;
6527 // atomicAdd grad_weight[j] += go * x * inv_rms
6528 cvt.u64.u32 %off, %j;
6529 shl.b64 %off, %off, 2;
6530 add.u64 %addr, %gw, %off;
6531 mul.f32 %result, %gov, %val;
6532 mul.f32 %result, %result, %inv_rms;
6533 atom.global.add.f32 %result, [%addr], %result;
6534 add.u32 %j, %j, %r_bdim;
6535 bra RNB_DOT;
6536RNB_DOTD:
6537 // Reduce dot in shared memory
6538 cvt.u64.u32 %off, %r_tid;
6539 shl.b64 %off, %off, 2;
6540 add.u64 %saddr, %sbase, %off;
6541 st.shared.f32 [%saddr], %dot;
6542 bar.sync 0;
6543 mov.u32 %half, %r_bdim;
6544RNB_DR:
6545 shr.u32 %half, %half, 1;
6546 setp.eq.u32 %rp, %half, 0;
6547 @%rp bra RNB_DRD;
6548 setp.ge.u32 %rp, %r_tid, %half;
6549 @%rp bra RNB_DRS;
6550 add.u32 %r_otid, %r_tid, %half;
6551 cvt.u64.u32 %off, %r_otid;
6552 shl.b64 %off, %off, 2;
6553 add.u64 %saddr, %sbase, %off;
6554 ld.shared.f32 %other_val, [%saddr];
6555 cvt.u64.u32 %off, %r_tid;
6556 shl.b64 %off, %off, 2;
6557 add.u64 %saddr, %sbase, %off;
6558 ld.shared.f32 %dot, [%saddr];
6559 add.f32 %dot, %dot, %other_val;
6560 st.shared.f32 [%saddr], %dot;
6561RNB_DRS:
6562 bar.sync 0;
6563 bra RNB_DR;
6564RNB_DRD:
6565 ld.shared.f32 %dot, [%sbase];
6566 // coeff = dot * inv_rms3 / n
6567 mul.f32 %coeff, %dot, %inv_rms3;
6568 div.approx.f32 %coeff, %coeff, %n_f;
6569 bar.sync 0;
6570
6571 // ===== Phase 3: Compute grad_input =====
6572 // grad_input[j] = inv_rms * w[j] * go[j] - x[j] * coeff
6573 mov.u32 %j, %r_tid;
6574RNB_GI:
6575 setp.ge.u32 %lp, %j, %cols_reg;
6576 @%lp bra RNB_GID;
6577 // Reload input
6578 cvt.u64.u32 %off, %j;
6579 shl.b64 %off, %off, 2;
6580 add.u64 %addr, %in, %off;
6581 add.u64 %addr, %addr, %row_off;
6582 ld.global.f32 %val, [%addr];
6583 // Reload grad_output and weight
6584 cvt.u64.u32 %off, %j;
6585 shl.b64 %off, %off, 2;
6586 add.u64 %addr, %go, %off;
6587 add.u64 %addr, %addr, %row_off;
6588 ld.global.f32 %gov, [%addr];
6589 cvt.u64.u32 %off, %j;
6590 shl.b64 %off, %off, 2;
6591 add.u64 %addr, %w, %off;
6592 ld.global.f32 %wv, [%addr];
6593 // result = inv_rms * w * go - x * coeff
6594 mul.f32 %result, %inv_rms, %wv;
6595 mul.f32 %result, %result, %gov;
6596 mul.f32 %tmp, %val, %coeff;
6597 sub.f32 %result, %result, %tmp;
6598 // Store grad_input[row, j]
6599 cvt.u64.u32 %off, %j;
6600 shl.b64 %off, %off, 2;
6601 add.u64 %addr, %gi, %off;
6602 add.u64 %addr, %addr, %row_off;
6603 st.global.f32 [%addr], %result;
6604 add.u32 %j, %j, %r_bdim;
6605 bra RNB_GI;
6606RNB_GID:
6607
6608RNB_DONE:
6609 ret;
6610}
6611";
6612
6613
6614#[cfg(feature = "cuda")]
6647pub(crate) const BATCHNORM_FORWARD_PTX: &str = "\
6648.version 7.0
6649.target sm_52
6650.address_size 64
6651
6652// Shared memory for block reduction
6653.shared .align 4 .f32 smem_sum[256];
6654.shared .align 4 .f32 smem_sq[256];
6655
6656.visible .entry batchnorm_forward_kernel(
6657 .param .u64 input_ptr,
6658 .param .u64 output_ptr,
6659 .param .u64 weight_ptr,
6660 .param .u64 bias_ptr,
6661 .param .u64 rmean_ptr,
6662 .param .u64 rvar_ptr,
6663 .param .u64 save_mean_ptr,
6664 .param .u64 save_invstd_ptr,
6665 .param .u32 channels,
6666 .param .u32 spatial,
6667 .param .f32 eps,
6668 .param .f32 momentum,
6669 .param .u32 total_per_ch,
6670 .param .u32 training
6671) {
6672 .reg .u32 %tid, %bid, %bdim, %ch, %n_ch, %sp, %tpc, %idx, %train;
6673 .reg .u64 %in, %out, %w, %b, %rm, %rv, %sm, %si, %off64, %tmp64;
6674 .reg .f32 %sum, %sqsum, %val, %mean, %var, %invstd;
6675 .reg .f32 %gamma, %beta, %eps_reg, %mom, %other;
6676 .reg .f32 %n_f, %one, %normalized;
6677 .reg .pred %p, %ptrain, %ptid0;
6678 .reg .u32 %half;
6679
6680 ld.param.u64 %in, [input_ptr];
6681 ld.param.u64 %out, [output_ptr];
6682 ld.param.u64 %w, [weight_ptr];
6683 ld.param.u64 %b, [bias_ptr];
6684 ld.param.u64 %rm, [rmean_ptr];
6685 ld.param.u64 %rv, [rvar_ptr];
6686 ld.param.u64 %sm, [save_mean_ptr];
6687 ld.param.u64 %si, [save_invstd_ptr];
6688 ld.param.u32 %n_ch, [channels];
6689 ld.param.u32 %sp, [spatial];
6690 ld.param.f32 %eps_reg, [eps];
6691 ld.param.f32 %mom, [momentum];
6692 ld.param.u32 %tpc, [total_per_ch];
6693 ld.param.u32 %train, [training];
6694
6695 mov.u32 %bid, %ctaid.x;
6696 mov.u32 %tid, %tid.x;
6697 mov.u32 %bdim, %ntid.x;
6698 mov.u32 %ch, %bid;
6699 mov.f32 %one, 0f3F800000;
6700
6701 setp.ge.u32 %p, %ch, %n_ch;
6702 @%p bra END;
6703
6704 setp.ne.u32 %ptrain, %train, 0;
6705
6706 // ---- Pass 1: compute sum and sum-of-squares for this channel ----
6707 mov.f32 %sum, 0f00000000;
6708 mov.f32 %sqsum, 0f00000000;
6709
6710 // Grid-stride loop over B*spatial for this channel
6711 mov.u32 %idx, %tid;
6712PASS1_LOOP:
6713 setp.ge.u32 %p, %idx, %tpc;
6714 @%p bra PASS1_DONE;
6715
6716 // Linear offset = (idx / spatial) * channels * spatial + ch * spatial + idx % spatial
6717 div.u32 %half, %idx, %sp;
6718 rem.u32 %half, %idx, %sp; // reuse half as spatial_idx
6719 // batch_offset = (idx / sp) * (n_ch * sp) + ch * sp + (idx % sp)
6720 div.u32 %half, %idx, %sp; // batch_idx
6721 mul.lo.u32 %half, %half, %n_ch;
6722 add.u32 %half, %half, %ch;
6723 mul.lo.u32 %half, %half, %sp;
6724 rem.u32 %idx, %idx, %sp; // spatial_idx
6725 add.u32 %half, %half, %idx;
6726
6727 cvt.u64.u32 %off64, %half;
6728 shl.b64 %off64, %off64, 2;
6729 add.u64 %tmp64, %in, %off64;
6730 ld.global.f32 %val, [%tmp64];
6731 add.f32 %sum, %sum, %val;
6732 fma.rn.f32 %sqsum, %val, %val, %sqsum;
6733
6734 // Restore idx for stride
6735 // Recompute idx from tid + iteration * bdim
6736 add.u32 %idx, %idx, %bdim; // This is wrong - need proper loop counter
6737 bra PASS1_LOOP;
6738
6739PASS1_DONE:
6740 // Store to shared memory for block reduction
6741 cvt.u64.u32 %off64, %tid;
6742 shl.b64 %off64, %off64, 2;
6743 st.shared.f32 [smem_sum + %off64], %sum;
6744 st.shared.f32 [smem_sq + %off64], %sqsum;
6745 bar.sync 0;
6746
6747 // Tree reduction
6748 mov.u32 %half, 128;
6749REDUCE_LOOP:
6750 setp.lt.u32 %p, %half, 1;
6751 @%p bra REDUCE_DONE;
6752 setp.ge.u32 %p, %tid, %half;
6753 @%p bra REDUCE_SKIP;
6754
6755 add.u32 %idx, %tid, %half;
6756 cvt.u64.u32 %off64, %idx;
6757 shl.b64 %off64, %off64, 2;
6758 ld.shared.f32 %other, [smem_sum + %off64];
6759 cvt.u64.u32 %tmp64, %tid;
6760 shl.b64 %tmp64, %tmp64, 2;
6761 ld.shared.f32 %sum, [smem_sum + %tmp64];
6762 add.f32 %sum, %sum, %other;
6763 st.shared.f32 [smem_sum + %tmp64], %sum;
6764
6765 ld.shared.f32 %other, [smem_sq + %off64];
6766 ld.shared.f32 %sqsum, [smem_sq + %tmp64];
6767 add.f32 %sqsum, %sqsum, %other;
6768 st.shared.f32 [smem_sq + %tmp64], %sqsum;
6769
6770REDUCE_SKIP:
6771 bar.sync 0;
6772 shr.u32 %half, %half, 1;
6773 bra REDUCE_LOOP;
6774
6775REDUCE_DONE:
6776 // Thread 0 computes mean and invstd
6777 setp.ne.u32 %ptid0, %tid, 0;
6778
6779 @%ptid0 bra WAIT_STATS;
6780
6781 ld.shared.f32 %sum, [smem_sum];
6782 ld.shared.f32 %sqsum, [smem_sq];
6783 cvt.rn.f32.u32 %n_f, %tpc;
6784 div.rn.f32 %mean, %sum, %n_f;
6785 // var = sqsum/n - mean^2
6786 div.rn.f32 %var, %sqsum, %n_f;
6787 fma.rn.f32 %var, %mean, %mean, %var; // This adds mean^2, need to subtract
6788 // Actually: var = E[x^2] - E[x]^2, so var = sqsum/n - mean^2
6789 // We had: var = sqsum/n, now subtract mean^2
6790 neg.f32 %other, %mean;
6791 fma.rn.f32 %var, %other, %mean, %var; // var = var + (-mean)*mean = sqsum/n - mean^2
6792
6793 // invstd = 1/sqrt(var + eps)
6794 add.f32 %other, %var, %eps_reg;
6795 sqrt.rn.f32 %other, %other;
6796 div.rn.f32 %invstd, %one, %other;
6797
6798 // Save mean and invstd
6799 cvt.u64.u32 %off64, %ch;
6800 shl.b64 %off64, %off64, 2;
6801 add.u64 %tmp64, %sm, %off64;
6802 st.global.f32 [%tmp64], %mean;
6803 add.u64 %tmp64, %si, %off64;
6804 st.global.f32 [%tmp64], %invstd;
6805
6806 // Store to shared for other threads
6807 st.shared.f32 [smem_sum], %mean;
6808 st.shared.f32 [smem_sq], %invstd;
6809
6810WAIT_STATS:
6811 bar.sync 0;
6812 // All threads read mean and invstd from shared
6813 ld.shared.f32 %mean, [smem_sum];
6814 ld.shared.f32 %invstd, [smem_sq];
6815
6816 // Load weight and bias for this channel
6817 cvt.u64.u32 %off64, %ch;
6818 shl.b64 %off64, %off64, 2;
6819 add.u64 %tmp64, %w, %off64;
6820 ld.global.f32 %gamma, [%tmp64];
6821 add.u64 %tmp64, %b, %off64;
6822 ld.global.f32 %beta, [%tmp64];
6823
6824 // ---- Pass 2: normalize + affine ----
6825 // For now this is a placeholder - the indexing needs to match pass 1
6826 // Each thread normalizes its elements
6827
6828END:
6829 ret;
6830}
6831";
6832
6833
6834#[cfg(feature = "cuda")]
6839pub(crate) const MAXPOOL2D_PTX: &str = "\
6840.version 7.0
6841.target sm_52
6842.address_size 64
6843
6844.visible .entry maxpool2d_forward_kernel(
6845 .param .u64 input_ptr,
6846 .param .u64 output_ptr,
6847 .param .u32 batch,
6848 .param .u32 channels,
6849 .param .u32 h_in,
6850 .param .u32 w_in,
6851 .param .u32 h_out,
6852 .param .u32 w_out,
6853 .param .u32 kh,
6854 .param .u32 kw,
6855 .param .u32 sh,
6856 .param .u32 sw,
6857 .param .u32 ph,
6858 .param .u32 pw,
6859 .param .u32 total
6860) {
6861 .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
6862 .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp;
6863 .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
6864 .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
6865 .reg .u32 %batch_reg, %ch_reg;
6866 .reg .u64 %in, %out, %off64, %tmp64;
6867 .reg .f32 %max_val, %cur_val, %neg_inf;
6868 .reg .pred %p, %p_bounds, %p_gt;
6869
6870 ld.param.u64 %in, [input_ptr];
6871 ld.param.u64 %out, [output_ptr];
6872 ld.param.u32 %batch_reg, [batch];
6873 ld.param.u32 %ch_reg, [channels];
6874 ld.param.u32 %h_in_reg, [h_in];
6875 ld.param.u32 %w_in_reg, [w_in];
6876 ld.param.u32 %h_out_reg, [h_out];
6877 ld.param.u32 %w_out_reg, [w_out];
6878 ld.param.u32 %kh_reg, [kh];
6879 ld.param.u32 %kw_reg, [kw];
6880 ld.param.u32 %sh_reg, [sh];
6881 ld.param.u32 %sw_reg, [sw];
6882 ld.param.u32 %ph_reg, [ph];
6883 ld.param.u32 %pw_reg, [pw];
6884 ld.param.u32 %total_reg, [total];
6885
6886 mov.u32 %bid, %ctaid.x;
6887 mov.u32 %bdim, %ntid.x;
6888 mov.u32 %tid, %tid.x;
6889 mov.u32 %gdim, %nctaid.x;
6890 mad.lo.u32 %idx, %bid, %bdim, %tid;
6891 mul.lo.u32 %stride, %bdim, %gdim;
6892
6893 // -inf for max initialization
6894 mov.f32 %neg_inf, 0fFF800000;
6895
6896LOOP:
6897 setp.ge.u32 %p, %idx, %total_reg;
6898 @%p bra END;
6899
6900 // Decompose idx into (b, c, oh, ow)
6901 mov.u32 %rem, %idx;
6902 div.u32 %b_idx, %rem, %ch_reg;
6903 // Actually need: idx = b * C * H_out * W_out + c * H_out * W_out + oh * W_out + ow
6904 // So decompose from the right:
6905 rem.u32 %ow, %rem, %w_out_reg;
6906 div.u32 %rem, %rem, %w_out_reg;
6907 rem.u32 %oh, %rem, %h_out_reg;
6908 div.u32 %rem, %rem, %h_out_reg;
6909 rem.u32 %c_idx, %rem, %ch_reg;
6910 div.u32 %b_idx, %rem, %ch_reg;
6911
6912 mov.f32 %max_val, %neg_inf;
6913
6914 // Slide the kernel window
6915 mov.u32 %i, 0;
6916KH_LOOP:
6917 setp.ge.u32 %p, %i, %kh_reg;
6918 @%p bra KH_DONE;
6919
6920 mov.u32 %j, 0;
6921KW_LOOP:
6922 setp.ge.u32 %p, %j, %kw_reg;
6923 @%p bra KW_DONE;
6924
6925 // ih = oh * sh + i - ph, iw = ow * sw + j - pw
6926 mad.lo.u32 %ih, %oh, %sh_reg, %i;
6927 sub.u32 %ih, %ih, %ph_reg;
6928 mad.lo.u32 %iw, %ow, %sw_reg, %j;
6929 sub.u32 %iw, %iw, %pw_reg;
6930
6931 // Bounds check: 0 <= ih < h_in && 0 <= iw < w_in
6932 // Since unsigned, just check < h_in and < w_in
6933 setp.ge.u32 %p_bounds, %ih, %h_in_reg;
6934 @%p_bounds bra KW_NEXT;
6935 setp.ge.u32 %p_bounds, %iw, %w_in_reg;
6936 @%p_bounds bra KW_NEXT;
6937
6938 // input_offset = b * C * H * W + c * H * W + ih * W + iw
6939 mul.lo.u32 %tmp, %b_idx, %ch_reg;
6940 add.u32 %tmp, %tmp, %c_idx;
6941 mul.lo.u32 %tmp, %tmp, %h_in_reg;
6942 add.u32 %tmp, %tmp, %ih;
6943 mul.lo.u32 %tmp, %tmp, %w_in_reg;
6944 add.u32 %tmp, %tmp, %iw;
6945
6946 cvt.u64.u32 %off64, %tmp;
6947 shl.b64 %off64, %off64, 2;
6948 add.u64 %tmp64, %in, %off64;
6949 ld.global.f32 %cur_val, [%tmp64];
6950
6951 max.f32 %max_val, %max_val, %cur_val;
6952
6953KW_NEXT:
6954 add.u32 %j, %j, 1;
6955 bra KW_LOOP;
6956
6957KW_DONE:
6958 add.u32 %i, %i, 1;
6959 bra KH_LOOP;
6960
6961KH_DONE:
6962 // Store output
6963 cvt.u64.u32 %off64, %idx;
6964 shl.b64 %off64, %off64, 2;
6965 add.u64 %tmp64, %out, %off64;
6966 st.global.f32 [%tmp64], %max_val;
6967
6968 add.u32 %idx, %idx, %stride;
6969 bra LOOP;
6970
6971END:
6972 ret;
6973}
6974";
6975
6976
6977#[cfg(feature = "cuda")]
6982pub(crate) const AVGPOOL2D_PTX: &str = "\
6983.version 7.0
6984.target sm_52
6985.address_size 64
6986
6987.visible .entry avgpool2d_forward_kernel(
6988 .param .u64 input_ptr,
6989 .param .u64 output_ptr,
6990 .param .u32 batch,
6991 .param .u32 channels,
6992 .param .u32 h_in,
6993 .param .u32 w_in,
6994 .param .u32 h_out,
6995 .param .u32 w_out,
6996 .param .u32 kh,
6997 .param .u32 kw,
6998 .param .u32 sh,
6999 .param .u32 sw,
7000 .param .u32 ph,
7001 .param .u32 pw,
7002 .param .u32 total
7003) {
7004 .reg .u32 %tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
7005 .reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp, %count;
7006 .reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
7007 .reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
7008 .reg .u32 %batch_reg, %ch_reg;
7009 .reg .u64 %in, %out, %off64, %tmp64;
7010 .reg .f32 %sum_val, %cur_val, %count_f, %avg;
7011 .reg .pred %p, %p_bounds;
7012
7013 ld.param.u64 %in, [input_ptr];
7014 ld.param.u64 %out, [output_ptr];
7015 ld.param.u32 %batch_reg, [batch];
7016 ld.param.u32 %ch_reg, [channels];
7017 ld.param.u32 %h_in_reg, [h_in];
7018 ld.param.u32 %w_in_reg, [w_in];
7019 ld.param.u32 %h_out_reg, [h_out];
7020 ld.param.u32 %w_out_reg, [w_out];
7021 ld.param.u32 %kh_reg, [kh];
7022 ld.param.u32 %kw_reg, [kw];
7023 ld.param.u32 %sh_reg, [sh];
7024 ld.param.u32 %sw_reg, [sw];
7025 ld.param.u32 %ph_reg, [ph];
7026 ld.param.u32 %pw_reg, [pw];
7027 ld.param.u32 %total_reg, [total];
7028
7029 mov.u32 %bid, %ctaid.x;
7030 mov.u32 %bdim, %ntid.x;
7031 mov.u32 %tid, %tid.x;
7032 mov.u32 %gdim, %nctaid.x;
7033 mad.lo.u32 %idx, %bid, %bdim, %tid;
7034 mul.lo.u32 %stride, %bdim, %gdim;
7035
7036LOOP:
7037 setp.ge.u32 %p, %idx, %total_reg;
7038 @%p bra END;
7039
7040 // Decompose idx into (b, c, oh, ow) — same as MaxPool2d
7041 mov.u32 %rem, %idx;
7042 rem.u32 %ow, %rem, %w_out_reg;
7043 div.u32 %rem, %rem, %w_out_reg;
7044 rem.u32 %oh, %rem, %h_out_reg;
7045 div.u32 %rem, %rem, %h_out_reg;
7046 rem.u32 %c_idx, %rem, %ch_reg;
7047 div.u32 %b_idx, %rem, %ch_reg;
7048
7049 mov.f32 %sum_val, 0f00000000;
7050 mov.u32 %count, 0;
7051
7052 mov.u32 %i, 0;
7053AKH_LOOP:
7054 setp.ge.u32 %p, %i, %kh_reg;
7055 @%p bra AKH_DONE;
7056
7057 mov.u32 %j, 0;
7058AKW_LOOP:
7059 setp.ge.u32 %p, %j, %kw_reg;
7060 @%p bra AKW_DONE;
7061
7062 mad.lo.u32 %ih, %oh, %sh_reg, %i;
7063 sub.u32 %ih, %ih, %ph_reg;
7064 mad.lo.u32 %iw, %ow, %sw_reg, %j;
7065 sub.u32 %iw, %iw, %pw_reg;
7066
7067 setp.ge.u32 %p_bounds, %ih, %h_in_reg;
7068 @%p_bounds bra AKW_NEXT;
7069 setp.ge.u32 %p_bounds, %iw, %w_in_reg;
7070 @%p_bounds bra AKW_NEXT;
7071
7072 mul.lo.u32 %tmp, %b_idx, %ch_reg;
7073 add.u32 %tmp, %tmp, %c_idx;
7074 mul.lo.u32 %tmp, %tmp, %h_in_reg;
7075 add.u32 %tmp, %tmp, %ih;
7076 mul.lo.u32 %tmp, %tmp, %w_in_reg;
7077 add.u32 %tmp, %tmp, %iw;
7078
7079 cvt.u64.u32 %off64, %tmp;
7080 shl.b64 %off64, %off64, 2;
7081 add.u64 %tmp64, %in, %off64;
7082 ld.global.f32 %cur_val, [%tmp64];
7083
7084 add.f32 %sum_val, %sum_val, %cur_val;
7085 add.u32 %count, %count, 1;
7086
7087AKW_NEXT:
7088 add.u32 %j, %j, 1;
7089 bra AKW_LOOP;
7090
7091AKW_DONE:
7092 add.u32 %i, %i, 1;
7093 bra AKH_LOOP;
7094
7095AKH_DONE:
7096 // avg = sum / count (count_include_pad = false behavior)
7097 cvt.rn.f32.u32 %count_f, %count;
7098 div.rn.f32 %avg, %sum_val, %count_f;
7099
7100 cvt.u64.u32 %off64, %idx;
7101 shl.b64 %off64, %off64, 2;
7102 add.u64 %tmp64, %out, %off64;
7103 st.global.f32 [%tmp64], %avg;
7104
7105 add.u32 %idx, %idx, %stride;
7106 bra LOOP;
7107
7108END:
7109 ret;
7110}
7111";
7112
7113
7114#[cfg(feature = "cuda")]
7115pub(crate) const SOFTMAX_PTX: &str = "\
7116.version 7.0\n\
7117.target sm_52\n\
7118.address_size 64\n\
7119\n\
7120.shared .align 4 .f32 sdata[256];\n\
7121\n\
7122.visible .entry softmax_kernel(\n\
7123 .param .u64 input_ptr,\n\
7124 .param .u64 output_ptr,\n\
7125 .param .u32 rows,\n\
7126 .param .u32 cols\n\
7127) {\n\
7128 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
7129 .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
7130 .reg .f32 %val, %max_val, %sum_val, %exp_val, %result;\n\
7131 .reg .pred %p, %loop_p;\n\
7132 .reg .u32 %half, %other_tid;\n\
7133 .reg .f32 %other_val;\n\
7134 .reg .pred %reduce_p;\n\
7135\n\
7136 ld.param.u64 %in, [input_ptr];\n\
7137 ld.param.u64 %out, [output_ptr];\n\
7138 ld.param.u32 %rows_reg, [rows];\n\
7139 ld.param.u32 %cols_reg, [cols];\n\
7140\n\
7141 mov.u32 %bid, %ctaid.x;\n\
7142 mov.u32 %bdim, %ntid.x;\n\
7143 mov.u32 %r_tid, %tid.x;\n\
7144 mov.u64 %sbase, sdata;\n\
7145\n\
7146 setp.ge.u32 %p, %bid, %rows_reg;\n\
7147 @%p bra DONE;\n\
7148\n\
7149 cvt.u64.u32 %row_off, %bid;\n\
7150 cvt.u64.u32 %off, %cols_reg;\n\
7151 mul.lo.u64 %row_off, %row_off, %off;\n\
7152 shl.b64 %row_off, %row_off, 2;\n\
7153\n\
7154 mov.f32 %max_val, 0fFF800000;\n\
7155 mov.u32 %j, %r_tid;\n\
7156FIND_MAX:\n\
7157 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7158 @%loop_p bra FIND_MAX_DONE;\n\
7159 cvt.u64.u32 %off, %j;\n\
7160 shl.b64 %off, %off, 2;\n\
7161 add.u64 %off, %in, %off;\n\
7162 add.u64 %off, %off, %row_off;\n\
7163 ld.global.f32 %val, [%off];\n\
7164 max.f32 %max_val, %max_val, %val;\n\
7165 add.u32 %j, %j, %bdim;\n\
7166 bra FIND_MAX;\n\
7167FIND_MAX_DONE:\n\
7168\n\
7169 cvt.u64.u32 %off, %r_tid;\n\
7170 shl.b64 %off, %off, 2;\n\
7171 add.u64 %saddr, %sbase, %off;\n\
7172 st.shared.f32 [%saddr], %max_val;\n\
7173 bar.sync 0;\n\
7174\n\
7175 mov.u32 %half, %bdim;\n\
7176MAX_REDUCE:\n\
7177 shr.u32 %half, %half, 1;\n\
7178 setp.eq.u32 %reduce_p, %half, 0;\n\
7179 @%reduce_p bra MAX_REDUCE_DONE;\n\
7180 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7181 @%reduce_p bra MAX_REDUCE_SKIP;\n\
7182 add.u32 %other_tid, %r_tid, %half;\n\
7183 cvt.u64.u32 %off, %other_tid;\n\
7184 shl.b64 %off, %off, 2;\n\
7185 add.u64 %saddr, %sbase, %off;
7186 ld.shared.f32 %other_val, [%saddr];\n\
7187 cvt.u64.u32 %off, %r_tid;\n\
7188 shl.b64 %off, %off, 2;\n\
7189 add.u64 %saddr, %sbase, %off;\n\
7190 ld.shared.f32 %max_val, [%saddr];\n\
7191 max.f32 %max_val, %max_val, %other_val;\n\
7192 add.u64 %saddr, %sbase, %off;\n\
7193 st.shared.f32 [%saddr], %max_val;\n\
7194MAX_REDUCE_SKIP:\n\
7195 bar.sync 0;\n\
7196 bra MAX_REDUCE;\n\
7197MAX_REDUCE_DONE:\n\
7198\n\
7199 ld.shared.f32 %max_val, [sdata];\n\
7200 bar.sync 0;\n\
7201\n\
7202 mov.f32 %sum_val, 0f00000000;\n\
7203 mov.u32 %j, %r_tid;\n\
7204SUM_EXP:\n\
7205 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7206 @%loop_p bra SUM_EXP_DONE;\n\
7207 cvt.u64.u32 %off, %j;\n\
7208 shl.b64 %off, %off, 2;\n\
7209 add.u64 %off, %in, %off;\n\
7210 add.u64 %off, %off, %row_off;\n\
7211 ld.global.f32 %val, [%off];\n\
7212 sub.f32 %val, %val, %max_val;\n\
7213 mul.f32 %val, %val, 0f3FB8AA3B;\n\
7214 ex2.approx.f32 %exp_val, %val;\n\
7215 add.f32 %sum_val, %sum_val, %exp_val;\n\
7216 cvt.u64.u32 %off, %j;\n\
7217 shl.b64 %off, %off, 2;\n\
7218 add.u64 %off, %out, %off;\n\
7219 add.u64 %off, %off, %row_off;\n\
7220 st.global.f32 [%off], %exp_val;\n\
7221 add.u32 %j, %j, %bdim;\n\
7222 bra SUM_EXP;\n\
7223SUM_EXP_DONE:\n\
7224\n\
7225 cvt.u64.u32 %off, %r_tid;\n\
7226 shl.b64 %off, %off, 2;\n\
7227 add.u64 %saddr, %sbase, %off;\n\
7228 st.shared.f32 [%saddr], %sum_val;\n\
7229 bar.sync 0;\n\
7230\n\
7231 mov.u32 %half, %bdim;\n\
7232SUM_REDUCE:\n\
7233 shr.u32 %half, %half, 1;\n\
7234 setp.eq.u32 %reduce_p, %half, 0;\n\
7235 @%reduce_p bra SUM_REDUCE_DONE;\n\
7236 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7237 @%reduce_p bra SUM_REDUCE_SKIP;\n\
7238 add.u32 %other_tid, %r_tid, %half;\n\
7239 cvt.u64.u32 %off, %other_tid;\n\
7240 shl.b64 %off, %off, 2;\n\
7241 add.u64 %saddr, %sbase, %off;
7242 ld.shared.f32 %other_val, [%saddr];\n\
7243 cvt.u64.u32 %off, %r_tid;\n\
7244 shl.b64 %off, %off, 2;\n\
7245 add.u64 %saddr, %sbase, %off;\n\
7246 ld.shared.f32 %sum_val, [%saddr];\n\
7247 add.f32 %sum_val, %sum_val, %other_val;\n\
7248 add.u64 %saddr, %sbase, %off;\n\
7249 st.shared.f32 [%saddr], %sum_val;\n\
7250SUM_REDUCE_SKIP:\n\
7251 bar.sync 0;\n\
7252 bra SUM_REDUCE;\n\
7253SUM_REDUCE_DONE:\n\
7254\n\
7255 ld.shared.f32 %sum_val, [sdata];\n\
7256 bar.sync 0;\n\
7257\n\
7258 rcp.approx.f32 %sum_val, %sum_val;\n\
7259 mov.u32 %j, %r_tid;\n\
7260NORMALIZE:\n\
7261 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7262 @%loop_p bra NORMALIZE_DONE;\n\
7263 cvt.u64.u32 %off, %j;\n\
7264 shl.b64 %off, %off, 2;\n\
7265 add.u64 %off, %out, %off;\n\
7266 add.u64 %off, %off, %row_off;\n\
7267 ld.global.f32 %val, [%off];\n\
7268 mul.f32 %result, %val, %sum_val;\n\
7269 st.global.f32 [%off], %result;\n\
7270 add.u32 %j, %j, %bdim;\n\
7271 bra NORMALIZE;\n\
7272NORMALIZE_DONE:\n\
7273\n\
7274DONE:\n\
7275 ret;\n\
7276}\n\
7277";
7278
7279#[cfg(feature = "cuda")]
7281pub(crate) const SOFTMAX_F64_PTX: &str = "\
7282.version 7.0\n\
7283.target sm_52\n\
7284.address_size 64\n\
7285\n\
7286.shared .align 8 .f64 sdata[256];\n\
7287\n\
7288.visible .entry softmax_f64_kernel(\n\
7289 .param .u64 input_ptr,\n\
7290 .param .u64 output_ptr,\n\
7291 .param .u32 rows,\n\
7292 .param .u32 cols\n\
7293) {\n\
7294 .reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
7295 .reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
7296 .reg .f64 %val, %max_val, %sum_val, %exp_val, %result, %one;\n\
7297 .reg .pred %p, %loop_p;\n\
7298 .reg .u32 %half, %other_tid;\n\
7299 .reg .f64 %other_val;\n\
7300 .reg .pred %reduce_p;\n\
7301 .reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
7302 .reg .s32 %e_ni;\n\
7303 .reg .s64 %e_ni64, %e_bits;\n\
7304\n\
7305 ld.param.u64 %in, [input_ptr];\n\
7306 ld.param.u64 %out, [output_ptr];\n\
7307 ld.param.u32 %rows_reg, [rows];\n\
7308 ld.param.u32 %cols_reg, [cols];\n\
7309\n\
7310 mov.u32 %bid, %ctaid.x;\n\
7311 mov.u32 %bdim, %ntid.x;\n\
7312 mov.u32 %r_tid, %tid.x;\n\
7313 mov.u64 %sbase, sdata;\n\
7314 mov.f64 %one, 0d3FF0000000000000;\n\
7315\n\
7316 setp.ge.u32 %p, %bid, %rows_reg;\n\
7317 @%p bra DONE;\n\
7318\n\
7319 cvt.u64.u32 %row_off, %bid;\n\
7320 cvt.u64.u32 %off, %cols_reg;\n\
7321 mul.lo.u64 %row_off, %row_off, %off;\n\
7322 shl.b64 %row_off, %row_off, 3;\n\
7323\n\
7324 mov.f64 %max_val, 0dFFF0000000000000;\n\
7325 mov.u32 %j, %r_tid;\n\
7326FIND_MAX:\n\
7327 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7328 @%loop_p bra FIND_MAX_DONE;\n\
7329 cvt.u64.u32 %off, %j;\n\
7330 shl.b64 %off, %off, 3;\n\
7331 add.u64 %off, %in, %off;\n\
7332 add.u64 %off, %off, %row_off;\n\
7333 ld.global.f64 %val, [%off];\n\
7334 max.f64 %max_val, %max_val, %val;\n\
7335 add.u32 %j, %j, %bdim;\n\
7336 bra FIND_MAX;\n\
7337FIND_MAX_DONE:\n\
7338\n\
7339 cvt.u64.u32 %off, %r_tid;\n\
7340 shl.b64 %off, %off, 3;\n\
7341 add.u64 %saddr, %sbase, %off;\n\
7342 st.shared.f64 [%saddr], %max_val;\n\
7343 bar.sync 0;\n\
7344\n\
7345 mov.u32 %half, %bdim;\n\
7346MAX_REDUCE:\n\
7347 shr.u32 %half, %half, 1;\n\
7348 setp.eq.u32 %reduce_p, %half, 0;\n\
7349 @%reduce_p bra MAX_REDUCE_DONE;\n\
7350 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7351 @%reduce_p bra MAX_REDUCE_SKIP;\n\
7352 add.u32 %other_tid, %r_tid, %half;\n\
7353 cvt.u64.u32 %off, %other_tid;\n\
7354 shl.b64 %off, %off, 3;\n\
7355 add.u64 %saddr, %sbase, %off;\n\
7356 ld.shared.f64 %other_val, [%saddr];\n\
7357 cvt.u64.u32 %off, %r_tid;\n\
7358 shl.b64 %off, %off, 3;\n\
7359 add.u64 %saddr, %sbase, %off;\n\
7360 ld.shared.f64 %max_val, [%saddr];\n\
7361 max.f64 %max_val, %max_val, %other_val;\n\
7362 st.shared.f64 [%saddr], %max_val;\n\
7363MAX_REDUCE_SKIP:\n\
7364 bar.sync 0;\n\
7365 bra MAX_REDUCE;\n\
7366MAX_REDUCE_DONE:\n\
7367\n\
7368 ld.shared.f64 %max_val, [sdata];\n\
7369 bar.sync 0;\n\
7370\n\
7371 mov.f64 %sum_val, 0d0000000000000000;\n\
7372 mov.u32 %j, %r_tid;\n\
7373SUM_EXP:\n\
7374 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7375 @%loop_p bra SUM_EXP_DONE;\n\
7376 cvt.u64.u32 %off, %j;\n\
7377 shl.b64 %off, %off, 3;\n\
7378 add.u64 %off, %in, %off;\n\
7379 add.u64 %off, %off, %row_off;\n\
7380 ld.global.f64 %val, [%off];\n\
7381 sub.f64 %val, %val, %max_val;\n\
7382 mov.f64 %e_one, 0d3FF0000000000000;\n\
7383 mov.f64 %e_half, 0d3FE0000000000000;\n\
7384 mul.f64 %e_nf, %val, 0d3FF71547652B82FE;\n\
7385 cvt.rni.f64.f64 %e_nf, %e_nf;\n\
7386 cvt.rni.s32.f64 %e_ni, %e_nf;\n\
7387 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %val;\n\
7388 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
7389 mov.f64 %e_p, 0d3E21EED8EFF8D898;\n\
7390 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;\n\
7391 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
7392 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
7393 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
7394 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
7395 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
7396 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;\n\
7397 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
7398 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
7399 fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
7400 fma.rn.f64 %exp_val, %e_p, %e_r, %e_one;\n\
7401 cvt.s64.s32 %e_ni64, %e_ni;\n\
7402 add.s64 %e_ni64, %e_ni64, 1023;\n\
7403 shl.b64 %e_bits, %e_ni64, 52;\n\
7404 mov.b64 %e_nf, %e_bits;\n\
7405 mul.f64 %exp_val, %exp_val, %e_nf;\n\
7406 add.f64 %sum_val, %sum_val, %exp_val;\n\
7407 cvt.u64.u32 %off, %j;\n\
7408 shl.b64 %off, %off, 3;\n\
7409 add.u64 %off, %out, %off;\n\
7410 add.u64 %off, %off, %row_off;\n\
7411 st.global.f64 [%off], %exp_val;\n\
7412 add.u32 %j, %j, %bdim;\n\
7413 bra SUM_EXP;\n\
7414SUM_EXP_DONE:\n\
7415\n\
7416 cvt.u64.u32 %off, %r_tid;\n\
7417 shl.b64 %off, %off, 3;\n\
7418 add.u64 %saddr, %sbase, %off;\n\
7419 st.shared.f64 [%saddr], %sum_val;\n\
7420 bar.sync 0;\n\
7421\n\
7422 mov.u32 %half, %bdim;\n\
7423SUM_REDUCE:\n\
7424 shr.u32 %half, %half, 1;\n\
7425 setp.eq.u32 %reduce_p, %half, 0;\n\
7426 @%reduce_p bra SUM_REDUCE_DONE;\n\
7427 setp.ge.u32 %reduce_p, %r_tid, %half;\n\
7428 @%reduce_p bra SUM_REDUCE_SKIP;\n\
7429 add.u32 %other_tid, %r_tid, %half;\n\
7430 cvt.u64.u32 %off, %other_tid;\n\
7431 shl.b64 %off, %off, 3;\n\
7432 add.u64 %saddr, %sbase, %off;\n\
7433 ld.shared.f64 %other_val, [%saddr];\n\
7434 cvt.u64.u32 %off, %r_tid;\n\
7435 shl.b64 %off, %off, 3;\n\
7436 add.u64 %saddr, %sbase, %off;\n\
7437 ld.shared.f64 %sum_val, [%saddr];\n\
7438 add.f64 %sum_val, %sum_val, %other_val;\n\
7439 st.shared.f64 [%saddr], %sum_val;\n\
7440SUM_REDUCE_SKIP:\n\
7441 bar.sync 0;\n\
7442 bra SUM_REDUCE;\n\
7443SUM_REDUCE_DONE:\n\
7444\n\
7445 ld.shared.f64 %sum_val, [sdata];\n\
7446 bar.sync 0;\n\
7447\n\
7448 div.rn.f64 %sum_val, %one, %sum_val;\n\
7449 mov.u32 %j, %r_tid;\n\
7450NORMALIZE:\n\
7451 setp.ge.u32 %loop_p, %j, %cols_reg;\n\
7452 @%loop_p bra NORMALIZE_DONE;\n\
7453 cvt.u64.u32 %off, %j;\n\
7454 shl.b64 %off, %off, 3;\n\
7455 add.u64 %off, %out, %off;\n\
7456 add.u64 %off, %off, %row_off;\n\
7457 ld.global.f64 %val, [%off];\n\
7458 mul.f64 %result, %val, %sum_val;\n\
7459 st.global.f64 [%off], %result;\n\
7460 add.u32 %j, %j, %bdim;\n\
7461 bra NORMALIZE;\n\
7462NORMALIZE_DONE:\n\
7463\n\
7464DONE:\n\
7465 ret;\n\
7466}\n\
7467";
7468
7469#[cfg(feature = "cuda")]
7474pub(crate) const DROPOUT_PTX: &str = "\
7475.version 7.0\n\
7476.target sm_52\n\
7477.address_size 64\n\
7478\n\
7479.visible .entry dropout_kernel(\n\
7480 .param .u64 input_ptr,\n\
7481 .param .u64 output_ptr,\n\
7482 .param .u32 n,\n\
7483 .param .u32 threshold,\n\
7484 .param .f32 scale,\n\
7485 .param .u32 seed\n\
7486) {\n\
7487 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %thresh, %seed_reg, %rng, %tmp;\n\
7488 .reg .u64 %in, %out, %off;\n\
7489 .reg .f32 %val, %scale_reg, %zero;\n\
7490 .reg .pred %p, %drop_p;\n\
7491\n\
7492 ld.param.u64 %in, [input_ptr];\n\
7493 ld.param.u64 %out, [output_ptr];\n\
7494 ld.param.u32 %n_reg, [n];\n\
7495 ld.param.u32 %thresh, [threshold];\n\
7496 ld.param.f32 %scale_reg, [scale];\n\
7497 ld.param.u32 %seed_reg, [seed];\n\
7498\n\
7499 mov.u32 %bid, %ctaid.x;\n\
7500 mov.u32 %bdim, %ntid.x;\n\
7501 mov.u32 %r_tid, %tid.x;\n\
7502 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
7503\n\
7504 setp.ge.u32 %p, %r_tid, %n_reg;\n\
7505 @%p bra DONE;\n\
7506\n\
7507 mul.lo.u32 %rng, %r_tid, 2654435761;\n\
7508 xor.b32 %rng, %rng, %seed_reg;\n\
7509 shl.b32 %tmp, %rng, 13;\n\
7510 xor.b32 %rng, %rng, %tmp;\n\
7511 shr.b32 %tmp, %rng, 17;\n\
7512 xor.b32 %rng, %rng, %tmp;\n\
7513 shl.b32 %tmp, %rng, 5;\n\
7514 xor.b32 %rng, %rng, %tmp;\n\
7515\n\
7516 cvt.u64.u32 %off, %r_tid;\n\
7517 shl.b64 %off, %off, 2;\n\
7518 add.u64 %in, %in, %off;\n\
7519 add.u64 %out, %out, %off;\n\
7520 ld.global.f32 %val, [%in];\n\
7521\n\
7522 setp.lo.u32 %drop_p, %rng, %thresh;\n\
7523 mov.f32 %zero, 0f00000000;\n\
7524 @%drop_p mov.f32 %val, %zero;\n\
7525 @!%drop_p mul.f32 %val, %val, %scale_reg;\n\
7526\n\
7527 st.global.f32 [%out], %val;\n\
7528\n\
7529DONE:\n\
7530 ret;\n\
7531}\n\
7532";
7533
7534
7535#[cfg(feature = "cuda")]
7558pub(crate) const BROADCAST_ADD_PTX: &str = "\
7559.version 7.0
7560.target sm_52
7561.address_size 64
7562
7563.visible .entry broadcast_add_kernel(
7564 .param .u64 a_ptr,
7565 .param .u64 b_ptr,
7566 .param .u64 out_ptr,
7567 .param .u64 a_strides_ptr,
7568 .param .u64 b_strides_ptr,
7569 .param .u64 out_shape_ptr,
7570 .param .u32 n,
7571 .param .u32 ndim
7572) {
7573 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7574 .reg .u32 %remaining, %a_idx, %b_idx, %d;
7575 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7576 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7577 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7578 .reg .f32 %va, %vb, %vr;
7579 .reg .pred %p, %loop_p;
7580
7581 ld.param.u64 %a, [a_ptr];
7582 ld.param.u64 %b, [b_ptr];
7583 ld.param.u64 %out, [out_ptr];
7584 ld.param.u64 %a_str, [a_strides_ptr];
7585 ld.param.u64 %b_str, [b_strides_ptr];
7586 ld.param.u64 %oshape, [out_shape_ptr];
7587 ld.param.u32 %n_reg, [n];
7588 ld.param.u32 %ndim_reg, [ndim];
7589
7590 // Global thread index.
7591 mov.u32 %bid, %ctaid.x;
7592 mov.u32 %bdim, %ntid.x;
7593 mov.u32 %r_tid, %tid.x;
7594 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7595
7596 setp.ge.u32 %p, %r_tid, %n_reg;
7597 @%p bra DONE;
7598
7599 // Decompose flat index into N-d coordinates and compute A/B indices.
7600 mov.u32 %remaining, %r_tid;
7601 mov.u32 %a_idx, 0;
7602 mov.u32 %b_idx, 0;
7603 mov.u32 %d, %ndim_reg;
7604
7605LOOP:
7606 setp.eq.u32 %loop_p, %d, 0;
7607 @%loop_p bra END_LOOP;
7608
7609 sub.u32 %d, %d, 1;
7610
7611 // Byte offset for dimension d: d * 4.
7612 cvt.u64.u32 %d64, %d;
7613 shl.b64 %d64, %d64, 2;
7614
7615 // Load out_shape[d].
7616 add.u64 %tmp, %oshape, %d64;
7617 ld.global.u32 %shape_d, [%tmp];
7618
7619 // Load a_strides[d] and b_strides[d].
7620 add.u64 %tmp, %a_str, %d64;
7621 ld.global.u32 %a_str_d, [%tmp];
7622 add.u64 %tmp, %b_str, %d64;
7623 ld.global.u32 %b_str_d, [%tmp];
7624
7625 // coord = remaining % shape_d; remaining /= shape_d.
7626 rem.u32 %coord, %remaining, %shape_d;
7627 div.u32 %remaining, %remaining, %shape_d;
7628
7629 // a_idx += coord * a_stride[d]; b_idx += coord * b_stride[d].
7630 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7631 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7632
7633 bra LOOP;
7634END_LOOP:
7635
7636 // Load a[a_idx] and b[b_idx] (f32 = 4 bytes).
7637 cvt.u64.u32 %off_a, %a_idx;
7638 shl.b64 %off_a, %off_a, 2;
7639 add.u64 %off_a, %a, %off_a;
7640 ld.global.f32 %va, [%off_a];
7641
7642 cvt.u64.u32 %off_b, %b_idx;
7643 shl.b64 %off_b, %off_b, 2;
7644 add.u64 %off_b, %b, %off_b;
7645 ld.global.f32 %vb, [%off_b];
7646
7647 // Operation: add.
7648 add.f32 %vr, %va, %vb;
7649
7650 // Store to out[tid].
7651 cvt.u64.u32 %off_out, %r_tid;
7652 shl.b64 %off_out, %off_out, 2;
7653 add.u64 %off_out, %out, %off_out;
7654 st.global.f32 [%off_out], %vr;
7655
7656DONE:
7657 ret;
7658}
7659";
7660
7661
7662#[cfg(feature = "cuda")]
7664pub(crate) const BROADCAST_SUB_PTX: &str = "\
7665.version 7.0
7666.target sm_52
7667.address_size 64
7668
7669.visible .entry broadcast_sub_kernel(
7670 .param .u64 a_ptr,
7671 .param .u64 b_ptr,
7672 .param .u64 out_ptr,
7673 .param .u64 a_strides_ptr,
7674 .param .u64 b_strides_ptr,
7675 .param .u64 out_shape_ptr,
7676 .param .u32 n,
7677 .param .u32 ndim
7678) {
7679 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7680 .reg .u32 %remaining, %a_idx, %b_idx, %d;
7681 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7682 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7683 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7684 .reg .f32 %va, %vb, %vr;
7685 .reg .pred %p, %loop_p;
7686
7687 ld.param.u64 %a, [a_ptr];
7688 ld.param.u64 %b, [b_ptr];
7689 ld.param.u64 %out, [out_ptr];
7690 ld.param.u64 %a_str, [a_strides_ptr];
7691 ld.param.u64 %b_str, [b_strides_ptr];
7692 ld.param.u64 %oshape, [out_shape_ptr];
7693 ld.param.u32 %n_reg, [n];
7694 ld.param.u32 %ndim_reg, [ndim];
7695
7696 mov.u32 %bid, %ctaid.x;
7697 mov.u32 %bdim, %ntid.x;
7698 mov.u32 %r_tid, %tid.x;
7699 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7700 setp.ge.u32 %p, %r_tid, %n_reg;
7701 @%p bra DONE;
7702
7703 mov.u32 %remaining, %r_tid;
7704 mov.u32 %a_idx, 0;
7705 mov.u32 %b_idx, 0;
7706 mov.u32 %d, %ndim_reg;
7707LOOP:
7708 setp.eq.u32 %loop_p, %d, 0;
7709 @%loop_p bra END_LOOP;
7710 sub.u32 %d, %d, 1;
7711 cvt.u64.u32 %d64, %d;
7712 shl.b64 %d64, %d64, 2;
7713 add.u64 %tmp, %oshape, %d64;
7714 ld.global.u32 %shape_d, [%tmp];
7715 add.u64 %tmp, %a_str, %d64;
7716 ld.global.u32 %a_str_d, [%tmp];
7717 add.u64 %tmp, %b_str, %d64;
7718 ld.global.u32 %b_str_d, [%tmp];
7719 rem.u32 %coord, %remaining, %shape_d;
7720 div.u32 %remaining, %remaining, %shape_d;
7721 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7722 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7723 bra LOOP;
7724END_LOOP:
7725
7726 cvt.u64.u32 %off_a, %a_idx;
7727 shl.b64 %off_a, %off_a, 2;
7728 add.u64 %off_a, %a, %off_a;
7729 ld.global.f32 %va, [%off_a];
7730 cvt.u64.u32 %off_b, %b_idx;
7731 shl.b64 %off_b, %off_b, 2;
7732 add.u64 %off_b, %b, %off_b;
7733 ld.global.f32 %vb, [%off_b];
7734
7735 sub.f32 %vr, %va, %vb;
7736
7737 cvt.u64.u32 %off_out, %r_tid;
7738 shl.b64 %off_out, %off_out, 2;
7739 add.u64 %off_out, %out, %off_out;
7740 st.global.f32 [%off_out], %vr;
7741DONE:
7742 ret;
7743}
7744";
7745
7746
7747#[cfg(feature = "cuda")]
7749pub(crate) const BROADCAST_MUL_PTX: &str = "\
7750.version 7.0
7751.target sm_52
7752.address_size 64
7753
7754.visible .entry broadcast_mul_kernel(
7755 .param .u64 a_ptr,
7756 .param .u64 b_ptr,
7757 .param .u64 out_ptr,
7758 .param .u64 a_strides_ptr,
7759 .param .u64 b_strides_ptr,
7760 .param .u64 out_shape_ptr,
7761 .param .u32 n,
7762 .param .u32 ndim
7763) {
7764 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7765 .reg .u32 %remaining, %a_idx, %b_idx, %d;
7766 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7767 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7768 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7769 .reg .f32 %va, %vb, %vr;
7770 .reg .pred %p, %loop_p;
7771
7772 ld.param.u64 %a, [a_ptr];
7773 ld.param.u64 %b, [b_ptr];
7774 ld.param.u64 %out, [out_ptr];
7775 ld.param.u64 %a_str, [a_strides_ptr];
7776 ld.param.u64 %b_str, [b_strides_ptr];
7777 ld.param.u64 %oshape, [out_shape_ptr];
7778 ld.param.u32 %n_reg, [n];
7779 ld.param.u32 %ndim_reg, [ndim];
7780
7781 mov.u32 %bid, %ctaid.x;
7782 mov.u32 %bdim, %ntid.x;
7783 mov.u32 %r_tid, %tid.x;
7784 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7785 setp.ge.u32 %p, %r_tid, %n_reg;
7786 @%p bra DONE;
7787
7788 mov.u32 %remaining, %r_tid;
7789 mov.u32 %a_idx, 0;
7790 mov.u32 %b_idx, 0;
7791 mov.u32 %d, %ndim_reg;
7792LOOP:
7793 setp.eq.u32 %loop_p, %d, 0;
7794 @%loop_p bra END_LOOP;
7795 sub.u32 %d, %d, 1;
7796 cvt.u64.u32 %d64, %d;
7797 shl.b64 %d64, %d64, 2;
7798 add.u64 %tmp, %oshape, %d64;
7799 ld.global.u32 %shape_d, [%tmp];
7800 add.u64 %tmp, %a_str, %d64;
7801 ld.global.u32 %a_str_d, [%tmp];
7802 add.u64 %tmp, %b_str, %d64;
7803 ld.global.u32 %b_str_d, [%tmp];
7804 rem.u32 %coord, %remaining, %shape_d;
7805 div.u32 %remaining, %remaining, %shape_d;
7806 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7807 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7808 bra LOOP;
7809END_LOOP:
7810
7811 cvt.u64.u32 %off_a, %a_idx;
7812 shl.b64 %off_a, %off_a, 2;
7813 add.u64 %off_a, %a, %off_a;
7814 ld.global.f32 %va, [%off_a];
7815 cvt.u64.u32 %off_b, %b_idx;
7816 shl.b64 %off_b, %off_b, 2;
7817 add.u64 %off_b, %b, %off_b;
7818 ld.global.f32 %vb, [%off_b];
7819
7820 mul.f32 %vr, %va, %vb;
7821
7822 cvt.u64.u32 %off_out, %r_tid;
7823 shl.b64 %off_out, %off_out, 2;
7824 add.u64 %off_out, %out, %off_out;
7825 st.global.f32 [%off_out], %vr;
7826DONE:
7827 ret;
7828}
7829";
7830
7831
7832#[cfg(feature = "cuda")]
7835pub(crate) const BROADCAST_DIV_PTX: &str = "\
7836.version 7.0
7837.target sm_52
7838.address_size 64
7839
7840.visible .entry broadcast_div_kernel(
7841 .param .u64 a_ptr,
7842 .param .u64 b_ptr,
7843 .param .u64 out_ptr,
7844 .param .u64 a_strides_ptr,
7845 .param .u64 b_strides_ptr,
7846 .param .u64 out_shape_ptr,
7847 .param .u32 n,
7848 .param .u32 ndim
7849) {
7850 .reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
7851 .reg .u32 %remaining, %a_idx, %b_idx, %d;
7852 .reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
7853 .reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
7854 .reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
7855 .reg .f32 %va, %vb, %vr;
7856 .reg .pred %p, %loop_p;
7857
7858 ld.param.u64 %a, [a_ptr];
7859 ld.param.u64 %b, [b_ptr];
7860 ld.param.u64 %out, [out_ptr];
7861 ld.param.u64 %a_str, [a_strides_ptr];
7862 ld.param.u64 %b_str, [b_strides_ptr];
7863 ld.param.u64 %oshape, [out_shape_ptr];
7864 ld.param.u32 %n_reg, [n];
7865 ld.param.u32 %ndim_reg, [ndim];
7866
7867 mov.u32 %bid, %ctaid.x;
7868 mov.u32 %bdim, %ntid.x;
7869 mov.u32 %r_tid, %tid.x;
7870 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7871 setp.ge.u32 %p, %r_tid, %n_reg;
7872 @%p bra DONE;
7873
7874 mov.u32 %remaining, %r_tid;
7875 mov.u32 %a_idx, 0;
7876 mov.u32 %b_idx, 0;
7877 mov.u32 %d, %ndim_reg;
7878LOOP:
7879 setp.eq.u32 %loop_p, %d, 0;
7880 @%loop_p bra END_LOOP;
7881 sub.u32 %d, %d, 1;
7882 cvt.u64.u32 %d64, %d;
7883 shl.b64 %d64, %d64, 2;
7884 add.u64 %tmp, %oshape, %d64;
7885 ld.global.u32 %shape_d, [%tmp];
7886 add.u64 %tmp, %a_str, %d64;
7887 ld.global.u32 %a_str_d, [%tmp];
7888 add.u64 %tmp, %b_str, %d64;
7889 ld.global.u32 %b_str_d, [%tmp];
7890 rem.u32 %coord, %remaining, %shape_d;
7891 div.u32 %remaining, %remaining, %shape_d;
7892 mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
7893 mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
7894 bra LOOP;
7895END_LOOP:
7896
7897 cvt.u64.u32 %off_a, %a_idx;
7898 shl.b64 %off_a, %off_a, 2;
7899 add.u64 %off_a, %a, %off_a;
7900 ld.global.f32 %va, [%off_a];
7901 cvt.u64.u32 %off_b, %b_idx;
7902 shl.b64 %off_b, %off_b, 2;
7903 add.u64 %off_b, %b, %off_b;
7904 ld.global.f32 %vb, [%off_b];
7905
7906 div.f32 %vr, %va, %vb;
7907
7908 cvt.u64.u32 %off_out, %r_tid;
7909 shl.b64 %off_out, %off_out, 2;
7910 add.u64 %off_out, %out, %off_out;
7911 st.global.f32 [%off_out], %vr;
7912DONE:
7913 ret;
7914}
7915";
7916
7917
7918#[cfg(feature = "cuda")]
7926pub(crate) const STRIDED_SPLIT_PTX: &str = "\
7927.version 7.0
7928.target sm_52
7929.address_size 64
7930
7931.visible .entry strided_split_kernel(
7932 .param .u64 input_ptr,
7933 .param .u64 output_ptr,
7934 .param .u32 total_along_axis,
7935 .param .u32 split_offset,
7936 .param .u32 split_size,
7937 .param .u32 inner_size,
7938 .param .u32 n
7939) {
7940 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
7941 .reg .u32 %total_ax, %sp_off, %sp_sz, %inner_sz;
7942 .reg .u32 %outer_idx, %within, %chunk_stride, %src_idx, %base_off, %tmp;
7943 .reg .u64 %in, %out, %off;
7944 .reg .f32 %val;
7945 .reg .pred %p;
7946
7947 ld.param.u64 %in, [input_ptr];
7948 ld.param.u64 %out, [output_ptr];
7949 ld.param.u32 %total_ax, [total_along_axis];
7950 ld.param.u32 %sp_off, [split_offset];
7951 ld.param.u32 %sp_sz, [split_size];
7952 ld.param.u32 %inner_sz, [inner_size];
7953 ld.param.u32 %n_reg, [n];
7954
7955 mov.u32 %bid, %ctaid.x;
7956 mov.u32 %bdim, %ntid.x;
7957 mov.u32 %r_tid, %tid.x;
7958 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
7959
7960 setp.ge.u32 %p, %r_tid, %n_reg;
7961 @%p bra DONE;
7962
7963 // chunk_stride = split_size * inner_size
7964 mul.lo.u32 %chunk_stride, %sp_sz, %inner_sz;
7965
7966 // outer_idx = r_tid / chunk_stride
7967 div.u32 %outer_idx, %r_tid, %chunk_stride;
7968
7969 // within = r_tid % chunk_stride
7970 rem.u32 %within, %r_tid, %chunk_stride;
7971
7972 // base_off = split_offset * inner_size
7973 mul.lo.u32 %base_off, %sp_off, %inner_sz;
7974
7975 // src_idx = outer_idx * total_along_axis * inner_size + base_off + within
7976 mul.lo.u32 %src_idx, %outer_idx, %total_ax;
7977 mul.lo.u32 %src_idx, %src_idx, %inner_sz;
7978 add.u32 %src_idx, %src_idx, %base_off;
7979 add.u32 %src_idx, %src_idx, %within;
7980
7981 // Load from in[src_idx]
7982 cvt.u64.u32 %off, %src_idx;
7983 shl.b64 %off, %off, 2;
7984 add.u64 %off, %in, %off;
7985 ld.global.f32 %val, [%off];
7986
7987 // Store to out[r_tid]
7988 cvt.u64.u32 %off, %r_tid;
7989 shl.b64 %off, %off, 2;
7990 add.u64 %off, %out, %off;
7991 st.global.f32 [%off], %val;
7992
7993DONE:
7994 ret;
7995}
7996";
7997
7998
7999#[cfg(feature = "cuda")]
8008pub(crate) const STRIDED_CAT_PTX: &str = "\
8009.version 7.0
8010.target sm_52
8011.address_size 64
8012
8013.visible .entry strided_cat_kernel(
8014 .param .u64 input_ptr,
8015 .param .u64 output_ptr,
8016 .param .u32 total_along_axis,
8017 .param .u32 cat_offset,
8018 .param .u32 part_size,
8019 .param .u32 inner_size,
8020 .param .u32 n
8021) {
8022 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8023 .reg .u32 %total_ax, %cat_off, %part_sz, %inner_sz;
8024 .reg .u32 %outer_idx, %within, %chunk_stride, %dst_idx, %base_off;
8025 .reg .u64 %in, %out, %off;
8026 .reg .f32 %val;
8027 .reg .pred %p;
8028
8029 ld.param.u64 %in, [input_ptr];
8030 ld.param.u64 %out, [output_ptr];
8031 ld.param.u32 %total_ax, [total_along_axis];
8032 ld.param.u32 %cat_off, [cat_offset];
8033 ld.param.u32 %part_sz, [part_size];
8034 ld.param.u32 %inner_sz, [inner_size];
8035 ld.param.u32 %n_reg, [n];
8036
8037 mov.u32 %bid, %ctaid.x;
8038 mov.u32 %bdim, %ntid.x;
8039 mov.u32 %r_tid, %tid.x;
8040 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8041
8042 setp.ge.u32 %p, %r_tid, %n_reg;
8043 @%p bra DONE;
8044
8045 // chunk_stride = part_size * inner_size
8046 mul.lo.u32 %chunk_stride, %part_sz, %inner_sz;
8047
8048 // outer_idx = r_tid / chunk_stride
8049 div.u32 %outer_idx, %r_tid, %chunk_stride;
8050
8051 // within = r_tid % chunk_stride
8052 rem.u32 %within, %r_tid, %chunk_stride;
8053
8054 // base_off = cat_offset * inner_size
8055 mul.lo.u32 %base_off, %cat_off, %inner_sz;
8056
8057 // dst_idx = outer_idx * total_along_axis * inner_size + base_off + within
8058 mul.lo.u32 %dst_idx, %outer_idx, %total_ax;
8059 mul.lo.u32 %dst_idx, %dst_idx, %inner_sz;
8060 add.u32 %dst_idx, %dst_idx, %base_off;
8061 add.u32 %dst_idx, %dst_idx, %within;
8062
8063 // Load from in[r_tid]
8064 cvt.u64.u32 %off, %r_tid;
8065 shl.b64 %off, %off, 2;
8066 add.u64 %off, %in, %off;
8067 ld.global.f32 %val, [%off];
8068
8069 // Store to out[dst_idx]
8070 cvt.u64.u32 %off, %dst_idx;
8071 shl.b64 %off, %off, 2;
8072 add.u64 %off, %out, %off;
8073 st.global.f32 [%off], %val;
8074
8075DONE:
8076 ret;
8077}
8078";
8079
8080
8081#[cfg(feature = "cuda")]
8101pub(crate) const STRIDED_COPY_PTX: &str = "\
8102.version 7.0
8103.target sm_52
8104.address_size 64
8105
8106.visible .entry strided_copy_kernel(
8107 .param .u64 input_ptr,
8108 .param .u64 output_ptr,
8109 .param .u32 src_offset_base,
8110 .param .u32 n,
8111 .param .u32 os0, .param .u32 os1, .param .u32 os2, .param .u32 os3,
8112 .param .u32 os4, .param .u32 os5, .param .u32 os6, .param .u32 os7,
8113 .param .u32 ss0, .param .u32 ss1, .param .u32 ss2, .param .u32 ss3,
8114 .param .u32 ss4, .param .u32 ss5, .param .u32 ss6, .param .u32 ss7
8115) {
8116 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8117 .reg .u32 %flat, %src_idx, %coord, %tmp, %os, %ss;
8118 .reg .u64 %in, %out, %off;
8119 .reg .f32 %val;
8120 .reg .pred %p;
8121
8122 ld.param.u64 %in, [input_ptr];
8123 ld.param.u64 %out, [output_ptr];
8124 ld.param.u32 %src_idx, [src_offset_base];
8125 ld.param.u32 %n_reg, [n];
8126
8127 mov.u32 %bid, %ctaid.x;
8128 mov.u32 %bdim, %ntid.x;
8129 mov.u32 %r_tid, %tid.x;
8130 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8131
8132 setp.ge.u32 %p, %r_tid, %n_reg;
8133 @%p bra DONE;
8134
8135 mov.u32 %flat, %r_tid;
8136
8137 // Dim 0
8138 ld.param.u32 %os, [os0];
8139 ld.param.u32 %ss, [ss0];
8140 div.u32 %coord, %flat, %os;
8141 mul.lo.u32 %tmp, %coord, %os;
8142 sub.u32 %flat, %flat, %tmp;
8143 mul.lo.u32 %tmp, %coord, %ss;
8144 add.u32 %src_idx, %src_idx, %tmp;
8145
8146 // Dim 1
8147 ld.param.u32 %os, [os1];
8148 ld.param.u32 %ss, [ss1];
8149 div.u32 %coord, %flat, %os;
8150 mul.lo.u32 %tmp, %coord, %os;
8151 sub.u32 %flat, %flat, %tmp;
8152 mul.lo.u32 %tmp, %coord, %ss;
8153 add.u32 %src_idx, %src_idx, %tmp;
8154
8155 // Dim 2
8156 ld.param.u32 %os, [os2];
8157 ld.param.u32 %ss, [ss2];
8158 div.u32 %coord, %flat, %os;
8159 mul.lo.u32 %tmp, %coord, %os;
8160 sub.u32 %flat, %flat, %tmp;
8161 mul.lo.u32 %tmp, %coord, %ss;
8162 add.u32 %src_idx, %src_idx, %tmp;
8163
8164 // Dim 3
8165 ld.param.u32 %os, [os3];
8166 ld.param.u32 %ss, [ss3];
8167 div.u32 %coord, %flat, %os;
8168 mul.lo.u32 %tmp, %coord, %os;
8169 sub.u32 %flat, %flat, %tmp;
8170 mul.lo.u32 %tmp, %coord, %ss;
8171 add.u32 %src_idx, %src_idx, %tmp;
8172
8173 // Dim 4
8174 ld.param.u32 %os, [os4];
8175 ld.param.u32 %ss, [ss4];
8176 div.u32 %coord, %flat, %os;
8177 mul.lo.u32 %tmp, %coord, %os;
8178 sub.u32 %flat, %flat, %tmp;
8179 mul.lo.u32 %tmp, %coord, %ss;
8180 add.u32 %src_idx, %src_idx, %tmp;
8181
8182 // Dim 5
8183 ld.param.u32 %os, [os5];
8184 ld.param.u32 %ss, [ss5];
8185 div.u32 %coord, %flat, %os;
8186 mul.lo.u32 %tmp, %coord, %os;
8187 sub.u32 %flat, %flat, %tmp;
8188 mul.lo.u32 %tmp, %coord, %ss;
8189 add.u32 %src_idx, %src_idx, %tmp;
8190
8191 // Dim 6
8192 ld.param.u32 %os, [os6];
8193 ld.param.u32 %ss, [ss6];
8194 div.u32 %coord, %flat, %os;
8195 mul.lo.u32 %tmp, %coord, %os;
8196 sub.u32 %flat, %flat, %tmp;
8197 mul.lo.u32 %tmp, %coord, %ss;
8198 add.u32 %src_idx, %src_idx, %tmp;
8199
8200 // Dim 7
8201 ld.param.u32 %os, [os7];
8202 ld.param.u32 %ss, [ss7];
8203 div.u32 %coord, %flat, %os;
8204 mul.lo.u32 %tmp, %coord, %os;
8205 sub.u32 %flat, %flat, %tmp;
8206 mul.lo.u32 %tmp, %coord, %ss;
8207 add.u32 %src_idx, %src_idx, %tmp;
8208
8209 // Load from in[src_idx]
8210 cvt.u64.u32 %off, %src_idx;
8211 shl.b64 %off, %off, 2;
8212 add.u64 %off, %in, %off;
8213 ld.global.f32 %val, [%off];
8214
8215 // Store to out[r_tid]
8216 cvt.u64.u32 %off, %r_tid;
8217 shl.b64 %off, %off, 2;
8218 add.u64 %off, %out, %off;
8219 st.global.f32 [%off], %val;
8220
8221DONE:
8222 ret;
8223}
8224";
8225
8226
8227#[cfg(feature = "cuda")]
8229pub(crate) const DIV_PTX: &str = "\
8230.version 7.0
8231.target sm_52
8232.address_size 64
8233
8234.visible .entry div_kernel(
8235 .param .u64 a_ptr,
8236 .param .u64 b_ptr,
8237 .param .u64 out_ptr,
8238 .param .u32 n
8239) {
8240 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8241 .reg .u64 %a, %b, %out, %off;
8242 .reg .f32 %va, %vb, %vr;
8243 .reg .pred %p;
8244
8245 ld.param.u64 %a, [a_ptr];
8246 ld.param.u64 %b, [b_ptr];
8247 ld.param.u64 %out, [out_ptr];
8248 ld.param.u32 %n_reg, [n];
8249
8250 mov.u32 %bid, %ctaid.x;
8251 mov.u32 %bdim, %ntid.x;
8252 mov.u32 %r_tid, %tid.x;
8253 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8254
8255 setp.ge.u32 %p, %r_tid, %n_reg;
8256 @%p bra DONE;
8257
8258 cvt.u64.u32 %off, %r_tid;
8259 shl.b64 %off, %off, 2;
8260
8261 add.u64 %a, %a, %off;
8262 add.u64 %b, %b, %off;
8263 add.u64 %out, %out, %off;
8264
8265 ld.global.f32 %va, [%a];
8266 ld.global.f32 %vb, [%b];
8267 div.rn.f32 %vr, %va, %vb;
8268 st.global.f32 [%out], %vr;
8269
8270DONE:
8271 ret;
8272}
8273";
8274
8275
8276#[cfg(feature = "cuda")]
8278pub(crate) const EXP_PTX: &str = "\
8279.version 7.0
8280.target sm_52
8281.address_size 64
8282
8283.visible .entry exp_kernel(
8284 .param .u64 a_ptr,
8285 .param .u64 out_ptr,
8286 .param .u32 n
8287) {
8288 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8289 .reg .u64 %a, %out, %off;
8290 .reg .f32 %va, %vr;
8291 .reg .pred %p;
8292
8293 ld.param.u64 %a, [a_ptr];
8294 ld.param.u64 %out, [out_ptr];
8295 ld.param.u32 %n_reg, [n];
8296
8297 mov.u32 %bid, %ctaid.x;
8298 mov.u32 %bdim, %ntid.x;
8299 mov.u32 %r_tid, %tid.x;
8300 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8301
8302 setp.ge.u32 %p, %r_tid, %n_reg;
8303 @%p bra DONE;
8304
8305 cvt.u64.u32 %off, %r_tid;
8306 shl.b64 %off, %off, 2;
8307
8308 add.u64 %a, %a, %off;
8309 add.u64 %out, %out, %off;
8310
8311 ld.global.f32 %va, [%a];
8312 // PTX ex2.approx computes 2^x; use the identity exp(x) = 2^(x * log2(e))
8313 // log2(e) = 1.4426950408889634
8314 mul.f32 %va, %va, 0f3FB8AA3B;
8315 ex2.approx.f32 %vr, %va;
8316 st.global.f32 [%out], %vr;
8317
8318DONE:
8319 ret;
8320}
8321";
8322
8323#[cfg(feature = "cuda")]
8327pub(crate) const EXP_F64_PTX: &str = "\
8336.version 7.0
8337.target sm_52
8338.address_size 64
8339
8340.visible .entry exp_f64_kernel(
8341 .param .u64 a_ptr,
8342 .param .u64 out_ptr,
8343 .param .u32 n
8344) {
8345 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8346 .reg .u64 %a, %out, %off;
8347 .reg .f64 %x, %vr;
8348 .reg .f64 %log2e, %nf, %r;
8349 .reg .f64 %p, %one, %half;
8350 .reg .s32 %ni;
8351 .reg .s64 %ni64, %exp_bits;
8352 .reg .pred %p_bounds, %p_tid;
8353
8354 ld.param.u64 %a, [a_ptr];
8355 ld.param.u64 %out, [out_ptr];
8356 ld.param.u32 %n_reg, [n];
8357
8358 mov.u32 %bid, %ctaid.x;
8359 mov.u32 %bdim, %ntid.x;
8360 mov.u32 %r_tid, %tid.x;
8361 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8362
8363 setp.ge.u32 %p_tid, %r_tid, %n_reg;
8364 @%p_tid bra DONE;
8365
8366 cvt.u64.u32 %off, %r_tid;
8367 shl.b64 %off, %off, 3;
8368 add.u64 %a, %a, %off;
8369 add.u64 %out, %out, %off;
8370
8371 ld.global.f64 %x, [%a];
8372
8373 // Constants
8374 mov.f64 %log2e, 0d3FF71547652B82FE; // log2(e) = 1.4426950408889634
8375 mov.f64 %ln2_hi, 0d3FE62E42FEFA3800; // ln(2) high bits
8376 mov.f64 %ln2_lo, 0d3D2EF35793C76730; // ln(2) low bits
8377 mov.f64 %one, 0d3FF0000000000000; // 1.0
8378 mov.f64 %half, 0d3FE0000000000000; // 0.5
8379
8380 // n = round(x * log2(e))
8381 mul.f64 %nf, %x, %log2e;
8382 cvt.rni.f64.f64 %nf, %nf; // round to nearest integer
8383 cvt.rni.s32.f64 %ni, %nf; // integer n
8384
8385 // r = x - n * ln2 (Cody-Waite two-step for precision)
8386 fma.rn.f64 %r, %nf, 0dBFE62E42FEFA3800, %x; // r = x - n*ln2_hi
8387 fma.rn.f64 %r, %nf, 0dBD2EF35793C76730, %r; // r -= n*ln2_lo
8388
8389 // Horner polynomial for exp(r) - 1 - r = r^2 * (1/2! + r*(1/3! + r*(1/4! + ...)))
8390 // p starts at 1/11!, accumulates down to 1/2!
8391 mov.f64 %p, 0d3E21EED8EFF8D898; // 1/11! = 2.505e-8
8392 fma.rn.f64 %p, %p, %r, 0d3E5AE64567F544E4; // 1/10! = 2.756e-7
8393 fma.rn.f64 %p, %p, %r, 0d3E927E4FB7789F5C; // 1/9! = 2.756e-6
8394 fma.rn.f64 %p, %p, %r, 0d3EC71DE3A556C734; // 1/8! = 2.480e-5
8395 fma.rn.f64 %p, %p, %r, 0d3EFA01A01A01A01A; // 1/7! = 1.984e-4
8396 fma.rn.f64 %p, %p, %r, 0d3F2A01A01A01A01A; // 1/6! = 1.389e-3
8397 fma.rn.f64 %p, %p, %r, 0d3F56C16C16C16C17; // 1/5! = 8.333e-3
8398 fma.rn.f64 %p, %p, %r, 0d3F811111111111111; // 1/4! = 4.167e-2
8399 fma.rn.f64 %p, %p, %r, 0d3FC5555555555555; // 1/3! = 1.667e-1
8400 fma.rn.f64 %p, %p, %r, %half; // 1/2! = 5.000e-1
8401
8402 // exp(r) = 1 + r + r^2 * p => 1 + r*(1 + r*p)
8403 fma.rn.f64 %p, %p, %r, %one; // p = r*p + 1
8404 fma.rn.f64 %vr, %p, %r, %one; // vr = p*r + 1 = exp(r)
8405
8406 // Scale by 2^n: multiply by constructing the f64 bit pattern for 2^n.
8407 // IEEE 754 f64: 2^n has exponent field = n + 1023, no mantissa bits.
8408 // Bit pattern: (n + 1023) << 52.
8409 cvt.s64.s32 %ni64, %ni;
8410 add.s64 %ni64, %ni64, 1023;
8411 shl.b64 %exp_bits, %ni64, 52;
8412 mov.b64 %nf, %exp_bits; // reinterpret as f64 = 2^n
8413 mul.f64 %vr, %vr, %nf;
8414
8415 st.global.f64 [%out], %vr;
8416
8417DONE:
8418 ret;
8419}
8420";
8421
8422#[cfg(feature = "cuda")]
8424pub(crate) const LOG_PTX: &str = "\
8425.version 7.0
8426.target sm_52
8427.address_size 64
8428
8429.visible .entry log_kernel(
8430 .param .u64 a_ptr,
8431 .param .u64 out_ptr,
8432 .param .u32 n
8433) {
8434 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8435 .reg .u64 %a, %out, %off;
8436 .reg .f32 %va, %vr;
8437 .reg .pred %p;
8438
8439 ld.param.u64 %a, [a_ptr];
8440 ld.param.u64 %out, [out_ptr];
8441 ld.param.u32 %n_reg, [n];
8442
8443 mov.u32 %bid, %ctaid.x;
8444 mov.u32 %bdim, %ntid.x;
8445 mov.u32 %r_tid, %tid.x;
8446 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8447
8448 setp.ge.u32 %p, %r_tid, %n_reg;
8449 @%p bra DONE;
8450
8451 cvt.u64.u32 %off, %r_tid;
8452 shl.b64 %off, %off, 2;
8453
8454 add.u64 %a, %a, %off;
8455 add.u64 %out, %out, %off;
8456
8457 ld.global.f32 %va, [%a];
8458 // PTX lg2.approx computes log2(x); use the identity ln(x) = log2(x) / log2(e)
8459 // 1/log2(e) = ln(2) = 0.6931471805599453
8460 lg2.approx.f32 %vr, %va;
8461 mul.f32 %vr, %vr, 0f3F317218;
8462 st.global.f32 [%out], %vr;
8463
8464DONE:
8465 ret;
8466}
8467";
8468
8469#[cfg(feature = "cuda")]
8473pub(crate) const LOG_F64_PTX: &str = "\
8482.version 7.0
8483.target sm_52
8484.address_size 64
8485
8486.visible .entry log_f64_kernel(
8487 .param .u64 a_ptr,
8488 .param .u64 out_ptr,
8489 .param .u32 n
8490) {
8491 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8492 .reg .u64 %a, %out, %off;
8493 .reg .u64 %xbits, %mantissa_bits, %bias_bits;
8494 .reg .f64 %x, %vr, %m, %f, %f2, %s, %p;
8495 .reg .f64 %ln2_hi, %ln2_lo, %one, %two;
8496 .reg .s32 %exp_i;
8497 .reg .s64 %exp64;
8498 .reg .f64 %nf;
8499 .reg .pred %p_tid;
8500
8501 ld.param.u64 %a, [a_ptr];
8502 ld.param.u64 %out, [out_ptr];
8503 ld.param.u32 %n_reg, [n];
8504
8505 mov.u32 %bid, %ctaid.x;
8506 mov.u32 %bdim, %ntid.x;
8507 mov.u32 %r_tid, %tid.x;
8508 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8509
8510 setp.ge.u32 %p_tid, %r_tid, %n_reg;
8511 @%p_tid bra DONE;
8512
8513 cvt.u64.u32 %off, %r_tid;
8514 shl.b64 %off, %off, 3;
8515 add.u64 %a, %a, %off;
8516 add.u64 %out, %out, %off;
8517
8518 ld.global.f64 %x, [%a];
8519
8520 mov.f64 %ln2_hi, 0d3FE62E42FEFA39EF; // ln(2) = 0.6931471805599453
8521 mov.f64 %one, 0d3FF0000000000000;
8522 mov.f64 %two, 0d4000000000000000;
8523
8524 // Extract exponent: n = exponent_field - 1023
8525 mov.b64 %xbits, %x;
8526 shr.u64 %exp64, %xbits, 52;
8527 and.b64 %exp64, %exp64, 2047; // 11-bit exponent field
8528 sub.s64 %exp64, %exp64, 1023;
8529 cvt.rn.f64.s64 %nf, %exp64; // n as f64
8530
8531 // Extract mantissa m: set exponent to 1023 (so m is in [1, 2))
8532 mov.u64 %bias_bits, 0x3FF0000000000000; // exponent = 1023
8533 and.b64 %mantissa_bits, %xbits, 0x000FFFFFFFFFFFFF; // mantissa bits
8534 or.b64 %mantissa_bits, %mantissa_bits, %bias_bits;
8535 mov.b64 %m, %mantissa_bits; // m in [1.0, 2.0)
8536
8537 // f = (m - 1) / (m + 1) — maps [1,2) to [0, 1/3)
8538 sub.f64 %f, %m, %one;
8539 add.f64 %s, %m, %one;
8540 div.rn.f64 %f, %f, %s;
8541
8542 // ln(m) = 2*f + 2*f^3/3 + 2*f^5/5 + 2*f^7/7 + 2*f^9/9 + 2*f^11/11
8543 // Horner: ln(m) = 2*f*(1 + f^2*(1/3 + f^2*(1/5 + f^2*(1/7 + f^2*(1/9 + f^2/11)))))
8544 mul.f64 %f2, %f, %f;
8545
8546 // p = 1/11
8547 mov.f64 %p, 0d3FB745D1745D1746;
8548 // p = p*f2 + 1/9
8549 fma.rn.f64 %p, %p, %f2, 0d3FC1C71C71C71C72;
8550 // p = p*f2 + 1/7
8551 fma.rn.f64 %p, %p, %f2, 0d3FC2492492492492;
8552 // p = p*f2 + 1/5
8553 fma.rn.f64 %p, %p, %f2, 0d3FC999999999999A;
8554 // p = p*f2 + 1/3
8555 fma.rn.f64 %p, %p, %f2, 0d3FD5555555555555;
8556 // p = p*f2 + 1
8557 fma.rn.f64 %p, %p, %f2, %one;
8558
8559 // ln(m) = 2*f*p
8560 mul.f64 %p, %p, %f;
8561 add.f64 %p, %p, %p; // * 2
8562
8563 // ln(x) = n*ln(2) + ln(m)
8564 fma.rn.f64 %vr, %nf, %ln2_hi, %p;
8565
8566 st.global.f64 [%out], %vr;
8567
8568DONE:
8569 ret;
8570}
8571";
8572
8573#[cfg(feature = "cuda")]
8575pub(crate) const SQRT_PTX: &str = "\
8576.version 7.0
8577.target sm_52
8578.address_size 64
8579
8580.visible .entry sqrt_kernel(
8581 .param .u64 a_ptr,
8582 .param .u64 out_ptr,
8583 .param .u32 n
8584) {
8585 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8586 .reg .u64 %a, %out, %off;
8587 .reg .f32 %va, %vr;
8588 .reg .pred %p;
8589
8590 ld.param.u64 %a, [a_ptr];
8591 ld.param.u64 %out, [out_ptr];
8592 ld.param.u32 %n_reg, [n];
8593
8594 mov.u32 %bid, %ctaid.x;
8595 mov.u32 %bdim, %ntid.x;
8596 mov.u32 %r_tid, %tid.x;
8597 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8598
8599 setp.ge.u32 %p, %r_tid, %n_reg;
8600 @%p bra DONE;
8601
8602 cvt.u64.u32 %off, %r_tid;
8603 shl.b64 %off, %off, 2;
8604
8605 add.u64 %a, %a, %off;
8606 add.u64 %out, %out, %off;
8607
8608 ld.global.f32 %va, [%a];
8609 sqrt.rn.f32 %vr, %va;
8610 st.global.f32 [%out], %vr;
8611
8612DONE:
8613 ret;
8614}
8615";
8616
8617
8618#[cfg(feature = "cuda")]
8621pub(crate) const POW_PTX: &str = "\
8622.version 7.0
8623.target sm_52
8624.address_size 64
8625
8626.visible .entry pow_kernel(
8627 .param .u64 a_ptr,
8628 .param .u64 out_ptr,
8629 .param .f32 exponent,
8630 .param .u32 n
8631) {
8632 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8633 .reg .u64 %a, %out, %off;
8634 .reg .f32 %va, %vr, %exp, %lg;
8635 .reg .pred %p;
8636
8637 ld.param.u64 %a, [a_ptr];
8638 ld.param.u64 %out, [out_ptr];
8639 ld.param.f32 %exp, [exponent];
8640 ld.param.u32 %n_reg, [n];
8641
8642 mov.u32 %bid, %ctaid.x;
8643 mov.u32 %bdim, %ntid.x;
8644 mov.u32 %r_tid, %tid.x;
8645 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8646
8647 setp.ge.u32 %p, %r_tid, %n_reg;
8648 @%p bra DONE;
8649
8650 cvt.u64.u32 %off, %r_tid;
8651 shl.b64 %off, %off, 2;
8652
8653 add.u64 %a, %a, %off;
8654 add.u64 %out, %out, %off;
8655
8656 ld.global.f32 %va, [%a];
8657 // x^e = 2^(e * log2(x))
8658 lg2.approx.f32 %lg, %va;
8659 mul.f32 %lg, %lg, %exp;
8660 ex2.approx.f32 %vr, %lg;
8661 st.global.f32 [%out], %vr;
8662
8663DONE:
8664 ret;
8665}
8666";
8667
8668#[cfg(feature = "cuda")]
8673pub(crate) const POW_F64_PTX: &str = "\
8674.version 7.0
8675.target sm_52
8676.address_size 64
8677
8678.visible .entry pow_f64_kernel(
8679 .param .u64 a_ptr,
8680 .param .u64 out_ptr,
8681 .param .f64 exponent,
8682 .param .u32 n
8683) {
8684 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8685 .reg .u64 %a, %out, %off;
8686 .reg .f64 %va, %vr, %exp64, %one, %two;
8687 // log registers
8688 .reg .u64 %l_xbits, %l_mbits, %l_bias;
8689 .reg .s64 %l_exp64;
8690 .reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf, %l_ln2, %l_lnx;
8691 // exp registers
8692 .reg .f64 %e_z, %e_nf, %e_r, %e_p, %e_half;
8693 .reg .s32 %e_ni;
8694 .reg .s64 %e_ni64, %e_bits;
8695 .reg .pred %p;
8696
8697 ld.param.u64 %a, [a_ptr];
8698 ld.param.u64 %out, [out_ptr];
8699 ld.param.f64 %exp64, [exponent];
8700 ld.param.u32 %n_reg, [n];
8701
8702 mov.u32 %bid, %ctaid.x;
8703 mov.u32 %bdim, %ntid.x;
8704 mov.u32 %r_tid, %tid.x;
8705 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8706
8707 setp.ge.u32 %p, %r_tid, %n_reg;
8708 @%p bra DONE;
8709
8710 cvt.u64.u32 %off, %r_tid;
8711 shl.b64 %off, %off, 3;
8712
8713 add.u64 %a, %a, %off;
8714 add.u64 %out, %out, %off;
8715
8716 ld.global.f64 %va, [%a];
8717 mov.f64 %one, 0d3FF0000000000000;
8718 mov.f64 %two, 0d4000000000000000;
8719
8720 // === ln(va) via argument reduction ===
8721 // Decompose va = 2^n * m, m in [1,2), ln(va) = n*ln(2) + ln(m)
8722 mov.b64 %l_xbits, %va;
8723 shr.u64 %l_exp64, %l_xbits, 52;
8724 and.b64 %l_exp64, %l_exp64, 2047;
8725 sub.s64 %l_exp64, %l_exp64, 1023;
8726 cvt.rn.f64.s64 %l_nf, %l_exp64;
8727
8728 mov.u64 %l_bias, 0x3FF0000000000000;
8729 and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
8730 or.b64 %l_mbits, %l_mbits, %l_bias;
8731 mov.b64 %l_m, %l_mbits;
8732
8733 // f = (m-1)/(m+1)
8734 sub.f64 %l_f, %l_m, %one;
8735 add.f64 %l_s, %l_m, %one;
8736 div.rn.f64 %l_f, %l_f, %l_s;
8737 mul.f64 %l_f2, %l_f, %l_f;
8738
8739 // Horner: p = 1/11 + f2*(1/9 + f2*(1/7 + f2*(1/5 + f2*(1/3 + f2*1))))
8740 mov.f64 %l_p, 0d3FB745D1745D1746;
8741 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC1C71C71C71C72;
8742 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;
8743 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;
8744 fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;
8745 fma.rn.f64 %l_p, %l_p, %l_f2, %one;
8746
8747 // ln(m) = 2*f*p
8748 mul.f64 %l_p, %l_p, %l_f;
8749 add.f64 %l_p, %l_p, %l_p;
8750
8751 // ln(x) = n*ln(2) + ln(m)
8752 mov.f64 %l_ln2, 0d3FE62E42FEFA39EF;
8753 fma.rn.f64 %l_lnx, %l_nf, %l_ln2, %l_p;
8754
8755 // === exp(exponent * ln(x)) ===
8756 mul.f64 %e_z, %exp64, %l_lnx;
8757
8758 mov.f64 %e_half, 0d3FE0000000000000;
8759 fma.rn.f64 %e_nf, %e_z, 0d3FF71547652B82FE, %e_half;
8760 cvt.rmi.f64.f64 %e_nf, %e_nf;
8761 cvt.rni.s32.f64 %e_ni, %e_nf;
8762 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %e_z;
8763 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
8764 mov.f64 %e_p, 0d3E21EED8EFF8D898;
8765 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
8766 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
8767 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
8768 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
8769 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
8770 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
8771 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
8772 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
8773 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
8774 fma.rn.f64 %e_p, %e_p, %e_r, %one;
8775 fma.rn.f64 %vr, %e_p, %e_r, %one;
8776 cvt.s64.s32 %e_ni64, %e_ni;
8777 add.s64 %e_ni64, %e_ni64, 1023;
8778 shl.b64 %e_bits, %e_ni64, 52;
8779 mov.b64 %e_nf, %e_bits;
8780 mul.f64 %vr, %vr, %e_nf;
8781
8782 st.global.f64 [%out], %vr;
8783
8784DONE:
8785 ret;
8786}
8787";
8788
8789#[cfg(feature = "cuda")]
8791pub(crate) const ABS_PTX: &str = "\
8792.version 7.0
8793.target sm_52
8794.address_size 64
8795
8796.visible .entry abs_kernel(
8797 .param .u64 a_ptr,
8798 .param .u64 out_ptr,
8799 .param .u32 n
8800) {
8801 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8802 .reg .u64 %a, %out, %off;
8803 .reg .f32 %va, %vr;
8804 .reg .pred %p;
8805
8806 ld.param.u64 %a, [a_ptr];
8807 ld.param.u64 %out, [out_ptr];
8808 ld.param.u32 %n_reg, [n];
8809
8810 mov.u32 %bid, %ctaid.x;
8811 mov.u32 %bdim, %ntid.x;
8812 mov.u32 %r_tid, %tid.x;
8813 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8814
8815 setp.ge.u32 %p, %r_tid, %n_reg;
8816 @%p bra DONE;
8817
8818 cvt.u64.u32 %off, %r_tid;
8819 shl.b64 %off, %off, 2;
8820
8821 add.u64 %a, %a, %off;
8822 add.u64 %out, %out, %off;
8823
8824 ld.global.f32 %va, [%a];
8825 abs.f32 %vr, %va;
8826 st.global.f32 [%out], %vr;
8827
8828DONE:
8829 ret;
8830}
8831";
8832
8833
8834#[cfg(feature = "cuda")]
8836pub(crate) const SIGMOID_PTX: &str = "\
8837.version 7.0
8838.target sm_52
8839.address_size 64
8840
8841.visible .entry sigmoid_kernel(
8842 .param .u64 a_ptr,
8843 .param .u64 out_ptr,
8844 .param .u32 n
8845) {
8846 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8847 .reg .u64 %a, %out, %off;
8848 .reg .f32 %va, %vr, %neg, %e, %denom, %one, %lg2e;
8849 .reg .pred %p;
8850
8851 ld.param.u64 %a, [a_ptr];
8852 ld.param.u64 %out, [out_ptr];
8853 ld.param.u32 %n_reg, [n];
8854
8855 mov.u32 %bid, %ctaid.x;
8856 mov.u32 %bdim, %ntid.x;
8857 mov.u32 %r_tid, %tid.x;
8858 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8859
8860 setp.ge.u32 %p, %r_tid, %n_reg;
8861 @%p bra DONE;
8862
8863 cvt.u64.u32 %off, %r_tid;
8864 shl.b64 %off, %off, 2;
8865
8866 add.u64 %a, %a, %off;
8867 add.u64 %out, %out, %off;
8868
8869 ld.global.f32 %va, [%a];
8870 // sigmoid(x) = 1 / (1 + exp(-x))
8871 neg.f32 %neg, %va;
8872 mov.f32 %lg2e, 0f3FB8AA3B;
8873 mul.f32 %neg, %neg, %lg2e;
8874 ex2.approx.f32 %e, %neg;
8875 mov.f32 %one, 0f3F800000;
8876 add.f32 %denom, %one, %e;
8877 div.rn.f32 %vr, %one, %denom;
8878 st.global.f32 [%out], %vr;
8879
8880DONE:
8881 ret;
8882}
8883";
8884
8885#[cfg(feature = "cuda")]
8889pub(crate) const SIGMOID_F64_PTX: &str = "\
8890.version 7.0
8891.target sm_52
8892.address_size 64
8893
8894.visible .entry sigmoid_f64_kernel(
8895 .param .u64 a_ptr,
8896 .param .u64 out_ptr,
8897 .param .u32 n
8898) {
8899 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8900 .reg .u64 %a, %out, %off;
8901 .reg .f64 %va, %vr, %e64, %denom, %one, %neg_x;
8902 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
8903 .reg .s32 %e_ni;
8904 .reg .s64 %e_ni64, %e_bits;
8905 .reg .pred %p;
8906
8907 ld.param.u64 %a, [a_ptr];
8908 ld.param.u64 %out, [out_ptr];
8909 ld.param.u32 %n_reg, [n];
8910
8911 mov.u32 %bid, %ctaid.x;
8912 mov.u32 %bdim, %ntid.x;
8913 mov.u32 %r_tid, %tid.x;
8914 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8915
8916 setp.ge.u32 %p, %r_tid, %n_reg;
8917 @%p bra DONE;
8918
8919 cvt.u64.u32 %off, %r_tid;
8920 shl.b64 %off, %off, 3;
8921
8922 add.u64 %a, %a, %off;
8923 add.u64 %out, %out, %off;
8924
8925 ld.global.f64 %va, [%a];
8926 mov.f64 %one, 0d3FF0000000000000;
8927
8928 // sigmoid(x) = 1 / (1 + exp(-x))
8929 neg.f64 %neg_x, %va;
8930
8931 // --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
8932 mov.f64 %e_half, 0d3FE0000000000000;
8933 fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
8934 cvt.rmi.f64.f64 %e_nf, %e_nf;
8935 cvt.rni.s32.f64 %e_ni, %e_nf;
8936 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
8937 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
8938 mov.f64 %e_p, 0d3E21EED8EFF8D898;
8939 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
8940 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
8941 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
8942 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
8943 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
8944 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
8945 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
8946 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
8947 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
8948 fma.rn.f64 %e_p, %e_p, %e_r, %one;
8949 fma.rn.f64 %e64, %e_p, %e_r, %one;
8950 cvt.s64.s32 %e_ni64, %e_ni;
8951 add.s64 %e_ni64, %e_ni64, 1023;
8952 shl.b64 %e_bits, %e_ni64, 52;
8953 mov.b64 %e_nf, %e_bits;
8954 mul.f64 %e64, %e64, %e_nf;
8955 // --- end exp ---
8956
8957 add.f64 %denom, %one, %e64;
8958 div.rn.f64 %vr, %one, %denom;
8959 st.global.f64 [%out], %vr;
8960
8961DONE:
8962 ret;
8963}
8964";
8965
8966#[cfg(feature = "cuda")]
8969pub(crate) const TANH_PTX: &str = "\
8970.version 7.0
8971.target sm_52
8972.address_size 64
8973
8974.visible .entry tanh_kernel(
8975 .param .u64 a_ptr,
8976 .param .u64 out_ptr,
8977 .param .u32 n
8978) {
8979 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
8980 .reg .u64 %a, %out, %off;
8981 .reg .f32 %va, %vr, %neg2x, %e, %denom, %sig, %one, %two, %lg2e;
8982 .reg .pred %p;
8983
8984 ld.param.u64 %a, [a_ptr];
8985 ld.param.u64 %out, [out_ptr];
8986 ld.param.u32 %n_reg, [n];
8987
8988 mov.u32 %bid, %ctaid.x;
8989 mov.u32 %bdim, %ntid.x;
8990 mov.u32 %r_tid, %tid.x;
8991 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
8992
8993 setp.ge.u32 %p, %r_tid, %n_reg;
8994 @%p bra DONE;
8995
8996 cvt.u64.u32 %off, %r_tid;
8997 shl.b64 %off, %off, 2;
8998
8999 add.u64 %a, %a, %off;
9000 add.u64 %out, %out, %off;
9001
9002 ld.global.f32 %va, [%a];
9003 // tanh(x) = 2*sigmoid(2x) - 1
9004 mov.f32 %two, 0f40000000;
9005 mul.f32 %neg2x, %va, %two;
9006 neg.f32 %neg2x, %neg2x;
9007 mov.f32 %lg2e, 0f3FB8AA3B;
9008 mul.f32 %neg2x, %neg2x, %lg2e;
9009 ex2.approx.f32 %e, %neg2x;
9010 mov.f32 %one, 0f3F800000;
9011 add.f32 %denom, %one, %e;
9012 div.rn.f32 %sig, %one, %denom;
9013 mul.f32 %vr, %two, %sig;
9014 sub.f32 %vr, %vr, %one;
9015 st.global.f32 [%out], %vr;
9016
9017DONE:
9018 ret;
9019}
9020";
9021
9022#[cfg(feature = "cuda")]
9026pub(crate) const TANH_F64_PTX: &str = "\
9027.version 7.0
9028.target sm_52
9029.address_size 64
9030
9031.visible .entry tanh_f64_kernel(
9032 .param .u64 a_ptr,
9033 .param .u64 out_ptr,
9034 .param .u32 n
9035) {
9036 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
9037 .reg .u64 %a, %out, %off;
9038 .reg .f64 %va, %vr, %e64, %num, %denom, %one, %two, %neg2x;
9039 .reg .f64 %e_nf, %e_r, %e_p, %e_half;
9040 .reg .s32 %e_ni;
9041 .reg .s64 %e_ni64, %e_bits;
9042 .reg .pred %p;
9043
9044 ld.param.u64 %a, [a_ptr];
9045 ld.param.u64 %out, [out_ptr];
9046 ld.param.u32 %n_reg, [n];
9047
9048 mov.u32 %bid, %ctaid.x;
9049 mov.u32 %bdim, %ntid.x;
9050 mov.u32 %r_tid, %tid.x;
9051 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
9052
9053 setp.ge.u32 %p, %r_tid, %n_reg;
9054 @%p bra DONE;
9055
9056 cvt.u64.u32 %off, %r_tid;
9057 shl.b64 %off, %off, 3;
9058
9059 add.u64 %a, %a, %off;
9060 add.u64 %out, %out, %off;
9061
9062 ld.global.f64 %va, [%a];
9063 mov.f64 %one, 0d3FF0000000000000;
9064 mov.f64 %two, 0d4000000000000000;
9065
9066 // tanh(x) = (1 - exp(-2x)) / (1 + exp(-2x))
9067 mul.f64 %neg2x, %va, %two;
9068 neg.f64 %neg2x, %neg2x;
9069
9070 // --- exp(%neg2x) via Cody-Waite + degree-11 Horner ---
9071 mov.f64 %e_half, 0d3FE0000000000000;
9072 fma.rn.f64 %e_nf, %neg2x, 0d3FF71547652B82FE, %e_half;
9073 cvt.rmi.f64.f64 %e_nf, %e_nf;
9074 cvt.rni.s32.f64 %e_ni, %e_nf;
9075 fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg2x;
9076 fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
9077 mov.f64 %e_p, 0d3E21EED8EFF8D898;
9078 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E5AE64567F544E4;
9079 fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
9080 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
9081 fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
9082 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
9083 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
9084 fma.rn.f64 %e_p, %e_p, %e_r, 0d3F811111111111111;
9085 fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
9086 fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
9087 fma.rn.f64 %e_p, %e_p, %e_r, %one;
9088 fma.rn.f64 %e64, %e_p, %e_r, %one;
9089 cvt.s64.s32 %e_ni64, %e_ni;
9090 add.s64 %e_ni64, %e_ni64, 1023;
9091 shl.b64 %e_bits, %e_ni64, 52;
9092 mov.b64 %e_nf, %e_bits;
9093 mul.f64 %e64, %e64, %e_nf;
9094 // --- end exp ---
9095
9096 sub.f64 %num, %one, %e64;
9097 add.f64 %denom, %one, %e64;
9098 div.rn.f64 %vr, %num, %denom;
9099 st.global.f64 [%out], %vr;
9100
9101DONE:
9102 ret;
9103}
9104";
9105
9106#[cfg(feature = "cuda")]
9116pub(crate) const FUSED_ADAM_PTX: &str = "\
9117.version 7.0
9118.target sm_52
9119.address_size 64
9120
9121.visible .entry fused_adam_kernel(
9122 .param .u64 param_ptr,
9123 .param .u64 grad_ptr,
9124 .param .u64 exp_avg_ptr,
9125 .param .u64 exp_avg_sq_ptr,
9126 .param .f32 beta1,
9127 .param .f32 beta2,
9128 .param .f32 lr,
9129 .param .f32 eps,
9130 .param .f32 bc1,
9131 .param .f32 bc2,
9132 .param .f32 weight_decay,
9133 .param .u32 n
9134) {
9135 .reg .u32 %r_tid, %bid, %bdim, %n_reg;
9136 .reg .u64 %p, %g, %m, %v, %off;
9137 .reg .f32 %vp, %vg, %vm, %vv;
9138 .reg .f32 %b1, %b2, %f_lr, %f_eps, %f_bc1, %f_bc2, %f_wd;
9139 .reg .f32 %t1, %t2, %m_hat, %v_hat, %denom, %update;
9140 .reg .f32 %one;
9141 .reg .pred %p_bound, %p_wd;
9142
9143 ld.param.u64 %p, [param_ptr];
9144 ld.param.u64 %g, [grad_ptr];
9145 ld.param.u64 %m, [exp_avg_ptr];
9146 ld.param.u64 %v, [exp_avg_sq_ptr];
9147 ld.param.f32 %b1, [beta1];
9148 ld.param.f32 %b2, [beta2];
9149 ld.param.f32 %f_lr, [lr];
9150 ld.param.f32 %f_eps, [eps];
9151 ld.param.f32 %f_bc1, [bc1];
9152 ld.param.f32 %f_bc2, [bc2];
9153 ld.param.f32 %f_wd, [weight_decay];
9154 ld.param.u32 %n_reg, [n];
9155
9156 mov.u32 %bid, %ctaid.x;
9157 mov.u32 %bdim, %ntid.x;
9158 mov.u32 %r_tid, %tid.x;
9159 mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
9160
9161 setp.ge.u32 %p_bound, %r_tid, %n_reg;
9162 @%p_bound bra DONE;
9163
9164 cvt.u64.u32 %off, %r_tid;
9165 shl.b64 %off, %off, 2;
9166
9167 add.u64 %p, %p, %off;
9168 add.u64 %g, %g, %off;
9169 add.u64 %m, %m, %off;
9170 add.u64 %v, %v, %off;
9171
9172 ld.global.f32 %vp, [%p];
9173 ld.global.f32 %vg, [%g];
9174 ld.global.f32 %vm, [%m];
9175 ld.global.f32 %vv, [%v];
9176
9177 // L2 weight decay: g = g + wd * p
9178 mov.f32 %one, 0f00000000;
9179 setp.gt.f32 %p_wd, %f_wd, %one;
9180 @%p_wd fma.rn.f32 %vg, %f_wd, %vp, %vg;
9181
9182 // exp_avg = beta1 * exp_avg + (1 - beta1) * g
9183 mov.f32 %one, 0f3F800000;
9184 sub.f32 %t1, %one, %b1;
9185 mul.f32 %vm, %vm, %b1;
9186 fma.rn.f32 %vm, %t1, %vg, %vm;
9187
9188 // exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * g * g
9189 sub.f32 %t2, %one, %b2;
9190 mul.f32 %vv, %vv, %b2;
9191 mul.f32 %t1, %vg, %vg;
9192 fma.rn.f32 %vv, %t2, %t1, %vv;
9193
9194 // m_hat = exp_avg / bc1
9195 div.rn.f32 %m_hat, %vm, %f_bc1;
9196
9197 // v_hat = exp_avg_sq / bc2
9198 div.rn.f32 %v_hat, %vv, %f_bc2;
9199
9200 // denom = sqrt(v_hat) + eps
9201 sqrt.rn.f32 %denom, %v_hat;
9202 add.f32 %denom, %denom, %f_eps;
9203
9204 // param = param - lr * m_hat / denom
9205 div.rn.f32 %update, %m_hat, %denom;
9206 mul.f32 %update, %update, %f_lr;
9207 sub.f32 %vp, %vp, %update;
9208
9209 st.global.f32 [%p], %vp;
9210 st.global.f32 [%m], %vm;
9211 st.global.f32 [%v], %vv;
9212
9213DONE:
9214 ret;
9215}
9216";
9217
9218#[cfg(feature = "cuda")]
9230pub(crate) const FUSED_GRU_FORWARD_PTX: &str = "\
9231.version 7.0
9232.target sm_52
9233.address_size 64
9234
9235.visible .entry fused_gru_forward_kernel(
9236 .param .u64 input_gates_ptr,
9237 .param .u64 hidden_gates_ptr,
9238 .param .u64 bias_ih_ptr,
9239 .param .u64 bias_hh_ptr,
9240 .param .u64 hx_ptr,
9241 .param .u64 hy_ptr,
9242 .param .u64 workspace_ptr,
9243 .param .u32 hsz,
9244 .param .u32 total
9245) {
9246 .reg .u32 %tid, %bid, %bdim, %gdim, %total_reg, %hsz_reg;
9247 .reg .u32 %idx, %stride, %offset3, %offset5, %hmod, %batch_idx;
9248 .reg .u64 %ig, %hg, %b1, %b2, %hx, %hy, %ws;
9249 .reg .u64 %off64, %tmp64;
9250 .reg .f32 %ir, %ii, %in, %hr, %hi, %hn;
9251 .reg .f32 %b1r, %b1i, %b1n, %b2r, %b2i, %b2n;
9252 .reg .f32 %hx_val, %rg, %zg, %ng, %hy_val;
9253 .reg .f32 %one, %neg_one, %exp_val, %denom, %tmp;
9254 .reg .pred %p;
9255
9256 ld.param.u64 %ig, [input_gates_ptr];
9257 ld.param.u64 %hg, [hidden_gates_ptr];
9258 ld.param.u64 %b1, [bias_ih_ptr];
9259 ld.param.u64 %b2, [bias_hh_ptr];
9260 ld.param.u64 %hx, [hx_ptr];
9261 ld.param.u64 %hy, [hy_ptr];
9262 ld.param.u64 %ws, [workspace_ptr];
9263 ld.param.u32 %hsz_reg, [hsz];
9264 ld.param.u32 %total_reg, [total];
9265
9266 mov.u32 %bid, %ctaid.x;
9267 mov.u32 %bdim, %ntid.x;
9268 mov.u32 %tid, %tid.x;
9269 mov.u32 %gdim, %nctaid.x;
9270 mad.lo.u32 %idx, %bid, %bdim, %tid;
9271 mul.lo.u32 %stride, %bdim, %gdim;
9272 mov.f32 %one, 0f3F800000;
9273
9274LOOP:
9275 setp.ge.u32 %p, %idx, %total_reg;
9276 @%p bra END;
9277
9278 // offset3 = (idx/hsz)*3*hsz + idx%hsz (into [B, 3*H] gates tensor)
9279 div.u32 %batch_idx, %idx, %hsz_reg;
9280 rem.u32 %hmod, %idx, %hsz_reg;
9281 mul.lo.u32 %offset3, %batch_idx, %hsz_reg;
9282 mul.lo.u32 %offset3, %offset3, 3;
9283 add.u32 %offset3, %offset3, %hmod;
9284
9285 // Load input gate components: ir, ii, in
9286 cvt.u64.u32 %off64, %offset3;
9287 shl.b64 %off64, %off64, 2;
9288 add.u64 %tmp64, %ig, %off64;
9289 ld.global.f32 %ir, [%tmp64];
9290 cvt.u64.u32 %off64, %hsz_reg;
9291 shl.b64 %off64, %off64, 2;
9292 add.u64 %tmp64, %tmp64, %off64;
9293 ld.global.f32 %ii, [%tmp64];
9294 add.u64 %tmp64, %tmp64, %off64;
9295 ld.global.f32 %in, [%tmp64];
9296
9297 // Load hidden gate components: hr, hi, hn
9298 cvt.u64.u32 %off64, %offset3;
9299 shl.b64 %off64, %off64, 2;
9300 add.u64 %tmp64, %hg, %off64;
9301 ld.global.f32 %hr, [%tmp64];
9302 cvt.u64.u32 %off64, %hsz_reg;
9303 shl.b64 %off64, %off64, 2;
9304 add.u64 %tmp64, %tmp64, %off64;
9305 ld.global.f32 %hi, [%tmp64];
9306 add.u64 %tmp64, %tmp64, %off64;
9307 ld.global.f32 %hn, [%tmp64];
9308
9309 // Load biases (indexed by hmod, hmod+hsz, hmod+2*hsz)
9310 cvt.u64.u32 %off64, %hmod;
9311 shl.b64 %off64, %off64, 2;
9312 add.u64 %tmp64, %b1, %off64;
9313 ld.global.f32 %b1r, [%tmp64];
9314 cvt.u64.u32 %off64, %hsz_reg;
9315 shl.b64 %off64, %off64, 2;
9316 add.u64 %tmp64, %tmp64, %off64;
9317 ld.global.f32 %b1i, [%tmp64];
9318 add.u64 %tmp64, %tmp64, %off64;
9319 ld.global.f32 %b1n, [%tmp64];
9320
9321 cvt.u64.u32 %off64, %hmod;
9322 shl.b64 %off64, %off64, 2;
9323 add.u64 %tmp64, %b2, %off64;
9324 ld.global.f32 %b2r, [%tmp64];
9325 cvt.u64.u32 %off64, %hsz_reg;
9326 shl.b64 %off64, %off64, 2;
9327 add.u64 %tmp64, %tmp64, %off64;
9328 ld.global.f32 %b2i, [%tmp64];
9329 add.u64 %tmp64, %tmp64, %off64;
9330 ld.global.f32 %b2n, [%tmp64];
9331
9332 // Load hx[idx]
9333 cvt.u64.u32 %off64, %idx;
9334 shl.b64 %off64, %off64, 2;
9335 add.u64 %tmp64, %hx, %off64;
9336 ld.global.f32 %hx_val, [%tmp64];
9337
9338 // r = sigmoid(ir + hr + b1r + b2r)
9339 add.f32 %rg, %ir, %hr;
9340 add.f32 %rg, %rg, %b1r;
9341 add.f32 %rg, %rg, %b2r;
9342 neg.f32 %tmp, %rg;
9343 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
9344 ex2.approx.f32 %exp_val, %tmp;
9345 add.f32 %denom, %one, %exp_val;
9346 div.rn.f32 %rg, %one, %denom;
9347
9348 // z = sigmoid(ii + hi + b1i + b2i)
9349 add.f32 %zg, %ii, %hi;
9350 add.f32 %zg, %zg, %b1i;
9351 add.f32 %zg, %zg, %b2i;
9352 neg.f32 %tmp, %zg;
9353 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
9354 ex2.approx.f32 %exp_val, %tmp;
9355 add.f32 %denom, %one, %exp_val;
9356 div.rn.f32 %zg, %one, %denom;
9357
9358 // n = tanh(in + b1n + r*(hn + b2n))
9359 add.f32 %tmp, %hn, %b2n;
9360 fma.rn.f32 %ng, %rg, %tmp, %in;
9361 add.f32 %ng, %ng, %b1n;
9362 // tanh via 2*sigmoid(2x)-1
9363 mul.f32 %tmp, %ng, 0f40000000;
9364 neg.f32 %tmp, %tmp;
9365 mul.f32 %tmp, %tmp, 0f3FB8AA3B;
9366 ex2.approx.f32 %exp_val, %tmp;
9367 add.f32 %denom, %one, %exp_val;
9368 div.rn.f32 %ng, %one, %denom;
9369 mul.f32 %ng, %ng, 0f40000000;
9370 sub.f32 %ng, %ng, %one;
9371
9372 // hy = n + z * (hx - n)
9373 sub.f32 %tmp, %hx_val, %ng;
9374 fma.rn.f32 %hy_val, %zg, %tmp, %ng;
9375
9376 // Store hy[idx]
9377 cvt.u64.u32 %off64, %idx;
9378 shl.b64 %off64, %off64, 2;
9379 add.u64 %tmp64, %hy, %off64;
9380 st.global.f32 [%tmp64], %hy_val;
9381
9382 // Store workspace: [r, z, n, hx, hn+b2n] at offset5 = (idx/hsz)*5*hsz + idx%hsz
9383 mul.lo.u32 %offset5, %batch_idx, %hsz_reg;
9384 mul.lo.u32 %offset5, %offset5, 5;
9385 add.u32 %offset5, %offset5, %hmod;
9386
9387 cvt.u64.u32 %off64, %offset5;
9388 shl.b64 %off64, %off64, 2;
9389 add.u64 %tmp64, %ws, %off64;
9390 st.global.f32 [%tmp64], %rg;
9391 cvt.u64.u32 %off64, %hsz_reg;
9392 shl.b64 %off64, %off64, 2;
9393 add.u64 %tmp64, %tmp64, %off64;
9394 st.global.f32 [%tmp64], %zg;
9395 add.u64 %tmp64, %tmp64, %off64;
9396 st.global.f32 [%tmp64], %ng;
9397 add.u64 %tmp64, %tmp64, %off64;
9398 st.global.f32 [%tmp64], %hx_val;
9399 add.u64 %tmp64, %tmp64, %off64;
9400 add.f32 %tmp, %hn, %b2n;
9401 st.global.f32 [%tmp64], %tmp;
9402
9403 add.u32 %idx, %idx, %stride;
9404 bra LOOP;
9405
9406END:
9407 ret;
9408}
9409";
9410
9411#[cfg(feature = "cuda")]
9425fn launch_cfg(n: usize) -> GpuResult<LaunchConfig> {
9426 if n > u32::MAX as usize {
9427 return Err(GpuError::ShapeMismatch {
9428 op: "kernel_launch",
9429 expected: vec![u32::MAX as usize],
9430 got: vec![n],
9431 });
9432 }
9433 const BLOCK: u32 = 256;
9434 let grid = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
9435 Ok(LaunchConfig {
9436 grid_dim: (grid.max(1), 1, 1),
9437 block_dim: (BLOCK, 1, 1),
9438 shared_mem_bytes: 0,
9439 })
9440}
9441
9442#[cfg(feature = "cuda")]
9448fn validate_binary(a: &CudaBuffer<f32>, b: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
9449 if a.device_ordinal() != device.ordinal() {
9450 return Err(GpuError::DeviceMismatch {
9451 expected: a.device_ordinal(),
9452 got: device.ordinal(),
9453 });
9454 }
9455 if b.device_ordinal() != device.ordinal() {
9456 return Err(GpuError::DeviceMismatch {
9457 expected: b.device_ordinal(),
9458 got: device.ordinal(),
9459 });
9460 }
9461 if a.len() != b.len() {
9462 return Err(GpuError::LengthMismatch {
9463 a: a.len(),
9464 b: b.len(),
9465 });
9466 }
9467 Ok(())
9468}
9469
9470#[cfg(feature = "cuda")]
9472fn validate_unary(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
9473 if a.device_ordinal() != device.ordinal() {
9474 return Err(GpuError::DeviceMismatch {
9475 expected: a.device_ordinal(),
9476 got: device.ordinal(),
9477 });
9478 }
9479 Ok(())
9480}
9481
9482#[cfg(feature = "cuda")]
9484fn validate_device<T>(a: &CudaBuffer<T>, device: &GpuDevice) -> GpuResult<()> {
9485 if a.device_ordinal() != device.ordinal() {
9486 return Err(GpuError::DeviceMismatch {
9487 expected: a.device_ordinal(),
9488 got: device.ordinal(),
9489 });
9490 }
9491 Ok(())
9492}
9493
9494#[cfg(feature = "cuda")]
9502fn try_launch_binary(
9503 a: &CudaBuffer<f32>,
9504 b: &CudaBuffer<f32>,
9505 device: &GpuDevice,
9506 ptx_src: &'static str,
9507 kernel_name: &'static str,
9508) -> GpuResult<Option<CudaBuffer<f32>>> {
9509 use cudarc::driver::PushKernelArg;
9510
9511 let n = a.len();
9512 let ctx = device.context();
9513 let stream = device.stream();
9514
9515 let f = match crate::module_cache::get_or_compile(
9519 ctx,
9520 ptx_src,
9521 kernel_name,
9522 device.ordinal() as u32,
9523 ) {
9524 Ok(f) => f,
9525 Err(_) => return Ok(None),
9526 };
9527
9528 let mut out = alloc_zeros_f32(n, device)?;
9529 let cfg = launch_cfg(n)?;
9530 let n_u32 = n as u32;
9531
9532 unsafe {
9536 stream
9537 .launch_builder(&f)
9538 .arg(a.inner())
9539 .arg(b.inner())
9540 .arg(out.inner_mut())
9541 .arg(&n_u32)
9542 .launch(cfg)?;
9543 }
9544
9545 Ok(Some(out))
9546}
9547
9548#[cfg(feature = "cuda")]
9553fn try_launch_binary_vec4(
9554 a: &CudaBuffer<f32>,
9555 b: &CudaBuffer<f32>,
9556 device: &GpuDevice,
9557 ptx_src: &'static str,
9558 kernel_name: &'static str,
9559) -> GpuResult<Option<CudaBuffer<f32>>> {
9560 use cudarc::driver::PushKernelArg;
9561
9562 let n = a.len();
9563 let n4 = (n / 4) as u32;
9564 let ctx = device.context();
9565 let stream = device.stream();
9566
9567 let f = match crate::module_cache::get_or_compile(
9568 ctx,
9569 ptx_src,
9570 kernel_name,
9571 device.ordinal() as u32,
9572 ) {
9573 Ok(f) => f,
9574 Err(_) => return Ok(None),
9575 };
9576
9577 let mut out = alloc_zeros_f32(n, device)?;
9578 let cfg = launch_cfg(n4 as usize)?;
9579
9580 unsafe {
9581 stream
9582 .launch_builder(&f)
9583 .arg(a.inner())
9584 .arg(b.inner())
9585 .arg(out.inner_mut())
9586 .arg(&n4)
9587 .launch(cfg)?;
9588 }
9589
9590 Ok(Some(out))
9591}
9592
9593#[cfg(feature = "cuda")]
9596fn try_launch_unary(
9597 a: &CudaBuffer<f32>,
9598 device: &GpuDevice,
9599 ptx_src: &'static str,
9600 kernel_name: &'static str,
9601) -> GpuResult<Option<CudaBuffer<f32>>> {
9602 use cudarc::driver::PushKernelArg;
9603
9604 let n = a.len();
9605 let ctx = device.context();
9606 let stream = device.stream();
9607
9608 let f = match crate::module_cache::get_or_compile(
9610 ctx,
9611 ptx_src,
9612 kernel_name,
9613 device.ordinal() as u32,
9614 ) {
9615 Ok(f) => f,
9616 Err(_) => return Ok(None),
9617 };
9618
9619 let mut out = alloc_zeros_f32(n, device)?;
9620 let cfg = launch_cfg(n)?;
9621 let n_u32 = n as u32;
9622
9623 unsafe {
9626 stream
9627 .launch_builder(&f)
9628 .arg(a.inner())
9629 .arg(out.inner_mut())
9630 .arg(&n_u32)
9631 .launch(cfg)?;
9632 }
9633
9634 Ok(Some(out))
9635}
9636
9637#[cfg(feature = "cuda")]
9644fn try_launch_binary_into(
9645 a: &CudaBuffer<f32>,
9646 b: &CudaBuffer<f32>,
9647 out: &mut CudaBuffer<f32>,
9648 device: &GpuDevice,
9649 ptx_src: &'static str,
9650 kernel_name: &'static str,
9651) -> GpuResult<bool> {
9652 use cudarc::driver::PushKernelArg;
9653
9654 let n = a.len();
9655 let ctx = device.context();
9656 let stream = device.stream();
9657
9658 let f = match crate::module_cache::get_or_compile(
9659 ctx,
9660 ptx_src,
9661 kernel_name,
9662 device.ordinal() as u32,
9663 ) {
9664 Ok(f) => f,
9665 Err(_) => return Ok(false),
9666 };
9667
9668 let cfg = launch_cfg(n)?;
9669 let n_u32 = n as u32;
9670
9671 unsafe {
9672 stream
9673 .launch_builder(&f)
9674 .arg(a.inner())
9675 .arg(b.inner())
9676 .arg(out.inner_mut())
9677 .arg(&n_u32)
9678 .launch(cfg)?;
9679 }
9680
9681 Ok(true)
9682}
9683
9684#[cfg(feature = "cuda")]
9687fn try_launch_unary_into(
9688 a: &CudaBuffer<f32>,
9689 out: &mut CudaBuffer<f32>,
9690 device: &GpuDevice,
9691 ptx_src: &'static str,
9692 kernel_name: &'static str,
9693) -> GpuResult<bool> {
9694 use cudarc::driver::PushKernelArg;
9695
9696 let n = a.len();
9697 let ctx = device.context();
9698 let stream = device.stream();
9699
9700 let f = match crate::module_cache::get_or_compile(
9701 ctx,
9702 ptx_src,
9703 kernel_name,
9704 device.ordinal() as u32,
9705 ) {
9706 Ok(f) => f,
9707 Err(_) => return Ok(false),
9708 };
9709
9710 let cfg = launch_cfg(n)?;
9711 let n_u32 = n as u32;
9712
9713 unsafe {
9714 stream
9715 .launch_builder(&f)
9716 .arg(a.inner())
9717 .arg(out.inner_mut())
9718 .arg(&n_u32)
9719 .launch(cfg)?;
9720 }
9721
9722 Ok(true)
9723}
9724
9725#[cfg(feature = "cuda")]
9731fn try_launch_binary_f64(
9732 a: &CudaBuffer<f64>,
9733 b: &CudaBuffer<f64>,
9734 device: &GpuDevice,
9735 ptx_src: &'static str,
9736 kernel_name: &'static str,
9737) -> GpuResult<Option<CudaBuffer<f64>>> {
9738 use cudarc::driver::PushKernelArg;
9739
9740 let n = a.len();
9741 let ctx = device.context();
9742 let stream = device.stream();
9743
9744 let f = match crate::module_cache::get_or_compile(
9745 ctx, ptx_src, kernel_name, device.ordinal() as u32,
9746 ) {
9747 Ok(f) => f,
9748 Err(_) => return Ok(None),
9749 };
9750
9751 let mut out = alloc_zeros_f64(n, device)?;
9752 let cfg = launch_cfg(n)?;
9753 let n_u32 = n as u32;
9754
9755 unsafe {
9756 stream
9757 .launch_builder(&f)
9758 .arg(a.inner())
9759 .arg(b.inner())
9760 .arg(out.inner_mut())
9761 .arg(&n_u32)
9762 .launch(cfg)?;
9763 }
9764 Ok(Some(out))
9765}
9766
9767#[cfg(feature = "cuda")]
9769fn try_launch_unary_f64(
9770 a: &CudaBuffer<f64>,
9771 device: &GpuDevice,
9772 ptx_src: &'static str,
9773 kernel_name: &'static str,
9774) -> GpuResult<Option<CudaBuffer<f64>>> {
9775 use cudarc::driver::PushKernelArg;
9776
9777 let n = a.len();
9778 let ctx = device.context();
9779 let stream = device.stream();
9780
9781 let f = match crate::module_cache::get_or_compile(
9782 ctx, ptx_src, kernel_name, device.ordinal() as u32,
9783 ) {
9784 Ok(f) => f,
9785 Err(_) => return Ok(None),
9786 };
9787
9788 let mut out = alloc_zeros_f64(n, device)?;
9789 let cfg = launch_cfg(n)?;
9790 let n_u32 = n as u32;
9791
9792 unsafe {
9793 stream
9794 .launch_builder(&f)
9795 .arg(a.inner())
9796 .arg(out.inner_mut())
9797 .arg(&n_u32)
9798 .launch(cfg)?;
9799 }
9800 Ok(Some(out))
9801}
9802
9803#[cfg(feature = "cuda")]
9805fn cpu_fallback_binary_f64(
9806 a: &CudaBuffer<f64>,
9807 b: &CudaBuffer<f64>,
9808 device: &GpuDevice,
9809 op: fn(f64, f64) -> f64,
9810) -> GpuResult<CudaBuffer<f64>> {
9811 let a_host = gpu_to_cpu(a, device)?;
9812 let b_host = gpu_to_cpu(b, device)?;
9813 let result: Vec<f64> = a_host.iter().zip(b_host.iter()).map(|(&x, &y)| op(x, y)).collect();
9814 cpu_to_gpu(&result, device)
9815}
9816
9817#[cfg(feature = "cuda")]
9819fn cpu_fallback_unary_f64(
9820 a: &CudaBuffer<f64>,
9821 device: &GpuDevice,
9822 op: fn(f64) -> f64,
9823) -> GpuResult<CudaBuffer<f64>> {
9824 let a_host = gpu_to_cpu(a, device)?;
9825 let result: Vec<f64> = a_host.iter().map(|&x| op(x)).collect();
9826 cpu_to_gpu(&result, device)
9827}
9828
9829#[cfg(feature = "cuda")]
9833#[allow(clippy::too_many_arguments)]
9834fn try_launch_broadcast_binary_f64(
9835 a: &CudaBuffer<f64>,
9836 b: &CudaBuffer<f64>,
9837 a_strides: &[u32],
9838 b_strides: &[u32],
9839 out_shape: &[u32],
9840 out_numel: usize,
9841 device: &GpuDevice,
9842 ptx_src: &'static str,
9843 kernel_name: &'static str,
9844) -> GpuResult<Option<CudaBuffer<f64>>> {
9845 use cudarc::driver::PushKernelArg;
9846
9847 let ndim = out_shape.len();
9848 let ctx = device.context();
9849 let stream = device.stream();
9850
9851 let f = match crate::module_cache::get_or_compile(
9852 ctx,
9853 ptx_src,
9854 kernel_name,
9855 device.ordinal() as u32,
9856 ) {
9857 Ok(f) => f,
9858 Err(_) => return Ok(None),
9859 };
9860
9861 let a_str_buf = cpu_to_gpu(a_strides, device)?;
9863 let b_str_buf = cpu_to_gpu(b_strides, device)?;
9864 let shape_buf = cpu_to_gpu(out_shape, device)?;
9865
9866 let mut out = alloc_zeros_f64(out_numel, device)?;
9867 let cfg = launch_cfg(out_numel)?;
9868 let n_u32 = out_numel as u32;
9869 let ndim_u32 = ndim as u32;
9870
9871 unsafe {
9872 stream
9873 .launch_builder(&f)
9874 .arg(a.inner())
9875 .arg(b.inner())
9876 .arg(out.inner_mut())
9877 .arg(a_str_buf.inner())
9878 .arg(b_str_buf.inner())
9879 .arg(shape_buf.inner())
9880 .arg(&n_u32)
9881 .arg(&ndim_u32)
9882 .launch(cfg)?;
9883 }
9884
9885 Ok(Some(out))
9886}
9887
9888#[cfg(feature = "cuda")]
9890fn cpu_fallback_broadcast_binary_f64(
9891 a: &CudaBuffer<f64>,
9892 b: &CudaBuffer<f64>,
9893 a_shape: &[usize],
9894 b_shape: &[usize],
9895 out_shape: &[usize],
9896 device: &GpuDevice,
9897 op: fn(f64, f64) -> f64,
9898) -> GpuResult<CudaBuffer<f64>> {
9899 let a_host = gpu_to_cpu(a, device)?;
9900 let b_host = gpu_to_cpu(b, device)?;
9901 let out_numel: usize = out_shape.iter().product();
9902
9903 let a_str = broadcast_strides(a_shape, out_shape);
9904 let b_str = broadcast_strides(b_shape, out_shape);
9905
9906 let mut result = Vec::with_capacity(out_numel);
9907 for i in 0..out_numel {
9908 let mut remaining = i;
9909 let mut a_idx = 0usize;
9910 let mut b_idx = 0usize;
9911 for d in (0..out_shape.len()).rev() {
9912 let coord = remaining % out_shape[d];
9913 remaining /= out_shape[d];
9914 a_idx += coord * a_str[d] as usize;
9915 b_idx += coord * b_str[d] as usize;
9916 }
9917 result.push(op(a_host[a_idx], b_host[b_idx]));
9918 }
9919 cpu_to_gpu(&result, device)
9920}
9921
9922#[cfg(feature = "cuda")]
9929#[allow(clippy::too_many_arguments)]
9930fn try_launch_broadcast_binary(
9931 a: &CudaBuffer<f32>,
9932 b: &CudaBuffer<f32>,
9933 a_strides: &[u32],
9934 b_strides: &[u32],
9935 out_shape: &[u32],
9936 out_numel: usize,
9937 device: &GpuDevice,
9938 ptx_src: &'static str,
9939 kernel_name: &'static str,
9940) -> GpuResult<Option<CudaBuffer<f32>>> {
9941 use cudarc::driver::PushKernelArg;
9942
9943 let ndim = out_shape.len();
9944 let ctx = device.context();
9945 let stream = device.stream();
9946
9947 let f = match crate::module_cache::get_or_compile(
9948 ctx,
9949 ptx_src,
9950 kernel_name,
9951 device.ordinal() as u32,
9952 ) {
9953 Ok(f) => f,
9954 Err(_) => return Ok(None),
9955 };
9956
9957 let a_str_buf = cpu_to_gpu(a_strides, device)?;
9959 let b_str_buf = cpu_to_gpu(b_strides, device)?;
9960 let shape_buf = cpu_to_gpu(out_shape, device)?;
9961
9962 let mut out = alloc_zeros_f32(out_numel, device)?;
9963 let cfg = launch_cfg(out_numel)?;
9964 let n_u32 = out_numel as u32;
9965 let ndim_u32 = ndim as u32;
9966
9967 unsafe {
9970 stream
9971 .launch_builder(&f)
9972 .arg(a.inner())
9973 .arg(b.inner())
9974 .arg(out.inner_mut())
9975 .arg(a_str_buf.inner())
9976 .arg(b_str_buf.inner())
9977 .arg(shape_buf.inner())
9978 .arg(&n_u32)
9979 .arg(&ndim_u32)
9980 .launch(cfg)?;
9981 }
9982
9983 Ok(Some(out))
9984}
9985
9986#[cfg(feature = "cuda")]
9993fn broadcast_strides(in_shape: &[usize], out_shape: &[usize]) -> Vec<u32> {
9994 let ndim = out_shape.len();
9995 let in_ndim = in_shape.len();
9996 let mut strides = vec![0u32; ndim];
9997
9998 let mut stride: u32 = 1;
10000 for d in (0..ndim).rev() {
10001 let in_d = if d + in_ndim >= ndim {
10002 d + in_ndim - ndim
10003 } else {
10004 strides[d] = 0;
10006 continue;
10007 };
10008
10009 if in_shape[in_d] == 1 {
10010 strides[d] = 0; } else {
10012 strides[d] = stride;
10013 }
10014 stride *= in_shape[in_d] as u32;
10015 }
10016
10017 strides
10018}
10019
10020#[cfg(feature = "cuda")]
10027fn cpu_fallback_binary(
10028 a: &CudaBuffer<f32>,
10029 b: &CudaBuffer<f32>,
10030 device: &GpuDevice,
10031 op: fn(f32, f32) -> f32,
10032) -> GpuResult<CudaBuffer<f32>> {
10033 let a_host = gpu_to_cpu(a, device)?;
10034 let b_host = gpu_to_cpu(b, device)?;
10035 let result: Vec<f32> = a_host
10036 .iter()
10037 .zip(b_host.iter())
10038 .map(|(&x, &y)| op(x, y))
10039 .collect();
10040 cpu_to_gpu(&result, device)
10041}
10042
10043#[cfg(feature = "cuda")]
10045fn cpu_fallback_unary(
10046 a: &CudaBuffer<f32>,
10047 device: &GpuDevice,
10048 op: fn(f32) -> f32,
10049) -> GpuResult<CudaBuffer<f32>> {
10050 let a_host = gpu_to_cpu(a, device)?;
10051 let result: Vec<f32> = a_host.iter().map(|&x| op(x)).collect();
10052 cpu_to_gpu(&result, device)
10053}
10054
10055#[cfg(feature = "cuda")]
10071pub fn gpu_add(
10072 a: &CudaBuffer<f32>,
10073 b: &CudaBuffer<f32>,
10074 device: &GpuDevice,
10075) -> GpuResult<CudaBuffer<f32>> {
10076 validate_binary(a, b, device)?;
10077
10078 let n = a.len();
10080 if n >= 16 && n % 4 == 0 {
10081 if let Some(out) = try_launch_binary_vec4(
10082 a, b, device, ADD_VEC4_PTX, "add_vec4_kernel",
10083 )? {
10084 return Ok(out);
10085 }
10086 }
10087
10088 if let Some(out) = try_launch_binary(a, b, device, ADD_PTX, "add_kernel")? {
10089 return Ok(out);
10090 }
10091
10092 cpu_fallback_binary(a, b, device, |x, y| x + y)
10093}
10094
10095#[cfg(feature = "cuda")]
10107pub fn gpu_sub(
10108 a: &CudaBuffer<f32>,
10109 b: &CudaBuffer<f32>,
10110 device: &GpuDevice,
10111) -> GpuResult<CudaBuffer<f32>> {
10112 validate_binary(a, b, device)?;
10113
10114 if let Some(out) = try_launch_binary(a, b, device, SUB_PTX, "sub_kernel")? {
10115 return Ok(out);
10116 }
10117
10118 cpu_fallback_binary(a, b, device, |x, y| x - y)
10119}
10120
10121#[cfg(feature = "cuda")]
10133pub fn gpu_mul(
10134 a: &CudaBuffer<f32>,
10135 b: &CudaBuffer<f32>,
10136 device: &GpuDevice,
10137) -> GpuResult<CudaBuffer<f32>> {
10138 validate_binary(a, b, device)?;
10139
10140 let n = a.len();
10141 if n >= 16 && n % 4 == 0 {
10142 if let Some(out) = try_launch_binary_vec4(
10143 a, b, device, MUL_VEC4_PTX, "mul_vec4_kernel",
10144 )? {
10145 return Ok(out);
10146 }
10147 }
10148
10149 if let Some(out) = try_launch_binary(a, b, device, MUL_PTX, "mul_kernel")? {
10150 return Ok(out);
10151 }
10152
10153 cpu_fallback_binary(a, b, device, |x, y| x * y)
10154}
10155
10156#[cfg(feature = "cuda")]
10169pub fn gpu_broadcast_add(
10170 a: &CudaBuffer<f32>,
10171 b: &CudaBuffer<f32>,
10172 a_shape: &[usize],
10173 b_shape: &[usize],
10174 out_shape: &[usize],
10175 device: &GpuDevice,
10176) -> GpuResult<CudaBuffer<f32>> {
10177 let a_str = broadcast_strides(a_shape, out_shape);
10178 let b_str = broadcast_strides(b_shape, out_shape);
10179 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10180 let out_numel: usize = out_shape.iter().product();
10181
10182 if let Some(out) = try_launch_broadcast_binary(
10183 a,
10184 b,
10185 &a_str,
10186 &b_str,
10187 &shape_u32,
10188 out_numel,
10189 device,
10190 BROADCAST_ADD_PTX,
10191 "broadcast_add_kernel",
10192 )? {
10193 return Ok(out);
10194 }
10195
10196 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
10198}
10199
10200#[cfg(feature = "cuda")]
10202pub fn gpu_broadcast_sub(
10203 a: &CudaBuffer<f32>,
10204 b: &CudaBuffer<f32>,
10205 a_shape: &[usize],
10206 b_shape: &[usize],
10207 out_shape: &[usize],
10208 device: &GpuDevice,
10209) -> GpuResult<CudaBuffer<f32>> {
10210 let a_str = broadcast_strides(a_shape, out_shape);
10211 let b_str = broadcast_strides(b_shape, out_shape);
10212 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10213 let out_numel: usize = out_shape.iter().product();
10214
10215 if let Some(out) = try_launch_broadcast_binary(
10216 a,
10217 b,
10218 &a_str,
10219 &b_str,
10220 &shape_u32,
10221 out_numel,
10222 device,
10223 BROADCAST_SUB_PTX,
10224 "broadcast_sub_kernel",
10225 )? {
10226 return Ok(out);
10227 }
10228
10229 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
10230}
10231
10232#[cfg(feature = "cuda")]
10234pub fn gpu_broadcast_mul(
10235 a: &CudaBuffer<f32>,
10236 b: &CudaBuffer<f32>,
10237 a_shape: &[usize],
10238 b_shape: &[usize],
10239 out_shape: &[usize],
10240 device: &GpuDevice,
10241) -> GpuResult<CudaBuffer<f32>> {
10242 let a_str = broadcast_strides(a_shape, out_shape);
10243 let b_str = broadcast_strides(b_shape, out_shape);
10244 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10245 let out_numel: usize = out_shape.iter().product();
10246
10247 if let Some(out) = try_launch_broadcast_binary(
10248 a,
10249 b,
10250 &a_str,
10251 &b_str,
10252 &shape_u32,
10253 out_numel,
10254 device,
10255 BROADCAST_MUL_PTX,
10256 "broadcast_mul_kernel",
10257 )? {
10258 return Ok(out);
10259 }
10260
10261 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
10262}
10263
10264#[cfg(feature = "cuda")]
10266pub fn gpu_broadcast_div(
10267 a: &CudaBuffer<f32>,
10268 b: &CudaBuffer<f32>,
10269 a_shape: &[usize],
10270 b_shape: &[usize],
10271 out_shape: &[usize],
10272 device: &GpuDevice,
10273) -> GpuResult<CudaBuffer<f32>> {
10274 let a_str = broadcast_strides(a_shape, out_shape);
10275 let b_str = broadcast_strides(b_shape, out_shape);
10276 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
10277 let out_numel: usize = out_shape.iter().product();
10278
10279 if let Some(out) = try_launch_broadcast_binary(
10280 a,
10281 b,
10282 &a_str,
10283 &b_str,
10284 &shape_u32,
10285 out_numel,
10286 device,
10287 BROADCAST_DIV_PTX,
10288 "broadcast_div_kernel",
10289 )? {
10290 return Ok(out);
10291 }
10292
10293 cpu_fallback_broadcast_binary(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
10294}
10295
10296#[cfg(feature = "cuda")]
10299fn cpu_fallback_broadcast_binary(
10300 a: &CudaBuffer<f32>,
10301 b: &CudaBuffer<f32>,
10302 a_shape: &[usize],
10303 b_shape: &[usize],
10304 out_shape: &[usize],
10305 device: &GpuDevice,
10306 op: fn(f32, f32) -> f32,
10307) -> GpuResult<CudaBuffer<f32>> {
10308 let a_host = gpu_to_cpu(a, device)?;
10309 let b_host = gpu_to_cpu(b, device)?;
10310 let out_numel: usize = out_shape.iter().product();
10311
10312 let a_str = broadcast_strides(a_shape, out_shape);
10313 let b_str = broadcast_strides(b_shape, out_shape);
10314
10315 let mut result = Vec::with_capacity(out_numel);
10316 for i in 0..out_numel {
10317 let mut remaining = i;
10318 let mut a_idx = 0usize;
10319 let mut b_idx = 0usize;
10320 for d in (0..out_shape.len()).rev() {
10321 let coord = remaining % out_shape[d];
10322 remaining /= out_shape[d];
10323 a_idx += coord * a_str[d] as usize;
10324 b_idx += coord * b_str[d] as usize;
10325 }
10326 result.push(op(a_host[a_idx], b_host[b_idx]));
10327 }
10328 cpu_to_gpu(&result, device)
10329}
10330
10331#[cfg(feature = "cuda")]
10346pub fn gpu_neg(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
10347 validate_unary(a, device)?;
10348
10349 if let Some(out) = try_launch_unary(a, device, NEG_PTX, "neg_kernel")? {
10350 return Ok(out);
10351 }
10352
10353 cpu_fallback_unary(a, device, |x| -x)
10354}
10355
10356#[cfg(feature = "cuda")]
10367pub fn gpu_relu(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
10368 validate_unary(a, device)?;
10369
10370 if let Some(out) = try_launch_unary(a, device, RELU_PTX, "relu_kernel")? {
10371 return Ok(out);
10372 }
10373
10374 cpu_fallback_unary(a, device, |x| x.max(0.0))
10375}
10376
10377#[cfg(feature = "cuda")]
10379pub fn gpu_relu_backward(
10380 grad: &CudaBuffer<f32>,
10381 input: &CudaBuffer<f32>,
10382 device: &GpuDevice,
10383) -> GpuResult<CudaBuffer<f32>> {
10384 validate_binary(grad, input, device)?;
10385
10386 if let Some(out) = try_launch_binary(
10387 grad,
10388 input,
10389 device,
10390 RELU_BACKWARD_PTX,
10391 "relu_backward_kernel",
10392 )? {
10393 return Ok(out);
10394 }
10395
10396 let grad_host = gpu_to_cpu(grad, device)?;
10398 let input_host = gpu_to_cpu(input, device)?;
10399 let result: Vec<f32> = grad_host
10400 .iter()
10401 .zip(input_host.iter())
10402 .map(|(&g, &x)| if x > 0.0 { g } else { 0.0 })
10403 .collect();
10404 cpu_to_gpu(&result, device)
10405}
10406
10407#[cfg(feature = "cuda")]
10410pub fn gpu_gelu_backward(
10411 grad: &CudaBuffer<f32>,
10412 input: &CudaBuffer<f32>,
10413 device: &GpuDevice,
10414) -> GpuResult<CudaBuffer<f32>> {
10415 validate_binary(grad, input, device)?;
10416
10417 if let Some(out) = try_launch_binary(
10418 grad,
10419 input,
10420 device,
10421 GELU_BACKWARD_PTX,
10422 "gelu_backward_kernel",
10423 )? {
10424 return Ok(out);
10425 }
10426
10427 let grad_host = gpu_to_cpu(grad, device)?;
10429 let input_host = gpu_to_cpu(input, device)?;
10430 let result: Vec<f32> = grad_host
10431 .iter()
10432 .zip(input_host.iter())
10433 .map(|(&g, &x)| {
10434 let k: f32 = 1.702;
10435 let sig = 1.0 / (1.0 + (-k * x).exp());
10436 g * (sig + k * x * sig * (1.0 - sig))
10437 })
10438 .collect();
10439 cpu_to_gpu(&result, device)
10440}
10441
10442#[cfg(feature = "cuda")]
10446pub fn gpu_gelu_backward_erf(
10447 grad: &CudaBuffer<f32>,
10448 input: &CudaBuffer<f32>,
10449 device: &GpuDevice,
10450) -> GpuResult<CudaBuffer<f32>> {
10451 validate_binary(grad, input, device)?;
10452
10453 if let Some(out) = try_launch_binary(
10454 grad,
10455 input,
10456 device,
10457 GELU_BACKWARD_ERF_PTX,
10458 "gelu_backward_erf_kernel",
10459 )? {
10460 return Ok(out);
10461 }
10462
10463 let grad_host = gpu_to_cpu(grad, device)?;
10465 let input_host = gpu_to_cpu(input, device)?;
10466 let inv_sqrt_2: f32 = std::f32::consts::FRAC_1_SQRT_2;
10467 let inv_sqrt_2pi: f32 = 1.0 / (2.0 * std::f32::consts::PI).sqrt();
10468 let result: Vec<f32> = grad_host
10469 .iter()
10470 .zip(input_host.iter())
10471 .map(|(&g, &x)| {
10472 let z = x * inv_sqrt_2;
10473 let az = z.abs();
10474 let t = 1.0 / (1.0 + 0.3275911 * az);
10475 let poly = t * (0.2548296 + t * (-0.2844967 + t * (1.4214137 + t * (-1.4531520 + t * 0.3275911))));
10476 let erf_abs = 1.0 - poly * (-az * az).exp();
10477 let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
10478 let cdf = 0.5 * (1.0 + erf_val);
10479 let pdf = inv_sqrt_2pi * (-0.5 * x * x).exp();
10480 g * (cdf + x * pdf)
10481 })
10482 .collect();
10483 cpu_to_gpu(&result, device)
10484}
10485
10486#[cfg(feature = "cuda")]
10495pub fn gpu_index_select_1d(
10496 input: &CudaBuffer<f32>,
10497 indices: &CudaBuffer<f32>,
10498 device: &GpuDevice,
10499) -> GpuResult<CudaBuffer<f32>> {
10500 use cudarc::driver::PushKernelArg;
10501
10502 validate_unary(input, device)?;
10503
10504 let n = indices.len();
10505 let ctx = device.context();
10506 let stream = device.stream();
10507
10508 let f = match crate::module_cache::get_or_compile(
10509 ctx,
10510 INDEX_SELECT_1D_PTX,
10511 "index_select_1d_kernel",
10512 device.ordinal() as u32,
10513 ) {
10514 Ok(f) => f,
10515 Err(_) => {
10516 let input_host = gpu_to_cpu(input, device)?;
10518 let indices_host = gpu_to_cpu(indices, device)?;
10519 let result: Vec<f32> = indices_host
10520 .iter()
10521 .map(|&idx_f| input_host[idx_f as usize])
10522 .collect();
10523 return cpu_to_gpu(&result, device);
10524 }
10525 };
10526
10527 let mut out = alloc_zeros_f32(n, device)?;
10528 let cfg = launch_cfg(n)?;
10529 let n_u32 = n as u32;
10530
10531 unsafe {
10532 stream
10533 .launch_builder(&f)
10534 .arg(input.inner())
10535 .arg(indices.inner())
10536 .arg(out.inner_mut())
10537 .arg(&n_u32)
10538 .launch(cfg)?;
10539 }
10540
10541 Ok(out)
10542}
10543
10544#[cfg(feature = "cuda")]
10556pub fn gpu_scatter_add_1d(
10557 grad_output: &CudaBuffer<f32>,
10558 indices: &CudaBuffer<f32>,
10559 input_len: usize,
10560 device: &GpuDevice,
10561) -> GpuResult<CudaBuffer<f32>> {
10562 use cudarc::driver::PushKernelArg;
10563
10564 validate_unary(grad_output, device)?;
10565
10566 let n = grad_output.len();
10567 let ctx = device.context();
10568 let stream = device.stream();
10569
10570 let f = match crate::module_cache::get_or_compile(
10571 ctx,
10572 SCATTER_ADD_1D_PTX,
10573 "scatter_add_1d_kernel",
10574 device.ordinal() as u32,
10575 ) {
10576 Ok(f) => f,
10577 Err(_) => {
10578 let go_host = gpu_to_cpu(grad_output, device)?;
10580 let idx_host = gpu_to_cpu(indices, device)?;
10581 let mut result = vec![0.0f32; input_len];
10582 for (i, &idx_f) in idx_host.iter().enumerate() {
10583 result[idx_f as usize] += go_host[i];
10584 }
10585 return cpu_to_gpu(&result, device);
10586 }
10587 };
10588
10589 let mut out = alloc_zeros_f32(input_len, device)?;
10590 let cfg = launch_cfg(n)?;
10591 let n_u32 = n as u32;
10592
10593 unsafe {
10594 stream
10595 .launch_builder(&f)
10596 .arg(grad_output.inner())
10597 .arg(indices.inner())
10598 .arg(out.inner_mut())
10599 .arg(&n_u32)
10600 .launch(cfg)?;
10601 }
10602
10603 Ok(out)
10604}
10605
10606#[cfg(feature = "cuda")]
10615pub fn gpu_masked_fill(
10616 input: &CudaBuffer<f32>,
10617 mask: &CudaBuffer<f32>,
10618 value: f32,
10619 device: &GpuDevice,
10620) -> GpuResult<CudaBuffer<f32>> {
10621 use cudarc::driver::PushKernelArg;
10622
10623 validate_binary(input, mask, device)?;
10624
10625 let n = input.len();
10626 let ctx = device.context();
10627 let stream = device.stream();
10628
10629 let f = match crate::module_cache::get_or_compile(
10630 ctx,
10631 MASKED_FILL_PTX,
10632 "masked_fill_kernel",
10633 device.ordinal() as u32,
10634 ) {
10635 Ok(f) => f,
10636 Err(_) => {
10637 let input_host = gpu_to_cpu(input, device)?;
10639 let mask_host = gpu_to_cpu(mask, device)?;
10640 let result: Vec<f32> = input_host
10641 .iter()
10642 .zip(mask_host.iter())
10643 .map(|(&x, &m)| if m >= 0.5 { value } else { x })
10644 .collect();
10645 return cpu_to_gpu(&result, device);
10646 }
10647 };
10648
10649 let mut out = alloc_zeros_f32(n, device)?;
10650 let cfg = launch_cfg(n)?;
10651 let n_u32 = n as u32;
10652
10653 unsafe {
10654 stream
10655 .launch_builder(&f)
10656 .arg(input.inner())
10657 .arg(mask.inner())
10658 .arg(out.inner_mut())
10659 .arg(&value)
10660 .arg(&n_u32)
10661 .launch(cfg)?;
10662 }
10663
10664 Ok(out)
10665}
10666
10667#[cfg(feature = "cuda")]
10676pub fn gpu_masked_zero(
10677 grad: &CudaBuffer<f32>,
10678 mask: &CudaBuffer<f32>,
10679 device: &GpuDevice,
10680) -> GpuResult<CudaBuffer<f32>> {
10681 validate_binary(grad, mask, device)?;
10682
10683 if let Some(out) = try_launch_binary(grad, mask, device, MASKED_ZERO_PTX, "masked_zero_kernel")?
10684 {
10685 return Ok(out);
10686 }
10687
10688 let grad_host = gpu_to_cpu(grad, device)?;
10690 let mask_host = gpu_to_cpu(mask, device)?;
10691 let result: Vec<f32> = grad_host
10692 .iter()
10693 .zip(mask_host.iter())
10694 .map(|(&g, &m)| if m >= 0.5 { 0.0 } else { g })
10695 .collect();
10696 cpu_to_gpu(&result, device)
10697}
10698
10699#[cfg(feature = "cuda")]
10707pub fn gpu_sigmoid_backward(
10708 grad: &CudaBuffer<f32>,
10709 output: &CudaBuffer<f32>,
10710 device: &GpuDevice,
10711) -> GpuResult<CudaBuffer<f32>> {
10712 validate_binary(grad, output, device)?;
10713
10714 if let Some(out) = try_launch_binary(
10715 grad,
10716 output,
10717 device,
10718 SIGMOID_BACKWARD_PTX,
10719 "sigmoid_backward_kernel",
10720 )? {
10721 return Ok(out);
10722 }
10723
10724 let grad_host = gpu_to_cpu(grad, device)?;
10726 let output_host = gpu_to_cpu(output, device)?;
10727 let result: Vec<f32> = grad_host
10728 .iter()
10729 .zip(output_host.iter())
10730 .map(|(&g, &o)| g * o * (1.0 - o))
10731 .collect();
10732 cpu_to_gpu(&result, device)
10733}
10734
10735#[cfg(feature = "cuda")]
10743pub fn gpu_tanh_backward(
10744 grad: &CudaBuffer<f32>,
10745 output: &CudaBuffer<f32>,
10746 device: &GpuDevice,
10747) -> GpuResult<CudaBuffer<f32>> {
10748 validate_binary(grad, output, device)?;
10749
10750 if let Some(out) = try_launch_binary(
10751 grad,
10752 output,
10753 device,
10754 TANH_BACKWARD_PTX,
10755 "tanh_backward_kernel",
10756 )? {
10757 return Ok(out);
10758 }
10759
10760 let grad_host = gpu_to_cpu(grad, device)?;
10762 let output_host = gpu_to_cpu(output, device)?;
10763 let result: Vec<f32> = grad_host
10764 .iter()
10765 .zip(output_host.iter())
10766 .map(|(&g, &o)| g * (1.0 - o * o))
10767 .collect();
10768 cpu_to_gpu(&result, device)
10769}
10770
10771#[cfg(feature = "cuda")]
10783pub fn gpu_softmax_backward(
10784 grad: &CudaBuffer<f32>,
10785 output: &CudaBuffer<f32>,
10786 cols: usize,
10787 device: &GpuDevice,
10788) -> GpuResult<CudaBuffer<f32>> {
10789 use cudarc::driver::PushKernelArg;
10790
10791 validate_binary(grad, output, device)?;
10792
10793 let total = grad.len();
10794 let rows = total / cols;
10795
10796 let ctx = device.context();
10797 let stream = device.stream();
10798
10799 let f = match crate::module_cache::get_or_compile(
10800 ctx,
10801 SOFTMAX_BACKWARD_PTX,
10802 "softmax_backward_kernel",
10803 device.ordinal() as u32,
10804 ) {
10805 Ok(f) => f,
10806 Err(_) => {
10807 let grad_host = gpu_to_cpu(grad, device)?;
10809 let output_host = gpu_to_cpu(output, device)?;
10810 let mut result = vec![0.0f32; total];
10811 for r in 0..rows {
10812 let base = r * cols;
10813 let mut dot = 0.0f32;
10814 for c in 0..cols {
10815 dot += grad_host[base + c] * output_host[base + c];
10816 }
10817 for c in 0..cols {
10818 result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
10819 }
10820 }
10821 return cpu_to_gpu(&result, device);
10822 }
10823 };
10824
10825 let mut out = alloc_zeros_f32(total, device)?;
10826 let rows_u32 = rows as u32;
10827 let cols_u32 = cols as u32;
10828
10829 let cfg = LaunchConfig {
10831 grid_dim: ((rows as u32).max(1), 1, 1),
10832 block_dim: (256, 1, 1),
10833 shared_mem_bytes: 256 * 4,
10834 };
10835
10836 unsafe {
10837 stream
10838 .launch_builder(&f)
10839 .arg(grad.inner())
10840 .arg(output.inner())
10841 .arg(out.inner_mut())
10842 .arg(&rows_u32)
10843 .arg(&cols_u32)
10844 .launch(cfg)?;
10845 }
10846
10847 Ok(out)
10848}
10849
10850#[cfg(feature = "cuda")]
10861pub fn gpu_log_softmax(
10862 input: &CudaBuffer<f32>,
10863 cols: usize,
10864 device: &GpuDevice,
10865) -> GpuResult<CudaBuffer<f32>> {
10866 use cudarc::driver::PushKernelArg;
10867
10868 validate_unary(input, device)?;
10869
10870 let total = input.len();
10871 let rows = total / cols;
10872
10873 let ctx = device.context();
10874 let stream = device.stream();
10875
10876 let f = match crate::module_cache::get_or_compile(
10877 ctx,
10878 LOG_SOFTMAX_PTX,
10879 "log_softmax_kernel",
10880 device.ordinal() as u32,
10881 ) {
10882 Ok(f) => f,
10883 Err(_) => {
10884 let host = gpu_to_cpu(input, device)?;
10886 let mut out = vec![0.0f32; total];
10887 for r in 0..rows {
10888 let base = r * cols;
10889 let mut max_v = f32::NEG_INFINITY;
10890 for c in 0..cols {
10891 max_v = max_v.max(host[base + c]);
10892 }
10893 let mut sum_exp = 0.0f32;
10894 for c in 0..cols {
10895 sum_exp += (host[base + c] - max_v).exp();
10896 }
10897 let log_sum_exp = max_v + sum_exp.ln();
10898 for c in 0..cols {
10899 out[base + c] = host[base + c] - log_sum_exp;
10900 }
10901 }
10902 return cpu_to_gpu(&out, device);
10903 }
10904 };
10905
10906 let mut out = alloc_zeros_f32(total, device)?;
10907 let rows_u32 = rows as u32;
10908 let cols_u32 = cols as u32;
10909
10910 let cfg = LaunchConfig {
10912 grid_dim: ((rows as u32).max(1), 1, 1),
10913 block_dim: (256, 1, 1),
10914 shared_mem_bytes: 256 * 4,
10915 };
10916
10917 unsafe {
10918 stream
10919 .launch_builder(&f)
10920 .arg(input.inner())
10921 .arg(out.inner_mut())
10922 .arg(&rows_u32)
10923 .arg(&cols_u32)
10924 .launch(cfg)?;
10925 }
10926
10927 Ok(out)
10928}
10929
10930#[cfg(feature = "cuda")]
10938pub fn gpu_log_softmax_backward(
10939 grad: &CudaBuffer<f32>,
10940 output: &CudaBuffer<f32>,
10941 cols: usize,
10942 device: &GpuDevice,
10943) -> GpuResult<CudaBuffer<f32>> {
10944 use cudarc::driver::PushKernelArg;
10945
10946 validate_binary(grad, output, device)?;
10947
10948 let total = grad.len();
10949 let rows = total / cols;
10950
10951 let ctx = device.context();
10952 let stream = device.stream();
10953
10954 let f = match crate::module_cache::get_or_compile(
10955 ctx,
10956 LOG_SOFTMAX_BACKWARD_PTX,
10957 "log_softmax_backward_kernel",
10958 device.ordinal() as u32,
10959 ) {
10960 Ok(f) => f,
10961 Err(_) => {
10962 let grad_host = gpu_to_cpu(grad, device)?;
10964 let output_host = gpu_to_cpu(output, device)?;
10965 let mut result = vec![0.0f32; total];
10966 for r in 0..rows {
10967 let base = r * cols;
10968 let mut sum_grad = 0.0f32;
10969 for c in 0..cols {
10970 sum_grad += grad_host[base + c];
10971 }
10972 for c in 0..cols {
10973 result[base + c] =
10974 grad_host[base + c] - output_host[base + c].exp() * sum_grad;
10975 }
10976 }
10977 return cpu_to_gpu(&result, device);
10978 }
10979 };
10980
10981 let mut out = alloc_zeros_f32(total, device)?;
10982 let rows_u32 = rows as u32;
10983 let cols_u32 = cols as u32;
10984
10985 let cfg = LaunchConfig {
10987 grid_dim: ((rows as u32).max(1), 1, 1),
10988 block_dim: (256, 1, 1),
10989 shared_mem_bytes: 256 * 4,
10990 };
10991
10992 unsafe {
10993 stream
10994 .launch_builder(&f)
10995 .arg(grad.inner())
10996 .arg(output.inner())
10997 .arg(out.inner_mut())
10998 .arg(&rows_u32)
10999 .arg(&cols_u32)
11000 .launch(cfg)?;
11001 }
11002
11003 Ok(out)
11004}
11005
11006#[cfg(feature = "cuda")]
11020pub fn gpu_reduce_sum(
11021 a: &CudaBuffer<f32>,
11022 device: &GpuDevice,
11023) -> GpuResult<CudaBuffer<f32>> {
11024 use cudarc::driver::PushKernelArg;
11025
11026 let n = a.len();
11027 if n == 0 {
11028 return cpu_to_gpu(&[0.0f32], device);
11029 }
11030
11031 let ctx = device.context();
11032 let stream = device.stream();
11033
11034 let f = match crate::module_cache::get_or_compile(
11035 ctx,
11036 REDUCE_SUM_PTX,
11037 "reduce_sum_kernel",
11038 device.ordinal() as u32,
11039 ) {
11040 Ok(f) => f,
11041 Err(_) => {
11042 let host = gpu_to_cpu(a, device)?;
11044 let total: f32 = host.iter().sum();
11045 return cpu_to_gpu(&[total], device);
11046 }
11047 };
11048
11049 const BLOCK: u32 = 256;
11051 let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
11052 let num_blocks = num_blocks.min(1024);
11054
11055 let mut partials = alloc_zeros_f32(num_blocks as usize, device)?;
11056 let n_u32 = n as u32;
11057
11058 let cfg = cudarc::driver::LaunchConfig {
11059 grid_dim: (num_blocks.max(1), 1, 1),
11060 block_dim: (BLOCK, 1, 1),
11061 shared_mem_bytes: 0, };
11063
11064 unsafe {
11065 stream
11066 .launch_builder(&f)
11067 .arg(a.inner())
11068 .arg(partials.inner_mut())
11069 .arg(&n_u32)
11070 .launch(cfg)?;
11071 }
11072
11073 if num_blocks <= 1 {
11075 return Ok(partials);
11076 }
11077
11078 if num_blocks <= 256 {
11080 let host_partials = gpu_to_cpu(&partials, device)?;
11081 let total: f32 = host_partials.iter().sum();
11082 return cpu_to_gpu(&[total], device);
11083 }
11084
11085 gpu_reduce_sum(&partials, device)
11087}
11088
11089#[cfg(not(feature = "cuda"))]
11091pub fn gpu_reduce_sum(
11092 _a: &CudaBuffer<f32>,
11093 _device: &GpuDevice,
11094) -> GpuResult<CudaBuffer<f32>> {
11095 Err(GpuError::NoCudaFeature)
11096}
11097
11098#[cfg(feature = "cuda")]
11102pub fn gpu_sum_axis(
11103 a: &CudaBuffer<f32>,
11104 outer: usize,
11105 axis_size: usize,
11106 inner: usize,
11107 device: &GpuDevice,
11108) -> GpuResult<CudaBuffer<f32>> {
11109 use cudarc::driver::PushKernelArg;
11110
11111 validate_unary(a, device)?;
11112
11113 let total_output = outer * inner;
11114 let ctx = device.context();
11115 let stream = device.stream();
11116
11117 let f = match crate::module_cache::get_or_compile(
11118 ctx,
11119 SUM_AXIS_PTX,
11120 "sum_axis_kernel",
11121 device.ordinal() as u32,
11122 ) {
11123 Ok(f) => f,
11124 Err(_) => {
11125 let host = gpu_to_cpu(a, device)?;
11127 let mut result = vec![0.0f32; total_output];
11128 for (i, out) in result.iter_mut().enumerate() {
11129 let outer_idx = i / inner;
11130 let inner_idx = i % inner;
11131 let mut sum = 0.0f32;
11132 for k in 0..axis_size {
11133 sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
11134 }
11135 *out = sum;
11136 }
11137 return cpu_to_gpu(&result, device);
11138 }
11139 };
11140
11141 let mut out = alloc_zeros_f32(total_output, device)?;
11142 let cfg = launch_cfg(total_output)?;
11143 let outer_u32 = outer as u32;
11144 let axis_size_u32 = axis_size as u32;
11145 let inner_u32 = inner as u32;
11146 let total_u32 = total_output as u32;
11147
11148 unsafe {
11149 stream
11150 .launch_builder(&f)
11151 .arg(a.inner())
11152 .arg(out.inner_mut())
11153 .arg(&outer_u32)
11154 .arg(&axis_size_u32)
11155 .arg(&inner_u32)
11156 .arg(&total_u32)
11157 .launch(cfg)?;
11158 }
11159
11160 Ok(out)
11161}
11162
11163#[cfg(feature = "cuda")]
11180pub fn gpu_cumsum(
11181 input: &CudaBuffer<f32>,
11182 outer: usize,
11183 dim_size: usize,
11184 inner: usize,
11185 device: &GpuDevice,
11186) -> GpuResult<CudaBuffer<f32>> {
11187 use cudarc::driver::PushKernelArg;
11188
11189 validate_unary(input, device)?;
11190
11191 let total = outer * dim_size * inner;
11192 let num_threads = outer * inner;
11193 let ctx = device.context();
11194 let stream = device.stream();
11195
11196 let f = match crate::module_cache::get_or_compile(
11197 ctx,
11198 CUMSUM_PTX,
11199 "cumsum_kernel",
11200 device.ordinal() as u32,
11201 ) {
11202 Ok(f) => f,
11203 Err(_) => {
11204 let host = gpu_to_cpu(input, device)?;
11206 let mut result = vec![0.0f32; total];
11207 for i in 0..num_threads {
11208 let outer_idx = i / inner;
11209 let inner_idx = i % inner;
11210 let base = outer_idx * dim_size * inner + inner_idx;
11211 let mut acc = 0.0f32;
11212 for k in 0..dim_size {
11213 let idx = base + k * inner;
11214 acc += host[idx];
11215 result[idx] = acc;
11216 }
11217 }
11218 return cpu_to_gpu(&result, device);
11219 }
11220 };
11221
11222 let mut out = alloc_zeros_f32(total, device)?;
11223 let cfg = launch_cfg(num_threads)?;
11224 let outer_u32 = outer as u32;
11225 let dim_size_u32 = dim_size as u32;
11226 let inner_u32 = inner as u32;
11227 let total_u32 = total as u32;
11228
11229 unsafe {
11230 stream
11231 .launch_builder(&f)
11232 .arg(input.inner())
11233 .arg(out.inner_mut())
11234 .arg(&outer_u32)
11235 .arg(&dim_size_u32)
11236 .arg(&inner_u32)
11237 .arg(&total_u32)
11238 .launch(cfg)?;
11239 }
11240
11241 Ok(out)
11242}
11243
11244#[cfg(feature = "cuda")]
11254pub fn gpu_cumprod(
11255 input: &CudaBuffer<f32>,
11256 outer: usize,
11257 dim_size: usize,
11258 inner: usize,
11259 device: &GpuDevice,
11260) -> GpuResult<CudaBuffer<f32>> {
11261 use cudarc::driver::PushKernelArg;
11262
11263 validate_unary(input, device)?;
11264
11265 let total = outer * dim_size * inner;
11266 let num_threads = outer * inner;
11267 let ctx = device.context();
11268 let stream = device.stream();
11269
11270 let f = match crate::module_cache::get_or_compile(
11271 ctx,
11272 CUMPROD_PTX,
11273 "cumprod_kernel",
11274 device.ordinal() as u32,
11275 ) {
11276 Ok(f) => f,
11277 Err(_) => {
11278 let host = gpu_to_cpu(input, device)?;
11280 let mut result = vec![0.0f32; total];
11281 for i in 0..num_threads {
11282 let outer_idx = i / inner;
11283 let inner_idx = i % inner;
11284 let base = outer_idx * dim_size * inner + inner_idx;
11285 let mut acc = 1.0f32;
11286 for k in 0..dim_size {
11287 let idx = base + k * inner;
11288 acc *= host[idx];
11289 result[idx] = acc;
11290 }
11291 }
11292 return cpu_to_gpu(&result, device);
11293 }
11294 };
11295
11296 let mut out = alloc_zeros_f32(total, device)?;
11297 let cfg = launch_cfg(num_threads)?;
11298 let outer_u32 = outer as u32;
11299 let dim_size_u32 = dim_size as u32;
11300 let inner_u32 = inner as u32;
11301 let total_u32 = total as u32;
11302
11303 unsafe {
11304 stream
11305 .launch_builder(&f)
11306 .arg(input.inner())
11307 .arg(out.inner_mut())
11308 .arg(&outer_u32)
11309 .arg(&dim_size_u32)
11310 .arg(&inner_u32)
11311 .arg(&total_u32)
11312 .launch(cfg)?;
11313 }
11314
11315 Ok(out)
11316}
11317
11318#[cfg(feature = "cuda")]
11328pub fn gpu_cummax(
11329 input: &CudaBuffer<f32>,
11330 outer: usize,
11331 dim_size: usize,
11332 inner: usize,
11333 device: &GpuDevice,
11334) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
11335 use cudarc::driver::PushKernelArg;
11336
11337 validate_unary(input, device)?;
11338
11339 let total = outer * dim_size * inner;
11340 let num_threads = outer * inner;
11341 let ctx = device.context();
11342 let stream = device.stream();
11343
11344 let f = match crate::module_cache::get_or_compile(
11345 ctx,
11346 CUMMAX_PTX,
11347 "cummax_kernel",
11348 device.ordinal() as u32,
11349 ) {
11350 Ok(f) => f,
11351 Err(_) => {
11352 let host = gpu_to_cpu(input, device)?;
11353 let mut vals = vec![0.0f32; total];
11354 let mut idxs = vec![0.0f32; total];
11355 for i in 0..num_threads {
11356 let outer_idx = i / inner;
11357 let inner_idx = i % inner;
11358 let base = outer_idx * dim_size * inner + inner_idx;
11359 let mut acc = f32::NEG_INFINITY;
11360 let mut best = 0u32;
11361 for k in 0..dim_size {
11362 let idx = base + k * inner;
11363 if host[idx] > acc {
11364 acc = host[idx];
11365 best = k as u32;
11366 }
11367 vals[idx] = acc;
11368 idxs[idx] = best as f32;
11369 }
11370 }
11371 return Ok((cpu_to_gpu(&vals, device)?, cpu_to_gpu(&idxs, device)?));
11372 }
11373 };
11374
11375 let mut out = alloc_zeros_f32(total, device)?;
11376 let mut out_idx = alloc_zeros_f32(total, device)?;
11377 let cfg = launch_cfg(num_threads)?;
11378 let outer_u32 = outer as u32;
11379 let dim_size_u32 = dim_size as u32;
11380 let inner_u32 = inner as u32;
11381 let total_u32 = total as u32;
11382
11383 unsafe {
11384 stream
11385 .launch_builder(&f)
11386 .arg(input.inner())
11387 .arg(out.inner_mut())
11388 .arg(out_idx.inner_mut())
11389 .arg(&outer_u32)
11390 .arg(&dim_size_u32)
11391 .arg(&inner_u32)
11392 .arg(&total_u32)
11393 .launch(cfg)?;
11394 }
11395
11396 Ok((out, out_idx))
11397}
11398
11399#[cfg(feature = "cuda")]
11409pub fn gpu_cummin(
11410 input: &CudaBuffer<f32>,
11411 outer: usize,
11412 dim_size: usize,
11413 inner: usize,
11414 device: &GpuDevice,
11415) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
11416 use cudarc::driver::PushKernelArg;
11417
11418 validate_unary(input, device)?;
11419
11420 let total = outer * dim_size * inner;
11421 let num_threads = outer * inner;
11422 let ctx = device.context();
11423 let stream = device.stream();
11424
11425 let f = match crate::module_cache::get_or_compile(
11426 ctx,
11427 CUMMIN_PTX,
11428 "cummin_kernel",
11429 device.ordinal() as u32,
11430 ) {
11431 Ok(f) => f,
11432 Err(_) => {
11433 let host = gpu_to_cpu(input, device)?;
11434 let mut vals = vec![0.0f32; total];
11435 let mut idxs = vec![0.0f32; total];
11436 for i in 0..num_threads {
11437 let outer_idx = i / inner;
11438 let inner_idx = i % inner;
11439 let base = outer_idx * dim_size * inner + inner_idx;
11440 let mut acc = f32::INFINITY;
11441 let mut best = 0u32;
11442 for k in 0..dim_size {
11443 let idx = base + k * inner;
11444 if host[idx] < acc {
11445 acc = host[idx];
11446 best = k as u32;
11447 }
11448 vals[idx] = acc;
11449 idxs[idx] = best as f32;
11450 }
11451 }
11452 return Ok((cpu_to_gpu(&vals, device)?, cpu_to_gpu(&idxs, device)?));
11453 }
11454 };
11455
11456 let mut out = alloc_zeros_f32(total, device)?;
11457 let mut out_idx = alloc_zeros_f32(total, device)?;
11458 let cfg = launch_cfg(num_threads)?;
11459 let outer_u32 = outer as u32;
11460 let dim_size_u32 = dim_size as u32;
11461 let inner_u32 = inner as u32;
11462 let total_u32 = total as u32;
11463
11464 unsafe {
11465 stream
11466 .launch_builder(&f)
11467 .arg(input.inner())
11468 .arg(out.inner_mut())
11469 .arg(out_idx.inner_mut())
11470 .arg(&outer_u32)
11471 .arg(&dim_size_u32)
11472 .arg(&inner_u32)
11473 .arg(&total_u32)
11474 .launch(cfg)?;
11475 }
11476
11477 Ok((out, out_idx))
11478}
11479
11480#[cfg(feature = "cuda")]
11490pub fn gpu_logcumsumexp(
11491 input: &CudaBuffer<f32>,
11492 outer: usize,
11493 dim_size: usize,
11494 inner: usize,
11495 device: &GpuDevice,
11496) -> GpuResult<CudaBuffer<f32>> {
11497 use cudarc::driver::PushKernelArg;
11498
11499 validate_unary(input, device)?;
11500
11501 let total = outer * dim_size * inner;
11502 let num_threads = outer * inner;
11503 let ctx = device.context();
11504 let stream = device.stream();
11505
11506 let f = match crate::module_cache::get_or_compile(
11507 ctx,
11508 LOGCUMSUMEXP_PTX,
11509 "logcumsumexp_kernel",
11510 device.ordinal() as u32,
11511 ) {
11512 Ok(f) => f,
11513 Err(_) => {
11514 let host = gpu_to_cpu(input, device)?;
11516 let mut result = vec![0.0f32; total];
11517 for i in 0..num_threads {
11518 let outer_idx = i / inner;
11519 let inner_idx = i % inner;
11520 let base = outer_idx * dim_size * inner + inner_idx;
11521 let mut acc = f32::NEG_INFINITY;
11522 for k in 0..dim_size {
11523 let idx = base + k * inner;
11524 let x = host[idx];
11525 let m = acc.max(x);
11526 acc = m + ((acc - m).exp() + (x - m).exp()).ln();
11527 result[idx] = acc;
11528 }
11529 }
11530 return cpu_to_gpu(&result, device);
11531 }
11532 };
11533
11534 let mut out = alloc_zeros_f32(total, device)?;
11535 let cfg = launch_cfg(num_threads)?;
11536 let outer_u32 = outer as u32;
11537 let dim_size_u32 = dim_size as u32;
11538 let inner_u32 = inner as u32;
11539 let total_u32 = total as u32;
11540
11541 unsafe {
11542 stream
11543 .launch_builder(&f)
11544 .arg(input.inner())
11545 .arg(out.inner_mut())
11546 .arg(&outer_u32)
11547 .arg(&dim_size_u32)
11548 .arg(&inner_u32)
11549 .arg(&total_u32)
11550 .launch(cfg)?;
11551 }
11552
11553 Ok(out)
11554}
11555
11556#[cfg(feature = "cuda")]
11574pub fn gpu_strided_split(
11575 input: &CudaBuffer<f32>,
11576 total_along_axis: usize,
11577 split_offset: usize,
11578 split_size: usize,
11579 inner_size: usize,
11580 n: usize,
11581 device: &GpuDevice,
11582) -> GpuResult<CudaBuffer<f32>> {
11583 use cudarc::driver::PushKernelArg;
11584
11585 validate_unary(input, device)?;
11586
11587 let ctx = device.context();
11588 let stream = device.stream();
11589
11590 let f = match crate::module_cache::get_or_compile(
11591 ctx,
11592 STRIDED_SPLIT_PTX,
11593 "strided_split_kernel",
11594 device.ordinal() as u32,
11595 ) {
11596 Ok(f) => f,
11597 Err(_) => {
11598 let host = gpu_to_cpu(input, device)?;
11600 let outer = n / (split_size * inner_size);
11601 let mut result = vec![0.0f32; n];
11602 for (i, out) in result.iter_mut().enumerate() {
11603 let outer_idx = i / (split_size * inner_size);
11604 let within = i % (split_size * inner_size);
11605 let src_idx =
11606 outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
11607 *out = host[src_idx];
11608 }
11609 let _ = outer;
11610 return cpu_to_gpu(&result, device);
11611 }
11612 };
11613
11614 let mut out = alloc_zeros_f32(n, device)?;
11615 let cfg = launch_cfg(n)?;
11616 let total_ax_u32 = total_along_axis as u32;
11617 let offset_u32 = split_offset as u32;
11618 let split_sz_u32 = split_size as u32;
11619 let inner_u32 = inner_size as u32;
11620 let n_u32 = n as u32;
11621
11622 unsafe {
11623 stream
11624 .launch_builder(&f)
11625 .arg(input.inner())
11626 .arg(out.inner_mut())
11627 .arg(&total_ax_u32)
11628 .arg(&offset_u32)
11629 .arg(&split_sz_u32)
11630 .arg(&inner_u32)
11631 .arg(&n_u32)
11632 .launch(cfg)?;
11633 }
11634
11635 Ok(out)
11636}
11637
11638#[cfg(feature = "cuda")]
11662#[allow(clippy::too_many_arguments)]
11663pub fn gpu_strided_cat(
11664 input: &CudaBuffer<f32>,
11665 output: &mut CudaBuffer<f32>,
11666 total_along_axis: usize,
11667 cat_offset: usize,
11668 part_size: usize,
11669 inner_size: usize,
11670 n: usize,
11671 device: &GpuDevice,
11672) -> GpuResult<()> {
11673 use cudarc::driver::PushKernelArg;
11674
11675 validate_unary(input, device)?;
11676
11677 let ctx = device.context();
11678 let stream = device.stream();
11679
11680 let f = match crate::module_cache::get_or_compile(
11681 ctx,
11682 STRIDED_CAT_PTX,
11683 "strided_cat_kernel",
11684 device.ordinal() as u32,
11685 ) {
11686 Ok(f) => f,
11687 Err(_) => {
11688 let host_in = gpu_to_cpu(input, device)?;
11690 let mut host_out = gpu_to_cpu(output, device)?;
11691 for (i, &val) in host_in.iter().enumerate().take(n) {
11692 let outer_idx = i / (part_size * inner_size);
11693 let within = i % (part_size * inner_size);
11694 let dst_idx =
11695 outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
11696 host_out[dst_idx] = val;
11697 }
11698 *output = cpu_to_gpu(&host_out, device)?;
11699 return Ok(());
11700 }
11701 };
11702
11703 let cfg = launch_cfg(n)?;
11704 let total_ax_u32 = total_along_axis as u32;
11705 let offset_u32 = cat_offset as u32;
11706 let part_sz_u32 = part_size as u32;
11707 let inner_u32 = inner_size as u32;
11708 let n_u32 = n as u32;
11709
11710 unsafe {
11711 stream
11712 .launch_builder(&f)
11713 .arg(input.inner())
11714 .arg(output.inner_mut())
11715 .arg(&total_ax_u32)
11716 .arg(&offset_u32)
11717 .arg(&part_sz_u32)
11718 .arg(&inner_u32)
11719 .arg(&n_u32)
11720 .launch(cfg)?;
11721 }
11722
11723 Ok(())
11724}
11725
11726pub const STRIDED_COPY_MAX_DIMS: usize = 8;
11733
11734#[cfg(feature = "cuda")]
11748fn pad_strided_copy_params(
11749 out_shape: &[usize],
11750 src_strides: &[isize],
11751 n: usize,
11752) -> GpuResult<([u32; STRIDED_COPY_MAX_DIMS], [u32; STRIDED_COPY_MAX_DIMS])> {
11753 if out_shape.len() != src_strides.len() {
11754 return Err(GpuError::ShapeMismatch {
11755 op: "strided_copy_pad",
11756 expected: vec![out_shape.len()],
11757 got: vec![src_strides.len()],
11758 });
11759 }
11760 if out_shape.len() > STRIDED_COPY_MAX_DIMS {
11761 return Err(GpuError::ShapeMismatch {
11762 op: "strided_copy_pad",
11763 expected: vec![STRIDED_COPY_MAX_DIMS],
11764 got: vec![out_shape.len()],
11765 });
11766 }
11767 for &s in src_strides {
11770 if s < 0 {
11771 return Err(GpuError::ShapeMismatch {
11772 op: "strided_copy_pad_negative_stride",
11773 expected: vec![0],
11774 got: vec![s.unsigned_abs()],
11775 });
11776 }
11777 }
11778
11779 let rank = out_shape.len();
11780 let mut out_stride = [0u32; STRIDED_COPY_MAX_DIMS];
11783 if rank > 0 {
11784 let mut acc: usize = 1;
11785 for d in (0..rank).rev() {
11786 if acc > u32::MAX as usize {
11787 return Err(GpuError::ShapeMismatch {
11788 op: "strided_copy_stride_overflow",
11789 expected: vec![u32::MAX as usize],
11790 got: vec![acc],
11791 });
11792 }
11793 out_stride[d] = acc as u32;
11794 acc = acc.saturating_mul(out_shape[d]);
11795 }
11796 }
11797
11798 let pad_val = (n as u32).saturating_add(1).max(1);
11801 for d in rank..STRIDED_COPY_MAX_DIMS {
11802 out_stride[d] = pad_val;
11803 }
11804
11805 let mut src_stride_out = [0u32; STRIDED_COPY_MAX_DIMS];
11807 for d in 0..rank {
11808 let s = src_strides[d];
11809 if s as usize > u32::MAX as usize {
11810 return Err(GpuError::ShapeMismatch {
11811 op: "strided_copy_src_stride_overflow",
11812 expected: vec![u32::MAX as usize],
11813 got: vec![s as usize],
11814 });
11815 }
11816 src_stride_out[d] = s as u32;
11817 }
11818
11819 Ok((out_stride, src_stride_out))
11820}
11821
11822#[cfg(feature = "cuda")]
11847pub fn gpu_strided_copy(
11848 input: &CudaBuffer<f32>,
11849 out_shape: &[usize],
11850 src_strides: &[isize],
11851 src_offset: usize,
11852 device: &GpuDevice,
11853) -> GpuResult<CudaBuffer<f32>> {
11854 use cudarc::driver::PushKernelArg;
11855
11856 validate_unary(input, device)?;
11857
11858 let n: usize = out_shape.iter().product();
11859 let (out_stride, src_stride) = pad_strided_copy_params(out_shape, src_strides, n)?;
11860
11861 if n == 0 {
11862 return alloc_zeros_f32(0, device);
11863 }
11864
11865 let ctx = device.context();
11866 let stream = device.stream();
11867
11868 let f = match crate::module_cache::get_or_compile(
11869 ctx,
11870 STRIDED_COPY_PTX,
11871 "strided_copy_kernel",
11872 device.ordinal() as u32,
11873 ) {
11874 Ok(f) => f,
11875 Err(_) => {
11876 let host = gpu_to_cpu(input, device)?;
11878 let mut result = vec![0.0f32; n];
11879 for i in 0..n {
11880 let mut flat = i as u32;
11881 let mut src_idx = src_offset as u32;
11882 for d in 0..STRIDED_COPY_MAX_DIMS {
11883 let os = out_stride[d];
11884 let ss = src_stride[d];
11885 let coord = flat / os;
11886 flat -= coord * os;
11887 src_idx += coord * ss;
11888 }
11889 result[i] = host[src_idx as usize];
11890 }
11891 return cpu_to_gpu(&result, device);
11892 }
11893 };
11894
11895 let mut out = alloc_zeros_f32(n, device)?;
11896 let cfg = launch_cfg(n)?;
11897 let src_offset_u32 = src_offset as u32;
11898 let n_u32 = n as u32;
11899
11900 unsafe {
11901 stream
11902 .launch_builder(&f)
11903 .arg(input.inner())
11904 .arg(out.inner_mut())
11905 .arg(&src_offset_u32)
11906 .arg(&n_u32)
11907 .arg(&out_stride[0])
11908 .arg(&out_stride[1])
11909 .arg(&out_stride[2])
11910 .arg(&out_stride[3])
11911 .arg(&out_stride[4])
11912 .arg(&out_stride[5])
11913 .arg(&out_stride[6])
11914 .arg(&out_stride[7])
11915 .arg(&src_stride[0])
11916 .arg(&src_stride[1])
11917 .arg(&src_stride[2])
11918 .arg(&src_stride[3])
11919 .arg(&src_stride[4])
11920 .arg(&src_stride[5])
11921 .arg(&src_stride[6])
11922 .arg(&src_stride[7])
11923 .launch(cfg)?;
11924 }
11925
11926 Ok(out)
11927}
11928
11929#[cfg(feature = "cuda")]
11931pub fn gpu_strided_copy_f64(
11932 input: &CudaBuffer<f64>,
11933 out_shape: &[usize],
11934 src_strides: &[isize],
11935 src_offset: usize,
11936 device: &GpuDevice,
11937) -> GpuResult<CudaBuffer<f64>> {
11938 use cudarc::driver::PushKernelArg;
11939 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
11940
11941 validate_device(input, device)?;
11942
11943 let n: usize = out_shape.iter().product();
11944 let (out_stride, src_stride) = pad_strided_copy_params(out_shape, src_strides, n)?;
11945
11946 if n == 0 {
11947 return alloc_zeros_f64(0, device);
11948 }
11949
11950 let ctx = device.context();
11951 let stream = device.stream();
11952
11953 let ptx = get_f64_ptx(
11954 &CACHE,
11955 STRIDED_COPY_PTX,
11956 "strided_copy_kernel",
11957 "strided_copy_f64_kernel",
11958 );
11959 let f = match crate::module_cache::get_or_compile(
11960 ctx,
11961 ptx,
11962 "strided_copy_f64_kernel",
11963 device.ordinal() as u32,
11964 ) {
11965 Ok(f) => f,
11966 Err(_) => {
11967 let host = gpu_to_cpu(input, device)?;
11968 let mut result = vec![0.0f64; n];
11969 for i in 0..n {
11970 let mut flat = i as u32;
11971 let mut src_idx = src_offset as u32;
11972 for d in 0..STRIDED_COPY_MAX_DIMS {
11973 let os = out_stride[d];
11974 let ss = src_stride[d];
11975 let coord = flat / os;
11976 flat -= coord * os;
11977 src_idx += coord * ss;
11978 }
11979 result[i] = host[src_idx as usize];
11980 }
11981 return cpu_to_gpu(&result, device);
11982 }
11983 };
11984
11985 let mut out = alloc_zeros_f64(n, device)?;
11986 let cfg = launch_cfg(n)?;
11987 let src_offset_u32 = src_offset as u32;
11988 let n_u32 = n as u32;
11989
11990 unsafe {
11991 stream
11992 .launch_builder(&f)
11993 .arg(input.inner())
11994 .arg(out.inner_mut())
11995 .arg(&src_offset_u32)
11996 .arg(&n_u32)
11997 .arg(&out_stride[0])
11998 .arg(&out_stride[1])
11999 .arg(&out_stride[2])
12000 .arg(&out_stride[3])
12001 .arg(&out_stride[4])
12002 .arg(&out_stride[5])
12003 .arg(&out_stride[6])
12004 .arg(&out_stride[7])
12005 .arg(&src_stride[0])
12006 .arg(&src_stride[1])
12007 .arg(&src_stride[2])
12008 .arg(&src_stride[3])
12009 .arg(&src_stride[4])
12010 .arg(&src_stride[5])
12011 .arg(&src_stride[6])
12012 .arg(&src_stride[7])
12013 .launch(cfg)?;
12014 }
12015
12016 Ok(out)
12017}
12018
12019#[cfg(feature = "cuda")]
12028pub fn gpu_scale(
12029 a: &CudaBuffer<f32>,
12030 scalar: f32,
12031 device: &GpuDevice,
12032) -> GpuResult<CudaBuffer<f32>> {
12033 use cudarc::driver::PushKernelArg;
12034
12035 validate_unary(a, device)?;
12036
12037 let n = a.len();
12038 let ctx = device.context();
12039 let stream = device.stream();
12040
12041 let f = match crate::module_cache::get_or_compile(
12042 ctx,
12043 SCALE_PTX,
12044 "scale_kernel",
12045 device.ordinal() as u32,
12046 ) {
12047 Ok(f) => f,
12048 Err(_) => {
12049 let host = gpu_to_cpu(a, device)?;
12051 let result: Vec<f32> = host.iter().map(|&x| x * scalar).collect();
12052 return cpu_to_gpu(&result, device);
12053 }
12054 };
12055
12056 let mut out = alloc_zeros_f32(n, device)?;
12057 let cfg = launch_cfg(n)?;
12058 let n_u32 = n as u32;
12059
12060 unsafe {
12061 stream
12062 .launch_builder(&f)
12063 .arg(a.inner())
12064 .arg(out.inner_mut())
12065 .arg(&scalar)
12066 .arg(&n_u32)
12067 .launch(cfg)?;
12068 }
12069
12070 Ok(out)
12071}
12072
12073#[cfg(feature = "cuda")]
12081pub fn gpu_softmax(
12082 input: &CudaBuffer<f32>,
12083 rows: usize,
12084 cols: usize,
12085 device: &GpuDevice,
12086) -> GpuResult<CudaBuffer<f32>> {
12087 use cudarc::driver::PushKernelArg;
12088
12089 validate_unary(input, device)?;
12090
12091 let ctx = device.context();
12092 let stream = device.stream();
12093
12094 let f = match crate::module_cache::get_or_compile(
12095 ctx,
12096 SOFTMAX_PTX,
12097 "softmax_kernel",
12098 device.ordinal() as u32,
12099 ) {
12100 Ok(f) => f,
12101 Err(_) => {
12102 let host = gpu_to_cpu(input, device)?;
12104 let mut out = vec![0.0f32; host.len()];
12105 for r in 0..rows {
12106 let base = r * cols;
12107 let mut max_v = f32::NEG_INFINITY;
12108 for c in 0..cols {
12109 max_v = max_v.max(host[base + c]);
12110 }
12111 let mut sum = 0.0f32;
12112 for c in 0..cols {
12113 let e = (host[base + c] - max_v).exp();
12114 out[base + c] = e;
12115 sum += e;
12116 }
12117 let inv = 1.0 / sum;
12118 for c in 0..cols {
12119 out[base + c] *= inv;
12120 }
12121 }
12122 return cpu_to_gpu(&out, device);
12123 }
12124 };
12125
12126 let mut out = alloc_zeros_f32(rows * cols, device)?;
12127 let rows_u32 = rows as u32;
12128 let cols_u32 = cols as u32;
12129
12130 let cfg = LaunchConfig {
12132 grid_dim: ((rows as u32).max(1), 1, 1),
12133 block_dim: (256, 1, 1),
12134 shared_mem_bytes: 256 * 4, };
12136
12137 unsafe {
12138 stream
12139 .launch_builder(&f)
12140 .arg(input.inner())
12141 .arg(out.inner_mut())
12142 .arg(&rows_u32)
12143 .arg(&cols_u32)
12144 .launch(cfg)?;
12145 }
12146
12147 Ok(out)
12148}
12149
12150#[cfg(feature = "cuda")]
12169pub fn gpu_dropout(
12170 input: &CudaBuffer<f32>,
12171 threshold: u32,
12172 scale: f32,
12173 seed: u32,
12174 device: &GpuDevice,
12175) -> GpuResult<CudaBuffer<f32>> {
12176 use cudarc::driver::PushKernelArg;
12177
12178 validate_unary(input, device)?;
12179
12180 let n = input.len();
12181 let ctx = device.context();
12182 let stream = device.stream();
12183
12184 let f = match crate::module_cache::get_or_compile(
12185 ctx,
12186 DROPOUT_PTX,
12187 "dropout_kernel",
12188 device.ordinal() as u32,
12189 ) {
12190 Ok(f) => f,
12191 Err(_) => {
12192 let host = gpu_to_cpu(input, device)?;
12194 let result: Vec<f32> = host
12198 .iter()
12199 .enumerate()
12200 .map(|(i, &x)| {
12201 let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
12202 r ^= r << 13;
12203 r ^= r >> 17;
12204 r ^= r << 5;
12205 if r < threshold { 0.0 } else { x * scale }
12206 })
12207 .collect();
12208 return cpu_to_gpu(&result, device);
12209 }
12210 };
12211
12212 let mut out = alloc_zeros_f32(n, device)?;
12213 let cfg = launch_cfg(n)?;
12214 let n_u32 = n as u32;
12215
12216 unsafe {
12217 stream
12218 .launch_builder(&f)
12219 .arg(input.inner())
12220 .arg(out.inner_mut())
12221 .arg(&n_u32)
12222 .arg(&threshold)
12223 .arg(&scale)
12224 .arg(&seed)
12225 .launch(cfg)?;
12226 }
12227
12228 Ok(out)
12229}
12230
12231#[cfg(feature = "cuda")]
12233pub fn gpu_dropout_f64(
12234 input: &CudaBuffer<f64>,
12235 threshold: u32,
12236 scale: f64,
12237 seed: u32,
12238 device: &GpuDevice,
12239) -> GpuResult<CudaBuffer<f64>> {
12240 use cudarc::driver::PushKernelArg;
12241 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
12242
12243 let n = input.len();
12244 let ctx = device.context();
12245 let stream = device.stream();
12246
12247 let ptx = get_f64_ptx(&CACHE, DROPOUT_PTX, "dropout_kernel", "dropout_f64_kernel");
12248 let f = match crate::module_cache::get_or_compile(
12249 ctx, ptx, "dropout_f64_kernel", device.ordinal() as u32,
12250 ) {
12251 Ok(f) => f,
12252 Err(_) => {
12253 let host = gpu_to_cpu(input, device)?;
12254 let result: Vec<f64> = host
12255 .iter()
12256 .enumerate()
12257 .map(|(i, &x)| {
12258 let mut r = (i as u32).wrapping_mul(2654435761) ^ seed;
12259 r ^= r << 13;
12260 r ^= r >> 17;
12261 r ^= r << 5;
12262 if r < threshold { 0.0 } else { x * scale }
12263 })
12264 .collect();
12265 return cpu_to_gpu(&result, device);
12266 }
12267 };
12268
12269 let mut out = alloc_zeros_f64(n, device)?;
12270 let cfg = launch_cfg(n)?;
12271 let n_u32 = n as u32;
12272
12273 unsafe {
12274 stream
12275 .launch_builder(&f)
12276 .arg(input.inner())
12277 .arg(out.inner_mut())
12278 .arg(&n_u32)
12279 .arg(&threshold)
12280 .arg(&scale)
12281 .arg(&seed)
12282 .launch(cfg)?;
12283 }
12284
12285 Ok(out)
12286}
12287
12288#[cfg(not(feature = "cuda"))]
12289pub fn gpu_dropout_f64(_input: &CudaBuffer<f64>, _threshold: u32, _scale: f64, _seed: u32, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
12290
12291#[cfg(feature = "cuda")]
12297pub fn gpu_transpose_2d(
12298 input: &CudaBuffer<f32>,
12299 m: usize,
12300 n: usize,
12301 device: &GpuDevice,
12302) -> GpuResult<CudaBuffer<f32>> {
12303 use cudarc::driver::PushKernelArg;
12304
12305 validate_unary(input, device)?;
12306
12307 let total = m * n;
12308 let ctx = device.context();
12309 let stream = device.stream();
12310
12311 let f = match crate::module_cache::get_or_compile(
12312 ctx,
12313 TRANSPOSE_2D_PTX,
12314 "transpose_2d_kernel",
12315 device.ordinal() as u32,
12316 ) {
12317 Ok(f) => f,
12318 Err(_) => {
12319 let host = gpu_to_cpu(input, device)?;
12321 let mut out = vec![0.0f32; total];
12322 for i in 0..m {
12323 for j in 0..n {
12324 out[j * m + i] = host[i * n + j];
12325 }
12326 }
12327 return cpu_to_gpu(&out, device);
12328 }
12329 };
12330
12331 let mut out = alloc_zeros_f32(total, device)?;
12332 let cfg = launch_cfg(total)?;
12333 let m_u32 = m as u32;
12334 let n_u32 = n as u32;
12335 let total_u32 = total as u32;
12336
12337 unsafe {
12338 stream
12339 .launch_builder(&f)
12340 .arg(input.inner())
12341 .arg(out.inner_mut())
12342 .arg(&m_u32)
12343 .arg(&n_u32)
12344 .arg(&total_u32)
12345 .launch(cfg)?;
12346 }
12347
12348 Ok(out)
12349}
12350
12351#[cfg(feature = "cuda")]
12358pub fn gpu_permute_0213(
12359 input: &CudaBuffer<f32>,
12360 d0: usize,
12361 d1: usize,
12362 d2: usize,
12363 d3: usize,
12364 device: &GpuDevice,
12365) -> GpuResult<CudaBuffer<f32>> {
12366 use cudarc::driver::PushKernelArg;
12367
12368 validate_unary(input, device)?;
12369
12370 let total = d0 * d1 * d2 * d3;
12371 let ctx = device.context();
12372 let stream = device.stream();
12373
12374 let f = match crate::module_cache::get_or_compile(
12375 ctx,
12376 PERMUTE_0213_PTX,
12377 "permute_0213_kernel",
12378 device.ordinal() as u32,
12379 ) {
12380 Ok(f) => f,
12381 Err(_) => {
12382 let host = gpu_to_cpu(input, device)?;
12384 let mut out = vec![0.0f32; total];
12385 for i0 in 0..d0 {
12386 for i1 in 0..d1 {
12387 for i2 in 0..d2 {
12388 for i3 in 0..d3 {
12389 let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
12390 let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
12391 out[out_idx] = host[in_idx];
12392 }
12393 }
12394 }
12395 }
12396 return cpu_to_gpu(&out, device);
12397 }
12398 };
12399
12400 let mut out = alloc_zeros_f32(total, device)?;
12401 let cfg = launch_cfg(total)?;
12402 let d0_u32 = d0 as u32;
12403 let d1_u32 = d1 as u32;
12404 let d2_u32 = d2 as u32;
12405 let d3_u32 = d3 as u32;
12406 let total_u32 = total as u32;
12407
12408 unsafe {
12409 stream
12410 .launch_builder(&f)
12411 .arg(input.inner())
12412 .arg(out.inner_mut())
12413 .arg(&d0_u32)
12414 .arg(&d1_u32)
12415 .arg(&d2_u32)
12416 .arg(&d3_u32)
12417 .arg(&total_u32)
12418 .launch(cfg)?;
12419 }
12420
12421 Ok(out)
12422}
12423
12424#[cfg(feature = "cuda")]
12433pub fn gpu_small_matmul(
12434 a: &CudaBuffer<f32>,
12435 b: &CudaBuffer<f32>,
12436 m: usize,
12437 k: usize,
12438 n: usize,
12439 device: &GpuDevice,
12440) -> GpuResult<CudaBuffer<f32>> {
12441 use cudarc::driver::PushKernelArg;
12442
12443 let total = m * n;
12444 let ctx = device.context();
12445 let stream = device.stream();
12446
12447 let f = match crate::module_cache::get_or_compile(
12448 ctx,
12449 SMALL_MATMUL_PTX,
12450 "small_matmul_kernel",
12451 device.ordinal() as u32,
12452 ) {
12453 Ok(f) => f,
12454 Err(_) => {
12455 return crate::blas::gpu_matmul_f32(a, b, m, k, n, device);
12457 }
12458 };
12459
12460 let mut c = alloc_zeros_f32(total, device)?;
12461 let cfg = launch_cfg(total)?;
12462 let m_u32 = m as u32;
12463 let k_u32 = k as u32;
12464 let n_u32 = n as u32;
12465 let total_u32 = total as u32;
12466
12467 unsafe {
12468 stream
12469 .launch_builder(&f)
12470 .arg(a.inner())
12471 .arg(b.inner())
12472 .arg(c.inner_mut())
12473 .arg(&m_u32)
12474 .arg(&k_u32)
12475 .arg(&n_u32)
12476 .arg(&total_u32)
12477 .launch(cfg)?;
12478 }
12479
12480 Ok(c)
12481}
12482
12483#[cfg(feature = "cuda")]
12492pub fn gpu_small_bmm(
12493 a: &CudaBuffer<f32>,
12494 b: &CudaBuffer<f32>,
12495 batch: usize,
12496 m: usize,
12497 k: usize,
12498 n: usize,
12499 device: &GpuDevice,
12500) -> GpuResult<CudaBuffer<f32>> {
12501 if batch == 1 {
12503 return gpu_small_matmul(a, b, m, k, n, device);
12504 }
12505 crate::blas::gpu_bmm_f32(a, b, batch, m, k, n, device)
12508}
12509
12510#[cfg(feature = "cuda")]
12518pub fn gpu_embed_lookup(
12519 idx: &CudaBuffer<f32>,
12520 weight: &CudaBuffer<f32>,
12521 d: usize,
12522 device: &GpuDevice,
12523) -> GpuResult<CudaBuffer<f32>> {
12524 use cudarc::driver::PushKernelArg;
12525
12526 let ctx = device.context();
12527 let stream = device.stream();
12528
12529 let f = match crate::module_cache::get_or_compile(
12530 ctx,
12531 EMBED_LOOKUP_PTX,
12532 "embed_lookup_kernel",
12533 device.ordinal() as u32,
12534 ) {
12535 Ok(f) => f,
12536 Err(_) => {
12537 let idx_host = gpu_to_cpu(idx, device)?;
12539 let weight_host = gpu_to_cpu(weight, device)?;
12540 let row = idx_host[0] as usize;
12541 let start = row * d;
12542 let out = weight_host[start..start + d].to_vec();
12543 return cpu_to_gpu(&out, device);
12544 }
12545 };
12546
12547 let mut out = alloc_zeros_f32(d, device)?;
12548 let cfg = launch_cfg(d)?;
12549 let d_u32 = d as u32;
12550
12551 unsafe {
12552 stream
12553 .launch_builder(&f)
12554 .arg(idx.inner())
12555 .arg(weight.inner())
12556 .arg(out.inner_mut())
12557 .arg(&d_u32)
12558 .launch(cfg)?;
12559 }
12560
12561 Ok(out)
12562}
12563
12564#[cfg(feature = "cuda")]
12571pub fn gpu_slice_write(
12572 src: &CudaBuffer<f32>,
12573 dst: &mut CudaBuffer<f32>,
12574 n_batch: usize,
12575 d: usize,
12576 max_len: usize,
12577 pos: usize,
12578 device: &GpuDevice,
12579) -> GpuResult<()> {
12580 use cudarc::driver::PushKernelArg;
12581
12582 let total = n_batch * d;
12583 let ctx = device.context();
12584 let stream = device.stream();
12585
12586 let f = match crate::module_cache::get_or_compile(
12587 ctx,
12588 SLICE_WRITE_PTX,
12589 "slice_write_kernel",
12590 device.ordinal() as u32,
12591 ) {
12592 Ok(f) => f,
12593 Err(_) => {
12594 let src_host = gpu_to_cpu(src, device)?;
12596 let mut dst_host = gpu_to_cpu(dst, device)?;
12597 for b in 0..n_batch {
12598 for di in 0..d {
12599 dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
12600 }
12601 }
12602 let new_dst = cpu_to_gpu(&dst_host, device)?;
12603 *dst = new_dst;
12604 return Ok(());
12605 }
12606 };
12607
12608 let cfg = launch_cfg(total)?;
12609 let n_u32 = total as u32;
12610 let d_u32 = d as u32;
12611 let max_len_u32 = max_len as u32;
12612 let pos_u32 = pos as u32;
12613
12614 unsafe {
12615 stream
12616 .launch_builder(&f)
12617 .arg(src.inner())
12618 .arg(dst.inner_mut())
12619 .arg(&n_u32)
12620 .arg(&d_u32)
12621 .arg(&max_len_u32)
12622 .arg(&pos_u32)
12623 .launch(cfg)?;
12624 }
12625
12626 Ok(())
12627}
12628
12629#[cfg(feature = "cuda")]
12635pub fn gpu_slice_read(
12636 src: &CudaBuffer<f32>,
12637 n_batch: usize,
12638 d: usize,
12639 len: usize,
12640 max_len: usize,
12641 device: &GpuDevice,
12642) -> GpuResult<CudaBuffer<f32>> {
12643 use cudarc::driver::PushKernelArg;
12644
12645 let total = n_batch * len * d;
12646 let ctx = device.context();
12647 let stream = device.stream();
12648
12649 let f = match crate::module_cache::get_or_compile(
12650 ctx,
12651 SLICE_READ_PTX,
12652 "slice_read_kernel",
12653 device.ordinal() as u32,
12654 ) {
12655 Ok(f) => f,
12656 Err(_) => {
12657 let host = gpu_to_cpu(src, device)?;
12658 let mut out = vec![0.0f32; total];
12659 for b in 0..n_batch {
12660 for r in 0..len {
12661 for di in 0..d {
12662 out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
12663 }
12664 }
12665 }
12666 return cpu_to_gpu(&out, device);
12667 }
12668 };
12669
12670 let mut out = alloc_zeros_f32(total, device)?;
12671 let cfg = launch_cfg(total)?;
12672 let total_u32 = total as u32;
12673 let d_u32 = d as u32;
12674 let len_u32 = len as u32;
12675 let max_len_u32 = max_len as u32;
12676
12677 unsafe {
12678 stream
12679 .launch_builder(&f)
12680 .arg(src.inner())
12681 .arg(out.inner_mut())
12682 .arg(&total_u32)
12683 .arg(&d_u32)
12684 .arg(&len_u32)
12685 .arg(&max_len_u32)
12686 .launch(cfg)?;
12687 }
12688
12689 Ok(out)
12690}
12691
12692#[cfg(feature = "cuda")]
12698pub fn gpu_gelu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12699 validate_unary(input, device)?;
12700 if let Some(out) = try_launch_unary(input, device, GELU_PTX, "gelu_kernel")? {
12701 return Ok(out);
12702 }
12703 cpu_fallback_unary(input, device, |x| {
12704 let s = 1.0 / (1.0 + (-1.702 * x).exp());
12705 x * s
12706 })
12707}
12708
12709#[cfg(feature = "cuda")]
12714pub fn gpu_gelu_tanh(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12715 validate_unary(input, device)?;
12716 if let Some(out) = try_launch_unary(input, device, GELU_TANH_PTX, "gelu_tanh_kernel")? {
12717 return Ok(out);
12718 }
12719 cpu_fallback_unary(input, device, |x| {
12720 let sqrt_2_over_pi: f32 = 0.7978845608;
12721 let c: f32 = 0.044715;
12722 let inner = sqrt_2_over_pi * (x + c * x * x * x);
12723 0.5 * x * (1.0 + inner.tanh())
12724 })
12725}
12726
12727#[cfg(feature = "cuda")]
12732pub fn gpu_gelu_erf(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12733 validate_unary(input, device)?;
12734 if let Some(out) = try_launch_unary(input, device, GELU_ERF_PTX, "gelu_erf_kernel")? {
12735 return Ok(out);
12736 }
12737 cpu_fallback_unary(input, device, |x| {
12738 let z = x * std::f32::consts::FRAC_1_SQRT_2;
12740 let az = z.abs();
12741 let t = 1.0 / (1.0 + 0.3275911 * az);
12742 let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
12743 let erf_abs = 1.0 - poly * (-az * az).exp();
12744 let erf_val = if z < 0.0 { -erf_abs } else { erf_abs };
12745 x * 0.5 * (1.0 + erf_val)
12746 })
12747}
12748
12749#[cfg(feature = "cuda")]
12753pub fn gpu_gelu_backward_tanh(
12754 grad: &CudaBuffer<f32>,
12755 input: &CudaBuffer<f32>,
12756 device: &GpuDevice,
12757) -> GpuResult<CudaBuffer<f32>> {
12758 validate_binary(grad, input, device)?;
12759 if let Some(out) = try_launch_binary(
12760 grad,
12761 input,
12762 device,
12763 GELU_BACKWARD_TANH_PTX,
12764 "gelu_backward_tanh_kernel",
12765 )? {
12766 return Ok(out);
12767 }
12768 let grad_host = gpu_to_cpu(grad, device)?;
12770 let input_host = gpu_to_cpu(input, device)?;
12771 let result: Vec<f32> = grad_host
12772 .iter()
12773 .zip(input_host.iter())
12774 .map(|(&g, &x)| {
12775 let sqrt_2_over_pi: f32 = 0.7978845608;
12776 let c: f32 = 0.044715;
12777 let c3: f32 = 0.134145;
12778 let u = sqrt_2_over_pi * (x + c * x * x * x);
12779 let t = u.tanh();
12780 let dt = 1.0 - t * t;
12781 let d_inner = sqrt_2_over_pi * (1.0 + c3 * x * x);
12782 g * (0.5 * (1.0 + t) + 0.5 * x * dt * d_inner)
12783 })
12784 .collect();
12785 cpu_to_gpu(&result, device)
12786}
12787
12788#[cfg(feature = "cuda")]
12794pub fn gpu_silu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12795 validate_unary(input, device)?;
12796 if let Some(out) = try_launch_unary(input, device, SILU_PTX, "silu_kernel")? {
12797 return Ok(out);
12798 }
12799 cpu_fallback_unary(input, device, |x| {
12800 let sig = 1.0 / (1.0 + (-x).exp());
12801 x * sig
12802 })
12803}
12804
12805#[cfg(feature = "cuda")]
12808pub fn gpu_silu_backward(
12809 grad: &CudaBuffer<f32>,
12810 input: &CudaBuffer<f32>,
12811 device: &GpuDevice,
12812) -> GpuResult<CudaBuffer<f32>> {
12813 validate_binary(grad, input, device)?;
12814
12815 if let Some(out) = try_launch_binary(
12816 grad,
12817 input,
12818 device,
12819 SILU_BACKWARD_PTX,
12820 "silu_backward_kernel",
12821 )? {
12822 return Ok(out);
12823 }
12824
12825 let grad_host = gpu_to_cpu(grad, device)?;
12827 let input_host = gpu_to_cpu(input, device)?;
12828 let result: Vec<f32> = grad_host
12829 .iter()
12830 .zip(input_host.iter())
12831 .map(|(&g, &x)| {
12832 let sig = 1.0 / (1.0 + (-x).exp());
12833 g * (sig + x * sig * (1.0 - sig))
12834 })
12835 .collect();
12836 cpu_to_gpu(&result, device)
12837}
12838
12839#[cfg(feature = "cuda")]
12847pub fn gpu_elu(
12848 input: &CudaBuffer<f32>,
12849 alpha: f32,
12850 device: &GpuDevice,
12851) -> GpuResult<CudaBuffer<f32>> {
12852 use cudarc::driver::PushKernelArg;
12853
12854 validate_unary(input, device)?;
12855
12856 let n = input.len();
12857 let ctx = device.context();
12858 let stream = device.stream();
12859
12860 let f = match crate::module_cache::get_or_compile(
12861 ctx,
12862 ELU_PTX,
12863 "elu_kernel",
12864 device.ordinal() as u32,
12865 ) {
12866 Ok(f) => f,
12867 Err(_) => {
12868 let host = gpu_to_cpu(input, device)?;
12869 let result: Vec<f32> = host
12870 .iter()
12871 .map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
12872 .collect();
12873 return cpu_to_gpu(&result, device);
12874 }
12875 };
12876
12877 let mut out = alloc_zeros_f32(n, device)?;
12878 let cfg = launch_cfg(n)?;
12879 let n_u32 = n as u32;
12880
12881 unsafe {
12882 stream
12883 .launch_builder(&f)
12884 .arg(input.inner())
12885 .arg(out.inner_mut())
12886 .arg(&n_u32)
12887 .arg(&alpha)
12888 .launch(cfg)?;
12889 }
12890
12891 Ok(out)
12892}
12893
12894#[cfg(feature = "cuda")]
12898pub fn gpu_elu_backward(
12899 grad: &CudaBuffer<f32>,
12900 input: &CudaBuffer<f32>,
12901 alpha: f32,
12902 device: &GpuDevice,
12903) -> GpuResult<CudaBuffer<f32>> {
12904 use cudarc::driver::PushKernelArg;
12905
12906 validate_binary(grad, input, device)?;
12907
12908 let n = grad.len();
12909 let ctx = device.context();
12910 let stream = device.stream();
12911
12912 let f = match crate::module_cache::get_or_compile(
12913 ctx,
12914 ELU_BACKWARD_PTX,
12915 "elu_backward_kernel",
12916 device.ordinal() as u32,
12917 ) {
12918 Ok(f) => f,
12919 Err(_) => {
12920 let grad_host = gpu_to_cpu(grad, device)?;
12921 let input_host = gpu_to_cpu(input, device)?;
12922 let result: Vec<f32> = grad_host
12923 .iter()
12924 .zip(input_host.iter())
12925 .map(|(&g, &x)| if x > 0.0 { g } else { g * alpha * x.exp() })
12926 .collect();
12927 return cpu_to_gpu(&result, device);
12928 }
12929 };
12930
12931 let mut out = alloc_zeros_f32(n, device)?;
12932 let cfg = launch_cfg(n)?;
12933 let n_u32 = n as u32;
12934
12935 unsafe {
12936 stream
12937 .launch_builder(&f)
12938 .arg(grad.inner())
12939 .arg(input.inner())
12940 .arg(out.inner_mut())
12941 .arg(&n_u32)
12942 .arg(&alpha)
12943 .launch(cfg)?;
12944 }
12945
12946 Ok(out)
12947}
12948
12949#[cfg(feature = "cuda")]
12955pub fn gpu_mish(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
12956 validate_unary(input, device)?;
12957 if let Some(out) = try_launch_unary(input, device, MISH_PTX, "mish_kernel")? {
12958 return Ok(out);
12959 }
12960 cpu_fallback_unary(input, device, |x| {
12961 let sp = if x > 20.0 { x } else { (1.0 + x.exp()).ln() };
12962 x * sp.tanh()
12963 })
12964}
12965
12966#[cfg(feature = "cuda")]
12970pub fn gpu_mish_backward(
12971 grad: &CudaBuffer<f32>,
12972 input: &CudaBuffer<f32>,
12973 device: &GpuDevice,
12974) -> GpuResult<CudaBuffer<f32>> {
12975 validate_binary(grad, input, device)?;
12976
12977 if let Some(out) = try_launch_binary(
12978 grad,
12979 input,
12980 device,
12981 MISH_BACKWARD_PTX,
12982 "mish_backward_kernel",
12983 )? {
12984 return Ok(out);
12985 }
12986
12987 let grad_host = gpu_to_cpu(grad, device)?;
12989 let input_host = gpu_to_cpu(input, device)?;
12990 let result: Vec<f32> = grad_host
12991 .iter()
12992 .zip(input_host.iter())
12993 .map(|(&g, &x)| {
12994 let sp = if x > 20.0 { x } else { (1.0 + x.exp()).ln() };
12995 let t = sp.tanh();
12996 let sig = 1.0 / (1.0 + (-x).exp());
12997 g * (t + x * sig * (1.0 - t * t))
12998 })
12999 .collect();
13000 cpu_to_gpu(&result, device)
13001}
13002
13003#[cfg(feature = "cuda")]
13007pub fn gpu_clamp(
13008 input: &CudaBuffer<f32>,
13009 min_val: f32,
13010 max_val: f32,
13011 device: &GpuDevice,
13012) -> GpuResult<CudaBuffer<f32>> {
13013 use cudarc::driver::PushKernelArg;
13014
13015 validate_unary(input, device)?;
13016
13017 let n = input.len();
13018 let ctx = device.context();
13019 let stream = device.stream();
13020
13021 let f = match crate::module_cache::get_or_compile(
13022 ctx,
13023 CLAMP_PTX,
13024 "clamp_kernel",
13025 device.ordinal() as u32,
13026 ) {
13027 Ok(f) => f,
13028 Err(_) => {
13029 let host = gpu_to_cpu(input, device)?;
13030 let result: Vec<f32> = host
13031 .iter()
13032 .map(|&x| x.max(min_val).min(max_val))
13033 .collect();
13034 return cpu_to_gpu(&result, device);
13035 }
13036 };
13037
13038 let mut out = alloc_zeros_f32(n, device)?;
13039 let cfg = launch_cfg(n)?;
13040 let n_u32 = n as u32;
13041
13042 unsafe {
13043 stream
13044 .launch_builder(&f)
13045 .arg(input.inner())
13046 .arg(out.inner_mut())
13047 .arg(&n_u32)
13048 .arg(&min_val)
13049 .arg(&max_val)
13050 .launch(cfg)?;
13051 }
13052
13053 Ok(out)
13054}
13055
13056#[cfg(feature = "cuda")]
13062pub fn gpu_div(
13063 a: &CudaBuffer<f32>,
13064 b: &CudaBuffer<f32>,
13065 device: &GpuDevice,
13066) -> GpuResult<CudaBuffer<f32>> {
13067 validate_binary(a, b, device)?;
13068
13069 if let Some(out) = try_launch_binary(a, b, device, DIV_PTX, "div_kernel")? {
13070 return Ok(out);
13071 }
13072
13073 let a_host = gpu_to_cpu(a, device)?;
13075 let b_host = gpu_to_cpu(b, device)?;
13076 let result: Vec<f32> = a_host
13077 .iter()
13078 .zip(b_host.iter())
13079 .map(|(&x, &y)| x / y)
13080 .collect();
13081 cpu_to_gpu(&result, device)
13082}
13083
13084#[cfg(feature = "cuda")]
13086pub fn gpu_exp(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13087 validate_unary(a, device)?;
13088 if let Some(out) = try_launch_unary(a, device, EXP_PTX, "exp_kernel")? {
13089 return Ok(out);
13090 }
13091 cpu_fallback_unary(a, device, |x| x.exp())
13092}
13093
13094#[cfg(feature = "cuda")]
13096pub fn gpu_log(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13097 validate_unary(a, device)?;
13098 if let Some(out) = try_launch_unary(a, device, LOG_PTX, "log_kernel")? {
13099 return Ok(out);
13100 }
13101 cpu_fallback_unary(a, device, |x| x.ln())
13102}
13103
13104#[cfg(feature = "cuda")]
13106pub fn gpu_sqrt(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13107 validate_unary(a, device)?;
13108 if let Some(out) = try_launch_unary(a, device, SQRT_PTX, "sqrt_kernel")? {
13109 return Ok(out);
13110 }
13111 cpu_fallback_unary(a, device, |x| x.sqrt())
13112}
13113
13114#[cfg(feature = "cuda")]
13116pub fn gpu_pow(
13117 a: &CudaBuffer<f32>,
13118 exponent: f32,
13119 device: &GpuDevice,
13120) -> GpuResult<CudaBuffer<f32>> {
13121 use cudarc::driver::PushKernelArg;
13122
13123 validate_unary(a, device)?;
13124
13125 let n = a.len();
13126 let ctx = device.context();
13127 let stream = device.stream();
13128
13129 let f = match crate::module_cache::get_or_compile(
13130 ctx,
13131 POW_PTX,
13132 "pow_kernel",
13133 device.ordinal() as u32,
13134 ) {
13135 Ok(f) => f,
13136 Err(_) => {
13137 let host = gpu_to_cpu(a, device)?;
13138 let result: Vec<f32> = host.iter().map(|&x| x.powf(exponent)).collect();
13139 return cpu_to_gpu(&result, device);
13140 }
13141 };
13142
13143 let mut out = alloc_zeros_f32(n, device)?;
13144 let cfg = launch_cfg(n)?;
13145 let n_u32 = n as u32;
13146
13147 unsafe {
13148 stream
13149 .launch_builder(&f)
13150 .arg(a.inner())
13151 .arg(out.inner_mut())
13152 .arg(&exponent)
13153 .arg(&n_u32)
13154 .launch(cfg)?;
13155 }
13156
13157 Ok(out)
13158}
13159
13160#[cfg(feature = "cuda")]
13162pub fn gpu_abs(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13163 validate_unary(a, device)?;
13164 if let Some(out) = try_launch_unary(a, device, ABS_PTX, "abs_kernel")? {
13165 return Ok(out);
13166 }
13167 cpu_fallback_unary(a, device, |x| x.abs())
13168}
13169
13170#[cfg(feature = "cuda")]
13172pub fn gpu_sigmoid(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13173 validate_unary(a, device)?;
13174 if let Some(out) = try_launch_unary(a, device, SIGMOID_PTX, "sigmoid_kernel")? {
13175 return Ok(out);
13176 }
13177 cpu_fallback_unary(a, device, |x| 1.0 / (1.0 + (-x).exp()))
13178}
13179
13180#[cfg(feature = "cuda")]
13182pub fn gpu_tanh(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
13183 validate_unary(a, device)?;
13184 if let Some(out) = try_launch_unary(a, device, TANH_PTX, "tanh_kernel")? {
13185 return Ok(out);
13186 }
13187 cpu_fallback_unary(a, device, |x| x.tanh())
13188}
13189
13190#[cfg(feature = "cuda")]
13196pub fn gpu_add_f64(
13197 a: &CudaBuffer<f64>,
13198 b: &CudaBuffer<f64>,
13199 device: &GpuDevice,
13200) -> GpuResult<CudaBuffer<f64>> {
13201 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13202 if a.len() != b.len() {
13203 return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13204 }
13205 let ptx = get_f64_ptx(&CACHE, ADD_PTX, "add_kernel", "add_f64_kernel");
13206 if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "add_f64_kernel")? {
13207 return Ok(out);
13208 }
13209 cpu_fallback_binary_f64(a, b, device, |x, y| x + y)
13210}
13211
13212#[cfg(feature = "cuda")]
13214pub fn gpu_sub_f64(
13215 a: &CudaBuffer<f64>,
13216 b: &CudaBuffer<f64>,
13217 device: &GpuDevice,
13218) -> GpuResult<CudaBuffer<f64>> {
13219 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13220 if a.len() != b.len() {
13221 return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13222 }
13223 let ptx = get_f64_ptx(&CACHE, SUB_PTX, "sub_kernel", "sub_f64_kernel");
13224 if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "sub_f64_kernel")? {
13225 return Ok(out);
13226 }
13227 cpu_fallback_binary_f64(a, b, device, |x, y| x - y)
13228}
13229
13230#[cfg(feature = "cuda")]
13232pub fn gpu_mul_f64(
13233 a: &CudaBuffer<f64>,
13234 b: &CudaBuffer<f64>,
13235 device: &GpuDevice,
13236) -> GpuResult<CudaBuffer<f64>> {
13237 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13238 if a.len() != b.len() {
13239 return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13240 }
13241 let ptx = get_f64_ptx(&CACHE, MUL_PTX, "mul_kernel", "mul_f64_kernel");
13242 if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "mul_f64_kernel")? {
13243 return Ok(out);
13244 }
13245 cpu_fallback_binary_f64(a, b, device, |x, y| x * y)
13246}
13247
13248#[cfg(feature = "cuda")]
13250pub fn gpu_div_f64(
13251 a: &CudaBuffer<f64>,
13252 b: &CudaBuffer<f64>,
13253 device: &GpuDevice,
13254) -> GpuResult<CudaBuffer<f64>> {
13255 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13256 if a.len() != b.len() {
13257 return Err(GpuError::LengthMismatch { a: a.len(), b: b.len() });
13258 }
13259 let ptx = get_f64_ptx(&CACHE, DIV_PTX, "div_kernel", "div_f64_kernel");
13260 if let Some(out) = try_launch_binary_f64(a, b, device, ptx, "div_f64_kernel")? {
13261 return Ok(out);
13262 }
13263 cpu_fallback_binary_f64(a, b, device, |x, y| x / y)
13264}
13265
13266#[cfg(feature = "cuda")]
13268pub fn gpu_neg_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13269 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13270 let ptx = get_f64_ptx(&CACHE, NEG_PTX, "neg_kernel", "neg_f64_kernel");
13271 if let Some(out) = try_launch_unary_f64(a, device, ptx, "neg_f64_kernel")? {
13272 return Ok(out);
13273 }
13274 cpu_fallback_unary_f64(a, device, |x| -x)
13275}
13276
13277#[cfg(feature = "cuda")]
13279pub fn gpu_relu_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13280 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13281 let ptx = get_f64_ptx(&CACHE, RELU_PTX, "relu_kernel", "relu_f64_kernel");
13282 if let Some(out) = try_launch_unary_f64(a, device, ptx, "relu_f64_kernel")? {
13283 return Ok(out);
13284 }
13285 cpu_fallback_unary_f64(a, device, |x| x.max(0.0))
13286}
13287
13288#[cfg(feature = "cuda")]
13290pub fn gpu_scale_f64(
13291 a: &CudaBuffer<f64>,
13292 scalar: f64,
13293 device: &GpuDevice,
13294) -> GpuResult<CudaBuffer<f64>> {
13295 use cudarc::driver::PushKernelArg;
13296 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13297
13298 let n = a.len();
13299 let ctx = device.context();
13300 let stream = device.stream();
13301
13302 let ptx = get_f64_ptx(&CACHE, SCALE_PTX, "scale_kernel", "scale_f64_kernel");
13303 if let Ok(f) = crate::module_cache::get_or_compile(
13304 ctx, ptx, "scale_f64_kernel", device.ordinal() as u32,
13305 ) {
13306 let mut out = alloc_zeros_f64(n, device)?;
13307 let cfg = launch_cfg(n)?;
13308 let n_u32 = n as u32;
13309
13310 unsafe {
13311 stream
13312 .launch_builder(&f)
13313 .arg(a.inner())
13314 .arg(out.inner_mut())
13315 .arg(&scalar)
13316 .arg(&n_u32)
13317 .launch(cfg)?;
13318 }
13319 return Ok(out);
13320 }
13321
13322 let a_host = gpu_to_cpu(a, device)?;
13323 let result: Vec<f64> = a_host.iter().map(|&x| x * scalar).collect();
13324 cpu_to_gpu(&result, device)
13325}
13326
13327#[cfg(feature = "cuda")]
13329pub fn gpu_exp_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13330 if let Some(out) = try_launch_unary_f64(a, device, EXP_F64_PTX, "exp_f64_kernel")? {
13331 return Ok(out);
13332 }
13333 cpu_fallback_unary_f64(a, device, |x| x.exp())
13334}
13335
13336#[cfg(feature = "cuda")]
13338pub fn gpu_log_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13339 if let Some(out) = try_launch_unary_f64(a, device, LOG_F64_PTX, "log_f64_kernel")? {
13340 return Ok(out);
13341 }
13342 cpu_fallback_unary_f64(a, device, |x| x.ln())
13343}
13344
13345#[cfg(feature = "cuda")]
13347pub fn gpu_sqrt_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13348 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13349 let ptx = get_f64_ptx(&CACHE, SQRT_PTX, "sqrt_kernel", "sqrt_f64_kernel");
13350 if let Some(out) = try_launch_unary_f64(a, device, ptx, "sqrt_f64_kernel")? {
13351 return Ok(out);
13352 }
13353 cpu_fallback_unary_f64(a, device, |x| x.sqrt())
13354}
13355
13356#[cfg(feature = "cuda")]
13358pub fn gpu_pow_f64(
13359 a: &CudaBuffer<f64>,
13360 exponent: f64,
13361 device: &GpuDevice,
13362) -> GpuResult<CudaBuffer<f64>> {
13363 use cudarc::driver::PushKernelArg;
13364
13365 let n = a.len();
13366 let ctx = device.context();
13367 let stream = device.stream();
13368
13369 if let Ok(f) = crate::module_cache::get_or_compile(
13370 ctx, POW_F64_PTX, "pow_f64_kernel", device.ordinal() as u32,
13371 ) {
13372 let mut out = alloc_zeros_f64(n, device)?;
13373 let cfg = launch_cfg(n)?;
13374 let n_u32 = n as u32;
13375
13376 unsafe {
13377 stream
13378 .launch_builder(&f)
13379 .arg(a.inner())
13380 .arg(out.inner_mut())
13381 .arg(&exponent)
13382 .arg(&n_u32)
13383 .launch(cfg)?;
13384 }
13385 return Ok(out);
13386 }
13387
13388 let a_host = gpu_to_cpu(a, device)?;
13389 let result: Vec<f64> = a_host.iter().map(|&x| x.powf(exponent)).collect();
13390 cpu_to_gpu(&result, device)
13391}
13392
13393#[cfg(feature = "cuda")]
13395pub fn gpu_abs_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13396 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13397 let ptx = get_f64_ptx(&CACHE, ABS_PTX, "abs_kernel", "abs_f64_kernel");
13398 if let Some(out) = try_launch_unary_f64(a, device, ptx, "abs_f64_kernel")? {
13399 return Ok(out);
13400 }
13401 cpu_fallback_unary_f64(a, device, |x| x.abs())
13402}
13403
13404#[cfg(feature = "cuda")]
13406pub fn gpu_sigmoid_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13407 if let Some(out) = try_launch_unary_f64(a, device, SIGMOID_F64_PTX, "sigmoid_f64_kernel")? {
13408 return Ok(out);
13409 }
13410 cpu_fallback_unary_f64(a, device, |x| 1.0 / (1.0 + (-x).exp()))
13411}
13412
13413#[cfg(feature = "cuda")]
13415pub fn gpu_tanh_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
13416 if let Some(out) = try_launch_unary_f64(a, device, TANH_F64_PTX, "tanh_f64_kernel")? {
13417 return Ok(out);
13418 }
13419 cpu_fallback_unary_f64(a, device, |x| x.tanh())
13420}
13421
13422#[cfg(feature = "cuda")]
13428pub fn gpu_relu_backward_f64(
13429 grad: &CudaBuffer<f64>,
13430 input: &CudaBuffer<f64>,
13431 device: &GpuDevice,
13432) -> GpuResult<CudaBuffer<f64>> {
13433 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13434 if grad.len() != input.len() {
13435 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
13436 }
13437 let ptx = get_f64_ptx(&CACHE, RELU_BACKWARD_PTX, "relu_backward_kernel", "relu_backward_f64_kernel");
13438 if let Some(out) = try_launch_binary_f64(
13439 grad,
13440 input,
13441 device,
13442 ptx,
13443 "relu_backward_f64_kernel",
13444 )? {
13445 return Ok(out);
13446 }
13447 cpu_fallback_binary_f64(grad, input, device, |g, x| if x > 0.0 { g } else { 0.0 })
13448}
13449
13450#[cfg(feature = "cuda")]
13452pub fn gpu_sigmoid_backward_f64(
13453 grad: &CudaBuffer<f64>,
13454 output: &CudaBuffer<f64>,
13455 device: &GpuDevice,
13456) -> GpuResult<CudaBuffer<f64>> {
13457 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13458 if grad.len() != output.len() {
13459 return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
13460 }
13461 let ptx = get_f64_ptx(&CACHE, SIGMOID_BACKWARD_PTX, "sigmoid_backward_kernel", "sigmoid_backward_f64_kernel");
13462 if let Some(out) = try_launch_binary_f64(
13463 grad,
13464 output,
13465 device,
13466 ptx,
13467 "sigmoid_backward_f64_kernel",
13468 )? {
13469 return Ok(out);
13470 }
13471 cpu_fallback_binary_f64(grad, output, device, |g, o| g * o * (1.0 - o))
13472}
13473
13474#[cfg(feature = "cuda")]
13476pub fn gpu_tanh_backward_f64(
13477 grad: &CudaBuffer<f64>,
13478 output: &CudaBuffer<f64>,
13479 device: &GpuDevice,
13480) -> GpuResult<CudaBuffer<f64>> {
13481 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13482 if grad.len() != output.len() {
13483 return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
13484 }
13485 let ptx = get_f64_ptx(&CACHE, TANH_BACKWARD_PTX, "tanh_backward_kernel", "tanh_backward_f64_kernel");
13486 if let Some(out) = try_launch_binary_f64(
13487 grad,
13488 output,
13489 device,
13490 ptx,
13491 "tanh_backward_f64_kernel",
13492 )? {
13493 return Ok(out);
13494 }
13495 cpu_fallback_binary_f64(grad, output, device, |g, o| g * (1.0 - o * o))
13496}
13497
13498#[cfg(feature = "cuda")]
13504pub fn gpu_broadcast_add_f64(
13505 a: &CudaBuffer<f64>,
13506 b: &CudaBuffer<f64>,
13507 a_shape: &[usize],
13508 b_shape: &[usize],
13509 out_shape: &[usize],
13510 device: &GpuDevice,
13511) -> GpuResult<CudaBuffer<f64>> {
13512 let a_str = broadcast_strides(a_shape, out_shape);
13513 let b_str = broadcast_strides(b_shape, out_shape);
13514 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13515 let out_numel: usize = out_shape.iter().product();
13516
13517 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13518 let ptx = get_f64_ptx(&CACHE, BROADCAST_ADD_PTX, "broadcast_add_kernel", "broadcast_add_f64_kernel");
13519 if let Some(out) = try_launch_broadcast_binary_f64(
13520 a,
13521 b,
13522 &a_str,
13523 &b_str,
13524 &shape_u32,
13525 out_numel,
13526 device,
13527 ptx,
13528 "broadcast_add_f64_kernel",
13529 )? {
13530 return Ok(out);
13531 }
13532
13533 cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x + y)
13534}
13535
13536#[cfg(feature = "cuda")]
13538pub fn gpu_broadcast_sub_f64(
13539 a: &CudaBuffer<f64>,
13540 b: &CudaBuffer<f64>,
13541 a_shape: &[usize],
13542 b_shape: &[usize],
13543 out_shape: &[usize],
13544 device: &GpuDevice,
13545) -> GpuResult<CudaBuffer<f64>> {
13546 let a_str = broadcast_strides(a_shape, out_shape);
13547 let b_str = broadcast_strides(b_shape, out_shape);
13548 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13549 let out_numel: usize = out_shape.iter().product();
13550
13551 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13552 let ptx = get_f64_ptx(&CACHE, BROADCAST_SUB_PTX, "broadcast_sub_kernel", "broadcast_sub_f64_kernel");
13553 if let Some(out) = try_launch_broadcast_binary_f64(
13554 a,
13555 b,
13556 &a_str,
13557 &b_str,
13558 &shape_u32,
13559 out_numel,
13560 device,
13561 ptx,
13562 "broadcast_sub_f64_kernel",
13563 )? {
13564 return Ok(out);
13565 }
13566
13567 cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x - y)
13568}
13569
13570#[cfg(feature = "cuda")]
13572pub fn gpu_broadcast_mul_f64(
13573 a: &CudaBuffer<f64>,
13574 b: &CudaBuffer<f64>,
13575 a_shape: &[usize],
13576 b_shape: &[usize],
13577 out_shape: &[usize],
13578 device: &GpuDevice,
13579) -> GpuResult<CudaBuffer<f64>> {
13580 let a_str = broadcast_strides(a_shape, out_shape);
13581 let b_str = broadcast_strides(b_shape, out_shape);
13582 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13583 let out_numel: usize = out_shape.iter().product();
13584
13585 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13586 let ptx = get_f64_ptx(&CACHE, BROADCAST_MUL_PTX, "broadcast_mul_kernel", "broadcast_mul_f64_kernel");
13587 if let Some(out) = try_launch_broadcast_binary_f64(
13588 a,
13589 b,
13590 &a_str,
13591 &b_str,
13592 &shape_u32,
13593 out_numel,
13594 device,
13595 ptx,
13596 "broadcast_mul_f64_kernel",
13597 )? {
13598 return Ok(out);
13599 }
13600
13601 cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x * y)
13602}
13603
13604#[cfg(feature = "cuda")]
13606pub fn gpu_broadcast_div_f64(
13607 a: &CudaBuffer<f64>,
13608 b: &CudaBuffer<f64>,
13609 a_shape: &[usize],
13610 b_shape: &[usize],
13611 out_shape: &[usize],
13612 device: &GpuDevice,
13613) -> GpuResult<CudaBuffer<f64>> {
13614 let a_str = broadcast_strides(a_shape, out_shape);
13615 let b_str = broadcast_strides(b_shape, out_shape);
13616 let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
13617 let out_numel: usize = out_shape.iter().product();
13618
13619 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13620 let ptx = get_f64_ptx(&CACHE, BROADCAST_DIV_PTX, "broadcast_div_kernel", "broadcast_div_f64_kernel");
13621 if let Some(out) = try_launch_broadcast_binary_f64(
13622 a,
13623 b,
13624 &a_str,
13625 &b_str,
13626 &shape_u32,
13627 out_numel,
13628 device,
13629 ptx,
13630 "broadcast_div_f64_kernel",
13631 )? {
13632 return Ok(out);
13633 }
13634
13635 cpu_fallback_broadcast_binary_f64(a, b, a_shape, b_shape, out_shape, device, |x, y| x / y)
13636}
13637
13638#[cfg(feature = "cuda")]
13644pub fn gpu_reduce_sum_f64(
13645 a: &CudaBuffer<f64>,
13646 device: &GpuDevice,
13647) -> GpuResult<CudaBuffer<f64>> {
13648 use cudarc::driver::PushKernelArg;
13649 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13650
13651 let n = a.len();
13652 if n == 0 {
13653 return cpu_to_gpu(&[0.0f64], device);
13654 }
13655
13656 let ctx = device.context();
13657 let stream = device.stream();
13658
13659 let ptx = get_f64_ptx(&CACHE, REDUCE_SUM_PTX, "reduce_sum_kernel", "reduce_sum_f64_kernel");
13660 let f = match crate::module_cache::get_or_compile(
13661 ctx,
13662 ptx,
13663 "reduce_sum_f64_kernel",
13664 device.ordinal() as u32,
13665 ) {
13666 Ok(f) => f,
13667 Err(_) => {
13668 let host = gpu_to_cpu(a, device)?;
13669 let total: f64 = host.iter().sum();
13670 return cpu_to_gpu(&[total], device);
13671 }
13672 };
13673
13674 const BLOCK: u32 = 256;
13675 let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
13676 let num_blocks = num_blocks.min(1024);
13677
13678 let mut partials = alloc_zeros_f64(num_blocks as usize, device)?;
13679 let n_u32 = n as u32;
13680
13681 let cfg = cudarc::driver::LaunchConfig {
13682 grid_dim: (num_blocks.max(1), 1, 1),
13683 block_dim: (BLOCK, 1, 1),
13684 shared_mem_bytes: 0,
13685 };
13686
13687 unsafe {
13688 stream
13689 .launch_builder(&f)
13690 .arg(a.inner())
13691 .arg(partials.inner_mut())
13692 .arg(&n_u32)
13693 .launch(cfg)?;
13694 }
13695
13696 if num_blocks <= 1 {
13697 return Ok(partials);
13698 }
13699
13700 if num_blocks <= 256 {
13701 let host_partials = gpu_to_cpu(&partials, device)?;
13702 let total: f64 = host_partials.iter().sum();
13703 return cpu_to_gpu(&[total], device);
13704 }
13705
13706 gpu_reduce_sum_f64(&partials, device)
13707}
13708
13709#[cfg(feature = "cuda")]
13711pub fn gpu_sum_axis_f64(
13712 a: &CudaBuffer<f64>,
13713 outer: usize,
13714 axis_size: usize,
13715 inner: usize,
13716 device: &GpuDevice,
13717) -> GpuResult<CudaBuffer<f64>> {
13718 use cudarc::driver::PushKernelArg;
13719 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13720
13721 let total_output = outer * inner;
13722 let ctx = device.context();
13723 let stream = device.stream();
13724
13725 let ptx = get_f64_ptx(&CACHE, SUM_AXIS_PTX, "sum_axis_kernel", "sum_axis_f64_kernel");
13726 let f = match crate::module_cache::get_or_compile(
13727 ctx,
13728 ptx,
13729 "sum_axis_f64_kernel",
13730 device.ordinal() as u32,
13731 ) {
13732 Ok(f) => f,
13733 Err(_) => {
13734 let host = gpu_to_cpu(a, device)?;
13735 let mut result = vec![0.0f64; total_output];
13736 for (i, out) in result.iter_mut().enumerate() {
13737 let outer_idx = i / inner;
13738 let inner_idx = i % inner;
13739 let mut sum = 0.0f64;
13740 for k in 0..axis_size {
13741 sum += host[outer_idx * axis_size * inner + k * inner + inner_idx];
13742 }
13743 *out = sum;
13744 }
13745 return cpu_to_gpu(&result, device);
13746 }
13747 };
13748
13749 let mut out = alloc_zeros_f64(total_output, device)?;
13750 let cfg = launch_cfg(total_output)?;
13751 let outer_u32 = outer as u32;
13752 let axis_size_u32 = axis_size as u32;
13753 let inner_u32 = inner as u32;
13754 let total_u32 = total_output as u32;
13755
13756 unsafe {
13757 stream
13758 .launch_builder(&f)
13759 .arg(a.inner())
13760 .arg(out.inner_mut())
13761 .arg(&outer_u32)
13762 .arg(&axis_size_u32)
13763 .arg(&inner_u32)
13764 .arg(&total_u32)
13765 .launch(cfg)?;
13766 }
13767
13768 Ok(out)
13769}
13770
13771#[cfg(not(feature = "cuda"))]
13772pub fn gpu_reduce_sum_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
13773#[cfg(not(feature = "cuda"))]
13774pub fn gpu_sum_axis_f64(_a: &CudaBuffer<f64>, _outer: usize, _axis_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
13775
13776#[cfg(feature = "cuda")]
13782pub fn gpu_transpose_2d_f64(
13783 input: &CudaBuffer<f64>,
13784 m: usize,
13785 n: usize,
13786 device: &GpuDevice,
13787) -> GpuResult<CudaBuffer<f64>> {
13788 use cudarc::driver::PushKernelArg;
13789 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13790
13791 validate_device(input, device)?;
13792
13793 let total = m * n;
13794 let ctx = device.context();
13795 let stream = device.stream();
13796
13797 let ptx = get_f64_ptx(&CACHE, TRANSPOSE_2D_PTX, "transpose_2d_kernel", "transpose_2d_f64_kernel");
13798 let f = match crate::module_cache::get_or_compile(
13799 ctx,
13800 ptx,
13801 "transpose_2d_f64_kernel",
13802 device.ordinal() as u32,
13803 ) {
13804 Ok(f) => f,
13805 Err(_) => {
13806 let host = gpu_to_cpu(input, device)?;
13807 let mut out = vec![0.0f64; total];
13808 for i in 0..m {
13809 for j in 0..n {
13810 out[j * m + i] = host[i * n + j];
13811 }
13812 }
13813 return cpu_to_gpu(&out, device);
13814 }
13815 };
13816
13817 let mut out = alloc_zeros_f64(total, device)?;
13818 let cfg = launch_cfg(total)?;
13819 let m_u32 = m as u32;
13820 let n_u32 = n as u32;
13821 let total_u32 = total as u32;
13822
13823 unsafe {
13824 stream
13825 .launch_builder(&f)
13826 .arg(input.inner())
13827 .arg(out.inner_mut())
13828 .arg(&m_u32)
13829 .arg(&n_u32)
13830 .arg(&total_u32)
13831 .launch(cfg)?;
13832 }
13833
13834 Ok(out)
13835}
13836
13837#[cfg(feature = "cuda")]
13839pub fn gpu_permute_0213_f64(
13840 input: &CudaBuffer<f64>,
13841 d0: usize,
13842 d1: usize,
13843 d2: usize,
13844 d3: usize,
13845 device: &GpuDevice,
13846) -> GpuResult<CudaBuffer<f64>> {
13847 use cudarc::driver::PushKernelArg;
13848 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13849
13850 validate_device(input, device)?;
13851
13852 let total = d0 * d1 * d2 * d3;
13853 let ctx = device.context();
13854 let stream = device.stream();
13855
13856 let ptx = get_f64_ptx(&CACHE, PERMUTE_0213_PTX, "permute_0213_kernel", "permute_0213_f64_kernel");
13857 let f = match crate::module_cache::get_or_compile(
13858 ctx,
13859 ptx,
13860 "permute_0213_f64_kernel",
13861 device.ordinal() as u32,
13862 ) {
13863 Ok(f) => f,
13864 Err(_) => {
13865 let host = gpu_to_cpu(input, device)?;
13866 let mut out = vec![0.0f64; total];
13867 for i0 in 0..d0 {
13868 for i1 in 0..d1 {
13869 for i2 in 0..d2 {
13870 for i3 in 0..d3 {
13871 let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
13872 let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
13873 out[out_idx] = host[in_idx];
13874 }
13875 }
13876 }
13877 }
13878 return cpu_to_gpu(&out, device);
13879 }
13880 };
13881
13882 let mut out = alloc_zeros_f64(total, device)?;
13883 let cfg = launch_cfg(total)?;
13884 let d0_u32 = d0 as u32;
13885 let d1_u32 = d1 as u32;
13886 let d2_u32 = d2 as u32;
13887 let d3_u32 = d3 as u32;
13888 let total_u32 = total as u32;
13889
13890 unsafe {
13891 stream
13892 .launch_builder(&f)
13893 .arg(input.inner())
13894 .arg(out.inner_mut())
13895 .arg(&d0_u32)
13896 .arg(&d1_u32)
13897 .arg(&d2_u32)
13898 .arg(&d3_u32)
13899 .arg(&total_u32)
13900 .launch(cfg)?;
13901 }
13902
13903 Ok(out)
13904}
13905
13906#[cfg(feature = "cuda")]
13908pub fn gpu_strided_split_f64(
13909 input: &CudaBuffer<f64>,
13910 total_along_axis: usize,
13911 split_offset: usize,
13912 split_size: usize,
13913 inner_size: usize,
13914 n: usize,
13915 device: &GpuDevice,
13916) -> GpuResult<CudaBuffer<f64>> {
13917 use cudarc::driver::PushKernelArg;
13918 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13919
13920 validate_device(input, device)?;
13921
13922 let ctx = device.context();
13923 let stream = device.stream();
13924
13925 let ptx = get_f64_ptx(&CACHE, STRIDED_SPLIT_PTX, "strided_split_kernel", "strided_split_f64_kernel");
13926 let f = match crate::module_cache::get_or_compile(
13927 ctx,
13928 ptx,
13929 "strided_split_f64_kernel",
13930 device.ordinal() as u32,
13931 ) {
13932 Ok(f) => f,
13933 Err(_) => {
13934 let host = gpu_to_cpu(input, device)?;
13935 let mut result = vec![0.0f64; n];
13936 for (i, out) in result.iter_mut().enumerate() {
13937 let outer_idx = i / (split_size * inner_size);
13938 let within = i % (split_size * inner_size);
13939 let src_idx =
13940 outer_idx * total_along_axis * inner_size + split_offset * inner_size + within;
13941 *out = host[src_idx];
13942 }
13943 return cpu_to_gpu(&result, device);
13944 }
13945 };
13946
13947 let mut out = alloc_zeros_f64(n, device)?;
13948 let cfg = launch_cfg(n)?;
13949 let total_ax_u32 = total_along_axis as u32;
13950 let offset_u32 = split_offset as u32;
13951 let split_sz_u32 = split_size as u32;
13952 let inner_u32 = inner_size as u32;
13953 let n_u32 = n as u32;
13954
13955 unsafe {
13956 stream
13957 .launch_builder(&f)
13958 .arg(input.inner())
13959 .arg(out.inner_mut())
13960 .arg(&total_ax_u32)
13961 .arg(&offset_u32)
13962 .arg(&split_sz_u32)
13963 .arg(&inner_u32)
13964 .arg(&n_u32)
13965 .launch(cfg)?;
13966 }
13967
13968 Ok(out)
13969}
13970
13971#[cfg(feature = "cuda")]
13973#[allow(clippy::too_many_arguments)]
13974pub fn gpu_strided_cat_f64(
13975 input: &CudaBuffer<f64>,
13976 output: &mut CudaBuffer<f64>,
13977 total_along_axis: usize,
13978 cat_offset: usize,
13979 part_size: usize,
13980 inner_size: usize,
13981 n: usize,
13982 device: &GpuDevice,
13983) -> GpuResult<()> {
13984 use cudarc::driver::PushKernelArg;
13985
13986 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
13987 validate_device(input, device)?;
13988
13989 let ctx = device.context();
13990 let stream = device.stream();
13991
13992 let ptx = get_f64_ptx(&CACHE, STRIDED_CAT_PTX, "strided_cat_kernel", "strided_cat_f64_kernel");
13993 let f = match crate::module_cache::get_or_compile(
13994 ctx,
13995 ptx,
13996 "strided_cat_f64_kernel",
13997 device.ordinal() as u32,
13998 ) {
13999 Ok(f) => f,
14000 Err(_) => {
14001 let host_in = gpu_to_cpu(input, device)?;
14002 let mut host_out = gpu_to_cpu(output, device)?;
14003 for (i, &val) in host_in.iter().enumerate().take(n) {
14004 let outer_idx = i / (part_size * inner_size);
14005 let within = i % (part_size * inner_size);
14006 let dst_idx =
14007 outer_idx * total_along_axis * inner_size + cat_offset * inner_size + within;
14008 host_out[dst_idx] = val;
14009 }
14010 *output = cpu_to_gpu(&host_out, device)?;
14011 return Ok(());
14012 }
14013 };
14014
14015 let cfg = launch_cfg(n)?;
14016 let total_ax_u32 = total_along_axis as u32;
14017 let offset_u32 = cat_offset as u32;
14018 let part_sz_u32 = part_size as u32;
14019 let inner_u32 = inner_size as u32;
14020 let n_u32 = n as u32;
14021
14022 unsafe {
14023 stream
14024 .launch_builder(&f)
14025 .arg(input.inner())
14026 .arg(output.inner_mut())
14027 .arg(&total_ax_u32)
14028 .arg(&offset_u32)
14029 .arg(&part_sz_u32)
14030 .arg(&inner_u32)
14031 .arg(&n_u32)
14032 .launch(cfg)?;
14033 }
14034
14035 Ok(())
14036}
14037
14038#[cfg(feature = "cuda")]
14044pub fn gpu_index_select_1d_f64(
14045 input: &CudaBuffer<f64>,
14046 indices: &CudaBuffer<f32>,
14047 device: &GpuDevice,
14048) -> GpuResult<CudaBuffer<f64>> {
14049 use cudarc::driver::PushKernelArg;
14050 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14051
14052 validate_device(input, device)?;
14053
14054 let n = indices.len();
14055 let ctx = device.context();
14056 let stream = device.stream();
14057
14058 let ptx = get_f64_ptx(&CACHE, INDEX_SELECT_1D_PTX, "index_select_1d_kernel", "index_select_1d_f64_kernel");
14059 let f = match crate::module_cache::get_or_compile(
14060 ctx,
14061 ptx,
14062 "index_select_1d_f64_kernel",
14063 device.ordinal() as u32,
14064 ) {
14065 Ok(f) => f,
14066 Err(_) => {
14067 let input_host = gpu_to_cpu(input, device)?;
14068 let indices_host = gpu_to_cpu(indices, device)?;
14069 let result: Vec<f64> = indices_host
14070 .iter()
14071 .map(|&idx_f| input_host[idx_f as usize])
14072 .collect();
14073 return cpu_to_gpu(&result, device);
14074 }
14075 };
14076
14077 let mut out = alloc_zeros_f64(n, device)?;
14078 let cfg = launch_cfg(n)?;
14079 let n_u32 = n as u32;
14080
14081 unsafe {
14082 stream
14083 .launch_builder(&f)
14084 .arg(input.inner())
14085 .arg(indices.inner())
14086 .arg(out.inner_mut())
14087 .arg(&n_u32)
14088 .launch(cfg)?;
14089 }
14090
14091 Ok(out)
14092}
14093
14094#[cfg(feature = "cuda")]
14098pub fn gpu_scatter_add_1d_f64(
14099 grad_output: &CudaBuffer<f64>,
14100 indices: &CudaBuffer<f32>,
14101 input_len: usize,
14102 device: &GpuDevice,
14103) -> GpuResult<CudaBuffer<f64>> {
14104 use cudarc::driver::PushKernelArg;
14105 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14106
14107 validate_device(grad_output, device)?;
14108
14109 let n = grad_output.len();
14110 let ctx = device.context();
14111 let stream = device.stream();
14112
14113 let ptx = get_f64_ptx(&CACHE, SCATTER_ADD_1D_PTX, "scatter_add_1d_kernel", "scatter_add_1d_f64_kernel");
14114 let f = match crate::module_cache::get_or_compile(
14115 ctx,
14116 ptx,
14117 "scatter_add_1d_f64_kernel",
14118 device.ordinal() as u32,
14119 ) {
14120 Ok(f) => f,
14121 Err(_) => {
14122 let go_host = gpu_to_cpu(grad_output, device)?;
14123 let idx_host = gpu_to_cpu(indices, device)?;
14124 let mut result = vec![0.0f64; input_len];
14125 for (i, &idx_f) in idx_host.iter().enumerate() {
14126 result[idx_f as usize] += go_host[i];
14127 }
14128 return cpu_to_gpu(&result, device);
14129 }
14130 };
14131
14132 let mut out = alloc_zeros_f64(input_len, device)?;
14133 let cfg = launch_cfg(n)?;
14134 let n_u32 = n as u32;
14135
14136 unsafe {
14137 stream
14138 .launch_builder(&f)
14139 .arg(grad_output.inner())
14140 .arg(indices.inner())
14141 .arg(out.inner_mut())
14142 .arg(&n_u32)
14143 .launch(cfg)?;
14144 }
14145
14146 Ok(out)
14147}
14148
14149#[cfg(feature = "cuda")]
14154pub fn gpu_masked_fill_f64(
14155 input: &CudaBuffer<f64>,
14156 mask: &CudaBuffer<u8>,
14157 value: f64,
14158 device: &GpuDevice,
14159) -> GpuResult<CudaBuffer<f64>> {
14160 use cudarc::driver::PushKernelArg;
14161 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14162
14163 validate_device(input, device)?;
14164
14165 let n = input.len();
14166 let ctx = device.context();
14167 let stream = device.stream();
14168
14169 let ptx = get_f64_ptx(&CACHE, MASKED_FILL_PTX, "masked_fill_kernel", "masked_fill_f64_kernel");
14170 let f = match crate::module_cache::get_or_compile(
14171 ctx,
14172 ptx,
14173 "masked_fill_f64_kernel",
14174 device.ordinal() as u32,
14175 ) {
14176 Ok(f) => f,
14177 Err(_) => {
14178 let input_host = gpu_to_cpu(input, device)?;
14179 let mask_host = gpu_to_cpu(mask, device)?;
14180 let result: Vec<f64> = input_host
14181 .iter()
14182 .zip(mask_host.iter())
14183 .map(|(&x, &m)| if m != 0 { value } else { x })
14184 .collect();
14185 return cpu_to_gpu(&result, device);
14186 }
14187 };
14188
14189 let mut out = alloc_zeros_f64(n, device)?;
14190 let cfg = launch_cfg(n)?;
14191 let n_u32 = n as u32;
14192
14193 unsafe {
14194 stream
14195 .launch_builder(&f)
14196 .arg(input.inner())
14197 .arg(mask.inner())
14198 .arg(out.inner_mut())
14199 .arg(&value)
14200 .arg(&n_u32)
14201 .launch(cfg)?;
14202 }
14203
14204 Ok(out)
14205}
14206
14207#[cfg(feature = "cuda")]
14211pub fn gpu_masked_zero_f64(
14212 grad: &CudaBuffer<f64>,
14213 mask: &CudaBuffer<u8>,
14214 device: &GpuDevice,
14215) -> GpuResult<CudaBuffer<f64>> {
14216 use cudarc::driver::PushKernelArg;
14217 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14218
14219 validate_device(grad, device)?;
14220
14221 let n = grad.len();
14222 let ctx = device.context();
14223 let stream = device.stream();
14224
14225 let ptx = get_f64_ptx(&CACHE, MASKED_ZERO_PTX, "masked_zero_kernel", "masked_zero_f64_kernel");
14226 let f = match crate::module_cache::get_or_compile(
14227 ctx,
14228 ptx,
14229 "masked_zero_f64_kernel",
14230 device.ordinal() as u32,
14231 ) {
14232 Ok(f) => f,
14233 Err(_) => {
14234 let grad_host = gpu_to_cpu(grad, device)?;
14235 let mask_host = gpu_to_cpu(mask, device)?;
14236 let result: Vec<f64> = grad_host
14237 .iter()
14238 .zip(mask_host.iter())
14239 .map(|(&g, &m)| if m != 0 { 0.0 } else { g })
14240 .collect();
14241 return cpu_to_gpu(&result, device);
14242 }
14243 };
14244
14245 let mut out = alloc_zeros_f64(n, device)?;
14246 let cfg = launch_cfg(n)?;
14247 let n_u32 = n as u32;
14248
14249 unsafe {
14250 stream
14251 .launch_builder(&f)
14252 .arg(grad.inner())
14253 .arg(mask.inner())
14254 .arg(out.inner_mut())
14255 .arg(&n_u32)
14256 .launch(cfg)?;
14257 }
14258
14259 Ok(out)
14260}
14261
14262#[cfg(feature = "cuda")]
14264pub fn gpu_slice_write_f64(
14265 src: &CudaBuffer<f64>,
14266 dst: &mut CudaBuffer<f64>,
14267 n_batch: usize,
14268 d: usize,
14269 max_len: usize,
14270 pos: usize,
14271 device: &GpuDevice,
14272) -> GpuResult<()> {
14273 use cudarc::driver::PushKernelArg;
14274 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14275
14276 let total = n_batch * d;
14277 let ctx = device.context();
14278 let stream = device.stream();
14279
14280 let ptx = get_f64_ptx(&CACHE, SLICE_WRITE_PTX, "slice_write_kernel", "slice_write_f64_kernel");
14281 let f = match crate::module_cache::get_or_compile(
14282 ctx,
14283 ptx,
14284 "slice_write_f64_kernel",
14285 device.ordinal() as u32,
14286 ) {
14287 Ok(f) => f,
14288 Err(_) => {
14289 let src_host = gpu_to_cpu(src, device)?;
14290 let mut dst_host = gpu_to_cpu(dst, device)?;
14291 for b in 0..n_batch {
14292 for di in 0..d {
14293 dst_host[b * max_len * d + pos * d + di] = src_host[b * d + di];
14294 }
14295 }
14296 let new_dst = cpu_to_gpu(&dst_host, device)?;
14297 *dst = new_dst;
14298 return Ok(());
14299 }
14300 };
14301
14302 let cfg = launch_cfg(total)?;
14303 let n_u32 = total as u32;
14304 let d_u32 = d as u32;
14305 let max_len_u32 = max_len as u32;
14306 let pos_u32 = pos as u32;
14307
14308 unsafe {
14309 stream
14310 .launch_builder(&f)
14311 .arg(src.inner())
14312 .arg(dst.inner_mut())
14313 .arg(&n_u32)
14314 .arg(&d_u32)
14315 .arg(&max_len_u32)
14316 .arg(&pos_u32)
14317 .launch(cfg)?;
14318 }
14319
14320 Ok(())
14321}
14322
14323#[cfg(feature = "cuda")]
14325pub fn gpu_slice_read_f64(
14326 src: &CudaBuffer<f64>,
14327 n_batch: usize,
14328 d: usize,
14329 len: usize,
14330 max_len: usize,
14331 device: &GpuDevice,
14332) -> GpuResult<CudaBuffer<f64>> {
14333 use cudarc::driver::PushKernelArg;
14334 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14335
14336 let total = n_batch * len * d;
14337 let ctx = device.context();
14338 let stream = device.stream();
14339
14340 let ptx = get_f64_ptx(&CACHE, SLICE_READ_PTX, "slice_read_kernel", "slice_read_f64_kernel");
14341 let f = match crate::module_cache::get_or_compile(
14342 ctx,
14343 ptx,
14344 "slice_read_f64_kernel",
14345 device.ordinal() as u32,
14346 ) {
14347 Ok(f) => f,
14348 Err(_) => {
14349 let host = gpu_to_cpu(src, device)?;
14350 let mut out = vec![0.0f64; total];
14351 for b in 0..n_batch {
14352 for r in 0..len {
14353 for di in 0..d {
14354 out[b * len * d + r * d + di] = host[b * max_len * d + r * d + di];
14355 }
14356 }
14357 }
14358 return cpu_to_gpu(&out, device);
14359 }
14360 };
14361
14362 let mut out = alloc_zeros_f64(total, device)?;
14363 let cfg = launch_cfg(total)?;
14364 let total_u32 = total as u32;
14365 let d_u32 = d as u32;
14366 let len_u32 = len as u32;
14367 let max_len_u32 = max_len as u32;
14368
14369 unsafe {
14370 stream
14371 .launch_builder(&f)
14372 .arg(src.inner())
14373 .arg(out.inner_mut())
14374 .arg(&total_u32)
14375 .arg(&d_u32)
14376 .arg(&len_u32)
14377 .arg(&max_len_u32)
14378 .launch(cfg)?;
14379 }
14380
14381 Ok(out)
14382}
14383
14384#[cfg(feature = "cuda")]
14390pub fn gpu_embed_lookup_f64(
14391 idx: &CudaBuffer<f32>,
14392 weight: &CudaBuffer<f64>,
14393 d: usize,
14394 device: &GpuDevice,
14395) -> GpuResult<CudaBuffer<f64>> {
14396 use cudarc::driver::PushKernelArg;
14397 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14398
14399 let ctx = device.context();
14400 let stream = device.stream();
14401
14402 let ptx = get_f64_ptx(&CACHE, EMBED_LOOKUP_PTX, "embed_lookup_kernel", "embed_lookup_f64_kernel");
14403 let f = match crate::module_cache::get_or_compile(
14404 ctx,
14405 ptx,
14406 "embed_lookup_f64_kernel",
14407 device.ordinal() as u32,
14408 ) {
14409 Ok(f) => f,
14410 Err(_) => {
14411 let idx_host = gpu_to_cpu(idx, device)?;
14412 let weight_host = gpu_to_cpu(weight, device)?;
14413 let row = idx_host[0] as usize;
14414 let start = row * d;
14415 let out = weight_host[start..start + d].to_vec();
14416 return cpu_to_gpu(&out, device);
14417 }
14418 };
14419
14420 let mut out = alloc_zeros_f64(d, device)?;
14421 let cfg = launch_cfg(d)?;
14422 let d_u32 = d as u32;
14423
14424 unsafe {
14425 stream
14426 .launch_builder(&f)
14427 .arg(idx.inner())
14428 .arg(weight.inner())
14429 .arg(out.inner_mut())
14430 .arg(&d_u32)
14431 .launch(cfg)?;
14432 }
14433
14434 Ok(out)
14435}
14436
14437#[cfg(feature = "cuda")]
14439pub fn gpu_embed_lookup_batch_f64(
14440 indices: &CudaBuffer<f32>,
14441 weight: &CudaBuffer<f64>,
14442 n: usize,
14443 d: usize,
14444 device: &GpuDevice,
14445) -> GpuResult<CudaBuffer<f64>> {
14446 use cudarc::driver::PushKernelArg;
14447 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14448
14449 let total = n * d;
14450 if total == 0 {
14451 return alloc_zeros_f64(0, device);
14452 }
14453
14454 let ctx = device.context();
14455 let stream = device.stream();
14456
14457 let ptx = get_f64_ptx(&CACHE, EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel", "embed_lookup_batch_f64_kernel");
14458 let f = match crate::module_cache::get_or_compile(
14459 ctx,
14460 ptx,
14461 "embed_lookup_batch_f64_kernel",
14462 device.ordinal() as u32,
14463 ) {
14464 Ok(f) => f,
14465 Err(_) => {
14466 let idx_host = gpu_to_cpu(indices, device)?;
14467 let weight_host = gpu_to_cpu(weight, device)?;
14468 let mut out = Vec::with_capacity(total);
14469 for &idx_f in &idx_host {
14470 let row = idx_f as usize;
14471 let start = row * d;
14472 out.extend_from_slice(&weight_host[start..start + d]);
14473 }
14474 return cpu_to_gpu(&out, device);
14475 }
14476 };
14477
14478 let mut out = alloc_zeros_f64(total, device)?;
14479 let cfg = launch_cfg(total)?;
14480 let d_u32 = d as u32;
14481 let total_u32 = total as u32;
14482
14483 unsafe {
14484 stream
14485 .launch_builder(&f)
14486 .arg(indices.inner())
14487 .arg(weight.inner())
14488 .arg(out.inner_mut())
14489 .arg(&d_u32)
14490 .arg(&total_u32)
14491 .launch(cfg)?;
14492 }
14493
14494 Ok(out)
14495}
14496
14497#[cfg(feature = "cuda")]
14501pub fn gpu_scatter_add_rows_f64(
14502 grad_output: &CudaBuffer<f64>,
14503 indices: &CudaBuffer<f32>,
14504 num_embeddings: usize,
14505 d: usize,
14506 device: &GpuDevice,
14507) -> GpuResult<CudaBuffer<f64>> {
14508 use cudarc::driver::PushKernelArg;
14509 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
14510
14511 let n = indices.len();
14512 let total = n * d;
14513
14514 if total == 0 {
14515 return alloc_zeros_f64(num_embeddings * d, device);
14516 }
14517
14518 let ctx = device.context();
14519 let stream = device.stream();
14520
14521 let ptx = get_f64_ptx(&CACHE, SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel", "scatter_add_rows_f64_kernel");
14522 let f = match crate::module_cache::get_or_compile(
14523 ctx,
14524 ptx,
14525 "scatter_add_rows_f64_kernel",
14526 device.ordinal() as u32,
14527 ) {
14528 Ok(f) => f,
14529 Err(_) => {
14530 let go_host = gpu_to_cpu(grad_output, device)?;
14531 let idx_host = gpu_to_cpu(indices, device)?;
14532 let mut result = vec![0.0f64; num_embeddings * d];
14533 for (i, &idx_f) in idx_host.iter().enumerate() {
14534 let row = idx_f as usize;
14535 for j in 0..d {
14536 result[row * d + j] += go_host[i * d + j];
14537 }
14538 }
14539 return cpu_to_gpu(&result, device);
14540 }
14541 };
14542
14543 let mut out = alloc_zeros_f64(num_embeddings * d, device)?;
14544 let cfg = launch_cfg(total)?;
14545 let d_u32 = d as u32;
14546 let total_u32 = total as u32;
14547
14548 unsafe {
14549 stream
14550 .launch_builder(&f)
14551 .arg(grad_output.inner())
14552 .arg(indices.inner())
14553 .arg(out.inner_mut())
14554 .arg(&d_u32)
14555 .arg(&total_u32)
14556 .launch(cfg)?;
14557 }
14558
14559 Ok(out)
14560}
14561
14562#[cfg(feature = "cuda")]
14572#[allow(clippy::too_many_arguments)]
14573pub fn gpu_fused_adam(
14574 param: &mut CudaBuffer<f32>,
14575 grad: &CudaBuffer<f32>,
14576 exp_avg: &mut CudaBuffer<f32>,
14577 exp_avg_sq: &mut CudaBuffer<f32>,
14578 beta1: f32,
14579 beta2: f32,
14580 lr: f32,
14581 eps: f32,
14582 bc1: f32,
14583 bc2: f32,
14584 weight_decay: f32,
14585 device: &GpuDevice,
14586) -> GpuResult<()> {
14587 use cudarc::driver::PushKernelArg;
14588
14589 let n = param.len();
14590 if grad.len() != n || exp_avg.len() != n || exp_avg_sq.len() != n {
14591 return Err(GpuError::LengthMismatch {
14592 a: n,
14593 b: grad.len(),
14594 });
14595 }
14596
14597 let ctx = device.context();
14598 let stream = device.stream();
14599
14600 let f = match crate::module_cache::get_or_compile(
14601 ctx,
14602 FUSED_ADAM_PTX,
14603 "fused_adam_kernel",
14604 device.ordinal() as u32,
14605 ) {
14606 Ok(f) => f,
14607 Err(_) => {
14608 let mut p_host = gpu_to_cpu(param, device)?;
14610 let g_host = gpu_to_cpu(grad, device)?;
14611 let mut m_host = gpu_to_cpu(exp_avg, device)?;
14612 let mut v_host = gpu_to_cpu(exp_avg_sq, device)?;
14613
14614 for i in 0..n {
14615 let mut g = g_host[i];
14616 if weight_decay > 0.0 {
14617 g += weight_decay * p_host[i];
14618 }
14619 m_host[i] = beta1 * m_host[i] + (1.0 - beta1) * g;
14620 v_host[i] = beta2 * v_host[i] + (1.0 - beta2) * g * g;
14621 let m_hat = m_host[i] / bc1;
14622 let v_hat = v_host[i] / bc2;
14623 p_host[i] -= lr * m_hat / (v_hat.sqrt() + eps);
14624 }
14625
14626 *param = cpu_to_gpu(&p_host, device)?;
14627 *exp_avg = cpu_to_gpu(&m_host, device)?;
14628 *exp_avg_sq = cpu_to_gpu(&v_host, device)?;
14629 return Ok(());
14630 }
14631 };
14632
14633 let cfg = launch_cfg(n)?;
14634 let n_u32 = n as u32;
14635
14636 unsafe {
14637 stream
14638 .launch_builder(&f)
14639 .arg(param.inner_mut())
14640 .arg(grad.inner())
14641 .arg(exp_avg.inner_mut())
14642 .arg(exp_avg_sq.inner_mut())
14643 .arg(&beta1)
14644 .arg(&beta2)
14645 .arg(&lr)
14646 .arg(&eps)
14647 .arg(&bc1)
14648 .arg(&bc2)
14649 .arg(&weight_decay)
14650 .arg(&n_u32)
14651 .launch(cfg)?;
14652 }
14653
14654 Ok(())
14655}
14656
14657#[cfg(not(feature = "cuda"))]
14659#[allow(clippy::too_many_arguments)]
14660pub fn gpu_fused_adam(
14661 _param: &mut CudaBuffer<f32>,
14662 _grad: &CudaBuffer<f32>,
14663 _exp_avg: &mut CudaBuffer<f32>,
14664 _exp_avg_sq: &mut CudaBuffer<f32>,
14665 _beta1: f32,
14666 _beta2: f32,
14667 _lr: f32,
14668 _eps: f32,
14669 _bc1: f32,
14670 _bc2: f32,
14671 _weight_decay: f32,
14672 _device: &GpuDevice,
14673) -> GpuResult<()> {
14674 Err(GpuError::NoCudaFeature)
14675}
14676
14677#[cfg(feature = "cuda")]
14695pub fn gpu_fused_gru_forward(
14696 input_gates: &CudaBuffer<f32>,
14697 hidden_gates: &CudaBuffer<f32>,
14698 bias_ih: &CudaBuffer<f32>,
14699 bias_hh: &CudaBuffer<f32>,
14700 hx: &CudaBuffer<f32>,
14701 hsz: usize,
14702 device: &GpuDevice,
14703) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
14704 use cudarc::driver::PushKernelArg;
14705
14706 let total = hx.len(); let batch = total / hsz;
14708
14709 let ctx = device.context();
14710 let stream = device.stream();
14711
14712 let f = match crate::module_cache::get_or_compile(
14713 ctx,
14714 FUSED_GRU_FORWARD_PTX,
14715 "fused_gru_forward_kernel",
14716 device.ordinal() as u32,
14717 ) {
14718 Ok(f) => f,
14719 Err(_) => {
14720 return Err(GpuError::PtxCompileFailed {
14721 kernel: "fused_gru_forward_kernel",
14722 });
14723 }
14724 };
14725
14726 let mut hy = alloc_zeros_f32(total, device)?;
14727 let mut workspace = alloc_zeros_f32(batch * 5 * hsz, device)?;
14728
14729 let cfg = launch_cfg(total)?;
14730 let hsz_u32 = hsz as u32;
14731 let total_u32 = total as u32;
14732
14733 unsafe {
14734 stream
14735 .launch_builder(&f)
14736 .arg(input_gates.inner())
14737 .arg(hidden_gates.inner())
14738 .arg(bias_ih.inner())
14739 .arg(bias_hh.inner())
14740 .arg(hx.inner())
14741 .arg(hy.inner_mut())
14742 .arg(workspace.inner_mut())
14743 .arg(&hsz_u32)
14744 .arg(&total_u32)
14745 .launch(cfg)?;
14746 }
14747
14748 Ok((hy, workspace))
14749}
14750
14751#[cfg(not(feature = "cuda"))]
14753pub fn gpu_fused_gru_forward(
14754 _input_gates: &CudaBuffer<f32>,
14755 _hidden_gates: &CudaBuffer<f32>,
14756 _bias_ih: &CudaBuffer<f32>,
14757 _bias_hh: &CudaBuffer<f32>,
14758 _hx: &CudaBuffer<f32>,
14759 _hsz: usize,
14760 _device: &GpuDevice,
14761) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
14762 Err(GpuError::NoCudaFeature)
14763}
14764
14765#[cfg(feature = "cuda")]
14771#[allow(clippy::too_many_arguments)]
14772pub fn gpu_maxpool2d(
14773 input: &CudaBuffer<f32>,
14774 batch: usize,
14775 channels: usize,
14776 h_in: usize,
14777 w_in: usize,
14778 kh: usize,
14779 kw: usize,
14780 sh: usize,
14781 sw: usize,
14782 ph: usize,
14783 pw: usize,
14784 device: &GpuDevice,
14785) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
14786 use cudarc::driver::PushKernelArg;
14787
14788 let h_out = (h_in + 2 * ph - kh) / sh + 1;
14789 let w_out = (w_in + 2 * pw - kw) / sw + 1;
14790 let total = batch * channels * h_out * w_out;
14791
14792 let ctx = device.context();
14793 let stream = device.stream();
14794
14795 let f = match crate::module_cache::get_or_compile(
14796 ctx, MAXPOOL2D_PTX, "maxpool2d_forward_kernel", device.ordinal() as u32,
14797 ) {
14798 Ok(f) => f,
14799 Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "maxpool2d_forward_kernel" }),
14800 };
14801
14802 let mut out = alloc_zeros_f32(total, device)?;
14803 let cfg = launch_cfg(total)?;
14804
14805 let (batch_u32, ch_u32) = (batch as u32, channels as u32);
14806 let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
14807 let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
14808 let (kh_u32, kw_u32) = (kh as u32, kw as u32);
14809 let (sh_u32, sw_u32) = (sh as u32, sw as u32);
14810 let (ph_u32, pw_u32) = (ph as u32, pw as u32);
14811 let total_u32 = total as u32;
14812
14813 unsafe {
14814 stream.launch_builder(&f)
14815 .arg(input.inner())
14816 .arg(out.inner_mut())
14817 .arg(&batch_u32).arg(&ch_u32)
14818 .arg(&h_in_u32).arg(&w_in_u32)
14819 .arg(&h_out_u32).arg(&w_out_u32)
14820 .arg(&kh_u32).arg(&kw_u32)
14821 .arg(&sh_u32).arg(&sw_u32)
14822 .arg(&ph_u32).arg(&pw_u32)
14823 .arg(&total_u32)
14824 .launch(cfg)?;
14825 }
14826
14827 Ok((out, [batch, channels, h_out, w_out]))
14828}
14829
14830#[cfg(not(feature = "cuda"))]
14832#[allow(clippy::too_many_arguments)]
14833pub fn gpu_maxpool2d(
14834 _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
14835 _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
14836 _sh: usize, _sw: usize, _ph: usize, _pw: usize,
14837 _device: &GpuDevice,
14838) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
14839 Err(GpuError::NoCudaFeature)
14840}
14841
14842#[cfg(feature = "cuda")]
14844#[allow(clippy::too_many_arguments)]
14845pub fn gpu_avgpool2d(
14846 input: &CudaBuffer<f32>,
14847 batch: usize,
14848 channels: usize,
14849 h_in: usize,
14850 w_in: usize,
14851 kh: usize,
14852 kw: usize,
14853 sh: usize,
14854 sw: usize,
14855 ph: usize,
14856 pw: usize,
14857 device: &GpuDevice,
14858) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
14859 use cudarc::driver::PushKernelArg;
14860
14861 let h_out = (h_in + 2 * ph - kh) / sh + 1;
14862 let w_out = (w_in + 2 * pw - kw) / sw + 1;
14863 let total = batch * channels * h_out * w_out;
14864
14865 let ctx = device.context();
14866 let stream = device.stream();
14867
14868 let f = match crate::module_cache::get_or_compile(
14869 ctx, AVGPOOL2D_PTX, "avgpool2d_forward_kernel", device.ordinal() as u32,
14870 ) {
14871 Ok(f) => f,
14872 Err(_) => return Err(GpuError::PtxCompileFailed { kernel: "avgpool2d_forward_kernel" }),
14873 };
14874
14875 let mut out = alloc_zeros_f32(total, device)?;
14876 let cfg = launch_cfg(total)?;
14877
14878 let (batch_u32, ch_u32) = (batch as u32, channels as u32);
14879 let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
14880 let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
14881 let (kh_u32, kw_u32) = (kh as u32, kw as u32);
14882 let (sh_u32, sw_u32) = (sh as u32, sw as u32);
14883 let (ph_u32, pw_u32) = (ph as u32, pw as u32);
14884 let total_u32 = total as u32;
14885
14886 unsafe {
14887 stream.launch_builder(&f)
14888 .arg(input.inner())
14889 .arg(out.inner_mut())
14890 .arg(&batch_u32).arg(&ch_u32)
14891 .arg(&h_in_u32).arg(&w_in_u32)
14892 .arg(&h_out_u32).arg(&w_out_u32)
14893 .arg(&kh_u32).arg(&kw_u32)
14894 .arg(&sh_u32).arg(&sw_u32)
14895 .arg(&ph_u32).arg(&pw_u32)
14896 .arg(&total_u32)
14897 .launch(cfg)?;
14898 }
14899
14900 Ok((out, [batch, channels, h_out, w_out]))
14901}
14902
14903#[cfg(not(feature = "cuda"))]
14905#[allow(clippy::too_many_arguments)]
14906pub fn gpu_avgpool2d(
14907 _input: &CudaBuffer<f32>, _batch: usize, _channels: usize,
14908 _h_in: usize, _w_in: usize, _kh: usize, _kw: usize,
14909 _sh: usize, _sw: usize, _ph: usize, _pw: usize,
14910 _device: &GpuDevice,
14911) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
14912 Err(GpuError::NoCudaFeature)
14913}
14914
14915#[cfg(feature = "cuda")]
14923#[allow(clippy::too_many_arguments)]
14924pub fn gpu_batchnorm_forward(
14925 _input: &CudaBuffer<f32>,
14926 _weight: &CudaBuffer<f32>,
14927 _bias: &CudaBuffer<f32>,
14928 _running_mean: &mut CudaBuffer<f32>,
14929 _running_var: &mut CudaBuffer<f32>,
14930 _channels: usize,
14931 _spatial: usize,
14932 _eps: f32,
14933 _momentum: f32,
14934 _training: bool,
14935 device: &GpuDevice,
14936) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
14937 let ctx = device.context();
14939 let _f = crate::module_cache::get_or_compile(
14940 ctx,
14941 BATCHNORM_FORWARD_PTX,
14942 "batchnorm_forward_kernel",
14943 device.ordinal() as u32,
14944 );
14945 Err(GpuError::ShapeMismatch {
14947 op: "batchnorm_forward",
14948 expected: vec![0],
14949 got: vec![1],
14950 })
14951}
14952
14953#[cfg(not(feature = "cuda"))]
14955#[allow(clippy::too_many_arguments)]
14956pub fn gpu_batchnorm_forward(
14957 _input: &CudaBuffer<f32>,
14958 _weight: &CudaBuffer<f32>,
14959 _bias: &CudaBuffer<f32>,
14960 _running_mean: &mut CudaBuffer<f32>,
14961 _running_var: &mut CudaBuffer<f32>,
14962 _channels: usize,
14963 _spatial: usize,
14964 _eps: f32,
14965 _momentum: f32,
14966 _training: bool,
14967 _device: &GpuDevice,
14968) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
14969 Err(GpuError::NoCudaFeature)
14970}
14971
14972#[cfg(feature = "cuda")]
14981pub fn gpu_layernorm(
14982 input: &CudaBuffer<f32>,
14983 weight: &CudaBuffer<f32>,
14984 bias: &CudaBuffer<f32>,
14985 rows: usize,
14986 cols: usize,
14987 eps: f32,
14988 device: &GpuDevice,
14989) -> GpuResult<CudaBuffer<f32>> {
14990 use cudarc::driver::PushKernelArg;
14991
14992 validate_unary(input, device)?;
14993
14994 let ctx = device.context();
14995 let stream = device.stream();
14996
14997 let f = match crate::module_cache::get_or_compile(
14998 ctx,
14999 LAYERNORM_PTX,
15000 "layernorm_kernel",
15001 device.ordinal() as u32,
15002 ) {
15003 Ok(f) => f,
15004 Err(e) => {
15005 eprintln!("ferrotorch-gpu: LayerNorm PTX compilation failed ({e:?}), CPU fallback");
15006 std::fs::write("/tmp/layernorm_debug.ptx", LAYERNORM_PTX).ok();
15007 eprintln!(
15008 "ferrotorch-gpu: dumped PTX to /tmp/layernorm_debug.ptx ({} bytes)",
15009 LAYERNORM_PTX.len()
15010 );
15011 let h_in = gpu_to_cpu(input, device)?;
15012 let h_w = gpu_to_cpu(weight, device)?;
15013 let h_b = gpu_to_cpu(bias, device)?;
15014 let mut out = vec![0.0f32; rows * cols];
15015 for r in 0..rows {
15016 let base = r * cols;
15017 let slice = &h_in[base..base + cols];
15018 let mean: f32 = slice.iter().sum::<f32>() / cols as f32;
15019 let var: f32 =
15020 slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / cols as f32;
15021 let inv_std = 1.0 / (var + eps).sqrt();
15022 for c in 0..cols {
15023 let normed = (slice[c] - mean) * inv_std;
15024 out[base + c] = h_w[c] * normed + h_b[c];
15025 }
15026 }
15027 return cpu_to_gpu(&out, device);
15028 }
15029 };
15030
15031 let mut out = alloc_zeros_f32(rows * cols, device)?;
15032 let rows_u32 = rows as u32;
15033 let cols_u32 = cols as u32;
15034
15035 let cfg = LaunchConfig {
15036 grid_dim: ((rows as u32).max(1), 1, 1),
15037 block_dim: (256, 1, 1),
15038 shared_mem_bytes: 256 * 4,
15039 };
15040
15041 unsafe {
15042 stream
15043 .launch_builder(&f)
15044 .arg(input.inner())
15045 .arg(out.inner_mut())
15046 .arg(weight.inner())
15047 .arg(bias.inner())
15048 .arg(&rows_u32)
15049 .arg(&cols_u32)
15050 .arg(&eps)
15051 .launch(cfg)?;
15052 }
15053
15054 Ok(out)
15055}
15056
15057#[cfg(feature = "cuda")]
15070pub fn gpu_layernorm_backward(
15071 input: &CudaBuffer<f32>,
15072 grad_output: &CudaBuffer<f32>,
15073 weight: &CudaBuffer<f32>,
15074 rows: usize,
15075 cols: usize,
15076 eps: f32,
15077 device: &GpuDevice,
15078) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
15079 use cudarc::driver::PushKernelArg;
15080
15081 validate_unary(input, device)?;
15082
15083 let ctx = device.context();
15084 let stream = device.stream();
15085
15086 let f = match crate::module_cache::get_or_compile(
15087 ctx,
15088 LAYERNORM_BACKWARD_PTX,
15089 "layernorm_backward_kernel",
15090 device.ordinal() as u32,
15091 ) {
15092 Ok(f) => f,
15093 Err(_) => {
15094 let h_in = gpu_to_cpu(input, device)?;
15096 let h_go = gpu_to_cpu(grad_output, device)?;
15097 let h_w = gpu_to_cpu(weight, device)?;
15098 let mut grad_input = vec![0.0f32; rows * cols];
15099 let mut grad_weight = vec![0.0f32; cols];
15100 let mut grad_bias = vec![0.0f32; cols];
15101 let n_f = cols as f32;
15102 for r in 0..rows {
15103 let base = r * cols;
15104 let x_slice = &h_in[base..base + cols];
15105 let go_slice = &h_go[base..base + cols];
15106 let mean: f32 = x_slice.iter().sum::<f32>() / n_f;
15107 let var: f32 = x_slice
15108 .iter()
15109 .map(|&x| (x - mean) * (x - mean))
15110 .sum::<f32>()
15111 / n_f;
15112 let inv_std = 1.0 / (var + eps).sqrt();
15113 let mut sum1 = 0.0f32;
15114 let mut sum2 = 0.0f32;
15115 for c in 0..cols {
15116 let x_hat = (x_slice[c] - mean) * inv_std;
15117 let dl = go_slice[c] * h_w[c];
15118 sum1 += dl;
15119 sum2 += dl * x_hat;
15120 grad_weight[c] += go_slice[c] * x_hat;
15121 grad_bias[c] += go_slice[c];
15122 }
15123 let m1 = sum1 / n_f;
15124 let m2 = sum2 / n_f;
15125 for c in 0..cols {
15126 let x_hat = (x_slice[c] - mean) * inv_std;
15127 let dl = go_slice[c] * h_w[c];
15128 grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
15129 }
15130 }
15131 let gi = cpu_to_gpu(&grad_input, device)?;
15132 let gw = cpu_to_gpu(&grad_weight, device)?;
15133 let gb = cpu_to_gpu(&grad_bias, device)?;
15134 return Ok((gi, gw, gb));
15135 }
15136 };
15137
15138 let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
15139 let mut grad_w = alloc_zeros_f32(cols, device)?;
15140 let mut grad_b = alloc_zeros_f32(cols, device)?;
15141 let rows_u32 = rows as u32;
15142 let cols_u32 = cols as u32;
15143
15144 let cfg = LaunchConfig {
15146 grid_dim: ((rows as u32).max(1), 1, 1),
15147 block_dim: (256, 1, 1),
15148 shared_mem_bytes: 256 * 4,
15149 };
15150
15151 unsafe {
15152 stream
15153 .launch_builder(&f)
15154 .arg(input.inner())
15155 .arg(grad_output.inner())
15156 .arg(weight.inner())
15157 .arg(grad_in.inner_mut())
15158 .arg(grad_w.inner_mut())
15159 .arg(grad_b.inner_mut())
15160 .arg(&rows_u32)
15161 .arg(&cols_u32)
15162 .arg(&eps)
15163 .launch(cfg)?;
15164 }
15165
15166 Ok((grad_in, grad_w, grad_b))
15167}
15168
15169#[cfg(not(feature = "cuda"))]
15171pub fn gpu_layernorm_backward(
15172 _input: &CudaBuffer<f32>,
15173 _grad_output: &CudaBuffer<f32>,
15174 _weight: &CudaBuffer<f32>,
15175 _rows: usize,
15176 _cols: usize,
15177 _eps: f32,
15178 _device: &GpuDevice,
15179) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
15180 Err(GpuError::NoCudaFeature)
15181}
15182
15183#[cfg(feature = "cuda")]
15195pub fn gpu_rmsnorm(
15196 input: &CudaBuffer<f32>,
15197 weight: &CudaBuffer<f32>,
15198 rows: usize,
15199 cols: usize,
15200 eps: f32,
15201 device: &GpuDevice,
15202) -> GpuResult<CudaBuffer<f32>> {
15203 use cudarc::driver::PushKernelArg;
15204
15205 validate_unary(input, device)?;
15206
15207 let ctx = device.context();
15208 let stream = device.stream();
15209
15210 let f = match crate::module_cache::get_or_compile(
15211 ctx,
15212 RMSNORM_PTX,
15213 "rmsnorm_kernel",
15214 device.ordinal() as u32,
15215 ) {
15216 Ok(f) => f,
15217 Err(e) => {
15218 eprintln!("ferrotorch-gpu: RMSNorm PTX compilation failed ({e:?}), CPU fallback");
15219 std::fs::write("/tmp/rmsnorm_debug.ptx", RMSNORM_PTX).ok();
15220 eprintln!(
15221 "ferrotorch-gpu: dumped PTX to /tmp/rmsnorm_debug.ptx ({} bytes)",
15222 RMSNORM_PTX.len()
15223 );
15224 let h_in = gpu_to_cpu(input, device)?;
15225 let h_w = gpu_to_cpu(weight, device)?;
15226 let mut out = vec![0.0f32; rows * cols];
15227 for r in 0..rows {
15228 let base = r * cols;
15229 let slice = &h_in[base..base + cols];
15230 let sq_mean: f32 =
15231 slice.iter().map(|&x| x * x).sum::<f32>() / cols as f32;
15232 let inv_rms = 1.0 / (sq_mean + eps).sqrt();
15233 for c in 0..cols {
15234 out[base + c] = slice[c] * inv_rms * h_w[c];
15235 }
15236 }
15237 return cpu_to_gpu(&out, device);
15238 }
15239 };
15240
15241 let mut out = alloc_zeros_f32(rows * cols, device)?;
15242 let rows_u32 = rows as u32;
15243 let cols_u32 = cols as u32;
15244
15245 let cfg = LaunchConfig {
15246 grid_dim: ((rows as u32).max(1), 1, 1),
15247 block_dim: (256, 1, 1),
15248 shared_mem_bytes: 256 * 4,
15249 };
15250
15251 unsafe {
15252 stream
15253 .launch_builder(&f)
15254 .arg(input.inner())
15255 .arg(out.inner_mut())
15256 .arg(weight.inner())
15257 .arg(&rows_u32)
15258 .arg(&cols_u32)
15259 .arg(&eps)
15260 .launch(cfg)?;
15261 }
15262
15263 Ok(out)
15264}
15265
15266#[cfg(feature = "cuda")]
15279pub fn gpu_rmsnorm_backward(
15280 input: &CudaBuffer<f32>,
15281 grad_output: &CudaBuffer<f32>,
15282 weight: &CudaBuffer<f32>,
15283 rows: usize,
15284 cols: usize,
15285 eps: f32,
15286 device: &GpuDevice,
15287) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
15288 use cudarc::driver::PushKernelArg;
15289
15290 validate_unary(input, device)?;
15291
15292 let ctx = device.context();
15293 let stream = device.stream();
15294
15295 let f = match crate::module_cache::get_or_compile(
15296 ctx,
15297 RMSNORM_BACKWARD_PTX,
15298 "rmsnorm_backward_kernel",
15299 device.ordinal() as u32,
15300 ) {
15301 Ok(f) => f,
15302 Err(_) => {
15303 let h_in = gpu_to_cpu(input, device)?;
15305 let h_go = gpu_to_cpu(grad_output, device)?;
15306 let h_w = gpu_to_cpu(weight, device)?;
15307 let mut grad_input = vec![0.0f32; rows * cols];
15308 let mut grad_weight = vec![0.0f32; cols];
15309 let n_f = cols as f32;
15310 for r in 0..rows {
15311 let base = r * cols;
15312 let x_slice = &h_in[base..base + cols];
15313 let go_slice = &h_go[base..base + cols];
15314 let sq_mean: f32 =
15315 x_slice.iter().map(|&x| x * x).sum::<f32>() / n_f;
15316 let inv_rms = 1.0 / (sq_mean + eps).sqrt();
15317 let inv_rms3 = inv_rms * inv_rms * inv_rms;
15318 let mut dot = 0.0f32;
15319 for c in 0..cols {
15320 dot += go_slice[c] * x_slice[c] * h_w[c];
15321 grad_weight[c] += go_slice[c] * x_slice[c] * inv_rms;
15322 }
15323 let coeff = dot * inv_rms3 / n_f;
15324 for c in 0..cols {
15325 grad_input[base + c] =
15326 inv_rms * h_w[c] * go_slice[c] - x_slice[c] * coeff;
15327 }
15328 }
15329 let gi = cpu_to_gpu(&grad_input, device)?;
15330 let gw = cpu_to_gpu(&grad_weight, device)?;
15331 return Ok((gi, gw));
15332 }
15333 };
15334
15335 let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
15336 let mut grad_w = alloc_zeros_f32(cols, device)?;
15337 let rows_u32 = rows as u32;
15338 let cols_u32 = cols as u32;
15339
15340 let cfg = LaunchConfig {
15342 grid_dim: ((rows as u32).max(1), 1, 1),
15343 block_dim: (256, 1, 1),
15344 shared_mem_bytes: 256 * 4,
15345 };
15346
15347 unsafe {
15348 stream
15349 .launch_builder(&f)
15350 .arg(input.inner())
15351 .arg(grad_output.inner())
15352 .arg(weight.inner())
15353 .arg(grad_in.inner_mut())
15354 .arg(grad_w.inner_mut())
15355 .arg(&rows_u32)
15356 .arg(&cols_u32)
15357 .arg(&eps)
15358 .launch(cfg)?;
15359 }
15360
15361 Ok((grad_in, grad_w))
15362}
15363
15364#[cfg(not(feature = "cuda"))]
15366pub fn gpu_rmsnorm(
15367 _input: &CudaBuffer<f32>,
15368 _weight: &CudaBuffer<f32>,
15369 _rows: usize,
15370 _cols: usize,
15371 _eps: f32,
15372 _device: &GpuDevice,
15373) -> GpuResult<CudaBuffer<f32>> {
15374 Err(GpuError::NoCudaFeature)
15375}
15376
15377#[cfg(not(feature = "cuda"))]
15379pub fn gpu_rmsnorm_backward(
15380 _input: &CudaBuffer<f32>,
15381 _grad_output: &CudaBuffer<f32>,
15382 _weight: &CudaBuffer<f32>,
15383 _rows: usize,
15384 _cols: usize,
15385 _eps: f32,
15386 _device: &GpuDevice,
15387) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
15388 Err(GpuError::NoCudaFeature)
15389}
15390
15391#[cfg(feature = "cuda")]
15401pub fn gpu_add_into(
15402 a: &CudaBuffer<f32>,
15403 b: &CudaBuffer<f32>,
15404 out: &mut CudaBuffer<f32>,
15405 device: &GpuDevice,
15406) -> GpuResult<()> {
15407 validate_binary(a, b, device)?;
15408 if out.len() < a.len() {
15409 return Err(GpuError::ShapeMismatch {
15410 op: "add_into",
15411 expected: vec![a.len()],
15412 got: vec![out.len()],
15413 });
15414 }
15415 if try_launch_binary_into(a, b, out, device, ADD_PTX, "add_kernel")? {
15416 return Ok(());
15417 }
15418 Err(GpuError::PtxCompileFailed {
15419 kernel: "add_kernel",
15420 })
15421}
15422
15423#[cfg(feature = "cuda")]
15425pub fn gpu_mul_into(
15426 a: &CudaBuffer<f32>,
15427 b: &CudaBuffer<f32>,
15428 out: &mut CudaBuffer<f32>,
15429 device: &GpuDevice,
15430) -> GpuResult<()> {
15431 validate_binary(a, b, device)?;
15432 if out.len() < a.len() {
15433 return Err(GpuError::ShapeMismatch {
15434 op: "mul_into",
15435 expected: vec![a.len()],
15436 got: vec![out.len()],
15437 });
15438 }
15439 if try_launch_binary_into(a, b, out, device, MUL_PTX, "mul_kernel")? {
15440 return Ok(());
15441 }
15442 Err(GpuError::PtxCompileFailed {
15443 kernel: "mul_kernel",
15444 })
15445}
15446
15447#[cfg(feature = "cuda")]
15449pub fn gpu_scale_into(
15450 a: &CudaBuffer<f32>,
15451 scalar: f32,
15452 out: &mut CudaBuffer<f32>,
15453 device: &GpuDevice,
15454) -> GpuResult<()> {
15455 use cudarc::driver::PushKernelArg;
15456 validate_unary(a, device)?;
15457 let n = a.len();
15458 let ctx = device.context();
15459 let stream = device.stream();
15460 let f = crate::module_cache::get_or_compile(
15461 ctx,
15462 SCALE_PTX,
15463 "scale_kernel",
15464 device.ordinal() as u32,
15465 )
15466 .map_err(|_| GpuError::PtxCompileFailed {
15467 kernel: "scale_kernel",
15468 })?;
15469 let cfg = launch_cfg(n)?;
15470 let n_u32 = n as u32;
15471 unsafe {
15472 stream
15473 .launch_builder(&f)
15474 .arg(a.inner())
15475 .arg(out.inner_mut())
15476 .arg(&scalar)
15477 .arg(&n_u32)
15478 .launch(cfg)?;
15479 }
15480 Ok(())
15481}
15482
15483#[cfg(feature = "cuda")]
15500pub fn gpu_has_inf_nan(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<bool> {
15501 let n = a.len();
15502 if n == 0 {
15503 return Ok(false);
15504 }
15505
15506 validate_unary(a, device)?;
15507
15508 let host: Vec<f32> = crate::transfer::gpu_to_cpu(a, device)?;
15509 Ok(host.iter().any(|v| !v.is_finite()))
15510}
15511
15512#[cfg(not(feature = "cuda"))]
15514pub fn gpu_has_inf_nan(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<bool> {
15515 Err(GpuError::NoCudaFeature)
15516}
15517
15518#[cfg(feature = "cuda")]
15520pub fn gpu_gelu_into(
15521 a: &CudaBuffer<f32>,
15522 out: &mut CudaBuffer<f32>,
15523 device: &GpuDevice,
15524) -> GpuResult<()> {
15525 validate_unary(a, device)?;
15526 if try_launch_unary_into(a, out, device, GELU_PTX, "gelu_kernel")? {
15527 return Ok(());
15528 }
15529 Err(GpuError::PtxCompileFailed {
15530 kernel: "gelu_kernel",
15531 })
15532}
15533
15534#[cfg(feature = "cuda")]
15536pub fn gpu_embed_lookup_into(
15537 idx: &CudaBuffer<f32>,
15538 weight: &CudaBuffer<f32>,
15539 d: usize,
15540 out: &mut CudaBuffer<f32>,
15541 device: &GpuDevice,
15542) -> GpuResult<()> {
15543 use cudarc::driver::PushKernelArg;
15544 let ctx = device.context();
15545 let stream = device.stream();
15546 let f = crate::module_cache::get_or_compile(
15547 ctx,
15548 EMBED_LOOKUP_PTX,
15549 "embed_lookup_kernel",
15550 device.ordinal() as u32,
15551 )
15552 .map_err(|_| GpuError::PtxCompileFailed {
15553 kernel: "embed_lookup_kernel",
15554 })?;
15555 let cfg = launch_cfg(d)?;
15556 let d_u32 = d as u32;
15557 unsafe {
15558 stream
15559 .launch_builder(&f)
15560 .arg(idx.inner())
15561 .arg(weight.inner())
15562 .arg(out.inner_mut())
15563 .arg(&d_u32)
15564 .launch(cfg)?;
15565 }
15566 Ok(())
15567}
15568
15569#[cfg(feature = "cuda")]
15577pub fn gpu_embed_lookup_batch(
15578 indices: &CudaBuffer<f32>,
15579 weight: &CudaBuffer<f32>,
15580 n: usize,
15581 d: usize,
15582 device: &GpuDevice,
15583) -> GpuResult<CudaBuffer<f32>> {
15584 use cudarc::driver::PushKernelArg;
15585
15586 let total = n * d;
15587 if total == 0 {
15588 return alloc_zeros_f32(0, device);
15589 }
15590
15591 let ctx = device.context();
15592 let stream = device.stream();
15593
15594 let f = match crate::module_cache::get_or_compile(
15595 ctx,
15596 EMBED_LOOKUP_BATCH_PTX,
15597 "embed_lookup_batch_kernel",
15598 device.ordinal() as u32,
15599 ) {
15600 Ok(f) => f,
15601 Err(_) => {
15602 let idx_host = gpu_to_cpu(indices, device)?;
15604 let weight_host = gpu_to_cpu(weight, device)?;
15605 let mut out = Vec::with_capacity(total);
15606 for &idx_f in &idx_host {
15607 let row = idx_f as usize;
15608 let start = row * d;
15609 out.extend_from_slice(&weight_host[start..start + d]);
15610 }
15611 return cpu_to_gpu(&out, device);
15612 }
15613 };
15614
15615 let mut out = alloc_zeros_f32(total, device)?;
15616 let cfg = launch_cfg(total)?;
15617 let d_u32 = d as u32;
15618 let total_u32 = total as u32;
15619
15620 unsafe {
15621 stream
15622 .launch_builder(&f)
15623 .arg(indices.inner())
15624 .arg(weight.inner())
15625 .arg(out.inner_mut())
15626 .arg(&d_u32)
15627 .arg(&total_u32)
15628 .launch(cfg)?;
15629 }
15630
15631 Ok(out)
15632}
15633
15634#[cfg(feature = "cuda")]
15644pub fn gpu_scatter_add_rows(
15645 grad_output: &CudaBuffer<f32>,
15646 indices: &CudaBuffer<f32>,
15647 num_embeddings: usize,
15648 d: usize,
15649 device: &GpuDevice,
15650) -> GpuResult<CudaBuffer<f32>> {
15651 use cudarc::driver::PushKernelArg;
15652
15653 let n = indices.len();
15654 let total = n * d;
15655
15656 if total == 0 {
15657 return alloc_zeros_f32(num_embeddings * d, device);
15658 }
15659
15660 let ctx = device.context();
15661 let stream = device.stream();
15662
15663 let f = match crate::module_cache::get_or_compile(
15664 ctx,
15665 SCATTER_ADD_ROWS_PTX,
15666 "scatter_add_rows_kernel",
15667 device.ordinal() as u32,
15668 ) {
15669 Ok(f) => f,
15670 Err(_) => {
15671 let go_host = gpu_to_cpu(grad_output, device)?;
15673 let idx_host = gpu_to_cpu(indices, device)?;
15674 let mut result = vec![0.0f32; num_embeddings * d];
15675 for (i, &idx_f) in idx_host.iter().enumerate() {
15676 let row = idx_f as usize;
15677 for j in 0..d {
15678 result[row * d + j] += go_host[i * d + j];
15679 }
15680 }
15681 return cpu_to_gpu(&result, device);
15682 }
15683 };
15684
15685 let mut out = alloc_zeros_f32(num_embeddings * d, device)?;
15686 let cfg = launch_cfg(total)?;
15687 let d_u32 = d as u32;
15688 let total_u32 = total as u32;
15689
15690 unsafe {
15691 stream
15692 .launch_builder(&f)
15693 .arg(grad_output.inner())
15694 .arg(indices.inner())
15695 .arg(out.inner_mut())
15696 .arg(&d_u32)
15697 .arg(&total_u32)
15698 .launch(cfg)?;
15699 }
15700
15701 Ok(out)
15702}
15703
15704#[cfg(feature = "cuda")]
15706pub fn gpu_transpose_2d_into(
15707 a: &CudaBuffer<f32>,
15708 m: usize,
15709 n: usize,
15710 out: &mut CudaBuffer<f32>,
15711 device: &GpuDevice,
15712) -> GpuResult<()> {
15713 use cudarc::driver::PushKernelArg;
15714 let total = m * n;
15715 let ctx = device.context();
15716 let stream = device.stream();
15717 let f = crate::module_cache::get_or_compile(
15718 ctx,
15719 TRANSPOSE_2D_PTX,
15720 "transpose_2d_kernel",
15721 device.ordinal() as u32,
15722 )
15723 .map_err(|_| GpuError::PtxCompileFailed {
15724 kernel: "transpose_2d_kernel",
15725 })?;
15726 let cfg = launch_cfg(total)?;
15727 let m_u32 = m as u32;
15728 let n_u32 = n as u32;
15729 let total_u32 = total as u32;
15730 unsafe {
15731 stream
15732 .launch_builder(&f)
15733 .arg(a.inner())
15734 .arg(out.inner_mut())
15735 .arg(&m_u32)
15736 .arg(&n_u32)
15737 .arg(&total_u32)
15738 .launch(cfg)?;
15739 }
15740 Ok(())
15741}
15742
15743#[cfg(feature = "cuda")]
15745pub fn gpu_permute_0213_into(
15746 a: &CudaBuffer<f32>,
15747 d0: usize,
15748 d1: usize,
15749 d2: usize,
15750 d3: usize,
15751 out: &mut CudaBuffer<f32>,
15752 device: &GpuDevice,
15753) -> GpuResult<()> {
15754 use cudarc::driver::PushKernelArg;
15755 let total = d0 * d1 * d2 * d3;
15756 let ctx = device.context();
15757 let stream = device.stream();
15758 let f = crate::module_cache::get_or_compile(
15759 ctx,
15760 PERMUTE_0213_PTX,
15761 "permute_0213_kernel",
15762 device.ordinal() as u32,
15763 )
15764 .map_err(|_| GpuError::PtxCompileFailed {
15765 kernel: "permute_0213_kernel",
15766 })?;
15767 let cfg = launch_cfg(total)?;
15768 let (d0u, d1u, d2u, d3u, tu) = (d0 as u32, d1 as u32, d2 as u32, d3 as u32, total as u32);
15769 unsafe {
15770 stream
15771 .launch_builder(&f)
15772 .arg(a.inner())
15773 .arg(out.inner_mut())
15774 .arg(&d0u)
15775 .arg(&d1u)
15776 .arg(&d2u)
15777 .arg(&d3u)
15778 .arg(&tu)
15779 .launch(cfg)?;
15780 }
15781 Ok(())
15782}
15783
15784#[cfg(feature = "cuda")]
15786pub fn gpu_softmax_into(
15787 a: &CudaBuffer<f32>,
15788 rows: usize,
15789 cols: usize,
15790 out: &mut CudaBuffer<f32>,
15791 device: &GpuDevice,
15792) -> GpuResult<()> {
15793 use cudarc::driver::PushKernelArg;
15794 let ctx = device.context();
15795 let stream = device.stream();
15796 let f = crate::module_cache::get_or_compile(
15797 ctx,
15798 SOFTMAX_PTX,
15799 "softmax_kernel",
15800 device.ordinal() as u32,
15801 )
15802 .map_err(|_| GpuError::PtxCompileFailed {
15803 kernel: "softmax_kernel",
15804 })?;
15805 let block_size = 256u32;
15806 let grid_size = rows as u32;
15807 let cfg = LaunchConfig {
15808 grid_dim: (grid_size, 1, 1),
15809 block_dim: (block_size, 1, 1),
15810 shared_mem_bytes: (cols as u32) * 4,
15811 };
15812 let rows_u32 = rows as u32;
15813 let cols_u32 = cols as u32;
15814 unsafe {
15815 stream
15816 .launch_builder(&f)
15817 .arg(a.inner())
15818 .arg(out.inner_mut())
15819 .arg(&rows_u32)
15820 .arg(&cols_u32)
15821 .launch(cfg)?;
15822 }
15823 Ok(())
15824}
15825
15826#[cfg(feature = "cuda")]
15828#[allow(clippy::too_many_arguments)]
15829pub fn gpu_layernorm_into(
15830 input: &CudaBuffer<f32>,
15831 weight: &CudaBuffer<f32>,
15832 bias: &CudaBuffer<f32>,
15833 rows: usize,
15834 cols: usize,
15835 eps: f32,
15836 out: &mut CudaBuffer<f32>,
15837 device: &GpuDevice,
15838) -> GpuResult<()> {
15839 use cudarc::driver::PushKernelArg;
15840 let ctx = device.context();
15841 let stream = device.stream();
15842 let f = crate::module_cache::get_or_compile(
15843 ctx,
15844 LAYERNORM_PTX,
15845 "layernorm_kernel",
15846 device.ordinal() as u32,
15847 )
15848 .map_err(|_| GpuError::PtxCompileFailed {
15849 kernel: "layernorm_kernel",
15850 })?;
15851 let block_size = 256u32;
15852 let grid_size = rows as u32;
15853 let cfg = LaunchConfig {
15854 grid_dim: (grid_size, 1, 1),
15855 block_dim: (block_size, 1, 1),
15856 shared_mem_bytes: (cols as u32) * 4,
15857 };
15858 let rows_u32 = rows as u32;
15859 let cols_u32 = cols as u32;
15860 unsafe {
15861 stream
15862 .launch_builder(&f)
15863 .arg(input.inner())
15864 .arg(out.inner_mut())
15865 .arg(weight.inner())
15866 .arg(bias.inner())
15867 .arg(&rows_u32)
15868 .arg(&cols_u32)
15869 .arg(&eps)
15870 .launch(cfg)?;
15871 }
15872 Ok(())
15873}
15874
15875#[cfg(feature = "cuda")]
15878pub fn gpu_slice_read_into(
15879 src: &CudaBuffer<f32>,
15880 n_batch: usize,
15881 d: usize,
15882 len: usize,
15883 max_len: usize,
15884 out: &mut CudaBuffer<f32>,
15885 device: &GpuDevice,
15886) -> GpuResult<()> {
15887 use cudarc::driver::PushKernelArg;
15888 let total = n_batch * len * d;
15889 let ctx = device.context();
15890 let stream = device.stream();
15891 let f = crate::module_cache::get_or_compile(
15892 ctx,
15893 SLICE_READ_PTX,
15894 "slice_read_kernel",
15895 device.ordinal() as u32,
15896 )
15897 .map_err(|_| GpuError::PtxCompileFailed {
15898 kernel: "slice_read_kernel",
15899 })?;
15900 let cfg = launch_cfg(total)?;
15901 let total_u32 = total as u32;
15902 let d_u32 = d as u32;
15903 let len_u32 = len as u32;
15904 let max_len_u32 = max_len as u32;
15905 unsafe {
15906 stream
15907 .launch_builder(&f)
15908 .arg(src.inner())
15909 .arg(out.inner_mut())
15910 .arg(&total_u32)
15911 .arg(&d_u32)
15912 .arg(&len_u32)
15913 .arg(&max_len_u32)
15914 .launch(cfg)?;
15915 }
15916 Ok(())
15917}
15918
15919#[cfg(feature = "cuda")]
15921pub fn gpu_small_matmul_into(
15922 a: &CudaBuffer<f32>,
15923 b: &CudaBuffer<f32>,
15924 m: usize,
15925 k: usize,
15926 n: usize,
15927 out: &mut CudaBuffer<f32>,
15928 device: &GpuDevice,
15929) -> GpuResult<()> {
15930 use cudarc::driver::PushKernelArg;
15931 let total = m * n;
15932 let ctx = device.context();
15933 let stream = device.stream();
15934 let f = crate::module_cache::get_or_compile(
15935 ctx,
15936 SMALL_MATMUL_PTX,
15937 "small_matmul_kernel",
15938 device.ordinal() as u32,
15939 )
15940 .map_err(|_| GpuError::PtxCompileFailed {
15941 kernel: "small_matmul_kernel",
15942 })?;
15943 let cfg = launch_cfg(total)?;
15944 let (m_u32, k_u32, n_u32, total_u32) = (m as u32, k as u32, n as u32, total as u32);
15945 unsafe {
15946 stream
15947 .launch_builder(&f)
15948 .arg(a.inner())
15949 .arg(b.inner())
15950 .arg(out.inner_mut())
15951 .arg(&m_u32)
15952 .arg(&k_u32)
15953 .arg(&n_u32)
15954 .arg(&total_u32)
15955 .launch(cfg)?;
15956 }
15957 Ok(())
15958}
15959
15960#[cfg(feature = "cuda")]
15967pub fn gpu_slice_write_indirect(
15968 src: &CudaBuffer<f32>,
15969 dst: &mut CudaBuffer<f32>,
15970 n_batch: usize,
15971 d: usize,
15972 max_len: usize,
15973 pos_ptr: &cudarc::driver::CudaSlice<u32>,
15974 device: &GpuDevice,
15975) -> GpuResult<()> {
15976 use cudarc::driver::PushKernelArg;
15977 let total = n_batch * d;
15978 let ctx = device.context();
15979 let stream = device.stream();
15980 let f = crate::module_cache::get_or_compile(
15981 ctx,
15982 SLICE_WRITE_INDIRECT_PTX,
15983 "slice_write_indirect_kernel",
15984 device.ordinal() as u32,
15985 )
15986 .map_err(|_| GpuError::PtxCompileFailed {
15987 kernel: "slice_write_indirect_kernel",
15988 })?;
15989 let cfg = launch_cfg(total)?;
15990 let n_u32 = total as u32;
15991 let d_u32 = d as u32;
15992 let max_len_u32 = max_len as u32;
15993 unsafe {
15994 stream
15995 .launch_builder(&f)
15996 .arg(src.inner())
15997 .arg(dst.inner_mut())
15998 .arg(&n_u32)
15999 .arg(&d_u32)
16000 .arg(&max_len_u32)
16001 .arg(pos_ptr)
16002 .launch(cfg)?;
16003 }
16004 Ok(())
16005}
16006
16007#[cfg(feature = "cuda")]
16011pub fn gpu_causal_mask_indirect(
16012 total_len_ptr: &cudarc::driver::CudaSlice<u32>,
16013 n_head: usize,
16014 max_pos: usize,
16015 out: &mut CudaBuffer<f32>,
16016 device: &GpuDevice,
16017) -> GpuResult<()> {
16018 use cudarc::driver::PushKernelArg;
16019 let total = n_head * max_pos;
16020 let ctx = device.context();
16021 let stream = device.stream();
16022 let f = crate::module_cache::get_or_compile(
16023 ctx,
16024 CAUSAL_MASK_INDIRECT_PTX,
16025 "causal_mask_indirect_kernel",
16026 device.ordinal() as u32,
16027 )
16028 .map_err(|_| GpuError::PtxCompileFailed {
16029 kernel: "causal_mask_indirect_kernel",
16030 })?;
16031 let cfg = launch_cfg(total)?;
16032 let max_pos_u32 = max_pos as u32;
16033 let total_u32 = total as u32;
16034 unsafe {
16035 stream
16036 .launch_builder(&f)
16037 .arg(total_len_ptr)
16038 .arg(out.inner_mut())
16039 .arg(&max_pos_u32)
16040 .arg(&total_u32)
16041 .launch(cfg)?;
16042 }
16043 Ok(())
16044}
16045
16046#[cfg(feature = "cuda")]
16054pub fn precompile_decode_kernels(device: &GpuDevice) -> GpuResult<()> {
16055 let ctx = device.context();
16056 ctx.bind_to_thread()?;
16057 let ord = device.ordinal() as u32;
16058 let compile = |ptx: &'static str, name: &'static str| -> GpuResult<()> {
16059 crate::module_cache::get_or_compile(ctx, ptx, name, ord)
16060 .map(|_| ())
16061 .map_err(GpuError::Driver)
16062 };
16063 compile(ADD_PTX, "add_kernel")?;
16064 compile(MUL_PTX, "mul_kernel")?;
16065 compile(SCALE_PTX, "scale_kernel")?;
16066 compile(GELU_PTX, "gelu_kernel")?;
16067 compile(SOFTMAX_PTX, "softmax_kernel")?;
16068 compile(LAYERNORM_PTX, "layernorm_kernel")?;
16069 compile(PERMUTE_0213_PTX, "permute_0213_kernel")?;
16070 compile(EMBED_LOOKUP_PTX, "embed_lookup_kernel")?;
16071 compile(EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel")?;
16072 compile(SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel")?;
16073 compile(SMALL_MATMUL_PTX, "small_matmul_kernel")?;
16074 compile(SLICE_WRITE_INDIRECT_PTX, "slice_write_indirect_kernel")?;
16075 compile(CAUSAL_MASK_INDIRECT_PTX, "causal_mask_indirect_kernel")?;
16076 compile(SLICE_READ_PTX, "slice_read_kernel")?;
16077 compile(RELU_BACKWARD_PTX, "relu_backward_kernel")?;
16078 compile(GELU_BACKWARD_PTX, "gelu_backward_kernel")?;
16079 Ok(())
16080}
16081
16082#[cfg(not(feature = "cuda"))]
16084pub fn precompile_decode_kernels(_device: &GpuDevice) -> GpuResult<()> {
16085 Err(GpuError::NoCudaFeature)
16086}
16087
16088#[cfg(not(feature = "cuda"))]
16094pub fn gpu_gelu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16095 Err(GpuError::NoCudaFeature)
16096}
16097
16098#[cfg(not(feature = "cuda"))]
16100pub fn gpu_gelu_tanh(
16101 _input: &CudaBuffer<f32>,
16102 _device: &GpuDevice,
16103) -> GpuResult<CudaBuffer<f32>> {
16104 Err(GpuError::NoCudaFeature)
16105}
16106
16107#[cfg(not(feature = "cuda"))]
16109pub fn gpu_gelu_erf(
16110 _input: &CudaBuffer<f32>,
16111 _device: &GpuDevice,
16112) -> GpuResult<CudaBuffer<f32>> {
16113 Err(GpuError::NoCudaFeature)
16114}
16115
16116#[cfg(not(feature = "cuda"))]
16118pub fn gpu_gelu_backward_tanh(
16119 _grad: &CudaBuffer<f32>,
16120 _input: &CudaBuffer<f32>,
16121 _device: &GpuDevice,
16122) -> GpuResult<CudaBuffer<f32>> {
16123 Err(GpuError::NoCudaFeature)
16124}
16125
16126#[cfg(not(feature = "cuda"))]
16128pub fn gpu_silu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16129 Err(GpuError::NoCudaFeature)
16130}
16131
16132#[cfg(not(feature = "cuda"))]
16134pub fn gpu_silu_backward(
16135 _grad: &CudaBuffer<f32>,
16136 _input: &CudaBuffer<f32>,
16137 _device: &GpuDevice,
16138) -> GpuResult<CudaBuffer<f32>> {
16139 Err(GpuError::NoCudaFeature)
16140}
16141
16142#[cfg(not(feature = "cuda"))]
16144pub fn gpu_elu(
16145 _input: &CudaBuffer<f32>,
16146 _alpha: f32,
16147 _device: &GpuDevice,
16148) -> GpuResult<CudaBuffer<f32>> {
16149 Err(GpuError::NoCudaFeature)
16150}
16151
16152#[cfg(not(feature = "cuda"))]
16154pub fn gpu_elu_backward(
16155 _grad: &CudaBuffer<f32>,
16156 _input: &CudaBuffer<f32>,
16157 _alpha: f32,
16158 _device: &GpuDevice,
16159) -> GpuResult<CudaBuffer<f32>> {
16160 Err(GpuError::NoCudaFeature)
16161}
16162
16163#[cfg(not(feature = "cuda"))]
16165pub fn gpu_mish(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16166 Err(GpuError::NoCudaFeature)
16167}
16168
16169#[cfg(not(feature = "cuda"))]
16171pub fn gpu_mish_backward(
16172 _grad: &CudaBuffer<f32>,
16173 _input: &CudaBuffer<f32>,
16174 _device: &GpuDevice,
16175) -> GpuResult<CudaBuffer<f32>> {
16176 Err(GpuError::NoCudaFeature)
16177}
16178
16179#[cfg(not(feature = "cuda"))]
16181pub fn gpu_clamp(
16182 _input: &CudaBuffer<f32>,
16183 _min_val: f32,
16184 _max_val: f32,
16185 _device: &GpuDevice,
16186) -> GpuResult<CudaBuffer<f32>> {
16187 Err(GpuError::NoCudaFeature)
16188}
16189
16190#[cfg(not(feature = "cuda"))]
16192pub fn gpu_div(
16193 _a: &CudaBuffer<f32>,
16194 _b: &CudaBuffer<f32>,
16195 _device: &GpuDevice,
16196) -> GpuResult<CudaBuffer<f32>> {
16197 Err(GpuError::NoCudaFeature)
16198}
16199
16200#[cfg(not(feature = "cuda"))]
16202pub fn gpu_exp(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16203 Err(GpuError::NoCudaFeature)
16204}
16205
16206#[cfg(not(feature = "cuda"))]
16208pub fn gpu_log(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16209 Err(GpuError::NoCudaFeature)
16210}
16211
16212#[cfg(not(feature = "cuda"))]
16214pub fn gpu_sqrt(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16215 Err(GpuError::NoCudaFeature)
16216}
16217
16218#[cfg(not(feature = "cuda"))]
16220pub fn gpu_pow(
16221 _a: &CudaBuffer<f32>,
16222 _exponent: f32,
16223 _device: &GpuDevice,
16224) -> GpuResult<CudaBuffer<f32>> {
16225 Err(GpuError::NoCudaFeature)
16226}
16227
16228#[cfg(not(feature = "cuda"))]
16230pub fn gpu_abs(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16231 Err(GpuError::NoCudaFeature)
16232}
16233
16234#[cfg(not(feature = "cuda"))]
16236pub fn gpu_sigmoid(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16237 Err(GpuError::NoCudaFeature)
16238}
16239
16240#[cfg(not(feature = "cuda"))]
16242pub fn gpu_tanh(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16243 Err(GpuError::NoCudaFeature)
16244}
16245
16246#[cfg(not(feature = "cuda"))]
16248pub fn gpu_layernorm(
16249 _input: &CudaBuffer<f32>,
16250 _weight: &CudaBuffer<f32>,
16251 _bias: &CudaBuffer<f32>,
16252 _rows: usize,
16253 _cols: usize,
16254 _eps: f32,
16255 _device: &GpuDevice,
16256) -> GpuResult<CudaBuffer<f32>> {
16257 Err(GpuError::NoCudaFeature)
16258}
16259
16260#[cfg(not(feature = "cuda"))]
16262pub fn gpu_transpose_2d(
16263 _input: &CudaBuffer<f32>,
16264 _m: usize,
16265 _n: usize,
16266 _device: &GpuDevice,
16267) -> GpuResult<CudaBuffer<f32>> {
16268 Err(GpuError::NoCudaFeature)
16269}
16270
16271#[cfg(not(feature = "cuda"))]
16273pub fn gpu_add(
16274 _a: &CudaBuffer<f32>,
16275 _b: &CudaBuffer<f32>,
16276 _device: &GpuDevice,
16277) -> GpuResult<CudaBuffer<f32>> {
16278 Err(GpuError::NoCudaFeature)
16279}
16280
16281#[cfg(not(feature = "cuda"))]
16283pub fn gpu_sub(
16284 _a: &CudaBuffer<f32>,
16285 _b: &CudaBuffer<f32>,
16286 _device: &GpuDevice,
16287) -> GpuResult<CudaBuffer<f32>> {
16288 Err(GpuError::NoCudaFeature)
16289}
16290
16291#[cfg(not(feature = "cuda"))]
16293pub fn gpu_mul(
16294 _a: &CudaBuffer<f32>,
16295 _b: &CudaBuffer<f32>,
16296 _device: &GpuDevice,
16297) -> GpuResult<CudaBuffer<f32>> {
16298 Err(GpuError::NoCudaFeature)
16299}
16300
16301#[cfg(not(feature = "cuda"))]
16303pub fn gpu_neg(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16304 Err(GpuError::NoCudaFeature)
16305}
16306
16307#[cfg(not(feature = "cuda"))]
16309pub fn gpu_relu(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
16310 Err(GpuError::NoCudaFeature)
16311}
16312
16313#[cfg(not(feature = "cuda"))]
16315pub fn gpu_scale(
16316 _a: &CudaBuffer<f32>,
16317 _scalar: f32,
16318 _device: &GpuDevice,
16319) -> GpuResult<CudaBuffer<f32>> {
16320 Err(GpuError::NoCudaFeature)
16321}
16322
16323#[cfg(not(feature = "cuda"))]
16325pub fn gpu_broadcast_add(
16326 _a: &CudaBuffer<f32>,
16327 _b: &CudaBuffer<f32>,
16328 _a_shape: &[usize],
16329 _b_shape: &[usize],
16330 _out_shape: &[usize],
16331 _device: &GpuDevice,
16332) -> GpuResult<CudaBuffer<f32>> {
16333 Err(GpuError::NoCudaFeature)
16334}
16335
16336#[cfg(not(feature = "cuda"))]
16338pub fn gpu_broadcast_sub(
16339 _a: &CudaBuffer<f32>,
16340 _b: &CudaBuffer<f32>,
16341 _a_shape: &[usize],
16342 _b_shape: &[usize],
16343 _out_shape: &[usize],
16344 _device: &GpuDevice,
16345) -> GpuResult<CudaBuffer<f32>> {
16346 Err(GpuError::NoCudaFeature)
16347}
16348
16349#[cfg(not(feature = "cuda"))]
16351pub fn gpu_broadcast_mul(
16352 _a: &CudaBuffer<f32>,
16353 _b: &CudaBuffer<f32>,
16354 _a_shape: &[usize],
16355 _b_shape: &[usize],
16356 _out_shape: &[usize],
16357 _device: &GpuDevice,
16358) -> GpuResult<CudaBuffer<f32>> {
16359 Err(GpuError::NoCudaFeature)
16360}
16361
16362#[cfg(not(feature = "cuda"))]
16364pub fn gpu_softmax(
16365 _input: &CudaBuffer<f32>,
16366 _rows: usize,
16367 _cols: usize,
16368 _device: &GpuDevice,
16369) -> GpuResult<CudaBuffer<f32>> {
16370 Err(GpuError::NoCudaFeature)
16371}
16372
16373#[cfg(not(feature = "cuda"))]
16375pub fn gpu_dropout(
16376 _input: &CudaBuffer<f32>,
16377 _threshold: u32,
16378 _scale: f32,
16379 _seed: u32,
16380 _device: &GpuDevice,
16381) -> GpuResult<CudaBuffer<f32>> {
16382 Err(GpuError::NoCudaFeature)
16383}
16384
16385#[cfg(not(feature = "cuda"))]
16387pub fn gpu_permute_0213(
16388 _input: &CudaBuffer<f32>,
16389 _d0: usize,
16390 _d1: usize,
16391 _d2: usize,
16392 _d3: usize,
16393 _device: &GpuDevice,
16394) -> GpuResult<CudaBuffer<f32>> {
16395 Err(GpuError::NoCudaFeature)
16396}
16397
16398#[cfg(not(feature = "cuda"))]
16400pub fn gpu_slice_write(
16401 _src: &CudaBuffer<f32>,
16402 _dst: &mut CudaBuffer<f32>,
16403 _n_batch: usize,
16404 _d: usize,
16405 _max_len: usize,
16406 _pos: usize,
16407 _device: &GpuDevice,
16408) -> GpuResult<()> {
16409 Err(GpuError::NoCudaFeature)
16410}
16411
16412#[cfg(not(feature = "cuda"))]
16414pub fn gpu_slice_read(
16415 _src: &CudaBuffer<f32>,
16416 _n_batch: usize,
16417 _d: usize,
16418 _len: usize,
16419 _max_len: usize,
16420 _device: &GpuDevice,
16421) -> GpuResult<CudaBuffer<f32>> {
16422 Err(GpuError::NoCudaFeature)
16423}
16424
16425#[cfg(not(feature = "cuda"))]
16427pub fn gpu_embed_lookup(
16428 _idx: &CudaBuffer<f32>,
16429 _weight: &CudaBuffer<f32>,
16430 _d: usize,
16431 _device: &GpuDevice,
16432) -> GpuResult<CudaBuffer<f32>> {
16433 Err(GpuError::NoCudaFeature)
16434}
16435
16436#[cfg(not(feature = "cuda"))]
16438pub fn gpu_embed_lookup_batch(
16439 _indices: &CudaBuffer<f32>,
16440 _weight: &CudaBuffer<f32>,
16441 _n: usize,
16442 _d: usize,
16443 _device: &GpuDevice,
16444) -> GpuResult<CudaBuffer<f32>> {
16445 Err(GpuError::NoCudaFeature)
16446}
16447
16448#[cfg(not(feature = "cuda"))]
16450pub fn gpu_scatter_add_rows(
16451 _grad_output: &CudaBuffer<f32>,
16452 _indices: &CudaBuffer<f32>,
16453 _num_embeddings: usize,
16454 _d: usize,
16455 _device: &GpuDevice,
16456) -> GpuResult<CudaBuffer<f32>> {
16457 Err(GpuError::NoCudaFeature)
16458}
16459
16460#[cfg(not(feature = "cuda"))]
16462pub fn gpu_relu_backward(
16463 _grad: &CudaBuffer<f32>,
16464 _input: &CudaBuffer<f32>,
16465 _device: &GpuDevice,
16466) -> GpuResult<CudaBuffer<f32>> {
16467 Err(GpuError::NoCudaFeature)
16468}
16469
16470#[cfg(not(feature = "cuda"))]
16472pub fn gpu_gelu_backward(
16473 _grad: &CudaBuffer<f32>,
16474 _input: &CudaBuffer<f32>,
16475 _device: &GpuDevice,
16476) -> GpuResult<CudaBuffer<f32>> {
16477 Err(GpuError::NoCudaFeature)
16478}
16479
16480#[cfg(not(feature = "cuda"))]
16482pub fn gpu_index_select_1d(
16483 _input: &CudaBuffer<f32>,
16484 _indices: &CudaBuffer<f32>,
16485 _device: &GpuDevice,
16486) -> GpuResult<CudaBuffer<f32>> {
16487 Err(GpuError::NoCudaFeature)
16488}
16489
16490#[cfg(not(feature = "cuda"))]
16492pub fn gpu_scatter_add_1d(
16493 _grad_output: &CudaBuffer<f32>,
16494 _indices: &CudaBuffer<f32>,
16495 _input_len: usize,
16496 _device: &GpuDevice,
16497) -> GpuResult<CudaBuffer<f32>> {
16498 Err(GpuError::NoCudaFeature)
16499}
16500
16501#[cfg(not(feature = "cuda"))]
16503pub fn gpu_masked_fill(
16504 _input: &CudaBuffer<f32>,
16505 _mask: &CudaBuffer<f32>,
16506 _value: f32,
16507 _device: &GpuDevice,
16508) -> GpuResult<CudaBuffer<f32>> {
16509 Err(GpuError::NoCudaFeature)
16510}
16511
16512#[cfg(not(feature = "cuda"))]
16514pub fn gpu_masked_zero(
16515 _grad: &CudaBuffer<f32>,
16516 _mask: &CudaBuffer<f32>,
16517 _device: &GpuDevice,
16518) -> GpuResult<CudaBuffer<f32>> {
16519 Err(GpuError::NoCudaFeature)
16520}
16521
16522#[cfg(not(feature = "cuda"))]
16524pub fn gpu_sigmoid_backward(
16525 _grad: &CudaBuffer<f32>,
16526 _output: &CudaBuffer<f32>,
16527 _device: &GpuDevice,
16528) -> GpuResult<CudaBuffer<f32>> {
16529 Err(GpuError::NoCudaFeature)
16530}
16531
16532#[cfg(not(feature = "cuda"))]
16534pub fn gpu_tanh_backward(
16535 _grad: &CudaBuffer<f32>,
16536 _output: &CudaBuffer<f32>,
16537 _device: &GpuDevice,
16538) -> GpuResult<CudaBuffer<f32>> {
16539 Err(GpuError::NoCudaFeature)
16540}
16541
16542#[cfg(not(feature = "cuda"))]
16544pub fn gpu_softmax_backward(
16545 _grad: &CudaBuffer<f32>,
16546 _output: &CudaBuffer<f32>,
16547 _cols: usize,
16548 _device: &GpuDevice,
16549) -> GpuResult<CudaBuffer<f32>> {
16550 Err(GpuError::NoCudaFeature)
16551}
16552
16553#[cfg(not(feature = "cuda"))]
16555pub fn gpu_log_softmax(
16556 _input: &CudaBuffer<f32>,
16557 _cols: usize,
16558 _device: &GpuDevice,
16559) -> GpuResult<CudaBuffer<f32>> {
16560 Err(GpuError::NoCudaFeature)
16561}
16562
16563#[cfg(not(feature = "cuda"))]
16565pub fn gpu_log_softmax_backward(
16566 _grad: &CudaBuffer<f32>,
16567 _output: &CudaBuffer<f32>,
16568 _cols: usize,
16569 _device: &GpuDevice,
16570) -> GpuResult<CudaBuffer<f32>> {
16571 Err(GpuError::NoCudaFeature)
16572}
16573
16574#[cfg(not(feature = "cuda"))]
16576pub fn gpu_sum_axis(
16577 _a: &CudaBuffer<f32>,
16578 _outer: usize,
16579 _axis_size: usize,
16580 _inner: usize,
16581 _device: &GpuDevice,
16582) -> GpuResult<CudaBuffer<f32>> {
16583 Err(GpuError::NoCudaFeature)
16584}
16585
16586#[cfg(not(feature = "cuda"))]
16588pub fn gpu_cumsum(
16589 _input: &CudaBuffer<f32>,
16590 _outer: usize,
16591 _dim_size: usize,
16592 _inner: usize,
16593 _device: &GpuDevice,
16594) -> GpuResult<CudaBuffer<f32>> {
16595 Err(GpuError::NoCudaFeature)
16596}
16597
16598#[cfg(not(feature = "cuda"))]
16600pub fn gpu_cumprod(
16601 _input: &CudaBuffer<f32>,
16602 _outer: usize,
16603 _dim_size: usize,
16604 _inner: usize,
16605 _device: &GpuDevice,
16606) -> GpuResult<CudaBuffer<f32>> {
16607 Err(GpuError::NoCudaFeature)
16608}
16609
16610#[cfg(not(feature = "cuda"))]
16612pub fn gpu_cummax(
16613 _input: &CudaBuffer<f32>,
16614 _outer: usize,
16615 _dim_size: usize,
16616 _inner: usize,
16617 _device: &GpuDevice,
16618) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
16619 Err(GpuError::NoCudaFeature)
16620}
16621
16622#[cfg(not(feature = "cuda"))]
16624pub fn gpu_cummin(
16625 _input: &CudaBuffer<f32>,
16626 _outer: usize,
16627 _dim_size: usize,
16628 _inner: usize,
16629 _device: &GpuDevice,
16630) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
16631 Err(GpuError::NoCudaFeature)
16632}
16633
16634#[cfg(not(feature = "cuda"))]
16636pub fn gpu_logcumsumexp(
16637 _input: &CudaBuffer<f32>,
16638 _outer: usize,
16639 _dim_size: usize,
16640 _inner: usize,
16641 _device: &GpuDevice,
16642) -> GpuResult<CudaBuffer<f32>> {
16643 Err(GpuError::NoCudaFeature)
16644}
16645
16646#[cfg(not(feature = "cuda"))]
16648pub fn gpu_strided_split(
16649 _input: &CudaBuffer<f32>,
16650 _total_along_axis: usize,
16651 _split_offset: usize,
16652 _split_size: usize,
16653 _inner_size: usize,
16654 _n: usize,
16655 _device: &GpuDevice,
16656) -> GpuResult<CudaBuffer<f32>> {
16657 Err(GpuError::NoCudaFeature)
16658}
16659
16660#[cfg(not(feature = "cuda"))]
16662pub fn gpu_strided_cat(
16663 _input: &CudaBuffer<f32>,
16664 _output: &mut CudaBuffer<f32>,
16665 _total_along_axis: usize,
16666 _cat_offset: usize,
16667 _part_size: usize,
16668 _inner_size: usize,
16669 _n: usize,
16670 _device: &GpuDevice,
16671) -> GpuResult<()> {
16672 Err(GpuError::NoCudaFeature)
16673}
16674
16675#[cfg(not(feature = "cuda"))]
16678pub const STRIDED_COPY_MAX_DIMS: usize = 8;
16679
16680#[cfg(not(feature = "cuda"))]
16682pub fn gpu_strided_copy(
16683 _input: &CudaBuffer<f32>,
16684 _out_shape: &[usize],
16685 _src_strides: &[isize],
16686 _src_offset: usize,
16687 _device: &GpuDevice,
16688) -> GpuResult<CudaBuffer<f32>> {
16689 Err(GpuError::NoCudaFeature)
16690}
16691
16692#[cfg(not(feature = "cuda"))]
16694pub fn gpu_strided_copy_f64(
16695 _input: &CudaBuffer<f64>,
16696 _out_shape: &[usize],
16697 _src_strides: &[isize],
16698 _src_offset: usize,
16699 _device: &GpuDevice,
16700) -> GpuResult<CudaBuffer<f64>> {
16701 Err(GpuError::NoCudaFeature)
16702}
16703
16704#[cfg(feature = "cuda")]
16720pub(crate) fn gpu_f32_to_f16(
16721 input: &CudaBuffer<f32>,
16722 device: &GpuDevice,
16723) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
16724 use cudarc::driver::PushKernelArg;
16725
16726 let n = input.len();
16727 if n == 0 {
16728 let empty = device.stream().alloc_zeros::<u16>(0)?;
16729 return Ok(empty);
16730 }
16731
16732 let ctx = device.context();
16733 let stream = device.stream();
16734
16735 let f = crate::module_cache::get_or_compile(
16736 ctx,
16737 F32_TO_F16_PTX,
16738 "f32_to_f16_kernel",
16739 device.ordinal() as u32,
16740 )
16741 .map_err(|_| GpuError::PtxCompileFailed {
16742 kernel: "f32_to_f16_kernel",
16743 })?;
16744
16745 let mut out = stream.alloc_zeros::<u16>(n)?;
16746 let cfg = launch_cfg(n)?;
16747 let n_u32 = n as u32;
16748
16749 unsafe {
16753 stream
16754 .launch_builder(&f)
16755 .arg(input.inner())
16756 .arg(&mut out)
16757 .arg(&n_u32)
16758 .launch(cfg)?;
16759 }
16760
16761 Ok(out)
16762}
16763
16764#[cfg(not(feature = "cuda"))]
16766pub(crate) fn gpu_f32_to_f16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
16767 Err(GpuError::NoCudaFeature)
16768}
16769
16770#[cfg(feature = "cuda")]
16775pub(crate) fn gpu_f32_to_bf16(
16776 input: &CudaBuffer<f32>,
16777 device: &GpuDevice,
16778) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
16779 use cudarc::driver::PushKernelArg;
16780
16781 let n = input.len();
16782 if n == 0 {
16783 let empty = device.stream().alloc_zeros::<u16>(0)?;
16784 return Ok(empty);
16785 }
16786
16787 let ctx = device.context();
16788 let stream = device.stream();
16789
16790 let f = crate::module_cache::get_or_compile(
16791 ctx,
16792 F32_TO_BF16_PTX,
16793 "f32_to_bf16_kernel",
16794 device.ordinal() as u32,
16795 )
16796 .map_err(|_| GpuError::PtxCompileFailed {
16797 kernel: "f32_to_bf16_kernel",
16798 })?;
16799
16800 let mut out = stream.alloc_zeros::<u16>(n)?;
16801 let cfg = launch_cfg(n)?;
16802 let n_u32 = n as u32;
16803
16804 unsafe {
16805 stream
16806 .launch_builder(&f)
16807 .arg(input.inner())
16808 .arg(&mut out)
16809 .arg(&n_u32)
16810 .launch(cfg)?;
16811 }
16812
16813 Ok(out)
16814}
16815
16816#[cfg(not(feature = "cuda"))]
16818pub(crate) fn gpu_f32_to_bf16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
16819 Err(GpuError::NoCudaFeature)
16820}
16821
16822#[cfg(not(feature = "cuda"))]
16827pub fn gpu_add_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16828#[cfg(not(feature = "cuda"))]
16829pub fn gpu_sub_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16830#[cfg(not(feature = "cuda"))]
16831pub fn gpu_mul_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16832#[cfg(not(feature = "cuda"))]
16833pub fn gpu_div_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16834#[cfg(not(feature = "cuda"))]
16835pub fn gpu_neg_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16836#[cfg(not(feature = "cuda"))]
16837pub fn gpu_relu_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16838#[cfg(not(feature = "cuda"))]
16839pub fn gpu_scale_f64(_a: &CudaBuffer<f64>, _scalar: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16840#[cfg(not(feature = "cuda"))]
16841pub fn gpu_exp_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16842#[cfg(not(feature = "cuda"))]
16843pub fn gpu_log_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16844#[cfg(not(feature = "cuda"))]
16845pub fn gpu_sqrt_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16846#[cfg(not(feature = "cuda"))]
16847pub fn gpu_pow_f64(_a: &CudaBuffer<f64>, _exponent: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16848#[cfg(not(feature = "cuda"))]
16849pub fn gpu_abs_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16850#[cfg(not(feature = "cuda"))]
16851pub fn gpu_sigmoid_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16852#[cfg(not(feature = "cuda"))]
16853pub fn gpu_tanh_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16854#[cfg(not(feature = "cuda"))]
16855pub fn gpu_relu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16856#[cfg(not(feature = "cuda"))]
16857pub fn gpu_sigmoid_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16858#[cfg(not(feature = "cuda"))]
16859pub fn gpu_tanh_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16860#[cfg(not(feature = "cuda"))]
16861pub fn gpu_broadcast_add_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16862#[cfg(not(feature = "cuda"))]
16863pub fn gpu_broadcast_sub_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16864#[cfg(not(feature = "cuda"))]
16865pub fn gpu_broadcast_mul_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16866#[cfg(not(feature = "cuda"))]
16867pub fn gpu_broadcast_div_f64(_a: &CudaBuffer<f64>, _b: &CudaBuffer<f64>, _a_shape: &[usize], _b_shape: &[usize], _out_shape: &[usize], _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16868#[cfg(not(feature = "cuda"))]
16869pub fn gpu_transpose_2d_f64(_input: &CudaBuffer<f64>, _m: usize, _n: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16870#[cfg(not(feature = "cuda"))]
16871pub fn gpu_permute_0213_f64(_input: &CudaBuffer<f64>, _d0: usize, _d1: usize, _d2: usize, _d3: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16872#[cfg(not(feature = "cuda"))]
16873pub fn gpu_strided_split_f64(_input: &CudaBuffer<f64>, _total_along_axis: usize, _split_offset: usize, _split_size: usize, _inner_size: usize, _n: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16874#[cfg(not(feature = "cuda"))]
16875pub fn gpu_strided_cat_f64(_input: &CudaBuffer<f64>, _output: &mut CudaBuffer<f64>, _total_along_axis: usize, _cat_offset: usize, _part_size: usize, _inner_size: usize, _n: usize, _device: &GpuDevice) -> GpuResult<()> { Err(GpuError::NoCudaFeature) }
16876#[cfg(not(feature = "cuda"))]
16877pub fn gpu_index_select_1d_f64(_input: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16878#[cfg(not(feature = "cuda"))]
16879pub fn gpu_scatter_add_1d_f64(_grad_output: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _input_len: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16880#[cfg(not(feature = "cuda"))]
16881pub fn gpu_masked_fill_f64(_input: &CudaBuffer<f64>, _mask: &CudaBuffer<u8>, _value: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16882#[cfg(not(feature = "cuda"))]
16883pub fn gpu_masked_zero_f64(_grad: &CudaBuffer<f64>, _mask: &CudaBuffer<u8>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16884#[cfg(not(feature = "cuda"))]
16885pub fn gpu_slice_write_f64(_src: &CudaBuffer<f64>, _dst: &mut CudaBuffer<f64>, _n_batch: usize, _d: usize, _max_len: usize, _pos: usize, _device: &GpuDevice) -> GpuResult<()> { Err(GpuError::NoCudaFeature) }
16886#[cfg(not(feature = "cuda"))]
16887pub fn gpu_slice_read_f64(_src: &CudaBuffer<f64>, _n_batch: usize, _d: usize, _len: usize, _max_len: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16888#[cfg(not(feature = "cuda"))]
16889pub fn gpu_embed_lookup_f64(_idx: &CudaBuffer<f32>, _weight: &CudaBuffer<f64>, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16890#[cfg(not(feature = "cuda"))]
16891pub fn gpu_embed_lookup_batch_f64(_indices: &CudaBuffer<f32>, _weight: &CudaBuffer<f64>, _n: usize, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16892#[cfg(not(feature = "cuda"))]
16893pub fn gpu_scatter_add_rows_f64(_grad_output: &CudaBuffer<f64>, _indices: &CudaBuffer<f32>, _num_embeddings: usize, _d: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
16894
16895
16896#[cfg(feature = "cuda")]
16902pub fn gpu_gelu_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
16903 if let Some(out) = try_launch_unary_f64(input, device, GELU_F64_PTX, "gelu_f64_kernel")? {
16904 return Ok(out);
16905 }
16906 cpu_fallback_unary_f64(input, device, |x| x * (1.0 / (1.0 + (-1.702 * x).exp())))
16907}
16908
16909#[cfg(feature = "cuda")]
16911pub fn gpu_gelu_tanh_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
16912 if let Some(out) = try_launch_unary_f64(input, device, GELU_TANH_F64_PTX, "gelu_tanh_f64_kernel")? {
16913 return Ok(out);
16914 }
16915 cpu_fallback_unary_f64(input, device, |x| {
16916 let inner = (2.0_f64 / std::f64::consts::PI).sqrt() * (x + 0.044715 * x * x * x);
16917 0.5 * x * (1.0 + inner.tanh())
16918 })
16919}
16920
16921#[cfg(feature = "cuda")]
16923pub fn gpu_gelu_erf_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
16924 if let Some(out) = try_launch_unary_f64(input, device, GELU_ERF_F64_PTX, "gelu_erf_f64_kernel")? {
16925 return Ok(out);
16926 }
16927 cpu_fallback_unary_f64(input, device, |x| {
16928 let z = x * std::f64::consts::FRAC_1_SQRT_2;
16930 let az = z.abs();
16931 let t = 1.0 / (1.0 + 0.3275911 * az);
16932 let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
16933 let erf_abs = 1.0 - poly * (-az * az).exp();
16934 let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
16935 x * 0.5 * (1.0 + erf_val)
16936 })
16937}
16938
16939#[cfg(feature = "cuda")]
16941pub fn gpu_gelu_backward_f64(
16942 grad: &CudaBuffer<f64>,
16943 input: &CudaBuffer<f64>,
16944 device: &GpuDevice,
16945) -> GpuResult<CudaBuffer<f64>> {
16946 if grad.len() != input.len() {
16947 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
16948 }
16949 if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_F64_PTX, "gelu_backward_f64_kernel")? {
16950 return Ok(out);
16951 }
16952 cpu_fallback_binary_f64(grad, input, device, |g, x| {
16953 let sig = 1.0 / (1.0 + (-1.702 * x).exp());
16954 g * (sig + 1.702 * x * sig * (1.0 - sig))
16955 })
16956}
16957
16958#[cfg(feature = "cuda")]
16960pub fn gpu_gelu_backward_tanh_f64(
16961 grad: &CudaBuffer<f64>,
16962 input: &CudaBuffer<f64>,
16963 device: &GpuDevice,
16964) -> GpuResult<CudaBuffer<f64>> {
16965 if grad.len() != input.len() {
16966 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
16967 }
16968 if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_TANH_F64_PTX, "gelu_backward_tanh_f64_kernel")? {
16969 return Ok(out);
16970 }
16971 cpu_fallback_binary_f64(grad, input, device, |g, x| {
16972 let s2pi = (2.0_f64 / std::f64::consts::PI).sqrt();
16973 let c = 0.044715_f64;
16974 let u = s2pi * (x + c * x * x * x);
16975 let t = u.tanh();
16976 let d = 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * s2pi * (1.0 + 3.0 * c * x * x);
16977 g * d
16978 })
16979}
16980
16981#[cfg(feature = "cuda")]
16983pub fn gpu_gelu_backward_erf_f64(
16984 grad: &CudaBuffer<f64>,
16985 input: &CudaBuffer<f64>,
16986 device: &GpuDevice,
16987) -> GpuResult<CudaBuffer<f64>> {
16988 if grad.len() != input.len() {
16989 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
16990 }
16991 if let Some(out) = try_launch_binary_f64(grad, input, device, GELU_BACKWARD_ERF_F64_PTX, "gelu_backward_erf_f64_kernel")? {
16992 return Ok(out);
16993 }
16994 cpu_fallback_binary_f64(grad, input, device, |g, x| {
16995 let z = x * std::f64::consts::FRAC_1_SQRT_2;
16996 let az = z.abs();
16997 let t = 1.0 / (1.0 + 0.3275911 * az);
16998 let poly = t * (0.254829592 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
16999 let erf_abs = 1.0 - poly * (-az * az).exp();
17000 let erf_val = if z >= 0.0 { erf_abs } else { -erf_abs };
17001 let cdf = 0.5 * (1.0 + erf_val);
17002 let pdf = (-x * x / 2.0).exp() / (2.0 * std::f64::consts::PI).sqrt();
17003 g * (cdf + x * pdf)
17004 })
17005}
17006
17007#[cfg(feature = "cuda")]
17009pub fn gpu_silu_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17010 if let Some(out) = try_launch_unary_f64(input, device, SILU_F64_PTX, "silu_f64_kernel")? {
17011 return Ok(out);
17012 }
17013 cpu_fallback_unary_f64(input, device, |x| x / (1.0 + (-x).exp()))
17014}
17015
17016#[cfg(feature = "cuda")]
17018pub fn gpu_silu_backward_f64(
17019 grad: &CudaBuffer<f64>,
17020 input: &CudaBuffer<f64>,
17021 device: &GpuDevice,
17022) -> GpuResult<CudaBuffer<f64>> {
17023 if grad.len() != input.len() {
17024 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17025 }
17026 if let Some(out) = try_launch_binary_f64(grad, input, device, SILU_BACKWARD_F64_PTX, "silu_backward_f64_kernel")? {
17027 return Ok(out);
17028 }
17029 cpu_fallback_binary_f64(grad, input, device, |g, x| {
17030 let sig = 1.0 / (1.0 + (-x).exp());
17031 g * (sig + x * sig * (1.0 - sig))
17032 })
17033}
17034
17035#[cfg(feature = "cuda")]
17037pub fn gpu_elu_f64(
17038 input: &CudaBuffer<f64>,
17039 alpha: f64,
17040 device: &GpuDevice,
17041) -> GpuResult<CudaBuffer<f64>> {
17042 use cudarc::driver::PushKernelArg;
17043 let n = input.len();
17044 if n == 0 { return cpu_to_gpu(&[], device); }
17045 let ctx = device.context();
17046 let stream = device.stream();
17047 if let Ok(f) = crate::module_cache::get_or_compile(ctx, ELU_F64_PTX, "elu_f64_kernel", device.ordinal() as u32) {
17048 let mut out = alloc_zeros_f64(n, device)?;
17049 let n_u32 = n as u32;
17050 let cfg = launch_cfg(n)?;
17051 unsafe {
17052 stream.launch_builder(&f)
17053 .arg(input.inner())
17054 .arg(out.inner_mut())
17055 .arg(&n_u32)
17056 .arg(&alpha)
17057 .launch(cfg)?;
17058 }
17059 return Ok(out);
17060 }
17061 let host = gpu_to_cpu(input, device)?;
17062 let result: Vec<f64> = host.iter().map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) }).collect();
17063 cpu_to_gpu(&result, device)
17064}
17065
17066#[cfg(feature = "cuda")]
17068pub fn gpu_elu_backward_f64(
17069 grad: &CudaBuffer<f64>,
17070 input: &CudaBuffer<f64>,
17071 alpha: f64,
17072 device: &GpuDevice,
17073) -> GpuResult<CudaBuffer<f64>> {
17074 use cudarc::driver::PushKernelArg;
17075 if grad.len() != input.len() {
17076 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17077 }
17078 let n = grad.len();
17079 if n == 0 { return cpu_to_gpu(&[], device); }
17080 let ctx = device.context();
17081 let stream = device.stream();
17082 if let Ok(f) = crate::module_cache::get_or_compile(ctx, ELU_BACKWARD_F64_PTX, "elu_backward_f64_kernel", device.ordinal() as u32) {
17083 let mut out = alloc_zeros_f64(n, device)?;
17084 let n_u32 = n as u32;
17085 let cfg = launch_cfg(n)?;
17086 unsafe {
17087 stream.launch_builder(&f)
17088 .arg(grad.inner())
17089 .arg(input.inner())
17090 .arg(out.inner_mut())
17091 .arg(&n_u32)
17092 .arg(&alpha)
17093 .launch(cfg)?;
17094 }
17095 return Ok(out);
17096 }
17097 let g_host = gpu_to_cpu(grad, device)?;
17098 let x_host = gpu_to_cpu(input, device)?;
17099 let result: Vec<f64> = g_host.iter().zip(x_host.iter()).map(|(&g, &x)| if x > 0.0 { g } else { g * alpha * x.exp() }).collect();
17100 cpu_to_gpu(&result, device)
17101}
17102
17103#[cfg(feature = "cuda")]
17105pub fn gpu_mish_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
17106 if let Some(out) = try_launch_unary_f64(input, device, MISH_F64_PTX, "mish_f64_kernel")? {
17107 return Ok(out);
17108 }
17109 cpu_fallback_unary_f64(input, device, |x| x * (1.0_f64 + x.exp()).ln().tanh())
17110}
17111
17112#[cfg(feature = "cuda")]
17114pub fn gpu_mish_backward_f64(
17115 grad: &CudaBuffer<f64>,
17116 input: &CudaBuffer<f64>,
17117 device: &GpuDevice,
17118) -> GpuResult<CudaBuffer<f64>> {
17119 if grad.len() != input.len() {
17120 return Err(GpuError::LengthMismatch { a: grad.len(), b: input.len() });
17121 }
17122 if let Some(out) = try_launch_binary_f64(grad, input, device, MISH_BACKWARD_F64_PTX, "mish_backward_f64_kernel")? {
17123 return Ok(out);
17124 }
17125 cpu_fallback_binary_f64(grad, input, device, |g, x| {
17126 let sp = (1.0_f64 + x.exp()).ln();
17127 let t = sp.tanh();
17128 let sig = 1.0 / (1.0 + (-x).exp());
17129 g * (t + x * sig * (1.0 - t * t))
17130 })
17131}
17132
17133#[cfg(feature = "cuda")]
17135pub fn gpu_clamp_f64(
17136 input: &CudaBuffer<f64>,
17137 min_val: f64,
17138 max_val: f64,
17139 device: &GpuDevice,
17140) -> GpuResult<CudaBuffer<f64>> {
17141 use cudarc::driver::PushKernelArg;
17142 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17143 let n = input.len();
17144 if n == 0 { return cpu_to_gpu(&[], device); }
17145 let ctx = device.context();
17146 let stream = device.stream();
17147 let ptx = get_f64_ptx(&CACHE, CLAMP_PTX, "clamp_kernel", "clamp_f64_kernel");
17148 if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "clamp_f64_kernel", device.ordinal() as u32) {
17149 let mut out = alloc_zeros_f64(n, device)?;
17150 let n_u32 = n as u32;
17151 let cfg = launch_cfg(n)?;
17152 unsafe {
17153 stream.launch_builder(&f)
17154 .arg(input.inner())
17155 .arg(out.inner_mut())
17156 .arg(&n_u32)
17157 .arg(&min_val)
17158 .arg(&max_val)
17159 .launch(cfg)?;
17160 }
17161 return Ok(out);
17162 }
17163 let host = gpu_to_cpu(input, device)?;
17164 let result: Vec<f64> = host.iter().map(|&x| x.max(min_val).min(max_val)).collect();
17165 cpu_to_gpu(&result, device)
17166}
17167
17168#[cfg(feature = "cuda")]
17170pub fn gpu_cumsum_f64(
17171 input: &CudaBuffer<f64>,
17172 outer: usize,
17173 dim_size: usize,
17174 inner: usize,
17175 device: &GpuDevice,
17176) -> GpuResult<CudaBuffer<f64>> {
17177 use cudarc::driver::PushKernelArg;
17178 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17179 let total = outer * inner;
17180 let n = outer * dim_size * inner;
17181 if n == 0 { return cpu_to_gpu(&[], device); }
17182 let ctx = device.context();
17183 let stream = device.stream();
17184 let ptx = get_f64_ptx(&CACHE, CUMSUM_PTX, "cumsum_kernel", "cumsum_f64_kernel");
17185 if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "cumsum_f64_kernel", device.ordinal() as u32) {
17186 let mut out = alloc_zeros_f64(n, device)?;
17187 let cfg = launch_cfg(total)?;
17188 let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17189 unsafe {
17190 stream.launch_builder(&f)
17191 .arg(input.inner())
17192 .arg(out.inner_mut())
17193 .arg(&o)
17194 .arg(&d)
17195 .arg(&i)
17196 .arg(&t)
17197 .launch(cfg)?;
17198 }
17199 return Ok(out);
17200 }
17201 Err(GpuError::PtxCompileFailed { kernel: "cumsum_f64_kernel" })
17202}
17203
17204#[cfg(feature = "cuda")]
17206pub fn gpu_cumprod_f64(
17207 input: &CudaBuffer<f64>,
17208 outer: usize,
17209 dim_size: usize,
17210 inner: usize,
17211 device: &GpuDevice,
17212) -> GpuResult<CudaBuffer<f64>> {
17213 use cudarc::driver::PushKernelArg;
17214 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17215 let total = outer * inner;
17216 let n = outer * dim_size * inner;
17217 if n == 0 { return cpu_to_gpu(&[], device); }
17218 let ctx = device.context();
17219 let stream = device.stream();
17220 let ptx = get_f64_ptx(&CACHE, CUMPROD_PTX, "cumprod_kernel", "cumprod_f64_kernel");
17221 if let Ok(f) = crate::module_cache::get_or_compile(ctx, ptx, "cumprod_f64_kernel", device.ordinal() as u32) {
17222 let mut out = alloc_zeros_f64(n, device)?;
17223 let cfg = launch_cfg(total)?;
17224 let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17225 unsafe {
17226 stream.launch_builder(&f)
17227 .arg(input.inner())
17228 .arg(out.inner_mut())
17229 .arg(&o)
17230 .arg(&d)
17231 .arg(&i)
17232 .arg(&t)
17233 .launch(cfg)?;
17234 }
17235 return Ok(out);
17236 }
17237 Err(GpuError::PtxCompileFailed { kernel: "cumprod_f64_kernel" })
17238}
17239
17240#[cfg(feature = "cuda")]
17242pub fn gpu_cummax_f64(
17243 input: &CudaBuffer<f64>,
17244 outer: usize,
17245 dim_size: usize,
17246 inner: usize,
17247 device: &GpuDevice,
17248) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
17249 use cudarc::driver::PushKernelArg;
17250 let total = outer * inner;
17251 let n = outer * dim_size * inner;
17252 if n == 0 {
17253 let e: &[f64] = &[];
17254 return Ok((cpu_to_gpu(e, device)?, cpu_to_gpu(e, device)?));
17255 }
17256 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17257 let ctx = device.context();
17258 let stream = device.stream();
17259 let ptx = get_f64_ptx(&CACHE, CUMMAX_PTX, "cummax_kernel", "cummax_f64_kernel");
17260 let f = crate::module_cache::get_or_compile(ctx, ptx, "cummax_f64_kernel", device.ordinal() as u32)
17261 .map_err(|_| GpuError::PtxCompileFailed { kernel: "cummax_f64_kernel" })?;
17262 let mut out = alloc_zeros_f64(n, device)?;
17263 let mut ind = alloc_zeros_f64(n, device)?;
17264 let cfg = launch_cfg(total)?;
17265 let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17266 unsafe {
17267 stream.launch_builder(&f)
17268 .arg(input.inner())
17269 .arg(out.inner_mut())
17270 .arg(ind.inner_mut())
17271 .arg(&o)
17272 .arg(&d)
17273 .arg(&i)
17274 .arg(&t)
17275 .launch(cfg)?;
17276 }
17277 Ok((out, ind))
17278}
17279
17280#[cfg(feature = "cuda")]
17282pub fn gpu_cummin_f64(
17283 input: &CudaBuffer<f64>,
17284 outer: usize,
17285 dim_size: usize,
17286 inner: usize,
17287 device: &GpuDevice,
17288) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
17289 use cudarc::driver::PushKernelArg;
17290 let total = outer * inner;
17291 let n = outer * dim_size * inner;
17292 if n == 0 {
17293 let e: &[f64] = &[];
17294 return Ok((cpu_to_gpu(e, device)?, cpu_to_gpu(e, device)?));
17295 }
17296 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17297 let ctx = device.context();
17298 let stream = device.stream();
17299 let ptx = get_f64_ptx(&CACHE, CUMMIN_PTX, "cummin_kernel", "cummin_f64_kernel");
17300 let f = crate::module_cache::get_or_compile(ctx, ptx, "cummin_f64_kernel", device.ordinal() as u32)
17301 .map_err(|_| GpuError::PtxCompileFailed { kernel: "cummin_f64_kernel" })?;
17302 let mut out = alloc_zeros_f64(n, device)?;
17303 let mut ind = alloc_zeros_f64(n, device)?;
17304 let cfg = launch_cfg(total)?;
17305 let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17306 unsafe {
17307 stream.launch_builder(&f)
17308 .arg(input.inner())
17309 .arg(out.inner_mut())
17310 .arg(ind.inner_mut())
17311 .arg(&o)
17312 .arg(&d)
17313 .arg(&i)
17314 .arg(&t)
17315 .launch(cfg)?;
17316 }
17317 Ok((out, ind))
17318}
17319
17320#[cfg(feature = "cuda")]
17322pub fn gpu_logcumsumexp_f64(
17323 input: &CudaBuffer<f64>,
17324 outer: usize,
17325 dim_size: usize,
17326 inner: usize,
17327 device: &GpuDevice,
17328) -> GpuResult<CudaBuffer<f64>> {
17329 use cudarc::driver::PushKernelArg;
17330 let total = outer * inner;
17331 let n = outer * dim_size * inner;
17332 if n == 0 { return cpu_to_gpu(&[], device); }
17333 let ctx = device.context();
17334 let stream = device.stream();
17335 if let Ok(f) = crate::module_cache::get_or_compile(ctx, LOGCUMSUMEXP_F64_PTX, "logcumsumexp_f64_kernel", device.ordinal() as u32) {
17336 let mut out = alloc_zeros_f64(n, device)?;
17337 let cfg = launch_cfg(total)?;
17338 let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
17339 unsafe {
17340 stream.launch_builder(&f)
17341 .arg(input.inner())
17342 .arg(out.inner_mut())
17343 .arg(&o)
17344 .arg(&d)
17345 .arg(&i)
17346 .arg(&t)
17347 .launch(cfg)?;
17348 }
17349 return Ok(out);
17350 }
17351 Err(GpuError::PtxCompileFailed { kernel: "logcumsumexp_f64_kernel" })
17352}
17353
17354#[cfg(feature = "cuda")]
17363pub fn gpu_softmax_f64(
17364 input: &CudaBuffer<f64>,
17365 rows: usize,
17366 cols: usize,
17367 device: &GpuDevice,
17368) -> GpuResult<CudaBuffer<f64>> {
17369 use cudarc::driver::PushKernelArg;
17370
17371 validate_device(input, device)?;
17372
17373 let ctx = device.context();
17374 let stream = device.stream();
17375
17376 let f = match crate::module_cache::get_or_compile(
17377 ctx,
17378 SOFTMAX_F64_PTX,
17379 "softmax_f64_kernel",
17380 device.ordinal() as u32,
17381 ) {
17382 Ok(f) => f,
17383 Err(_) => {
17384 let host = gpu_to_cpu(input, device)?;
17385 let mut out = vec![0.0f64; host.len()];
17386 for r in 0..rows {
17387 let base = r * cols;
17388 let mut max_v = f64::NEG_INFINITY;
17389 for c in 0..cols {
17390 max_v = max_v.max(host[base + c]);
17391 }
17392 let mut sum = 0.0f64;
17393 for c in 0..cols {
17394 let e = (host[base + c] - max_v).exp();
17395 out[base + c] = e;
17396 sum += e;
17397 }
17398 let inv = 1.0 / sum;
17399 for c in 0..cols {
17400 out[base + c] *= inv;
17401 }
17402 }
17403 return cpu_to_gpu(&out, device);
17404 }
17405 };
17406
17407 let mut out = alloc_zeros_f64(rows * cols, device)?;
17408 let rows_u32 = rows as u32;
17409 let cols_u32 = cols as u32;
17410
17411 let cfg = LaunchConfig {
17412 grid_dim: ((rows as u32).max(1), 1, 1),
17413 block_dim: (256, 1, 1),
17414 shared_mem_bytes: 256 * 8, };
17416
17417 unsafe {
17418 stream
17419 .launch_builder(&f)
17420 .arg(input.inner())
17421 .arg(out.inner_mut())
17422 .arg(&rows_u32)
17423 .arg(&cols_u32)
17424 .launch(cfg)?;
17425 }
17426
17427 Ok(out)
17428}
17429
17430#[cfg(feature = "cuda")]
17434pub fn gpu_softmax_backward_f64(
17435 grad: &CudaBuffer<f64>,
17436 output: &CudaBuffer<f64>,
17437 cols: usize,
17438 device: &GpuDevice,
17439) -> GpuResult<CudaBuffer<f64>> {
17440 use cudarc::driver::PushKernelArg;
17441
17442 validate_device(grad, device)?;
17443 if grad.len() != output.len() {
17444 return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
17445 }
17446
17447 let total = grad.len();
17448 let rows = total / cols;
17449
17450 let ctx = device.context();
17451 let stream = device.stream();
17452
17453 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17454 let ptx = get_f64_ptx(&CACHE, SOFTMAX_BACKWARD_PTX, "softmax_backward_kernel", "softmax_backward_f64_kernel");
17455 let f = match crate::module_cache::get_or_compile(
17456 ctx,
17457 ptx,
17458 "softmax_backward_f64_kernel",
17459 device.ordinal() as u32,
17460 ) {
17461 Ok(f) => f,
17462 Err(_) => {
17463 let grad_host = gpu_to_cpu(grad, device)?;
17464 let output_host = gpu_to_cpu(output, device)?;
17465 let mut result = vec![0.0f64; total];
17466 for r in 0..rows {
17467 let base = r * cols;
17468 let mut dot = 0.0f64;
17469 for c in 0..cols {
17470 dot += grad_host[base + c] * output_host[base + c];
17471 }
17472 for c in 0..cols {
17473 result[base + c] = output_host[base + c] * (grad_host[base + c] - dot);
17474 }
17475 }
17476 return cpu_to_gpu(&result, device);
17477 }
17478 };
17479
17480 let mut out = alloc_zeros_f64(total, device)?;
17481 let rows_u32 = rows as u32;
17482 let cols_u32 = cols as u32;
17483
17484 let cfg = LaunchConfig {
17485 grid_dim: ((rows as u32).max(1), 1, 1),
17486 block_dim: (256, 1, 1),
17487 shared_mem_bytes: 256 * 8,
17488 };
17489
17490 unsafe {
17491 stream
17492 .launch_builder(&f)
17493 .arg(grad.inner())
17494 .arg(output.inner())
17495 .arg(out.inner_mut())
17496 .arg(&rows_u32)
17497 .arg(&cols_u32)
17498 .launch(cfg)?;
17499 }
17500
17501 Ok(out)
17502}
17503
17504#[cfg(feature = "cuda")]
17508pub fn gpu_log_softmax_f64(
17509 input: &CudaBuffer<f64>,
17510 cols: usize,
17511 device: &GpuDevice,
17512) -> GpuResult<CudaBuffer<f64>> {
17513 use cudarc::driver::PushKernelArg;
17514
17515 validate_device(input, device)?;
17516
17517 let total = input.len();
17518 let rows = total / cols;
17519
17520 let ctx = device.context();
17521 let stream = device.stream();
17522
17523 let f = match crate::module_cache::get_or_compile(
17524 ctx,
17525 LOG_SOFTMAX_F64_PTX,
17526 "log_softmax_f64_kernel",
17527 device.ordinal() as u32,
17528 ) {
17529 Ok(f) => f,
17530 Err(_) => {
17531 let host = gpu_to_cpu(input, device)?;
17532 let mut out = vec![0.0f64; total];
17533 for r in 0..rows {
17534 let base = r * cols;
17535 let mut max_v = f64::NEG_INFINITY;
17536 for c in 0..cols {
17537 max_v = max_v.max(host[base + c]);
17538 }
17539 let mut sum_exp = 0.0f64;
17540 for c in 0..cols {
17541 sum_exp += (host[base + c] - max_v).exp();
17542 }
17543 let log_sum_exp = max_v + sum_exp.ln();
17544 for c in 0..cols {
17545 out[base + c] = host[base + c] - log_sum_exp;
17546 }
17547 }
17548 return cpu_to_gpu(&out, device);
17549 }
17550 };
17551
17552 let mut out = alloc_zeros_f64(total, device)?;
17553 let rows_u32 = rows as u32;
17554 let cols_u32 = cols as u32;
17555
17556 let cfg = LaunchConfig {
17557 grid_dim: ((rows as u32).max(1), 1, 1),
17558 block_dim: (256, 1, 1),
17559 shared_mem_bytes: 256 * 8,
17560 };
17561
17562 unsafe {
17563 stream
17564 .launch_builder(&f)
17565 .arg(input.inner())
17566 .arg(out.inner_mut())
17567 .arg(&rows_u32)
17568 .arg(&cols_u32)
17569 .launch(cfg)?;
17570 }
17571
17572 Ok(out)
17573}
17574
17575#[cfg(feature = "cuda")]
17581pub fn gpu_log_softmax_backward_f64(
17582 grad: &CudaBuffer<f64>,
17583 output: &CudaBuffer<f64>,
17584 cols: usize,
17585 device: &GpuDevice,
17586) -> GpuResult<CudaBuffer<f64>> {
17587 use cudarc::driver::PushKernelArg;
17588
17589 validate_device(grad, device)?;
17590 if grad.len() != output.len() {
17591 return Err(GpuError::LengthMismatch { a: grad.len(), b: output.len() });
17592 }
17593
17594 let total = grad.len();
17595 let rows = total / cols;
17596
17597 let ctx = device.context();
17598 let stream = device.stream();
17599
17600 let f = match crate::module_cache::get_or_compile(
17601 ctx,
17602 LOG_SOFTMAX_BACKWARD_F64_PTX,
17603 "log_softmax_backward_f64_kernel",
17604 device.ordinal() as u32,
17605 ) {
17606 Ok(f) => f,
17607 Err(_) => {
17608 let grad_host = gpu_to_cpu(grad, device)?;
17609 let output_host = gpu_to_cpu(output, device)?;
17610 let mut result = vec![0.0f64; total];
17611 for r in 0..rows {
17612 let base = r * cols;
17613 let mut sum_grad = 0.0f64;
17614 for c in 0..cols {
17615 sum_grad += grad_host[base + c];
17616 }
17617 for c in 0..cols {
17618 result[base + c] =
17619 grad_host[base + c] - output_host[base + c].exp() * sum_grad;
17620 }
17621 }
17622 return cpu_to_gpu(&result, device);
17623 }
17624 };
17625
17626 let mut out = alloc_zeros_f64(total, device)?;
17627 let rows_u32 = rows as u32;
17628 let cols_u32 = cols as u32;
17629
17630 let cfg = LaunchConfig {
17631 grid_dim: ((rows as u32).max(1), 1, 1),
17632 block_dim: (256, 1, 1),
17633 shared_mem_bytes: 256 * 8,
17634 };
17635
17636 unsafe {
17637 stream
17638 .launch_builder(&f)
17639 .arg(grad.inner())
17640 .arg(output.inner())
17641 .arg(out.inner_mut())
17642 .arg(&rows_u32)
17643 .arg(&cols_u32)
17644 .launch(cfg)?;
17645 }
17646
17647 Ok(out)
17648}
17649
17650#[cfg(feature = "cuda")]
17655pub fn gpu_layernorm_f64(
17656 input: &CudaBuffer<f64>,
17657 weight: &CudaBuffer<f64>,
17658 bias: &CudaBuffer<f64>,
17659 rows: usize,
17660 cols: usize,
17661 eps: f64,
17662 device: &GpuDevice,
17663) -> GpuResult<CudaBuffer<f64>> {
17664 use cudarc::driver::PushKernelArg;
17665 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17666
17667 validate_device(input, device)?;
17668
17669 let ctx = device.context();
17670 let stream = device.stream();
17671
17672 let ptx = get_f64_ptx(&CACHE, LAYERNORM_PTX, "layernorm_kernel", "layernorm_f64_kernel");
17673 let f = match crate::module_cache::get_or_compile(
17674 ctx,
17675 ptx,
17676 "layernorm_f64_kernel",
17677 device.ordinal() as u32,
17678 ) {
17679 Ok(f) => f,
17680 Err(_) => {
17681 let h_in = gpu_to_cpu(input, device)?;
17682 let h_w = gpu_to_cpu(weight, device)?;
17683 let h_b = gpu_to_cpu(bias, device)?;
17684 let mut out = vec![0.0f64; rows * cols];
17685 for r in 0..rows {
17686 let base = r * cols;
17687 let slice = &h_in[base..base + cols];
17688 let mean: f64 = slice.iter().sum::<f64>() / cols as f64;
17689 let var: f64 =
17690 slice.iter().map(|&x| (x - mean) * (x - mean)).sum::<f64>() / cols as f64;
17691 let inv_std = 1.0 / (var + eps).sqrt();
17692 for c in 0..cols {
17693 let normed = (slice[c] - mean) * inv_std;
17694 out[base + c] = h_w[c] * normed + h_b[c];
17695 }
17696 }
17697 return cpu_to_gpu(&out, device);
17698 }
17699 };
17700
17701 let mut out = alloc_zeros_f64(rows * cols, device)?;
17702 let rows_u32 = rows as u32;
17703 let cols_u32 = cols as u32;
17704
17705 let cfg = LaunchConfig {
17706 grid_dim: ((rows as u32).max(1), 1, 1),
17707 block_dim: (256, 1, 1),
17708 shared_mem_bytes: 256 * 8,
17709 };
17710
17711 unsafe {
17712 stream
17713 .launch_builder(&f)
17714 .arg(input.inner())
17715 .arg(out.inner_mut())
17716 .arg(weight.inner())
17717 .arg(bias.inner())
17718 .arg(&rows_u32)
17719 .arg(&cols_u32)
17720 .arg(&eps)
17721 .launch(cfg)?;
17722 }
17723
17724 Ok(out)
17725}
17726
17727#[cfg(feature = "cuda")]
17731pub fn gpu_layernorm_backward_f64(
17732 input: &CudaBuffer<f64>,
17733 grad_output: &CudaBuffer<f64>,
17734 weight: &CudaBuffer<f64>,
17735 rows: usize,
17736 cols: usize,
17737 eps: f64,
17738 device: &GpuDevice,
17739) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>, CudaBuffer<f64>)> {
17740 use cudarc::driver::PushKernelArg;
17741 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17742
17743 validate_device(input, device)?;
17744
17745 let ctx = device.context();
17746 let stream = device.stream();
17747
17748 let ptx = get_f64_ptx(&CACHE, LAYERNORM_BACKWARD_PTX, "layernorm_backward_kernel", "layernorm_backward_f64_kernel");
17749 let f = match crate::module_cache::get_or_compile(
17750 ctx,
17751 ptx,
17752 "layernorm_backward_f64_kernel",
17753 device.ordinal() as u32,
17754 ) {
17755 Ok(f) => f,
17756 Err(_) => {
17757 let h_in = gpu_to_cpu(input, device)?;
17758 let h_go = gpu_to_cpu(grad_output, device)?;
17759 let h_w = gpu_to_cpu(weight, device)?;
17760 let mut grad_input = vec![0.0f64; rows * cols];
17761 let mut grad_weight = vec![0.0f64; cols];
17762 let mut grad_bias = vec![0.0f64; cols];
17763 let n_f = cols as f64;
17764 for r in 0..rows {
17765 let base = r * cols;
17766 let x_slice = &h_in[base..base + cols];
17767 let go_slice = &h_go[base..base + cols];
17768 let mean: f64 = x_slice.iter().sum::<f64>() / n_f;
17769 let var: f64 = x_slice
17770 .iter()
17771 .map(|&x| (x - mean) * (x - mean))
17772 .sum::<f64>()
17773 / n_f;
17774 let inv_std = 1.0 / (var + eps).sqrt();
17775 let mut sum1 = 0.0f64;
17776 let mut sum2 = 0.0f64;
17777 for c in 0..cols {
17778 let x_hat = (x_slice[c] - mean) * inv_std;
17779 let dl = go_slice[c] * h_w[c];
17780 sum1 += dl;
17781 sum2 += dl * x_hat;
17782 grad_weight[c] += go_slice[c] * x_hat;
17783 grad_bias[c] += go_slice[c];
17784 }
17785 let m1 = sum1 / n_f;
17786 let m2 = sum2 / n_f;
17787 for c in 0..cols {
17788 let x_hat = (x_slice[c] - mean) * inv_std;
17789 let dl = go_slice[c] * h_w[c];
17790 grad_input[base + c] = inv_std * (dl - m1 - x_hat * m2);
17791 }
17792 }
17793 let gi = cpu_to_gpu(&grad_input, device)?;
17794 let gw = cpu_to_gpu(&grad_weight, device)?;
17795 let gb = cpu_to_gpu(&grad_bias, device)?;
17796 return Ok((gi, gw, gb));
17797 }
17798 };
17799
17800 let mut grad_in = alloc_zeros_f64(rows * cols, device)?;
17801 let mut grad_w = alloc_zeros_f64(cols, device)?;
17802 let mut grad_b = alloc_zeros_f64(cols, device)?;
17803 let rows_u32 = rows as u32;
17804 let cols_u32 = cols as u32;
17805
17806 let cfg = LaunchConfig {
17807 grid_dim: ((rows as u32).max(1), 1, 1),
17808 block_dim: (256, 1, 1),
17809 shared_mem_bytes: 256 * 8,
17810 };
17811
17812 unsafe {
17813 stream
17814 .launch_builder(&f)
17815 .arg(input.inner())
17816 .arg(grad_output.inner())
17817 .arg(weight.inner())
17818 .arg(grad_in.inner_mut())
17819 .arg(grad_w.inner_mut())
17820 .arg(grad_b.inner_mut())
17821 .arg(&rows_u32)
17822 .arg(&cols_u32)
17823 .arg(&eps)
17824 .launch(cfg)?;
17825 }
17826
17827 Ok((grad_in, grad_w, grad_b))
17828}
17829
17830#[cfg(feature = "cuda")]
17835pub fn gpu_rmsnorm_f64(
17836 input: &CudaBuffer<f64>,
17837 weight: &CudaBuffer<f64>,
17838 rows: usize,
17839 cols: usize,
17840 eps: f64,
17841 device: &GpuDevice,
17842) -> GpuResult<CudaBuffer<f64>> {
17843 use cudarc::driver::PushKernelArg;
17844 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17845
17846 validate_device(input, device)?;
17847
17848 let ctx = device.context();
17849 let stream = device.stream();
17850
17851 let ptx = get_f64_ptx(&CACHE, RMSNORM_PTX, "rmsnorm_kernel", "rmsnorm_f64_kernel");
17852 let f = match crate::module_cache::get_or_compile(
17853 ctx,
17854 ptx,
17855 "rmsnorm_f64_kernel",
17856 device.ordinal() as u32,
17857 ) {
17858 Ok(f) => f,
17859 Err(_) => {
17860 let h_in = gpu_to_cpu(input, device)?;
17861 let h_w = gpu_to_cpu(weight, device)?;
17862 let mut out = vec![0.0f64; rows * cols];
17863 for r in 0..rows {
17864 let base = r * cols;
17865 let slice = &h_in[base..base + cols];
17866 let sq_mean: f64 =
17867 slice.iter().map(|&x| x * x).sum::<f64>() / cols as f64;
17868 let inv_rms = 1.0 / (sq_mean + eps).sqrt();
17869 for c in 0..cols {
17870 out[base + c] = slice[c] * inv_rms * h_w[c];
17871 }
17872 }
17873 return cpu_to_gpu(&out, device);
17874 }
17875 };
17876
17877 let mut out = alloc_zeros_f64(rows * cols, device)?;
17878 let rows_u32 = rows as u32;
17879 let cols_u32 = cols as u32;
17880
17881 let cfg = LaunchConfig {
17882 grid_dim: ((rows as u32).max(1), 1, 1),
17883 block_dim: (256, 1, 1),
17884 shared_mem_bytes: 256 * 8,
17885 };
17886
17887 unsafe {
17888 stream
17889 .launch_builder(&f)
17890 .arg(input.inner())
17891 .arg(out.inner_mut())
17892 .arg(weight.inner())
17893 .arg(&rows_u32)
17894 .arg(&cols_u32)
17895 .arg(&eps)
17896 .launch(cfg)?;
17897 }
17898
17899 Ok(out)
17900}
17901
17902#[cfg(feature = "cuda")]
17906pub fn gpu_rmsnorm_backward_f64(
17907 input: &CudaBuffer<f64>,
17908 grad_output: &CudaBuffer<f64>,
17909 weight: &CudaBuffer<f64>,
17910 rows: usize,
17911 cols: usize,
17912 eps: f64,
17913 device: &GpuDevice,
17914) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
17915 use cudarc::driver::PushKernelArg;
17916 static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
17917
17918 validate_device(input, device)?;
17919
17920 let ctx = device.context();
17921 let stream = device.stream();
17922
17923 let ptx = get_f64_ptx(&CACHE, RMSNORM_BACKWARD_PTX, "rmsnorm_backward_kernel", "rmsnorm_backward_f64_kernel");
17924 let f = match crate::module_cache::get_or_compile(
17925 ctx,
17926 ptx,
17927 "rmsnorm_backward_f64_kernel",
17928 device.ordinal() as u32,
17929 ) {
17930 Ok(f) => f,
17931 Err(_) => {
17932 let h_in = gpu_to_cpu(input, device)?;
17933 let h_go = gpu_to_cpu(grad_output, device)?;
17934 let h_w = gpu_to_cpu(weight, device)?;
17935 let mut grad_input = vec![0.0f64; rows * cols];
17936 let mut grad_weight = vec![0.0f64; cols];
17937 let n_f = cols as f64;
17938 for r in 0..rows {
17939 let base = r * cols;
17940 let x_slice = &h_in[base..base + cols];
17941 let go_slice = &h_go[base..base + cols];
17942 let sq_mean: f64 =
17943 x_slice.iter().map(|&x| x * x).sum::<f64>() / n_f;
17944 let inv_rms = 1.0 / (sq_mean + eps).sqrt();
17945 let inv_rms3 = inv_rms * inv_rms * inv_rms;
17946 let mut dot = 0.0f64;
17947 for c in 0..cols {
17948 dot += go_slice[c] * x_slice[c] * h_w[c];
17949 grad_weight[c] += go_slice[c] * x_slice[c] * inv_rms;
17950 }
17951 let coeff = dot * inv_rms3 / n_f;
17952 for c in 0..cols {
17953 grad_input[base + c] =
17954 inv_rms * h_w[c] * go_slice[c] - x_slice[c] * coeff;
17955 }
17956 }
17957 let gi = cpu_to_gpu(&grad_input, device)?;
17958 let gw = cpu_to_gpu(&grad_weight, device)?;
17959 return Ok((gi, gw));
17960 }
17961 };
17962
17963 let mut grad_in = alloc_zeros_f64(rows * cols, device)?;
17964 let mut grad_w = alloc_zeros_f64(cols, device)?;
17965 let rows_u32 = rows as u32;
17966 let cols_u32 = cols as u32;
17967
17968 let cfg = LaunchConfig {
17969 grid_dim: ((rows as u32).max(1), 1, 1),
17970 block_dim: (256, 1, 1),
17971 shared_mem_bytes: 256 * 8,
17972 };
17973
17974 unsafe {
17975 stream
17976 .launch_builder(&f)
17977 .arg(input.inner())
17978 .arg(grad_output.inner())
17979 .arg(weight.inner())
17980 .arg(grad_in.inner_mut())
17981 .arg(grad_w.inner_mut())
17982 .arg(&rows_u32)
17983 .arg(&cols_u32)
17984 .arg(&eps)
17985 .launch(cfg)?;
17986 }
17987
17988 Ok((grad_in, grad_w))
17989}
17990
17991#[cfg(not(feature = "cuda"))]
17996pub fn gpu_softmax_f64(_input: &CudaBuffer<f64>, _rows: usize, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17997#[cfg(not(feature = "cuda"))]
17998pub fn gpu_softmax_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
17999#[cfg(not(feature = "cuda"))]
18000pub fn gpu_log_softmax_f64(_input: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18001#[cfg(not(feature = "cuda"))]
18002pub fn gpu_log_softmax_backward_f64(_grad: &CudaBuffer<f64>, _output: &CudaBuffer<f64>, _cols: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18003#[cfg(not(feature = "cuda"))]
18004pub fn gpu_layernorm_f64(_input: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _bias: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18005#[cfg(not(feature = "cuda"))]
18006pub fn gpu_layernorm_backward_f64(_input: &CudaBuffer<f64>, _grad_output: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18007#[cfg(not(feature = "cuda"))]
18008pub fn gpu_rmsnorm_f64(_input: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18009#[cfg(not(feature = "cuda"))]
18010pub fn gpu_rmsnorm_backward_f64(_input: &CudaBuffer<f64>, _grad_output: &CudaBuffer<f64>, _weight: &CudaBuffer<f64>, _rows: usize, _cols: usize, _eps: f64, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18011
18012#[cfg(not(feature = "cuda"))]
18017pub fn gpu_gelu_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18018#[cfg(not(feature = "cuda"))]
18019pub fn gpu_gelu_tanh_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18020#[cfg(not(feature = "cuda"))]
18021pub fn gpu_gelu_erf_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18022#[cfg(not(feature = "cuda"))]
18023pub fn gpu_gelu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18024#[cfg(not(feature = "cuda"))]
18025pub fn gpu_gelu_backward_tanh_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18026#[cfg(not(feature = "cuda"))]
18027pub fn gpu_gelu_backward_erf_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18028#[cfg(not(feature = "cuda"))]
18029pub fn gpu_silu_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18030#[cfg(not(feature = "cuda"))]
18031pub fn gpu_silu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18032#[cfg(not(feature = "cuda"))]
18033pub fn gpu_elu_f64(_input: &CudaBuffer<f64>, _alpha: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18034#[cfg(not(feature = "cuda"))]
18035pub fn gpu_elu_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _alpha: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18036#[cfg(not(feature = "cuda"))]
18037pub fn gpu_mish_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18038#[cfg(not(feature = "cuda"))]
18039pub fn gpu_mish_backward_f64(_grad: &CudaBuffer<f64>, _input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18040#[cfg(not(feature = "cuda"))]
18041pub fn gpu_clamp_f64(_input: &CudaBuffer<f64>, _min: f64, _max: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18042#[cfg(not(feature = "cuda"))]
18043pub fn gpu_cumsum_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18044#[cfg(not(feature = "cuda"))]
18045pub fn gpu_cumprod_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18046#[cfg(not(feature = "cuda"))]
18047pub fn gpu_cummax_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18048#[cfg(not(feature = "cuda"))]
18049pub fn gpu_cummin_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> { Err(GpuError::NoCudaFeature) }
18050#[cfg(not(feature = "cuda"))]
18051pub fn gpu_logcumsumexp_f64(_input: &CudaBuffer<f64>, _outer: usize, _dim_size: usize, _inner: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> { Err(GpuError::NoCudaFeature) }
18052
18053#[cfg(test)]
18058#[cfg(feature = "cuda")]
18059mod tests {
18060 use super::*;
18061
18062 fn setup(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
18064 let dev = GpuDevice::new(0).expect("CUDA device 0");
18065 let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
18066 (dev, buf)
18067 }
18068
18069 fn assert_buf_eq(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32]) {
18072 let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
18073 assert_eq!(host.len(), expected.len(), "length mismatch");
18074 for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
18075 assert!(
18076 (got - exp).abs() < 1e-6,
18077 "element {i}: got {got}, expected {exp}",
18078 );
18079 }
18080 }
18081
18082 #[test]
18085 fn add_basic() {
18086 let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
18087 let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
18088 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
18089
18090 let (dev, a) = setup(&a_data);
18091 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18092 let out = gpu_add(&a, &b, &dev).expect("gpu_add");
18093 assert_buf_eq(&out, &dev, &expected);
18094 }
18095
18096 #[test]
18097 fn add_empty() {
18098 let (dev, a) = setup(&[]);
18099 let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
18100 let out = gpu_add(&a, &b, &dev).expect("gpu_add empty");
18101 assert_eq!(out.len(), 0);
18102 }
18103
18104 #[test]
18105 fn add_large() {
18106 let n = 100_000;
18107 let a_data: Vec<f32> = (0..n).map(|i| i as f32).collect();
18108 let b_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
18109 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
18110
18111 let (dev, a) = setup(&a_data);
18112 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18113 let out = gpu_add(&a, &b, &dev).expect("gpu_add large");
18114 assert_buf_eq(&out, &dev, &expected);
18115 }
18116
18117 #[test]
18118 fn add_length_mismatch() {
18119 let (dev, a) = setup(&[1.0, 2.0, 3.0]);
18120 let b = cpu_to_gpu(&[1.0, 2.0], &dev).expect("cpu_to_gpu b");
18121 let err = gpu_add(&a, &b, &dev).unwrap_err();
18122 match err {
18123 GpuError::LengthMismatch { a: 3, b: 2 } => {}
18124 other => panic!("unexpected error: {other}"),
18125 }
18126 }
18127
18128 #[test]
18131 fn sub_basic() {
18132 let a_data = vec![10.0f32, 20.0, 30.0, 40.0];
18133 let b_data = vec![1.0f32, 2.0, 3.0, 4.0];
18134 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x - y).collect();
18135
18136 let (dev, a) = setup(&a_data);
18137 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18138 let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
18139 assert_buf_eq(&out, &dev, &expected);
18140 }
18141
18142 #[test]
18143 fn sub_negative_result() {
18144 let a_data = vec![1.0f32, 2.0];
18145 let b_data = vec![5.0f32, 10.0];
18146 let expected: Vec<f32> = vec![-4.0, -8.0];
18147
18148 let (dev, a) = setup(&a_data);
18149 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18150 let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
18151 assert_buf_eq(&out, &dev, &expected);
18152 }
18153
18154 #[test]
18157 fn mul_basic() {
18158 let a_data = vec![2.0f32, 3.0, 4.0, 5.0];
18159 let b_data = vec![10.0f32, 10.0, 10.0, 10.0];
18160 let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x * y).collect();
18161
18162 let (dev, a) = setup(&a_data);
18163 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18164 let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
18165 assert_buf_eq(&out, &dev, &expected);
18166 }
18167
18168 #[test]
18169 fn mul_by_zero() {
18170 let a_data = vec![1.0f32, 2.0, 3.0];
18171 let b_data = vec![0.0f32, 0.0, 0.0];
18172 let expected = vec![0.0f32, 0.0, 0.0];
18173
18174 let (dev, a) = setup(&a_data);
18175 let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
18176 let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
18177 assert_buf_eq(&out, &dev, &expected);
18178 }
18179
18180 #[test]
18183 fn neg_basic() {
18184 let a_data = vec![1.0f32, -2.0, 3.0, 0.0, -5.5];
18185 let expected: Vec<f32> = a_data.iter().map(|x| -x).collect();
18186
18187 let (dev, a) = setup(&a_data);
18188 let out = gpu_neg(&a, &dev).expect("gpu_neg");
18189 assert_buf_eq(&out, &dev, &expected);
18190 }
18191
18192 #[test]
18193 fn neg_double_negation() {
18194 let a_data = vec![1.0f32, -2.0, 3.0];
18195 let (dev, a) = setup(&a_data);
18196 let neg1 = gpu_neg(&a, &dev).expect("gpu_neg 1");
18197 let neg2 = gpu_neg(&neg1, &dev).expect("gpu_neg 2");
18198 assert_buf_eq(&neg2, &dev, &a_data);
18199 }
18200
18201 #[test]
18204 fn relu_basic() {
18205 let a_data = vec![-3.0f32, -1.0, 0.0, 1.0, 3.0];
18206 let expected = vec![0.0f32, 0.0, 0.0, 1.0, 3.0];
18207
18208 let (dev, a) = setup(&a_data);
18209 let out = gpu_relu(&a, &dev).expect("gpu_relu");
18210 assert_buf_eq(&out, &dev, &expected);
18211 }
18212
18213 #[test]
18214 fn relu_all_negative() {
18215 let a_data = vec![-5.0f32, -0.1, -100.0];
18216 let expected = vec![0.0f32, 0.0, 0.0];
18217
18218 let (dev, a) = setup(&a_data);
18219 let out = gpu_relu(&a, &dev).expect("gpu_relu");
18220 assert_buf_eq(&out, &dev, &expected);
18221 }
18222
18223 #[test]
18224 fn relu_all_positive() {
18225 let a_data = vec![0.1f32, 1.0, 100.0];
18226
18227 let (dev, a) = setup(&a_data);
18228 let out = gpu_relu(&a, &dev).expect("gpu_relu");
18229 assert_buf_eq(&out, &dev, &a_data);
18230 }
18231
18232 #[test]
18233 fn relu_empty() {
18234 let (dev, a) = setup(&[]);
18235 let out = gpu_relu(&a, &dev).expect("gpu_relu empty");
18236 assert_eq!(out.len(), 0);
18237 }
18238
18239 #[test]
18240 fn small_matmul_2x2() {
18241 let dev = GpuDevice::new(0).expect("CUDA device 0");
18242 let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0, 4.0], &dev).unwrap();
18245 let b = cpu_to_gpu(&[5.0f32, 6.0, 7.0, 8.0], &dev).unwrap();
18246 let c = gpu_small_matmul(&a, &b, 2, 2, 2, &dev).unwrap();
18247 assert_buf_eq(&c, &dev, &[19.0, 22.0, 43.0, 50.0]);
18248 }
18249
18250 #[test]
18251 fn small_matmul_1xk_kxn() {
18252 let dev = GpuDevice::new(0).expect("CUDA device 0");
18253 let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0], &dev).unwrap();
18256 let b = cpu_to_gpu(&[1.0f32, 0.0, 0.0, 1.0, 1.0, 1.0], &dev).unwrap();
18257 let c = gpu_small_matmul(&a, &b, 1, 3, 2, &dev).unwrap();
18258 assert_buf_eq(&c, &dev, &[4.0, 5.0]);
18259 }
18260
18261 #[test]
18262 fn small_matmul_vs_cublas() {
18263 let dev = GpuDevice::new(0).expect("CUDA device 0");
18266 let m = 1;
18267 let k = 64;
18268 let n = 64;
18269
18270 let a_data: Vec<f32> = (0..m * k)
18272 .map(|i| ((i * 7 + 3) % 100) as f32 / 100.0)
18273 .collect();
18274 let b_data: Vec<f32> = (0..k * n)
18275 .map(|i| ((i * 11 + 5) % 100) as f32 / 100.0)
18276 .collect();
18277
18278 let a = cpu_to_gpu(&a_data, &dev).unwrap();
18279 let b = cpu_to_gpu(&b_data, &dev).unwrap();
18280
18281 let c_cublas = crate::blas::gpu_matmul_f32(&a, &b, m, k, n, &dev).unwrap();
18283 let cublas_result = gpu_to_cpu(&c_cublas, &dev).unwrap();
18284
18285 let c_ours = gpu_small_matmul(&a, &b, m, k, n, &dev).unwrap();
18287 let our_result = gpu_to_cpu(&c_ours, &dev).unwrap();
18288
18289 assert_eq!(cublas_result.len(), our_result.len());
18290 for (i, (&cb, &ours)) in cublas_result.iter().zip(our_result.iter()).enumerate() {
18291 assert!(
18292 (cb - ours).abs() < 0.1,
18293 "element {i}: cuBLAS={cb}, ours={ours}, diff={}",
18294 (cb - ours).abs()
18295 );
18296 }
18297 }
18298
18299 #[test]
18302 fn strided_copy_identity_contiguous_2d() {
18303 let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
18307 let (dev, input) = setup(&data);
18308 let out = gpu_strided_copy(&input, &[2, 3], &[3, 1], 0, &dev)
18309 .expect("strided_copy identity");
18310 assert_buf_eq(&out, &dev, &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
18311 }
18312
18313 #[test]
18314 fn strided_copy_transpose_2d() {
18315 let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
18322 let (dev, input) = setup(&data);
18323 let out = gpu_strided_copy(&input, &[3, 2], &[1, 3], 0, &dev)
18324 .expect("strided_copy transpose");
18325 assert_buf_eq(&out, &dev, &[0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
18326 }
18327
18328 #[test]
18329 fn strided_copy_sliced_column() {
18330 let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
18337 let (dev, input) = setup(&data);
18338 let out = gpu_strided_copy(&input, &[3], &[4], 2, &dev)
18339 .expect("strided_copy col slice");
18340 assert_buf_eq(&out, &dev, &[2.0, 6.0, 10.0]);
18341 }
18342
18343 #[test]
18344 fn strided_copy_3d_permute() {
18345 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
18352 let (dev, input) = setup(&data);
18353 let out =
18354 gpu_strided_copy(&input, &[2, 4, 3], &[12, 1, 4], 0, &dev).expect("strided_copy 3d");
18355
18356 let mut expected = vec![0.0f32; 24];
18357 for b in 0..2 {
18358 for i in 0..4 {
18359 for j in 0..3 {
18360 let dst = b * 12 + i * 3 + j;
18361 let src = b * 12 + j * 4 + i;
18362 expected[dst] = data[src];
18363 }
18364 }
18365 }
18366 assert_buf_eq(&out, &dev, &expected);
18367 }
18368
18369 #[test]
18370 fn strided_copy_4d_max_rank_supported() {
18371 let shape = [2usize, 3, 2, 2];
18373 let n: usize = shape.iter().product();
18374 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
18375 let (dev, input) = setup(&data);
18376 let out = gpu_strided_copy(&input, &shape, &[12, 4, 2, 1], 0, &dev)
18378 .expect("strided_copy 4d");
18379 assert_buf_eq(&out, &dev, &data);
18380 }
18381
18382 #[test]
18383 fn strided_copy_rejects_too_many_dims() {
18384 let (dev, input) = setup(&[0.0f32; 16]);
18385 let result = gpu_strided_copy(
18387 &input,
18388 &[1, 1, 1, 1, 1, 1, 1, 1, 16],
18389 &[1; 9],
18390 0,
18391 &dev,
18392 );
18393 assert!(result.is_err());
18394 }
18395
18396 #[test]
18397 fn strided_copy_rejects_shape_stride_length_mismatch() {
18398 let (dev, input) = setup(&[0.0f32; 12]);
18399 let result = gpu_strided_copy(&input, &[3, 4], &[4, 1, 1], 0, &dev);
18400 assert!(result.is_err());
18401 }
18402
18403 #[test]
18404 fn strided_copy_rejects_negative_stride() {
18405 let (dev, input) = setup(&[0.0f32; 6]);
18406 let result = gpu_strided_copy(&input, &[2, 3], &[3, -1], 0, &dev);
18407 assert!(result.is_err());
18408 }
18409
18410 #[test]
18411 fn strided_copy_empty_output() {
18412 let (dev, input) = setup(&[1.0f32, 2.0, 3.0]);
18413 let out = gpu_strided_copy(&input, &[0, 3], &[3, 1], 0, &dev)
18414 .expect("strided_copy empty");
18415 assert_eq!(out.len(), 0);
18416 }
18417
18418 #[test]
18419 fn strided_copy_f64_transpose_matches_f32() {
18420 let data: Vec<f64> = (0..6).map(|i| i as f64).collect();
18422 let dev = GpuDevice::new(0).expect("CUDA device 0");
18423 let input = cpu_to_gpu(&data, &dev).expect("cpu_to_gpu f64");
18424 let out = gpu_strided_copy_f64(&input, &[3, 2], &[1, 3], 0, &dev)
18425 .expect("strided_copy_f64 transpose");
18426 let host = gpu_to_cpu(&out, &dev).expect("gpu_to_cpu f64");
18427 assert_eq!(host, vec![0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
18428 }
18429}