#[cfg(feature = "cuda")]
use cudarc::driver::LaunchConfig;
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
#[cfg(feature = "cuda")]
use crate::transfer::{alloc_zeros_f32, alloc_zeros_f64, cpu_to_gpu, gpu_to_cpu};
#[cfg(feature = "cuda")]
pub(crate) fn ptx_f32_to_f64(
f32_ptx: &str,
f32_kernel_name: &str,
f64_kernel_name: &str,
) -> String {
f32_ptx
.replace(f32_kernel_name, f64_kernel_name)
.replace(".reg .f32", ".reg .f64")
.replace("ld.global.f32", "ld.global.f64")
.replace("st.global.f32", "st.global.f64")
.replace("ld.shared.f32", "ld.shared.f64")
.replace("st.shared.f32", "st.shared.f64")
.replace("ld.param.f32", "ld.param.f64")
.replace(".param .f32", ".param .f64")
.replace(".shared .align 4 .f32", ".shared .align 8 .f64")
.replace("add.f32", "add.f64")
.replace("sub.f32", "sub.f64")
.replace("mul.f32", "mul.f64")
.replace("div.rn.f32", "div.rn.f64")
.replace("div.f32", "div.f64")
.replace("neg.f32", "neg.f64")
.replace("abs.f32", "abs.f64")
.replace("max.f32", "max.f64")
.replace("min.f32", "min.f64")
.replace("selp.f32", "selp.f64")
.replace("sqrt.rn.f32", "sqrt.rn.f64")
.replace("sqrt.f32", "sqrt.f64")
.replace("fma.rn.f32", "fma.rn.f64")
.replace("mov.f32", "mov.f64")
.replace("setp.gt.f32", "setp.gt.f64")
.replace("setp.ge.f32", "setp.ge.f64")
.replace("setp.lt.f32", "setp.lt.f64")
.replace("setp.le.f32", "setp.le.f64")
.replace("setp.eq.f32", "setp.eq.f64")
.replace("setp.ne.f32", "setp.ne.f64")
.replace("cvt.rn.f32.u32", "cvt.rn.f64.u32")
.replace("cvt.rn.f32.s32", "cvt.rn.f64.s32")
.replace(
"mov.b32 %acc, 0xFF800000",
"mov.b64 %acc, 0xFFF0000000000000",
)
.replace(
"mov.b32 %acc, 0x7F800000",
"mov.b64 %acc, 0x7FF0000000000000",
)
.replace("mov.b32", "mov.b64")
.replace("shl.b64 %off, %off, 2", "shl.b64 %off, %off, 3")
.replace("shl.b64 %off_in, %off_in, 2", "shl.b64 %off_in, %off_in, 3")
.replace(
"shl.b64 %off_out, %off_out, 2",
"shl.b64 %off_out, %off_out, 3",
)
.replace(
"shl.b64 %off_src, %off_src, 2",
"shl.b64 %off_src, %off_src, 3",
)
.replace(
"shl.b64 %off_dst, %off_dst, 2",
"shl.b64 %off_dst, %off_dst, 3",
)
.replace("shl.b64 %off_a, %off_a, 2", "shl.b64 %off_a, %off_a, 3")
.replace("shl.b64 %off_b, %off_b, 2", "shl.b64 %off_b, %off_b, 3")
.replace(
"shl.b64 %row_off, %row_off, 2",
"shl.b64 %row_off, %row_off, 3",
)
.replace("rcp.approx.f32", "rcp.rn.f64")
.replace("div.approx.f32", "div.rn.f64")
.replace("sqrt.approx.f32", "sqrt.rn.f64")
.replace(".target sm_52", ".target sm_60")
.replace("atom.global.add.f32", "atom.global.add.f64")
.replace("0f00000000", "0d0000000000000000") .replace("0f3F800000", "0d3FF0000000000000") .replace("0fBF800000", "0dBFF0000000000000") .replace("0f40000000", "0d4000000000000000") .replace("0f3F000000", "0d3FE0000000000000") .replace("0fFF800000", "0dFFF0000000000000") .replace("0f7F800000", "0d7FF0000000000000") .replace("0f3FB8AA3B", "0d3FF71547652B82FE") .replace("0f3F317218", "0d3FE62E42FEFA39EF") }
#[cfg(feature = "cuda")]
pub(crate) fn get_f64_ptx<'a>(
cache: &'a std::sync::OnceLock<String>,
f32_ptx: &str,
f32_name: &str,
f64_name: &str,
) -> &'a str {
cache.get_or_init(|| ptx_f32_to_f64(f32_ptx, f32_name, f64_name))
}
#[cfg(feature = "cuda")]
pub(crate) const ADD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry add_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
ld.global.f32 %vb, [%b];
add.f32 %vr, %va, %vb;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const ADD_VEC4_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry add_vec4_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n4
) {
.reg .u32 %r_tid, %bid, %bdim, %n4_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n4_reg, [n4];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n4_reg;
@%p bra DONE;
// Byte offset = tid * 16 (4 floats × 4 bytes)
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 4;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
add.f32 %r0, %a0, %b0;
add.f32 %r1, %a1, %b1;
add.f32 %r2, %a2, %b2;
add.f32 %r3, %a3, %b3;
st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MUL_VEC4_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mul_vec4_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n4
) {
.reg .u32 %r_tid, %bid, %bdim, %n4_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %a0, %a1, %a2, %a3, %b0, %b1, %b2, %b3, %r0, %r1, %r2, %r3;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n4_reg, [n4];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n4_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 4;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.v4.f32 {%a0, %a1, %a2, %a3}, [%a];
ld.global.v4.f32 {%b0, %b1, %b2, %b3}, [%b];
mul.f32 %r0, %a0, %b0;
mul.f32 %r1, %a1, %b1;
mul.f32 %r2, %a2, %b2;
mul.f32 %r3, %a3, %b3;
st.global.v4.f32 [%out], {%r0, %r1, %r2, %r3};
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SUB_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sub_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
ld.global.f32 %vb, [%b];
sub.f32 %vr, %va, %vb;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MUL_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mul_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
ld.global.f32 %vb, [%b];
mul.f32 %vr, %va, %vb;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const NEG_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry neg_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
neg.f32 %vr, %va;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const RELU_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry relu_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr, %zero;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
mov.f32 %zero, 0f00000000;
max.f32 %vr, %va, %zero;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SCALE_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry scale_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .f32 scalar,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr, %s;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.f32 %s, [scalar];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
mul.f32 %vr, %va, %s;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const TRANSPOSE_2D_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry transpose_2d_kernel(\n\
.param .u64 in_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 M,\n\
.param .u32 N,\n\
.param .u32 total\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %N_reg;\n\
.reg .u32 %out_row, %out_col, %in_idx;\n\
.reg .u64 %in, %out, %off_in, %off_out;\n\
.reg .f32 %val;\n\
.reg .pred %p;\n\
\n\
ld.param.u64 %in, [in_ptr];\n\
ld.param.u64 %out, [out_ptr];\n\
ld.param.u32 %M_reg, [M];\n\
ld.param.u32 %N_reg, [N];\n\
ld.param.u32 %total_reg, [total];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %total_reg;\n\
@%p bra DONE;\n\
\n\
// Output shape is [N, M]. tid = out_row * M + out_col.\n\
div.u32 %out_row, %r_tid, %M_reg;\n\
rem.u32 %out_col, %r_tid, %M_reg;\n\
// Input index: out_col * N + out_row (transposed).\n\
mad.lo.u32 %in_idx, %out_col, %N_reg, %out_row;\n\
\n\
cvt.u64.u32 %off_in, %in_idx;\n\
shl.b64 %off_in, %off_in, 2;\n\
add.u64 %off_in, %in, %off_in;\n\
ld.global.f32 %val, [%off_in];\n\
\n\
cvt.u64.u32 %off_out, %r_tid;\n\
shl.b64 %off_out, %off_out, 2;\n\
add.u64 %off_out, %out, %off_out;\n\
st.global.f32 [%off_out], %val;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const PERMUTE_0213_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry permute_0213_kernel(\n\
.param .u64 in_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 d0,\n\
.param .u32 d1,\n\
.param .u32 d2,\n\
.param .u32 d3,\n\
.param .u32 total\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %total_reg;\n\
.reg .u32 %d0r, %d1r, %d2r, %d3r;\n\
.reg .u32 %i0, %i1, %i2, %i3, %rem, %in_idx;\n\
.reg .u32 %s_out2, %s_out1, %s_in1;\n\
.reg .u64 %in, %out, %off_in, %off_out;\n\
.reg .f32 %val;\n\
.reg .pred %p;\n\
\n\
ld.param.u64 %in, [in_ptr];\n\
ld.param.u64 %out, [out_ptr];\n\
ld.param.u32 %d0r, [d0];\n\
ld.param.u32 %d1r, [d1];\n\
ld.param.u32 %d2r, [d2];\n\
ld.param.u32 %d3r, [d3];\n\
ld.param.u32 %total_reg, [total];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %total_reg;\n\
@%p bra DONE;\n\
\n\
// Output shape: [d0, d2, d1, d3]\n\
// Decompose tid into (i0, i2, i1, i3) in output layout.\n\
mul.lo.u32 %s_out2, %d1r, %d3r;\n\
mul.lo.u32 %s_out1, %s_out2, %d2r;\n\
\n\
div.u32 %i0, %r_tid, %s_out1;\n\
rem.u32 %rem, %r_tid, %s_out1;\n\
div.u32 %i2, %rem, %s_out2;\n\
rem.u32 %rem, %rem, %s_out2;\n\
div.u32 %i1, %rem, %d3r;\n\
rem.u32 %i3, %rem, %d3r;\n\
\n\
// Input index: i0 * (d1*d2*d3) + i1 * (d2*d3) + i2 * d3 + i3\n\
mul.lo.u32 %s_in1, %d2r, %d3r;\n\
mul.lo.u32 %in_idx, %i0, %d1r;\n\
add.u32 %in_idx, %in_idx, %i1;\n\
mul.lo.u32 %in_idx, %in_idx, %s_in1;\n\
mad.lo.u32 %in_idx, %i2, %d3r, %in_idx;\n\
add.u32 %in_idx, %in_idx, %i3;\n\
\n\
cvt.u64.u32 %off_in, %in_idx;\n\
shl.b64 %off_in, %off_in, 2;\n\
add.u64 %off_in, %in, %off_in;\n\
ld.global.f32 %val, [%off_in];\n\
\n\
cvt.u64.u32 %off_out, %r_tid;\n\
shl.b64 %off_out, %off_out, 2;\n\
add.u64 %off_out, %out, %off_out;\n\
st.global.f32 [%off_out], %val;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const F32_TO_F16_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry f32_to_f16_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off_in, %off_out;
.reg .f32 %vf;
.reg .b16 %vh;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// Compute input offset: i * 4 (f32 = 4 bytes)
cvt.u64.u32 %off_in, %r_tid;
shl.b64 %off_in, %off_in, 2;
add.u64 %in, %in, %off_in;
// Compute output offset: i * 2 (f16 = 2 bytes)
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 1;
add.u64 %out, %out, %off_out;
// Load f32, convert to f16 (round-to-nearest-even), store as u16
ld.global.f32 %vf, [%in];
cvt.rn.f16.f32 %vh, %vf;
st.global.b16 [%out], %vh;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const F32_TO_BF16_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry f32_to_bf16_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off_in, %off_out;
.reg .f32 %vf;
.reg .u32 %bits, %round, %lsb, %result;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off_in, %r_tid;
shl.b64 %off_in, %off_in, 2;
add.u64 %in, %in, %off_in;
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 1;
add.u64 %out, %out, %off_out;
// Load f32 as raw bits
ld.global.u32 %bits, [%in];
// Round-to-nearest-even: add (0x7FFF + bit[16]) then shift right 16
shr.u32 %lsb, %bits, 16;
and.b32 %lsb, %lsb, 1;
add.u32 %round, %bits, 0x7FFF;
add.u32 %round, %round, %lsb;
shr.u32 %result, %round, 16;
// Store as u16
st.global.u16 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SMALL_MATMUL_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry small_matmul_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 c_ptr,
.param .u32 M,
.param .u32 K,
.param .u32 N,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %total_reg, %M_reg, %K_reg, %N_reg;
.reg .u32 %row, %col, %p, %idx;
.reg .u64 %a, %b, %c, %a_off, %b_off, %c_off;
.reg .f32 %sum, %va, %vb;
.reg .pred %bounds_p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %c, [c_ptr];
ld.param.u32 %M_reg, [M];
ld.param.u32 %K_reg, [K];
ld.param.u32 %N_reg, [N];
ld.param.u32 %total_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %bounds_p, %r_tid, %total_reg;
@%bounds_p bra DONE;
div.u32 %row, %r_tid, %N_reg;
rem.u32 %col, %r_tid, %N_reg;
mov.f32 %sum, 0f00000000;
mov.u32 %p, 0;
DOT:
setp.ge.u32 %loop_p, %p, %K_reg;
@%loop_p bra DOT_DONE;
mad.lo.u32 %idx, %row, %K_reg, %p;
cvt.u64.u32 %a_off, %idx;
shl.b64 %a_off, %a_off, 2;
add.u64 %a_off, %a, %a_off;
ld.global.f32 %va, [%a_off];
mad.lo.u32 %idx, %p, %N_reg, %col;
cvt.u64.u32 %b_off, %idx;
shl.b64 %b_off, %b_off, 2;
add.u64 %b_off, %b, %b_off;
ld.global.f32 %vb, [%b_off];
fma.rn.f32 %sum, %va, %vb, %sum;
add.u32 %p, %p, 1;
bra DOT;
DOT_DONE:
cvt.u64.u32 %c_off, %r_tid;
shl.b64 %c_off, %c_off, 2;
add.u64 %c_off, %c, %c_off;
st.global.f32 [%c_off], %sum;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SLICE_WRITE_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry slice_write_kernel(
.param .u64 src_ptr,
.param .u64 dst_ptr,
.param .u32 n,
.param .u32 D,
.param .u32 max_len,
.param .u32 pos
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
.reg .u32 %batch_idx, %d_idx, %dst_row;
.reg .u64 %src, %dst, %src_off, %dst_off;
.reg .f32 %val;
.reg .pred %p;
ld.param.u64 %src, [src_ptr];
ld.param.u64 %dst, [dst_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %D_reg, [D];
ld.param.u32 %max_len_reg, [max_len];
ld.param.u32 %pos_reg, [pos];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %src_off, %r_tid;
shl.b64 %src_off, %src_off, 2;
add.u64 %src, %src, %src_off;
ld.global.f32 %val, [%src];
div.u32 %batch_idx, %r_tid, %D_reg;
rem.u32 %d_idx, %r_tid, %D_reg;
mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
add.u32 %dst_row, %dst_row, %pos_reg;
mul.lo.u32 %dst_row, %dst_row, %D_reg;
add.u32 %dst_row, %dst_row, %d_idx;
cvt.u64.u32 %dst_off, %dst_row;
shl.b64 %dst_off, %dst_off, 2;
add.u64 %dst, %dst, %dst_off;
st.global.f32 [%dst], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SLICE_WRITE_INDIRECT_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry slice_write_indirect_kernel(
.param .u64 src_ptr,
.param .u64 dst_ptr,
.param .u32 n,
.param .u32 D,
.param .u32 max_len,
.param .u64 pos_ptr
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %D_reg, %max_len_reg, %pos_reg;
.reg .u32 %batch_idx, %d_idx, %dst_row;
.reg .u64 %src, %dst, %src_off, %dst_off, %pos_p;
.reg .f32 %val;
.reg .pred %p;
ld.param.u64 %src, [src_ptr];
ld.param.u64 %dst, [dst_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %D_reg, [D];
ld.param.u32 %max_len_reg, [max_len];
ld.param.u64 %pos_p, [pos_ptr];
ld.global.u32 %pos_reg, [%pos_p];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %src_off, %r_tid;
shl.b64 %src_off, %src_off, 2;
add.u64 %src, %src, %src_off;
ld.global.f32 %val, [%src];
div.u32 %batch_idx, %r_tid, %D_reg;
rem.u32 %d_idx, %r_tid, %D_reg;
mul.lo.u32 %dst_row, %batch_idx, %max_len_reg;
add.u32 %dst_row, %dst_row, %pos_reg;
mul.lo.u32 %dst_row, %dst_row, %D_reg;
add.u32 %dst_row, %dst_row, %d_idx;
cvt.u64.u32 %dst_off, %dst_row;
shl.b64 %dst_off, %dst_off, 2;
add.u64 %dst, %dst, %dst_off;
st.global.f32 [%dst], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const CAUSAL_MASK_INDIRECT_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry causal_mask_indirect_kernel(
.param .u64 total_len_ptr,
.param .u64 out_ptr,
.param .u32 max_pos,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %total_reg, %tlen, %max_pos_reg, %col;
.reg .u64 %out, %off, %tl_p;
.reg .f32 %val;
.reg .pred %bounds_p, %mask_p;
ld.param.u64 %tl_p, [total_len_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %max_pos_reg, [max_pos];
ld.param.u32 %total_reg, [total];
ld.global.u32 %tlen, [%tl_p];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %bounds_p, %r_tid, %total_reg;
@%bounds_p bra DONE;
rem.u32 %col, %r_tid, %max_pos_reg;
setp.lt.u32 %mask_p, %col, %tlen;
@%mask_p bra WRITE_ZERO;
// 0fCE6E6B28 = -1.0e9 in IEEE 754 f32, used as a large negative mask value
// to effectively zero out masked positions after softmax.
mov.f32 %val, 0fCE6E6B28;
bra WRITE;
WRITE_ZERO:
mov.f32 %val, 0f00000000;
WRITE:
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %out, %out, %off;
st.global.f32 [%out], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const EMBED_LOOKUP_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry embed_lookup_kernel(
.param .u64 idx_ptr,
.param .u64 weight_ptr,
.param .u64 out_ptr,
.param .u32 D
) {
.reg .u32 %r_tid, %bid, %bdim, %D_reg, %row, %src_idx;
.reg .u64 %idx_addr, %w, %out, %off;
.reg .f32 %idx_f, %val;
.reg .pred %p;
ld.param.u64 %idx_addr, [idx_ptr];
ld.param.u64 %w, [weight_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %D_reg, [D];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %D_reg;
@%p bra DONE;
ld.global.f32 %idx_f, [%idx_addr];
cvt.rzi.u32.f32 %row, %idx_f;
mad.lo.u32 %src_idx, %row, %D_reg, %r_tid;
cvt.u64.u32 %off, %src_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %w, %off;
ld.global.f32 %val, [%off];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
st.global.f32 [%off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const EMBED_LOOKUP_BATCH_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry embed_lookup_batch_kernel(
.param .u64 idx_ptr,
.param .u64 weight_ptr,
.param .u64 out_ptr,
.param .u32 D,
.param .u32 total
) {
.reg .u32 %my_tid, %bid, %bdim, %D_reg, %total_reg;
.reg .u32 %row, %col, %src_idx;
.reg .u64 %idx_addr, %w, %out, %off;
.reg .f32 %idx_f, %val;
.reg .pred %p;
ld.param.u64 %idx_addr, [idx_ptr];
ld.param.u64 %w, [weight_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %D_reg, [D];
ld.param.u32 %total_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %my_tid, %tid.x;
mad.lo.u32 %my_tid, %bid, %bdim, %my_tid;
setp.ge.u32 %p, %my_tid, %total_reg;
@%p bra DONE;
// row = tid / D, col = tid % D
div.u32 %row, %my_tid, %D_reg;
rem.u32 %col, %my_tid, %D_reg;
// Read indices[row] (f32 -> u32)
cvt.u64.u32 %off, %row;
shl.b64 %off, %off, 2;
add.u64 %off, %idx_addr, %off;
ld.global.f32 %idx_f, [%off];
cvt.rzi.u32.f32 %src_idx, %idx_f;
// src_idx = indices[row] * D + col
mad.lo.u32 %src_idx, %src_idx, %D_reg, %col;
cvt.u64.u32 %off, %src_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %w, %off;
ld.global.f32 %val, [%off];
// Write to out[tid]
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
st.global.f32 [%off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SCATTER_ADD_ROWS_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry scatter_add_rows_kernel(
.param .u64 grad_output_ptr,
.param .u64 indices_ptr,
.param .u64 grad_weight_ptr,
.param .u32 D,
.param .u32 total
) {
.reg .u32 %my_tid, %bid, %bdim, %D_reg, %total_reg;
.reg .u32 %row, %col, %dst_idx;
.reg .u64 %go, %idx_addr, %gw, %off;
.reg .f32 %idx_f, %grad_val, %dummy;
.reg .pred %p;
ld.param.u64 %go, [grad_output_ptr];
ld.param.u64 %idx_addr, [indices_ptr];
ld.param.u64 %gw, [grad_weight_ptr];
ld.param.u32 %D_reg, [D];
ld.param.u32 %total_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %my_tid, %tid.x;
mad.lo.u32 %my_tid, %bid, %bdim, %my_tid;
setp.ge.u32 %p, %my_tid, %total_reg;
@%p bra DONE;
// row = tid / D, col = tid % D
div.u32 %row, %my_tid, %D_reg;
rem.u32 %col, %my_tid, %D_reg;
// Read grad_output[tid]
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
add.u64 %off, %go, %off;
ld.global.f32 %grad_val, [%off];
// Read indices[row] (f32 -> u32)
cvt.u64.u32 %off, %row;
shl.b64 %off, %off, 2;
add.u64 %off, %idx_addr, %off;
ld.global.f32 %idx_f, [%off];
cvt.rzi.u32.f32 %dst_idx, %idx_f;
// dst_idx = indices[row] * D + col
mad.lo.u32 %dst_idx, %dst_idx, %D_reg, %col;
cvt.u64.u32 %off, %dst_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %gw, %off;
atom.global.add.f32 %dummy, [%off], %grad_val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SLICE_READ_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry slice_read_kernel(
.param .u64 src_ptr,
.param .u64 dst_ptr,
.param .u32 total,
.param .u32 D,
.param .u32 len,
.param .u32 max_len
) {
.reg .u32 %r_tid, %bid, %bdim, %total_reg, %D_reg, %len_reg, %max_len_reg;
.reg .u32 %batch_idx, %within, %row, %col, %src_idx;
.reg .u32 %len_d;
.reg .u64 %src, %dst, %src_off, %dst_off;
.reg .f32 %val;
.reg .pred %p;
ld.param.u64 %src, [src_ptr];
ld.param.u64 %dst, [dst_ptr];
ld.param.u32 %total_reg, [total];
ld.param.u32 %D_reg, [D];
ld.param.u32 %len_reg, [len];
ld.param.u32 %max_len_reg, [max_len];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %total_reg;
@%p bra DONE;
// dst index = r_tid
// batch_idx = r_tid / (len * D)
// within = r_tid % (len * D)
// row = within / D
// col = within % D
// src_idx = batch_idx * max_len * D + row * D + col
mul.lo.u32 %len_d, %len_reg, %D_reg;
div.u32 %batch_idx, %r_tid, %len_d;
rem.u32 %within, %r_tid, %len_d;
div.u32 %row, %within, %D_reg;
rem.u32 %col, %within, %D_reg;
mul.lo.u32 %src_idx, %batch_idx, %max_len_reg;
add.u32 %src_idx, %src_idx, %row;
mul.lo.u32 %src_idx, %src_idx, %D_reg;
add.u32 %src_idx, %src_idx, %col;
cvt.u64.u32 %src_off, %src_idx;
shl.b64 %src_off, %src_off, 2;
add.u64 %src_off, %src, %src_off;
ld.global.f32 %val, [%src_off];
cvt.u64.u32 %dst_off, %r_tid;
shl.b64 %dst_off, %dst_off, 2;
add.u64 %dst_off, %dst, %dst_off;
st.global.f32 [%dst_off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f32 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%in];
mov.f32 %k, 0f3FDA2720;
mul.f32 %neg_kx, %k, %x;
neg.f32 %neg_kx, %neg_kx;
mul.f32 %neg_kx, %neg_kx, 0f3FB8AA3B;
ex2.approx.f32 %exp_neg, %neg_kx;
mov.f32 %one, 0f3F800000;
add.f32 %denom, %one, %exp_neg;
rcp.approx.f32 %sig, %denom;
mul.f32 %result, %x, %sig;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_f64_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f64 %x, %neg_kx, %exp_neg, %one, %denom, %sig, %result, %k;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%in];
mov.f64 %one, 0d3FF0000000000000;
// k = 1.702
mov.f64 %k, 0d3FFB44E400000000;
mul.f64 %neg_kx, %k, %x;
neg.f64 %neg_kx, %neg_kx;
// --- exp(%neg_kx) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg_kx, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_kx;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %exp_neg, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %exp_neg, %exp_neg, %e_nf;
// --- end exp ---
add.f64 %denom, %one, %exp_neg;
div.rn.f64 %sig, %one, %denom;
mul.f64 %result, %x, %sig;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_TANH_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_tanh_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f32 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
.reg .f32 %e2y_m1, %e2y_p1, %th, %one, %half, %log2e, %result;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%in];
// inner = sqrt(2/pi) * (x + 0.044715 * x^3)
// sqrt(2/pi) = 0.7978845608 = 0x3F4C422A
// 0.044715 = 0x3D372713
mul.f32 %x3, %x, %x;
mul.f32 %x3, %x3, %x;
mov.f32 %c, 0f3D372713;
mul.f32 %x3, %c, %x3;
add.f32 %inner, %x, %x3;
mov.f32 %sqrt2pi, 0f3F4C422A;
mul.f32 %y, %sqrt2pi, %inner;
// tanh(y) = (exp(2y) - 1) / (exp(2y) + 1)
// exp(2y) = 2^(2y * log2(e))
mov.f32 %log2e, 0f3FB8AA3B;
add.f32 %two_y, %y, %y;
mul.f32 %two_y, %two_y, %log2e;
ex2.approx.f32 %e2y, %two_y;
mov.f32 %one, 0f3F800000;
sub.f32 %e2y_m1, %e2y, %one;
add.f32 %e2y_p1, %e2y, %one;
rcp.approx.f32 %e2y_p1, %e2y_p1;
mul.f32 %th, %e2y_m1, %e2y_p1;
// out = 0.5 * x * (1 + tanh)
add.f32 %th, %one, %th;
mov.f32 %half, 0f3F000000;
mul.f32 %result, %half, %x;
mul.f32 %result, %result, %th;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_TANH_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_tanh_f64_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f64 %x, %x3, %inner, %sqrt2pi, %c, %y, %two_y, %e2y;
.reg .f64 %e2y_m1, %e2y_p1, %th, %one, %half, %result;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%in];
mov.f64 %one, 0d3FF0000000000000;
// inner = sqrt(2/pi) * (x + 0.044715 * x^3)
mul.f64 %x3, %x, %x;
mul.f64 %x3, %x3, %x;
mov.f64 %c, 0d3FA6E4E260000000;
mul.f64 %x3, %c, %x3;
add.f64 %inner, %x, %x3;
mov.f64 %sqrt2pi, 0d3FE9884540000000;
mul.f64 %y, %sqrt2pi, %inner;
// tanh(y) = (exp(2y)-1)/(exp(2y)+1), exp(2y) in full f64
add.f64 %two_y, %y, %y;
// --- exp(%two_y) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %two_y, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_y;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e2y, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e2y, %e2y, %e_nf;
// --- end exp ---
sub.f64 %e2y_m1, %e2y, %one;
add.f64 %e2y_p1, %e2y, %one;
div.rn.f64 %th, %e2y_m1, %e2y_p1;
// out = 0.5 * x * (1 + tanh)
add.f64 %th, %one, %th;
mov.f64 %half, 0d3FE0000000000000;
mul.f64 %result, %half, %x;
mul.f64 %result, %result, %th;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_ERF_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_erf_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f32 %x, %z, %ax, %one, %half, %log2e;
.reg .f32 %t, %pt, %z2, %neg_z2, %exp_neg_z2, %erf_val;
.reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %result;
.reg .pred %pred_ge, %pred_neg;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %pred_ge, %r_tid, %n_reg;
@%pred_ge bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%in];
mov.f32 %one, 0f3F800000;
mov.f32 %half, 0f3F000000;
mov.f32 %log2e, 0f3FB8AA3B;
// z = x / sqrt(2) = x * 0.70710678
mov.f32 %z, 0f3F3504F3;
mul.f32 %z, %x, %z;
// |z| for erf(|z|)
abs.f32 %ax, %z;
// t = 1 / (1 + 0.3275911 * |z|)
mov.f32 %p, 0f3EA7BA05;
mul.f32 %t, %p, %ax;
add.f32 %t, %one, %t;
rcp.approx.f32 %t, %t;
// Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
// A&S 7.1.26 coefficients (|err(erf)| < 1.5e-7 over x in [0, +inf)):
// a1 = 0.254829592 -> 0f3E827906
// a2 = -0.284496736 -> 0fBE91A98E
// a3 = 1.421413741 -> 0f3FB5F0E3
// a4 = -1.453152027 -> 0fBFBA00E3
// a5 = 1.061405429 -> 0f3F87DC22
// The pre-#799 build held a corrupted set of constants here whose
// polynomial happened to alias to a different (much worse) curve;
// residual was ~1.25e-2 against PyTorch on the conformance fixture
// (well outside F32_TRANSCENDENTAL_GPU = 1e-4).
mov.f32 %a5, 0f3F87DC22;
mov.f32 %a4, 0fBFBA00E3;
mov.f32 %a3, 0f3FB5F0E3;
mov.f32 %a2, 0fBE91A98E;
mov.f32 %a1, 0f3E827906;
mul.f32 %pt, %t, %a5;
add.f32 %pt, %pt, %a4;
mul.f32 %pt, %pt, %t;
add.f32 %pt, %pt, %a3;
mul.f32 %pt, %pt, %t;
add.f32 %pt, %pt, %a2;
mul.f32 %pt, %pt, %t;
add.f32 %pt, %pt, %a1;
mul.f32 %pt, %pt, %t;
// exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
mul.f32 %z2, %ax, %ax;
neg.f32 %neg_z2, %z2;
mul.f32 %neg_z2, %neg_z2, %log2e;
ex2.approx.f32 %exp_neg_z2, %neg_z2;
// erf(|z|) = 1 - poly * exp(-z^2)
mul.f32 %erf_val, %pt, %exp_neg_z2;
sub.f32 %erf_val, %one, %erf_val;
// erf(-z) = -erf(z), so sign-correct
setp.lt.f32 %pred_neg, %z, 0f00000000;
@%pred_neg neg.f32 %erf_val, %erf_val;
// out = x * 0.5 * (1 + erf(x/sqrt(2)))
add.f32 %erf_val, %one, %erf_val;
mul.f32 %result, %half, %x;
mul.f32 %result, %result, %erf_val;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_ERF_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_erf_f64_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f64 %x, %z, %ax, %one, %half, %erf_val, %result;
.reg .f64 %s, %num, %den;
.reg .f64 %z32, %dx, %arg1, %arg2, %exp1, %exp2, %r_factor;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits, %ax_bits, %mask_hi;
.reg .pred %pred_ge, %pred_neg, %pred_small, %pred_mid, %pred_far, %pred_ra;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %pred_ge, %r_tid, %n_reg;
@%pred_ge bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%in];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %half, 0d3FE0000000000000;
mov.f64 %e_half, 0d3FE0000000000000;
// z = x / sqrt(2) using the full-precision 1/sqrt(2) constant
// (0d3FE6A09E667F3BCD ~ 0.7071067811865475). The previous
// truncated constant 0d3FE6A09E60000000 introduced ~3e-9 error
// in z which alone exceeded the 1e-10 transcendental gate.
mov.f64 %z, 0d3FE6A09E667F3BCD;
mul.f64 %z, %x, %z;
abs.f64 %ax, %z;
// -- Region select on ax --
// pred_small : ax < 0.84375
// pred_mid : 0.84375 <= ax < 1.25
// pred_far : ax >= 6.0
// otherwise : 1.25 <= ax < 6
setp.lt.f64 %pred_small, %ax, 0d3FEB000000000000;
setp.lt.f64 %pred_mid, %ax, 0d3FF4000000000000;
setp.ge.f64 %pred_far, %ax, 0d4018000000000000;
@%pred_small bra ERF_SMALL;
@%pred_mid bra ERF_MID;
@%pred_far bra ERF_FAR;
bra ERF_TAIL;
// ---------------------------------------------------------------------
// Small-x branch: ax < 0.84375
// r = PP0 + t*(PP1 + t*(PP2 + t*(PP3 + t*PP4))), t = z*z
// s = 1 + t*(QQ1 + t*(QQ2 + t*(QQ3 + t*(QQ4 + t*QQ5))))
// erf(z) = z + z * (r/s)
// fdlibm degenerate-near-zero branch (ax < 2^-28) is omitted because
// for our gelu_with(None) path z is bounded away from zero whenever
// it matters; the small rational is exact to ~1 ulp at z = 0 anyway.
// ---------------------------------------------------------------------
ERF_SMALL:
mul.f64 %s, %z, %z;
// PP Horner (descending): PP4, PP3, PP2, PP1, PP0
mov.f64 %num, 0dBEF8EAD6120016AC; // PP4 = -2.37630166566501626084e-05
fma.rn.f64 %num, %num, %s, 0dBF77A291236668E4; // PP3
fma.rn.f64 %num, %num, %s, 0dBF9D2A51DBD7194F; // PP2
fma.rn.f64 %num, %num, %s, 0dBFD4CD7D691CB913; // PP1
fma.rn.f64 %num, %num, %s, 0d3FC06EBA8214DB68; // PP0
// QQ Horner: QQ5, QQ4, QQ3, QQ2, QQ1, then *s + 1
mov.f64 %den, 0dBED09C4342A26120; // QQ5 = -3.96022827877536812320e-06
fma.rn.f64 %den, %den, %s, 0d3F215DC9221C1A10; // QQ4 = 1.32494738004321644526e-04
fma.rn.f64 %den, %den, %s, 0d3F74D022C4D36B0F; // QQ3 = 5.08130628187576562776e-03
fma.rn.f64 %den, %den, %s, 0d3FB0A54C5536CEBA; // QQ2 = 6.50222499887672944485e-02
fma.rn.f64 %den, %den, %s, 0d3FD97779CDDADC09; // QQ1 = 3.97917223959155352819e-01
fma.rn.f64 %den, %den, %s, %one;
div.rn.f64 %erf_val, %num, %den;
fma.rn.f64 %erf_val, %z, %erf_val, %z;
bra ERF_DONE;
// ---------------------------------------------------------------------
// Mid-x branch: 0.84375 <= ax < 1.25
// s = ax - 1
// p = PA0 + s*(PA1 + s*(PA2 + s*(PA3 + s*(PA4 + s*(PA5 + s*PA6)))))
// q = 1 + s*(QA1 + s*(QA2 + s*(QA3 + s*(QA4 + s*(QA5 + s*QA6)))))
// erf(z) = sign(z) * (ERX + p/q), ERX = 8.45062911510467529297e-01
// ---------------------------------------------------------------------
ERF_MID:
sub.f64 %s, %ax, %one;
mov.f64 %num, 0dBF61BF380A96073F; // PA6 = -2.16637559486879084300e-03
fma.rn.f64 %num, %num, %s, 0d3FA22A36599795EB; // PA5
fma.rn.f64 %num, %num, %s, 0dBFBC63983D3E28EC; // PA4
fma.rn.f64 %num, %num, %s, 0d3FD45FCA805120E4; // PA3
fma.rn.f64 %num, %num, %s, 0dBFD7D240FBB8C3F1; // PA2
fma.rn.f64 %num, %num, %s, 0d3FDA8D00AD92B34D; // PA1
fma.rn.f64 %num, %num, %s, 0dBF6359B8BEF77538; // PA0
mov.f64 %den, 0d3F888B545735151D; // QA6 = 1.19844998467991074170e-02
fma.rn.f64 %den, %den, %s, 0d3F8BEDC26B51DD1C; // QA5
fma.rn.f64 %den, %den, %s, 0d3FC02660E763351F; // QA4
fma.rn.f64 %den, %den, %s, 0d3FB2635CD99FE9A7; // QA3
fma.rn.f64 %den, %den, %s, 0d3FE14AF092EB6F33; // QA2
fma.rn.f64 %den, %den, %s, 0d3FBB3E6618EEE323; // QA1
fma.rn.f64 %den, %den, %s, %one;
div.rn.f64 %erf_val, %num, %den;
add.f64 %erf_val, %erf_val, 0d3FEB0AC160000000; // ERX (truncated form)
// Note: SunPro fdlibm spells ERX as 8.45062911510467529297e-01,
// implemented bit-exactly as 0d3FEB0AC160000000 (the trailing
// 32 mantissa bits are zero by design, mirroring the z32
// truncation used in the tail). The low-bit residual is folded
// into the rational p/q.
setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
@%pred_neg neg.f64 %erf_val, %erf_val;
bra ERF_DONE;
// ---------------------------------------------------------------------
// Far tail: ax >= 6.0 -- erf saturates to +-1 within 1 ulp.
// ---------------------------------------------------------------------
ERF_FAR:
mov.f64 %erf_val, %one;
setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
@%pred_neg neg.f64 %erf_val, %erf_val;
bra ERF_DONE;
// ---------------------------------------------------------------------
// Tail: 1.25 <= ax < 6
// t = 1 / (ax * ax)
// pred_ra = (ax < 1/0.35 ~= 2.857142857)
// if pred_ra: use RA/SA else: use RB/SB
// z32 = ax with low 32 mantissa bits zeroed (bit-truncate)
// factor = exp(-z32*z32 - 0.5625)
// * exp(-(ax - z32) * (ax + z32) + R/S) / ax
// erf(z) = sign(z) * (1 - factor)
// ---------------------------------------------------------------------
ERF_TAIL:
mul.f64 %s, %ax, %ax;
div.rn.f64 %s, %one, %s;
setp.lt.f64 %pred_ra, %ax, 0d4006DB6DB6DB6DB7; // 1 / 0.35 ~= 2.857142857
@%pred_ra bra ERF_TAIL_RA;
// RB / SB (1/0.35 <= ax < 28; we cap earlier at 6)
mov.f64 %num, 0dC07E384E9BDC383F; // RB6 = -4.83519191608651397019e+02
fma.rn.f64 %num, %num, %s, 0dC09004616A2E5992; // RB5
fma.rn.f64 %num, %num, %s, 0dC083EC881375F228; // RB4
fma.rn.f64 %num, %num, %s, 0dC064145D43C5ED98; // RB3
fma.rn.f64 %num, %num, %s, 0dC031C209555F995A; // RB2
fma.rn.f64 %num, %num, %s, 0dBFE993BA70C285DE; // RB1
fma.rn.f64 %num, %num, %s, 0dBF84341239E86F4A; // RB0
mov.f64 %den, 0dC03670E242712D62; // SB7 = -2.24409524465858183362e+01
fma.rn.f64 %den, %den, %s, 0d407DA874E79FE763; // SB6
fma.rn.f64 %den, %den, %s, 0d40A3F219CEDF3BE6; // SB5
fma.rn.f64 %den, %den, %s, 0d40A8FFB7688C246A; // SB4
fma.rn.f64 %den, %den, %s, 0d409802EB189D5118; // SB3
fma.rn.f64 %den, %den, %s, 0d40745CAE221B9F0A; // SB2
fma.rn.f64 %den, %den, %s, 0d403E568B261D5190; // SB1
fma.rn.f64 %den, %den, %s, %one;
bra ERF_TAIL_RS_DONE;
ERF_TAIL_RA:
// RA / SA (1.25 <= ax < 1/0.35)
mov.f64 %num, 0dC023A0EFC69AC25C; // RA7 = -9.81432934416914548592
fma.rn.f64 %num, %num, %s, 0dC054526557E4D2F2; // RA6 = -8.12874355063065934246e+01
fma.rn.f64 %num, %num, %s, 0dC067135CEBCCABB2; // RA5
fma.rn.f64 %num, %num, %s, 0dC0644CB184282266; // RA4
fma.rn.f64 %num, %num, %s, 0dC04F300AE4CBA38D; // RA3
fma.rn.f64 %num, %num, %s, 0dC0251E0441B0E726; // RA2 = -1.05586262253232909814e+01
fma.rn.f64 %num, %num, %s, 0dBFE63416E4BA7360; // RA1
fma.rn.f64 %num, %num, %s, 0dBF843412600D6435; // RA0
mov.f64 %den, 0dBFAEEFF2EE749A62; // SA8 = -6.04244152148580987438e-02
fma.rn.f64 %den, %den, %s, 0d401A47EF8E484A93; // SA7
fma.rn.f64 %den, %den, %s, 0d405B28A3EE48AE2C; // SA6 = 1.08635005541779435134e+02
fma.rn.f64 %den, %den, %s, 0d407AD02157700314; // SA5
fma.rn.f64 %den, %den, %s, 0d40842B1921EC2868; // SA4
fma.rn.f64 %den, %den, %s, 0d407B290DD58A1A71; // SA3
fma.rn.f64 %den, %den, %s, 0d4061350C526AE721; // SA2
fma.rn.f64 %den, %den, %s, 0d4033A6B9BD707687; // SA1 = 1.96512716674392571292e+01
fma.rn.f64 %den, %den, %s, %one;
// fall through
ERF_TAIL_RS_DONE:
div.rn.f64 %s, %num, %den; // R/S
// z32 = bits(ax) & 0xFFFFFFFF00000000 (truncate low 32 mantissa bits)
mov.b64 %ax_bits, %ax;
mov.s64 %mask_hi, -4294967296; // 0xFFFFFFFF00000000 reinterpreted as i64
and.b64 %ax_bits, %ax_bits, %mask_hi;
mov.b64 %z32, %ax_bits;
// arg1 = -z32*z32 - 0.5625
mul.f64 %arg1, %z32, %z32;
neg.f64 %arg1, %arg1;
sub.f64 %arg1, %arg1, 0d3FE2000000000000; // 0.5625
// arg2 = -(ax - z32) * (ax + z32) + R/S
sub.f64 %dx, %ax, %z32;
add.f64 %arg2, %ax, %z32;
mul.f64 %dx, %dx, %arg2;
neg.f64 %dx, %dx;
add.f64 %arg2, %dx, %s;
// -- exp(arg1) via Cody-Waite + degree-11 Horner --
fma.rn.f64 %e_nf, %arg1, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %arg1;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %exp1, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %exp1, %exp1, %e_nf;
// -- exp(arg2) via Cody-Waite + degree-11 Horner --
fma.rn.f64 %e_nf, %arg2, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %arg2;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %exp2, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %exp2, %exp2, %e_nf;
mul.f64 %r_factor, %exp1, %exp2;
div.rn.f64 %r_factor, %r_factor, %ax;
sub.f64 %erf_val, %one, %r_factor;
setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
@%pred_neg neg.f64 %erf_val, %erf_val;
// fall through
ERF_DONE:
add.f64 %erf_val, %one, %erf_val;
mul.f64 %result, %half, %x;
mul.f64 %result, %result, %erf_val;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_BACKWARD_TANH_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_backward_tanh_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
.reg .f32 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half, %log2e;
.reg .f32 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %x, [%input];
mov.f32 %one, 0f3F800000;
mov.f32 %half, 0f3F000000;
mov.f32 %log2e, 0f3FB8AA3B;
mov.f32 %sqrt2pi, 0f3F4C422A;
mov.f32 %c, 0f3D372713;
// 3 * 0.044715 = 0.134145 = 0x3E096B8C
mov.f32 %c3, 0f3E096B8C;
// u = sqrt(2/π) * (x + 0.044715 * x³)
mul.f32 %x2, %x, %x;
mul.f32 %x3, %x2, %x;
mul.f32 %x3, %c, %x3;
add.f32 %inner, %x, %x3;
mul.f32 %y, %sqrt2pi, %inner;
// tanh(y) via exp
add.f32 %two_y, %y, %y;
mul.f32 %two_y, %two_y, %log2e;
ex2.approx.f32 %e2y, %two_y;
sub.f32 %e2y_m1, %e2y, %one;
add.f32 %e2y_p1, %e2y, %one;
rcp.approx.f32 %e2y_p1, %e2y_p1;
mul.f32 %th, %e2y_m1, %e2y_p1;
// d/dx = 0.5*(1+tanh) + 0.5*x*(1-tanh²)*sqrt(2/π)*(1+3*0.044715*x²)
// term1 = 0.5 * (1 + th)
add.f32 %term1, %one, %th;
mul.f32 %term1, %half, %term1;
// (1 - th²)
mul.f32 %th2, %th, %th;
sub.f32 %one_m_th2, %one, %th2;
// d_inner = sqrt(2/π) * (1 + 3*0.044715*x²)
mul.f32 %d_inner, %c3, %x2;
add.f32 %d_inner, %one, %d_inner;
mul.f32 %d_inner, %sqrt2pi, %d_inner;
// term2 = 0.5 * x * (1-th²) * d_inner
mul.f32 %term2, %half, %x;
mul.f32 %term2, %term2, %one_m_th2;
mul.f32 %term2, %term2, %d_inner;
add.f32 %d_gelu, %term1, %term2;
mul.f32 %result, %vg, %d_gelu;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_BACKWARD_TANH_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_backward_tanh_f64_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f64 %vg, %x, %x2, %x3, %inner, %sqrt2pi, %c, %c3, %y;
.reg .f64 %two_y, %e2y, %e2y_m1, %e2y_p1, %th, %one, %half;
.reg .f64 %th2, %one_m_th2, %d_inner, %term1, %term2, %d_gelu, %result;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f64 %vg, [%grad];
ld.global.f64 %x, [%input];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %half, 0d3FE0000000000000;
mov.f64 %sqrt2pi, 0d3FE9884540000000;
mov.f64 %c, 0d3FA6E4E260000000;
// 3 * 0.044715 = 0.134145
mov.f64 %c3, 0d3FC12D7180000000;
mul.f64 %x2, %x, %x;
mul.f64 %x3, %x2, %x;
mul.f64 %x3, %c, %x3;
add.f64 %inner, %x, %x3;
mul.f64 %y, %sqrt2pi, %inner;
// tanh(y) = (exp(2y)-1)/(exp(2y)+1) in full f64
add.f64 %two_y, %y, %y;
// --- exp(%two_y) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %two_y, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_y;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e2y, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e2y, %e2y, %e_nf;
// --- end exp ---
sub.f64 %e2y_m1, %e2y, %one;
add.f64 %e2y_p1, %e2y, %one;
div.rn.f64 %th, %e2y_m1, %e2y_p1;
add.f64 %term1, %one, %th;
mul.f64 %term1, %half, %term1;
mul.f64 %th2, %th, %th;
sub.f64 %one_m_th2, %one, %th2;
mul.f64 %d_inner, %c3, %x2;
add.f64 %d_inner, %one, %d_inner;
mul.f64 %d_inner, %sqrt2pi, %d_inner;
mul.f64 %term2, %half, %x;
mul.f64 %term2, %term2, %one_m_th2;
mul.f64 %term2, %term2, %d_inner;
add.f64 %d_gelu, %term1, %term2;
mul.f64 %result, %vg, %d_gelu;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SILU_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry silu_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %x, %neg, %e, %denom, %sig, %vr, %one, %lg2e;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%a];
// sigmoid(x) = 1 / (1 + exp(-x))
// exp(-x) = 2^(-x * log2(e))
mov.f32 %one, 0f3F800000;
mov.f32 %lg2e, 0f3FB8AA3B;
neg.f32 %neg, %x;
mul.f32 %neg, %neg, %lg2e;
ex2.approx.f32 %e, %neg;
add.f32 %denom, %one, %e;
rcp.approx.f32 %sig, %denom;
// silu(x) = x * sigmoid(x)
mul.f32 %vr, %x, %sig;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SILU_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry silu_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %x, %neg_x, %e, %denom, %sig, %vr, %one;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%a];
mov.f64 %one, 0d3FF0000000000000;
neg.f64 %neg_x, %x;
// --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e, %e, %e_nf;
// --- end exp ---
add.f64 %denom, %one, %e;
div.rn.f64 %sig, %one, %denom;
mul.f64 %vr, %x, %sig;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SILU_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry silu_backward_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %x, %neg, %e, %denom, %sig, %one, %lg2e;
.reg .f32 %one_m_sig, %x_sig_omsig, %deriv, %result;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %x, [%input];
// sig = sigmoid(x) = 1 / (1 + exp(-x))
mov.f32 %one, 0f3F800000;
mov.f32 %lg2e, 0f3FB8AA3B;
neg.f32 %neg, %x;
mul.f32 %neg, %neg, %lg2e;
ex2.approx.f32 %e, %neg;
add.f32 %denom, %one, %e;
rcp.approx.f32 %sig, %denom;
// deriv = sig + x * sig * (1 - sig)
sub.f32 %one_m_sig, %one, %sig;
mul.f32 %x_sig_omsig, %x, %sig;
mul.f32 %x_sig_omsig, %x_sig_omsig, %one_m_sig;
add.f32 %deriv, %sig, %x_sig_omsig;
mul.f32 %result, %vg, %deriv;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SILU_BACKWARD_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry silu_backward_f64_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f64 %vg, %x, %neg_x, %e, %denom, %sig, %one;
.reg .f64 %one_m_sig, %x_sig_omsig, %deriv, %result;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f64 %vg, [%grad];
ld.global.f64 %x, [%input];
mov.f64 %one, 0d3FF0000000000000;
neg.f64 %neg_x, %x;
// --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e, %e, %e_nf;
// --- end exp ---
add.f64 %denom, %one, %e;
div.rn.f64 %sig, %one, %denom;
sub.f64 %one_m_sig, %one, %sig;
mul.f64 %x_sig_omsig, %x, %sig;
mul.f64 %x_sig_omsig, %x_sig_omsig, %one_m_sig;
add.f64 %deriv, %sig, %x_sig_omsig;
mul.f64 %result, %vg, %deriv;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const ELU_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry elu_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n,
.param .f32 alpha
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %x, %alpha_r, %lg2e, %one, %ex, %em1, %neg_branch, %vr;
.reg .pred %p, %pos;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
ld.param.f32 %alpha_r, [alpha];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%a];
mov.f32 %one, 0f3F800000;
mov.f32 %lg2e, 0f3FB8AA3B;
// exp(x) = 2^(x * log2(e))
mul.f32 %ex, %x, %lg2e;
ex2.approx.f32 %ex, %ex;
sub.f32 %em1, %ex, %one;
mul.f32 %neg_branch, %alpha_r, %em1;
// x > 0 ? x : alpha*(exp(x)-1)
mov.f32 %vr, 0f00000000;
setp.gt.f32 %pos, %x, %vr;
selp.f32 %vr, %x, %neg_branch, %pos;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const ELU_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry elu_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n,
.param .f64 alpha
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %x, %alpha_r, %one, %ex, %em1, %neg_branch, %vr;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p, %pos;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
ld.param.f64 %alpha_r, [alpha];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%a];
mov.f64 %one, 0d3FF0000000000000;
// --- exp(%x) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %ex, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %ex, %ex, %e_nf;
// --- end exp ---
sub.f64 %em1, %ex, %one;
mul.f64 %neg_branch, %alpha_r, %em1;
mov.f64 %vr, 0d0000000000000000;
setp.gt.f64 %pos, %x, %vr;
selp.f64 %vr, %x, %neg_branch, %pos;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const ELU_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry elu_backward_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n,
.param .f32 alpha
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %x, %alpha_r, %lg2e, %ex, %neg_branch, %vr, %zero;
.reg .pred %p, %pos;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
ld.param.f32 %alpha_r, [alpha];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %x, [%input];
mov.f32 %lg2e, 0f3FB8AA3B;
mov.f32 %zero, 0f00000000;
// exp(x) = 2^(x * log2(e))
mul.f32 %ex, %x, %lg2e;
ex2.approx.f32 %ex, %ex;
// negative branch: grad * alpha * exp(x)
mul.f32 %neg_branch, %vg, %alpha_r;
mul.f32 %neg_branch, %neg_branch, %ex;
// x > 0 ? grad : grad * alpha * exp(x)
setp.gt.f32 %pos, %x, %zero;
selp.f32 %vr, %vg, %neg_branch, %pos;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const ELU_BACKWARD_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry elu_backward_f64_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n,
.param .f64 alpha
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f64 %vg, %x, %alpha_r, %ex, %neg_branch, %vr, %zero, %one;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p, %pos;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
ld.param.f64 %alpha_r, [alpha];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f64 %vg, [%grad];
ld.global.f64 %x, [%input];
mov.f64 %zero, 0d0000000000000000;
mov.f64 %one, 0d3FF0000000000000;
// --- exp(%x) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %ex, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %ex, %ex, %e_nf;
// --- end exp ---
mul.f64 %neg_branch, %vg, %alpha_r;
mul.f64 %neg_branch, %neg_branch, %ex;
setp.gt.f64 %pos, %x, %zero;
selp.f64 %vr, %vg, %neg_branch, %pos;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MISH_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mish_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
.reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %th, %vr;
.reg .f32 %threshold;
.reg .pred %p, %large;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%a];
mov.f32 %one, 0f3F800000;
mov.f32 %lg2e, 0f3FB8AA3B;
// threshold = 20.0 = 0x41A00000
mov.f32 %threshold, 0f41A00000;
// softplus(x) = ln(1 + exp(x))
// For large x (> 20), softplus ~ x to avoid overflow
setp.gt.f32 %large, %x, %threshold;
@%large bra LARGE_X;
// exp(x) = 2^(x * log2(e))
mul.f32 %ex, %x, %lg2e;
ex2.approx.f32 %ex, %ex;
add.f32 %ep1, %ex, %one;
// ln(1+exp(x)) = log2(1+exp(x)) / log2(e)
lg2.approx.f32 %lg_ep1, %ep1;
// 1/log2(e) = ln(2) = 0.6931472 = 0x3F317218
mul.f32 %sp, %lg_ep1, 0f3F317218;
// tanh(sp) = (exp(2*sp) - 1) / (exp(2*sp) + 1)
add.f32 %two_sp, %sp, %sp;
mul.f32 %two_sp, %two_sp, %lg2e;
ex2.approx.f32 %e2sp, %two_sp;
sub.f32 %e2sp_m1, %e2sp, %one;
add.f32 %e2sp_p1, %e2sp, %one;
rcp.approx.f32 %e2sp_p1, %e2sp_p1;
mul.f32 %th, %e2sp_m1, %e2sp_p1;
mul.f32 %vr, %x, %th;
st.global.f32 [%out], %vr;
bra DONE;
LARGE_X:
// softplus ~ x, mish ~ x * tanh(x)
// tanh(x) = (exp(2x)-1)/(exp(2x)+1)
add.f32 %two_sp, %x, %x;
mul.f32 %two_sp, %two_sp, %lg2e;
ex2.approx.f32 %e2sp, %two_sp;
sub.f32 %e2sp_m1, %e2sp, %one;
add.f32 %e2sp_p1, %e2sp, %one;
rcp.approx.f32 %e2sp_p1, %e2sp_p1;
mul.f32 %th, %e2sp_m1, %e2sp_p1;
mul.f32 %vr, %x, %th;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MISH_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mish_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %x, %one, %two, %ex, %ep1, %sp;
.reg .f64 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %th, %vr;
.reg .f64 %threshold;
// exp subroutine regs
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
// log subroutine regs
.reg .u64 %l_xbits, %l_mbits, %l_bias;
.reg .s64 %l_exp64;
.reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf;
.reg .f64 %l_ln2_hi, %l_ln2_lo, %l_sqrt2, %l_half_const;
.reg .pred %p, %large, %p_shift;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%a];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %two, 0d4000000000000000;
mov.f64 %threshold, 0d4034000000000000;
setp.gt.f64 %large, %x, %threshold;
@%large bra LARGE_X;
// === softplus: sp = ln(1 + exp(x)) ===
// exp(x)
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %ex, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %ex, %ex, %e_nf;
// ep1 = 1 + exp(x)
add.f64 %ep1, %ex, %one;
// ln(ep1) via half-step argument reduction (#783 cluster fix)
mov.b64 %l_xbits, %ep1;
shr.u64 %l_exp64, %l_xbits, 52;
and.b64 %l_exp64, %l_exp64, 2047;
sub.s64 %l_exp64, %l_exp64, 1023;
cvt.rn.f64.s64 %l_nf, %l_exp64;
mov.u64 %l_bias, 0x3FF0000000000000;
and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
or.b64 %l_mbits, %l_mbits, %l_bias;
mov.b64 %l_m, %l_mbits;
// Half-step: m > sqrt(2) -> m/=2, n+=1
mov.f64 %l_sqrt2, 0d3FF6A09E667F3BCD;
mov.f64 %l_half_const, 0d3FE0000000000000;
setp.gt.f64 %p_shift, %l_m, %l_sqrt2;
@%p_shift mul.f64 %l_m, %l_m, %l_half_const;
@%p_shift add.f64 %l_nf, %l_nf, %one;
sub.f64 %l_f, %l_m, %one;
add.f64 %l_s, %l_m, %one;
div.rn.f64 %l_f, %l_f, %l_s;
mul.f64 %l_f2, %l_f, %l_f;
// Degree-7 odd-power Horner (was degree-5; #783 cluster fix).
mov.f64 %l_p, 0d3FB1111111111111; // 1/15
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FB3B13B13B13B14; // 1/13
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FB745D1745D1746; // 1/11
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FBC71C71C71C71C; // 1/9
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492; // 1/7
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A; // 1/5
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555; // 1/3
fma.rn.f64 %l_p, %l_p, %l_f2, %one;
mul.f64 %l_p, %l_p, %l_f;
add.f64 %l_p, %l_p, %l_p;
mov.f64 %l_ln2_hi, 0d3FE62E42FEFA3800;
mov.f64 %l_ln2_lo, 0d3D2EF35793C76730;
fma.rn.f64 %sp, %l_nf, %l_ln2_hi, %l_p;
fma.rn.f64 %sp, %l_nf, %l_ln2_lo, %sp;
// === tanh(sp) = (exp(2*sp)-1)/(exp(2*sp)+1) ===
add.f64 %two_sp, %sp, %sp;
fma.rn.f64 %e_nf, %two_sp, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e2sp, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e2sp, %e2sp, %e_nf;
sub.f64 %e2sp_m1, %e2sp, %one;
add.f64 %e2sp_p1, %e2sp, %one;
div.rn.f64 %th, %e2sp_m1, %e2sp_p1;
mul.f64 %vr, %x, %th;
st.global.f64 [%out], %vr;
bra DONE;
LARGE_X:
// softplus ~ x, tanh(x) = (exp(2x)-1)/(exp(2x)+1) in f64
add.f64 %two_sp, %x, %x;
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %two_sp, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e2sp, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e2sp, %e2sp, %e_nf;
sub.f64 %e2sp_m1, %e2sp, %one;
add.f64 %e2sp_p1, %e2sp, %one;
div.rn.f64 %th, %e2sp_m1, %e2sp_p1;
mul.f64 %vr, %x, %th;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MISH_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mish_backward_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %x, %lg2e, %one, %ex, %ep1, %sp, %lg_ep1;
.reg .f32 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %t, %t2, %one_m_t2;
.reg .f32 %neg, %en, %denom, %sig, %x_sig_omt2, %deriv, %result;
.reg .f32 %threshold;
.reg .pred %p, %large;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %x, [%input];
mov.f32 %one, 0f3F800000;
mov.f32 %lg2e, 0f3FB8AA3B;
// threshold = 20.0
mov.f32 %threshold, 0f41A00000;
setp.gt.f32 %large, %x, %threshold;
@%large bra LARGE_X;
// --- Normal path ---
// softplus: sp = ln(1 + exp(x))
mul.f32 %ex, %x, %lg2e;
ex2.approx.f32 %ex, %ex;
add.f32 %ep1, %ex, %one;
lg2.approx.f32 %lg_ep1, %ep1;
// ln(2) = 0x3F317218
mul.f32 %sp, %lg_ep1, 0f3F317218;
// t = tanh(sp) = (exp(2*sp)-1)/(exp(2*sp)+1)
add.f32 %two_sp, %sp, %sp;
mul.f32 %two_sp, %two_sp, %lg2e;
ex2.approx.f32 %e2sp, %two_sp;
sub.f32 %e2sp_m1, %e2sp, %one;
add.f32 %e2sp_p1, %e2sp, %one;
rcp.approx.f32 %e2sp_p1, %e2sp_p1;
mul.f32 %t, %e2sp_m1, %e2sp_p1;
// sig = sigmoid(x) = 1/(1+exp(-x))
neg.f32 %neg, %x;
mul.f32 %neg, %neg, %lg2e;
ex2.approx.f32 %en, %neg;
add.f32 %denom, %one, %en;
rcp.approx.f32 %sig, %denom;
// deriv = t + x * sig * (1 - t*t)
mul.f32 %t2, %t, %t;
sub.f32 %one_m_t2, %one, %t2;
mul.f32 %x_sig_omt2, %x, %sig;
mul.f32 %x_sig_omt2, %x_sig_omt2, %one_m_t2;
add.f32 %deriv, %t, %x_sig_omt2;
mul.f32 %result, %vg, %deriv;
st.global.f32 [%out], %result;
bra DONE;
LARGE_X:
// sp ~ x, t ~ tanh(x), sig ~ 1
// tanh(x) = (exp(2x)-1)/(exp(2x)+1)
add.f32 %two_sp, %x, %x;
mul.f32 %two_sp, %two_sp, %lg2e;
ex2.approx.f32 %e2sp, %two_sp;
sub.f32 %e2sp_m1, %e2sp, %one;
add.f32 %e2sp_p1, %e2sp, %one;
rcp.approx.f32 %e2sp_p1, %e2sp_p1;
mul.f32 %t, %e2sp_m1, %e2sp_p1;
// sig ~ 1, deriv ~ t + x*(1-t*t)
mul.f32 %t2, %t, %t;
sub.f32 %one_m_t2, %one, %t2;
mul.f32 %x_sig_omt2, %x, %one_m_t2;
add.f32 %deriv, %t, %x_sig_omt2;
mul.f32 %result, %vg, %deriv;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MISH_BACKWARD_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mish_backward_f64_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f64 %vg, %x, %one, %ex, %ep1, %sp;
.reg .f64 %two_sp, %e2sp, %e2sp_m1, %e2sp_p1, %t, %t2, %one_m_t2;
.reg .f64 %neg_x, %en, %denom, %sig, %x_sig_omt2, %deriv, %result;
.reg .f64 %threshold;
// exp subroutine regs
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
// log subroutine regs
.reg .u64 %l_xbits, %l_mbits, %l_bias;
.reg .s64 %l_exp64;
.reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf;
.reg .f64 %l_ln2_hi, %l_ln2_lo, %l_sqrt2, %l_half_const;
.reg .pred %p, %large, %p_shift;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f64 %vg, [%grad];
ld.global.f64 %x, [%input];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %threshold, 0d4034000000000000;
setp.gt.f64 %large, %x, %threshold;
@%large bra LARGE_X;
// === softplus: sp = ln(1 + exp(x)) ===
// exp(x)
mov.f64 %e_half, 0d3FE0000000000000;
mul.f64 %e_nf, %x, 0d3FF71547652B82FE;
cvt.rni.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %ex, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %ex, %ex, %e_nf;
add.f64 %ep1, %ex, %one;
// ln(ep1) via half-step argument reduction (#783 cluster fix)
mov.b64 %l_xbits, %ep1;
shr.u64 %l_exp64, %l_xbits, 52;
and.b64 %l_exp64, %l_exp64, 2047;
sub.s64 %l_exp64, %l_exp64, 1023;
cvt.rn.f64.s64 %l_nf, %l_exp64;
mov.u64 %l_bias, 0x3FF0000000000000;
and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
or.b64 %l_mbits, %l_mbits, %l_bias;
mov.b64 %l_m, %l_mbits;
mov.f64 %l_sqrt2, 0d3FF6A09E667F3BCD;
mov.f64 %l_half_const, 0d3FE0000000000000;
setp.gt.f64 %p_shift, %l_m, %l_sqrt2;
@%p_shift mul.f64 %l_m, %l_m, %l_half_const;
@%p_shift add.f64 %l_nf, %l_nf, %one;
sub.f64 %l_f, %l_m, %one;
add.f64 %l_s, %l_m, %one;
div.rn.f64 %l_f, %l_f, %l_s;
mul.f64 %l_f2, %l_f, %l_f;
// Degree-7 odd-power Horner (was degree-5; #783 cluster fix).
mov.f64 %l_p, 0d3FB1111111111111; // 1/15
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FB3B13B13B13B14; // 1/13
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FB745D1745D1746; // 1/11
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FBC71C71C71C71C; // 1/9
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492; // 1/7
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A; // 1/5
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555; // 1/3
fma.rn.f64 %l_p, %l_p, %l_f2, %one;
mul.f64 %l_p, %l_p, %l_f;
add.f64 %l_p, %l_p, %l_p;
mov.f64 %l_ln2_hi, 0d3FE62E42FEFA3800;
mov.f64 %l_ln2_lo, 0d3D2EF35793C76730;
fma.rn.f64 %sp, %l_nf, %l_ln2_hi, %l_p;
fma.rn.f64 %sp, %l_nf, %l_ln2_lo, %sp;
// === tanh(sp) ===
add.f64 %two_sp, %sp, %sp;
mul.f64 %e_nf, %two_sp, 0d3FF71547652B82FE;
cvt.rni.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e2sp, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e2sp, %e2sp, %e_nf;
sub.f64 %e2sp_m1, %e2sp, %one;
add.f64 %e2sp_p1, %e2sp, %one;
div.rn.f64 %t, %e2sp_m1, %e2sp_p1;
// === sigmoid(x) = 1/(1+exp(-x)) ===
neg.f64 %neg_x, %x;
mul.f64 %e_nf, %neg_x, 0d3FF71547652B82FE;
cvt.rni.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %en, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %en, %en, %e_nf;
add.f64 %denom, %one, %en;
div.rn.f64 %sig, %one, %denom;
// deriv = t + x * sig * (1 - t*t)
mul.f64 %t2, %t, %t;
sub.f64 %one_m_t2, %one, %t2;
mul.f64 %x_sig_omt2, %x, %sig;
mul.f64 %x_sig_omt2, %x_sig_omt2, %one_m_t2;
add.f64 %deriv, %t, %x_sig_omt2;
mul.f64 %result, %vg, %deriv;
st.global.f64 [%out], %result;
bra DONE;
LARGE_X:
// sp ~ x, tanh(x) in f64, sig ~ 1
add.f64 %two_sp, %x, %x;
mov.f64 %e_half, 0d3FE0000000000000;
mul.f64 %e_nf, %two_sp, 0d3FF71547652B82FE;
cvt.rni.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %two_sp;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e2sp, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e2sp, %e2sp, %e_nf;
sub.f64 %e2sp_m1, %e2sp, %one;
add.f64 %e2sp_p1, %e2sp, %one;
div.rn.f64 %t, %e2sp_m1, %e2sp_p1;
// sig ~ 1, deriv ~ t + x*(1-t*t)
mul.f64 %t2, %t, %t;
sub.f64 %one_m_t2, %one, %t2;
mul.f64 %x_sig_omt2, %x, %one_m_t2;
add.f64 %deriv, %t, %x_sig_omt2;
mul.f64 %result, %vg, %deriv;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const CLAMP_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry clamp_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n,
.param .f32 min_val,
.param .f32 max_val
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %in, %out, %off;
.reg .f32 %x, %mn, %mx, %result;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
ld.param.f32 %mn, [min_val];
ld.param.f32 %mx, [max_val];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %in, %in, %off;
add.u64 %out, %out, %off;
ld.global.f32 %x, [%in];
max.f32 %result, %x, %mn;
min.f32 %result, %result, %mx;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const FILL_F32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry fill_f32_kernel(
.param .u64 out_ptr,
.param .f32 scalar,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %out, %off;
.reg .f32 %v;
.reg .pred %p;
ld.param.u64 %out, [out_ptr];
ld.param.f32 %v, [scalar];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %out, %out, %off;
st.global.f32 [%out], %v;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const ABS_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry abs_backward_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %vi, %zero, %neg_vg, %tmp, %vr;
.reg .pred %p, %pos, %neg;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %vi, [%input];
mov.f32 %zero, 0f00000000;
neg.f32 %neg_vg, %vg;
// tmp = (vi < 0) ? -vg : 0
setp.lt.f32 %neg, %vi, %zero;
selp.f32 %tmp, %neg_vg, %zero, %neg;
// vr = (vi > 0) ? vg : tmp
setp.gt.f32 %pos, %vi, %zero;
selp.f32 %vr, %vg, %tmp, %pos;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const RELU_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry relu_backward_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %vi, %zero, %vr;
.reg .pred %p, %pos;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %vi, [%input];
mov.f32 %zero, 0f00000000;
setp.gt.f32 %pos, %vi, %zero;
selp.f32 %vr, %vg, %zero, %pos;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_backward_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %x, %k, %kx, %neg_kx, %log2e, %exp_neg, %one, %denom, %sig;
.reg .f32 %one_minus_sig, %kx_sig_oms, %dsig, %result;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %x, [%input];
// sig = sigmoid(1.702 * x)
mov.f32 %k, 0f3FDA2720;
mul.f32 %kx, %k, %x;
neg.f32 %neg_kx, %kx;
mov.f32 %log2e, 0f3FB8AA3B;
mul.f32 %neg_kx, %neg_kx, %log2e;
ex2.approx.f32 %exp_neg, %neg_kx;
mov.f32 %one, 0f3F800000;
add.f32 %denom, %one, %exp_neg;
rcp.approx.f32 %sig, %denom;
// d/dx gelu(x) = sig + k * x * sig * (1 - sig)
sub.f32 %one_minus_sig, %one, %sig;
mul.f32 %kx_sig_oms, %kx, %sig;
mul.f32 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
add.f32 %dsig, %sig, %kx_sig_oms;
// out = grad * d_gelu
mul.f32 %result, %vg, %dsig;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_BACKWARD_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_backward_f64_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f64 %vg, %x, %k, %kx, %neg_kx, %exp_neg, %one, %denom, %sig;
.reg .f64 %one_minus_sig, %kx_sig_oms, %dsig, %result;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f64 %vg, [%grad];
ld.global.f64 %x, [%input];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %k, 0d3FFB44E400000000;
mul.f64 %kx, %k, %x;
neg.f64 %neg_kx, %kx;
// --- exp(%neg_kx) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg_kx, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_kx;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %exp_neg, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %exp_neg, %exp_neg, %e_nf;
// --- end exp ---
add.f64 %denom, %one, %exp_neg;
div.rn.f64 %sig, %one, %denom;
sub.f64 %one_minus_sig, %one, %sig;
mul.f64 %kx_sig_oms, %kx, %sig;
mul.f64 %kx_sig_oms, %kx_sig_oms, %one_minus_sig;
add.f64 %dsig, %sig, %kx_sig_oms;
mul.f64 %result, %vg, %dsig;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_BACKWARD_ERF_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_backward_erf_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f32 %vg, %x, %ax, %z, %z2, %neg_z2, %exp_neg_z2;
.reg .f32 %t, %pt, %one, %half, %erf_val, %cdf, %pdf;
.reg .f32 %neg_x2h, %exp_neg_x2h, %inv_sqrt_2pi, %x_pdf;
.reg .f32 %d_gelu, %result;
.reg .f32 %p, %a1, %a2, %a3, %a4, %a5, %log2e;
.reg .pred %pred_ge, %pred_neg;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %pred_ge, %r_tid, %n_reg;
@%pred_ge bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %x, [%input];
mov.f32 %one, 0f3F800000;
mov.f32 %half, 0f3F000000;
// z = x / sqrt(2) = x * 0.70710678
mov.f32 %z, 0f3F3504F3;
mul.f32 %z, %x, %z;
// |z| for erf(|z|)
abs.f32 %ax, %z;
// t = 1 / (1 + 0.3275911 * |z|)
mov.f32 %p, 0f3EA7BA05;
mul.f32 %t, %p, %ax;
add.f32 %t, %one, %t;
rcp.approx.f32 %t, %t;
// Horner: poly = t*(a1 + t*(a2 + t*(a3 + t*(a4 + t*a5))))
// A&S 7.1.26 (restored under #799 from corrupted constants).
mov.f32 %a5, 0f3F87DC22;
mov.f32 %a4, 0fBFBA00E3;
mov.f32 %a3, 0f3FB5F0E3;
mov.f32 %a2, 0fBE91A98E;
mov.f32 %a1, 0f3E827906;
mul.f32 %pt, %t, %a5;
add.f32 %pt, %pt, %a4;
mul.f32 %pt, %pt, %t;
add.f32 %pt, %pt, %a3;
mul.f32 %pt, %pt, %t;
add.f32 %pt, %pt, %a2;
mul.f32 %pt, %pt, %t;
add.f32 %pt, %pt, %a1;
mul.f32 %pt, %pt, %t;
// exp(-z^2) via ex2.approx: exp(y) = 2^(y * log2(e))
mul.f32 %z2, %ax, %ax;
neg.f32 %neg_z2, %z2;
mov.f32 %log2e, 0f3FB8AA3B;
mul.f32 %neg_z2, %neg_z2, %log2e;
ex2.approx.f32 %exp_neg_z2, %neg_z2;
// erf(|z|) = 1 - poly * exp(-z^2)
mul.f32 %erf_val, %pt, %exp_neg_z2;
sub.f32 %erf_val, %one, %erf_val;
// erf(-z) = -erf(z), so sign-correct
setp.lt.f32 %pred_neg, %z, 0f00000000;
@%pred_neg neg.f32 %erf_val, %erf_val;
// Phi(x) = 0.5 * (1 + erf(x/sqrt(2)))
add.f32 %cdf, %one, %erf_val;
mul.f32 %cdf, %half, %cdf;
// phi(x) = exp(-x^2/2) / sqrt(2pi)
// exp(-x^2/2):
mul.f32 %neg_x2h, %x, %x;
mul.f32 %neg_x2h, %neg_x2h, %half;
neg.f32 %neg_x2h, %neg_x2h;
mul.f32 %neg_x2h, %neg_x2h, %log2e;
ex2.approx.f32 %exp_neg_x2h, %neg_x2h;
// 1/sqrt(2pi) = 0.39894228
mov.f32 %inv_sqrt_2pi, 0f3ECC4220;
mul.f32 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
// d/dx gelu(x) = Phi(x) + x * phi(x)
mul.f32 %x_pdf, %x, %pdf;
add.f32 %d_gelu, %cdf, %x_pdf;
// out = grad * d_gelu
mul.f32 %result, %vg, %d_gelu;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const GELU_BACKWARD_ERF_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry gelu_backward_erf_f64_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %input, %out, %off;
.reg .f64 %vg, %x, %z, %ax, %one, %half, %erf_val;
.reg .f64 %s, %num, %den;
.reg .f64 %z32, %dx, %arg1, %arg2, %exp1, %exp2, %r_factor;
.reg .f64 %cdf, %pdf, %neg_x2h, %exp_neg_x2h;
.reg .f64 %inv_sqrt_2pi, %x_pdf, %d_gelu, %result;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits, %ax_bits, %mask_hi;
.reg .pred %pred_ge, %pred_neg, %pred_small, %pred_mid, %pred_far, %pred_ra;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %input, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %pred_ge, %r_tid, %n_reg;
@%pred_ge bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %grad, %grad, %off;
add.u64 %input, %input, %off;
add.u64 %out, %out, %off;
ld.global.f64 %vg, [%grad];
ld.global.f64 %x, [%input];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %half, 0d3FE0000000000000;
mov.f64 %e_half, 0d3FE0000000000000;
// z = x / sqrt(2) -- full-precision constant 0d3FE6A09E667F3BCD.
mov.f64 %z, 0d3FE6A09E667F3BCD;
mul.f64 %z, %x, %z;
abs.f64 %ax, %z;
setp.lt.f64 %pred_small, %ax, 0d3FEB000000000000; // 0.84375
setp.lt.f64 %pred_mid, %ax, 0d3FF4000000000000; // 1.25
setp.ge.f64 %pred_far, %ax, 0d4018000000000000; // 6.0
@%pred_small bra B_ERF_SMALL;
@%pred_mid bra B_ERF_MID;
@%pred_far bra B_ERF_FAR;
bra B_ERF_TAIL;
B_ERF_SMALL:
mul.f64 %s, %z, %z;
mov.f64 %num, 0dBEF8EAD6120016AC; // PP4
fma.rn.f64 %num, %num, %s, 0dBF77A291236668E4; // PP3
fma.rn.f64 %num, %num, %s, 0dBF9D2A51DBD7194F; // PP2
fma.rn.f64 %num, %num, %s, 0dBFD4CD7D691CB913; // PP1
fma.rn.f64 %num, %num, %s, 0d3FC06EBA8214DB68; // PP0
mov.f64 %den, 0dBED09C4342A26120; // QQ5
fma.rn.f64 %den, %den, %s, 0d3F215DC9221C1A10; // QQ4
fma.rn.f64 %den, %den, %s, 0d3F74D022C4D36B0F; // QQ3
fma.rn.f64 %den, %den, %s, 0d3FB0A54C5536CEBA; // QQ2
fma.rn.f64 %den, %den, %s, 0d3FD97779CDDADC09; // QQ1
fma.rn.f64 %den, %den, %s, %one;
div.rn.f64 %erf_val, %num, %den;
fma.rn.f64 %erf_val, %z, %erf_val, %z;
bra B_ERF_DONE;
B_ERF_MID:
sub.f64 %s, %ax, %one;
mov.f64 %num, 0dBF61BF380A96073F; // PA6
fma.rn.f64 %num, %num, %s, 0d3FA22A36599795EB; // PA5
fma.rn.f64 %num, %num, %s, 0dBFBC63983D3E28EC; // PA4
fma.rn.f64 %num, %num, %s, 0d3FD45FCA805120E4; // PA3
fma.rn.f64 %num, %num, %s, 0dBFD7D240FBB8C3F1; // PA2
fma.rn.f64 %num, %num, %s, 0d3FDA8D00AD92B34D; // PA1
fma.rn.f64 %num, %num, %s, 0dBF6359B8BEF77538; // PA0
mov.f64 %den, 0d3F888B545735151D; // QA6
fma.rn.f64 %den, %den, %s, 0d3F8BEDC26B51DD1C; // QA5
fma.rn.f64 %den, %den, %s, 0d3FC02660E763351F; // QA4
fma.rn.f64 %den, %den, %s, 0d3FB2635CD99FE9A7; // QA3
fma.rn.f64 %den, %den, %s, 0d3FE14AF092EB6F33; // QA2
fma.rn.f64 %den, %den, %s, 0d3FBB3E6618EEE323; // QA1
fma.rn.f64 %den, %den, %s, %one;
div.rn.f64 %erf_val, %num, %den;
add.f64 %erf_val, %erf_val, 0d3FEB0AC160000000; // ERX
setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
@%pred_neg neg.f64 %erf_val, %erf_val;
bra B_ERF_DONE;
B_ERF_FAR:
mov.f64 %erf_val, %one;
setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
@%pred_neg neg.f64 %erf_val, %erf_val;
bra B_ERF_DONE;
B_ERF_TAIL:
mul.f64 %s, %ax, %ax;
div.rn.f64 %s, %one, %s;
setp.lt.f64 %pred_ra, %ax, 0d4006DB6DB6DB6DB7; // 1/0.35
@%pred_ra bra B_ERF_TAIL_RA;
// RB / SB
mov.f64 %num, 0dC07E384E9BDC383F; // RB6
fma.rn.f64 %num, %num, %s, 0dC09004616A2E5992; // RB5
fma.rn.f64 %num, %num, %s, 0dC083EC881375F228; // RB4
fma.rn.f64 %num, %num, %s, 0dC064145D43C5ED98; // RB3
fma.rn.f64 %num, %num, %s, 0dC031C209555F995A; // RB2
fma.rn.f64 %num, %num, %s, 0dBFE993BA70C285DE; // RB1
fma.rn.f64 %num, %num, %s, 0dBF84341239E86F4A; // RB0
mov.f64 %den, 0dC03670E242712D62; // SB7
fma.rn.f64 %den, %den, %s, 0d407DA874E79FE763; // SB6
fma.rn.f64 %den, %den, %s, 0d40A3F219CEDF3BE6; // SB5
fma.rn.f64 %den, %den, %s, 0d40A8FFB7688C246A; // SB4
fma.rn.f64 %den, %den, %s, 0d409802EB189D5118; // SB3
fma.rn.f64 %den, %den, %s, 0d40745CAE221B9F0A; // SB2
fma.rn.f64 %den, %den, %s, 0d403E568B261D5190; // SB1
fma.rn.f64 %den, %den, %s, %one;
bra B_ERF_TAIL_RS_DONE;
B_ERF_TAIL_RA:
// RA / SA
mov.f64 %num, 0dC023A0EFC69AC25C; // RA7
fma.rn.f64 %num, %num, %s, 0dC054526557E4D2F2; // RA6
fma.rn.f64 %num, %num, %s, 0dC067135CEBCCABB2; // RA5
fma.rn.f64 %num, %num, %s, 0dC0644CB184282266; // RA4
fma.rn.f64 %num, %num, %s, 0dC04F300AE4CBA38D; // RA3
fma.rn.f64 %num, %num, %s, 0dC0251E0441B0E726; // RA2
fma.rn.f64 %num, %num, %s, 0dBFE63416E4BA7360; // RA1
fma.rn.f64 %num, %num, %s, 0dBF843412600D6435; // RA0
mov.f64 %den, 0dBFAEEFF2EE749A62; // SA8
fma.rn.f64 %den, %den, %s, 0d401A47EF8E484A93; // SA7
fma.rn.f64 %den, %den, %s, 0d405B28A3EE48AE2C; // SA6
fma.rn.f64 %den, %den, %s, 0d407AD02157700314; // SA5
fma.rn.f64 %den, %den, %s, 0d40842B1921EC2868; // SA4
fma.rn.f64 %den, %den, %s, 0d407B290DD58A1A71; // SA3
fma.rn.f64 %den, %den, %s, 0d4061350C526AE721; // SA2
fma.rn.f64 %den, %den, %s, 0d4033A6B9BD707687; // SA1
fma.rn.f64 %den, %den, %s, %one;
// fall through
B_ERF_TAIL_RS_DONE:
div.rn.f64 %s, %num, %den;
// z32 = bit-truncated ax (low 32 bits zeroed)
mov.b64 %ax_bits, %ax;
mov.s64 %mask_hi, -4294967296;
and.b64 %ax_bits, %ax_bits, %mask_hi;
mov.b64 %z32, %ax_bits;
mul.f64 %arg1, %z32, %z32;
neg.f64 %arg1, %arg1;
sub.f64 %arg1, %arg1, 0d3FE2000000000000; // 0.5625
sub.f64 %dx, %ax, %z32;
add.f64 %arg2, %ax, %z32;
mul.f64 %dx, %dx, %arg2;
neg.f64 %dx, %dx;
add.f64 %arg2, %dx, %s;
// exp(arg1)
fma.rn.f64 %e_nf, %arg1, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %arg1;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %exp1, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %exp1, %exp1, %e_nf;
// exp(arg2)
fma.rn.f64 %e_nf, %arg2, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %arg2;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %exp2, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %exp2, %exp2, %e_nf;
mul.f64 %r_factor, %exp1, %exp2;
div.rn.f64 %r_factor, %r_factor, %ax;
sub.f64 %erf_val, %one, %r_factor;
setp.lt.f64 %pred_neg, %z, 0d0000000000000000;
@%pred_neg neg.f64 %erf_val, %erf_val;
// fall through
B_ERF_DONE:
add.f64 %cdf, %one, %erf_val;
mul.f64 %cdf, %half, %cdf;
// phi(x) = exp(-x^2/2) / sqrt(2*pi)
mul.f64 %neg_x2h, %x, %x;
mul.f64 %neg_x2h, %neg_x2h, %half;
neg.f64 %neg_x2h, %neg_x2h;
fma.rn.f64 %e_nf, %neg_x2h, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x2h;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %exp_neg_x2h, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %exp_neg_x2h, %exp_neg_x2h, %e_nf;
// 1/sqrt(2*pi) = 0.3989422804014327 = 0d3FD9884533D43651
mov.f64 %inv_sqrt_2pi, 0d3FD9884533D43651;
mul.f64 %pdf, %exp_neg_x2h, %inv_sqrt_2pi;
mul.f64 %x_pdf, %x, %pdf;
add.f64 %d_gelu, %cdf, %x_pdf;
mul.f64 %result, %vg, %d_gelu;
st.global.f64 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const INDEX_SELECT_1D_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry index_select_1d_kernel(
.param .u64 input_ptr,
.param .u64 indices_ptr,
.param .u64 out_ptr,
.param .u32 n_indices
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
.reg .u64 %input, %indices, %out, %off, %addr;
.reg .f32 %idx_f, %val;
.reg .pred %p;
ld.param.u64 %input, [input_ptr];
ld.param.u64 %indices, [indices_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n_indices];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// Byte offset for thread
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
// Read indices[tid] (f32 -> u32)
add.u64 %addr, %indices, %off;
ld.global.f32 %idx_f, [%addr];
cvt.rzi.u32.f32 %idx, %idx_f;
// Read input[idx]
cvt.u64.u32 %addr, %idx;
shl.b64 %addr, %addr, 2;
add.u64 %addr, %input, %addr;
ld.global.f32 %val, [%addr];
// Write output[tid]
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const INDEX_SELECT_DIM_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry index_select_dim_kernel(
.param .u64 input_ptr,
.param .u64 indices_ptr,
.param .u64 out_ptr,
.param .u32 outer,
.param .u32 in_dim_size,
.param .u32 out_dim_size,
.param .u32 inner,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim;
.reg .u32 %outer_r, %in_dim_r, %out_dim_r, %inner_r, %total_r;
.reg .u32 %slab, %o, %rem, %i, %k, %idx;
.reg .u32 %src_flat, %tmp;
.reg .u64 %input, %indices, %out, %off, %addr;
.reg .f32 %idx_f, %val;
.reg .pred %p;
ld.param.u64 %input, [input_ptr];
ld.param.u64 %indices, [indices_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %outer_r, [outer];
ld.param.u32 %in_dim_r, [in_dim_size];
ld.param.u32 %out_dim_r, [out_dim_size];
ld.param.u32 %inner_r, [inner];
ld.param.u32 %total_r, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %total_r;
@%p bra DONE;
// slab = out_dim_size * inner
mul.lo.u32 %slab, %out_dim_r, %inner_r;
// o = t / slab
div.u32 %o, %r_tid, %slab;
// rem = t % slab
rem.u32 %rem, %r_tid, %slab;
// i = rem / inner
div.u32 %i, %rem, %inner_r;
// k = rem % inner
rem.u32 %k, %rem, %inner_r;
// idx = indices[i] (read f32, convert toward zero to u32)
cvt.u64.u32 %off, %i;
shl.b64 %off, %off, 2;
add.u64 %addr, %indices, %off;
ld.global.f32 %idx_f, [%addr];
cvt.rzi.u32.f32 %idx, %idx_f;
// src_flat = o * (in_dim_size * inner) + idx * inner + k
mul.lo.u32 %tmp, %in_dim_r, %inner_r;
mul.lo.u32 %src_flat, %o, %tmp;
mad.lo.u32 %src_flat, %idx, %inner_r, %src_flat;
add.u32 %src_flat, %src_flat, %k;
cvt.u64.u32 %off, %src_flat;
shl.b64 %off, %off, 2;
add.u64 %addr, %input, %off;
ld.global.f32 %val, [%addr];
// Write output[t]
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const INDEX_SELECT_DIM_F64_PTX: &str = "\
.version 7.0
.target sm_60
.address_size 64
.visible .entry index_select_dim_f64_kernel(
.param .u64 input_ptr,
.param .u64 indices_ptr,
.param .u64 out_ptr,
.param .u32 outer,
.param .u32 in_dim_size,
.param .u32 out_dim_size,
.param .u32 inner,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim;
.reg .u32 %outer_r, %in_dim_r, %out_dim_r, %inner_r, %total_r;
.reg .u32 %slab, %o, %rem, %i, %k, %idx;
.reg .u32 %src_flat, %tmp;
.reg .u64 %input, %indices, %out, %off_d, %off_i, %addr;
.reg .f32 %idx_f;
.reg .f64 %val;
.reg .pred %p;
ld.param.u64 %input, [input_ptr];
ld.param.u64 %indices, [indices_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %outer_r, [outer];
ld.param.u32 %in_dim_r, [in_dim_size];
ld.param.u32 %out_dim_r, [out_dim_size];
ld.param.u32 %inner_r, [inner];
ld.param.u32 %total_r, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %total_r;
@%p bra DONE;
mul.lo.u32 %slab, %out_dim_r, %inner_r;
div.u32 %o, %r_tid, %slab;
rem.u32 %rem, %r_tid, %slab;
div.u32 %i, %rem, %inner_r;
rem.u32 %k, %rem, %inner_r;
// idx = indices[i] (f32-encoded, 4-byte stride)
cvt.u64.u32 %off_i, %i;
shl.b64 %off_i, %off_i, 2;
add.u64 %addr, %indices, %off_i;
ld.global.f32 %idx_f, [%addr];
cvt.rzi.u32.f32 %idx, %idx_f;
// src_flat = o * (in_dim_size * inner) + idx * inner + k
mul.lo.u32 %tmp, %in_dim_r, %inner_r;
mul.lo.u32 %src_flat, %o, %tmp;
mad.lo.u32 %src_flat, %idx, %inner_r, %src_flat;
add.u32 %src_flat, %src_flat, %k;
// input[src_flat] (f64, 8-byte stride)
cvt.u64.u32 %off_d, %src_flat;
shl.b64 %off_d, %off_d, 3;
add.u64 %addr, %input, %off_d;
ld.global.f64 %val, [%addr];
// output[t] (f64, 8-byte stride)
cvt.u64.u32 %off_d, %r_tid;
shl.b64 %off_d, %off_d, 3;
add.u64 %addr, %out, %off_d;
st.global.f64 [%addr], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SCATTER_ADD_1D_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry scatter_add_1d_kernel(
.param .u64 grad_output_ptr,
.param .u64 indices_ptr,
.param .u64 grad_input_ptr,
.param .u32 n_indices
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %idx;
.reg .u64 %go, %indices, %gi, %off, %addr;
.reg .f32 %idx_f, %grad_val, %dummy;
.reg .pred %p;
ld.param.u64 %go, [grad_output_ptr];
ld.param.u64 %indices, [indices_ptr];
ld.param.u64 %gi, [grad_input_ptr];
ld.param.u32 %n_reg, [n_indices];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// Byte offset for thread
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
// Read grad_output[tid]
add.u64 %addr, %go, %off;
ld.global.f32 %grad_val, [%addr];
// Read indices[tid] (f32 -> u32)
add.u64 %addr, %indices, %off;
ld.global.f32 %idx_f, [%addr];
cvt.rzi.u32.f32 %idx, %idx_f;
// Atomic add: grad_input[idx] += grad_val
cvt.u64.u32 %addr, %idx;
shl.b64 %addr, %addr, 2;
add.u64 %addr, %gi, %addr;
atom.global.add.f32 %dummy, [%addr], %grad_val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MASKED_FILL_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry masked_fill_kernel(
.param .u64 input_ptr,
.param .u64 mask_ptr,
.param .u64 out_ptr,
.param .f32 fill_value,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %input, %mask, %out, %off;
.reg .f32 %in_val, %mask_val, %fill, %result, %half;
.reg .pred %p, %pmask;
ld.param.u64 %input, [input_ptr];
ld.param.u64 %mask, [mask_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.f32 %fill, [fill_value];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %input, %input, %off;
add.u64 %mask, %mask, %off;
add.u64 %out, %out, %off;
ld.global.f32 %in_val, [%input];
ld.global.f32 %mask_val, [%mask];
mov.f32 %half, 0f3F000000;
setp.ge.f32 %pmask, %mask_val, %half;
selp.f32 %result, %fill, %in_val, %pmask;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MASKED_ZERO_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry masked_zero_kernel(
.param .u64 grad_ptr,
.param .u64 mask_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %mask, %out, %off;
.reg .f32 %vg, %mask_val, %zero, %result, %half;
.reg .pred %p, %pmask;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %mask, [mask_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %mask, %mask, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %mask_val, [%mask];
mov.f32 %zero, 0f00000000;
mov.f32 %half, 0f3F000000;
setp.ge.f32 %pmask, %mask_val, %half;
selp.f32 %result, %zero, %vg, %pmask;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SIGMOID_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sigmoid_backward_kernel(
.param .u64 grad_ptr,
.param .u64 output_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %output, %out, %off;
.reg .f32 %vg, %vo, %one, %one_minus_o, %result;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %output, [output_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %output, %output, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %vo, [%output];
mov.f32 %one, 0f3F800000;
sub.f32 %one_minus_o, %one, %vo;
mul.f32 %result, %vo, %one_minus_o;
mul.f32 %result, %vg, %result;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const TANH_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry tanh_backward_kernel(
.param .u64 grad_ptr,
.param .u64 output_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %grad, %output, %out, %off;
.reg .f32 %vg, %vo, %one, %o_sq, %one_minus_sq, %result;
.reg .pred %p;
ld.param.u64 %grad, [grad_ptr];
ld.param.u64 %output, [output_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %grad, %grad, %off;
add.u64 %output, %output, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%grad];
ld.global.f32 %vo, [%output];
mov.f32 %one, 0f3F800000;
mul.f32 %o_sq, %vo, %vo;
sub.f32 %one_minus_sq, %one, %o_sq;
mul.f32 %result, %vg, %one_minus_sq;
st.global.f32 [%out], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SOFTMAX_BACKWARD_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 4 .f32 sdata[256];\n\
\n\
.visible .entry softmax_backward_kernel(\n\
.param .u64 grad_ptr,\n\
.param .u64 output_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
.reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f32 %vg, %vo, %dot, %other_val, %diff, %result;\n\
.reg .pred %p, %loop_p, %reduce_p;\n\
\n\
ld.param.u64 %grad, [grad_ptr];\n\
ld.param.u64 %output, [output_ptr];\n\
ld.param.u64 %out, [out_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
// row_off = bid * cols * 4 (byte offset)\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 2;\n\
\n\
// Phase 1: compute partial dot = sum(grad[j] * output[j]) for this thread's elements\n\
mov.f32 %dot, 0f00000000;\n\
mov.u32 %j, %r_tid;\n\
DOT_LOOP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra DOT_LOOP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %grad, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %vg, [%saddr];\n\
add.u64 %saddr, %output, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %vo, [%saddr];\n\
fma.rn.f32 %dot, %vg, %vo, %dot;\n\
add.u32 %j, %j, %bdim;\n\
bra DOT_LOOP;\n\
DOT_LOOP_DONE:\n\
\n\
// Store partial dot into shared memory and reduce\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %dot;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
DOT_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra DOT_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra DOT_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %dot, [%saddr];\n\
add.f32 %dot, %dot, %other_val;\n\
st.shared.f32 [%saddr], %dot;\n\
DOT_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra DOT_REDUCE;\n\
DOT_REDUCE_DONE:\n\
\n\
// Broadcast dot to all threads\n\
ld.shared.f32 %dot, [sdata];\n\
bar.sync 0;\n\
\n\
// Phase 2: out[j] = output[j] * (grad[j] - dot)\n\
mov.u32 %j, %r_tid;\n\
WRITE_LOOP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra WRITE_LOOP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %grad, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %vg, [%saddr];\n\
add.u64 %saddr, %output, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %vo, [%saddr];\n\
sub.f32 %diff, %vg, %dot;\n\
mul.f32 %result, %vo, %diff;\n\
add.u64 %saddr, %out, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
st.global.f32 [%saddr], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra WRITE_LOOP;\n\
WRITE_LOOP_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const LOG_SOFTMAX_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 4 .f32 sdata[256];\n\
\n\
.visible .entry log_softmax_kernel(\n\
.param .u64 input_ptr,\n\
.param .u64 output_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
.reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f32 %val, %max_val, %sum_val, %exp_val, %log_sum_exp, %result;\n\
.reg .pred %p, %loop_p;\n\
.reg .u32 %half, %other_tid;\n\
.reg .f32 %other_val;\n\
.reg .pred %reduce_p;\n\
\n\
ld.param.u64 %in, [input_ptr];\n\
ld.param.u64 %out, [output_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
// row_off = bid * cols * 4 (byte offset)\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 2;\n\
\n\
// Phase 1: find max across row (grid-stride over columns)\n\
mov.f32 %max_val, 0fFF800000;\n\
mov.u32 %j, %r_tid;\n\
FIND_MAX:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra FIND_MAX_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f32 %val, [%off];\n\
max.f32 %max_val, %max_val, %val;\n\
add.u32 %j, %j, %bdim;\n\
bra FIND_MAX;\n\
FIND_MAX_DONE:\n\
\n\
// Shared-memory tree reduction for max\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %max_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
MAX_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra MAX_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra MAX_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %max_val, [%saddr];\n\
max.f32 %max_val, %max_val, %other_val;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %max_val;\n\
MAX_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra MAX_REDUCE;\n\
MAX_REDUCE_DONE:\n\
\n\
// Broadcast max to all threads\n\
ld.shared.f32 %max_val, [sdata];\n\
bar.sync 0;\n\
\n\
// Phase 2: compute partial sum of exp(x[j] - max)\n\
mov.f32 %sum_val, 0f00000000;\n\
mov.u32 %j, %r_tid;\n\
SUM_EXP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra SUM_EXP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f32 %val, [%off];\n\
sub.f32 %val, %val, %max_val;\n\
// exp(x) = exp2(x * log2(e)), log2(e) = 0x3FB8AA3B\n\
mul.f32 %val, %val, 0f3FB8AA3B;\n\
ex2.approx.f32 %exp_val, %val;\n\
add.f32 %sum_val, %sum_val, %exp_val;\n\
add.u32 %j, %j, %bdim;\n\
bra SUM_EXP;\n\
SUM_EXP_DONE:\n\
\n\
// Shared-memory tree reduction for sum\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %sum_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
SUM_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra SUM_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra SUM_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %sum_val, [%saddr];\n\
add.f32 %sum_val, %sum_val, %other_val;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %sum_val;\n\
SUM_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra SUM_REDUCE;\n\
SUM_REDUCE_DONE:\n\
\n\
// Broadcast sum to all threads, compute log_sum_exp = max + log(sum)\n\
ld.shared.f32 %sum_val, [sdata];\n\
bar.sync 0;\n\
// log(x) = log2(x) / log2(e) = log2(x) * ln(2)\n\
// ln(2) = 0x3F317218\n\
lg2.approx.f32 %log_sum_exp, %sum_val;\n\
mul.f32 %log_sum_exp, %log_sum_exp, 0f3F317218;\n\
add.f32 %log_sum_exp, %max_val, %log_sum_exp;\n\
\n\
// Phase 3: out[j] = x[j] - log_sum_exp\n\
mov.u32 %j, %r_tid;\n\
WRITE_OUTPUT:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra WRITE_OUTPUT_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %in, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %val, [%saddr];\n\
sub.f32 %result, %val, %log_sum_exp;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %out, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
st.global.f32 [%saddr], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra WRITE_OUTPUT;\n\
WRITE_OUTPUT_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const LOG_SOFTMAX_F64_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 8 .f64 sdata[256];\n\
\n\
.visible .entry log_softmax_f64_kernel(\n\
.param .u64 input_ptr,\n\
.param .u64 output_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
.reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f64 %val, %max_val, %sum_val, %exp_val, %log_sum_exp, %result;\n\
.reg .pred %p, %loop_p;\n\
.reg .u32 %half, %other_tid;\n\
.reg .f64 %other_val;\n\
.reg .pred %reduce_p;\n\
.reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
.reg .s32 %e_ni;\n\
.reg .s64 %e_ni64, %e_bits;\n\
.reg .u64 %l_xbits, %l_mbits, %l_bias;\n\
.reg .s64 %l_exp64;\n\
.reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf;\n\
.reg .f64 %l_ln2_hi, %l_ln2_lo, %l_sqrt2, %l_half_const;\n\
.reg .pred %p_shift;\n\
\n\
ld.param.u64 %in, [input_ptr];\n\
ld.param.u64 %out, [output_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 3;\n\
\n\
mov.f64 %max_val, 0dFFF0000000000000;\n\
mov.u32 %j, %r_tid;\n\
FIND_MAX:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra FIND_MAX_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f64 %val, [%off];\n\
max.f64 %max_val, %max_val, %val;\n\
add.u32 %j, %j, %bdim;\n\
bra FIND_MAX;\n\
FIND_MAX_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f64 [%saddr], %max_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
MAX_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra MAX_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra MAX_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %max_val, [%saddr];\n\
max.f64 %max_val, %max_val, %other_val;\n\
st.shared.f64 [%saddr], %max_val;\n\
MAX_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra MAX_REDUCE;\n\
MAX_REDUCE_DONE:\n\
\n\
ld.shared.f64 %max_val, [sdata];\n\
bar.sync 0;\n\
\n\
mov.f64 %sum_val, 0d0000000000000000;\n\
mov.u32 %j, %r_tid;\n\
SUM_EXP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra SUM_EXP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f64 %val, [%off];\n\
sub.f64 %val, %val, %max_val;\n\
mov.f64 %e_one, 0d3FF0000000000000;\n\
mov.f64 %e_half, 0d3FE0000000000000;\n\
mul.f64 %e_nf, %val, 0d3FF71547652B82FE;\n\
cvt.rni.f64.f64 %e_nf, %e_nf;\n\
cvt.rni.s32.f64 %e_ni, %e_nf;\n\
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %val;\n\
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
mov.f64 %e_p, 0d3E5AE64567F544E4;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
fma.rn.f64 %exp_val, %e_p, %e_r, %e_one;\n\
cvt.s64.s32 %e_ni64, %e_ni;\n\
add.s64 %e_ni64, %e_ni64, 1023;\n\
shl.b64 %e_bits, %e_ni64, 52;\n\
mov.b64 %e_nf, %e_bits;\n\
mul.f64 %exp_val, %exp_val, %e_nf;\n\
add.f64 %sum_val, %sum_val, %exp_val;\n\
add.u32 %j, %j, %bdim;\n\
bra SUM_EXP;\n\
SUM_EXP_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f64 [%saddr], %sum_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
SUM_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra SUM_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra SUM_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %sum_val, [%saddr];\n\
add.f64 %sum_val, %sum_val, %other_val;\n\
st.shared.f64 [%saddr], %sum_val;\n\
SUM_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra SUM_REDUCE;\n\
SUM_REDUCE_DONE:\n\
\n\
ld.shared.f64 %sum_val, [sdata];\n\
bar.sync 0;\n\
mov.f64 %e_one, 0d3FF0000000000000;\n\
mov.b64 %l_xbits, %sum_val;\n\
shr.u64 %l_exp64, %l_xbits, 52;\n\
and.b64 %l_exp64, %l_exp64, 2047;\n\
sub.s64 %l_exp64, %l_exp64, 1023;\n\
cvt.rn.f64.s64 %l_nf, %l_exp64;\n\
mov.u64 %l_bias, 0x3FF0000000000000;\n\
and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;\n\
or.b64 %l_mbits, %l_mbits, %l_bias;\n\
mov.b64 %l_m, %l_mbits;\n\
mov.f64 %l_sqrt2, 0d3FF6A09E667F3BCD;\n\
mov.f64 %l_half_const, 0d3FE0000000000000;\n\
setp.gt.f64 %p_shift, %l_m, %l_sqrt2;\n\
@%p_shift mul.f64 %l_m, %l_m, %l_half_const;\n\
@%p_shift add.f64 %l_nf, %l_nf, %e_one;\n\
sub.f64 %l_f, %l_m, %e_one;\n\
add.f64 %l_s, %l_m, %e_one;\n\
div.rn.f64 %l_f, %l_f, %l_s;\n\
mul.f64 %l_f2, %l_f, %l_f;\n\
mov.f64 %l_p, 0d3FB1111111111111;\n\
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FB3B13B13B13B14;\n\
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FB745D1745D1746;\n\
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FBC71C71C71C71C;\n\
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492;\n\
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A;\n\
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555;\n\
fma.rn.f64 %l_p, %l_p, %l_f2, %e_one;\n\
mul.f64 %l_p, %l_p, %l_f;\n\
add.f64 %l_p, %l_p, %l_p;\n\
mov.f64 %l_ln2_hi, 0d3FE62E42FEFA3800;\n\
mov.f64 %l_ln2_lo, 0d3D2EF35793C76730;\n\
fma.rn.f64 %log_sum_exp, %l_nf, %l_ln2_hi, %l_p;\n\
fma.rn.f64 %log_sum_exp, %l_nf, %l_ln2_lo, %log_sum_exp;\n\
add.f64 %log_sum_exp, %max_val, %log_sum_exp;\n\
\n\
mov.u32 %j, %r_tid;\n\
WRITE_OUTPUT:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra WRITE_OUTPUT_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %in, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f64 %val, [%saddr];\n\
sub.f64 %result, %val, %log_sum_exp;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %out, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
st.global.f64 [%saddr], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra WRITE_OUTPUT;\n\
WRITE_OUTPUT_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const LOG_SOFTMAX_BACKWARD_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 4 .f32 sdata[256];\n\
\n\
.visible .entry log_softmax_backward_kernel(\n\
.param .u64 grad_ptr,\n\
.param .u64 output_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
.reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f32 %vg, %vo, %sum_grad, %other_val, %softmax_j, %result;\n\
.reg .pred %p, %loop_p, %reduce_p;\n\
\n\
ld.param.u64 %grad, [grad_ptr];\n\
ld.param.u64 %output, [output_ptr];\n\
ld.param.u64 %out, [out_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
// row_off = bid * cols * 4 (byte offset)\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 2;\n\
\n\
// Phase 1: compute partial sum_grad = sum(grad[j]) for this thread's elements\n\
mov.f32 %sum_grad, 0f00000000;\n\
mov.u32 %j, %r_tid;\n\
SUM_LOOP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra SUM_LOOP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %grad, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %vg, [%saddr];\n\
add.f32 %sum_grad, %sum_grad, %vg;\n\
add.u32 %j, %j, %bdim;\n\
bra SUM_LOOP;\n\
SUM_LOOP_DONE:\n\
\n\
// Store partial sum into shared memory and reduce\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %sum_grad;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
SUM_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra SUM_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra SUM_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %sum_grad, [%saddr];\n\
add.f32 %sum_grad, %sum_grad, %other_val;\n\
st.shared.f32 [%saddr], %sum_grad;\n\
SUM_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra SUM_REDUCE;\n\
SUM_REDUCE_DONE:\n\
\n\
// Broadcast sum_grad to all threads\n\
ld.shared.f32 %sum_grad, [sdata];\n\
bar.sync 0;\n\
\n\
// Phase 2: out[j] = grad[j] - softmax[j] * sum_grad\n\
// `output_ptr` already holds softmax probabilities (the host computed\n\
// exp(log_softmax) at forward time and saved it as softmax_output);\n\
// load directly. Re-applying exp() here was the #798 algebra bug.\n\
mov.u32 %j, %r_tid;\n\
WRITE_LOOP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra WRITE_LOOP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %grad, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %vg, [%saddr];\n\
add.u64 %saddr, %output, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f32 %softmax_j, [%saddr];\n\
// out[j] = grad[j] - softmax[j] * sum_grad\n\
mul.f32 %result, %softmax_j, %sum_grad;\n\
sub.f32 %result, %vg, %result;\n\
add.u64 %saddr, %out, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
st.global.f32 [%saddr], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra WRITE_LOOP;\n\
WRITE_LOOP_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const LOG_SOFTMAX_BACKWARD_F64_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 8 .f64 sdata[256];\n\
\n\
.visible .entry log_softmax_backward_f64_kernel(\n\
.param .u64 grad_ptr,\n\
.param .u64 output_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j, %half, %other_tid;\n\
.reg .u64 %grad, %output, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f64 %vg, %sum_grad, %other_val, %softmax_j, %result;\n\
.reg .pred %p, %loop_p, %reduce_p;\n\
\n\
ld.param.u64 %grad, [grad_ptr];\n\
ld.param.u64 %output, [output_ptr];\n\
ld.param.u64 %out, [out_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 3;\n\
\n\
mov.f64 %sum_grad, 0d0000000000000000;\n\
mov.u32 %j, %r_tid;\n\
SUM_LOOP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra SUM_LOOP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %grad, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f64 %vg, [%saddr];\n\
add.f64 %sum_grad, %sum_grad, %vg;\n\
add.u32 %j, %j, %bdim;\n\
bra SUM_LOOP;\n\
SUM_LOOP_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f64 [%saddr], %sum_grad;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
SUM_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra SUM_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra SUM_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %sum_grad, [%saddr];\n\
add.f64 %sum_grad, %sum_grad, %other_val;\n\
st.shared.f64 [%saddr], %sum_grad;\n\
SUM_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra SUM_REDUCE;\n\
SUM_REDUCE_DONE:\n\
\n\
ld.shared.f64 %sum_grad, [sdata];\n\
bar.sync 0;\n\
\n\
// Phase 2: out[j] = grad[j] - softmax[j] * sum_grad\n\
// `output_ptr` already holds softmax probabilities (the host computed\n\
// exp(log_softmax) at forward time and saved it as softmax_output);\n\
// load directly. Re-applying exp() here was the #820 algebra bug\n\
// (f64 sibling of #798's f32 bug; see commit 2fbb23d8).\n\
mov.u32 %j, %r_tid;\n\
WRITE_LOOP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra WRITE_LOOP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %grad, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f64 %vg, [%saddr];\n\
add.u64 %saddr, %output, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
ld.global.f64 %softmax_j, [%saddr];\n\
// out[j] = grad[j] - softmax[j] * sum_grad\n\
mul.f64 %result, %softmax_j, %sum_grad;\n\
sub.f64 %result, %vg, %result;\n\
add.u64 %saddr, %out, %off;\n\
add.u64 %saddr, %saddr, %row_off;\n\
st.global.f64 [%saddr], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra WRITE_LOOP;\n\
WRITE_LOOP_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const REDUCE_SUM_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
// Shared memory for intra-block reduction (256 floats = 1024 bytes).
.shared .align 4 .f32 sdata[256];
.visible .entry reduce_sum_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %my_tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
.reg .u64 %in, %out, %off, %sbase, %saddr;
.reg .f32 %sum, %other;
.reg .pred %p, %ptid;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %my_tid, %tid.x;
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %gdim, %nctaid.x;
mov.u64 %sbase, sdata;
// Grid-stride accumulation: each thread sums multiple elements.
// idx = bid * bdim + tid; stride = bdim * gdim
mad.lo.u32 %idx, %bid, %bdim, %my_tid;
mul.lo.u32 %stride, %bdim, %gdim;
mov.f32 %sum, 0f00000000;
GRID_LOOP:
setp.ge.u32 %p, %idx, %n_reg;
@%p bra GRID_DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
ld.global.f32 %other, [%off];
add.f32 %sum, %sum, %other;
add.u32 %idx, %idx, %stride;
bra GRID_LOOP;
GRID_DONE:
// Write thread's partial sum to shared memory.
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sum;
bar.sync 0;
// Tree reduction in shared memory.
mov.u32 %half, 128;
TREE_LOOP:
setp.lt.u32 %p, %half, 1;
@%p bra TREE_DONE;
setp.ge.u32 %ptid, %my_tid, %half;
@%ptid bra TREE_SKIP;
// Load partner's value from sdata[tid + half].
add.u32 %idx, %my_tid, %half;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other, [%saddr];
// Load own value.
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %sum, [%saddr];
add.f32 %sum, %sum, %other;
st.shared.f32 [%saddr], %sum;
TREE_SKIP:
bar.sync 0;
shr.u32 %half, %half, 1;
bra TREE_LOOP;
TREE_DONE:
// Thread 0 writes block result.
setp.ne.u32 %ptid, %my_tid, 0;
@%ptid bra END;
ld.shared.f32 %sum, [sdata];
cvt.u64.u32 %off, %bid;
shl.b64 %off, %off, 2;
add.u64 %out, %out, %off;
st.global.f32 [%out], %sum;
END:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const REDUCE_PROD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.shared .align 4 .f32 sdata[256];
.visible .entry reduce_prod_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %my_tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
.reg .u64 %in, %out, %off, %saddr;
.reg .f32 %acc, %other;
.reg .pred %p, %ptid;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %my_tid, %tid.x;
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %gdim, %nctaid.x;
mad.lo.u32 %idx, %bid, %bdim, %my_tid;
mul.lo.u32 %stride, %bdim, %gdim;
mov.f32 %acc, 0f3F800000; // 1.0 (multiplicative identity)
GRID_LOOP_PROD:
setp.ge.u32 %p, %idx, %n_reg;
@%p bra GRID_DONE_PROD;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
ld.global.f32 %other, [%off];
mul.f32 %acc, %acc, %other;
add.u32 %idx, %idx, %stride;
bra GRID_LOOP_PROD;
GRID_DONE_PROD:
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
st.shared.f32 [%saddr], %acc;
bar.sync 0;
mov.u32 %half, 128;
TREE_LOOP_PROD:
setp.lt.u32 %p, %half, 1;
@%p bra TREE_DONE_PROD;
setp.ge.u32 %ptid, %my_tid, %half;
@%ptid bra TREE_SKIP_PROD;
add.u32 %idx, %my_tid, %half;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
ld.shared.f32 %other, [%saddr];
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
ld.shared.f32 %acc, [%saddr];
mul.f32 %acc, %acc, %other;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
st.shared.f32 [%saddr], %acc;
TREE_SKIP_PROD:
bar.sync 0;
shr.u32 %half, %half, 1;
bra TREE_LOOP_PROD;
TREE_DONE_PROD:
setp.ne.u32 %ptid, %my_tid, 0;
@%ptid bra END_PROD;
mov.u64 %saddr, sdata;
ld.shared.f32 %acc, [%saddr];
cvt.u64.u32 %off, %bid;
shl.b64 %off, %off, 2;
add.u64 %out, %out, %off;
st.global.f32 [%out], %acc;
END_PROD:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const PROD_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry prod_backward_kernel(
.param .u64 input_ptr,
.param .u64 grad_out_ptr,
.param .u64 grad_in_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %i, %j, %n_reg;
.reg .u64 %in, %go, %gi, %off, %addr;
.reg .f32 %prefix, %suffix, %val, %go_val, %result;
.reg .pred %p, %lp, %skip;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %go, [grad_out_ptr];
ld.param.u64 %gi, [grad_in_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %i, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %i, %n_reg;
@%p bra DONE;
// Read scalar grad_output once.
ld.global.f32 %go_val, [%go];
// prefix = prod(input[0..i])
mov.f32 %prefix, 0f3F800000;
mov.u32 %j, 0;
PREFIX_LOOP:
setp.ge.u32 %lp, %j, %i;
@%lp bra PREFIX_DONE;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
mul.f32 %prefix, %prefix, %val;
add.u32 %j, %j, 1;
bra PREFIX_LOOP;
PREFIX_DONE:
// suffix = prod(input[i+1..n])
mov.f32 %suffix, 0f3F800000;
add.u32 %j, %i, 1;
SUFFIX_LOOP:
setp.ge.u32 %lp, %j, %n_reg;
@%lp bra SUFFIX_DONE;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
mul.f32 %suffix, %suffix, %val;
add.u32 %j, %j, 1;
bra SUFFIX_LOOP;
SUFFIX_DONE:
// grad_input[i] = grad_output * prefix * suffix
mul.f32 %result, %prefix, %suffix;
mul.f32 %result, %result, %go_val;
cvt.u64.u32 %off, %i;
shl.b64 %off, %off, 2;
add.u64 %addr, %gi, %off;
st.global.f32 [%addr], %result;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const REDUCE_MIN_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.shared .align 4 .f32 sdata[256];
.visible .entry reduce_min_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %my_tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
.reg .u64 %in, %out, %off, %saddr;
.reg .f32 %acc, %other;
.reg .pred %p, %ptid;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %my_tid, %tid.x;
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %gdim, %nctaid.x;
mad.lo.u32 %idx, %bid, %bdim, %my_tid;
mul.lo.u32 %stride, %bdim, %gdim;
// accumulator init = +inf
mov.f32 %acc, 0f7F800000;
GRID_LOOP_MIN:
setp.ge.u32 %p, %idx, %n_reg;
@%p bra GRID_DONE_MIN;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
ld.global.f32 %other, [%off];
min.f32 %acc, %acc, %other;
add.u32 %idx, %idx, %stride;
bra GRID_LOOP_MIN;
GRID_DONE_MIN:
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
st.shared.f32 [%saddr], %acc;
bar.sync 0;
mov.u32 %half, 128;
TREE_LOOP_MIN:
setp.lt.u32 %p, %half, 1;
@%p bra TREE_DONE_MIN;
setp.ge.u32 %ptid, %my_tid, %half;
@%ptid bra TREE_SKIP_MIN;
add.u32 %idx, %my_tid, %half;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
ld.shared.f32 %other, [%saddr];
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
ld.shared.f32 %acc, [%saddr];
min.f32 %acc, %acc, %other;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
st.shared.f32 [%saddr], %acc;
TREE_SKIP_MIN:
bar.sync 0;
shr.u32 %half, %half, 1;
bra TREE_LOOP_MIN;
TREE_DONE_MIN:
setp.ne.u32 %ptid, %my_tid, 0;
@%ptid bra END_MIN;
mov.u64 %saddr, sdata;
ld.shared.f32 %acc, [%saddr];
cvt.u64.u32 %off, %bid;
shl.b64 %off, %off, 2;
add.u64 %out, %out, %off;
st.global.f32 [%out], %acc;
END_MIN:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const REDUCE_MAX_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.shared .align 4 .f32 sdata[256];
.visible .entry reduce_max_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %my_tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
.reg .u64 %in, %out, %off, %saddr;
.reg .f32 %acc, %other;
.reg .pred %p, %ptid;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %my_tid, %tid.x;
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %gdim, %nctaid.x;
mad.lo.u32 %idx, %bid, %bdim, %my_tid;
mul.lo.u32 %stride, %bdim, %gdim;
// accumulator init = -inf
mov.f32 %acc, 0fFF800000;
GRID_LOOP_MAX:
setp.ge.u32 %p, %idx, %n_reg;
@%p bra GRID_DONE_MAX;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
ld.global.f32 %other, [%off];
max.f32 %acc, %acc, %other;
add.u32 %idx, %idx, %stride;
bra GRID_LOOP_MAX;
GRID_DONE_MAX:
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
st.shared.f32 [%saddr], %acc;
bar.sync 0;
mov.u32 %half, 128;
TREE_LOOP_MAX:
setp.lt.u32 %p, %half, 1;
@%p bra TREE_DONE_MAX;
setp.ge.u32 %ptid, %my_tid, %half;
@%ptid bra TREE_SKIP_MAX;
add.u32 %idx, %my_tid, %half;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
ld.shared.f32 %other, [%saddr];
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
ld.shared.f32 %acc, [%saddr];
max.f32 %acc, %acc, %other;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
st.shared.f32 [%saddr], %acc;
TREE_SKIP_MAX:
bar.sync 0;
shr.u32 %half, %half, 1;
bra TREE_LOOP_MAX;
TREE_DONE_MAX:
setp.ne.u32 %ptid, %my_tid, 0;
@%ptid bra END_MAX;
mov.u64 %saddr, sdata;
ld.shared.f32 %acc, [%saddr];
cvt.u64.u32 %off, %bid;
shl.b64 %off, %off, 2;
add.u64 %out, %out, %off;
st.global.f32 [%out], %acc;
END_MAX:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const HAS_INF_NAN_F32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry has_inf_nan_f32_kernel(
.param .u64 a_ptr,
.param .u32 n,
.param .u64 flag_ptr
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %bits, %exp_mask, %old;
.reg .u64 %a, %flag, %off;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u64 %flag, [flag_ptr];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %off, %a, %off;
ld.global.u32 %bits, [%off];
and.b32 %exp_mask, %bits, 0x7F800000;
setp.ne.u32 %p, %exp_mask, 0x7F800000;
@%p bra DONE;
atom.global.or.b32 %old, [%flag], 1;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MASKED_REDUCE_MIN_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.shared .align 4 .f32 sdata[256];
.visible .entry masked_reduce_min_kernel(
.param .u64 data_ptr,
.param .u64 mask_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %my_tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
.reg .u64 %dat, %msk, %out, %off, %saddr;
.reg .f32 %acc, %d, %m, %sentinel, %val;
.reg .pred %p, %ptid, %p_valid;
ld.param.u64 %dat, [data_ptr];
ld.param.u64 %msk, [mask_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %my_tid, %tid.x;
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %gdim, %nctaid.x;
mad.lo.u32 %idx, %bid, %bdim, %my_tid;
mul.lo.u32 %stride, %bdim, %gdim;
mov.f32 %acc, 0f7F800000; // +inf
mov.f32 %sentinel, 0f7F800000;
GRID_LOOP_MMIN:
setp.ge.u32 %p, %idx, %n_reg;
@%p bra GRID_DONE_MMIN;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %off, %dat, %off;
ld.global.f32 %d, [%off];
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %off, %msk, %off;
ld.global.f32 %m, [%off];
// val = (m != 0) ? d : +inf
setp.ne.f32 %p_valid, %m, 0f00000000;
selp.f32 %val, %d, %sentinel, %p_valid;
min.f32 %acc, %acc, %val;
add.u32 %idx, %idx, %stride;
bra GRID_LOOP_MMIN;
GRID_DONE_MMIN:
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
st.shared.f32 [%saddr], %acc;
bar.sync 0;
mov.u32 %half, 128;
TREE_LOOP_MMIN:
setp.lt.u32 %p, %half, 1;
@%p bra TREE_DONE_MMIN;
setp.ge.u32 %ptid, %my_tid, %half;
@%ptid bra TREE_SKIP_MMIN;
add.u32 %idx, %my_tid, %half;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
ld.shared.f32 %val, [%saddr];
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
ld.shared.f32 %acc, [%saddr];
min.f32 %acc, %acc, %val;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
st.shared.f32 [%saddr], %acc;
TREE_SKIP_MMIN:
bar.sync 0;
shr.u32 %half, %half, 1;
bra TREE_LOOP_MMIN;
TREE_DONE_MMIN:
setp.ne.u32 %ptid, %my_tid, 0;
@%ptid bra END_MMIN;
mov.u64 %saddr, sdata;
ld.shared.f32 %acc, [%saddr];
cvt.u64.u32 %off, %bid;
shl.b64 %off, %off, 2;
add.u64 %out, %out, %off;
st.global.f32 [%out], %acc;
END_MMIN:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MASKED_REDUCE_MAX_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.shared .align 4 .f32 sdata[256];
.visible .entry masked_reduce_max_kernel(
.param .u64 data_ptr,
.param .u64 mask_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %my_tid, %bid, %bdim, %gdim, %n_reg, %idx, %stride, %half;
.reg .u64 %dat, %msk, %out, %off, %saddr;
.reg .f32 %acc, %d, %m, %sentinel, %val;
.reg .pred %p, %ptid, %p_valid;
ld.param.u64 %dat, [data_ptr];
ld.param.u64 %msk, [mask_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %my_tid, %tid.x;
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %gdim, %nctaid.x;
mad.lo.u32 %idx, %bid, %bdim, %my_tid;
mul.lo.u32 %stride, %bdim, %gdim;
mov.f32 %acc, 0fFF800000; // -inf
mov.f32 %sentinel, 0fFF800000;
GRID_LOOP_MMAX:
setp.ge.u32 %p, %idx, %n_reg;
@%p bra GRID_DONE_MMAX;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %off, %dat, %off;
ld.global.f32 %d, [%off];
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %off, %msk, %off;
ld.global.f32 %m, [%off];
setp.ne.f32 %p_valid, %m, 0f00000000;
selp.f32 %val, %d, %sentinel, %p_valid;
max.f32 %acc, %acc, %val;
add.u32 %idx, %idx, %stride;
bra GRID_LOOP_MMAX;
GRID_DONE_MMAX:
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
st.shared.f32 [%saddr], %acc;
bar.sync 0;
mov.u32 %half, 128;
TREE_LOOP_MMAX:
setp.lt.u32 %p, %half, 1;
@%p bra TREE_DONE_MMAX;
setp.ge.u32 %ptid, %my_tid, %half;
@%ptid bra TREE_SKIP_MMAX;
add.u32 %idx, %my_tid, %half;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
ld.shared.f32 %val, [%saddr];
cvt.u64.u32 %off, %my_tid;
shl.b64 %off, %off, 2;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
ld.shared.f32 %acc, [%saddr];
max.f32 %acc, %acc, %val;
mov.u64 %saddr, sdata;
add.u64 %saddr, %saddr, %off;
st.shared.f32 [%saddr], %acc;
TREE_SKIP_MMAX:
bar.sync 0;
shr.u32 %half, %half, 1;
bra TREE_LOOP_MMAX;
TREE_DONE_MMAX:
setp.ne.u32 %ptid, %my_tid, 0;
@%ptid bra END_MMAX;
mov.u64 %saddr, sdata;
ld.shared.f32 %acc, [%saddr];
cvt.u64.u32 %off, %bid;
shl.b64 %off, %off, 2;
add.u64 %out, %out, %off;
st.global.f32 [%out], %acc;
END_MMAX:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const REPEAT_ALONG_DIM_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry repeat_along_dim_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 outer,
.param .u32 repeat_count,
.param .u32 inner
) {
.reg .u32 %r_tid, %bid, %bdim, %t, %total, %o, %r, %i, %tmp_ri, %tmp_ri2, %ri_extent;
.reg .u32 %src_idx, %re_x_in;
.reg .u64 %inp, %out, %off;
.reg .f32 %v;
.reg .pred %p;
ld.param.u64 %inp, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %tmp_ri, [outer];
ld.param.u32 %tmp_ri2, [repeat_count];
ld.param.u32 %ri_extent, [inner];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %t, %bid, %bdim, %r_tid;
// total = outer * repeat_count * inner
mul.lo.u32 %re_x_in, %tmp_ri2, %ri_extent;
mul.lo.u32 %total, %tmp_ri, %re_x_in;
setp.ge.u32 %p, %t, %total;
@%p bra DONE_RAD;
// o = t / (repeat_count * inner)
div.u32 %o, %t, %re_x_in;
// tmp = t % (repeat_count * inner)
rem.u32 %tmp_ri, %t, %re_x_in;
// r = tmp / inner; i = tmp % inner
div.u32 %r, %tmp_ri, %ri_extent;
rem.u32 %i, %tmp_ri, %ri_extent;
// src_idx = o * inner + i
mad.lo.u32 %src_idx, %o, %ri_extent, %i;
// Load src
cvt.u64.u32 %off, %src_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %inp, %off;
ld.global.f32 %v, [%off];
// Store to dst[t]
cvt.u64.u32 %off, %t;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
st.global.f32 [%off], %v;
// suppress unused warning
mov.u32 %r, %r;
DONE_RAD:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const CLAMP_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry clamp_backward_kernel(
.param .u64 grad_ptr,
.param .u64 input_ptr,
.param .u64 out_ptr,
.param .f32 min_val,
.param .f32 max_val,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %g, %x, %out, %off;
.reg .f32 %vg, %vx, %vmin, %vmax, %vr;
.reg .pred %p, %plo, %phi, %pin;
ld.param.u64 %g, [grad_ptr];
ld.param.u64 %x, [input_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.f32 %vmin, [min_val];
ld.param.f32 %vmax, [max_val];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE_CB;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %g, %g, %off;
add.u64 %x, %x, %off;
add.u64 %out, %out, %off;
ld.global.f32 %vg, [%g];
ld.global.f32 %vx, [%x];
setp.ge.f32 %plo, %vx, %vmin;
setp.le.f32 %phi, %vx, %vmax;
and.pred %pin, %plo, %phi;
mov.f32 %vr, 0f00000000;
@%pin mov.f32 %vr, %vg;
st.global.f32 [%out], %vr;
DONE_CB:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const PAD_TRUNCATE_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry pad_truncate_kernel(
.param .u64 src_ptr,
.param .u64 dst_ptr,
.param .u32 batch,
.param .u32 src_n,
.param .u32 dst_n
) {
.reg .u32 %r_tid, %bid, %bdim, %total, %b_idx, %k_idx, %src_offset, %dst_offset;
.reg .u32 %tmp32, %tmp32b;
.reg .u64 %src_base, %dst_base, %off_src, %off_dst;
.reg .f32 %re, %im;
.reg .pred %p_oob, %p_pad;
ld.param.u64 %src_base, [src_ptr];
ld.param.u64 %dst_base, [dst_ptr];
ld.param.u32 %tmp32, [batch];
ld.param.u32 %tmp32b, [dst_n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
// total = batch * dst_n
mul.lo.u32 %total, %tmp32, %tmp32b;
setp.ge.u32 %p_oob, %r_tid, %total;
@%p_oob bra DONE_PT;
// b_idx = r_tid / dst_n
// k_idx = r_tid % dst_n
div.u32 %b_idx, %r_tid, %tmp32b;
rem.u32 %k_idx, %r_tid, %tmp32b;
// dst_offset = (b_idx * dst_n + k_idx) * 2
mad.lo.u32 %dst_offset, %b_idx, %tmp32b, %k_idx;
shl.b32 %dst_offset, %dst_offset, 1;
// Compare k_idx vs src_n.
ld.param.u32 %tmp32, [src_n];
setp.ge.u32 %p_pad, %k_idx, %tmp32;
@%p_pad bra PAD;
// Copy from src[b_idx, k_idx, :].
// src_offset = (b_idx * src_n + k_idx) * 2
mad.lo.u32 %src_offset, %b_idx, %tmp32, %k_idx;
shl.b32 %src_offset, %src_offset, 1;
cvt.u64.u32 %off_src, %src_offset;
shl.b64 %off_src, %off_src, 2;
add.u64 %off_src, %src_base, %off_src;
cvt.u64.u32 %off_dst, %dst_offset;
shl.b64 %off_dst, %off_dst, 2;
add.u64 %off_dst, %dst_base, %off_dst;
ld.global.f32 %re, [%off_src];
ld.global.f32 %im, [%off_src + 4];
st.global.f32 [%off_dst], %re;
st.global.f32 [%off_dst + 4], %im;
bra DONE_PT;
PAD:
cvt.u64.u32 %off_dst, %dst_offset;
shl.b64 %off_dst, %off_dst, 2;
add.u64 %off_dst, %dst_base, %off_dst;
mov.f32 %re, 0f00000000;
st.global.f32 [%off_dst], %re;
st.global.f32 [%off_dst + 4], %re;
DONE_PT:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SUM_AXIS_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sum_axis_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 outer_size,
.param .u32 axis_size,
.param .u32 inner_size,
.param .u32 total_output
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %axis_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %tmp;
.reg .u64 %in, %out, %off, %addr;
.reg .f32 %val, %sum;
.reg .pred %p, %lp;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %axis_sz, [axis_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total_output];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// outer_idx = r_tid / inner_size
div.u32 %outer_idx, %r_tid, %inner_sz;
// inner_idx = r_tid % inner_size
rem.u32 %inner_idx, %r_tid, %inner_sz;
// base = outer_idx * axis_size * inner_size + inner_idx
mul.lo.u32 %tmp, %outer_idx, %axis_sz;
mul.lo.u32 %tmp, %tmp, %inner_sz;
add.u32 %tmp, %tmp, %inner_idx;
mov.f32 %sum, 0f00000000;
mov.u32 %k, 0;
SUM_LOOP:
setp.ge.u32 %lp, %k, %axis_sz;
@%lp bra SUM_LOOP_DONE;
// addr = in + (tmp + k * inner_size) * 4
mul.lo.u32 %inner_idx, %k, %inner_sz;
add.u32 %inner_idx, %tmp, %inner_idx;
cvt.u64.u32 %off, %inner_idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
add.f32 %sum, %sum, %val;
add.u32 %k, %k, 1;
bra SUM_LOOP;
SUM_LOOP_DONE:
// output[r_tid] = sum
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %sum;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const CUMSUM_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry cumsum_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 outer_size,
.param .u32 dim_size,
.param .u32 inner_size,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
.reg .u64 %in, %out, %off, %addr;
.reg .f32 %val, %acc;
.reg .pred %p, %lp;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %dim_sz, [dim_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
// total threads = outer * inner
mul.lo.u32 %tmp, %outer_sz, %inner_sz;
setp.ge.u32 %p, %r_tid, %tmp;
@%p bra DONE;
div.u32 %outer_idx, %r_tid, %inner_sz;
rem.u32 %inner_idx, %r_tid, %inner_sz;
// base = outer_idx * dim_size * inner_size + inner_idx
mul.lo.u32 %base, %outer_idx, %dim_sz;
mul.lo.u32 %base, %base, %inner_sz;
add.u32 %base, %base, %inner_idx;
mov.f32 %acc, 0f00000000;
mov.u32 %k, 0;
SCAN_LOOP:
setp.ge.u32 %lp, %k, %dim_sz;
@%lp bra SCAN_DONE;
// idx = base + k * inner_size
mul.lo.u32 %idx, %k, %inner_sz;
add.u32 %idx, %base, %idx;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
add.f32 %acc, %acc, %val;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %acc;
add.u32 %k, %k, 1;
bra SCAN_LOOP;
SCAN_DONE:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const CUMPROD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry cumprod_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 outer_size,
.param .u32 dim_size,
.param .u32 inner_size,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
.reg .u64 %in, %out, %off, %addr;
.reg .f32 %val, %acc;
.reg .pred %p, %lp;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %dim_sz, [dim_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
mul.lo.u32 %tmp, %outer_sz, %inner_sz;
setp.ge.u32 %p, %r_tid, %tmp;
@%p bra DONE;
div.u32 %outer_idx, %r_tid, %inner_sz;
rem.u32 %inner_idx, %r_tid, %inner_sz;
mul.lo.u32 %base, %outer_idx, %dim_sz;
mul.lo.u32 %base, %base, %inner_sz;
add.u32 %base, %base, %inner_idx;
// acc = 1.0
mov.f32 %acc, 0f3F800000;
mov.u32 %k, 0;
SCAN_LOOP:
setp.ge.u32 %lp, %k, %dim_sz;
@%lp bra SCAN_DONE;
mul.lo.u32 %idx, %k, %inner_sz;
add.u32 %idx, %base, %idx;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
mul.f32 %acc, %acc, %val;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %acc;
add.u32 %k, %k, 1;
bra SCAN_LOOP;
SCAN_DONE:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const CUMMAX_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry cummax_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u64 indices_ptr,
.param .u32 outer_size,
.param .u32 dim_size,
.param .u32 inner_size,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
.reg .u64 %in, %out, %ind, %off, %addr;
.reg .f32 %val, %acc, %best_k_f;
.reg .pred %p, %lp, %is_new_max;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u64 %ind, [indices_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %dim_sz, [dim_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
mul.lo.u32 %tmp, %outer_sz, %inner_sz;
setp.ge.u32 %p, %r_tid, %tmp;
@%p bra DONE;
div.u32 %outer_idx, %r_tid, %inner_sz;
rem.u32 %inner_idx, %r_tid, %inner_sz;
mul.lo.u32 %base, %outer_idx, %dim_sz;
mul.lo.u32 %base, %base, %inner_sz;
add.u32 %base, %base, %inner_idx;
mov.b32 %acc, 0xFF800000;
mov.u32 %best_k, 0;
mov.u32 %k, 0;
SCAN_LOOP:
setp.ge.u32 %lp, %k, %dim_sz;
@%lp bra SCAN_DONE;
mul.lo.u32 %idx, %k, %inner_sz;
add.u32 %idx, %base, %idx;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
setp.gt.f32 %is_new_max, %val, %acc;
@%is_new_max mov.u32 %best_k, %k;
max.f32 %acc, %acc, %val;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %acc;
cvt.rn.f32.u32 %best_k_f, %best_k;
add.u64 %addr, %ind, %off;
st.global.f32 [%addr], %best_k_f;
add.u32 %k, %k, 1;
bra SCAN_LOOP;
SCAN_DONE:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const CUMMIN_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry cummin_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u64 indices_ptr,
.param .u32 outer_size,
.param .u32 dim_size,
.param .u32 inner_size,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp, %best_k;
.reg .u64 %in, %out, %ind, %off, %addr;
.reg .f32 %val, %acc, %best_k_f;
.reg .pred %p, %lp, %is_new_min;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u64 %ind, [indices_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %dim_sz, [dim_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
mul.lo.u32 %tmp, %outer_sz, %inner_sz;
setp.ge.u32 %p, %r_tid, %tmp;
@%p bra DONE;
div.u32 %outer_idx, %r_tid, %inner_sz;
rem.u32 %inner_idx, %r_tid, %inner_sz;
mul.lo.u32 %base, %outer_idx, %dim_sz;
mul.lo.u32 %base, %base, %inner_sz;
add.u32 %base, %base, %inner_idx;
mov.b32 %acc, 0x7F800000;
mov.u32 %best_k, 0;
mov.u32 %k, 0;
SCAN_LOOP:
setp.ge.u32 %lp, %k, %dim_sz;
@%lp bra SCAN_DONE;
mul.lo.u32 %idx, %k, %inner_sz;
add.u32 %idx, %base, %idx;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
setp.lt.f32 %is_new_min, %val, %acc;
@%is_new_min mov.u32 %best_k, %k;
min.f32 %acc, %acc, %val;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %acc;
cvt.rn.f32.u32 %best_k_f, %best_k;
add.u64 %addr, %ind, %off;
st.global.f32 [%addr], %best_k_f;
add.u32 %k, %k, 1;
bra SCAN_LOOP;
SCAN_DONE:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const LOGCUMSUMEXP_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry logcumsumexp_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 outer_size,
.param .u32 dim_size,
.param .u32 inner_size,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
.reg .u64 %in, %out, %off, %addr;
.reg .f32 %val, %acc, %m, %ea, %ev, %s, %ls, %log2e, %ln2;
.reg .pred %p, %lp;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %dim_sz, [dim_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total];
// log2(e) = 1.4426950408... -> 0x3FB8AA3B
mov.b32 %log2e, 0x3FB8AA3B;
// ln(2) = 0.6931471805... -> 0x3F317218
mov.b32 %ln2, 0x3F317218;
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
mul.lo.u32 %tmp, %outer_sz, %inner_sz;
setp.ge.u32 %p, %r_tid, %tmp;
@%p bra DONE;
div.u32 %outer_idx, %r_tid, %inner_sz;
rem.u32 %inner_idx, %r_tid, %inner_sz;
mul.lo.u32 %base, %outer_idx, %dim_sz;
mul.lo.u32 %base, %base, %inner_sz;
add.u32 %base, %base, %inner_idx;
// acc = -inf
mov.b32 %acc, 0xFF800000;
mov.u32 %k, 0;
SCAN_LOOP:
setp.ge.u32 %lp, %k, %dim_sz;
@%lp bra SCAN_DONE;
mul.lo.u32 %idx, %k, %inner_sz;
add.u32 %idx, %base, %idx;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %val, [%addr];
// Numerically stable: m = max(acc, x)
max.f32 %m, %acc, %val;
// exp(acc - m): (acc - m) * log2(e) -> ex2
sub.f32 %ea, %acc, %m;
mul.f32 %ea, %ea, %log2e;
ex2.approx.f32 %ea, %ea;
// exp(x - m): (x - m) * log2(e) -> ex2
sub.f32 %ev, %val, %m;
mul.f32 %ev, %ev, %log2e;
ex2.approx.f32 %ev, %ev;
// sum
add.f32 %s, %ea, %ev;
// log(sum) = lg2(sum) * ln(2)
lg2.approx.f32 %ls, %s;
mul.f32 %ls, %ls, %ln2;
// acc = m + log(sum)
add.f32 %acc, %m, %ls;
add.u64 %addr, %out, %off;
st.global.f32 [%addr], %acc;
add.u32 %k, %k, 1;
bra SCAN_LOOP;
SCAN_DONE:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const LOGCUMSUMEXP_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry logcumsumexp_f64_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 outer_size,
.param .u32 dim_size,
.param .u32 inner_size,
.param .u32 total
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %dim_sz, %inner_sz;
.reg .u32 %outer_idx, %inner_idx, %k, %base, %idx, %tmp;
.reg .u64 %in, %out, %off, %addr;
.reg .f64 %val, %acc, %m, %ea, %ev, %s, %ls;
.reg .pred %p, %lp;
.reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .u64 %l_xbits, %l_mbits, %l_bias;
.reg .s64 %l_exp64;
.reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf;
.reg .f64 %l_ln2_hi, %l_ln2_lo, %l_sqrt2, %l_half_const;
.reg .pred %p_shift;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %outer_sz, [outer_size];
ld.param.u32 %dim_sz, [dim_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
mul.lo.u32 %tmp, %outer_sz, %inner_sz;
setp.ge.u32 %p, %r_tid, %tmp;
@%p bra DONE;
div.u32 %outer_idx, %r_tid, %inner_sz;
rem.u32 %inner_idx, %r_tid, %inner_sz;
mul.lo.u32 %base, %outer_idx, %dim_sz;
mul.lo.u32 %base, %base, %inner_sz;
add.u32 %base, %base, %inner_idx;
// Seed `acc` from input[0] (#786). The previous `acc = -inf`
// seed reached the inline exp/ln poly with `r = NaN` because
// `cvt.rni.s32.f64(-inf)` saturates to INT_MIN and the
// `(INT_MIN + 1023) << 52` bit-shift trick yields the f64 bits
// for 1.0; the poly drift then accumulated ~+710.19 per element.
// Reading the seed from `input[0]` makes the first iteration's
// exp(acc - m) = exp(0) = 1, well-defined for the polynomial.
cvt.u64.u32 %off, %base;
shl.b64 %off, %off, 3;
add.u64 %addr, %in, %off;
ld.global.f64 %acc, [%addr];
add.u64 %addr, %out, %off;
st.global.f64 [%addr], %acc;
mov.u32 %k, 1;
SCAN_LOOP:
setp.ge.u32 %lp, %k, %dim_sz;
@%lp bra SCAN_DONE;
mul.lo.u32 %idx, %k, %inner_sz;
add.u32 %idx, %base, %idx;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %addr, %in, %off;
ld.global.f64 %val, [%addr];
max.f64 %m, %acc, %val;
mov.f64 %e_one, 0d3FF0000000000000;
mov.f64 %e_half, 0d3FE0000000000000;
// --- inline exp(acc - m) -> %ea ---
sub.f64 %ea, %acc, %m;
mul.f64 %e_nf, %ea, 0d3FF71547652B82FE;
cvt.rni.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %ea;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %e_one;
fma.rn.f64 %ea, %e_p, %e_r, %e_one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %ea, %ea, %e_nf;
// --- inline exp(val - m) -> %ev ---
sub.f64 %ev, %val, %m;
mul.f64 %e_nf, %ev, 0d3FF71547652B82FE;
cvt.rni.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %ev;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %e_one;
fma.rn.f64 %ev, %e_p, %e_r, %e_one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %ev, %ev, %e_nf;
add.f64 %s, %ea, %ev;
// --- inline ln(%s) -> %ls (#783 cluster fix: half-step + degree-7 + 2-double ln2) ---
mov.b64 %l_xbits, %s;
shr.u64 %l_exp64, %l_xbits, 52;
and.b64 %l_exp64, %l_exp64, 2047;
sub.s64 %l_exp64, %l_exp64, 1023;
cvt.rn.f64.s64 %l_nf, %l_exp64;
mov.u64 %l_bias, 0x3FF0000000000000;
and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
or.b64 %l_mbits, %l_mbits, %l_bias;
mov.b64 %l_m, %l_mbits;
mov.f64 %l_sqrt2, 0d3FF6A09E667F3BCD;
mov.f64 %l_half_const, 0d3FE0000000000000;
setp.gt.f64 %p_shift, %l_m, %l_sqrt2;
@%p_shift mul.f64 %l_m, %l_m, %l_half_const;
@%p_shift add.f64 %l_nf, %l_nf, %e_one;
sub.f64 %l_f, %l_m, %e_one;
add.f64 %l_s, %l_m, %e_one;
div.rn.f64 %l_f, %l_f, %l_s;
mul.f64 %l_f2, %l_f, %l_f;
// Degree-7 odd-power Horner (was degree-5; #783 cluster fix).
mov.f64 %l_p, 0d3FB1111111111111; // 1/15
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FB3B13B13B13B14; // 1/13
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FB745D1745D1746; // 1/11
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FBC71C71C71C71C; // 1/9
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492; // 1/7
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A; // 1/5
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555; // 1/3
fma.rn.f64 %l_p, %l_p, %l_f2, %e_one;
mul.f64 %l_p, %l_p, %l_f;
add.f64 %l_p, %l_p, %l_p;
mov.f64 %l_ln2_hi, 0d3FE62E42FEFA3800;
mov.f64 %l_ln2_lo, 0d3D2EF35793C76730;
fma.rn.f64 %ls, %l_nf, %l_ln2_hi, %l_p;
fma.rn.f64 %ls, %l_nf, %l_ln2_lo, %ls;
add.f64 %acc, %m, %ls;
add.u64 %addr, %out, %off;
st.global.f64 [%addr], %acc;
add.u32 %k, %k, 1;
bra SCAN_LOOP;
SCAN_DONE:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const LAYERNORM_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.shared .align 4 .f32 sdata[256];
.visible .entry layernorm_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u64 w_ptr,
.param .u64 b_ptr,
.param .u32 rows,
.param .u32 cols,
.param .f32 eps
) {
.reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
.reg .u64 %in, %out, %w, %b, %row_off, %off, %sbase, %saddr;
.reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %normed, %wv, %bv, %result, %other_val, %n_f;
.reg .pred %p, %lp, %rp;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %w, [w_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u32 %rows_reg, [rows];
ld.param.u32 %cols_reg, [cols];
ld.param.f32 %eps_r, [eps];
mov.u64 %sbase, sdata;
mov.u32 %r_bid, %ctaid.x;
mov.u32 %r_bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
setp.ge.u32 %p, %r_bid, %rows_reg;
@%p bra DONE;
cvt.u64.u32 %row_off, %r_bid;
cvt.u64.u32 %off, %cols_reg;
mul.lo.u64 %row_off, %row_off, %off;
shl.b64 %row_off, %row_off, 2;
cvt.rn.f32.u32 %n_f, %cols_reg;
mov.f32 %mean, 0f00000000;
mov.u32 %j, %r_tid;
SM:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra SMD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.f32 %val, [%off];
add.f32 %mean, %mean, %val;
add.u32 %j, %j, %r_bdim;
bra SM;
SMD:
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %mean;
bar.sync 0;
mov.u32 %half, %r_bdim;
MR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra MRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra MRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %mean, [%saddr];
add.f32 %mean, %mean, %other_val;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %mean;
MRS:
bar.sync 0;
bra MR;
MRD:
ld.shared.f32 %mean, [%sbase];
div.approx.f32 %mean, %mean, %n_f;
bar.sync 0;
mov.f32 %var, 0f00000000;
mov.u32 %j, %r_tid;
SV:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra SVD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.f32 %val, [%off];
sub.f32 %diff, %val, %mean;
fma.rn.f32 %var, %diff, %diff, %var;
add.u32 %j, %j, %r_bdim;
bra SV;
SVD:
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %var;
bar.sync 0;
mov.u32 %half, %r_bdim;
VR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra VRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra VRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %var, [%saddr];
add.f32 %var, %var, %other_val;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %var;
VRS:
bar.sync 0;
bra VR;
VRD:
ld.shared.f32 %var, [%sbase];
div.approx.f32 %var, %var, %n_f;
add.f32 %var, %var, %eps_r;
sqrt.approx.f32 %inv_std, %var;
rcp.approx.f32 %inv_std, %inv_std;
bar.sync 0;
mov.u32 %j, %r_tid;
NM:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra NMD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.f32 %val, [%off];
sub.f32 %normed, %val, %mean;
mul.f32 %normed, %normed, %inv_std;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %w, %off;
ld.global.f32 %wv, [%off];
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %b, %off;
ld.global.f32 %bv, [%off];
fma.rn.f32 %result, %wv, %normed, %bv;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
add.u64 %off, %off, %row_off;
st.global.f32 [%off], %result;
add.u32 %j, %j, %r_bdim;
bra NM;
NMD:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const LAYERNORM_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.shared .align 4 .f32 sdata[256];
.visible .entry layernorm_backward_kernel(
.param .u64 in_ptr,
.param .u64 grad_out_ptr,
.param .u64 w_ptr,
.param .u64 grad_in_ptr,
.param .u64 grad_w_ptr,
.param .u64 grad_b_ptr,
.param .u32 rows,
.param .u32 cols,
.param .f32 eps
) {
.reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
.reg .u64 %in, %go, %w, %gi, %gw, %gb, %row_off, %off, %sbase, %saddr, %addr;
.reg .f32 %val, %mean, %var, %diff, %eps_r, %inv_std, %x_hat, %wv, %gov;
.reg .f32 %dl_dx_hat, %sum1, %sum2, %other_val, %n_f, %mean1, %mean2, %result;
.reg .pred %p, %lp, %rp;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %go, [grad_out_ptr];
ld.param.u64 %w, [w_ptr];
ld.param.u64 %gi, [grad_in_ptr];
ld.param.u64 %gw, [grad_w_ptr];
ld.param.u64 %gb, [grad_b_ptr];
ld.param.u32 %rows_reg, [rows];
ld.param.u32 %cols_reg, [cols];
ld.param.f32 %eps_r, [eps];
mov.u64 %sbase, sdata;
mov.u32 %r_bid, %ctaid.x;
mov.u32 %r_bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
setp.ge.u32 %p, %r_bid, %rows_reg;
@%p bra LNB_DONE;
// row_off = bid * cols * 4 (byte offset for this row)
cvt.u64.u32 %row_off, %r_bid;
cvt.u64.u32 %off, %cols_reg;
mul.lo.u64 %row_off, %row_off, %off;
shl.b64 %row_off, %row_off, 2;
cvt.rn.f32.u32 %n_f, %cols_reg;
// ===== Phase 1: Compute mean =====
mov.f32 %mean, 0f00000000;
mov.u32 %j, %r_tid;
LNB_SM:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra LNB_SMD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
add.f32 %mean, %mean, %val;
add.u32 %j, %j, %r_bdim;
bra LNB_SM;
LNB_SMD:
// Shared memory reduce for mean
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %mean;
bar.sync 0;
mov.u32 %half, %r_bdim;
LNB_MR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra LNB_MRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra LNB_MRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %mean, [%saddr];
add.f32 %mean, %mean, %other_val;
st.shared.f32 [%saddr], %mean;
LNB_MRS:
bar.sync 0;
bra LNB_MR;
LNB_MRD:
ld.shared.f32 %mean, [%sbase];
div.approx.f32 %mean, %mean, %n_f;
bar.sync 0;
// ===== Phase 2: Compute variance =====
mov.f32 %var, 0f00000000;
mov.u32 %j, %r_tid;
LNB_SV:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra LNB_SVD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
sub.f32 %diff, %val, %mean;
fma.rn.f32 %var, %diff, %diff, %var;
add.u32 %j, %j, %r_bdim;
bra LNB_SV;
LNB_SVD:
// Shared memory reduce for variance
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %var;
bar.sync 0;
mov.u32 %half, %r_bdim;
LNB_VR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra LNB_VRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra LNB_VRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %var, [%saddr];
add.f32 %var, %var, %other_val;
st.shared.f32 [%saddr], %var;
LNB_VRS:
bar.sync 0;
bra LNB_VR;
LNB_VRD:
ld.shared.f32 %var, [%sbase];
div.approx.f32 %var, %var, %n_f;
add.f32 %var, %var, %eps_r;
sqrt.approx.f32 %inv_std, %var;
rcp.approx.f32 %inv_std, %inv_std;
bar.sync 0;
// ===== Phase 3: Compute sum1 = sum(dl_dx_hat), sum2 = sum(dl_dx_hat * x_hat) =====
// Also accumulate grad_weight and grad_bias via atomicAdd
mov.f32 %sum1, 0f00000000;
mov.f32 %sum2, 0f00000000;
mov.u32 %j, %r_tid;
LNB_S12:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra LNB_S12D;
// Load input[row, j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
// x_hat = (val - mean) * inv_std
sub.f32 %x_hat, %val, %mean;
mul.f32 %x_hat, %x_hat, %inv_std;
// Load grad_output[row, j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %go, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %gov, [%addr];
// Load weight[j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %w, %off;
ld.global.f32 %wv, [%addr];
// dl_dx_hat = grad_output * weight
mul.f32 %dl_dx_hat, %gov, %wv;
// Accumulate sums
add.f32 %sum1, %sum1, %dl_dx_hat;
fma.rn.f32 %sum2, %dl_dx_hat, %x_hat, %sum2;
// atomicAdd grad_weight[j] += grad_output * x_hat
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %gw, %off;
mul.f32 %result, %gov, %x_hat;
atom.global.add.f32 %result, [%addr], %result;
// atomicAdd grad_bias[j] += grad_output
add.u64 %addr, %gb, %off;
atom.global.add.f32 %result, [%addr], %gov;
add.u32 %j, %j, %r_bdim;
bra LNB_S12;
LNB_S12D:
// Reduce sum1 in shared memory
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sum1;
bar.sync 0;
mov.u32 %half, %r_bdim;
LNB_R1:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra LNB_R1D;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra LNB_R1S;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %sum1, [%saddr];
add.f32 %sum1, %sum1, %other_val;
st.shared.f32 [%saddr], %sum1;
LNB_R1S:
bar.sync 0;
bra LNB_R1;
LNB_R1D:
ld.shared.f32 %sum1, [%sbase];
// mean1 = sum1 / n
div.approx.f32 %mean1, %sum1, %n_f;
bar.sync 0;
// Reduce sum2 in shared memory
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sum2;
bar.sync 0;
mov.u32 %half, %r_bdim;
LNB_R2:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra LNB_R2D;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra LNB_R2S;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %sum2, [%saddr];
add.f32 %sum2, %sum2, %other_val;
st.shared.f32 [%saddr], %sum2;
LNB_R2S:
bar.sync 0;
bra LNB_R2;
LNB_R2D:
ld.shared.f32 %sum2, [%sbase];
// mean2 = sum2 / n
div.approx.f32 %mean2, %sum2, %n_f;
bar.sync 0;
// ===== Phase 4: Compute grad_input =====
// grad_input[j] = inv_std * (dl_dx_hat[j] - mean1 - x_hat[j] * mean2)
mov.u32 %j, %r_tid;
LNB_GI:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra LNB_GID;
// Reload input to recompute x_hat
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
sub.f32 %x_hat, %val, %mean;
mul.f32 %x_hat, %x_hat, %inv_std;
// Reload grad_output and weight to recompute dl_dx_hat
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %go, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %gov, [%addr];
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %w, %off;
ld.global.f32 %wv, [%addr];
mul.f32 %dl_dx_hat, %gov, %wv;
// result = inv_std * (dl_dx_hat - mean1 - x_hat * mean2)
sub.f32 %result, %dl_dx_hat, %mean1;
mul.f32 %diff, %x_hat, %mean2;
sub.f32 %result, %result, %diff;
mul.f32 %result, %inv_std, %result;
// Store grad_input[row, j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %gi, %off;
add.u64 %addr, %addr, %row_off;
st.global.f32 [%addr], %result;
add.u32 %j, %j, %r_bdim;
bra LNB_GI;
LNB_GID:
LNB_DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const RMSNORM_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.shared .align 4 .f32 sdata[256];
.visible .entry rmsnorm_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u64 w_ptr,
.param .u32 rows,
.param .u32 cols,
.param .f32 eps
) {
.reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
.reg .u64 %in, %out, %w, %row_off, %off, %sbase, %saddr;
.reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %wv, %result, %other_val, %n_f;
.reg .pred %p, %lp, %rp;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %w, [w_ptr];
ld.param.u32 %rows_reg, [rows];
ld.param.u32 %cols_reg, [cols];
ld.param.f32 %eps_r, [eps];
mov.u64 %sbase, sdata;
mov.u32 %r_bid, %ctaid.x;
mov.u32 %r_bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
setp.ge.u32 %p, %r_bid, %rows_reg;
@%p bra DONE;
cvt.u64.u32 %row_off, %r_bid;
cvt.u64.u32 %off, %cols_reg;
mul.lo.u64 %row_off, %row_off, %off;
shl.b64 %row_off, %row_off, 2;
cvt.rn.f32.u32 %n_f, %cols_reg;
// ===== Phase 1: Compute sum(x^2) =====
mov.f32 %sq_sum, 0f00000000;
mov.u32 %j, %r_tid;
SS:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra SSD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.f32 %val, [%off];
fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
add.u32 %j, %j, %r_bdim;
bra SS;
SSD:
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sq_sum;
bar.sync 0;
mov.u32 %half, %r_bdim;
SR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra SRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra SRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %sq_sum, [%saddr];
add.f32 %sq_sum, %sq_sum, %other_val;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sq_sum;
SRS:
bar.sync 0;
bra SR;
SRD:
ld.shared.f32 %sq_sum, [%sbase];
div.approx.f32 %sq_sum, %sq_sum, %n_f;
add.f32 %sq_sum, %sq_sum, %eps_r;
sqrt.approx.f32 %inv_rms, %sq_sum;
rcp.approx.f32 %inv_rms, %inv_rms;
bar.sync 0;
// ===== Phase 2: Normalize and scale =====
// out[j] = x[j] * inv_rms * weight[j]
mov.u32 %j, %r_tid;
NM:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra NMD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
add.u64 %off, %off, %row_off;
ld.global.f32 %val, [%off];
mul.f32 %result, %val, %inv_rms;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %w, %off;
ld.global.f32 %wv, [%off];
mul.f32 %result, %result, %wv;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
add.u64 %off, %off, %row_off;
st.global.f32 [%off], %result;
add.u32 %j, %j, %r_bdim;
bra NM;
NMD:
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const RMSNORM_BACKWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.shared .align 4 .f32 sdata[256];
.visible .entry rmsnorm_backward_kernel(
.param .u64 in_ptr,
.param .u64 grad_out_ptr,
.param .u64 w_ptr,
.param .u64 grad_in_ptr,
.param .u64 grad_w_ptr,
.param .u32 rows,
.param .u32 cols,
.param .f32 eps
) {
.reg .u32 %r_tid, %r_bid, %r_bdim, %rows_reg, %cols_reg, %j, %half, %r_otid;
.reg .u64 %in, %go, %w, %gi, %gw, %row_off, %off, %sbase, %saddr, %addr;
.reg .f32 %val, %sq_sum, %eps_r, %inv_rms, %inv_rms3, %wv, %gov;
.reg .f32 %dot, %other_val, %n_f, %coeff, %result, %tmp;
.reg .pred %p, %lp, %rp;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %go, [grad_out_ptr];
ld.param.u64 %w, [w_ptr];
ld.param.u64 %gi, [grad_in_ptr];
ld.param.u64 %gw, [grad_w_ptr];
ld.param.u32 %rows_reg, [rows];
ld.param.u32 %cols_reg, [cols];
ld.param.f32 %eps_r, [eps];
mov.u64 %sbase, sdata;
mov.u32 %r_bid, %ctaid.x;
mov.u32 %r_bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
setp.ge.u32 %p, %r_bid, %rows_reg;
@%p bra RNB_DONE;
// row_off = bid * cols * 4 (byte offset for this row)
cvt.u64.u32 %row_off, %r_bid;
cvt.u64.u32 %off, %cols_reg;
mul.lo.u64 %row_off, %row_off, %off;
shl.b64 %row_off, %row_off, 2;
cvt.rn.f32.u32 %n_f, %cols_reg;
// ===== Phase 1: Compute sum(x^2) -> inv_rms =====
mov.f32 %sq_sum, 0f00000000;
mov.u32 %j, %r_tid;
RNB_SS:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra RNB_SSD;
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
fma.rn.f32 %sq_sum, %val, %val, %sq_sum;
add.u32 %j, %j, %r_bdim;
bra RNB_SS;
RNB_SSD:
// Shared memory reduce for sum(x^2)
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %sq_sum;
bar.sync 0;
mov.u32 %half, %r_bdim;
RNB_SR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra RNB_SRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra RNB_SRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %sq_sum, [%saddr];
add.f32 %sq_sum, %sq_sum, %other_val;
st.shared.f32 [%saddr], %sq_sum;
RNB_SRS:
bar.sync 0;
bra RNB_SR;
RNB_SRD:
ld.shared.f32 %sq_sum, [%sbase];
div.approx.f32 %sq_sum, %sq_sum, %n_f;
add.f32 %sq_sum, %sq_sum, %eps_r;
sqrt.approx.f32 %inv_rms, %sq_sum;
rcp.approx.f32 %inv_rms, %inv_rms;
// inv_rms3 = inv_rms^3 = inv_rms * inv_rms * inv_rms
mul.f32 %inv_rms3, %inv_rms, %inv_rms;
mul.f32 %inv_rms3, %inv_rms3, %inv_rms;
bar.sync 0;
// ===== Phase 2: Compute dot = sum(go[j] * x[j] * w[j]) =====
// Also accumulate grad_weight via atomicAdd
mov.f32 %dot, 0f00000000;
mov.u32 %j, %r_tid;
RNB_DOT:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra RNB_DOTD;
// Load input[row, j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
// Load grad_output[row, j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %go, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %gov, [%addr];
// Load weight[j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %w, %off;
ld.global.f32 %wv, [%addr];
// dot += go * x * w
mul.f32 %tmp, %gov, %val;
fma.rn.f32 %dot, %tmp, %wv, %dot;
// atomicAdd grad_weight[j] += go * x * inv_rms
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %gw, %off;
mul.f32 %result, %gov, %val;
mul.f32 %result, %result, %inv_rms;
atom.global.add.f32 %result, [%addr], %result;
add.u32 %j, %j, %r_bdim;
bra RNB_DOT;
RNB_DOTD:
// Reduce dot in shared memory
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
st.shared.f32 [%saddr], %dot;
bar.sync 0;
mov.u32 %half, %r_bdim;
RNB_DR:
shr.u32 %half, %half, 1;
setp.eq.u32 %rp, %half, 0;
@%rp bra RNB_DRD;
setp.ge.u32 %rp, %r_tid, %half;
@%rp bra RNB_DRS;
add.u32 %r_otid, %r_tid, %half;
cvt.u64.u32 %off, %r_otid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %dot, [%saddr];
add.f32 %dot, %dot, %other_val;
st.shared.f32 [%saddr], %dot;
RNB_DRS:
bar.sync 0;
bra RNB_DR;
RNB_DRD:
ld.shared.f32 %dot, [%sbase];
// coeff = dot * inv_rms3 / n
mul.f32 %coeff, %dot, %inv_rms3;
div.approx.f32 %coeff, %coeff, %n_f;
bar.sync 0;
// ===== Phase 3: Compute grad_input =====
// grad_input[j] = inv_rms * w[j] * go[j] - x[j] * coeff
mov.u32 %j, %r_tid;
RNB_GI:
setp.ge.u32 %lp, %j, %cols_reg;
@%lp bra RNB_GID;
// Reload input
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %val, [%addr];
// Reload grad_output and weight
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %go, %off;
add.u64 %addr, %addr, %row_off;
ld.global.f32 %gov, [%addr];
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %w, %off;
ld.global.f32 %wv, [%addr];
// result = inv_rms * w * go - x * coeff
mul.f32 %result, %inv_rms, %wv;
mul.f32 %result, %result, %gov;
mul.f32 %tmp, %val, %coeff;
sub.f32 %result, %result, %tmp;
// Store grad_input[row, j]
cvt.u64.u32 %off, %j;
shl.b64 %off, %off, 2;
add.u64 %addr, %gi, %off;
add.u64 %addr, %addr, %row_off;
st.global.f32 [%addr], %result;
add.u32 %j, %j, %r_bdim;
bra RNB_GI;
RNB_GID:
RNB_DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const BATCHNORM_FORWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
// smem_sum[256] holds per-thread partial sums in pass 1,
// then mean (index 0) and invstd (index 1) after reduction.
.shared .align 4 .f32 smem_sum[256];
.shared .align 4 .f32 smem_sq[256];
.visible .entry batchnorm_forward_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u64 weight_ptr,
.param .u64 bias_ptr,
.param .u64 rmean_ptr,
.param .u64 rvar_ptr,
.param .u64 save_mean_ptr,
.param .u64 save_invstd_ptr,
.param .u32 channels,
.param .u32 spatial,
.param .f32 eps,
.param .f32 momentum,
.param .u32 total_per_ch,
.param .u32 training
) {
// k_loop = loop counter over [0, total_per_ch) for this thread
// flat = linearised NCHW index for the current k_loop value
// bidx, sidx = batch-element index and spatial index derived from k_loop
// sbase_sum/sbase_sq = base addresses of the two smem arrays held in
// registers, so we can do [base + offset] with a register offset.
.reg .u32 %my_tid, %bid, %bdim, %ch, %n_ch, %sp, %tpc, %train;
.reg .u32 %k_loop, %flat, %bidx, %sidx;
.reg .u32 %half, %peer;
.reg .u64 %in, %out, %w, %b, %rm, %rv, %sm, %si, %off64, %tmp64;
.reg .u64 %sbase_sum, %sbase_sq, %saddr_tid, %saddr_peer, %saddr_sq_tid;
.reg .f32 %sum, %sqsum, %val, %mean, %var, %invstd;
.reg .f32 %gamma, %beta, %eps_reg, %mom, %other;
.reg .f32 %n_f, %one, %normalized, %rm_old, %rv_old, %one_m_mom;
.reg .pred %p, %ptrain, %ptid0;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u64 %w, [weight_ptr];
ld.param.u64 %b, [bias_ptr];
ld.param.u64 %rm, [rmean_ptr];
ld.param.u64 %rv, [rvar_ptr];
ld.param.u64 %sm, [save_mean_ptr];
ld.param.u64 %si, [save_invstd_ptr];
ld.param.u32 %n_ch, [channels];
ld.param.u32 %sp, [spatial];
ld.param.f32 %eps_reg, [eps];
ld.param.f32 %mom, [momentum];
ld.param.u32 %tpc, [total_per_ch];
ld.param.u32 %train, [training];
mov.u32 %bid, %ctaid.x;
mov.u32 %my_tid, %tid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %ch, %bid;
mov.f32 %one, 0f3F800000;
// Load smem array base addresses into registers.
// PTX requires [reg] or [reg + imm_offset] for shared accesses;
// [label + reg_offset] is not valid syntax. Use mov.u64 to get
// the base address into a register, then add the byte offset.
mov.u64 %sbase_sum, smem_sum;
mov.u64 %sbase_sq, smem_sq;
// Per-thread slot address: saddr_tid = sbase_sum + tid*4
cvt.u64.u32 %off64, %my_tid;
shl.b64 %off64, %off64, 2;
add.u64 %saddr_tid, %sbase_sum, %off64;
add.u64 %saddr_sq_tid, %sbase_sq, %off64;
// Out-of-range block guard
setp.ge.u32 %p, %ch, %n_ch;
@%p bra END;
setp.ne.u32 %ptrain, %train, 0;
// ===== Pass 1: accumulate per-thread sum and sum-of-squares =====
mov.f32 %sum, 0f00000000;
mov.f32 %sqsum, 0f00000000;
// Grid-stride over k_loop in [0, tpc), step = bdim.
// k_loop is the loop counter; flat is recomputed fresh each iteration
// so the loop counter is never clobbered by address arithmetic.
mov.u32 %k_loop, %my_tid;
PASS1_LOOP:
setp.ge.u32 %p, %k_loop, %tpc;
@%p bra PASS1_DONE;
// Flat NCHW index:
// k_loop in [0, N*H*W) for this channel.
// bidx = k_loop / spatial (batch element)
// sidx = k_loop % spatial (spatial element)
// flat = bidx * (C * spatial) + ch * spatial + sidx
div.u32 %bidx, %k_loop, %sp;
rem.u32 %sidx, %k_loop, %sp;
mul.lo.u32 %flat, %bidx, %n_ch;
add.u32 %flat, %flat, %ch;
mul.lo.u32 %flat, %flat, %sp;
add.u32 %flat, %flat, %sidx;
cvt.u64.u32 %off64, %flat;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %in, %off64;
ld.global.f32 %val, [%tmp64];
add.f32 %sum, %sum, %val;
fma.rn.f32 %sqsum, %val, %val, %sqsum;
add.u32 %k_loop, %k_loop, %bdim;
bra PASS1_LOOP;
PASS1_DONE:
// Store partial sums to smem slots for this thread
st.shared.f32 [%saddr_tid], %sum;
st.shared.f32 [%saddr_sq_tid], %sqsum;
bar.sync 0;
// Tree reduction: start at half = bdim/2 so peer indices are always in
// [0, bdim) regardless of block size. Fixed hardcoded 128 start caused
// out-of-bounds smem reads when bdim < 256 (e.g. bdim=16 => half=128
// tries to read smem[16..143] which are uninitialized, producing NaN).
shr.u32 %half, %bdim, 1;
REDUCE_LOOP:
setp.lt.u32 %p, %half, 1;
@%p bra REDUCE_DONE;
setp.ge.u32 %p, %my_tid, %half;
@%p bra REDUCE_SKIP;
// sum: load peer slot, load my slot, add, store back to my slot
add.u32 %peer, %my_tid, %half;
cvt.u64.u32 %off64, %peer;
shl.b64 %off64, %off64, 2;
add.u64 %saddr_peer, %sbase_sum, %off64;
ld.shared.f32 %other, [%saddr_peer];
ld.shared.f32 %sum, [%saddr_tid];
add.f32 %sum, %sum, %other;
st.shared.f32 [%saddr_tid], %sum;
// sqsum: same pattern using sbase_sq
add.u64 %saddr_peer, %sbase_sq, %off64;
ld.shared.f32 %other, [%saddr_peer];
ld.shared.f32 %sqsum, [%saddr_sq_tid];
add.f32 %sqsum, %sqsum, %other;
st.shared.f32 [%saddr_sq_tid], %sqsum;
REDUCE_SKIP:
bar.sync 0;
shr.u32 %half, %half, 1;
bra REDUCE_LOOP;
REDUCE_DONE:
// Thread 0 computes mean and invstd (or reads them from running stats)
setp.ne.u32 %ptid0, %my_tid, 0;
@%ptid0 bra WAIT_STATS;
// Channel byte offset for scalar per-channel arrays
cvt.u64.u32 %off64, %ch;
shl.b64 %off64, %off64, 2;
@!%ptrain bra USE_RUNNING_STATS;
// --- training mode: compute mean and biased variance from data ---
ld.shared.f32 %sum, [%sbase_sum];
ld.shared.f32 %sqsum, [%sbase_sq];
cvt.rn.f32.u32 %n_f, %tpc;
div.rn.f32 %mean, %sum, %n_f;
// biased_var = E[x^2] - mean^2 = sqsum/n - mean^2
div.rn.f32 %var, %sqsum, %n_f;
neg.f32 %other, %mean;
fma.rn.f32 %var, %other, %mean, %var;
// invstd = 1 / sqrt(var + eps)
add.f32 %other, %var, %eps_reg;
sqrt.rn.f32 %other, %other;
div.rn.f32 %invstd, %one, %other;
// Save mean and invstd for backward
add.u64 %tmp64, %sm, %off64;
st.global.f32 [%tmp64], %mean;
add.u64 %tmp64, %si, %off64;
st.global.f32 [%tmp64], %invstd;
// Update running_mean: rm = (1-mom)*rm + mom*mean
sub.f32 %one_m_mom, %one, %mom;
add.u64 %tmp64, %rm, %off64;
ld.global.f32 %rm_old, [%tmp64];
mul.f32 %other, %one_m_mom, %rm_old;
fma.rn.f32 %other, %mom, %mean, %other;
st.global.f32 [%tmp64], %other;
// Update running_var: rv = (1-mom)*rv + mom*var
add.u64 %tmp64, %rv, %off64;
ld.global.f32 %rv_old, [%tmp64];
mul.f32 %other, %one_m_mom, %rv_old;
fma.rn.f32 %other, %mom, %var, %other;
st.global.f32 [%tmp64], %other;
bra STATS_DONE;
USE_RUNNING_STATS:
// --- inference mode: read running_mean and running_var ---
add.u64 %tmp64, %rm, %off64;
ld.global.f32 %mean, [%tmp64];
add.u64 %tmp64, %rv, %off64;
ld.global.f32 %var, [%tmp64];
// invstd = 1 / sqrt(var + eps)
add.f32 %other, %var, %eps_reg;
sqrt.rn.f32 %other, %other;
div.rn.f32 %invstd, %one, %other;
STATS_DONE:
// Broadcast mean/invstd to smem slot [0] so all threads can read them
st.shared.f32 [%sbase_sum], %mean;
st.shared.f32 [%sbase_sq], %invstd;
WAIT_STATS:
bar.sync 0;
ld.shared.f32 %mean, [%sbase_sum];
ld.shared.f32 %invstd, [%sbase_sq];
// Load gamma (weight) and beta (bias) for this channel
cvt.u64.u32 %off64, %ch;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %w, %off64;
ld.global.f32 %gamma, [%tmp64];
add.u64 %tmp64, %b, %off64;
ld.global.f32 %beta, [%tmp64];
// ===== Pass 2: normalize + affine, same grid-stride as pass 1 =====
mov.u32 %k_loop, %my_tid;
PASS2_LOOP:
setp.ge.u32 %p, %k_loop, %tpc;
@%p bra PASS2_DONE;
// Reconstruct flat index (same formula as pass 1)
div.u32 %bidx, %k_loop, %sp;
rem.u32 %sidx, %k_loop, %sp;
mul.lo.u32 %flat, %bidx, %n_ch;
add.u32 %flat, %flat, %ch;
mul.lo.u32 %flat, %flat, %sp;
add.u32 %flat, %flat, %sidx;
cvt.u64.u32 %off64, %flat;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %in, %off64;
ld.global.f32 %val, [%tmp64];
// normalized = (val - mean) * invstd
sub.f32 %normalized, %val, %mean;
mul.f32 %normalized, %normalized, %invstd;
// output = gamma * normalized + beta
fma.rn.f32 %normalized, %gamma, %normalized, %beta;
add.u64 %tmp64, %out, %off64;
st.global.f32 [%tmp64], %normalized;
add.u32 %k_loop, %k_loop, %bdim;
bra PASS2_LOOP;
PASS2_DONE:
END:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const MAXPOOL2D_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry maxpool2d_forward_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 batch,
.param .u32 channels,
.param .u32 h_in,
.param .u32 w_in,
.param .u32 h_out,
.param .u32 w_out,
.param .u32 kh,
.param .u32 kw,
.param .u32 sh,
.param .u32 sw,
.param .u32 ph,
.param .u32 pw,
.param .u32 total
) {
.reg .u32 %my_tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
.reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp;
.reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
.reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
.reg .u32 %batch_reg, %ch_reg;
.reg .u64 %in, %out, %off64, %tmp64;
.reg .f32 %max_val, %cur_val, %neg_inf;
.reg .pred %p, %p_bounds, %p_gt;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %batch_reg, [batch];
ld.param.u32 %ch_reg, [channels];
ld.param.u32 %h_in_reg, [h_in];
ld.param.u32 %w_in_reg, [w_in];
ld.param.u32 %h_out_reg, [h_out];
ld.param.u32 %w_out_reg, [w_out];
ld.param.u32 %kh_reg, [kh];
ld.param.u32 %kw_reg, [kw];
ld.param.u32 %sh_reg, [sh];
ld.param.u32 %sw_reg, [sw];
ld.param.u32 %ph_reg, [ph];
ld.param.u32 %pw_reg, [pw];
ld.param.u32 %total_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %my_tid, %tid.x;
mov.u32 %gdim, %nctaid.x;
mad.lo.u32 %idx, %bid, %bdim, %my_tid;
mul.lo.u32 %stride, %bdim, %gdim;
// -inf for max initialization
mov.f32 %neg_inf, 0fFF800000;
LOOP:
setp.ge.u32 %p, %idx, %total_reg;
@%p bra END;
// Decompose idx into (b, c, oh, ow)
mov.u32 %rem, %idx;
div.u32 %b_idx, %rem, %ch_reg;
// Actually need: idx = b * C * H_out * W_out + c * H_out * W_out + oh * W_out + ow
// So decompose from the right:
rem.u32 %ow, %rem, %w_out_reg;
div.u32 %rem, %rem, %w_out_reg;
rem.u32 %oh, %rem, %h_out_reg;
div.u32 %rem, %rem, %h_out_reg;
rem.u32 %c_idx, %rem, %ch_reg;
div.u32 %b_idx, %rem, %ch_reg;
mov.f32 %max_val, %neg_inf;
// Slide the kernel window
mov.u32 %i, 0;
KH_LOOP:
setp.ge.u32 %p, %i, %kh_reg;
@%p bra KH_DONE;
mov.u32 %j, 0;
KW_LOOP:
setp.ge.u32 %p, %j, %kw_reg;
@%p bra KW_DONE;
// ih = oh * sh + i - ph, iw = ow * sw + j - pw
mad.lo.u32 %ih, %oh, %sh_reg, %i;
sub.u32 %ih, %ih, %ph_reg;
mad.lo.u32 %iw, %ow, %sw_reg, %j;
sub.u32 %iw, %iw, %pw_reg;
// Bounds check: 0 <= ih < h_in && 0 <= iw < w_in
// Since unsigned, just check < h_in and < w_in
setp.ge.u32 %p_bounds, %ih, %h_in_reg;
@%p_bounds bra KW_NEXT;
setp.ge.u32 %p_bounds, %iw, %w_in_reg;
@%p_bounds bra KW_NEXT;
// input_offset = b * C * H * W + c * H * W + ih * W + iw
mul.lo.u32 %tmp, %b_idx, %ch_reg;
add.u32 %tmp, %tmp, %c_idx;
mul.lo.u32 %tmp, %tmp, %h_in_reg;
add.u32 %tmp, %tmp, %ih;
mul.lo.u32 %tmp, %tmp, %w_in_reg;
add.u32 %tmp, %tmp, %iw;
cvt.u64.u32 %off64, %tmp;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %in, %off64;
ld.global.f32 %cur_val, [%tmp64];
max.f32 %max_val, %max_val, %cur_val;
KW_NEXT:
add.u32 %j, %j, 1;
bra KW_LOOP;
KW_DONE:
add.u32 %i, %i, 1;
bra KH_LOOP;
KH_DONE:
// Store output
cvt.u64.u32 %off64, %idx;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %out, %off64;
st.global.f32 [%tmp64], %max_val;
add.u32 %idx, %idx, %stride;
bra LOOP;
END:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const AVGPOOL2D_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry avgpool2d_forward_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 batch,
.param .u32 channels,
.param .u32 h_in,
.param .u32 w_in,
.param .u32 h_out,
.param .u32 w_out,
.param .u32 kh,
.param .u32 kw,
.param .u32 sh,
.param .u32 sw,
.param .u32 ph,
.param .u32 pw,
.param .u32 total
) {
.reg .u32 %my_tid, %bid, %bdim, %gdim, %idx, %stride, %total_reg;
.reg .u32 %b_idx, %c_idx, %oh, %ow, %rem, %ih, %iw, %tmp, %count;
.reg .u32 %i, %j, %h_in_reg, %w_in_reg, %kh_reg, %kw_reg;
.reg .u32 %sh_reg, %sw_reg, %ph_reg, %pw_reg, %h_out_reg, %w_out_reg;
.reg .u32 %batch_reg, %ch_reg;
.reg .u64 %in, %out, %off64, %tmp64;
.reg .f32 %sum_val, %cur_val, %count_f, %avg;
.reg .pred %p, %p_bounds;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %batch_reg, [batch];
ld.param.u32 %ch_reg, [channels];
ld.param.u32 %h_in_reg, [h_in];
ld.param.u32 %w_in_reg, [w_in];
ld.param.u32 %h_out_reg, [h_out];
ld.param.u32 %w_out_reg, [w_out];
ld.param.u32 %kh_reg, [kh];
ld.param.u32 %kw_reg, [kw];
ld.param.u32 %sh_reg, [sh];
ld.param.u32 %sw_reg, [sw];
ld.param.u32 %ph_reg, [ph];
ld.param.u32 %pw_reg, [pw];
ld.param.u32 %total_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %my_tid, %tid.x;
mov.u32 %gdim, %nctaid.x;
mad.lo.u32 %idx, %bid, %bdim, %my_tid;
mul.lo.u32 %stride, %bdim, %gdim;
LOOP:
setp.ge.u32 %p, %idx, %total_reg;
@%p bra END;
// Decompose idx into (b, c, oh, ow) -- same as MaxPool2d
mov.u32 %rem, %idx;
rem.u32 %ow, %rem, %w_out_reg;
div.u32 %rem, %rem, %w_out_reg;
rem.u32 %oh, %rem, %h_out_reg;
div.u32 %rem, %rem, %h_out_reg;
rem.u32 %c_idx, %rem, %ch_reg;
div.u32 %b_idx, %rem, %ch_reg;
mov.f32 %sum_val, 0f00000000;
mov.u32 %count, 0;
mov.u32 %i, 0;
AKH_LOOP:
setp.ge.u32 %p, %i, %kh_reg;
@%p bra AKH_DONE;
mov.u32 %j, 0;
AKW_LOOP:
setp.ge.u32 %p, %j, %kw_reg;
@%p bra AKW_DONE;
mad.lo.u32 %ih, %oh, %sh_reg, %i;
sub.u32 %ih, %ih, %ph_reg;
mad.lo.u32 %iw, %ow, %sw_reg, %j;
sub.u32 %iw, %iw, %pw_reg;
setp.ge.u32 %p_bounds, %ih, %h_in_reg;
@%p_bounds bra AKW_NEXT;
setp.ge.u32 %p_bounds, %iw, %w_in_reg;
@%p_bounds bra AKW_NEXT;
mul.lo.u32 %tmp, %b_idx, %ch_reg;
add.u32 %tmp, %tmp, %c_idx;
mul.lo.u32 %tmp, %tmp, %h_in_reg;
add.u32 %tmp, %tmp, %ih;
mul.lo.u32 %tmp, %tmp, %w_in_reg;
add.u32 %tmp, %tmp, %iw;
cvt.u64.u32 %off64, %tmp;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %in, %off64;
ld.global.f32 %cur_val, [%tmp64];
add.f32 %sum_val, %sum_val, %cur_val;
add.u32 %count, %count, 1;
AKW_NEXT:
add.u32 %j, %j, 1;
bra AKW_LOOP;
AKW_DONE:
add.u32 %i, %i, 1;
bra AKH_LOOP;
AKH_DONE:
// avg = sum / count (count_include_pad = false behavior)
cvt.rn.f32.u32 %count_f, %count;
div.rn.f32 %avg, %sum_val, %count_f;
cvt.u64.u32 %off64, %idx;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %out, %off64;
st.global.f32 [%tmp64], %avg;
add.u32 %idx, %idx, %stride;
bra LOOP;
END:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SOFTMAX_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 4 .f32 sdata[256];\n\
\n\
.visible .entry softmax_kernel(\n\
.param .u64 input_ptr,\n\
.param .u64 output_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
.reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f32 %val, %max_val, %sum_val, %exp_val, %result;\n\
.reg .pred %p, %loop_p;\n\
.reg .u32 %half, %other_tid;\n\
.reg .f32 %other_val;\n\
.reg .pred %reduce_p;\n\
\n\
ld.param.u64 %in, [input_ptr];\n\
ld.param.u64 %out, [output_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 2;\n\
\n\
mov.f32 %max_val, 0fFF800000;\n\
mov.u32 %j, %r_tid;\n\
FIND_MAX:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra FIND_MAX_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f32 %val, [%off];\n\
max.f32 %max_val, %max_val, %val;\n\
add.u32 %j, %j, %bdim;\n\
bra FIND_MAX;\n\
FIND_MAX_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %max_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
MAX_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra MAX_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra MAX_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %max_val, [%saddr];\n\
max.f32 %max_val, %max_val, %other_val;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %max_val;\n\
MAX_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra MAX_REDUCE;\n\
MAX_REDUCE_DONE:\n\
\n\
ld.shared.f32 %max_val, [sdata];\n\
bar.sync 0;\n\
\n\
mov.f32 %sum_val, 0f00000000;\n\
mov.u32 %j, %r_tid;\n\
SUM_EXP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra SUM_EXP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f32 %val, [%off];\n\
sub.f32 %val, %val, %max_val;\n\
mul.f32 %val, %val, 0f3FB8AA3B;\n\
ex2.approx.f32 %exp_val, %val;\n\
add.f32 %sum_val, %sum_val, %exp_val;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %out, %off;\n\
add.u64 %off, %off, %row_off;\n\
st.global.f32 [%off], %exp_val;\n\
add.u32 %j, %j, %bdim;\n\
bra SUM_EXP;\n\
SUM_EXP_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %sum_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
SUM_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra SUM_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra SUM_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %sum_val, [%saddr];\n\
add.f32 %sum_val, %sum_val, %other_val;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %sum_val;\n\
SUM_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra SUM_REDUCE;\n\
SUM_REDUCE_DONE:\n\
\n\
ld.shared.f32 %sum_val, [sdata];\n\
bar.sync 0;\n\
\n\
rcp.approx.f32 %sum_val, %sum_val;\n\
mov.u32 %j, %r_tid;\n\
NORMALIZE:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra NORMALIZE_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %out, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f32 %val, [%off];\n\
mul.f32 %result, %val, %sum_val;\n\
st.global.f32 [%off], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra NORMALIZE;\n\
NORMALIZE_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const SOFTMAX_F64_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 8 .f64 sdata[256];\n\
\n\
.visible .entry softmax_f64_kernel(\n\
.param .u64 input_ptr,\n\
.param .u64 output_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
.reg .u64 %in, %out, %row_off, %off, %sbase, %saddr;\n\
.reg .f64 %val, %max_val, %sum_val, %exp_val, %result, %one;\n\
.reg .pred %p, %loop_p;\n\
.reg .u32 %half, %other_tid;\n\
.reg .f64 %other_val;\n\
.reg .pred %reduce_p;\n\
.reg .f64 %e_nf, %e_r, %e_p, %e_half, %e_one;\n\
.reg .s32 %e_ni;\n\
.reg .s64 %e_ni64, %e_bits;\n\
.reg .f64 %neg_inf;\n\
.reg .pred %p_underflow;\n\
\n\
ld.param.u64 %in, [input_ptr];\n\
ld.param.u64 %out, [output_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata;\n\
mov.f64 %one, 0d3FF0000000000000;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %row_off, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off, %row_off, %off;\n\
shl.b64 %row_off, %row_off, 3;\n\
\n\
mov.f64 %max_val, 0dFFF0000000000000;\n\
mov.u32 %j, %r_tid;\n\
FIND_MAX:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra FIND_MAX_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f64 %val, [%off];\n\
max.f64 %max_val, %max_val, %val;\n\
add.u32 %j, %j, %bdim;\n\
bra FIND_MAX;\n\
FIND_MAX_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f64 [%saddr], %max_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
MAX_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra MAX_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra MAX_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %max_val, [%saddr];\n\
max.f64 %max_val, %max_val, %other_val;\n\
st.shared.f64 [%saddr], %max_val;\n\
MAX_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra MAX_REDUCE;\n\
MAX_REDUCE_DONE:\n\
\n\
ld.shared.f64 %max_val, [sdata];\n\
bar.sync 0;\n\
\n\
mov.f64 %sum_val, 0d0000000000000000;\n\
mov.u32 %j, %r_tid;\n\
SUM_EXP:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra SUM_EXP_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f64 %val, [%off];\n\
sub.f64 %val, %val, %max_val;\n\
// Underflow guard: when `val == -inf` (e.g. from a `-inf` mask\n\
// addend used by causal/block_diag/empty_mask attention masks), the\n\
// inlined `2^x` polynomial below converts -inf -> s32 via\n\
// `cvt.rni.s32.f64`, which produces an out-of-range value that the\n\
// subsequent `add+shl+bitcast` reinterprets as a NaN-bit-pattern,\n\
// poisoning the row sum. Detect val <= -inf (the only ordered way to\n\
// hit `==-inf`) and short-circuit to exp_val = 0.0; this matches the\n\
// hardware `ex2.approx.f32` semantics that the f32 softmax kernel\n\
// gets for free.\n\
mov.f64 %neg_inf, 0dFFF0000000000000;\n\
setp.le.f64 %p_underflow, %val, %neg_inf;\n\
@%p_underflow bra SUM_EXP_UNDERFLOW;\n\
mov.f64 %e_one, 0d3FF0000000000000;\n\
mov.f64 %e_half, 0d3FE0000000000000;\n\
mul.f64 %e_nf, %val, 0d3FF71547652B82FE;\n\
cvt.rni.f64.f64 %e_nf, %e_nf;\n\
cvt.rni.s32.f64 %e_ni, %e_nf;\n\
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %val;\n\
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;\n\
mov.f64 %e_p, 0d3E5AE64567F544E4;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;\n\
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;\n\
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;\n\
fma.rn.f64 %e_p, %e_p, %e_r, %e_one;\n\
fma.rn.f64 %exp_val, %e_p, %e_r, %e_one;\n\
cvt.s64.s32 %e_ni64, %e_ni;\n\
add.s64 %e_ni64, %e_ni64, 1023;\n\
shl.b64 %e_bits, %e_ni64, 52;\n\
mov.b64 %e_nf, %e_bits;\n\
mul.f64 %exp_val, %exp_val, %e_nf;\n\
bra SUM_EXP_STORE;\n\
SUM_EXP_UNDERFLOW:\n\
mov.f64 %exp_val, 0d0000000000000000;\n\
SUM_EXP_STORE:\n\
add.f64 %sum_val, %sum_val, %exp_val;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %off, %out, %off;\n\
add.u64 %off, %off, %row_off;\n\
st.global.f64 [%off], %exp_val;\n\
add.u32 %j, %j, %bdim;\n\
bra SUM_EXP;\n\
SUM_EXP_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f64 [%saddr], %sum_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
SUM_REDUCE:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra SUM_REDUCE_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra SUM_REDUCE_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f64 %sum_val, [%saddr];\n\
add.f64 %sum_val, %sum_val, %other_val;\n\
st.shared.f64 [%saddr], %sum_val;\n\
SUM_REDUCE_SKIP:\n\
bar.sync 0;\n\
bra SUM_REDUCE;\n\
SUM_REDUCE_DONE:\n\
\n\
ld.shared.f64 %sum_val, [sdata];\n\
bar.sync 0;\n\
\n\
div.rn.f64 %sum_val, %one, %sum_val;\n\
mov.u32 %j, %r_tid;\n\
NORMALIZE:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra NORMALIZE_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 3;\n\
add.u64 %off, %out, %off;\n\
add.u64 %off, %off, %row_off;\n\
ld.global.f64 %val, [%off];\n\
mul.f64 %result, %val, %sum_val;\n\
st.global.f64 [%off], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra NORMALIZE;\n\
NORMALIZE_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const SOFTMAX_BF16_F32_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.shared .align 4 .f32 sdata_bf16sm[256];\n\
\n\
.visible .entry softmax_bf16_f32_kernel(\n\
.param .u64 input_ptr,\n\
.param .u64 output_ptr,\n\
.param .u32 rows,\n\
.param .u32 cols\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %rows_reg, %cols_reg, %j;\n\
.reg .u64 %in, %out, %row_off_in, %row_off_out, %off, %sbase, %saddr;\n\
.reg .u16 %bf16_raw;\n\
.reg .u32 %bf16_bits;\n\
.reg .f32 %val, %max_val, %sum_val, %exp_val, %result;\n\
.reg .pred %p, %loop_p;\n\
.reg .u32 %half, %other_tid;\n\
.reg .f32 %other_val;\n\
.reg .pred %reduce_p;\n\
\n\
ld.param.u64 %in, [input_ptr];\n\
ld.param.u64 %out, [output_ptr];\n\
ld.param.u32 %rows_reg, [rows];\n\
ld.param.u32 %cols_reg, [cols];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mov.u64 %sbase, sdata_bf16sm;\n\
\n\
setp.ge.u32 %p, %bid, %rows_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %row_off_in, %bid;\n\
cvt.u64.u32 %off, %cols_reg;\n\
mul.lo.u64 %row_off_in, %row_off_in, %off;\n\
shl.b64 %row_off_in, %row_off_in, 1;\n\
\n\
cvt.u64.u32 %row_off_out, %bid;\n\
mul.lo.u64 %row_off_out, %row_off_out, %off;\n\
shl.b64 %row_off_out, %row_off_out, 2;\n\
\n\
mov.f32 %max_val, 0fFF800000;\n\
mov.u32 %j, %r_tid;\n\
FIND_MAX_BF16:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra FIND_MAX_BF16_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 1;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off_in;\n\
ld.global.u16 %bf16_raw, [%off];\n\
mov.b32 %bf16_bits, 0;\n\
cvt.u32.u16 %bf16_bits, %bf16_raw;\n\
shl.b32 %bf16_bits, %bf16_bits, 16;\n\
mov.b32 %val, %bf16_bits;\n\
max.f32 %max_val, %max_val, %val;\n\
add.u32 %j, %j, %bdim;\n\
bra FIND_MAX_BF16;\n\
FIND_MAX_BF16_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %max_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
MAX_REDUCE_BF16:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra MAX_REDUCE_BF16_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra MAX_REDUCE_BF16_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %max_val, [%saddr];\n\
max.f32 %max_val, %max_val, %other_val;\n\
st.shared.f32 [%saddr], %max_val;\n\
MAX_REDUCE_BF16_SKIP:\n\
bar.sync 0;\n\
bra MAX_REDUCE_BF16;\n\
MAX_REDUCE_BF16_DONE:\n\
\n\
ld.shared.f32 %max_val, [sdata_bf16sm];\n\
bar.sync 0;\n\
\n\
mov.f32 %sum_val, 0f00000000;\n\
mov.u32 %j, %r_tid;\n\
SUM_EXP_BF16:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra SUM_EXP_BF16_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 1;\n\
add.u64 %off, %in, %off;\n\
add.u64 %off, %off, %row_off_in;\n\
ld.global.u16 %bf16_raw, [%off];\n\
cvt.u32.u16 %bf16_bits, %bf16_raw;\n\
shl.b32 %bf16_bits, %bf16_bits, 16;\n\
mov.b32 %val, %bf16_bits;\n\
sub.f32 %val, %val, %max_val;\n\
mul.f32 %val, %val, 0f3FB8AA3B;\n\
ex2.approx.f32 %exp_val, %val;\n\
add.f32 %sum_val, %sum_val, %exp_val;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %out, %off;\n\
add.u64 %off, %off, %row_off_out;\n\
st.global.f32 [%off], %exp_val;\n\
add.u32 %j, %j, %bdim;\n\
bra SUM_EXP_BF16;\n\
SUM_EXP_BF16_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
st.shared.f32 [%saddr], %sum_val;\n\
bar.sync 0;\n\
\n\
mov.u32 %half, %bdim;\n\
SUM_REDUCE_BF16:\n\
shr.u32 %half, %half, 1;\n\
setp.eq.u32 %reduce_p, %half, 0;\n\
@%reduce_p bra SUM_REDUCE_BF16_DONE;\n\
setp.ge.u32 %reduce_p, %r_tid, %half;\n\
@%reduce_p bra SUM_REDUCE_BF16_SKIP;\n\
add.u32 %other_tid, %r_tid, %half;\n\
cvt.u64.u32 %off, %other_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %other_val, [%saddr];\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %saddr, %sbase, %off;\n\
ld.shared.f32 %sum_val, [%saddr];\n\
add.f32 %sum_val, %sum_val, %other_val;\n\
st.shared.f32 [%saddr], %sum_val;\n\
SUM_REDUCE_BF16_SKIP:\n\
bar.sync 0;\n\
bra SUM_REDUCE_BF16;\n\
SUM_REDUCE_BF16_DONE:\n\
\n\
ld.shared.f32 %sum_val, [sdata_bf16sm];\n\
bar.sync 0;\n\
\n\
rcp.approx.f32 %sum_val, %sum_val;\n\
mov.u32 %j, %r_tid;\n\
NORMALIZE_BF16:\n\
setp.ge.u32 %loop_p, %j, %cols_reg;\n\
@%loop_p bra NORMALIZE_BF16_DONE;\n\
cvt.u64.u32 %off, %j;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %off, %out, %off;\n\
add.u64 %off, %off, %row_off_out;\n\
ld.global.f32 %val, [%off];\n\
mul.f32 %result, %val, %sum_val;\n\
st.global.f32 [%off], %result;\n\
add.u32 %j, %j, %bdim;\n\
bra NORMALIZE_BF16;\n\
NORMALIZE_BF16_DONE:\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const ADD_BF16_F32_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry add_bf16_f32_kernel(\n\
.param .u64 a_ptr,\n\
.param .u64 b_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 n\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %n_reg;\n\
.reg .u64 %pa, %pb, %pout, %off, %addr;\n\
.reg .u16 %ra, %rb;\n\
.reg .u32 %ba, %bb;\n\
.reg .f32 %fa, %fb, %res;\n\
.reg .pred %p;\n\
\n\
ld.param.u64 %pa, [a_ptr];\n\
ld.param.u64 %pb, [b_ptr];\n\
ld.param.u64 %pout, [out_ptr];\n\
ld.param.u32 %n_reg, [n];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %n_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 1;\n\
add.u64 %addr, %pa, %off;\n\
ld.global.u16 %ra, [%addr];\n\
add.u64 %addr, %pb, %off;\n\
ld.global.u16 %rb, [%addr];\n\
\n\
mov.b32 %ba, 0;\n\
cvt.u32.u16 %ba, %ra;\n\
shl.b32 %ba, %ba, 16;\n\
mov.b32 %fa, %ba;\n\
\n\
mov.b32 %bb, 0;\n\
cvt.u32.u16 %bb, %rb;\n\
shl.b32 %bb, %bb, 16;\n\
mov.b32 %fb, %bb;\n\
\n\
add.f32 %res, %fa, %fb;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %addr, %pout, %off;\n\
st.global.f32 [%addr], %res;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const SUB_BF16_F32_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry sub_bf16_f32_kernel(\n\
.param .u64 a_ptr,\n\
.param .u64 b_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 n\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %n_reg;\n\
.reg .u64 %pa, %pb, %pout, %off, %addr;\n\
.reg .u16 %ra, %rb;\n\
.reg .u32 %ba, %bb;\n\
.reg .f32 %fa, %fb, %res;\n\
.reg .pred %p;\n\
\n\
ld.param.u64 %pa, [a_ptr];\n\
ld.param.u64 %pb, [b_ptr];\n\
ld.param.u64 %pout, [out_ptr];\n\
ld.param.u32 %n_reg, [n];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %n_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 1;\n\
add.u64 %addr, %pa, %off;\n\
ld.global.u16 %ra, [%addr];\n\
add.u64 %addr, %pb, %off;\n\
ld.global.u16 %rb, [%addr];\n\
\n\
mov.b32 %ba, 0;\n\
cvt.u32.u16 %ba, %ra;\n\
shl.b32 %ba, %ba, 16;\n\
mov.b32 %fa, %ba;\n\
\n\
mov.b32 %bb, 0;\n\
cvt.u32.u16 %bb, %rb;\n\
shl.b32 %bb, %bb, 16;\n\
mov.b32 %fb, %bb;\n\
\n\
sub.f32 %res, %fa, %fb;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %addr, %pout, %off;\n\
st.global.f32 [%addr], %res;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const MUL_BF16_F32_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry mul_bf16_f32_kernel(\n\
.param .u64 a_ptr,\n\
.param .u64 b_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 n\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %n_reg;\n\
.reg .u64 %pa, %pb, %pout, %off, %addr;\n\
.reg .u16 %ra, %rb;\n\
.reg .u32 %ba, %bb;\n\
.reg .f32 %fa, %fb, %res;\n\
.reg .pred %p;\n\
\n\
ld.param.u64 %pa, [a_ptr];\n\
ld.param.u64 %pb, [b_ptr];\n\
ld.param.u64 %pout, [out_ptr];\n\
ld.param.u32 %n_reg, [n];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %n_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 1;\n\
add.u64 %addr, %pa, %off;\n\
ld.global.u16 %ra, [%addr];\n\
add.u64 %addr, %pb, %off;\n\
ld.global.u16 %rb, [%addr];\n\
\n\
mov.b32 %ba, 0;\n\
cvt.u32.u16 %ba, %ra;\n\
shl.b32 %ba, %ba, 16;\n\
mov.b32 %fa, %ba;\n\
\n\
mov.b32 %bb, 0;\n\
cvt.u32.u16 %bb, %rb;\n\
shl.b32 %bb, %bb, 16;\n\
mov.b32 %fb, %bb;\n\
\n\
mul.f32 %res, %fa, %fb;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %addr, %pout, %off;\n\
st.global.f32 [%addr], %res;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const DIV_BF16_F32_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry div_bf16_f32_kernel(\n\
.param .u64 a_ptr,\n\
.param .u64 b_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 n\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %n_reg;\n\
.reg .u64 %pa, %pb, %pout, %off, %addr;\n\
.reg .u16 %ra, %rb;\n\
.reg .u32 %ba, %bb;\n\
.reg .f32 %fa, %fb, %res;\n\
.reg .pred %p;\n\
\n\
ld.param.u64 %pa, [a_ptr];\n\
ld.param.u64 %pb, [b_ptr];\n\
ld.param.u64 %pout, [out_ptr];\n\
ld.param.u32 %n_reg, [n];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %n_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 1;\n\
add.u64 %addr, %pa, %off;\n\
ld.global.u16 %ra, [%addr];\n\
add.u64 %addr, %pb, %off;\n\
ld.global.u16 %rb, [%addr];\n\
\n\
mov.b32 %ba, 0;\n\
cvt.u32.u16 %ba, %ra;\n\
shl.b32 %ba, %ba, 16;\n\
mov.b32 %fa, %ba;\n\
\n\
mov.b32 %bb, 0;\n\
cvt.u32.u16 %bb, %rb;\n\
shl.b32 %bb, %bb, 16;\n\
mov.b32 %fb, %bb;\n\
\n\
div.approx.f32 %res, %fa, %fb;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %addr, %pout, %off;\n\
st.global.f32 [%addr], %res;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const SUM_AXIS_BF16_F32_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry sum_axis_bf16_f32_kernel(\n\
.param .u64 input_ptr,\n\
.param .u64 output_ptr,\n\
.param .u32 outer_size,\n\
.param .u32 axis_size,\n\
.param .u32 inner_size,\n\
.param .u32 total_output\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %axis_sz, %inner_sz;\n\
.reg .u32 %outer_idx, %inner_idx, %k, %tmp, %elem_idx;\n\
.reg .u64 %in, %out, %off, %addr;\n\
.reg .u16 %raw;\n\
.reg .u32 %bits;\n\
.reg .f32 %val, %sum;\n\
.reg .pred %p, %lp;\n\
\n\
ld.param.u64 %in, [input_ptr];\n\
ld.param.u64 %out, [output_ptr];\n\
ld.param.u32 %outer_sz, [outer_size];\n\
ld.param.u32 %axis_sz, [axis_size];\n\
ld.param.u32 %inner_sz, [inner_size];\n\
ld.param.u32 %n_reg, [total_output];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %n_reg;\n\
@%p bra DONE;\n\
\n\
div.u32 %outer_idx, %r_tid, %inner_sz;\n\
rem.u32 %inner_idx, %r_tid, %inner_sz;\n\
\n\
mul.lo.u32 %tmp, %outer_idx, %axis_sz;\n\
mul.lo.u32 %tmp, %tmp, %inner_sz;\n\
add.u32 %tmp, %tmp, %inner_idx;\n\
\n\
mov.f32 %sum, 0f00000000;\n\
mov.u32 %k, 0;\n\
SUM_BF16_LOOP:\n\
setp.ge.u32 %lp, %k, %axis_sz;\n\
@%lp bra SUM_BF16_LOOP_DONE;\n\
\n\
mul.lo.u32 %elem_idx, %k, %inner_sz;\n\
add.u32 %elem_idx, %tmp, %elem_idx;\n\
cvt.u64.u32 %off, %elem_idx;\n\
shl.b64 %off, %off, 1;\n\
add.u64 %addr, %in, %off;\n\
ld.global.u16 %raw, [%addr];\n\
mov.b32 %bits, 0;\n\
cvt.u32.u16 %bits, %raw;\n\
shl.b32 %bits, %bits, 16;\n\
mov.b32 %val, %bits;\n\
add.f32 %sum, %sum, %val;\n\
\n\
add.u32 %k, %k, 1;\n\
bra SUM_BF16_LOOP;\n\
SUM_BF16_LOOP_DONE:\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %addr, %out, %off;\n\
st.global.f32 [%addr], %sum;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const MEAN_AXIS_BF16_F32_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry mean_axis_bf16_f32_kernel(\n\
.param .u64 input_ptr,\n\
.param .u64 output_ptr,\n\
.param .u32 outer_size,\n\
.param .u32 axis_size,\n\
.param .u32 inner_size,\n\
.param .u32 total_output\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %outer_sz, %axis_sz, %inner_sz;\n\
.reg .u32 %outer_idx, %inner_idx, %k, %tmp, %elem_idx;\n\
.reg .u64 %in, %out, %off, %addr;\n\
.reg .u16 %raw;\n\
.reg .u32 %bits;\n\
.reg .f32 %val, %sum, %cnt, %mean;\n\
.reg .pred %p, %lp;\n\
\n\
ld.param.u64 %in, [input_ptr];\n\
ld.param.u64 %out, [output_ptr];\n\
ld.param.u32 %outer_sz, [outer_size];\n\
ld.param.u32 %axis_sz, [axis_size];\n\
ld.param.u32 %inner_sz, [inner_size];\n\
ld.param.u32 %n_reg, [total_output];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %n_reg;\n\
@%p bra DONE;\n\
\n\
div.u32 %outer_idx, %r_tid, %inner_sz;\n\
rem.u32 %inner_idx, %r_tid, %inner_sz;\n\
\n\
mul.lo.u32 %tmp, %outer_idx, %axis_sz;\n\
mul.lo.u32 %tmp, %tmp, %inner_sz;\n\
add.u32 %tmp, %tmp, %inner_idx;\n\
\n\
mov.f32 %sum, 0f00000000;\n\
mov.u32 %k, 0;\n\
MEAN_BF16_LOOP:\n\
setp.ge.u32 %lp, %k, %axis_sz;\n\
@%lp bra MEAN_BF16_LOOP_DONE;\n\
\n\
mul.lo.u32 %elem_idx, %k, %inner_sz;\n\
add.u32 %elem_idx, %tmp, %elem_idx;\n\
cvt.u64.u32 %off, %elem_idx;\n\
shl.b64 %off, %off, 1;\n\
add.u64 %addr, %in, %off;\n\
ld.global.u16 %raw, [%addr];\n\
mov.b32 %bits, 0;\n\
cvt.u32.u16 %bits, %raw;\n\
shl.b32 %bits, %bits, 16;\n\
mov.b32 %val, %bits;\n\
add.f32 %sum, %sum, %val;\n\
\n\
add.u32 %k, %k, 1;\n\
bra MEAN_BF16_LOOP;\n\
MEAN_BF16_LOOP_DONE:\n\
\n\
cvt.rn.f32.u32 %cnt, %axis_sz;\n\
div.approx.f32 %mean, %sum, %cnt;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %addr, %out, %off;\n\
st.global.f32 [%addr], %mean;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const RELU_BF16_F32_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry relu_bf16_f32_kernel(\n\
.param .u64 in_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 n\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %bits;\n\
.reg .u64 %pin, %pout, %off, %addr;\n\
.reg .u16 %raw;\n\
.reg .f32 %val, %zero, %res;\n\
.reg .pred %p;\n\
\n\
ld.param.u64 %pin, [in_ptr];\n\
ld.param.u64 %pout, [out_ptr];\n\
ld.param.u32 %n_reg, [n];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %n_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 1;\n\
add.u64 %addr, %pin, %off;\n\
ld.global.u16 %raw, [%addr];\n\
mov.b32 %bits, 0;\n\
cvt.u32.u16 %bits, %raw;\n\
shl.b32 %bits, %bits, 16;\n\
mov.b32 %val, %bits;\n\
\n\
mov.f32 %zero, 0f00000000;\n\
max.f32 %res, %val, %zero;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %addr, %pout, %off;\n\
st.global.f32 [%addr], %res;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const SIGMOID_BF16_F32_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry sigmoid_bf16_f32_kernel(\n\
.param .u64 in_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 n\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %bits;\n\
.reg .u64 %pin, %pout, %off, %addr;\n\
.reg .u16 %raw;\n\
.reg .f32 %val, %neg, %exp_arg, %ex2, %denom, %res;\n\
.reg .pred %p;\n\
\n\
ld.param.u64 %pin, [in_ptr];\n\
ld.param.u64 %pout, [out_ptr];\n\
ld.param.u32 %n_reg, [n];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %n_reg;\n\
@%p bra DONE;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 1;\n\
add.u64 %addr, %pin, %off;\n\
ld.global.u16 %raw, [%addr];\n\
mov.b32 %bits, 0;\n\
cvt.u32.u16 %bits, %raw;\n\
shl.b32 %bits, %bits, 16;\n\
mov.b32 %val, %bits;\n\
\n\
neg.f32 %neg, %val;\n\
mul.f32 %exp_arg, %neg, 0f3FB8AA3B;\n\
ex2.approx.f32 %ex2, %exp_arg;\n\
add.f32 %denom, 0f3F800000, %ex2;\n\
rcp.approx.f32 %res, %denom;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %addr, %pout, %off;\n\
st.global.f32 [%addr], %res;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const DROPOUT_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry dropout_kernel(\n\
.param .u64 input_ptr,\n\
.param .u64 output_ptr,\n\
.param .u32 n,\n\
.param .u32 threshold,\n\
.param .f32 scale,\n\
.param .u32 seed\n\
) {\n\
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %thresh, %seed_reg, %rng, %tmp;\n\
.reg .u64 %in, %out, %off;\n\
.reg .f32 %val, %scale_reg, %zero;\n\
.reg .pred %p, %drop_p;\n\
\n\
ld.param.u64 %in, [input_ptr];\n\
ld.param.u64 %out, [output_ptr];\n\
ld.param.u32 %n_reg, [n];\n\
ld.param.u32 %thresh, [threshold];\n\
ld.param.f32 %scale_reg, [scale];\n\
ld.param.u32 %seed_reg, [seed];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %r_tid, %tid.x;\n\
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;\n\
\n\
setp.ge.u32 %p, %r_tid, %n_reg;\n\
@%p bra DONE;\n\
\n\
mul.lo.u32 %rng, %r_tid, 2654435761;\n\
xor.b32 %rng, %rng, %seed_reg;\n\
shl.b32 %tmp, %rng, 13;\n\
xor.b32 %rng, %rng, %tmp;\n\
shr.b32 %tmp, %rng, 17;\n\
xor.b32 %rng, %rng, %tmp;\n\
shl.b32 %tmp, %rng, 5;\n\
xor.b32 %rng, %rng, %tmp;\n\
\n\
cvt.u64.u32 %off, %r_tid;\n\
shl.b64 %off, %off, 2;\n\
add.u64 %in, %in, %off;\n\
add.u64 %out, %out, %off;\n\
ld.global.f32 %val, [%in];\n\
\n\
setp.lo.u32 %drop_p, %rng, %thresh;\n\
mov.f32 %zero, 0f00000000;\n\
@%drop_p mov.f32 %val, %zero;\n\
@!%drop_p mul.f32 %val, %val, %scale_reg;\n\
\n\
st.global.f32 [%out], %val;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub(crate) const BROADCAST_ADD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry broadcast_add_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u64 a_strides_ptr,
.param .u64 b_strides_ptr,
.param .u64 out_shape_ptr,
.param .u32 n,
.param .u32 ndim
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
.reg .u32 %remaining, %a_idx, %b_idx, %d;
.reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
.reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
.reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
.reg .f32 %va, %vb, %vr;
.reg .pred %p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %a_str, [a_strides_ptr];
ld.param.u64 %b_str, [b_strides_ptr];
ld.param.u64 %oshape, [out_shape_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %ndim_reg, [ndim];
// Global thread index.
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// Decompose flat index into N-d coordinates and compute A/B indices.
mov.u32 %remaining, %r_tid;
mov.u32 %a_idx, 0;
mov.u32 %b_idx, 0;
mov.u32 %d, %ndim_reg;
LOOP:
setp.eq.u32 %loop_p, %d, 0;
@%loop_p bra END_LOOP;
sub.u32 %d, %d, 1;
// Byte offset for dimension d: d * 4.
cvt.u64.u32 %d64, %d;
shl.b64 %d64, %d64, 2;
// Load out_shape[d].
add.u64 %tmp, %oshape, %d64;
ld.global.u32 %shape_d, [%tmp];
// Load a_strides[d] and b_strides[d].
add.u64 %tmp, %a_str, %d64;
ld.global.u32 %a_str_d, [%tmp];
add.u64 %tmp, %b_str, %d64;
ld.global.u32 %b_str_d, [%tmp];
// coord = remaining % shape_d; remaining /= shape_d.
rem.u32 %coord, %remaining, %shape_d;
div.u32 %remaining, %remaining, %shape_d;
// a_idx += coord * a_stride[d]; b_idx += coord * b_stride[d].
mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
bra LOOP;
END_LOOP:
// Load a[a_idx] and b[b_idx] (f32 = 4 bytes).
cvt.u64.u32 %off_a, %a_idx;
shl.b64 %off_a, %off_a, 2;
add.u64 %off_a, %a, %off_a;
ld.global.f32 %va, [%off_a];
cvt.u64.u32 %off_b, %b_idx;
shl.b64 %off_b, %off_b, 2;
add.u64 %off_b, %b, %off_b;
ld.global.f32 %vb, [%off_b];
// Operation: add.
add.f32 %vr, %va, %vb;
// Store to out[tid].
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 2;
add.u64 %off_out, %out, %off_out;
st.global.f32 [%off_out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const BROADCAST_SUB_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry broadcast_sub_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u64 a_strides_ptr,
.param .u64 b_strides_ptr,
.param .u64 out_shape_ptr,
.param .u32 n,
.param .u32 ndim
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
.reg .u32 %remaining, %a_idx, %b_idx, %d;
.reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
.reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
.reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
.reg .f32 %va, %vb, %vr;
.reg .pred %p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %a_str, [a_strides_ptr];
ld.param.u64 %b_str, [b_strides_ptr];
ld.param.u64 %oshape, [out_shape_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %ndim_reg, [ndim];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %remaining, %r_tid;
mov.u32 %a_idx, 0;
mov.u32 %b_idx, 0;
mov.u32 %d, %ndim_reg;
LOOP:
setp.eq.u32 %loop_p, %d, 0;
@%loop_p bra END_LOOP;
sub.u32 %d, %d, 1;
cvt.u64.u32 %d64, %d;
shl.b64 %d64, %d64, 2;
add.u64 %tmp, %oshape, %d64;
ld.global.u32 %shape_d, [%tmp];
add.u64 %tmp, %a_str, %d64;
ld.global.u32 %a_str_d, [%tmp];
add.u64 %tmp, %b_str, %d64;
ld.global.u32 %b_str_d, [%tmp];
rem.u32 %coord, %remaining, %shape_d;
div.u32 %remaining, %remaining, %shape_d;
mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
bra LOOP;
END_LOOP:
cvt.u64.u32 %off_a, %a_idx;
shl.b64 %off_a, %off_a, 2;
add.u64 %off_a, %a, %off_a;
ld.global.f32 %va, [%off_a];
cvt.u64.u32 %off_b, %b_idx;
shl.b64 %off_b, %off_b, 2;
add.u64 %off_b, %b, %off_b;
ld.global.f32 %vb, [%off_b];
sub.f32 %vr, %va, %vb;
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 2;
add.u64 %off_out, %out, %off_out;
st.global.f32 [%off_out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const BROADCAST_MUL_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry broadcast_mul_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u64 a_strides_ptr,
.param .u64 b_strides_ptr,
.param .u64 out_shape_ptr,
.param .u32 n,
.param .u32 ndim
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
.reg .u32 %remaining, %a_idx, %b_idx, %d;
.reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
.reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
.reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
.reg .f32 %va, %vb, %vr;
.reg .pred %p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %a_str, [a_strides_ptr];
ld.param.u64 %b_str, [b_strides_ptr];
ld.param.u64 %oshape, [out_shape_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %ndim_reg, [ndim];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %remaining, %r_tid;
mov.u32 %a_idx, 0;
mov.u32 %b_idx, 0;
mov.u32 %d, %ndim_reg;
LOOP:
setp.eq.u32 %loop_p, %d, 0;
@%loop_p bra END_LOOP;
sub.u32 %d, %d, 1;
cvt.u64.u32 %d64, %d;
shl.b64 %d64, %d64, 2;
add.u64 %tmp, %oshape, %d64;
ld.global.u32 %shape_d, [%tmp];
add.u64 %tmp, %a_str, %d64;
ld.global.u32 %a_str_d, [%tmp];
add.u64 %tmp, %b_str, %d64;
ld.global.u32 %b_str_d, [%tmp];
rem.u32 %coord, %remaining, %shape_d;
div.u32 %remaining, %remaining, %shape_d;
mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
bra LOOP;
END_LOOP:
cvt.u64.u32 %off_a, %a_idx;
shl.b64 %off_a, %off_a, 2;
add.u64 %off_a, %a, %off_a;
ld.global.f32 %va, [%off_a];
cvt.u64.u32 %off_b, %b_idx;
shl.b64 %off_b, %off_b, 2;
add.u64 %off_b, %b, %off_b;
ld.global.f32 %vb, [%off_b];
mul.f32 %vr, %va, %vb;
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 2;
add.u64 %off_out, %out, %off_out;
st.global.f32 [%off_out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const BROADCAST_DIV_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry broadcast_div_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u64 a_strides_ptr,
.param .u64 b_strides_ptr,
.param .u64 out_shape_ptr,
.param .u32 n,
.param .u32 ndim
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg, %ndim_reg;
.reg .u32 %remaining, %a_idx, %b_idx, %d;
.reg .u32 %shape_d, %a_str_d, %b_str_d, %coord;
.reg .u64 %a, %b, %out, %a_str, %b_str, %oshape;
.reg .u64 %off_a, %off_b, %off_out, %d64, %tmp;
.reg .f32 %va, %vb, %vr;
.reg .pred %p, %loop_p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u64 %a_str, [a_strides_ptr];
ld.param.u64 %b_str, [b_strides_ptr];
ld.param.u64 %oshape, [out_shape_ptr];
ld.param.u32 %n_reg, [n];
ld.param.u32 %ndim_reg, [ndim];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %remaining, %r_tid;
mov.u32 %a_idx, 0;
mov.u32 %b_idx, 0;
mov.u32 %d, %ndim_reg;
LOOP:
setp.eq.u32 %loop_p, %d, 0;
@%loop_p bra END_LOOP;
sub.u32 %d, %d, 1;
cvt.u64.u32 %d64, %d;
shl.b64 %d64, %d64, 2;
add.u64 %tmp, %oshape, %d64;
ld.global.u32 %shape_d, [%tmp];
add.u64 %tmp, %a_str, %d64;
ld.global.u32 %a_str_d, [%tmp];
add.u64 %tmp, %b_str, %d64;
ld.global.u32 %b_str_d, [%tmp];
rem.u32 %coord, %remaining, %shape_d;
div.u32 %remaining, %remaining, %shape_d;
mad.lo.u32 %a_idx, %coord, %a_str_d, %a_idx;
mad.lo.u32 %b_idx, %coord, %b_str_d, %b_idx;
bra LOOP;
END_LOOP:
cvt.u64.u32 %off_a, %a_idx;
shl.b64 %off_a, %off_a, 2;
add.u64 %off_a, %a, %off_a;
ld.global.f32 %va, [%off_a];
cvt.u64.u32 %off_b, %b_idx;
shl.b64 %off_b, %off_b, 2;
add.u64 %off_b, %b, %off_b;
ld.global.f32 %vb, [%off_b];
div.rn.f32 %vr, %va, %vb;
cvt.u64.u32 %off_out, %r_tid;
shl.b64 %off_out, %off_out, 2;
add.u64 %off_out, %out, %off_out;
st.global.f32 [%off_out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const STRIDED_SPLIT_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry strided_split_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 total_along_axis,
.param .u32 split_offset,
.param .u32 split_size,
.param .u32 inner_size,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u32 %total_ax, %sp_off, %sp_sz, %inner_sz;
.reg .u32 %outer_idx, %within, %chunk_stride, %src_idx, %base_off, %tmp;
.reg .u64 %in, %out, %off;
.reg .f32 %val;
.reg .pred %p;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %total_ax, [total_along_axis];
ld.param.u32 %sp_off, [split_offset];
ld.param.u32 %sp_sz, [split_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// chunk_stride = split_size * inner_size
mul.lo.u32 %chunk_stride, %sp_sz, %inner_sz;
// outer_idx = r_tid / chunk_stride
div.u32 %outer_idx, %r_tid, %chunk_stride;
// within = r_tid % chunk_stride
rem.u32 %within, %r_tid, %chunk_stride;
// base_off = split_offset * inner_size
mul.lo.u32 %base_off, %sp_off, %inner_sz;
// src_idx = outer_idx * total_along_axis * inner_size + base_off + within
mul.lo.u32 %src_idx, %outer_idx, %total_ax;
mul.lo.u32 %src_idx, %src_idx, %inner_sz;
add.u32 %src_idx, %src_idx, %base_off;
add.u32 %src_idx, %src_idx, %within;
// Load from in[src_idx]
cvt.u64.u32 %off, %src_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
ld.global.f32 %val, [%off];
// Store to out[r_tid]
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
st.global.f32 [%off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const STRIDED_CAT_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry strided_cat_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 total_along_axis,
.param .u32 cat_offset,
.param .u32 part_size,
.param .u32 inner_size,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u32 %total_ax, %cat_off, %part_sz, %inner_sz;
.reg .u32 %outer_idx, %within, %chunk_stride, %dst_idx, %base_off;
.reg .u64 %in, %out, %off;
.reg .f32 %val;
.reg .pred %p;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %total_ax, [total_along_axis];
ld.param.u32 %cat_off, [cat_offset];
ld.param.u32 %part_sz, [part_size];
ld.param.u32 %inner_sz, [inner_size];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
// chunk_stride = part_size * inner_size
mul.lo.u32 %chunk_stride, %part_sz, %inner_sz;
// outer_idx = r_tid / chunk_stride
div.u32 %outer_idx, %r_tid, %chunk_stride;
// within = r_tid % chunk_stride
rem.u32 %within, %r_tid, %chunk_stride;
// base_off = cat_offset * inner_size
mul.lo.u32 %base_off, %cat_off, %inner_sz;
// dst_idx = outer_idx * total_along_axis * inner_size + base_off + within
mul.lo.u32 %dst_idx, %outer_idx, %total_ax;
mul.lo.u32 %dst_idx, %dst_idx, %inner_sz;
add.u32 %dst_idx, %dst_idx, %base_off;
add.u32 %dst_idx, %dst_idx, %within;
// Load from in[r_tid]
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
ld.global.f32 %val, [%off];
// Store to out[dst_idx]
cvt.u64.u32 %off, %dst_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
st.global.f32 [%off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const STRIDED_COPY_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry strided_copy_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 src_offset_base,
.param .u32 n,
.param .u32 os0, .param .u32 os1, .param .u32 os2, .param .u32 os3,
.param .u32 os4, .param .u32 os5, .param .u32 os6, .param .u32 os7,
.param .u32 ss0, .param .u32 ss1, .param .u32 ss2, .param .u32 ss3,
.param .u32 ss4, .param .u32 ss5, .param .u32 ss6, .param .u32 ss7
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u32 %flat, %src_idx, %coord, %tmp, %os, %ss;
.reg .u64 %in, %out, %off;
.reg .f32 %val;
.reg .pred %p;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %src_idx, [src_offset_base];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %flat, %r_tid;
// Dim 0
ld.param.u32 %os, [os0];
ld.param.u32 %ss, [ss0];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 1
ld.param.u32 %os, [os1];
ld.param.u32 %ss, [ss1];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 2
ld.param.u32 %os, [os2];
ld.param.u32 %ss, [ss2];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 3
ld.param.u32 %os, [os3];
ld.param.u32 %ss, [ss3];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 4
ld.param.u32 %os, [os4];
ld.param.u32 %ss, [ss4];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 5
ld.param.u32 %os, [os5];
ld.param.u32 %ss, [ss5];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 6
ld.param.u32 %os, [os6];
ld.param.u32 %ss, [ss6];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 7
ld.param.u32 %os, [os7];
ld.param.u32 %ss, [ss7];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Load from in[src_idx]
cvt.u64.u32 %off, %src_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
ld.global.f32 %val, [%off];
// Store to out[r_tid]
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
st.global.f32 [%off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const STRIDED_SCATTER_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry strided_scatter_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 dst_offset_base,
.param .u32 n,
.param .u32 is0, .param .u32 is1, .param .u32 is2, .param .u32 is3,
.param .u32 is4, .param .u32 is5, .param .u32 is6, .param .u32 is7,
.param .u32 ds0, .param .u32 ds1, .param .u32 ds2, .param .u32 ds3,
.param .u32 ds4, .param .u32 ds5, .param .u32 ds6, .param .u32 ds7
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u32 %flat, %dst_idx, %coord, %tmp, %is, %ds;
.reg .u64 %in, %out, %off;
.reg .f32 %val;
.reg .pred %p;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %dst_idx, [dst_offset_base];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %flat, %r_tid;
// Dim 0
ld.param.u32 %is, [is0];
ld.param.u32 %ds, [ds0];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
// Dim 1
ld.param.u32 %is, [is1];
ld.param.u32 %ds, [ds1];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
// Dim 2
ld.param.u32 %is, [is2];
ld.param.u32 %ds, [ds2];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
// Dim 3
ld.param.u32 %is, [is3];
ld.param.u32 %ds, [ds3];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
// Dim 4
ld.param.u32 %is, [is4];
ld.param.u32 %ds, [ds4];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
// Dim 5
ld.param.u32 %is, [is5];
ld.param.u32 %ds, [ds5];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
// Dim 6
ld.param.u32 %is, [is6];
ld.param.u32 %ds, [ds6];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
// Dim 7
ld.param.u32 %is, [is7];
ld.param.u32 %ds, [ds7];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
// Load from in[r_tid]
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %off, %in, %off;
ld.global.f32 %val, [%off];
// Store to out[dst_idx]
cvt.u64.u32 %off, %dst_idx;
shl.b64 %off, %off, 2;
add.u64 %off, %out, %off;
st.global.f32 [%off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const STRIDED_SCATTER_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry strided_scatter_f64_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 dst_offset_base,
.param .u32 n,
.param .u32 is0, .param .u32 is1, .param .u32 is2, .param .u32 is3,
.param .u32 is4, .param .u32 is5, .param .u32 is6, .param .u32 is7,
.param .u32 ds0, .param .u32 ds1, .param .u32 ds2, .param .u32 ds3,
.param .u32 ds4, .param .u32 ds5, .param .u32 ds6, .param .u32 ds7
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u32 %flat, %dst_idx, %coord, %tmp, %is, %ds;
.reg .u64 %in, %out, %off;
.reg .f64 %val;
.reg .pred %p;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %dst_idx, [dst_offset_base];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %flat, %r_tid;
ld.param.u32 %is, [is0];
ld.param.u32 %ds, [ds0];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
ld.param.u32 %is, [is1];
ld.param.u32 %ds, [ds1];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
ld.param.u32 %is, [is2];
ld.param.u32 %ds, [ds2];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
ld.param.u32 %is, [is3];
ld.param.u32 %ds, [ds3];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
ld.param.u32 %is, [is4];
ld.param.u32 %ds, [ds4];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
ld.param.u32 %is, [is5];
ld.param.u32 %ds, [ds5];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
ld.param.u32 %is, [is6];
ld.param.u32 %ds, [ds6];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
ld.param.u32 %is, [is7];
ld.param.u32 %ds, [ds7];
div.u32 %coord, %flat, %is;
mul.lo.u32 %tmp, %coord, %is;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ds;
add.u32 %dst_idx, %dst_idx, %tmp;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %off, %in, %off;
ld.global.f64 %val, [%off];
cvt.u64.u32 %off, %dst_idx;
shl.b64 %off, %off, 3;
add.u64 %off, %out, %off;
st.global.f64 [%off], %val;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const DIV_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry div_kernel(
.param .u64 a_ptr,
.param .u64 b_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %b, %out, %off;
.reg .f32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
ld.global.f32 %vb, [%b];
div.rn.f32 %vr, %va, %vb;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const EXP_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry exp_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
// PTX ex2.approx computes 2^x; use the identity exp(x) = 2^(x * log2(e))
// log2(e) = 1.4426950408889634
mul.f32 %va, %va, 0f3FB8AA3B;
ex2.approx.f32 %vr, %va;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const EXP_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry exp_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %x, %vr;
.reg .f64 %log2e, %nf, %r;
.reg .f64 %p, %one, %half;
.reg .s32 %ni;
.reg .s64 %ni64, %exp_bits;
.reg .pred %p_bounds, %p_tid;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p_tid, %r_tid, %n_reg;
@%p_tid bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%a];
// Constants
mov.f64 %log2e, 0d3FF71547652B82FE; // log2(e) = 1.4426950408889634
mov.f64 %one, 0d3FF0000000000000; // 1.0
mov.f64 %half, 0d3FE0000000000000; // 0.5
// n = round(x * log2(e))
mul.f64 %nf, %x, %log2e;
cvt.rni.f64.f64 %nf, %nf; // round to nearest integer
cvt.rni.s32.f64 %ni, %nf; // integer n
// r = x - n * ln2 (Cody-Waite two-step for precision); the
// hi/lo split of ln(2) is `0d3FE62E42FEFA3800` (hi) and
// `0d3D2EF35793C76730` (lo), inlined directly into the FMA
// instructions below to avoid the cost of an unused register.
fma.rn.f64 %r, %nf, 0dBFE62E42FEFA3800, %x; // r = x - n*ln2_hi
fma.rn.f64 %r, %nf, 0dBD2EF35793C76730, %r; // r -= n*ln2_lo
// Horner polynomial for exp(r) - 1 - r = r^2 * (1/2! + r*(1/3! + r*(1/4! + ...)))
// p starts at 1/11!, accumulates down to 1/2!
mov.f64 %p, 0d3E5AE64567F544E4; // 1/11! = 2.505e-8
fma.rn.f64 %p, %p, %r, 0d3E927E4FB7789F5C; // 1/10! = 2.756e-7
fma.rn.f64 %p, %p, %r, 0d3EC71DE3A556C734; // 1/9! = 2.756e-6
fma.rn.f64 %p, %p, %r, 0d3EFA01A01A01A01A; // 1/8! = 2.480e-5
fma.rn.f64 %p, %p, %r, 0d3F2A01A01A01A01A; // 1/7! = 1.984e-4
fma.rn.f64 %p, %p, %r, 0d3F56C16C16C16C17; // 1/6! = 1.389e-3
fma.rn.f64 %p, %p, %r, 0d3F81111111111111; // 1/5! = 8.333e-3
fma.rn.f64 %p, %p, %r, 0d3FA5555555555555; // 1/4! = 4.167e-2
fma.rn.f64 %p, %p, %r, 0d3FC5555555555555; // 1/3! = 1.667e-1
fma.rn.f64 %p, %p, %r, %half; // 1/2! = 5.000e-1
// exp(r) = 1 + r + r^2 * p => 1 + r*(1 + r*p)
fma.rn.f64 %p, %p, %r, %one; // p = r*p + 1
fma.rn.f64 %vr, %p, %r, %one; // vr = p*r + 1 = exp(r)
// Scale by 2^n: multiply by constructing the f64 bit pattern for 2^n.
// IEEE 754 f64: 2^n has exponent field = n + 1023, no mantissa bits.
// Bit pattern: (n + 1023) << 52.
cvt.s64.s32 %ni64, %ni;
add.s64 %ni64, %ni64, 1023;
shl.b64 %exp_bits, %ni64, 52;
mov.b64 %nf, %exp_bits; // reinterpret as f64 = 2^n
mul.f64 %vr, %vr, %nf;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const LOG_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry log_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
// PTX lg2.approx computes log2(x); use the identity ln(x) = log2(x) / log2(e)
// 1/log2(e) = ln(2) = 0.6931471805599453
lg2.approx.f32 %vr, %va;
mul.f32 %vr, %vr, 0f3F317218;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const LOG_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry log_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .u64 %xbits, %mantissa_bits, %bias_bits;
.reg .f64 %x, %vr, %m, %f, %f2, %s, %p;
.reg .f64 %ln2_hi, %ln2_lo, %one, %two, %sqrt2, %half_const;
.reg .s32 %exp_i;
.reg .s64 %exp64;
.reg .f64 %nf;
.reg .pred %p_tid, %p_shift;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p_tid, %r_tid, %n_reg;
@%p_tid bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %x, [%a];
mov.f64 %ln2_hi, 0d3FE62E42FEFA3800; // hi of ln(2) Cody-Waite split
mov.f64 %ln2_lo, 0d3D2EF35793C76730; // lo of ln(2) Cody-Waite split
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %two, 0d4000000000000000;
// Extract exponent: n = exponent_field - 1023
mov.b64 %xbits, %x;
shr.u64 %exp64, %xbits, 52;
and.b64 %exp64, %exp64, 2047; // 11-bit exponent field
sub.s64 %exp64, %exp64, 1023;
cvt.rn.f64.s64 %nf, %exp64; // n as f64
// Extract mantissa m: set exponent to 1023 (so m is in [1, 2))
mov.u64 %bias_bits, 0x3FF0000000000000; // exponent = 1023
and.b64 %mantissa_bits, %xbits, 0x000FFFFFFFFFFFFF; // mantissa bits
or.b64 %mantissa_bits, %mantissa_bits, %bias_bits;
mov.b64 %m, %mantissa_bits; // m in [1.0, 2.0)
// Half-step: if m > sqrt(2), halve m and bump n by 1 to restrict
// m to [sqrt(2)/2, sqrt(2)) and |f| <= ~0.172.
mov.f64 %sqrt2, 0d3FF6A09E667F3BCD;
mov.f64 %half_const, 0d3FE0000000000000;
setp.gt.f64 %p_shift, %m, %sqrt2;
@%p_shift mul.f64 %m, %m, %half_const;
@%p_shift add.f64 %nf, %nf, %one;
// f = (m - 1) / (m + 1), |f| <= ~0.172 after half-step
sub.f64 %f, %m, %one;
add.f64 %s, %m, %one;
div.rn.f64 %f, %f, %s;
// Degree-7 odd-power Horner (extra 1/13, 1/15 vs degree-5).
mul.f64 %f2, %f, %f;
mov.f64 %p, 0d3FB1111111111111; // 1/15
fma.rn.f64 %p, %p, %f2, 0d3FB3B13B13B13B14; // 1/13
fma.rn.f64 %p, %p, %f2, 0d3FB745D1745D1746; // 1/11
fma.rn.f64 %p, %p, %f2, 0d3FBC71C71C71C71C; // 1/9
fma.rn.f64 %p, %p, %f2, 0d3FC2492492492492; // 1/7
fma.rn.f64 %p, %p, %f2, 0d3FC999999999999A; // 1/5
fma.rn.f64 %p, %p, %f2, 0d3FD5555555555555; // 1/3
fma.rn.f64 %p, %p, %f2, %one;
// ln(m) = 2*f*p
mul.f64 %p, %p, %f;
add.f64 %p, %p, %p; // * 2
// ln(x) = n*ln(2) + ln(m), with 2-double ln(2) for the n term.
fma.rn.f64 %vr, %nf, %ln2_hi, %p;
fma.rn.f64 %vr, %nf, %ln2_lo, %vr;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SQRT_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sqrt_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
sqrt.rn.f32 %vr, %va;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const POW_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry pow_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .f32 exponent,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr, %exp, %lg;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.f32 %exp, [exponent];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
// x^e = 2^(e * log2(x))
lg2.approx.f32 %lg, %va;
mul.f32 %lg, %lg, %exp;
ex2.approx.f32 %vr, %lg;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const POW_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry pow_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .f64 exponent,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %va, %vr, %exp64, %one, %two;
// log registers
.reg .u64 %l_xbits, %l_mbits, %l_bias;
.reg .s64 %l_exp64;
.reg .f64 %l_m, %l_f, %l_f2, %l_s, %l_p, %l_nf;
.reg .f64 %l_ln2_hi, %l_ln2_lo, %l_lnx, %l_sqrt2, %l_half_const;
.reg .pred %p_shift;
// exp registers
.reg .f64 %e_z, %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.f64 %exp64, [exponent];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %va, [%a];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %two, 0d4000000000000000;
// === ln(va) via argument reduction ===
// Decompose va = 2^n * m, m in [1,2). Then half-step reduce: if
// m > sqrt(2), set m <- m/2 and n <- n+1, so m ends up in
// [sqrt(2)/2, sqrt(2)) and |f| = |(m-1)/(m+1)| <= ~0.172.
mov.b64 %l_xbits, %va;
shr.u64 %l_exp64, %l_xbits, 52;
and.b64 %l_exp64, %l_exp64, 2047;
sub.s64 %l_exp64, %l_exp64, 1023;
cvt.rn.f64.s64 %l_nf, %l_exp64;
mov.u64 %l_bias, 0x3FF0000000000000;
and.b64 %l_mbits, %l_xbits, 0x000FFFFFFFFFFFFF;
or.b64 %l_mbits, %l_mbits, %l_bias;
mov.b64 %l_m, %l_mbits;
// Half-step: if m > sqrt(2), halve m and bump n by 1 (in the f64
// count register; the integer exp64 is no longer used after this).
mov.f64 %l_sqrt2, 0d3FF6A09E667F3BCD;
mov.f64 %l_half_const, 0d3FE0000000000000;
setp.gt.f64 %p_shift, %l_m, %l_sqrt2;
@%p_shift mul.f64 %l_m, %l_m, %l_half_const;
@%p_shift add.f64 %l_nf, %l_nf, %one;
// f = (m-1)/(m+1), |f| <= ~0.172 after half-step reduction
sub.f64 %l_f, %l_m, %one;
add.f64 %l_s, %l_m, %one;
div.rn.f64 %l_f, %l_f, %l_s;
mul.f64 %l_f2, %l_f, %l_f;
// Degree-7 odd-power Horner: p = sum_{k=0..6} f^(2k) / (2k+1).
// Extra terms beyond degree-5 are 1/13 and 1/15.
// p = 1/15 + f2*(1/13 + f2*(1/11 + f2*(1/9 + f2*(1/7 + f2*(1/5 + f2*(1/3 + f2*1))))))
mov.f64 %l_p, 0d3FB1111111111111; // 1/15
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FB3B13B13B13B14; // 1/13
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FB745D1745D1746; // 1/11
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FBC71C71C71C71C; // 1/9
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC2492492492492; // 1/7
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FC999999999999A; // 1/5
fma.rn.f64 %l_p, %l_p, %l_f2, 0d3FD5555555555555; // 1/3
fma.rn.f64 %l_p, %l_p, %l_f2, %one;
// ln(m) = 2*f*p
mul.f64 %l_p, %l_p, %l_f;
add.f64 %l_p, %l_p, %l_p;
// ln(x) = n*ln(2) + ln(m), where ln(2) is split into hi+lo
// (Cody-Waite 2-double form) so the n*ln(2) reconstruction is
// accurate to ~106 bits before adding ln(m). The hi/lo split is
// the same as used inside the exp argument reduction below.
mov.f64 %l_ln2_hi, 0d3FE62E42FEFA3800;
mov.f64 %l_ln2_lo, 0d3D2EF35793C76730;
fma.rn.f64 %l_lnx, %l_nf, %l_ln2_hi, %l_p;
fma.rn.f64 %l_lnx, %l_nf, %l_ln2_lo, %l_lnx;
// === exp(exponent * ln(x)) ===
mul.f64 %e_z, %exp64, %l_lnx;
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %e_z, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %e_z;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %vr, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %vr, %vr, %e_nf;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const ABS_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry abs_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
abs.f32 %vr, %va;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SIGMOID_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sigmoid_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr, %neg, %e, %denom, %one, %lg2e;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
// sigmoid(x) = 1 / (1 + exp(-x))
neg.f32 %neg, %va;
mov.f32 %lg2e, 0f3FB8AA3B;
mul.f32 %neg, %neg, %lg2e;
ex2.approx.f32 %e, %neg;
mov.f32 %one, 0f3F800000;
add.f32 %denom, %one, %e;
div.rn.f32 %vr, %one, %denom;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const SIGMOID_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sigmoid_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %va, %vr, %e64, %denom, %one, %neg_x;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %va, [%a];
mov.f64 %one, 0d3FF0000000000000;
// sigmoid(x) = 1 / (1 + exp(-x))
neg.f64 %neg_x, %va;
// --- exp(%neg_x) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg_x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg_x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e64, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e64, %e64, %e_nf;
// --- end exp ---
add.f64 %denom, %one, %e64;
div.rn.f64 %vr, %one, %denom;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const TANH_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry tanh_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f32 %va, %vr, %neg2x, %e, %denom, %sig, %one, %two, %lg2e;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f32 %va, [%a];
// tanh(x) = 2*sigmoid(2x) - 1
mov.f32 %two, 0f40000000;
mul.f32 %neg2x, %va, %two;
neg.f32 %neg2x, %neg2x;
mov.f32 %lg2e, 0f3FB8AA3B;
mul.f32 %neg2x, %neg2x, %lg2e;
ex2.approx.f32 %e, %neg2x;
mov.f32 %one, 0f3F800000;
add.f32 %denom, %one, %e;
div.rn.f32 %sig, %one, %denom;
mul.f32 %vr, %two, %sig;
sub.f32 %vr, %vr, %one;
st.global.f32 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const TANH_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry tanh_f64_kernel(
.param .u64 a_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %a, %out, %off;
.reg .f64 %va, %vr, %e64, %num, %denom, %one, %two, %neg2x;
.reg .f64 %e_nf, %e_r, %e_p, %e_half;
.reg .s32 %e_ni;
.reg .s64 %e_ni64, %e_bits;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.f64 %va, [%a];
mov.f64 %one, 0d3FF0000000000000;
mov.f64 %two, 0d4000000000000000;
// tanh(x) = (1 - exp(-2x)) / (1 + exp(-2x))
mul.f64 %neg2x, %va, %two;
neg.f64 %neg2x, %neg2x;
// --- exp(%neg2x) via Cody-Waite + degree-11 Horner ---
mov.f64 %e_half, 0d3FE0000000000000;
fma.rn.f64 %e_nf, %neg2x, 0d3FF71547652B82FE, %e_half;
cvt.rmi.f64.f64 %e_nf, %e_nf;
cvt.rni.s32.f64 %e_ni, %e_nf;
fma.rn.f64 %e_r, %e_nf, 0dBFE62E42FEFA3800, %neg2x;
fma.rn.f64 %e_r, %e_nf, 0dBD2EF35793C76730, %e_r;
mov.f64 %e_p, 0d3E5AE64567F544E4;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3E927E4FB7789F5C;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EC71DE3A556C734;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3EFA01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F2A01A01A01A01A;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F56C16C16C16C17;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3F81111111111111;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FA5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, 0d3FC5555555555555;
fma.rn.f64 %e_p, %e_p, %e_r, %e_half;
fma.rn.f64 %e_p, %e_p, %e_r, %one;
fma.rn.f64 %e64, %e_p, %e_r, %one;
cvt.s64.s32 %e_ni64, %e_ni;
add.s64 %e_ni64, %e_ni64, 1023;
shl.b64 %e_bits, %e_ni64, 52;
mov.b64 %e_nf, %e_bits;
mul.f64 %e64, %e64, %e_nf;
// --- end exp ---
sub.f64 %num, %one, %e64;
add.f64 %denom, %one, %e64;
div.rn.f64 %vr, %num, %denom;
st.global.f64 [%out], %vr;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const FUSED_ADAM_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry fused_adam_kernel(
.param .u64 param_ptr,
.param .u64 grad_ptr,
.param .u64 exp_avg_ptr,
.param .u64 exp_avg_sq_ptr,
.param .f32 beta1,
.param .f32 beta2,
.param .f32 lr,
.param .f32 eps,
.param .f32 bc1,
.param .f32 bc2,
.param .f32 weight_decay,
.param .u32 n
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %p, %g, %m, %v, %off;
.reg .f32 %vp, %vg, %vm, %vv;
.reg .f32 %b1, %b2, %f_lr, %f_eps, %f_bc1, %f_bc2, %f_wd;
.reg .f32 %t1, %t2, %m_hat, %v_hat, %denom, %update;
.reg .f32 %one;
.reg .pred %p_bound, %p_wd;
ld.param.u64 %p, [param_ptr];
ld.param.u64 %g, [grad_ptr];
ld.param.u64 %m, [exp_avg_ptr];
ld.param.u64 %v, [exp_avg_sq_ptr];
ld.param.f32 %b1, [beta1];
ld.param.f32 %b2, [beta2];
ld.param.f32 %f_lr, [lr];
ld.param.f32 %f_eps, [eps];
ld.param.f32 %f_bc1, [bc1];
ld.param.f32 %f_bc2, [bc2];
ld.param.f32 %f_wd, [weight_decay];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p_bound, %r_tid, %n_reg;
@%p_bound bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 2;
add.u64 %p, %p, %off;
add.u64 %g, %g, %off;
add.u64 %m, %m, %off;
add.u64 %v, %v, %off;
ld.global.f32 %vp, [%p];
ld.global.f32 %vg, [%g];
ld.global.f32 %vm, [%m];
ld.global.f32 %vv, [%v];
// L2 weight decay: g = g + wd * p
mov.f32 %one, 0f00000000;
setp.gt.f32 %p_wd, %f_wd, %one;
@%p_wd fma.rn.f32 %vg, %f_wd, %vp, %vg;
// exp_avg = beta1 * exp_avg + (1 - beta1) * g
mov.f32 %one, 0f3F800000;
sub.f32 %t1, %one, %b1;
mul.f32 %vm, %vm, %b1;
fma.rn.f32 %vm, %t1, %vg, %vm;
// exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * g * g
sub.f32 %t2, %one, %b2;
mul.f32 %vv, %vv, %b2;
mul.f32 %t1, %vg, %vg;
fma.rn.f32 %vv, %t2, %t1, %vv;
// m_hat = exp_avg / bc1
div.rn.f32 %m_hat, %vm, %f_bc1;
// v_hat = exp_avg_sq / bc2
div.rn.f32 %v_hat, %vv, %f_bc2;
// denom = sqrt(v_hat) + eps
sqrt.rn.f32 %denom, %v_hat;
add.f32 %denom, %denom, %f_eps;
// param = param - lr * m_hat / denom
div.rn.f32 %update, %m_hat, %denom;
mul.f32 %update, %update, %f_lr;
sub.f32 %vp, %vp, %update;
st.global.f32 [%p], %vp;
st.global.f32 [%m], %vm;
st.global.f32 [%v], %vv;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub(crate) const FUSED_GRU_FORWARD_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry fused_gru_forward_kernel(
.param .u64 input_gates_ptr,
.param .u64 hidden_gates_ptr,
.param .u64 bias_ih_ptr,
.param .u64 bias_hh_ptr,
.param .u64 hx_ptr,
.param .u64 hy_ptr,
.param .u64 workspace_ptr,
.param .u32 hsz,
.param .u32 total
) {
.reg .u32 %my_tid, %bid, %bdim, %gdim, %total_reg, %hsz_reg;
.reg .u32 %idx, %stride, %offset3, %offset5, %hmod, %batch_idx;
.reg .u64 %ig, %hg, %b1, %b2, %hx, %hy, %ws;
.reg .u64 %off64, %tmp64;
.reg .f32 %ir, %ii, %in, %hr, %hi, %hn;
.reg .f32 %b1r, %b1i, %b1n, %b2r, %b2i, %b2n;
.reg .f32 %hx_val, %rg, %zg, %ng, %hy_val;
.reg .f32 %one, %neg_one, %exp_val, %denom, %tmp;
.reg .pred %p;
ld.param.u64 %ig, [input_gates_ptr];
ld.param.u64 %hg, [hidden_gates_ptr];
ld.param.u64 %b1, [bias_ih_ptr];
ld.param.u64 %b2, [bias_hh_ptr];
ld.param.u64 %hx, [hx_ptr];
ld.param.u64 %hy, [hy_ptr];
ld.param.u64 %ws, [workspace_ptr];
ld.param.u32 %hsz_reg, [hsz];
ld.param.u32 %total_reg, [total];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %my_tid, %tid.x;
mov.u32 %gdim, %nctaid.x;
mad.lo.u32 %idx, %bid, %bdim, %my_tid;
mul.lo.u32 %stride, %bdim, %gdim;
mov.f32 %one, 0f3F800000;
LOOP:
setp.ge.u32 %p, %idx, %total_reg;
@%p bra END;
// offset3 = (idx/hsz)*3*hsz + idx%hsz (into [B, 3*H] gates tensor)
div.u32 %batch_idx, %idx, %hsz_reg;
rem.u32 %hmod, %idx, %hsz_reg;
mul.lo.u32 %offset3, %batch_idx, %hsz_reg;
mul.lo.u32 %offset3, %offset3, 3;
add.u32 %offset3, %offset3, %hmod;
// Load input gate components: ir, ii, in
cvt.u64.u32 %off64, %offset3;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %ig, %off64;
ld.global.f32 %ir, [%tmp64];
cvt.u64.u32 %off64, %hsz_reg;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %ii, [%tmp64];
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %in, [%tmp64];
// Load hidden gate components: hr, hi, hn
cvt.u64.u32 %off64, %offset3;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %hg, %off64;
ld.global.f32 %hr, [%tmp64];
cvt.u64.u32 %off64, %hsz_reg;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %hi, [%tmp64];
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %hn, [%tmp64];
// Load biases (indexed by hmod, hmod+hsz, hmod+2*hsz)
cvt.u64.u32 %off64, %hmod;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %b1, %off64;
ld.global.f32 %b1r, [%tmp64];
cvt.u64.u32 %off64, %hsz_reg;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %b1i, [%tmp64];
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %b1n, [%tmp64];
cvt.u64.u32 %off64, %hmod;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %b2, %off64;
ld.global.f32 %b2r, [%tmp64];
cvt.u64.u32 %off64, %hsz_reg;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %b2i, [%tmp64];
add.u64 %tmp64, %tmp64, %off64;
ld.global.f32 %b2n, [%tmp64];
// Load hx[idx]
cvt.u64.u32 %off64, %idx;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %hx, %off64;
ld.global.f32 %hx_val, [%tmp64];
// r = sigmoid(ir + hr + b1r + b2r)
add.f32 %rg, %ir, %hr;
add.f32 %rg, %rg, %b1r;
add.f32 %rg, %rg, %b2r;
neg.f32 %tmp, %rg;
mul.f32 %tmp, %tmp, 0f3FB8AA3B;
ex2.approx.f32 %exp_val, %tmp;
add.f32 %denom, %one, %exp_val;
div.rn.f32 %rg, %one, %denom;
// z = sigmoid(ii + hi + b1i + b2i)
add.f32 %zg, %ii, %hi;
add.f32 %zg, %zg, %b1i;
add.f32 %zg, %zg, %b2i;
neg.f32 %tmp, %zg;
mul.f32 %tmp, %tmp, 0f3FB8AA3B;
ex2.approx.f32 %exp_val, %tmp;
add.f32 %denom, %one, %exp_val;
div.rn.f32 %zg, %one, %denom;
// n = tanh(in + b1n + r*(hn + b2n))
add.f32 %tmp, %hn, %b2n;
fma.rn.f32 %ng, %rg, %tmp, %in;
add.f32 %ng, %ng, %b1n;
// tanh via 2*sigmoid(2x)-1
mul.f32 %tmp, %ng, 0f40000000;
neg.f32 %tmp, %tmp;
mul.f32 %tmp, %tmp, 0f3FB8AA3B;
ex2.approx.f32 %exp_val, %tmp;
add.f32 %denom, %one, %exp_val;
div.rn.f32 %ng, %one, %denom;
mul.f32 %ng, %ng, 0f40000000;
sub.f32 %ng, %ng, %one;
// hy = n + z * (hx - n)
sub.f32 %tmp, %hx_val, %ng;
fma.rn.f32 %hy_val, %zg, %tmp, %ng;
// Store hy[idx]
cvt.u64.u32 %off64, %idx;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %hy, %off64;
st.global.f32 [%tmp64], %hy_val;
// Store workspace: [r, z, n, hx, hn+b2n] at offset5 = (idx/hsz)*5*hsz + idx%hsz
mul.lo.u32 %offset5, %batch_idx, %hsz_reg;
mul.lo.u32 %offset5, %offset5, 5;
add.u32 %offset5, %offset5, %hmod;
cvt.u64.u32 %off64, %offset5;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %ws, %off64;
st.global.f32 [%tmp64], %rg;
cvt.u64.u32 %off64, %hsz_reg;
shl.b64 %off64, %off64, 2;
add.u64 %tmp64, %tmp64, %off64;
st.global.f32 [%tmp64], %zg;
add.u64 %tmp64, %tmp64, %off64;
st.global.f32 [%tmp64], %ng;
add.u64 %tmp64, %tmp64, %off64;
st.global.f32 [%tmp64], %hx_val;
add.u64 %tmp64, %tmp64, %off64;
add.f32 %tmp, %hn, %b2n;
st.global.f32 [%tmp64], %tmp;
add.u32 %idx, %idx, %stride;
bra LOOP;
END:
ret;
}
";
#[cfg(feature = "cuda")]
fn launch_cfg(n: usize) -> GpuResult<LaunchConfig> {
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "kernel_launch",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
const BLOCK: u32 = 256;
let grid = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
Ok(LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
})
}
#[cfg(feature = "cuda")]
fn validate_binary(a: &CudaBuffer<f32>, b: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: a.device_ordinal(),
got: device.ordinal(),
});
}
if b.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: b.device_ordinal(),
got: device.ordinal(),
});
}
if a.len() != b.len() {
return Err(GpuError::LengthMismatch {
a: a.len(),
b: b.len(),
});
}
Ok(())
}
#[cfg(feature = "cuda")]
fn validate_unary(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<()> {
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: a.device_ordinal(),
got: device.ordinal(),
});
}
Ok(())
}
#[cfg(feature = "cuda")]
fn validate_device<T>(a: &CudaBuffer<T>, device: &GpuDevice) -> GpuResult<()> {
if a.device_ordinal() != device.ordinal() {
return Err(GpuError::DeviceMismatch {
expected: a.device_ordinal(),
got: device.ordinal(),
});
}
Ok(())
}
#[cfg(feature = "cuda")]
fn try_launch_binary(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
fn try_launch_binary_vec4(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let n4 = (n / 4) as u32;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n4 as usize)?;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(&n4)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
fn try_launch_unary(
a: &CudaBuffer<f32>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
fn try_launch_binary_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
});
}
};
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
fn try_launch_unary_into(
a: &CudaBuffer<f32>,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
});
}
};
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
fn try_launch_binary_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
});
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
fn try_launch_unary_f64(
a: &CudaBuffer<f64>,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
});
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
fn try_launch_broadcast_binary_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
a_strides: &[u32],
b_strides: &[u32],
out_shape: &[u32],
out_numel: usize,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
let ndim = out_shape.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
});
}
};
let a_str_buf = cpu_to_gpu(a_strides, device)?;
let b_str_buf = cpu_to_gpu(b_strides, device)?;
let shape_buf = cpu_to_gpu(out_shape, device)?;
let mut out = alloc_zeros_f64(out_numel, device)?;
let cfg = launch_cfg(out_numel)?;
let n_u32 = out_numel as u32;
let ndim_u32 = ndim as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(a_str_buf.inner())
.arg(b_str_buf.inner())
.arg(shape_buf.inner())
.arg(&n_u32)
.arg(&ndim_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
fn try_launch_broadcast_binary(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
a_strides: &[u32],
b_strides: &[u32],
out_shape: &[u32],
out_numel: usize,
device: &GpuDevice,
ptx_src: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let ndim = out_shape.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx_src,
kernel_name,
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
});
}
};
let a_str_buf = cpu_to_gpu(a_strides, device)?;
let b_str_buf = cpu_to_gpu(b_strides, device)?;
let shape_buf = cpu_to_gpu(out_shape, device)?;
let mut out = alloc_zeros_f32(out_numel, device)?;
let cfg = launch_cfg(out_numel)?;
let n_u32 = out_numel as u32;
let ndim_u32 = ndim as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(a_str_buf.inner())
.arg(b_str_buf.inner())
.arg(shape_buf.inner())
.arg(&n_u32)
.arg(&ndim_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
fn broadcast_strides(in_shape: &[usize], out_shape: &[usize]) -> Vec<u32> {
let ndim = out_shape.len();
let in_ndim = in_shape.len();
let mut strides = vec![0u32; ndim];
let mut stride: u32 = 1;
for d in (0..ndim).rev() {
let in_d = if d + in_ndim >= ndim {
d + in_ndim - ndim
} else {
strides[d] = 0;
continue;
};
if in_shape[in_d] == 1 {
strides[d] = 0; } else {
strides[d] = stride;
}
stride *= in_shape[in_d] as u32;
}
strides
}
#[cfg(feature = "cuda")]
pub fn gpu_add(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(a, b, device)?;
let n = a.len();
if n >= 16 && n % 4 == 0 {
match try_launch_binary_vec4(a, b, device, ADD_VEC4_PTX, "add_vec4_kernel") {
Ok(out) => return Ok(out),
Err(GpuError::PtxCompileFailed { .. }) => {}
Err(e) => return Err(e),
}
}
try_launch_binary(a, b, device, ADD_PTX, "add_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_sub(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(a, b, device)?;
try_launch_binary(a, b, device, SUB_PTX, "sub_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_mul(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(a, b, device)?;
let n = a.len();
if n >= 16 && n % 4 == 0 {
match try_launch_binary_vec4(a, b, device, MUL_VEC4_PTX, "mul_vec4_kernel") {
Ok(out) => return Ok(out),
Err(GpuError::PtxCompileFailed { .. }) => {}
Err(e) => return Err(e),
}
}
try_launch_binary(a, b, device, MUL_PTX, "mul_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_add(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
try_launch_broadcast_binary(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
BROADCAST_ADD_PTX,
"broadcast_add_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_sub(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
try_launch_broadcast_binary(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
BROADCAST_SUB_PTX,
"broadcast_sub_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_mul(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
try_launch_broadcast_binary(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
BROADCAST_MUL_PTX,
"broadcast_mul_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_div(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
try_launch_broadcast_binary(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
BROADCAST_DIV_PTX,
"broadcast_div_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_neg(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
try_launch_unary(a, device, NEG_PTX, "neg_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_relu(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
try_launch_unary(a, device, RELU_PTX, "relu_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_relu_backward(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, input, device)?;
try_launch_binary(
grad,
input,
device,
RELU_BACKWARD_PTX,
"relu_backward_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_abs_backward(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, input, device)?;
try_launch_binary(grad, input, device, ABS_BACKWARD_PTX, "abs_backward_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_backward(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, input, device)?;
try_launch_binary(
grad,
input,
device,
GELU_BACKWARD_PTX,
"gelu_backward_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_backward_erf(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, input, device)?;
try_launch_binary(
grad,
input,
device,
GELU_BACKWARD_ERF_PTX,
"gelu_backward_erf_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_index_select_1d(
input: &CudaBuffer<f32>,
indices: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let n = indices.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
INDEX_SELECT_1D_PTX,
"index_select_1d_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "index_select_1d_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_scatter_add_1d(
grad_output: &CudaBuffer<f32>,
indices: &CudaBuffer<f32>,
input_len: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(grad_output, device)?;
let n = grad_output.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SCATTER_ADD_1D_PTX,
"scatter_add_1d_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "scatter_add_1d_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(input_len, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad_output.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_index_select_dim(
input: &CudaBuffer<f32>,
indices: &CudaBuffer<f32>,
outer: usize,
in_dim_size: usize,
out_dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
validate_device(indices, device)?;
let total = outer
.checked_mul(out_dim_size)
.and_then(|t| t.checked_mul(inner))
.ok_or(GpuError::ShapeMismatch {
op: "index_select_dim",
expected: vec![usize::MAX],
got: vec![outer, out_dim_size, inner],
})?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
INDEX_SELECT_DIM_PTX,
"index_select_dim_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "index_select_dim_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let outer_u32 = outer as u32;
let in_dim_u32 = in_dim_size as u32;
let out_dim_u32 = out_dim_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&in_dim_u32)
.arg(&out_dim_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_index_select_dim_f64(
input: &CudaBuffer<f64>,
indices: &CudaBuffer<f32>,
outer: usize,
in_dim_size: usize,
out_dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
validate_device(input, device)?;
validate_device(indices, device)?;
let total = outer
.checked_mul(out_dim_size)
.and_then(|t| t.checked_mul(inner))
.ok_or(GpuError::ShapeMismatch {
op: "index_select_dim_f64",
expected: vec![usize::MAX],
got: vec![outer, out_dim_size, inner],
})?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
INDEX_SELECT_DIM_F64_PTX,
"index_select_dim_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "index_select_dim_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(total, device)?;
let cfg = launch_cfg(total)?;
let outer_u32 = outer as u32;
let in_dim_u32 = in_dim_size as u32;
let out_dim_u32 = out_dim_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&in_dim_u32)
.arg(&out_dim_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_masked_fill(
input: &CudaBuffer<f32>,
mask: &CudaBuffer<f32>,
value: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_binary(input, mask, device)?;
let n = input.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
MASKED_FILL_PTX,
"masked_fill_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "masked_fill_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(mask.inner())
.arg(out.inner_mut())
.arg(&value)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_masked_zero(
grad: &CudaBuffer<f32>,
mask: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, mask, device)?;
try_launch_binary(grad, mask, device, MASKED_ZERO_PTX, "masked_zero_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_sigmoid_backward(
grad: &CudaBuffer<f32>,
output: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, output, device)?;
try_launch_binary(
grad,
output,
device,
SIGMOID_BACKWARD_PTX,
"sigmoid_backward_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_tanh_backward(
grad: &CudaBuffer<f32>,
output: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, output, device)?;
try_launch_binary(
grad,
output,
device,
TANH_BACKWARD_PTX,
"tanh_backward_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_softmax_backward(
grad: &CudaBuffer<f32>,
output: &CudaBuffer<f32>,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_binary(grad, output, device)?;
let total = grad.len();
let rows = total / cols;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SOFTMAX_BACKWARD_PTX,
"softmax_backward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "softmax_backward_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(output.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_log_softmax(
input: &CudaBuffer<f32>,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = input.len();
let rows = total / cols;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LOG_SOFTMAX_PTX,
"log_softmax_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "log_softmax_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_log_softmax_backward(
grad: &CudaBuffer<f32>,
output: &CudaBuffer<f32>,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_binary(grad, output, device)?;
let total = grad.len();
let rows = total / cols;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LOG_SOFTMAX_BACKWARD_PTX,
"log_softmax_backward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "log_softmax_backward_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(output.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_reduce_sum(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "gpu_reduce_sum",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
if n == 0 {
return cpu_to_gpu(&[0.0f32], device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
REDUCE_SUM_PTX,
"reduce_sum_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "reduce_sum_kernel",
source: e,
});
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = alloc_zeros_f32(num_blocks as usize, device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0, };
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host_partials = gpu_to_cpu(&partials, device)?;
let total: f32 = host_partials.iter().sum();
return cpu_to_gpu(&[total], device);
}
gpu_reduce_sum(&partials, device)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_reduce_sum(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_reduce_prod(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "gpu_reduce_prod",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
if n == 0 {
return cpu_to_gpu(&[1.0_f32], device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
REDUCE_PROD_PTX,
"reduce_prod_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "reduce_prod_kernel",
source: e,
});
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = cpu_to_gpu(&vec![1.0_f32; num_blocks as usize], device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host = gpu_to_cpu(&partials, device)?;
let total: f32 = host.iter().product();
return cpu_to_gpu(&[total], device);
}
gpu_reduce_prod(&partials, device)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_reduce_prod(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_prod_backward_f32(
input: &CudaBuffer<f32>,
grad_output: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let n = input.len();
if n == 0 {
let e: &[f32] = &[];
return cpu_to_gpu(e, device);
}
if grad_output.len() != 1 {
return Err(GpuError::ShapeMismatch {
op: "prod_backward",
expected: vec![1],
got: vec![grad_output.len()],
});
}
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
PROD_BACKWARD_PTX,
"prod_backward_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "prod_backward_kernel",
source: e,
})?;
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(grad_output.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_prod_backward_f32(
_input: &CudaBuffer<f32>,
_grad_output: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_prod_backward_f64(
input: &CudaBuffer<f64>,
grad_output: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
let n = input.len();
if n == 0 {
let e: &[f64] = &[];
return cpu_to_gpu(e, device);
}
if grad_output.len() != 1 {
return Err(GpuError::ShapeMismatch {
op: "prod_backward",
expected: vec![1],
got: vec![grad_output.len()],
});
}
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
PROD_BACKWARD_PTX,
"prod_backward_kernel",
"prod_backward_f64_kernel",
);
let f = crate::module_cache::get_or_compile(
ctx,
ptx,
"prod_backward_f64_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "prod_backward_kernel",
source: e,
})?;
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(grad_output.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_prod_backward_f64(
_input: &CudaBuffer<f64>,
_grad_output: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_reduce_min(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "gpu_reduce_min",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
if n == 0 {
return cpu_to_gpu(&[f32::INFINITY], device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
REDUCE_MIN_PTX,
"reduce_min_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "reduce_min_kernel",
source: e,
});
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = cpu_to_gpu(&vec![f32::INFINITY; num_blocks as usize], device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host_partials = gpu_to_cpu(&partials, device)?;
let total = host_partials.iter().copied().fold(f32::INFINITY, f32::min);
return cpu_to_gpu(&[total], device);
}
gpu_reduce_min(&partials, device)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_reduce_min(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_reduce_max(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "gpu_reduce_max",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
if n == 0 {
return cpu_to_gpu(&[f32::NEG_INFINITY], device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
REDUCE_MAX_PTX,
"reduce_max_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "reduce_max_kernel",
source: e,
});
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = cpu_to_gpu(&vec![f32::NEG_INFINITY; num_blocks as usize], device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host_partials = gpu_to_cpu(&partials, device)?;
let total = host_partials
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
return cpu_to_gpu(&[total], device);
}
gpu_reduce_max(&partials, device)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_reduce_max(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_masked_reduce_min(
data: &CudaBuffer<f32>,
mask_f: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
if data.len() != mask_f.len() {
return Err(GpuError::LengthMismatch {
a: data.len(),
b: mask_f.len(),
});
}
let n = data.len();
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "gpu_masked_reduce_min",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
if n == 0 {
return cpu_to_gpu(&[f32::INFINITY], device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
MASKED_REDUCE_MIN_PTX,
"masked_reduce_min_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "masked_reduce_min_kernel",
source: e,
});
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = cpu_to_gpu(&vec![f32::INFINITY; num_blocks as usize], device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f)
.arg(data.inner())
.arg(mask_f.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host = gpu_to_cpu(&partials, device)?;
let total = host.iter().copied().fold(f32::INFINITY, f32::min);
return cpu_to_gpu(&[total], device);
}
gpu_reduce_min(&partials, device)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_reduce_min(
_data: &CudaBuffer<f32>,
_mask_f: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_masked_reduce_max(
data: &CudaBuffer<f32>,
mask_f: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
if data.len() != mask_f.len() {
return Err(GpuError::LengthMismatch {
a: data.len(),
b: mask_f.len(),
});
}
let n = data.len();
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "gpu_masked_reduce_max",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
if n == 0 {
return cpu_to_gpu(&[f32::NEG_INFINITY], device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
MASKED_REDUCE_MAX_PTX,
"masked_reduce_max_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "masked_reduce_max_kernel",
source: e,
});
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = cpu_to_gpu(&vec![f32::NEG_INFINITY; num_blocks as usize], device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f)
.arg(data.inner())
.arg(mask_f.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host = gpu_to_cpu(&partials, device)?;
let total = host.iter().copied().fold(f32::NEG_INFINITY, f32::max);
return cpu_to_gpu(&[total], device);
}
gpu_reduce_max(&partials, device)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_reduce_max(
_data: &CudaBuffer<f32>,
_mask_f: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_pad_truncate_complex_f32(
src: &CudaBuffer<f32>,
batch: usize,
src_n: usize,
dst_n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
if src.len() != batch * src_n * 2 {
return Err(GpuError::ShapeMismatch {
op: "gpu_pad_truncate_complex_f32",
expected: vec![batch * src_n * 2],
got: vec![src.len()],
});
}
let total_pairs = batch * dst_n;
if total_pairs == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
PAD_TRUNCATE_PTX,
"pad_truncate_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "pad_truncate_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(batch * dst_n * 2, device)?;
let cfg = launch_cfg(total_pairs)?;
let batch_u32 = batch as u32;
let src_n_u32 = src_n as u32;
let dst_n_u32 = dst_n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(out.inner_mut())
.arg(&batch_u32)
.arg(&src_n_u32)
.arg(&dst_n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_pad_truncate_complex_f32(
_src: &CudaBuffer<f32>,
_batch: usize,
_src_n: usize,
_dst_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_sum_axis(
a: &CudaBuffer<f32>,
outer: usize,
axis_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(a, device)?;
let total_output = outer * inner;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SUM_AXIS_PTX,
"sum_axis_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "sum_axis_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total_output, device)?;
let cfg = launch_cfg(total_output)?;
let outer_u32 = outer as u32;
let axis_size_u32 = axis_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total_output as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&axis_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_cumsum(
input: &CudaBuffer<f32>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = outer * dim_size * inner;
let num_threads = outer * inner;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
CUMSUM_PTX,
"cumsum_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "cumsum_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(num_threads)?;
let outer_u32 = outer as u32;
let dim_size_u32 = dim_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&dim_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_cumprod(
input: &CudaBuffer<f32>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = outer * dim_size * inner;
let num_threads = outer * inner;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
CUMPROD_PTX,
"cumprod_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "cumprod_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(num_threads)?;
let outer_u32 = outer as u32;
let dim_size_u32 = dim_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&dim_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_cummax(
input: &CudaBuffer<f32>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = outer * dim_size * inner;
let num_threads = outer * inner;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
CUMMAX_PTX,
"cummax_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "cummax_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let mut out_idx = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(num_threads)?;
let outer_u32 = outer as u32;
let dim_size_u32 = dim_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(out_idx.inner_mut())
.arg(&outer_u32)
.arg(&dim_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok((out, out_idx))
}
#[cfg(feature = "cuda")]
pub fn gpu_cummin(
input: &CudaBuffer<f32>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = outer * dim_size * inner;
let num_threads = outer * inner;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
CUMMIN_PTX,
"cummin_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "cummin_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let mut out_idx = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(num_threads)?;
let outer_u32 = outer as u32;
let dim_size_u32 = dim_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(out_idx.inner_mut())
.arg(&outer_u32)
.arg(&dim_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok((out, out_idx))
}
#[cfg(feature = "cuda")]
pub fn gpu_logcumsumexp(
input: &CudaBuffer<f32>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = outer * dim_size * inner;
let num_threads = outer * inner;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LOGCUMSUMEXP_PTX,
"logcumsumexp_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "logcumsumexp_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(num_threads)?;
let outer_u32 = outer as u32;
let dim_size_u32 = dim_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&dim_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_strided_split(
input: &CudaBuffer<f32>,
total_along_axis: usize,
split_offset: usize,
split_size: usize,
inner_size: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
STRIDED_SPLIT_PTX,
"strided_split_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "strided_split_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let total_ax_u32 = total_along_axis as u32;
let offset_u32 = split_offset as u32;
let split_sz_u32 = split_size as u32;
let inner_u32 = inner_size as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&total_ax_u32)
.arg(&offset_u32)
.arg(&split_sz_u32)
.arg(&inner_u32)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_strided_cat(
input: &CudaBuffer<f32>,
output: &mut CudaBuffer<f32>,
total_along_axis: usize,
cat_offset: usize,
part_size: usize,
inner_size: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
STRIDED_CAT_PTX,
"strided_cat_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "strided_cat_kernel",
source: e,
});
}
};
let cfg = launch_cfg(n)?;
let total_ax_u32 = total_along_axis as u32;
let offset_u32 = cat_offset as u32;
let part_sz_u32 = part_size as u32;
let inner_u32 = inner_size as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(output.inner_mut())
.arg(&total_ax_u32)
.arg(&offset_u32)
.arg(&part_sz_u32)
.arg(&inner_u32)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(())
}
pub const STRIDED_COPY_MAX_DIMS: usize = 8;
#[cfg(feature = "cuda")]
fn pad_strided_copy_params(
out_shape: &[usize],
src_strides: &[isize],
n: usize,
) -> GpuResult<([u32; STRIDED_COPY_MAX_DIMS], [u32; STRIDED_COPY_MAX_DIMS])> {
if out_shape.len() != src_strides.len() {
return Err(GpuError::ShapeMismatch {
op: "strided_copy_pad",
expected: vec![out_shape.len()],
got: vec![src_strides.len()],
});
}
if out_shape.len() > STRIDED_COPY_MAX_DIMS {
return Err(GpuError::ShapeMismatch {
op: "strided_copy_pad",
expected: vec![STRIDED_COPY_MAX_DIMS],
got: vec![out_shape.len()],
});
}
for &s in src_strides {
if s < 0 {
return Err(GpuError::ShapeMismatch {
op: "strided_copy_pad_negative_stride",
expected: vec![0],
got: vec![s.unsigned_abs()],
});
}
}
let rank = out_shape.len();
let mut out_stride = [0u32; STRIDED_COPY_MAX_DIMS];
if rank > 0 {
let mut acc: usize = 1;
for d in (0..rank).rev() {
if acc > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "strided_copy_stride_overflow",
expected: vec![u32::MAX as usize],
got: vec![acc],
});
}
out_stride[d] = acc as u32;
acc = acc.saturating_mul(out_shape[d]);
}
}
let pad_val = (n as u32).saturating_add(1).max(1);
out_stride[rank..STRIDED_COPY_MAX_DIMS].fill(pad_val);
let mut src_stride_out = [0u32; STRIDED_COPY_MAX_DIMS];
for d in 0..rank {
let s = src_strides[d];
if s as usize > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "strided_copy_src_stride_overflow",
expected: vec![u32::MAX as usize],
got: vec![s as usize],
});
}
src_stride_out[d] = s as u32;
}
Ok((out_stride, src_stride_out))
}
#[cfg(feature = "cuda")]
pub fn gpu_strided_copy(
input: &CudaBuffer<f32>,
out_shape: &[usize],
src_strides: &[isize],
src_offset: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let n: usize = out_shape.iter().product();
let (out_stride, src_stride) = pad_strided_copy_params(out_shape, src_strides, n)?;
if n == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
STRIDED_COPY_PTX,
"strided_copy_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "strided_copy_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let src_offset_u32 = src_offset as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&src_offset_u32)
.arg(&n_u32)
.arg(&out_stride[0])
.arg(&out_stride[1])
.arg(&out_stride[2])
.arg(&out_stride[3])
.arg(&out_stride[4])
.arg(&out_stride[5])
.arg(&out_stride[6])
.arg(&out_stride[7])
.arg(&src_stride[0])
.arg(&src_stride[1])
.arg(&src_stride[2])
.arg(&src_stride[3])
.arg(&src_stride[4])
.arg(&src_stride[5])
.arg(&src_stride[6])
.arg(&src_stride[7])
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_strided_copy_f64(
input: &CudaBuffer<f64>,
out_shape: &[usize],
src_strides: &[isize],
src_offset: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let n: usize = out_shape.iter().product();
let (out_stride, src_stride) = pad_strided_copy_params(out_shape, src_strides, n)?;
if n == 0 {
return alloc_zeros_f64(0, device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
STRIDED_COPY_PTX,
"strided_copy_kernel",
"strided_copy_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"strided_copy_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "strided_copy_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let src_offset_u32 = src_offset as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&src_offset_u32)
.arg(&n_u32)
.arg(&out_stride[0])
.arg(&out_stride[1])
.arg(&out_stride[2])
.arg(&out_stride[3])
.arg(&out_stride[4])
.arg(&out_stride[5])
.arg(&out_stride[6])
.arg(&out_stride[7])
.arg(&src_stride[0])
.arg(&src_stride[1])
.arg(&src_stride[2])
.arg(&src_stride[3])
.arg(&src_stride[4])
.arg(&src_stride[5])
.arg(&src_stride[6])
.arg(&src_stride[7])
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_strided_scatter(
src: &CudaBuffer<f32>,
dst: &mut CudaBuffer<f32>,
view_shape: &[usize],
dst_strides: &[isize],
dst_offset: usize,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
validate_device(src, device)?;
validate_device(dst, device)?;
let n: usize = view_shape.iter().product();
if n == 0 {
return Ok(());
}
let (in_decode_stride, dst_stride_padded) =
pad_strided_copy_params(view_shape, dst_strides, n)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
STRIDED_SCATTER_PTX,
"strided_scatter_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "strided_scatter_kernel",
source: e,
});
}
};
let cfg = launch_cfg(n)?;
let dst_offset_u32 = dst_offset as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(dst.inner_mut())
.arg(&dst_offset_u32)
.arg(&n_u32)
.arg(&in_decode_stride[0])
.arg(&in_decode_stride[1])
.arg(&in_decode_stride[2])
.arg(&in_decode_stride[3])
.arg(&in_decode_stride[4])
.arg(&in_decode_stride[5])
.arg(&in_decode_stride[6])
.arg(&in_decode_stride[7])
.arg(&dst_stride_padded[0])
.arg(&dst_stride_padded[1])
.arg(&dst_stride_padded[2])
.arg(&dst_stride_padded[3])
.arg(&dst_stride_padded[4])
.arg(&dst_stride_padded[5])
.arg(&dst_stride_padded[6])
.arg(&dst_stride_padded[7])
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_strided_scatter_f64(
src: &CudaBuffer<f64>,
dst: &mut CudaBuffer<f64>,
view_shape: &[usize],
dst_strides: &[isize],
dst_offset: usize,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
validate_device(src, device)?;
validate_device(dst, device)?;
let n: usize = view_shape.iter().product();
if n == 0 {
return Ok(());
}
let (in_decode_stride, dst_stride_padded) =
pad_strided_copy_params(view_shape, dst_strides, n)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
STRIDED_SCATTER_F64_PTX,
"strided_scatter_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "strided_scatter_f64_kernel",
source: e,
});
}
};
let cfg = launch_cfg(n)?;
let dst_offset_u32 = dst_offset as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(dst.inner_mut())
.arg(&dst_offset_u32)
.arg(&n_u32)
.arg(&in_decode_stride[0])
.arg(&in_decode_stride[1])
.arg(&in_decode_stride[2])
.arg(&in_decode_stride[3])
.arg(&in_decode_stride[4])
.arg(&in_decode_stride[5])
.arg(&in_decode_stride[6])
.arg(&in_decode_stride[7])
.arg(&dst_stride_padded[0])
.arg(&dst_stride_padded[1])
.arg(&dst_stride_padded[2])
.arg(&dst_stride_padded[3])
.arg(&dst_stride_padded[4])
.arg(&dst_stride_padded[5])
.arg(&dst_stride_padded[6])
.arg(&dst_stride_padded[7])
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_scale(
a: &CudaBuffer<f32>,
scalar: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(a, device)?;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SCALE_PTX,
"scale_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "scale_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&scalar)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_softmax(
input: &CudaBuffer<f32>,
rows: usize,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SOFTMAX_PTX,
"softmax_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "softmax_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(rows * cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4, };
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_dropout(
input: &CudaBuffer<f32>,
threshold: u32,
scale: f32,
seed: u32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let n = input.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
DROPOUT_PTX,
"dropout_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "dropout_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&threshold)
.arg(&scale)
.arg(&seed)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_dropout_f64(
input: &CudaBuffer<f64>,
threshold: u32,
scale: f64,
seed: u32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n = input.len();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, DROPOUT_PTX, "dropout_kernel", "dropout_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"dropout_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "dropout_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&threshold)
.arg(&scale)
.arg(&seed)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_dropout_f64(
_input: &CudaBuffer<f64>,
_threshold: u32,
_scale: f64,
_seed: u32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_transpose_2d(
input: &CudaBuffer<f32>,
m: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = m * n;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
TRANSPOSE_2D_PTX,
"transpose_2d_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "transpose_2d_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let m_u32 = m as u32;
let n_u32 = n as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&m_u32)
.arg(&n_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_permute_0213(
input: &CudaBuffer<f32>,
d0: usize,
d1: usize,
d2: usize,
d3: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let total = d0 * d1 * d2 * d3;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
PERMUTE_0213_PTX,
"permute_0213_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "permute_0213_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let d0_u32 = d0 as u32;
let d1_u32 = d1 as u32;
let d2_u32 = d2 as u32;
let d3_u32 = d3 as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&d0_u32)
.arg(&d1_u32)
.arg(&d2_u32)
.arg(&d3_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_small_matmul(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let total = m * n;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SMALL_MATMUL_PTX,
"small_matmul_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(_) => {
return crate::blas::gpu_matmul_f32(a, b, m, k, n, device);
}
};
let mut c = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let m_u32 = m as u32;
let k_u32 = k as u32;
let n_u32 = n as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(c.inner_mut())
.arg(&m_u32)
.arg(&k_u32)
.arg(&n_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(c)
}
#[cfg(feature = "cuda")]
pub fn gpu_small_bmm(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
batch: usize,
m: usize,
k: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
if batch == 1 {
return gpu_small_matmul(a, b, m, k, n, device);
}
crate::blas::gpu_bmm_f32(a, b, batch, m, k, n, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_embed_lookup(
idx: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
d: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
EMBED_LOOKUP_PTX,
"embed_lookup_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "embed_lookup_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(d, device)?;
let cfg = launch_cfg(d)?;
let d_u32 = d as u32;
unsafe {
stream
.launch_builder(&f)
.arg(idx.inner())
.arg(weight.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_slice_write(
src: &CudaBuffer<f32>,
dst: &mut CudaBuffer<f32>,
n_batch: usize,
d: usize,
max_len: usize,
pos: usize,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = n_batch * d;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SLICE_WRITE_PTX,
"slice_write_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "slice_write_kernel",
source: e,
});
}
};
let cfg = launch_cfg(total)?;
let n_u32 = total as u32;
let d_u32 = d as u32;
let max_len_u32 = max_len as u32;
let pos_u32 = pos as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(dst.inner_mut())
.arg(&n_u32)
.arg(&d_u32)
.arg(&max_len_u32)
.arg(&pos_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_slice_read(
src: &CudaBuffer<f32>,
n_batch: usize,
d: usize,
len: usize,
max_len: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let total = n_batch * len * d;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SLICE_READ_PTX,
"slice_read_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "slice_read_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let total_u32 = total as u32;
let d_u32 = d as u32;
let len_u32 = len as u32;
let max_len_u32 = max_len as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(out.inner_mut())
.arg(&total_u32)
.arg(&d_u32)
.arg(&len_u32)
.arg(&max_len_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(input, device)?;
try_launch_unary(input, device, GELU_PTX, "gelu_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_tanh(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(input, device)?;
try_launch_unary(input, device, GELU_TANH_PTX, "gelu_tanh_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_erf(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(input, device)?;
try_launch_unary(input, device, GELU_ERF_PTX, "gelu_erf_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_backward_tanh(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, input, device)?;
try_launch_binary(
grad,
input,
device,
GELU_BACKWARD_TANH_PTX,
"gelu_backward_tanh_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_silu(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(input, device)?;
try_launch_unary(input, device, SILU_PTX, "silu_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_silu_backward(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, input, device)?;
try_launch_binary(
grad,
input,
device,
SILU_BACKWARD_PTX,
"silu_backward_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_elu(
input: &CudaBuffer<f32>,
alpha: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let n = input.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ELU_PTX,
"elu_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "elu_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&alpha)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_elu_backward(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
alpha: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_binary(grad, input, device)?;
let n = grad.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ELU_BACKWARD_PTX,
"elu_backward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "elu_backward_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&alpha)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_mish(input: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(input, device)?;
try_launch_unary(input, device, MISH_PTX, "mish_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_mish_backward(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(grad, input, device)?;
try_launch_binary(
grad,
input,
device,
MISH_BACKWARD_PTX,
"mish_backward_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_clamp(
input: &CudaBuffer<f32>,
min_val: f32,
max_val: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let n = input.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
CLAMP_PTX,
"clamp_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "clamp_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&min_val)
.arg(&max_val)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_clamp_backward(
grad: &CudaBuffer<f32>,
input: &CudaBuffer<f32>,
min_val: f32,
max_val: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: input.len(),
});
}
let n = input.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
CLAMP_BACKWARD_PTX,
"clamp_backward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "clamp_backward_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(input.inner())
.arg(out.inner_mut())
.arg(&min_val)
.arg(&max_val)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_clamp_backward(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_min_val: f32,
_max_val: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_repeat_along_dim(
input: &CudaBuffer<f32>,
outer: usize,
repeat_count: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
if input.len() != outer * inner {
return Err(GpuError::ShapeMismatch {
op: "gpu_repeat_along_dim",
expected: vec![outer * inner],
got: vec![input.len()],
});
}
let total = outer * repeat_count * inner;
if total == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
REPEAT_ALONG_DIM_PTX,
"repeat_along_dim_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "repeat_along_dim_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let outer_u32 = outer as u32;
let rep_u32 = repeat_count as u32;
let inner_u32 = inner as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&rep_u32)
.arg(&inner_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_repeat_along_dim_f64(
input: &CudaBuffer<f64>,
outer: usize,
repeat_count: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if input.len() != outer * inner {
return Err(GpuError::ShapeMismatch {
op: "gpu_repeat_along_dim_f64",
expected: vec![outer * inner],
got: vec![input.len()],
});
}
let total = outer * repeat_count * inner;
if total == 0 {
return alloc_zeros_f64(0, device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
REPEAT_ALONG_DIM_PTX,
"repeat_along_dim_kernel",
"repeat_along_dim_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"repeat_along_dim_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "repeat_along_dim_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(total, device)?;
let cfg = launch_cfg(total)?;
let outer_u32 = outer as u32;
let rep_u32 = repeat_count as u32;
let inner_u32 = inner as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&rep_u32)
.arg(&inner_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_repeat_along_dim(
_input: &CudaBuffer<f32>,
_outer: usize,
_rep: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_repeat_along_dim_f64(
_input: &CudaBuffer<f64>,
_outer: usize,
_rep: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_div(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
validate_binary(a, b, device)?;
try_launch_binary(a, b, device, DIV_PTX, "div_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_exp(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
try_launch_unary(a, device, EXP_PTX, "exp_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_log(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
try_launch_unary(a, device, LOG_PTX, "log_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_sqrt(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
try_launch_unary(a, device, SQRT_PTX, "sqrt_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_pow(
a: &CudaBuffer<f32>,
exponent: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(a, device)?;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
POW_PTX,
"pow_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "pow_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&exponent)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_abs(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
try_launch_unary(a, device, ABS_PTX, "abs_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_sigmoid(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
try_launch_unary(a, device, SIGMOID_PTX, "sigmoid_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_tanh(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
validate_unary(a, device)?;
try_launch_unary(a, device, TANH_PTX, "tanh_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_add_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if a.len() != b.len() {
return Err(GpuError::LengthMismatch {
a: a.len(),
b: b.len(),
});
}
let ptx = get_f64_ptx(&CACHE, ADD_PTX, "add_kernel", "add_f64_kernel");
try_launch_binary_f64(a, b, device, ptx, "add_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_sub_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if a.len() != b.len() {
return Err(GpuError::LengthMismatch {
a: a.len(),
b: b.len(),
});
}
let ptx = get_f64_ptx(&CACHE, SUB_PTX, "sub_kernel", "sub_f64_kernel");
try_launch_binary_f64(a, b, device, ptx, "sub_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_mul_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if a.len() != b.len() {
return Err(GpuError::LengthMismatch {
a: a.len(),
b: b.len(),
});
}
let ptx = get_f64_ptx(&CACHE, MUL_PTX, "mul_kernel", "mul_f64_kernel");
try_launch_binary_f64(a, b, device, ptx, "mul_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_div_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if a.len() != b.len() {
return Err(GpuError::LengthMismatch {
a: a.len(),
b: b.len(),
});
}
let ptx = get_f64_ptx(&CACHE, DIV_PTX, "div_kernel", "div_f64_kernel");
try_launch_binary_f64(a, b, device, ptx, "div_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_neg_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CACHE, NEG_PTX, "neg_kernel", "neg_f64_kernel");
try_launch_unary_f64(a, device, ptx, "neg_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_relu_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CACHE, RELU_PTX, "relu_kernel", "relu_f64_kernel");
try_launch_unary_f64(a, device, ptx, "relu_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_scale_f64(
a: &CudaBuffer<f64>,
scalar: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, SCALE_PTX, "scale_kernel", "scale_f64_kernel");
match crate::module_cache::get_or_compile(ctx, ptx, "scale_f64_kernel", device.ordinal() as u32)
{
Ok(f) => {
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&scalar)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
Err(e) => {
if std::env::var("FERROTORCH_ENABLE_GPU_FALLBACK").is_ok() {
tracing::warn!(
target: "ferrotorch::gpu_fallback",
kernel = "scale_f64_kernel",
error = %e,
"PTX compile failed; falling back to CPU. Unset \
FERROTORCH_ENABLE_GPU_FALLBACK to make this an error instead.",
);
let a_host = gpu_to_cpu(a, device)?;
let result: Vec<f64> = a_host.iter().map(|&x| x * scalar).collect();
return cpu_to_gpu(&result, device);
}
Err(GpuError::PtxCompileFailed {
kernel: "scale_f64_kernel",
source: e,
})
}
}
}
#[cfg(feature = "cuda")]
pub fn gpu_exp_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
try_launch_unary_f64(a, device, EXP_F64_PTX, "exp_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_log_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
try_launch_unary_f64(a, device, LOG_F64_PTX, "log_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_sqrt_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CACHE, SQRT_PTX, "sqrt_kernel", "sqrt_f64_kernel");
try_launch_unary_f64(a, device, ptx, "sqrt_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_pow_f64(
a: &CudaBuffer<f64>,
exponent: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
match crate::module_cache::get_or_compile(
ctx,
POW_F64_PTX,
"pow_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => {
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&exponent)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
Err(e) => {
if std::env::var("FERROTORCH_ENABLE_GPU_FALLBACK").is_ok() {
tracing::warn!(
target: "ferrotorch::gpu_fallback",
kernel = "pow_f64_kernel",
error = %e,
"PTX compile failed; falling back to CPU. Unset \
FERROTORCH_ENABLE_GPU_FALLBACK to make this an error instead.",
);
let a_host = gpu_to_cpu(a, device)?;
let result: Vec<f64> = a_host.iter().map(|&x| x.powf(exponent)).collect();
return cpu_to_gpu(&result, device);
}
Err(GpuError::PtxCompileFailed {
kernel: "pow_f64_kernel",
source: e,
})
}
}
}
#[cfg(feature = "cuda")]
pub fn gpu_abs_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CACHE, ABS_PTX, "abs_kernel", "abs_f64_kernel");
try_launch_unary_f64(a, device, ptx, "abs_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_sigmoid_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
try_launch_unary_f64(a, device, SIGMOID_F64_PTX, "sigmoid_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_tanh_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
try_launch_unary_f64(a, device, TANH_F64_PTX, "tanh_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_relu_backward_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: input.len(),
});
}
let ptx = get_f64_ptx(
&CACHE,
RELU_BACKWARD_PTX,
"relu_backward_kernel",
"relu_backward_f64_kernel",
);
try_launch_binary_f64(grad, input, device, ptx, "relu_backward_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_abs_backward_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: input.len(),
});
}
let ptx = get_f64_ptx(
&CACHE,
ABS_BACKWARD_PTX,
"abs_backward_kernel",
"abs_backward_f64_kernel",
);
try_launch_binary_f64(grad, input, device, ptx, "abs_backward_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_sigmoid_backward_f64(
grad: &CudaBuffer<f64>,
output: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if grad.len() != output.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: output.len(),
});
}
let ptx = get_f64_ptx(
&CACHE,
SIGMOID_BACKWARD_PTX,
"sigmoid_backward_kernel",
"sigmoid_backward_f64_kernel",
);
try_launch_binary_f64(grad, output, device, ptx, "sigmoid_backward_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_tanh_backward_f64(
grad: &CudaBuffer<f64>,
output: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if grad.len() != output.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: output.len(),
});
}
let ptx = get_f64_ptx(
&CACHE,
TANH_BACKWARD_PTX,
"tanh_backward_kernel",
"tanh_backward_f64_kernel",
);
try_launch_binary_f64(grad, output, device, ptx, "tanh_backward_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_add_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(
&CACHE,
BROADCAST_ADD_PTX,
"broadcast_add_kernel",
"broadcast_add_f64_kernel",
);
try_launch_broadcast_binary_f64(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
ptx,
"broadcast_add_f64_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_sub_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(
&CACHE,
BROADCAST_SUB_PTX,
"broadcast_sub_kernel",
"broadcast_sub_f64_kernel",
);
try_launch_broadcast_binary_f64(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
ptx,
"broadcast_sub_f64_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_mul_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(
&CACHE,
BROADCAST_MUL_PTX,
"broadcast_mul_kernel",
"broadcast_mul_f64_kernel",
);
try_launch_broadcast_binary_f64(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
ptx,
"broadcast_mul_f64_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_broadcast_div_f64(
a: &CudaBuffer<f64>,
b: &CudaBuffer<f64>,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let a_str = broadcast_strides(a_shape, out_shape);
let b_str = broadcast_strides(b_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&d| d as u32).collect();
let out_numel: usize = out_shape.iter().product();
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(
&CACHE,
BROADCAST_DIV_PTX,
"broadcast_div_kernel",
"broadcast_div_f64_kernel",
);
try_launch_broadcast_binary_f64(
a,
b,
&a_str,
&b_str,
&shape_u32,
out_numel,
device,
ptx,
"broadcast_div_f64_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_reduce_sum_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n = a.len();
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "gpu_reduce_sum_f64",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
if n == 0 {
return cpu_to_gpu(&[0.0f64], device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
REDUCE_SUM_PTX,
"reduce_sum_kernel",
"reduce_sum_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"reduce_sum_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "reduce_sum_kernel",
source: e,
});
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = alloc_zeros_f64(num_blocks as usize, device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host_partials = gpu_to_cpu(&partials, device)?;
let total: f64 = host_partials.iter().sum();
return cpu_to_gpu(&[total], device);
}
gpu_reduce_sum_f64(&partials, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_reduce_min_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n = a.len();
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "gpu_reduce_min_f64",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
if n == 0 {
return cpu_to_gpu(&[f64::INFINITY], device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
REDUCE_MIN_PTX,
"reduce_min_kernel",
"reduce_min_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"reduce_min_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "reduce_min_kernel",
source: e,
});
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = cpu_to_gpu(&vec![f64::INFINITY; num_blocks as usize], device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host_partials = gpu_to_cpu(&partials, device)?;
let total = host_partials.iter().copied().fold(f64::INFINITY, f64::min);
return cpu_to_gpu(&[total], device);
}
gpu_reduce_min_f64(&partials, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_reduce_max_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n = a.len();
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "gpu_reduce_max_f64",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
if n == 0 {
return cpu_to_gpu(&[f64::NEG_INFINITY], device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
REDUCE_MAX_PTX,
"reduce_max_kernel",
"reduce_max_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"reduce_max_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "reduce_max_kernel",
source: e,
});
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = cpu_to_gpu(&vec![f64::NEG_INFINITY; num_blocks as usize], device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host_partials = gpu_to_cpu(&partials, device)?;
let total = host_partials
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max);
return cpu_to_gpu(&[total], device);
}
gpu_reduce_max_f64(&partials, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_reduce_prod_f64(a: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n = a.len();
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "gpu_reduce_prod_f64",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
if n == 0 {
return cpu_to_gpu(&[1.0_f64], device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
REDUCE_PROD_PTX,
"reduce_prod_kernel",
"reduce_prod_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"reduce_prod_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "reduce_prod_kernel",
source: e,
});
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = cpu_to_gpu(&vec![1.0_f64; num_blocks as usize], device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host = gpu_to_cpu(&partials, device)?;
let total: f64 = host.iter().product();
return cpu_to_gpu(&[total], device);
}
gpu_reduce_prod_f64(&partials, device)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_reduce_prod_f64(
_a: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_reduce_min_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_reduce_max_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_masked_reduce_min_f64(
data: &CudaBuffer<f64>,
mask_f: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if data.len() != mask_f.len() {
return Err(GpuError::LengthMismatch {
a: data.len(),
b: mask_f.len(),
});
}
let n = data.len();
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "gpu_masked_reduce_min_f64",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
if n == 0 {
return cpu_to_gpu(&[f64::INFINITY], device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
MASKED_REDUCE_MIN_PTX,
"masked_reduce_min_kernel",
"masked_reduce_min_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"masked_reduce_min_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "masked_reduce_min_kernel",
source: e,
});
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = cpu_to_gpu(&vec![f64::INFINITY; num_blocks as usize], device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f)
.arg(data.inner())
.arg(mask_f.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host = gpu_to_cpu(&partials, device)?;
let total = host.iter().copied().fold(f64::INFINITY, f64::min);
return cpu_to_gpu(&[total], device);
}
gpu_reduce_min_f64(&partials, device)
}
#[cfg(feature = "cuda")]
pub fn gpu_masked_reduce_max_f64(
data: &CudaBuffer<f64>,
mask_f: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if data.len() != mask_f.len() {
return Err(GpuError::LengthMismatch {
a: data.len(),
b: mask_f.len(),
});
}
let n = data.len();
if n > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "gpu_masked_reduce_max_f64",
expected: vec![u32::MAX as usize],
got: vec![n],
});
}
if n == 0 {
return cpu_to_gpu(&[f64::NEG_INFINITY], device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
MASKED_REDUCE_MAX_PTX,
"masked_reduce_max_kernel",
"masked_reduce_max_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"masked_reduce_max_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "masked_reduce_max_kernel",
source: e,
});
}
};
const BLOCK: u32 = 256;
let num_blocks = ((n as u32).saturating_add(BLOCK - 1)) / BLOCK;
let num_blocks = num_blocks.min(1024);
let mut partials = cpu_to_gpu(&vec![f64::NEG_INFINITY; num_blocks as usize], device)?;
let n_u32 = n as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks.max(1), 1, 1),
block_dim: (BLOCK, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f)
.arg(data.inner())
.arg(mask_f.inner())
.arg(partials.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
if num_blocks <= 1 {
return Ok(partials);
}
if num_blocks <= 256 {
let host = gpu_to_cpu(&partials, device)?;
let total = host.iter().copied().fold(f64::NEG_INFINITY, f64::max);
return cpu_to_gpu(&[total], device);
}
gpu_reduce_max_f64(&partials, device)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_reduce_min(
_d: &CudaBuffer<f32>,
_m: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_reduce_max(
_d: &CudaBuffer<f32>,
_m: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_reduce_min_f64(
_d: &CudaBuffer<f64>,
_m: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_reduce_max_f64(
_d: &CudaBuffer<f64>,
_m: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_pad_truncate_complex_f64(
src: &CudaBuffer<f64>,
batch: usize,
src_n: usize,
dst_n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if src.len() != batch * src_n * 2 {
return Err(GpuError::ShapeMismatch {
op: "gpu_pad_truncate_complex_f64",
expected: vec![batch * src_n * 2],
got: vec![src.len()],
});
}
let total_pairs = batch * dst_n;
if total_pairs == 0 {
return alloc_zeros_f64(0, device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = CACHE
.get_or_init(|| {
let mut s = ptx_f32_to_f64(
PAD_TRUNCATE_PTX,
"pad_truncate_kernel",
"pad_truncate_f64_kernel",
);
s = s.replace("[%off_src + 4]", "[%off_src + 8]");
s = s.replace("[%off_dst + 4]", "[%off_dst + 8]");
s
})
.as_str();
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"pad_truncate_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "pad_truncate_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(batch * dst_n * 2, device)?;
let cfg = launch_cfg(total_pairs)?;
let batch_u32 = batch as u32;
let src_n_u32 = src_n as u32;
let dst_n_u32 = dst_n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(out.inner_mut())
.arg(&batch_u32)
.arg(&src_n_u32)
.arg(&dst_n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_pad_truncate_complex_f64(
_src: &CudaBuffer<f64>,
_batch: usize,
_src_n: usize,
_dst_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_sum_axis_f64(
a: &CudaBuffer<f64>,
outer: usize,
axis_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let total_output = outer * inner;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
SUM_AXIS_PTX,
"sum_axis_kernel",
"sum_axis_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"sum_axis_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "sum_axis_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(total_output, device)?;
let cfg = launch_cfg(total_output)?;
let outer_u32 = outer as u32;
let axis_size_u32 = axis_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total_output as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&axis_size_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_reduce_sum_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sum_axis_f64(
_a: &CudaBuffer<f64>,
_outer: usize,
_axis_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub(crate) const TRANSPOSE_COMPLEX_F32_PTX: &str = "\
.version 7.0\n\
.target sm_52\n\
.address_size 64\n\
\n\
.visible .entry transpose_complex_f32_kernel(\n\
.param .u64 in_ptr,\n\
.param .u64 out_ptr,\n\
.param .u32 N,\n\
.param .u32 total\n\
) {\n\
.reg .u32 %tid, %bid, %bdim, %total_reg, %N_reg;\n\
.reg .u32 %out_row, %out_col, %in_idx;\n\
.reg .u32 %in_off, %out_off;\n\
.reg .u64 %in, %out;\n\
.reg .u64 %p_in_re, %p_in_im, %p_out_re, %p_out_im;\n\
.reg .f32 %re, %im;\n\
.reg .pred %p;\n\
\n\
ld.param.u64 %in, [in_ptr];\n\
ld.param.u64 %out, [out_ptr];\n\
ld.param.u32 %N_reg, [N];\n\
ld.param.u32 %total_reg, [total];\n\
\n\
mov.u32 %bid, %ctaid.x;\n\
mov.u32 %bdim, %ntid.x;\n\
mov.u32 %tid, %tid.x;\n\
mad.lo.u32 %tid, %bid, %bdim, %tid;\n\
\n\
setp.ge.u32 %p, %tid, %total_reg;\n\
@%p bra DONE;\n\
\n\
// Output element: out_row = tid / N, out_col = tid % N.\n\
div.u32 %out_row, %tid, %N_reg;\n\
rem.u32 %out_col, %tid, %N_reg;\n\
// Input element in column-major: (out_col, out_row) = out_col*N + out_row.\n\
mad.lo.u32 %in_idx, %out_col, %N_reg, %out_row;\n\
\n\
// Byte offsets: each complex element is 2 f32 = 8 bytes.\n\
shl.b32 %in_off, %in_idx, 3;\n\
shl.b32 %out_off, %tid, 3;\n\
\n\
cvt.u64.u32 %p_in_re, %in_off;\n\
add.u64 %p_in_re, %in, %p_in_re;\n\
add.u64 %p_in_im, %p_in_re, 4;\n\
\n\
cvt.u64.u32 %p_out_re, %out_off;\n\
add.u64 %p_out_re, %out, %p_out_re;\n\
add.u64 %p_out_im, %p_out_re, 4;\n\
\n\
ld.global.f32 %re, [%p_in_re];\n\
ld.global.f32 %im, [%p_in_im];\n\
st.global.f32 [%p_out_re], %re;\n\
st.global.f32 [%p_out_im], %im;\n\
\n\
DONE:\n\
ret;\n\
}\n\
";
#[cfg(feature = "cuda")]
pub fn gpu_transpose_2d_f64(
input: &CudaBuffer<f64>,
m: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let total = m * n;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
TRANSPOSE_2D_PTX,
"transpose_2d_kernel",
"transpose_2d_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"transpose_2d_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "transpose_2d_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(total, device)?;
let cfg = launch_cfg(total)?;
let m_u32 = m as u32;
let n_u32 = n as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&m_u32)
.arg(&n_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_transpose_complex_f32(
input: &CudaBuffer<f32>,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let n_elems = n * n; if input.len() != 2 * n_elems {
return Err(GpuError::ShapeMismatch {
op: "gpu_transpose_complex_f32",
expected: vec![2 * n_elems],
got: vec![input.len()],
});
}
if n == 0 {
return crate::transfer::alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
TRANSPOSE_COMPLEX_F32_PTX,
"transpose_complex_f32_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "transpose_complex_f32_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(2 * n_elems, device)?;
let cfg = launch_cfg(n_elems)?;
let n_u32 = n as u32;
let total_u32 = n_elems as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_transpose_complex_f64(
input: &CudaBuffer<f64>,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n_elems = n * n;
if input.len() != 2 * n_elems {
return Err(GpuError::ShapeMismatch {
op: "gpu_transpose_complex_f64",
expected: vec![2 * n_elems],
got: vec![input.len()],
});
}
if n == 0 {
return crate::transfer::alloc_zeros_f64(0, device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
TRANSPOSE_COMPLEX_F32_PTX,
"transpose_complex_f32_kernel",
"transpose_complex_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"transpose_complex_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "transpose_complex_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(2 * n_elems, device)?;
let cfg = launch_cfg(n_elems)?;
let n_u32 = n as u32;
let total_u32 = n_elems as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_permute_0213_f64(
input: &CudaBuffer<f64>,
d0: usize,
d1: usize,
d2: usize,
d3: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let total = d0 * d1 * d2 * d3;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
PERMUTE_0213_PTX,
"permute_0213_kernel",
"permute_0213_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"permute_0213_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "permute_0213_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(total, device)?;
let cfg = launch_cfg(total)?;
let d0_u32 = d0 as u32;
let d1_u32 = d1 as u32;
let d2_u32 = d2 as u32;
let d3_u32 = d3 as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&d0_u32)
.arg(&d1_u32)
.arg(&d2_u32)
.arg(&d3_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_strided_split_f64(
input: &CudaBuffer<f64>,
total_along_axis: usize,
split_offset: usize,
split_size: usize,
inner_size: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
STRIDED_SPLIT_PTX,
"strided_split_kernel",
"strided_split_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"strided_split_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "strided_split_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let total_ax_u32 = total_along_axis as u32;
let offset_u32 = split_offset as u32;
let split_sz_u32 = split_size as u32;
let inner_u32 = inner_size as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&total_ax_u32)
.arg(&offset_u32)
.arg(&split_sz_u32)
.arg(&inner_u32)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_strided_cat_f64(
input: &CudaBuffer<f64>,
output: &mut CudaBuffer<f64>,
total_along_axis: usize,
cat_offset: usize,
part_size: usize,
inner_size: usize,
n: usize,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
STRIDED_CAT_PTX,
"strided_cat_kernel",
"strided_cat_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"strided_cat_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "strided_cat_kernel",
source: e,
});
}
};
let cfg = launch_cfg(n)?;
let total_ax_u32 = total_along_axis as u32;
let offset_u32 = cat_offset as u32;
let part_sz_u32 = part_size as u32;
let inner_u32 = inner_size as u32;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(output.inner_mut())
.arg(&total_ax_u32)
.arg(&offset_u32)
.arg(&part_sz_u32)
.arg(&inner_u32)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_index_select_1d_f64(
input: &CudaBuffer<f64>,
indices: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let n = indices.len();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
INDEX_SELECT_1D_PTX,
"index_select_1d_kernel",
"index_select_1d_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"index_select_1d_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "index_select_1d_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_scatter_add_1d_f64(
grad_output: &CudaBuffer<f64>,
indices: &CudaBuffer<f32>,
input_len: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(grad_output, device)?;
let n = grad_output.len();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
SCATTER_ADD_1D_PTX,
"scatter_add_1d_kernel",
"scatter_add_1d_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"scatter_add_1d_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "scatter_add_1d_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(input_len, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad_output.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_masked_fill_f64(
input: &CudaBuffer<f64>,
mask: &CudaBuffer<u8>,
value: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let n = input.len();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
MASKED_FILL_PTX,
"masked_fill_kernel",
"masked_fill_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"masked_fill_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "masked_fill_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(mask.inner())
.arg(out.inner_mut())
.arg(&value)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_masked_zero_f64(
grad: &CudaBuffer<f64>,
mask: &CudaBuffer<u8>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(grad, device)?;
let n = grad.len();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
MASKED_ZERO_PTX,
"masked_zero_kernel",
"masked_zero_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"masked_zero_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "masked_zero_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(mask.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_slice_write_f64(
src: &CudaBuffer<f64>,
dst: &mut CudaBuffer<f64>,
n_batch: usize,
d: usize,
max_len: usize,
pos: usize,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let total = n_batch * d;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
SLICE_WRITE_PTX,
"slice_write_kernel",
"slice_write_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"slice_write_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "slice_write_kernel",
source: e,
});
}
};
let cfg = launch_cfg(total)?;
let n_u32 = total as u32;
let d_u32 = d as u32;
let max_len_u32 = max_len as u32;
let pos_u32 = pos as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(dst.inner_mut())
.arg(&n_u32)
.arg(&d_u32)
.arg(&max_len_u32)
.arg(&pos_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_slice_read_f64(
src: &CudaBuffer<f64>,
n_batch: usize,
d: usize,
len: usize,
max_len: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let total = n_batch * len * d;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
SLICE_READ_PTX,
"slice_read_kernel",
"slice_read_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"slice_read_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "slice_read_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(total, device)?;
let cfg = launch_cfg(total)?;
let total_u32 = total as u32;
let d_u32 = d as u32;
let len_u32 = len as u32;
let max_len_u32 = max_len as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(out.inner_mut())
.arg(&total_u32)
.arg(&d_u32)
.arg(&len_u32)
.arg(&max_len_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_embed_lookup_f64(
idx: &CudaBuffer<f32>,
weight: &CudaBuffer<f64>,
d: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
EMBED_LOOKUP_PTX,
"embed_lookup_kernel",
"embed_lookup_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"embed_lookup_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "embed_lookup_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(d, device)?;
let cfg = launch_cfg(d)?;
let d_u32 = d as u32;
unsafe {
stream
.launch_builder(&f)
.arg(idx.inner())
.arg(weight.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_embed_lookup_batch_f64(
indices: &CudaBuffer<f32>,
weight: &CudaBuffer<f64>,
n: usize,
d: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let total = n * d;
if total == 0 {
return alloc_zeros_f64(0, device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
EMBED_LOOKUP_BATCH_PTX,
"embed_lookup_batch_kernel",
"embed_lookup_batch_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"embed_lookup_batch_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "embed_lookup_batch_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(total, device)?;
let cfg = launch_cfg(total)?;
let d_u32 = d as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(indices.inner())
.arg(weight.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_scatter_add_rows_f64(
grad_output: &CudaBuffer<f64>,
indices: &CudaBuffer<f32>,
num_embeddings: usize,
d: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n = indices.len();
let total = n * d;
if total == 0 {
return alloc_zeros_f64(num_embeddings * d, device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
SCATTER_ADD_ROWS_PTX,
"scatter_add_rows_kernel",
"scatter_add_rows_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"scatter_add_rows_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "scatter_add_rows_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(num_embeddings * d, device)?;
let cfg = launch_cfg(total)?;
let d_u32 = d as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad_output.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_fused_adam(
param: &mut CudaBuffer<f32>,
grad: &CudaBuffer<f32>,
exp_avg: &mut CudaBuffer<f32>,
exp_avg_sq: &mut CudaBuffer<f32>,
beta1: f32,
beta2: f32,
lr: f32,
eps: f32,
bc1: f32,
bc2: f32,
weight_decay: f32,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let n = param.len();
if grad.len() != n || exp_avg.len() != n || exp_avg_sq.len() != n {
return Err(GpuError::LengthMismatch {
a: n,
b: grad.len(),
});
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
FUSED_ADAM_PTX,
"fused_adam_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "fused_adam_kernel",
source: e,
});
}
};
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(param.inner_mut())
.arg(grad.inner())
.arg(exp_avg.inner_mut())
.arg(exp_avg_sq.inner_mut())
.arg(&beta1)
.arg(&beta2)
.arg(&lr)
.arg(&eps)
.arg(&bc1)
.arg(&bc2)
.arg(&weight_decay)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_fused_adam(
_param: &mut CudaBuffer<f32>,
_grad: &CudaBuffer<f32>,
_exp_avg: &mut CudaBuffer<f32>,
_exp_avg_sq: &mut CudaBuffer<f32>,
_beta1: f32,
_beta2: f32,
_lr: f32,
_eps: f32,
_bc1: f32,
_bc2: f32,
_weight_decay: f32,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_fused_gru_forward(
input_gates: &CudaBuffer<f32>,
hidden_gates: &CudaBuffer<f32>,
bias_ih: &CudaBuffer<f32>,
bias_hh: &CudaBuffer<f32>,
hx: &CudaBuffer<f32>,
hsz: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
use cudarc::driver::PushKernelArg;
let total = hx.len(); let batch = total / hsz;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
FUSED_GRU_FORWARD_PTX,
"fused_gru_forward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "fused_gru_forward_kernel",
source: e,
});
}
};
let mut hy = alloc_zeros_f32(total, device)?;
let mut workspace = alloc_zeros_f32(batch * 5 * hsz, device)?;
let cfg = launch_cfg(total)?;
let hsz_u32 = hsz as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input_gates.inner())
.arg(hidden_gates.inner())
.arg(bias_ih.inner())
.arg(bias_hh.inner())
.arg(hx.inner())
.arg(hy.inner_mut())
.arg(workspace.inner_mut())
.arg(&hsz_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok((hy, workspace))
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_fused_gru_forward(
_input_gates: &CudaBuffer<f32>,
_hidden_gates: &CudaBuffer<f32>,
_bias_ih: &CudaBuffer<f32>,
_bias_hh: &CudaBuffer<f32>,
_hx: &CudaBuffer<f32>,
_hsz: usize,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_maxpool2d(
input: &CudaBuffer<f32>,
batch: usize,
channels: usize,
h_in: usize,
w_in: usize,
kh: usize,
kw: usize,
sh: usize,
sw: usize,
ph: usize,
pw: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
use cudarc::driver::PushKernelArg;
let h_out = (h_in + 2 * ph - kh) / sh + 1;
let w_out = (w_in + 2 * pw - kw) / sw + 1;
let total = batch * channels * h_out * w_out;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
MAXPOOL2D_PTX,
"maxpool2d_forward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "maxpool2d_forward_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let (batch_u32, ch_u32) = (batch as u32, channels as u32);
let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
let (kh_u32, kw_u32) = (kh as u32, kw as u32);
let (sh_u32, sw_u32) = (sh as u32, sw as u32);
let (ph_u32, pw_u32) = (ph as u32, pw as u32);
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&batch_u32)
.arg(&ch_u32)
.arg(&h_in_u32)
.arg(&w_in_u32)
.arg(&h_out_u32)
.arg(&w_out_u32)
.arg(&kh_u32)
.arg(&kw_u32)
.arg(&sh_u32)
.arg(&sw_u32)
.arg(&ph_u32)
.arg(&pw_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok((out, [batch, channels, h_out, w_out]))
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_maxpool2d(
_input: &CudaBuffer<f32>,
_batch: usize,
_channels: usize,
_h_in: usize,
_w_in: usize,
_kh: usize,
_kw: usize,
_sh: usize,
_sw: usize,
_ph: usize,
_pw: usize,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_avgpool2d(
input: &CudaBuffer<f32>,
batch: usize,
channels: usize,
h_in: usize,
w_in: usize,
kh: usize,
kw: usize,
sh: usize,
sw: usize,
ph: usize,
pw: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
use cudarc::driver::PushKernelArg;
let h_out = (h_in + 2 * ph - kh) / sh + 1;
let w_out = (w_in + 2 * pw - kw) / sw + 1;
let total = batch * channels * h_out * w_out;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
AVGPOOL2D_PTX,
"avgpool2d_forward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "avgpool2d_forward_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let (batch_u32, ch_u32) = (batch as u32, channels as u32);
let (h_in_u32, w_in_u32) = (h_in as u32, w_in as u32);
let (h_out_u32, w_out_u32) = (h_out as u32, w_out as u32);
let (kh_u32, kw_u32) = (kh as u32, kw as u32);
let (sh_u32, sw_u32) = (sh as u32, sw as u32);
let (ph_u32, pw_u32) = (ph as u32, pw as u32);
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&batch_u32)
.arg(&ch_u32)
.arg(&h_in_u32)
.arg(&w_in_u32)
.arg(&h_out_u32)
.arg(&w_out_u32)
.arg(&kh_u32)
.arg(&kw_u32)
.arg(&sh_u32)
.arg(&sw_u32)
.arg(&ph_u32)
.arg(&pw_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok((out, [batch, channels, h_out, w_out]))
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_avgpool2d(
_input: &CudaBuffer<f32>,
_batch: usize,
_channels: usize,
_h_in: usize,
_w_in: usize,
_kh: usize,
_kw: usize,
_sh: usize,
_sw: usize,
_ph: usize,
_pw: usize,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, [usize; 4])> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_batchnorm_forward(
input: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
bias: &CudaBuffer<f32>,
running_mean: &mut CudaBuffer<f32>,
running_var: &mut CudaBuffer<f32>,
channels: usize,
spatial: usize,
eps: f32,
momentum: f32,
training: bool,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
use cudarc::driver::PushKernelArg;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
BATCHNORM_FORWARD_PTX,
"batchnorm_forward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "batchnorm_forward_kernel",
source: e,
});
}
};
let total = input.len(); let total_per_ch = total / channels;
let mut output = alloc_zeros_f32(total, device)?;
let mut save_mean = alloc_zeros_f32(channels, device)?;
let mut save_invstd = alloc_zeros_f32(channels, device)?;
let threads = (total_per_ch as u32).clamp(1, 256);
let cfg = LaunchConfig {
grid_dim: (channels as u32, 1, 1),
block_dim: (threads, 1, 1),
shared_mem_bytes: 256 * 4 * 2,
};
let channels_u32 = channels as u32;
let spatial_u32 = spatial as u32;
let total_per_ch_u32 = total_per_ch as u32;
let training_u32 = u32::from(training);
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(output.inner_mut())
.arg(weight.inner())
.arg(bias.inner())
.arg(running_mean.inner_mut())
.arg(running_var.inner_mut())
.arg(save_mean.inner_mut())
.arg(save_invstd.inner_mut())
.arg(&channels_u32)
.arg(&spatial_u32)
.arg(&eps)
.arg(&momentum)
.arg(&total_per_ch_u32)
.arg(&training_u32)
.launch(cfg)?;
}
Ok((output, save_mean, save_invstd))
}
#[cfg(not(feature = "cuda"))]
#[allow(clippy::too_many_arguments)]
pub fn gpu_batchnorm_forward(
_input: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_bias: &CudaBuffer<f32>,
_running_mean: &mut CudaBuffer<f32>,
_running_var: &mut CudaBuffer<f32>,
_channels: usize,
_spatial: usize,
_eps: f32,
_momentum: f32,
_training: bool,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_layernorm(
input: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
bias: &CudaBuffer<f32>,
rows: usize,
cols: usize,
eps: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LAYERNORM_PTX,
"layernorm_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "layernorm_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(rows * cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(weight.inner())
.arg(bias.inner())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_layernorm_backward(
input: &CudaBuffer<f32>,
grad_output: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
rows: usize,
cols: usize,
eps: f32,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LAYERNORM_BACKWARD_PTX,
"layernorm_backward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "layernorm_backward_kernel",
source: e,
});
}
};
let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
let mut grad_w = alloc_zeros_f32(cols, device)?;
let mut grad_b = alloc_zeros_f32(cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(grad_output.inner())
.arg(weight.inner())
.arg(grad_in.inner_mut())
.arg(grad_w.inner_mut())
.arg(grad_b.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok((grad_in, grad_w, grad_b))
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_layernorm_backward(
_input: &CudaBuffer<f32>,
_grad_output: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_rows: usize,
_cols: usize,
_eps: f32,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>, CudaBuffer<f32>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_rmsnorm(
input: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
rows: usize,
cols: usize,
eps: f32,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
RMSNORM_PTX,
"rmsnorm_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "rmsnorm_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(rows * cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(weight.inner())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_rmsnorm_backward(
input: &CudaBuffer<f32>,
grad_output: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
rows: usize,
cols: usize,
eps: f32,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
use cudarc::driver::PushKernelArg;
validate_unary(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
RMSNORM_BACKWARD_PTX,
"rmsnorm_backward_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "rmsnorm_backward_kernel",
source: e,
});
}
};
let mut grad_in = alloc_zeros_f32(rows * cols, device)?;
let mut grad_w = alloc_zeros_f32(cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(grad_output.inner())
.arg(weight.inner())
.arg(grad_in.inner_mut())
.arg(grad_w.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok((grad_in, grad_w))
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_rmsnorm(
_input: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_rows: usize,
_cols: usize,
_eps: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_rmsnorm_backward(
_input: &CudaBuffer<f32>,
_grad_output: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_rows: usize,
_cols: usize,
_eps: f32,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_add_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
validate_binary(a, b, device)?;
if out.len() < a.len() {
return Err(GpuError::ShapeMismatch {
op: "add_into",
expected: vec![a.len()],
got: vec![out.len()],
});
}
try_launch_binary_into(a, b, out, device, ADD_PTX, "add_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_add_into_on_stream(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
validate_binary(a, b, device)?;
if out.len() < a.len() {
return Err(GpuError::ShapeMismatch {
op: "add_into_on_stream",
expected: vec![a.len()],
got: vec![out.len()],
});
}
let n = a.len();
let ctx = device.context();
let f = match crate::module_cache::get_or_compile(
ctx,
ADD_PTX,
"add_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "add_kernel",
source: e,
});
}
};
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_mul_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
validate_binary(a, b, device)?;
if out.len() < a.len() {
return Err(GpuError::ShapeMismatch {
op: "mul_into",
expected: vec![a.len()],
got: vec![out.len()],
});
}
try_launch_binary_into(a, b, out, device, MUL_PTX, "mul_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_scale_into(
a: &CudaBuffer<f32>,
scalar: f32,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
validate_unary(a, device)?;
let n = a.len();
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
SCALE_PTX,
"scale_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "scale_kernel",
source: e,
})?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&scalar)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_fill_f32(n: usize, scalar: f32, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
FILL_F32_PTX,
"fill_f32_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "fill_f32_kernel",
source: e,
})?;
let mut out = alloc_zeros_f32(n, device)?;
if n == 0 {
return Ok(out);
}
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(out.inner_mut())
.arg(&scalar)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_fill_f64(n: usize, scalar: f64, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, FILL_F32_PTX, "fill_f32_kernel", "fill_f64_kernel");
let f =
crate::module_cache::get_or_compile(ctx, ptx, "fill_f64_kernel", device.ordinal() as u32)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "fill_f64_kernel",
source: e,
})?;
let mut out = alloc_zeros_f64(n, device)?;
if n == 0 {
return Ok(out);
}
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(out.inner_mut())
.arg(&scalar)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_has_inf_nan(a: &CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<bool> {
use cudarc::driver::PushKernelArg;
let n = a.len();
if n == 0 {
return Ok(false);
}
validate_unary(a, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
HAS_INF_NAN_F32_PTX,
"has_inf_nan_f32_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
if std::env::var("FERROTORCH_ENABLE_GPU_FALLBACK").is_ok() {
tracing::warn!(
target: "ferrotorch::gpu_fallback",
kernel = "has_inf_nan_f32_kernel",
error = %e,
"PTX compile failed; falling back to host-readback. \
Unset FERROTORCH_ENABLE_GPU_FALLBACK to make this an error instead.",
);
let host: Vec<f32> = crate::transfer::gpu_to_cpu(a, device)?;
return Ok(host.iter().any(|v| !v.is_finite()));
}
return Err(GpuError::PtxCompileFailed {
kernel: "has_inf_nan_f32_kernel",
source: e,
});
}
};
let mut flag = crate::transfer::alloc_zeros::<u32>(1, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(&n_u32)
.arg(flag.inner_mut())
.launch(cfg)?;
}
let host: Vec<u32> = crate::transfer::gpu_to_cpu(&flag, device)?;
Ok(host[0] != 0)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_has_inf_nan(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<bool> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_into(
a: &CudaBuffer<f32>,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
validate_unary(a, device)?;
try_launch_unary_into(a, out, device, GELU_PTX, "gelu_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_embed_lookup_into(
idx: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
d: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
EMBED_LOOKUP_PTX,
"embed_lookup_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "embed_lookup_kernel",
source: e,
})?;
let cfg = launch_cfg(d)?;
let d_u32 = d as u32;
unsafe {
stream
.launch_builder(&f)
.arg(idx.inner())
.arg(weight.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_embed_lookup_batch(
indices: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
n: usize,
d: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let total = n * d;
if total == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
EMBED_LOOKUP_BATCH_PTX,
"embed_lookup_batch_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "embed_lookup_batch_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total, device)?;
let cfg = launch_cfg(total)?;
let d_u32 = d as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(indices.inner())
.arg(weight.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_scatter_add_rows(
grad_output: &CudaBuffer<f32>,
indices: &CudaBuffer<f32>,
num_embeddings: usize,
d: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let n = indices.len();
let total = n * d;
if total == 0 {
return alloc_zeros_f32(num_embeddings * d, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SCATTER_ADD_ROWS_PTX,
"scatter_add_rows_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "scatter_add_rows_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(num_embeddings * d, device)?;
let cfg = launch_cfg(total)?;
let d_u32 = d as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad_output.inner())
.arg(indices.inner())
.arg(out.inner_mut())
.arg(&d_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_transpose_2d_into(
a: &CudaBuffer<f32>,
m: usize,
n: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = m * n;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
TRANSPOSE_2D_PTX,
"transpose_2d_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "transpose_2d_kernel",
source: e,
})?;
let cfg = launch_cfg(total)?;
let m_u32 = m as u32;
let n_u32 = n as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&m_u32)
.arg(&n_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_permute_0213_into(
a: &CudaBuffer<f32>,
d0: usize,
d1: usize,
d2: usize,
d3: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = d0 * d1 * d2 * d3;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
PERMUTE_0213_PTX,
"permute_0213_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "permute_0213_kernel",
source: e,
})?;
let cfg = launch_cfg(total)?;
let (d0u, d1u, d2u, d3u, tu) = (d0 as u32, d1 as u32, d2 as u32, d3 as u32, total as u32);
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&d0u)
.arg(&d1u)
.arg(&d2u)
.arg(&d3u)
.arg(&tu)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_softmax_into(
a: &CudaBuffer<f32>,
rows: usize,
cols: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
SOFTMAX_PTX,
"softmax_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "softmax_kernel",
source: e,
})?;
let block_size = 256u32;
let grid_size = rows as u32;
let cfg = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: (cols as u32) * 4,
};
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
#[allow(clippy::too_many_arguments)]
pub fn gpu_layernorm_into(
input: &CudaBuffer<f32>,
weight: &CudaBuffer<f32>,
bias: &CudaBuffer<f32>,
rows: usize,
cols: usize,
eps: f32,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
LAYERNORM_PTX,
"layernorm_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "layernorm_kernel",
source: e,
})?;
let block_size = 256u32;
let grid_size = rows as u32;
let cfg = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: (cols as u32) * 4,
};
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(weight.inner())
.arg(bias.inner())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_slice_read_into(
src: &CudaBuffer<f32>,
n_batch: usize,
d: usize,
len: usize,
max_len: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = n_batch * len * d;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
SLICE_READ_PTX,
"slice_read_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "slice_read_kernel",
source: e,
})?;
let cfg = launch_cfg(total)?;
let total_u32 = total as u32;
let d_u32 = d as u32;
let len_u32 = len as u32;
let max_len_u32 = max_len as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(out.inner_mut())
.arg(&total_u32)
.arg(&d_u32)
.arg(&len_u32)
.arg(&max_len_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_small_matmul_into(
a: &CudaBuffer<f32>,
b: &CudaBuffer<f32>,
m: usize,
k: usize,
n: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = m * n;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
SMALL_MATMUL_PTX,
"small_matmul_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "small_matmul_kernel",
source: e,
})?;
let cfg = launch_cfg(total)?;
let (m_u32, k_u32, n_u32, total_u32) = (m as u32, k as u32, n as u32, total as u32);
unsafe {
stream
.launch_builder(&f)
.arg(a.inner())
.arg(b.inner())
.arg(out.inner_mut())
.arg(&m_u32)
.arg(&k_u32)
.arg(&n_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_slice_write_indirect(
src: &CudaBuffer<f32>,
dst: &mut CudaBuffer<f32>,
n_batch: usize,
d: usize,
max_len: usize,
pos_ptr: &cudarc::driver::CudaSlice<u32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = n_batch * d;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
SLICE_WRITE_INDIRECT_PTX,
"slice_write_indirect_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "slice_write_indirect_kernel",
source: e,
})?;
let cfg = launch_cfg(total)?;
let n_u32 = total as u32;
let d_u32 = d as u32;
let max_len_u32 = max_len as u32;
unsafe {
stream
.launch_builder(&f)
.arg(src.inner())
.arg(dst.inner_mut())
.arg(&n_u32)
.arg(&d_u32)
.arg(&max_len_u32)
.arg(pos_ptr)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn gpu_causal_mask_indirect(
total_len_ptr: &cudarc::driver::CudaSlice<u32>,
n_head: usize,
max_pos: usize,
out: &mut CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<()> {
use cudarc::driver::PushKernelArg;
let total = n_head * max_pos;
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
CAUSAL_MASK_INDIRECT_PTX,
"causal_mask_indirect_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "causal_mask_indirect_kernel",
source: e,
})?;
let cfg = launch_cfg(total)?;
let max_pos_u32 = max_pos as u32;
let total_u32 = total as u32;
unsafe {
stream
.launch_builder(&f)
.arg(total_len_ptr)
.arg(out.inner_mut())
.arg(&max_pos_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn precompile_decode_kernels(device: &GpuDevice) -> GpuResult<()> {
let ctx = device.context();
ctx.bind_to_thread()?;
let ord = device.ordinal() as u32;
let compile = |ptx: &'static str, name: &'static str| -> GpuResult<()> {
crate::module_cache::get_or_compile(ctx, ptx, name, ord)
.map(|_| ())
.map_err(GpuError::Driver)
};
compile(ADD_PTX, "add_kernel")?;
compile(MUL_PTX, "mul_kernel")?;
compile(SCALE_PTX, "scale_kernel")?;
compile(GELU_PTX, "gelu_kernel")?;
compile(SOFTMAX_PTX, "softmax_kernel")?;
compile(LAYERNORM_PTX, "layernorm_kernel")?;
compile(PERMUTE_0213_PTX, "permute_0213_kernel")?;
compile(EMBED_LOOKUP_PTX, "embed_lookup_kernel")?;
compile(EMBED_LOOKUP_BATCH_PTX, "embed_lookup_batch_kernel")?;
compile(SCATTER_ADD_ROWS_PTX, "scatter_add_rows_kernel")?;
compile(SMALL_MATMUL_PTX, "small_matmul_kernel")?;
compile(SLICE_WRITE_INDIRECT_PTX, "slice_write_indirect_kernel")?;
compile(CAUSAL_MASK_INDIRECT_PTX, "causal_mask_indirect_kernel")?;
compile(SLICE_READ_PTX, "slice_read_kernel")?;
compile(RELU_BACKWARD_PTX, "relu_backward_kernel")?;
compile(GELU_BACKWARD_PTX, "gelu_backward_kernel")?;
Ok(())
}
#[cfg(not(feature = "cuda"))]
pub fn precompile_decode_kernels(_device: &GpuDevice) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_tanh(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_erf(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_backward_tanh(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_silu(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_silu_backward(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_elu(
_input: &CudaBuffer<f32>,
_alpha: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_elu_backward(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_alpha: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mish(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mish_backward(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_clamp(
_input: &CudaBuffer<f32>,
_min_val: f32,
_max_val: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_div(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_exp(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_log(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sqrt(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_pow(
_a: &CudaBuffer<f32>,
_exponent: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_abs(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sigmoid(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_tanh(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_layernorm(
_input: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_bias: &CudaBuffer<f32>,
_rows: usize,
_cols: usize,
_eps: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_transpose_2d(
_input: &CudaBuffer<f32>,
_m: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_add(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sub(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mul(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_neg(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_relu(_a: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_scale(
_a: &CudaBuffer<f32>,
_scalar: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_add(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_a_shape: &[usize],
_b_shape: &[usize],
_out_shape: &[usize],
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_sub(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_a_shape: &[usize],
_b_shape: &[usize],
_out_shape: &[usize],
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_mul(
_a: &CudaBuffer<f32>,
_b: &CudaBuffer<f32>,
_a_shape: &[usize],
_b_shape: &[usize],
_out_shape: &[usize],
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_softmax(
_input: &CudaBuffer<f32>,
_rows: usize,
_cols: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_dropout(
_input: &CudaBuffer<f32>,
_threshold: u32,
_scale: f32,
_seed: u32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_permute_0213(
_input: &CudaBuffer<f32>,
_d0: usize,
_d1: usize,
_d2: usize,
_d3: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_slice_write(
_src: &CudaBuffer<f32>,
_dst: &mut CudaBuffer<f32>,
_n_batch: usize,
_d: usize,
_max_len: usize,
_pos: usize,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_slice_read(
_src: &CudaBuffer<f32>,
_n_batch: usize,
_d: usize,
_len: usize,
_max_len: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_embed_lookup(
_idx: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_d: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_embed_lookup_batch(
_indices: &CudaBuffer<f32>,
_weight: &CudaBuffer<f32>,
_n: usize,
_d: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_scatter_add_rows(
_grad_output: &CudaBuffer<f32>,
_indices: &CudaBuffer<f32>,
_num_embeddings: usize,
_d: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_relu_backward(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_abs_backward(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_fill_f32(_n: usize, _scalar: f32, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_fill_f64(_n: usize, _scalar: f64, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_backward(
_grad: &CudaBuffer<f32>,
_input: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_index_select_1d(
_input: &CudaBuffer<f32>,
_indices: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_scatter_add_1d(
_grad_output: &CudaBuffer<f32>,
_indices: &CudaBuffer<f32>,
_input_len: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_index_select_dim(
_input: &CudaBuffer<f32>,
_indices: &CudaBuffer<f32>,
_outer: usize,
_in_dim_size: usize,
_out_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_index_select_dim_f64(
_input: &CudaBuffer<f64>,
_indices: &CudaBuffer<f32>,
_outer: usize,
_in_dim_size: usize,
_out_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_fill(
_input: &CudaBuffer<f32>,
_mask: &CudaBuffer<f32>,
_value: f32,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_zero(
_grad: &CudaBuffer<f32>,
_mask: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sigmoid_backward(
_grad: &CudaBuffer<f32>,
_output: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_tanh_backward(
_grad: &CudaBuffer<f32>,
_output: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_softmax_backward(
_grad: &CudaBuffer<f32>,
_output: &CudaBuffer<f32>,
_cols: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_log_softmax(
_input: &CudaBuffer<f32>,
_cols: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_log_softmax_backward(
_grad: &CudaBuffer<f32>,
_output: &CudaBuffer<f32>,
_cols: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sum_axis(
_a: &CudaBuffer<f32>,
_outer: usize,
_axis_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_cumsum(
_input: &CudaBuffer<f32>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_cumprod(
_input: &CudaBuffer<f32>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_cummax(
_input: &CudaBuffer<f32>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_cummin(
_input: &CudaBuffer<f32>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, CudaBuffer<f32>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_logcumsumexp(
_input: &CudaBuffer<f32>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_strided_split(
_input: &CudaBuffer<f32>,
_total_along_axis: usize,
_split_offset: usize,
_split_size: usize,
_inner_size: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_strided_cat(
_input: &CudaBuffer<f32>,
_output: &mut CudaBuffer<f32>,
_total_along_axis: usize,
_cat_offset: usize,
_part_size: usize,
_inner_size: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub const STRIDED_COPY_MAX_DIMS: usize = 8;
#[cfg(not(feature = "cuda"))]
pub fn gpu_strided_copy(
_input: &CudaBuffer<f32>,
_out_shape: &[usize],
_src_strides: &[isize],
_src_offset: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_strided_copy_f64(
_input: &CudaBuffer<f64>,
_out_shape: &[usize],
_src_strides: &[isize],
_src_offset: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub(crate) fn gpu_f32_to_f16(
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
use cudarc::driver::PushKernelArg;
let n = input.len();
if n == 0 {
let empty = device.stream().alloc_zeros::<u16>(0)?;
return Ok(empty);
}
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
F32_TO_F16_PTX,
"f32_to_f16_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "f32_to_f16_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<u16>(n)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub(crate) fn gpu_f32_to_f16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub(crate) fn gpu_f32_to_bf16(
input: &CudaBuffer<f32>,
device: &GpuDevice,
) -> GpuResult<cudarc::driver::CudaSlice<u16>> {
use cudarc::driver::PushKernelArg;
let n = input.len();
if n == 0 {
let empty = device.stream().alloc_zeros::<u16>(0)?;
return Ok(empty);
}
let ctx = device.context();
let stream = device.stream();
let f = crate::module_cache::get_or_compile(
ctx,
F32_TO_BF16_PTX,
"f32_to_bf16_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "f32_to_bf16_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<u16>(n)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub(crate) fn gpu_f32_to_bf16(_input: &CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_add_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sub_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mul_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_div_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_neg_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_relu_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_scale_f64(
_a: &CudaBuffer<f64>,
_scalar: f64,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_exp_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_log_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sqrt_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_pow_f64(
_a: &CudaBuffer<f64>,
_exponent: f64,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_abs_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sigmoid_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_tanh_f64(_a: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_relu_backward_f64(
_grad: &CudaBuffer<f64>,
_input: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_abs_backward_f64(
_grad: &CudaBuffer<f64>,
_input: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sigmoid_backward_f64(
_grad: &CudaBuffer<f64>,
_output: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_tanh_backward_f64(
_grad: &CudaBuffer<f64>,
_output: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_add_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_a_shape: &[usize],
_b_shape: &[usize],
_out_shape: &[usize],
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_sub_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_a_shape: &[usize],
_b_shape: &[usize],
_out_shape: &[usize],
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_mul_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_a_shape: &[usize],
_b_shape: &[usize],
_out_shape: &[usize],
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_broadcast_div_f64(
_a: &CudaBuffer<f64>,
_b: &CudaBuffer<f64>,
_a_shape: &[usize],
_b_shape: &[usize],
_out_shape: &[usize],
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_transpose_2d_f64(
_input: &CudaBuffer<f64>,
_m: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_permute_0213_f64(
_input: &CudaBuffer<f64>,
_d0: usize,
_d1: usize,
_d2: usize,
_d3: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_transpose_complex_f32(
_input: &CudaBuffer<f32>,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_transpose_complex_f64(
_input: &CudaBuffer<f64>,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_strided_split_f64(
_input: &CudaBuffer<f64>,
_total_along_axis: usize,
_split_offset: usize,
_split_size: usize,
_inner_size: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_strided_cat_f64(
_input: &CudaBuffer<f64>,
_output: &mut CudaBuffer<f64>,
_total_along_axis: usize,
_cat_offset: usize,
_part_size: usize,
_inner_size: usize,
_n: usize,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_index_select_1d_f64(
_input: &CudaBuffer<f64>,
_indices: &CudaBuffer<f32>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_scatter_add_1d_f64(
_grad_output: &CudaBuffer<f64>,
_indices: &CudaBuffer<f32>,
_input_len: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_fill_f64(
_input: &CudaBuffer<f64>,
_mask: &CudaBuffer<u8>,
_value: f64,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_masked_zero_f64(
_grad: &CudaBuffer<f64>,
_mask: &CudaBuffer<u8>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_slice_write_f64(
_src: &CudaBuffer<f64>,
_dst: &mut CudaBuffer<f64>,
_n_batch: usize,
_d: usize,
_max_len: usize,
_pos: usize,
_device: &GpuDevice,
) -> GpuResult<()> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_slice_read_f64(
_src: &CudaBuffer<f64>,
_n_batch: usize,
_d: usize,
_len: usize,
_max_len: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_embed_lookup_f64(
_idx: &CudaBuffer<f32>,
_weight: &CudaBuffer<f64>,
_d: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_embed_lookup_batch_f64(
_indices: &CudaBuffer<f32>,
_weight: &CudaBuffer<f64>,
_n: usize,
_d: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_scatter_add_rows_f64(
_grad_output: &CudaBuffer<f64>,
_indices: &CudaBuffer<f32>,
_num_embeddings: usize,
_d: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
try_launch_unary_f64(input, device, GELU_F64_PTX, "gelu_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_tanh_f64(
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
try_launch_unary_f64(input, device, GELU_TANH_F64_PTX, "gelu_tanh_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_erf_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
try_launch_unary_f64(input, device, GELU_ERF_F64_PTX, "gelu_erf_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_backward_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: input.len(),
});
}
try_launch_binary_f64(
grad,
input,
device,
GELU_BACKWARD_F64_PTX,
"gelu_backward_f64_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_backward_tanh_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: input.len(),
});
}
try_launch_binary_f64(
grad,
input,
device,
GELU_BACKWARD_TANH_F64_PTX,
"gelu_backward_tanh_f64_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_gelu_backward_erf_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: input.len(),
});
}
try_launch_binary_f64(
grad,
input,
device,
GELU_BACKWARD_ERF_F64_PTX,
"gelu_backward_erf_f64_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_silu_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
try_launch_unary_f64(input, device, SILU_F64_PTX, "silu_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_silu_backward_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: input.len(),
});
}
try_launch_binary_f64(
grad,
input,
device,
SILU_BACKWARD_F64_PTX,
"silu_backward_f64_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_elu_f64(
input: &CudaBuffer<f64>,
alpha: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
let n = input.len();
if n == 0 {
return cpu_to_gpu(&[], device);
}
let ctx = device.context();
let stream = device.stream();
match crate::module_cache::get_or_compile(
ctx,
ELU_F64_PTX,
"elu_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => {
let mut out = alloc_zeros_f64(n, device)?;
let n_u32 = n as u32;
let cfg = launch_cfg(n)?;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&alpha)
.launch(cfg)?;
}
Ok(out)
}
Err(e) => {
if std::env::var("FERROTORCH_ENABLE_GPU_FALLBACK").is_ok() {
tracing::warn!(
target: "ferrotorch::gpu_fallback",
kernel = "elu_f64_kernel",
error = %e,
"PTX compile failed; falling back to CPU. Unset \
FERROTORCH_ENABLE_GPU_FALLBACK to make this an error instead.",
);
let host = gpu_to_cpu(input, device)?;
let result: Vec<f64> = host
.iter()
.map(|&x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
.collect();
return cpu_to_gpu(&result, device);
}
Err(GpuError::PtxCompileFailed {
kernel: "elu_f64_kernel",
source: e,
})
}
}
}
#[cfg(feature = "cuda")]
pub fn gpu_elu_backward_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
alpha: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: input.len(),
});
}
let n = grad.len();
if n == 0 {
return cpu_to_gpu(&[], device);
}
let ctx = device.context();
let stream = device.stream();
match crate::module_cache::get_or_compile(
ctx,
ELU_BACKWARD_F64_PTX,
"elu_backward_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => {
let mut out = alloc_zeros_f64(n, device)?;
let n_u32 = n as u32;
let cfg = launch_cfg(n)?;
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&alpha)
.launch(cfg)?;
}
Ok(out)
}
Err(e) => {
if std::env::var("FERROTORCH_ENABLE_GPU_FALLBACK").is_ok() {
tracing::warn!(
target: "ferrotorch::gpu_fallback",
kernel = "elu_backward_f64_kernel",
error = %e,
"PTX compile failed; falling back to CPU. Unset \
FERROTORCH_ENABLE_GPU_FALLBACK to make this an error instead.",
);
let g_host = gpu_to_cpu(grad, device)?;
let x_host = gpu_to_cpu(input, device)?;
let result: Vec<f64> = g_host
.iter()
.zip(x_host.iter())
.map(|(&g, &x)| if x > 0.0 { g } else { g * alpha * x.exp() })
.collect();
return cpu_to_gpu(&result, device);
}
Err(GpuError::PtxCompileFailed {
kernel: "elu_backward_f64_kernel",
source: e,
})
}
}
}
#[cfg(feature = "cuda")]
pub fn gpu_mish_f64(input: &CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
try_launch_unary_f64(input, device, MISH_F64_PTX, "mish_f64_kernel")
}
#[cfg(feature = "cuda")]
pub fn gpu_mish_backward_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: input.len(),
});
}
try_launch_binary_f64(
grad,
input,
device,
MISH_BACKWARD_F64_PTX,
"mish_backward_f64_kernel",
)
}
#[cfg(feature = "cuda")]
pub fn gpu_clamp_f64(
input: &CudaBuffer<f64>,
min_val: f64,
max_val: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let n = input.len();
if n == 0 {
return cpu_to_gpu(&[], device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, CLAMP_PTX, "clamp_kernel", "clamp_f64_kernel");
match crate::module_cache::get_or_compile(ctx, ptx, "clamp_f64_kernel", device.ordinal() as u32)
{
Ok(f) => {
let mut out = alloc_zeros_f64(n, device)?;
let n_u32 = n as u32;
let cfg = launch_cfg(n)?;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&n_u32)
.arg(&min_val)
.arg(&max_val)
.launch(cfg)?;
}
Ok(out)
}
Err(e) => {
if std::env::var("FERROTORCH_ENABLE_GPU_FALLBACK").is_ok() {
tracing::warn!(
target: "ferrotorch::gpu_fallback",
kernel = "clamp_f64_kernel",
error = %e,
"PTX compile failed; falling back to CPU. Unset \
FERROTORCH_ENABLE_GPU_FALLBACK to make this an error instead.",
);
let host = gpu_to_cpu(input, device)?;
let result: Vec<f64> = host.iter().map(|&x| x.max(min_val).min(max_val)).collect();
return cpu_to_gpu(&result, device);
}
Err(GpuError::PtxCompileFailed {
kernel: "clamp_f64_kernel",
source: e,
})
}
}
}
#[cfg(feature = "cuda")]
pub fn gpu_clamp_backward_f64(
grad: &CudaBuffer<f64>,
input: &CudaBuffer<f64>,
min_val: f64,
max_val: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
if grad.len() != input.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: input.len(),
});
}
let n = input.len();
if n == 0 {
return cpu_to_gpu(&[], device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
CLAMP_BACKWARD_PTX,
"clamp_backward_kernel",
"clamp_backward_f64_kernel",
);
match crate::module_cache::get_or_compile(
ctx,
ptx,
"clamp_backward_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => {
let mut out = alloc_zeros_f64(n, device)?;
let n_u32 = n as u32;
let cfg = launch_cfg(n)?;
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(input.inner())
.arg(out.inner_mut())
.arg(&min_val)
.arg(&max_val)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
Err(e) => {
if std::env::var("FERROTORCH_ENABLE_GPU_FALLBACK").is_ok() {
tracing::warn!(
target: "ferrotorch::gpu_fallback",
kernel = "clamp_backward_f64_kernel",
error = %e,
"PTX compile failed; falling back to CPU. Unset \
FERROTORCH_ENABLE_GPU_FALLBACK to make this an error instead.",
);
let g = gpu_to_cpu(grad, device)?;
let x = gpu_to_cpu(input, device)?;
let out: Vec<f64> = g
.iter()
.zip(x.iter())
.map(|(&gi, &xi)| {
if xi >= min_val && xi <= max_val {
gi
} else {
0.0
}
})
.collect();
return cpu_to_gpu(&out, device);
}
Err(GpuError::PtxCompileFailed {
kernel: "clamp_backward_f64_kernel",
source: e,
})
}
}
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_clamp_backward_f64(
_grad: &CudaBuffer<f64>,
_input: &CudaBuffer<f64>,
_min: f64,
_max: f64,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_cumsum_f64(
input: &CudaBuffer<f64>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let total = outer * inner;
let n = outer * dim_size * inner;
if n == 0 {
return cpu_to_gpu(&[], device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, CUMSUM_PTX, "cumsum_kernel", "cumsum_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"cumsum_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "cumsum_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(total)?;
let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&o)
.arg(&d)
.arg(&i)
.arg(&t)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_cumprod_f64(
input: &CudaBuffer<f64>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let total = outer * inner;
let n = outer * dim_size * inner;
if n == 0 {
return cpu_to_gpu(&[], device);
}
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, CUMPROD_PTX, "cumprod_kernel", "cumprod_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"cumprod_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "cumprod_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(total)?;
let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&o)
.arg(&d)
.arg(&i)
.arg(&t)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_cummax_f64(
input: &CudaBuffer<f64>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
use cudarc::driver::PushKernelArg;
let total = outer * inner;
let n = outer * dim_size * inner;
if n == 0 {
let e: &[f64] = &[];
return Ok((cpu_to_gpu(e, device)?, cpu_to_gpu(e, device)?));
}
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, CUMMAX_PTX, "cummax_kernel", "cummax_f64_kernel");
let f =
crate::module_cache::get_or_compile(ctx, ptx, "cummax_f64_kernel", device.ordinal() as u32)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "cummax_kernel",
source: e,
})?;
let mut out = alloc_zeros_f64(n, device)?;
let mut ind = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(total)?;
let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(ind.inner_mut())
.arg(&o)
.arg(&d)
.arg(&i)
.arg(&t)
.launch(cfg)?;
}
Ok((out, ind))
}
#[cfg(feature = "cuda")]
pub fn gpu_cummin_f64(
input: &CudaBuffer<f64>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
use cudarc::driver::PushKernelArg;
let total = outer * inner;
let n = outer * dim_size * inner;
if n == 0 {
let e: &[f64] = &[];
return Ok((cpu_to_gpu(e, device)?, cpu_to_gpu(e, device)?));
}
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, CUMMIN_PTX, "cummin_kernel", "cummin_f64_kernel");
let f =
crate::module_cache::get_or_compile(ctx, ptx, "cummin_f64_kernel", device.ordinal() as u32)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "cummin_kernel",
source: e,
})?;
let mut out = alloc_zeros_f64(n, device)?;
let mut ind = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(total)?;
let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(ind.inner_mut())
.arg(&o)
.arg(&d)
.arg(&i)
.arg(&t)
.launch(cfg)?;
}
Ok((out, ind))
}
#[cfg(feature = "cuda")]
pub fn gpu_logcumsumexp_f64(
input: &CudaBuffer<f64>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
let total = outer * inner;
let n = outer * dim_size * inner;
if n == 0 {
return cpu_to_gpu(&[], device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LOGCUMSUMEXP_F64_PTX,
"logcumsumexp_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "logcumsumexp_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(n, device)?;
let cfg = launch_cfg(total)?;
let (o, d, i, t) = (outer as u32, dim_size as u32, inner as u32, total as u32);
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&o)
.arg(&d)
.arg(&i)
.arg(&t)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_softmax_f64(
input: &CudaBuffer<f64>,
rows: usize,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SOFTMAX_F64_PTX,
"softmax_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "softmax_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(rows * cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8, };
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_softmax_bf16_f32(
input: &cudarc::driver::CudaSlice<u16>,
rows: usize,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
if rows == 0 || cols == 0 {
return alloc_zeros_f32(rows * cols, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SOFTMAX_BF16_F32_PTX,
"softmax_bf16_f32_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "softmax_bf16_f32_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(rows * cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 4, };
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_softmax_bf16_f32(
_input: &(),
_rows: usize,
_cols: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_softmax_backward_f64(
grad: &CudaBuffer<f64>,
output: &CudaBuffer<f64>,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
validate_device(grad, device)?;
if grad.len() != output.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: output.len(),
});
}
let total = grad.len();
let rows = total / cols;
let ctx = device.context();
let stream = device.stream();
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(
&CACHE,
SOFTMAX_BACKWARD_PTX,
"softmax_backward_kernel",
"softmax_backward_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"softmax_backward_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "softmax_backward_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(total, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(output.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_log_softmax_f64(
input: &CudaBuffer<f64>,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
validate_device(input, device)?;
let total = input.len();
let rows = total / cols;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LOG_SOFTMAX_F64_PTX,
"log_softmax_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "log_softmax_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(total, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_log_softmax_backward_f64(
grad: &CudaBuffer<f64>,
output: &CudaBuffer<f64>,
cols: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
validate_device(grad, device)?;
if grad.len() != output.len() {
return Err(GpuError::LengthMismatch {
a: grad.len(),
b: output.len(),
});
}
let total = grad.len();
let rows = total / cols;
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
LOG_SOFTMAX_BACKWARD_F64_PTX,
"log_softmax_backward_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "log_softmax_backward_f64_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(total, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(grad.inner())
.arg(output.inner())
.arg(out.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_layernorm_f64(
input: &CudaBuffer<f64>,
weight: &CudaBuffer<f64>,
bias: &CudaBuffer<f64>,
rows: usize,
cols: usize,
eps: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
LAYERNORM_PTX,
"layernorm_kernel",
"layernorm_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"layernorm_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "layernorm_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(rows * cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(weight.inner())
.arg(bias.inner())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_layernorm_backward_f64(
input: &CudaBuffer<f64>,
grad_output: &CudaBuffer<f64>,
weight: &CudaBuffer<f64>,
rows: usize,
cols: usize,
eps: f64,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>, CudaBuffer<f64>)> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
LAYERNORM_BACKWARD_PTX,
"layernorm_backward_kernel",
"layernorm_backward_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"layernorm_backward_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "layernorm_backward_kernel",
source: e,
});
}
};
let mut grad_in = alloc_zeros_f64(rows * cols, device)?;
let mut grad_w = alloc_zeros_f64(cols, device)?;
let mut grad_b = alloc_zeros_f64(cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(grad_output.inner())
.arg(weight.inner())
.arg(grad_in.inner_mut())
.arg(grad_w.inner_mut())
.arg(grad_b.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok((grad_in, grad_w, grad_b))
}
#[cfg(feature = "cuda")]
pub fn gpu_rmsnorm_f64(
input: &CudaBuffer<f64>,
weight: &CudaBuffer<f64>,
rows: usize,
cols: usize,
eps: f64,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(&CACHE, RMSNORM_PTX, "rmsnorm_kernel", "rmsnorm_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"rmsnorm_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "rmsnorm_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f64(rows * cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(out.inner_mut())
.arg(weight.inner())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(feature = "cuda")]
pub fn gpu_rmsnorm_backward_f64(
input: &CudaBuffer<f64>,
grad_output: &CudaBuffer<f64>,
weight: &CudaBuffer<f64>,
rows: usize,
cols: usize,
eps: f64,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
use cudarc::driver::PushKernelArg;
static CACHE: std::sync::OnceLock<String> = std::sync::OnceLock::new();
validate_device(input, device)?;
let ctx = device.context();
let stream = device.stream();
let ptx = get_f64_ptx(
&CACHE,
RMSNORM_BACKWARD_PTX,
"rmsnorm_backward_kernel",
"rmsnorm_backward_f64_kernel",
);
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"rmsnorm_backward_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "rmsnorm_backward_kernel",
source: e,
});
}
};
let mut grad_in = alloc_zeros_f64(rows * cols, device)?;
let mut grad_w = alloc_zeros_f64(cols, device)?;
let rows_u32 = rows as u32;
let cols_u32 = cols as u32;
let cfg = LaunchConfig {
grid_dim: ((rows as u32).max(1), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 256 * 8,
};
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(grad_output.inner())
.arg(weight.inner())
.arg(grad_in.inner_mut())
.arg(grad_w.inner_mut())
.arg(&rows_u32)
.arg(&cols_u32)
.arg(&eps)
.launch(cfg)?;
}
Ok((grad_in, grad_w))
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_softmax_f64(
_input: &CudaBuffer<f64>,
_rows: usize,
_cols: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_softmax_backward_f64(
_grad: &CudaBuffer<f64>,
_output: &CudaBuffer<f64>,
_cols: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_log_softmax_f64(
_input: &CudaBuffer<f64>,
_cols: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_log_softmax_backward_f64(
_grad: &CudaBuffer<f64>,
_output: &CudaBuffer<f64>,
_cols: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_layernorm_f64(
_input: &CudaBuffer<f64>,
_weight: &CudaBuffer<f64>,
_bias: &CudaBuffer<f64>,
_rows: usize,
_cols: usize,
_eps: f64,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_layernorm_backward_f64(
_input: &CudaBuffer<f64>,
_grad_output: &CudaBuffer<f64>,
_weight: &CudaBuffer<f64>,
_rows: usize,
_cols: usize,
_eps: f64,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>, CudaBuffer<f64>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_rmsnorm_f64(
_input: &CudaBuffer<f64>,
_weight: &CudaBuffer<f64>,
_rows: usize,
_cols: usize,
_eps: f64,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_rmsnorm_backward_f64(
_input: &CudaBuffer<f64>,
_grad_output: &CudaBuffer<f64>,
_weight: &CudaBuffer<f64>,
_rows: usize,
_cols: usize,
_eps: f64,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_tanh_f64(
_input: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_erf_f64(
_input: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_backward_f64(
_grad: &CudaBuffer<f64>,
_input: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_backward_tanh_f64(
_grad: &CudaBuffer<f64>,
_input: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_gelu_backward_erf_f64(
_grad: &CudaBuffer<f64>,
_input: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_silu_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_silu_backward_f64(
_grad: &CudaBuffer<f64>,
_input: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_elu_f64(
_input: &CudaBuffer<f64>,
_alpha: f64,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_elu_backward_f64(
_grad: &CudaBuffer<f64>,
_input: &CudaBuffer<f64>,
_alpha: f64,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mish_f64(_input: &CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mish_backward_f64(
_grad: &CudaBuffer<f64>,
_input: &CudaBuffer<f64>,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_clamp_f64(
_input: &CudaBuffer<f64>,
_min: f64,
_max: f64,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_cumsum_f64(
_input: &CudaBuffer<f64>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_cumprod_f64(
_input: &CudaBuffer<f64>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_cummax_f64(
_input: &CudaBuffer<f64>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_cummin_f64(
_input: &CudaBuffer<f64>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, CudaBuffer<f64>)> {
Err(GpuError::NoCudaFeature)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_logcumsumexp_f64(
_input: &CudaBuffer<f64>,
_outer: usize,
_dim_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub(crate) const CONJ_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry conj_kernel(
.param .u64 buf_ptr,
.param .u32 n_pairs
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u64 %buf, %off;
.reg .f32 %vim;
.reg .pred %p;
ld.param.u64 %buf, [buf_ptr];
ld.param.u32 %n_reg, [n_pairs];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
cvt.u64.u32 %off, %r_tid;
shl.b64 %off, %off, 1;
add.u64 %off, %off, 1;
shl.b64 %off, %off, 2;
add.u64 %buf, %buf, %off;
ld.global.f32 %vim, [%buf];
neg.f32 %vim, %vim;
st.global.f32 [%buf], %vim;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
pub fn gpu_conj_f32(mut buf: CudaBuffer<f32>, device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
if buf.len() % 2 != 0 {
return Err(GpuError::ShapeMismatch {
op: "gpu_conj_f32",
expected: vec![2],
got: vec![buf.len() % 2],
});
}
let n_pairs = buf.len() / 2;
if n_pairs == 0 {
return Ok(buf);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
CONJ_PTX,
"conj_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "conj_kernel",
source: e,
});
}
};
let cfg = launch_cfg(n_pairs)?;
let n_pairs_u32 = n_pairs as u32;
unsafe {
stream
.launch_builder(&f)
.arg(buf.inner_mut())
.arg(&n_pairs_u32)
.launch(cfg)?;
}
Ok(buf)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_conj_f32(_buf: CudaBuffer<f32>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_conj_f64(mut buf: CudaBuffer<f64>, device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
use cudarc::driver::PushKernelArg;
if buf.len() % 2 != 0 {
return Err(GpuError::ShapeMismatch {
op: "gpu_conj_f64",
expected: vec![2],
got: vec![buf.len() % 2],
});
}
let n_pairs = buf.len() / 2;
if n_pairs == 0 {
return Ok(buf);
}
let ctx = device.context();
let stream = device.stream();
static CONJ_F64_PTX: std::sync::OnceLock<String> = std::sync::OnceLock::new();
let ptx = get_f64_ptx(&CONJ_F64_PTX, CONJ_PTX, "conj_kernel", "conj_f64_kernel");
let f = match crate::module_cache::get_or_compile(
ctx,
ptx,
"conj_f64_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "conj_f64_kernel",
source: e,
});
}
};
let cfg = launch_cfg(n_pairs)?;
let n_pairs_u32 = n_pairs as u32;
unsafe {
stream
.launch_builder(&f)
.arg(buf.inner_mut())
.arg(&n_pairs_u32)
.launch(cfg)?;
}
Ok(buf)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_conj_f64(_buf: CudaBuffer<f64>, _device: &GpuDevice) -> GpuResult<CudaBuffer<f64>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_add_bf16_f32(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
if n == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
ADD_BF16_F32_PTX,
"add_bf16_f32_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "add_bf16_f32_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(b)
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_add_bf16_f32(
_a: &(),
_b: &(),
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_sub_bf16_f32(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
if n == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SUB_BF16_F32_PTX,
"sub_bf16_f32_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "sub_bf16_f32_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(b)
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sub_bf16_f32(
_a: &(),
_b: &(),
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_mul_bf16_f32(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
if n == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
MUL_BF16_F32_PTX,
"mul_bf16_f32_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "mul_bf16_f32_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(b)
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mul_bf16_f32(
_a: &(),
_b: &(),
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_div_bf16_f32(
a: &cudarc::driver::CudaSlice<u16>,
b: &cudarc::driver::CudaSlice<u16>,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
if n == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
DIV_BF16_F32_PTX,
"div_bf16_f32_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "div_bf16_f32_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(b)
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_div_bf16_f32(
_a: &(),
_b: &(),
_n: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_sum_axis_bf16_f32(
a: &cudarc::driver::CudaSlice<u16>,
outer: usize,
axis_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let total_output = outer * inner;
if total_output == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SUM_AXIS_BF16_F32_PTX,
"sum_axis_bf16_f32_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "sum_axis_bf16_f32_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total_output, device)?;
let cfg = launch_cfg(total_output)?;
let outer_u32 = outer as u32;
let axis_u32 = axis_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total_output as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&axis_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sum_axis_bf16_f32(
_a: &(),
_outer: usize,
_axis_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_mean_axis_bf16_f32(
a: &cudarc::driver::CudaSlice<u16>,
outer: usize,
axis_size: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
let total_output = outer * inner;
if total_output == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
MEAN_AXIS_BF16_F32_PTX,
"mean_axis_bf16_f32_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "mean_axis_bf16_f32_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(total_output, device)?;
let cfg = launch_cfg(total_output)?;
let outer_u32 = outer as u32;
let axis_u32 = axis_size as u32;
let inner_u32 = inner as u32;
let total_u32 = total_output as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(out.inner_mut())
.arg(&outer_u32)
.arg(&axis_u32)
.arg(&inner_u32)
.arg(&total_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_mean_axis_bf16_f32(
_a: &(),
_outer: usize,
_axis_size: usize,
_inner: usize,
_device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_relu_bf16_f32(
a: &cudarc::driver::CudaSlice<u16>,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
if n == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
RELU_BF16_F32_PTX,
"relu_bf16_f32_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "relu_bf16_f32_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_relu_bf16_f32(_a: &(), _n: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(feature = "cuda")]
pub fn gpu_sigmoid_bf16_f32(
a: &cudarc::driver::CudaSlice<u16>,
n: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
use cudarc::driver::PushKernelArg;
if n == 0 {
return alloc_zeros_f32(0, device);
}
let ctx = device.context();
let stream = device.stream();
let f = match crate::module_cache::get_or_compile(
ctx,
SIGMOID_BF16_F32_PTX,
"sigmoid_bf16_f32_kernel",
device.ordinal() as u32,
) {
Ok(f) => f,
Err(e) => {
return Err(GpuError::PtxCompileFailed {
kernel: "sigmoid_bf16_f32_kernel",
source: e,
});
}
};
let mut out = alloc_zeros_f32(n, device)?;
let cfg = launch_cfg(n)?;
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(out.inner_mut())
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
#[cfg(not(feature = "cuda"))]
pub fn gpu_sigmoid_bf16_f32(_a: &(), _n: usize, _device: &GpuDevice) -> GpuResult<CudaBuffer<f32>> {
Err(GpuError::NoCudaFeature)
}
#[cfg(test)]
#[cfg(feature = "cuda")]
mod tests {
use super::*;
fn setup(data: &[f32]) -> (GpuDevice, CudaBuffer<f32>) {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let buf = cpu_to_gpu(data, &dev).expect("cpu_to_gpu");
(dev, buf)
}
fn assert_buf_eq(buf: &CudaBuffer<f32>, device: &GpuDevice, expected: &[f32]) {
let host = gpu_to_cpu(buf, device).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "length mismatch");
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-6,
"element {i}: got {got}, expected {exp}",
);
}
}
#[test]
fn add_basic() {
let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
let b_data = vec![10.0f32, 20.0, 30.0, 40.0];
let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_add(&a, &b, &dev).expect("gpu_add");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn add_empty() {
let (dev, a) = setup(&[]);
let b = cpu_to_gpu::<f32>(&[], &dev).expect("cpu_to_gpu b");
let out = gpu_add(&a, &b, &dev).expect("gpu_add empty");
assert_eq!(out.len(), 0);
}
#[test]
fn add_large() {
let n = 100_000;
let a_data: Vec<f32> = (0..n).map(|i| i as f32).collect();
let b_data: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5).collect();
let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x + y).collect();
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_add(&a, &b, &dev).expect("gpu_add large");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn add_length_mismatch() {
let (dev, a) = setup(&[1.0, 2.0, 3.0]);
let b = cpu_to_gpu(&[1.0, 2.0], &dev).expect("cpu_to_gpu b");
let err = gpu_add(&a, &b, &dev).unwrap_err();
match err {
GpuError::LengthMismatch { a: 3, b: 2 } => {}
other => panic!("unexpected error: {other}"),
}
}
#[test]
fn sub_basic() {
let a_data = vec![10.0f32, 20.0, 30.0, 40.0];
let b_data = vec![1.0f32, 2.0, 3.0, 4.0];
let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x - y).collect();
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn sub_negative_result() {
let a_data = vec![1.0f32, 2.0];
let b_data = vec![5.0f32, 10.0];
let expected: Vec<f32> = vec![-4.0, -8.0];
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_sub(&a, &b, &dev).expect("gpu_sub");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn mul_basic() {
let a_data = vec![2.0f32, 3.0, 4.0, 5.0];
let b_data = vec![10.0f32, 10.0, 10.0, 10.0];
let expected: Vec<f32> = a_data.iter().zip(&b_data).map(|(x, y)| x * y).collect();
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn mul_by_zero() {
let a_data = vec![1.0f32, 2.0, 3.0];
let b_data = vec![0.0f32, 0.0, 0.0];
let expected = vec![0.0f32, 0.0, 0.0];
let (dev, a) = setup(&a_data);
let b = cpu_to_gpu(&b_data, &dev).expect("cpu_to_gpu b");
let out = gpu_mul(&a, &b, &dev).expect("gpu_mul");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn neg_basic() {
let a_data = vec![1.0f32, -2.0, 3.0, 0.0, -5.5];
let expected: Vec<f32> = a_data.iter().map(|x| -x).collect();
let (dev, a) = setup(&a_data);
let out = gpu_neg(&a, &dev).expect("gpu_neg");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn neg_double_negation() {
let a_data = vec![1.0f32, -2.0, 3.0];
let (dev, a) = setup(&a_data);
let neg1 = gpu_neg(&a, &dev).expect("gpu_neg 1");
let neg2 = gpu_neg(&neg1, &dev).expect("gpu_neg 2");
assert_buf_eq(&neg2, &dev, &a_data);
}
#[test]
fn relu_basic() {
let a_data = vec![-3.0f32, -1.0, 0.0, 1.0, 3.0];
let expected = vec![0.0f32, 0.0, 0.0, 1.0, 3.0];
let (dev, a) = setup(&a_data);
let out = gpu_relu(&a, &dev).expect("gpu_relu");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn relu_all_negative() {
let a_data = vec![-5.0f32, -0.1, -100.0];
let expected = vec![0.0f32, 0.0, 0.0];
let (dev, a) = setup(&a_data);
let out = gpu_relu(&a, &dev).expect("gpu_relu");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn relu_all_positive() {
let a_data = vec![0.1f32, 1.0, 100.0];
let (dev, a) = setup(&a_data);
let out = gpu_relu(&a, &dev).expect("gpu_relu");
assert_buf_eq(&out, &dev, &a_data);
}
#[test]
fn relu_empty() {
let (dev, a) = setup(&[]);
let out = gpu_relu(&a, &dev).expect("gpu_relu empty");
assert_eq!(out.len(), 0);
}
#[test]
fn small_matmul_2x2() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0, 4.0], &dev).unwrap();
let b = cpu_to_gpu(&[5.0f32, 6.0, 7.0, 8.0], &dev).unwrap();
let c = gpu_small_matmul(&a, &b, 2, 2, 2, &dev).unwrap();
assert_buf_eq(&c, &dev, &[19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn small_matmul_1xk_kxn() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let a = cpu_to_gpu(&[1.0f32, 2.0, 3.0], &dev).unwrap();
let b = cpu_to_gpu(&[1.0f32, 0.0, 0.0, 1.0, 1.0, 1.0], &dev).unwrap();
let c = gpu_small_matmul(&a, &b, 1, 3, 2, &dev).unwrap();
assert_buf_eq(&c, &dev, &[4.0, 5.0]);
}
#[test]
fn small_matmul_vs_cublas() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let m = 1;
let k = 64;
let n = 64;
let a_data: Vec<f32> = (0..m * k)
.map(|i| ((i * 7 + 3) % 100) as f32 / 100.0)
.collect();
let b_data: Vec<f32> = (0..k * n)
.map(|i| ((i * 11 + 5) % 100) as f32 / 100.0)
.collect();
let a = cpu_to_gpu(&a_data, &dev).unwrap();
let b = cpu_to_gpu(&b_data, &dev).unwrap();
let c_cublas = crate::blas::gpu_matmul_f32(&a, &b, m, k, n, &dev).unwrap();
let cublas_result = gpu_to_cpu(&c_cublas, &dev).unwrap();
let c_ours = gpu_small_matmul(&a, &b, m, k, n, &dev).unwrap();
let our_result = gpu_to_cpu(&c_ours, &dev).unwrap();
assert_eq!(cublas_result.len(), our_result.len());
for (i, (&cb, &ours)) in cublas_result.iter().zip(our_result.iter()).enumerate() {
assert!(
(cb - ours).abs() < 0.1,
"element {i}: cuBLAS={cb}, ours={ours}, diff={}",
(cb - ours).abs()
);
}
}
#[test]
fn strided_copy_identity_contiguous_2d() {
let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
let (dev, input) = setup(&data);
let out =
gpu_strided_copy(&input, &[2, 3], &[3, 1], 0, &dev).expect("strided_copy identity");
assert_buf_eq(&out, &dev, &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn strided_copy_transpose_2d() {
let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
let (dev, input) = setup(&data);
let out =
gpu_strided_copy(&input, &[3, 2], &[1, 3], 0, &dev).expect("strided_copy transpose");
assert_buf_eq(&out, &dev, &[0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
}
#[test]
fn strided_copy_sliced_column() {
let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
let (dev, input) = setup(&data);
let out = gpu_strided_copy(&input, &[3], &[4], 2, &dev).expect("strided_copy col slice");
assert_buf_eq(&out, &dev, &[2.0, 6.0, 10.0]);
}
#[test]
fn strided_copy_3d_permute() {
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let (dev, input) = setup(&data);
let out =
gpu_strided_copy(&input, &[2, 4, 3], &[12, 1, 4], 0, &dev).expect("strided_copy 3d");
let mut expected = vec![0.0f32; 24];
for b in 0..2 {
for i in 0..4 {
for j in 0..3 {
let dst = b * 12 + i * 3 + j;
let src = b * 12 + j * 4 + i;
expected[dst] = data[src];
}
}
}
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn strided_copy_4d_max_rank_supported() {
let shape = [2usize, 3, 2, 2];
let n: usize = shape.iter().product();
let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
let (dev, input) = setup(&data);
let out =
gpu_strided_copy(&input, &shape, &[12, 4, 2, 1], 0, &dev).expect("strided_copy 4d");
assert_buf_eq(&out, &dev, &data);
}
#[test]
fn strided_copy_rejects_too_many_dims() {
let (dev, input) = setup(&[0.0f32; 16]);
let result = gpu_strided_copy(&input, &[1, 1, 1, 1, 1, 1, 1, 1, 16], &[1; 9], 0, &dev);
assert!(result.is_err());
}
#[test]
fn strided_copy_rejects_shape_stride_length_mismatch() {
let (dev, input) = setup(&[0.0f32; 12]);
let result = gpu_strided_copy(&input, &[3, 4], &[4, 1, 1], 0, &dev);
assert!(result.is_err());
}
#[test]
fn strided_copy_rejects_negative_stride() {
let (dev, input) = setup(&[0.0f32; 6]);
let result = gpu_strided_copy(&input, &[2, 3], &[3, -1], 0, &dev);
assert!(result.is_err());
}
#[test]
fn strided_copy_empty_output() {
let (dev, input) = setup(&[1.0f32, 2.0, 3.0]);
let out = gpu_strided_copy(&input, &[0, 3], &[3, 1], 0, &dev).expect("strided_copy empty");
assert_eq!(out.len(), 0);
}
#[test]
fn strided_copy_f64_transpose_matches_f32() {
let data: Vec<f64> = (0..6).map(|i| i as f64).collect();
let dev = GpuDevice::new(0).expect("CUDA device 0");
let input = cpu_to_gpu(&data, &dev).expect("cpu_to_gpu f64");
let out = gpu_strided_copy_f64(&input, &[3, 2], &[1, 3], 0, &dev)
.expect("strided_copy_f64 transpose");
let host = gpu_to_cpu(&out, &dev).expect("gpu_to_cpu f64");
assert_eq!(host, vec![0.0, 3.0, 1.0, 4.0, 2.0, 5.0]);
}
#[test]
fn has_inf_nan_f32_all_finite_returns_false() {
let host = vec![1.0_f32, 2.0, 3.0, -4.5, 0.0];
let (dev, buf) = setup(&host);
assert!(!gpu_has_inf_nan(&buf, &dev).expect("kernel must succeed"));
}
#[test]
fn has_inf_nan_f32_with_nan_returns_true() {
let host = vec![1.0_f32, f32::NAN, 3.0];
let (dev, buf) = setup(&host);
assert!(gpu_has_inf_nan(&buf, &dev).expect("kernel must succeed"));
}
#[test]
fn has_inf_nan_f32_with_inf_returns_true() {
let host = vec![1.0_f32, f32::INFINITY, 3.0];
let (dev, buf) = setup(&host);
assert!(gpu_has_inf_nan(&buf, &dev).expect("kernel must succeed"));
}
#[test]
fn has_inf_nan_f32_with_neg_inf_returns_true() {
let host = vec![1.0_f32, f32::NEG_INFINITY, 3.0];
let (dev, buf) = setup(&host);
assert!(gpu_has_inf_nan(&buf, &dev).expect("kernel must succeed"));
}
#[test]
fn has_inf_nan_f32_empty_buffer_returns_false() {
let dev = GpuDevice::new(0).expect("CUDA device 0");
let buf = crate::transfer::alloc_zeros_f32(0, &dev).expect("alloc empty");
assert!(!gpu_has_inf_nan(&buf, &dev).expect("must short-circuit"));
}
#[test]
fn has_inf_nan_f32_large_finite_returns_false() {
let host: Vec<f32> = (0..10_000).map(|i| i as f32 * 0.5).collect();
let (dev, buf) = setup(&host);
assert!(!gpu_has_inf_nan(&buf, &dev).expect("kernel must succeed"));
}
fn cpu_ref_index_select_dim<T: Copy + Default>(
input: &[T],
indices: &[usize],
outer: usize,
in_dim_size: usize,
inner: usize,
) -> Vec<T> {
let out_dim_size = indices.len();
let total = outer * out_dim_size * inner;
let mut out = vec![T::default(); total];
for o in 0..outer {
for i in 0..out_dim_size {
let src_i = indices[i];
let in_base = o * in_dim_size * inner + src_i * inner;
let out_base = o * out_dim_size * inner + i * inner;
out[out_base..out_base + inner]
.copy_from_slice(&input[in_base..in_base + inner]);
}
}
out
}
#[test]
fn gpu_index_select_dim_f32_basic() {
let outer = 2usize;
let in_dim_size = 3usize;
let inner = 4usize;
let indices_usize: Vec<usize> = vec![2, 0, 1];
let indices_f32: Vec<f32> = indices_usize.iter().map(|&u| u as f32).collect();
let numel = outer * in_dim_size * inner;
let input: Vec<f32> = (0..numel).map(|i| i as f32).collect();
let expected =
cpu_ref_index_select_dim(&input, &indices_usize, outer, in_dim_size, inner);
let (dev, in_buf) = setup(&input);
let idx_buf = cpu_to_gpu(&indices_f32, &dev).expect("cpu_to_gpu indices");
let out = gpu_index_select_dim(
&in_buf,
&idx_buf,
outer,
in_dim_size,
indices_usize.len(),
inner,
&dev,
)
.expect("gpu_index_select_dim");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn gpu_index_select_dim_f64_basic() {
let outer = 2usize;
let in_dim_size = 3usize;
let inner = 4usize;
let indices_usize: Vec<usize> = vec![2, 0, 1];
let indices_f32: Vec<f32> = indices_usize.iter().map(|&u| u as f32).collect();
let numel = outer * in_dim_size * inner;
let input: Vec<f64> = (0..numel).map(|i| i as f64 + 0.25).collect();
let expected =
cpu_ref_index_select_dim(&input, &indices_usize, outer, in_dim_size, inner);
let dev = GpuDevice::new(0).expect("CUDA device 0");
let in_buf = cpu_to_gpu(&input, &dev).expect("cpu_to_gpu input f64");
let idx_buf = cpu_to_gpu(&indices_f32, &dev).expect("cpu_to_gpu indices");
let out = gpu_index_select_dim_f64(
&in_buf,
&idx_buf,
outer,
in_dim_size,
indices_usize.len(),
inner,
&dev,
)
.expect("gpu_index_select_dim_f64");
let host = gpu_to_cpu(&out, &dev).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "length mismatch");
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-12,
"f64 element {i}: got {got}, expected {exp}",
);
}
}
#[test]
fn gpu_index_select_dim_f32_dim_0() {
let outer = 1usize;
let in_dim_size = 4usize;
let inner = 3usize;
let indices_usize: Vec<usize> = vec![3, 1, 0];
let indices_f32: Vec<f32> = indices_usize.iter().map(|&u| u as f32).collect();
let input: Vec<f32> = vec![
10.0, 11.0, 12.0, 20.0, 21.0, 22.0, 30.0, 31.0, 32.0, 40.0, 41.0, 42.0, ];
let expected =
cpu_ref_index_select_dim(&input, &indices_usize, outer, in_dim_size, inner);
let (dev, in_buf) = setup(&input);
let idx_buf = cpu_to_gpu(&indices_f32, &dev).expect("cpu_to_gpu indices");
let out = gpu_index_select_dim(
&in_buf,
&idx_buf,
outer,
in_dim_size,
indices_usize.len(),
inner,
&dev,
)
.expect("gpu_index_select_dim dim_0");
assert_buf_eq(
&out,
&dev,
&[40.0, 41.0, 42.0, 20.0, 21.0, 22.0, 10.0, 11.0, 12.0],
);
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn gpu_index_select_dim_f32_dim_last() {
let outer = 2usize;
let in_dim_size = 4usize;
let inner = 1usize;
let indices_usize: Vec<usize> = vec![3, 2, 1, 0];
let indices_f32: Vec<f32> = indices_usize.iter().map(|&u| u as f32).collect();
let input: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let expected =
cpu_ref_index_select_dim(&input, &indices_usize, outer, in_dim_size, inner);
assert_eq!(expected, vec![4.0, 3.0, 2.0, 1.0, 8.0, 7.0, 6.0, 5.0]);
let (dev, in_buf) = setup(&input);
let idx_buf = cpu_to_gpu(&indices_f32, &dev).expect("cpu_to_gpu indices");
let out = gpu_index_select_dim(
&in_buf,
&idx_buf,
outer,
in_dim_size,
indices_usize.len(),
inner,
&dev,
)
.expect("gpu_index_select_dim dim_last");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn gpu_index_select_dim_f32_repeated_indices() {
let outer = 1usize;
let in_dim_size = 3usize;
let inner = 1usize;
let indices_usize: Vec<usize> = vec![0, 0, 1, 1];
let indices_f32: Vec<f32> = indices_usize.iter().map(|&u| u as f32).collect();
let input: Vec<f32> = vec![7.5, -3.0, 11.0];
let expected =
cpu_ref_index_select_dim(&input, &indices_usize, outer, in_dim_size, inner);
assert_eq!(expected, vec![7.5, 7.5, -3.0, -3.0]);
let (dev, in_buf) = setup(&input);
let idx_buf = cpu_to_gpu(&indices_f32, &dev).expect("cpu_to_gpu indices");
let out = gpu_index_select_dim(
&in_buf,
&idx_buf,
outer,
in_dim_size,
indices_usize.len(),
inner,
&dev,
)
.expect("gpu_index_select_dim repeated");
assert_buf_eq(&out, &dev, &expected);
}
#[test]
fn gpu_index_select_dim_f32_random_permutation() {
let outer = 2usize;
let in_dim_size = 7usize;
let inner = 5usize;
let indices_usize: Vec<usize> = vec![4, 0, 6, 2, 5, 1, 3];
let indices_f32: Vec<f32> = indices_usize.iter().map(|&u| u as f32).collect();
let numel = outer * in_dim_size * inner;
let input: Vec<f32> = (0..numel).map(|i| (i as f32) * 0.5 - 1.25).collect();
let expected =
cpu_ref_index_select_dim(&input, &indices_usize, outer, in_dim_size, inner);
let (dev, in_buf) = setup(&input);
let idx_buf = cpu_to_gpu(&indices_f32, &dev).expect("cpu_to_gpu indices");
let out = gpu_index_select_dim(
&in_buf,
&idx_buf,
outer,
in_dim_size,
indices_usize.len(),
inner,
&dev,
)
.expect("gpu_index_select_dim random permutation");
let host = gpu_to_cpu(&out, &dev).expect("gpu_to_cpu");
assert_eq!(host.len(), expected.len(), "length mismatch");
for (i, (&got, &exp)) in host.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 1e-9,
"permutation element {i}: got {got}, expected {exp}",
);
}
assert_eq!(out.device_ordinal(), in_buf.device_ordinal());
}
}